use elata_eeg_hal::SampleBuffer;
use elata_eeg_signal::band_powers;
use crate::model::{Model, ModelOutput};
#[derive(Debug, Clone)]
pub struct CalmnessOutput {
pub score: f32,
pub smoothed_score: f32,
pub alpha_beta_ratio: f32,
pub theta_level: f32,
pub alpha_power: f32,
pub beta_power: f32,
pub theta_power: f32,
}
impl ModelOutput for CalmnessOutput {
fn description(&self) -> String {
let state = if self.smoothed_score > 0.7 {
"very calm"
} else if self.smoothed_score > 0.5 {
"calm"
} else if self.smoothed_score > 0.3 {
"neutral"
} else {
"alert"
};
format!(
"Calmness: {:.0}% ({}) [α/β={:.2}]",
self.smoothed_score * 100.0,
state,
self.alpha_beta_ratio
)
}
fn value(&self) -> Option<f32> {
Some(self.smoothed_score)
}
fn confidence(&self) -> Option<f32> {
let signal_strength = self.alpha_power + self.beta_power + self.theta_power;
Some((signal_strength / 1000.0).min(1.0))
}
}
pub struct CalmnessModel {
sample_rate: u16,
smoothed_score: f32,
smoothing_alpha: f32,
update_count: usize,
}
impl CalmnessModel {
pub fn new(sample_rate: u16) -> Self {
Self {
sample_rate,
smoothed_score: 0.5, smoothing_alpha: 0.1,
update_count: 0,
}
}
pub fn set_smoothing(&mut self, alpha: f32) {
self.smoothing_alpha = alpha.clamp(0.01, 1.0);
}
fn compute_calmness(&self, alpha: f32, beta: f32, theta: f32) -> (f32, f32) {
let beta_safe = beta.max(1e-6);
let alpha_beta_ratio = alpha / beta_safe;
let relaxation_power = alpha + theta * 0.5;
let alertness_power = beta;
let _ratio = relaxation_power / (relaxation_power + alertness_power + 1e-6);
let normalized_ratio = (alpha_beta_ratio - 0.5) / 2.5; let score = 1.0 / (1.0 + (-normalized_ratio * 4.0).exp());
(score.clamp(0.0, 1.0), alpha_beta_ratio)
}
fn compute_band_powers(&self, buffer: &SampleBuffer) -> (f32, f32, f32) {
let channel_count = buffer.channel_count;
if channel_count == 0 {
return (0.0, 0.0, 0.0);
}
let mut total_alpha = 0.0;
let mut total_beta = 0.0;
let mut total_theta = 0.0;
for ch in 0..channel_count {
let data = buffer.channel_data(ch);
let powers = band_powers(data, self.sample_rate as f32);
total_alpha += powers.alpha;
total_beta += powers.beta;
total_theta += powers.theta;
}
let n = channel_count as f32;
(total_alpha / n, total_beta / n, total_theta / n)
}
}
impl Model for CalmnessModel {
type Output = CalmnessOutput;
fn name(&self) -> &str {
"Calmness Model"
}
fn min_samples(&self) -> usize {
self.sample_rate as usize
}
fn process(&mut self, buffer: &SampleBuffer) -> Option<Self::Output> {
if buffer.sample_count() < self.min_samples() {
return None;
}
let (alpha_power, beta_power, theta_power) = self.compute_band_powers(buffer);
let (raw_score, alpha_beta_ratio) =
self.compute_calmness(alpha_power, beta_power, theta_power);
if self.update_count == 0 {
self.smoothed_score = raw_score;
} else {
self.smoothed_score = self.smoothed_score * (1.0 - self.smoothing_alpha)
+ raw_score * self.smoothing_alpha;
}
self.update_count += 1;
Some(CalmnessOutput {
score: raw_score,
smoothed_score: self.smoothed_score,
alpha_beta_ratio,
theta_level: theta_power,
alpha_power,
beta_power,
theta_power,
})
}
fn reset(&mut self) {
self.smoothed_score = 0.5;
self.update_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_buffer_multi_freq(
sample_rate: u16,
samples: usize,
frequencies: &[(f32, f32)], ) -> SampleBuffer {
use std::f32::consts::PI;
let mut buffer = SampleBuffer::new(sample_rate, 1);
let data: Vec<f32> = (0..samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
frequencies
.iter()
.map(|(freq, amp)| amp * (2.0 * PI * freq * t).sin())
.sum()
})
.collect();
buffer.push_interleaved(&data, 0, sample_rate);
buffer
}
#[test]
fn test_calmness_model_creation() {
let model = CalmnessModel::new(256);
assert_eq!(model.sample_rate, 256);
}
#[test]
fn test_high_alpha_is_calm() {
let mut model = CalmnessModel::new(256);
let buffer = create_test_buffer_multi_freq(256, 512, &[(10.0, 50.0)]);
let output = model.process(&buffer).unwrap();
assert!(output.alpha_power > output.beta_power);
}
#[test]
fn test_high_beta_is_alert() {
let mut model = CalmnessModel::new(256);
let buffer = create_test_buffer_multi_freq(256, 512, &[(20.0, 50.0)]);
let output = model.process(&buffer).unwrap();
assert!(output.beta_power > output.theta_power);
}
#[test]
fn test_score_range() {
let mut model = CalmnessModel::new(256);
let buffer =
create_test_buffer_multi_freq(256, 512, &[(10.0, 30.0), (20.0, 30.0), (6.0, 20.0)]);
let output = model.process(&buffer).unwrap();
assert!(output.score >= 0.0 && output.score <= 1.0);
assert!(output.smoothed_score >= 0.0 && output.smoothed_score <= 1.0);
}
}