use crate::error::{NeuralError, Result};
use crate::layers::{Layer, ParamLayer};
use scirs2_core::ndarray::{Array, ArrayView, ArrayView1, Ix2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
const RNN_SIMD_THRESHOLD: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RecurrentActivation {
Tanh,
Sigmoid,
ReLU,
}
#[derive(Debug, Clone)]
pub struct RNNConfig {
pub input_size: usize,
pub hidden_size: usize,
pub activation: RecurrentActivation,
}
impl RecurrentActivation {
pub fn apply<F: Float>(&self, x: F) -> F {
match self {
RecurrentActivation::Tanh => x.tanh(),
RecurrentActivation::Sigmoid => F::one() / (F::one() + (-x).exp()),
RecurrentActivation::ReLU => {
if x > F::zero() {
x
} else {
F::zero()
}
}
}
}
#[allow(dead_code)]
pub fn apply_array<F: Float + ScalarOperand>(&self, x: &Array<F, IxDyn>) -> Array<F, IxDyn> {
match self {
RecurrentActivation::Tanh => x.mapv(|v| v.tanh()),
RecurrentActivation::Sigmoid => x.mapv(|v| F::one() / (F::one() + (-v).exp())),
RecurrentActivation::ReLU => x.mapv(|v| if v > F::zero() { v } else { F::zero() }),
}
}
}
pub struct RNN<F: Float + Debug + Send + Sync + NumAssign> {
input_size: usize,
hidden_size: usize,
activation: RecurrentActivation,
weight_ih: Array<F, IxDyn>,
weight_hh: Array<F, IxDyn>,
bias_ih: Array<F, IxDyn>,
bias_hh: Array<F, IxDyn>,
dweight_ih: Array<F, IxDyn>,
dweight_hh: Array<F, IxDyn>,
dbias_ih: Array<F, IxDyn>,
dbias_hh: Array<F, IxDyn>,
input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
hidden_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + 'static + NumAssign> RNN<F> {
pub fn new<R: Rng>(
input_size: usize,
hidden_size: usize,
activation: RecurrentActivation,
rng: &mut R,
) -> Result<Self> {
if input_size == 0 || hidden_size == 0 {
return Err(NeuralError::InvalidArchitecture(
"Input _size and hidden _size must be positive".to_string(),
));
}
let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
})?;
let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
})?;
let mut weight_ih_vec: Vec<F> = Vec::with_capacity(hidden_size * input_size);
for _ in 0..(hidden_size * input_size) {
let rand_val = rng.random_range(-1.0..1.0);
let val = F::from(rand_val).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
})?;
weight_ih_vec.push(val * scale_ih);
}
let weight_ih = Array::from_shape_vec(IxDyn(&[hidden_size, input_size]), weight_ih_vec)
.map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
})?;
let mut weight_hh_vec: Vec<F> = Vec::with_capacity(hidden_size * hidden_size);
for _ in 0..(hidden_size * hidden_size) {
let rand_val = rng.random_range(-1.0..1.0);
let val = F::from(rand_val).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
})?;
weight_hh_vec.push(val * scale_hh);
}
let weight_hh = Array::from_shape_vec(IxDyn(&[hidden_size, hidden_size]), weight_hh_vec)
.map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
})?;
let bias_ih = Array::zeros(IxDyn(&[hidden_size]));
let bias_hh = Array::zeros(IxDyn(&[hidden_size]));
let dweight_ih = Array::zeros(weight_ih.dim());
let dweight_hh = Array::zeros(weight_hh.dim());
let dbias_ih = Array::zeros(bias_ih.dim());
let dbias_hh = Array::zeros(bias_hh.dim());
Ok(Self {
input_size,
hidden_size,
activation,
weight_ih,
weight_hh,
bias_ih,
bias_hh,
dweight_ih,
dweight_hh,
dbias_ih,
dbias_hh,
input_cache: Arc::new(RwLock::new(None)),
hidden_states_cache: Arc::new(RwLock::new(None)),
})
}
fn should_use_simd(&self) -> bool {
self.input_size + self.hidden_size >= RNN_SIMD_THRESHOLD
}
fn step(&self, x: &ArrayView<F, IxDyn>, h: &ArrayView<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if self.should_use_simd() {
self.step_simd(x, h)
} else {
self.step_naive(x, h)
}
}
fn step_simd(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let xshape = x.shape();
let hshape = h.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden state dimension mismatch: expected {}, got {}",
self.hidden_size, hshape[1]
)));
}
if xshape[0] != hshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}",
xshape[0], hshape[0]
)));
}
let mut new_h = Array::zeros((batch_size, self.hidden_size));
for b in 0..batch_size {
let x_b = x.slice(scirs2_core::ndarray::s![b, ..]);
let x_view: ArrayView1<F> = x_b.into_dimensionality().expect("Operation failed");
let h_b = h.slice(scirs2_core::ndarray::s![b, ..]);
let h_view: ArrayView1<F> = h_b.into_dimensionality().expect("Operation failed");
for i in 0..self.hidden_size {
let wih_row = self.weight_ih.slice(scirs2_core::ndarray::s![i, ..]);
let wih_view: ArrayView1<F> =
wih_row.into_dimensionality().expect("Operation failed");
let whh_row = self.weight_hh.slice(scirs2_core::ndarray::s![i, ..]);
let whh_view: ArrayView1<F> =
whh_row.into_dimensionality().expect("Operation failed");
let ih_sum = self.bias_ih[i] + F::simd_dot(&wih_view, &x_view);
let hh_sum = self.bias_hh[i] + F::simd_dot(&whh_view, &h_view);
new_h[[b, i]] = self.activation.apply(ih_sum + hh_sum);
}
}
Ok(new_h.into_dyn())
}
fn step_naive(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let xshape = x.shape();
let hshape = h.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden state dimension mismatch: expected {}, got {}",
self.hidden_size, hshape[1]
)));
}
if xshape[0] != hshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}",
xshape[0], hshape[0]
)));
}
let mut new_h = Array::zeros((batch_size, self.hidden_size));
for b in 0..batch_size {
for i in 0..self.hidden_size {
let mut ih_sum = self.bias_ih[i];
for j in 0..self.input_size {
ih_sum += self.weight_ih[[i, j]] * x[[b, j]];
}
let mut hh_sum = self.bias_hh[i];
for j in 0..self.hidden_size {
hh_sum += self.weight_hh[[i, j]] * h[[b, j]];
}
new_h[[b, i]] = self.activation.apply(ih_sum + hh_sum);
}
}
Ok(new_h.into_dyn())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + 'static + NumAssign> Layer<F>
for RNN<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if let Ok(mut cache) = self.input_cache.write() {
*cache = Some(input.to_owned());
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on input cache".to_string(),
));
}
let inputshape = input.shape();
if inputshape.len() != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
)));
}
let batch_size = inputshape[0];
let seq_len = inputshape[1];
let features = inputshape[2];
if features != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input features dimension mismatch: expected {}, got {}",
self.input_size, features
)));
}
let mut h = Array::zeros((batch_size, self.hidden_size));
let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
for t in 0..seq_len {
let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
let x_t_view = x_t.view().into_dyn();
let h_view = h.view().into_dyn();
h = self
.step(&x_t_view, &h_view)?
.into_dimensionality::<Ix2>()
.expect("Operation failed");
for b in 0..batch_size {
for i in 0..self.hidden_size {
all_hidden_states[[b, t, i]] = h[[b, i]];
}
}
}
if let Ok(mut cache) = self.hidden_states_cache.write() {
*cache = Some(all_hidden_states.to_owned().into_dyn());
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on hidden states cache".to_string(),
));
}
Ok(all_hidden_states.into_dyn())
}
fn backward(
&self,
input: &Array<F, IxDyn>,
_grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let input_ref = match self.input_cache.read() {
Ok(guard) => guard,
Err(_) => {
return Err(NeuralError::InferenceError(
"Failed to acquire read lock on input cache".to_string(),
))
}
};
let hidden_states_ref = match self.hidden_states_cache.read() {
Ok(guard) => guard,
Err(_) => {
return Err(NeuralError::InferenceError(
"Failed to acquire read lock on hidden states cache".to_string(),
))
}
};
if input_ref.is_none() || hidden_states_ref.is_none() {
return Err(NeuralError::InferenceError(
"No cached values for backward pass. Call forward() first.".to_string(),
));
}
let grad_input = Array::zeros(input.dim());
Ok(grad_input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
let small_change = F::from(0.001).expect("Failed to convert constant to float");
let lr = small_change * learningrate;
for w in self.weight_ih.iter_mut() {
*w -= lr;
}
for w in self.weight_hh.iter_mut() {
*w -= lr;
}
for b in self.bias_ih.iter_mut() {
*b -= lr;
}
for b in self.bias_hh.iter_mut() {
*b -= lr;
}
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + 'static + NumAssign>
ParamLayer<F> for RNN<F>
{
fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
vec![
self.weight_ih.clone(),
self.weight_hh.clone(),
self.bias_ih.clone(),
self.bias_hh.clone(),
]
}
fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
vec![
self.dweight_ih.clone(),
self.dweight_hh.clone(),
self.dbias_ih.clone(),
self.dbias_hh.clone(),
]
}
fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
if params.len() != 4 {
return Err(NeuralError::InvalidArchitecture(format!(
"Expected 4 parameters, got {}",
params.len()
)));
}
if params[0].shape() != self.weight_ih.shape() {
return Err(NeuralError::InvalidArchitecture(format!(
"Weight_ih shape mismatch: expected {:?}, got {:?}",
self.weight_ih.shape(),
params[0].shape()
)));
}
if params[1].shape() != self.weight_hh.shape() {
return Err(NeuralError::InvalidArchitecture(format!(
"Weight_hh shape mismatch: expected {:?}, got {:?}",
self.weight_hh.shape(),
params[1].shape()
)));
}
if params[2].shape() != self.bias_ih.shape() {
return Err(NeuralError::InvalidArchitecture(format!(
"Bias_ih shape mismatch: expected {:?}, got {:?}",
self.bias_ih.shape(),
params[2].shape()
)));
}
if params[3].shape() != self.bias_hh.shape() {
return Err(NeuralError::InvalidArchitecture(format!(
"Bias_hh shape mismatch: expected {:?}, got {:?}",
self.bias_hh.shape(),
params[3].shape()
)));
}
self.weight_ih = params[0].clone();
self.weight_hh = params[1].clone();
self.bias_ih = params[2].clone();
self.bias_hh = params[3].clone();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
use scirs2_core::random::SeedableRng;
#[test]
fn test_rnnshape() {
let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
let rnn = RNN::<f64>::new(
10, 20, RecurrentActivation::Tanh, &mut rng,
)
.expect("Operation failed");
let batch_size = 2;
let seq_len = 5;
let input_size = 10;
let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
let output = rnn.forward(&input).expect("Operation failed");
assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
}
#[test]
fn test_recurrent_activations() {
let tanh = RecurrentActivation::Tanh;
let sigmoid = RecurrentActivation::Sigmoid;
let relu = RecurrentActivation::ReLU;
assert_eq!(tanh.apply(0.0f64), 0.0f64.tanh());
assert_eq!(tanh.apply(1.0f64), 1.0f64.tanh());
assert_eq!(tanh.apply(-1.0f64), (-1.0f64).tanh());
assert_eq!(sigmoid.apply(0.0f64), 0.5f64);
assert!((sigmoid.apply(10.0f64) - 1.0).abs() < 1e-4);
assert!(sigmoid.apply(-10.0f64).abs() < 1e-4);
assert_eq!(relu.apply(1.0f64), 1.0f64);
assert_eq!(relu.apply(-1.0f64), 0.0f64);
assert_eq!(relu.apply(0.0f64), 0.0f64);
}
}