mod cell;
use cell::LstmCell;
use crate::error::Error;
use crate::model::{LstmConfig, ModelConfig, NamModel};
use crate::reader::Reader;
#[derive(Debug)]
pub struct Lstm {
cells: Vec<LstmCell>,
head_w: Vec<f32>,
head_b: f32,
sample_rate: f64,
buf_a: Vec<f32>,
buf_b: Vec<f32>,
}
impl Lstm {
pub fn new(model: &NamModel) -> Result<Self, Error> {
let cfg = match &model.config {
ModelConfig::Lstm(cfg) => cfg,
ModelConfig::WaveNet(_) => {
return Err(Error::UnsupportedArchitecture(model.architecture.clone()))
}
};
let expected = expected_weight_count(cfg);
if expected != model.weights.len() {
return Err(Error::WeightCountMismatch {
expected,
found: model.weights.len(),
});
}
let h = cfg.hidden_size;
let mut r = Reader::new(&model.weights);
let mut cells = Vec::with_capacity(cfg.num_layers);
for layer in 0..cfg.num_layers {
let in_dim = if layer == 0 { cfg.input_size } else { h };
let w = r.take(4 * h * (in_dim + h));
let b = r.take(4 * h);
let h0 = r.take(h);
let c0 = r.take(h);
cells.push(LstmCell::new(in_dim, h, w, b, h0, c0));
}
let head_w = r.take(h);
let head_b = r.take(1)[0];
Ok(Self {
cells,
head_w,
head_b,
sample_rate: model.sample_rate(),
buf_a: vec![0.0; h.max(1)],
buf_b: vec![0.0; h.max(1)],
})
}
pub fn process_buffer(&mut self, io: &mut [f32]) {
for s in io.iter_mut() {
*s = self.process_sample(*s);
}
}
pub fn process_sample(&mut self, x: f32) -> f32 {
if self.cells.is_empty() {
return self.head_b;
}
let h = self.cells[0].hidden_size();
let x0 = [x];
let out = self.cells[0].process(&x0);
self.buf_a[..h].copy_from_slice(out);
for i in 1..self.cells.len() {
let out = self.cells[i].process(&self.buf_a[..h]);
self.buf_b[..h].copy_from_slice(out);
std::mem::swap(&mut self.buf_a, &mut self.buf_b);
}
let mut y = self.head_b;
for j in 0..h {
y += self.head_w[j] * self.buf_a[j];
}
y
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
pub fn reset(&mut self) {
for c in &mut self.cells {
c.reset();
}
self.buf_a.fill(0.0);
self.buf_b.fill(0.0);
}
}
fn expected_weight_count(cfg: &LstmConfig) -> usize {
let h = cfg.hidden_size;
let mut total = 0;
for layer in 0..cfg.num_layers {
let in_dim = if layer == 0 { cfg.input_size } else { h };
total += 4 * h * (in_dim + h); total += 4 * h; total += h; total += h; }
total + h + 1 }
#[cfg(test)]
mod tests {
use super::*;
use crate::NamModel;
const TINY_LSTM: &str = r#"{
"version": "0.5.4", "architecture": "LSTM",
"config": { "input_size": 1, "hidden_size": 1, "num_layers": 1 },
"weights": [1.0,0.0, 0.0,0.0, 2.0,0.0, 0.0,0.0, 0.0,0.0,0.0,0.0, 0.0, 0.0, 3.0, 0.5]
}"#;
#[test]
fn tiny_lstm_matches_hand_computed() {
let model = NamModel::from_json_str(TINY_LSTM).unwrap();
let mut net = Lstm::new(&model).unwrap();
let mut buf = [0.5_f32];
net.process_buffer(&mut buf);
assert!((buf[0] - 1.1623).abs() < 1e-3, "got {}", buf[0]);
}
#[test]
fn weight_count_mismatch_is_rejected() {
let bad = TINY_LSTM.replace(", 3.0, 0.5]", ", 3.0]");
let model = NamModel::from_json_str(&bad).unwrap();
assert!(matches!(
Lstm::new(&model),
Err(crate::Error::WeightCountMismatch { .. })
));
}
#[test]
fn reset_restores_fresh_output() {
let model = NamModel::from_json_str(TINY_LSTM).unwrap();
let mut net = Lstm::new(&model).unwrap();
let mut warm = [0.3_f32, -0.7, 0.2];
net.process_buffer(&mut warm);
net.reset();
let mut a = [0.5_f32];
net.process_buffer(&mut a);
assert!((a[0] - 1.1623).abs() < 1e-3, "got {}", a[0]);
}
}