pub mod gru;
pub mod lstm;
pub use gru::{GruWeights, gru_cell_forward, gru_sequence_forward};
pub use lstm::{LstmWeights, lstm_cell_forward, lstm_sequence_forward};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RnnDirection {
Forward,
Backward,
Bidirectional,
}
#[derive(Debug, Clone)]
pub struct RnnConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub direction: RnnDirection,
pub dropout: f32,
}
impl RnnConfig {
pub fn new(
input_size: usize,
hidden_size: usize,
num_layers: usize,
direction: RnnDirection,
dropout: f32,
) -> Result<Self, crate::error::DnnError> {
if input_size == 0 {
return Err(crate::error::DnnError::InvalidArgument(
"RNN input_size must be non-zero".into(),
));
}
if hidden_size == 0 {
return Err(crate::error::DnnError::InvalidArgument(
"RNN hidden_size must be non-zero".into(),
));
}
if num_layers == 0 {
return Err(crate::error::DnnError::InvalidArgument(
"RNN num_layers must be non-zero".into(),
));
}
if !(0.0..1.0).contains(&dropout) {
return Err(crate::error::DnnError::InvalidArgument(format!(
"RNN dropout must be in [0.0, 1.0), got {dropout}"
)));
}
Ok(Self {
input_size,
hidden_size,
num_layers,
direction,
dropout,
})
}
#[must_use]
pub fn direction_multiplier(&self) -> usize {
match self.direction {
RnnDirection::Forward | RnnDirection::Backward => 1,
RnnDirection::Bidirectional => 2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rnn_config_valid() {
let cfg = RnnConfig::new(128, 256, 2, RnnDirection::Forward, 0.1);
assert!(cfg.is_ok());
}
#[test]
fn rnn_config_zero_input_size() {
let cfg = RnnConfig::new(0, 256, 1, RnnDirection::Forward, 0.0);
assert!(cfg.is_err());
}
#[test]
fn rnn_config_zero_hidden_size() {
let cfg = RnnConfig::new(128, 0, 1, RnnDirection::Forward, 0.0);
assert!(cfg.is_err());
}
#[test]
fn rnn_config_zero_layers() {
let cfg = RnnConfig::new(128, 256, 0, RnnDirection::Forward, 0.0);
assert!(cfg.is_err());
}
#[test]
fn rnn_config_invalid_dropout() {
let cfg = RnnConfig::new(128, 256, 1, RnnDirection::Forward, 1.0);
assert!(cfg.is_err());
let cfg2 = RnnConfig::new(128, 256, 1, RnnDirection::Forward, -0.1);
assert!(cfg2.is_err());
}
#[test]
fn direction_multiplier_values() {
assert_eq!(
RnnConfig::new(1, 1, 1, RnnDirection::Forward, 0.0)
.map(|c| c.direction_multiplier())
.ok(),
Some(1)
);
assert_eq!(
RnnConfig::new(1, 1, 1, RnnDirection::Backward, 0.0)
.map(|c| c.direction_multiplier())
.ok(),
Some(1)
);
assert_eq!(
RnnConfig::new(1, 1, 1, RnnDirection::Bidirectional, 0.0)
.map(|c| c.direction_multiplier())
.ok(),
Some(2)
);
}
#[test]
fn rnn_direction_debug() {
let _ = format!("{:?}", RnnDirection::Bidirectional);
}
}