use ferrolearn_core::error::FerroError;
use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
#[derive(Debug, Clone)]
pub struct FactorAnalysis<F> {
n_components: usize,
max_iter: usize,
tol: f64,
random_state: Option<u64>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> FactorAnalysis<F> {
#[must_use]
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 1000,
tol: 1e-3,
random_state: None,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn n_components(&self) -> usize {
self.n_components
}
}
impl<F: Float + Send + Sync + 'static> Default for FactorAnalysis<F> {
fn default() -> Self {
Self::new(1)
}
}
#[derive(Debug, Clone)]
pub struct FittedFactorAnalysis<F> {
components: Array2<F>,
noise_variance: Array1<F>,
mean: Array1<F>,
n_iter: usize,
log_likelihood: F,
}
impl<F: Float + Send + Sync + 'static> FittedFactorAnalysis<F> {
#[must_use]
pub fn components(&self) -> &Array2<F> {
&self.components
}
#[must_use]
pub fn noise_variance(&self) -> &Array1<F> {
&self.noise_variance
}
#[must_use]
pub fn mean(&self) -> &Array1<F> {
&self.mean
}
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter
}
#[must_use]
pub fn log_likelihood(&self) -> F {
self.log_likelihood
}
pub fn inverse_transform(&self, z: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_components = self.components.ncols();
if z.ncols() != n_components {
return Err(FerroError::ShapeMismatch {
expected: vec![z.nrows(), n_components],
actual: vec![z.nrows(), z.ncols()],
context: "FittedFactorAnalysis::inverse_transform".into(),
});
}
let mut result = z.dot(&self.components.t());
for mut row in result.rows_mut() {
for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
*v = *v + m;
}
}
Ok(result)
}
}
fn cholesky_inv<F: Float>(a: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n = a.nrows();
let mut l = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut s = a[[i, j]];
for k in 0..j {
s = s - l[[i, k]] * l[[j, k]];
}
if i == j {
if s <= F::zero() {
s = F::from(1e-10).unwrap();
}
l[[i, j]] = s.sqrt();
} else {
l[[i, j]] = s / l[[j, j]];
}
}
}
let mut l_inv = Array2::<F>::zeros((n, n));
for j in 0..n {
l_inv[[j, j]] = F::one() / l[[j, j]];
for i in (j + 1)..n {
let mut s = F::zero();
for k in j..i {
s = s + l[[i, k]] * l_inv[[k, j]];
}
l_inv[[i, j]] = -s / l[[i, i]];
}
}
let mut inv = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut s = F::zero();
let start = i.max(j);
for k in start..n {
s = s + l_inv[[k, i]] * l_inv[[k, j]];
}
inv[[i, j]] = s;
}
}
Ok(inv)
}
fn compute_log_likelihood<F: Float + Send + Sync + 'static>(
x_centered: &Array2<F>,
w: &Array2<F>,
psi: &Array1<F>,
) -> F {
let (n, p) = x_centered.dim();
let k = w.ncols();
let two_pi = F::from(2.0 * std::f64::consts::PI).unwrap();
let n_f = F::from(n).unwrap();
let p_f = F::from(p).unwrap();
let mut wtpsiw = Array2::<F>::zeros((k, k));
for i in 0..k {
for j in 0..k {
let mut s = F::zero();
for d in 0..p {
s = s + w[[d, i]] * w[[d, j]] / psi[d];
}
wtpsiw[[i, j]] = s;
}
}
for i in 0..k {
wtpsiw[[i, i]] = wtpsiw[[i, i]] + F::one();
}
let mut log_det_inner = F::zero();
{
let mut l = Array2::<F>::zeros((k, k));
for i in 0..k {
for j in 0..=i {
let mut s = wtpsiw[[i, j]];
for kk in 0..j {
s = s - l[[i, kk]] * l[[j, kk]];
}
if i == j {
s = if s > F::zero() {
s
} else {
F::from(1e-30).unwrap()
};
l[[i, j]] = s.sqrt();
log_det_inner = log_det_inner + l[[i, j]].ln();
} else {
l[[i, j]] = s / l[[j, j]];
}
}
}
log_det_inner = log_det_inner * F::from(2.0).unwrap();
}
let log_det_psi: F = psi
.iter()
.copied()
.map(|v| {
let v_clamped = if v > F::zero() {
v
} else {
F::from(1e-30).unwrap()
};
v_clamped.ln()
})
.fold(F::zero(), |a, b| a + b);
let log_det_sigma = log_det_inner + log_det_psi;
let m_inv = match cholesky_inv(&wtpsiw) {
Ok(inv) => inv,
Err(_) => return F::neg_infinity(),
};
let mut trace_sum = F::zero();
for i in 0..n {
let mut psi_inv_x = Array1::<F>::zeros(p);
let mut xpsiinvx = F::zero();
for d in 0..p {
psi_inv_x[d] = x_centered[[i, d]] / psi[d];
xpsiinvx = xpsiinvx + x_centered[[i, d]] * psi_inv_x[d];
}
let mut wtpx = Array1::<F>::zeros(k);
for kk in 0..k {
let mut s = F::zero();
for d in 0..p {
s = s + w[[d, kk]] * psi_inv_x[d];
}
wtpx[kk] = s;
}
let mut quad = F::zero();
for ii in 0..k {
let mut s = F::zero();
for jj in 0..k {
s = s + m_inv[[ii, jj]] * wtpx[jj];
}
quad = quad + wtpx[ii] * s;
}
trace_sum = trace_sum + xpsiinvx - quad;
}
let trace_term = trace_sum / n_f;
let half = F::from(0.5).unwrap();
-n_f * half * (p_f * two_pi.ln() + log_det_sigma + trace_term)
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FactorAnalysis<F> {
type Fitted = FittedFactorAnalysis<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFactorAnalysis<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_components == 0 {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: "must be at least 1".into(),
});
}
if self.n_components > n_features {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: format!(
"n_components ({}) exceeds n_features ({})",
self.n_components, n_features
),
});
}
if n_samples < 2 {
return Err(FerroError::InsufficientSamples {
required: 2,
actual: n_samples,
context: "FactorAnalysis requires at least 2 samples".into(),
});
}
let k = self.n_components;
let p = n_features;
let n_f = F::from(n_samples).unwrap();
let mut mean = Array1::<F>::zeros(p);
for j in 0..p {
let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
mean[j] = s / n_f;
}
let mut xc = x.to_owned();
for mut row in xc.rows_mut() {
for (v, &m) in row.iter_mut().zip(mean.iter()) {
*v = *v - m;
}
}
let seed = self.random_state.unwrap_or(42);
let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
let std_normal = StandardNormal;
let mut w = Array2::<F>::zeros((p, k));
let scale = F::from(0.01).unwrap();
for i in 0..p {
for j in 0..k {
let v: f64 = std_normal.sample(&mut rng);
w[[i, j]] = F::from(v).unwrap() * scale;
}
}
let mut psi = Array1::<F>::from_elem(p, F::one());
let mut prev_ll = F::neg_infinity();
let mut n_iter = 0usize;
let tol_f = F::from(self.tol).unwrap();
for iter in 0..self.max_iter {
let mut wzw = Array2::<F>::zeros((k, k));
for i in 0..k {
for j in 0..k {
let mut s = F::zero();
for d in 0..p {
s = s + w[[d, i]] * w[[d, j]] / psi[d];
}
wzw[[i, j]] = s;
}
}
for i in 0..k {
wzw[[i, i]] = wzw[[i, i]] + F::one();
}
let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
message: "FactorAnalysis: (I + W^T Ψ⁻¹ W) is singular".into(),
})?;
let mut beta = Array2::<F>::zeros((k, p));
for i in 0..k {
for d in 0..p {
let mut s = F::zero();
for j in 0..k {
s = s + sigma_z[[i, j]] * w[[d, j]];
}
beta[[i, d]] = s / psi[d];
}
}
let ez = beta.dot(&xc.t());
let ezz_t_sum = sigma_z.mapv(|v| v * n_f) + ez.dot(&ez.t());
let xc_ez_t = xc.t().dot(&ez.t());
let ezz_t_inv =
cholesky_inv(&ezz_t_sum).map_err(|_| FerroError::NumericalInstability {
message: "FactorAnalysis: E[ZZ^T] is singular in M-step".into(),
})?;
let w_new = xc_ez_t.dot(&ezz_t_inv);
let mut psi_new = Array1::<F>::zeros(p);
for d in 0..p {
let var_d = xc
.column(d)
.iter()
.copied()
.map(|v| v * v)
.fold(F::zero(), |a, b| a + b)
/ n_f;
let mut ez_xd = Array1::<F>::zeros(k);
for kk in 0..k {
let s = (0..n_samples)
.map(|i| ez[[kk, i]] * xc[[i, d]])
.fold(F::zero(), |a, b| a + b);
ez_xd[kk] = s / n_f;
}
let wd = w_new.row(d);
let corr = wd
.iter()
.zip(ez_xd.iter())
.map(|(&wi, &ei)| wi * ei)
.fold(F::zero(), |a, b| a + b);
let psi_d = var_d - corr;
psi_new[d] = if psi_d > F::from(1e-6).unwrap() {
psi_d
} else {
F::from(1e-6).unwrap()
};
}
w = w_new;
psi = psi_new;
let ll = compute_log_likelihood(&xc, &w, &psi);
let ll_change = (ll - prev_ll).abs();
n_iter = iter + 1;
if ll_change < tol_f && iter > 0 {
prev_ll = ll;
break;
}
prev_ll = ll;
}
Ok(FittedFactorAnalysis {
components: w,
noise_variance: psi,
mean,
n_iter,
log_likelihood: prev_ll,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFactorAnalysis<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = self.mean.len();
if x.ncols() != n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), n_features],
actual: vec![x.nrows(), x.ncols()],
context: "FittedFactorAnalysis::transform".into(),
});
}
let (n_samples, _) = x.dim();
let k = self.components.ncols();
let mut xc = x.to_owned();
for mut row in xc.rows_mut() {
for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
*v = *v - m;
}
}
let mut wzw = Array2::<F>::zeros((k, k));
for i in 0..k {
for j in 0..k {
let mut s = F::zero();
for d in 0..n_features {
s = s + self.components[[d, i]] * self.components[[d, j]]
/ self.noise_variance[d];
}
wzw[[i, j]] = s;
}
}
for i in 0..k {
wzw[[i, i]] = wzw[[i, i]] + F::one();
}
let sigma_z = cholesky_inv(&wzw).map_err(|_| FerroError::NumericalInstability {
message: "FittedFactorAnalysis::transform: (I + W^T Ψ⁻¹ W) is singular".into(),
})?;
let mut beta = Array2::<F>::zeros((k, n_features));
for i in 0..k {
for d in 0..n_features {
let mut s = F::zero();
for j in 0..k {
s = s + sigma_z[[i, j]] * self.components[[d, j]];
}
beta[[i, d]] = s / self.noise_variance[d];
}
}
let ez = beta.dot(&xc.t()); let scores = ez.t().to_owned(); assert_eq!(scores.dim(), (n_samples, k));
Ok(scores)
}
}
impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FactorAnalysis<F> {
fn fit_pipeline(
&self,
x: &Array2<F>,
_y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
let fitted = self.fit(x, &())?;
Ok(Box::new(fitted))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFactorAnalysis<F> {
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
self.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array2;
fn simple_data() -> Array2<f64> {
Array2::from_shape_vec(
(10, 4),
vec![
1.0, 2.0, 1.5, 3.0, 1.1, 2.1, 1.6, 3.1, 0.9, 1.9, 1.4, 2.9, 2.0, 4.0, 3.0, 6.0,
2.1, 4.1, 3.1, 6.1, 1.9, 3.9, 2.9, 5.9, 0.5, 1.0, 0.7, 1.5, 0.4, 0.9, 0.6, 1.4,
0.6, 1.1, 0.8, 1.6, 1.5, 3.0, 2.2, 4.5,
],
)
.unwrap()
}
#[test]
fn test_fa_fit_returns_fitted() {
let fa = FactorAnalysis::<f64>::new(2);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (4, 2));
}
#[test]
fn test_fa_transform_shape() {
let fa = FactorAnalysis::<f64>::new(2);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
let scores = fitted.transform(&x).unwrap();
assert_eq!(scores.dim(), (10, 2));
}
#[test]
fn test_fa_transform_new_data() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
let x_new = Array2::from_shape_vec(
(3, 4),
vec![1.0, 2.0, 1.5, 3.0, 2.0, 4.0, 3.0, 6.0, 0.5, 1.0, 0.7, 1.5],
)
.unwrap();
let scores = fitted.transform(&x_new).unwrap();
assert_eq!(scores.dim(), (3, 1));
}
#[test]
fn test_fa_noise_variance_positive() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
for &v in fitted.noise_variance() {
assert!(v > 0.0, "noise variance must be positive, got {v}");
}
}
#[test]
fn test_fa_mean_shape() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
assert_eq!(fitted.mean().len(), 4);
}
#[test]
fn test_fa_n_iter_positive() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
assert!(fitted.n_iter() >= 1);
}
#[test]
fn test_fa_log_likelihood_finite() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
assert!(fitted.log_likelihood().is_finite());
}
#[test]
fn test_fa_error_zero_components() {
let fa = FactorAnalysis::<f64>::new(0);
let x = simple_data();
assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_fa_error_too_many_components() {
let fa = FactorAnalysis::<f64>::new(10); let x = simple_data();
assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_fa_error_insufficient_samples() {
let fa = FactorAnalysis::<f64>::new(1);
let x = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert!(fa.fit(&x, &()).is_err());
}
#[test]
fn test_fa_transform_shape_mismatch() {
let fa = FactorAnalysis::<f64>::new(1);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
let x_bad = Array2::<f64>::zeros((3, 7));
assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_fa_reproducible_with_seed() {
let fa1 = FactorAnalysis::<f64>::new(2).with_random_state(42);
let fa2 = FactorAnalysis::<f64>::new(2).with_random_state(42);
let x = simple_data();
let f1 = fa1.fit(&x, &()).unwrap();
let f2 = fa2.fit(&x, &()).unwrap();
let c1 = f1.components();
let c2 = f2.components();
for (a, b) in c1.iter().zip(c2.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-12);
}
}
#[test]
fn test_fa_different_seeds_differ() {
let fa1 = FactorAnalysis::<f64>::new(2)
.with_random_state(0)
.with_max_iter(1);
let fa2 = FactorAnalysis::<f64>::new(2)
.with_random_state(99)
.with_max_iter(1);
let x = simple_data();
let f1 = fa1.fit(&x, &()).unwrap();
let f2 = fa2.fit(&x, &()).unwrap();
let diff: f64 = f1
.components()
.iter()
.zip(f2.components().iter())
.map(|(a, b)| (a - b).abs())
.sum();
let _ = diff; }
#[test]
fn test_fa_components_accessor() {
let fa = FactorAnalysis::<f64>::new(2);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
assert_eq!(fitted.components().ncols(), 2);
assert_eq!(fitted.components().nrows(), 4);
}
#[test]
fn test_fa_n_components_getter() {
let fa = FactorAnalysis::<f64>::new(3);
assert_eq!(fa.n_components(), 3);
}
#[test]
fn test_fa_pipeline_transformer() {
use ferrolearn_core::pipeline::PipelineTransformer;
let fa = FactorAnalysis::<f64>::new(2);
let x = simple_data();
let y = Array1::<f64>::zeros(10);
let fitted = fa.fit_pipeline(&x, &y).unwrap();
let out = fitted.transform_pipeline(&x).unwrap();
assert_eq!(out.ncols(), 2);
}
#[test]
fn test_fa_scores_not_all_zero() {
let fa = FactorAnalysis::<f64>::new(2);
let x = simple_data();
let fitted = fa.fit(&x, &()).unwrap();
let scores = fitted.transform(&x).unwrap();
let total: f64 = scores.iter().map(|v| v.abs()).sum();
assert!(total > 0.0, "Factor scores should not all be zero");
}
}