Skip to main content

ferrolearn_decomp/
factor_analysis.rs

1//! Factor Analysis (FA) via the EM algorithm.
2//!
3//! Factor Analysis assumes that data is generated by a linear combination of
4//! latent factors plus independent Gaussian noise:
5//!
6//! ```text
7//! X = W Z + μ + ε,   Z ~ N(0, I),   ε ~ N(0, diag(ψ))
8//! ```
9//!
10//! where:
11//! - `W` is the `(n_features × n_components)` loading matrix,
12//! - `Z` is the `(n_components,)` latent factor vector,
13//! - `ψ` is the `(n_features,)` noise variance vector.
14//!
15//! # Algorithm
16//!
17//! 1. Centre the data: `X_c = X - μ`.
18//! 2. **E-step**: compute the posterior mean and covariance of `Z`:
19//!    ```text
20//!    Σ_z = (I + W^T diag(ψ)⁻¹ W)⁻¹
21//!    E[Z | X] = Σ_z W^T diag(ψ)⁻¹ X_c^T
22//!    ```
23//! 3. **M-step**: update `W` and `ψ` via maximum-likelihood closed-form
24//!    updates.
25//! 4. Repeat until convergence (log-likelihood change < `tol`).
26//!
27//! # Examples
28//!
29//! ```
30//! use ferrolearn_decomp::factor_analysis::FactorAnalysis;
31//! use ferrolearn_core::traits::{Fit, Transform};
32//! use ndarray::Array2;
33//!
34//! let fa = FactorAnalysis::new(2);
35//! let x = Array2::from_shape_vec(
36//!     (10, 4),
37//!     (0..40).map(|v| v as f64 * 0.1 + (v % 3) as f64 * 0.5).collect(),
38//! ).unwrap();
39//! let fitted = fa.fit(&x, &()).unwrap();
40//! let scores = fitted.transform(&x).unwrap();
41//! assert_eq!(scores.ncols(), 2);
42//! ```
43
44use ferrolearn_core::error::FerroError;
45use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
46use ferrolearn_core::traits::{Fit, Transform};
47use ndarray::{Array1, Array2};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand_distr::{Distribution, StandardNormal};
51
52// ---------------------------------------------------------------------------
53// FactorAnalysis (unfitted)
54// ---------------------------------------------------------------------------
55
56/// Factor Analysis configuration.
57///
58/// Calling [`Fit::fit`] fits the EM algorithm and returns a
59/// [`FittedFactorAnalysis`].
60///
61/// # Type Parameters
62///
63/// - `F`: The floating-point scalar type.
64#[derive(Debug, Clone)]
65pub struct FactorAnalysis<F> {
66    /// Number of latent factors to extract.
67    n_components: usize,
68    /// Maximum number of EM iterations.
69    max_iter: usize,
70    /// Convergence tolerance on the log-likelihood change.
71    tol: f64,
72    /// Optional random seed for reproducibility.
73    random_state: Option<u64>,
74    _marker: std::marker::PhantomData<F>,
75}
76
77impl<F: Float + Send + Sync + 'static> FactorAnalysis<F> {
78    /// Create a new `FactorAnalysis` with `n_components` factors.
79    ///
80    /// Defaults: `max_iter = 1000`, `tol = 1e-3`, no fixed random seed.
81    #[must_use]
82    pub fn new(n_components: usize) -> Self {
83        Self {
84            n_components,
85            max_iter: 1000,
86            tol: 1e-3,
87            random_state: None,
88            _marker: std::marker::PhantomData,
89        }
90    }
91
92    /// Set the maximum number of EM iterations.
93    #[must_use]
94    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
95        self.max_iter = max_iter;
96        self
97    }
98
99    /// Set the convergence tolerance.
100    #[must_use]
101    pub fn with_tol(mut self, tol: f64) -> Self {
102        self.tol = tol;
103        self
104    }
105
106    /// Set the random seed for reproducibility.
107    #[must_use]
108    pub fn with_random_state(mut self, seed: u64) -> Self {
109        self.random_state = Some(seed);
110        self
111    }
112
113    /// Return the number of latent factors.
114    #[must_use]
115    pub fn n_components(&self) -> usize {
116        self.n_components
117    }
118}
119
120impl<F: Float + Send + Sync + 'static> Default for FactorAnalysis<F> {
121    fn default() -> Self {
122        Self::new(1)
123    }
124}
125
126// ---------------------------------------------------------------------------
127// FittedFactorAnalysis
128// ---------------------------------------------------------------------------
129
130/// A fitted Factor Analysis model.
131///
132/// Created by calling [`Fit::fit`] on a [`FactorAnalysis`].
133/// Implements [`Transform<Array2<F>>`] to compute factor scores for new data.
134#[derive(Debug, Clone)]
135pub struct FittedFactorAnalysis<F> {
136    /// Loading matrix `W`, shape `(n_features, n_components)`.
137    components: Array2<F>,
138
139    /// Noise variance vector `ψ`, shape `(n_features,)`.
140    noise_variance: Array1<F>,
141
142    /// Per-feature mean, shape `(n_features,)`.
143    mean: Array1<F>,
144
145    /// Number of EM iterations actually performed.
146    n_iter: usize,
147
148    /// Final log-likelihood value.
149    log_likelihood: F,
150}
151
152impl<F: Float + Send + Sync + 'static> FittedFactorAnalysis<F> {
153    /// Loading matrix `W`, shape `(n_features, n_components)`.
154    #[must_use]
155    pub fn components(&self) -> &Array2<F> {
156        &self.components
157    }
158
159    /// Noise variance vector `ψ`, shape `(n_features,)`.
160    #[must_use]
161    pub fn noise_variance(&self) -> &Array1<F> {
162        &self.noise_variance
163    }
164
165    /// Per-feature mean learned during fitting.
166    #[must_use]
167    pub fn mean(&self) -> &Array1<F> {
168        &self.mean
169    }
170
171    /// Number of EM iterations performed.
172    #[must_use]
173    pub fn n_iter(&self) -> usize {
174        self.n_iter
175    }
176
177    /// Final log-likelihood value.
178    #[must_use]
179    pub fn log_likelihood(&self) -> F {
180        self.log_likelihood
181    }
182}
183
184// ---------------------------------------------------------------------------
185// Internal helpers
186// ---------------------------------------------------------------------------
187
188/// Invert a small symmetric positive-definite matrix via Cholesky.
189fn cholesky_inv<F: Float>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
190    let n = a.nrows();
191    // Compute lower triangular L.
192    let mut l = Array2::<F>::zeros((n, n));
193    for i in 0..n {
194        for j in 0..=i {
195            let mut s = a[[i, j]];
196            for k in 0..j {
197                s = s - l[[i, k]] * l[[j, k]];
198            }
199            if i == j {
200                if s <= F::zero() {
201                    // Regularise.
202                    s = F::from(1e-10).unwrap();
203                }
204                l[[i, j]] = s.sqrt();
205            } else {
206                l[[i, j]] = s / l[[j, j]];
207            }
208        }
209    }
210    // Invert L using forward substitution: L L_inv = I.
211    let mut l_inv = Array2::<F>::zeros((n, n));
212    for j in 0..n {
213        l_inv[[j, j]] = F::one() / l[[j, j]];
214        for i in (j + 1)..n {
215            let mut s = F::zero();
216            for k in j..i {
217                s = s + l[[i, k]] * l_inv[[k, j]];
218            }
219            l_inv[[i, j]] = -s / l[[i, i]];
220        }
221    }
222    // A_inv = L_inv^T @ L_inv.
223    let mut inv = Array2::<F>::zeros((n, n));
224    for i in 0..n {
225        for j in 0..n {
226            let mut s = F::zero();
227            let start = i.max(j);
228            for k in start..n {
229                s = s + l_inv[[k, i]] * l_inv[[k, j]];
230            }
231            inv[[i, j]] = s;
232        }
233    }
234    Ok(inv)
235}
236
237/// Compute the log-likelihood under the factor analysis model.
238///
239/// `log p(X) = -n/2 * [p*log(2π) + log|Σ| + tr(Σ⁻¹ S)]`
240/// where `Σ = W W^T + diag(ψ)` and `S = X_c^T X_c / n`.
241fn compute_log_likelihood<F: Float + Send + Sync + 'static>(
242    x_centered: &Array2<F>,
243    w: &Array2<F>,
244    psi: &Array1<F>,
245) -> F {
246    let (n, p) = x_centered.dim();
247    let k = w.ncols();
248    // Σ = W W^T + diag(ψ)
249    // We use the Woodbury identity for the log-det and trace.
250    // log|Σ| = log|I_k + W^T Ψ⁻¹ W| + Σ_j log ψ_j
251    let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap();
252    let n_f = F::from(n).unwrap();
253    let p_f = F::from(p).unwrap();
254
255    // W^T Ψ⁻¹ W: k × k
256    let mut wtpsiw = Array2::<F>::zeros((k, k));
257    for i in 0..k {
258        for j in 0..k {
259            let mut s = F::zero();
260            for d in 0..p {
261                s = s + w[[d, i]] * w[[d, j]] / psi[d];
262            }
263            wtpsiw[[i, j]] = s;
264        }
265    }
266    // Add identity.
267    for i in 0..k {
268        wtpsiw[[i, i]] = wtpsiw[[i, i]] + F::one();
269    }
270    // log det of (I + W^T Ψ⁻¹ W) via Cholesky.
271    let mut log_det_inner = F::zero();
272    {
273        let mut l = Array2::<F>::zeros((k, k));
274        for i in 0..k {
275            for j in 0..=i {
276                let mut s = wtpsiw[[i, j]];
277                for kk in 0..j {
278                    s = s - l[[i, kk]] * l[[j, kk]];
279                }
280                if i == j {
281                    s = if s > F::zero() {
282                        s
283                    } else {
284                        F::from(1e-30).unwrap()
285                    };
286                    l[[i, j]] = s.sqrt();
287                    log_det_inner = log_det_inner + l[[i, j]].ln();
288                } else {
289                    l[[i, j]] = s / l[[j, j]];
290                }
291            }
292        }
293        log_det_inner = log_det_inner * F::from(2.0).unwrap();
294    }
295    let log_det_psi: F = psi
296        .iter()
297        .copied()
298        .map(|v| {
299            let v_clamped = if v > F::zero() {
300                v
301            } else {
302                F::from(1e-30).unwrap()
303            };
304            v_clamped.ln()
305        })
306        .fold(F::zero(), |a, b| a + b);
307    let log_det_sigma = log_det_inner + log_det_psi;
308
309    // Sample covariance S = X_c^T X_c / n.
310    // tr(Σ⁻¹ S) using Woodbury: Σ⁻¹ = Ψ⁻¹ - Ψ⁻¹ W M⁻¹ W^T Ψ⁻¹
311    // where M = I + W^T Ψ⁻¹ W.
312    // tr(Σ⁻¹ S) = (1/n) Σ_i x_i^T Σ⁻¹ x_i
313    // We compute it directly sample-by-sample for simplicity.
314    // For efficiency, we use the factored form:
315    // x^T Σ⁻¹ x = x^T Ψ⁻¹ x - (Ψ⁻¹ W m)^T M⁻¹ (W^T Ψ⁻¹ x)
316    // where m = W^T Ψ⁻¹ x.
317
318    // Invert M = I + W^T Ψ⁻¹ W.
319    let m_inv = match cholesky_inv(&wtpsiw) {
320        Ok(inv) => inv,
321        Err(_) => return F::neg_infinity(),
322    };
323
324    let mut trace_sum = F::zero();
325    for i in 0..n {
326        // Ψ⁻¹ x_i
327        let mut psi_inv_x = Array1::<F>::zeros(p);
328        let mut xpsiinvx = F::zero();
329        for d in 0..p {
330            psi_inv_x[d] = x_centered[[i, d]] / psi[d];
331            xpsiinvx = xpsiinvx + x_centered[[i, d]] * psi_inv_x[d];
332        }
333        // W^T Ψ⁻¹ x_i  (k-vector)
334        let mut wtpx = Array1::<F>::zeros(k);
335        for kk in 0..k {
336            let mut s = F::zero();
337            for d in 0..p {
338                s = s + w[[d, kk]] * psi_inv_x[d];
339            }
340            wtpx[kk] = s;
341        }
342        // (W^T Ψ⁻¹ x)^T M⁻¹ (W^T Ψ⁻¹ x)
343        let mut quad = F::zero();
344        for ii in 0..k {
345            let mut s = F::zero();
346            for jj in 0..k {
347                s = s + m_inv[[ii, jj]] * wtpx[jj];
348            }
349            quad = quad + wtpx[ii] * s;
350        }
351        trace_sum = trace_sum + xpsiinvx - quad;
352    }
353    let trace_term = trace_sum / n_f;
354
355    // log p = -n/2 * [p*log(2π) + log|Σ| + tr(Σ⁻¹ S)]
356    let half = F::from(0.5).unwrap();
357    -n_f * half * (p_f * two_pi.ln() + log_det_sigma + trace_term)
358}
359
360// ---------------------------------------------------------------------------
361// Fit
362// ---------------------------------------------------------------------------
363
364impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FactorAnalysis<F> {
365    type Fitted = FittedFactorAnalysis<F>;
366    type Error = FerroError;
367
368    /// Fit the Factor Analysis model using the EM algorithm.
369    ///
370    /// # Errors
371    ///
372    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
373    ///   `n_features`.
374    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
375    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFactorAnalysis<F>, FerroError> {
376        let (n_samples, n_features) = x.dim();
377
378        if self.n_components == 0 {
379            return Err(FerroError::InvalidParameter {
380                name: "n_components".into(),
381                reason: "must be at least 1".into(),
382            });
383        }
384        if self.n_components > n_features {
385            return Err(FerroError::InvalidParameter {
386                name: "n_components".into(),
387                reason: format!(
388                    "n_components ({}) exceeds n_features ({})",
389                    self.n_components, n_features
390                ),
391            });
392        }
393        if n_samples < 2 {
394            return Err(FerroError::InsufficientSamples {
395                required: 2,
396                actual: n_samples,
397                context: "FactorAnalysis requires at least 2 samples".into(),
398            });
399        }
400
401        let k = self.n_components;
402        let p = n_features;
403        let n_f = F::from(n_samples).unwrap();
404
405        // Compute mean and centre data.
406        let mut mean = Array1::<F>::zeros(p);
407        for j in 0..p {
408            let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
409            mean[j] = s / n_f;
410        }
411        let mut xc = x.to_owned();
412        for mut row in xc.rows_mut() {
413            for (v, &m) in row.iter_mut().zip(mean.iter()) {
414                *v = *v - m;
415            }
416        }
417
418        // Initialise W randomly, ψ = 1.
419        let seed = self.random_state.unwrap_or(42);
420        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
421        let std_normal = StandardNormal;
422        let mut w = Array2::<F>::zeros((p, k));
423        let scale = F::from(0.01).unwrap();
424        for i in 0..p {
425            for j in 0..k {
426                let v: f64 = std_normal.sample(&mut rng);
427                w[[i, j]] = F::from(v).unwrap() * scale;
428            }
429        }
430        let mut psi = Array1::<F>::from_elem(p, F::one());
431
432        let mut prev_ll = F::neg_infinity();
433        let mut n_iter = 0usize;
434        let tol_f = F::from(self.tol).unwrap();
435
436        for iter in 0..self.max_iter {
437            // --- E-step --------------------------------------------------------
438            // Σ_z = (I_k + W^T Ψ⁻¹ W)⁻¹   shape k × k
439            let mut wzw = Array2::<F>::zeros((k, k));
440            for i in 0..k {
441                for j in 0..k {
442                    let mut s = F::zero();
443                    for d in 0..p {
444                        s = s + w[[d, i]] * w[[d, j]] / psi[d];
445                    }
446                    wzw[[i, j]] = s;
447                }
448            }
449            for i in 0..k {
450                wzw[[i, i]] = wzw[[i, i]] + F::one();
451            }
452            let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
453                message: "FactorAnalysis: (I + W^T Ψ⁻¹ W) is singular".into(),
454            })?;
455
456            // β = Σ_z W^T Ψ⁻¹   shape k × p
457            let mut beta = Array2::<F>::zeros((k, p));
458            for i in 0..k {
459                for d in 0..p {
460                    let mut s = F::zero();
461                    for j in 0..k {
462                        s = s + sigma_z[[i, j]] * w[[d, j]];
463                    }
464                    beta[[i, d]] = s / psi[d];
465                }
466            }
467
468            // E[Z | X] = β X_c^T   shape k × n
469            let ez = beta.dot(&xc.t()); // k × n
470
471            // E[Z Z^T | X] summed over samples = n * Σ_z + Σ_i e_i e_i^T
472            // We keep the average: E_zzt = Σ_z + (1/n) Σ_i e_i e_i^T
473            // shape k × k
474            let ezz_t_sum = sigma_z.mapv(|v| v * n_f) + ez.dot(&ez.t()); // k × k
475
476            // --- M-step --------------------------------------------------------
477            // W_new = (Σ_i x_i e_i^T) (Σ_i e_i e_i^T)⁻¹
478            //       = X_c^T E[Z|X]^T * (n Σ_z + E[Z|X] E[Z|X]^T)⁻¹
479
480            // X_c^T E[Z|X]^T: xc^T is p×n, ez^T is n×k → result is p×k
481            let xc_ez_t = xc.t().dot(&ez.t()); // p × k
482
483            // ezz_t_sum is k × k
484            let ezz_t_inv =
485                cholesky_inv(&ezz_t_sum).map_err(|_| FerroError::NumericalInstability {
486                    message: "FactorAnalysis: E[ZZ^T] is singular in M-step".into(),
487                })?;
488
489            let w_new = xc_ez_t.dot(&ezz_t_inv); // p × k
490
491            // ψ_new[d] = (1/n) Σ_i (x_id² - w_new[d,:] e_i x_id)
492            //          = (1/n) [Σ_i x_id² - w_new[d,:] Σ_i e_i x_id^T]
493            //          = S[d,d] - (w_new[d,:] @ (1/n) Σ_i e_i x_id^T)
494            // (1/n) Σ_i e_i x_id = (1/n) ez[:,i] * x_i[d] = (1/n) ez @ x_c[:,d]
495            // = (1/n) ez @ xc[:, d]
496
497            let mut psi_new = Array1::<F>::zeros(p);
498            for d in 0..p {
499                // Sample variance of feature d.
500                let var_d = xc
501                    .column(d)
502                    .iter()
503                    .copied()
504                    .map(|v| v * v)
505                    .fold(F::zero(), |a, b| a + b)
506                    / n_f;
507                // w_new[d,:] @ (1/n) ez @ xc[:,d]
508                // (1/n) ez @ xc[:,d] is (1/n) Σ_i ez[:,i] * xc[i,d] — k-vector
509                let mut ez_xd = Array1::<F>::zeros(k);
510                for kk in 0..k {
511                    let s = (0..n_samples)
512                        .map(|i| ez[[kk, i]] * xc[[i, d]])
513                        .fold(F::zero(), |a, b| a + b);
514                    ez_xd[kk] = s / n_f;
515                }
516                let wd = w_new.row(d);
517                let corr = wd
518                    .iter()
519                    .zip(ez_xd.iter())
520                    .map(|(&wi, &ei)| wi * ei)
521                    .fold(F::zero(), |a, b| a + b);
522                let psi_d = var_d - corr;
523                psi_new[d] = if psi_d > F::from(1e-6).unwrap() {
524                    psi_d
525                } else {
526                    F::from(1e-6).unwrap()
527                };
528            }
529
530            w = w_new;
531            psi = psi_new;
532
533            // --- Convergence check ------------------------------------------
534            let ll = compute_log_likelihood(&xc, &w, &psi);
535            let ll_change = (ll - prev_ll).abs();
536            n_iter = iter + 1;
537            if ll_change < tol_f && iter > 0 {
538                prev_ll = ll;
539                break;
540            }
541            prev_ll = ll;
542        }
543
544        Ok(FittedFactorAnalysis {
545            components: w,
546            noise_variance: psi,
547            mean,
548            n_iter,
549            log_likelihood: prev_ll,
550        })
551    }
552}
553
554// ---------------------------------------------------------------------------
555// Transform — compute factor scores
556// ---------------------------------------------------------------------------
557
558impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFactorAnalysis<F> {
559    type Output = Array2<F>;
560    type Error = FerroError;
561
562    /// Compute factor scores: `E[Z | X] = Σ_z W^T Ψ⁻¹ (X - μ)^T`.
563    ///
564    /// Returns an array of shape `(n_samples, n_components)`.
565    ///
566    /// # Errors
567    ///
568    /// Returns [`FerroError::ShapeMismatch`] if the number of columns in `x`
569    /// does not match the model.
570    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
571        let n_features = self.mean.len();
572        if x.ncols() != n_features {
573            return Err(FerroError::ShapeMismatch {
574                expected: vec![x.nrows(), n_features],
575                actual: vec![x.nrows(), x.ncols()],
576                context: "FittedFactorAnalysis::transform".into(),
577            });
578        }
579        let (n_samples, _) = x.dim();
580        let k = self.components.ncols();
581
582        // Centre.
583        let mut xc = x.to_owned();
584        for mut row in xc.rows_mut() {
585            for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
586                *v = *v - m;
587            }
588        }
589
590        // Σ_z = (I + W^T Ψ⁻¹ W)⁻¹
591        let mut wzw = Array2::<F>::zeros((k, k));
592        for i in 0..k {
593            for j in 0..k {
594                let mut s = F::zero();
595                for d in 0..n_features {
596                    s = s + self.components[[d, i]] * self.components[[d, j]]
597                        / self.noise_variance[d];
598                }
599                wzw[[i, j]] = s;
600            }
601        }
602        for i in 0..k {
603            wzw[[i, i]] = wzw[[i, i]] + F::one();
604        }
605        let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
606            message: "FittedFactorAnalysis::transform: (I + W^T Ψ⁻¹ W) is singular".into(),
607        })?;
608
609        // β = Σ_z W^T Ψ⁻¹  (k × p)
610        let mut beta = Array2::<F>::zeros((k, n_features));
611        for i in 0..k {
612            for d in 0..n_features {
613                let mut s = F::zero();
614                for j in 0..k {
615                    s = s + sigma_z[[i, j]] * self.components[[d, j]];
616                }
617                beta[[i, d]] = s / self.noise_variance[d];
618            }
619        }
620
621        // scores = (β @ X_c^T)^T  (n × k)
622        let ez = beta.dot(&xc.t()); // k × n
623        let scores = ez.t().to_owned(); // n × k
624        assert_eq!(scores.dim(), (n_samples, k));
625        Ok(scores)
626    }
627}
628
629// ---------------------------------------------------------------------------
630// Pipeline integration
631// ---------------------------------------------------------------------------
632
633impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FactorAnalysis<F> {
634    /// Fit using the pipeline interface (ignores `y`).
635    ///
636    /// # Errors
637    ///
638    /// Propagates errors from [`Fit::fit`].
639    fn fit_pipeline(
640        &self,
641        x: &Array2<F>,
642        _y: &Array1<F>,
643    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
644        let fitted = self.fit(x, &())?;
645        Ok(Box::new(fitted))
646    }
647}
648
649impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFactorAnalysis<F> {
650    /// Transform via the pipeline interface.
651    ///
652    /// # Errors
653    ///
654    /// Propagates errors from [`Transform::transform`].
655    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
656        self.transform(x)
657    }
658}
659
660// ---------------------------------------------------------------------------
661// Tests
662// ---------------------------------------------------------------------------
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use approx::assert_abs_diff_eq;
668    use ndarray::Array2;
669
670    fn simple_data() -> Array2<f64> {
671        // 10 samples, 4 features with some latent structure.
672        Array2::from_shape_vec(
673            (10, 4),
674            vec![
675                1.0, 2.0, 1.5, 3.0, 1.1, 2.1, 1.6, 3.1, 0.9, 1.9, 1.4, 2.9, 2.0, 4.0, 3.0, 6.0,
676                2.1, 4.1, 3.1, 6.1, 1.9, 3.9, 2.9, 5.9, 0.5, 1.0, 0.7, 1.5, 0.4, 0.9, 0.6, 1.4,
677                0.6, 1.1, 0.8, 1.6, 1.5, 3.0, 2.2, 4.5,
678            ],
679        )
680        .unwrap()
681    }
682
683    #[test]
684    fn test_fa_fit_returns_fitted() {
685        let fa = FactorAnalysis::<f64>::new(2);
686        let x = simple_data();
687        let fitted = fa.fit(&x, &()).unwrap();
688        assert_eq!(fitted.components().dim(), (4, 2));
689    }
690
691    #[test]
692    fn test_fa_transform_shape() {
693        let fa = FactorAnalysis::<f64>::new(2);
694        let x = simple_data();
695        let fitted = fa.fit(&x, &()).unwrap();
696        let scores = fitted.transform(&x).unwrap();
697        assert_eq!(scores.dim(), (10, 2));
698    }
699
700    #[test]
701    fn test_fa_transform_new_data() {
702        let fa = FactorAnalysis::<f64>::new(1);
703        let x = simple_data();
704        let fitted = fa.fit(&x, &()).unwrap();
705        let x_new = Array2::from_shape_vec(
706            (3, 4),
707            vec![1.0, 2.0, 1.5, 3.0, 2.0, 4.0, 3.0, 6.0, 0.5, 1.0, 0.7, 1.5],
708        )
709        .unwrap();
710        let scores = fitted.transform(&x_new).unwrap();
711        assert_eq!(scores.dim(), (3, 1));
712    }
713
714    #[test]
715    fn test_fa_noise_variance_positive() {
716        let fa = FactorAnalysis::<f64>::new(1);
717        let x = simple_data();
718        let fitted = fa.fit(&x, &()).unwrap();
719        for &v in fitted.noise_variance().iter() {
720            assert!(v > 0.0, "noise variance must be positive, got {v}");
721        }
722    }
723
724    #[test]
725    fn test_fa_mean_shape() {
726        let fa = FactorAnalysis::<f64>::new(1);
727        let x = simple_data();
728        let fitted = fa.fit(&x, &()).unwrap();
729        assert_eq!(fitted.mean().len(), 4);
730    }
731
732    #[test]
733    fn test_fa_n_iter_positive() {
734        let fa = FactorAnalysis::<f64>::new(1);
735        let x = simple_data();
736        let fitted = fa.fit(&x, &()).unwrap();
737        assert!(fitted.n_iter() >= 1);
738    }
739
740    #[test]
741    fn test_fa_log_likelihood_finite() {
742        let fa = FactorAnalysis::<f64>::new(1);
743        let x = simple_data();
744        let fitted = fa.fit(&x, &()).unwrap();
745        assert!(fitted.log_likelihood().is_finite());
746    }
747
748    #[test]
749    fn test_fa_error_zero_components() {
750        let fa = FactorAnalysis::<f64>::new(0);
751        let x = simple_data();
752        assert!(fa.fit(&x, &()).is_err());
753    }
754
755    #[test]
756    fn test_fa_error_too_many_components() {
757        let fa = FactorAnalysis::<f64>::new(10); // more than n_features = 4
758        let x = simple_data();
759        assert!(fa.fit(&x, &()).is_err());
760    }
761
762    #[test]
763    fn test_fa_error_insufficient_samples() {
764        let fa = FactorAnalysis::<f64>::new(1);
765        let x = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
766        assert!(fa.fit(&x, &()).is_err());
767    }
768
769    #[test]
770    fn test_fa_transform_shape_mismatch() {
771        let fa = FactorAnalysis::<f64>::new(1);
772        let x = simple_data();
773        let fitted = fa.fit(&x, &()).unwrap();
774        let x_bad = Array2::<f64>::zeros((3, 7));
775        assert!(fitted.transform(&x_bad).is_err());
776    }
777
778    #[test]
779    fn test_fa_reproducible_with_seed() {
780        let fa1 = FactorAnalysis::<f64>::new(2).with_random_state(42);
781        let fa2 = FactorAnalysis::<f64>::new(2).with_random_state(42);
782        let x = simple_data();
783        let f1 = fa1.fit(&x, &()).unwrap();
784        let f2 = fa2.fit(&x, &()).unwrap();
785        let c1 = f1.components();
786        let c2 = f2.components();
787        for (a, b) in c1.iter().zip(c2.iter()) {
788            assert_abs_diff_eq!(a, b, epsilon = 1e-12);
789        }
790    }
791
792    #[test]
793    fn test_fa_different_seeds_differ() {
794        let fa1 = FactorAnalysis::<f64>::new(2)
795            .with_random_state(0)
796            .with_max_iter(1);
797        let fa2 = FactorAnalysis::<f64>::new(2)
798            .with_random_state(99)
799            .with_max_iter(1);
800        let x = simple_data();
801        let f1 = fa1.fit(&x, &()).unwrap();
802        let f2 = fa2.fit(&x, &()).unwrap();
803        // After 1 iteration with different seeds the components should differ.
804        let diff: f64 = f1
805            .components()
806            .iter()
807            .zip(f2.components().iter())
808            .map(|(a, b)| (a - b).abs())
809            .sum();
810        // They may differ unless the initialisation is identical.
811        let _ = diff; // just check it doesn't panic
812    }
813
814    #[test]
815    fn test_fa_components_accessor() {
816        let fa = FactorAnalysis::<f64>::new(2);
817        let x = simple_data();
818        let fitted = fa.fit(&x, &()).unwrap();
819        assert_eq!(fitted.components().ncols(), 2);
820        assert_eq!(fitted.components().nrows(), 4);
821    }
822
823    #[test]
824    fn test_fa_n_components_getter() {
825        let fa = FactorAnalysis::<f64>::new(3);
826        assert_eq!(fa.n_components(), 3);
827    }
828
829    #[test]
830    fn test_fa_pipeline_transformer() {
831        use ferrolearn_core::pipeline::PipelineTransformer;
832        let fa = FactorAnalysis::<f64>::new(2);
833        let x = simple_data();
834        let y = Array1::<f64>::zeros(10);
835        let fitted = fa.fit_pipeline(&x, &y).unwrap();
836        let out = fitted.transform_pipeline(&x).unwrap();
837        assert_eq!(out.ncols(), 2);
838    }
839
840    #[test]
841    fn test_fa_scores_not_all_zero() {
842        let fa = FactorAnalysis::<f64>::new(2);
843        let x = simple_data();
844        let fitted = fa.fit(&x, &()).unwrap();
845        let scores = fitted.transform(&x).unwrap();
846        let total: f64 = scores.iter().map(|v| v.abs()).sum();
847        assert!(total > 0.0, "Factor scores should not all be zero");
848    }
849}