use alloc::vec;
use alloc::vec::Vec;
use crate::math;
use crate::rng::standard_normal;
pub struct DelayConv1D {
weights: Vec<f64>,
delays: Vec<f64>,
buffer: Vec<f64>,
buf_pos: usize,
d_in: usize,
kernel_size: usize,
buffer_len: usize,
}
impl DelayConv1D {
pub fn new(d_in: usize, kernel_size: usize, seed: u64) -> Self {
let mut rng = seed;
let buffer_len = 2 * kernel_size.max(1);
let scale = 1.0 / math::sqrt(kernel_size as f64);
let n_weights = d_in * kernel_size;
let weights: Vec<f64> = (0..n_weights)
.map(|_| standard_normal(&mut rng) * scale)
.collect();
let delays: Vec<f64> = (0..d_in)
.flat_map(|_| (0..kernel_size).map(|k| k as f64))
.collect();
let buffer = vec![0.0; d_in * buffer_len];
Self {
weights,
delays,
buffer,
buf_pos: 0,
d_in,
kernel_size,
buffer_len,
}
}
pub fn forward(&mut self, input: &[f64]) -> Vec<f64> {
for (d, &val) in input.iter().enumerate().take(self.d_in) {
self.buffer[d * self.buffer_len + self.buf_pos] = val;
}
let output = self.compute_output();
self.buf_pos = (self.buf_pos + 1) % self.buffer_len;
output
}
pub fn forward_predict(&self, input: &[f64]) -> Vec<f64> {
let mut buf_copy = self.buffer.clone();
for (d, &val) in input.iter().enumerate().take(self.d_in) {
buf_copy[d * self.buffer_len + self.buf_pos] = val;
}
let mut output = vec![0.0; self.d_in];
for (d, out_d) in output.iter_mut().enumerate() {
let mut sum = 0.0;
for k in 0..self.kernel_size {
let delay = self.delays[d * self.kernel_size + k];
let delay_int = crate::math::round(delay) as isize;
let idx = ((self.buf_pos as isize - delay_int).rem_euclid(self.buffer_len as isize))
as usize;
let w = self.weights[d * self.kernel_size + k];
sum += w * buf_copy[d * self.buffer_len + idx];
}
*out_d = sum;
}
output
}
pub fn reset(&mut self) {
self.buffer.fill(0.0);
self.buf_pos = 0;
}
#[inline]
pub fn d_in(&self) -> usize {
self.d_in
}
#[inline]
pub fn kernel_size(&self) -> usize {
self.kernel_size
}
fn compute_output(&self) -> Vec<f64> {
let mut output = vec![0.0; self.d_in];
for (d, out_d) in output.iter_mut().enumerate() {
let mut sum = 0.0;
for k in 0..self.kernel_size {
let delay = self.delays[d * self.kernel_size + k];
let delay_int = crate::math::round(delay) as isize;
let idx = ((self.buf_pos as isize - delay_int).rem_euclid(self.buffer_len as isize))
as usize;
let w = self.weights[d * self.kernel_size + k];
sum += w * self.buffer[d * self.buffer_len + idx];
}
*out_d = sum;
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn delay_conv_new() {
let conv = DelayConv1D::new(3, 4, 42);
assert_eq!(conv.d_in(), 3, "d_in should match constructor arg");
assert_eq!(
conv.kernel_size(),
4,
"kernel_size should match constructor arg"
);
assert_eq!(
conv.weights.len(),
3 * 4,
"weights should have d_in * kernel_size elements"
);
assert_eq!(
conv.delays.len(),
3 * 4,
"delays should have d_in * kernel_size elements"
);
}
#[test]
fn delay_conv_delays_initialized_correctly() {
let conv = DelayConv1D::new(2, 4, 42);
for d in 0..2 {
for k in 0..4 {
let expected = k as f64;
let actual = conv.delays[d * 4 + k];
assert!(
(actual - expected).abs() < 1e-12,
"delay[{d},{k}] should be {expected}, got {actual}"
);
}
}
}
#[test]
fn delay_conv_forward_output_length() {
let mut conv = DelayConv1D::new(5, 3, 42);
let input = [1.0, 2.0, 3.0, 4.0, 5.0];
let output = conv.forward(&input);
assert_eq!(output.len(), 5, "output should have d_in elements");
}
#[test]
fn delay_conv_forward_finite() {
let mut conv = DelayConv1D::new(3, 4, 123);
let input = [1.0, -0.5, 2.0];
for _ in 0..10 {
let output = conv.forward(&input);
for (i, &val) in output.iter().enumerate() {
assert!(val.is_finite(), "output[{}] = {} should be finite", i, val);
}
}
}
#[test]
fn delay_conv_forward_predict_no_state_change() {
let mut conv = DelayConv1D::new(3, 4, 42);
let input = [1.0, 2.0, 3.0];
conv.forward(&input);
let buf_before = conv.buffer.clone();
let pos_before = conv.buf_pos;
let _pred = conv.forward_predict(&[0.5, -0.5, 1.5]);
assert_eq!(
conv.buffer, buf_before,
"buffer should not change after forward_predict"
);
assert_eq!(
conv.buf_pos, pos_before,
"buf_pos should not change after forward_predict"
);
}
#[test]
fn delay_conv_reset() {
let mut conv = DelayConv1D::new(3, 4, 42);
for i in 0..10 {
conv.forward(&[i as f64, (i as f64) * 0.5, -(i as f64)]);
}
let weights_before = conv.weights.clone();
let delays_before = conv.delays.clone();
conv.reset();
assert!(
conv.buffer.iter().all(|&v| v == 0.0),
"buffer should be all zeros after reset"
);
assert_eq!(conv.buf_pos, 0, "buf_pos should be 0 after reset");
assert_eq!(
conv.weights, weights_before,
"weights should be preserved after reset"
);
assert_eq!(
conv.delays, delays_before,
"delays should be preserved after reset"
);
}
#[test]
fn delay_conv_circular_buffer_wraps() {
let mut conv = DelayConv1D::new(1, 2, 42);
for i in 0..10 {
conv.forward(&[i as f64]);
}
let output = conv.forward(&[10.0]);
assert!(
output[0].is_finite(),
"output should be finite after buffer wraps"
);
}
#[test]
fn delay_conv_forward_predict_matches_forward() {
let mut conv = DelayConv1D::new(3, 4, 42);
let x1 = [1.0, 2.0, 3.0];
conv.forward(&x1);
let x2 = [0.5, -0.5, 1.5];
let pred = conv.forward_predict(&x2);
let actual = conv.forward(&x2);
for (i, (p, a)) in pred.iter().zip(actual.iter()).enumerate() {
assert!(
(p - a).abs() < 1e-12,
"forward_predict[{i}]={p} should match forward[{i}]={a}"
);
}
}
}