use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct BiquadFilter {
b0: f64, b1: f64, b2: f64,
a1: f64, a2: f64,
z1: f64, z2: f64,
}
impl BiquadFilter {
pub fn new(b0: f64, b1: f64, b2: f64, a1: f64, a2: f64) -> Self {
BiquadFilter { b0, b1, b2, a1, a2, z1: 0.0, z2: 0.0 }
}
pub fn notch(fs: f64, f0: f64, q: f64) -> Self {
let w0 = 2.0 * std::f64::consts::PI * f0 / fs;
let alpha = w0.sin() / (2.0 * q);
let cos_w0 = w0.cos();
let a0 = 1.0 + alpha;
BiquadFilter::new(
1.0 / a0,
-2.0 * cos_w0 / a0,
1.0 / a0,
-2.0 * cos_w0 / a0,
(1.0 - alpha) / a0,
)
}
pub fn bandpass(fs: f64, f_low: f64, f_hi: f64) -> Self {
let f0 = (f_low * f_hi).sqrt(); let q = f0 / (f_hi - f_low);
let w0 = 2.0 * std::f64::consts::PI * f0 / fs;
let sin_w0 = w0.sin();
let cos_w0 = w0.cos();
let alpha = sin_w0 / (2.0 * q);
let a0 = 1.0 + alpha;
BiquadFilter::new(
alpha / a0,
0.0,
-alpha / a0,
-2.0 * cos_w0 / a0,
(1.0 - alpha) / a0,
)
}
#[inline]
pub fn process(&mut self, x: f64) -> f64 {
let y = self.b0 * x + self.z1;
self.z1 = self.b1 * x - self.a1 * y + self.z2;
self.z2 = self.b2 * x - self.a2 * y;
y
}
pub fn reset(&mut self) { self.z1 = 0.0; self.z2 = 0.0; }
}
#[derive(Debug, Clone)]
pub struct EpochBuffer {
data: VecDeque<f64>,
capacity: usize,
}
impl EpochBuffer {
pub fn new(capacity: usize) -> Self {
EpochBuffer { data: VecDeque::with_capacity(capacity + 1), capacity }
}
pub fn push(&mut self, sample: f64) {
if self.data.len() >= self.capacity { self.data.pop_front(); }
self.data.push_back(sample);
}
pub fn len(&self) -> usize { self.data.len() }
pub fn capacity(&self) -> usize { self.capacity }
pub fn is_full(&self) -> bool { self.data.len() >= self.capacity }
pub fn mean_power(&self) -> f64 {
if self.data.is_empty() { return 0.0; }
self.data.iter().map(|&x| x * x).sum::<f64>() / self.data.len() as f64
}
pub fn rms(&self) -> f64 { self.mean_power().sqrt() }
pub fn clear(&mut self) { self.data.clear(); }
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BandPowers {
pub delta: f64, pub theta: f64, pub alpha: f64, pub beta: f64, pub gamma: f64, }
impl BandPowers {
pub fn normalised(&self) -> Self {
let total = self.delta + self.theta + self.alpha + self.beta + self.gamma;
if total == 0.0 { return *self; }
BandPowers {
delta: self.delta / total,
theta: self.theta / total,
alpha: self.alpha / total,
beta: self.beta / total,
gamma: self.gamma / total,
}
}
pub fn as_array(&self) -> [f64; 5] {
[self.delta, self.theta, self.alpha, self.beta, self.gamma]
}
pub fn max(&self) -> f64 {
self.as_array().iter().cloned().fold(0.0_f64, f64::max)
}
}
pub struct BandPowerExtractor {
notch: BiquadFilter,
delta_bp: BiquadFilter,
theta_bp: BiquadFilter,
alpha_bp: BiquadFilter,
beta_bp: BiquadFilter,
gamma_bp: BiquadFilter,
delta_buf: EpochBuffer,
theta_buf: EpochBuffer,
alpha_buf: EpochBuffer,
beta_buf: EpochBuffer,
gamma_buf: EpochBuffer,
}
impl BandPowerExtractor {
pub fn new(fs: f64, notch_hz: f64, epoch: usize) -> Self {
BandPowerExtractor {
notch: BiquadFilter::notch(fs, notch_hz, 30.0),
delta_bp: BiquadFilter::bandpass(fs, 0.5, 4.0),
theta_bp: BiquadFilter::bandpass(fs, 4.0, 8.0),
alpha_bp: BiquadFilter::bandpass(fs, 8.0, 13.0),
beta_bp: BiquadFilter::bandpass(fs, 13.0, 30.0),
gamma_bp: BiquadFilter::bandpass(fs, 30.0, 50.0),
delta_buf: EpochBuffer::new(epoch),
theta_buf: EpochBuffer::new(epoch),
alpha_buf: EpochBuffer::new(epoch),
beta_buf: EpochBuffer::new(epoch),
gamma_buf: EpochBuffer::new(epoch),
}
}
pub fn push(&mut self, raw: i16) -> BandPowers {
let x = self.notch.process(raw as f64);
self.delta_buf.push(self.delta_bp.process(x));
self.theta_buf.push(self.theta_bp.process(x));
self.alpha_buf.push(self.alpha_bp.process(x));
self.beta_buf .push(self.beta_bp .process(x));
self.gamma_buf.push(self.gamma_bp.process(x));
self.current()
}
pub fn current(&self) -> BandPowers {
BandPowers {
delta: self.delta_buf.mean_power(),
theta: self.theta_buf.mean_power(),
alpha: self.alpha_buf.mean_power(),
beta: self.beta_buf .mean_power(),
gamma: self.gamma_buf.mean_power(),
}
}
pub fn reset(&mut self) {
self.notch.reset();
self.delta_bp.reset(); self.theta_bp.reset(); self.alpha_bp.reset();
self.beta_bp.reset(); self.gamma_bp.reset();
let cap = self.delta_buf.capacity();
self.delta_buf = EpochBuffer::new(cap);
self.theta_buf = EpochBuffer::new(cap);
self.alpha_buf = EpochBuffer::new(cap);
self.beta_buf = EpochBuffer::new(cap);
self.gamma_buf = EpochBuffer::new(cap);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn sine(fs: f64, freq: f64, n: usize) -> Vec<f64> {
(0..n).map(|i| (2.0 * PI * freq * i as f64 / fs).sin()).collect()
}
fn rms(samples: &[f64]) -> f64 {
(samples.iter().map(|x| x * x).sum::<f64>() / samples.len() as f64).sqrt()
}
#[test]
fn test_notch_attenuates_target() {
let fs = 512.0;
let mut f = BiquadFilter::notch(fs, 50.0, 30.0);
let input = sine(fs, 50.0, 1024);
let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
assert!(rms(&out[512..]) < 0.05, "50 Hz not attenuated: rms={}", rms(&out[512..]));
}
#[test]
fn test_notch_passes_distant_frequency() {
let fs = 512.0;
let mut f = BiquadFilter::notch(fs, 50.0, 30.0);
let input = sine(fs, 10.0, 1024);
let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
assert!(rms(&out[512..]) > 0.5, "10 Hz attenuated by notch: rms={}", rms(&out[512..]));
}
#[test]
fn test_bandpass_rejects_dc() {
let mut f = BiquadFilter::bandpass(512.0, 8.0, 13.0);
let mut out = 0.0;
for _ in 0..2000 { out = f.process(1.0); }
assert!(out.abs() < 0.01, "DC not rejected: {out}");
}
#[test]
fn test_bandpass_passes_centre() {
let fs = 512.0;
let mut f = BiquadFilter::bandpass(fs, 8.0, 13.0);
let input = sine(fs, 10.0, 1024);
let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
assert!(rms(&out[512..]) > 0.2, "10 Hz attenuated by alpha bandpass: rms={}", rms(&out[512..]));
}
#[test]
fn test_bandpass_rejects_out_of_band() {
let fs = 512.0;
let mut f = BiquadFilter::bandpass(fs, 8.0, 13.0);
let input = sine(fs, 100.0, 1024);
let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
assert!(rms(&out[512..]) < 0.2, "100 Hz not attenuated: rms={}", rms(&out[512..]));
}
#[test]
fn test_epoch_buffer_mean_power() {
let mut buf = EpochBuffer::new(4);
for v in [1.0_f64, -1.0, 1.0, -1.0] { buf.push(v); }
assert!(buf.is_full());
assert!((buf.mean_power() - 1.0).abs() < 1e-9);
assert!((buf.rms() - 1.0).abs() < 1e-9);
}
#[test]
fn test_epoch_buffer_eviction() {
let mut buf = EpochBuffer::new(3);
buf.push(100.0); buf.push(1.0); buf.push(1.0); buf.push(1.0);
assert_eq!(buf.len(), 3);
assert!((buf.mean_power() - 1.0).abs() < 1e-9);
}
#[test]
fn test_band_powers_normalised() {
let bp = BandPowers { delta: 1.0, theta: 1.0, alpha: 1.0, beta: 1.0, gamma: 1.0 };
let n = bp.normalised();
for v in n.as_array() { assert!((v - 0.2).abs() < 1e-9); }
}
#[test]
fn test_band_powers_max() {
let bp = BandPowers { delta: 1.0, theta: 2.0, alpha: 3.0, beta: 4.0, gamma: 5.0 };
assert!((bp.max() - 5.0).abs() < 1e-9);
}
#[test]
fn test_extractor_no_nans() {
let mut ext = BandPowerExtractor::new(512.0, 50.0, 256);
for i in 0..1024i16 {
let bp = ext.push(i.wrapping_mul(100));
assert!(bp.delta.is_finite(), "delta NaN at {i}");
assert!(bp.alpha.is_finite(), "alpha NaN at {i}");
assert!(bp.gamma.is_finite(), "gamma NaN at {i}");
}
}
#[test]
fn test_extractor_reset_clears() {
let mut ext = BandPowerExtractor::new(512.0, 50.0, 256);
for i in 0..512i16 { ext.push(i * 10); }
ext.reset();
let bp = ext.current();
assert_eq!(bp.delta, 0.0);
assert_eq!(bp.alpha, 0.0);
assert_eq!(bp.gamma, 0.0);
}
#[test]
fn test_extractor_alpha_in_alpha_band() {
let fs = 512.0;
let mut ext = BandPowerExtractor::new(fs, 50.0, 512);
let samples: Vec<i16> = (0..1024)
.map(|i| ((2.0 * PI * 10.0 * i as f64 / fs).sin() * 1000.0) as i16)
.collect();
for &s in &samples { ext.push(s); }
let bp = ext.current();
assert!(bp.alpha > bp.delta, "alpha={} should dominate delta={}", bp.alpha, bp.delta);
}
}