Skip to main content

nam_rs/
wavenet.rs

1//! Real-time WaveNet inference.
2//!
3//! [`WaveNet`] is built once from a parsed [`NamModel`] (which may allocate), then
4//! run on the audio thread via [`WaveNet::process_buffer`], which never allocates.
5//! All scratch buffers are pre-allocated in [`WaveNet::new`].
6//!
7//! The forward pass is a port of NAM's WaveNet, built bottom-up from the `conv`,
8//! `layer`, and `array` submodules (each unit-tested) and validated end-to-end
9//! against the reference in `tests/parity.rs`.
10
11use crate::error::Error;
12use crate::model::{LayerArrayConfig, NamModel, WaveNetConfig};
13use crate::reader::Reader;
14
15mod array;
16mod conv;
17mod layer;
18
19use array::LayerArray;
20use conv::MAX_BLOCK;
21use layer::{Activation, Layer};
22
23/// A ready-to-run WaveNet, with all scratch buffers pre-allocated.
24#[derive(Debug)]
25pub struct WaveNet {
26    arrays: Vec<LayerArray>,
27    head_scale: f32,
28    /// Samples of input history the deepest dilated tap reaches back over; equals
29    /// the model's warmup length / processing latency in samples.
30    receptive_field: usize,
31    /// Channel width of the first array (its incoming head is silence this wide).
32    channels0: usize,
33    /// Training/inference sample rate, copied from the source `NamModel`.
34    sample_rate: f64,
35    /// Head signal carried between arrays (two buffers, ping-ponged).
36    head_a: Vec<f32>,
37    head_b: Vec<f32>,
38    /// Layer signal carried between arrays (two buffers, ping-ponged).
39    sig_a: Vec<f32>,
40    sig_b: Vec<f32>,
41    /// Planar `[width][MAX_BLOCK]` block-path twins of the carry buffers, plus a
42    /// scratch copy of the conditioning chunk. Used by [`WaveNet::process_buffer`].
43    head_a_blk: Vec<f32>,
44    head_b_blk: Vec<f32>,
45    sig_a_blk: Vec<f32>,
46    sig_b_blk: Vec<f32>,
47    cond_blk: Vec<f32>,
48}
49
50impl WaveNet {
51    /// Build a runnable model from a parsed `.nam` file.
52    ///
53    /// All allocation happens here. Fails if the architecture is unsupported, an
54    /// activation is unknown, or the flat weight blob does not match the config.
55    pub fn new(model: &NamModel) -> Result<Self, Error> {
56        let cfg = match &model.config {
57            crate::model::ModelConfig::WaveNet(cfg) => cfg,
58            crate::model::ModelConfig::Lstm(_) => {
59                return Err(Error::UnsupportedArchitecture(model.architecture.clone()))
60            }
61        };
62
63        let expected = expected_weight_count(cfg);
64        if expected != model.weights.len() {
65            return Err(Error::WeightCountMismatch {
66                expected,
67                found: model.weights.len(),
68            });
69        }
70
71        let mut r = Reader::new(&model.weights);
72        let mut arrays = Vec::with_capacity(cfg.layers.len());
73        for la in &cfg.layers {
74            arrays.push(build_array(&mut r, la)?);
75        }
76        let head_scale = r.take(1)[0];
77
78        let max_ch = arrays.iter().map(LayerArray::channels).max().unwrap_or(1);
79        let max_head = arrays.iter().map(LayerArray::head_size).max().unwrap_or(1);
80        let head_w = max_ch.max(max_head).max(1);
81        let sig_w = max_ch.max(1);
82        let channels0 = arrays.first().map_or(0, LayerArray::channels);
83
84        Ok(Self {
85            arrays,
86            head_scale,
87            receptive_field: receptive_field(cfg),
88            channels0,
89            sample_rate: model.sample_rate(),
90            head_a: vec![0.0; head_w],
91            head_b: vec![0.0; head_w],
92            sig_a: vec![0.0; sig_w],
93            sig_b: vec![0.0; sig_w],
94            head_a_blk: vec![0.0; head_w * MAX_BLOCK],
95            head_b_blk: vec![0.0; head_w * MAX_BLOCK],
96            sig_a_blk: vec![0.0; sig_w * MAX_BLOCK],
97            sig_b_blk: vec![0.0; sig_w * MAX_BLOCK],
98            cond_blk: vec![0.0; MAX_BLOCK],
99        })
100    }
101
102    /// Receptive field in samples: how far back the deepest dilated tap reaches.
103    ///
104    /// This is the model's warmup length and its processing latency. The first
105    /// `receptive_field()` output samples of a fresh (or freshly [`reset`](Self::reset))
106    /// model are a startup transient computed against zero-filled history, so they
107    /// reflect the streaming zero-init convention (matching NAM Core / NeuralAudio)
108    /// rather than a training-time forward pass that pre-pads the whole input.
109    pub fn receptive_field(&self) -> usize {
110        self.receptive_field
111    }
112
113    /// The model's sample rate (from the source `.nam`, or the NAM default).
114    pub fn sample_rate(&self) -> f64 {
115        self.sample_rate
116    }
117
118    /// Process a buffer of mono samples in place.
119    ///
120    /// Runs the block kernel: each `MAX_BLOCK`-sized chunk is pushed through one
121    /// array (and one layer) at a time, keeping each weight matrix hot across the
122    /// whole chunk. Bit-for-bit equivalent to looping [`Self::process_sample`], and
123    /// it shares the same streaming history, so the two are interchangeable.
124    ///
125    /// **Real-time contract:** no heap allocation, locks, or syscalls. Enforced by
126    /// `tests/rt_safety.rs`.
127    pub fn process_buffer(&mut self, io: &mut [f32]) {
128        if self.arrays.is_empty() {
129            for s in io.iter_mut() {
130                *s *= self.head_scale;
131            }
132            return;
133        }
134        let mut off = 0;
135        while off < io.len() {
136            let n = (io.len() - off).min(MAX_BLOCK);
137            self.process_chunk(&mut io[off..off + n], n);
138            off += n;
139        }
140    }
141
142    /// Run one `n <= MAX_BLOCK` chunk through every array via the planar block path.
143    /// `chunk` is the mono input and is overwritten with the output.
144    fn process_chunk(&mut self, chunk: &mut [f32], n: usize) {
145        // The conditioning is the mono input; copy it out before we overwrite `chunk`.
146        self.cond_blk[..n].copy_from_slice(chunk);
147
148        // First array: input and condition are the mono signal; the incoming head is
149        // silence of the array's channel width.
150        self.head_a_blk[..self.channels0 * n].fill(0.0);
151        {
152            let ch = self.arrays[0].channels();
153            let hs = self.arrays[0].head_size();
154            self.arrays[0].process_block(
155                &self.cond_blk[..n],
156                &self.cond_blk[..n],
157                &self.head_a_blk[..ch * n],
158                &mut self.head_b_blk[..hs * n],
159                &mut self.sig_b_blk[..ch * n],
160                n,
161            );
162        }
163        std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
164        std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
165
166        for i in 1..self.arrays.len() {
167            let in_w = self.arrays[i - 1].channels();
168            let ch = self.arrays[i].channels();
169            let hs = self.arrays[i].head_size();
170            self.arrays[i].process_block(
171                &self.sig_a_blk[..in_w * n],
172                &self.cond_blk[..n],
173                &self.head_a_blk[..ch * n],
174                &mut self.head_b_blk[..hs * n],
175                &mut self.sig_b_blk[..ch * n],
176                n,
177            );
178            std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
179            std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
180        }
181
182        // After the final swap, head_a_blk holds the last array's head output, whose
183        // head_size is 1 — row 0 is the per-sample head signal.
184        for (t, s) in chunk.iter_mut().enumerate() {
185            *s = self.head_scale * self.head_a_blk[t];
186        }
187    }
188
189    /// Process a single mono sample, returning one output sample.
190    ///
191    /// Equivalent to a one-element [`Self::process_buffer`]; convenient for
192    /// callers that are not buffer-oriented. Allocation-free.
193    pub fn process_sample(&mut self, x: f32) -> f32 {
194        let cond = [x];
195        let n = self.arrays.len();
196        if n == 0 {
197            return self.head_scale * x;
198        }
199
200        // First array: input and condition are the mono sample; the incoming head
201        // is silence of the array's channel width.
202        self.head_a[..self.channels0].fill(0.0);
203        {
204            let ch = self.arrays[0].channels();
205            let hs = self.arrays[0].head_size();
206            self.arrays[0].process_sample(
207                &cond,
208                &cond,
209                &self.head_a[..ch],
210                &mut self.head_b[..hs],
211                &mut self.sig_b[..ch],
212            );
213        }
214        std::mem::swap(&mut self.head_a, &mut self.head_b);
215        std::mem::swap(&mut self.sig_a, &mut self.sig_b);
216
217        for i in 1..n {
218            let in_w = self.arrays[i - 1].channels();
219            let ch = self.arrays[i].channels();
220            let hs = self.arrays[i].head_size();
221            self.arrays[i].process_sample(
222                &self.sig_a[..in_w],
223                &cond,
224                &self.head_a[..ch],
225                &mut self.head_b[..hs],
226                &mut self.sig_b[..ch],
227            );
228            std::mem::swap(&mut self.head_a, &mut self.head_b);
229            std::mem::swap(&mut self.sig_a, &mut self.sig_b);
230        }
231
232        // After the final swap, head_a holds the last array's head output.
233        self.head_scale * self.head_a[0]
234    }
235
236    /// Reset all internal state (ring buffers) to silence.
237    pub fn reset(&mut self) {
238        for a in &mut self.arrays {
239            a.reset();
240        }
241        self.head_a.fill(0.0);
242        self.head_b.fill(0.0);
243        self.sig_a.fill(0.0);
244        self.sig_b.fill(0.0);
245    }
246}
247
248/// Receptive field implied by `config`: `1 + Σ (kernel_size - 1) · dilation` over
249/// every dilated layer in every array. The stacked dilated convs compose additively,
250/// so this is the number of past input samples the final output depends on.
251fn receptive_field(cfg: &WaveNetConfig) -> usize {
252    let mut rf = 1;
253    for la in &cfg.layers {
254        for &d in &la.dilations {
255            rf += (la.kernel_size - 1) * d;
256        }
257    }
258    rf
259}
260
261/// Number of `f32`s `config` implies in the flat weight blob, including the final
262/// `head_scale`.
263fn expected_weight_count(cfg: &WaveNetConfig) -> usize {
264    let mut total = 0;
265    for la in &cfg.layers {
266        let mid = if la.gated {
267            2 * la.channels
268        } else {
269            la.channels
270        };
271        total += la.channels * la.input_size; // rechannel (no bias)
272        let per_layer = mid * la.channels * la.kernel_size // conv weights
273            + mid                                          // conv bias
274            + mid * la.condition_size                      // input mixer (no bias)
275            + la.channels * la.channels                    // 1x1 weights
276            + la.channels; // 1x1 bias
277        total += la.dilations.len() * per_layer;
278        total += la.head_size * la.channels; // head rechannel weights
279        if la.head_bias {
280            total += la.head_size;
281        }
282    }
283    total + 1 // head_scale
284}
285
286fn build_array(r: &mut Reader, la: &LayerArrayConfig) -> Result<LayerArray, Error> {
287    let activation = Activation::from_name(&la.activation)?;
288    let mid = if la.gated {
289        2 * la.channels
290    } else {
291        la.channels
292    };
293
294    let rechannel_w = r.take(la.channels * la.input_size);
295    let mut layers = Vec::with_capacity(la.dilations.len());
296    for &d in &la.dilations {
297        let conv_w = r.take(mid * la.channels * la.kernel_size);
298        let conv_b = r.take(mid);
299        let mix_w = r.take(mid * la.condition_size);
300        let one_w = r.take(la.channels * la.channels);
301        let one_b = r.take(la.channels);
302        layers.push(Layer::new(
303            la.channels,
304            la.condition_size,
305            la.kernel_size,
306            d,
307            activation,
308            la.gated,
309            conv_w,
310            conv_b,
311            mix_w,
312            one_w,
313            one_b,
314        ));
315    }
316    let head_w = r.take(la.head_size * la.channels);
317    let head_b = if la.head_bias {
318        Some(r.take(la.head_size))
319    } else {
320        None
321    };
322
323    Ok(LayerArray::new(
324        la.input_size,
325        la.channels,
326        la.head_size,
327        rechannel_w,
328        layers,
329        head_w,
330        head_b,
331    ))
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    // 1 array, 1 layer, 1 channel, ReLU. Weight order:
339    // rechannel=1, conv_w=2, conv_b=0.5, mix_w=1, one_w=3, one_b=0.1,
340    // head_rechannel=0.5, head_scale=10.
341    const TINY: &str = r#"{
342        "version": "0.5.4",
343        "architecture": "WaveNet",
344        "config": {
345            "layers": [{
346                "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
347                "kernel_size": 1, "dilations": [1], "activation": "ReLU",
348                "gated": false, "head_bias": false
349            }],
350            "head": null, "head_scale": 10.0
351        },
352        "weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
353    }"#;
354
355    #[test]
356    fn tiny_model_matches_hand_computed_forward() {
357        let model = NamModel::from_json_str(TINY).unwrap();
358        let mut wn = WaveNet::new(&model).unwrap();
359
360        // x=0.5, cond=0.5: z = 2*0.5 + 0.5 + 1*0.5 = 2.0 ; relu=2.0
361        // head = 0.5*2.0 = 1.0 ; out = head_scale * 1.0 = 10.0
362        let mut buf = [0.5_f32];
363        wn.process_buffer(&mut buf);
364        assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
365    }
366
367    #[test]
368    fn receptive_field_sums_dilated_taps() {
369        // 1 + Σ(k-1)·d. Mirrors the reference model: kernel 3, dilations [1,2] then [8].
370        let mk = |dilations: Vec<usize>| LayerArrayConfig {
371            input_size: 1,
372            condition_size: 1,
373            channels: 1,
374            head_size: 1,
375            kernel_size: 3,
376            dilations,
377            activation: "Tanh".into(),
378            gated: false,
379            head_bias: false,
380        };
381        let cfg = WaveNetConfig {
382            layers: vec![mk(vec![1, 2]), mk(vec![8])],
383            head: None,
384            head_scale: 1.0,
385        };
386        // (3-1)*1 + (3-1)*2 + (3-1)*8 = 2 + 4 + 16 = 22, + 1 = 23.
387        assert_eq!(receptive_field(&cfg), 23);
388
389        // TINY (kernel 1, dilation 1) reaches back over no past samples: rf = 1.
390        let model = NamModel::from_json_str(TINY).unwrap();
391        assert_eq!(WaveNet::new(&model).unwrap().receptive_field(), 1);
392    }
393
394    #[test]
395    fn reset_restores_from_fresh_result() {
396        let model = NamModel::from_json_str(TINY).unwrap();
397        let mut wn = WaveNet::new(&model).unwrap();
398        let mut warm = [0.3_f32, -0.7, 0.2];
399        wn.process_buffer(&mut warm);
400        wn.reset();
401        let mut a = [0.5_f32];
402        wn.process_buffer(&mut a);
403        assert!((a[0] - 10.0).abs() < 1e-5, "got {}", a[0]);
404    }
405
406    #[test]
407    fn wrong_weight_count_is_rejected() {
408        let bad = TINY.replace(
409            "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]",
410            "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5]",
411        );
412        let model = NamModel::from_json_str(&bad).unwrap();
413        match WaveNet::new(&model) {
414            Err(Error::WeightCountMismatch { expected, found }) => {
415                assert_eq!(expected, 8);
416                assert_eq!(found, 7);
417            }
418            other => panic!("expected WeightCountMismatch, got {other:?}"),
419        }
420    }
421
422    #[test]
423    fn wavenet_new_rejects_non_wavenet() {
424        let lstm = r#"{
425            "version": "0.5.4", "architecture": "LSTM",
426            "config": { "input_size": 1, "hidden_size": 4, "num_layers": 1 },
427            "weights": [0.0]
428        }"#;
429        let model = NamModel::from_json_str(lstm).unwrap();
430        assert!(matches!(
431            WaveNet::new(&model),
432            Err(Error::UnsupportedArchitecture(_))
433        ));
434    }
435
436    /// End-to-end on the realistic standard model: the block `process_buffer` must
437    /// equal a per-sample `process_sample` loop over the same signal, including
438    /// across `MAX_BLOCK` chunk boundaries. `tests/parity.rs` drives `process_buffer`
439    /// (the block path) and pins it to the reference NAM oracle within 1e-5; this
440    /// test additionally ties the block path to the per-sample path, so the two are
441    /// transitively guaranteed equivalent and both oracle-correct.
442    #[test]
443    fn process_buffer_equals_process_sample_loop_on_standard_model() {
444        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
445            .join("tests/fixtures/reference_standard.nam");
446        let json = std::fs::read_to_string(path).expect("read standard fixture");
447        let model = NamModel::from_json_str(&json).expect("parse standard fixture");
448
449        // A signal longer than MAX_BLOCK so chunking is exercised.
450        let len = 2 * MAX_BLOCK + 137;
451        let signal: Vec<f32> = (0..len)
452            .map(|i| (i as f32 * 0.013).sin() * 0.5 + (i as f32 * 0.27).sin() * 0.2)
453            .collect();
454
455        let mut per_sample = WaveNet::new(&model).unwrap();
456        let want: Vec<f32> = signal
457            .iter()
458            .map(|&x| per_sample.process_sample(x))
459            .collect();
460
461        let mut block = WaveNet::new(&model).unwrap();
462        let mut got = signal.clone();
463        block.process_buffer(&mut got);
464
465        for (i, (g, w)) in got.iter().zip(&want).enumerate() {
466            assert!(
467                (g - w).abs() < 1e-5,
468                "sample {i}: block {g}, per-sample {w}"
469            );
470        }
471    }
472}