#[derive(Debug)]
pub(crate) struct SymmetricMatrixBuffer<const S: usize> {
buf: Vec<f32>,
}
impl<const S: usize> Default for SymmetricMatrixBuffer<S> {
fn default() -> Self {
const { assert!(S > 2) };
Self {
buf: vec![0.0; (S - 1) * (S - 1)],
}
}
}
impl<const S: usize> SymmetricMatrixBuffer<S> {
#[cfg(test)]
pub(crate) fn reset(&mut self) {
self.buf.fill(0.0);
}
pub(crate) fn push(&mut self, values: &[f32]) {
debug_assert_eq!(values.len(), S - 1);
self.buf.copy_within(S.., 0);
for (i, &val) in values.iter().enumerate().take(S - 1) {
let index = (S - 1 - i) * (S - 1) - 1;
debug_assert!(index < self.buf.len());
self.buf[index] = val;
}
}
pub(crate) fn get_value(&self, delay1: usize, delay2: usize) -> f32 {
use std::mem;
debug_assert_ne!(delay1, delay2, "The diagonal cannot be accessed.");
let mut row = S - 1 - delay1;
let mut col = S - 1 - delay2;
if row > col {
mem::swap(&mut row, &mut col);
}
debug_assert!(row < S - 1);
debug_assert!(col >= 1 && col < S);
let index = row * (S - 1) + (col - 1);
debug_assert!(index < self.buf.len());
self.buf[index]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rnn_vad::ring_buffer::RingBuffer;
fn check_symmetry<const S: usize>(sym_matrix_buf: &SymmetricMatrixBuffer<S>) {
for row in 0..S - 1 {
for col in row + 1..S {
assert_eq!(
sym_matrix_buf.get_value(row, col),
sym_matrix_buf.get_value(col, row),
"Asymmetry at ({row}, {col})"
);
}
}
}
#[test]
fn symmetric_matrix_buffer_use_case() {
const RING_BUF_SIZE: usize = 10;
let mut ring_buf = RingBuffer::<1, RING_BUF_SIZE>::default();
let mut sym_matrix_buf = SymmetricMatrixBuffer::<RING_BUF_SIZE>::default();
for t in 1..=100u32 {
let t_f = t as f32;
ring_buf.push(&[t_f]);
assert_eq!(ring_buf.get_array_view(0), &[t_f]);
let mut new_comparisons = [0.0_f32; RING_BUF_SIZE - 1];
for (i, cmp) in new_comparisons.iter_mut().enumerate() {
let delay = i + 1;
let t_prev = ring_buf.get_array_view(delay)[0];
*cmp = t_prev * 1000.0 + t_f;
}
sym_matrix_buf.push(&new_comparisons);
check_symmetry(&sym_matrix_buf);
for delay1 in 0..RING_BUF_SIZE - 1 {
for delay2 in delay1 + 1..RING_BUF_SIZE {
let t1 = ring_buf.get_array_view(delay1)[0];
let t2 = ring_buf.get_array_view(delay2)[0];
assert!(t2 <= t1);
let val = sym_matrix_buf.get_value(delay1, delay2);
let expected = t2 * 1000.0 + t1;
assert_eq!(
val, expected,
"Mismatch at t={t}, delay1={delay1}, delay2={delay2}"
);
}
}
}
}
}