Skip to main content

ferrolearn_decomp/
factor_analysis.rs

1//! Factor Analysis (FA) via the EM algorithm.
2//!
3//! Factor Analysis assumes that data is generated by a linear combination of
4//! latent factors plus independent Gaussian noise:
5//!
6//! ```text
7//! X = W Z + μ + ε,   Z ~ N(0, I),   ε ~ N(0, diag(ψ))
8//! ```
9//!
10//! where:
11//! - `W` is the `(n_features × n_components)` loading matrix,
12//! - `Z` is the `(n_components,)` latent factor vector,
13//! - `ψ` is the `(n_features,)` noise variance vector.
14//!
15//! # Algorithm
16//!
17//! 1. Centre the data: `X_c = X - μ`.
18//! 2. **E-step**: compute the posterior mean and covariance of `Z`:
19//!    ```text
20//!    Σ_z = (I + W^T diag(ψ)⁻¹ W)⁻¹
21//!    E[Z | X] = Σ_z W^T diag(ψ)⁻¹ X_c^T
22//!    ```
23//! 3. **M-step**: update `W` and `ψ` via maximum-likelihood closed-form
24//!    updates.
25//! 4. Repeat until convergence (log-likelihood change < `tol`).
26//!
27//! # Examples
28//!
29//! ```
30//! use ferrolearn_decomp::factor_analysis::FactorAnalysis;
31//! use ferrolearn_core::traits::{Fit, Transform};
32//! use ndarray::Array2;
33//!
34//! let fa = FactorAnalysis::new(2);
35//! let x = Array2::from_shape_vec(
36//!     (10, 4),
37//!     (0..40).map(|v| v as f64 * 0.1 + (v % 3) as f64 * 0.5).collect(),
38//! ).unwrap();
39//! let fitted = fa.fit(&x, &()).unwrap();
40//! let scores = fitted.transform(&x).unwrap();
41//! assert_eq!(scores.ncols(), 2);
42//! ```
43
44use ferrolearn_core::error::FerroError;
45use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
46use ferrolearn_core::traits::{Fit, Transform};
47use ndarray::{Array1, Array2};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand_distr::{Distribution, StandardNormal};
51
52// ---------------------------------------------------------------------------
53// FactorAnalysis (unfitted)
54// ---------------------------------------------------------------------------
55
56/// Factor Analysis configuration.
57///
58/// Calling [`Fit::fit`] fits the EM algorithm and returns a
59/// [`FittedFactorAnalysis`].
60///
61/// # Type Parameters
62///
63/// - `F`: The floating-point scalar type.
64#[derive(Debug, Clone)]
65pub struct FactorAnalysis<F> {
66    /// Number of latent factors to extract.
67    n_components: usize,
68    /// Maximum number of EM iterations.
69    max_iter: usize,
70    /// Convergence tolerance on the log-likelihood change.
71    tol: f64,
72    /// Optional random seed for reproducibility.
73    random_state: Option<u64>,
74    _marker: std::marker::PhantomData<F>,
75}
76
77impl<F: Float + Send + Sync + 'static> FactorAnalysis<F> {
78    /// Create a new `FactorAnalysis` with `n_components` factors.
79    ///
80    /// Defaults: `max_iter = 1000`, `tol = 1e-3`, no fixed random seed.
81    #[must_use]
82    pub fn new(n_components: usize) -> Self {
83        Self {
84            n_components,
85            max_iter: 1000,
86            tol: 1e-3,
87            random_state: None,
88            _marker: std::marker::PhantomData,
89        }
90    }
91
92    /// Set the maximum number of EM iterations.
93    #[must_use]
94    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
95        self.max_iter = max_iter;
96        self
97    }
98
99    /// Set the convergence tolerance.
100    #[must_use]
101    pub fn with_tol(mut self, tol: f64) -> Self {
102        self.tol = tol;
103        self
104    }
105
106    /// Set the random seed for reproducibility.
107    #[must_use]
108    pub fn with_random_state(mut self, seed: u64) -> Self {
109        self.random_state = Some(seed);
110        self
111    }
112
113    /// Return the number of latent factors.
114    #[must_use]
115    pub fn n_components(&self) -> usize {
116        self.n_components
117    }
118}
119
120impl<F: Float + Send + Sync + 'static> Default for FactorAnalysis<F> {
121    fn default() -> Self {
122        Self::new(1)
123    }
124}
125
126// ---------------------------------------------------------------------------
127// FittedFactorAnalysis
128// ---------------------------------------------------------------------------
129
130/// A fitted Factor Analysis model.
131///
132/// Created by calling [`Fit::fit`] on a [`FactorAnalysis`].
133/// Implements [`Transform<Array2<F>>`] to compute factor scores for new data.
134#[derive(Debug, Clone)]
135pub struct FittedFactorAnalysis<F> {
136    /// Loading matrix `W`, shape `(n_features, n_components)`.
137    components: Array2<F>,
138
139    /// Noise variance vector `ψ`, shape `(n_features,)`.
140    noise_variance: Array1<F>,
141
142    /// Per-feature mean, shape `(n_features,)`.
143    mean: Array1<F>,
144
145    /// Number of EM iterations actually performed.
146    n_iter: usize,
147
148    /// Final log-likelihood value.
149    log_likelihood: F,
150}
151
152impl<F: Float + Send + Sync + 'static> FittedFactorAnalysis<F> {
153    /// Loading matrix `W`, shape `(n_features, n_components)`.
154    #[must_use]
155    pub fn components(&self) -> &Array2<F> {
156        &self.components
157    }
158
159    /// Noise variance vector `ψ`, shape `(n_features,)`.
160    #[must_use]
161    pub fn noise_variance(&self) -> &Array1<F> {
162        &self.noise_variance
163    }
164
165    /// Per-feature mean learned during fitting.
166    #[must_use]
167    pub fn mean(&self) -> &Array1<F> {
168        &self.mean
169    }
170
171    /// Number of EM iterations performed.
172    #[must_use]
173    pub fn n_iter(&self) -> usize {
174        self.n_iter
175    }
176
177    /// Final log-likelihood value.
178    #[must_use]
179    pub fn log_likelihood(&self) -> F {
180        self.log_likelihood
181    }
182
183    /// Map latent representation back to the original feature space.
184    /// Mirrors sklearn `FactorAnalysis.inverse_transform`. Returns
185    /// `Z @ Wᵀ + mean` where `W` is the loading matrix.
186    ///
187    /// Note: ferrolearn's FactorAnalysis stores `components` with shape
188    /// `(n_features, n_components)` (transposed relative to sklearn's
189    /// `components_` layout), so the formula transposes accordingly.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`FerroError::ShapeMismatch`] if `z.ncols()` does not
194    /// equal the number of components.
195    pub fn inverse_transform(&self, z: &Array2<F>) -> Result<Array2<F>, FerroError> {
196        let n_components = self.components.ncols();
197        if z.ncols() != n_components {
198            return Err(FerroError::ShapeMismatch {
199                expected: vec![z.nrows(), n_components],
200                actual: vec![z.nrows(), z.ncols()],
201                context: "FittedFactorAnalysis::inverse_transform".into(),
202            });
203        }
204        let mut result = z.dot(&self.components.t());
205        for mut row in result.rows_mut() {
206            for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
207                *v = *v + m;
208            }
209        }
210        Ok(result)
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Internal helpers
216// ---------------------------------------------------------------------------
217
218/// Invert a small symmetric positive-definite matrix via Cholesky.
219fn cholesky_inv<F: Float>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
220    let n = a.nrows();
221    // Compute lower triangular L.
222    let mut l = Array2::<F>::zeros((n, n));
223    for i in 0..n {
224        for j in 0..=i {
225            let mut s = a[[i, j]];
226            for k in 0..j {
227                s = s - l[[i, k]] * l[[j, k]];
228            }
229            if i == j {
230                if s <= F::zero() {
231                    // Regularise.
232                    s = F::from(1e-10).unwrap();
233                }
234                l[[i, j]] = s.sqrt();
235            } else {
236                l[[i, j]] = s / l[[j, j]];
237            }
238        }
239    }
240    // Invert L using forward substitution: L L_inv = I.
241    let mut l_inv = Array2::<F>::zeros((n, n));
242    for j in 0..n {
243        l_inv[[j, j]] = F::one() / l[[j, j]];
244        for i in (j + 1)..n {
245            let mut s = F::zero();
246            for k in j..i {
247                s = s + l[[i, k]] * l_inv[[k, j]];
248            }
249            l_inv[[i, j]] = -s / l[[i, i]];
250        }
251    }
252    // A_inv = L_inv^T @ L_inv.
253    let mut inv = Array2::<F>::zeros((n, n));
254    for i in 0..n {
255        for j in 0..n {
256            let mut s = F::zero();
257            let start = i.max(j);
258            for k in start..n {
259                s = s + l_inv[[k, i]] * l_inv[[k, j]];
260            }
261            inv[[i, j]] = s;
262        }
263    }
264    Ok(inv)
265}
266
267/// Compute the log-likelihood under the factor analysis model.
268///
269/// `log p(X) = -n/2 * [p*log(2π) + log|Σ| + tr(Σ⁻¹ S)]`
270/// where `Σ = W W^T + diag(ψ)` and `S = X_c^T X_c / n`.
271fn compute_log_likelihood<F: Float + Send + Sync + 'static>(
272    x_centered: &Array2<F>,
273    w: &Array2<F>,
274    psi: &Array1<F>,
275) -> F {
276    let (n, p) = x_centered.dim();
277    let k = w.ncols();
278    // Σ = W W^T + diag(ψ)
279    // We use the Woodbury identity for the log-det and trace.
280    // log|Σ| = log|I_k + W^T Ψ⁻¹ W| + Σ_j log ψ_j
281    let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap();
282    let n_f = F::from(n).unwrap();
283    let p_f = F::from(p).unwrap();
284
285    // W^T Ψ⁻¹ W: k × k
286    let mut wtpsiw = Array2::<F>::zeros((k, k));
287    for i in 0..k {
288        for j in 0..k {
289            let mut s = F::zero();
290            for d in 0..p {
291                s = s + w[[d, i]] * w[[d, j]] / psi[d];
292            }
293            wtpsiw[[i, j]] = s;
294        }
295    }
296    // Add identity.
297    for i in 0..k {
298        wtpsiw[[i, i]] = wtpsiw[[i, i]] + F::one();
299    }
300    // log det of (I + W^T Ψ⁻¹ W) via Cholesky.
301    let mut log_det_inner = F::zero();
302    {
303        let mut l = Array2::<F>::zeros((k, k));
304        for i in 0..k {
305            for j in 0..=i {
306                let mut s = wtpsiw[[i, j]];
307                for kk in 0..j {
308                    s = s - l[[i, kk]] * l[[j, kk]];
309                }
310                if i == j {
311                    s = if s > F::zero() {
312                        s
313                    } else {
314                        F::from(1e-30).unwrap()
315                    };
316                    l[[i, j]] = s.sqrt();
317                    log_det_inner = log_det_inner + l[[i, j]].ln();
318                } else {
319                    l[[i, j]] = s / l[[j, j]];
320                }
321            }
322        }
323        log_det_inner = log_det_inner * F::from(2.0).unwrap();
324    }
325    let log_det_psi: F = psi
326        .iter()
327        .copied()
328        .map(|v| {
329            let v_clamped = if v > F::zero() {
330                v
331            } else {
332                F::from(1e-30).unwrap()
333            };
334            v_clamped.ln()
335        })
336        .fold(F::zero(), |a, b| a + b);
337    let log_det_sigma = log_det_inner + log_det_psi;
338
339    // Sample covariance S = X_c^T X_c / n.
340    // tr(Σ⁻¹ S) using Woodbury: Σ⁻¹ = Ψ⁻¹ - Ψ⁻¹ W M⁻¹ W^T Ψ⁻¹
341    // where M = I + W^T Ψ⁻¹ W.
342    // tr(Σ⁻¹ S) = (1/n) Σ_i x_i^T Σ⁻¹ x_i
343    // We compute it directly sample-by-sample for simplicity.
344    // For efficiency, we use the factored form:
345    // x^T Σ⁻¹ x = x^T Ψ⁻¹ x - (Ψ⁻¹ W m)^T M⁻¹ (W^T Ψ⁻¹ x)
346    // where m = W^T Ψ⁻¹ x.
347
348    // Invert M = I + W^T Ψ⁻¹ W.
349    let m_inv = match cholesky_inv(&wtpsiw) {
350        Ok(inv) => inv,
351        Err(_) => return F::neg_infinity(),
352    };
353
354    let mut trace_sum = F::zero();
355    for i in 0..n {
356        // Ψ⁻¹ x_i
357        let mut psi_inv_x = Array1::<F>::zeros(p);
358        let mut xpsiinvx = F::zero();
359        for d in 0..p {
360            psi_inv_x[d] = x_centered[[i, d]] / psi[d];
361            xpsiinvx = xpsiinvx + x_centered[[i, d]] * psi_inv_x[d];
362        }
363        // W^T Ψ⁻¹ x_i  (k-vector)
364        let mut wtpx = Array1::<F>::zeros(k);
365        for kk in 0..k {
366            let mut s = F::zero();
367            for d in 0..p {
368                s = s + w[[d, kk]] * psi_inv_x[d];
369            }
370            wtpx[kk] = s;
371        }
372        // (W^T Ψ⁻¹ x)^T M⁻¹ (W^T Ψ⁻¹ x)
373        let mut quad = F::zero();
374        for ii in 0..k {
375            let mut s = F::zero();
376            for jj in 0..k {
377                s = s + m_inv[[ii, jj]] * wtpx[jj];
378            }
379            quad = quad + wtpx[ii] * s;
380        }
381        trace_sum = trace_sum + xpsiinvx - quad;
382    }
383    let trace_term = trace_sum / n_f;
384
385    // log p = -n/2 * [p*log(2π) + log|Σ| + tr(Σ⁻¹ S)]
386    let half = F::from(0.5).unwrap();
387    -n_f * half * (p_f * two_pi.ln() + log_det_sigma + trace_term)
388}
389
390// ---------------------------------------------------------------------------
391// Fit
392// ---------------------------------------------------------------------------
393
394impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FactorAnalysis<F> {
395    type Fitted = FittedFactorAnalysis<F>;
396    type Error = FerroError;
397
398    /// Fit the Factor Analysis model using the EM algorithm.
399    ///
400    /// # Errors
401    ///
402    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
403    ///   `n_features`.
404    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
405    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFactorAnalysis<F>, FerroError> {
406        let (n_samples, n_features) = x.dim();
407
408        if self.n_components == 0 {
409            return Err(FerroError::InvalidParameter {
410                name: "n_components".into(),
411                reason: "must be at least 1".into(),
412            });
413        }
414        if self.n_components > n_features {
415            return Err(FerroError::InvalidParameter {
416                name: "n_components".into(),
417                reason: format!(
418                    "n_components ({}) exceeds n_features ({})",
419                    self.n_components, n_features
420                ),
421            });
422        }
423        if n_samples < 2 {
424            return Err(FerroError::InsufficientSamples {
425                required: 2,
426                actual: n_samples,
427                context: "FactorAnalysis requires at least 2 samples".into(),
428            });
429        }
430
431        let k = self.n_components;
432        let p = n_features;
433        let n_f = F::from(n_samples).unwrap();
434
435        // Compute mean and centre data.
436        let mut mean = Array1::<F>::zeros(p);
437        for j in 0..p {
438            let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
439            mean[j] = s / n_f;
440        }
441        let mut xc = x.to_owned();
442        for mut row in xc.rows_mut() {
443            for (v, &m) in row.iter_mut().zip(mean.iter()) {
444                *v = *v - m;
445            }
446        }
447
448        // Initialise W randomly, ψ = 1.
449        let seed = self.random_state.unwrap_or(42);
450        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
451        let std_normal = StandardNormal;
452        let mut w = Array2::<F>::zeros((p, k));
453        let scale = F::from(0.01).unwrap();
454        for i in 0..p {
455            for j in 0..k {
456                let v: f64 = std_normal.sample(&mut rng);
457                w[[i, j]] = F::from(v).unwrap() * scale;
458            }
459        }
460        let mut psi = Array1::<F>::from_elem(p, F::one());
461
462        let mut prev_ll = F::neg_infinity();
463        let mut n_iter = 0usize;
464        let tol_f = F::from(self.tol).unwrap();
465
466        for iter in 0..self.max_iter {
467            // --- E-step --------------------------------------------------------
468            // Σ_z = (I_k + W^T Ψ⁻¹ W)⁻¹   shape k × k
469            let mut wzw = Array2::<F>::zeros((k, k));
470            for i in 0..k {
471                for j in 0..k {
472                    let mut s = F::zero();
473                    for d in 0..p {
474                        s = s + w[[d, i]] * w[[d, j]] / psi[d];
475                    }
476                    wzw[[i, j]] = s;
477                }
478            }
479            for i in 0..k {
480                wzw[[i, i]] = wzw[[i, i]] + F::one();
481            }
482            let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
483                message: "FactorAnalysis: (I + W^T Ψ⁻¹ W) is singular".into(),
484            })?;
485
486            // β = Σ_z W^T Ψ⁻¹   shape k × p
487            let mut beta = Array2::<F>::zeros((k, p));
488            for i in 0..k {
489                for d in 0..p {
490                    let mut s = F::zero();
491                    for j in 0..k {
492                        s = s + sigma_z[[i, j]] * w[[d, j]];
493                    }
494                    beta[[i, d]] = s / psi[d];
495                }
496            }
497
498            // E[Z | X] = β X_c^T   shape k × n
499            let ez = beta.dot(&xc.t()); // k × n
500
501            // E[Z Z^T | X] summed over samples = n * Σ_z + Σ_i e_i e_i^T
502            // We keep the average: E_zzt = Σ_z + (1/n) Σ_i e_i e_i^T
503            // shape k × k
504            let ezz_t_sum = sigma_z.mapv(|v| v * n_f) + ez.dot(&ez.t()); // k × k
505
506            // --- M-step --------------------------------------------------------
507            // W_new = (Σ_i x_i e_i^T) (Σ_i e_i e_i^T)⁻¹
508            //       = X_c^T E[Z|X]^T * (n Σ_z + E[Z|X] E[Z|X]^T)⁻¹
509
510            // X_c^T E[Z|X]^T: xc^T is p×n, ez^T is n×k → result is p×k
511            let xc_ez_t = xc.t().dot(&ez.t()); // p × k
512
513            // ezz_t_sum is k × k
514            let ezz_t_inv =
515                cholesky_inv(&ezz_t_sum).map_err(|_| FerroError::NumericalInstability {
516                    message: "FactorAnalysis: E[ZZ^T] is singular in M-step".into(),
517                })?;
518
519            let w_new = xc_ez_t.dot(&ezz_t_inv); // p × k
520
521            // ψ_new[d] = (1/n) Σ_i (x_id² - w_new[d,:] e_i x_id)
522            //          = (1/n) [Σ_i x_id² - w_new[d,:] Σ_i e_i x_id^T]
523            //          = S[d,d] - (w_new[d,:] @ (1/n) Σ_i e_i x_id^T)
524            // (1/n) Σ_i e_i x_id = (1/n) ez[:,i] * x_i[d] = (1/n) ez @ x_c[:,d]
525            // = (1/n) ez @ xc[:, d]
526
527            let mut psi_new = Array1::<F>::zeros(p);
528            for d in 0..p {
529                // Sample variance of feature d.
530                let var_d = xc
531                    .column(d)
532                    .iter()
533                    .copied()
534                    .map(|v| v * v)
535                    .fold(F::zero(), |a, b| a + b)
536                    / n_f;
537                // w_new[d,:] @ (1/n) ez @ xc[:,d]
538                // (1/n) ez @ xc[:,d] is (1/n) Σ_i ez[:,i] * xc[i,d] — k-vector
539                let mut ez_xd = Array1::<F>::zeros(k);
540                for kk in 0..k {
541                    let s = (0..n_samples)
542                        .map(|i| ez[[kk, i]] * xc[[i, d]])
543                        .fold(F::zero(), |a, b| a + b);
544                    ez_xd[kk] = s / n_f;
545                }
546                let wd = w_new.row(d);
547                let corr = wd
548                    .iter()
549                    .zip(ez_xd.iter())
550                    .map(|(&wi, &ei)| wi * ei)
551                    .fold(F::zero(), |a, b| a + b);
552                let psi_d = var_d - corr;
553                psi_new[d] = if psi_d > F::from(1e-6).unwrap() {
554                    psi_d
555                } else {
556                    F::from(1e-6).unwrap()
557                };
558            }
559
560            w = w_new;
561            psi = psi_new;
562
563            // --- Convergence check ------------------------------------------
564            let ll = compute_log_likelihood(&xc, &w, &psi);
565            let ll_change = (ll - prev_ll).abs();
566            n_iter = iter + 1;
567            if ll_change < tol_f && iter > 0 {
568                prev_ll = ll;
569                break;
570            }
571            prev_ll = ll;
572        }
573
574        Ok(FittedFactorAnalysis {
575            components: w,
576            noise_variance: psi,
577            mean,
578            n_iter,
579            log_likelihood: prev_ll,
580        })
581    }
582}
583
584// ---------------------------------------------------------------------------
585// Transform — compute factor scores
586// ---------------------------------------------------------------------------
587
588impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFactorAnalysis<F> {
589    type Output = Array2<F>;
590    type Error = FerroError;
591
592    /// Compute factor scores: `E[Z | X] = Σ_z W^T Ψ⁻¹ (X - μ)^T`.
593    ///
594    /// Returns an array of shape `(n_samples, n_components)`.
595    ///
596    /// # Errors
597    ///
598    /// Returns [`FerroError::ShapeMismatch`] if the number of columns in `x`
599    /// does not match the model.
600    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
601        let n_features = self.mean.len();
602        if x.ncols() != n_features {
603            return Err(FerroError::ShapeMismatch {
604                expected: vec![x.nrows(), n_features],
605                actual: vec![x.nrows(), x.ncols()],
606                context: "FittedFactorAnalysis::transform".into(),
607            });
608        }
609        let (n_samples, _) = x.dim();
610        let k = self.components.ncols();
611
612        // Centre.
613        let mut xc = x.to_owned();
614        for mut row in xc.rows_mut() {
615            for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
616                *v = *v - m;
617            }
618        }
619
620        // Σ_z = (I + W^T Ψ⁻¹ W)⁻¹
621        let mut wzw = Array2::<F>::zeros((k, k));
622        for i in 0..k {
623            for j in 0..k {
624                let mut s = F::zero();
625                for d in 0..n_features {
626                    s = s + self.components[[d, i]] * self.components[[d, j]]
627                        / self.noise_variance[d];
628                }
629                wzw[[i, j]] = s;
630            }
631        }
632        for i in 0..k {
633            wzw[[i, i]] = wzw[[i, i]] + F::one();
634        }
635        let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
636            message: "FittedFactorAnalysis::transform: (I + W^T Ψ⁻¹ W) is singular".into(),
637        })?;
638
639        // β = Σ_z W^T Ψ⁻¹  (k × p)
640        let mut beta = Array2::<F>::zeros((k, n_features));
641        for i in 0..k {
642            for d in 0..n_features {
643                let mut s = F::zero();
644                for j in 0..k {
645                    s = s + sigma_z[[i, j]] * self.components[[d, j]];
646                }
647                beta[[i, d]] = s / self.noise_variance[d];
648            }
649        }
650
651        // scores = (β @ X_c^T)^T  (n × k)
652        let ez = beta.dot(&xc.t()); // k × n
653        let scores = ez.t().to_owned(); // n × k
654        assert_eq!(scores.dim(), (n_samples, k));
655        Ok(scores)
656    }
657}
658
659// ---------------------------------------------------------------------------
660// Pipeline integration
661// ---------------------------------------------------------------------------
662
663impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FactorAnalysis<F> {
664    /// Fit using the pipeline interface (ignores `y`).
665    ///
666    /// # Errors
667    ///
668    /// Propagates errors from [`Fit::fit`].
669    fn fit_pipeline(
670        &self,
671        x: &Array2<F>,
672        _y: &Array1<F>,
673    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
674        let fitted = self.fit(x, &())?;
675        Ok(Box::new(fitted))
676    }
677}
678
679impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFactorAnalysis<F> {
680    /// Transform via the pipeline interface.
681    ///
682    /// # Errors
683    ///
684    /// Propagates errors from [`Transform::transform`].
685    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
686        self.transform(x)
687    }
688}
689
690// ---------------------------------------------------------------------------
691// Tests
692// ---------------------------------------------------------------------------
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use approx::assert_abs_diff_eq;
698    use ndarray::Array2;
699
700    fn simple_data() -> Array2<f64> {
701        // 10 samples, 4 features with some latent structure.
702        Array2::from_shape_vec(
703            (10, 4),
704            vec![
705                1.0, 2.0, 1.5, 3.0, 1.1, 2.1, 1.6, 3.1, 0.9, 1.9, 1.4, 2.9, 2.0, 4.0, 3.0, 6.0,
706                2.1, 4.1, 3.1, 6.1, 1.9, 3.9, 2.9, 5.9, 0.5, 1.0, 0.7, 1.5, 0.4, 0.9, 0.6, 1.4,
707                0.6, 1.1, 0.8, 1.6, 1.5, 3.0, 2.2, 4.5,
708            ],
709        )
710        .unwrap()
711    }
712
713    #[test]
714    fn test_fa_fit_returns_fitted() {
715        let fa = FactorAnalysis::<f64>::new(2);
716        let x = simple_data();
717        let fitted = fa.fit(&x, &()).unwrap();
718        assert_eq!(fitted.components().dim(), (4, 2));
719    }
720
721    #[test]
722    fn test_fa_transform_shape() {
723        let fa = FactorAnalysis::<f64>::new(2);
724        let x = simple_data();
725        let fitted = fa.fit(&x, &()).unwrap();
726        let scores = fitted.transform(&x).unwrap();
727        assert_eq!(scores.dim(), (10, 2));
728    }
729
730    #[test]
731    fn test_fa_transform_new_data() {
732        let fa = FactorAnalysis::<f64>::new(1);
733        let x = simple_data();
734        let fitted = fa.fit(&x, &()).unwrap();
735        let x_new = Array2::from_shape_vec(
736            (3, 4),
737            vec![1.0, 2.0, 1.5, 3.0, 2.0, 4.0, 3.0, 6.0, 0.5, 1.0, 0.7, 1.5],
738        )
739        .unwrap();
740        let scores = fitted.transform(&x_new).unwrap();
741        assert_eq!(scores.dim(), (3, 1));
742    }
743
744    #[test]
745    fn test_fa_noise_variance_positive() {
746        let fa = FactorAnalysis::<f64>::new(1);
747        let x = simple_data();
748        let fitted = fa.fit(&x, &()).unwrap();
749        for &v in fitted.noise_variance() {
750            assert!(v > 0.0, "noise variance must be positive, got {v}");
751        }
752    }
753
754    #[test]
755    fn test_fa_mean_shape() {
756        let fa = FactorAnalysis::<f64>::new(1);
757        let x = simple_data();
758        let fitted = fa.fit(&x, &()).unwrap();
759        assert_eq!(fitted.mean().len(), 4);
760    }
761
762    #[test]
763    fn test_fa_n_iter_positive() {
764        let fa = FactorAnalysis::<f64>::new(1);
765        let x = simple_data();
766        let fitted = fa.fit(&x, &()).unwrap();
767        assert!(fitted.n_iter() >= 1);
768    }
769
770    #[test]
771    fn test_fa_log_likelihood_finite() {
772        let fa = FactorAnalysis::<f64>::new(1);
773        let x = simple_data();
774        let fitted = fa.fit(&x, &()).unwrap();
775        assert!(fitted.log_likelihood().is_finite());
776    }
777
778    #[test]
779    fn test_fa_error_zero_components() {
780        let fa = FactorAnalysis::<f64>::new(0);
781        let x = simple_data();
782        assert!(fa.fit(&x, &()).is_err());
783    }
784
785    #[test]
786    fn test_fa_error_too_many_components() {
787        let fa = FactorAnalysis::<f64>::new(10); // more than n_features = 4
788        let x = simple_data();
789        assert!(fa.fit(&x, &()).is_err());
790    }
791
792    #[test]
793    fn test_fa_error_insufficient_samples() {
794        let fa = FactorAnalysis::<f64>::new(1);
795        let x = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
796        assert!(fa.fit(&x, &()).is_err());
797    }
798
799    #[test]
800    fn test_fa_transform_shape_mismatch() {
801        let fa = FactorAnalysis::<f64>::new(1);
802        let x = simple_data();
803        let fitted = fa.fit(&x, &()).unwrap();
804        let x_bad = Array2::<f64>::zeros((3, 7));
805        assert!(fitted.transform(&x_bad).is_err());
806    }
807
808    #[test]
809    fn test_fa_reproducible_with_seed() {
810        let fa1 = FactorAnalysis::<f64>::new(2).with_random_state(42);
811        let fa2 = FactorAnalysis::<f64>::new(2).with_random_state(42);
812        let x = simple_data();
813        let f1 = fa1.fit(&x, &()).unwrap();
814        let f2 = fa2.fit(&x, &()).unwrap();
815        let c1 = f1.components();
816        let c2 = f2.components();
817        for (a, b) in c1.iter().zip(c2.iter()) {
818            assert_abs_diff_eq!(a, b, epsilon = 1e-12);
819        }
820    }
821
822    #[test]
823    fn test_fa_different_seeds_differ() {
824        let fa1 = FactorAnalysis::<f64>::new(2)
825            .with_random_state(0)
826            .with_max_iter(1);
827        let fa2 = FactorAnalysis::<f64>::new(2)
828            .with_random_state(99)
829            .with_max_iter(1);
830        let x = simple_data();
831        let f1 = fa1.fit(&x, &()).unwrap();
832        let f2 = fa2.fit(&x, &()).unwrap();
833        // After 1 iteration with different seeds the components should differ.
834        let diff: f64 = f1
835            .components()
836            .iter()
837            .zip(f2.components().iter())
838            .map(|(a, b)| (a - b).abs())
839            .sum();
840        // They may differ unless the initialisation is identical.
841        let _ = diff; // just check it doesn't panic
842    }
843
844    #[test]
845    fn test_fa_components_accessor() {
846        let fa = FactorAnalysis::<f64>::new(2);
847        let x = simple_data();
848        let fitted = fa.fit(&x, &()).unwrap();
849        assert_eq!(fitted.components().ncols(), 2);
850        assert_eq!(fitted.components().nrows(), 4);
851    }
852
853    #[test]
854    fn test_fa_n_components_getter() {
855        let fa = FactorAnalysis::<f64>::new(3);
856        assert_eq!(fa.n_components(), 3);
857    }
858
859    #[test]
860    fn test_fa_pipeline_transformer() {
861        use ferrolearn_core::pipeline::PipelineTransformer;
862        let fa = FactorAnalysis::<f64>::new(2);
863        let x = simple_data();
864        let y = Array1::<f64>::zeros(10);
865        let fitted = fa.fit_pipeline(&x, &y).unwrap();
866        let out = fitted.transform_pipeline(&x).unwrap();
867        assert_eq!(out.ncols(), 2);
868    }
869
870    #[test]
871    fn test_fa_scores_not_all_zero() {
872        let fa = FactorAnalysis::<f64>::new(2);
873        let x = simple_data();
874        let fitted = fa.fit(&x, &()).unwrap();
875        let scores = fitted.transform(&x).unwrap();
876        let total: f64 = scores.iter().map(|v| v.abs()).sum();
877        assert!(total > 0.0, "Factor scores should not all be zero");
878    }
879}