1use crate::error::{AprenderError, Result};
43use crate::primitives::{Matrix, Vector};
44
45#[derive(Debug, Clone)]
49pub struct ICA {
50 n_components: usize,
52
53 max_iter: usize,
55
56 tol: f32,
58
59 random_state: Option<u64>,
61
62 whitening_matrix: Option<Matrix<f32>>,
65
66 unmixing_matrix: Option<Matrix<f32>>,
68
69 mean: Option<Vector<f32>>,
71}
72
73impl ICA {
74 #[must_use]
88 pub fn new(n_components: usize) -> Self {
89 Self {
90 n_components,
91 max_iter: 200,
92 tol: 1e-4,
93 random_state: None,
94 whitening_matrix: None,
95 unmixing_matrix: None,
96 mean: None,
97 }
98 }
99
100 #[must_use]
102 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
103 self.max_iter = max_iter;
104 self
105 }
106
107 #[must_use]
109 pub fn with_tolerance(mut self, tol: f32) -> Self {
110 self.tol = tol;
111 self
112 }
113
114 #[must_use]
116 pub fn with_random_state(mut self, seed: u64) -> Self {
117 self.random_state = Some(seed);
118 self
119 }
120
121 #[allow(clippy::similar_names)]
131 pub fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
132 let n = x.n_rows();
133 let p = x.n_cols();
134
135 if n == 0 || p == 0 {
136 return Err(AprenderError::Other("Data cannot be empty".into()));
137 }
138
139 if self.n_components > p {
140 return Err(AprenderError::Other(format!(
141 "n_components ({}) cannot exceed number of features ({})",
142 self.n_components, p
143 )));
144 }
145
146 let (x_centered, mean) = Self::center_data(x)?;
148 self.mean = Some(mean);
149
150 let (x_whitened, whitening_matrix) = Self::whiten_data(&x_centered, self.n_components)?;
152 self.whitening_matrix = Some(whitening_matrix);
153
154 let unmixing = self.fastica(&x_whitened)?;
156 self.unmixing_matrix = Some(unmixing);
157
158 Ok(())
159 }
160
161 pub fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
171 let mean = self
172 .mean
173 .as_ref()
174 .ok_or_else(|| AprenderError::Other("Model not fitted. Call fit() first.".into()))?;
175
176 let whitening = self
177 .whitening_matrix
178 .as_ref()
179 .ok_or_else(|| AprenderError::Other("Model not fitted. Call fit() first.".into()))?;
180
181 let unmixing = self
182 .unmixing_matrix
183 .as_ref()
184 .ok_or_else(|| AprenderError::Other("Model not fitted. Call fit() first.".into()))?;
185
186 let n = x.n_rows();
187 let p = x.n_cols();
188
189 if p != mean.len() {
190 return Err(AprenderError::DimensionMismatch {
191 expected: format!("{} features", mean.len()),
192 actual: format!("{p} features in data"),
193 });
194 }
195
196 let mut x_centered_data = Vec::with_capacity(n * p);
198 for i in 0..n {
199 for j in 0..p {
200 x_centered_data.push(x.get(i, j) - mean[j]);
201 }
202 }
203 let x_centered = Matrix::from_vec(n, p, x_centered_data)
204 .map_err(|e| AprenderError::Other(format!("Centering failed: {e}")))?;
205
206 let x_whitened = x_centered
208 .matmul(whitening)
209 .map_err(|e| AprenderError::Other(format!("Whitening failed: {e}")))?;
210
211 x_whitened
213 .matmul(unmixing)
214 .map_err(|e| AprenderError::Other(format!("Unmixing failed: {e}")))
215 }
216
217 #[allow(clippy::needless_range_loop)]
219 fn center_data(x: &Matrix<f32>) -> Result<(Matrix<f32>, Vector<f32>)> {
220 let n = x.n_rows();
221 let p = x.n_cols();
222
223 let mut means = vec![0.0_f32; p];
225 #[allow(clippy::needless_range_loop)]
226 for j in 0..p {
227 let mut sum = 0.0;
228 for i in 0..n {
229 sum += x.get(i, j);
230 }
231 means[j] = sum / n as f32;
232 }
233
234 let mut centered_data = Vec::with_capacity(n * p);
236 for i in 0..n {
237 for j in 0..p {
238 centered_data.push(x.get(i, j) - means[j]);
239 }
240 }
241
242 let centered = Matrix::from_vec(n, p, centered_data)
243 .map_err(|e| AprenderError::Other(format!("Failed to center data: {e}")))?;
244
245 Ok((centered, Vector::from_vec(means)))
246 }
247
248 #[allow(clippy::similar_names)]
252 #[allow(clippy::needless_range_loop)]
253 fn whiten_data(
254 x_centered: &Matrix<f32>,
255 n_components: usize,
256 ) -> Result<(Matrix<f32>, Matrix<f32>)> {
257 let n = x_centered.n_rows();
258 let p = x_centered.n_cols();
259
260 let xt = x_centered.transpose();
262 let cov = xt
263 .matmul(x_centered)
264 .map_err(|e| AprenderError::Other(format!("Covariance computation failed: {e}")))?;
265
266 let mut cov_data = vec![0.0_f32; p * p];
268 for i in 0..p {
269 for j in 0..p {
270 cov_data[i * p + j] = cov.get(i, j) / n as f32;
271 }
272 }
273 let cov_scaled = Matrix::from_vec(p, p, cov_data)
274 .map_err(|e| AprenderError::Other(format!("Covariance scaling failed: {e}")))?;
275
276 let (eigenvalues, eigenvectors) = Self::eigen_decomposition(&cov_scaled, n_components)?;
278
279 let mut whitening_data = Vec::with_capacity(p * n_components);
282 for j in 0..n_components {
283 let scale = 1.0 / eigenvalues[j].sqrt();
284 for i in 0..p {
285 whitening_data.push(eigenvectors.get(i, j) * scale);
286 }
287 }
288 let whitening_matrix = Matrix::from_vec(p, n_components, whitening_data)
289 .map_err(|e| AprenderError::Other(format!("Whitening matrix creation failed: {e}")))?;
290
291 let x_whitened = x_centered
293 .matmul(&whitening_matrix)
294 .map_err(|e| AprenderError::Other(format!("Data whitening failed: {e}")))?;
295
296 Ok((x_whitened, whitening_matrix))
297 }
298
299 #[allow(clippy::needless_range_loop)]
301 fn eigen_decomposition(matrix: &Matrix<f32>, k: usize) -> Result<(Vec<f32>, Matrix<f32>)> {
302 let n = matrix.n_rows();
303
304 if matrix.n_cols() != n {
305 return Err(AprenderError::Other(
306 "Eigendecomposition requires square matrix".into(),
307 ));
308 }
309
310 let mut eigenvalues = Vec::with_capacity(k);
311 let mut eigenvectors_data = Vec::with_capacity(n * k);
312
313 let mut residual = matrix.clone();
314
315 for _ in 0..k {
316 let (eigenvalue, eigenvector) = Self::power_iteration(&residual, 100)?;
318
319 eigenvalues.push(eigenvalue);
320 eigenvectors_data.extend(eigenvector.as_slice());
321
322 let mut new_residual_data = vec![0.0_f32; n * n];
324 for i in 0..n {
325 for j in 0..n {
326 let deflation = eigenvalue * eigenvector[i] * eigenvector[j];
327 new_residual_data[i * n + j] = residual.get(i, j) - deflation;
328 }
329 }
330 residual = Matrix::from_vec(n, n, new_residual_data)
331 .map_err(|e| AprenderError::Other(format!("Deflation failed: {e}")))?;
332 }
333
334 let eigenvectors = Matrix::from_vec(n, k, eigenvectors_data).map_err(|e| {
335 AprenderError::Other(format!("Eigenvector matrix creation failed: {e}"))
336 })?;
337
338 Ok((eigenvalues, eigenvectors))
339 }
340
341 #[allow(clippy::needless_range_loop)]
343 fn power_iteration(matrix: &Matrix<f32>, max_iter: usize) -> Result<(f32, Vector<f32>)> {
344 let n = matrix.n_rows();
345
346 let mut v = vec![1.0_f32; n];
348 let norm = (v.iter().map(|x| x * x).sum::<f32>()).sqrt();
349 for val in &mut v {
350 *val /= norm;
351 }
352
353 let mut eigenvalue = 0.0;
354
355 for _ in 0..max_iter {
356 let mut v_new = vec![0.0_f32; n];
358 for i in 0..n {
359 let mut sum = 0.0;
360 for j in 0..n {
361 sum += matrix.get(i, j) * v[j];
362 }
363 v_new[i] = sum;
364 }
365
366 let norm = (v_new.iter().map(|x| x * x).sum::<f32>()).sqrt();
368 if norm < 1e-10 {
369 return Err(AprenderError::Other(
370 "Power iteration converged to zero vector".into(),
371 ));
372 }
373
374 for val in &mut v_new {
375 *val /= norm;
376 }
377
378 eigenvalue = norm;
379 v = v_new;
380 }
381
382 Ok((eigenvalue, Vector::from_vec(v)))
383 }
384
385 #[allow(clippy::similar_names)]
389 #[allow(clippy::needless_range_loop)]
390 fn fastica(&self, x_white: &Matrix<f32>) -> Result<Matrix<f32>> {
391 let n = x_white.n_rows();
392 let p = x_white.n_cols(); let mut w_vectors = Vec::with_capacity(p * p);
395
396 for comp in 0..p {
398 let mut w = vec![0.0_f32; p];
400 w[comp % p] = 1.0;
401
402 let norm = (w.iter().map(|x| x * x).sum::<f32>()).sqrt();
404 for val in &mut w {
405 *val /= norm;
406 }
407
408 for _iter in 0..self.max_iter {
410 let mut wtx = vec![0.0_f32; n];
412 for i in 0..n {
413 let mut sum = 0.0;
414 for j in 0..p {
415 sum += w[j] * x_white.get(i, j);
416 }
417 wtx[i] = sum;
418 }
419
420 let mut ex_g = vec![0.0_f32; p];
422 for j in 0..p {
423 let mut sum = 0.0;
424 for i in 0..n {
425 let g = wtx[i].tanh(); sum += x_white.get(i, j) * g;
427 }
428 ex_g[j] = sum / n as f32;
429 }
430
431 let mut eg_prime = 0.0;
433 for i in 0..n {
434 let tanh_val = wtx[i].tanh();
435 eg_prime += 1.0 - tanh_val * tanh_val;
436 }
437 eg_prime /= n as f32;
438
439 let mut w_new = vec![0.0_f32; p];
441 for j in 0..p {
442 w_new[j] = ex_g[j] - eg_prime * w[j];
443 }
444
445 for prev_comp in 0..comp {
447 let mut dot = 0.0;
448 for j in 0..p {
449 dot += w_new[j] * w_vectors[prev_comp * p + j];
450 }
451 for j in 0..p {
452 w_new[j] -= dot * w_vectors[prev_comp * p + j];
453 }
454 }
455
456 let norm = (w_new.iter().map(|x| x * x).sum::<f32>()).sqrt();
458 if norm < 1e-10 {
459 return Err(AprenderError::Other(
460 "FastICA failed: w converged to zero".into(),
461 ));
462 }
463 for val in &mut w_new {
464 *val /= norm;
465 }
466
467 let mut dot = 0.0;
469 for j in 0..p {
470 dot += w[j] * w_new[j];
471 }
472
473 if (1.0 - dot.abs()) < self.tol {
474 w = w_new;
475 break;
476 }
477
478 w = w_new;
479 }
480
481 w_vectors.extend(&w);
483 }
484
485 Matrix::from_vec(p, p, w_vectors)
487 .map_err(|e| AprenderError::Other(format!("Failed to create unmixing matrix: {e}")))
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_ica_basic() {
497 let data = Matrix::from_vec(
499 10,
500 2,
501 vec![
502 1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0, 5.0, 6.0, 1.5, 2.5, 2.5, 1.5, 3.5, 4.5,
503 4.5, 3.5, 5.5, 6.5,
504 ],
505 )
506 .expect("Valid matrix");
507
508 let mut ica = ICA::new(2);
509 let result = ica.fit(&data);
510 assert!(result.is_ok(), "ICA should fit");
511
512 let sources = ica.transform(&data).expect("Should transform");
513 assert_eq!(sources.n_rows(), 10);
514 assert_eq!(sources.n_cols(), 2);
515 }
516
517 #[test]
518 fn test_ica_invalid_n_components() {
519 let data = Matrix::from_vec(5, 2, vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0])
520 .expect("Valid matrix");
521
522 let mut ica = ICA::new(3); let result = ica.fit(&data);
524 assert!(result.is_err());
525 }
526
527 #[test]
528 fn test_ica_transform_not_fitted() {
529 let ica = ICA::new(2);
530 let data =
531 Matrix::from_vec(3, 2, vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0]).expect("Valid matrix");
532
533 let result = ica.transform(&data);
534 assert!(result.is_err());
535 }
536
537 #[test]
538 fn test_ica_dimension_mismatch() {
539 let data = Matrix::from_vec(
540 5,
541 3,
542 vec![
543 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0,
544 ],
545 )
546 .expect("Valid matrix");
547
548 let mut ica = ICA::new(2);
549 ica.fit(&data).expect("Should fit");
550
551 let wrong_data =
553 Matrix::from_vec(3, 2, vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0]).expect("Valid matrix");
554
555 let result = ica.transform(&wrong_data);
556 assert!(result.is_err());
557 }
558
559 #[test]
560 fn test_ica_with_options() {
561 let data = Matrix::from_vec(
563 8,
564 2,
565 vec![
566 1.0, 2.0, 2.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 4.0, 6.0, 5.0, 7.0, 7.0, 8.0, 9.0,
567 ],
568 )
569 .expect("Valid matrix");
570
571 let mut ica = ICA::new(2).with_max_iter(100).with_tolerance(1e-5);
572
573 let result = ica.fit(&data);
574 assert!(result.is_ok());
575 }
576
577 #[test]
578 fn test_center_data() {
579 let data =
580 Matrix::from_vec(3, 2, vec![1.0, 2.0, 2.0, 4.0, 3.0, 6.0]).expect("Valid matrix");
581
582 let (centered, mean) = ICA::center_data(&data).expect("Should center");
583
584 assert_eq!(mean.len(), 2);
585 assert!((mean[0] - 2.0).abs() < 1e-6); assert!((mean[1] - 4.0).abs() < 1e-6); let mut col0_sum = 0.0;
590 let mut col1_sum = 0.0;
591 for i in 0..3 {
592 col0_sum += centered.get(i, 0);
593 col1_sum += centered.get(i, 1);
594 }
595 assert!(col0_sum.abs() < 1e-6);
596 assert!(col1_sum.abs() < 1e-6);
597 }
598
599 #[test]
600 fn test_power_iteration() {
601 let matrix = Matrix::from_vec(2, 2, vec![3.0, 1.0, 1.0, 3.0]).expect("Valid matrix");
603
604 let (eigenvalue, eigenvector) =
605 ICA::power_iteration(&matrix, 100).expect("Should converge");
606
607 assert!((eigenvalue - 4.0).abs() < 0.1, "Eigenvalue should be ~4.0");
609
610 let norm: f32 = eigenvector
612 .as_slice()
613 .iter()
614 .map(|x| x * x)
615 .sum::<f32>()
616 .sqrt();
617 assert!((norm - 1.0).abs() < 1e-6);
618 }
619
620 #[test]
625 fn test_ica_with_random_state() {
626 let data = Matrix::from_vec(
627 8,
628 2,
629 vec![
630 1.0, 2.0, 2.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 4.0, 6.0, 5.0, 7.0, 7.0, 8.0, 9.0,
631 ],
632 )
633 .expect("Valid matrix");
634
635 let mut ica = ICA::new(2).with_random_state(42);
636 let result = ica.fit(&data);
637 assert!(result.is_ok());
638 }
639
640 #[test]
641 fn test_ica_empty_data() {
642 let data = Matrix::from_vec(0, 2, vec![]).expect("Valid empty matrix");
643 let mut ica = ICA::new(2);
644 let result = ica.fit(&data);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_ica_empty_features() {
650 let data = Matrix::from_vec(5, 0, vec![]).expect("Valid empty matrix");
651 let mut ica = ICA::new(1);
652 let result = ica.fit(&data);
653 assert!(result.is_err());
654 }
655
656 #[test]
657 fn test_ica_single_component() {
658 let data = Matrix::from_vec(
659 6,
660 3,
661 vec![
662 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0, 5.0, 10.0, 15.0, 6.0,
663 12.0, 18.0,
664 ],
665 )
666 .expect("Valid matrix");
667
668 let mut ica = ICA::new(1);
669 let result = ica.fit(&data);
670 assert!(result.is_ok());
671
672 let sources = ica.transform(&data).expect("Should transform");
673 assert_eq!(sources.n_cols(), 1);
674 }
675
676 #[test]
677 fn test_ica_whitening() {
678 let data = Matrix::from_vec(
679 10,
680 2,
681 vec![
682 1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0, 5.0, 6.0, 1.5, 2.5, 2.5, 1.5, 3.5, 4.5,
683 4.5, 3.5, 5.5, 6.5,
684 ],
685 )
686 .expect("Valid matrix");
687
688 let (centered, _mean) = ICA::center_data(&data).expect("Should center");
690 let (whitened, _whitening_matrix) = ICA::whiten_data(¢ered, 2).expect("Should whiten");
691
692 assert_eq!(whitened.n_rows(), 10);
693 assert_eq!(whitened.n_cols(), 2);
694 }
695
696 #[test]
697 fn test_ica_eigen_decomposition() {
698 let matrix = Matrix::from_vec(3, 3, vec![4.0, 1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 2.0])
700 .expect("Valid matrix");
701
702 let (eigenvalues, eigenvectors) =
703 ICA::eigen_decomposition(&matrix, 2).expect("Should decompose");
704
705 assert_eq!(eigenvalues.len(), 2);
706 assert_eq!(eigenvectors.n_rows(), 3);
707 assert_eq!(eigenvectors.n_cols(), 2);
708 }
709
710 #[test]
711 fn test_ica_eigen_decomposition_non_square() {
712 let matrix =
713 Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("Valid matrix");
714
715 let result = ICA::eigen_decomposition(&matrix, 2);
716 assert!(result.is_err());
717 }
718
719 #[test]
720 fn test_ica_clone() {
721 let ica = ICA::new(3)
722 .with_max_iter(100)
723 .with_tolerance(1e-5)
724 .with_random_state(42);
725
726 let cloned = ica.clone();
727 assert_eq!(format!("{:?}", ica), format!("{:?}", cloned));
729 }
730
731 #[test]
732 fn test_ica_debug() {
733 let ica = ICA::new(2);
734 let debug_str = format!("{:?}", ica);
735 assert!(debug_str.contains("ICA"));
736 assert!(debug_str.contains("n_components"));
737 }
738
739 #[test]
740 fn test_ica_fit_then_transform_new_data() {
741 let training_data = Matrix::from_vec(
742 10,
743 2,
744 vec![
745 1.0, 2.0, 2.0, 1.0, 3.0, 4.0, 4.0, 3.0, 5.0, 6.0, 1.5, 2.5, 2.5, 1.5, 3.5, 4.5,
746 4.5, 3.5, 5.5, 6.5,
747 ],
748 )
749 .expect("Valid matrix");
750
751 let mut ica = ICA::new(2);
752 ica.fit(&training_data).expect("Should fit");
753
754 let new_data =
756 Matrix::from_vec(5, 2, vec![2.0, 3.0, 3.0, 2.0, 4.0, 5.0, 5.0, 4.0, 6.0, 7.0])
757 .expect("Valid matrix");
758
759 let transformed = ica.transform(&new_data).expect("Should transform");
760 assert_eq!(transformed.n_rows(), 5);
761 assert_eq!(transformed.n_cols(), 2);
762 }
763
764 #[test]
765 fn test_ica_3d_data() {
766 let data = Matrix::from_vec(
769 12,
770 3,
771 vec![
772 1.0, 5.0, 2.0, 4.0, 2.0, 6.0, 3.0, 7.0, 1.0, 6.0, 3.0, 4.0, 2.0, 8.0, 5.0, 5.0, 1.0, 3.0, 1.5,
774 6.0, 2.5, 4.5, 2.5, 5.5, 3.5, 6.5, 1.5, 5.5, 4.5, 4.5, 2.5, 7.5, 6.5, 6.5, 1.5,
775 3.5,
776 ],
777 )
778 .expect("Valid matrix");
779
780 let mut ica = ICA::new(2); let result = ica.fit(&data);
782 assert!(result.is_ok());
783
784 let sources = ica.transform(&data).expect("Should transform");
785 assert_eq!(sources.n_rows(), 12);
786 assert_eq!(sources.n_cols(), 2);
787 }
788
789 #[test]
790 fn test_ica_strict_tolerance() {
791 let data = Matrix::from_vec(
792 8,
793 2,
794 vec![
795 1.0, 2.0, 2.0, 3.0, 3.0, 5.0, 4.0, 6.0, 5.0, 4.0, 6.0, 5.0, 7.0, 7.0, 8.0, 9.0,
796 ],
797 )
798 .expect("Valid matrix");
799
800 let mut ica = ICA::new(2).with_tolerance(1e-8).with_max_iter(500);
802 let result = ica.fit(&data);
803 assert!(result.is_ok());
804 }
805}