use super::activation::Activation;
use super::conv::{Conv1d, MAX_BLOCK};
#[derive(Debug, Clone)]
pub(super) struct PostStackHead {
convs: Vec<(Activation, Conv1d)>,
in_channels: usize,
out_channels: usize,
scratch_a: Vec<f32>,
scratch_b: Vec<f32>,
sample_in: Vec<f32>,
}
impl PostStackHead {
pub(super) fn new(
convs: Vec<(Activation, Conv1d)>,
in_channels: usize,
out_channels: usize,
) -> Self {
let max_w = convs
.iter()
.map(|(_, c)| c.out_ch())
.max()
.unwrap_or(out_channels)
.max(in_channels)
.max(1);
Self {
convs,
in_channels,
out_channels,
scratch_a: vec![0.0; max_w * MAX_BLOCK],
scratch_b: vec![0.0; max_w * MAX_BLOCK],
sample_in: vec![0.0; in_channels.max(1)],
}
}
pub(super) fn in_channels(&self) -> usize {
self.in_channels
}
pub(super) fn out_channels(&self) -> usize {
self.out_channels
}
#[cfg(test)]
pub(super) fn receptive_field(&self) -> usize {
let mut rf = 1;
for (_, c) in &self.convs {
rf += c.kernel() - 1;
}
rf
}
pub(super) fn process_sample(&mut self, work: &[f32]) -> &[f32] {
debug_assert_eq!(work.len(), self.in_channels);
self.sample_in[..self.in_channels].copy_from_slice(work);
let mut row = std::mem::take(&mut self.sample_in);
self.process_block(&mut row, 1);
self.sample_in = row;
let out_ch = self.out_channels;
if self.convs.len() % 2 == 1 {
&self.scratch_a[..out_ch]
} else {
&self.scratch_b[..out_ch]
}
}
pub(super) fn process_block(&mut self, work: &mut [f32], n: usize) -> &[f32] {
debug_assert!(n <= MAX_BLOCK);
let nconvs = self.convs.len();
for i in 0..nconvs {
let (act, conv) = &mut self.convs[i];
let in_ch = conv.in_ch();
let out_ch = conv.out_ch();
if i == 0 {
for v in work[..in_ch * n].iter_mut() {
*v = act.apply(*v);
}
conv.process_block(&work[..in_ch * n], &mut self.scratch_a[..out_ch * n], n);
} else {
let (src, dst): (&mut [f32], &mut [f32]) = if i % 2 == 1 {
let (a, b) = (&mut self.scratch_a, &mut self.scratch_b);
(a.as_mut_slice(), b.as_mut_slice())
} else {
let (a, b) = (&mut self.scratch_b, &mut self.scratch_a);
(a.as_mut_slice(), b.as_mut_slice())
};
for v in src[..in_ch * n].iter_mut() {
*v = act.apply(*v);
}
conv.process_block(&src[..in_ch * n], &mut dst[..out_ch * n], n);
}
}
let out_ch = self.out_channels;
if nconvs % 2 == 1 {
&self.scratch_a[..out_ch * n]
} else {
&self.scratch_b[..out_ch * n]
}
}
pub(super) fn reset(&mut self) {
for (_, c) in &mut self.convs {
c.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wavenet::activation::Activation;
#[test]
fn single_conv_head_applies_activation_then_conv() {
let convs = vec![(
Activation::Relu,
Conv1d::new(1, 1, 1, 1, vec![3.0], Some(vec![0.5])),
)];
let mut head = PostStackHead::new(convs, 1, 1);
let work = [2.0_f32];
let out = head.process_sample(&work);
assert!((out[0] - 6.5).abs() < 1e-6, "got {}", out[0]);
let work2 = [-1.0_f32];
let out2 = head.process_sample(&work2);
assert!((out2[0] - 0.5).abs() < 1e-6, "got {}", out2[0]);
assert_eq!(head.out_channels(), 1);
}
#[test]
fn two_conv_head_chains_activation_conv_activation_conv() {
let convs = vec![
(
Activation::Relu,
Conv1d::new(1, 1, 1, 1, vec![2.0], Some(vec![0.0])),
),
(
Activation::Relu,
Conv1d::new(1, 1, 1, 1, vec![1.0], Some(vec![1.0])),
),
];
let mut head = PostStackHead::new(convs, 1, 1);
let work = [3.0_f32];
let out = head.process_sample(&work);
assert!((out[0] - 7.0).abs() < 1e-6, "got {}", out[0]);
}
#[test]
fn receptive_field_sums_kernel_minus_one() {
let convs = vec![
(
Activation::Relu,
Conv1d::new(1, 1, 16, 1, vec![0.0; 16], Some(vec![0.0])),
),
(
Activation::Relu,
Conv1d::new(1, 1, 1, 1, vec![0.0], Some(vec![0.0])),
),
];
let head = PostStackHead::new(convs, 1, 1);
assert_eq!(head.receptive_field(), 16);
}
}