use crate::error::{NeuralError, Result};
use crate::layers::recurrent::{LstmGateCache, LstmStepOutput};
use crate::layers::{Layer, ParamLayer};
use scirs2_core::ndarray::{Array, ArrayView, ArrayView1, Ix2, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Distribution, Uniform};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
const LSTM_SIMD_THRESHOLD: usize = 32;
#[derive(Debug, Clone)]
pub struct LSTMConfig {
pub input_size: usize,
pub hidden_size: usize,
}
pub struct LSTM<F: Float + Debug + Send + Sync + NumAssign> {
input_size: usize,
hidden_size: usize,
weight_ii: Array<F, IxDyn>,
weight_hi: Array<F, IxDyn>,
bias_ii: Array<F, IxDyn>,
bias_hi: Array<F, IxDyn>,
weight_if: Array<F, IxDyn>,
weight_hf: Array<F, IxDyn>,
bias_if: Array<F, IxDyn>,
bias_hf: Array<F, IxDyn>,
weight_ig: Array<F, IxDyn>,
weight_hg: Array<F, IxDyn>,
bias_ig: Array<F, IxDyn>,
bias_hg: Array<F, IxDyn>,
weight_io: Array<F, IxDyn>,
weight_ho: Array<F, IxDyn>,
bias_io: Array<F, IxDyn>,
bias_ho: Array<F, IxDyn>,
#[allow(dead_code)]
gradients: Arc<RwLock<Vec<Array<F, IxDyn>>>>,
input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
hidden_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
cell_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
#[allow(dead_code)]
gate_cache: LstmGateCache<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
LSTM<F>
{
pub fn new<R: scirs2_core::random::Rng>(
input_size: usize,
hidden_size: usize,
rng: &mut R,
) -> Result<Self> {
if input_size == 0 || hidden_size == 0 {
return Err(NeuralError::InvalidArchitecture(
"Input _size and hidden _size must be positive".to_string(),
));
}
let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
})?;
let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
})?;
let mut create_weight_matrix = |rows: usize,
cols: usize,
scale: F|
-> Result<Array<F, IxDyn>> {
let mut weights_vec: Vec<F> = Vec::with_capacity(rows * cols);
let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
NeuralError::InvalidArchitecture(format!(
"Failed to create uniform distribution: {e}"
))
})?;
for _ in 0..(rows * cols) {
let rand_val = uniform.sample(rng);
let val = F::from(rand_val).ok_or_else(|| {
NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
})?;
weights_vec.push(val * scale);
}
Array::from_shape_vec(IxDyn(&[rows, cols]), weights_vec).map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
})
};
let weight_ii = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hi = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_ii: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_hi: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let weight_if = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hf = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let mut bias_if: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let mut bias_hf: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let one = F::one();
for i in 0..hidden_size {
bias_if[i] = one;
bias_hf[i] = one;
}
let weight_ig = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_hg = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_ig: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_hg: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let weight_io = create_weight_matrix(hidden_size, input_size, scale_ih)?;
let weight_ho = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
let bias_io: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let bias_ho: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
let gradients = vec![
Array::zeros(weight_ii.dim()),
Array::zeros(weight_hi.dim()),
Array::zeros(bias_ii.dim()),
Array::zeros(bias_hi.dim()),
Array::zeros(weight_if.dim()),
Array::zeros(weight_hf.dim()),
Array::zeros(bias_if.dim()),
Array::zeros(bias_hf.dim()),
Array::zeros(weight_ig.dim()),
Array::zeros(weight_hg.dim()),
Array::zeros(bias_ig.dim()),
Array::zeros(bias_hg.dim()),
Array::zeros(weight_io.dim()),
Array::zeros(weight_ho.dim()),
Array::zeros(bias_io.dim()),
Array::zeros(bias_ho.dim()),
];
Ok(Self {
input_size,
hidden_size,
weight_ii,
weight_hi,
bias_ii,
bias_hi,
weight_if,
weight_hf,
bias_if,
bias_hf,
weight_ig,
weight_hg,
bias_ig,
bias_hg,
weight_io,
weight_ho,
bias_io,
bias_ho,
gradients: Arc::new(RwLock::new(gradients)),
input_cache: Arc::new(RwLock::new(None)),
hidden_states_cache: Arc::new(RwLock::new(None)),
cell_states_cache: Arc::new(RwLock::new(None)),
gate_cache: Arc::new(RwLock::new(None)),
})
}
fn should_use_simd(&self) -> bool {
self.input_size + self.hidden_size >= LSTM_SIMD_THRESHOLD
}
fn step(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
c: &ArrayView<F, IxDyn>,
) -> Result<LstmStepOutput<F>> {
if self.should_use_simd() {
self.step_simd(x, h, c)
} else {
self.step_naive(x, h, c)
}
}
fn step_simd(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
c: &ArrayView<F, IxDyn>,
) -> Result<LstmStepOutput<F>> {
let xshape = x.shape();
let hshape = h.shape();
let cshape = c.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size || cshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden/cell state dimension mismatch: expected {}, got {}/{}",
self.hidden_size, hshape[1], cshape[1]
)));
}
if xshape[0] != hshape[0] || xshape[0] != cshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}, cell state has {}",
xshape[0], hshape[0], cshape[0]
)));
}
let mut i_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut f_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut g_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut o_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_c: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
for b in 0..batch_size {
let x_b = x.slice(scirs2_core::ndarray::s![b, ..]);
let x_view: ArrayView1<F> = x_b.into_dimensionality().expect("Operation failed");
let h_b = h.slice(scirs2_core::ndarray::s![b, ..]);
let h_view: ArrayView1<F> = h_b.into_dimensionality().expect("Operation failed");
for i in 0..self.hidden_size {
let wii_row = self.weight_ii.slice(scirs2_core::ndarray::s![i, ..]);
let wii_view: ArrayView1<F> =
wii_row.into_dimensionality().expect("Operation failed");
let whi_row = self.weight_hi.slice(scirs2_core::ndarray::s![i, ..]);
let whi_view: ArrayView1<F> =
whi_row.into_dimensionality().expect("Operation failed");
let wif_row = self.weight_if.slice(scirs2_core::ndarray::s![i, ..]);
let wif_view: ArrayView1<F> =
wif_row.into_dimensionality().expect("Operation failed");
let whf_row = self.weight_hf.slice(scirs2_core::ndarray::s![i, ..]);
let whf_view: ArrayView1<F> =
whf_row.into_dimensionality().expect("Operation failed");
let wig_row = self.weight_ig.slice(scirs2_core::ndarray::s![i, ..]);
let wig_view: ArrayView1<F> =
wig_row.into_dimensionality().expect("Operation failed");
let whg_row = self.weight_hg.slice(scirs2_core::ndarray::s![i, ..]);
let whg_view: ArrayView1<F> =
whg_row.into_dimensionality().expect("Operation failed");
let wio_row = self.weight_io.slice(scirs2_core::ndarray::s![i, ..]);
let wio_view: ArrayView1<F> =
wio_row.into_dimensionality().expect("Operation failed");
let who_row = self.weight_ho.slice(scirs2_core::ndarray::s![i, ..]);
let who_view: ArrayView1<F> =
who_row.into_dimensionality().expect("Operation failed");
let i_sum = self.bias_ii[i]
+ self.bias_hi[i]
+ F::simd_dot(&wii_view, &x_view)
+ F::simd_dot(&whi_view, &h_view);
i_gate[[b, i]] = F::one() / (F::one() + (-i_sum).exp());
let f_sum = self.bias_if[i]
+ self.bias_hf[i]
+ F::simd_dot(&wif_view, &x_view)
+ F::simd_dot(&whf_view, &h_view);
f_gate[[b, i]] = F::one() / (F::one() + (-f_sum).exp());
let g_sum = self.bias_ig[i]
+ self.bias_hg[i]
+ F::simd_dot(&wig_view, &x_view)
+ F::simd_dot(&whg_view, &h_view);
g_gate[[b, i]] = g_sum.tanh();
let o_sum = self.bias_io[i]
+ self.bias_ho[i]
+ F::simd_dot(&wio_view, &x_view)
+ F::simd_dot(&who_view, &h_view);
o_gate[[b, i]] = F::one() / (F::one() + (-o_sum).exp());
new_c[[b, i]] = f_gate[[b, i]] * c[[b, i]] + i_gate[[b, i]] * g_gate[[b, i]];
new_h[[b, i]] = o_gate[[b, i]] * new_c[[b, i]].tanh();
}
}
Ok((
new_h.into_dyn(),
new_c.into_dyn(),
(
i_gate.into_dyn(),
f_gate.into_dyn(),
g_gate.into_dyn(),
o_gate.into_dyn(),
),
))
}
fn step_naive(
&self,
x: &ArrayView<F, IxDyn>,
h: &ArrayView<F, IxDyn>,
c: &ArrayView<F, IxDyn>,
) -> Result<LstmStepOutput<F>> {
let xshape = x.shape();
let hshape = h.shape();
let cshape = c.shape();
let batch_size = xshape[0];
if xshape[1] != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input feature dimension mismatch: expected {}, got {}",
self.input_size, xshape[1]
)));
}
if hshape[1] != self.hidden_size || cshape[1] != self.hidden_size {
return Err(NeuralError::InferenceError(format!(
"Hidden/cell state dimension mismatch: expected {}, got {}/{}",
self.hidden_size, hshape[1], cshape[1]
)));
}
if xshape[0] != hshape[0] || xshape[0] != cshape[0] {
return Err(NeuralError::InferenceError(format!(
"Batch size mismatch: input has {}, hidden state has {}, cell state has {}",
xshape[0], hshape[0], cshape[0]
)));
}
let mut i_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut f_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut g_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut o_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_c: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
for b in 0..batch_size {
for i in 0..self.hidden_size {
let mut i_sum = self.bias_ii[i] + self.bias_hi[i];
for j in 0..self.input_size {
i_sum += self.weight_ii[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
i_sum += self.weight_hi[[i, j]] * h[[b, j]];
}
i_gate[[b, i]] = F::one() / (F::one() + (-i_sum).exp());
let mut f_sum = self.bias_if[i] + self.bias_hf[i];
for j in 0..self.input_size {
f_sum += self.weight_if[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
f_sum += self.weight_hf[[i, j]] * h[[b, j]];
}
f_gate[[b, i]] = F::one() / (F::one() + (-f_sum).exp());
let mut g_sum = self.bias_ig[i] + self.bias_hg[i];
for j in 0..self.input_size {
g_sum += self.weight_ig[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
g_sum += self.weight_hg[[i, j]] * h[[b, j]];
}
g_gate[[b, i]] = g_sum.tanh();
let mut o_sum = self.bias_io[i] + self.bias_ho[i];
for j in 0..self.input_size {
o_sum += self.weight_io[[i, j]] * x[[b, j]];
}
for j in 0..self.hidden_size {
o_sum += self.weight_ho[[i, j]] * h[[b, j]];
}
o_gate[[b, i]] = F::one() / (F::one() + (-o_sum).exp());
new_c[[b, i]] = f_gate[[b, i]] * c[[b, i]] + i_gate[[b, i]] * g_gate[[b, i]];
new_h[[b, i]] = o_gate[[b, i]] * new_c[[b, i]].tanh();
}
}
Ok((
new_h.into_dyn(),
new_c.into_dyn(),
(
i_gate.into_dyn(),
f_gate.into_dyn(),
g_gate.into_dyn(),
o_gate.into_dyn(),
),
))
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
for LSTM<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
*self.input_cache.write().expect("Operation failed") = Some(input.clone());
let inputshape = input.shape();
if inputshape.len() != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
)));
}
let batch_size = inputshape[0];
let seq_len = inputshape[1];
let features = inputshape[2];
if features != self.input_size {
return Err(NeuralError::InferenceError(format!(
"Input features dimension mismatch: expected {}, got {}",
self.input_size, features
)));
}
let mut h = Array::zeros((batch_size, self.hidden_size));
let mut c = Array::zeros((batch_size, self.hidden_size));
let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
let mut all_cell_states = Array::zeros((batch_size, seq_len, self.hidden_size));
let mut all_gates = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
let x_t_view = x_t.view().into_dyn();
let h_view = h.view().into_dyn();
let c_view = c.view().into_dyn();
let (new_h, new_c, gates) = self.step(&x_t_view, &h_view, &c_view)?;
h = new_h
.into_dimensionality::<Ix2>()
.expect("Operation failed");
c = new_c
.into_dimensionality::<Ix2>()
.expect("Operation failed");
all_gates.push(gates);
for b in 0..batch_size {
for i in 0..self.hidden_size {
all_hidden_states[[b, t, i]] = h[[b, i]];
all_cell_states[[b, t, i]] = c[[b, i]];
}
}
}
*self.hidden_states_cache.write().expect("Operation failed") =
Some(all_hidden_states.clone().into_dyn());
*self.cell_states_cache.write().expect("Operation failed") =
Some(all_cell_states.into_dyn());
Ok(all_hidden_states.into_dyn())
}
fn backward(
&self,
input: &Array<F, IxDyn>,
_grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let input_ref = self.input_cache.read().map_err(|_| {
NeuralError::InferenceError("Failed to acquire read lock on input cache".to_string())
})?;
let hidden_states_ref = self.hidden_states_cache.read().map_err(|_| {
NeuralError::InferenceError(
"Failed to acquire read lock on hidden states cache".to_string(),
)
})?;
let cell_states_ref = self.cell_states_cache.read().map_err(|_| {
NeuralError::InferenceError(
"Failed to acquire read lock on cell states cache".to_string(),
)
})?;
if input_ref.is_none() || hidden_states_ref.is_none() || cell_states_ref.is_none() {
return Err(NeuralError::InferenceError(
"No cached values for backward pass. Call forward() first.".to_string(),
));
}
let grad_input = Array::zeros(input.dim());
Ok(grad_input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
let small_change = F::from(0.001).expect("Failed to convert constant to float");
let lr = small_change * learningrate;
let update_param = |param: &mut Array<F, IxDyn>| {
for w in param.iter_mut() {
*w -= lr;
}
};
update_param(&mut self.weight_ii);
update_param(&mut self.weight_hi);
update_param(&mut self.bias_ii);
update_param(&mut self.bias_hi);
update_param(&mut self.weight_if);
update_param(&mut self.weight_hf);
update_param(&mut self.bias_if);
update_param(&mut self.bias_hf);
update_param(&mut self.weight_ig);
update_param(&mut self.weight_hg);
update_param(&mut self.bias_ig);
update_param(&mut self.bias_hg);
update_param(&mut self.weight_io);
update_param(&mut self.weight_ho);
update_param(&mut self.bias_io);
update_param(&mut self.bias_ho);
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
ParamLayer<F> for LSTM<F>
{
fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
vec![
self.weight_ii.clone(),
self.weight_hi.clone(),
self.bias_ii.clone(),
self.bias_hi.clone(),
self.weight_if.clone(),
self.weight_hf.clone(),
self.bias_if.clone(),
self.bias_hf.clone(),
self.weight_ig.clone(),
self.weight_hg.clone(),
self.bias_ig.clone(),
self.bias_hg.clone(),
self.weight_io.clone(),
self.weight_ho.clone(),
self.bias_io.clone(),
self.bias_ho.clone(),
]
}
fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
Vec::new()
}
fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
if params.len() != 16 {
return Err(NeuralError::InvalidArchitecture(format!(
"Expected 16 parameters, got {}",
params.len()
)));
}
let expectedshapes = vec![
self.weight_ii.shape(),
self.weight_hi.shape(),
self.bias_ii.shape(),
self.bias_hi.shape(),
self.weight_if.shape(),
self.weight_hf.shape(),
self.bias_if.shape(),
self.bias_hf.shape(),
self.weight_ig.shape(),
self.weight_hg.shape(),
self.bias_ig.shape(),
self.bias_hg.shape(),
self.weight_io.shape(),
self.weight_ho.shape(),
self.bias_io.shape(),
self.bias_ho.shape(),
];
for (i, (param, expected)) in params.iter().zip(expectedshapes.iter()).enumerate() {
if param.shape() != *expected {
return Err(NeuralError::InvalidArchitecture(format!(
"Parameter {} shape mismatch: expected {:?}, got {:?}",
i,
expected,
param.shape()
)));
}
}
self.weight_ii = params[0].clone();
self.weight_hi = params[1].clone();
self.bias_ii = params[2].clone();
self.bias_hi = params[3].clone();
self.weight_if = params[4].clone();
self.weight_hf = params[5].clone();
self.bias_if = params[6].clone();
self.bias_hf = params[7].clone();
self.weight_ig = params[8].clone();
self.weight_hg = params[9].clone();
self.bias_ig = params[10].clone();
self.bias_hg = params[11].clone();
self.weight_io = params[12].clone();
self.weight_ho = params[13].clone();
self.bias_io = params[14].clone();
self.bias_ho = params[15].clone();
Ok(())
}
}