nam-rs 0.2.0

Pure-Rust, real-time-safe inference for Neural Amp Modeler (NAM) .nam models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
//! Real-time WaveNet inference.
//!
//! [`WaveNet`] is built once from a parsed [`NamModel`] (which may allocate), then
//! run on the audio thread via [`WaveNet::process_buffer`], which never allocates.
//! All scratch buffers are pre-allocated in [`WaveNet::new`].
//!
//! The forward pass is a port of NAM's WaveNet, built bottom-up from the `conv`,
//! `layer`, and `array` submodules (each unit-tested) and validated end-to-end
//! against the reference in `tests/parity.rs`.

use crate::error::Error;
use crate::model::{LayerArrayConfig, NamModel, WaveNetConfig};
use crate::reader::Reader;

mod array;
mod conv;
mod layer;

use array::LayerArray;
use conv::MAX_BLOCK;
use layer::{Activation, Layer};

/// A ready-to-run WaveNet, with all scratch buffers pre-allocated.
#[derive(Debug)]
pub struct WaveNet {
    arrays: Vec<LayerArray>,
    head_scale: f32,
    /// Samples of input history the deepest dilated tap reaches back over; equals
    /// the model's warmup length / processing latency in samples.
    receptive_field: usize,
    /// Channel width of the first array (its incoming head is silence this wide).
    channels0: usize,
    /// Training/inference sample rate, copied from the source `NamModel`.
    sample_rate: f64,
    /// Head signal carried between arrays (two buffers, ping-ponged).
    head_a: Vec<f32>,
    head_b: Vec<f32>,
    /// Layer signal carried between arrays (two buffers, ping-ponged).
    sig_a: Vec<f32>,
    sig_b: Vec<f32>,
    /// Planar `[width][MAX_BLOCK]` block-path twins of the carry buffers, plus a
    /// scratch copy of the conditioning chunk. Used by [`WaveNet::process_buffer`].
    head_a_blk: Vec<f32>,
    head_b_blk: Vec<f32>,
    sig_a_blk: Vec<f32>,
    sig_b_blk: Vec<f32>,
    cond_blk: Vec<f32>,
}

impl WaveNet {
    /// Build a runnable model from a parsed `.nam` file.
    ///
    /// All allocation happens here. Fails if the architecture is unsupported, an
    /// activation is unknown, or the flat weight blob does not match the config.
    pub fn new(model: &NamModel) -> Result<Self, Error> {
        let cfg = match &model.config {
            crate::model::ModelConfig::WaveNet(cfg) => cfg,
            crate::model::ModelConfig::Lstm(_) => {
                return Err(Error::UnsupportedArchitecture(model.architecture.clone()))
            }
        };

        let expected = expected_weight_count(cfg);
        if expected != model.weights.len() {
            return Err(Error::WeightCountMismatch {
                expected,
                found: model.weights.len(),
            });
        }

        let mut r = Reader::new(&model.weights);
        let mut arrays = Vec::with_capacity(cfg.layers.len());
        for la in &cfg.layers {
            arrays.push(build_array(&mut r, la)?);
        }
        let head_scale = r.take(1)[0];

        let max_ch = arrays.iter().map(LayerArray::channels).max().unwrap_or(1);
        let max_head = arrays.iter().map(LayerArray::head_size).max().unwrap_or(1);
        let head_w = max_ch.max(max_head).max(1);
        let sig_w = max_ch.max(1);
        let channels0 = arrays.first().map_or(0, LayerArray::channels);

        Ok(Self {
            arrays,
            head_scale,
            receptive_field: receptive_field(cfg),
            channels0,
            sample_rate: model.sample_rate(),
            head_a: vec![0.0; head_w],
            head_b: vec![0.0; head_w],
            sig_a: vec![0.0; sig_w],
            sig_b: vec![0.0; sig_w],
            head_a_blk: vec![0.0; head_w * MAX_BLOCK],
            head_b_blk: vec![0.0; head_w * MAX_BLOCK],
            sig_a_blk: vec![0.0; sig_w * MAX_BLOCK],
            sig_b_blk: vec![0.0; sig_w * MAX_BLOCK],
            cond_blk: vec![0.0; MAX_BLOCK],
        })
    }

    /// Receptive field in samples: how far back the deepest dilated tap reaches.
    ///
    /// This is the model's warmup length and its processing latency. The first
    /// `receptive_field()` output samples of a fresh (or freshly [`reset`](Self::reset))
    /// model are a startup transient computed against zero-filled history, so they
    /// reflect the streaming zero-init convention (matching NAM Core / NeuralAudio)
    /// rather than a training-time forward pass that pre-pads the whole input.
    pub fn receptive_field(&self) -> usize {
        self.receptive_field
    }

    /// The model's sample rate (from the source `.nam`, or the NAM default).
    pub fn sample_rate(&self) -> f64 {
        self.sample_rate
    }

    /// Process a buffer of mono samples in place.
    ///
    /// Runs the block kernel: each `MAX_BLOCK`-sized chunk is pushed through one
    /// array (and one layer) at a time, keeping each weight matrix hot across the
    /// whole chunk. Bit-for-bit equivalent to looping [`Self::process_sample`], and
    /// it shares the same streaming history, so the two are interchangeable.
    ///
    /// **Real-time contract:** no heap allocation, locks, or syscalls. Enforced by
    /// `tests/rt_safety.rs`.
    pub fn process_buffer(&mut self, io: &mut [f32]) {
        if self.arrays.is_empty() {
            for s in io.iter_mut() {
                *s *= self.head_scale;
            }
            return;
        }
        let mut off = 0;
        while off < io.len() {
            let n = (io.len() - off).min(MAX_BLOCK);
            self.process_chunk(&mut io[off..off + n], n);
            off += n;
        }
    }

    /// Run one `n <= MAX_BLOCK` chunk through every array via the planar block path.
    /// `chunk` is the mono input and is overwritten with the output.
    fn process_chunk(&mut self, chunk: &mut [f32], n: usize) {
        // The conditioning is the mono input; copy it out before we overwrite `chunk`.
        self.cond_blk[..n].copy_from_slice(chunk);

        // First array: input and condition are the mono signal; the incoming head is
        // silence of the array's channel width.
        self.head_a_blk[..self.channels0 * n].fill(0.0);
        {
            let ch = self.arrays[0].channels();
            let hs = self.arrays[0].head_size();
            self.arrays[0].process_block(
                &self.cond_blk[..n],
                &self.cond_blk[..n],
                &self.head_a_blk[..ch * n],
                &mut self.head_b_blk[..hs * n],
                &mut self.sig_b_blk[..ch * n],
                n,
            );
        }
        std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
        std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);

        for i in 1..self.arrays.len() {
            let in_w = self.arrays[i - 1].channels();
            let ch = self.arrays[i].channels();
            let hs = self.arrays[i].head_size();
            self.arrays[i].process_block(
                &self.sig_a_blk[..in_w * n],
                &self.cond_blk[..n],
                &self.head_a_blk[..ch * n],
                &mut self.head_b_blk[..hs * n],
                &mut self.sig_b_blk[..ch * n],
                n,
            );
            std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
            std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
        }

        // After the final swap, head_a_blk holds the last array's head output, whose
        // head_size is 1 — row 0 is the per-sample head signal.
        for (t, s) in chunk.iter_mut().enumerate() {
            *s = self.head_scale * self.head_a_blk[t];
        }
    }

    /// Process a single mono sample, returning one output sample.
    ///
    /// Equivalent to a one-element [`Self::process_buffer`]; convenient for
    /// callers that are not buffer-oriented. Allocation-free.
    pub fn process_sample(&mut self, x: f32) -> f32 {
        let cond = [x];
        let n = self.arrays.len();
        if n == 0 {
            return self.head_scale * x;
        }

        // First array: input and condition are the mono sample; the incoming head
        // is silence of the array's channel width.
        self.head_a[..self.channels0].fill(0.0);
        {
            let ch = self.arrays[0].channels();
            let hs = self.arrays[0].head_size();
            self.arrays[0].process_sample(
                &cond,
                &cond,
                &self.head_a[..ch],
                &mut self.head_b[..hs],
                &mut self.sig_b[..ch],
            );
        }
        std::mem::swap(&mut self.head_a, &mut self.head_b);
        std::mem::swap(&mut self.sig_a, &mut self.sig_b);

        for i in 1..n {
            let in_w = self.arrays[i - 1].channels();
            let ch = self.arrays[i].channels();
            let hs = self.arrays[i].head_size();
            self.arrays[i].process_sample(
                &self.sig_a[..in_w],
                &cond,
                &self.head_a[..ch],
                &mut self.head_b[..hs],
                &mut self.sig_b[..ch],
            );
            std::mem::swap(&mut self.head_a, &mut self.head_b);
            std::mem::swap(&mut self.sig_a, &mut self.sig_b);
        }

        // After the final swap, head_a holds the last array's head output.
        self.head_scale * self.head_a[0]
    }

    /// Reset all internal state (ring buffers) to silence.
    pub fn reset(&mut self) {
        for a in &mut self.arrays {
            a.reset();
        }
        self.head_a.fill(0.0);
        self.head_b.fill(0.0);
        self.sig_a.fill(0.0);
        self.sig_b.fill(0.0);
    }
}

/// Receptive field implied by `config`: `1 + Σ (kernel_size - 1) · dilation` over
/// every dilated layer in every array. The stacked dilated convs compose additively,
/// so this is the number of past input samples the final output depends on.
fn receptive_field(cfg: &WaveNetConfig) -> usize {
    let mut rf = 1;
    for la in &cfg.layers {
        for &d in &la.dilations {
            rf += (la.kernel_size - 1) * d;
        }
    }
    rf
}

/// Number of `f32`s `config` implies in the flat weight blob, including the final
/// `head_scale`.
fn expected_weight_count(cfg: &WaveNetConfig) -> usize {
    let mut total = 0;
    for la in &cfg.layers {
        let mid = if la.gated {
            2 * la.channels
        } else {
            la.channels
        };
        total += la.channels * la.input_size; // rechannel (no bias)
        let per_layer = mid * la.channels * la.kernel_size // conv weights
            + mid                                          // conv bias
            + mid * la.condition_size                      // input mixer (no bias)
            + la.channels * la.channels                    // 1x1 weights
            + la.channels; // 1x1 bias
        total += la.dilations.len() * per_layer;
        total += la.head_size * la.channels; // head rechannel weights
        if la.head_bias {
            total += la.head_size;
        }
    }
    total + 1 // head_scale
}

fn build_array(r: &mut Reader, la: &LayerArrayConfig) -> Result<LayerArray, Error> {
    let activation = Activation::from_name(&la.activation)?;
    let mid = if la.gated {
        2 * la.channels
    } else {
        la.channels
    };

    let rechannel_w = r.take(la.channels * la.input_size);
    let mut layers = Vec::with_capacity(la.dilations.len());
    for &d in &la.dilations {
        let conv_w = r.take(mid * la.channels * la.kernel_size);
        let conv_b = r.take(mid);
        let mix_w = r.take(mid * la.condition_size);
        let one_w = r.take(la.channels * la.channels);
        let one_b = r.take(la.channels);
        layers.push(Layer::new(
            la.channels,
            la.condition_size,
            la.kernel_size,
            d,
            activation,
            la.gated,
            conv_w,
            conv_b,
            mix_w,
            one_w,
            one_b,
        ));
    }
    let head_w = r.take(la.head_size * la.channels);
    let head_b = if la.head_bias {
        Some(r.take(la.head_size))
    } else {
        None
    };

    Ok(LayerArray::new(
        la.input_size,
        la.channels,
        la.head_size,
        rechannel_w,
        layers,
        head_w,
        head_b,
    ))
}

#[cfg(test)]
mod tests {
    use super::*;

    // 1 array, 1 layer, 1 channel, ReLU. Weight order:
    // rechannel=1, conv_w=2, conv_b=0.5, mix_w=1, one_w=3, one_b=0.1,
    // head_rechannel=0.5, head_scale=10.
    const TINY: &str = r#"{
        "version": "0.5.4",
        "architecture": "WaveNet",
        "config": {
            "layers": [{
                "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
                "kernel_size": 1, "dilations": [1], "activation": "ReLU",
                "gated": false, "head_bias": false
            }],
            "head": null, "head_scale": 10.0
        },
        "weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
    }"#;

    #[test]
    fn tiny_model_matches_hand_computed_forward() {
        let model = NamModel::from_json_str(TINY).unwrap();
        let mut wn = WaveNet::new(&model).unwrap();

        // x=0.5, cond=0.5: z = 2*0.5 + 0.5 + 1*0.5 = 2.0 ; relu=2.0
        // head = 0.5*2.0 = 1.0 ; out = head_scale * 1.0 = 10.0
        let mut buf = [0.5_f32];
        wn.process_buffer(&mut buf);
        assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
    }

    #[test]
    fn receptive_field_sums_dilated_taps() {
        // 1 + Σ(k-1)·d. Mirrors the reference model: kernel 3, dilations [1,2] then [8].
        let mk = |dilations: Vec<usize>| LayerArrayConfig {
            input_size: 1,
            condition_size: 1,
            channels: 1,
            head_size: 1,
            kernel_size: 3,
            dilations,
            activation: "Tanh".into(),
            gated: false,
            head_bias: false,
        };
        let cfg = WaveNetConfig {
            layers: vec![mk(vec![1, 2]), mk(vec![8])],
            head: None,
            head_scale: 1.0,
        };
        // (3-1)*1 + (3-1)*2 + (3-1)*8 = 2 + 4 + 16 = 22, + 1 = 23.
        assert_eq!(receptive_field(&cfg), 23);

        // TINY (kernel 1, dilation 1) reaches back over no past samples: rf = 1.
        let model = NamModel::from_json_str(TINY).unwrap();
        assert_eq!(WaveNet::new(&model).unwrap().receptive_field(), 1);
    }

    #[test]
    fn reset_restores_from_fresh_result() {
        let model = NamModel::from_json_str(TINY).unwrap();
        let mut wn = WaveNet::new(&model).unwrap();
        let mut warm = [0.3_f32, -0.7, 0.2];
        wn.process_buffer(&mut warm);
        wn.reset();
        let mut a = [0.5_f32];
        wn.process_buffer(&mut a);
        assert!((a[0] - 10.0).abs() < 1e-5, "got {}", a[0]);
    }

    #[test]
    fn wrong_weight_count_is_rejected() {
        let bad = TINY.replace(
            "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]",
            "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5]",
        );
        let model = NamModel::from_json_str(&bad).unwrap();
        match WaveNet::new(&model) {
            Err(Error::WeightCountMismatch { expected, found }) => {
                assert_eq!(expected, 8);
                assert_eq!(found, 7);
            }
            other => panic!("expected WeightCountMismatch, got {other:?}"),
        }
    }

    #[test]
    fn wavenet_new_rejects_non_wavenet() {
        let lstm = r#"{
            "version": "0.5.4", "architecture": "LSTM",
            "config": { "input_size": 1, "hidden_size": 4, "num_layers": 1 },
            "weights": [0.0]
        }"#;
        let model = NamModel::from_json_str(lstm).unwrap();
        assert!(matches!(
            WaveNet::new(&model),
            Err(Error::UnsupportedArchitecture(_))
        ));
    }

    /// End-to-end on the realistic standard model: the block `process_buffer` must
    /// equal a per-sample `process_sample` loop over the same signal, including
    /// across `MAX_BLOCK` chunk boundaries. `tests/parity.rs` drives `process_buffer`
    /// (the block path) and pins it to the reference NAM oracle within 1e-5; this
    /// test additionally ties the block path to the per-sample path, so the two are
    /// transitively guaranteed equivalent and both oracle-correct.
    #[test]
    fn process_buffer_equals_process_sample_loop_on_standard_model() {
        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
            .join("tests/fixtures/reference_standard.nam");
        let json = std::fs::read_to_string(path).expect("read standard fixture");
        let model = NamModel::from_json_str(&json).expect("parse standard fixture");

        // A signal longer than MAX_BLOCK so chunking is exercised.
        let len = 2 * MAX_BLOCK + 137;
        let signal: Vec<f32> = (0..len)
            .map(|i| (i as f32 * 0.013).sin() * 0.5 + (i as f32 * 0.27).sin() * 0.2)
            .collect();

        let mut per_sample = WaveNet::new(&model).unwrap();
        let want: Vec<f32> = signal
            .iter()
            .map(|&x| per_sample.process_sample(x))
            .collect();

        let mut block = WaveNet::new(&model).unwrap();
        let mut got = signal.clone();
        block.process_buffer(&mut got);

        for (i, (g, w)) in got.iter().zip(&want).enumerate() {
            assert!(
                (g - w).abs() < 1e-5,
                "sample {i}: block {g}, per-sample {w}"
            );
        }
    }
}