1use ferrolearn_core::error::FerroError;
45use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
46use ferrolearn_core::traits::{Fit, Transform};
47use ndarray::{Array1, Array2};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand_distr::{Distribution, StandardNormal};
51
52#[derive(Debug, Clone)]
65pub struct FactorAnalysis<F> {
66 n_components: usize,
68 max_iter: usize,
70 tol: f64,
72 random_state: Option<u64>,
74 _marker: std::marker::PhantomData<F>,
75}
76
77impl<F: Float + Send + Sync + 'static> FactorAnalysis<F> {
78 #[must_use]
82 pub fn new(n_components: usize) -> Self {
83 Self {
84 n_components,
85 max_iter: 1000,
86 tol: 1e-3,
87 random_state: None,
88 _marker: std::marker::PhantomData,
89 }
90 }
91
92 #[must_use]
94 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
95 self.max_iter = max_iter;
96 self
97 }
98
99 #[must_use]
101 pub fn with_tol(mut self, tol: f64) -> Self {
102 self.tol = tol;
103 self
104 }
105
106 #[must_use]
108 pub fn with_random_state(mut self, seed: u64) -> Self {
109 self.random_state = Some(seed);
110 self
111 }
112
113 #[must_use]
115 pub fn n_components(&self) -> usize {
116 self.n_components
117 }
118}
119
120impl<F: Float + Send + Sync + 'static> Default for FactorAnalysis<F> {
121 fn default() -> Self {
122 Self::new(1)
123 }
124}
125
126#[derive(Debug, Clone)]
135pub struct FittedFactorAnalysis<F> {
136 components: Array2<F>,
138
139 noise_variance: Array1<F>,
141
142 mean: Array1<F>,
144
145 n_iter: usize,
147
148 log_likelihood: F,
150}
151
152impl<F: Float + Send + Sync + 'static> FittedFactorAnalysis<F> {
153 #[must_use]
155 pub fn components(&self) -> &Array2<F> {
156 &self.components
157 }
158
159 #[must_use]
161 pub fn noise_variance(&self) -> &Array1<F> {
162 &self.noise_variance
163 }
164
165 #[must_use]
167 pub fn mean(&self) -> &Array1<F> {
168 &self.mean
169 }
170
171 #[must_use]
173 pub fn n_iter(&self) -> usize {
174 self.n_iter
175 }
176
177 #[must_use]
179 pub fn log_likelihood(&self) -> F {
180 self.log_likelihood
181 }
182
183 pub fn inverse_transform(&self, z: &Array2<F>) -> Result<Array2<F>, FerroError> {
196 let n_components = self.components.ncols();
197 if z.ncols() != n_components {
198 return Err(FerroError::ShapeMismatch {
199 expected: vec![z.nrows(), n_components],
200 actual: vec![z.nrows(), z.ncols()],
201 context: "FittedFactorAnalysis::inverse_transform".into(),
202 });
203 }
204 let mut result = z.dot(&self.components.t());
205 for mut row in result.rows_mut() {
206 for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
207 *v = *v + m;
208 }
209 }
210 Ok(result)
211 }
212}
213
214fn cholesky_inv<F: Float>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
220 let n = a.nrows();
221 let mut l = Array2::<F>::zeros((n, n));
223 for i in 0..n {
224 for j in 0..=i {
225 let mut s = a[[i, j]];
226 for k in 0..j {
227 s = s - l[[i, k]] * l[[j, k]];
228 }
229 if i == j {
230 if s <= F::zero() {
231 s = F::from(1e-10).unwrap();
233 }
234 l[[i, j]] = s.sqrt();
235 } else {
236 l[[i, j]] = s / l[[j, j]];
237 }
238 }
239 }
240 let mut l_inv = Array2::<F>::zeros((n, n));
242 for j in 0..n {
243 l_inv[[j, j]] = F::one() / l[[j, j]];
244 for i in (j + 1)..n {
245 let mut s = F::zero();
246 for k in j..i {
247 s = s + l[[i, k]] * l_inv[[k, j]];
248 }
249 l_inv[[i, j]] = -s / l[[i, i]];
250 }
251 }
252 let mut inv = Array2::<F>::zeros((n, n));
254 for i in 0..n {
255 for j in 0..n {
256 let mut s = F::zero();
257 let start = i.max(j);
258 for k in start..n {
259 s = s + l_inv[[k, i]] * l_inv[[k, j]];
260 }
261 inv[[i, j]] = s;
262 }
263 }
264 Ok(inv)
265}
266
267fn compute_log_likelihood<F: Float + Send + Sync + 'static>(
272 x_centered: &Array2<F>,
273 w: &Array2<F>,
274 psi: &Array1<F>,
275) -> F {
276 let (n, p) = x_centered.dim();
277 let k = w.ncols();
278 let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap();
282 let n_f = F::from(n).unwrap();
283 let p_f = F::from(p).unwrap();
284
285 let mut wtpsiw = Array2::<F>::zeros((k, k));
287 for i in 0..k {
288 for j in 0..k {
289 let mut s = F::zero();
290 for d in 0..p {
291 s = s + w[[d, i]] * w[[d, j]] / psi[d];
292 }
293 wtpsiw[[i, j]] = s;
294 }
295 }
296 for i in 0..k {
298 wtpsiw[[i, i]] = wtpsiw[[i, i]] + F::one();
299 }
300 let mut log_det_inner = F::zero();
302 {
303 let mut l = Array2::<F>::zeros((k, k));
304 for i in 0..k {
305 for j in 0..=i {
306 let mut s = wtpsiw[[i, j]];
307 for kk in 0..j {
308 s = s - l[[i, kk]] * l[[j, kk]];
309 }
310 if i == j {
311 s = if s > F::zero() {
312 s
313 } else {
314 F::from(1e-30).unwrap()
315 };
316 l[[i, j]] = s.sqrt();
317 log_det_inner = log_det_inner + l[[i, j]].ln();
318 } else {
319 l[[i, j]] = s / l[[j, j]];
320 }
321 }
322 }
323 log_det_inner = log_det_inner * F::from(2.0).unwrap();
324 }
325 let log_det_psi: F = psi
326 .iter()
327 .copied()
328 .map(|v| {
329 let v_clamped = if v > F::zero() {
330 v
331 } else {
332 F::from(1e-30).unwrap()
333 };
334 v_clamped.ln()
335 })
336 .fold(F::zero(), |a, b| a + b);
337 let log_det_sigma = log_det_inner + log_det_psi;
338
339 let m_inv = match cholesky_inv(&wtpsiw) {
350 Ok(inv) => inv,
351 Err(_) => return F::neg_infinity(),
352 };
353
354 let mut trace_sum = F::zero();
355 for i in 0..n {
356 let mut psi_inv_x = Array1::<F>::zeros(p);
358 let mut xpsiinvx = F::zero();
359 for d in 0..p {
360 psi_inv_x[d] = x_centered[[i, d]] / psi[d];
361 xpsiinvx = xpsiinvx + x_centered[[i, d]] * psi_inv_x[d];
362 }
363 let mut wtpx = Array1::<F>::zeros(k);
365 for kk in 0..k {
366 let mut s = F::zero();
367 for d in 0..p {
368 s = s + w[[d, kk]] * psi_inv_x[d];
369 }
370 wtpx[kk] = s;
371 }
372 let mut quad = F::zero();
374 for ii in 0..k {
375 let mut s = F::zero();
376 for jj in 0..k {
377 s = s + m_inv[[ii, jj]] * wtpx[jj];
378 }
379 quad = quad + wtpx[ii] * s;
380 }
381 trace_sum = trace_sum + xpsiinvx - quad;
382 }
383 let trace_term = trace_sum / n_f;
384
385 let half = F::from(0.5).unwrap();
387 -n_f * half * (p_f * two_pi.ln() + log_det_sigma + trace_term)
388}
389
390impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FactorAnalysis<F> {
395 type Fitted = FittedFactorAnalysis<F>;
396 type Error = FerroError;
397
398 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFactorAnalysis<F>, FerroError> {
406 let (n_samples, n_features) = x.dim();
407
408 if self.n_components == 0 {
409 return Err(FerroError::InvalidParameter {
410 name: "n_components".into(),
411 reason: "must be at least 1".into(),
412 });
413 }
414 if self.n_components > n_features {
415 return Err(FerroError::InvalidParameter {
416 name: "n_components".into(),
417 reason: format!(
418 "n_components ({}) exceeds n_features ({})",
419 self.n_components, n_features
420 ),
421 });
422 }
423 if n_samples < 2 {
424 return Err(FerroError::InsufficientSamples {
425 required: 2,
426 actual: n_samples,
427 context: "FactorAnalysis requires at least 2 samples".into(),
428 });
429 }
430
431 let k = self.n_components;
432 let p = n_features;
433 let n_f = F::from(n_samples).unwrap();
434
435 let mut mean = Array1::<F>::zeros(p);
437 for j in 0..p {
438 let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
439 mean[j] = s / n_f;
440 }
441 let mut xc = x.to_owned();
442 for mut row in xc.rows_mut() {
443 for (v, &m) in row.iter_mut().zip(mean.iter()) {
444 *v = *v - m;
445 }
446 }
447
448 let seed = self.random_state.unwrap_or(42);
450 let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
451 let std_normal = StandardNormal;
452 let mut w = Array2::<F>::zeros((p, k));
453 let scale = F::from(0.01).unwrap();
454 for i in 0..p {
455 for j in 0..k {
456 let v: f64 = std_normal.sample(&mut rng);
457 w[[i, j]] = F::from(v).unwrap() * scale;
458 }
459 }
460 let mut psi = Array1::<F>::from_elem(p, F::one());
461
462 let mut prev_ll = F::neg_infinity();
463 let mut n_iter = 0usize;
464 let tol_f = F::from(self.tol).unwrap();
465
466 for iter in 0..self.max_iter {
467 let mut wzw = Array2::<F>::zeros((k, k));
470 for i in 0..k {
471 for j in 0..k {
472 let mut s = F::zero();
473 for d in 0..p {
474 s = s + w[[d, i]] * w[[d, j]] / psi[d];
475 }
476 wzw[[i, j]] = s;
477 }
478 }
479 for i in 0..k {
480 wzw[[i, i]] = wzw[[i, i]] + F::one();
481 }
482 let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
483 message: "FactorAnalysis: (I + W^T Ψ⁻¹ W) is singular".into(),
484 })?;
485
486 let mut beta = Array2::<F>::zeros((k, p));
488 for i in 0..k {
489 for d in 0..p {
490 let mut s = F::zero();
491 for j in 0..k {
492 s = s + sigma_z[[i, j]] * w[[d, j]];
493 }
494 beta[[i, d]] = s / psi[d];
495 }
496 }
497
498 let ez = beta.dot(&xc.t()); let ezz_t_sum = sigma_z.mapv(|v| v * n_f) + ez.dot(&ez.t()); let xc_ez_t = xc.t().dot(&ez.t()); let ezz_t_inv =
515 cholesky_inv(&ezz_t_sum).map_err(|_| FerroError::NumericalInstability {
516 message: "FactorAnalysis: E[ZZ^T] is singular in M-step".into(),
517 })?;
518
519 let w_new = xc_ez_t.dot(&ezz_t_inv); let mut psi_new = Array1::<F>::zeros(p);
528 for d in 0..p {
529 let var_d = xc
531 .column(d)
532 .iter()
533 .copied()
534 .map(|v| v * v)
535 .fold(F::zero(), |a, b| a + b)
536 / n_f;
537 let mut ez_xd = Array1::<F>::zeros(k);
540 for kk in 0..k {
541 let s = (0..n_samples)
542 .map(|i| ez[[kk, i]] * xc[[i, d]])
543 .fold(F::zero(), |a, b| a + b);
544 ez_xd[kk] = s / n_f;
545 }
546 let wd = w_new.row(d);
547 let corr = wd
548 .iter()
549 .zip(ez_xd.iter())
550 .map(|(&wi, &ei)| wi * ei)
551 .fold(F::zero(), |a, b| a + b);
552 let psi_d = var_d - corr;
553 psi_new[d] = if psi_d > F::from(1e-6).unwrap() {
554 psi_d
555 } else {
556 F::from(1e-6).unwrap()
557 };
558 }
559
560 w = w_new;
561 psi = psi_new;
562
563 let ll = compute_log_likelihood(&xc, &w, &psi);
565 let ll_change = (ll - prev_ll).abs();
566 n_iter = iter + 1;
567 if ll_change < tol_f && iter > 0 {
568 prev_ll = ll;
569 break;
570 }
571 prev_ll = ll;
572 }
573
574 Ok(FittedFactorAnalysis {
575 components: w,
576 noise_variance: psi,
577 mean,
578 n_iter,
579 log_likelihood: prev_ll,
580 })
581 }
582}
583
584impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFactorAnalysis<F> {
589 type Output = Array2<F>;
590 type Error = FerroError;
591
592 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
601 let n_features = self.mean.len();
602 if x.ncols() != n_features {
603 return Err(FerroError::ShapeMismatch {
604 expected: vec![x.nrows(), n_features],
605 actual: vec![x.nrows(), x.ncols()],
606 context: "FittedFactorAnalysis::transform".into(),
607 });
608 }
609 let (n_samples, _) = x.dim();
610 let k = self.components.ncols();
611
612 let mut xc = x.to_owned();
614 for mut row in xc.rows_mut() {
615 for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
616 *v = *v - m;
617 }
618 }
619
620 let mut wzw = Array2::<F>::zeros((k, k));
622 for i in 0..k {
623 for j in 0..k {
624 let mut s = F::zero();
625 for d in 0..n_features {
626 s = s + self.components[[d, i]] * self.components[[d, j]]
627 / self.noise_variance[d];
628 }
629 wzw[[i, j]] = s;
630 }
631 }
632 for i in 0..k {
633 wzw[[i, i]] = wzw[[i, i]] + F::one();
634 }
635 let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
636 message: "FittedFactorAnalysis::transform: (I + W^T Ψ⁻¹ W) is singular".into(),
637 })?;
638
639 let mut beta = Array2::<F>::zeros((k, n_features));
641 for i in 0..k {
642 for d in 0..n_features {
643 let mut s = F::zero();
644 for j in 0..k {
645 s = s + sigma_z[[i, j]] * self.components[[d, j]];
646 }
647 beta[[i, d]] = s / self.noise_variance[d];
648 }
649 }
650
651 let ez = beta.dot(&xc.t()); let scores = ez.t().to_owned(); assert_eq!(scores.dim(), (n_samples, k));
655 Ok(scores)
656 }
657}
658
659impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FactorAnalysis<F> {
664 fn fit_pipeline(
670 &self,
671 x: &Array2<F>,
672 _y: &Array1<F>,
673 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
674 let fitted = self.fit(x, &())?;
675 Ok(Box::new(fitted))
676 }
677}
678
679impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFactorAnalysis<F> {
680 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
686 self.transform(x)
687 }
688}
689
690#[cfg(test)]
695mod tests {
696 use super::*;
697 use approx::assert_abs_diff_eq;
698 use ndarray::Array2;
699
700 fn simple_data() -> Array2<f64> {
701 Array2::from_shape_vec(
703 (10, 4),
704 vec![
705 1.0, 2.0, 1.5, 3.0, 1.1, 2.1, 1.6, 3.1, 0.9, 1.9, 1.4, 2.9, 2.0, 4.0, 3.0, 6.0,
706 2.1, 4.1, 3.1, 6.1, 1.9, 3.9, 2.9, 5.9, 0.5, 1.0, 0.7, 1.5, 0.4, 0.9, 0.6, 1.4,
707 0.6, 1.1, 0.8, 1.6, 1.5, 3.0, 2.2, 4.5,
708 ],
709 )
710 .unwrap()
711 }
712
713 #[test]
714 fn test_fa_fit_returns_fitted() {
715 let fa = FactorAnalysis::<f64>::new(2);
716 let x = simple_data();
717 let fitted = fa.fit(&x, &()).unwrap();
718 assert_eq!(fitted.components().dim(), (4, 2));
719 }
720
721 #[test]
722 fn test_fa_transform_shape() {
723 let fa = FactorAnalysis::<f64>::new(2);
724 let x = simple_data();
725 let fitted = fa.fit(&x, &()).unwrap();
726 let scores = fitted.transform(&x).unwrap();
727 assert_eq!(scores.dim(), (10, 2));
728 }
729
730 #[test]
731 fn test_fa_transform_new_data() {
732 let fa = FactorAnalysis::<f64>::new(1);
733 let x = simple_data();
734 let fitted = fa.fit(&x, &()).unwrap();
735 let x_new = Array2::from_shape_vec(
736 (3, 4),
737 vec![1.0, 2.0, 1.5, 3.0, 2.0, 4.0, 3.0, 6.0, 0.5, 1.0, 0.7, 1.5],
738 )
739 .unwrap();
740 let scores = fitted.transform(&x_new).unwrap();
741 assert_eq!(scores.dim(), (3, 1));
742 }
743
744 #[test]
745 fn test_fa_noise_variance_positive() {
746 let fa = FactorAnalysis::<f64>::new(1);
747 let x = simple_data();
748 let fitted = fa.fit(&x, &()).unwrap();
749 for &v in fitted.noise_variance() {
750 assert!(v > 0.0, "noise variance must be positive, got {v}");
751 }
752 }
753
754 #[test]
755 fn test_fa_mean_shape() {
756 let fa = FactorAnalysis::<f64>::new(1);
757 let x = simple_data();
758 let fitted = fa.fit(&x, &()).unwrap();
759 assert_eq!(fitted.mean().len(), 4);
760 }
761
762 #[test]
763 fn test_fa_n_iter_positive() {
764 let fa = FactorAnalysis::<f64>::new(1);
765 let x = simple_data();
766 let fitted = fa.fit(&x, &()).unwrap();
767 assert!(fitted.n_iter() >= 1);
768 }
769
770 #[test]
771 fn test_fa_log_likelihood_finite() {
772 let fa = FactorAnalysis::<f64>::new(1);
773 let x = simple_data();
774 let fitted = fa.fit(&x, &()).unwrap();
775 assert!(fitted.log_likelihood().is_finite());
776 }
777
778 #[test]
779 fn test_fa_error_zero_components() {
780 let fa = FactorAnalysis::<f64>::new(0);
781 let x = simple_data();
782 assert!(fa.fit(&x, &()).is_err());
783 }
784
785 #[test]
786 fn test_fa_error_too_many_components() {
787 let fa = FactorAnalysis::<f64>::new(10); let x = simple_data();
789 assert!(fa.fit(&x, &()).is_err());
790 }
791
792 #[test]
793 fn test_fa_error_insufficient_samples() {
794 let fa = FactorAnalysis::<f64>::new(1);
795 let x = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
796 assert!(fa.fit(&x, &()).is_err());
797 }
798
799 #[test]
800 fn test_fa_transform_shape_mismatch() {
801 let fa = FactorAnalysis::<f64>::new(1);
802 let x = simple_data();
803 let fitted = fa.fit(&x, &()).unwrap();
804 let x_bad = Array2::<f64>::zeros((3, 7));
805 assert!(fitted.transform(&x_bad).is_err());
806 }
807
808 #[test]
809 fn test_fa_reproducible_with_seed() {
810 let fa1 = FactorAnalysis::<f64>::new(2).with_random_state(42);
811 let fa2 = FactorAnalysis::<f64>::new(2).with_random_state(42);
812 let x = simple_data();
813 let f1 = fa1.fit(&x, &()).unwrap();
814 let f2 = fa2.fit(&x, &()).unwrap();
815 let c1 = f1.components();
816 let c2 = f2.components();
817 for (a, b) in c1.iter().zip(c2.iter()) {
818 assert_abs_diff_eq!(a, b, epsilon = 1e-12);
819 }
820 }
821
822 #[test]
823 fn test_fa_different_seeds_differ() {
824 let fa1 = FactorAnalysis::<f64>::new(2)
825 .with_random_state(0)
826 .with_max_iter(1);
827 let fa2 = FactorAnalysis::<f64>::new(2)
828 .with_random_state(99)
829 .with_max_iter(1);
830 let x = simple_data();
831 let f1 = fa1.fit(&x, &()).unwrap();
832 let f2 = fa2.fit(&x, &()).unwrap();
833 let diff: f64 = f1
835 .components()
836 .iter()
837 .zip(f2.components().iter())
838 .map(|(a, b)| (a - b).abs())
839 .sum();
840 let _ = diff; }
843
844 #[test]
845 fn test_fa_components_accessor() {
846 let fa = FactorAnalysis::<f64>::new(2);
847 let x = simple_data();
848 let fitted = fa.fit(&x, &()).unwrap();
849 assert_eq!(fitted.components().ncols(), 2);
850 assert_eq!(fitted.components().nrows(), 4);
851 }
852
853 #[test]
854 fn test_fa_n_components_getter() {
855 let fa = FactorAnalysis::<f64>::new(3);
856 assert_eq!(fa.n_components(), 3);
857 }
858
859 #[test]
860 fn test_fa_pipeline_transformer() {
861 use ferrolearn_core::pipeline::PipelineTransformer;
862 let fa = FactorAnalysis::<f64>::new(2);
863 let x = simple_data();
864 let y = Array1::<f64>::zeros(10);
865 let fitted = fa.fit_pipeline(&x, &y).unwrap();
866 let out = fitted.transform_pipeline(&x).unwrap();
867 assert_eq!(out.ncols(), 2);
868 }
869
870 #[test]
871 fn test_fa_scores_not_all_zero() {
872 let fa = FactorAnalysis::<f64>::new(2);
873 let x = simple_data();
874 let fitted = fa.fit(&x, &()).unwrap();
875 let scores = fitted.transform(&x).unwrap();
876 let total: f64 = scores.iter().map(|v| v.abs()).sum();
877 assert!(total > 0.0, "Factor scores should not all be zero");
878 }
879}