neurosky 0.0.1

Rust library and TUI for NeuroSky MindWave EEG headsets via the ThinkGear serial protocol
Documentation
//! Digital signal processing pipeline for raw EEG.
//!
//! Implements the real-time chain used in professional BCI tools like
//! [neuromore/studio](https://github.com/neuromore/studio):
//!
//! ```text
//! raw ADC  →  notch filter  →  ┌→ delta bandpass → epoch → power ┐
//!             (50 / 60 Hz)     ├→ theta bandpass → epoch → power ├→ BandPowers
//!                              ├→ alpha bandpass → epoch → power │
//!                              ├→ beta  bandpass → epoch → power │
//!                              └→ gamma bandpass → epoch → power ┘
//! ```
//!
//! All filters are 2nd-order biquad IIR in Direct Form II Transposed.
//! Coefficients are computed at construction time from the sample rate and
//! the desired cutoff frequencies using the Audio EQ Cookbook formulas.

use std::collections::VecDeque;

// ── Biquad IIR filter ─────────────────────────────────────────────────────────

/// 2nd-order biquad IIR filter (Direct Form II Transposed).
///
/// Coefficients follow the
/// [Audio EQ Cookbook](https://www.w3.org/TR/audio-eq-cookbook/) convention.
#[derive(Debug, Clone)]
pub struct BiquadFilter {
    b0: f64, b1: f64, b2: f64,
    a1: f64, a2: f64,
    z1: f64, z2: f64,
}

impl BiquadFilter {
    /// Construct from normalised (`a0 = 1`) coefficients.
    pub fn new(b0: f64, b1: f64, b2: f64, a1: f64, a2: f64) -> Self {
        BiquadFilter { b0, b1, b2, a1, a2, z1: 0.0, z2: 0.0 }
    }

    /// 2nd-order IIR notch filter.
    ///
    /// - `fs`  – sample rate in Hz
    /// - `f0`  – notch frequency in Hz (50 or 60 for power-line removal)
    /// - `q`   – quality factor; higher = narrower notch (30 is typical)
    pub fn notch(fs: f64, f0: f64, q: f64) -> Self {
        let w0    = 2.0 * std::f64::consts::PI * f0 / fs;
        let alpha = w0.sin() / (2.0 * q);
        let cos_w0 = w0.cos();
        let a0    = 1.0 + alpha;
        BiquadFilter::new(
            1.0 / a0,
            -2.0 * cos_w0 / a0,
            1.0 / a0,
            -2.0 * cos_w0 / a0,
            (1.0 - alpha) / a0,
        )
    }

    /// 2nd-order Butterworth bandpass filter (constant 0 dB peak gain).
    ///
    /// - `fs`    – sample rate in Hz
    /// - `f_low` – lower −3 dB edge in Hz
    /// - `f_hi`  – upper −3 dB edge in Hz
    pub fn bandpass(fs: f64, f_low: f64, f_hi: f64) -> Self {
        let f0     = (f_low * f_hi).sqrt();           // geometric centre
        let q      = f0 / (f_hi - f_low);
        let w0     = 2.0 * std::f64::consts::PI * f0 / fs;
        let sin_w0 = w0.sin();
        let cos_w0 = w0.cos();
        let alpha  = sin_w0 / (2.0 * q);
        let a0     = 1.0 + alpha;
        BiquadFilter::new(
            alpha / a0,
            0.0,
            -alpha / a0,
            -2.0 * cos_w0 / a0,
            (1.0 - alpha) / a0,
        )
    }

    /// Process one sample; returns filtered output.
    #[inline]
    pub fn process(&mut self, x: f64) -> f64 {
        let y  = self.b0 * x + self.z1;
        self.z1 = self.b1 * x - self.a1 * y + self.z2;
        self.z2 = self.b2 * x - self.a2 * y;
        y
    }

    /// Zero the delay elements (hard reset).
    pub fn reset(&mut self) { self.z1 = 0.0; self.z2 = 0.0; }
}

// ── Epoch buffer ──────────────────────────────────────────────────────────────

/// Fixed-capacity circular buffer for windowed power estimation.
#[derive(Debug, Clone)]
pub struct EpochBuffer {
    data:     VecDeque<f64>,
    capacity: usize,
}

impl EpochBuffer {
    pub fn new(capacity: usize) -> Self {
        EpochBuffer { data: VecDeque::with_capacity(capacity + 1), capacity }
    }

    /// Push a sample, evicting the oldest when full.
    pub fn push(&mut self, sample: f64) {
        if self.data.len() >= self.capacity { self.data.pop_front(); }
        self.data.push_back(sample);
    }

    pub fn len(&self)      -> usize { self.data.len() }
    pub fn capacity(&self) -> usize { self.capacity }
    pub fn is_full(&self)  -> bool  { self.data.len() >= self.capacity }

    /// Mean squared amplitude (power) of the window.
    pub fn mean_power(&self) -> f64 {
        if self.data.is_empty() { return 0.0; }
        self.data.iter().map(|&x| x * x).sum::<f64>() / self.data.len() as f64
    }

    /// Root mean square.
    pub fn rms(&self) -> f64 { self.mean_power().sqrt() }

    pub fn clear(&mut self) { self.data.clear(); }
}

// ── Band powers ───────────────────────────────────────────────────────────────

/// EEG frequency band powers (mean squared amplitude in each band).
#[derive(Debug, Clone, Copy, Default)]
pub struct BandPowers {
    pub delta: f64,   //  0.5 –  4 Hz
    pub theta: f64,   //  4   –  8 Hz
    pub alpha: f64,   //  8   – 13 Hz
    pub beta:  f64,   // 13   – 30 Hz
    pub gamma: f64,   // 30   – 50 Hz
}

impl BandPowers {
    /// Normalise each band as its share of total power (sum = 1.0).
    pub fn normalised(&self) -> Self {
        let total = self.delta + self.theta + self.alpha + self.beta + self.gamma;
        if total == 0.0 { return *self; }
        BandPowers {
            delta: self.delta / total,
            theta: self.theta / total,
            alpha: self.alpha / total,
            beta:  self.beta  / total,
            gamma: self.gamma / total,
        }
    }

    /// As array `[delta, theta, alpha, beta, gamma]`.
    pub fn as_array(&self) -> [f64; 5] {
        [self.delta, self.theta, self.alpha, self.beta, self.gamma]
    }

    /// Maximum band power value.
    pub fn max(&self) -> f64 {
        self.as_array().iter().cloned().fold(0.0_f64, f64::max)
    }
}

// ── Band power extractor ──────────────────────────────────────────────────────

/// Real-time EEG band-power extractor, inspired by the
/// [neuromore/studio](https://github.com/neuromore/studio) signal-processing pipeline.
///
/// # Signal chain
/// ```text
/// raw sample ──► notch filter ──► five parallel bandpass filters
///                                 each feeding a rolling epoch buffer
///                                 → mean squared amplitude = band power
/// ```
///
/// Designed to run at the ThinkGear raw EEG rate (512 Hz).
pub struct BandPowerExtractor {
    notch:     BiquadFilter,
    delta_bp:  BiquadFilter,
    theta_bp:  BiquadFilter,
    alpha_bp:  BiquadFilter,
    beta_bp:   BiquadFilter,
    gamma_bp:  BiquadFilter,
    delta_buf: EpochBuffer,
    theta_buf: EpochBuffer,
    alpha_buf: EpochBuffer,
    beta_buf:  EpochBuffer,
    gamma_buf: EpochBuffer,
}

impl BandPowerExtractor {
    /// Create a new extractor.
    ///
    /// - `fs`       – sample rate in Hz (512 for ThinkGear raw EEG)
    /// - `notch_hz` – power-line frequency to reject (50 or 60 Hz)
    /// - `epoch`    – window size in samples for power estimation
    ///                (e.g. 512 = 1 second at 512 Hz)
    pub fn new(fs: f64, notch_hz: f64, epoch: usize) -> Self {
        BandPowerExtractor {
            notch:     BiquadFilter::notch(fs, notch_hz, 30.0),
            delta_bp:  BiquadFilter::bandpass(fs,  0.5,  4.0),
            theta_bp:  BiquadFilter::bandpass(fs,  4.0,  8.0),
            alpha_bp:  BiquadFilter::bandpass(fs,  8.0, 13.0),
            beta_bp:   BiquadFilter::bandpass(fs, 13.0, 30.0),
            gamma_bp:  BiquadFilter::bandpass(fs, 30.0, 50.0),
            delta_buf: EpochBuffer::new(epoch),
            theta_buf: EpochBuffer::new(epoch),
            alpha_buf: EpochBuffer::new(epoch),
            beta_buf:  EpochBuffer::new(epoch),
            gamma_buf: EpochBuffer::new(epoch),
        }
    }

    /// Push one raw ADC sample; returns updated [`BandPowers`].
    pub fn push(&mut self, raw: i16) -> BandPowers {
        let x = self.notch.process(raw as f64);
        self.delta_buf.push(self.delta_bp.process(x));
        self.theta_buf.push(self.theta_bp.process(x));
        self.alpha_buf.push(self.alpha_bp.process(x));
        self.beta_buf .push(self.beta_bp .process(x));
        self.gamma_buf.push(self.gamma_bp.process(x));
        self.current()
    }

    /// Current band powers without feeding a new sample.
    pub fn current(&self) -> BandPowers {
        BandPowers {
            delta: self.delta_buf.mean_power(),
            theta: self.theta_buf.mean_power(),
            alpha: self.alpha_buf.mean_power(),
            beta:  self.beta_buf .mean_power(),
            gamma: self.gamma_buf.mean_power(),
        }
    }

    /// Reset all filter states and epoch buffers.
    pub fn reset(&mut self) {
        self.notch.reset();
        self.delta_bp.reset(); self.theta_bp.reset(); self.alpha_bp.reset();
        self.beta_bp.reset();  self.gamma_bp.reset();
        let cap = self.delta_buf.capacity();
        self.delta_buf = EpochBuffer::new(cap);
        self.theta_buf = EpochBuffer::new(cap);
        self.alpha_buf = EpochBuffer::new(cap);
        self.beta_buf  = EpochBuffer::new(cap);
        self.gamma_buf = EpochBuffer::new(cap);
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use std::f64::consts::PI;

    fn sine(fs: f64, freq: f64, n: usize) -> Vec<f64> {
        (0..n).map(|i| (2.0 * PI * freq * i as f64 / fs).sin()).collect()
    }

    fn rms(samples: &[f64]) -> f64 {
        (samples.iter().map(|x| x * x).sum::<f64>() / samples.len() as f64).sqrt()
    }

    #[test]
    fn test_notch_attenuates_target() {
        let fs = 512.0;
        let mut f = BiquadFilter::notch(fs, 50.0, 30.0);
        let input = sine(fs, 50.0, 1024);
        let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
        // skip transient
        assert!(rms(&out[512..]) < 0.05, "50 Hz not attenuated: rms={}", rms(&out[512..]));
    }

    #[test]
    fn test_notch_passes_distant_frequency() {
        let fs = 512.0;
        let mut f = BiquadFilter::notch(fs, 50.0, 30.0);
        let input = sine(fs, 10.0, 1024);
        let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
        assert!(rms(&out[512..]) > 0.5, "10 Hz attenuated by notch: rms={}", rms(&out[512..]));
    }

    #[test]
    fn test_bandpass_rejects_dc() {
        let mut f = BiquadFilter::bandpass(512.0, 8.0, 13.0);
        let mut out = 0.0;
        for _ in 0..2000 { out = f.process(1.0); }
        assert!(out.abs() < 0.01, "DC not rejected: {out}");
    }

    #[test]
    fn test_bandpass_passes_centre() {
        let fs = 512.0;
        let mut f = BiquadFilter::bandpass(fs, 8.0, 13.0);
        let input = sine(fs, 10.0, 1024);
        let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
        assert!(rms(&out[512..]) > 0.2, "10 Hz attenuated by alpha bandpass: rms={}", rms(&out[512..]));
    }

    #[test]
    fn test_bandpass_rejects_out_of_band() {
        let fs = 512.0;
        let mut f = BiquadFilter::bandpass(fs, 8.0, 13.0);
        // 100 Hz is well outside alpha band
        let input = sine(fs, 100.0, 1024);
        let out: Vec<f64> = input.iter().map(|&x| f.process(x)).collect();
        assert!(rms(&out[512..]) < 0.2, "100 Hz not attenuated: rms={}", rms(&out[512..]));
    }

    #[test]
    fn test_epoch_buffer_mean_power() {
        let mut buf = EpochBuffer::new(4);
        for v in [1.0_f64, -1.0, 1.0, -1.0] { buf.push(v); }
        assert!(buf.is_full());
        assert!((buf.mean_power() - 1.0).abs() < 1e-9);
        assert!((buf.rms()        - 1.0).abs() < 1e-9);
    }

    #[test]
    fn test_epoch_buffer_eviction() {
        let mut buf = EpochBuffer::new(3);
        buf.push(100.0); buf.push(1.0); buf.push(1.0); buf.push(1.0);
        assert_eq!(buf.len(), 3);
        // 100.0 should have been evicted
        assert!((buf.mean_power() - 1.0).abs() < 1e-9);
    }

    #[test]
    fn test_band_powers_normalised() {
        let bp = BandPowers { delta: 1.0, theta: 1.0, alpha: 1.0, beta: 1.0, gamma: 1.0 };
        let n = bp.normalised();
        for v in n.as_array() { assert!((v - 0.2).abs() < 1e-9); }
    }

    #[test]
    fn test_band_powers_max() {
        let bp = BandPowers { delta: 1.0, theta: 2.0, alpha: 3.0, beta: 4.0, gamma: 5.0 };
        assert!((bp.max() - 5.0).abs() < 1e-9);
    }

    #[test]
    fn test_extractor_no_nans() {
        let mut ext = BandPowerExtractor::new(512.0, 50.0, 256);
        for i in 0..1024i16 {
            let bp = ext.push(i.wrapping_mul(100));
            assert!(bp.delta.is_finite(), "delta NaN at {i}");
            assert!(bp.alpha.is_finite(), "alpha NaN at {i}");
            assert!(bp.gamma.is_finite(), "gamma NaN at {i}");
        }
    }

    #[test]
    fn test_extractor_reset_clears() {
        let mut ext = BandPowerExtractor::new(512.0, 50.0, 256);
        for i in 0..512i16 { ext.push(i * 10); }
        ext.reset();
        let bp = ext.current();
        assert_eq!(bp.delta, 0.0);
        assert_eq!(bp.alpha, 0.0);
        assert_eq!(bp.gamma, 0.0);
    }

    #[test]
    fn test_extractor_alpha_in_alpha_band() {
        // Feed a pure 10 Hz sine (alpha band) for 2 seconds and verify
        // alpha power dominates over delta.
        let fs = 512.0;
        let mut ext = BandPowerExtractor::new(fs, 50.0, 512);
        let samples: Vec<i16> = (0..1024)
            .map(|i| ((2.0 * PI * 10.0 * i as f64 / fs).sin() * 1000.0) as i16)
            .collect();
        for &s in &samples { ext.push(s); }
        let bp = ext.current();
        assert!(bp.alpha > bp.delta, "alpha={} should dominate delta={}", bp.alpha, bp.delta);
    }
}