use crate::error::Error;
use crate::model::{LayerArrayConfig, NamModel, WaveNetConfig};
use crate::reader::Reader;
mod array;
mod conv;
mod layer;
use array::LayerArray;
use conv::MAX_BLOCK;
use layer::{Activation, Layer};
#[derive(Debug)]
pub struct WaveNet {
arrays: Vec<LayerArray>,
head_scale: f32,
receptive_field: usize,
channels0: usize,
sample_rate: f64,
head_a: Vec<f32>,
head_b: Vec<f32>,
sig_a: Vec<f32>,
sig_b: Vec<f32>,
head_a_blk: Vec<f32>,
head_b_blk: Vec<f32>,
sig_a_blk: Vec<f32>,
sig_b_blk: Vec<f32>,
cond_blk: Vec<f32>,
}
impl WaveNet {
pub fn new(model: &NamModel) -> Result<Self, Error> {
let cfg = match &model.config {
crate::model::ModelConfig::WaveNet(cfg) => cfg,
crate::model::ModelConfig::Lstm(_) => {
return Err(Error::UnsupportedArchitecture(model.architecture.clone()))
}
};
let expected = expected_weight_count(cfg);
if expected != model.weights.len() {
return Err(Error::WeightCountMismatch {
expected,
found: model.weights.len(),
});
}
let mut r = Reader::new(&model.weights);
let mut arrays = Vec::with_capacity(cfg.layers.len());
for la in &cfg.layers {
arrays.push(build_array(&mut r, la)?);
}
let head_scale = r.take(1)[0];
let max_ch = arrays.iter().map(LayerArray::channels).max().unwrap_or(1);
let max_head = arrays.iter().map(LayerArray::head_size).max().unwrap_or(1);
let head_w = max_ch.max(max_head).max(1);
let sig_w = max_ch.max(1);
let channels0 = arrays.first().map_or(0, LayerArray::channels);
Ok(Self {
arrays,
head_scale,
receptive_field: receptive_field(cfg),
channels0,
sample_rate: model.sample_rate(),
head_a: vec![0.0; head_w],
head_b: vec![0.0; head_w],
sig_a: vec![0.0; sig_w],
sig_b: vec![0.0; sig_w],
head_a_blk: vec![0.0; head_w * MAX_BLOCK],
head_b_blk: vec![0.0; head_w * MAX_BLOCK],
sig_a_blk: vec![0.0; sig_w * MAX_BLOCK],
sig_b_blk: vec![0.0; sig_w * MAX_BLOCK],
cond_blk: vec![0.0; MAX_BLOCK],
})
}
pub fn receptive_field(&self) -> usize {
self.receptive_field
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
pub fn process_buffer(&mut self, io: &mut [f32]) {
if self.arrays.is_empty() {
for s in io.iter_mut() {
*s *= self.head_scale;
}
return;
}
let mut off = 0;
while off < io.len() {
let n = (io.len() - off).min(MAX_BLOCK);
self.process_chunk(&mut io[off..off + n], n);
off += n;
}
}
fn process_chunk(&mut self, chunk: &mut [f32], n: usize) {
self.cond_blk[..n].copy_from_slice(chunk);
self.head_a_blk[..self.channels0 * n].fill(0.0);
{
let ch = self.arrays[0].channels();
let hs = self.arrays[0].head_size();
self.arrays[0].process_block(
&self.cond_blk[..n],
&self.cond_blk[..n],
&self.head_a_blk[..ch * n],
&mut self.head_b_blk[..hs * n],
&mut self.sig_b_blk[..ch * n],
n,
);
}
std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
for i in 1..self.arrays.len() {
let in_w = self.arrays[i - 1].channels();
let ch = self.arrays[i].channels();
let hs = self.arrays[i].head_size();
self.arrays[i].process_block(
&self.sig_a_blk[..in_w * n],
&self.cond_blk[..n],
&self.head_a_blk[..ch * n],
&mut self.head_b_blk[..hs * n],
&mut self.sig_b_blk[..ch * n],
n,
);
std::mem::swap(&mut self.head_a_blk, &mut self.head_b_blk);
std::mem::swap(&mut self.sig_a_blk, &mut self.sig_b_blk);
}
for (t, s) in chunk.iter_mut().enumerate() {
*s = self.head_scale * self.head_a_blk[t];
}
}
pub fn process_sample(&mut self, x: f32) -> f32 {
let cond = [x];
let n = self.arrays.len();
if n == 0 {
return self.head_scale * x;
}
self.head_a[..self.channels0].fill(0.0);
{
let ch = self.arrays[0].channels();
let hs = self.arrays[0].head_size();
self.arrays[0].process_sample(
&cond,
&cond,
&self.head_a[..ch],
&mut self.head_b[..hs],
&mut self.sig_b[..ch],
);
}
std::mem::swap(&mut self.head_a, &mut self.head_b);
std::mem::swap(&mut self.sig_a, &mut self.sig_b);
for i in 1..n {
let in_w = self.arrays[i - 1].channels();
let ch = self.arrays[i].channels();
let hs = self.arrays[i].head_size();
self.arrays[i].process_sample(
&self.sig_a[..in_w],
&cond,
&self.head_a[..ch],
&mut self.head_b[..hs],
&mut self.sig_b[..ch],
);
std::mem::swap(&mut self.head_a, &mut self.head_b);
std::mem::swap(&mut self.sig_a, &mut self.sig_b);
}
self.head_scale * self.head_a[0]
}
pub fn reset(&mut self) {
for a in &mut self.arrays {
a.reset();
}
self.head_a.fill(0.0);
self.head_b.fill(0.0);
self.sig_a.fill(0.0);
self.sig_b.fill(0.0);
}
}
fn receptive_field(cfg: &WaveNetConfig) -> usize {
let mut rf = 1;
for la in &cfg.layers {
for &d in &la.dilations {
rf += (la.kernel_size - 1) * d;
}
}
rf
}
fn expected_weight_count(cfg: &WaveNetConfig) -> usize {
let mut total = 0;
for la in &cfg.layers {
let mid = if la.gated {
2 * la.channels
} else {
la.channels
};
total += la.channels * la.input_size; let per_layer = mid * la.channels * la.kernel_size + mid + mid * la.condition_size + la.channels * la.channels + la.channels; total += la.dilations.len() * per_layer;
total += la.head_size * la.channels; if la.head_bias {
total += la.head_size;
}
}
total + 1 }
fn build_array(r: &mut Reader, la: &LayerArrayConfig) -> Result<LayerArray, Error> {
let activation = Activation::from_name(&la.activation)?;
let mid = if la.gated {
2 * la.channels
} else {
la.channels
};
let rechannel_w = r.take(la.channels * la.input_size);
let mut layers = Vec::with_capacity(la.dilations.len());
for &d in &la.dilations {
let conv_w = r.take(mid * la.channels * la.kernel_size);
let conv_b = r.take(mid);
let mix_w = r.take(mid * la.condition_size);
let one_w = r.take(la.channels * la.channels);
let one_b = r.take(la.channels);
layers.push(Layer::new(
la.channels,
la.condition_size,
la.kernel_size,
d,
activation,
la.gated,
conv_w,
conv_b,
mix_w,
one_w,
one_b,
));
}
let head_w = r.take(la.head_size * la.channels);
let head_b = if la.head_bias {
Some(r.take(la.head_size))
} else {
None
};
Ok(LayerArray::new(
la.input_size,
la.channels,
la.head_size,
rechannel_w,
layers,
head_w,
head_b,
))
}
#[cfg(test)]
mod tests {
use super::*;
const TINY: &str = r#"{
"version": "0.5.4",
"architecture": "WaveNet",
"config": {
"layers": [{
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 1, "dilations": [1], "activation": "ReLU",
"gated": false, "head_bias": false
}],
"head": null, "head_scale": 10.0
},
"weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
}"#;
#[test]
fn tiny_model_matches_hand_computed_forward() {
let model = NamModel::from_json_str(TINY).unwrap();
let mut wn = WaveNet::new(&model).unwrap();
let mut buf = [0.5_f32];
wn.process_buffer(&mut buf);
assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
}
#[test]
fn receptive_field_sums_dilated_taps() {
let mk = |dilations: Vec<usize>| LayerArrayConfig {
input_size: 1,
condition_size: 1,
channels: 1,
head_size: 1,
kernel_size: 3,
dilations,
activation: "Tanh".into(),
gated: false,
head_bias: false,
};
let cfg = WaveNetConfig {
layers: vec![mk(vec![1, 2]), mk(vec![8])],
head: None,
head_scale: 1.0,
};
assert_eq!(receptive_field(&cfg), 23);
let model = NamModel::from_json_str(TINY).unwrap();
assert_eq!(WaveNet::new(&model).unwrap().receptive_field(), 1);
}
#[test]
fn reset_restores_from_fresh_result() {
let model = NamModel::from_json_str(TINY).unwrap();
let mut wn = WaveNet::new(&model).unwrap();
let mut warm = [0.3_f32, -0.7, 0.2];
wn.process_buffer(&mut warm);
wn.reset();
let mut a = [0.5_f32];
wn.process_buffer(&mut a);
assert!((a[0] - 10.0).abs() < 1e-5, "got {}", a[0]);
}
#[test]
fn wrong_weight_count_is_rejected() {
let bad = TINY.replace(
"[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]",
"[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5]",
);
let model = NamModel::from_json_str(&bad).unwrap();
match WaveNet::new(&model) {
Err(Error::WeightCountMismatch { expected, found }) => {
assert_eq!(expected, 8);
assert_eq!(found, 7);
}
other => panic!("expected WeightCountMismatch, got {other:?}"),
}
}
#[test]
fn wavenet_new_rejects_non_wavenet() {
let lstm = r#"{
"version": "0.5.4", "architecture": "LSTM",
"config": { "input_size": 1, "hidden_size": 4, "num_layers": 1 },
"weights": [0.0]
}"#;
let model = NamModel::from_json_str(lstm).unwrap();
assert!(matches!(
WaveNet::new(&model),
Err(Error::UnsupportedArchitecture(_))
));
}
#[test]
fn process_buffer_equals_process_sample_loop_on_standard_model() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/reference_standard.nam");
let json = std::fs::read_to_string(path).expect("read standard fixture");
let model = NamModel::from_json_str(&json).expect("parse standard fixture");
let len = 2 * MAX_BLOCK + 137;
let signal: Vec<f32> = (0..len)
.map(|i| (i as f32 * 0.013).sin() * 0.5 + (i as f32 * 0.27).sin() * 0.2)
.collect();
let mut per_sample = WaveNet::new(&model).unwrap();
let want: Vec<f32> = signal
.iter()
.map(|&x| per_sample.process_sample(x))
.collect();
let mut block = WaveNet::new(&model).unwrap();
let mut got = signal.clone();
block.process_buffer(&mut got);
for (i, (g, w)) in got.iter().zip(&want).enumerate() {
assert!(
(g - w).abs() < 1e-5,
"sample {i}: block {g}, per-sample {w}"
);
}
}
}