1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
40use ferrolearn_core::traits::{Fit, Transform};
41use ndarray::{Array1, Array2};
42use num_traits::Float;
43use rand::SeedableRng;
44use rand_distr::{Distribution, Uniform};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum NMFSolver {
53 MultiplicativeUpdate,
55 CoordinateDescent,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum NMFInit {
62 Random,
64 Nndsvd,
66}
67
68#[derive(Debug, Clone)]
78pub struct NMF<F> {
79 n_components: usize,
81 max_iter: usize,
83 tol: f64,
85 solver: NMFSolver,
87 init: NMFInit,
89 random_state: Option<u64>,
91 _marker: std::marker::PhantomData<F>,
92}
93
94impl<F: Float + Send + Sync + 'static> NMF<F> {
95 #[must_use]
100 pub fn new(n_components: usize) -> Self {
101 Self {
102 n_components,
103 max_iter: 200,
104 tol: 1e-4,
105 solver: NMFSolver::MultiplicativeUpdate,
106 init: NMFInit::Random,
107 random_state: None,
108 _marker: std::marker::PhantomData,
109 }
110 }
111
112 #[must_use]
114 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
115 self.max_iter = max_iter;
116 self
117 }
118
119 #[must_use]
121 pub fn with_tol(mut self, tol: f64) -> Self {
122 self.tol = tol;
123 self
124 }
125
126 #[must_use]
128 pub fn with_solver(mut self, solver: NMFSolver) -> Self {
129 self.solver = solver;
130 self
131 }
132
133 #[must_use]
135 pub fn with_init(mut self, init: NMFInit) -> Self {
136 self.init = init;
137 self
138 }
139
140 #[must_use]
142 pub fn with_random_state(mut self, seed: u64) -> Self {
143 self.random_state = Some(seed);
144 self
145 }
146
147 #[must_use]
149 pub fn n_components(&self) -> usize {
150 self.n_components
151 }
152
153 #[must_use]
155 pub fn max_iter(&self) -> usize {
156 self.max_iter
157 }
158
159 #[must_use]
161 pub fn tol(&self) -> f64 {
162 self.tol
163 }
164
165 #[must_use]
167 pub fn solver(&self) -> NMFSolver {
168 self.solver
169 }
170
171 #[must_use]
173 pub fn init(&self) -> NMFInit {
174 self.init
175 }
176
177 #[must_use]
179 pub fn random_state(&self) -> Option<u64> {
180 self.random_state
181 }
182}
183
184#[derive(Debug, Clone)]
193pub struct FittedNMF<F> {
194 components_: Array2<F>,
196 reconstruction_err_: F,
198 n_iter_: usize,
200}
201
202impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
203 #[must_use]
205 pub fn components(&self) -> &Array2<F> {
206 &self.components_
207 }
208
209 #[must_use]
211 pub fn reconstruction_err(&self) -> F {
212 self.reconstruction_err_
213 }
214
215 #[must_use]
217 pub fn n_iter(&self) -> usize {
218 self.n_iter_
219 }
220
221 pub fn inverse_transform(&self, w: &Array2<F>) -> Result<Array2<F>, FerroError> {
230 let n_components = self.components_.nrows();
231 if w.ncols() != n_components {
232 return Err(FerroError::ShapeMismatch {
233 expected: vec![w.nrows(), n_components],
234 actual: vec![w.nrows(), w.ncols()],
235 context: "FittedNMF::inverse_transform".into(),
236 });
237 }
238 Ok(w.dot(&self.components_))
239 }
240}
241
242fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
248 let wh = w.dot(h);
249 let mut err = F::zero();
250 for (a, b) in x.iter().zip(wh.iter()) {
251 let diff = *a - *b;
252 err = err + diff * diff;
253 }
254 err.sqrt()
255}
256
257fn eps<F: Float>() -> F {
259 F::from(1e-12).unwrap_or_else(F::epsilon)
260}
261
262fn init_random<F: Float>(
264 n_samples: usize,
265 n_features: usize,
266 n_components: usize,
267 seed: u64,
268) -> (Array2<F>, Array2<F>) {
269 let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
270 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
271
272 let mut w = Array2::<F>::zeros((n_samples, n_components));
273 for elem in &mut w {
274 *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
275 }
276
277 let mut h = Array2::<F>::zeros((n_components, n_features));
278 for elem in &mut h {
279 *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
280 }
281
282 (w, h)
283}
284
285fn init_nndsvd<F: Float + Send + Sync + 'static>(
290 x: &Array2<F>,
291 n_components: usize,
292 seed: u64,
293) -> Result<(Array2<F>, Array2<F>), FerroError> {
294 let (n_samples, n_features) = x.dim();
295
296 let mut total = F::zero();
298 for &v in x {
299 total = total + v;
300 }
301 let avg = (total / F::from(n_samples * n_features).unwrap())
302 .abs()
303 .sqrt();
304 let avg = if avg < eps::<F>() { F::one() } else { avg };
305
306 let xtx = x.t().dot(x);
308
309 let max_iter = n_features * n_features * 100 + 1000;
311 let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
312
313 let mut indices: Vec<usize> = (0..n_features).collect();
315 indices.sort_by(|&a, &b| {
316 eigenvalues[b]
317 .partial_cmp(&eigenvalues[a])
318 .unwrap_or(std::cmp::Ordering::Equal)
319 });
320
321 let mut h = Array2::<F>::zeros((n_components, n_features));
323 for (k, &idx) in indices.iter().take(n_components).enumerate() {
324 for j in 0..n_features {
325 let val = eigenvectors[[j, idx]];
326 h[[k, j]] = if val > F::zero() { val } else { F::zero() };
327 }
328 let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
330 if row_sum < eps::<F>() {
331 let mut rng: rand::rngs::StdRng =
333 SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
334 let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
335 for j in 0..n_features {
336 h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) * avg;
337 }
338 }
339 }
340
341 let mut w = Array2::<F>::zeros((n_samples, n_components));
344 let ht = h.t();
347 let w_init = x.dot(&ht);
348 for i in 0..n_samples {
349 for k in 0..n_components {
350 let val = w_init[[i, k]];
351 w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
352 }
353 }
354
355 Ok((w, h))
356}
357
358fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
363 a: &Array2<F>,
364 max_iter: usize,
365) -> Result<(Array1<F>, Array2<F>), FerroError> {
366 let n = a.nrows();
367 if n == 0 {
368 return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
369 }
370 if n == 1 {
371 let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
372 let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
373 return Ok((eigenvalues, eigenvectors));
374 }
375
376 let mut mat = a.to_owned();
377 let mut v = Array2::<F>::zeros((n, n));
378 for i in 0..n {
379 v[[i, i]] = F::one();
380 }
381
382 let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
383
384 for _iteration in 0..max_iter {
385 let mut max_off = F::zero();
386 let mut p = 0;
387 let mut q = 1;
388 for i in 0..n {
389 for j in (i + 1)..n {
390 let val = mat[[i, j]].abs();
391 if val > max_off {
392 max_off = val;
393 p = i;
394 q = j;
395 }
396 }
397 }
398
399 if max_off < tol {
400 let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
401 return Ok((eigenvalues, v));
402 }
403
404 let app = mat[[p, p]];
405 let aqq = mat[[q, q]];
406 let apq = mat[[p, q]];
407
408 let theta = if (app - aqq).abs() < tol {
409 F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
410 } else {
411 let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
412 let t = if tau >= F::zero() {
413 F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
414 } else {
415 -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
416 };
417 t.atan()
418 };
419
420 let c = theta.cos();
421 let s = theta.sin();
422
423 let mut new_mat = mat.clone();
424 for i in 0..n {
425 if i != p && i != q {
426 let mip = mat[[i, p]];
427 let miq = mat[[i, q]];
428 new_mat[[i, p]] = c * mip - s * miq;
429 new_mat[[p, i]] = new_mat[[i, p]];
430 new_mat[[i, q]] = s * mip + c * miq;
431 new_mat[[q, i]] = new_mat[[i, q]];
432 }
433 }
434
435 new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
436 new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
437 new_mat[[p, q]] = F::zero();
438 new_mat[[q, p]] = F::zero();
439
440 mat = new_mat;
441
442 for i in 0..n {
443 let vip = v[[i, p]];
444 let viq = v[[i, q]];
445 v[[i, p]] = c * vip - s * viq;
446 v[[i, q]] = s * vip + c * viq;
447 }
448 }
449
450 Err(FerroError::ConvergenceFailure {
451 iterations: max_iter,
452 message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
453 })
454}
455
456fn solve_multiplicative_update<F: Float + 'static>(
462 x: &Array2<F>,
463 w: &mut Array2<F>,
464 h: &mut Array2<F>,
465 max_iter: usize,
466 tol: f64,
467) -> usize {
468 let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
469 let epsilon = eps::<F>();
470 let mut prev_err = reconstruction_error(x, w, h);
471
472 for iteration in 0..max_iter {
473 let wt = w.t();
475 let numerator_h = wt.dot(x);
476 let denominator_h = wt.dot(&*w).dot(&*h);
477
478 for (h_val, (num, den)) in h
479 .iter_mut()
480 .zip(numerator_h.iter().zip(denominator_h.iter()))
481 {
482 *h_val = *h_val * (*num / (*den + epsilon));
483 }
484
485 let ht = h.t();
487 let numerator_w = x.dot(&ht);
488 let denominator_w = w.dot(&*h).dot(&ht);
489
490 for (w_val, (num, den)) in w
491 .iter_mut()
492 .zip(numerator_w.iter().zip(denominator_w.iter()))
493 {
494 *w_val = *w_val * (*num / (*den + epsilon));
495 }
496
497 let err = reconstruction_error(x, w, h);
499 if (prev_err - err).abs() < tol_f {
500 return iteration + 1;
501 }
502 prev_err = err;
503 }
504
505 max_iter
506}
507
508fn solve_coordinate_descent<F: Float + 'static>(
512 x: &Array2<F>,
513 w: &mut Array2<F>,
514 h: &mut Array2<F>,
515 max_iter: usize,
516 tol: f64,
517) -> usize {
518 let (n_samples, n_features) = x.dim();
519 let n_components = h.nrows();
520 let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
521 let epsilon = eps::<F>();
522 let mut prev_err = reconstruction_error(x, w, h);
523
524 for iteration in 0..max_iter {
525 for k in 0..n_components {
528 let mut wk_norm_sq = F::zero();
529 for i in 0..n_samples {
530 wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
531 }
532
533 if wk_norm_sq < epsilon {
534 continue;
535 }
536
537 for j in 0..n_features {
538 let mut numerator = F::zero();
540 for i in 0..n_samples {
541 let mut wh_ij = F::zero();
542 for kk in 0..n_components {
543 if kk != k {
544 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
545 }
546 }
547 numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
548 }
549
550 h[[k, j]] = if numerator > F::zero() {
551 numerator / wk_norm_sq
552 } else {
553 F::zero()
554 };
555 }
556 }
557
558 for k in 0..n_components {
560 let mut hk_norm_sq = F::zero();
561 for j in 0..n_features {
562 hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
563 }
564
565 if hk_norm_sq < epsilon {
566 continue;
567 }
568
569 for i in 0..n_samples {
570 let mut numerator = F::zero();
571 for j in 0..n_features {
572 let mut wh_ij = F::zero();
573 for kk in 0..n_components {
574 if kk != k {
575 wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
576 }
577 }
578 numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
579 }
580
581 w[[i, k]] = if numerator > F::zero() {
582 numerator / hk_norm_sq
583 } else {
584 F::zero()
585 };
586 }
587 }
588
589 let err = reconstruction_error(x, w, h);
591 if (prev_err - err).abs() < tol_f {
592 return iteration + 1;
593 }
594 prev_err = err;
595 }
596
597 max_iter
598}
599
600impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
605 type Fitted = FittedNMF<F>;
606 type Error = FerroError;
607
608 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
618 let (n_samples, n_features) = x.dim();
619
620 if self.n_components == 0 {
621 return Err(FerroError::InvalidParameter {
622 name: "n_components".into(),
623 reason: "must be at least 1".into(),
624 });
625 }
626 if n_samples == 0 {
627 return Err(FerroError::InsufficientSamples {
628 required: 1,
629 actual: 0,
630 context: "NMF::fit".into(),
631 });
632 }
633 if self.n_components > n_samples.min(n_features) {
634 return Err(FerroError::InvalidParameter {
635 name: "n_components".into(),
636 reason: format!(
637 "n_components ({}) exceeds min(n_samples, n_features) = {}",
638 self.n_components,
639 n_samples.min(n_features)
640 ),
641 });
642 }
643
644 for &val in x {
646 if val < F::zero() {
647 return Err(FerroError::InvalidParameter {
648 name: "X".into(),
649 reason: "NMF requires all entries in X to be non-negative".into(),
650 });
651 }
652 }
653
654 let seed = self.random_state.unwrap_or(0);
655
656 let (mut w, mut h) = match self.init {
658 NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
659 NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
660 };
661
662 let n_iter = match self.solver {
664 NMFSolver::MultiplicativeUpdate => {
665 solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
666 }
667 NMFSolver::CoordinateDescent => {
668 solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
669 }
670 };
671
672 let reconstruction_err = reconstruction_error(x, &w, &h);
673
674 Ok(FittedNMF {
675 components_: h,
676 reconstruction_err_: reconstruction_err,
677 n_iter_: n_iter,
678 })
679 }
680}
681
682impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
683 type Output = Array2<F>;
684 type Error = FerroError;
685
686 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
697 let n_features = self.components_.ncols();
698 if x.ncols() != n_features {
699 return Err(FerroError::ShapeMismatch {
700 expected: vec![x.nrows(), n_features],
701 actual: vec![x.nrows(), x.ncols()],
702 context: "FittedNMF::transform".into(),
703 });
704 }
705
706 for &val in x {
708 if val < F::zero() {
709 return Err(FerroError::InvalidParameter {
710 name: "X".into(),
711 reason: "NMF requires all entries in X to be non-negative".into(),
712 });
713 }
714 }
715
716 let n_samples = x.nrows();
717 let n_components = self.components_.nrows();
718 let epsilon = eps::<F>();
719
720 let mut w = Array2::<F>::zeros((n_samples, n_components));
722 let init_val = F::from(0.1).unwrap_or_else(F::one);
723 w.fill(init_val);
724
725 let h = &self.components_;
727 for _iter in 0..200 {
728 let wt_num = x.dot(&h.t());
729 let wt_den = w.dot(h).dot(&h.t());
730
731 for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
732 *w_val = *w_val * (*num / (*den + epsilon));
733 }
734 }
735
736 Ok(w)
737 }
738}
739
740impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
745 fn fit_pipeline(
753 &self,
754 x: &Array2<F>,
755 _y: &Array1<F>,
756 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
757 let fitted = self.fit(x, &())?;
758 Ok(Box::new(fitted))
759 }
760}
761
762impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
763 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
769 self.transform(x)
770 }
771}
772
773#[cfg(test)]
778mod tests {
779 use super::*;
780 use approx::assert_abs_diff_eq;
781 use ndarray::array;
782
783 fn small_dataset() -> Array2<f64> {
785 array![
786 [1.0, 2.0, 3.0],
787 [4.0, 5.0, 6.0],
788 [7.0, 8.0, 9.0],
789 [10.0, 11.0, 12.0],
790 ]
791 }
792
793 fn medium_dataset() -> Array2<f64> {
795 array![
796 [5.0, 3.0, 0.0, 1.0],
797 [4.0, 0.0, 0.0, 1.0],
798 [1.0, 1.0, 0.0, 5.0],
799 [1.0, 0.0, 0.0, 4.0],
800 [0.0, 1.0, 5.0, 4.0],
801 [0.0, 0.0, 4.0, 3.0],
802 ]
803 }
804
805 #[test]
806 fn test_nmf_basic_fit() {
807 let nmf = NMF::<f64>::new(2).with_random_state(42);
808 let x = small_dataset();
809 let fitted = nmf.fit(&x, &()).unwrap();
810 assert_eq!(fitted.components().dim(), (2, 3));
811 }
812
813 #[test]
814 fn test_nmf_components_non_negative() {
815 let nmf = NMF::<f64>::new(2).with_random_state(42);
816 let x = small_dataset();
817 let fitted = nmf.fit(&x, &()).unwrap();
818 for &val in fitted.components() {
819 assert!(
820 val >= 0.0,
821 "component value should be non-negative, got {val}"
822 );
823 }
824 }
825
826 #[test]
827 fn test_nmf_transform_dimensions() {
828 let nmf = NMF::<f64>::new(2).with_random_state(42);
829 let x = small_dataset();
830 let fitted = nmf.fit(&x, &()).unwrap();
831 let projected = fitted.transform(&x).unwrap();
832 assert_eq!(projected.dim(), (4, 2));
833 }
834
835 #[test]
836 fn test_nmf_transform_non_negative() {
837 let nmf = NMF::<f64>::new(2).with_random_state(42);
838 let x = small_dataset();
839 let fitted = nmf.fit(&x, &()).unwrap();
840 let projected = fitted.transform(&x).unwrap();
841 for &val in &projected {
842 assert!(val >= 0.0, "W value should be non-negative, got {val}");
843 }
844 }
845
846 #[test]
847 fn test_nmf_reconstruction_error_decreases() {
848 let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
849 let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
850 let x = small_dataset();
851 let fitted_few = nmf_few.fit(&x, &()).unwrap();
852 let fitted_many = nmf_many.fit(&x, &()).unwrap();
853 assert!(
854 fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
855 "more iterations should reduce error: few={}, many={}",
856 fitted_few.reconstruction_err(),
857 fitted_many.reconstruction_err()
858 );
859 }
860
861 #[test]
862 fn test_nmf_reconstruction_error_positive() {
863 let nmf = NMF::<f64>::new(2).with_random_state(42);
864 let x = small_dataset();
865 let fitted = nmf.fit(&x, &()).unwrap();
866 assert!(fitted.reconstruction_err() >= 0.0);
867 }
868
869 #[test]
870 fn test_nmf_coordinate_descent_solver() {
871 let nmf = NMF::<f64>::new(2)
872 .with_solver(NMFSolver::CoordinateDescent)
873 .with_random_state(42);
874 let x = medium_dataset();
875 let fitted = nmf.fit(&x, &()).unwrap();
876 assert_eq!(fitted.components().dim(), (2, 4));
877 for &val in fitted.components() {
878 assert!(val >= 0.0, "CD component should be non-negative, got {val}");
879 }
880 }
881
882 #[test]
883 fn test_nmf_nndsvd_init() {
884 let nmf = NMF::<f64>::new(2)
885 .with_init(NMFInit::Nndsvd)
886 .with_random_state(42);
887 let x = medium_dataset();
888 let fitted = nmf.fit(&x, &()).unwrap();
889 assert_eq!(fitted.components().dim(), (2, 4));
890 for &val in fitted.components() {
891 assert!(
892 val >= 0.0,
893 "NNDSVD component should be non-negative, got {val}"
894 );
895 }
896 }
897
898 #[test]
899 fn test_nmf_cd_with_nndsvd() {
900 let nmf = NMF::<f64>::new(2)
901 .with_solver(NMFSolver::CoordinateDescent)
902 .with_init(NMFInit::Nndsvd)
903 .with_random_state(42);
904 let x = medium_dataset();
905 let fitted = nmf.fit(&x, &()).unwrap();
906 assert_eq!(fitted.components().dim(), (2, 4));
907 }
908
909 #[test]
910 fn test_nmf_invalid_n_components_zero() {
911 let nmf = NMF::<f64>::new(0);
912 let x = small_dataset();
913 assert!(nmf.fit(&x, &()).is_err());
914 }
915
916 #[test]
917 fn test_nmf_invalid_n_components_too_large() {
918 let nmf = NMF::<f64>::new(10);
919 let x = small_dataset(); assert!(nmf.fit(&x, &()).is_err());
921 }
922
923 #[test]
924 fn test_nmf_negative_input_rejected() {
925 let nmf = NMF::<f64>::new(1);
926 let x = array![[1.0, -2.0], [3.0, 4.0]];
927 assert!(nmf.fit(&x, &()).is_err());
928 }
929
930 #[test]
931 fn test_nmf_transform_shape_mismatch() {
932 let nmf = NMF::<f64>::new(2).with_random_state(42);
933 let x = small_dataset();
934 let fitted = nmf.fit(&x, &()).unwrap();
935 let x_bad = array![[1.0, 2.0]]; assert!(fitted.transform(&x_bad).is_err());
937 }
938
939 #[test]
940 fn test_nmf_transform_negative_rejected() {
941 let nmf = NMF::<f64>::new(2).with_random_state(42);
942 let x = small_dataset();
943 let fitted = nmf.fit(&x, &()).unwrap();
944 let x_neg = array![[1.0, -2.0, 3.0]];
945 assert!(fitted.transform(&x_neg).is_err());
946 }
947
948 #[test]
949 fn test_nmf_reproducibility() {
950 let nmf1 = NMF::<f64>::new(2).with_random_state(42);
951 let nmf2 = NMF::<f64>::new(2).with_random_state(42);
952 let x = small_dataset();
953 let fitted1 = nmf1.fit(&x, &()).unwrap();
954 let fitted2 = nmf2.fit(&x, &()).unwrap();
955 for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
956 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
957 }
958 }
959
960 #[test]
961 fn test_nmf_single_component() {
962 let nmf = NMF::<f64>::new(1).with_random_state(42);
963 let x = small_dataset();
964 let fitted = nmf.fit(&x, &()).unwrap();
965 assert_eq!(fitted.components().nrows(), 1);
966 let projected = fitted.transform(&x).unwrap();
967 assert_eq!(projected.ncols(), 1);
968 }
969
970 #[test]
971 fn test_nmf_n_iter_positive() {
972 let nmf = NMF::<f64>::new(2).with_random_state(42);
973 let x = small_dataset();
974 let fitted = nmf.fit(&x, &()).unwrap();
975 assert!(fitted.n_iter() > 0);
976 }
977
978 #[test]
979 fn test_nmf_getters() {
980 let nmf = NMF::<f64>::new(3)
981 .with_max_iter(100)
982 .with_tol(1e-5)
983 .with_solver(NMFSolver::CoordinateDescent)
984 .with_init(NMFInit::Nndsvd)
985 .with_random_state(99);
986 assert_eq!(nmf.n_components(), 3);
987 assert_eq!(nmf.max_iter(), 100);
988 assert_abs_diff_eq!(nmf.tol(), 1e-5);
989 assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
990 assert_eq!(nmf.init(), NMFInit::Nndsvd);
991 assert_eq!(nmf.random_state(), Some(99));
992 }
993
994 #[test]
995 fn test_nmf_f32() {
996 let nmf = NMF::<f32>::new(1).with_random_state(42);
997 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
998 let fitted = nmf.fit(&x, &()).unwrap();
999 let projected = fitted.transform(&x).unwrap();
1000 assert_eq!(projected.ncols(), 1);
1001 }
1002
1003 #[test]
1004 fn test_nmf_zero_entries() {
1005 let nmf = NMF::<f64>::new(2).with_random_state(42);
1006 let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
1007 let fitted = nmf.fit(&x, &()).unwrap();
1008 assert_eq!(fitted.components().dim(), (2, 3));
1009 }
1010
1011 #[test]
1012 fn test_nmf_pipeline_integration() {
1013 use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
1014 use ferrolearn_core::traits::Predict;
1015
1016 struct SumEstimator;
1017
1018 impl PipelineEstimator<f64> for SumEstimator {
1019 fn fit_pipeline(
1020 &self,
1021 _x: &Array2<f64>,
1022 _y: &Array1<f64>,
1023 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
1024 Ok(Box::new(FittedSumEstimator))
1025 }
1026 }
1027
1028 struct FittedSumEstimator;
1029
1030 impl FittedPipelineEstimator<f64> for FittedSumEstimator {
1031 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1032 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
1033 Ok(Array1::from_vec(sums))
1034 }
1035 }
1036
1037 let pipeline = Pipeline::new()
1038 .transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
1039 .estimator_step("sum", Box::new(SumEstimator));
1040
1041 let x = small_dataset();
1042 let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1043
1044 let fitted = pipeline.fit(&x, &y).unwrap();
1045 let preds = fitted.predict(&x).unwrap();
1046 assert_eq!(preds.len(), 4);
1047 }
1048
1049 #[test]
1050 fn test_nmf_medium_dataset_mu() {
1051 let nmf = NMF::<f64>::new(3)
1052 .with_solver(NMFSolver::MultiplicativeUpdate)
1053 .with_random_state(42)
1054 .with_max_iter(500);
1055 let x = medium_dataset();
1056 let fitted = nmf.fit(&x, &()).unwrap();
1057 assert_eq!(fitted.components().dim(), (3, 4));
1058 assert!(
1060 fitted.reconstruction_err() < 10.0,
1061 "reconstruction error too large: {}",
1062 fitted.reconstruction_err()
1063 );
1064 }
1065
1066 #[test]
1067 fn test_nmf_insufficient_samples() {
1068 let nmf = NMF::<f64>::new(1);
1069 let x = Array2::<f64>::zeros((0, 3));
1070 assert!(nmf.fit(&x, &()).is_err());
1071 }
1072
1073 #[test]
1074 fn test_nmf_more_components_lower_error() {
1075 let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
1076 let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
1077 let x = medium_dataset();
1078 let fitted1 = nmf1.fit(&x, &()).unwrap();
1079 let fitted2 = nmf2.fit(&x, &()).unwrap();
1080 assert!(
1081 fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
1082 "more components should reduce error: 1comp={}, 2comp={}",
1083 fitted1.reconstruction_err(),
1084 fitted2.reconstruction_err()
1085 );
1086 }
1087}