Skip to main content

bids_core/
timeseries.rs

1//! Common trait for multichannel time-series data across modalities.
2//!
3//! Implemented by `EegData`, `MegData`, `NirsData`, and any other type that
4//! holds channels × samples arrays. Enables generic processing pipelines
5//! that work across EEG, MEG, NIRS, etc.
6
7/// Trait for read-only access to multichannel time-series data.
8///
9/// All electrophysiology data types (`EegData`, `MegData`, `NirsData`)
10/// implement this trait, enabling modality-agnostic processing.
11pub trait TimeSeries {
12    /// Number of channels.
13    fn n_channels(&self) -> usize;
14    /// Number of time samples (for channel 0; channels may differ for multi-rate).
15    fn n_samples(&self) -> usize;
16    /// Channel names / labels.
17    fn channel_names(&self) -> &[String];
18    /// Primary sampling rate in Hz.
19    fn sampling_rate(&self) -> f64;
20    /// Get one channel's data by index.
21    fn channel_data(&self, index: usize) -> Option<&[f64]>;
22    /// Total duration in seconds.
23    fn duration(&self) -> f64;
24
25    // ── Provided methods ────────────────────────────────────────────────
26
27    /// Get one channel's data by name.
28    fn channel_data_by_name(&self, name: &str) -> Option<&[f64]> {
29        let idx = self.channel_names().iter().position(|n| n == name)?;
30        self.channel_data(idx)
31    }
32
33    /// Time array for channel 0: `[0, 1/sr, 2/sr, ...]`.
34    fn times(&self) -> Vec<f64> {
35        let sr = self.sampling_rate();
36        (0..self.n_samples()).map(|i| i as f64 / sr).collect()
37    }
38
39    /// Mean value per channel.
40    fn channel_means(&self) -> Vec<f64> {
41        (0..self.n_channels())
42            .map(|ch| {
43                let d = self.channel_data(ch).unwrap_or(&[]);
44                if d.is_empty() {
45                    0.0
46                } else {
47                    d.iter().sum::<f64>() / d.len() as f64
48                }
49            })
50            .collect()
51    }
52
53    /// Standard deviation per channel (using pre-computed means).
54    fn channel_stds_with_means(&self, means: &[f64]) -> Vec<f64> {
55        (0..self.n_channels())
56            .map(|ch| {
57                let d = self.channel_data(ch).unwrap_or(&[]);
58                if d.len() < 2 {
59                    return 0.0;
60                }
61                let m = means[ch];
62                let var = d.iter().map(|v| (v - m).powi(2)).sum::<f64>() / (d.len() - 1) as f64;
63                var.sqrt()
64            })
65            .collect()
66    }
67
68    /// Standard deviation per channel.
69    fn channel_stds(&self) -> Vec<f64> {
70        let means = self.channel_means();
71        self.channel_stds_with_means(&means)
72    }
73
74    /// Z-score normalize all channels (zero mean, unit variance).
75    /// Returns a new `Vec<Vec<f64>>`.
76    #[must_use]
77    fn z_score(&self) -> Vec<Vec<f64>> {
78        let means = self.channel_means();
79        let stds = self.channel_stds_with_means(&means);
80        (0..self.n_channels())
81            .map(|ch| {
82                let d = self.channel_data(ch).unwrap_or(&[]);
83                let m = means[ch];
84                let s = if stds[ch] > f64::EPSILON {
85                    stds[ch]
86                } else {
87                    1.0
88                };
89                d.iter().map(|v| (v - m) / s).collect()
90            })
91            .collect()
92    }
93
94    /// Min-max normalize all channels to [0, 1].
95    #[must_use]
96    fn min_max_normalize(&self) -> Vec<Vec<f64>> {
97        (0..self.n_channels())
98            .map(|ch| {
99                let d = self.channel_data(ch).unwrap_or(&[]);
100                let min = d.iter().copied().fold(f64::INFINITY, f64::min);
101                let max = d.iter().copied().fold(f64::NEG_INFINITY, f64::max);
102                let range = max - min;
103                let range = if range > f64::EPSILON { range } else { 1.0 };
104                d.iter().map(|v| (v - min) / range).collect()
105            })
106            .collect()
107    }
108
109    // ── ML-oriented methods ─────────────────────────────────────────────
110
111    /// Extract a time window as a channels × samples `Vec<Vec<f64>>`.
112    ///
113    /// Useful for cutting epochs from continuous data for ML training.
114    /// `start_sec` and `end_sec` are in seconds.
115    #[must_use]
116    fn window(&self, start_sec: f64, end_sec: f64) -> Vec<Vec<f64>> {
117        let sr = self.sampling_rate();
118        let start = (start_sec * sr).round() as usize;
119        let end = (end_sec * sr).round() as usize;
120        (0..self.n_channels())
121            .map(|ch| {
122                let d = self.channel_data(ch).unwrap_or(&[]);
123                let s = start.min(d.len());
124                let e = end.min(d.len());
125                d[s..e].to_vec()
126            })
127            .collect()
128    }
129
130    /// Extract non-overlapping fixed-length epochs.
131    ///
132    /// Returns a Vec of epochs, each epoch is channels × window_samples.
133    /// Drops the last partial epoch if it's shorter than `window_sec`.
134    #[must_use]
135    fn epochs(&self, window_sec: f64) -> Vec<Vec<Vec<f64>>> {
136        self.epochs_with_stride(window_sec, window_sec)
137    }
138
139    /// Extract epochs with a given stride (allows overlap when stride < window).
140    ///
141    /// Returns `Vec<epoch>` where each epoch is `Vec<channel_data>`.
142    #[must_use]
143    fn epochs_with_stride(&self, window_sec: f64, stride_sec: f64) -> Vec<Vec<Vec<f64>>> {
144        let dur = self.duration();
145        if dur < window_sec {
146            return vec![];
147        }
148        let n = ((dur - window_sec) / stride_sec).floor() as usize + 1;
149        (0..n)
150            .map(|i| {
151                let start = i as f64 * stride_sec;
152                self.window(start, start + window_sec)
153            })
154            .collect()
155    }
156
157    /// Flatten channels × samples into a single contiguous `Vec<f64>` (row-major).
158    ///
159    /// Layout: `[ch0_s0, ch0_s1, ..., ch0_sN, ch1_s0, ..., chM_sN]`.
160    /// This is the format expected by most ML frameworks (batch × features).
161    #[must_use]
162    fn to_flat_vec(&self) -> Vec<f64> {
163        let mut out = Vec::with_capacity(self.n_channels() * self.n_samples());
164        for ch in 0..self.n_channels() {
165            if let Some(d) = self.channel_data(ch) {
166                out.extend_from_slice(d);
167            }
168        }
169        out
170    }
171
172    /// Get data as a contiguous `Vec<f64>` in column-major order (samples × channels).
173    ///
174    /// Layout: `[ch0_s0, ch1_s0, ..., chM_s0, ch0_s1, ch1_s1, ..., chM_sN]`.
175    /// This matches the layout expected by many time-series models (T × C).
176    #[must_use]
177    fn to_column_major(&self) -> Vec<f64> {
178        let nc = self.n_channels();
179        let ns = self.n_samples();
180        let mut out = Vec::with_capacity(nc * ns);
181        for s in 0..ns {
182            for ch in 0..nc {
183                let val = self
184                    .channel_data(ch)
185                    .and_then(|d| d.get(s))
186                    .copied()
187                    .unwrap_or(0.0);
188                out.push(val);
189            }
190        }
191        out
192    }
193
194    /// Shape as (n_channels, n_samples) — matches tensor dimension conventions.
195    #[must_use]
196    fn shape(&self) -> (usize, usize) {
197        (self.n_channels(), self.n_samples())
198    }
199
200    // ── Feature extraction (inspired by MOABB pipelines) ────────────────
201
202    /// Log-variance per channel — a simple but effective BCI feature.
203    ///
204    /// Equivalent to MOABB's `LogVariance` transformer.
205    /// Returns one value per channel: `ln(var(channel_data))`.
206    #[must_use]
207    fn log_variance(&self) -> Vec<f64> {
208        let means = self.channel_means();
209        (0..self.n_channels())
210            .map(|ch| {
211                let d = self.channel_data(ch).unwrap_or(&[]);
212                if d.len() < 2 {
213                    return f64::NEG_INFINITY;
214                }
215                let m = means[ch];
216                let var = d.iter().map(|v| (v - m).powi(2)).sum::<f64>() / d.len() as f64;
217                if var > 0.0 {
218                    var.ln()
219                } else {
220                    f64::NEG_INFINITY
221                }
222            })
223            .collect()
224    }
225
226    /// Band power per channel — average power in the signal.
227    ///
228    /// Returns one value per channel: `mean(x²)`.
229    #[must_use]
230    fn band_power(&self) -> Vec<f64> {
231        (0..self.n_channels())
232            .map(|ch| {
233                let d = self.channel_data(ch).unwrap_or(&[]);
234                if d.is_empty() {
235                    return 0.0;
236                }
237                d.iter().map(|v| v * v).sum::<f64>() / d.len() as f64
238            })
239            .collect()
240    }
241
242    /// RMS (root-mean-square) per channel.
243    #[must_use]
244    fn rms(&self) -> Vec<f64> {
245        self.band_power().iter().map(|p| p.sqrt()).collect()
246    }
247
248    /// Peak-to-peak amplitude per channel: `max - min`.
249    #[must_use]
250    fn peak_to_peak(&self) -> Vec<f64> {
251        (0..self.n_channels())
252            .map(|ch| {
253                let d = self.channel_data(ch).unwrap_or(&[]);
254                if d.is_empty() {
255                    return 0.0;
256                }
257                let min = d.iter().copied().fold(f64::INFINITY, f64::min);
258                let max = d.iter().copied().fold(f64::NEG_INFINITY, f64::max);
259                max - min
260            })
261            .collect()
262    }
263
264    /// Compute the covariance matrix (channels × channels).
265    ///
266    /// Returns a flat Vec in row-major order (length = n_channels²).
267    /// Used for Riemannian geometry BCI methods (CSP, MDM, etc.).
268    #[must_use]
269    fn covariance_matrix(&self) -> Vec<f64> {
270        let nc = self.n_channels();
271        let ns = self.n_samples();
272        let means = self.channel_means();
273        let mut cov = vec![0.0; nc * nc];
274
275        if ns < 2 {
276            return cov;
277        }
278
279        for i in 0..nc {
280            let di = self.channel_data(i).unwrap_or(&[]);
281            for j in i..nc {
282                let dj = self.channel_data(j).unwrap_or(&[]);
283                let sum: f64 = di
284                    .iter()
285                    .zip(dj.iter())
286                    .map(|(a, b)| (a - means[i]) * (b - means[j]))
287                    .sum();
288                let val = sum / (ns - 1) as f64;
289                cov[i * nc + j] = val;
290                cov[j * nc + i] = val;
291            }
292        }
293        cov
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    /// Minimal TimeSeries implementation for testing.
302    struct TestSeries {
303        channels: Vec<Vec<f64>>,
304        names: Vec<String>,
305        sr: f64,
306    }
307
308    impl TimeSeries for TestSeries {
309        fn n_channels(&self) -> usize {
310            self.channels.len()
311        }
312        fn n_samples(&self) -> usize {
313            self.channels.first().map_or(0, |v| v.len())
314        }
315        fn channel_names(&self) -> &[String] {
316            &self.names
317        }
318        fn sampling_rate(&self) -> f64 {
319            self.sr
320        }
321        fn channel_data(&self, index: usize) -> Option<&[f64]> {
322            self.channels.get(index).map(|v| v.as_slice())
323        }
324        fn duration(&self) -> f64 {
325            self.n_samples() as f64 / self.sr
326        }
327    }
328
329    fn make_test_series() -> TestSeries {
330        TestSeries {
331            channels: vec![
332                vec![1.0, 2.0, 3.0, 4.0, 5.0],
333                vec![10.0, 20.0, 30.0, 40.0, 50.0],
334            ],
335            names: vec!["Ch1".into(), "Ch2".into()],
336            sr: 100.0,
337        }
338    }
339
340    #[test]
341    fn test_times() {
342        let ts = make_test_series();
343        let times = ts.times();
344        assert_eq!(times.len(), 5);
345        assert!((times[0] - 0.0).abs() < 1e-10);
346        assert!((times[1] - 0.01).abs() < 1e-10);
347        assert!((times[4] - 0.04).abs() < 1e-10);
348    }
349
350    #[test]
351    fn test_channel_means() {
352        let ts = make_test_series();
353        let means = ts.channel_means();
354        assert!((means[0] - 3.0).abs() < 1e-10);
355        assert!((means[1] - 30.0).abs() < 1e-10);
356    }
357
358    #[test]
359    fn test_channel_stds() {
360        let ts = make_test_series();
361        let stds = ts.channel_stds();
362        // std of [1,2,3,4,5] = sqrt(2.5) ≈ 1.5811
363        assert!((stds[0] - 1.5811388300841898).abs() < 1e-10);
364        assert!((stds[1] - 15.811388300841896).abs() < 1e-10);
365    }
366
367    #[test]
368    fn test_z_score() {
369        let ts = make_test_series();
370        let z = ts.z_score();
371        assert_eq!(z.len(), 2);
372        assert_eq!(z[0].len(), 5);
373        // After z-score: mean ≈ 0, std ≈ 1
374        let mean: f64 = z[0].iter().sum::<f64>() / z[0].len() as f64;
375        assert!(mean.abs() < 1e-10, "z-score mean = {}", mean);
376        let var: f64 =
377            z[0].iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (z[0].len() - 1) as f64;
378        assert!((var - 1.0).abs() < 1e-10, "z-score variance = {}", var);
379    }
380
381    #[test]
382    fn test_min_max_normalize() {
383        let ts = make_test_series();
384        let norm = ts.min_max_normalize();
385        assert!((norm[0][0] - 0.0).abs() < 1e-10); // min → 0
386        assert!((norm[0][4] - 1.0).abs() < 1e-10); // max → 1
387        assert!((norm[0][2] - 0.5).abs() < 1e-10); // middle → 0.5
388    }
389
390    #[test]
391    fn test_channel_data_by_name() {
392        let ts = make_test_series();
393        assert_eq!(
394            ts.channel_data_by_name("Ch1"),
395            Some(&[1.0, 2.0, 3.0, 4.0, 5.0][..])
396        );
397        assert_eq!(
398            ts.channel_data_by_name("Ch2"),
399            Some(&[10.0, 20.0, 30.0, 40.0, 50.0][..])
400        );
401        assert_eq!(ts.channel_data_by_name("Missing"), None);
402    }
403
404    #[test]
405    fn test_duration() {
406        let ts = make_test_series();
407        assert!((ts.duration() - 0.05).abs() < 1e-10);
408    }
409
410    #[test]
411    fn test_window() {
412        let ts = make_long_series();
413        let w = ts.window(0.0, 0.5);
414        assert_eq!(w.len(), 2); // 2 channels
415        assert_eq!(w[0].len(), 50); // 0.5s × 100Hz
416    }
417
418    #[test]
419    fn test_epochs() {
420        let ts = make_long_series();
421        let epochs = ts.epochs(0.5);
422        // 1.0s duration / 0.5s window = 2 epochs
423        assert_eq!(epochs.len(), 2);
424        assert_eq!(epochs[0].len(), 2); // 2 channels per epoch
425        assert_eq!(epochs[0][0].len(), 50); // 50 samples per window
426    }
427
428    #[test]
429    fn test_epochs_with_stride() {
430        let ts = make_long_series();
431        let epochs = ts.epochs_with_stride(0.5, 0.25);
432        // 1.0s, window=0.5, stride=0.25 → floor((1.0-0.5)/0.25)+1 = 3
433        assert_eq!(epochs.len(), 3);
434    }
435
436    #[test]
437    fn test_to_flat_vec() {
438        let ts = make_test_series();
439        let flat = ts.to_flat_vec();
440        assert_eq!(flat.len(), 10); // 2 channels × 5 samples
441        assert_eq!(&flat[..5], &[1.0, 2.0, 3.0, 4.0, 5.0]); // ch0
442        assert_eq!(&flat[5..], &[10.0, 20.0, 30.0, 40.0, 50.0]); // ch1
443    }
444
445    #[test]
446    fn test_to_column_major() {
447        let ts = make_test_series();
448        let col = ts.to_column_major();
449        assert_eq!(col.len(), 10);
450        // First two elements: ch0_s0, ch1_s0
451        assert_eq!(col[0], 1.0);
452        assert_eq!(col[1], 10.0);
453        // Next: ch0_s1, ch1_s1
454        assert_eq!(col[2], 2.0);
455        assert_eq!(col[3], 20.0);
456    }
457
458    #[test]
459    fn test_shape() {
460        let ts = make_test_series();
461        assert_eq!(ts.shape(), (2, 5));
462    }
463
464    #[test]
465    fn test_log_variance() {
466        let ts = make_test_series();
467        let lv = ts.log_variance();
468        assert_eq!(lv.len(), 2);
469        // var([1,2,3,4,5]) = 2.0, ln(2.0) ≈ 0.693
470        assert!((lv[0] - 2.0f64.ln()).abs() < 1e-10);
471    }
472
473    #[test]
474    fn test_band_power() {
475        let ts = make_test_series();
476        let bp = ts.band_power();
477        assert_eq!(bp.len(), 2);
478        // mean([1,4,9,16,25]) = 11.0
479        assert!((bp[0] - 11.0).abs() < 1e-10);
480    }
481
482    #[test]
483    fn test_peak_to_peak() {
484        let ts = make_test_series();
485        let ptp = ts.peak_to_peak();
486        assert!((ptp[0] - 4.0).abs() < 1e-10); // 5-1
487        assert!((ptp[1] - 40.0).abs() < 1e-10); // 50-10
488    }
489
490    #[test]
491    fn test_covariance_matrix() {
492        let ts = make_test_series();
493        let cov = ts.covariance_matrix();
494        assert_eq!(cov.len(), 4); // 2×2
495        // Diagonal should be variance with n-1 denominator
496        // var([1,2,3,4,5], ddof=1) = 2.5
497        assert!((cov[0] - 2.5).abs() < 1e-10);
498        // cov(ch0, ch1) should be positive (both increasing)
499        assert!(cov[1] > 0.0);
500        // Symmetric
501        assert!((cov[1] - cov[2]).abs() < 1e-10);
502    }
503
504    fn make_long_series() -> TestSeries {
505        let n = 100; // 100 samples @ 100Hz = 1.0s
506        TestSeries {
507            channels: vec![
508                (0..n).map(|i| i as f64).collect(),
509                (0..n).map(|i| (i * 2) as f64).collect(),
510            ],
511            names: vec!["Ch1".into(), "Ch2".into()],
512            sr: 100.0,
513        }
514    }
515}