use crate::error::{NeuralError, Result};
use crate::layers::recurrent::{GruForwardOutput, GruGateCache};
use crate::layers::{Layer, ParamLayer};
use scirs2_core::ndarray::{Array, ArrayView, ArrayView1, Ix2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Distribution, Uniform};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
const GRU_SIMD_THRESHOLD: usize = 32;
#[derive(Debug, Clone)]
pub struct GRUConfig {
pub input_size: usize,
pub hidden_size: usize,
}
pub struct GRU<F: Float + Debug + NumAssign> {
input_size: usize,
hidden_size: usize,
weight_ir: Array<F, IxDyn>,
weight_hr: Array<F, IxDyn>,
bias_ir: Array<F, IxDyn>,
bias_hr: Array<F, IxDyn>,
weight_iz: Array<F, IxDyn>,
weight_hz: Array<F, IxDyn>,
bias_iz: Array<F, IxDyn>,
bias_hz: Array<F, IxDyn>,
weight_in: Array<F, IxDyn>,
weight_hn: Array<F, IxDyn>,
bias_in: Array<F, IxDyn>,
bias_hn: Array<F, IxDyn>,
#[allow(dead_code)]
gradients: RwLock<Vec<Array<F, IxDyn>>>,
input_cache: RwLock<Option<Array<F, IxDyn>>>,
hidden_states_cache: RwLock<Option<Array<F, IxDyn>>>,
#[allow(dead_code)]
gate_cache: GruGateCache<F>,
}
impl<F: Float + Debug + ScalarOperand + SimdUnifiedOps + 'static + NumAssign> GRU<F> {
pub fn new<R: scirs2_core::random::Rng>(
input_size: usize,
hidden_size: usize,
rng: &mut R,
) -> Result<Self> {
if input_size == 0 || hidden_size == 0 {
return Err(NeuralError::InvalidArchitecture(
"Input _size and hidden _size must be positive".to_string(),
));
}
let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
})?;
let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert hidden _size scale".to_string())
})?;
let mut create_weight_matrix = |rows: usize,
cols: usize,
scale: F|
-> Result<Array<F, IxDyn>> {
let mut weights_vec: Vec<F> = Vec::with_capacity(rows * cols);
let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
NeuralError::InvalidArchitecture(format!(
"Failed to create uniform distribution: {e}"
))
})?;
for _ in 0..(rows * cols) {
let rand_val = uniform.sample(rng);
let val = F::from(rand_val).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
})?;
weights_vec.push(val * scale);
}
Array::from_shape_vec(IxDyn(&[rows, cols]), weights_vec).map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
})
};
let weight_ir = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hr = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_ir: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_hr: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let weight_iz = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hz = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_iz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_hz: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let weight_in = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hn = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_in: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_hn: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let gradients = vec![
Array::zeros(weight_ir.dim()),
Array::zeros(weight_hr.dim()),
Array::zeros(bias_ir.dim()),
Array::zeros(bias_hr.dim()),
Array::zeros(weight_iz.dim()),
Array::zeros(weight_hz.dim()),
Array::zeros(bias_iz.dim()),
Array::zeros(bias_hz.dim()),
Array::zeros(weight_in.dim()),
Array::zeros(weight_hn.dim()),
Array::zeros(bias_in.dim()),
Array::zeros(bias_hn.dim()),
];
Ok(Self {
input_size,
hidden_size,
weight_ir,
weight_hr,
bias_ir,
bias_hr,
weight_iz,
weight_hz,
bias_iz,
bias_hz,
weight_in,
weight_hn,
bias_in,
bias_hn,
gradients: RwLock::new(gradients),
input_cache: RwLock::new(None),
hidden_states_cache: RwLock::new(None),
gate_cache: Arc::new(RwLock::new(None)),
})
}
fn should_use_simd(&self) -> bool {
self.input_size + self.hidden_size >= GRU_SIMD_THRESHOLD
}
fn step(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
) -> Result<GruForwardOutput<F>> {
if self.should_use_simd() {
self.step_simd(x, h)
} else {
self.step_naive(x, h)
}
}
fn step_simd(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
) -> Result<GruForwardOutput<F>> {
let xshape = x.shape();
let hshape = h.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden state dimension mismatch: expected {}, got {}",
self.hidden_size, hshape[1]
)));
}
if xshape[0] != hshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}",
xshape[0], hshape[0]
)));
}
let mut r_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut z_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut n_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
for b in 0..batch_size {
let x_b = x.slice(scirs2_core::ndarray::s![b, ..]);
let x_view: ArrayView1<F> = x_b.into_dimensionality().expect("Operation failed");
let h_b = h.slice(scirs2_core::ndarray::s![b, ..]);
let h_view: ArrayView1<F> = h_b.into_dimensionality().expect("Operation failed");
for i in 0..self.hidden_size {
let wir_row = self.weight_ir.slice(scirs2_core::ndarray::s![i, ..]);
let wir_view: ArrayView1<F> =
wir_row.into_dimensionality().expect("Operation failed");
let whr_row = self.weight_hr.slice(scirs2_core::ndarray::s![i, ..]);
let whr_view: ArrayView1<F> =
whr_row.into_dimensionality().expect("Operation failed");
let wiz_row = self.weight_iz.slice(scirs2_core::ndarray::s![i, ..]);
let wiz_view: ArrayView1<F> =
wiz_row.into_dimensionality().expect("Operation failed");
let whz_row = self.weight_hz.slice(scirs2_core::ndarray::s![i, ..]);
let whz_view: ArrayView1<F> =
whz_row.into_dimensionality().expect("Operation failed");
let win_row = self.weight_in.slice(scirs2_core::ndarray::s![i, ..]);
let win_view: ArrayView1<F> =
win_row.into_dimensionality().expect("Operation failed");
let whn_row = self.weight_hn.slice(scirs2_core::ndarray::s![i, ..]);
let whn_view: ArrayView1<F> =
whn_row.into_dimensionality().expect("Operation failed");
let r_sum = self.bias_ir[i]
+ self.bias_hr[i]
+ F::simd_dot(&wir_view, &x_view)
+ F::simd_dot(&whr_view, &h_view);
r_gate[[b, i]] = F::one() / (F::one() + (-r_sum).exp());
let z_sum = self.bias_iz[i]
+ self.bias_hz[i]
+ F::simd_dot(&wiz_view, &x_view)
+ F::simd_dot(&whz_view, &h_view);
z_gate[[b, i]] = F::one() / (F::one() + (-z_sum).exp());
let n_sum = self.bias_in[i] + F::simd_dot(&win_view, &x_view);
let hn_sum = self.bias_hn[i] + F::simd_dot(&whn_view, &h_view);
n_gate[[b, i]] = (n_sum + r_gate[[b, i]] * hn_sum).tanh();
new_h[[b, i]] =
(F::one() - z_gate[[b, i]]) * n_gate[[b, i]] + z_gate[[b, i]] * h[[b, i]];
}
}
Ok((
new_h.into_dyn(),
(r_gate.into_dyn(), z_gate.into_dyn(), n_gate.into_dyn()),
))
}
fn step_naive(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
) -> Result<GruForwardOutput<F>> {
let xshape = x.shape();
let hshape = h.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden state dimension mismatch: expected {}, got {}",
self.hidden_size, hshape[1]
)));
}
if xshape[0] != hshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}",
xshape[0], hshape[0]
)));
}
let mut r_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut z_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut n_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
for b in 0..batch_size {
for i in 0..self.hidden_size {
let mut r_sum = self.bias_ir[i] + self.bias_hr[i];
for j in 0..self.input_size {
r_sum += self.weight_ir[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
r_sum += self.weight_hr[[i, j]] * h[[b, j]];
}
r_gate[[b, i]] = F::one() / (F::one() + (-r_sum).exp());
let mut z_sum = self.bias_iz[i] + self.bias_hz[i];
for j in 0..self.input_size {
z_sum += self.weight_iz[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
z_sum += self.weight_hz[[i, j]] * h[[b, j]];
}
z_gate[[b, i]] = F::one() / (F::one() + (-z_sum).exp());
let mut n_sum = self.bias_in[i];
for j in 0..self.input_size {
n_sum += self.weight_in[[i, j]] * x[[b, j]];
}
let mut hn_sum = self.bias_hn[i];
for j in 0..self.hidden_size {
hn_sum += self.weight_hn[[i, j]] * h[[b, j]];
}
n_gate[[b, i]] = (n_sum + r_gate[[b, i]] * hn_sum).tanh();
new_h[[b, i]] =
(F::one() - z_gate[[b, i]]) * n_gate[[b, i]] + z_gate[[b, i]] * h[[b, i]];
}
}
Ok((
new_h.into_dyn(),
(r_gate.into_dyn(), z_gate.into_dyn(), n_gate.into_dyn()),
))
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + 'static + NumAssign> Layer<F>
for GRU<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
*self.input_cache.write().expect("Operation failed") = Some(input.clone());
let inputshape = input.shape();
if inputshape.len() != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
)));
}
let batch_size = inputshape[0];
let seq_len = inputshape[1];
let features = inputshape[2];
if features != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input features dimension mismatch: expected {}, got {}",
self.input_size, features
)));
}
let mut h = Array::zeros((batch_size, self.hidden_size));
let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
let mut all_gates = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
let x_t_view = x_t.view().into_dyn();
let h_view = h.view().into_dyn();
let step_result = self.step(&x_t_view, &h_view)?;
let new_h = step_result.0;
let gates = step_result.1;
h = new_h
.into_dimensionality::<Ix2>()
.expect("Operation failed");
all_gates.push(gates);
for b in 0..batch_size {
for i in 0..self.hidden_size {
all_hidden_states[[b, t, i]] = h[[b, i]];
}
}
}
*self.hidden_states_cache.write().expect("Operation failed") =
Some(all_hidden_states.clone().into_dyn());
Ok(all_hidden_states.into_dyn())
}
fn backward(
&self,
input: &Array<F, IxDyn>,
_grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let input_ref = self.input_cache.read().map_err(|_| {
NeuralError::InferenceError("Failed to acquire read lock on input cache".to_string())
})?;
let hidden_states_ref = self.hidden_states_cache.read().map_err(|_| {
NeuralError::InferenceError(
"Failed to acquire read lock on hidden states cache".to_string(),
)
})?;
if input_ref.is_none() || hidden_states_ref.is_none() {
return Err(NeuralError::InferenceError(
"No cached values for backward pass. Call forward() first.".to_string(),
));
}
let grad_input = Array::zeros(input.dim());
Ok(grad_input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
let small_change = F::from(0.001).expect("Failed to convert constant to float");
let lr = small_change * learningrate;
let update_param = |param: &mut Array<F, IxDyn>| {
for w in param.iter_mut() {
*w -= lr;
}
};
update_param(&mut self.weight_ir);
update_param(&mut self.weight_hr);
update_param(&mut self.bias_ir);
update_param(&mut self.bias_hr);
update_param(&mut self.weight_iz);
update_param(&mut self.weight_hz);
update_param(&mut self.bias_iz);
update_param(&mut self.bias_hz);
update_param(&mut self.weight_in);
update_param(&mut self.weight_hn);
update_param(&mut self.bias_in);
update_param(&mut self.bias_hn);
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + 'static + NumAssign>
ParamLayer<F> for GRU<F>
{
fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
vec![
self.weight_ir.clone(),
self.weight_hr.clone(),
self.bias_ir.clone(),
self.bias_hr.clone(),
self.weight_iz.clone(),
self.weight_hz.clone(),
self.bias_iz.clone(),
self.bias_hz.clone(),
self.weight_in.clone(),
self.weight_hn.clone(),
self.bias_in.clone(),
self.bias_hn.clone(),
]
}
fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
if params.len() != 12 {
return Err(NeuralError::InvalidArchitecture(format!(
"Expected 12 parameters, got {}",
params.len()
)));
}
let expectedshapes = [
self.weight_ir.shape(),
self.weight_hr.shape(),
self.bias_ir.shape(),
self.bias_hr.shape(),
self.weight_iz.shape(),
self.weight_hz.shape(),
self.bias_iz.shape(),
self.bias_hz.shape(),
self.weight_in.shape(),
self.weight_hn.shape(),
self.bias_in.shape(),
self.bias_hn.shape(),
];
for (i, (param, expected)) in params.iter().zip(expectedshapes.iter()).enumerate() {
if param.shape() != *expected {
return Err(NeuralError::InvalidArchitecture(format!(
"Parameter {} shape mismatch: expected {:?}, got {:?}",
i,
expected,
param.shape()
)));
}
}
self.weight_ir = params[0].clone();
self.weight_hr = params[1].clone();
self.bias_ir = params[2].clone();
self.bias_hr = params[3].clone();
self.weight_iz = params[4].clone();
self.weight_hz = params[5].clone();
self.bias_iz = params[6].clone();
self.bias_hz = params[7].clone();
self.weight_in = params[8].clone();
self.weight_hn = params[9].clone();
self.bias_in = params[10].clone();
self.bias_hn = params[11].clone();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
use scirs2_core::random::rngs::SmallRng;
use scirs2_core::random::SeedableRng;
#[test]
fn test_grushape() {
let mut rng = SmallRng::from_seed([42; 32]);
let gru = GRU::<f64>::new(
10, 20, &mut rng,
)
.expect("Operation failed");
let batch_size = 2;
let seq_len = 5;
let input_size = 10;
let input = Array3::<f64>::from_elem((batch_size, seq_len, input_size), 0.1).into_dyn();
let output = gru.forward(&input).expect("Operation failed");
assert_eq!(output.shape(), &[batch_size, seq_len, 20]);
}
}