use ferrolearn_core::error::FerroError;
use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
use ferrolearn_core::traits::{Fit, Transform};
use ndarray::{Array1, Array2};
use num_traits::Float;
use rand::SeedableRng;
use rand_distr::{Distribution, Uniform};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NMFSolver {
MultiplicativeUpdate,
CoordinateDescent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NMFInit {
Random,
Nndsvd,
}
#[derive(Debug, Clone)]
pub struct NMF<F> {
n_components: usize,
max_iter: usize,
tol: f64,
solver: NMFSolver,
init: NMFInit,
random_state: Option<u64>,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> NMF<F> {
#[must_use]
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 200,
tol: 1e-4,
solver: NMFSolver::MultiplicativeUpdate,
init: NMFInit::Random,
random_state: None,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn with_solver(mut self, solver: NMFSolver) -> Self {
self.solver = solver;
self
}
#[must_use]
pub fn with_init(mut self, init: NMFInit) -> Self {
self.init = init;
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
#[must_use]
pub fn n_components(&self) -> usize {
self.n_components
}
#[must_use]
pub fn max_iter(&self) -> usize {
self.max_iter
}
#[must_use]
pub fn tol(&self) -> f64 {
self.tol
}
#[must_use]
pub fn solver(&self) -> NMFSolver {
self.solver
}
#[must_use]
pub fn init(&self) -> NMFInit {
self.init
}
#[must_use]
pub fn random_state(&self) -> Option<u64> {
self.random_state
}
}
#[derive(Debug, Clone)]
pub struct FittedNMF<F> {
components_: Array2<F>,
reconstruction_err_: F,
n_iter_: usize,
}
impl<F: Float + Send + Sync + 'static> FittedNMF<F> {
#[must_use]
pub fn components(&self) -> &Array2<F> {
&self.components_
}
#[must_use]
pub fn reconstruction_err(&self) -> F {
self.reconstruction_err_
}
#[must_use]
pub fn n_iter(&self) -> usize {
self.n_iter_
}
}
fn reconstruction_error<F: Float + 'static>(x: &Array2<F>, w: &Array2<F>, h: &Array2<F>) -> F {
let wh = w.dot(h);
let mut err = F::zero();
for (a, b) in x.iter().zip(wh.iter()) {
let diff = *a - *b;
err = err + diff * diff;
}
err.sqrt()
}
fn eps<F: Float>() -> F {
F::from(1e-12).unwrap_or_else(F::epsilon)
}
fn init_random<F: Float>(
n_samples: usize,
n_features: usize,
n_components: usize,
seed: u64,
) -> (Array2<F>, Array2<F>) {
let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
let mut w = Array2::<F>::zeros((n_samples, n_components));
for elem in &mut w {
*elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
}
let mut h = Array2::<F>::zeros((n_components, n_features));
for elem in &mut h {
*elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) + eps::<F>();
}
(w, h)
}
fn init_nndsvd<F: Float + Send + Sync + 'static>(
x: &Array2<F>,
n_components: usize,
seed: u64,
) -> Result<(Array2<F>, Array2<F>), FerroError> {
let (n_samples, n_features) = x.dim();
let mut total = F::zero();
for &v in x {
total = total + v;
}
let avg = (total / F::from(n_samples * n_features).unwrap())
.abs()
.sqrt();
let avg = if avg < eps::<F>() { F::one() } else { avg };
let xtx = x.t().dot(x);
let max_iter = n_features * n_features * 100 + 1000;
let (eigenvalues, eigenvectors) = jacobi_eigen_symmetric(&xtx, max_iter)?;
let mut indices: Vec<usize> = (0..n_features).collect();
indices.sort_by(|&a, &b| {
eigenvalues[b]
.partial_cmp(&eigenvalues[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut h = Array2::<F>::zeros((n_components, n_features));
for (k, &idx) in indices.iter().take(n_components).enumerate() {
for j in 0..n_features {
let val = eigenvectors[[j, idx]];
h[[k, j]] = if val > F::zero() { val } else { F::zero() };
}
let row_sum: F = h.row(k).iter().copied().fold(F::zero(), |a, b| a + b);
if row_sum < eps::<F>() {
let mut rng: rand::rngs::StdRng =
SeedableRng::seed_from_u64(seed.wrapping_add(k as u64));
let uniform = Uniform::new(0.0f64, 1.0f64).unwrap();
for j in 0..n_features {
h[[k, j]] = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero) * avg;
}
}
}
let mut w = Array2::<F>::zeros((n_samples, n_components));
let ht = h.t();
let w_init = x.dot(&ht);
for i in 0..n_samples {
for k in 0..n_components {
let val = w_init[[i, k]];
w[[i, k]] = if val > F::zero() { val } else { eps::<F>() };
}
}
Ok((w, h))
}
fn jacobi_eigen_symmetric<F: Float + Send + Sync + 'static>(
a: &Array2<F>,
max_iter: usize,
) -> Result<(Array1<F>, Array2<F>), FerroError> {
let n = a.nrows();
if n == 0 {
return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
}
if n == 1 {
let eigenvalues = Array1::from_vec(vec![a[[0, 0]]]);
let eigenvectors = Array2::from_shape_vec((1, 1), vec![F::one()]).unwrap();
return Ok((eigenvalues, eigenvectors));
}
let mut mat = a.to_owned();
let mut v = Array2::<F>::zeros((n, n));
for i in 0..n {
v[[i, i]] = F::one();
}
let tol = F::from(1e-12).unwrap_or_else(F::epsilon);
for _iteration in 0..max_iter {
let mut max_off = F::zero();
let mut p = 0;
let mut q = 1;
for i in 0..n {
for j in (i + 1)..n {
let val = mat[[i, j]].abs();
if val > max_off {
max_off = val;
p = i;
q = j;
}
}
}
if max_off < tol {
let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
return Ok((eigenvalues, v));
}
let app = mat[[p, p]];
let aqq = mat[[q, q]];
let apq = mat[[p, q]];
let theta = if (app - aqq).abs() < tol {
F::from(std::f64::consts::FRAC_PI_4).unwrap_or_else(F::one)
} else {
let tau = (aqq - app) / (F::from(2.0).unwrap() * apq);
let t = if tau >= F::zero() {
F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
} else {
-F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
};
t.atan()
};
let c = theta.cos();
let s = theta.sin();
let mut new_mat = mat.clone();
for i in 0..n {
if i != p && i != q {
let mip = mat[[i, p]];
let miq = mat[[i, q]];
new_mat[[i, p]] = c * mip - s * miq;
new_mat[[p, i]] = new_mat[[i, p]];
new_mat[[i, q]] = s * mip + c * miq;
new_mat[[q, i]] = new_mat[[i, q]];
}
}
new_mat[[p, p]] = c * c * app - F::from(2.0).unwrap() * s * c * apq + s * s * aqq;
new_mat[[q, q]] = s * s * app + F::from(2.0).unwrap() * s * c * apq + c * c * aqq;
new_mat[[p, q]] = F::zero();
new_mat[[q, p]] = F::zero();
mat = new_mat;
for i in 0..n {
let vip = v[[i, p]];
let viq = v[[i, q]];
v[[i, p]] = c * vip - s * viq;
v[[i, q]] = s * vip + c * viq;
}
}
Err(FerroError::ConvergenceFailure {
iterations: max_iter,
message: "Jacobi eigendecomposition did not converge in NMF NNDSVD init".into(),
})
}
fn solve_multiplicative_update<F: Float + 'static>(
x: &Array2<F>,
w: &mut Array2<F>,
h: &mut Array2<F>,
max_iter: usize,
tol: f64,
) -> usize {
let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
let epsilon = eps::<F>();
let mut prev_err = reconstruction_error(x, w, h);
for iteration in 0..max_iter {
let wt = w.t();
let numerator_h = wt.dot(x);
let denominator_h = wt.dot(&*w).dot(&*h);
for (h_val, (num, den)) in h
.iter_mut()
.zip(numerator_h.iter().zip(denominator_h.iter()))
{
*h_val = *h_val * (*num / (*den + epsilon));
}
let ht = h.t();
let numerator_w = x.dot(&ht);
let denominator_w = w.dot(&*h).dot(&ht);
for (w_val, (num, den)) in w
.iter_mut()
.zip(numerator_w.iter().zip(denominator_w.iter()))
{
*w_val = *w_val * (*num / (*den + epsilon));
}
let err = reconstruction_error(x, w, h);
if (prev_err - err).abs() < tol_f {
return iteration + 1;
}
prev_err = err;
}
max_iter
}
fn solve_coordinate_descent<F: Float + 'static>(
x: &Array2<F>,
w: &mut Array2<F>,
h: &mut Array2<F>,
max_iter: usize,
tol: f64,
) -> usize {
let (n_samples, n_features) = x.dim();
let n_components = h.nrows();
let tol_f = F::from(tol).unwrap_or_else(F::epsilon);
let epsilon = eps::<F>();
let mut prev_err = reconstruction_error(x, w, h);
for iteration in 0..max_iter {
for k in 0..n_components {
let mut wk_norm_sq = F::zero();
for i in 0..n_samples {
wk_norm_sq = wk_norm_sq + w[[i, k]] * w[[i, k]];
}
if wk_norm_sq < epsilon {
continue;
}
for j in 0..n_features {
let mut numerator = F::zero();
for i in 0..n_samples {
let mut wh_ij = F::zero();
for kk in 0..n_components {
if kk != k {
wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
}
}
numerator = numerator + w[[i, k]] * (x[[i, j]] - wh_ij);
}
h[[k, j]] = if numerator > F::zero() {
numerator / wk_norm_sq
} else {
F::zero()
};
}
}
for k in 0..n_components {
let mut hk_norm_sq = F::zero();
for j in 0..n_features {
hk_norm_sq = hk_norm_sq + h[[k, j]] * h[[k, j]];
}
if hk_norm_sq < epsilon {
continue;
}
for i in 0..n_samples {
let mut numerator = F::zero();
for j in 0..n_features {
let mut wh_ij = F::zero();
for kk in 0..n_components {
if kk != k {
wh_ij = wh_ij + w[[i, kk]] * h[[kk, j]];
}
}
numerator = numerator + h[[k, j]] * (x[[i, j]] - wh_ij);
}
w[[i, k]] = if numerator > F::zero() {
numerator / hk_norm_sq
} else {
F::zero()
};
}
}
let err = reconstruction_error(x, w, h);
if (prev_err - err).abs() < tol_f {
return iteration + 1;
}
prev_err = err;
}
max_iter
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for NMF<F> {
type Fitted = FittedNMF<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedNMF<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if self.n_components == 0 {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: "must be at least 1".into(),
});
}
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "NMF::fit".into(),
});
}
if self.n_components > n_samples.min(n_features) {
return Err(FerroError::InvalidParameter {
name: "n_components".into(),
reason: format!(
"n_components ({}) exceeds min(n_samples, n_features) = {}",
self.n_components,
n_samples.min(n_features)
),
});
}
for &val in x {
if val < F::zero() {
return Err(FerroError::InvalidParameter {
name: "X".into(),
reason: "NMF requires all entries in X to be non-negative".into(),
});
}
}
let seed = self.random_state.unwrap_or(0);
let (mut w, mut h) = match self.init {
NMFInit::Random => init_random(n_samples, n_features, self.n_components, seed),
NMFInit::Nndsvd => init_nndsvd(x, self.n_components, seed)?,
};
let n_iter = match self.solver {
NMFSolver::MultiplicativeUpdate => {
solve_multiplicative_update(x, &mut w, &mut h, self.max_iter, self.tol)
}
NMFSolver::CoordinateDescent => {
solve_coordinate_descent(x, &mut w, &mut h, self.max_iter, self.tol)
}
};
let reconstruction_err = reconstruction_error(x, &w, &h);
Ok(FittedNMF {
components_: h,
reconstruction_err_: reconstruction_err,
n_iter_: n_iter,
})
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedNMF<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let n_features = self.components_.ncols();
if x.ncols() != n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows(), n_features],
actual: vec![x.nrows(), x.ncols()],
context: "FittedNMF::transform".into(),
});
}
for &val in x {
if val < F::zero() {
return Err(FerroError::InvalidParameter {
name: "X".into(),
reason: "NMF requires all entries in X to be non-negative".into(),
});
}
}
let n_samples = x.nrows();
let n_components = self.components_.nrows();
let epsilon = eps::<F>();
let mut w = Array2::<F>::zeros((n_samples, n_components));
let init_val = F::from(0.1).unwrap_or_else(F::one);
w.fill(init_val);
let h = &self.components_;
for _iter in 0..200 {
let wt_num = x.dot(&h.t());
let wt_den = w.dot(h).dot(&h.t());
for (w_val, (num, den)) in w.iter_mut().zip(wt_num.iter().zip(wt_den.iter())) {
*w_val = *w_val * (*num / (*den + epsilon));
}
}
Ok(w)
}
}
impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for NMF<F> {
fn fit_pipeline(
&self,
x: &Array2<F>,
_y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
let fitted = self.fit(x, &())?;
Ok(Box::new(fitted))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedNMF<F> {
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
self.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn small_dataset() -> Array2<f64> {
array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
]
}
fn medium_dataset() -> Array2<f64> {
array![
[5.0, 3.0, 0.0, 1.0],
[4.0, 0.0, 0.0, 1.0],
[1.0, 1.0, 0.0, 5.0],
[1.0, 0.0, 0.0, 4.0],
[0.0, 1.0, 5.0, 4.0],
[0.0, 0.0, 4.0, 3.0],
]
}
#[test]
fn test_nmf_basic_fit() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (2, 3));
}
#[test]
fn test_nmf_components_non_negative() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
for &val in fitted.components() {
assert!(
val >= 0.0,
"component value should be non-negative, got {val}"
);
}
}
#[test]
fn test_nmf_transform_dimensions() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
let projected = fitted.transform(&x).unwrap();
assert_eq!(projected.dim(), (4, 2));
}
#[test]
fn test_nmf_transform_non_negative() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
let projected = fitted.transform(&x).unwrap();
for &val in &projected {
assert!(val >= 0.0, "W value should be non-negative, got {val}");
}
}
#[test]
fn test_nmf_reconstruction_error_decreases() {
let nmf_few = NMF::<f64>::new(2).with_random_state(42).with_max_iter(10);
let nmf_many = NMF::<f64>::new(2).with_random_state(42).with_max_iter(200);
let x = small_dataset();
let fitted_few = nmf_few.fit(&x, &()).unwrap();
let fitted_many = nmf_many.fit(&x, &()).unwrap();
assert!(
fitted_many.reconstruction_err() <= fitted_few.reconstruction_err() + 1e-6,
"more iterations should reduce error: few={}, many={}",
fitted_few.reconstruction_err(),
fitted_many.reconstruction_err()
);
}
#[test]
fn test_nmf_reconstruction_error_positive() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert!(fitted.reconstruction_err() >= 0.0);
}
#[test]
fn test_nmf_coordinate_descent_solver() {
let nmf = NMF::<f64>::new(2)
.with_solver(NMFSolver::CoordinateDescent)
.with_random_state(42);
let x = medium_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (2, 4));
for &val in fitted.components() {
assert!(val >= 0.0, "CD component should be non-negative, got {val}");
}
}
#[test]
fn test_nmf_nndsvd_init() {
let nmf = NMF::<f64>::new(2)
.with_init(NMFInit::Nndsvd)
.with_random_state(42);
let x = medium_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (2, 4));
for &val in fitted.components() {
assert!(
val >= 0.0,
"NNDSVD component should be non-negative, got {val}"
);
}
}
#[test]
fn test_nmf_cd_with_nndsvd() {
let nmf = NMF::<f64>::new(2)
.with_solver(NMFSolver::CoordinateDescent)
.with_init(NMFInit::Nndsvd)
.with_random_state(42);
let x = medium_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (2, 4));
}
#[test]
fn test_nmf_invalid_n_components_zero() {
let nmf = NMF::<f64>::new(0);
let x = small_dataset();
assert!(nmf.fit(&x, &()).is_err());
}
#[test]
fn test_nmf_invalid_n_components_too_large() {
let nmf = NMF::<f64>::new(10);
let x = small_dataset(); assert!(nmf.fit(&x, &()).is_err());
}
#[test]
fn test_nmf_negative_input_rejected() {
let nmf = NMF::<f64>::new(1);
let x = array![[1.0, -2.0], [3.0, 4.0]];
assert!(nmf.fit(&x, &()).is_err());
}
#[test]
fn test_nmf_transform_shape_mismatch() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
let x_bad = array![[1.0, 2.0]]; assert!(fitted.transform(&x_bad).is_err());
}
#[test]
fn test_nmf_transform_negative_rejected() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
let x_neg = array![[1.0, -2.0, 3.0]];
assert!(fitted.transform(&x_neg).is_err());
}
#[test]
fn test_nmf_reproducibility() {
let nmf1 = NMF::<f64>::new(2).with_random_state(42);
let nmf2 = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted1 = nmf1.fit(&x, &()).unwrap();
let fitted2 = nmf2.fit(&x, &()).unwrap();
for (a, b) in fitted1.components().iter().zip(fitted2.components().iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_nmf_single_component() {
let nmf = NMF::<f64>::new(1).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().nrows(), 1);
let projected = fitted.transform(&x).unwrap();
assert_eq!(projected.ncols(), 1);
}
#[test]
fn test_nmf_n_iter_positive() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = small_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert!(fitted.n_iter() > 0);
}
#[test]
fn test_nmf_getters() {
let nmf = NMF::<f64>::new(3)
.with_max_iter(100)
.with_tol(1e-5)
.with_solver(NMFSolver::CoordinateDescent)
.with_init(NMFInit::Nndsvd)
.with_random_state(99);
assert_eq!(nmf.n_components(), 3);
assert_eq!(nmf.max_iter(), 100);
assert_abs_diff_eq!(nmf.tol(), 1e-5);
assert_eq!(nmf.solver(), NMFSolver::CoordinateDescent);
assert_eq!(nmf.init(), NMFInit::Nndsvd);
assert_eq!(nmf.random_state(), Some(99));
}
#[test]
fn test_nmf_f32() {
let nmf = NMF::<f32>::new(1).with_random_state(42);
let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
let fitted = nmf.fit(&x, &()).unwrap();
let projected = fitted.transform(&x).unwrap();
assert_eq!(projected.ncols(), 1);
}
#[test]
fn test_nmf_zero_entries() {
let nmf = NMF::<f64>::new(2).with_random_state(42);
let x = array![[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (2, 3));
}
#[test]
fn test_nmf_pipeline_integration() {
use ferrolearn_core::pipeline::{FittedPipelineEstimator, Pipeline, PipelineEstimator};
use ferrolearn_core::traits::Predict;
struct SumEstimator;
impl PipelineEstimator<f64> for SumEstimator {
fn fit_pipeline(
&self,
_x: &Array2<f64>,
_y: &Array1<f64>,
) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
Ok(Box::new(FittedSumEstimator))
}
}
struct FittedSumEstimator;
impl FittedPipelineEstimator<f64> for FittedSumEstimator {
fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
Ok(Array1::from_vec(sums))
}
}
let pipeline = Pipeline::new()
.transform_step("nmf", Box::new(NMF::<f64>::new(2).with_random_state(42)))
.estimator_step("sum", Box::new(SumEstimator));
let x = small_dataset();
let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let fitted = pipeline.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 4);
}
#[test]
fn test_nmf_medium_dataset_mu() {
let nmf = NMF::<f64>::new(3)
.with_solver(NMFSolver::MultiplicativeUpdate)
.with_random_state(42)
.with_max_iter(500);
let x = medium_dataset();
let fitted = nmf.fit(&x, &()).unwrap();
assert_eq!(fitted.components().dim(), (3, 4));
assert!(
fitted.reconstruction_err() < 10.0,
"reconstruction error too large: {}",
fitted.reconstruction_err()
);
}
#[test]
fn test_nmf_insufficient_samples() {
let nmf = NMF::<f64>::new(1);
let x = Array2::<f64>::zeros((0, 3));
assert!(nmf.fit(&x, &()).is_err());
}
#[test]
fn test_nmf_more_components_lower_error() {
let nmf1 = NMF::<f64>::new(1).with_random_state(42).with_max_iter(300);
let nmf2 = NMF::<f64>::new(2).with_random_state(42).with_max_iter(300);
let x = medium_dataset();
let fitted1 = nmf1.fit(&x, &()).unwrap();
let fitted2 = nmf2.fit(&x, &()).unwrap();
assert!(
fitted2.reconstruction_err() <= fitted1.reconstruction_err() + 1e-6,
"more components should reduce error: 1comp={}, 2comp={}",
fitted1.reconstruction_err(),
fitted2.reconstruction_err()
);
}
}