1#[derive(Debug, Clone)]
28pub struct CSP {
29 n_components: usize,
31 filters: Option<Vec<Vec<f64>>>,
34 eigenvalues: Vec<f64>,
36}
37
38impl CSP {
39 #[must_use]
43 pub fn new(n_components: usize) -> Self {
44 Self {
45 n_components,
46 filters: None,
47 eigenvalues: Vec::new(),
48 }
49 }
50
51 pub fn fit(&mut self, class1: &[Vec<Vec<f64>>], class2: &[Vec<Vec<f64>>]) {
56 if class1.is_empty() || class2.is_empty() {
57 return;
58 }
59
60 let n_ch = class1[0].len();
61
62 let cov1 = mean_covariance(class1, n_ch);
64 let cov2 = mean_covariance(class2, n_ch);
65
66 let mut cov_sum = vec![0.0; n_ch * n_ch];
68 for i in 0..cov_sum.len() {
69 cov_sum[i] = cov1[i] + cov2[i];
70 }
71
72 let (eig_vals_sum, eig_vecs_sum) = symmetric_eigen(n_ch, &cov_sum);
78
79 let mut p = vec![0.0; n_ch * n_ch];
81 for i in 0..n_ch {
82 let d = eig_vals_sum[i];
83 let scale = if d > 1e-12 { 1.0 / d.sqrt() } else { 0.0 };
84 for j in 0..n_ch {
85 p[i * n_ch + j] = eig_vecs_sum[j * n_ch + i] * scale; }
87 }
88
89 let pc1 = mat_mul(n_ch, &p, &cov1);
91 let p_t = transpose(n_ch, &p);
92 let s = mat_mul(n_ch, &pc1, &p_t);
93
94 let (eig_vals_s, eig_vecs_s) = symmetric_eigen(n_ch, &s);
96
97 let mut indices: Vec<usize> = (0..n_ch).collect();
99 indices.sort_by(|&a, &b| {
100 eig_vals_s[b]
101 .partial_cmp(&eig_vals_s[a])
102 .unwrap_or(std::cmp::Ordering::Equal)
103 });
104
105 let n = self.n_components.min(n_ch / 2);
107 let selected: Vec<usize> = indices[..n]
108 .iter()
109 .chain(indices[n_ch - n..].iter())
110 .copied()
111 .collect();
112
113 let mut filters = Vec::with_capacity(selected.len());
115 let mut eigenvalues = Vec::with_capacity(selected.len());
116
117 for &idx in &selected {
118 let mut w = vec![0.0; n_ch];
119 for j in 0..n_ch {
120 let mut sum = 0.0;
121 for k in 0..n_ch {
122 sum += eig_vecs_s[k * n_ch + idx] * p[k * n_ch + j];
123 }
124 w[j] = sum;
125 }
126 let norm: f64 = w.iter().map(|v| v * v).sum::<f64>().sqrt();
128 if norm > 1e-12 {
129 for v in &mut w {
130 *v /= norm;
131 }
132 }
133 filters.push(w);
134 eigenvalues.push(eig_vals_s[idx]);
135 }
136
137 self.filters = Some(filters);
138 self.eigenvalues = eigenvalues;
139 }
140
141 #[must_use]
146 pub fn transform(&self, epoch: &[Vec<f64>]) -> Vec<f64> {
147 let filters = match &self.filters {
148 Some(f) => f,
149 None => return Vec::new(),
150 };
151
152 let n_ch = epoch.len();
153 let n_s = epoch.first().map_or(0, |ch| ch.len());
154
155 filters
156 .iter()
157 .map(|w| {
158 let nc = n_ch.min(w.len());
160 let projected: Vec<f64> = (0..n_s)
161 .map(|t| (0..nc).map(|c| w[c] * epoch[c][t]).sum::<f64>())
162 .collect();
163
164 let mean = if n_s > 0 {
165 projected.iter().sum::<f64>() / n_s as f64
166 } else {
167 0.0
168 };
169 let var = if n_s > 1 {
170 projected.iter().map(|z| (z - mean).powi(2)).sum::<f64>() / (n_s - 1) as f64
171 } else {
172 0.0
173 };
174 if var > 0.0 {
175 var.ln()
176 } else {
177 f64::NEG_INFINITY
178 }
179 })
180 .collect()
181 }
182
183 #[must_use]
187 pub fn transform_all(&self, epochs: &[Vec<Vec<f64>>]) -> Vec<Vec<f64>> {
188 epochs.iter().map(|e| self.transform(e)).collect()
189 }
190
191 #[must_use]
193 pub fn n_features(&self) -> usize {
194 self.filters.as_ref().map_or(0, |f| f.len())
195 }
196
197 #[must_use]
199 pub fn is_fitted(&self) -> bool {
200 self.filters.is_some()
201 }
202
203 #[must_use]
205 pub fn eigenvalues(&self) -> &[f64] {
206 &self.eigenvalues
207 }
208}
209
210fn mean_covariance(epochs: &[Vec<Vec<f64>>], n_ch: usize) -> Vec<f64> {
214 let mut cov = vec![0.0; n_ch * n_ch];
215 let n_epochs = epochs.len() as f64;
216
217 for epoch in epochs {
218 let nc = epoch.len().min(n_ch);
219 let ns = epoch.first().map_or(0, |ch| ch.len());
220 if ns < 2 {
221 continue;
222 }
223
224 let means: Vec<f64> = (0..nc)
226 .map(|c| epoch[c].iter().sum::<f64>() / ns as f64)
227 .collect();
228
229 for i in 0..nc {
231 for j in i..nc {
232 let sum: f64 = (0..ns)
233 .map(|t| (epoch[i][t] - means[i]) * (epoch[j][t] - means[j]))
234 .sum();
235 let val = sum / (ns - 1) as f64;
236 cov[i * n_ch + j] += val / n_epochs;
237 if i != j {
238 cov[j * n_ch + i] += val / n_epochs;
239 }
240 }
241 }
242 }
243
244 let trace: f64 = (0..n_ch).map(|i| cov[i * n_ch + i]).sum();
246 if trace > 1e-12 {
247 for v in &mut cov {
248 *v /= trace;
249 }
250 }
251
252 cov
253}
254
255fn symmetric_eigen(n: usize, a: &[f64]) -> (Vec<f64>, Vec<f64>) {
260 let mut d = a.to_vec(); let mut v = vec![0.0; n * n]; for i in 0..n {
263 v[i * n + i] = 1.0;
264 }
265
266 let max_iter = 100 * n * n;
267 for _ in 0..max_iter {
268 let mut max_val = 0.0;
270 let mut p = 0;
271 let mut q = 1;
272 for i in 0..n {
273 for j in i + 1..n {
274 let val = d[i * n + j].abs();
275 if val > max_val {
276 max_val = val;
277 p = i;
278 q = j;
279 }
280 }
281 }
282
283 if max_val < 1e-14 {
284 break;
285 }
286
287 let app = d[p * n + p];
289 let aqq = d[q * n + q];
290 let apq = d[p * n + q];
291
292 let theta = if (app - aqq).abs() < 1e-15 {
293 std::f64::consts::FRAC_PI_4
294 } else {
295 0.5 * (2.0 * apq / (app - aqq)).atan()
296 };
297
298 let c = theta.cos();
299 let s = theta.sin();
300
301 let mut new_d = d.clone();
303 new_d[p * n + p] = c * c * app + 2.0 * s * c * apq + s * s * aqq;
304 new_d[q * n + q] = s * s * app - 2.0 * s * c * apq + c * c * aqq;
305 new_d[p * n + q] = 0.0;
306 new_d[q * n + p] = 0.0;
307
308 for i in 0..n {
309 if i != p && i != q {
310 let dip = c * d[i * n + p] + s * d[i * n + q];
311 let diq = -s * d[i * n + p] + c * d[i * n + q];
312 new_d[i * n + p] = dip;
313 new_d[p * n + i] = dip;
314 new_d[i * n + q] = diq;
315 new_d[q * n + i] = diq;
316 }
317 }
318 d = new_d;
319
320 for i in 0..n {
322 let vip = v[i * n + p];
323 let viq = v[i * n + q];
324 v[i * n + p] = c * vip + s * viq;
325 v[i * n + q] = -s * vip + c * viq;
326 }
327 }
328
329 let eigenvalues: Vec<f64> = (0..n).map(|i| d[i * n + i]).collect();
330 (eigenvalues, v)
331}
332
333fn mat_mul(n: usize, a: &[f64], b: &[f64]) -> Vec<f64> {
335 let mut c = vec![0.0; n * n];
336 for i in 0..n {
337 for k in 0..n {
338 let aik = a[i * n + k];
339 if aik.abs() < 1e-15 {
340 continue;
341 }
342 for j in 0..n {
343 c[i * n + j] += aik * b[k * n + j];
344 }
345 }
346 }
347 c
348}
349
350fn transpose(n: usize, a: &[f64]) -> Vec<f64> {
352 let mut t = vec![0.0; n * n];
353 for i in 0..n {
354 for j in 0..n {
355 t[j * n + i] = a[i * n + j];
356 }
357 }
358 t
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 fn make_epochs(n_epochs: usize, n_ch: usize, n_s: usize, class: usize) -> Vec<Vec<Vec<f64>>> {
366 (0..n_epochs)
368 .map(|epoch_idx| {
369 (0..n_ch)
370 .map(|ch| {
371 (0..n_s)
372 .map(|s| {
373 let seed = (epoch_idx * 1000 + ch * 100 + s + class * 50000) as f64;
374 let val =
375 (seed * 0.1).sin() * (1.0 + class as f64 * ch as f64 * 0.5);
376 val
377 })
378 .collect()
379 })
380 .collect()
381 })
382 .collect()
383 }
384
385 #[test]
386 fn test_csp_fit_transform() {
387 let class1 = make_epochs(20, 4, 100, 0);
388 let class2 = make_epochs(20, 4, 100, 1);
389
390 let mut csp = CSP::new(2);
391 csp.fit(&class1, &class2);
392
393 assert!(csp.is_fitted());
394 assert_eq!(csp.n_features(), 4); let features = csp.transform(&class1[0]);
397 assert_eq!(features.len(), 4);
398 assert!(features.iter().all(|f| f.is_finite()));
400 }
401
402 #[test]
403 fn test_csp_transform_all() {
404 let class1 = make_epochs(10, 3, 50, 0);
405 let class2 = make_epochs(10, 3, 50, 1);
406
407 let mut csp = CSP::new(1);
408 csp.fit(&class1, &class2);
409
410 let features1 = csp.transform_all(&class1);
411 let features2 = csp.transform_all(&class2);
412
413 assert_eq!(features1.len(), 10);
414 assert_eq!(features2.len(), 10);
415 assert_eq!(features1[0].len(), 2); let mean1: f64 = features1.iter().map(|f| f[0]).sum::<f64>() / 10.0;
419 let mean2: f64 = features2.iter().map(|f| f[0]).sum::<f64>() / 10.0;
420 assert!(
421 (mean1 - mean2).abs() > 1e-6,
422 "CSP should separate classes: {mean1} vs {mean2}"
423 );
424 }
425
426 #[test]
427 fn test_csp_not_fitted() {
428 let csp = CSP::new(2);
429 assert!(!csp.is_fitted());
430 assert_eq!(csp.transform(&[vec![1.0, 2.0]]).len(), 0);
431 }
432
433 #[test]
434 fn test_symmetric_eigen() {
435 let a = vec![2.0, 1.0, 1.0, 3.0];
437 let (vals, _vecs) = symmetric_eigen(2, &a);
438 let mut sorted = vals.clone();
439 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
440 assert!((sorted[0] - 1.382).abs() < 0.01);
442 assert!((sorted[1] - 3.618).abs() < 0.01);
443 }
444}