use crate::error::{Result, TransformError};
use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{Rng, RngExt};
const EPS: f64 = 1e-10;
fn clip_nonneg(a: &mut Array2<f64>) {
a.mapv_inplace(|v| v.max(EPS));
}
fn frob2(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
fn mm(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
assert_eq!(b.nrows(), k, "mm: inner dimension mismatch");
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for l in 0..k {
for j in 0..n {
c[[i, j]] += a[[i, l]] * b[[l, j]];
}
}
}
c
}
fn mm_at_b(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let k = a.ncols();
let m = a.nrows();
let n = b.ncols();
assert_eq!(b.nrows(), m);
let mut c = Array2::<f64>::zeros((k, n));
for i in 0..k {
for l in 0..m {
for j in 0..n {
c[[i, j]] += a[[l, i]] * b[[l, j]];
}
}
}
c
}
fn mm_a_bt(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let (m, k) = (a.nrows(), a.ncols());
let n = b.nrows();
assert_eq!(b.ncols(), k);
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for l in 0..k {
for j in 0..n {
c[[i, j]] += a[[i, l]] * b[[j, l]];
}
}
}
c
}
fn rand_nonneg(nrows: usize, ncols: usize, scale: f64) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let mut a = Array2::<f64>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
a[[i, j]] = scale * rng.gen_range(0.0..1.0_f64);
}
}
a
}
fn rand_signed(nrows: usize, ncols: usize, scale: f64) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let mut a = Array2::<f64>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
a[[i, j]] = scale * rng.gen_range(-1.0..1.0_f64);
}
}
a
}
#[derive(Debug, Clone, PartialEq)]
pub enum NmfDivergence {
Frobenius,
KullbackLeibler,
}
#[derive(Debug, Clone)]
pub struct NMF {
pub n_components: usize,
pub divergence: NmfDivergence,
pub max_iter: usize,
pub tol: f64,
pub w: Option<Array2<f64>>,
pub h: Option<Array2<f64>>,
pub reconstruction_errors: Vec<f64>,
}
impl NMF {
pub fn new(n_components: usize, divergence: NmfDivergence, max_iter: usize, tol: f64) -> Self {
Self {
n_components,
divergence,
max_iter,
tol,
w: None,
h: None,
reconstruction_errors: Vec::new(),
}
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data<Elem = f64>,
{
let x = x.to_owned();
let (n, p) = (x.nrows(), x.ncols());
let k = self.n_components;
if k == 0 || k > n.min(p) {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..=min({n},{p}), got {k}"
)));
}
if x.iter().any(|&v| v < 0.0) {
return Err(TransformError::InvalidInput(
"NMF requires non-negative input matrix".into(),
));
}
let scale = (x.iter().cloned().fold(0.0_f64, f64::max) / k as f64).sqrt();
let mut w = rand_nonneg(n, k, scale.max(EPS));
let mut h = rand_nonneg(k, p, scale.max(EPS));
clip_nonneg(&mut w);
clip_nonneg(&mut h);
self.reconstruction_errors.clear();
let mut prev_err = f64::INFINITY;
for _ in 0..self.max_iter {
match self.divergence {
NmfDivergence::Frobenius => {
let wt_x = mm_at_b(&w, &x);
let wt_wh = mm_at_b(&w, &mm(&w, &h));
for i in 0..k {
for j in 0..p {
h[[i, j]] *= wt_x[[i, j]] / (wt_wh[[i, j]] + EPS);
}
}
clip_nonneg(&mut h);
let x_ht = mm_a_bt(&x, &h);
let whht = mm_a_bt(&mm(&w, &h), &h);
for i in 0..n {
for j in 0..k {
w[[i, j]] *= x_ht[[i, j]] / (whht[[i, j]] + EPS);
}
}
clip_nonneg(&mut w);
}
NmfDivergence::KullbackLeibler => {
let wh = mm(&w, &h);
let x_over_wh = Array2::from_shape_fn((n, p), |(i, j)| {
x[[i, j]] / (wh[[i, j]] + EPS)
});
let numerator_h = mm_at_b(&w, &x_over_wh);
let sum_w = w.sum_axis(scirs2_core::ndarray::Axis(0)); for i in 0..k {
for j in 0..p {
h[[i, j]] *= numerator_h[[i, j]] / (sum_w[i] + EPS);
}
}
clip_nonneg(&mut h);
let wh = mm(&w, &h);
let x_over_wh2 = Array2::from_shape_fn((n, p), |(i, j)| {
x[[i, j]] / (wh[[i, j]] + EPS)
});
let numerator_w = mm_a_bt(&x_over_wh2, &h);
let sum_h = h.sum_axis(scirs2_core::ndarray::Axis(1)); for i in 0..n {
for j in 0..k {
w[[i, j]] *= numerator_w[[i, j]] / (sum_h[j] + EPS);
}
}
clip_nonneg(&mut w);
}
}
let wh = mm(&w, &h);
let err = frob2(&x, &wh).sqrt();
self.reconstruction_errors.push(err);
let delta = (prev_err - err).abs() / (prev_err + EPS);
if delta < self.tol {
break;
}
prev_err = err;
}
self.w = Some(w.clone());
self.h = Some(h.clone());
Ok((w, h))
}
pub fn transform<S>(&self, x_new: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data<Elem = f64>,
{
let w = self
.w
.as_ref()
.ok_or_else(|| TransformError::NotFitted("NMF not fitted".into()))?;
let x = x_new.to_owned();
let (n, p) = (x.nrows(), x.ncols());
let k = self.n_components;
if p != w.nrows() {
return Err(TransformError::DimensionMismatch(
"Feature dimension mismatch in transform".into(),
));
}
let scale = (x.iter().cloned().fold(0.0_f64, f64::max) / k as f64).sqrt().max(EPS);
let mut h = rand_nonneg(k, p, scale);
for _ in 0..200 {
let wt_x = mm_at_b(w, &x);
let wt_wh = mm_at_b(w, &mm(w, &h));
for i in 0..k {
for j in 0..p {
h[[i, j]] *= wt_x[[i, j]] / (wt_wh[[i, j]] + EPS);
}
}
clip_nonneg(&mut h);
}
Ok(h)
}
pub fn inverse_transform(&self, h: &Array2<f64>) -> Result<Array2<f64>> {
let w = self
.w
.as_ref()
.ok_or_else(|| TransformError::NotFitted("NMF not fitted".into()))?;
Ok(mm(w, h))
}
}
#[derive(Debug, Clone)]
pub struct SemiNMF {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub w: Option<Array2<f64>>,
pub h: Option<Array2<f64>>,
}
impl SemiNMF {
pub fn new(n_components: usize, max_iter: usize, tol: f64) -> Self {
Self {
n_components,
max_iter,
tol,
w: None,
h: None,
}
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data<Elem = f64>,
{
let x = x.to_owned();
let (n, p) = (x.nrows(), x.ncols());
let k = self.n_components;
if k == 0 || k > n.min(p) {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..=min({n},{p}), got {k}"
)));
}
let scale = 1.0 / (k as f64).sqrt();
let mut w = rand_signed(n, k, scale);
let mut h = rand_nonneg(k, p, scale);
let mut prev_err = f64::INFINITY;
for _ in 0..self.max_iter {
let wt_x = mm_at_b(&w, &x); let wt_w = mm_at_b(&w, &w); let wt_wh = mm(&wt_w, &h); for i in 0..k {
for j in 0..p {
let pos = wt_x[[i, j]].max(0.0);
let neg = (-wt_x[[i, j]]).max(0.0);
h[[i, j]] *= (pos + EPS) / (neg + wt_wh[[i, j]] + EPS);
}
}
clip_nonneg(&mut h);
let x_ht = mm_a_bt(&x, &h); let hht = mm_a_bt(&h, &h); let hht_inv = pseudo_inv_small(&hht)?;
w = mm(&x_ht, &hht_inv);
let err = frob2(&x, &mm(&w, &h)).sqrt();
let delta = (prev_err - err).abs() / (prev_err + EPS);
if delta < self.tol {
break;
}
prev_err = err;
}
self.w = Some(w.clone());
self.h = Some(h.clone());
Ok((w, h))
}
}
#[derive(Debug, Clone)]
pub struct ConvexNMF {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub s: Option<Array2<f64>>,
pub h: Option<Array2<f64>>,
}
impl ConvexNMF {
pub fn new(n_components: usize, max_iter: usize, tol: f64) -> Self {
Self {
n_components,
max_iter,
tol,
s: None,
h: None,
}
}
pub fn fit_transform<S2>(&mut self, x: &ArrayBase<S2, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S2: Data<Elem = f64>,
{
let x = x.to_owned();
let (n, p) = (x.nrows(), x.ncols());
let k = self.n_components;
if k == 0 || k > n {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..={n}, got {k}"
)));
}
let scale = 1.0 / (k as f64).sqrt();
let mut s = rand_nonneg(n, k, scale); let mut h = rand_nonneg(k, n, scale);
let mut g = s; let mut h2 = rand_nonneg(k, p, scale);
let k_mat = mm_a_bt(&x, &x);
let mut prev_err = f64::INFINITY;
for _ in 0..self.max_iter {
let kht = mm_a_bt(&k_mat, &h2); let ghht = mm_a_bt(&mm(&k_mat, &g), &h2); let kg = mm(&k_mat, &g); let hht = mm_a_bt(&h2, &h2); let kg_hht = mm(&kg, &hht);
for i in 0..n {
for j in 0..k {
let pos = kht[[i, j]].max(0.0);
let neg = (-kht[[i, j]]).max(0.0);
g[[i, j]] *= (pos + EPS) / (neg + kg_hht[[i, j]] + EPS);
}
}
clip_nonneg(&mut g);
let gtk = mm_at_b(&g, &k_mat); let gtkg = mm(>k, &g); let gtkx = mm(>k, &x); let xg = mm(&x, &g); let xgt_x = mm_at_b(&xg, &x); let xgt_xg = mm_at_b(&xg, &xg); let xgt_xg_h = mm(&xgt_xg, &h2);
for i in 0..k {
for j in 0..p {
let pos = xgt_x[[i, j]].max(0.0);
let neg = (-xgt_x[[i, j]]).max(0.0);
h2[[i, j]] *= (pos + EPS) / (neg + xgt_xg_h[[i, j]] + EPS);
}
}
clip_nonneg(&mut h2);
let x_hat = mm(&mm(&x, &g), &h2);
let err = frob2(&x, &x_hat).sqrt();
let delta = (prev_err - err).abs() / (prev_err + EPS);
if delta < self.tol {
break;
}
prev_err = err;
}
let w = mm(&x, &g); self.s = Some(g);
self.h = Some(h2.clone());
Ok((w, h2))
}
}
#[derive(Debug, Clone)]
pub struct RobustNMF {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub w: Option<Array2<f64>>,
pub h: Option<Array2<f64>>,
}
impl RobustNMF {
pub fn new(n_components: usize, max_iter: usize, tol: f64) -> Self {
Self {
n_components,
max_iter,
tol,
w: None,
h: None,
}
}
pub fn fit_transform<S2>(&mut self, x: &ArrayBase<S2, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S2: Data<Elem = f64>,
{
let x_in = x.to_owned();
let (n, p) = (x_in.nrows(), x_in.ncols());
let k = self.n_components;
if k == 0 || k > p.min(n) {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..=min({n},{p}), got {k}"
)));
}
let scale = 1.0 / (k as f64).sqrt();
let mut w = rand_nonneg(p, k, scale);
let mut h = rand_nonneg(k, n, scale);
let mut xt = Array2::<f64>::zeros((p, n));
for i in 0..n {
for j in 0..p {
xt[[j, i]] = x_in[[i, j]];
}
}
let mut prev_err = f64::INFINITY;
for _ in 0..self.max_iter {
let wh = mm(&w, &h); let mut d_diag = vec![0.0f64; n];
for col in 0..n {
let mut sum_sq = 0.0;
for row in 0..p {
let r = xt[[row, col]] - wh[[row, col]];
sum_sq += r * r;
}
d_diag[col] = 1.0 / (2.0 * sum_sq.sqrt() + EPS);
}
let wt_xt = mm_at_b(&w, &xt); let wt_w = mm_at_b(&w, &w); let wt_wh = mm(&wt_w, &h);
for i in 0..k {
for j in 0..n {
let num = wt_xt[[i, j]] * d_diag[j];
let den = wt_wh[[i, j]] * d_diag[j] + EPS;
h[[i, j]] *= num.max(EPS) / den;
}
}
clip_nonneg(&mut h);
let hd: Array2<f64> = Array2::from_shape_fn((k, n), |(i, j)| h[[i, j]] * d_diag[j]);
let xt_d_ht = mm_a_bt(&xt, &hd); let xt_d = Array2::from_shape_fn((p, n), |(i, j)| xt[[i, j]] * d_diag[j]); let xt_d_ht = mm_a_bt(&xt_d, &h); let whdt = mm_a_bt(&mm(&w, &hd), &h); let whd_ht = mm_a_bt(&mm(&w, &hd), &h);
for i in 0..p {
for j in 0..k {
w[[i, j]] *= (xt_d_ht[[i, j]] + EPS) / (whd_ht[[i, j]] + EPS);
}
}
clip_nonneg(&mut w);
let wh2 = mm(&w, &h);
let mut err = 0.0f64;
for col in 0..n {
let mut sum_sq = 0.0;
for row in 0..p {
let r = xt[[row, col]] - wh2[[row, col]];
sum_sq += r * r;
}
err += sum_sq.sqrt();
}
let delta = (prev_err - err).abs() / (prev_err + EPS);
if delta < self.tol {
break;
}
prev_err = err;
}
self.w = Some(w.clone());
self.h = Some(h.clone());
Ok((w, h))
}
}
#[derive(Debug, Clone)]
pub struct DeepNMF {
pub n_components_1: usize,
pub n_components_2: usize,
pub max_iter: usize,
pub inner_iter: usize,
pub tol: f64,
pub w1: Option<Array2<f64>>,
pub w2: Option<Array2<f64>>,
pub h: Option<Array2<f64>>,
}
impl DeepNMF {
pub fn new(
n_components_1: usize,
n_components_2: usize,
max_iter: usize,
inner_iter: usize,
tol: f64,
) -> Self {
Self {
n_components_1,
n_components_2,
max_iter,
inner_iter,
tol,
w1: None,
w2: None,
h: None,
}
}
pub fn fit_transform<S2>(
&mut self,
x: &ArrayBase<S2, Ix2>,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)>
where
S2: Data<Elem = f64>,
{
let x_in = x.to_owned();
let (n, p) = (x_in.nrows(), x_in.ncols());
let k1 = self.n_components_1;
let k2 = self.n_components_2;
if k1 == 0 || k1 > p.min(n) {
return Err(TransformError::InvalidInput(format!(
"n_components_1 must be in 1..=min({n},{p}), got {k1}"
)));
}
if k2 == 0 || k2 > k1 {
return Err(TransformError::InvalidInput(format!(
"n_components_2 must be in 1..={k1}, got {k2}"
)));
}
let mut xt = Array2::<f64>::zeros((p, n));
for i in 0..n {
for j in 0..p {
xt[[j, i]] = x_in[[i, j]];
}
}
let scale = ((xt.iter().cloned().fold(0.0_f64, f64::max)) / k1 as f64).sqrt();
let scale = scale.max(EPS);
let mut w1 = rand_nonneg(p, k1, scale); let mut w2 = rand_nonneg(k1, k2, scale / (k1 as f64).sqrt()); let mut h = rand_nonneg(k2, n, scale / (k1 as f64 * k2 as f64).sqrt());
let mut prev_err = f64::INFINITY;
for _ in 0..self.max_iter {
let z1 = mm(&w2, &h); for _ in 0..self.inner_iter {
let num = mm_a_bt(&xt, &z1); let den = mm_a_bt(&mm(&w1, &z1), &z1); for i in 0..p {
for j in 0..k1 {
w1[[i, j]] *= (num[[i, j]] + EPS) / (den[[i, j]] + EPS);
}
}
clip_nonneg(&mut w1);
}
let z2 = mm_at_b(&w1, &xt); for _ in 0..self.inner_iter {
let num = mm_a_bt(&z2, &h); let den = mm_a_bt(&mm(&w2, &h), &h); for i in 0..k1 {
for j in 0..k2 {
w2[[i, j]] *= (num[[i, j]] + EPS) / (den[[i, j]] + EPS);
}
}
clip_nonneg(&mut w2);
}
let w1w2 = mm(&w1, &w2); for _ in 0..self.inner_iter {
let num = mm_at_b(&w1w2, &xt); let den = mm_at_b(&w1w2, &mm(&w1w2, &h)); for i in 0..k2 {
for j in 0..n {
h[[i, j]] *= (num[[i, j]] + EPS) / (den[[i, j]] + EPS);
}
}
clip_nonneg(&mut h);
}
let x_hat = mm(&mm(&w1, &w2), &h); let err = frob2(&xt, &x_hat).sqrt();
let delta = (prev_err - err).abs() / (prev_err + EPS);
if delta < self.tol {
break;
}
prev_err = err;
}
self.w1 = Some(w1.clone());
self.w2 = Some(w2.clone());
self.h = Some(h.clone());
Ok((w1, w2, h))
}
pub fn reconstruct(&self) -> Result<Array2<f64>> {
let w1 = self.w1.as_ref().ok_or_else(|| TransformError::NotFitted("DeepNMF not fitted".into()))?;
let w2 = self.w2.as_ref().ok_or_else(|| TransformError::NotFitted("DeepNMF not fitted".into()))?;
let h = self.h.as_ref().ok_or_else(|| TransformError::NotFitted("DeepNMF not fitted".into()))?;
Ok(mm(&mm(w1, w2), h))
}
}
fn pseudo_inv_small(a: &Array2<f64>) -> Result<Array2<f64>> {
let k = a.nrows();
let mut ar = a.to_owned();
let trace: f64 = (0..k).map(|i| a[[i, i]]).sum();
let lambda = (trace / k as f64 * 1e-6).max(EPS);
for i in 0..k {
ar[[i, i]] += lambda;
}
invert_small_gauss(&ar)
}
fn invert_small_gauss(mat: &Array2<f64>) -> Result<Array2<f64>> {
let k = mat.nrows();
let mut aug = Array2::<f64>::zeros((k, 2 * k));
for i in 0..k {
for j in 0..k {
aug[[i, j]] = mat[[i, j]];
}
aug[[i, k + i]] = 1.0;
}
for col in 0..k {
let mut max_val = aug[[col, col]].abs();
let mut max_row = col;
for row in (col + 1)..k {
if aug[[row, col]].abs() > max_val {
max_val = aug[[row, col]].abs();
max_row = row;
}
}
if max_val < EPS {
return Err(TransformError::ComputationError("Singular matrix in pseudo_inv".into()));
}
if max_row != col {
for j in 0..(2 * k) {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[max_row, j]];
aug[[max_row, j]] = tmp;
}
}
let diag = aug[[col, col]];
for j in 0..(2 * k) {
aug[[col, j]] /= diag;
}
for row in 0..k {
if row == col {
continue;
}
let factor = aug[[row, col]];
for j in 0..(2 * k) {
let v = aug[[col, j]] * factor;
aug[[row, j]] -= v;
}
}
}
let mut inv = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in 0..k {
inv[[i, j]] = aug[[i, k + j]];
}
}
Ok(inv)
}
pub fn nmf_quality(x: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> (f64, f64) {
let wh = mm(w, h);
let num = frob2(x, &wh).sqrt();
let den = x.iter().map(|v| v * v).sum::<f64>().sqrt();
let frobenius_error = num / (den + EPS);
let total = h.len();
let sparse_count = h.iter().filter(|&&v| v < 1e-6).count();
let sparsity = sparse_count as f64 / total as f64;
(frobenius_error, sparsity)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_nonneg_data(n: usize, p: usize, k: usize) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let mut w = Array2::<f64>::zeros((n, k));
let mut h = Array2::<f64>::zeros((k, p));
for i in 0..n {
for j in 0..k {
w[[i, j]] = rng.gen_range(0.0..2.0);
}
}
for i in 0..k {
for j in 0..p {
h[[i, j]] = rng.gen_range(0.0..2.0);
}
}
mm(&w, &h)
}
#[test]
fn test_nmf_frobenius() {
let x = make_nonneg_data(20, 15, 3);
let mut nmf = NMF::new(3, NmfDivergence::Frobenius, 200, 1e-4);
let (w, h) = nmf.fit_transform(&x).expect("NMF fit failed");
assert_eq!(w.shape(), &[20, 3]);
assert_eq!(h.shape(), &[3, 15]);
assert!(w.iter().all(|&v| v >= 0.0));
assert!(h.iter().all(|&v| v >= 0.0));
let (err, _) = nmf_quality(&x, &w, &h);
assert!(err < 0.5, "Reconstruction error {err} too large");
}
#[test]
fn test_nmf_kl() {
let x = make_nonneg_data(15, 10, 2);
let mut nmf = NMF::new(2, NmfDivergence::KullbackLeibler, 100, 1e-4);
let (w, h) = nmf.fit_transform(&x).expect("NMF KL fit failed");
assert_eq!(w.shape(), &[15, 2]);
assert_eq!(h.shape(), &[2, 10]);
assert!(w.iter().all(|&v| v >= 0.0));
}
#[test]
fn test_semi_nmf() {
let mut rng = scirs2_core::random::rng();
let mut x = Array2::<f64>::zeros((15, 10));
for i in 0..15 {
for j in 0..10 {
x[[i, j]] = rng.gen_range(-1.0..2.0);
}
}
let mut model = SemiNMF::new(3, 100, 1e-4);
let (w, h) = model.fit_transform(&x).expect("SemiNMF failed");
assert_eq!(w.shape(), &[15, 3]);
assert_eq!(h.shape(), &[3, 10]);
assert!(h.iter().all(|&v| v >= 0.0), "H must be non-negative");
}
#[test]
fn test_convex_nmf() {
let x = make_nonneg_data(12, 8, 2);
let mut model = ConvexNMF::new(2, 50, 1e-4);
let (w, h) = model.fit_transform(&x).expect("ConvexNMF failed");
assert_eq!(w.shape(), &[12, 2]);
assert!(h.iter().all(|&v| v >= 0.0), "H must be non-negative");
assert!(w.iter().all(|&v| v >= 0.0), "W must be non-negative");
}
#[test]
fn test_robust_nmf() {
let x = make_nonneg_data(20, 8, 2);
let mut model = RobustNMF::new(2, 50, 1e-4);
let (w, h) = model.fit_transform(&x).expect("RobustNMF failed");
assert_eq!(w.shape(), &[8, 2]);
assert_eq!(h.shape(), &[2, 20]);
assert!(w.iter().all(|&v| v >= 0.0));
assert!(h.iter().all(|&v| v >= 0.0));
}
#[test]
fn test_deep_nmf() {
let x = make_nonneg_data(20, 10, 3);
let mut model = DeepNMF::new(4, 2, 30, 5, 1e-4);
let (w1, w2, h) = model.fit_transform(&x).expect("DeepNMF failed");
assert_eq!(w1.shape(), &[10, 4]);
assert_eq!(w2.shape(), &[4, 2]);
assert_eq!(h.shape(), &[2, 20]);
assert!(w1.iter().all(|&v| v >= 0.0));
assert!(w2.iter().all(|&v| v >= 0.0));
assert!(h.iter().all(|&v| v >= 0.0));
}
#[test]
fn test_nmf_quality() {
let x = make_nonneg_data(10, 8, 2);
let mut nmf = NMF::new(2, NmfDivergence::Frobenius, 100, 1e-4);
let (w, h) = nmf.fit_transform(&x).expect("NMF fit failed");
let (err, sparsity) = nmf_quality(&x, &w, &h);
assert!(err.is_finite());
assert!(sparsity >= 0.0 && sparsity <= 1.0);
}
}