use crate::error::Error;
use crate::model::{LayerArrayConfig, NamModel, WaveNetConfig};
mod array;
mod conv;
mod layer;
use array::LayerArray;
use layer::{Activation, Layer};
const ARCHITECTURE: &str = "WaveNet";
#[derive(Debug)]
pub struct WaveNet {
arrays: Vec<LayerArray>,
head_scale: f32,
receptive_field: usize,
channels0: usize,
head_a: Vec<f32>,
head_b: Vec<f32>,
sig_a: Vec<f32>,
sig_b: Vec<f32>,
}
impl WaveNet {
pub fn new(model: &NamModel) -> Result<Self, Error> {
if model.architecture != ARCHITECTURE {
return Err(Error::UnsupportedArchitecture(model.architecture.clone()));
}
let cfg = &model.config;
let expected = expected_weight_count(cfg);
if expected != model.weights.len() {
return Err(Error::WeightCountMismatch {
expected,
found: model.weights.len(),
});
}
let mut r = Reader::new(&model.weights);
let mut arrays = Vec::with_capacity(cfg.layers.len());
for la in &cfg.layers {
arrays.push(build_array(&mut r, la)?);
}
let head_scale = r.take(1)[0];
let max_ch = arrays.iter().map(LayerArray::channels).max().unwrap_or(1);
let max_head = arrays.iter().map(LayerArray::head_size).max().unwrap_or(1);
let head_w = max_ch.max(max_head).max(1);
let sig_w = max_ch.max(1);
let channels0 = arrays.first().map_or(0, LayerArray::channels);
Ok(Self {
arrays,
head_scale,
receptive_field: receptive_field(cfg),
channels0,
head_a: vec![0.0; head_w],
head_b: vec![0.0; head_w],
sig_a: vec![0.0; sig_w],
sig_b: vec![0.0; sig_w],
})
}
pub fn receptive_field(&self) -> usize {
self.receptive_field
}
pub fn process_buffer(&mut self, io: &mut [f32]) {
for sample in io.iter_mut() {
*sample = self.process_sample(*sample);
}
}
pub fn process_sample(&mut self, x: f32) -> f32 {
let cond = [x];
let n = self.arrays.len();
if n == 0 {
return self.head_scale * x;
}
self.head_a[..self.channels0].fill(0.0);
{
let ch = self.arrays[0].channels();
let hs = self.arrays[0].head_size();
self.arrays[0].process_sample(
&cond,
&cond,
&self.head_a[..ch],
&mut self.head_b[..hs],
&mut self.sig_b[..ch],
);
}
std::mem::swap(&mut self.head_a, &mut self.head_b);
std::mem::swap(&mut self.sig_a, &mut self.sig_b);
for i in 1..n {
let in_w = self.arrays[i - 1].channels();
let ch = self.arrays[i].channels();
let hs = self.arrays[i].head_size();
self.arrays[i].process_sample(
&self.sig_a[..in_w],
&cond,
&self.head_a[..ch],
&mut self.head_b[..hs],
&mut self.sig_b[..ch],
);
std::mem::swap(&mut self.head_a, &mut self.head_b);
std::mem::swap(&mut self.sig_a, &mut self.sig_b);
}
self.head_scale * self.head_a[0]
}
pub fn reset(&mut self) {
for a in &mut self.arrays {
a.reset();
}
self.head_a.fill(0.0);
self.head_b.fill(0.0);
self.sig_a.fill(0.0);
self.sig_b.fill(0.0);
}
}
fn receptive_field(cfg: &WaveNetConfig) -> usize {
let mut rf = 1;
for la in &cfg.layers {
for &d in &la.dilations {
rf += (la.kernel_size - 1) * d;
}
}
rf
}
fn expected_weight_count(cfg: &WaveNetConfig) -> usize {
let mut total = 0;
for la in &cfg.layers {
let mid = if la.gated {
2 * la.channels
} else {
la.channels
};
total += la.channels * la.input_size; let per_layer = mid * la.channels * la.kernel_size + mid + mid * la.condition_size + la.channels * la.channels + la.channels; total += la.dilations.len() * per_layer;
total += la.head_size * la.channels; if la.head_bias {
total += la.head_size;
}
}
total + 1 }
fn build_array(r: &mut Reader, la: &LayerArrayConfig) -> Result<LayerArray, Error> {
let activation = Activation::from_name(&la.activation)?;
let mid = if la.gated {
2 * la.channels
} else {
la.channels
};
let rechannel_w = r.take(la.channels * la.input_size);
let mut layers = Vec::with_capacity(la.dilations.len());
for &d in &la.dilations {
let conv_w = r.take(mid * la.channels * la.kernel_size);
let conv_b = r.take(mid);
let mix_w = r.take(mid * la.condition_size);
let one_w = r.take(la.channels * la.channels);
let one_b = r.take(la.channels);
layers.push(Layer::new(
la.channels,
la.condition_size,
la.kernel_size,
d,
activation,
la.gated,
conv_w,
conv_b,
mix_w,
one_w,
one_b,
));
}
let head_w = r.take(la.head_size * la.channels);
let head_b = if la.head_bias {
Some(r.take(la.head_size))
} else {
None
};
Ok(LayerArray::new(
la.input_size,
la.channels,
la.head_size,
rechannel_w,
layers,
head_w,
head_b,
))
}
struct Reader<'a> {
w: &'a [f32],
i: usize,
}
impl<'a> Reader<'a> {
fn new(w: &'a [f32]) -> Self {
Self { w, i: 0 }
}
fn take(&mut self, n: usize) -> Vec<f32> {
let chunk = self.w[self.i..self.i + n].to_vec();
self.i += n;
chunk
}
}
#[cfg(test)]
mod tests {
use super::*;
const TINY: &str = r#"{
"version": "0.5.4",
"architecture": "WaveNet",
"config": {
"layers": [{
"input_size": 1, "condition_size": 1, "channels": 1, "head_size": 1,
"kernel_size": 1, "dilations": [1], "activation": "ReLU",
"gated": false, "head_bias": false
}],
"head": null, "head_scale": 10.0
},
"weights": [1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]
}"#;
#[test]
fn tiny_model_matches_hand_computed_forward() {
let model = NamModel::from_json_str(TINY).unwrap();
let mut wn = WaveNet::new(&model).unwrap();
let mut buf = [0.5_f32];
wn.process_buffer(&mut buf);
assert!((buf[0] - 10.0).abs() < 1e-5, "got {}", buf[0]);
}
#[test]
fn receptive_field_sums_dilated_taps() {
let mk = |dilations: Vec<usize>| LayerArrayConfig {
input_size: 1,
condition_size: 1,
channels: 1,
head_size: 1,
kernel_size: 3,
dilations,
activation: "Tanh".into(),
gated: false,
head_bias: false,
};
let cfg = WaveNetConfig {
layers: vec![mk(vec![1, 2]), mk(vec![8])],
head: None,
head_scale: 1.0,
};
assert_eq!(receptive_field(&cfg), 23);
let model = NamModel::from_json_str(TINY).unwrap();
assert_eq!(WaveNet::new(&model).unwrap().receptive_field(), 1);
}
#[test]
fn reset_restores_from_fresh_result() {
let model = NamModel::from_json_str(TINY).unwrap();
let mut wn = WaveNet::new(&model).unwrap();
let mut warm = [0.3_f32, -0.7, 0.2];
wn.process_buffer(&mut warm);
wn.reset();
let mut a = [0.5_f32];
wn.process_buffer(&mut a);
assert!((a[0] - 10.0).abs() < 1e-5, "got {}", a[0]);
}
#[test]
fn wrong_weight_count_is_rejected() {
let bad = TINY.replace(
"[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5, 10.0]",
"[1.0, 2.0, 0.5, 1.0, 3.0, 0.1, 0.5]",
);
let model = NamModel::from_json_str(&bad).unwrap();
match WaveNet::new(&model) {
Err(Error::WeightCountMismatch { expected, found }) => {
assert_eq!(expected, 8);
assert_eq!(found, 7);
}
other => panic!("expected WeightCountMismatch, got {other:?}"),
}
}
#[test]
fn unsupported_architecture_is_rejected() {
let bad = TINY.replace("\"WaveNet\"", "\"LSTM\"");
let model = NamModel::from_json_str(&bad).unwrap();
assert!(matches!(
WaveNet::new(&model),
Err(Error::UnsupportedArchitecture(_))
));
}
}