use serde::{Deserialize, Serialize};
use crate::error::{MastishkError, validate_dt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum EegBand {
Delta,
Theta,
Alpha,
Beta,
Gamma,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EegState {
pub delta: f32,
pub theta: f32,
pub alpha: f32,
pub beta: f32,
pub gamma: f32,
}
impl Default for EegState {
fn default() -> Self {
Self {
delta: 0.1,
theta: 0.15,
alpha: 0.5,
beta: 0.3,
gamma: 0.1,
}
}
}
impl EegState {
#[inline]
pub fn tick_toward(&mut self, target: &EegState, dt: f32) -> Result<(), MastishkError> {
validate_dt(dt)?;
let alpha = 1.0 - (-0.5 * dt).exp();
self.delta += (target.delta - self.delta) * alpha;
self.theta += (target.theta - self.theta) * alpha;
self.alpha += (target.alpha - self.alpha) * alpha;
self.beta += (target.beta - self.beta) * alpha;
self.gamma += (target.gamma - self.gamma) * alpha;
tracing::trace!(
delta = self.delta,
alpha_band = self.alpha,
beta = self.beta,
"EEG tick"
);
Ok(())
}
#[inline]
#[must_use]
pub fn dominant_band(&self) -> EegBand {
let bands = [
(self.delta, EegBand::Delta),
(self.theta, EegBand::Theta),
(self.alpha, EegBand::Alpha),
(self.beta, EegBand::Beta),
(self.gamma, EegBand::Gamma),
];
bands
.iter()
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal))
.map(|&(_, band)| band)
.unwrap_or(EegBand::Alpha)
}
#[inline]
#[must_use]
pub fn total_power(&self) -> f32 {
self.delta + self.theta + self.alpha + self.beta + self.gamma
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_alpha_dominant() {
let s = EegState::default();
assert_eq!(s.dominant_band(), EegBand::Alpha);
}
#[test]
fn test_tick_toward_converges() {
let mut s = EegState::default();
let target = EegState {
delta: 0.8,
theta: 0.1,
alpha: 0.05,
beta: 0.02,
gamma: 0.01,
};
for _ in 0..100 {
s.tick_toward(&target, 1.0).unwrap();
}
assert!((s.delta - 0.8).abs() < 0.05);
assert_eq!(s.dominant_band(), EegBand::Delta);
}
#[test]
fn test_total_power() {
let s = EegState::default();
assert!(s.total_power() > 0.0);
}
#[test]
fn test_serde_roundtrip() {
let s = EegState::default();
let json = serde_json::to_string(&s).unwrap();
let s2: EegState = serde_json::from_str(&json).unwrap();
assert!((s2.alpha - s.alpha).abs() < f32::EPSILON);
}
}