use crate::error::{Result, TransformError};
use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{Rng, RngExt};
const EPS: f64 = 1e-10;
fn random_wh(n_rows: usize, k: usize, n_cols: usize, scale: f64) -> (Array2<f64>, Array2<f64>) {
let mut rng = scirs2_core::random::rng();
let mut w = Array2::<f64>::zeros((n_rows, k));
let mut h = Array2::<f64>::zeros((k, n_cols));
for i in 0..n_rows {
for j in 0..k {
w[[i, j]] = rng.random::<f64>() * scale;
}
}
for i in 0..k {
for j in 0..n_cols {
h[[i, j]] = rng.random::<f64>() * scale;
}
}
(w, h)
}
fn frob_error(x: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
let wh = w.dot(h);
let diff = x - &wh;
diff.mapv(|v| v * v).sum().sqrt()
}
#[derive(Debug, Clone)]
pub struct NmfMu {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
}
impl NmfMu {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 200,
tol: 1e-4,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_non_negative(&x)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let scale = (x.mean().unwrap_or(1.0) / self.n_components as f64).sqrt().max(EPS);
let (mut w, mut h) = random_wh(n, self.n_components, p, scale);
let mut prev_err = frob_error(&x, &w, &h);
for _ in 0..self.max_iter {
let wt_x = w.t().dot(&x);
let wt_wh = w.t().dot(&w.dot(&h));
for i in 0..self.n_components {
for j in 0..p {
h[[i, j]] = (h[[i, j]] * wt_x[[i, j]] / (wt_wh[[i, j]] + EPS)).max(EPS);
}
}
let x_ht = x.dot(&h.t());
let whht = w.dot(&h).dot(&h.t());
for i in 0..n {
for j in 0..self.n_components {
w[[i, j]] = (w[[i, j]] * x_ht[[i, j]] / (whht[[i, j]] + EPS)).max(EPS);
}
}
let err = frob_error(&x, &w, &h);
if (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
break;
}
prev_err = err;
}
Ok((w, h))
}
}
#[derive(Debug, Clone)]
pub struct NmfAls {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub reg: f64,
}
impl NmfAls {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 300,
tol: 1e-5,
reg: 1e-8,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_reg(mut self, reg: f64) -> Self {
self.reg = reg;
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_non_negative(&x)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let k = self.n_components;
let scale = (x.mean().unwrap_or(1.0) / k as f64).sqrt().max(EPS);
let (mut w, mut h) = random_wh(n, k, p, scale);
let mut prev_err = frob_error(&x, &w, &h);
for _ in 0..self.max_iter {
h = als_solve_h(&x, &w, k, self.reg);
w = als_solve_w(&x, &h, n, self.reg);
let err = frob_error(&x, &w, &h);
if (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
break;
}
prev_err = err;
}
Ok((w, h))
}
}
fn als_solve_h(x: &Array2<f64>, w: &Array2<f64>, k: usize, reg: f64) -> Array2<f64> {
let p = x.ncols();
let wtw = w.t().dot(w);
let mut a = wtw;
for i in 0..k {
a[[i, i]] += reg;
}
let b = w.t().dot(x);
let mut h = Array2::<f64>::zeros((k, p));
for col in 0..p {
let rhs: Vec<f64> = (0..k).map(|i| b[[i, col]]).collect();
let sol = solve_posdef_system(&a, &rhs);
for i in 0..k {
h[[i, col]] = sol[i].max(EPS);
}
}
h
}
fn als_solve_w(x: &Array2<f64>, h: &Array2<f64>, n: usize, reg: f64) -> Array2<f64> {
let k = h.nrows();
let hht = h.dot(&h.t());
let mut a = hht;
for i in 0..k {
a[[i, i]] += reg;
}
let xht = x.dot(&h.t());
let mut w = Array2::<f64>::zeros((n, k));
for row in 0..n {
let rhs: Vec<f64> = (0..k).map(|j| xht[[row, j]]).collect();
let sol = solve_posdef_system(&a, &rhs);
for j in 0..k {
w[[row, j]] = sol[j].max(EPS);
}
}
w
}
fn solve_posdef_system(a: &Array2<f64>, b: &[f64]) -> Vec<f64> {
let k = b.len();
let mut aug: Vec<Vec<f64>> = (0..k)
.map(|i| {
let mut row: Vec<f64> = (0..k).map(|j| a[[i, j]]).collect();
row.push(b[i]);
row
})
.collect();
for col in 0..k {
let mut max_row = col;
let mut max_val = aug[col][col].abs();
for row in (col + 1)..k {
if aug[row][col].abs() > max_val {
max_val = aug[row][col].abs();
max_row = row;
}
}
aug.swap(col, max_row);
let pivot = aug[col][col];
if pivot.abs() < EPS {
continue;
}
let inv_pivot = 1.0 / pivot;
for row in (col + 1)..k {
let factor = aug[row][col] * inv_pivot;
for c in col..=k {
let val = aug[col][c];
aug[row][c] -= factor * val;
}
}
}
let mut x = vec![0.0f64; k];
for i in (0..k).rev() {
let mut sum = aug[i][k];
for j in (i + 1)..k {
sum -= aug[i][j] * x[j];
}
x[i] = if aug[i][i].abs() > EPS {
sum / aug[i][i]
} else {
0.0
};
}
x
}
#[derive(Debug, Clone)]
pub struct NmfSparse {
pub n_components: usize,
pub lambda: f64,
pub max_iter: usize,
pub tol: f64,
}
impl NmfSparse {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
lambda: 0.1,
max_iter: 200,
tol: 1e-4,
}
}
pub fn with_lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_non_negative(&x)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let k = self.n_components;
let scale = (x.mean().unwrap_or(1.0) / k as f64).sqrt().max(EPS);
let (mut w, mut h) = random_wh(n, k, p, scale);
let mut prev_err = frob_error(&x, &w, &h);
for _ in 0..self.max_iter {
let wt_x = w.t().dot(&x);
let wtwh = w.t().dot(&w.dot(&h));
for i in 0..k {
for j in 0..p {
let num = wt_x[[i, j]];
let den = wtwh[[i, j]] + self.lambda + EPS;
h[[i, j]] = (h[[i, j]] * num / den).max(EPS);
}
}
let x_ht = x.dot(&h.t());
let whht = w.dot(&h).dot(&h.t());
for i in 0..n {
for j in 0..k {
w[[i, j]] = (w[[i, j]] * x_ht[[i, j]] / (whht[[i, j]] + EPS)).max(EPS);
}
}
let err = frob_error(&x, &w, &h);
if (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
break;
}
prev_err = err;
}
Ok((w, h))
}
}
#[derive(Debug, Clone)]
pub struct NmfSemiNmf {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub reg: f64,
}
impl NmfSemiNmf {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 200,
tol: 1e-4,
reg: 1e-8,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let k = self.n_components;
let scale = 0.1f64;
let (mut w, mut h) = random_wh(n, k, p, scale);
{
let mut rng = scirs2_core::random::rng();
for v in w.iter_mut() {
*v = rng.random::<f64>() * 2.0 - 1.0;
}
}
let mut prev_err = frob_error(&x, &w, &h);
for _ in 0..self.max_iter {
let wt_x = w.t().dot(&x); let wt_x_pos = wt_x.mapv(|v| v.max(0.0));
let wt_x_neg = wt_x.mapv(|v| (-v).max(0.0));
let wtw = w.t().dot(&w); let wtwh = wtw.dot(&h);
for i in 0..k {
for j in 0..p {
let num = wt_x_pos[[i, j]] + EPS;
let den = wt_x_neg[[i, j]] + wtwh[[i, j]] + EPS;
h[[i, j]] = (h[[i, j]] * num / den).max(EPS);
}
}
let xht = x.dot(&h.t()); let mut hht = h.dot(&h.t()); for i in 0..k {
hht[[i, i]] += self.reg;
}
for row in 0..n {
let rhs: Vec<f64> = (0..k).map(|j| xht[[row, j]]).collect();
let sol = solve_posdef_system(&hht, &rhs);
for j in 0..k {
w[[row, j]] = sol[j];
}
}
let err = frob_error(&x, &w, &h);
if (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
break;
}
prev_err = err;
}
Ok((w, h))
}
}
#[derive(Debug, Clone)]
pub struct NmfConvex {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
}
impl NmfConvex {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
max_iter: 200,
tol: 1e-4,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn fit<S>(
&self,
x_raw: &ArrayBase<S, Ix2>,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_non_negative(&x)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let k = self.n_components;
let mut rng = scirs2_core::random::rng();
let mut g = Array2::<f64>::zeros((n, k));
let mut h = Array2::<f64>::zeros((k, p));
for i in 0..n {
let mut row_sum = 0.0;
for j in 0..k {
g[[i, j]] = rng.random::<f64>() + EPS;
row_sum += g[[i, j]];
}
for j in 0..k {
g[[i, j]] /= row_sum;
}
}
for i in 0..k {
for j in 0..p {
h[[i, j]] = rng.random::<f64>() + EPS;
}
}
let xtx = x.t().dot(&x);
let mut prev_err = {
let w = x.dot(&g);
frob_error(&x, &w, &h)
};
for _ in 0..self.max_iter {
{
let xg = x.dot(&g); let w = &xg;
let r = &x - &w.dot(&h); let xt_r_ht = x.t().dot(&r).dot(&h.t()); let _ = xt_r_ht;
let xxt = x.dot(&x.t()); let xxt_g = xxt.dot(&g); let num_g = xxt.dot(&x.dot(&h.t())); let x_ht = x.dot(&h.t()); let num_g2 = xxt.dot(&x_ht);
let den_g = xxt_g.dot(&h).dot(&h.t());
for i in 0..n {
for j in 0..k {
let num = num_g2[[i, j]];
let den = den_g[[i, j]] + EPS;
if num > 0.0 {
g[[i, j]] = (g[[i, j]] * num / den).max(EPS);
}
}
}
for j in 0..k {
let col_sum: f64 = (0..n).map(|i| g[[i, j]]).sum::<f64>().max(EPS);
for i in 0..n {
g[[i, j]] /= col_sum;
}
}
}
{
let w = x.dot(&g); let wt_x = w.t().dot(&x); let wt_wh = w.t().dot(&w.dot(&h)); for i in 0..k {
for j in 0..p {
h[[i, j]] = (h[[i, j]] * wt_x[[i, j]] / (wt_wh[[i, j]] + EPS)).max(EPS);
}
}
}
let w = x.dot(&g);
let err = frob_error(&x, &w, &h);
if (prev_err - err).abs() / prev_err.max(EPS) < self.tol {
break;
}
prev_err = err;
}
let w = x.dot(&g);
Ok((w, h, g))
}
}
#[derive(Debug, Clone)]
pub struct NmfOnline {
pub n_components: usize,
pub batch_size: usize,
pub n_epochs: usize,
pub rho: f64,
pub tol: f64,
pub seed: Option<u64>,
}
impl NmfOnline {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
batch_size: 32,
n_epochs: 10,
rho: 0.9,
tol: 1e-4,
seed: None,
}
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_n_epochs(mut self, n_epochs: usize) -> Self {
self.n_epochs = n_epochs;
self
}
pub fn with_rho(mut self, rho: f64) -> Self {
self.rho = rho.clamp(0.5, 1.0);
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn fit<S>(&self, x_raw: &ArrayBase<S, Ix2>) -> Result<(Array2<f64>, Array2<f64>)>
where
S: Data,
S::Elem: Float + NumCast,
{
let x = to_f64(x_raw)?;
check_non_negative(&x)?;
check_rank(&x, self.n_components)?;
let (n, p) = x.dim();
let k = self.n_components;
let scale = (x.mean().unwrap_or(1.0) / k as f64).sqrt().max(EPS);
let (_, mut h) = random_wh(1, k, p, scale);
let mut a = Array2::<f64>::zeros((k, k)); let mut b = Array2::<f64>::zeros((k, p));
let mut rng = scirs2_core::random::rng();
for _epoch in 0..self.n_epochs {
let mut indices: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
let j = (rng.random::<f64>() * (i + 1) as f64) as usize;
indices.swap(i, j);
}
let mut start = 0;
while start < n {
let end = (start + self.batch_size).min(n);
let batch_idx = &indices[start..end];
let batch_n = batch_idx.len();
let mut x_batch = Array2::<f64>::zeros((batch_n, p));
for (bi, &gi) in batch_idx.iter().enumerate() {
for j in 0..p {
x_batch[[bi, j]] = x[[gi, j]];
}
}
let mut alpha = Array2::<f64>::zeros((batch_n, k)); for i in 0..batch_n {
for j in 0..k {
alpha[[i, j]] = scale * rng.random::<f64>();
}
}
for _inner in 0..50 {
let xt_batch = x_batch.clone(); let num = xt_batch.dot(&h.t()); let den = alpha.dot(&h).dot(&h.t()); for i in 0..batch_n {
for j in 0..k {
alpha[[i, j]] =
(alpha[[i, j]] * num[[i, j]] / (den[[i, j]] + EPS)).max(EPS);
}
}
}
let alpha_t = alpha.t().to_owned(); let new_a = alpha_t.dot(&alpha); let new_b = alpha_t.dot(&x_batch);
let rho = self.rho;
for i in 0..k {
for j in 0..k {
a[[i, j]] = rho * a[[i, j]] + new_a[[i, j]];
}
for j in 0..p {
b[[i, j]] = rho * b[[i, j]] + new_b[[i, j]];
}
}
for i in 0..k {
for j in 0..p {
let num_h = b[[i, j]];
let den_h = a.row(i).dot(&h.column(j)) + EPS;
if num_h > 0.0 {
h[[i, j]] = (h[[i, j]] * num_h / den_h).max(EPS);
}
}
}
start = end;
}
}
let (mut w, _) = random_wh(n, k, 1, scale); w = Array2::<f64>::zeros((n, k));
{
let mut rng2 = scirs2_core::random::rng();
for i in 0..n {
for j in 0..k {
w[[i, j]] = scale * rng2.random::<f64>();
}
}
}
for _inner in 0..100 {
let num = x.dot(&h.t()); let den = w.dot(&h).dot(&h.t()); for i in 0..n {
for j in 0..k {
w[[i, j]] = (w[[i, j]] * num[[i, j]] / (den[[i, j]] + EPS)).max(EPS);
}
}
}
Ok((w, h))
}
}
pub fn nmf_quality(x: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> (f64, f64) {
let wh = w.dot(h);
let diff = x - &wh;
let rec_err = diff.mapv(|v| v * v).sum().sqrt();
let x_norm = x.mapv(|v| v * v).sum().sqrt().max(EPS);
let rel_err = rec_err / x_norm;
let total_h = h.len() as f64;
let sparse_count = h.iter().filter(|&&v| v < 1e-6).count() as f64;
let sparsity = if total_h > 0.0 {
sparse_count / total_h
} else {
0.0
};
(rel_err, sparsity)
}
fn to_f64<S>(x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
Ok(x.mapv(|v| NumCast::from(v).unwrap_or(0.0)))
}
fn check_non_negative(x: &Array2<f64>) -> Result<()> {
for &v in x.iter() {
if v < 0.0 {
return Err(TransformError::InvalidInput(
"NMF requires a non-negative input matrix".to_string(),
));
}
}
Ok(())
}
fn check_rank(x: &Array2<f64>, k: usize) -> Result<()> {
let (n, p) = x.dim();
let max_rank = n.min(p);
if k > max_rank {
return Err(TransformError::InvalidInput(format!(
"n_components={k} must be ≤ min(n_samples={n}, n_features={p})"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn toy_matrix() -> Array2<f64> {
let data = vec![
1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
];
Array2::from_shape_vec((6, 4), data).expect("shape ok")
}
#[test]
fn test_nmf_mu_basic() {
let x = toy_matrix();
let nmf = NmfMu::new(2).with_max_iter(300).with_tol(1e-5);
let (w, h) = nmf.fit(&x).expect("NmfMu fit ok");
assert_eq!(w.shape(), &[6, 2]);
assert_eq!(h.shape(), &[2, 4]);
for &v in w.iter() {
assert!(v >= 0.0, "W must be non-negative");
}
for &v in h.iter() {
assert!(v >= 0.0, "H must be non-negative");
}
let (rel_err, _) = nmf_quality(&x, &w, &h);
assert!(rel_err < 0.5, "relative error {rel_err} should be small");
}
#[test]
fn test_nmf_als_basic() {
let x = toy_matrix();
let nmf = NmfAls::new(2).with_max_iter(200);
let (w, h) = nmf.fit(&x).expect("NmfAls fit ok");
assert_eq!(w.shape(), &[6, 2]);
assert_eq!(h.shape(), &[2, 4]);
for &v in w.iter() {
assert!(v >= 0.0);
}
for &v in h.iter() {
assert!(v >= 0.0);
}
}
#[test]
fn test_nmf_sparse_promotes_sparsity() {
let x = toy_matrix();
let nmf_sparse = NmfSparse::new(2).with_lambda(1.0).with_max_iter(300);
let nmf_base = NmfSparse::new(2).with_lambda(0.0).with_max_iter(300);
let (ws, hs) = nmf_sparse.fit(&x).expect("NmfSparse fit ok");
let (wb, hb) = nmf_base.fit(&x).expect("NmfSparse base fit ok");
let (_, sp_s) = nmf_quality(&x, &ws, &hs);
let (_, sp_b) = nmf_quality(&x, &wb, &hb);
assert!(sp_s >= sp_b || (sp_s - sp_b).abs() < 0.2, "sparse H not sparser");
}
#[test]
fn test_nmf_semi_basic() {
let x = toy_matrix();
let nmf = NmfSemiNmf::new(2).with_max_iter(200);
let (w, h) = nmf.fit(&x).expect("NmfSemiNmf fit ok");
assert_eq!(w.shape(), &[6, 2]);
assert_eq!(h.shape(), &[2, 4]);
for &v in h.iter() {
assert!(v >= 0.0, "H must be non-negative");
}
}
#[test]
fn test_nmf_convex_basic() {
let x = toy_matrix();
let nmf = NmfConvex::new(2).with_max_iter(100);
let (w, h, g) = nmf.fit(&x).expect("NmfConvex fit ok");
assert_eq!(w.shape(), &[6, 2]);
assert_eq!(h.shape(), &[2, 4]);
assert_eq!(g.shape(), &[6, 2]);
for &v in h.iter() {
assert!(v >= 0.0);
}
for &v in g.iter() {
assert!(v >= 0.0);
}
}
#[test]
fn test_nmf_online_basic() {
let x = toy_matrix();
let nmf = NmfOnline::new(2).with_batch_size(3).with_n_epochs(5);
let (w, h) = nmf.fit(&x).expect("NmfOnline fit ok");
assert_eq!(w.shape(), &[6, 2]);
assert_eq!(h.shape(), &[2, 4]);
for &v in w.iter() {
assert!(v >= 0.0);
}
for &v in h.iter() {
assert!(v >= 0.0);
}
}
#[test]
fn test_nmf_quality_perfect() {
let w = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).expect("valid shape");
let h = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 1.0]).expect("valid shape");
let x = w.dot(&h);
let (err, _sparsity) = nmf_quality(&x, &w, &h);
assert!(err < 1e-10, "perfect reconstruction should give zero error");
}
#[test]
fn test_nmf_mu_negative_input_rejected() {
let bad = Array2::from_shape_vec((2, 2), vec![1.0, -1.0, 2.0, 3.0]).expect("valid shape");
let nmf = NmfMu::new(1);
assert!(nmf.fit(&bad).is_err());
}
#[test]
fn test_nmf_rank_too_large_rejected() {
let small = Array2::<f64>::eye(3);
let nmf = NmfMu::new(10);
assert!(nmf.fit(&small).is_err());
}
}