Skip to main content

ferrolearn_decomp/
nmf.rs

1//! Non-negative Matrix Factorization (NMF).
2//!
3//! [`NMF`] decomposes a non-negative matrix `X` into two non-negative
4//! factors `W` and `H` such that `X ~ W * H`, where:
5//! - `X` has shape `(n_samples, n_features)`
6//! - `W` has shape `(n_samples, n_components)`
7//! - `H` has shape `(n_components, n_features)`
8//!
9//! # Algorithm
10//!
11//! Two solvers are supported:
12//!
13//! - **Multiplicative Update** (Lee & Seung, 2001): iteratively update `W` and
14//!   `H` using multiplicative rules that guarantee non-negativity.
15//! - **Coordinate Descent**: iteratively solve for each element of `W` and `H`
16//!   using closed-form coordinate-wise updates.
17//!
18//! # Initialization
19//!
20//! - **Random**: initialize `W` and `H` with random non-negative values.
21//! - **NNDSVD**: Non-Negative Double SVD, initializes `W` and `H` from a
22//!   truncated SVD of `X`, setting negative entries to zero.
23//!
24//! # Examples
25//!
26//! ```
27//! use ferrolearn_decomp::NMF;
28//! use ferrolearn_core::traits::{Fit, Transform};
29//! use ndarray::array;
30//!
31//! let nmf = NMF::<f64>::new(2);
32//! let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
33//! let fitted = nmf.fit(&x, &()).unwrap();
34//! let projected = fitted.transform(&x).unwrap();
35//! assert_eq!(projected.ncols(), 2);
36//! ```
37
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
40use ferrolearn_core::traits::{Fit, Transform};
41use ndarray::{Array1, Array2};
42use num_traits::Float;
43use rand::SeedableRng;
44use rand_distr::{Distribution, Uniform};
45
46// ---------------------------------------------------------------------------
47// Configuration enums
48// ---------------------------------------------------------------------------
49
50/// The solver algorithm for NMF.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum NMFSolver {
53    /// Multiplicative update rules (Lee & Seung, 2001).
54    MultiplicativeUpdate,
55    /// Coordinate descent.
56    CoordinateDescent,
57}
58
59/// The initialization strategy for NMF.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum NMFInit {
62    /// Random non-negative initialization.
63    Random,
64    /// Non-Negative Double SVD initialization.
65    Nndsvd,
66}
67
68// ---------------------------------------------------------------------------
69// NMF (unfitted)
70// ---------------------------------------------------------------------------
71
72/// Non-negative Matrix Factorization configuration.
73///
74/// Holds hyperparameters for the NMF decomposition. Calling [`Fit::fit`]
75/// computes the factorization and returns a [`FittedNMF`] that can
76/// project new data via [`Transform::transform`].
77#[derive(Debug, Clone)]
78pub struct NMF<F> {
79    /// Number of components to extract.
80    n_components: usize,
81    /// Maximum number of iterations for the solver.
82    max_iter: usize,
83    /// Convergence tolerance for the solver.
84    tol: f64,
85    /// The solver algorithm to use.
86    solver: NMFSolver,
87    /// The initialization strategy.
88    init: NMFInit,
89    /// Optional random seed for reproducibility.
90    random_state: Option<u64>,
91    _marker: std::marker::PhantomData<F>,
92}
93
94impl<F: Float + Send + Sync + 'static> NMF<F> {
95    /// Create a new `NMF` that extracts `n_components` components.
96    ///
97    /// Defaults: `max_iter=200`, `tol=1e-4`, solver=`MultiplicativeUpdate`,
98    /// init=`Random`, no random seed.
99    #[must_use]
100    pub fn new(n_components: usize) -> Self {
101        Self {
102            n_components,
103            max_iter: 200,
104            tol: 1e-4,
105            solver: NMFSolver::MultiplicativeUpdate,
106            init: NMFInit::Random,
107            random_state: None,
108            _marker: std::marker::PhantomData,
109        }
110    }
111
112    /// Set the maximum number of iterations.
113    #[must_use]
114    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
115        self.max_iter = max_iter;
116        self
117    }
118
119    /// Set the convergence tolerance.
120    #[must_use]
121    pub fn with_tol(mut self, tol: f64) -> Self {
122        self.tol = tol;
123        self
124    }
125
126    /// Set the solver algorithm.
127    #[must_use]
128    pub fn with_solver(mut self, solver: NMFSolver) -> Self {
129        self.solver = solver;
130        self
131    }
132
133    /// Set the initialization strategy.
134    #[must_use]
135    pub fn with_init(mut self, init: NMFInit) -> Self {
136        self.init = init;
137        self
138    }
139
140    /// Set the random seed for reproducible results.
141    #[must_use]
142    pub fn with_random_state(mut self, seed: u64) -> Self {
143        self.random_state = Some(seed);
144        self
145    }
146
147    /// Return the configured number of components.
148    #[must_use]
149    pub fn n_components(&self) -> usize {
150        self.n_components
151    }
152
153    /// Return the configured maximum iterations.
154    #[must_use]
155    pub fn max_iter(&self) -> usize {
156        self.max_iter
157    }
158
159    /// Return the configured tolerance.
160    #[must_use]
161    pub fn tol(&self) -> f64 {
162        self.tol
163    }
164
165    /// Return the configured solver.
166    #[must_use]
167    pub fn solver(&self) -> NMFSolver {
168        self.solver
169    }
170
171    /// Return the configured initialization strategy.
172    #[must_use]
173    pub fn init(&self) -> NMFInit {
174        self.init
175    }
176
177    /// Return the configured random state, if any.
178    #[must_use]
179    pub fn random_state(&self) -> Option<u64> {
180        self.random_state
181    }
182}
183
184// ---------------------------------------------------------------------------
185// FittedNMF
186// ---------------------------------------------------------------------------
187
188/// A fitted NMF model holding the learned components and reconstruction error.
189///
190/// Created by calling [`Fit::fit`] on an [`NMF`]. Implements
191/// [`Transform<Array2<F>>`] to project new data onto the learned components.
192#[derive(Debug, Clone)]
193pub struct FittedNMF<F> {
194    /// Learned component matrix H, shape `(n_components, n_features)`.
195    components_: Array2<F>,
196    /// The Frobenius norm of the reconstruction error at convergence.
197    reconstruction_err_: F,
198    /// Number of iterations performed.
199    n_iter_: usize,
200}
201
202impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
203    /// Learned components (H matrix), shape `(n_components, n_features)`.
204    #[must_use]
205    pub fn components(&self) -> &Array2<F> {
206        &self.components_
207    }
208
209    /// Frobenius norm of the reconstruction error `||X - W*H||_F`.
210    #[must_use]
211    pub fn reconstruction_err(&self) -> F {
212        self.reconstruction_err_
213    }
214
215    /// Number of iterations performed during fitting.
216    #[must_use]
217    pub fn n_iter(&self) -> usize {
218        self.n_iter_
219    }
220
221    /// Reconstruct the original feature space from the latent representation.
222    /// Mirrors sklearn `NMF.inverse_transform`. Returns `W @ H` where `W`
223    /// is the input transformed matrix and `H = self.components_`.
224    ///
225    /// # Errors
226    ///
227    /// Returns [`FerroError::ShapeMismatch`] if `w.ncols()` does not equal
228    /// the number of components.
229    pub fn inverse_transform(&self, w: &Array2<F>) -> Result<Array2<F>, FerroError> {
230        let n_components = self.components_.nrows();
231        if w.ncols() != n_components {
232            return Err(FerroError::ShapeMismatch {
233                expected: vec![w.nrows(), n_components],
234                actual: vec![w.nrows(), w.ncols()],
235                context: "FittedNMF::inverse_transform".into(),
236            });
237        }
238        Ok(w.dot(&self.components_))
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Internal helpers
244// ---------------------------------------------------------------------------
245
246/// Compute the Frobenius norm of `X - W * H`.
247fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
248    let wh = w.dot(h);
249    let mut err = F::zero();
250    for (a, b) in x.iter().zip(wh.iter()) {
251        let diff = *a - *b;
252        err = err + diff * diff;
253    }
254    err.sqrt()
255}
256
257/// Small epsilon to prevent division by zero.
258fn eps<F: Float>() -> F {
259    F::from(1e-12).unwrap_or_else(F::epsilon)
260}
261
262/// Initialize W and H with random non-negative values.
263fn init_random<F: Float>(
264    n_samples: usize,
265    n_features: usize,
266    n_components: usize,
267    seed: u64,
268) -> (Array2<F>, Array2<F>) {
269    let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
270    let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
271
272    let mut w = Array2::<F>::zeros((n_samples, n_components));
273    for elem in &mut w {
274        *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
275    }
276
277    let mut h = Array2::<F>::zeros((n_components, n_features));
278    for elem in &mut h {
279        *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
280    }
281
282    (w, h)
283}
284
285/// NNDSVD initialization: compute a truncated SVD-like initialization.
286///
287/// Uses a simple approach: compute `X^T X`, eigendecompose, then use the
288/// top eigenvectors to initialize H, and solve for W = X * H^+ (pseudoinverse).
289fn init_nndsvd<F: Float + Send + Sync + 'static>(
290    x: &Array2<F>,
291    n_components: usize,
292    seed: u64,
293) -> Result<(Array2<F>, Array2<F>), FerroError> {
294    let (n_samples, n_features) = x.dim();
295
296    // Compute mean of X for scale.
297    let mut total = F::zero();
298    for &v in x {
299        total = total + v;
300    }
301    let avg = (total / F::from(n_samples * n_features).unwrap())
302        .abs()
303        .sqrt();
304    let avg = if avg < eps::<F>() { F::one() } else { avg };
305
306    // Compute X^T X.
307    let xtx = x.t().dot(x);
308
309    // Eigendecompose with Jacobi.
310    let max_iter = n_features * n_features * 100 + 1000;
311    let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
312
313    // Sort eigenvalues descending.
314    let mut indices: Vec<usize> = (0..n_features).collect();
315    indices.sort_by(|&a, &b| {
316        eigenvalues[b]
317            .partial_cmp(&eigenvalues[a])
318            .unwrap_or(std::cmp::Ordering::Equal)
319    });
320
321    // Build H from top eigenvectors (as rows), clamp negatives to zero.
322    let mut h = Array2::<F>::zeros((n_components, n_features));
323    for (k, &idx) in indices.iter().take(n_components).enumerate() {
324        for j in 0..n_features {
325            let val = eigenvectors[[j, idx]];
326            h[[k, j]] = if val > F::zero() { val } else { F::zero() };
327        }
328        // Ensure row is not all zeros.
329        let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
330        if row_sum < eps::<F>() {
331            // Fall back to small random values.
332            let mut rng: rand::rngs::StdRng =
333                SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
334            let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
335            for j in 0..n_features {
336                h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) * avg;
337            }
338        }
339    }
340
341    // Compute W = X * H^T * (H * H^T)^{-1}, but simpler: use multiplicative
342    // update step starting from random W.
343    let mut w = Array2::<F>::zeros((n_samples, n_components));
344    // Solve W by least squares: W = X * H^T * pinv(H * H^T)
345    // For simplicity, initialize W = X * H^T and normalize.
346    let ht = h.t();
347    let w_init = x.dot(&ht);
348    for i in 0..n_samples {
349        for k in 0..n_components {
350            let val = w_init[[i, k]];
351            w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
352        }
353    }
354
355    Ok((w, h))
356}
357
358/// Jacobi eigendecomposition for symmetric matrices.
359///
360/// Returns `(eigenvalues, eigenvectors)` where column `i` of `eigenvectors`
361/// is the eigenvector for `eigenvalues[i]`. Eigenvalues are NOT sorted.
362fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
363    a: &Array2<F>,
364    max_iter: usize,
365) -> Result<(Array1<F>, Array2<F>), FerroError> {
366    let n = a.nrows();
367    if n == 0 {
368        return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
369    }
370    if n == 1 {
371        let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
372        let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
373        return Ok((eigenvalues, eigenvectors));
374    }
375
376    let mut mat = a.to_owned();
377    let mut v = Array2::<F>::zeros((n, n));
378    for i in 0..n {
379        v[[i, i]] = F::one();
380    }
381
382    let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
383
384    for _iteration in 0..max_iter {
385        let mut max_off = F::zero();
386        let mut p = 0;
387        let mut q = 1;
388        for i in 0..n {
389            for j in (i + 1)..n {
390                let val = mat[[i, j]].abs();
391                if val > max_off {
392                    max_off = val;
393                    p = i;
394                    q = j;
395                }
396            }
397        }
398
399        if max_off < tol {
400            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
401            return Ok((eigenvalues, v));
402        }
403
404        let app = mat[[p, p]];
405        let aqq = mat[[q, q]];
406        let apq = mat[[p, q]];
407
408        let theta = if (app - aqq).abs() < tol {
409            F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
410        } else {
411            let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
412            let t = if tau >= F::zero() {
413                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
414            } else {
415                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
416            };
417            t.atan()
418        };
419
420        let c = theta.cos();
421        let s = theta.sin();
422
423        let mut new_mat = mat.clone();
424        for i in 0..n {
425            if i != p && i != q {
426                let mip = mat[[i, p]];
427                let miq = mat[[i, q]];
428                new_mat[[i, p]] = c * mip - s * miq;
429                new_mat[[p, i]] = new_mat[[i, p]];
430                new_mat[[i, q]] = s * mip + c * miq;
431                new_mat[[q, i]] = new_mat[[i, q]];
432            }
433        }
434
435        new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
436        new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
437        new_mat[[p, q]] = F::zero();
438        new_mat[[q, p]] = F::zero();
439
440        mat = new_mat;
441
442        for i in 0..n {
443            let vip = v[[i, p]];
444            let viq = v[[i, q]];
445            v[[i, p]] = c * vip - s * viq;
446            v[[i, q]] = s * vip + c * viq;
447        }
448    }
449
450    Err(FerroError::ConvergenceFailure {
451        iterations: max_iter,
452        message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
453    })
454}
455
456/// Multiplicative update solver (Lee & Seung, 2001).
457///
458/// Update rules:
459///   W <- W * (X H^T) / (W H H^T + eps)
460///   H <- H * (W^T X) / (W^T W H + eps)
461fn solve_multiplicative_update<F: Float + 'static>(
462    x: &Array2<F>,
463    w: &mut Array2<F>,
464    h: &mut Array2<F>,
465    max_iter: usize,
466    tol: f64,
467) -> usize {
468    let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
469    let epsilon = eps::<F>();
470    let mut prev_err = reconstruction_error(x, w, h);
471
472    for iteration in 0..max_iter {
473        // Update H: H <- H * (W^T X) / (W^T W H + eps)
474        let wt = w.t();
475        let numerator_h = wt.dot(x);
476        let denominator_h = wt.dot(&*w).dot(&*h);
477
478        for (h_val, (num, den)) in h
479            .iter_mut()
480            .zip(numerator_h.iter().zip(denominator_h.iter()))
481        {
482            *h_val = *h_val * (*num / (*den + epsilon));
483        }
484
485        // Update W: W <- W * (X H^T) / (W H H^T + eps)
486        let ht = h.t();
487        let numerator_w = x.dot(&ht);
488        let denominator_w = w.dot(&*h).dot(&ht);
489
490        for (w_val, (num, den)) in w
491            .iter_mut()
492            .zip(numerator_w.iter().zip(denominator_w.iter()))
493        {
494            *w_val = *w_val * (*num / (*den + epsilon));
495        }
496
497        // Check convergence.
498        let err = reconstruction_error(x, w, h);
499        if (prev_err - err).abs() < tol_f {
500            return iteration + 1;
501        }
502        prev_err = err;
503    }
504
505    max_iter
506}
507
508/// Coordinate descent solver.
509///
510/// Updates each element of H and W by solving a scalar minimization problem.
511fn solve_coordinate_descent<F: Float + 'static>(
512    x: &Array2<F>,
513    w: &mut Array2<F>,
514    h: &mut Array2<F>,
515    max_iter: usize,
516    tol: f64,
517) -> usize {
518    let (n_samples, n_features) = x.dim();
519    let n_components = h.nrows();
520    let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
521    let epsilon = eps::<F>();
522    let mut prev_err = reconstruction_error(x, w, h);
523
524    for iteration in 0..max_iter {
525        // Update H: for each k, j, solve for H[k,j]
526        // H[k,j] = max(0, (W[:,k]^T * (X[:,j] - W * H[:,j] + W[:,k]*H[k,j])) / (W[:,k]^T W[:,k]))
527        for k in 0..n_components {
528            let mut wk_norm_sq = F::zero();
529            for i in 0..n_samples {
530                wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
531            }
532
533            if wk_norm_sq < epsilon {
534                continue;
535            }
536
537            for j in 0..n_features {
538                // Compute residual + current contribution.
539                let mut numerator = F::zero();
540                for i in 0..n_samples {
541                    let mut wh_ij = F::zero();
542                    for kk in 0..n_components {
543                        if kk != k {
544                            wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
545                        }
546                    }
547                    numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
548                }
549
550                h[[k, j]] = if numerator > F::zero() {
551                    numerator / wk_norm_sq
552                } else {
553                    F::zero()
554                };
555            }
556        }
557
558        // Update W: for each i, k, solve for W[i,k]
559        for k in 0..n_components {
560            let mut hk_norm_sq = F::zero();
561            for j in 0..n_features {
562                hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
563            }
564
565            if hk_norm_sq < epsilon {
566                continue;
567            }
568
569            for i in 0..n_samples {
570                let mut numerator = F::zero();
571                for j in 0..n_features {
572                    let mut wh_ij = F::zero();
573                    for kk in 0..n_components {
574                        if kk != k {
575                            wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
576                        }
577                    }
578                    numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
579                }
580
581                w[[i, k]] = if numerator > F::zero() {
582                    numerator / hk_norm_sq
583                } else {
584                    F::zero()
585                };
586            }
587        }
588
589        // Check convergence.
590        let err = reconstruction_error(x, w, h);
591        if (prev_err - err).abs() < tol_f {
592            return iteration + 1;
593        }
594        prev_err = err;
595    }
596
597    max_iter
598}
599
600// ---------------------------------------------------------------------------
601// Trait implementations
602// ---------------------------------------------------------------------------
603
604impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
605    type Fitted = FittedNMF<F>;
606    type Error = FerroError;
607
608    /// Fit the NMF model by decomposing `X ~ W * H`.
609    ///
610    /// # Errors
611    ///
612    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
613    ///   the minimum of `n_samples` and `n_features`.
614    /// - [`FerroError::InvalidParameter`] if any entry of `X` is negative.
615    /// - [`FerroError::InsufficientSamples`] if there are zero samples.
616    /// - [`FerroError::ConvergenceFailure`] if NNDSVD initialization fails.
617    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
618        let (n_samples, n_features) = x.dim();
619
620        if self.n_components == 0 {
621            return Err(FerroError::InvalidParameter {
622                name: "n_components".into(),
623                reason: "must be at least 1".into(),
624            });
625        }
626        if n_samples == 0 {
627            return Err(FerroError::InsufficientSamples {
628                required: 1,
629                actual: 0,
630                context: "NMF::fit".into(),
631            });
632        }
633        if self.n_components > n_samples.min(n_features) {
634            return Err(FerroError::InvalidParameter {
635                name: "n_components".into(),
636                reason: format!(
637                    "n_components ({}) exceeds min(n_samples, n_features) = {}",
638                    self.n_components,
639                    n_samples.min(n_features)
640                ),
641            });
642        }
643
644        // Check non-negativity.
645        for &val in x {
646            if val < F::zero() {
647                return Err(FerroError::InvalidParameter {
648                    name: "X".into(),
649                    reason: "NMF requires all entries in X to be non-negative".into(),
650                });
651            }
652        }
653
654        let seed = self.random_state.unwrap_or(0);
655
656        // Initialize W and H.
657        let (mut w, mut h) = match self.init {
658            NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
659            NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
660        };
661
662        // Solve.
663        let n_iter = match self.solver {
664            NMFSolver::MultiplicativeUpdate => {
665                solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
666            }
667            NMFSolver::CoordinateDescent => {
668                solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
669            }
670        };
671
672        let reconstruction_err = reconstruction_error(x, &w, &h);
673
674        Ok(FittedNMF {
675            components_: h,
676            reconstruction_err_: reconstruction_err,
677            n_iter_: n_iter,
678        })
679    }
680}
681
682impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
683    type Output = Array2<F>;
684    type Error = FerroError;
685
686    /// Project data onto the learned NMF components.
687    ///
688    /// Solves for `W` in `X ~ W * H` using multiplicative updates with
689    /// `H` fixed to the learned components.
690    ///
691    /// # Errors
692    ///
693    /// - [`FerroError::ShapeMismatch`] if the number of columns does not
694    ///   match the number of features seen during fitting.
695    /// - [`FerroError::InvalidParameter`] if any entry of `X` is negative.
696    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
697        let n_features = self.components_.ncols();
698        if x.ncols() != n_features {
699            return Err(FerroError::ShapeMismatch {
700                expected: vec![x.nrows(), n_features],
701                actual: vec![x.nrows(), x.ncols()],
702                context: "FittedNMF::transform".into(),
703            });
704        }
705
706        // Check non-negativity.
707        for &val in x {
708            if val < F::zero() {
709                return Err(FerroError::InvalidParameter {
710                    name: "X".into(),
711                    reason: "NMF requires all entries in X to be non-negative".into(),
712                });
713            }
714        }
715
716        let n_samples = x.nrows();
717        let n_components = self.components_.nrows();
718        let epsilon = eps::<F>();
719
720        // Initialize W with uniform small values.
721        let mut w = Array2::<F>::zeros((n_samples, n_components));
722        let init_val = F::from(0.1).unwrap_or_else(F::one);
723        w.fill(init_val);
724
725        // Run multiplicative updates with H fixed.
726        let h = &self.components_;
727        for _iter in 0..200 {
728            let wt_num = x.dot(&h.t());
729            let wt_den = w.dot(h).dot(&h.t());
730
731            for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
732                *w_val = *w_val * (*num / (*den + epsilon));
733            }
734        }
735
736        Ok(w)
737    }
738}
739
740// ---------------------------------------------------------------------------
741// Pipeline integration (generic)
742// ---------------------------------------------------------------------------
743
744impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
745    /// Fit NMF using the pipeline interface.
746    ///
747    /// The `y` argument is ignored; NMF is unsupervised.
748    ///
749    /// # Errors
750    ///
751    /// Propagates errors from [`Fit::fit`].
752    fn fit_pipeline(
753        &self,
754        x: &Array2<F>,
755        _y: &Array1<F>,
756    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
757        let fitted = self.fit(x, &())?;
758        Ok(Box::new(fitted))
759    }
760}
761
762impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
763    /// Transform data using the pipeline interface.
764    ///
765    /// # Errors
766    ///
767    /// Propagates errors from [`Transform::transform`].
768    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
769        self.transform(x)
770    }
771}
772
773// ---------------------------------------------------------------------------
774// Tests
775// ---------------------------------------------------------------------------
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780    use approx::assert_abs_diff_eq;
781    use ndarray::array;
782
783    /// Helper: create a small non-negative dataset.
784    fn small_dataset() -> Array2<f64> {
785        array![
786            [1.0, 2.0, 3.0],
787            [4.0, 5.0, 6.0],
788            [7.0, 8.0, 9.0],
789            [10.0, 11.0, 12.0],
790        ]
791    }
792
793    /// Helper: create a larger non-negative dataset.
794    fn medium_dataset() -> Array2<f64> {
795        array![
796            [5.0, 3.0, 0.0, 1.0],
797            [4.0, 0.0, 0.0, 1.0],
798            [1.0, 1.0, 0.0, 5.0],
799            [1.0, 0.0, 0.0, 4.0],
800            [0.0, 1.0, 5.0, 4.0],
801            [0.0, 0.0, 4.0, 3.0],
802        ]
803    }
804
805    #[test]
806    fn test_nmf_basic_fit() {
807        let nmf = NMF::<f64>::new(2).with_random_state(42);
808        let x = small_dataset();
809        let fitted = nmf.fit(&x, &()).unwrap();
810        assert_eq!(fitted.components().dim(), (2, 3));
811    }
812
813    #[test]
814    fn test_nmf_components_non_negative() {
815        let nmf = NMF::<f64>::new(2).with_random_state(42);
816        let x = small_dataset();
817        let fitted = nmf.fit(&x, &()).unwrap();
818        for &val in fitted.components() {
819            assert!(
820                val >= 0.0,
821                "component value should be non-negative, got {val}"
822            );
823        }
824    }
825
826    #[test]
827    fn test_nmf_transform_dimensions() {
828        let nmf = NMF::<f64>::new(2).with_random_state(42);
829        let x = small_dataset();
830        let fitted = nmf.fit(&x, &()).unwrap();
831        let projected = fitted.transform(&x).unwrap();
832        assert_eq!(projected.dim(), (4, 2));
833    }
834
835    #[test]
836    fn test_nmf_transform_non_negative() {
837        let nmf = NMF::<f64>::new(2).with_random_state(42);
838        let x = small_dataset();
839        let fitted = nmf.fit(&x, &()).unwrap();
840        let projected = fitted.transform(&x).unwrap();
841        for &val in &projected {
842            assert!(val >= 0.0, "W value should be non-negative, got {val}");
843        }
844    }
845
846    #[test]
847    fn test_nmf_reconstruction_error_decreases() {
848        let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
849        let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
850        let x = small_dataset();
851        let fitted_few = nmf_few.fit(&x, &()).unwrap();
852        let fitted_many = nmf_many.fit(&x, &()).unwrap();
853        assert!(
854            fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
855            "more iterations should reduce error: few={}, many={}",
856            fitted_few.reconstruction_err(),
857            fitted_many.reconstruction_err()
858        );
859    }
860
861    #[test]
862    fn test_nmf_reconstruction_error_positive() {
863        let nmf = NMF::<f64>::new(2).with_random_state(42);
864        let x = small_dataset();
865        let fitted = nmf.fit(&x, &()).unwrap();
866        assert!(fitted.reconstruction_err() >= 0.0);
867    }
868
869    #[test]
870    fn test_nmf_coordinate_descent_solver() {
871        let nmf = NMF::<f64>::new(2)
872            .with_solver(NMFSolver::CoordinateDescent)
873            .with_random_state(42);
874        let x = medium_dataset();
875        let fitted = nmf.fit(&x, &()).unwrap();
876        assert_eq!(fitted.components().dim(), (2, 4));
877        for &val in fitted.components() {
878            assert!(val >= 0.0, "CD component should be non-negative, got {val}");
879        }
880    }
881
882    #[test]
883    fn test_nmf_nndsvd_init() {
884        let nmf = NMF::<f64>::new(2)
885            .with_init(NMFInit::Nndsvd)
886            .with_random_state(42);
887        let x = medium_dataset();
888        let fitted = nmf.fit(&x, &()).unwrap();
889        assert_eq!(fitted.components().dim(), (2, 4));
890        for &val in fitted.components() {
891            assert!(
892                val >= 0.0,
893                "NNDSVD component should be non-negative, got {val}"
894            );
895        }
896    }
897
898    #[test]
899    fn test_nmf_cd_with_nndsvd() {
900        let nmf = NMF::<f64>::new(2)
901            .with_solver(NMFSolver::CoordinateDescent)
902            .with_init(NMFInit::Nndsvd)
903            .with_random_state(42);
904        let x = medium_dataset();
905        let fitted = nmf.fit(&x, &()).unwrap();
906        assert_eq!(fitted.components().dim(), (2, 4));
907    }
908
909    #[test]
910    fn test_nmf_invalid_n_components_zero() {
911        let nmf = NMF::<f64>::new(0);
912        let x = small_dataset();
913        assert!(nmf.fit(&x, &()).is_err());
914    }
915
916    #[test]
917    fn test_nmf_invalid_n_components_too_large() {
918        let nmf = NMF::<f64>::new(10);
919        let x = small_dataset(); // 4x3
920        assert!(nmf.fit(&x, &()).is_err());
921    }
922
923    #[test]
924    fn test_nmf_negative_input_rejected() {
925        let nmf = NMF::<f64>::new(1);
926        let x = array![[1.0, -2.0], [3.0, 4.0]];
927        assert!(nmf.fit(&x, &()).is_err());
928    }
929
930    #[test]
931    fn test_nmf_transform_shape_mismatch() {
932        let nmf = NMF::<f64>::new(2).with_random_state(42);
933        let x = small_dataset();
934        let fitted = nmf.fit(&x, &()).unwrap();
935        let x_bad = array![[1.0, 2.0]]; // wrong number of features
936        assert!(fitted.transform(&x_bad).is_err());
937    }
938
939    #[test]
940    fn test_nmf_transform_negative_rejected() {
941        let nmf = NMF::<f64>::new(2).with_random_state(42);
942        let x = small_dataset();
943        let fitted = nmf.fit(&x, &()).unwrap();
944        let x_neg = array![[1.0, -2.0, 3.0]];
945        assert!(fitted.transform(&x_neg).is_err());
946    }
947
948    #[test]
949    fn test_nmf_reproducibility() {
950        let nmf1 = NMF::<f64>::new(2).with_random_state(42);
951        let nmf2 = NMF::<f64>::new(2).with_random_state(42);
952        let x = small_dataset();
953        let fitted1 = nmf1.fit(&x, &()).unwrap();
954        let fitted2 = nmf2.fit(&x, &()).unwrap();
955        for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
956            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
957        }
958    }
959
960    #[test]
961    fn test_nmf_single_component() {
962        let nmf = NMF::<f64>::new(1).with_random_state(42);
963        let x = small_dataset();
964        let fitted = nmf.fit(&x, &()).unwrap();
965        assert_eq!(fitted.components().nrows(), 1);
966        let projected = fitted.transform(&x).unwrap();
967        assert_eq!(projected.ncols(), 1);
968    }
969
970    #[test]
971    fn test_nmf_n_iter_positive() {
972        let nmf = NMF::<f64>::new(2).with_random_state(42);
973        let x = small_dataset();
974        let fitted = nmf.fit(&x, &()).unwrap();
975        assert!(fitted.n_iter() > 0);
976    }
977
978    #[test]
979    fn test_nmf_getters() {
980        let nmf = NMF::<f64>::new(3)
981            .with_max_iter(100)
982            .with_tol(1e-5)
983            .with_solver(NMFSolver::CoordinateDescent)
984            .with_init(NMFInit::Nndsvd)
985            .with_random_state(99);
986        assert_eq!(nmf.n_components(), 3);
987        assert_eq!(nmf.max_iter(), 100);
988        assert_abs_diff_eq!(nmf.tol(), 1e-5);
989        assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
990        assert_eq!(nmf.init(), NMFInit::Nndsvd);
991        assert_eq!(nmf.random_state(), Some(99));
992    }
993
994    #[test]
995    fn test_nmf_f32() {
996        let nmf = NMF::<f32>::new(1).with_random_state(42);
997        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
998        let fitted = nmf.fit(&x, &()).unwrap();
999        let projected = fitted.transform(&x).unwrap();
1000        assert_eq!(projected.ncols(), 1);
1001    }
1002
1003    #[test]
1004    fn test_nmf_zero_entries() {
1005        let nmf = NMF::<f64>::new(2).with_random_state(42);
1006        let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
1007        let fitted = nmf.fit(&x, &()).unwrap();
1008        assert_eq!(fitted.components().dim(), (2, 3));
1009    }
1010
1011    #[test]
1012    fn test_nmf_pipeline_integration() {
1013        use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
1014        use ferrolearn_core::traits::Predict;
1015
1016        struct SumEstimator;
1017
1018        impl PipelineEstimator<f64> for SumEstimator {
1019            fn fit_pipeline(
1020                &self,
1021                _x: &Array2<f64>,
1022                _y: &Array1<f64>,
1023            ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
1024                Ok(Box::new(FittedSumEstimator))
1025            }
1026        }
1027
1028        struct FittedSumEstimator;
1029
1030        impl FittedPipelineEstimator<f64> for FittedSumEstimator {
1031            fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1032                let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
1033                Ok(Array1::from_vec(sums))
1034            }
1035        }
1036
1037        let pipeline = Pipeline::new()
1038            .transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
1039            .estimator_step("sum", Box::new(SumEstimator));
1040
1041        let x = small_dataset();
1042        let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1043
1044        let fitted = pipeline.fit(&x, &y).unwrap();
1045        let preds = fitted.predict(&x).unwrap();
1046        assert_eq!(preds.len(), 4);
1047    }
1048
1049    #[test]
1050    fn test_nmf_medium_dataset_mu() {
1051        let nmf = NMF::<f64>::new(3)
1052            .with_solver(NMFSolver::MultiplicativeUpdate)
1053            .with_random_state(42)
1054            .with_max_iter(500);
1055        let x = medium_dataset();
1056        let fitted = nmf.fit(&x, &()).unwrap();
1057        assert_eq!(fitted.components().dim(), (3, 4));
1058        // Reconstruction error should be reasonable.
1059        assert!(
1060            fitted.reconstruction_err() < 10.0,
1061            "reconstruction error too large: {}",
1062            fitted.reconstruction_err()
1063        );
1064    }
1065
1066    #[test]
1067    fn test_nmf_insufficient_samples() {
1068        let nmf = NMF::<f64>::new(1);
1069        let x = Array2::<f64>::zeros((0, 3));
1070        assert!(nmf.fit(&x, &()).is_err());
1071    }
1072
1073    #[test]
1074    fn test_nmf_more_components_lower_error() {
1075        let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
1076        let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
1077        let x = medium_dataset();
1078        let fitted1 = nmf1.fit(&x, &()).unwrap();
1079        let fitted2 = nmf2.fit(&x, &()).unwrap();
1080        assert!(
1081            fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
1082            "more components should reduce error: 1comp={}, 2comp={}",
1083            fitted1.reconstruction_err(),
1084            fitted2.reconstruction_err()
1085        );
1086    }
1087}