1use crate::error::{IoError, IoResult};
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use scirs2_core::random::thread_rng;
12
13pub struct FastICA {
18 n_components: usize,
20 max_iter: usize,
22 tolerance: f32,
24 nonlinearity: Nonlinearity,
26}
27
28#[derive(Debug, Clone, Copy)]
30pub enum Nonlinearity {
31 LogCosh,
33 Exp,
35 Cube,
37}
38
39impl FastICA {
40 pub fn new(n_components: usize, max_iter: Option<usize>, tolerance: Option<f32>) -> Self {
47 Self {
48 n_components,
49 max_iter: max_iter.unwrap_or(200),
50 tolerance: tolerance.unwrap_or(1e-4),
51 nonlinearity: Nonlinearity::LogCosh,
52 }
53 }
54
55 pub fn with_nonlinearity(mut self, nonlinearity: Nonlinearity) -> Self {
57 self.nonlinearity = nonlinearity;
58 self
59 }
60
61 pub fn fit_transform(&self, mixed: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
70 let (n_samples, n_signals) = mixed.dim();
71
72 if n_samples < 2 || n_signals < 2 {
73 return Err(IoError::SignalError(
74 "Need at least 2 samples and 2 signals".into(),
75 ));
76 }
77
78 if self.n_components > n_signals {
79 return Err(IoError::SignalError(
80 "n_components cannot exceed n_signals".into(),
81 ));
82 }
83
84 let mean = mixed
86 .mean_axis(Axis(0))
87 .expect("Mean axis computation must succeed");
88 let centered = mixed - &mean.view().insert_axis(Axis(0));
89
90 let (whitened, whitening_matrix) = Self::whiten(¢ered)?;
92
93 let unmixing = self.fastica_core(&whitened)?;
95
96 let sources = whitened.dot(&unmixing.t());
98
99 let full_unmixing = unmixing.dot(&whitening_matrix);
101
102 Ok((sources, full_unmixing))
103 }
104
105 fn fastica_core(&self, whitened: &Array2<f32>) -> IoResult<Array2<f32>> {
107 let n_components = self.n_components;
108 let mut rng = thread_rng();
109
110 let mut w = Array2::from_shape_fn((n_components, whitened.ncols()), |_| {
112 rng.gen_range(-1.0..1.0)
113 });
114
115 Self::gram_schmidt(&mut w);
117
118 for _iter in 0..self.max_iter {
120 let w_old = w.clone();
121
122 for i in 0..n_components {
124 let mut w_i = w.row(i).to_owned();
125
126 let wx = whitened.dot(&w_i);
128 let (g_wx, gp_wx) = self.apply_nonlinearity(&wx);
129
130 let eg = whitened.t().dot(&g_wx) / whitened.nrows() as f32;
131 let egp = gp_wx.mean().unwrap();
132
133 w_i = eg - &w_i * egp;
135
136 for (j, val) in w_i.iter().enumerate() {
138 w[[i, j]] = *val;
139 }
140
141 for j in 0..i {
143 let w_j = w.row(j).to_owned();
144 let dot = w_i.dot(&w_j);
145 w_i = w_i - &w_j * dot;
146 }
147
148 let norm = w_i.iter().map(|x| x * x).sum::<f32>().sqrt();
150 if norm > 1e-10 {
151 w_i /= norm;
152 }
153
154 for (j, val) in w_i.iter().enumerate() {
156 w[[i, j]] = *val;
157 }
158 }
159
160 let mut max_diff = 0.0f32;
162 for i in 0..n_components {
163 for j in 0..w.ncols() {
164 let diff = (w[[i, j]] - w_old[[i, j]]).abs();
165 max_diff = max_diff.max(diff);
166 }
167 }
168
169 if max_diff < self.tolerance {
170 break;
171 }
172 }
173
174 Ok(w)
175 }
176
177 fn apply_nonlinearity(&self, x: &Array1<f32>) -> (Array1<f32>, Array1<f32>) {
179 match self.nonlinearity {
180 Nonlinearity::LogCosh => {
181 let alpha = 1.0;
182 let g = x.mapv(|v| (alpha * v).tanh());
183 let gp = x.mapv(|v| alpha * (1.0 - (alpha * v).tanh().powi(2)));
184 (g, gp)
185 }
186 Nonlinearity::Exp => {
187 let g = x.mapv(|v| v * (-v * v / 2.0).exp());
188 let gp = x.mapv(|v| (1.0 - v * v) * (-v * v / 2.0).exp());
189 (g, gp)
190 }
191 Nonlinearity::Cube => {
192 let g = x.mapv(|v| v.powi(3));
193 let gp = x.mapv(|v| 3.0 * v * v);
194 (g, gp)
195 }
196 }
197 }
198
199 fn whiten(data: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
201 let n_samples = data.nrows();
202
203 let cov = data.t().dot(data) / n_samples as f32;
205
206 let (eigenvalues, eigenvectors) = Self::simple_eig(&cov)?;
209
210 let mut whitening = eigenvectors.t().to_owned();
212 for i in 0..eigenvalues.len() {
213 let scale = 1.0 / (eigenvalues[i].max(1e-10).sqrt());
214 for j in 0..whitening.ncols() {
215 whitening[[i, j]] *= scale;
216 }
217 }
218
219 let whitened = data.dot(&whitening.t());
221
222 Ok((whitened, whitening))
223 }
224
225 fn simple_eig(matrix: &Array2<f32>) -> IoResult<(Vec<f32>, Array2<f32>)> {
227 let n = matrix.nrows();
228 let mut eigenvalues = Vec::new();
229 let mut eigenvectors = Array2::zeros((n, n));
230 let mut remaining = matrix.clone();
231
232 for k in 0..n {
233 let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen_range(-1.0..1.0));
235 v = &v / v.iter().map(|x| x * x).sum::<f32>().sqrt();
236
237 for _ in 0..100 {
238 let v_new = remaining.dot(&v);
239 let norm = v_new.iter().map(|x| x * x).sum::<f32>().sqrt();
240 if norm < 1e-10 {
241 break;
242 }
243 v = &v_new / norm;
244 }
245
246 let av = remaining.dot(&v);
248 let eigenvalue = av.dot(&v);
249 eigenvalues.push(eigenvalue);
250
251 for i in 0..n {
253 eigenvectors[[i, k]] = v[i];
254 }
255
256 let vv = v
258 .clone()
259 .insert_axis(Axis(1))
260 .dot(&v.clone().insert_axis(Axis(0)));
261 remaining = &remaining - &(&vv * eigenvalue);
262 }
263
264 Ok((eigenvalues, eigenvectors))
265 }
266
267 fn gram_schmidt(matrix: &mut Array2<f32>) {
269 let n_rows = matrix.nrows();
270
271 for i in 0..n_rows {
272 for j in 0..i {
274 let dot: f32 = (0..matrix.ncols())
275 .map(|k| matrix[[i, k]] * matrix[[j, k]])
276 .sum();
277
278 for k in 0..matrix.ncols() {
279 matrix[[i, k]] -= dot * matrix[[j, k]];
280 }
281 }
282
283 let norm: f32 = (0..matrix.ncols())
285 .map(|k| matrix[[i, k]] * matrix[[i, k]])
286 .sum::<f32>()
287 .sqrt();
288
289 if norm > 1e-10 {
290 for k in 0..matrix.ncols() {
291 matrix[[i, k]] /= norm;
292 }
293 }
294 }
295 }
296}
297
298pub struct NMF {
303 n_components: usize,
305 max_iter: usize,
307 tolerance: f32,
309}
310
311impl NMF {
312 pub fn new(n_components: usize, max_iter: Option<usize>, tolerance: Option<f32>) -> Self {
319 Self {
320 n_components,
321 max_iter: max_iter.unwrap_or(200),
322 tolerance: tolerance.unwrap_or(1e-4),
323 }
324 }
325
326 pub fn fit_transform(&self, v: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
335 let (n_samples, n_features) = v.dim();
336
337 if n_samples < 1 || n_features < 1 {
338 return Err(IoError::SignalError("Empty matrix".into()));
339 }
340
341 if v.iter().any(|&x| x < 0.0) {
343 return Err(IoError::SignalError("Matrix must be non-negative".into()));
344 }
345
346 let mut rng = thread_rng();
348 let mut w =
349 Array2::from_shape_fn((n_samples, self.n_components), |_| rng.gen_range(0.0..1.0));
350 let mut h =
351 Array2::from_shape_fn((self.n_components, n_features), |_| rng.gen_range(0.0..1.0));
352
353 let eps = 1e-10;
354 let mut prev_error = f32::MAX;
355
356 for _iter in 0..self.max_iter {
358 let wt_v = w.t().dot(v);
360 let wt_w_h = w.t().dot(&w).dot(&h);
361
362 for i in 0..h.nrows() {
363 for j in 0..h.ncols() {
364 h[[i, j]] *= wt_v[[i, j]] / (wt_w_h[[i, j]] + eps);
365 }
366 }
367
368 let v_ht = v.dot(&h.t());
370 let w_h_ht = w.dot(&h).dot(&h.t());
371
372 for i in 0..w.nrows() {
373 for j in 0..w.ncols() {
374 w[[i, j]] *= v_ht[[i, j]] / (w_h_ht[[i, j]] + eps);
375 }
376 }
377
378 let wh = w.dot(&h);
380 let error: f32 = v
381 .iter()
382 .zip(wh.iter())
383 .map(|(&a, &b)| (a - b).powi(2))
384 .sum::<f32>()
385 .sqrt();
386
387 if (prev_error - error).abs() < self.tolerance {
389 break;
390 }
391 prev_error = error;
392 }
393
394 Ok((w, h))
395 }
396
397 pub fn transform(&self, v: &Array2<f32>, h: &Array2<f32>) -> IoResult<Array2<f32>> {
403 let (n_samples, _n_features) = v.dim();
404 let mut rng = thread_rng();
405
406 let mut w =
408 Array2::from_shape_fn((n_samples, self.n_components), |_| rng.gen_range(0.0..1.0));
409
410 let eps = 1e-10;
411
412 for _ in 0..self.max_iter {
414 let v_ht = v.dot(&h.t());
415 let w_h_ht = w.dot(h).dot(&h.t());
416
417 for i in 0..w.nrows() {
418 for j in 0..w.ncols() {
419 w[[i, j]] *= v_ht[[i, j]] / (w_h_ht[[i, j]] + eps);
420 }
421 }
422 }
423
424 Ok(w)
425 }
426}
427
428pub struct PCA {
433 n_components: usize,
435}
436
437impl PCA {
438 pub fn new(n_components: usize) -> Self {
440 Self { n_components }
441 }
442
443 pub fn fit_transform(
453 &self,
454 data: &Array2<f32>,
455 ) -> IoResult<(Array2<f32>, Array2<f32>, Vec<f32>)> {
456 let (n_samples, n_features) = data.dim();
457
458 if self.n_components > n_features {
459 return Err(IoError::SignalError(
460 "n_components cannot exceed n_features".into(),
461 ));
462 }
463
464 let mean = data
466 .mean_axis(Axis(0))
467 .expect("Mean axis computation must succeed");
468 let centered = data - &mean.view().insert_axis(Axis(0));
469
470 let cov = centered.t().dot(¢ered) / n_samples as f32;
472
473 let (eigenvalues, eigenvectors) = FastICA::simple_eig(&cov)?;
475
476 let mut indices: Vec<usize> = (0..eigenvalues.len()).collect();
478 indices.sort_by(|&a, &b| {
479 eigenvalues[b]
480 .partial_cmp(&eigenvalues[a])
481 .unwrap_or(std::cmp::Ordering::Equal)
482 });
483
484 let mut components = Array2::zeros((self.n_components, n_features));
486 let mut explained_var = Vec::new();
487
488 for i in 0..self.n_components {
489 let idx = indices[i];
490 explained_var.push(eigenvalues[idx]);
491
492 for j in 0..n_features {
493 components[[i, j]] = eigenvectors[[j, idx]];
494 }
495 }
496
497 let transformed = centered.dot(&components.t());
499
500 Ok((transformed, components, explained_var))
501 }
502
503 pub fn transform(&self, data: &Array2<f32>, components: &Array2<f32>) -> Array2<f32> {
505 let mean = data
507 .mean_axis(Axis(0))
508 .expect("Mean axis computation must succeed");
509 let centered = data - &mean.view().insert_axis(Axis(0));
510
511 centered.dot(&components.t())
513 }
514
515 pub fn inverse_transform(
517 &self,
518 transformed: &Array2<f32>,
519 components: &Array2<f32>,
520 ) -> Array2<f32> {
521 transformed.dot(components)
522 }
523}
524
525pub struct TemporalDecorrelation {
527 tau: usize,
529}
530
531impl TemporalDecorrelation {
532 pub fn new(tau: usize) -> Self {
534 Self { tau }
535 }
536
537 pub fn separate(&self, mixed: &Array2<f32>) -> IoResult<Array2<f32>> {
541 let (n_samples, n_channels) = mixed.dim();
542
543 if n_samples <= self.tau {
544 return Err(IoError::SignalError("Insufficient samples".into()));
545 }
546
547 let mut cov_delay = Array2::zeros((n_channels, n_channels));
549
550 for i in 0..(n_samples - self.tau) {
551 for j in 0..n_channels {
552 for k in 0..n_channels {
553 cov_delay[[j, k]] += mixed[[i, j]] * mixed[[i + self.tau, k]];
554 }
555 }
556 }
557
558 cov_delay /= (n_samples - self.tau) as f32;
559
560 let (_, eigenvectors) = FastICA::simple_eig(&cov_delay)?;
562
563 let separated = mixed.dot(&eigenvectors);
565
566 Ok(separated)
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use scirs2_core::ndarray::arr2;
574
575 #[test]
576 fn test_fastica_basic() {
577 let mixed = arr2(&[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]]);
579
580 let ica = FastICA::new(2, None, None);
581 let result = ica.fit_transform(&mixed);
582
583 assert!(result.is_ok());
584 let (sources, unmixing) = result.unwrap();
585 assert_eq!(sources.dim(), (4, 2));
586 assert_eq!(unmixing.dim(), (2, 2));
587 }
588
589 #[test]
590 fn test_nmf_basic() {
591 let v = arr2(&[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]);
593
594 let nmf = NMF::new(2, Some(50), None);
595 let result = nmf.fit_transform(&v);
596
597 assert!(result.is_ok());
598 let (w, h) = result.unwrap();
599 assert_eq!(w.dim(), (3, 2));
600 assert_eq!(h.dim(), (2, 3));
601
602 assert!(w.iter().all(|&x| x >= 0.0));
604 assert!(h.iter().all(|&x| x >= 0.0));
605 }
606
607 #[test]
608 fn test_pca_basic() {
609 let data = arr2(&[
610 [1.0, 2.0, 3.0],
611 [2.0, 3.0, 4.0],
612 [3.0, 4.0, 5.0],
613 [4.0, 5.0, 6.0],
614 ]);
615
616 let pca = PCA::new(2);
617 let result = pca.fit_transform(&data);
618
619 assert!(result.is_ok());
620 let (transformed, components, explained_var) = result.unwrap();
621 assert_eq!(transformed.dim(), (4, 2));
622 assert_eq!(components.dim(), (2, 3));
623 assert_eq!(explained_var.len(), 2);
624 }
625
626 #[test]
627 fn test_temporal_decorrelation() {
628 let mixed = arr2(&[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]]);
629
630 let td = TemporalDecorrelation::new(1);
631 let result = td.separate(&mixed);
632
633 assert!(result.is_ok());
634 let separated = result.unwrap();
635 assert_eq!(separated.dim(), (5, 2));
636 }
637
638 #[test]
639 fn test_nmf_negative_input() {
640 let v = arr2(&[[1.0, -2.0], [2.0, 3.0]]);
641
642 let nmf = NMF::new(2, None, None);
643 let result = nmf.fit_transform(&v);
644
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_pca_reconstruction() {
650 let data = arr2(&[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]);
651
652 let pca = PCA::new(2);
653 let (transformed, components, _) = pca.fit_transform(&data).unwrap();
654
655 let reconstructed = pca.inverse_transform(&transformed, &components);
656 assert_eq!(reconstructed.nrows(), 3);
657 }
658}