1use crate::error::Error;
12use crate::model::{LayerArrayConfig, NamModel, WaveNetConfig};
13use crate::reader::Reader;
14
15mod array;
16mod conv;
17mod layer;
18
19use array::LayerArray;
20use conv::MAX_BLOCK;
21use layer::{Activation, Layer};
22
23#[derive(Debug)]
25pub struct WaveNet {
26 arrays: Vec<LayerArray>,
27 head_scale: f32,
28 receptive_field: usize,
31 channels0: usize,
33 sample_rate: f64,
35 head_a: Vec<f32>,
37 head_b: Vec<f32>,
38 sig_a: Vec<f32>,
40 sig_b: Vec<f32>,
41 head_a_blk: Vec<f32>,
44 head_b_blk: Vec<f32>,
45 sig_a_blk: Vec<f32>,
46 sig_b_blk: Vec<f32>,
47 cond_blk: Vec<f32>,
48}
49
50impl WaveNet {
51 pub fn new(model: &NamModel) -> Result<Self, Error> {
56 let cfg = match &model.config {
57 crate::model::ModelConfig::WaveNet(cfg) => cfg,
58 crate::model::ModelConfig::Lstm(_) => {
59 return Err(Error::UnsupportedArchitecture(model.architecture.clone()))
60 }
61 };
62
63 let expected = expected_weight_count(cfg);
64 if expected != model.weights.len() {
65 return Err(Error::WeightCountMismatch {
66 expected,
67 found: model.weights.len(),
68 });
69 }
70
71 let mut r = Reader::new(&model.weights);
72 let mut arrays = Vec::with_capacity(cfg.layers.len());
73 for la in &cfg.layers {
74 arrays.push(build_array(&mut r, la)?);
75 }
76 let head_scale = r.take(1)[0];
77
78 let max_ch = arrays.iter().map(LayerArray::channels).max().unwrap_or(1);
79 let max_head = arrays.iter().map(LayerArray::head_size).max().unwrap_or(1);
80 let head_w = max_ch.max(max_head).max(1);
81 let sig_w = max_ch.max(1);
82 let channels0 = arrays.first().map_or(0, LayerArray::channels);
83
84 Ok(Self {
85 arrays,
86 head_scale,
87 receptive_field: receptive_field(cfg),
88 channels0,
89 sample_rate: model.sample_rate(),
90 head_a: vec![0.0; head_w],
91 head_b: vec![0.0; head_w],
92 sig_a: vec![0.0; sig_w],
93 sig_b: vec![0.0; sig_w],
94 head_a_blk: vec![0.0; head_w * MAX_BLOCK],
95 head_b_blk: vec![0.0; head_w * MAX_BLOCK],
96 sig_a_blk: vec![0.0; sig_w * MAX_BLOCK],
97 sig_b_blk: vec![0.0; sig_w * MAX_BLOCK],
98 cond_blk: vec![0.0; MAX_BLOCK],
99 })
100 }
101
102 pub fn receptive_field(&self) -> usize {
110 self.receptive_field
111 }
112
113 pub fn sample_rate(&self) -> f64 {
115 self.sample_rate
116 }
117
118 pub fn process_buffer(&mut self, io: &mut [f32]) {
128 if self.arrays.is_empty() {
129 for s in io.iter_mut() {
130 *s *= self.head_scale;
131 }
132 return;
133 }
134 let mut off = 0;
135 while off < io.len() {
136 let n = (io.len() - off).min(MAX_BLOCK);
137 self.process_chunk(&mut io[off..off + n], n);
138 off += n;
139 }
140 }
141
142 fn process_chunk(&mut self, chunk: &mut [f32], n: usize) {
145 self.cond_blk[..n].copy_from_slice(chunk);
147
148 self.head_a_blk[..self.channels0 * n].fill(0.0);
151 {
152 let ch = self.arrays[0].channels();
153 let hs = self.arrays[0].head_size();
154 self.arrays[0].process_block(
155 &self.cond_blk[..n],
156 &self.cond_blk[..n],
157 &self.head_a_blk[..ch * n],
158 &mut self.head_b_blk[..hs * n],
159 &mut self.sig_b_blk[..ch * n],
160 n,
161 );
162 }
163 std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
164 std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
165
166 for i in 1..self.arrays.len() {
167 let in_w = self.arrays[i - 1].channels();
168 let ch = self.arrays[i].channels();
169 let hs = self.arrays[i].head_size();
170 self.arrays[i].process_block(
171 &self.sig_a_blk[..in_w * n],
172 &self.cond_blk[..n],
173 &self.head_a_blk[..ch * n],
174 &mut self.head_b_blk[..hs * n],
175 &mut self.sig_b_blk[..ch * n],
176 n,
177 );
178 std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
179 std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
180 }
181
182 for (t, s) in chunk.iter_mut().enumerate() {
185 *s = self.head_scale * self.head_a_blk[t];
186 }
187 }
188
189 pub fn process_sample(&mut self, x: f32) -> f32 {
194 let cond = [x];
195 let n = self.arrays.len();
196 if n == 0 {
197 return self.head_scale * x;
198 }
199
200 self.head_a[..self.channels0].fill(0.0);
203 {
204 let ch = self.arrays[0].channels();
205 let hs = self.arrays[0].head_size();
206 self.arrays[0].process_sample(
207 &cond,
208 &cond,
209 &self.head_a[..ch],
210 &mut self.head_b[..hs],
211 &mut self.sig_b[..ch],
212 );
213 }
214 std::mem::swap(&mut self.head_a, &mut self.head_b);
215 std::mem::swap(&mut self.sig_a, &mut self.sig_b);
216
217 for i in 1..n {
218 let in_w = self.arrays[i - 1].channels();
219 let ch = self.arrays[i].channels();
220 let hs = self.arrays[i].head_size();
221 self.arrays[i].process_sample(
222 &self.sig_a[..in_w],
223 &cond,
224 &self.head_a[..ch],
225 &mut self.head_b[..hs],
226 &mut self.sig_b[..ch],
227 );
228 std::mem::swap(&mut self.head_a, &mut self.head_b);
229 std::mem::swap(&mut self.sig_a, &mut self.sig_b);
230 }
231
232 self.head_scale * self.head_a[0]
234 }
235
236 pub fn reset(&mut self) {
238 for a in &mut self.arrays {
239 a.reset();
240 }
241 self.head_a.fill(0.0);
242 self.head_b.fill(0.0);
243 self.sig_a.fill(0.0);
244 self.sig_b.fill(0.0);
245 }
246}
247
248fn receptive_field(cfg: &WaveNetConfig) -> usize {
252 let mut rf = 1;
253 for la in &cfg.layers {
254 for &d in &la.dilations {
255 rf += (la.kernel_size - 1) * d;
256 }
257 }
258 rf
259}
260
261fn expected_weight_count(cfg: &WaveNetConfig) -> usize {
264 let mut total = 0;
265 for la in &cfg.layers {
266 let mid = if la.gated {
267 2 * la.channels
268 } else {
269 la.channels
270 };
271 total += la.channels * la.input_size; let per_layer = mid * la.channels * la.kernel_size + mid + mid * la.condition_size + la.channels * la.channels + la.channels; total += la.dilations.len() * per_layer;
278 total += la.head_size * la.channels; if la.head_bias {
280 total += la.head_size;
281 }
282 }
283 total + 1 }
285
286fn build_array(r: &mut Reader, la: &LayerArrayConfig) -> Result<LayerArray, Error> {
287 let activation = Activation::from_name(&la.activation)?;
288 let mid = if la.gated {
289 2 * la.channels
290 } else {
291 la.channels
292 };
293
294 let rechannel_w = r.take(la.channels * la.input_size);
295 let mut layers = Vec::with_capacity(la.dilations.len());
296 for &d in &la.dilations {
297 let conv_w = r.take(mid * la.channels * la.kernel_size);
298 let conv_b = r.take(mid);
299 let mix_w = r.take(mid * la.condition_size);
300 let one_w = r.take(la.channels * la.channels);
301 let one_b = r.take(la.channels);
302 layers.push(Layer::new(
303 la.channels,
304 la.condition_size,
305 la.kernel_size,
306 d,
307 activation,
308 la.gated,
309 conv_w,
310 conv_b,
311 mix_w,
312 one_w,
313 one_b,
314 ));
315 }
316 let head_w = r.take(la.head_size * la.channels);
317 let head_b = if la.head_bias {
318 Some(r.take(la.head_size))
319 } else {
320 None
321 };
322
323 Ok(LayerArray::new(
324 la.input_size,
325 la.channels,
326 la.head_size,
327 rechannel_w,
328 layers,
329 head_w,
330 head_b,
331 ))
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 const TINY: &str = r#"{
342 "version": "0.5.4",
343 "architecture": "WaveNet",
344 "config": {
345 "layers": [{
346 "input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
347 "kernel_size": 1, "dilations": [1], "activation": "ReLU",
348 "gated": false, "head_bias": false
349 }],
350 "head": null, "head_scale": 10.0
351 },
352 "weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
353 }"#;
354
355 #[test]
356 fn tiny_model_matches_hand_computed_forward() {
357 let model = NamModel::from_json_str(TINY).unwrap();
358 let mut wn = WaveNet::new(&model).unwrap();
359
360 let mut buf = [0.5_f32];
363 wn.process_buffer(&mut buf);
364 assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
365 }
366
367 #[test]
368 fn receptive_field_sums_dilated_taps() {
369 let mk = |dilations: Vec<usize>| LayerArrayConfig {
371 input_size: 1,
372 condition_size: 1,
373 channels: 1,
374 head_size: 1,
375 kernel_size: 3,
376 dilations,
377 activation: "Tanh".into(),
378 gated: false,
379 head_bias: false,
380 };
381 let cfg = WaveNetConfig {
382 layers: vec![mk(vec![1, 2]), mk(vec![8])],
383 head: None,
384 head_scale: 1.0,
385 };
386 assert_eq!(receptive_field(&cfg), 23);
388
389 let model = NamModel::from_json_str(TINY).unwrap();
391 assert_eq!(WaveNet::new(&model).unwrap().receptive_field(), 1);
392 }
393
394 #[test]
395 fn reset_restores_from_fresh_result() {
396 let model = NamModel::from_json_str(TINY).unwrap();
397 let mut wn = WaveNet::new(&model).unwrap();
398 let mut warm = [0.3_f32, -0.7, 0.2];
399 wn.process_buffer(&mut warm);
400 wn.reset();
401 let mut a = [0.5_f32];
402 wn.process_buffer(&mut a);
403 assert!((a[0] - 10.0).abs() < 1e-5, "got {}", a[0]);
404 }
405
406 #[test]
407 fn wrong_weight_count_is_rejected() {
408 let bad = TINY.replace(
409 "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]",
410 "[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5]",
411 );
412 let model = NamModel::from_json_str(&bad).unwrap();
413 match WaveNet::new(&model) {
414 Err(Error::WeightCountMismatch { expected, found }) => {
415 assert_eq!(expected, 8);
416 assert_eq!(found, 7);
417 }
418 other => panic!("expected WeightCountMismatch, got {other:?}"),
419 }
420 }
421
422 #[test]
423 fn wavenet_new_rejects_non_wavenet() {
424 let lstm = r#"{
425 "version": "0.5.4", "architecture": "LSTM",
426 "config": { "input_size": 1, "hidden_size": 4, "num_layers": 1 },
427 "weights": [0.0]
428 }"#;
429 let model = NamModel::from_json_str(lstm).unwrap();
430 assert!(matches!(
431 WaveNet::new(&model),
432 Err(Error::UnsupportedArchitecture(_))
433 ));
434 }
435
436 #[test]
443 fn process_buffer_equals_process_sample_loop_on_standard_model() {
444 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
445 .join("tests/fixtures/reference_standard.nam");
446 let json = std::fs::read_to_string(path).expect("read standard fixture");
447 let model = NamModel::from_json_str(&json).expect("parse standard fixture");
448
449 let len = 2 * MAX_BLOCK + 137;
451 let signal: Vec<f32> = (0..len)
452 .map(|i| (i as f32 * 0.013).sin() * 0.5 + (i as f32 * 0.27).sin() * 0.2)
453 .collect();
454
455 let mut per_sample = WaveNet::new(&model).unwrap();
456 let want: Vec<f32> = signal
457 .iter()
458 .map(|&x| per_sample.process_sample(x))
459 .collect();
460
461 let mut block = WaveNet::new(&model).unwrap();
462 let mut got = signal.clone();
463 block.process_buffer(&mut got);
464
465 for (i, (g, w)) in got.iter().zip(&want).enumerate() {
466 assert!(
467 (g - w).abs() < 1e-5,
468 "sample {i}: block {g}, per-sample {w}"
469 );
470 }
471 }
472}