Skip to main content

nam_rs/
model_runtime.rs

1//! Architecture-agnostic runtime: [`Model`] dispatches over the `.nam`'s declared
2//! architecture so consumers run any supported model without branching.
3
4use crate::error::Error;
5use crate::lstm::Lstm;
6use crate::model::{ModelConfig, NamModel};
7use crate::wavenet::WaveNet;
8
9/// A runnable NAM model of any supported architecture.
10///
11/// Build with [`Model::from_nam`]; then call [`Model::process_buffer`] on the audio
12/// thread. `#[non_exhaustive]` so future architectures don't break downstream
13/// `match`es.
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum Model {
17    /// A WaveNet model.
18    WaveNet(WaveNet),
19    /// An LSTM model.
20    Lstm(Lstm),
21}
22
23impl Model {
24    /// Build the runtime matching `model.architecture`. All allocation happens here.
25    pub fn from_nam(model: &NamModel) -> Result<Self, Error> {
26        match &model.config {
27            ModelConfig::WaveNet(_) => Ok(Model::WaveNet(WaveNet::new(model)?)),
28            ModelConfig::Lstm(_) => Ok(Model::Lstm(Lstm::new(model)?)),
29        }
30    }
31
32    /// Process a buffer of mono samples in place. Allocation-free.
33    pub fn process_buffer(&mut self, io: &mut [f32]) {
34        match self {
35            Model::WaveNet(w) => w.process_buffer(io),
36            Model::Lstm(l) => l.process_buffer(io),
37        }
38    }
39
40    /// Process a single mono sample. Allocation-free.
41    pub fn process_sample(&mut self, x: f32) -> f32 {
42        match self {
43            Model::WaveNet(w) => w.process_sample(x),
44            Model::Lstm(l) => l.process_sample(x),
45        }
46    }
47
48    /// The model's sample rate.
49    pub fn sample_rate(&self) -> f64 {
50        match self {
51            Model::WaveNet(w) => w.sample_rate(),
52            Model::Lstm(l) => l.sample_rate(),
53        }
54    }
55
56    /// Reset all internal state to the model's initial conditions.
57    pub fn reset(&mut self) {
58        match self {
59            Model::WaveNet(w) => w.reset(),
60            Model::Lstm(l) => l.reset(),
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    const TINY_WAVENET: &str = r#"{
70        "version": "0.5.4", "architecture": "WaveNet",
71        "config": { "layers": [{
72            "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
73            "kernel_size": 1, "dilations": [1], "activation": "ReLU",
74            "gated": false, "head_bias": false
75        }], "head": null, "head_scale": 10.0 },
76        "weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
77    }"#;
78
79    const TINY_LSTM: &str = r#"{
80        "version": "0.5.4", "architecture": "LSTM",
81        "config": { "input_size": 1, "hidden_size": 1, "num_layers": 1 },
82        "weights": [1.0,0.0, 0.0,0.0, 2.0,0.0, 0.0,0.0, 0.0,0.0,0.0,0.0, 0.0, 0.0, 3.0, 0.5]
83    }"#;
84
85    #[test]
86    fn from_nam_builds_wavenet() {
87        let m = NamModel::from_json_str(TINY_WAVENET).unwrap();
88        let mut model = Model::from_nam(&m).unwrap();
89        assert!(matches!(model, Model::WaveNet(_)));
90        let mut buf = [0.5_f32];
91        model.process_buffer(&mut buf);
92        assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
93    }
94
95    #[test]
96    fn from_nam_builds_lstm() {
97        let m = NamModel::from_json_str(TINY_LSTM).unwrap();
98        let mut model = Model::from_nam(&m).unwrap();
99        assert!(matches!(model, Model::Lstm(_)));
100        let mut buf = [0.5_f32];
101        model.process_buffer(&mut buf);
102        assert!((buf[0] - 1.1623).abs() < 1e-3, "got {}", buf[0]);
103    }
104}