Skip to main content

bids_eeg/
csp.rs

1//! Common Spatial Patterns (CSP) for motor imagery BCI.
2//!
3//! CSP finds spatial filters that maximize the variance ratio between two
4//! classes. It's the standard baseline feature extraction for motor imagery
5//! BCI systems, typically paired with LDA or SVM classification.
6//!
7//! This is a pure-Rust implementation — no BLAS/LAPACK dependency.
8//!
9//! # Algorithm
10//!
11//! 1. Compute per-class average covariance matrices C₁, C₂
12//! 2. Solve the generalized eigenvalue problem: C₁ w = λ (C₁ + C₂) w
13//! 3. Select the m eigenvectors with largest/smallest eigenvalues as spatial filters
14//! 4. Project data: Z = W^T X, features = log(var(Z))
15//!
16//! # Example
17//!
18//! ```no_run
19//! use bids_eeg::csp::CSP;
20//!
21//! let mut csp = CSP::new(3); // 3 pairs = 6 components
22//! // csp.fit(&epochs_class1, &epochs_class2);
23//! // let features = csp.transform(&test_epoch);
24//! ```
25
26/// Common Spatial Patterns filter.
27#[derive(Debug, Clone)]
28pub struct CSP {
29    /// Number of spatial filter pairs (2*n_components total filters).
30    n_components: usize,
31    /// Spatial filter matrix W: each row is a spatial filter.
32    /// Shape: `[2*n_components][n_channels]`.
33    filters: Option<Vec<Vec<f64>>>,
34    /// Eigenvalues corresponding to the filters.
35    eigenvalues: Vec<f64>,
36}
37
38impl CSP {
39    /// Create a new CSP with `n_components` pairs of spatial filters.
40    ///
41    /// The total number of output features will be `2 * n_components`.
42    #[must_use]
43    pub fn new(n_components: usize) -> Self {
44        Self {
45            n_components,
46            filters: None,
47            eigenvalues: Vec::new(),
48        }
49    }
50
51    /// Fit CSP filters from two classes of epoch data.
52    ///
53    /// Each epoch is `[n_channels][n_samples]`. All epochs must have the
54    /// same number of channels.
55    pub fn fit(&mut self, class1: &[Vec<Vec<f64>>], class2: &[Vec<Vec<f64>>]) {
56        if class1.is_empty() || class2.is_empty() {
57            return;
58        }
59
60        let n_ch = class1[0].len();
61
62        // Compute mean covariance for each class
63        let cov1 = mean_covariance(class1, n_ch);
64        let cov2 = mean_covariance(class2, n_ch);
65
66        // Composite covariance
67        let mut cov_sum = vec![0.0; n_ch * n_ch];
68        for i in 0..cov_sum.len() {
69            cov_sum[i] = cov1[i] + cov2[i];
70        }
71
72        // Solve generalized eigenvalue problem via whitening:
73        // 1. Eigendecompose cov_sum = U D U^T
74        // 2. Whitening: P = D^(-1/2) U^T
75        // 3. S = P C₁ P^T
76        // 4. Eigendecompose S to get the CSP filters
77        let (eig_vals_sum, eig_vecs_sum) = symmetric_eigen(n_ch, &cov_sum);
78
79        // Whitening matrix P: D^(-1/2) * U^T
80        let mut p = vec![0.0; n_ch * n_ch];
81        for i in 0..n_ch {
82            let d = eig_vals_sum[i];
83            let scale = if d > 1e-12 { 1.0 / d.sqrt() } else { 0.0 };
84            for j in 0..n_ch {
85                p[i * n_ch + j] = eig_vecs_sum[j * n_ch + i] * scale; // U^T scaled
86            }
87        }
88
89        // S = P * C₁ * P^T
90        let pc1 = mat_mul(n_ch, &p, &cov1);
91        let p_t = transpose(n_ch, &p);
92        let s = mat_mul(n_ch, &pc1, &p_t);
93
94        // Eigendecompose S
95        let (eig_vals_s, eig_vecs_s) = symmetric_eigen(n_ch, &s);
96
97        // Sort eigenvalues (descending)
98        let mut indices: Vec<usize> = (0..n_ch).collect();
99        indices.sort_by(|&a, &b| {
100            eig_vals_s[b]
101                .partial_cmp(&eig_vals_s[a])
102                .unwrap_or(std::cmp::Ordering::Equal)
103        });
104
105        // Select top and bottom n_components
106        let n = self.n_components.min(n_ch / 2);
107        let selected: Vec<usize> = indices[..n]
108            .iter()
109            .chain(indices[n_ch - n..].iter())
110            .copied()
111            .collect();
112
113        // Compute spatial filters: W = eigvecs_S^T * P
114        let mut filters = Vec::with_capacity(selected.len());
115        let mut eigenvalues = Vec::with_capacity(selected.len());
116
117        for &idx in &selected {
118            let mut w = vec![0.0; n_ch];
119            for j in 0..n_ch {
120                let mut sum = 0.0;
121                for k in 0..n_ch {
122                    sum += eig_vecs_s[k * n_ch + idx] * p[k * n_ch + j];
123                }
124                w[j] = sum;
125            }
126            // Normalize
127            let norm: f64 = w.iter().map(|v| v * v).sum::<f64>().sqrt();
128            if norm > 1e-12 {
129                for v in &mut w {
130                    *v /= norm;
131                }
132            }
133            filters.push(w);
134            eigenvalues.push(eig_vals_s[idx]);
135        }
136
137        self.filters = Some(filters);
138        self.eigenvalues = eigenvalues;
139    }
140
141    /// Transform a single epoch into CSP features.
142    ///
143    /// Returns `log(var(W^T X))` for each spatial filter — one value per component.
144    /// The epoch shape is `[n_channels][n_samples]`.
145    #[must_use]
146    pub fn transform(&self, epoch: &[Vec<f64>]) -> Vec<f64> {
147        let filters = match &self.filters {
148            Some(f) => f,
149            None => return Vec::new(),
150        };
151
152        let n_ch = epoch.len();
153        let n_s = epoch.first().map_or(0, |ch| ch.len());
154
155        filters
156            .iter()
157            .map(|w| {
158                // Project: z[t] = sum_c w[c] * x[c][t]
159                let nc = n_ch.min(w.len());
160                let projected: Vec<f64> = (0..n_s)
161                    .map(|t| (0..nc).map(|c| w[c] * epoch[c][t]).sum::<f64>())
162                    .collect();
163
164                let mean = if n_s > 0 {
165                    projected.iter().sum::<f64>() / n_s as f64
166                } else {
167                    0.0
168                };
169                let var = if n_s > 1 {
170                    projected.iter().map(|z| (z - mean).powi(2)).sum::<f64>() / (n_s - 1) as f64
171                } else {
172                    0.0
173                };
174                if var > 0.0 {
175                    var.ln()
176                } else {
177                    f64::NEG_INFINITY
178                }
179            })
180            .collect()
181    }
182
183    /// Transform multiple epochs into a feature matrix.
184    ///
185    /// Returns `[n_epochs][2*n_components]`.
186    #[must_use]
187    pub fn transform_all(&self, epochs: &[Vec<Vec<f64>>]) -> Vec<Vec<f64>> {
188        epochs.iter().map(|e| self.transform(e)).collect()
189    }
190
191    /// Number of output features (2 × n_components).
192    #[must_use]
193    pub fn n_features(&self) -> usize {
194        self.filters.as_ref().map_or(0, |f| f.len())
195    }
196
197    /// Whether the CSP has been fitted.
198    #[must_use]
199    pub fn is_fitted(&self) -> bool {
200        self.filters.is_some()
201    }
202
203    /// Get the eigenvalues.
204    #[must_use]
205    pub fn eigenvalues(&self) -> &[f64] {
206        &self.eigenvalues
207    }
208}
209
210// ─── Linear algebra helpers (no BLAS needed) ───────────────────────────────────
211
212/// Compute mean covariance matrix across epochs.
213fn mean_covariance(epochs: &[Vec<Vec<f64>>], n_ch: usize) -> Vec<f64> {
214    let mut cov = vec![0.0; n_ch * n_ch];
215    let n_epochs = epochs.len() as f64;
216
217    for epoch in epochs {
218        let nc = epoch.len().min(n_ch);
219        let ns = epoch.first().map_or(0, |ch| ch.len());
220        if ns < 2 {
221            continue;
222        }
223
224        // Compute means
225        let means: Vec<f64> = (0..nc)
226            .map(|c| epoch[c].iter().sum::<f64>() / ns as f64)
227            .collect();
228
229        // Accumulate covariance
230        for i in 0..nc {
231            for j in i..nc {
232                let sum: f64 = (0..ns)
233                    .map(|t| (epoch[i][t] - means[i]) * (epoch[j][t] - means[j]))
234                    .sum();
235                let val = sum / (ns - 1) as f64;
236                cov[i * n_ch + j] += val / n_epochs;
237                if i != j {
238                    cov[j * n_ch + i] += val / n_epochs;
239                }
240            }
241        }
242    }
243
244    // Normalize by trace
245    let trace: f64 = (0..n_ch).map(|i| cov[i * n_ch + i]).sum();
246    if trace > 1e-12 {
247        for v in &mut cov {
248            *v /= trace;
249        }
250    }
251
252    cov
253}
254
255/// Simple symmetric eigendecomposition via Jacobi iteration.
256///
257/// Returns (eigenvalues, eigenvectors_column_major).
258/// Good enough for small matrices (n_channels typically < 128).
259fn symmetric_eigen(n: usize, a: &[f64]) -> (Vec<f64>, Vec<f64>) {
260    let mut d = a.to_vec(); // working copy
261    let mut v = vec![0.0; n * n]; // eigenvectors (identity initially)
262    for i in 0..n {
263        v[i * n + i] = 1.0;
264    }
265
266    let max_iter = 100 * n * n;
267    for _ in 0..max_iter {
268        // Find largest off-diagonal element
269        let mut max_val = 0.0;
270        let mut p = 0;
271        let mut q = 1;
272        for i in 0..n {
273            for j in i + 1..n {
274                let val = d[i * n + j].abs();
275                if val > max_val {
276                    max_val = val;
277                    p = i;
278                    q = j;
279                }
280            }
281        }
282
283        if max_val < 1e-14 {
284            break;
285        }
286
287        // Compute rotation
288        let app = d[p * n + p];
289        let aqq = d[q * n + q];
290        let apq = d[p * n + q];
291
292        let theta = if (app - aqq).abs() < 1e-15 {
293            std::f64::consts::FRAC_PI_4
294        } else {
295            0.5 * (2.0 * apq / (app - aqq)).atan()
296        };
297
298        let c = theta.cos();
299        let s = theta.sin();
300
301        // Apply Jacobi rotation to d
302        let mut new_d = d.clone();
303        new_d[p * n + p] = c * c * app + 2.0 * s * c * apq + s * s * aqq;
304        new_d[q * n + q] = s * s * app - 2.0 * s * c * apq + c * c * aqq;
305        new_d[p * n + q] = 0.0;
306        new_d[q * n + p] = 0.0;
307
308        for i in 0..n {
309            if i != p && i != q {
310                let dip = c * d[i * n + p] + s * d[i * n + q];
311                let diq = -s * d[i * n + p] + c * d[i * n + q];
312                new_d[i * n + p] = dip;
313                new_d[p * n + i] = dip;
314                new_d[i * n + q] = diq;
315                new_d[q * n + i] = diq;
316            }
317        }
318        d = new_d;
319
320        // Update eigenvectors
321        for i in 0..n {
322            let vip = v[i * n + p];
323            let viq = v[i * n + q];
324            v[i * n + p] = c * vip + s * viq;
325            v[i * n + q] = -s * vip + c * viq;
326        }
327    }
328
329    let eigenvalues: Vec<f64> = (0..n).map(|i| d[i * n + i]).collect();
330    (eigenvalues, v)
331}
332
333/// Matrix multiply: C = A * B (all n×n, row-major).
334fn mat_mul(n: usize, a: &[f64], b: &[f64]) -> Vec<f64> {
335    let mut c = vec![0.0; n * n];
336    for i in 0..n {
337        for k in 0..n {
338            let aik = a[i * n + k];
339            if aik.abs() < 1e-15 {
340                continue;
341            }
342            for j in 0..n {
343                c[i * n + j] += aik * b[k * n + j];
344            }
345        }
346    }
347    c
348}
349
350/// Transpose n×n matrix.
351fn transpose(n: usize, a: &[f64]) -> Vec<f64> {
352    let mut t = vec![0.0; n * n];
353    for i in 0..n {
354        for j in 0..n {
355            t[j * n + i] = a[i * n + j];
356        }
357    }
358    t
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    fn make_epochs(n_epochs: usize, n_ch: usize, n_s: usize, class: usize) -> Vec<Vec<Vec<f64>>> {
366        // Generate deterministic pseudo-random data with class-dependent variance
367        (0..n_epochs)
368            .map(|epoch_idx| {
369                (0..n_ch)
370                    .map(|ch| {
371                        (0..n_s)
372                            .map(|s| {
373                                let seed = (epoch_idx * 1000 + ch * 100 + s + class * 50000) as f64;
374                                let val =
375                                    (seed * 0.1).sin() * (1.0 + class as f64 * ch as f64 * 0.5);
376                                val
377                            })
378                            .collect()
379                    })
380                    .collect()
381            })
382            .collect()
383    }
384
385    #[test]
386    fn test_csp_fit_transform() {
387        let class1 = make_epochs(20, 4, 100, 0);
388        let class2 = make_epochs(20, 4, 100, 1);
389
390        let mut csp = CSP::new(2);
391        csp.fit(&class1, &class2);
392
393        assert!(csp.is_fitted());
394        assert_eq!(csp.n_features(), 4); // 2 pairs
395
396        let features = csp.transform(&class1[0]);
397        assert_eq!(features.len(), 4);
398        // Features should be finite
399        assert!(features.iter().all(|f| f.is_finite()));
400    }
401
402    #[test]
403    fn test_csp_transform_all() {
404        let class1 = make_epochs(10, 3, 50, 0);
405        let class2 = make_epochs(10, 3, 50, 1);
406
407        let mut csp = CSP::new(1);
408        csp.fit(&class1, &class2);
409
410        let features1 = csp.transform_all(&class1);
411        let features2 = csp.transform_all(&class2);
412
413        assert_eq!(features1.len(), 10);
414        assert_eq!(features2.len(), 10);
415        assert_eq!(features1[0].len(), 2); // 1 pair = 2 features
416
417        // CSP should separate the classes: mean features should differ
418        let mean1: f64 = features1.iter().map(|f| f[0]).sum::<f64>() / 10.0;
419        let mean2: f64 = features2.iter().map(|f| f[0]).sum::<f64>() / 10.0;
420        assert!(
421            (mean1 - mean2).abs() > 1e-6,
422            "CSP should separate classes: {mean1} vs {mean2}"
423        );
424    }
425
426    #[test]
427    fn test_csp_not_fitted() {
428        let csp = CSP::new(2);
429        assert!(!csp.is_fitted());
430        assert_eq!(csp.transform(&[vec![1.0, 2.0]]).len(), 0);
431    }
432
433    #[test]
434    fn test_symmetric_eigen() {
435        // 2x2 symmetric matrix
436        let a = vec![2.0, 1.0, 1.0, 3.0];
437        let (vals, _vecs) = symmetric_eigen(2, &a);
438        let mut sorted = vals.clone();
439        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
440        // Eigenvalues of [[2,1],[1,3]] are (5±√5)/2 ≈ 1.382, 3.618
441        assert!((sorted[0] - 1.382).abs() < 0.01);
442        assert!((sorted[1] - 3.618).abs() < 0.01);
443    }
444}