use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView1, ArrayView2, Axis, Dimension};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{NeuralError, Result};
use scirs2_core::random::RngExt;
type LSTMForwardReturn<F> = (
Array2<F>,
Array2<F>,
(Array2<F>, Array2<F>, Array2<F>, Array2<F>, Array2<F>),
);
type AdamUpdateReturn<F, D> = (Array<F, D>, Array<F, D>, Array<F, D>);
pub fn lstm_cell<F>(
x: &ArrayView2<F>,
h_prev: &ArrayView2<F>,
c_prev: &ArrayView2<F>,
w_ih: &ArrayView2<F>,
w_hh: &ArrayView2<F>,
b_ih: &ArrayView1<F>,
b_hh: &ArrayView1<F>,
) -> Result<LSTMForwardReturn<F>>
where
F: Float + Debug,
{
let batch_size = x.shape()[0];
let input_size = x.shape()[1];
let hidden_size = h_prev.shape()[1];
if h_prev.shape()[0] != batch_size {
return Err(NeuralError::ShapeMismatch(format!(
"Hidden state batch size mismatch in lstm_cell: x batch_size={}, h_prev batch_size={}",
batch_size,
h_prev.shape()[0]
)));
}
if c_prev.shape()[0] != batch_size || c_prev.shape()[1] != hidden_size {
return Err(NeuralError::ShapeMismatch(format!(
"Cell state shape mismatch in lstm_cell: c_prev shape {:?}, expected [{}, {}]",
c_prev.shape(),
batch_size,
hidden_size
)));
}
if w_ih.shape()[0] != 4 * hidden_size || w_ih.shape()[1] != input_size {
return Err(NeuralError::ShapeMismatch(
format!("Input-to-hidden weights shape mismatch in lstm_cell: w_ih shape {:?}, expected [{}, {}]",
w_ih.shape(), 4 * hidden_size, input_size)
));
}
if w_hh.shape()[0] != 4 * hidden_size || w_hh.shape()[1] != hidden_size {
return Err(NeuralError::ShapeMismatch(
format!("Hidden-to-hidden weights shape mismatch in lstm_cell: w_hh shape {:?}, expected [{}, {}]",
w_hh.shape(), 4 * hidden_size, hidden_size)
));
}
if b_ih.shape()[0] != 4 * hidden_size {
return Err(NeuralError::ShapeMismatch(format!(
"Input-to-hidden bias shape mismatch in lstm_cell: b_ih shape {:?}, expected [{}]",
b_ih.shape(),
4 * hidden_size
)));
}
if b_hh.shape()[0] != 4 * hidden_size {
return Err(NeuralError::ShapeMismatch(format!(
"Hidden-to-hidden bias shape mismatch in lstm_cell: b_hh shape {:?}, expected [{}]",
b_hh.shape(),
4 * hidden_size
)));
}
let mut gates = Array2::<F>::zeros((batch_size, 4 * hidden_size));
for b in 0..batch_size {
for i in 0..(4 * hidden_size) {
let mut sum = b_ih[i];
for j in 0..input_size {
sum = sum + w_ih[[i, j]] * x[[b, j]];
}
gates[[b, i]] = sum;
}
}
for b in 0..batch_size {
for i in 0..(4 * hidden_size) {
let mut sum = b_hh[i];
for j in 0..hidden_size {
sum = sum + w_hh[[i, j]] * h_prev[[b, j]];
}
gates[[b, i]] = gates[[b, i]] + sum;
}
}
let mut i_gate = Array2::<F>::zeros((batch_size, hidden_size));
let mut f_gate = Array2::<F>::zeros((batch_size, hidden_size));
let mut g_gate = Array2::<F>::zeros((batch_size, hidden_size));
let mut o_gate = Array2::<F>::zeros((batch_size, hidden_size));
for b in 0..batch_size {
for h in 0..hidden_size {
i_gate[[b, h]] = sigmoid(gates[[b, h]]);
f_gate[[b, h]] = sigmoid(gates[[b, h + hidden_size]]);
g_gate[[b, h]] = gates[[b, h + 2 * hidden_size]].tanh();
o_gate[[b, h]] = sigmoid(gates[[b, h + 3 * hidden_size]]);
}
}
let mut c_next = Array2::<F>::zeros((batch_size, hidden_size));
for b in 0..batch_size {
for h in 0..hidden_size {
c_next[[b, h]] = f_gate[[b, h]] * c_prev[[b, h]] + i_gate[[b, h]] * g_gate[[b, h]];
}
}
let mut h_next = Array2::<F>::zeros((batch_size, hidden_size));
for b in 0..batch_size {
for h in 0..hidden_size {
h_next[[b, h]] = o_gate[[b, h]] * c_next[[b, h]].tanh();
}
}
let cache = (i_gate, f_gate, g_gate, o_gate, c_next.clone());
Ok((h_next, c_next, cache))
}
fn sigmoid<F: Float>(x: F) -> F {
F::one() / (F::one() + (-x).exp())
}
pub fn dropout<F, D, R>(
x: &ArrayView<F, D>,
dropout_rate: F,
rng: &mut R,
training: bool,
) -> Result<(Array<F, D>, Array<F, D>)>
where
F: Float + Debug + std::fmt::Display,
D: Dimension,
R: scirs2_core::random::Rng,
{
if dropout_rate < F::zero() || dropout_rate >= F::one() {
return Err(NeuralError::InvalidArgument(format!(
"Dropout rate must be in [0, 1) range, got {}",
dropout_rate
)));
}
let mut mask = Array::ones(x.raw_dim());
let mut output = x.to_owned();
if training {
let keep_prob = F::from(1.0).expect("Failed to convert constant to float") - dropout_rate;
let scale = F::from(1.0).expect("Failed to convert constant to float") / keep_prob;
for val in mask.iter_mut() {
let rand_val = F::from(rng.random::<f64>()).expect("Operation failed");
if rand_val < dropout_rate {
*val = F::from(0.0).expect("Failed to convert constant to float");
} else {
*val = scale;
}
}
for (o, m) in output.iter_mut().zip(mask.iter()) {
*o = *o * *m;
}
}
Ok((output, mask))
}
pub fn dropout_backward<F, D>(
dout: &ArrayView<F, D>,
mask: &ArrayView<F, D>,
dropout_rate: F,
) -> Result<Array<F, D>>
where
F: Float + Debug + std::fmt::Display,
D: Dimension,
{
if dropout_rate < F::zero() || dropout_rate >= F::one() {
return Err(NeuralError::InvalidArgument(format!(
"Dropout rate must be in [0, 1) range, got {}",
dropout_rate
)));
}
let mut dx = dout.to_owned();
for (dx_val, mask_val) in dx.iter_mut().zip(mask.iter()) {
*dx_val = *dx_val * *mask_val;
}
Ok(dx)
}
pub fn log_softmax<F, D>(x: &ArrayView<F, D>, dim: usize) -> Result<Array<F, D>>
where
F: Float + Debug + std::fmt::Display,
D: Dimension,
{
if dim >= x.ndim() {
return Err(NeuralError::InvalidArgument(format!(
"Dimension out of bounds in log_softmax: dim={}, ndim={}",
dim,
x.ndim()
)));
}
let mut output = x.to_owned();
for mut slice in output.lanes_mut(Axis(dim)) {
let max_val = *slice
.iter()
.max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
.expect("Operation failed");
let mut sum_exp = F::zero();
for val in slice.iter_mut() {
*val = (*val - max_val).exp();
sum_exp = sum_exp + *val;
}
for val in slice.iter_mut() {
*val = (*val / sum_exp).ln();
}
}
Ok(output)
}
#[allow(clippy::too_many_arguments)]
pub fn adam_update<F, D>(
w: &ArrayView<F, D>,
dw: &ArrayView<F, D>,
m: &ArrayView<F, D>,
v: &ArrayView<F, D>,
learning_rate: F,
beta1: F,
beta2: F,
epsilon: F,
t: usize,
) -> Result<AdamUpdateReturn<F, D>>
where
F: Float + Debug + std::fmt::Display,
D: Dimension,
{
if w.shape() != dw.shape() || w.shape() != m.shape() || w.shape() != v.shape() {
return Err(NeuralError::ShapeMismatch(
format!("Shape mismatch in adam_update: w shape {:?}, dw shape {:?}, m shape {:?}, v shape {:?}",
w.shape(), dw.shape(), m.shape(), v.shape())
));
}
if learning_rate <= F::zero() {
return Err(NeuralError::InvalidArgument(format!(
"Learning rate must be positive in adam_update, got {}",
learning_rate
)));
}
if beta1 < F::zero() || beta1 >= F::one() {
return Err(NeuralError::InvalidArgument(format!(
"beta1 must be in [0, 1) range in adam_update, got {}",
beta1
)));
}
if beta2 < F::zero() || beta2 >= F::one() {
return Err(NeuralError::InvalidArgument(format!(
"beta2 must be in [0, 1) range in adam_update, got {}",
beta2
)));
}
if epsilon <= F::zero() {
return Err(NeuralError::InvalidArgument(format!(
"epsilon must be positive in adam_update, got {}",
epsilon
)));
}
let mut w_new = w.to_owned();
let mut m_new = Array::zeros(w.raw_dim());
let mut v_new = Array::zeros(w.raw_dim());
for (m_val, m_prev, dw_val) in zip3(m_new.iter_mut(), m.iter(), dw.iter()) {
*m_val = beta1 * *m_prev + (F::one() - beta1) * *dw_val;
}
for (v_val, v_prev, dw_val) in zip3(v_new.iter_mut(), v.iter(), dw.iter()) {
*v_val = beta2 * *v_prev + (F::one() - beta2) * (*dw_val * *dw_val);
}
let t_f = F::from(t).expect("Failed to convert to float");
let m_hat_factor = F::one() / (F::one() - beta1.powf(t_f));
let v_hat_factor = F::one() / (F::one() - beta2.powf(t_f));
for (w_val, m_val, v_val) in zip3(w_new.iter_mut(), m_new.iter(), v_new.iter()) {
let m_hat = *m_val * m_hat_factor;
let v_hat = *v_val * v_hat_factor;
*w_val = *w_val - learning_rate * m_hat / (v_hat.sqrt() + epsilon);
}
Ok((w_new, m_new, v_new))
}
fn zip3<I1, I2, I3>(i1: I1, i2: I2, i3: I3) -> impl Iterator<Item = (I1::Item, I2::Item, I3::Item)>
where
I1: IntoIterator,
I2: IntoIterator,
I3: IntoIterator,
{
i1.into_iter()
.zip(i2.into_iter().zip(i3))
.map(|(a, (b, c))| (a, b, c))
}