1use anofox_ml_core::{FitUnsupervised, Result, RustMlError};
7use faer::linalg::solvers::Svd;
8use faer::Mat;
9use ndarray::Array2;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum NmfInit {
15 Random,
17 Nndsvd,
19}
20
21#[derive(Debug, Clone)]
22pub struct Nmf {
23 pub n_components: usize,
24 pub max_iter: usize,
25 pub tol: f64,
26 pub seed: u64,
27 pub init: NmfInit,
28}
29
30impl Nmf {
31 pub fn new(n_components: usize) -> Self {
32 Self {
33 n_components,
34 max_iter: 200,
35 tol: 1e-4,
36 seed: 0,
37 init: NmfInit::Nndsvd,
38 }
39 }
40 pub fn with_init(mut self, init: NmfInit) -> Self {
41 self.init = init;
42 self
43 }
44}
45
46fn nndsvd_init(x: &Array2<f64>, k: usize) -> Result<(Array2<f64>, Array2<f64>)> {
54 let n = x.nrows();
55 let d = x.ncols();
56 let mat = Mat::<f64>::from_fn(n, d, |i, j| x[[i, j]]);
57 let svd = Svd::new(mat.as_ref())
58 .map_err(|e| RustMlError::InvalidParameter(format!("NNDSVD SVD failed: {e:?}")))?;
59 let u = svd.U();
60 let s = svd.S();
61 let v = svd.V();
62 let r = s.column_vector().nrows().min(k);
63
64 let mut w = Array2::<f64>::zeros((n, k));
65 let mut h = Array2::<f64>::zeros((k, d));
66
67 let s0 = s.column_vector()[0].max(1e-12);
69 let mut u0_pos_norm = 0.0_f64;
70 for i in 0..n {
71 u0_pos_norm += u[(i, 0)].max(0.0).powi(2);
72 }
73 u0_pos_norm = u0_pos_norm.sqrt();
74 let mut v0_pos_norm = 0.0_f64;
75 for j in 0..d {
76 v0_pos_norm += v[(j, 0)].max(0.0).powi(2);
77 }
78 v0_pos_norm = v0_pos_norm.sqrt();
79 let (u_sign, v_sign) =
81 if u0_pos_norm * v0_pos_norm >= (u0_pos_norm * v0_pos_norm).max(1e-12) / 2.0 {
82 (1.0, 1.0)
83 } else {
84 (-1.0, -1.0)
85 };
86 let lead_scale = s0.sqrt();
87 for i in 0..n {
88 w[[i, 0]] = (u_sign * u[(i, 0)]).max(0.0) * lead_scale;
89 }
90 for j in 0..d {
91 h[[0, j]] = (v_sign * v[(j, 0)]).max(0.0) * lead_scale;
92 }
93
94 for c in 1..r {
96 let sigma = s.column_vector()[c].max(1e-12);
97 let mut up = vec![0.0_f64; n];
99 let mut un = vec![0.0_f64; n];
100 let mut up_norm = 0.0_f64;
101 let mut un_norm = 0.0_f64;
102 for i in 0..n {
103 let val = u[(i, c)];
104 if val > 0.0 {
105 up[i] = val;
106 up_norm += val * val;
107 } else {
108 un[i] = -val;
109 un_norm += val * val;
110 }
111 }
112 up_norm = up_norm.sqrt();
113 un_norm = un_norm.sqrt();
114 let mut vp = vec![0.0_f64; d];
115 let mut vn = vec![0.0_f64; d];
116 let mut vp_norm = 0.0_f64;
117 let mut vn_norm = 0.0_f64;
118 for j in 0..d {
119 let val = v[(j, c)];
120 if val > 0.0 {
121 vp[j] = val;
122 vp_norm += val * val;
123 } else {
124 vn[j] = -val;
125 vn_norm += val * val;
126 }
127 }
128 vp_norm = vp_norm.sqrt();
129 vn_norm = vn_norm.sqrt();
130 let pos = up_norm * vp_norm;
133 let neg = un_norm * vn_norm;
134 let scale = sigma.sqrt() * (pos.max(neg)).sqrt();
135 if pos >= neg {
136 let nrm_u = up_norm.max(1e-12);
137 let nrm_v = vp_norm.max(1e-12);
138 for i in 0..n {
139 w[[i, c]] = up[i] / nrm_u * scale;
140 }
141 for j in 0..d {
142 h[[c, j]] = vp[j] / nrm_v * scale;
143 }
144 } else {
145 let nrm_u = un_norm.max(1e-12);
146 let nrm_v = vn_norm.max(1e-12);
147 for i in 0..n {
148 w[[i, c]] = un[i] / nrm_u * scale;
149 }
150 for j in 0..d {
151 h[[c, j]] = vn[j] / nrm_v * scale;
152 }
153 }
154 }
155 let eps = 1e-6;
157 for v in w.iter_mut() {
158 if *v < eps {
159 *v = eps;
160 }
161 }
162 for v in h.iter_mut() {
163 if *v < eps {
164 *v = eps;
165 }
166 }
167 Ok((w, h))
168}
169
170#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
171pub struct FittedNmf {
172 pub components: Array2<f64>,
174 pub reconstruction_err: f64,
176 pub n_iter: usize,
177}
178
179impl FitUnsupervised<f64> for Nmf {
180 type Fitted = FittedNmf;
181
182 fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
183 let n = x.nrows();
184 let d = x.ncols();
185 let k = self.n_components;
186 if n == 0 || d == 0 {
187 return Err(RustMlError::EmptyInput("empty input".into()));
188 }
189 if k == 0 || k > d.min(n) {
190 return Err(RustMlError::InvalidParameter(format!(
191 "n_components must be in 1..={}",
192 d.min(n)
193 )));
194 }
195 for v in x.iter() {
197 if *v < 0.0 {
198 return Err(RustMlError::InvalidParameter("NMF requires X >= 0".into()));
199 }
200 }
201
202 let (mut w, mut h) = match self.init {
203 NmfInit::Nndsvd => nndsvd_init(x, k)?,
204 NmfInit::Random => {
205 let mut rng = StdRng::seed_from_u64(self.seed);
206 let scale = (x.mean().unwrap_or(0.0).max(0.0) / k as f64)
207 .sqrt()
208 .max(1e-6);
209 let w = Array2::<f64>::from_shape_fn((n, k), |_| rng.gen::<f64>() * scale + 1e-6);
210 let h = Array2::<f64>::from_shape_fn((k, d), |_| rng.gen::<f64>() * scale + 1e-6);
211 (w, h)
212 }
213 };
214
215 let mut prev_err = f64::INFINITY;
216 let mut n_iter = 0;
217 for iter in 0..self.max_iter {
218 n_iter = iter + 1;
219
220 let wt_x = w.t().dot(x);
222 let wt_w = w.t().dot(&w);
223 let wt_w_h = wt_w.dot(&h);
224 for a in 0..k {
225 for b in 0..d {
226 h[[a, b]] *= wt_x[[a, b]] / wt_w_h[[a, b]].max(1e-12);
227 }
228 }
229 let h_ht = h.dot(&h.t());
231 let x_ht = x.dot(&h.t());
232 let w_h_ht = w.dot(&h_ht);
233 for r in 0..n {
234 for a in 0..k {
235 w[[r, a]] *= x_ht[[r, a]] / w_h_ht[[r, a]].max(1e-12);
236 }
237 }
238
239 let recon = w.dot(&h);
241 let mut err = 0.0;
242 for r in 0..n {
243 for c in 0..d {
244 let dv = x[[r, c]] - recon[[r, c]];
245 err += dv * dv;
246 }
247 }
248 err = err.sqrt();
249 if (prev_err - err).abs() / prev_err.max(1e-12) < self.tol {
250 prev_err = err;
251 break;
252 }
253 prev_err = err;
254 }
255
256 Ok(FittedNmf {
257 components: h,
258 reconstruction_err: prev_err,
259 n_iter,
260 })
261 }
262}
263
264impl FittedNmf {
265 pub fn transform(&self, x: &Array2<f64>, max_iter: usize) -> Result<Array2<f64>> {
267 let h = &self.components;
268 let n = x.nrows();
269 let k = h.nrows();
270 let mut rng = StdRng::seed_from_u64(7);
271 let scale = (x.mean().unwrap_or(0.0).max(0.0) / k as f64)
272 .sqrt()
273 .max(1e-6);
274 let mut w = Array2::<f64>::from_shape_fn((n, k), |_| rng.gen::<f64>() * scale + 1e-6);
275 let h_ht = h.dot(&h.t());
276 let x_ht = x.dot(&h.t());
277 for _ in 0..max_iter {
278 let w_h_ht = w.dot(&h_ht);
279 for r in 0..n {
280 for a in 0..k {
281 w[[r, a]] *= x_ht[[r, a]] / w_h_ht[[r, a]].max(1e-12);
282 }
283 }
284 }
285 Ok(w)
286 }
287
288 pub fn reconstruction_err(&self) -> f64 {
289 self.reconstruction_err
290 }
291 pub fn n_iter(&self) -> usize {
292 self.n_iter
293 }
294 pub fn components(&self) -> &Array2<f64> {
295 &self.components
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use ndarray::array;
303
304 #[test]
305 fn test_nmf_recovers_low_rank() {
306 let w_true = array![[1.0_f64, 0.0], [2.0, 0.5], [0.0, 1.0], [0.3, 2.0]];
308 let h_true = array![[1.0_f64, 2.0, 3.0], [0.5, 1.5, 0.5]];
309 let x = w_true.dot(&h_true);
310 let nmf = Nmf::new(2);
311 let fitted = nmf.fit(&x).unwrap();
312 let recon = nmf.max_iter.min(0); let _ = recon;
314 let recon = fitted.components.clone();
315 let w = fitted.transform(&x, 200).unwrap();
317 let approx = w.dot(&recon);
318 let mut err = 0.0;
319 for i in 0..x.nrows() {
320 for j in 0..x.ncols() {
321 err += (x[[i, j]] - approx[[i, j]]).powi(2);
322 }
323 }
324 let rel = err.sqrt() / x.iter().map(|v| v * v).sum::<f64>().sqrt();
325 assert!(rel < 0.05, "rel reconstruction error too large: {rel}");
326 }
327}