Skip to main content

ferrolearn_decomp/
nmf.rs

1//! Non-negative Matrix Factorization (NMF).
2//!
3//! [`NMF`] decomposes a non-negative matrix `X` into two non-negative
4//! factors `W` and `H` such that `X ~ W * H`, where:
5//! - `X` has shape `(n_samples, n_features)`
6//! - `W` has shape `(n_samples, n_components)`
7//! - `H` has shape `(n_components, n_features)`
8//!
9//! # Algorithm
10//!
11//! Two solvers are supported:
12//!
13//! - **Multiplicative Update** (Lee & Seung, 2001): iteratively update `W` and
14//!   `H` using multiplicative rules that guarantee non-negativity.
15//! - **Coordinate Descent**: iteratively solve for each element of `W` and `H`
16//!   using closed-form coordinate-wise updates.
17//!
18//! # Initialization
19//!
20//! - **Random**: initialize `W` and `H` with random non-negative values.
21//! - **NNDSVD**: Non-Negative Double SVD, initializes `W` and `H` from a
22//!   truncated SVD of `X`, setting negative entries to zero.
23//!
24//! # Examples
25//!
26//! ```
27//! use ferrolearn_decomp::NMF;
28//! use ferrolearn_core::traits::{Fit, Transform};
29//! use ndarray::array;
30//!
31//! let nmf = NMF::<f64>::new(2);
32//! let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
33//! let fitted = nmf.fit(&x, &()).unwrap();
34//! let projected = fitted.transform(&x).unwrap();
35//! assert_eq!(projected.ncols(), 2);
36//! ```
37
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
40use ferrolearn_core::traits::{Fit, Transform};
41use ndarray::{Array1, Array2};
42use num_traits::Float;
43use rand::SeedableRng;
44use rand_distr::{Distribution, Uniform};
45
46// ---------------------------------------------------------------------------
47// Configuration enums
48// ---------------------------------------------------------------------------
49
50/// The solver algorithm for NMF.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum NMFSolver {
53    /// Multiplicative update rules (Lee & Seung, 2001).
54    MultiplicativeUpdate,
55    /// Coordinate descent.
56    CoordinateDescent,
57}
58
59/// The initialization strategy for NMF.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum NMFInit {
62    /// Random non-negative initialization.
63    Random,
64    /// Non-Negative Double SVD initialization.
65    Nndsvd,
66}
67
68// ---------------------------------------------------------------------------
69// NMF (unfitted)
70// ---------------------------------------------------------------------------
71
72/// Non-negative Matrix Factorization configuration.
73///
74/// Holds hyperparameters for the NMF decomposition. Calling [`Fit::fit`]
75/// computes the factorization and returns a [`FittedNMF`] that can
76/// project new data via [`Transform::transform`].
77#[derive(Debug, Clone)]
78pub struct NMF<F> {
79    /// Number of components to extract.
80    n_components: usize,
81    /// Maximum number of iterations for the solver.
82    max_iter: usize,
83    /// Convergence tolerance for the solver.
84    tol: f64,
85    /// The solver algorithm to use.
86    solver: NMFSolver,
87    /// The initialization strategy.
88    init: NMFInit,
89    /// Optional random seed for reproducibility.
90    random_state: Option<u64>,
91    _marker: std::marker::PhantomData<F>,
92}
93
94impl<F: Float + Send + Sync + 'static> NMF<F> {
95    /// Create a new `NMF` that extracts `n_components` components.
96    ///
97    /// Defaults: `max_iter=200`, `tol=1e-4`, solver=`MultiplicativeUpdate`,
98    /// init=`Random`, no random seed.
99    #[must_use]
100    pub fn new(n_components: usize) -> Self {
101        Self {
102            n_components,
103            max_iter: 200,
104            tol: 1e-4,
105            solver: NMFSolver::MultiplicativeUpdate,
106            init: NMFInit::Random,
107            random_state: None,
108            _marker: std::marker::PhantomData,
109        }
110    }
111
112    /// Set the maximum number of iterations.
113    #[must_use]
114    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
115        self.max_iter = max_iter;
116        self
117    }
118
119    /// Set the convergence tolerance.
120    #[must_use]
121    pub fn with_tol(mut self, tol: f64) -> Self {
122        self.tol = tol;
123        self
124    }
125
126    /// Set the solver algorithm.
127    #[must_use]
128    pub fn with_solver(mut self, solver: NMFSolver) -> Self {
129        self.solver = solver;
130        self
131    }
132
133    /// Set the initialization strategy.
134    #[must_use]
135    pub fn with_init(mut self, init: NMFInit) -> Self {
136        self.init = init;
137        self
138    }
139
140    /// Set the random seed for reproducible results.
141    #[must_use]
142    pub fn with_random_state(mut self, seed: u64) -> Self {
143        self.random_state = Some(seed);
144        self
145    }
146
147    /// Return the configured number of components.
148    #[must_use]
149    pub fn n_components(&self) -> usize {
150        self.n_components
151    }
152
153    /// Return the configured maximum iterations.
154    #[must_use]
155    pub fn max_iter(&self) -> usize {
156        self.max_iter
157    }
158
159    /// Return the configured tolerance.
160    #[must_use]
161    pub fn tol(&self) -> f64 {
162        self.tol
163    }
164
165    /// Return the configured solver.
166    #[must_use]
167    pub fn solver(&self) -> NMFSolver {
168        self.solver
169    }
170
171    /// Return the configured initialization strategy.
172    #[must_use]
173    pub fn init(&self) -> NMFInit {
174        self.init
175    }
176
177    /// Return the configured random state, if any.
178    #[must_use]
179    pub fn random_state(&self) -> Option<u64> {
180        self.random_state
181    }
182}
183
184// ---------------------------------------------------------------------------
185// FittedNMF
186// ---------------------------------------------------------------------------
187
188/// A fitted NMF model holding the learned components and reconstruction error.
189///
190/// Created by calling [`Fit::fit`] on an [`NMF`]. Implements
191/// [`Transform<Array2<F>>`] to project new data onto the learned components.
192#[derive(Debug, Clone)]
193pub struct FittedNMF<F> {
194    /// Learned component matrix H, shape `(n_components, n_features)`.
195    components_: Array2<F>,
196    /// The Frobenius norm of the reconstruction error at convergence.
197    reconstruction_err_: F,
198    /// Number of iterations performed.
199    n_iter_: usize,
200}
201
202impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
203    /// Learned components (H matrix), shape `(n_components, n_features)`.
204    #[must_use]
205    pub fn components(&self) -> &Array2<F> {
206        &self.components_
207    }
208
209    /// Frobenius norm of the reconstruction error `||X - W*H||_F`.
210    #[must_use]
211    pub fn reconstruction_err(&self) -> F {
212        self.reconstruction_err_
213    }
214
215    /// Number of iterations performed during fitting.
216    #[must_use]
217    pub fn n_iter(&self) -> usize {
218        self.n_iter_
219    }
220}
221
222// ---------------------------------------------------------------------------
223// Internal helpers
224// ---------------------------------------------------------------------------
225
226/// Compute the Frobenius norm of `X - W * H`.
227fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
228    let wh = w.dot(h);
229    let mut err = F::zero();
230    for (a, b) in x.iter().zip(wh.iter()) {
231        let diff = *a - *b;
232        err = err + diff * diff;
233    }
234    err.sqrt()
235}
236
237/// Small epsilon to prevent division by zero.
238fn eps<F: Float>() -> F {
239    F::from(1e-12).unwrap_or(F::epsilon())
240}
241
242/// Initialize W and H with random non-negative values.
243fn init_random<F: Float>(
244    n_samples: usize,
245    n_features: usize,
246    n_components: usize,
247    seed: u64,
248) -> (Array2<F>, Array2<F>) {
249    let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
250    let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
251
252    let mut w = Array2::<F>::zeros((n_samples, n_components));
253    for elem in w.iter_mut() {
254        *elem = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) + eps::<F>();
255    }
256
257    let mut h = Array2::<F>::zeros((n_components, n_features));
258    for elem in h.iter_mut() {
259        *elem = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) + eps::<F>();
260    }
261
262    (w, h)
263}
264
265/// NNDSVD initialization: compute a truncated SVD-like initialization.
266///
267/// Uses a simple approach: compute `X^T X`, eigendecompose, then use the
268/// top eigenvectors to initialize H, and solve for W = X * H^+ (pseudoinverse).
269fn init_nndsvd<F: Float + Send + Sync + 'static>(
270    x: &Array2<F>,
271    n_components: usize,
272    seed: u64,
273) -> Result<(Array2<F>, Array2<F>), FerroError> {
274    let (n_samples, n_features) = x.dim();
275
276    // Compute mean of X for scale.
277    let mut total = F::zero();
278    for &v in x.iter() {
279        total = total + v;
280    }
281    let avg = (total / F::from(n_samples * n_features).unwrap())
282        .abs()
283        .sqrt();
284    let avg = if avg < eps::<F>() { F::one() } else { avg };
285
286    // Compute X^T X.
287    let xtx = x.t().dot(x);
288
289    // Eigendecompose with Jacobi.
290    let max_iter = n_features * n_features * 100 + 1000;
291    let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
292
293    // Sort eigenvalues descending.
294    let mut indices: Vec<usize> = (0..n_features).collect();
295    indices.sort_by(|&a, &b| {
296        eigenvalues[b]
297            .partial_cmp(&eigenvalues[a])
298            .unwrap_or(std::cmp::Ordering::Equal)
299    });
300
301    // Build H from top eigenvectors (as rows), clamp negatives to zero.
302    let mut h = Array2::<F>::zeros((n_components, n_features));
303    for (k, &idx) in indices.iter().take(n_components).enumerate() {
304        for j in 0..n_features {
305            let val = eigenvectors[[j, idx]];
306            h[[k, j]] = if val > F::zero() { val } else { F::zero() };
307        }
308        // Ensure row is not all zeros.
309        let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
310        if row_sum < eps::<F>() {
311            // Fall back to small random values.
312            let mut rng: rand::rngs::StdRng =
313                SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
314            let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
315            for j in 0..n_features {
316                h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) * avg;
317            }
318        }
319    }
320
321    // Compute W = X * H^T * (H * H^T)^{-1}, but simpler: use multiplicative
322    // update step starting from random W.
323    let mut w = Array2::<F>::zeros((n_samples, n_components));
324    // Solve W by least squares: W = X * H^T * pinv(H * H^T)
325    // For simplicity, initialize W = X * H^T and normalize.
326    let ht = h.t();
327    let w_init = x.dot(&ht);
328    for i in 0..n_samples {
329        for k in 0..n_components {
330            let val = w_init[[i, k]];
331            w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
332        }
333    }
334
335    Ok((w, h))
336}
337
338/// Jacobi eigendecomposition for symmetric matrices.
339///
340/// Returns `(eigenvalues, eigenvectors)` where column `i` of `eigenvectors`
341/// is the eigenvector for `eigenvalues[i]`. Eigenvalues are NOT sorted.
342fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
343    a: &Array2<F>,
344    max_iter: usize,
345) -> Result<(Array1<F>, Array2<F>), FerroError> {
346    let n = a.nrows();
347    if n == 0 {
348        return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
349    }
350    if n == 1 {
351        let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
352        let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
353        return Ok((eigenvalues, eigenvectors));
354    }
355
356    let mut mat = a.to_owned();
357    let mut v = Array2::<F>::zeros((n, n));
358    for i in 0..n {
359        v[[i, i]] = F::one();
360    }
361
362    let tol = F::from(1e-12).unwrap_or(F::epsilon());
363
364    for _iteration in 0..max_iter {
365        let mut max_off = F::zero();
366        let mut p = 0;
367        let mut q = 1;
368        for i in 0..n {
369            for j in (i + 1)..n {
370                let val = mat[[i, j]].abs();
371                if val > max_off {
372                    max_off = val;
373                    p = i;
374                    q = j;
375                }
376            }
377        }
378
379        if max_off < tol {
380            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
381            return Ok((eigenvalues, v));
382        }
383
384        let app = mat[[p, p]];
385        let aqq = mat[[q, q]];
386        let apq = mat[[p, q]];
387
388        let theta = if (app - aqq).abs() < tol {
389            F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
390        } else {
391            let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
392            let t = if tau >= F::zero() {
393                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
394            } else {
395                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
396            };
397            t.atan()
398        };
399
400        let c = theta.cos();
401        let s = theta.sin();
402
403        let mut new_mat = mat.clone();
404        for i in 0..n {
405            if i != p && i != q {
406                let mip = mat[[i, p]];
407                let miq = mat[[i, q]];
408                new_mat[[i, p]] = c * mip - s * miq;
409                new_mat[[p, i]] = new_mat[[i, p]];
410                new_mat[[i, q]] = s * mip + c * miq;
411                new_mat[[q, i]] = new_mat[[i, q]];
412            }
413        }
414
415        new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
416        new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
417        new_mat[[p, q]] = F::zero();
418        new_mat[[q, p]] = F::zero();
419
420        mat = new_mat;
421
422        for i in 0..n {
423            let vip = v[[i, p]];
424            let viq = v[[i, q]];
425            v[[i, p]] = c * vip - s * viq;
426            v[[i, q]] = s * vip + c * viq;
427        }
428    }
429
430    Err(FerroError::ConvergenceFailure {
431        iterations: max_iter,
432        message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
433    })
434}
435
436/// Multiplicative update solver (Lee & Seung, 2001).
437///
438/// Update rules:
439///   W <- W * (X H^T) / (W H H^T + eps)
440///   H <- H * (W^T X) / (W^T W H + eps)
441fn solve_multiplicative_update<F: Float + 'static>(
442    x: &Array2<F>,
443    w: &mut Array2<F>,
444    h: &mut Array2<F>,
445    max_iter: usize,
446    tol: f64,
447) -> usize {
448    let tol_f = F::from(tol).unwrap_or(F::epsilon());
449    let epsilon = eps::<F>();
450    let mut prev_err = reconstruction_error(x, w, h);
451
452    for iteration in 0..max_iter {
453        // Update H: H <- H * (W^T X) / (W^T W H + eps)
454        let wt = w.t();
455        let numerator_h = wt.dot(x);
456        let denominator_h = wt.dot(&*w).dot(&*h);
457
458        for (h_val, (num, den)) in h
459            .iter_mut()
460            .zip(numerator_h.iter().zip(denominator_h.iter()))
461        {
462            *h_val = *h_val * (*num / (*den + epsilon));
463        }
464
465        // Update W: W <- W * (X H^T) / (W H H^T + eps)
466        let ht = h.t();
467        let numerator_w = x.dot(&ht);
468        let denominator_w = w.dot(&*h).dot(&ht);
469
470        for (w_val, (num, den)) in w
471            .iter_mut()
472            .zip(numerator_w.iter().zip(denominator_w.iter()))
473        {
474            *w_val = *w_val * (*num / (*den + epsilon));
475        }
476
477        // Check convergence.
478        let err = reconstruction_error(x, w, h);
479        if (prev_err - err).abs() < tol_f {
480            return iteration + 1;
481        }
482        prev_err = err;
483    }
484
485    max_iter
486}
487
488/// Coordinate descent solver.
489///
490/// Updates each element of H and W by solving a scalar minimization problem.
491fn solve_coordinate_descent<F: Float + 'static>(
492    x: &Array2<F>,
493    w: &mut Array2<F>,
494    h: &mut Array2<F>,
495    max_iter: usize,
496    tol: f64,
497) -> usize {
498    let (n_samples, n_features) = x.dim();
499    let n_components = h.nrows();
500    let tol_f = F::from(tol).unwrap_or(F::epsilon());
501    let epsilon = eps::<F>();
502    let mut prev_err = reconstruction_error(x, w, h);
503
504    for iteration in 0..max_iter {
505        // Update H: for each k, j, solve for H[k,j]
506        // H[k,j] = max(0, (W[:,k]^T * (X[:,j] - W * H[:,j] + W[:,k]*H[k,j])) / (W[:,k]^T W[:,k]))
507        for k in 0..n_components {
508            let mut wk_norm_sq = F::zero();
509            for i in 0..n_samples {
510                wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
511            }
512
513            if wk_norm_sq < epsilon {
514                continue;
515            }
516
517            for j in 0..n_features {
518                // Compute residual + current contribution.
519                let mut numerator = F::zero();
520                for i in 0..n_samples {
521                    let mut wh_ij = F::zero();
522                    for kk in 0..n_components {
523                        if kk != k {
524                            wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
525                        }
526                    }
527                    numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
528                }
529
530                h[[k, j]] = if numerator > F::zero() {
531                    numerator / wk_norm_sq
532                } else {
533                    F::zero()
534                };
535            }
536        }
537
538        // Update W: for each i, k, solve for W[i,k]
539        for k in 0..n_components {
540            let mut hk_norm_sq = F::zero();
541            for j in 0..n_features {
542                hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
543            }
544
545            if hk_norm_sq < epsilon {
546                continue;
547            }
548
549            for i in 0..n_samples {
550                let mut numerator = F::zero();
551                for j in 0..n_features {
552                    let mut wh_ij = F::zero();
553                    for kk in 0..n_components {
554                        if kk != k {
555                            wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
556                        }
557                    }
558                    numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
559                }
560
561                w[[i, k]] = if numerator > F::zero() {
562                    numerator / hk_norm_sq
563                } else {
564                    F::zero()
565                };
566            }
567        }
568
569        // Check convergence.
570        let err = reconstruction_error(x, w, h);
571        if (prev_err - err).abs() < tol_f {
572            return iteration + 1;
573        }
574        prev_err = err;
575    }
576
577    max_iter
578}
579
580// ---------------------------------------------------------------------------
581// Trait implementations
582// ---------------------------------------------------------------------------
583
584impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
585    type Fitted = FittedNMF<F>;
586    type Error = FerroError;
587
588    /// Fit the NMF model by decomposing `X ~ W * H`.
589    ///
590    /// # Errors
591    ///
592    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
593    ///   the minimum of `n_samples` and `n_features`.
594    /// - [`FerroError::InvalidParameter`] if any entry of `X` is negative.
595    /// - [`FerroError::InsufficientSamples`] if there are zero samples.
596    /// - [`FerroError::ConvergenceFailure`] if NNDSVD initialization fails.
597    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
598        let (n_samples, n_features) = x.dim();
599
600        if self.n_components == 0 {
601            return Err(FerroError::InvalidParameter {
602                name: "n_components".into(),
603                reason: "must be at least 1".into(),
604            });
605        }
606        if n_samples == 0 {
607            return Err(FerroError::InsufficientSamples {
608                required: 1,
609                actual: 0,
610                context: "NMF::fit".into(),
611            });
612        }
613        if self.n_components > n_samples.min(n_features) {
614            return Err(FerroError::InvalidParameter {
615                name: "n_components".into(),
616                reason: format!(
617                    "n_components ({}) exceeds min(n_samples, n_features) = {}",
618                    self.n_components,
619                    n_samples.min(n_features)
620                ),
621            });
622        }
623
624        // Check non-negativity.
625        for &val in x.iter() {
626            if val < F::zero() {
627                return Err(FerroError::InvalidParameter {
628                    name: "X".into(),
629                    reason: "NMF requires all entries in X to be non-negative".into(),
630                });
631            }
632        }
633
634        let seed = self.random_state.unwrap_or(0);
635
636        // Initialize W and H.
637        let (mut w, mut h) = match self.init {
638            NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
639            NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
640        };
641
642        // Solve.
643        let n_iter = match self.solver {
644            NMFSolver::MultiplicativeUpdate => {
645                solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
646            }
647            NMFSolver::CoordinateDescent => {
648                solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
649            }
650        };
651
652        let reconstruction_err = reconstruction_error(x, &w, &h);
653
654        Ok(FittedNMF {
655            components_: h,
656            reconstruction_err_: reconstruction_err,
657            n_iter_: n_iter,
658        })
659    }
660}
661
662impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
663    type Output = Array2<F>;
664    type Error = FerroError;
665
666    /// Project data onto the learned NMF components.
667    ///
668    /// Solves for `W` in `X ~ W * H` using multiplicative updates with
669    /// `H` fixed to the learned components.
670    ///
671    /// # Errors
672    ///
673    /// - [`FerroError::ShapeMismatch`] if the number of columns does not
674    ///   match the number of features seen during fitting.
675    /// - [`FerroError::InvalidParameter`] if any entry of `X` is negative.
676    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
677        let n_features = self.components_.ncols();
678        if x.ncols() != n_features {
679            return Err(FerroError::ShapeMismatch {
680                expected: vec![x.nrows(), n_features],
681                actual: vec![x.nrows(), x.ncols()],
682                context: "FittedNMF::transform".into(),
683            });
684        }
685
686        // Check non-negativity.
687        for &val in x.iter() {
688            if val < F::zero() {
689                return Err(FerroError::InvalidParameter {
690                    name: "X".into(),
691                    reason: "NMF requires all entries in X to be non-negative".into(),
692                });
693            }
694        }
695
696        let n_samples = x.nrows();
697        let n_components = self.components_.nrows();
698        let epsilon = eps::<F>();
699
700        // Initialize W with uniform small values.
701        let mut w = Array2::<F>::zeros((n_samples, n_components));
702        let init_val = F::from(0.1).unwrap_or(F::one());
703        for elem in w.iter_mut() {
704            *elem = init_val;
705        }
706
707        // Run multiplicative updates with H fixed.
708        let h = &self.components_;
709        for _iter in 0..200 {
710            let wt_num = x.dot(&h.t());
711            let wt_den = w.dot(h).dot(&h.t());
712
713            for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
714                *w_val = *w_val * (*num / (*den + epsilon));
715            }
716        }
717
718        Ok(w)
719    }
720}
721
722// ---------------------------------------------------------------------------
723// Pipeline integration (generic)
724// ---------------------------------------------------------------------------
725
726impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
727    /// Fit NMF using the pipeline interface.
728    ///
729    /// The `y` argument is ignored; NMF is unsupervised.
730    ///
731    /// # Errors
732    ///
733    /// Propagates errors from [`Fit::fit`].
734    fn fit_pipeline(
735        &self,
736        x: &Array2<F>,
737        _y: &Array1<F>,
738    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
739        let fitted = self.fit(x, &())?;
740        Ok(Box::new(fitted))
741    }
742}
743
744impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
745    /// Transform data using the pipeline interface.
746    ///
747    /// # Errors
748    ///
749    /// Propagates errors from [`Transform::transform`].
750    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
751        self.transform(x)
752    }
753}
754
755// ---------------------------------------------------------------------------
756// Tests
757// ---------------------------------------------------------------------------
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762    use approx::assert_abs_diff_eq;
763    use ndarray::array;
764
765    /// Helper: create a small non-negative dataset.
766    fn small_dataset() -> Array2<f64> {
767        array![
768            [1.0, 2.0, 3.0],
769            [4.0, 5.0, 6.0],
770            [7.0, 8.0, 9.0],
771            [10.0, 11.0, 12.0],
772        ]
773    }
774
775    /// Helper: create a larger non-negative dataset.
776    fn medium_dataset() -> Array2<f64> {
777        array![
778            [5.0, 3.0, 0.0, 1.0],
779            [4.0, 0.0, 0.0, 1.0],
780            [1.0, 1.0, 0.0, 5.0],
781            [1.0, 0.0, 0.0, 4.0],
782            [0.0, 1.0, 5.0, 4.0],
783            [0.0, 0.0, 4.0, 3.0],
784        ]
785    }
786
787    #[test]
788    fn test_nmf_basic_fit() {
789        let nmf = NMF::<f64>::new(2).with_random_state(42);
790        let x = small_dataset();
791        let fitted = nmf.fit(&x, &()).unwrap();
792        assert_eq!(fitted.components().dim(), (2, 3));
793    }
794
795    #[test]
796    fn test_nmf_components_non_negative() {
797        let nmf = NMF::<f64>::new(2).with_random_state(42);
798        let x = small_dataset();
799        let fitted = nmf.fit(&x, &()).unwrap();
800        for &val in fitted.components().iter() {
801            assert!(
802                val >= 0.0,
803                "component value should be non-negative, got {val}"
804            );
805        }
806    }
807
808    #[test]
809    fn test_nmf_transform_dimensions() {
810        let nmf = NMF::<f64>::new(2).with_random_state(42);
811        let x = small_dataset();
812        let fitted = nmf.fit(&x, &()).unwrap();
813        let projected = fitted.transform(&x).unwrap();
814        assert_eq!(projected.dim(), (4, 2));
815    }
816
817    #[test]
818    fn test_nmf_transform_non_negative() {
819        let nmf = NMF::<f64>::new(2).with_random_state(42);
820        let x = small_dataset();
821        let fitted = nmf.fit(&x, &()).unwrap();
822        let projected = fitted.transform(&x).unwrap();
823        for &val in projected.iter() {
824            assert!(val >= 0.0, "W value should be non-negative, got {val}");
825        }
826    }
827
828    #[test]
829    fn test_nmf_reconstruction_error_decreases() {
830        let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
831        let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
832        let x = small_dataset();
833        let fitted_few = nmf_few.fit(&x, &()).unwrap();
834        let fitted_many = nmf_many.fit(&x, &()).unwrap();
835        assert!(
836            fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
837            "more iterations should reduce error: few={}, many={}",
838            fitted_few.reconstruction_err(),
839            fitted_many.reconstruction_err()
840        );
841    }
842
843    #[test]
844    fn test_nmf_reconstruction_error_positive() {
845        let nmf = NMF::<f64>::new(2).with_random_state(42);
846        let x = small_dataset();
847        let fitted = nmf.fit(&x, &()).unwrap();
848        assert!(fitted.reconstruction_err() >= 0.0);
849    }
850
851    #[test]
852    fn test_nmf_coordinate_descent_solver() {
853        let nmf = NMF::<f64>::new(2)
854            .with_solver(NMFSolver::CoordinateDescent)
855            .with_random_state(42);
856        let x = medium_dataset();
857        let fitted = nmf.fit(&x, &()).unwrap();
858        assert_eq!(fitted.components().dim(), (2, 4));
859        for &val in fitted.components().iter() {
860            assert!(val >= 0.0, "CD component should be non-negative, got {val}");
861        }
862    }
863
864    #[test]
865    fn test_nmf_nndsvd_init() {
866        let nmf = NMF::<f64>::new(2)
867            .with_init(NMFInit::Nndsvd)
868            .with_random_state(42);
869        let x = medium_dataset();
870        let fitted = nmf.fit(&x, &()).unwrap();
871        assert_eq!(fitted.components().dim(), (2, 4));
872        for &val in fitted.components().iter() {
873            assert!(
874                val >= 0.0,
875                "NNDSVD component should be non-negative, got {val}"
876            );
877        }
878    }
879
880    #[test]
881    fn test_nmf_cd_with_nndsvd() {
882        let nmf = NMF::<f64>::new(2)
883            .with_solver(NMFSolver::CoordinateDescent)
884            .with_init(NMFInit::Nndsvd)
885            .with_random_state(42);
886        let x = medium_dataset();
887        let fitted = nmf.fit(&x, &()).unwrap();
888        assert_eq!(fitted.components().dim(), (2, 4));
889    }
890
891    #[test]
892    fn test_nmf_invalid_n_components_zero() {
893        let nmf = NMF::<f64>::new(0);
894        let x = small_dataset();
895        assert!(nmf.fit(&x, &()).is_err());
896    }
897
898    #[test]
899    fn test_nmf_invalid_n_components_too_large() {
900        let nmf = NMF::<f64>::new(10);
901        let x = small_dataset(); // 4x3
902        assert!(nmf.fit(&x, &()).is_err());
903    }
904
905    #[test]
906    fn test_nmf_negative_input_rejected() {
907        let nmf = NMF::<f64>::new(1);
908        let x = array![[1.0, -2.0], [3.0, 4.0]];
909        assert!(nmf.fit(&x, &()).is_err());
910    }
911
912    #[test]
913    fn test_nmf_transform_shape_mismatch() {
914        let nmf = NMF::<f64>::new(2).with_random_state(42);
915        let x = small_dataset();
916        let fitted = nmf.fit(&x, &()).unwrap();
917        let x_bad = array![[1.0, 2.0]]; // wrong number of features
918        assert!(fitted.transform(&x_bad).is_err());
919    }
920
921    #[test]
922    fn test_nmf_transform_negative_rejected() {
923        let nmf = NMF::<f64>::new(2).with_random_state(42);
924        let x = small_dataset();
925        let fitted = nmf.fit(&x, &()).unwrap();
926        let x_neg = array![[1.0, -2.0, 3.0]];
927        assert!(fitted.transform(&x_neg).is_err());
928    }
929
930    #[test]
931    fn test_nmf_reproducibility() {
932        let nmf1 = NMF::<f64>::new(2).with_random_state(42);
933        let nmf2 = NMF::<f64>::new(2).with_random_state(42);
934        let x = small_dataset();
935        let fitted1 = nmf1.fit(&x, &()).unwrap();
936        let fitted2 = nmf2.fit(&x, &()).unwrap();
937        for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
938            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
939        }
940    }
941
942    #[test]
943    fn test_nmf_single_component() {
944        let nmf = NMF::<f64>::new(1).with_random_state(42);
945        let x = small_dataset();
946        let fitted = nmf.fit(&x, &()).unwrap();
947        assert_eq!(fitted.components().nrows(), 1);
948        let projected = fitted.transform(&x).unwrap();
949        assert_eq!(projected.ncols(), 1);
950    }
951
952    #[test]
953    fn test_nmf_n_iter_positive() {
954        let nmf = NMF::<f64>::new(2).with_random_state(42);
955        let x = small_dataset();
956        let fitted = nmf.fit(&x, &()).unwrap();
957        assert!(fitted.n_iter() > 0);
958    }
959
960    #[test]
961    fn test_nmf_getters() {
962        let nmf = NMF::<f64>::new(3)
963            .with_max_iter(100)
964            .with_tol(1e-5)
965            .with_solver(NMFSolver::CoordinateDescent)
966            .with_init(NMFInit::Nndsvd)
967            .with_random_state(99);
968        assert_eq!(nmf.n_components(), 3);
969        assert_eq!(nmf.max_iter(), 100);
970        assert_abs_diff_eq!(nmf.tol(), 1e-5);
971        assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
972        assert_eq!(nmf.init(), NMFInit::Nndsvd);
973        assert_eq!(nmf.random_state(), Some(99));
974    }
975
976    #[test]
977    fn test_nmf_f32() {
978        let nmf = NMF::<f32>::new(1).with_random_state(42);
979        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
980        let fitted = nmf.fit(&x, &()).unwrap();
981        let projected = fitted.transform(&x).unwrap();
982        assert_eq!(projected.ncols(), 1);
983    }
984
985    #[test]
986    fn test_nmf_zero_entries() {
987        let nmf = NMF::<f64>::new(2).with_random_state(42);
988        let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
989        let fitted = nmf.fit(&x, &()).unwrap();
990        assert_eq!(fitted.components().dim(), (2, 3));
991    }
992
993    #[test]
994    fn test_nmf_pipeline_integration() {
995        use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
996        use ferrolearn_core::traits::Predict;
997
998        struct SumEstimator;
999
1000        impl PipelineEstimator<f64> for SumEstimator {
1001            fn fit_pipeline(
1002                &self,
1003                _x: &Array2<f64>,
1004                _y: &Array1<f64>,
1005            ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
1006                Ok(Box::new(FittedSumEstimator))
1007            }
1008        }
1009
1010        struct FittedSumEstimator;
1011
1012        impl FittedPipelineEstimator<f64> for FittedSumEstimator {
1013            fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1014                let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
1015                Ok(Array1::from_vec(sums))
1016            }
1017        }
1018
1019        let pipeline = Pipeline::new()
1020            .transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
1021            .estimator_step("sum", Box::new(SumEstimator));
1022
1023        let x = small_dataset();
1024        let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1025
1026        let fitted = pipeline.fit(&x, &y).unwrap();
1027        let preds = fitted.predict(&x).unwrap();
1028        assert_eq!(preds.len(), 4);
1029    }
1030
1031    #[test]
1032    fn test_nmf_medium_dataset_mu() {
1033        let nmf = NMF::<f64>::new(3)
1034            .with_solver(NMFSolver::MultiplicativeUpdate)
1035            .with_random_state(42)
1036            .with_max_iter(500);
1037        let x = medium_dataset();
1038        let fitted = nmf.fit(&x, &()).unwrap();
1039        assert_eq!(fitted.components().dim(), (3, 4));
1040        // Reconstruction error should be reasonable.
1041        assert!(
1042            fitted.reconstruction_err() < 10.0,
1043            "reconstruction error too large: {}",
1044            fitted.reconstruction_err()
1045        );
1046    }
1047
1048    #[test]
1049    fn test_nmf_insufficient_samples() {
1050        let nmf = NMF::<f64>::new(1);
1051        let x = Array2::<f64>::zeros((0, 3));
1052        assert!(nmf.fit(&x, &()).is_err());
1053    }
1054
1055    #[test]
1056    fn test_nmf_more_components_lower_error() {
1057        let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
1058        let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
1059        let x = medium_dataset();
1060        let fitted1 = nmf1.fit(&x, &()).unwrap();
1061        let fitted2 = nmf2.fit(&x, &()).unwrap();
1062        assert!(
1063            fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
1064            "more components should reduce error: 1comp={}, 2comp={}",
1065            fitted1.reconstruction_err(),
1066            fitted2.reconstruction_err()
1067        );
1068    }
1069}