pub(super) const MAX_BLOCK: usize = 1024;
#[derive(Debug, Clone)]
pub(super) struct Conv1d {
in_ch: usize,
out_ch: usize,
kernel: usize,
dilation: usize,
weights: Vec<f32>,
bias: Option<Vec<f32>>,
ring: Vec<f32>,
ring_len: usize,
pos: usize,
staged: Vec<f32>,
staged_stride: usize,
}
impl Conv1d {
pub(super) fn new(
in_ch: usize,
out_ch: usize,
kernel: usize,
dilation: usize,
weights: Vec<f32>,
bias: Option<Vec<f32>>,
) -> Self {
assert_eq!(weights.len(), out_ch * in_ch * kernel, "conv weight count");
if let Some(b) = &bias {
assert_eq!(b.len(), out_ch, "conv bias count");
}
let ring_len = (kernel - 1) * dilation + 1;
let staged_stride = (ring_len - 1) + MAX_BLOCK;
Self {
in_ch,
out_ch,
kernel,
dilation,
weights,
bias,
ring: vec![0.0; in_ch * ring_len],
ring_len,
pos: ring_len - 1,
staged: vec![0.0; in_ch * staged_stride],
staged_stride,
}
}
#[cfg(test)]
pub(super) fn out_ch(&self) -> usize {
self.out_ch
}
pub(super) fn process_sample(&mut self, input: &[f32], out: &mut [f32]) {
debug_assert_eq!(input.len(), self.in_ch);
debug_assert_eq!(out.len(), self.out_ch);
self.pos = (self.pos + 1) % self.ring_len;
let base = self.pos * self.in_ch;
self.ring[base..base + self.in_ch].copy_from_slice(input);
for o in 0..self.out_ch {
let mut acc = self.bias.as_ref().map_or(0.0, |b| b[o]);
let wo = o * self.in_ch * self.kernel;
for k in 0..self.kernel {
let back = (self.kernel - 1 - k) * self.dilation;
let col = (self.pos + self.ring_len - back) % self.ring_len;
let rbase = col * self.in_ch;
for j in 0..self.in_ch {
acc += self.weights[wo + j * self.kernel + k] * self.ring[rbase + j];
}
}
out[o] = acc;
}
}
pub(super) fn process_block(&mut self, block_in: &[f32], block_out: &mut [f32], n: usize) {
debug_assert!(n <= MAX_BLOCK);
debug_assert_eq!(block_in.len(), self.in_ch * n);
debug_assert_eq!(block_out.len(), self.out_ch * n);
if n == 0 {
return;
}
let hist_len = self.ring_len - 1;
let s = self.staged_stride;
for j in 0..self.in_ch {
let row = j * s;
for h in 0..hist_len {
let col = (self.pos + self.ring_len - (hist_len - 1) + h) % self.ring_len;
self.staged[row + h] = self.ring[col * self.in_ch + j];
}
let src = &block_in[j * n..j * n + n];
self.staged[row + hist_len..row + hist_len + n].copy_from_slice(src);
}
for o in 0..self.out_ch {
let b = self.bias.as_ref().map_or(0.0, |bias| bias[o]);
block_out[o * n..o * n + n].fill(b);
}
for o in 0..self.out_ch {
let wo = o * self.in_ch * self.kernel;
let out = &mut block_out[o * n..o * n + n];
for k in 0..self.kernel {
let back = (self.kernel - 1 - k) * self.dilation;
for j in 0..self.in_ch {
let w = self.weights[wo + j * self.kernel + k];
let base = j * s + hist_len - back;
let src = &self.staged[base..base + n];
for t in 0..n {
out[t] += w * src[t];
}
}
}
}
for t in 0..n {
self.pos = (self.pos + 1) % self.ring_len;
let base = self.pos * self.in_ch;
for j in 0..self.in_ch {
self.ring[base + j] = block_in[j * n + t];
}
}
}
pub(super) fn reset(&mut self) {
self.ring.iter_mut().for_each(|x| *x = 0.0);
self.pos = self.ring_len - 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn run(conv: &mut Conv1d, xs: &[&[f32]]) -> Vec<Vec<f32>> {
let mut out = vec![0.0; conv.out_ch()];
xs.iter()
.map(|x| {
conv.process_sample(x, &mut out);
out.clone()
})
.collect()
}
#[test]
fn matches_namcore_conv1d_vectors() {
let mut basic = Conv1d::new(1, 1, 2, 1, vec![1.0, 2.0], None);
assert_eq!(
run(&mut basic, &[&[1.0], &[2.0], &[3.0], &[4.0]]),
vec![vec![2.0], vec![5.0], vec![8.0], vec![11.0]]
);
let mut biased = Conv1d::new(1, 1, 2, 1, vec![1.0, 0.0], Some(vec![5.0]));
assert_eq!(
run(&mut biased, &[&[2.0], &[3.0]]),
vec![vec![5.0], vec![7.0]]
);
let mut dil = Conv1d::new(1, 1, 2, 2, vec![1.0, 2.0], None);
assert_eq!(
run(&mut dil, &[&[1.0], &[2.0], &[3.0], &[4.0]]),
vec![vec![2.0], vec![4.0], vec![7.0], vec![10.0]]
);
}
#[test]
fn kernel2_dilation1_single_channel() {
let mut conv = Conv1d::new(1, 1, 2, 1, vec![0.5, 2.0], Some(vec![0.1]));
let got = run(&mut conv, &[&[1.0], &[2.0], &[3.0]]);
assert_eq!(got, vec![vec![2.1], vec![4.6], vec![7.1]]);
}
#[test]
fn kernel2_dilation2_skips_a_sample() {
let mut conv = Conv1d::new(1, 1, 2, 2, vec![0.5, 2.0], None);
let got = run(&mut conv, &[&[1.0], &[2.0], &[3.0], &[4.0]]);
assert_eq!(got, vec![vec![2.0], vec![4.0], vec![6.5], vec![9.0]]);
}
#[test]
fn one_by_one_is_a_matmul() {
let mut conv = Conv1d::new(2, 2, 1, 1, vec![1.0, 2.0, 3.0, 4.0], Some(vec![10.0, 20.0]));
let got = run(&mut conv, &[&[1.0, 1.0], &[2.0, 0.0]]);
assert_eq!(got, vec![vec![13.0, 27.0], vec![12.0, 26.0]]);
}
fn naive(xs: &[f32], w: &[f32], dilation: usize) -> Vec<f32> {
let k = w.len();
(0..xs.len())
.map(|t| {
(0..k)
.map(|tap| {
let back = (k - 1 - tap) * dilation;
let x = if t >= back { xs[t - back] } else { 0.0 };
w[tap] * x
})
.sum()
})
.collect()
}
#[test]
fn ring_buffer_matches_naive_over_long_signal() {
let w = vec![0.3, -1.1, 2.0];
let mut conv = Conv1d::new(1, 1, 3, 4, w.clone(), None);
let xs: Vec<f32> = (0..32).map(|i| (i as f32 * 0.37).sin()).collect();
let got: Vec<f32> = xs
.iter()
.map(|&x| {
let mut o = [0.0];
conv.process_sample(&[x], &mut o);
o[0]
})
.collect();
let want = naive(&xs, &w, 4);
for (g, e) in got.iter().zip(&want) {
assert!((g - e).abs() < 1e-6, "got {g}, want {e}");
}
}
#[test]
fn process_block_equals_process_sample_loop() {
let cases = [
(1_usize, 1_usize, 1_usize, 1_usize),
(2, 3, 1, 1),
(3, 2, 2, 1),
(2, 2, 3, 4),
(4, 5, 2, 7),
];
for (in_ch, out_ch, kernel, dilation) in cases {
let wlen = out_ch * in_ch * kernel;
let w: Vec<f32> = (0..wlen)
.map(|i| ((i * 37 % 23) as f32 - 11.0) * 0.1)
.collect();
let bias: Vec<f32> = (0..out_ch).map(|o| (o as f32 + 1.0) * 0.05).collect();
let total = 200_usize;
let xs: Vec<Vec<f32>> = (0..total)
.map(|t| {
(0..in_ch)
.map(|j| ((t * in_ch + j) as f32 * 0.31).sin())
.collect()
})
.collect();
let mut a = Conv1d::new(
in_ch,
out_ch,
kernel,
dilation,
w.clone(),
Some(bias.clone()),
);
let mut want = vec![0.0; out_ch];
let want_all: Vec<Vec<f32>> = xs
.iter()
.map(|x| {
a.process_sample(x, &mut want);
want.clone()
})
.collect();
let mut b = Conv1d::new(in_ch, out_ch, kernel, dilation, w, Some(bias));
let chunks = [50usize, 1, 99, 50];
let mut t0 = 0;
for &len in &chunks {
let mut bin = vec![0.0; in_ch * len];
for (lt, x) in xs[t0..t0 + len].iter().enumerate() {
for (j, &v) in x.iter().enumerate() {
bin[j * len + lt] = v;
}
}
let mut bout = vec![0.0; out_ch * len];
b.process_block(&bin, &mut bout, len);
for lt in 0..len {
for o in 0..out_ch {
let got = bout[o * len + lt];
let exp = want_all[t0 + lt][o];
assert!(
(got - exp).abs() < 1e-5,
"shape {in_ch}x{out_ch} k{kernel} d{dilation} t{} o{o}: got {got}, want {exp}",
t0 + lt
);
}
}
t0 += len;
}
}
}
#[test]
fn reset_clears_history() {
let mut conv = Conv1d::new(1, 1, 2, 1, vec![0.5, 2.0], None);
let _ = run(&mut conv, &[&[1.0], &[2.0]]);
conv.reset();
let got = run(&mut conv, &[&[1.0]]);
assert_eq!(got, vec![vec![2.0]]);
}
}