1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
40use ferrolearn_core::traits::{Fit, Transform};
41use ndarray::{Array1, Array2};
42use num_traits::Float;
43use rand::SeedableRng;
44use rand_distr::{Distribution, Uniform};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum NMFSolver {
53 MultiplicativeUpdate,
55 CoordinateDescent,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum NMFInit {
62 Random,
64 Nndsvd,
66}
67
68#[derive(Debug, Clone)]
78pub struct NMF<F> {
79 n_components: usize,
81 max_iter: usize,
83 tol: f64,
85 solver: NMFSolver,
87 init: NMFInit,
89 random_state: Option<u64>,
91 _marker: std::marker::PhantomData<F>,
92}
93
94impl<F: Float + Send + Sync + 'static> NMF<F> {
95 #[must_use]
100 pub fn new(n_components: usize) -> Self {
101 Self {
102 n_components,
103 max_iter: 200,
104 tol: 1e-4,
105 solver: NMFSolver::MultiplicativeUpdate,
106 init: NMFInit::Random,
107 random_state: None,
108 _marker: std::marker::PhantomData,
109 }
110 }
111
112 #[must_use]
114 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
115 self.max_iter = max_iter;
116 self
117 }
118
119 #[must_use]
121 pub fn with_tol(mut self, tol: f64) -> Self {
122 self.tol = tol;
123 self
124 }
125
126 #[must_use]
128 pub fn with_solver(mut self, solver: NMFSolver) -> Self {
129 self.solver = solver;
130 self
131 }
132
133 #[must_use]
135 pub fn with_init(mut self, init: NMFInit) -> Self {
136 self.init = init;
137 self
138 }
139
140 #[must_use]
142 pub fn with_random_state(mut self, seed: u64) -> Self {
143 self.random_state = Some(seed);
144 self
145 }
146
147 #[must_use]
149 pub fn n_components(&self) -> usize {
150 self.n_components
151 }
152
153 #[must_use]
155 pub fn max_iter(&self) -> usize {
156 self.max_iter
157 }
158
159 #[must_use]
161 pub fn tol(&self) -> f64 {
162 self.tol
163 }
164
165 #[must_use]
167 pub fn solver(&self) -> NMFSolver {
168 self.solver
169 }
170
171 #[must_use]
173 pub fn init(&self) -> NMFInit {
174 self.init
175 }
176
177 #[must_use]
179 pub fn random_state(&self) -> Option<u64> {
180 self.random_state
181 }
182}
183
184#[derive(Debug, Clone)]
193pub struct FittedNMF<F> {
194 components_: Array2<F>,
196 reconstruction_err_: F,
198 n_iter_: usize,
200}
201
202impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
203 #[must_use]
205 pub fn components(&self) -> &Array2<F> {
206 &self.components_
207 }
208
209 #[must_use]
211 pub fn reconstruction_err(&self) -> F {
212 self.reconstruction_err_
213 }
214
215 #[must_use]
217 pub fn n_iter(&self) -> usize {
218 self.n_iter_
219 }
220}
221
222fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
228 let wh = w.dot(h);
229 let mut err = F::zero();
230 for (a, b) in x.iter().zip(wh.iter()) {
231 let diff = *a - *b;
232 err = err + diff * diff;
233 }
234 err.sqrt()
235}
236
237fn eps<F: Float>() -> F {
239 F::from(1e-12).unwrap_or(F::epsilon())
240}
241
242fn init_random<F: Float>(
244 n_samples: usize,
245 n_features: usize,
246 n_components: usize,
247 seed: u64,
248) -> (Array2<F>, Array2<F>) {
249 let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
250 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
251
252 let mut w = Array2::<F>::zeros((n_samples, n_components));
253 for elem in w.iter_mut() {
254 *elem = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) + eps::<F>();
255 }
256
257 let mut h = Array2::<F>::zeros((n_components, n_features));
258 for elem in h.iter_mut() {
259 *elem = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) + eps::<F>();
260 }
261
262 (w, h)
263}
264
265fn init_nndsvd<F: Float + Send + Sync + 'static>(
270 x: &Array2<F>,
271 n_components: usize,
272 seed: u64,
273) -> Result<(Array2<F>, Array2<F>), FerroError> {
274 let (n_samples, n_features) = x.dim();
275
276 let mut total = F::zero();
278 for &v in x.iter() {
279 total = total + v;
280 }
281 let avg = (total / F::from(n_samples * n_features).unwrap())
282 .abs()
283 .sqrt();
284 let avg = if avg < eps::<F>() { F::one() } else { avg };
285
286 let xtx = x.t().dot(x);
288
289 let max_iter = n_features * n_features * 100 + 1000;
291 let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
292
293 let mut indices: Vec<usize> = (0..n_features).collect();
295 indices.sort_by(|&a, &b| {
296 eigenvalues[b]
297 .partial_cmp(&eigenvalues[a])
298 .unwrap_or(std::cmp::Ordering::Equal)
299 });
300
301 let mut h = Array2::<F>::zeros((n_components, n_features));
303 for (k, &idx) in indices.iter().take(n_components).enumerate() {
304 for j in 0..n_features {
305 let val = eigenvectors[[j, idx]];
306 h[[k, j]] = if val > F::zero() { val } else { F::zero() };
307 }
308 let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
310 if row_sum < eps::<F>() {
311 let mut rng: rand::rngs::StdRng =
313 SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
314 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
315 for j in 0..n_features {
316 h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or(F::zero()) * avg;
317 }
318 }
319 }
320
321 let mut w = Array2::<F>::zeros((n_samples, n_components));
324 let ht = h.t();
327 let w_init = x.dot(&ht);
328 for i in 0..n_samples {
329 for k in 0..n_components {
330 let val = w_init[[i, k]];
331 w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
332 }
333 }
334
335 Ok((w, h))
336}
337
338fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
343 a: &Array2<F>,
344 max_iter: usize,
345) -> Result<(Array1<F>, Array2<F>), FerroError> {
346 let n = a.nrows();
347 if n == 0 {
348 return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
349 }
350 if n == 1 {
351 let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
352 let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
353 return Ok((eigenvalues, eigenvectors));
354 }
355
356 let mut mat = a.to_owned();
357 let mut v = Array2::<F>::zeros((n, n));
358 for i in 0..n {
359 v[[i, i]] = F::one();
360 }
361
362 let tol = F::from(1e-12).unwrap_or(F::epsilon());
363
364 for _iteration in 0..max_iter {
365 let mut max_off = F::zero();
366 let mut p = 0;
367 let mut q = 1;
368 for i in 0..n {
369 for j in (i + 1)..n {
370 let val = mat[[i, j]].abs();
371 if val > max_off {
372 max_off = val;
373 p = i;
374 q = j;
375 }
376 }
377 }
378
379 if max_off < tol {
380 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
381 return Ok((eigenvalues, v));
382 }
383
384 let app = mat[[p, p]];
385 let aqq = mat[[q, q]];
386 let apq = mat[[p, q]];
387
388 let theta = if (app - aqq).abs() < tol {
389 F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
390 } else {
391 let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
392 let t = if tau >= F::zero() {
393 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
394 } else {
395 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
396 };
397 t.atan()
398 };
399
400 let c = theta.cos();
401 let s = theta.sin();
402
403 let mut new_mat = mat.clone();
404 for i in 0..n {
405 if i != p && i != q {
406 let mip = mat[[i, p]];
407 let miq = mat[[i, q]];
408 new_mat[[i, p]] = c * mip - s * miq;
409 new_mat[[p, i]] = new_mat[[i, p]];
410 new_mat[[i, q]] = s * mip + c * miq;
411 new_mat[[q, i]] = new_mat[[i, q]];
412 }
413 }
414
415 new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
416 new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
417 new_mat[[p, q]] = F::zero();
418 new_mat[[q, p]] = F::zero();
419
420 mat = new_mat;
421
422 for i in 0..n {
423 let vip = v[[i, p]];
424 let viq = v[[i, q]];
425 v[[i, p]] = c * vip - s * viq;
426 v[[i, q]] = s * vip + c * viq;
427 }
428 }
429
430 Err(FerroError::ConvergenceFailure {
431 iterations: max_iter,
432 message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
433 })
434}
435
436fn solve_multiplicative_update<F: Float + 'static>(
442 x: &Array2<F>,
443 w: &mut Array2<F>,
444 h: &mut Array2<F>,
445 max_iter: usize,
446 tol: f64,
447) -> usize {
448 let tol_f = F::from(tol).unwrap_or(F::epsilon());
449 let epsilon = eps::<F>();
450 let mut prev_err = reconstruction_error(x, w, h);
451
452 for iteration in 0..max_iter {
453 let wt = w.t();
455 let numerator_h = wt.dot(x);
456 let denominator_h = wt.dot(&*w).dot(&*h);
457
458 for (h_val, (num, den)) in h
459 .iter_mut()
460 .zip(numerator_h.iter().zip(denominator_h.iter()))
461 {
462 *h_val = *h_val * (*num / (*den + epsilon));
463 }
464
465 let ht = h.t();
467 let numerator_w = x.dot(&ht);
468 let denominator_w = w.dot(&*h).dot(&ht);
469
470 for (w_val, (num, den)) in w
471 .iter_mut()
472 .zip(numerator_w.iter().zip(denominator_w.iter()))
473 {
474 *w_val = *w_val * (*num / (*den + epsilon));
475 }
476
477 let err = reconstruction_error(x, w, h);
479 if (prev_err - err).abs() < tol_f {
480 return iteration + 1;
481 }
482 prev_err = err;
483 }
484
485 max_iter
486}
487
488fn solve_coordinate_descent<F: Float + 'static>(
492 x: &Array2<F>,
493 w: &mut Array2<F>,
494 h: &mut Array2<F>,
495 max_iter: usize,
496 tol: f64,
497) -> usize {
498 let (n_samples, n_features) = x.dim();
499 let n_components = h.nrows();
500 let tol_f = F::from(tol).unwrap_or(F::epsilon());
501 let epsilon = eps::<F>();
502 let mut prev_err = reconstruction_error(x, w, h);
503
504 for iteration in 0..max_iter {
505 for k in 0..n_components {
508 let mut wk_norm_sq = F::zero();
509 for i in 0..n_samples {
510 wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
511 }
512
513 if wk_norm_sq < epsilon {
514 continue;
515 }
516
517 for j in 0..n_features {
518 let mut numerator = F::zero();
520 for i in 0..n_samples {
521 let mut wh_ij = F::zero();
522 for kk in 0..n_components {
523 if kk != k {
524 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
525 }
526 }
527 numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
528 }
529
530 h[[k, j]] = if numerator > F::zero() {
531 numerator / wk_norm_sq
532 } else {
533 F::zero()
534 };
535 }
536 }
537
538 for k in 0..n_components {
540 let mut hk_norm_sq = F::zero();
541 for j in 0..n_features {
542 hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
543 }
544
545 if hk_norm_sq < epsilon {
546 continue;
547 }
548
549 for i in 0..n_samples {
550 let mut numerator = F::zero();
551 for j in 0..n_features {
552 let mut wh_ij = F::zero();
553 for kk in 0..n_components {
554 if kk != k {
555 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
556 }
557 }
558 numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
559 }
560
561 w[[i, k]] = if numerator > F::zero() {
562 numerator / hk_norm_sq
563 } else {
564 F::zero()
565 };
566 }
567 }
568
569 let err = reconstruction_error(x, w, h);
571 if (prev_err - err).abs() < tol_f {
572 return iteration + 1;
573 }
574 prev_err = err;
575 }
576
577 max_iter
578}
579
580impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
585 type Fitted = FittedNMF<F>;
586 type Error = FerroError;
587
588 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
598 let (n_samples, n_features) = x.dim();
599
600 if self.n_components == 0 {
601 return Err(FerroError::InvalidParameter {
602 name: "n_components".into(),
603 reason: "must be at least 1".into(),
604 });
605 }
606 if n_samples == 0 {
607 return Err(FerroError::InsufficientSamples {
608 required: 1,
609 actual: 0,
610 context: "NMF::fit".into(),
611 });
612 }
613 if self.n_components > n_samples.min(n_features) {
614 return Err(FerroError::InvalidParameter {
615 name: "n_components".into(),
616 reason: format!(
617 "n_components ({}) exceeds min(n_samples, n_features) = {}",
618 self.n_components,
619 n_samples.min(n_features)
620 ),
621 });
622 }
623
624 for &val in x.iter() {
626 if val < F::zero() {
627 return Err(FerroError::InvalidParameter {
628 name: "X".into(),
629 reason: "NMF requires all entries in X to be non-negative".into(),
630 });
631 }
632 }
633
634 let seed = self.random_state.unwrap_or(0);
635
636 let (mut w, mut h) = match self.init {
638 NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
639 NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
640 };
641
642 let n_iter = match self.solver {
644 NMFSolver::MultiplicativeUpdate => {
645 solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
646 }
647 NMFSolver::CoordinateDescent => {
648 solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
649 }
650 };
651
652 let reconstruction_err = reconstruction_error(x, &w, &h);
653
654 Ok(FittedNMF {
655 components_: h,
656 reconstruction_err_: reconstruction_err,
657 n_iter_: n_iter,
658 })
659 }
660}
661
662impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
663 type Output = Array2<F>;
664 type Error = FerroError;
665
666 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
677 let n_features = self.components_.ncols();
678 if x.ncols() != n_features {
679 return Err(FerroError::ShapeMismatch {
680 expected: vec![x.nrows(), n_features],
681 actual: vec![x.nrows(), x.ncols()],
682 context: "FittedNMF::transform".into(),
683 });
684 }
685
686 for &val in x.iter() {
688 if val < F::zero() {
689 return Err(FerroError::InvalidParameter {
690 name: "X".into(),
691 reason: "NMF requires all entries in X to be non-negative".into(),
692 });
693 }
694 }
695
696 let n_samples = x.nrows();
697 let n_components = self.components_.nrows();
698 let epsilon = eps::<F>();
699
700 let mut w = Array2::<F>::zeros((n_samples, n_components));
702 let init_val = F::from(0.1).unwrap_or(F::one());
703 for elem in w.iter_mut() {
704 *elem = init_val;
705 }
706
707 let h = &self.components_;
709 for _iter in 0..200 {
710 let wt_num = x.dot(&h.t());
711 let wt_den = w.dot(h).dot(&h.t());
712
713 for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
714 *w_val = *w_val * (*num / (*den + epsilon));
715 }
716 }
717
718 Ok(w)
719 }
720}
721
722impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
727 fn fit_pipeline(
735 &self,
736 x: &Array2<F>,
737 _y: &Array1<F>,
738 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
739 let fitted = self.fit(x, &())?;
740 Ok(Box::new(fitted))
741 }
742}
743
744impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
745 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
751 self.transform(x)
752 }
753}
754
755#[cfg(test)]
760mod tests {
761 use super::*;
762 use approx::assert_abs_diff_eq;
763 use ndarray::array;
764
765 fn small_dataset() -> Array2<f64> {
767 array![
768 [1.0, 2.0, 3.0],
769 [4.0, 5.0, 6.0],
770 [7.0, 8.0, 9.0],
771 [10.0, 11.0, 12.0],
772 ]
773 }
774
775 fn medium_dataset() -> Array2<f64> {
777 array![
778 [5.0, 3.0, 0.0, 1.0],
779 [4.0, 0.0, 0.0, 1.0],
780 [1.0, 1.0, 0.0, 5.0],
781 [1.0, 0.0, 0.0, 4.0],
782 [0.0, 1.0, 5.0, 4.0],
783 [0.0, 0.0, 4.0, 3.0],
784 ]
785 }
786
787 #[test]
788 fn test_nmf_basic_fit() {
789 let nmf = NMF::<f64>::new(2).with_random_state(42);
790 let x = small_dataset();
791 let fitted = nmf.fit(&x, &()).unwrap();
792 assert_eq!(fitted.components().dim(), (2, 3));
793 }
794
795 #[test]
796 fn test_nmf_components_non_negative() {
797 let nmf = NMF::<f64>::new(2).with_random_state(42);
798 let x = small_dataset();
799 let fitted = nmf.fit(&x, &()).unwrap();
800 for &val in fitted.components().iter() {
801 assert!(
802 val >= 0.0,
803 "component value should be non-negative, got {val}"
804 );
805 }
806 }
807
808 #[test]
809 fn test_nmf_transform_dimensions() {
810 let nmf = NMF::<f64>::new(2).with_random_state(42);
811 let x = small_dataset();
812 let fitted = nmf.fit(&x, &()).unwrap();
813 let projected = fitted.transform(&x).unwrap();
814 assert_eq!(projected.dim(), (4, 2));
815 }
816
817 #[test]
818 fn test_nmf_transform_non_negative() {
819 let nmf = NMF::<f64>::new(2).with_random_state(42);
820 let x = small_dataset();
821 let fitted = nmf.fit(&x, &()).unwrap();
822 let projected = fitted.transform(&x).unwrap();
823 for &val in projected.iter() {
824 assert!(val >= 0.0, "W value should be non-negative, got {val}");
825 }
826 }
827
828 #[test]
829 fn test_nmf_reconstruction_error_decreases() {
830 let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
831 let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
832 let x = small_dataset();
833 let fitted_few = nmf_few.fit(&x, &()).unwrap();
834 let fitted_many = nmf_many.fit(&x, &()).unwrap();
835 assert!(
836 fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
837 "more iterations should reduce error: few={}, many={}",
838 fitted_few.reconstruction_err(),
839 fitted_many.reconstruction_err()
840 );
841 }
842
843 #[test]
844 fn test_nmf_reconstruction_error_positive() {
845 let nmf = NMF::<f64>::new(2).with_random_state(42);
846 let x = small_dataset();
847 let fitted = nmf.fit(&x, &()).unwrap();
848 assert!(fitted.reconstruction_err() >= 0.0);
849 }
850
851 #[test]
852 fn test_nmf_coordinate_descent_solver() {
853 let nmf = NMF::<f64>::new(2)
854 .with_solver(NMFSolver::CoordinateDescent)
855 .with_random_state(42);
856 let x = medium_dataset();
857 let fitted = nmf.fit(&x, &()).unwrap();
858 assert_eq!(fitted.components().dim(), (2, 4));
859 for &val in fitted.components().iter() {
860 assert!(val >= 0.0, "CD component should be non-negative, got {val}");
861 }
862 }
863
864 #[test]
865 fn test_nmf_nndsvd_init() {
866 let nmf = NMF::<f64>::new(2)
867 .with_init(NMFInit::Nndsvd)
868 .with_random_state(42);
869 let x = medium_dataset();
870 let fitted = nmf.fit(&x, &()).unwrap();
871 assert_eq!(fitted.components().dim(), (2, 4));
872 for &val in fitted.components().iter() {
873 assert!(
874 val >= 0.0,
875 "NNDSVD component should be non-negative, got {val}"
876 );
877 }
878 }
879
880 #[test]
881 fn test_nmf_cd_with_nndsvd() {
882 let nmf = NMF::<f64>::new(2)
883 .with_solver(NMFSolver::CoordinateDescent)
884 .with_init(NMFInit::Nndsvd)
885 .with_random_state(42);
886 let x = medium_dataset();
887 let fitted = nmf.fit(&x, &()).unwrap();
888 assert_eq!(fitted.components().dim(), (2, 4));
889 }
890
891 #[test]
892 fn test_nmf_invalid_n_components_zero() {
893 let nmf = NMF::<f64>::new(0);
894 let x = small_dataset();
895 assert!(nmf.fit(&x, &()).is_err());
896 }
897
898 #[test]
899 fn test_nmf_invalid_n_components_too_large() {
900 let nmf = NMF::<f64>::new(10);
901 let x = small_dataset(); assert!(nmf.fit(&x, &()).is_err());
903 }
904
905 #[test]
906 fn test_nmf_negative_input_rejected() {
907 let nmf = NMF::<f64>::new(1);
908 let x = array![[1.0, -2.0], [3.0, 4.0]];
909 assert!(nmf.fit(&x, &()).is_err());
910 }
911
912 #[test]
913 fn test_nmf_transform_shape_mismatch() {
914 let nmf = NMF::<f64>::new(2).with_random_state(42);
915 let x = small_dataset();
916 let fitted = nmf.fit(&x, &()).unwrap();
917 let x_bad = array![[1.0, 2.0]]; assert!(fitted.transform(&x_bad).is_err());
919 }
920
921 #[test]
922 fn test_nmf_transform_negative_rejected() {
923 let nmf = NMF::<f64>::new(2).with_random_state(42);
924 let x = small_dataset();
925 let fitted = nmf.fit(&x, &()).unwrap();
926 let x_neg = array![[1.0, -2.0, 3.0]];
927 assert!(fitted.transform(&x_neg).is_err());
928 }
929
930 #[test]
931 fn test_nmf_reproducibility() {
932 let nmf1 = NMF::<f64>::new(2).with_random_state(42);
933 let nmf2 = NMF::<f64>::new(2).with_random_state(42);
934 let x = small_dataset();
935 let fitted1 = nmf1.fit(&x, &()).unwrap();
936 let fitted2 = nmf2.fit(&x, &()).unwrap();
937 for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
938 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
939 }
940 }
941
942 #[test]
943 fn test_nmf_single_component() {
944 let nmf = NMF::<f64>::new(1).with_random_state(42);
945 let x = small_dataset();
946 let fitted = nmf.fit(&x, &()).unwrap();
947 assert_eq!(fitted.components().nrows(), 1);
948 let projected = fitted.transform(&x).unwrap();
949 assert_eq!(projected.ncols(), 1);
950 }
951
952 #[test]
953 fn test_nmf_n_iter_positive() {
954 let nmf = NMF::<f64>::new(2).with_random_state(42);
955 let x = small_dataset();
956 let fitted = nmf.fit(&x, &()).unwrap();
957 assert!(fitted.n_iter() > 0);
958 }
959
960 #[test]
961 fn test_nmf_getters() {
962 let nmf = NMF::<f64>::new(3)
963 .with_max_iter(100)
964 .with_tol(1e-5)
965 .with_solver(NMFSolver::CoordinateDescent)
966 .with_init(NMFInit::Nndsvd)
967 .with_random_state(99);
968 assert_eq!(nmf.n_components(), 3);
969 assert_eq!(nmf.max_iter(), 100);
970 assert_abs_diff_eq!(nmf.tol(), 1e-5);
971 assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
972 assert_eq!(nmf.init(), NMFInit::Nndsvd);
973 assert_eq!(nmf.random_state(), Some(99));
974 }
975
976 #[test]
977 fn test_nmf_f32() {
978 let nmf = NMF::<f32>::new(1).with_random_state(42);
979 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
980 let fitted = nmf.fit(&x, &()).unwrap();
981 let projected = fitted.transform(&x).unwrap();
982 assert_eq!(projected.ncols(), 1);
983 }
984
985 #[test]
986 fn test_nmf_zero_entries() {
987 let nmf = NMF::<f64>::new(2).with_random_state(42);
988 let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
989 let fitted = nmf.fit(&x, &()).unwrap();
990 assert_eq!(fitted.components().dim(), (2, 3));
991 }
992
993 #[test]
994 fn test_nmf_pipeline_integration() {
995 use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
996 use ferrolearn_core::traits::Predict;
997
998 struct SumEstimator;
999
1000 impl PipelineEstimator<f64> for SumEstimator {
1001 fn fit_pipeline(
1002 &self,
1003 _x: &Array2<f64>,
1004 _y: &Array1<f64>,
1005 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
1006 Ok(Box::new(FittedSumEstimator))
1007 }
1008 }
1009
1010 struct FittedSumEstimator;
1011
1012 impl FittedPipelineEstimator<f64> for FittedSumEstimator {
1013 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1014 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
1015 Ok(Array1::from_vec(sums))
1016 }
1017 }
1018
1019 let pipeline = Pipeline::new()
1020 .transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
1021 .estimator_step("sum", Box::new(SumEstimator));
1022
1023 let x = small_dataset();
1024 let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1025
1026 let fitted = pipeline.fit(&x, &y).unwrap();
1027 let preds = fitted.predict(&x).unwrap();
1028 assert_eq!(preds.len(), 4);
1029 }
1030
1031 #[test]
1032 fn test_nmf_medium_dataset_mu() {
1033 let nmf = NMF::<f64>::new(3)
1034 .with_solver(NMFSolver::MultiplicativeUpdate)
1035 .with_random_state(42)
1036 .with_max_iter(500);
1037 let x = medium_dataset();
1038 let fitted = nmf.fit(&x, &()).unwrap();
1039 assert_eq!(fitted.components().dim(), (3, 4));
1040 assert!(
1042 fitted.reconstruction_err() < 10.0,
1043 "reconstruction error too large: {}",
1044 fitted.reconstruction_err()
1045 );
1046 }
1047
1048 #[test]
1049 fn test_nmf_insufficient_samples() {
1050 let nmf = NMF::<f64>::new(1);
1051 let x = Array2::<f64>::zeros((0, 3));
1052 assert!(nmf.fit(&x, &()).is_err());
1053 }
1054
1055 #[test]
1056 fn test_nmf_more_components_lower_error() {
1057 let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
1058 let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
1059 let x = medium_dataset();
1060 let fitted1 = nmf1.fit(&x, &()).unwrap();
1061 let fitted2 = nmf2.fit(&x, &()).unwrap();
1062 assert!(
1063 fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
1064 "more components should reduce error: 1comp={}, 2comp={}",
1065 fitted1.reconstruction_err(),
1066 fitted2.reconstruction_err()
1067 );
1068 }
1069}