Skip to main content

elata_eeg_models/
calmness.rs

1//! Calmness Model
2//!
3//! Computes a continuous calmness score based on the ratio of
4//! relaxation-associated frequencies (alpha, theta) to
5//! alertness-associated frequencies (beta).
6
7use elata_eeg_hal::SampleBuffer;
8use elata_eeg_signal::band_powers;
9
10use crate::model::{Model, ModelOutput};
11
12/// Output from the calmness model
13#[derive(Debug, Clone)]
14pub struct CalmnessOutput {
15    /// Calmness score (0.0 = very alert, 1.0 = very calm)
16    pub score: f32,
17    /// Smoothed calmness score (less jittery)
18    pub smoothed_score: f32,
19    /// Alpha/beta ratio (primary indicator)
20    pub alpha_beta_ratio: f32,
21    /// Theta contribution
22    pub theta_level: f32,
23    /// Raw band powers
24    pub alpha_power: f32,
25    pub beta_power: f32,
26    pub theta_power: f32,
27}
28
29impl ModelOutput for CalmnessOutput {
30    fn description(&self) -> String {
31        let state = if self.smoothed_score > 0.7 {
32            "very calm"
33        } else if self.smoothed_score > 0.5 {
34            "calm"
35        } else if self.smoothed_score > 0.3 {
36            "neutral"
37        } else {
38            "alert"
39        };
40        format!(
41            "Calmness: {:.0}% ({}) [α/β={:.2}]",
42            self.smoothed_score * 100.0,
43            state,
44            self.alpha_beta_ratio
45        )
46    }
47
48    fn value(&self) -> Option<f32> {
49        Some(self.smoothed_score)
50    }
51
52    fn confidence(&self) -> Option<f32> {
53        // Higher confidence when we have stronger signals
54        let signal_strength = self.alpha_power + self.beta_power + self.theta_power;
55        Some((signal_strength / 1000.0).min(1.0))
56    }
57}
58
59/// Calmness analysis model
60///
61/// Uses the ratio of alpha+theta power to beta power to estimate
62/// how calm or alert the user is.
63///
64/// The model outputs a score from 0.0 (very alert) to 1.0 (very calm).
65pub struct CalmnessModel {
66    /// Sample rate
67    sample_rate: u16,
68    /// Smoothed score (exponential moving average)
69    smoothed_score: f32,
70    /// Smoothing factor
71    smoothing_alpha: f32,
72    /// Number of updates
73    update_count: usize,
74}
75
76impl CalmnessModel {
77    /// Create a new calmness model
78    pub fn new(sample_rate: u16) -> Self {
79        Self {
80            sample_rate,
81            smoothed_score: 0.5, // Start neutral
82            smoothing_alpha: 0.1,
83            update_count: 0,
84        }
85    }
86
87    /// Set the smoothing factor (0-1, higher = faster response)
88    pub fn set_smoothing(&mut self, alpha: f32) {
89        self.smoothing_alpha = alpha.clamp(0.01, 1.0);
90    }
91
92    /// Compute calmness metrics from band powers
93    fn compute_calmness(&self, alpha: f32, beta: f32, theta: f32) -> (f32, f32) {
94        // Avoid division by zero
95        let beta_safe = beta.max(1e-6);
96
97        // Alpha/beta ratio is primary indicator
98        // Higher ratio = more relaxed
99        let alpha_beta_ratio = alpha / beta_safe;
100
101        // Include theta for deeper relaxation states
102        let relaxation_power = alpha + theta * 0.5;
103        let alertness_power = beta;
104
105        // Compute ratio (reserved for future use in enhanced model)
106        let _ratio = relaxation_power / (relaxation_power + alertness_power + 1e-6);
107
108        // Map to 0-1 score using sigmoid-like function
109        // Typical alpha/beta ratios: 0.5-3.0
110        let normalized_ratio = (alpha_beta_ratio - 0.5) / 2.5; // Center around 1.0
111        let score = 1.0 / (1.0 + (-normalized_ratio * 4.0).exp());
112
113        (score.clamp(0.0, 1.0), alpha_beta_ratio)
114    }
115
116    /// Compute average band powers across all channels
117    fn compute_band_powers(&self, buffer: &SampleBuffer) -> (f32, f32, f32) {
118        let channel_count = buffer.channel_count;
119        if channel_count == 0 {
120            return (0.0, 0.0, 0.0);
121        }
122
123        let mut total_alpha = 0.0;
124        let mut total_beta = 0.0;
125        let mut total_theta = 0.0;
126
127        for ch in 0..channel_count {
128            let data = buffer.channel_data(ch);
129            let powers = band_powers(data, self.sample_rate as f32);
130            total_alpha += powers.alpha;
131            total_beta += powers.beta;
132            total_theta += powers.theta;
133        }
134
135        let n = channel_count as f32;
136        (total_alpha / n, total_beta / n, total_theta / n)
137    }
138}
139
140impl Model for CalmnessModel {
141    type Output = CalmnessOutput;
142
143    fn name(&self) -> &str {
144        "Calmness Model"
145    }
146
147    fn min_samples(&self) -> usize {
148        // Need at least 1 second of data
149        self.sample_rate as usize
150    }
151
152    fn process(&mut self, buffer: &SampleBuffer) -> Option<Self::Output> {
153        if buffer.sample_count() < self.min_samples() {
154            return None;
155        }
156
157        let (alpha_power, beta_power, theta_power) = self.compute_band_powers(buffer);
158        let (raw_score, alpha_beta_ratio) =
159            self.compute_calmness(alpha_power, beta_power, theta_power);
160
161        // Update smoothed score
162        if self.update_count == 0 {
163            self.smoothed_score = raw_score;
164        } else {
165            self.smoothed_score = self.smoothed_score * (1.0 - self.smoothing_alpha)
166                + raw_score * self.smoothing_alpha;
167        }
168        self.update_count += 1;
169
170        Some(CalmnessOutput {
171            score: raw_score,
172            smoothed_score: self.smoothed_score,
173            alpha_beta_ratio,
174            theta_level: theta_power,
175            alpha_power,
176            beta_power,
177            theta_power,
178        })
179    }
180
181    fn reset(&mut self) {
182        self.smoothed_score = 0.5;
183        self.update_count = 0;
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    fn create_test_buffer_multi_freq(
192        sample_rate: u16,
193        samples: usize,
194        frequencies: &[(f32, f32)], // (freq, amplitude)
195    ) -> SampleBuffer {
196        use std::f32::consts::PI;
197
198        let mut buffer = SampleBuffer::new(sample_rate, 1);
199        let data: Vec<f32> = (0..samples)
200            .map(|i| {
201                let t = i as f32 / sample_rate as f32;
202                frequencies
203                    .iter()
204                    .map(|(freq, amp)| amp * (2.0 * PI * freq * t).sin())
205                    .sum()
206            })
207            .collect();
208        buffer.push_interleaved(&data, 0, sample_rate);
209        buffer
210    }
211
212    #[test]
213    fn test_calmness_model_creation() {
214        let model = CalmnessModel::new(256);
215        assert_eq!(model.sample_rate, 256);
216    }
217
218    #[test]
219    fn test_high_alpha_is_calm() {
220        let mut model = CalmnessModel::new(256);
221
222        // Strong alpha (10 Hz) signal
223        let buffer = create_test_buffer_multi_freq(256, 512, &[(10.0, 50.0)]);
224        let output = model.process(&buffer).unwrap();
225
226        // Should show some level of calmness with alpha
227        assert!(output.alpha_power > output.beta_power);
228    }
229
230    #[test]
231    fn test_high_beta_is_alert() {
232        let mut model = CalmnessModel::new(256);
233
234        // Strong beta (20 Hz) signal
235        let buffer = create_test_buffer_multi_freq(256, 512, &[(20.0, 50.0)]);
236        let output = model.process(&buffer).unwrap();
237
238        // Beta should dominate
239        assert!(output.beta_power > output.theta_power);
240    }
241
242    #[test]
243    fn test_score_range() {
244        let mut model = CalmnessModel::new(256);
245
246        // Mixed signal
247        let buffer =
248            create_test_buffer_multi_freq(256, 512, &[(10.0, 30.0), (20.0, 30.0), (6.0, 20.0)]);
249
250        let output = model.process(&buffer).unwrap();
251
252        // Score should be in valid range
253        assert!(output.score >= 0.0 && output.score <= 1.0);
254        assert!(output.smoothed_score >= 0.0 && output.smoothed_score <= 1.0);
255    }
256}