use elata_eeg_hal::SampleBuffer;
use elata_eeg_signal::{band_power, bands};
use crate::model::{Model, ModelOutput};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlphaState {
High,
Low,
Transitioning,
Unknown,
}
impl AlphaState {
pub fn as_str(&self) -> &'static str {
match self {
AlphaState::High => "high",
AlphaState::Low => "low",
AlphaState::Transitioning => "transitioning",
AlphaState::Unknown => "unknown",
}
}
}
#[derive(Debug, Clone)]
pub struct AlphaBumpOutput {
pub state: AlphaState,
pub alpha_power: f32,
pub baseline: f32,
pub state_changed: bool,
pub previous_state: Option<AlphaState>,
}
impl ModelOutput for AlphaBumpOutput {
fn description(&self) -> String {
let change = if self.state_changed {
format!(
" (changed from {})",
self.previous_state.map(|s| s.as_str()).unwrap_or("unknown")
)
} else {
String::new()
};
format!(
"Alpha: {} (power={:.2}, baseline={:.2}){}",
self.state.as_str(),
self.alpha_power,
self.baseline,
change
)
}
fn value(&self) -> Option<f32> {
Some(self.alpha_power)
}
fn confidence(&self) -> Option<f32> {
if self.baseline < 1e-6 {
return Some(0.0);
}
let ratio = self.alpha_power / self.baseline;
let distance_from_threshold = (ratio - 1.0).abs();
Some((distance_from_threshold * 2.0).min(1.0))
}
}
pub struct AlphaBumpDetector {
sample_rate: u16,
baseline: f32,
baseline_alpha: f32,
threshold_multiplier: f32,
current_state: AlphaState,
samples_processed: usize,
warmup_samples: usize,
}
impl AlphaBumpDetector {
pub fn new(sample_rate: u16) -> Self {
Self {
sample_rate,
baseline: 0.0,
baseline_alpha: 0.02, threshold_multiplier: 1.5,
current_state: AlphaState::Unknown,
samples_processed: 0,
warmup_samples: sample_rate as usize * 2, }
}
pub fn set_threshold(&mut self, multiplier: f32) {
self.threshold_multiplier = multiplier.max(1.1);
}
pub fn set_baseline_smoothing(&mut self, alpha: f32) {
self.baseline_alpha = alpha.clamp(0.001, 0.5);
}
fn compute_alpha_power(&self, buffer: &SampleBuffer) -> f32 {
let channel_count = buffer.channel_count;
if channel_count == 0 {
return 0.0;
}
let mut total_power = 0.0;
for ch in 0..channel_count {
let data = buffer.channel_data(ch);
let power = band_power(data, self.sample_rate as f32, bands::ALPHA);
total_power += power;
}
total_power / channel_count as f32
}
fn update_baseline(&mut self, alpha_power: f32) {
if self.samples_processed == 0 {
self.baseline = alpha_power;
} else {
self.baseline =
self.baseline * (1.0 - self.baseline_alpha) + alpha_power * self.baseline_alpha;
}
}
fn determine_state(&self, alpha_power: f32) -> AlphaState {
if self.samples_processed < self.warmup_samples || self.baseline < 1e-6 {
return AlphaState::Unknown;
}
let ratio = alpha_power / self.baseline;
let high_threshold = self.threshold_multiplier;
let low_threshold = 1.0 / self.threshold_multiplier;
if ratio > high_threshold {
AlphaState::High
} else if ratio < low_threshold {
AlphaState::Low
} else {
match self.current_state {
AlphaState::High | AlphaState::Low => AlphaState::Transitioning,
_ => AlphaState::Unknown,
}
}
}
}
impl Model for AlphaBumpDetector {
type Output = AlphaBumpOutput;
fn name(&self) -> &str {
"Alpha Bump Detector"
}
fn min_samples(&self) -> usize {
(self.sample_rate as usize) / 2
}
fn process(&mut self, buffer: &SampleBuffer) -> Option<Self::Output> {
if buffer.sample_count() < self.min_samples() {
return None;
}
let alpha_power = self.compute_alpha_power(buffer);
self.update_baseline(alpha_power);
self.samples_processed += buffer.sample_count();
let new_state = self.determine_state(alpha_power);
let state_changed = new_state != self.current_state
&& new_state != AlphaState::Unknown
&& self.current_state != AlphaState::Unknown;
let previous_state = if state_changed {
Some(self.current_state)
} else {
None
};
if new_state != AlphaState::Unknown {
self.current_state = new_state;
}
Some(AlphaBumpOutput {
state: self.current_state,
alpha_power,
baseline: self.baseline,
state_changed,
previous_state,
})
}
fn reset(&mut self) {
self.baseline = 0.0;
self.current_state = AlphaState::Unknown;
self.samples_processed = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_buffer(sample_rate: u16, samples: usize, freq: f32) -> SampleBuffer {
use std::f32::consts::PI;
let mut buffer = SampleBuffer::new(sample_rate, 1);
let data: Vec<f32> = (0..samples)
.map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin() * 50.0)
.collect();
buffer.push_interleaved(&data, 0, sample_rate);
buffer
}
fn create_test_buffer_with_amplitude(
sample_rate: u16,
samples: usize,
freq: f32,
amplitude: f32,
) -> SampleBuffer {
use std::f32::consts::PI;
let mut buffer = SampleBuffer::new(sample_rate, 1);
let data: Vec<f32> = (0..samples)
.map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin() * amplitude)
.collect();
buffer.push_interleaved(&data, 0, sample_rate);
buffer
}
#[test]
fn test_detector_creation() {
let detector = AlphaBumpDetector::new(256);
assert_eq!(detector.sample_rate, 256);
assert_eq!(detector.current_state, AlphaState::Unknown);
}
#[test]
fn test_alpha_detection() {
let mut detector = AlphaBumpDetector::new(256);
let buffer = create_test_buffer(256, 512, 10.0);
let output = detector.process(&buffer);
assert!(output.is_some());
let output = output.unwrap();
assert!(output.alpha_power > 0.0);
}
#[test]
fn test_state_transitions_with_synthetic_alpha_power_changes() {
let sample_rate = 256;
let mut detector = AlphaBumpDetector::new(sample_rate);
let baseline_buffer = create_test_buffer_with_amplitude(sample_rate, 512, 10.0, 20.0);
let warmup_output = detector.process(&baseline_buffer).expect("warmup output");
assert_eq!(warmup_output.state, AlphaState::Unknown);
let high_buffer = create_test_buffer_with_amplitude(sample_rate, 512, 10.0, 60.0);
let high_output = detector.process(&high_buffer).expect("high output");
assert_eq!(high_output.state, AlphaState::High);
assert!(!high_output.state_changed);
let low_buffer = create_test_buffer_with_amplitude(sample_rate, 512, 10.0, 5.0);
let low_output = detector.process(&low_buffer).expect("low output");
assert_eq!(low_output.state, AlphaState::Low);
assert!(low_output.state_changed);
assert_eq!(low_output.previous_state, Some(AlphaState::High));
}
}