1use ferrolearn_core::error::FerroError;
45use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
46use ferrolearn_core::traits::{Fit, Transform};
47use ndarray::{Array1, Array2};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand_distr::{Distribution, StandardNormal};
51
52#[derive(Debug, Clone)]
65pub struct FactorAnalysis<F> {
66 n_components: usize,
68 max_iter: usize,
70 tol: f64,
72 random_state: Option<u64>,
74 _marker: std::marker::PhantomData<F>,
75}
76
77impl<F: Float + Send + Sync + 'static> FactorAnalysis<F> {
78 #[must_use]
82 pub fn new(n_components: usize) -> Self {
83 Self {
84 n_components,
85 max_iter: 1000,
86 tol: 1e-3,
87 random_state: None,
88 _marker: std::marker::PhantomData,
89 }
90 }
91
92 #[must_use]
94 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
95 self.max_iter = max_iter;
96 self
97 }
98
99 #[must_use]
101 pub fn with_tol(mut self, tol: f64) -> Self {
102 self.tol = tol;
103 self
104 }
105
106 #[must_use]
108 pub fn with_random_state(mut self, seed: u64) -> Self {
109 self.random_state = Some(seed);
110 self
111 }
112
113 #[must_use]
115 pub fn n_components(&self) -> usize {
116 self.n_components
117 }
118}
119
120impl<F: Float + Send + Sync + 'static> Default for FactorAnalysis<F> {
121 fn default() -> Self {
122 Self::new(1)
123 }
124}
125
126#[derive(Debug, Clone)]
135pub struct FittedFactorAnalysis<F> {
136 components: Array2<F>,
138
139 noise_variance: Array1<F>,
141
142 mean: Array1<F>,
144
145 n_iter: usize,
147
148 log_likelihood: F,
150}
151
152impl<F: Float + Send + Sync + 'static> FittedFactorAnalysis<F> {
153 #[must_use]
155 pub fn components(&self) -> &Array2<F> {
156 &self.components
157 }
158
159 #[must_use]
161 pub fn noise_variance(&self) -> &Array1<F> {
162 &self.noise_variance
163 }
164
165 #[must_use]
167 pub fn mean(&self) -> &Array1<F> {
168 &self.mean
169 }
170
171 #[must_use]
173 pub fn n_iter(&self) -> usize {
174 self.n_iter
175 }
176
177 #[must_use]
179 pub fn log_likelihood(&self) -> F {
180 self.log_likelihood
181 }
182}
183
184fn cholesky_inv<F: Float>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
190 let n = a.nrows();
191 let mut l = Array2::<F>::zeros((n, n));
193 for i in 0..n {
194 for j in 0..=i {
195 let mut s = a[[i, j]];
196 for k in 0..j {
197 s = s - l[[i, k]] * l[[j, k]];
198 }
199 if i == j {
200 if s <= F::zero() {
201 s = F::from(1e-10).unwrap();
203 }
204 l[[i, j]] = s.sqrt();
205 } else {
206 l[[i, j]] = s / l[[j, j]];
207 }
208 }
209 }
210 let mut l_inv = Array2::<F>::zeros((n, n));
212 for j in 0..n {
213 l_inv[[j, j]] = F::one() / l[[j, j]];
214 for i in (j + 1)..n {
215 let mut s = F::zero();
216 for k in j..i {
217 s = s + l[[i, k]] * l_inv[[k, j]];
218 }
219 l_inv[[i, j]] = -s / l[[i, i]];
220 }
221 }
222 let mut inv = Array2::<F>::zeros((n, n));
224 for i in 0..n {
225 for j in 0..n {
226 let mut s = F::zero();
227 let start = i.max(j);
228 for k in start..n {
229 s = s + l_inv[[k, i]] * l_inv[[k, j]];
230 }
231 inv[[i, j]] = s;
232 }
233 }
234 Ok(inv)
235}
236
237fn compute_log_likelihood<F: Float + Send + Sync + 'static>(
242 x_centered: &Array2<F>,
243 w: &Array2<F>,
244 psi: &Array1<F>,
245) -> F {
246 let (n, p) = x_centered.dim();
247 let k = w.ncols();
248 let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap();
252 let n_f = F::from(n).unwrap();
253 let p_f = F::from(p).unwrap();
254
255 let mut wtpsiw = Array2::<F>::zeros((k, k));
257 for i in 0..k {
258 for j in 0..k {
259 let mut s = F::zero();
260 for d in 0..p {
261 s = s + w[[d, i]] * w[[d, j]] / psi[d];
262 }
263 wtpsiw[[i, j]] = s;
264 }
265 }
266 for i in 0..k {
268 wtpsiw[[i, i]] = wtpsiw[[i, i]] + F::one();
269 }
270 let mut log_det_inner = F::zero();
272 {
273 let mut l = Array2::<F>::zeros((k, k));
274 for i in 0..k {
275 for j in 0..=i {
276 let mut s = wtpsiw[[i, j]];
277 for kk in 0..j {
278 s = s - l[[i, kk]] * l[[j, kk]];
279 }
280 if i == j {
281 s = if s > F::zero() {
282 s
283 } else {
284 F::from(1e-30).unwrap()
285 };
286 l[[i, j]] = s.sqrt();
287 log_det_inner = log_det_inner + l[[i, j]].ln();
288 } else {
289 l[[i, j]] = s / l[[j, j]];
290 }
291 }
292 }
293 log_det_inner = log_det_inner * F::from(2.0).unwrap();
294 }
295 let log_det_psi: F = psi
296 .iter()
297 .copied()
298 .map(|v| {
299 let v_clamped = if v > F::zero() {
300 v
301 } else {
302 F::from(1e-30).unwrap()
303 };
304 v_clamped.ln()
305 })
306 .fold(F::zero(), |a, b| a + b);
307 let log_det_sigma = log_det_inner + log_det_psi;
308
309 let m_inv = match cholesky_inv(&wtpsiw) {
320 Ok(inv) => inv,
321 Err(_) => return F::neg_infinity(),
322 };
323
324 let mut trace_sum = F::zero();
325 for i in 0..n {
326 let mut psi_inv_x = Array1::<F>::zeros(p);
328 let mut xpsiinvx = F::zero();
329 for d in 0..p {
330 psi_inv_x[d] = x_centered[[i, d]] / psi[d];
331 xpsiinvx = xpsiinvx + x_centered[[i, d]] * psi_inv_x[d];
332 }
333 let mut wtpx = Array1::<F>::zeros(k);
335 for kk in 0..k {
336 let mut s = F::zero();
337 for d in 0..p {
338 s = s + w[[d, kk]] * psi_inv_x[d];
339 }
340 wtpx[kk] = s;
341 }
342 let mut quad = F::zero();
344 for ii in 0..k {
345 let mut s = F::zero();
346 for jj in 0..k {
347 s = s + m_inv[[ii, jj]] * wtpx[jj];
348 }
349 quad = quad + wtpx[ii] * s;
350 }
351 trace_sum = trace_sum + xpsiinvx - quad;
352 }
353 let trace_term = trace_sum / n_f;
354
355 let half = F::from(0.5).unwrap();
357 -n_f * half * (p_f * two_pi.ln() + log_det_sigma + trace_term)
358}
359
360impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FactorAnalysis<F> {
365 type Fitted = FittedFactorAnalysis<F>;
366 type Error = FerroError;
367
368 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFactorAnalysis<F>, FerroError> {
376 let (n_samples, n_features) = x.dim();
377
378 if self.n_components == 0 {
379 return Err(FerroError::InvalidParameter {
380 name: "n_components".into(),
381 reason: "must be at least 1".into(),
382 });
383 }
384 if self.n_components > n_features {
385 return Err(FerroError::InvalidParameter {
386 name: "n_components".into(),
387 reason: format!(
388 "n_components ({}) exceeds n_features ({})",
389 self.n_components, n_features
390 ),
391 });
392 }
393 if n_samples < 2 {
394 return Err(FerroError::InsufficientSamples {
395 required: 2,
396 actual: n_samples,
397 context: "FactorAnalysis requires at least 2 samples".into(),
398 });
399 }
400
401 let k = self.n_components;
402 let p = n_features;
403 let n_f = F::from(n_samples).unwrap();
404
405 let mut mean = Array1::<F>::zeros(p);
407 for j in 0..p {
408 let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
409 mean[j] = s / n_f;
410 }
411 let mut xc = x.to_owned();
412 for mut row in xc.rows_mut() {
413 for (v, &m) in row.iter_mut().zip(mean.iter()) {
414 *v = *v - m;
415 }
416 }
417
418 let seed = self.random_state.unwrap_or(42);
420 let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
421 let std_normal = StandardNormal;
422 let mut w = Array2::<F>::zeros((p, k));
423 let scale = F::from(0.01).unwrap();
424 for i in 0..p {
425 for j in 0..k {
426 let v: f64 = std_normal.sample(&mut rng);
427 w[[i, j]] = F::from(v).unwrap() * scale;
428 }
429 }
430 let mut psi = Array1::<F>::from_elem(p, F::one());
431
432 let mut prev_ll = F::neg_infinity();
433 let mut n_iter = 0usize;
434 let tol_f = F::from(self.tol).unwrap();
435
436 for iter in 0..self.max_iter {
437 let mut wzw = Array2::<F>::zeros((k, k));
440 for i in 0..k {
441 for j in 0..k {
442 let mut s = F::zero();
443 for d in 0..p {
444 s = s + w[[d, i]] * w[[d, j]] / psi[d];
445 }
446 wzw[[i, j]] = s;
447 }
448 }
449 for i in 0..k {
450 wzw[[i, i]] = wzw[[i, i]] + F::one();
451 }
452 let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
453 message: "FactorAnalysis: (I + W^T Ψ⁻¹ W) is singular".into(),
454 })?;
455
456 let mut beta = Array2::<F>::zeros((k, p));
458 for i in 0..k {
459 for d in 0..p {
460 let mut s = F::zero();
461 for j in 0..k {
462 s = s + sigma_z[[i, j]] * w[[d, j]];
463 }
464 beta[[i, d]] = s / psi[d];
465 }
466 }
467
468 let ez = beta.dot(&xc.t()); let ezz_t_sum = sigma_z.mapv(|v| v * n_f) + ez.dot(&ez.t()); let xc_ez_t = xc.t().dot(&ez.t()); let ezz_t_inv =
485 cholesky_inv(&ezz_t_sum).map_err(|_| FerroError::NumericalInstability {
486 message: "FactorAnalysis: E[ZZ^T] is singular in M-step".into(),
487 })?;
488
489 let w_new = xc_ez_t.dot(&ezz_t_inv); let mut psi_new = Array1::<F>::zeros(p);
498 for d in 0..p {
499 let var_d = xc
501 .column(d)
502 .iter()
503 .copied()
504 .map(|v| v * v)
505 .fold(F::zero(), |a, b| a + b)
506 / n_f;
507 let mut ez_xd = Array1::<F>::zeros(k);
510 for kk in 0..k {
511 let s = (0..n_samples)
512 .map(|i| ez[[kk, i]] * xc[[i, d]])
513 .fold(F::zero(), |a, b| a + b);
514 ez_xd[kk] = s / n_f;
515 }
516 let wd = w_new.row(d);
517 let corr = wd
518 .iter()
519 .zip(ez_xd.iter())
520 .map(|(&wi, &ei)| wi * ei)
521 .fold(F::zero(), |a, b| a + b);
522 let psi_d = var_d - corr;
523 psi_new[d] = if psi_d > F::from(1e-6).unwrap() {
524 psi_d
525 } else {
526 F::from(1e-6).unwrap()
527 };
528 }
529
530 w = w_new;
531 psi = psi_new;
532
533 let ll = compute_log_likelihood(&xc, &w, &psi);
535 let ll_change = (ll - prev_ll).abs();
536 n_iter = iter + 1;
537 if ll_change < tol_f && iter > 0 {
538 prev_ll = ll;
539 break;
540 }
541 prev_ll = ll;
542 }
543
544 Ok(FittedFactorAnalysis {
545 components: w,
546 noise_variance: psi,
547 mean,
548 n_iter,
549 log_likelihood: prev_ll,
550 })
551 }
552}
553
554impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFactorAnalysis<F> {
559 type Output = Array2<F>;
560 type Error = FerroError;
561
562 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
571 let n_features = self.mean.len();
572 if x.ncols() != n_features {
573 return Err(FerroError::ShapeMismatch {
574 expected: vec![x.nrows(), n_features],
575 actual: vec![x.nrows(), x.ncols()],
576 context: "FittedFactorAnalysis::transform".into(),
577 });
578 }
579 let (n_samples, _) = x.dim();
580 let k = self.components.ncols();
581
582 let mut xc = x.to_owned();
584 for mut row in xc.rows_mut() {
585 for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
586 *v = *v - m;
587 }
588 }
589
590 let mut wzw = Array2::<F>::zeros((k, k));
592 for i in 0..k {
593 for j in 0..k {
594 let mut s = F::zero();
595 for d in 0..n_features {
596 s = s + self.components[[d, i]] * self.components[[d, j]]
597 / self.noise_variance[d];
598 }
599 wzw[[i, j]] = s;
600 }
601 }
602 for i in 0..k {
603 wzw[[i, i]] = wzw[[i, i]] + F::one();
604 }
605 let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
606 message: "FittedFactorAnalysis::transform: (I + W^T Ψ⁻¹ W) is singular".into(),
607 })?;
608
609 let mut beta = Array2::<F>::zeros((k, n_features));
611 for i in 0..k {
612 for d in 0..n_features {
613 let mut s = F::zero();
614 for j in 0..k {
615 s = s + sigma_z[[i, j]] * self.components[[d, j]];
616 }
617 beta[[i, d]] = s / self.noise_variance[d];
618 }
619 }
620
621 let ez = beta.dot(&xc.t()); let scores = ez.t().to_owned(); assert_eq!(scores.dim(), (n_samples, k));
625 Ok(scores)
626 }
627}
628
629impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FactorAnalysis<F> {
634 fn fit_pipeline(
640 &self,
641 x: &Array2<F>,
642 _y: &Array1<F>,
643 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
644 let fitted = self.fit(x, &())?;
645 Ok(Box::new(fitted))
646 }
647}
648
649impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFactorAnalysis<F> {
650 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
656 self.transform(x)
657 }
658}
659
660#[cfg(test)]
665mod tests {
666 use super::*;
667 use approx::assert_abs_diff_eq;
668 use ndarray::Array2;
669
670 fn simple_data() -> Array2<f64> {
671 Array2::from_shape_vec(
673 (10, 4),
674 vec![
675 1.0, 2.0, 1.5, 3.0, 1.1, 2.1, 1.6, 3.1, 0.9, 1.9, 1.4, 2.9, 2.0, 4.0, 3.0, 6.0,
676 2.1, 4.1, 3.1, 6.1, 1.9, 3.9, 2.9, 5.9, 0.5, 1.0, 0.7, 1.5, 0.4, 0.9, 0.6, 1.4,
677 0.6, 1.1, 0.8, 1.6, 1.5, 3.0, 2.2, 4.5,
678 ],
679 )
680 .unwrap()
681 }
682
683 #[test]
684 fn test_fa_fit_returns_fitted() {
685 let fa = FactorAnalysis::<f64>::new(2);
686 let x = simple_data();
687 let fitted = fa.fit(&x, &()).unwrap();
688 assert_eq!(fitted.components().dim(), (4, 2));
689 }
690
691 #[test]
692 fn test_fa_transform_shape() {
693 let fa = FactorAnalysis::<f64>::new(2);
694 let x = simple_data();
695 let fitted = fa.fit(&x, &()).unwrap();
696 let scores = fitted.transform(&x).unwrap();
697 assert_eq!(scores.dim(), (10, 2));
698 }
699
700 #[test]
701 fn test_fa_transform_new_data() {
702 let fa = FactorAnalysis::<f64>::new(1);
703 let x = simple_data();
704 let fitted = fa.fit(&x, &()).unwrap();
705 let x_new = Array2::from_shape_vec(
706 (3, 4),
707 vec![1.0, 2.0, 1.5, 3.0, 2.0, 4.0, 3.0, 6.0, 0.5, 1.0, 0.7, 1.5],
708 )
709 .unwrap();
710 let scores = fitted.transform(&x_new).unwrap();
711 assert_eq!(scores.dim(), (3, 1));
712 }
713
714 #[test]
715 fn test_fa_noise_variance_positive() {
716 let fa = FactorAnalysis::<f64>::new(1);
717 let x = simple_data();
718 let fitted = fa.fit(&x, &()).unwrap();
719 for &v in fitted.noise_variance().iter() {
720 assert!(v > 0.0, "noise variance must be positive, got {v}");
721 }
722 }
723
724 #[test]
725 fn test_fa_mean_shape() {
726 let fa = FactorAnalysis::<f64>::new(1);
727 let x = simple_data();
728 let fitted = fa.fit(&x, &()).unwrap();
729 assert_eq!(fitted.mean().len(), 4);
730 }
731
732 #[test]
733 fn test_fa_n_iter_positive() {
734 let fa = FactorAnalysis::<f64>::new(1);
735 let x = simple_data();
736 let fitted = fa.fit(&x, &()).unwrap();
737 assert!(fitted.n_iter() >= 1);
738 }
739
740 #[test]
741 fn test_fa_log_likelihood_finite() {
742 let fa = FactorAnalysis::<f64>::new(1);
743 let x = simple_data();
744 let fitted = fa.fit(&x, &()).unwrap();
745 assert!(fitted.log_likelihood().is_finite());
746 }
747
748 #[test]
749 fn test_fa_error_zero_components() {
750 let fa = FactorAnalysis::<f64>::new(0);
751 let x = simple_data();
752 assert!(fa.fit(&x, &()).is_err());
753 }
754
755 #[test]
756 fn test_fa_error_too_many_components() {
757 let fa = FactorAnalysis::<f64>::new(10); let x = simple_data();
759 assert!(fa.fit(&x, &()).is_err());
760 }
761
762 #[test]
763 fn test_fa_error_insufficient_samples() {
764 let fa = FactorAnalysis::<f64>::new(1);
765 let x = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
766 assert!(fa.fit(&x, &()).is_err());
767 }
768
769 #[test]
770 fn test_fa_transform_shape_mismatch() {
771 let fa = FactorAnalysis::<f64>::new(1);
772 let x = simple_data();
773 let fitted = fa.fit(&x, &()).unwrap();
774 let x_bad = Array2::<f64>::zeros((3, 7));
775 assert!(fitted.transform(&x_bad).is_err());
776 }
777
778 #[test]
779 fn test_fa_reproducible_with_seed() {
780 let fa1 = FactorAnalysis::<f64>::new(2).with_random_state(42);
781 let fa2 = FactorAnalysis::<f64>::new(2).with_random_state(42);
782 let x = simple_data();
783 let f1 = fa1.fit(&x, &()).unwrap();
784 let f2 = fa2.fit(&x, &()).unwrap();
785 let c1 = f1.components();
786 let c2 = f2.components();
787 for (a, b) in c1.iter().zip(c2.iter()) {
788 assert_abs_diff_eq!(a, b, epsilon = 1e-12);
789 }
790 }
791
792 #[test]
793 fn test_fa_different_seeds_differ() {
794 let fa1 = FactorAnalysis::<f64>::new(2)
795 .with_random_state(0)
796 .with_max_iter(1);
797 let fa2 = FactorAnalysis::<f64>::new(2)
798 .with_random_state(99)
799 .with_max_iter(1);
800 let x = simple_data();
801 let f1 = fa1.fit(&x, &()).unwrap();
802 let f2 = fa2.fit(&x, &()).unwrap();
803 let diff: f64 = f1
805 .components()
806 .iter()
807 .zip(f2.components().iter())
808 .map(|(a, b)| (a - b).abs())
809 .sum();
810 let _ = diff; }
813
814 #[test]
815 fn test_fa_components_accessor() {
816 let fa = FactorAnalysis::<f64>::new(2);
817 let x = simple_data();
818 let fitted = fa.fit(&x, &()).unwrap();
819 assert_eq!(fitted.components().ncols(), 2);
820 assert_eq!(fitted.components().nrows(), 4);
821 }
822
823 #[test]
824 fn test_fa_n_components_getter() {
825 let fa = FactorAnalysis::<f64>::new(3);
826 assert_eq!(fa.n_components(), 3);
827 }
828
829 #[test]
830 fn test_fa_pipeline_transformer() {
831 use ferrolearn_core::pipeline::PipelineTransformer;
832 let fa = FactorAnalysis::<f64>::new(2);
833 let x = simple_data();
834 let y = Array1::<f64>::zeros(10);
835 let fitted = fa.fit_pipeline(&x, &y).unwrap();
836 let out = fitted.transform_pipeline(&x).unwrap();
837 assert_eq!(out.ncols(), 2);
838 }
839
840 #[test]
841 fn test_fa_scores_not_all_zero() {
842 let fa = FactorAnalysis::<f64>::new(2);
843 let x = simple_data();
844 let fitted = fa.fit(&x, &()).unwrap();
845 let scores = fitted.transform(&x).unwrap();
846 let total: f64 = scores.iter().map(|v| v.abs()).sum();
847 assert!(total > 0.0, "Factor scores should not all be zero");
848 }
849}