1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
40use ferrolearn_core::traits::{Fit, Transform};
41use ndarray::{Array1, Array2};
42use num_traits::Float;
43use rand::SeedableRng;
44use rand_distr::{Distribution, Uniform};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum NMFSolver {
53 MultiplicativeUpdate,
55 CoordinateDescent,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum NMFInit {
62 Random,
64 Nndsvd,
66}
67
68#[derive(Debug, Clone)]
78pub struct NMF<F> {
79 n_components: usize,
81 max_iter: usize,
83 tol: f64,
85 solver: NMFSolver,
87 init: NMFInit,
89 random_state: Option<u64>,
91 _marker: std::marker::PhantomData<F>,
92}
93
94impl<F: Float + Send + Sync + 'static> NMF<F> {
95 #[must_use]
100 pub fn new(n_components: usize) -> Self {
101 Self {
102 n_components,
103 max_iter: 200,
104 tol: 1e-4,
105 solver: NMFSolver::MultiplicativeUpdate,
106 init: NMFInit::Random,
107 random_state: None,
108 _marker: std::marker::PhantomData,
109 }
110 }
111
112 #[must_use]
114 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
115 self.max_iter = max_iter;
116 self
117 }
118
119 #[must_use]
121 pub fn with_tol(mut self, tol: f64) -> Self {
122 self.tol = tol;
123 self
124 }
125
126 #[must_use]
128 pub fn with_solver(mut self, solver: NMFSolver) -> Self {
129 self.solver = solver;
130 self
131 }
132
133 #[must_use]
135 pub fn with_init(mut self, init: NMFInit) -> Self {
136 self.init = init;
137 self
138 }
139
140 #[must_use]
142 pub fn with_random_state(mut self, seed: u64) -> Self {
143 self.random_state = Some(seed);
144 self
145 }
146
147 #[must_use]
149 pub fn n_components(&self) -> usize {
150 self.n_components
151 }
152
153 #[must_use]
155 pub fn max_iter(&self) -> usize {
156 self.max_iter
157 }
158
159 #[must_use]
161 pub fn tol(&self) -> f64 {
162 self.tol
163 }
164
165 #[must_use]
167 pub fn solver(&self) -> NMFSolver {
168 self.solver
169 }
170
171 #[must_use]
173 pub fn init(&self) -> NMFInit {
174 self.init
175 }
176
177 #[must_use]
179 pub fn random_state(&self) -> Option<u64> {
180 self.random_state
181 }
182}
183
184#[derive(Debug, Clone)]
193pub struct FittedNMF<F> {
194 components_: Array2<F>,
196 reconstruction_err_: F,
198 n_iter_: usize,
200}
201
202impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
203 #[must_use]
205 pub fn components(&self) -> &Array2<F> {
206 &self.components_
207 }
208
209 #[must_use]
211 pub fn reconstruction_err(&self) -> F {
212 self.reconstruction_err_
213 }
214
215 #[must_use]
217 pub fn n_iter(&self) -> usize {
218 self.n_iter_
219 }
220}
221
222fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
228 let wh = w.dot(h);
229 let mut err = F::zero();
230 for (a, b) in x.iter().zip(wh.iter()) {
231 let diff = *a - *b;
232 err = err + diff * diff;
233 }
234 err.sqrt()
235}
236
237fn eps<F: Float>() -> F {
239 F::from(1e-12).unwrap_or_else(F::epsilon)
240}
241
242fn init_random<F: Float>(
244 n_samples: usize,
245 n_features: usize,
246 n_components: usize,
247 seed: u64,
248) -> (Array2<F>, Array2<F>) {
249 let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
250 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
251
252 let mut w = Array2::<F>::zeros((n_samples, n_components));
253 for elem in &mut w {
254 *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
255 }
256
257 let mut h = Array2::<F>::zeros((n_components, n_features));
258 for elem in &mut h {
259 *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
260 }
261
262 (w, h)
263}
264
265fn init_nndsvd<F: Float + Send + Sync + 'static>(
270 x: &Array2<F>,
271 n_components: usize,
272 seed: u64,
273) -> Result<(Array2<F>, Array2<F>), FerroError> {
274 let (n_samples, n_features) = x.dim();
275
276 let mut total = F::zero();
278 for &v in x {
279 total = total + v;
280 }
281 let avg = (total / F::from(n_samples * n_features).unwrap())
282 .abs()
283 .sqrt();
284 let avg = if avg < eps::<F>() { F::one() } else { avg };
285
286 let xtx = x.t().dot(x);
288
289 let max_iter = n_features * n_features * 100 + 1000;
291 let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
292
293 let mut indices: Vec<usize> = (0..n_features).collect();
295 indices.sort_by(|&a, &b| {
296 eigenvalues[b]
297 .partial_cmp(&eigenvalues[a])
298 .unwrap_or(std::cmp::Ordering::Equal)
299 });
300
301 let mut h = Array2::<F>::zeros((n_components, n_features));
303 for (k, &idx) in indices.iter().take(n_components).enumerate() {
304 for j in 0..n_features {
305 let val = eigenvectors[[j, idx]];
306 h[[k, j]] = if val > F::zero() { val } else { F::zero() };
307 }
308 let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
310 if row_sum < eps::<F>() {
311 let mut rng: rand::rngs::StdRng =
313 SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
314 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
315 for j in 0..n_features {
316 h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) * avg;
317 }
318 }
319 }
320
321 let mut w = Array2::<F>::zeros((n_samples, n_components));
324 let ht = h.t();
327 let w_init = x.dot(&ht);
328 for i in 0..n_samples {
329 for k in 0..n_components {
330 let val = w_init[[i, k]];
331 w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
332 }
333 }
334
335 Ok((w, h))
336}
337
338fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
343 a: &Array2<F>,
344 max_iter: usize,
345) -> Result<(Array1<F>, Array2<F>), FerroError> {
346 let n = a.nrows();
347 if n == 0 {
348 return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
349 }
350 if n == 1 {
351 let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
352 let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
353 return Ok((eigenvalues, eigenvectors));
354 }
355
356 let mut mat = a.to_owned();
357 let mut v = Array2::<F>::zeros((n, n));
358 for i in 0..n {
359 v[[i, i]] = F::one();
360 }
361
362 let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
363
364 for _iteration in 0..max_iter {
365 let mut max_off = F::zero();
366 let mut p = 0;
367 let mut q = 1;
368 for i in 0..n {
369 for j in (i + 1)..n {
370 let val = mat[[i, j]].abs();
371 if val > max_off {
372 max_off = val;
373 p = i;
374 q = j;
375 }
376 }
377 }
378
379 if max_off < tol {
380 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
381 return Ok((eigenvalues, v));
382 }
383
384 let app = mat[[p, p]];
385 let aqq = mat[[q, q]];
386 let apq = mat[[p, q]];
387
388 let theta = if (app - aqq).abs() < tol {
389 F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
390 } else {
391 let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
392 let t = if tau >= F::zero() {
393 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
394 } else {
395 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
396 };
397 t.atan()
398 };
399
400 let c = theta.cos();
401 let s = theta.sin();
402
403 let mut new_mat = mat.clone();
404 for i in 0..n {
405 if i != p && i != q {
406 let mip = mat[[i, p]];
407 let miq = mat[[i, q]];
408 new_mat[[i, p]] = c * mip - s * miq;
409 new_mat[[p, i]] = new_mat[[i, p]];
410 new_mat[[i, q]] = s * mip + c * miq;
411 new_mat[[q, i]] = new_mat[[i, q]];
412 }
413 }
414
415 new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
416 new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
417 new_mat[[p, q]] = F::zero();
418 new_mat[[q, p]] = F::zero();
419
420 mat = new_mat;
421
422 for i in 0..n {
423 let vip = v[[i, p]];
424 let viq = v[[i, q]];
425 v[[i, p]] = c * vip - s * viq;
426 v[[i, q]] = s * vip + c * viq;
427 }
428 }
429
430 Err(FerroError::ConvergenceFailure {
431 iterations: max_iter,
432 message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
433 })
434}
435
436fn solve_multiplicative_update<F: Float + 'static>(
442 x: &Array2<F>,
443 w: &mut Array2<F>,
444 h: &mut Array2<F>,
445 max_iter: usize,
446 tol: f64,
447) -> usize {
448 let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
449 let epsilon = eps::<F>();
450 let mut prev_err = reconstruction_error(x, w, h);
451
452 for iteration in 0..max_iter {
453 let wt = w.t();
455 let numerator_h = wt.dot(x);
456 let denominator_h = wt.dot(&*w).dot(&*h);
457
458 for (h_val, (num, den)) in h
459 .iter_mut()
460 .zip(numerator_h.iter().zip(denominator_h.iter()))
461 {
462 *h_val = *h_val * (*num / (*den + epsilon));
463 }
464
465 let ht = h.t();
467 let numerator_w = x.dot(&ht);
468 let denominator_w = w.dot(&*h).dot(&ht);
469
470 for (w_val, (num, den)) in w
471 .iter_mut()
472 .zip(numerator_w.iter().zip(denominator_w.iter()))
473 {
474 *w_val = *w_val * (*num / (*den + epsilon));
475 }
476
477 let err = reconstruction_error(x, w, h);
479 if (prev_err - err).abs() < tol_f {
480 return iteration + 1;
481 }
482 prev_err = err;
483 }
484
485 max_iter
486}
487
488fn solve_coordinate_descent<F: Float + 'static>(
492 x: &Array2<F>,
493 w: &mut Array2<F>,
494 h: &mut Array2<F>,
495 max_iter: usize,
496 tol: f64,
497) -> usize {
498 let (n_samples, n_features) = x.dim();
499 let n_components = h.nrows();
500 let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
501 let epsilon = eps::<F>();
502 let mut prev_err = reconstruction_error(x, w, h);
503
504 for iteration in 0..max_iter {
505 for k in 0..n_components {
508 let mut wk_norm_sq = F::zero();
509 for i in 0..n_samples {
510 wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
511 }
512
513 if wk_norm_sq < epsilon {
514 continue;
515 }
516
517 for j in 0..n_features {
518 let mut numerator = F::zero();
520 for i in 0..n_samples {
521 let mut wh_ij = F::zero();
522 for kk in 0..n_components {
523 if kk != k {
524 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
525 }
526 }
527 numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
528 }
529
530 h[[k, j]] = if numerator > F::zero() {
531 numerator / wk_norm_sq
532 } else {
533 F::zero()
534 };
535 }
536 }
537
538 for k in 0..n_components {
540 let mut hk_norm_sq = F::zero();
541 for j in 0..n_features {
542 hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
543 }
544
545 if hk_norm_sq < epsilon {
546 continue;
547 }
548
549 for i in 0..n_samples {
550 let mut numerator = F::zero();
551 for j in 0..n_features {
552 let mut wh_ij = F::zero();
553 for kk in 0..n_components {
554 if kk != k {
555 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
556 }
557 }
558 numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
559 }
560
561 w[[i, k]] = if numerator > F::zero() {
562 numerator / hk_norm_sq
563 } else {
564 F::zero()
565 };
566 }
567 }
568
569 let err = reconstruction_error(x, w, h);
571 if (prev_err - err).abs() < tol_f {
572 return iteration + 1;
573 }
574 prev_err = err;
575 }
576
577 max_iter
578}
579
580impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
585 type Fitted = FittedNMF<F>;
586 type Error = FerroError;
587
588 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
598 let (n_samples, n_features) = x.dim();
599
600 if self.n_components == 0 {
601 return Err(FerroError::InvalidParameter {
602 name: "n_components".into(),
603 reason: "must be at least 1".into(),
604 });
605 }
606 if n_samples == 0 {
607 return Err(FerroError::InsufficientSamples {
608 required: 1,
609 actual: 0,
610 context: "NMF::fit".into(),
611 });
612 }
613 if self.n_components > n_samples.min(n_features) {
614 return Err(FerroError::InvalidParameter {
615 name: "n_components".into(),
616 reason: format!(
617 "n_components ({}) exceeds min(n_samples, n_features) = {}",
618 self.n_components,
619 n_samples.min(n_features)
620 ),
621 });
622 }
623
624 for &val in x {
626 if val < F::zero() {
627 return Err(FerroError::InvalidParameter {
628 name: "X".into(),
629 reason: "NMF requires all entries in X to be non-negative".into(),
630 });
631 }
632 }
633
634 let seed = self.random_state.unwrap_or(0);
635
636 let (mut w, mut h) = match self.init {
638 NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
639 NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
640 };
641
642 let n_iter = match self.solver {
644 NMFSolver::MultiplicativeUpdate => {
645 solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
646 }
647 NMFSolver::CoordinateDescent => {
648 solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
649 }
650 };
651
652 let reconstruction_err = reconstruction_error(x, &w, &h);
653
654 Ok(FittedNMF {
655 components_: h,
656 reconstruction_err_: reconstruction_err,
657 n_iter_: n_iter,
658 })
659 }
660}
661
662impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
663 type Output = Array2<F>;
664 type Error = FerroError;
665
666 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
677 let n_features = self.components_.ncols();
678 if x.ncols() != n_features {
679 return Err(FerroError::ShapeMismatch {
680 expected: vec![x.nrows(), n_features],
681 actual: vec![x.nrows(), x.ncols()],
682 context: "FittedNMF::transform".into(),
683 });
684 }
685
686 for &val in x {
688 if val < F::zero() {
689 return Err(FerroError::InvalidParameter {
690 name: "X".into(),
691 reason: "NMF requires all entries in X to be non-negative".into(),
692 });
693 }
694 }
695
696 let n_samples = x.nrows();
697 let n_components = self.components_.nrows();
698 let epsilon = eps::<F>();
699
700 let mut w = Array2::<F>::zeros((n_samples, n_components));
702 let init_val = F::from(0.1).unwrap_or_else(F::one);
703 w.fill(init_val);
704
705 let h = &self.components_;
707 for _iter in 0..200 {
708 let wt_num = x.dot(&h.t());
709 let wt_den = w.dot(h).dot(&h.t());
710
711 for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
712 *w_val = *w_val * (*num / (*den + epsilon));
713 }
714 }
715
716 Ok(w)
717 }
718}
719
720impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
725 fn fit_pipeline(
733 &self,
734 x: &Array2<F>,
735 _y: &Array1<F>,
736 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
737 let fitted = self.fit(x, &())?;
738 Ok(Box::new(fitted))
739 }
740}
741
742impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
743 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
749 self.transform(x)
750 }
751}
752
753#[cfg(test)]
758mod tests {
759 use super::*;
760 use approx::assert_abs_diff_eq;
761 use ndarray::array;
762
763 fn small_dataset() -> Array2<f64> {
765 array![
766 [1.0, 2.0, 3.0],
767 [4.0, 5.0, 6.0],
768 [7.0, 8.0, 9.0],
769 [10.0, 11.0, 12.0],
770 ]
771 }
772
773 fn medium_dataset() -> Array2<f64> {
775 array![
776 [5.0, 3.0, 0.0, 1.0],
777 [4.0, 0.0, 0.0, 1.0],
778 [1.0, 1.0, 0.0, 5.0],
779 [1.0, 0.0, 0.0, 4.0],
780 [0.0, 1.0, 5.0, 4.0],
781 [0.0, 0.0, 4.0, 3.0],
782 ]
783 }
784
785 #[test]
786 fn test_nmf_basic_fit() {
787 let nmf = NMF::<f64>::new(2).with_random_state(42);
788 let x = small_dataset();
789 let fitted = nmf.fit(&x, &()).unwrap();
790 assert_eq!(fitted.components().dim(), (2, 3));
791 }
792
793 #[test]
794 fn test_nmf_components_non_negative() {
795 let nmf = NMF::<f64>::new(2).with_random_state(42);
796 let x = small_dataset();
797 let fitted = nmf.fit(&x, &()).unwrap();
798 for &val in fitted.components() {
799 assert!(
800 val >= 0.0,
801 "component value should be non-negative, got {val}"
802 );
803 }
804 }
805
806 #[test]
807 fn test_nmf_transform_dimensions() {
808 let nmf = NMF::<f64>::new(2).with_random_state(42);
809 let x = small_dataset();
810 let fitted = nmf.fit(&x, &()).unwrap();
811 let projected = fitted.transform(&x).unwrap();
812 assert_eq!(projected.dim(), (4, 2));
813 }
814
815 #[test]
816 fn test_nmf_transform_non_negative() {
817 let nmf = NMF::<f64>::new(2).with_random_state(42);
818 let x = small_dataset();
819 let fitted = nmf.fit(&x, &()).unwrap();
820 let projected = fitted.transform(&x).unwrap();
821 for &val in &projected {
822 assert!(val >= 0.0, "W value should be non-negative, got {val}");
823 }
824 }
825
826 #[test]
827 fn test_nmf_reconstruction_error_decreases() {
828 let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
829 let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
830 let x = small_dataset();
831 let fitted_few = nmf_few.fit(&x, &()).unwrap();
832 let fitted_many = nmf_many.fit(&x, &()).unwrap();
833 assert!(
834 fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
835 "more iterations should reduce error: few={}, many={}",
836 fitted_few.reconstruction_err(),
837 fitted_many.reconstruction_err()
838 );
839 }
840
841 #[test]
842 fn test_nmf_reconstruction_error_positive() {
843 let nmf = NMF::<f64>::new(2).with_random_state(42);
844 let x = small_dataset();
845 let fitted = nmf.fit(&x, &()).unwrap();
846 assert!(fitted.reconstruction_err() >= 0.0);
847 }
848
849 #[test]
850 fn test_nmf_coordinate_descent_solver() {
851 let nmf = NMF::<f64>::new(2)
852 .with_solver(NMFSolver::CoordinateDescent)
853 .with_random_state(42);
854 let x = medium_dataset();
855 let fitted = nmf.fit(&x, &()).unwrap();
856 assert_eq!(fitted.components().dim(), (2, 4));
857 for &val in fitted.components() {
858 assert!(val >= 0.0, "CD component should be non-negative, got {val}");
859 }
860 }
861
862 #[test]
863 fn test_nmf_nndsvd_init() {
864 let nmf = NMF::<f64>::new(2)
865 .with_init(NMFInit::Nndsvd)
866 .with_random_state(42);
867 let x = medium_dataset();
868 let fitted = nmf.fit(&x, &()).unwrap();
869 assert_eq!(fitted.components().dim(), (2, 4));
870 for &val in fitted.components() {
871 assert!(
872 val >= 0.0,
873 "NNDSVD component should be non-negative, got {val}"
874 );
875 }
876 }
877
878 #[test]
879 fn test_nmf_cd_with_nndsvd() {
880 let nmf = NMF::<f64>::new(2)
881 .with_solver(NMFSolver::CoordinateDescent)
882 .with_init(NMFInit::Nndsvd)
883 .with_random_state(42);
884 let x = medium_dataset();
885 let fitted = nmf.fit(&x, &()).unwrap();
886 assert_eq!(fitted.components().dim(), (2, 4));
887 }
888
889 #[test]
890 fn test_nmf_invalid_n_components_zero() {
891 let nmf = NMF::<f64>::new(0);
892 let x = small_dataset();
893 assert!(nmf.fit(&x, &()).is_err());
894 }
895
896 #[test]
897 fn test_nmf_invalid_n_components_too_large() {
898 let nmf = NMF::<f64>::new(10);
899 let x = small_dataset(); assert!(nmf.fit(&x, &()).is_err());
901 }
902
903 #[test]
904 fn test_nmf_negative_input_rejected() {
905 let nmf = NMF::<f64>::new(1);
906 let x = array![[1.0, -2.0], [3.0, 4.0]];
907 assert!(nmf.fit(&x, &()).is_err());
908 }
909
910 #[test]
911 fn test_nmf_transform_shape_mismatch() {
912 let nmf = NMF::<f64>::new(2).with_random_state(42);
913 let x = small_dataset();
914 let fitted = nmf.fit(&x, &()).unwrap();
915 let x_bad = array![[1.0, 2.0]]; assert!(fitted.transform(&x_bad).is_err());
917 }
918
919 #[test]
920 fn test_nmf_transform_negative_rejected() {
921 let nmf = NMF::<f64>::new(2).with_random_state(42);
922 let x = small_dataset();
923 let fitted = nmf.fit(&x, &()).unwrap();
924 let x_neg = array![[1.0, -2.0, 3.0]];
925 assert!(fitted.transform(&x_neg).is_err());
926 }
927
928 #[test]
929 fn test_nmf_reproducibility() {
930 let nmf1 = NMF::<f64>::new(2).with_random_state(42);
931 let nmf2 = NMF::<f64>::new(2).with_random_state(42);
932 let x = small_dataset();
933 let fitted1 = nmf1.fit(&x, &()).unwrap();
934 let fitted2 = nmf2.fit(&x, &()).unwrap();
935 for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
936 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
937 }
938 }
939
940 #[test]
941 fn test_nmf_single_component() {
942 let nmf = NMF::<f64>::new(1).with_random_state(42);
943 let x = small_dataset();
944 let fitted = nmf.fit(&x, &()).unwrap();
945 assert_eq!(fitted.components().nrows(), 1);
946 let projected = fitted.transform(&x).unwrap();
947 assert_eq!(projected.ncols(), 1);
948 }
949
950 #[test]
951 fn test_nmf_n_iter_positive() {
952 let nmf = NMF::<f64>::new(2).with_random_state(42);
953 let x = small_dataset();
954 let fitted = nmf.fit(&x, &()).unwrap();
955 assert!(fitted.n_iter() > 0);
956 }
957
958 #[test]
959 fn test_nmf_getters() {
960 let nmf = NMF::<f64>::new(3)
961 .with_max_iter(100)
962 .with_tol(1e-5)
963 .with_solver(NMFSolver::CoordinateDescent)
964 .with_init(NMFInit::Nndsvd)
965 .with_random_state(99);
966 assert_eq!(nmf.n_components(), 3);
967 assert_eq!(nmf.max_iter(), 100);
968 assert_abs_diff_eq!(nmf.tol(), 1e-5);
969 assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
970 assert_eq!(nmf.init(), NMFInit::Nndsvd);
971 assert_eq!(nmf.random_state(), Some(99));
972 }
973
974 #[test]
975 fn test_nmf_f32() {
976 let nmf = NMF::<f32>::new(1).with_random_state(42);
977 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
978 let fitted = nmf.fit(&x, &()).unwrap();
979 let projected = fitted.transform(&x).unwrap();
980 assert_eq!(projected.ncols(), 1);
981 }
982
983 #[test]
984 fn test_nmf_zero_entries() {
985 let nmf = NMF::<f64>::new(2).with_random_state(42);
986 let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
987 let fitted = nmf.fit(&x, &()).unwrap();
988 assert_eq!(fitted.components().dim(), (2, 3));
989 }
990
991 #[test]
992 fn test_nmf_pipeline_integration() {
993 use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
994 use ferrolearn_core::traits::Predict;
995
996 struct SumEstimator;
997
998 impl PipelineEstimator<f64> for SumEstimator {
999 fn fit_pipeline(
1000 &self,
1001 _x: &Array2<f64>,
1002 _y: &Array1<f64>,
1003 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
1004 Ok(Box::new(FittedSumEstimator))
1005 }
1006 }
1007
1008 struct FittedSumEstimator;
1009
1010 impl FittedPipelineEstimator<f64> for FittedSumEstimator {
1011 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1012 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
1013 Ok(Array1::from_vec(sums))
1014 }
1015 }
1016
1017 let pipeline = Pipeline::new()
1018 .transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
1019 .estimator_step("sum", Box::new(SumEstimator));
1020
1021 let x = small_dataset();
1022 let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1023
1024 let fitted = pipeline.fit(&x, &y).unwrap();
1025 let preds = fitted.predict(&x).unwrap();
1026 assert_eq!(preds.len(), 4);
1027 }
1028
1029 #[test]
1030 fn test_nmf_medium_dataset_mu() {
1031 let nmf = NMF::<f64>::new(3)
1032 .with_solver(NMFSolver::MultiplicativeUpdate)
1033 .with_random_state(42)
1034 .with_max_iter(500);
1035 let x = medium_dataset();
1036 let fitted = nmf.fit(&x, &()).unwrap();
1037 assert_eq!(fitted.components().dim(), (3, 4));
1038 assert!(
1040 fitted.reconstruction_err() < 10.0,
1041 "reconstruction error too large: {}",
1042 fitted.reconstruction_err()
1043 );
1044 }
1045
1046 #[test]
1047 fn test_nmf_insufficient_samples() {
1048 let nmf = NMF::<f64>::new(1);
1049 let x = Array2::<f64>::zeros((0, 3));
1050 assert!(nmf.fit(&x, &()).is_err());
1051 }
1052
1053 #[test]
1054 fn test_nmf_more_components_lower_error() {
1055 let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
1056 let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
1057 let x = medium_dataset();
1058 let fitted1 = nmf1.fit(&x, &()).unwrap();
1059 let fitted2 = nmf2.fit(&x, &()).unwrap();
1060 assert!(
1061 fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
1062 "more components should reduce error: 1comp={}, 2comp={}",
1063 fitted1.reconstruction_err(),
1064 fitted2.reconstruction_err()
1065 );
1066 }
1067}