use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub enum RecurrentError {
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
InvalidHiddenSize(usize),
InvalidInputSize(usize),
EmptySequence,
InvalidSequenceLength {
got: usize,
},
}
impl std::fmt::Display for RecurrentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecurrentError::ShapeMismatch { expected, got } => {
write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
}
RecurrentError::InvalidHiddenSize(s) => {
write!(f, "invalid hidden_size: {s}")
}
RecurrentError::InvalidInputSize(s) => {
write!(f, "invalid input_size: {s}")
}
RecurrentError::EmptySequence => {
write!(f, "input sequence must not be empty")
}
RecurrentError::InvalidSequenceLength { got } => {
write!(f, "invalid sequence length: {got}")
}
}
}
}
impl std::error::Error for RecurrentError {}
#[inline]
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
#[inline]
fn lcg_value(state: &mut u64, scale: f64) -> f64 {
*state = state
.wrapping_mul(6364136223846793005_u64)
.wrapping_add(1442695040888963407_u64);
let normalised = (*state as f64) / (u64::MAX as f64); (normalised * 2.0 - 1.0) * scale
}
fn lcg_fill_2d(rows: usize, cols: usize, scale: f64, state: &mut u64) -> Array2<f64> {
let data: Vec<f64> = (0..rows * cols).map(|_| lcg_value(state, scale)).collect();
Array2::from_shape_vec((rows, cols), data).unwrap_or_else(|_| Array2::zeros((rows, cols)))
}
fn lcg_fill_1d(len: usize, scale: f64, state: &mut u64) -> Array1<f64> {
let data: Vec<f64> = (0..len).map(|_| lcg_value(state, scale)).collect();
Array1::from_vec(data)
}
#[derive(Debug, Clone)]
pub struct RnnCell {
pub input_size: usize,
pub hidden_size: usize,
pub w_ih: Array2<f64>,
pub w_hh: Array2<f64>,
pub b_ih: Array1<f64>,
pub b_hh: Array1<f64>,
}
impl RnnCell {
pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
if hidden_size == 0 {
return Err(RecurrentError::InvalidHiddenSize(hidden_size));
}
let scale = 0.1_f64;
let mut state: u64 = 0xdeadbeef_12345678_u64;
let w_ih = lcg_fill_2d(hidden_size, input_size, scale, &mut state);
let w_hh = lcg_fill_2d(hidden_size, hidden_size, scale, &mut state);
let b_ih = lcg_fill_1d(hidden_size, scale, &mut state);
let b_hh = lcg_fill_1d(hidden_size, scale, &mut state);
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn from_weights(
w_ih: Array2<f64>,
w_hh: Array2<f64>,
b_ih: Array1<f64>,
b_hh: Array1<f64>,
) -> Result<Self, RecurrentError> {
let hidden_size = w_ih.nrows();
let input_size = w_ih.ncols();
if hidden_size == 0 {
return Err(RecurrentError::InvalidHiddenSize(hidden_size));
}
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
if w_hh.nrows() != hidden_size || w_hh.ncols() != hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![hidden_size, hidden_size],
got: vec![w_hh.nrows(), w_hh.ncols()],
});
}
if b_ih.len() != hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![hidden_size],
got: vec![b_ih.len()],
});
}
if b_hh.len() != hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![hidden_size],
got: vec![b_hh.len()],
});
}
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn forward(
&self,
input: &Array1<f64>,
hidden: &Array1<f64>,
) -> Result<Array1<f64>, RecurrentError> {
if input.len() != self.input_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.input_size],
got: vec![input.len()],
});
}
if hidden.len() != self.hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.hidden_size],
got: vec![hidden.len()],
});
}
let pre_act = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(hidden) + &self.b_hh;
Ok(pre_act.mapv(f64::tanh))
}
pub fn init_hidden(&self) -> Array1<f64> {
Array1::zeros(self.hidden_size)
}
pub fn num_parameters(&self) -> usize {
self.hidden_size * self.input_size + self.hidden_size * self.hidden_size + self.hidden_size + self.hidden_size }
}
#[derive(Debug, Clone)]
pub struct LstmState {
pub h: Array1<f64>,
pub c: Array1<f64>,
}
impl LstmState {
pub fn zeros(hidden_size: usize) -> Self {
Self {
h: Array1::zeros(hidden_size),
c: Array1::zeros(hidden_size),
}
}
}
#[derive(Debug, Clone)]
pub struct LstmCell {
pub input_size: usize,
pub hidden_size: usize,
pub w_ih: Array2<f64>,
pub w_hh: Array2<f64>,
pub b_ih: Array1<f64>,
pub b_hh: Array1<f64>,
}
impl LstmCell {
pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
if hidden_size == 0 {
return Err(RecurrentError::InvalidHiddenSize(hidden_size));
}
let scale = 0.1_f64;
let mut state: u64 = 0xfeedface_abcd1234_u64;
let gates = 4;
let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn from_weights(
w_ih: Array2<f64>,
w_hh: Array2<f64>,
b_ih: Array1<f64>,
b_hh: Array1<f64>,
) -> Result<Self, RecurrentError> {
let input_size = w_ih.ncols();
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
let combined_rows = w_ih.nrows();
if combined_rows == 0 || !combined_rows.is_multiple_of(4) {
return Err(RecurrentError::ShapeMismatch {
expected: vec![0 , input_size],
got: vec![combined_rows, input_size],
});
}
let hidden_size = combined_rows / 4;
if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows, hidden_size],
got: vec![w_hh.nrows(), w_hh.ncols()],
});
}
if b_ih.len() != combined_rows {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows],
got: vec![b_ih.len()],
});
}
if b_hh.len() != combined_rows {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows],
got: vec![b_hh.len()],
});
}
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn forward(
&self,
input: &Array1<f64>,
state: &LstmState,
) -> Result<LstmState, RecurrentError> {
if input.len() != self.input_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.input_size],
got: vec![input.len()],
});
}
if state.h.len() != self.hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.hidden_size],
got: vec![state.h.len()],
});
}
if state.c.len() != self.hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.hidden_size],
got: vec![state.c.len()],
});
}
let gates_pre = self.w_ih.dot(input) + &self.b_ih + self.w_hh.dot(&state.h) + &self.b_hh;
let h = self.hidden_size;
let i_pre = gates_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
let f_pre = gates_pre
.slice(scirs2_core::ndarray::s![h..2 * h])
.to_owned();
let g_pre = gates_pre
.slice(scirs2_core::ndarray::s![2 * h..3 * h])
.to_owned();
let o_pre = gates_pre
.slice(scirs2_core::ndarray::s![3 * h..])
.to_owned();
let i_gate = i_pre.mapv(sigmoid);
let f_gate = f_pre.mapv(sigmoid);
let g_gate = g_pre.mapv(f64::tanh);
let o_gate = o_pre.mapv(sigmoid);
let new_c = &f_gate * &state.c + &i_gate * &g_gate;
let new_h = &o_gate * new_c.mapv(f64::tanh);
Ok(LstmState { h: new_h, c: new_c })
}
pub fn init_state(&self) -> LstmState {
LstmState::zeros(self.hidden_size)
}
pub fn num_parameters(&self) -> usize {
let gates = 4;
gates * self.hidden_size * self.input_size + gates * self.hidden_size * self.hidden_size + gates * self.hidden_size + gates * self.hidden_size }
}
#[derive(Debug, Clone)]
pub struct GruCell {
pub input_size: usize,
pub hidden_size: usize,
pub w_ih: Array2<f64>,
pub w_hh: Array2<f64>,
pub b_ih: Array1<f64>,
pub b_hh: Array1<f64>,
}
impl GruCell {
pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, RecurrentError> {
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
if hidden_size == 0 {
return Err(RecurrentError::InvalidHiddenSize(hidden_size));
}
let scale = 0.1_f64;
let mut state: u64 = 0xc0ffee00_87654321_u64;
let gates = 3;
let w_ih = lcg_fill_2d(gates * hidden_size, input_size, scale, &mut state);
let w_hh = lcg_fill_2d(gates * hidden_size, hidden_size, scale, &mut state);
let b_ih = lcg_fill_1d(gates * hidden_size, scale, &mut state);
let b_hh = lcg_fill_1d(gates * hidden_size, scale, &mut state);
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn from_weights(
w_ih: Array2<f64>,
w_hh: Array2<f64>,
b_ih: Array1<f64>,
b_hh: Array1<f64>,
) -> Result<Self, RecurrentError> {
let input_size = w_ih.ncols();
if input_size == 0 {
return Err(RecurrentError::InvalidInputSize(input_size));
}
let combined_rows = w_ih.nrows();
if combined_rows == 0 || !combined_rows.is_multiple_of(3) {
return Err(RecurrentError::ShapeMismatch {
expected: vec![0 , input_size],
got: vec![combined_rows, input_size],
});
}
let hidden_size = combined_rows / 3;
if w_hh.nrows() != combined_rows || w_hh.ncols() != hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows, hidden_size],
got: vec![w_hh.nrows(), w_hh.ncols()],
});
}
if b_ih.len() != combined_rows {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows],
got: vec![b_ih.len()],
});
}
if b_hh.len() != combined_rows {
return Err(RecurrentError::ShapeMismatch {
expected: vec![combined_rows],
got: vec![b_hh.len()],
});
}
Ok(Self {
input_size,
hidden_size,
w_ih,
w_hh,
b_ih,
b_hh,
})
}
pub fn forward(
&self,
input: &Array1<f64>,
hidden: &Array1<f64>,
) -> Result<Array1<f64>, RecurrentError> {
if input.len() != self.input_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.input_size],
got: vec![input.len()],
});
}
if hidden.len() != self.hidden_size {
return Err(RecurrentError::ShapeMismatch {
expected: vec![self.hidden_size],
got: vec![hidden.len()],
});
}
let h = self.hidden_size;
let x_pre = self.w_ih.dot(input) + &self.b_ih;
let h_pre = self.w_hh.dot(hidden) + &self.b_hh;
let r_pre = x_pre.slice(scirs2_core::ndarray::s![..h]).to_owned()
+ h_pre.slice(scirs2_core::ndarray::s![..h]).to_owned();
let z_pre = x_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned()
+ h_pre.slice(scirs2_core::ndarray::s![h..2 * h]).to_owned();
let r_gate = r_pre.mapv(sigmoid);
let z_gate = z_pre.mapv(sigmoid);
let n_x = x_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
let n_h = h_pre.slice(scirs2_core::ndarray::s![2 * h..]).to_owned();
let n_pre = n_x + &r_gate * n_h;
let n_gate = n_pre.mapv(f64::tanh);
let ones = Array1::<f64>::ones(h);
let new_h = (&ones - &z_gate) * &n_gate + &z_gate * hidden;
Ok(new_h)
}
pub fn init_hidden(&self) -> Array1<f64> {
Array1::zeros(self.hidden_size)
}
pub fn num_parameters(&self) -> usize {
let gates = 3;
gates * self.hidden_size * self.input_size + gates * self.hidden_size * self.hidden_size + gates * self.hidden_size + gates * self.hidden_size }
}
pub fn rnn_sequence(
cell: &RnnCell,
inputs: &[Array1<f64>],
) -> Result<Vec<Array1<f64>>, RecurrentError> {
if inputs.is_empty() {
return Err(RecurrentError::EmptySequence);
}
let mut hidden = cell.init_hidden();
let mut outputs = Vec::with_capacity(inputs.len());
for x in inputs {
hidden = cell.forward(x, &hidden)?;
outputs.push(hidden.clone());
}
Ok(outputs)
}
pub fn lstm_sequence(
cell: &LstmCell,
inputs: &[Array1<f64>],
) -> Result<(Vec<Array1<f64>>, LstmState), RecurrentError> {
if inputs.is_empty() {
return Err(RecurrentError::EmptySequence);
}
let mut state = cell.init_state();
let mut hidden_states = Vec::with_capacity(inputs.len());
for x in inputs {
state = cell.forward(x, &state)?;
hidden_states.push(state.h.clone());
}
Ok((hidden_states, state))
}
pub fn gru_sequence(
cell: &GruCell,
inputs: &[Array1<f64>],
) -> Result<Vec<Array1<f64>>, RecurrentError> {
if inputs.is_empty() {
return Err(RecurrentError::EmptySequence);
}
let mut hidden = cell.init_hidden();
let mut outputs = Vec::with_capacity(inputs.len());
for x in inputs {
hidden = cell.forward(x, &hidden)?;
outputs.push(hidden.clone());
}
Ok(outputs)
}
#[derive(Debug, Clone)]
pub struct RecurrentStats {
pub cell_type: String,
pub input_size: usize,
pub hidden_size: usize,
pub num_parameters: usize,
pub sequence_length: Option<usize>,
}
impl RecurrentStats {
pub fn summary(&self) -> String {
let seq = match self.sequence_length {
Some(t) => format!("seq_len={t}"),
None => "seq_len=n/a".to_string(),
};
format!(
"{} | input={} hidden={} params={} {}",
self.cell_type, self.input_size, self.hidden_size, self.num_parameters, seq
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_rnn_cell_new() {
let cell = RnnCell::new(4, 8);
assert!(cell.is_ok(), "RnnCell::new should succeed");
}
#[test]
fn test_rnn_cell_forward_shape() {
let cell = RnnCell::new(4, 8).expect("construct rnn");
let x = Array1::zeros(4);
let h = cell.init_hidden();
let h_new = cell.forward(&x, &h).expect("rnn forward");
assert_eq!(h_new.len(), 8);
}
#[test]
fn test_rnn_cell_init_hidden() {
let cell = RnnCell::new(3, 5).expect("construct rnn");
let h = cell.init_hidden();
assert_eq!(h.len(), 5);
assert!(h.iter().all(|&v| v == 0.0), "init hidden should be zeros");
}
#[test]
fn test_rnn_cell_num_parameters() {
let input_size = 4;
let hidden_size = 8;
let cell = RnnCell::new(input_size, hidden_size).expect("construct rnn");
let expected =
hidden_size * input_size + hidden_size * hidden_size + hidden_size + hidden_size;
assert_eq!(cell.num_parameters(), expected);
}
#[test]
fn test_lstm_cell_new() {
let cell = LstmCell::new(4, 8);
assert!(cell.is_ok(), "LstmCell::new should succeed");
}
#[test]
fn test_lstm_cell_forward_shape() {
let cell = LstmCell::new(4, 8).expect("construct lstm");
let x = Array1::zeros(4);
let state = cell.init_state();
let new_state = cell.forward(&x, &state).expect("lstm forward");
assert_eq!(new_state.h.len(), 8);
assert_eq!(new_state.c.len(), 8);
}
#[test]
fn test_lstm_cell_init_state() {
let cell = LstmCell::new(3, 6).expect("construct lstm");
let state = cell.init_state();
assert_eq!(state.h.len(), 6);
assert_eq!(state.c.len(), 6);
assert!(state.h.iter().all(|&v| v == 0.0));
assert!(state.c.iter().all(|&v| v == 0.0));
}
#[test]
fn test_lstm_cell_gate_bounds() {
let cell = LstmCell::new(4, 8).expect("construct lstm");
let x = Array1::from_elem(4, 0.5);
let state = cell.init_state();
let new_state = cell.forward(&x, &state).expect("lstm forward");
for &v in new_state.h.iter() {
assert!(v > -1.0 && v < 1.0, "h element out of (-1,1): {v}");
}
}
#[test]
fn test_lstm_cell_num_parameters() {
let input_size = 4;
let hidden_size = 8;
let cell = LstmCell::new(input_size, hidden_size).expect("construct lstm");
let gates = 4;
let expected = gates * hidden_size * input_size
+ gates * hidden_size * hidden_size
+ gates * hidden_size
+ gates * hidden_size;
assert_eq!(cell.num_parameters(), expected);
}
#[test]
fn test_gru_cell_new() {
let cell = GruCell::new(4, 8);
assert!(cell.is_ok(), "GruCell::new should succeed");
}
#[test]
fn test_gru_cell_forward_shape() {
let cell = GruCell::new(4, 8).expect("construct gru");
let x = Array1::zeros(4);
let h = cell.init_hidden();
let h_new = cell.forward(&x, &h).expect("gru forward");
assert_eq!(h_new.len(), 8);
}
#[test]
fn test_gru_cell_hidden_init_zeros() {
let cell = GruCell::new(3, 5).expect("construct gru");
let h = cell.init_hidden();
assert_eq!(h.len(), 5);
assert!(h.iter().all(|&v| v == 0.0));
}
#[test]
fn test_gru_cell_num_parameters() {
let input_size = 4;
let hidden_size = 8;
let cell = GruCell::new(input_size, hidden_size).expect("construct gru");
let gates = 3;
let expected = gates * hidden_size * input_size
+ gates * hidden_size * hidden_size
+ gates * hidden_size
+ gates * hidden_size;
assert_eq!(cell.num_parameters(), expected);
}
#[test]
fn test_rnn_sequence_length() {
let cell = RnnCell::new(4, 8).expect("rnn");
let inputs: Vec<Array1<f64>> = (0..7).map(|_| Array1::zeros(4)).collect();
let out = rnn_sequence(&cell, &inputs).expect("rnn sequence");
assert_eq!(out.len(), 7, "T inputs → T outputs");
}
#[test]
fn test_rnn_sequence_empty_error() {
let cell = RnnCell::new(4, 8).expect("rnn");
let result = rnn_sequence(&cell, &[]);
assert!(
matches!(result, Err(RecurrentError::EmptySequence)),
"expected EmptySequence error"
);
}
#[test]
fn test_lstm_sequence_length() {
let cell = LstmCell::new(4, 8).expect("lstm");
let inputs: Vec<Array1<f64>> = (0..5).map(|_| Array1::zeros(4)).collect();
let (hidden_states, _) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
assert_eq!(hidden_states.len(), 5);
}
#[test]
fn test_lstm_sequence_final_state_nonzero() {
let cell = LstmCell::new(4, 8).expect("lstm");
let inputs: Vec<Array1<f64>> = (0..3).map(|_| Array1::from_elem(4, 1.0)).collect();
let (_, final_state) = lstm_sequence(&cell, &inputs).expect("lstm sequence");
let h_norm: f64 = final_state.h.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(
h_norm > 1e-12,
"final h should be non-zero for non-zero inputs"
);
}
#[test]
fn test_gru_sequence_length() {
let cell = GruCell::new(4, 8).expect("gru");
let inputs: Vec<Array1<f64>> = (0..6).map(|_| Array1::zeros(4)).collect();
let out = gru_sequence(&cell, &inputs).expect("gru sequence");
assert_eq!(out.len(), 6);
}
#[test]
fn test_recurrent_stats_summary_nonempty() {
let stats = RecurrentStats {
cell_type: "LSTM".to_string(),
input_size: 4,
hidden_size: 8,
num_parameters: 416,
sequence_length: Some(10),
};
let s = stats.summary();
assert!(!s.is_empty(), "summary should not be empty");
assert!(s.contains("LSTM"));
assert!(s.contains("416"));
}
#[test]
fn test_lstm_cell_from_weights_shape_mismatch() {
use scirs2_core::ndarray::Array2;
let w_ih = Array2::zeros((8, 4));
let w_hh = Array2::zeros((8, 3)); let b_ih = Array1::zeros(8);
let b_hh = Array1::zeros(8);
let result = LstmCell::from_weights(w_ih, w_hh, b_ih, b_hh);
assert!(result.is_err(), "should fail due to w_hh shape mismatch");
}
}