use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::random::{Rng, RngExt};
use crate::error::{Result, TransformError};
#[derive(Debug, Clone)]
pub struct NMF {
n_components: usize,
init: String,
solver: String,
beta_loss: f64,
max_iter: usize,
tol: f64,
random_state: Option<u64>,
alpha: f64,
l1_ratio: f64,
components: Option<Array2<f64>>,
coefficients: Option<Array2<f64>>,
reconstruction_err: Option<f64>,
n_iter: Option<usize>,
}
impl NMF {
pub fn new(ncomponents: usize) -> Self {
NMF {
n_components: ncomponents,
init: "random".to_string(),
solver: "mu".to_string(),
beta_loss: 2.0, max_iter: 200,
tol: 1e-4,
random_state: None,
alpha: 0.0,
l1_ratio: 0.0,
components: None,
coefficients: None,
reconstruction_err: None,
n_iter: None,
}
}
pub fn with_init(mut self, init: &str) -> Self {
self.init = init.to_string();
self
}
pub fn with_solver(mut self, solver: &str) -> Self {
self.solver = solver.to_string();
self
}
pub fn with_beta_loss(mut self, beta: f64) -> Self {
self.beta_loss = beta;
self
}
pub fn with_max_iter(mut self, maxiter: usize) -> Self {
self.max_iter = maxiter;
self
}
pub fn with_tolerance(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
pub fn with_regularization(mut self, alpha: f64, l1ratio: f64) -> Self {
self.alpha = alpha;
self.l1_ratio = l1ratio;
self
}
fn random_initialization(&self, v: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
let mut rng = scirs2_core::random::rng();
let scale = (v.mean().expect("Operation failed") / self.n_components as f64).sqrt();
let mut w = Array2::zeros((n_samples, self.n_components));
let mut h = Array2::zeros((self.n_components, n_features));
for i in 0..n_samples {
for j in 0..self.n_components {
w[[i, j]] = rng.random::<f64>() * scale;
}
}
for i in 0..self.n_components {
for j in 0..n_features {
h[[i, j]] = rng.random::<f64>() * scale;
}
}
(w, h)
}
fn nndsvd_initialization(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
let (u, s, vt) = match scirs2_linalg::svd::<f64>(&v.view(), true, None) {
Ok(result) => result,
Err(e) => return Err(TransformError::LinalgError(e)),
};
let mut w = Array2::zeros((n_samples, self.n_components));
let mut h = Array2::zeros((self.n_components, n_features));
for j in 0..self.n_components {
let x = u.column(j);
let y = vt.row(j);
let x_pos = x.mapv(|v| v.max(0.0));
let x_neg = x.mapv(|v| (-v).max(0.0));
let y_pos = y.mapv(|v| v.max(0.0));
let y_neg = y.mapv(|v| (-v).max(0.0));
let x_pos_norm = x_pos.dot(&x_pos).sqrt();
let x_neg_norm = x_neg.dot(&x_neg).sqrt();
let y_pos_norm = y_pos.dot(&y_pos).sqrt();
let y_neg_norm = y_neg.dot(&y_neg).sqrt();
let m_pos = x_pos_norm * y_pos_norm;
let m_neg = x_neg_norm * y_neg_norm;
if m_pos > m_neg {
for i in 0..n_samples {
w[[i, j]] = (s[j].sqrt() * x_pos[i] / x_pos_norm).max(0.0);
}
for i in 0..n_features {
h[[j, i]] = (s[j].sqrt() * y_pos[i] / y_pos_norm).max(0.0);
}
} else {
for i in 0..n_samples {
w[[i, j]] = (s[j].sqrt() * x_neg[i] / x_neg_norm).max(0.0);
}
for i in 0..n_features {
h[[j, i]] = (s[j].sqrt() * y_neg[i] / y_neg_norm).max(0.0);
}
}
}
Ok((w, h))
}
fn initialize_matrices(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
match self.init.as_str() {
"random" => Ok(self.random_initialization(v)),
"nndsvd" => self.nndsvd_initialization(v),
_ => Ok(self.random_initialization(v)),
}
}
fn frobenius_loss(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
let wh = w.dot(h);
let diff = v - &wh;
diff.mapv(|x| x * x).sum().sqrt()
}
fn update_w(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
let eps = 1e-10;
let wh = w.dot(h);
let numerator = v.dot(&h.t());
let mut denominator = wh.dot(&h.t());
if self.alpha > 0.0 && self.l1_ratio < 1.0 {
let l2_reg = self.alpha * (1.0 - self.l1_ratio);
denominator = &denominator + &(w * l2_reg);
}
if self.alpha > 0.0 && self.l1_ratio > 0.0 {
let l1_reg = self.alpha * self.l1_ratio;
denominator = denominator.mapv(|x| x + l1_reg);
}
let mut w_new = w * &(numerator / (denominator + eps));
w_new.mapv_inplace(|x| x.max(eps));
w_new
}
fn update_h(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
let eps = 1e-10;
let wh = w.dot(h);
let numerator = w.t().dot(v);
let mut denominator = w.t().dot(&wh);
if self.alpha > 0.0 && self.l1_ratio < 1.0 {
let l2_reg = self.alpha * (1.0 - self.l1_ratio);
denominator = &denominator + &(h * l2_reg);
}
if self.alpha > 0.0 && self.l1_ratio > 0.0 {
let l1_reg = self.alpha * self.l1_ratio;
denominator = denominator.mapv(|x| x + l1_reg);
}
let mut h_new = h * &(numerator / (denominator + eps));
h_new.mapv_inplace(|x| x.max(eps));
h_new
}
fn update_w_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
let eps = 1e-10;
let (n_samples, n_components) = w.dim();
let mut w_new = w.clone();
let hht = h.dot(&h.t());
for i in 0..n_samples {
for j in 0..n_components {
let mut numerator = 0.0;
let mut denominator = hht[[j, j]];
for k in 0..h.ncols() {
numerator += v[[i, k]] * h[[j, k]];
}
for k in 0..n_components {
if k != j {
numerator -= w_new[[i, k]] * hht[[k, j]];
}
}
if self.alpha > 0.0 {
if self.l1_ratio > 0.0 {
let l1_penalty = self.alpha * self.l1_ratio;
numerator -= l1_penalty;
}
if self.l1_ratio < 1.0 {
let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
denominator += l2_penalty;
numerator -= l2_penalty * w_new[[i, j]];
}
}
let new_val = if denominator > eps {
(numerator / denominator).max(eps)
} else {
eps
};
w_new[[i, j]] = new_val;
}
}
w_new
}
fn update_h_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
let eps = 1e-10;
let (n_components, n_features) = h.dim();
let mut h_new = h.clone();
let wtw = w.t().dot(w);
for i in 0..n_components {
for j in 0..n_features {
let mut numerator = 0.0;
let mut denominator = wtw[[i, i]];
for k in 0..w.nrows() {
numerator += w[[k, i]] * v[[k, j]];
}
for k in 0..n_components {
if k != i {
numerator -= wtw[[i, k]] * h_new[[k, j]];
}
}
if self.alpha > 0.0 {
if self.l1_ratio > 0.0 {
let l1_penalty = self.alpha * self.l1_ratio;
numerator -= l1_penalty;
}
if self.l1_ratio < 1.0 {
let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
denominator += l2_penalty;
numerator -= l2_penalty * h_new[[i, j]];
}
}
let new_val = if denominator > eps {
(numerator / denominator).max(eps)
} else {
eps
};
h_new[[i, j]] = new_val;
}
}
h_new
}
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
where
S: Data,
S::Elem: Float + NumCast,
{
for elem in x.iter() {
let val = NumCast::from(*elem).unwrap_or(0.0);
if val < 0.0 {
return Err(TransformError::InvalidInput(
"NMF requires non-negative input data".to_string(),
));
}
}
let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
if self.n_components > n_features.min(n_samples) {
return Err(TransformError::InvalidInput(format!(
"n_components={} must be <= min(n_samples={}, n_features={})",
self.n_components, n_samples, n_features
)));
}
let (mut w, mut h) = self.initialize_matrices(&v)?;
let mut prev_error = self.frobenius_loss(&v, &w, &h);
let mut n_iter = 0;
for iter in 0..self.max_iter {
if self.solver == "mu" {
h = self.update_h(&v, &w, &h);
w = self.update_w(&v, &w, &h);
} else if self.solver == "cd" {
h = self.update_h_cd(&v, &w, &h);
w = self.update_w_cd(&v, &w, &h);
} else {
return Err(TransformError::InvalidInput(format!(
"Unknown solver '{}'. Supported solvers: 'mu', 'cd'",
self.solver
)));
}
let error = self.frobenius_loss(&v, &w, &h);
if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
n_iter = iter + 1;
break;
}
prev_error = error;
n_iter = iter + 1;
}
self.components = Some(h);
self.coefficients = Some(w);
self.reconstruction_err = Some(prev_error);
self.n_iter = Some(n_iter);
Ok(())
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
if self.components.is_none() {
return Err(TransformError::TransformationError(
"NMF model has not been fitted".to_string(),
));
}
for elem in x.iter() {
let val = NumCast::from(*elem).unwrap_or(0.0);
if val < 0.0 {
return Err(TransformError::InvalidInput(
"NMF requires non-negative input data".to_string(),
));
}
}
let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
let h = self.components.as_ref().expect("Operation failed");
let n_samples = v.shape()[0];
let mut rng = scirs2_core::random::rng();
let scale = (v.mean().expect("Operation failed") / self.n_components as f64).sqrt();
let mut w = Array2::zeros((n_samples, self.n_components));
for i in 0..n_samples {
for j in 0..self.n_components {
w[[i, j]] = rng.random::<f64>() * scale;
}
}
for _ in 0..self.max_iter {
w = self.update_w(&v, &w, h);
}
Ok(w)
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
where
S: Data,
S::Elem: Float + NumCast,
{
self.fit(x)?;
Ok(self
.coefficients
.as_ref()
.expect("Operation failed")
.clone())
}
pub fn components(&self) -> Option<&Array2<f64>> {
self.components.as_ref()
}
pub fn coefficients(&self) -> Option<&Array2<f64>> {
self.coefficients.as_ref()
}
pub fn reconstruction_error(&self) -> Option<f64> {
self.reconstruction_err
}
pub fn n_iterations(&self) -> Option<usize> {
self.n_iter
}
pub fn inverse_transform(&self, w: &Array2<f64>) -> Result<Array2<f64>> {
if self.components.is_none() {
return Err(TransformError::TransformationError(
"NMF model has not been fitted".to_string(),
));
}
let h = self.components.as_ref().expect("Operation failed");
Ok(w.dot(h))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_nmf_basic() {
let x = Array::from_shape_vec(
(6, 4),
vec![
1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
],
)
.expect("Operation failed");
let mut nmf = NMF::new(2).with_max_iter(100).with_random_state(42);
let w = nmf.fit_transform(&x).expect("Operation failed");
assert_eq!(w.shape(), &[6, 2]);
for val in w.iter() {
assert!(*val >= 0.0);
}
let h = nmf.components().expect("Operation failed");
assert_eq!(h.shape(), &[2, 4]);
for val in h.iter() {
assert!(*val >= 0.0);
}
let x_reconstructed = nmf.inverse_transform(&w).expect("Operation failed");
assert_eq!(x_reconstructed.shape(), x.shape());
}
#[test]
fn test_nmf_regularization() {
let x = Array2::<f64>::eye(10) + 0.1;
let mut nmf = NMF::new(3).with_regularization(0.1, 0.5).with_max_iter(50);
let result = nmf.fit_transform(&x);
assert!(result.is_ok());
let w = result.expect("Operation failed");
assert_eq!(w.shape(), &[10, 3]);
}
#[test]
fn test_nmf_negative_input() {
let x = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, -1.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.expect("Operation failed");
let mut nmf = NMF::new(2);
let result = nmf.fit(&x);
assert!(result.is_err());
if let Err(e) = result {
assert!(e
.to_string()
.contains("NMF requires non-negative input data"));
}
}
#[test]
fn test_nmf_coordinate_descent() {
let x = Array::from_shape_vec(
(6, 4),
vec![
1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
],
)
.expect("Operation failed");
let mut nmf_cd = NMF::new(2)
.with_solver("cd")
.with_max_iter(100)
.with_random_state(42);
let w_cd = nmf_cd.fit_transform(&x).expect("Operation failed");
assert_eq!(w_cd.shape(), &[6, 2]);
for val in w_cd.iter() {
assert!(*val >= 0.0);
}
let h_cd = nmf_cd.components().expect("Operation failed");
assert_eq!(h_cd.shape(), &[2, 4]);
for val in h_cd.iter() {
assert!(*val >= 0.0);
}
let x_reconstructed = nmf_cd.inverse_transform(&w_cd).expect("Operation failed");
assert_eq!(x_reconstructed.shape(), x.shape());
let mut nmf_mu = NMF::new(2)
.with_solver("mu")
.with_max_iter(100)
.with_random_state(42);
let _w_mu = nmf_mu.fit_transform(&x).expect("Operation failed");
assert!(nmf_cd.reconstruction_error().expect("Operation failed") >= 0.0);
assert!(nmf_mu.reconstruction_error().expect("Operation failed") >= 0.0);
}
#[test]
fn test_nmf_invalid_solver() {
let x = Array2::<f64>::eye(3) + 0.1;
let mut nmf = NMF::new(2).with_solver("invalid");
let result = nmf.fit(&x);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown solver"));
}
}