use numra_core::Scalar;
#[derive(Clone, Debug)]
pub struct SosFilter<S: Scalar> {
pub sections: Vec<[S; 6]>,
}
impl<S: Scalar> SosFilter<S> {
pub fn new(sections: Vec<[S; 6]>) -> Self {
Self { sections }
}
pub fn n_sections(&self) -> usize {
self.sections.len()
}
pub fn order(&self) -> usize {
self.sections.len() * 2
}
}
pub fn sosfilt<S: Scalar>(sos: &SosFilter<S>, x: &[S]) -> Vec<S> {
if x.is_empty() || sos.sections.is_empty() {
return x.to_vec();
}
let mut y = x.to_vec();
for section in &sos.sections {
let b0 = section[0];
let b1 = section[1];
let b2 = section[2];
let a0 = section[3];
let a1 = section[4];
let a2 = section[5];
let b0 = b0 / a0;
let b1 = b1 / a0;
let b2 = b2 / a0;
let a1 = a1 / a0;
let a2 = a2 / a0;
let mut d1 = S::ZERO;
let mut d2 = S::ZERO;
for yi in y.iter_mut() {
let xi = *yi;
let out = b0 * xi + d1;
d1 = b1 * xi - a1 * out + d2;
d2 = b2 * xi - a2 * out;
*yi = out;
}
}
y
}
pub fn filtfilt<S: Scalar>(sos: &SosFilter<S>, x: &[S]) -> Vec<S> {
let n = x.len();
if n < 4 {
return sosfilt(sos, x);
}
let padlen = (6 * sos.n_sections()).min(n - 1);
let mut padded = Vec::with_capacity(n + 2 * padlen);
let two = S::from_f64(2.0);
let x0 = x[0];
for i in (1..=padlen).rev() {
padded.push(two * x0 - x[i]);
}
padded.extend_from_slice(x);
let xn = x[n - 1];
for i in (n - 1 - padlen..n - 1).rev() {
padded.push(two * xn - x[i]);
}
let mut y = sosfilt(sos, &padded);
y.reverse();
y = sosfilt(sos, &y);
y.reverse();
y[padlen..padlen + n].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sosfilt_identity() {
let sos = SosFilter::new(vec![[1.0, 0.0, 0.0, 1.0, 0.0, 0.0]]);
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = sosfilt(&sos, &x);
for (a, b) in x.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
#[test]
fn test_sosfilt_first_order() {
let sos = SosFilter::new(vec![[0.5, 0.0, 0.0, 1.0, -0.5, 0.0]]);
let x = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let y = sosfilt(&sos, &x);
assert!((y[0] - 0.5).abs() < 1e-12);
assert!((y[1] - 0.25).abs() < 1e-12);
assert!((y[2] - 0.125).abs() < 1e-12);
assert!((y[3] - 0.0625).abs() < 1e-12);
}
#[test]
fn test_sosfilt_empty() {
let sos = SosFilter::<f64>::new(vec![[1.0, 0.0, 0.0, 1.0, 0.0, 0.0]]);
let y = sosfilt(&sos, &[]);
assert!(y.is_empty());
}
#[test]
fn test_filtfilt_preserves_length() {
let sos = SosFilter::new(vec![[0.5, 0.5, 0.0, 1.0, -0.5, 0.0]]);
let x: Vec<f64> = (0..100).map(|i| (i as f64 * 0.1).sin()).collect();
let y = filtfilt(&sos, &x);
assert_eq!(y.len(), x.len());
}
#[test]
fn test_filtfilt_dc_passthrough() {
use crate::filter_design::butter;
let sos = butter(4, 10.0, 100.0).unwrap();
let x = vec![3.0; 100];
let y = filtfilt(&sos, &x);
for &yi in &y[20..80] {
assert!((yi - 3.0).abs() < 0.01, "expected ~3.0, got {yi}");
}
}
#[test]
fn test_sosfilt_cascade() {
let sos = SosFilter::new(vec![
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
]);
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = sosfilt(&sos, &x);
for (a, b) in x.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
}