use sonora_simd::SimdBackend;
#[derive(Debug)]
pub(crate) struct VectorMath {
backend: SimdBackend,
}
impl VectorMath {
pub(crate) fn new(backend: SimdBackend) -> Self {
Self { backend }
}
pub(crate) fn sqrt(&self, x: &mut [f32]) {
self.backend.elementwise_sqrt(x);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::FFT_LENGTH_BY_2_PLUS_1;
#[test]
fn sqrt_matches_scalar() {
let vm = VectorMath::new(sonora_simd::detect_backend());
let mut x = [0.0f32; FFT_LENGTH_BY_2_PLUS_1];
for (k, v) in x.iter_mut().enumerate() {
*v = (2.0 / 3.0) * k as f32;
}
let mut z = x;
vm.sqrt(&mut z);
for k in 0..z.len() {
assert!(
(z[k] - x[k].sqrt()).abs() < 0.0001,
"mismatch at {k}: got {}, expected {}",
z[k],
x[k].sqrt()
);
}
}
}