#[inline]
fn sigmoid(z: f32) -> f32 {
1.0 / (1.0 + (-z).exp())
}
#[derive(Debug)]
pub(crate) struct LstmCell {
hidden_size: usize,
input_size: usize,
w: Vec<f32>,
b: Vec<f32>,
h0: Vec<f32>,
c0: Vec<f32>,
h: Vec<f32>, c: Vec<f32>, xh: Vec<f32>, gates: Vec<f32>, }
impl LstmCell {
pub(crate) fn new(
input_size: usize,
hidden_size: usize,
w: Vec<f32>,
b: Vec<f32>,
h0: Vec<f32>,
c0: Vec<f32>,
) -> Self {
debug_assert_eq!(w.len(), 4 * hidden_size * (input_size + hidden_size));
debug_assert_eq!(b.len(), 4 * hidden_size);
debug_assert_eq!(h0.len(), hidden_size);
debug_assert_eq!(c0.len(), hidden_size);
Self {
hidden_size,
input_size,
w,
b,
h: h0.clone(),
c: c0.clone(),
h0,
c0,
xh: vec![0.0; input_size + hidden_size],
gates: vec![0.0; 4 * hidden_size],
}
}
pub(crate) fn process(&mut self, x: &[f32]) -> &[f32] {
let h = self.hidden_size;
let row = self.input_size + h;
self.xh[..self.input_size].copy_from_slice(x);
self.xh[self.input_size..].copy_from_slice(&self.h);
for g in 0..4 * h {
let base = g * row;
let mut acc = self.b[g];
for k in 0..row {
acc += self.w[base + k] * self.xh[k];
}
self.gates[g] = acc;
}
for j in 0..h {
let i = sigmoid(self.gates[j]);
let f = sigmoid(self.gates[h + j]);
let g_ = self.gates[2 * h + j].tanh();
let o = sigmoid(self.gates[3 * h + j]);
let c = f * self.c[j] + i * g_;
self.c[j] = c;
self.h[j] = o * c.tanh();
}
&self.h
}
pub(crate) fn reset(&mut self) {
self.h.copy_from_slice(&self.h0);
self.c.copy_from_slice(&self.c0);
}
pub(crate) fn hidden_size(&self) -> usize {
self.hidden_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_cell_matches_hand_computed_step() {
let w = vec![
1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, ];
let b = vec![0.0, 0.0, 0.0, 0.0];
let mut cell = LstmCell::new(1, 1, w, b, vec![0.0], vec![0.0]);
let h = cell.process(&[0.5]);
assert!((h[0] - 0.2208).abs() < 1e-3, "got {}", h[0]);
}
#[test]
fn reset_restores_initial_state() {
let w = vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0];
let b = vec![0.0; 4];
let mut cell = LstmCell::new(1, 1, w, b, vec![0.0], vec![0.0]);
let first = cell.process(&[0.5])[0];
cell.process(&[0.5]); cell.reset();
let after = cell.process(&[0.5])[0];
assert!((first - after).abs() < 1e-7, "{first} vs {after}");
}
}