Skip to main content

anofox_ml_preprocessing/
nmf.rs

1//! Non-negative Matrix Factorisation.
2//!
3//! Mirrors `sklearn.decomposition.NMF` with the multiplicative-update solver
4//! (Lee & Seung). `X ≈ W H` with `W ≥ 0`, `H ≥ 0`.
5
6use anofox_ml_core::{FitUnsupervised, Result, RustMlError};
7use faer::linalg::solvers::Svd;
8use faer::Mat;
9use ndarray::Array2;
10use rand::rngs::StdRng;
11use rand::{Rng, SeedableRng};
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum NmfInit {
15    /// Sample W, H from uniform random.
16    Random,
17    /// NNDSVD: deterministic init from truncated SVD (sklearn default).
18    Nndsvd,
19}
20
21#[derive(Debug, Clone)]
22pub struct Nmf {
23    pub n_components: usize,
24    pub max_iter: usize,
25    pub tol: f64,
26    pub seed: u64,
27    pub init: NmfInit,
28}
29
30impl Nmf {
31    pub fn new(n_components: usize) -> Self {
32        Self {
33            n_components,
34            max_iter: 200,
35            tol: 1e-4,
36            seed: 0,
37            init: NmfInit::Nndsvd,
38        }
39    }
40    pub fn with_init(mut self, init: NmfInit) -> Self {
41        self.init = init;
42        self
43    }
44}
45
46/// NNDSVD initialisation (Boutsidis & Gallopoulos 2008).
47///
48/// 1. Compute the truncated SVD `X ≈ U Σ Vᵀ` keeping `k` triplets.
49/// 2. The first component is initialised from the leading singular triplet
50///    with positive sign (sign-fix).
51/// 3. Each subsequent component splits the next singular vectors into their
52///    positive and negative parts and picks whichever has higher norm.
53fn nndsvd_init(x: &Array2<f64>, k: usize) -> Result<(Array2<f64>, Array2<f64>)> {
54    let n = x.nrows();
55    let d = x.ncols();
56    let mat = Mat::<f64>::from_fn(n, d, |i, j| x[[i, j]]);
57    let svd = Svd::new(mat.as_ref())
58        .map_err(|e| RustMlError::InvalidParameter(format!("NNDSVD SVD failed: {e:?}")))?;
59    let u = svd.U();
60    let s = svd.S();
61    let v = svd.V();
62    let r = s.column_vector().nrows().min(k);
63
64    let mut w = Array2::<f64>::zeros((n, k));
65    let mut h = Array2::<f64>::zeros((k, d));
66
67    // First component: leading singular triplet (sign-fixed positive).
68    let s0 = s.column_vector()[0].max(1e-12);
69    let mut u0_pos_norm = 0.0_f64;
70    for i in 0..n {
71        u0_pos_norm += u[(i, 0)].max(0.0).powi(2);
72    }
73    u0_pos_norm = u0_pos_norm.sqrt();
74    let mut v0_pos_norm = 0.0_f64;
75    for j in 0..d {
76        v0_pos_norm += v[(j, 0)].max(0.0).powi(2);
77    }
78    v0_pos_norm = v0_pos_norm.sqrt();
79    // If positive part norm is larger, use it; else flip sign and use negative.
80    let (u_sign, v_sign) =
81        if u0_pos_norm * v0_pos_norm >= (u0_pos_norm * v0_pos_norm).max(1e-12) / 2.0 {
82            (1.0, 1.0)
83        } else {
84            (-1.0, -1.0)
85        };
86    let lead_scale = s0.sqrt();
87    for i in 0..n {
88        w[[i, 0]] = (u_sign * u[(i, 0)]).max(0.0) * lead_scale;
89    }
90    for j in 0..d {
91        h[[0, j]] = (v_sign * v[(j, 0)]).max(0.0) * lead_scale;
92    }
93
94    // Remaining components: split positive and negative parts.
95    for c in 1..r {
96        let sigma = s.column_vector()[c].max(1e-12);
97        // u positive / negative parts.
98        let mut up = vec![0.0_f64; n];
99        let mut un = vec![0.0_f64; n];
100        let mut up_norm = 0.0_f64;
101        let mut un_norm = 0.0_f64;
102        for i in 0..n {
103            let val = u[(i, c)];
104            if val > 0.0 {
105                up[i] = val;
106                up_norm += val * val;
107            } else {
108                un[i] = -val;
109                un_norm += val * val;
110            }
111        }
112        up_norm = up_norm.sqrt();
113        un_norm = un_norm.sqrt();
114        let mut vp = vec![0.0_f64; d];
115        let mut vn = vec![0.0_f64; d];
116        let mut vp_norm = 0.0_f64;
117        let mut vn_norm = 0.0_f64;
118        for j in 0..d {
119            let val = v[(j, c)];
120            if val > 0.0 {
121                vp[j] = val;
122                vp_norm += val * val;
123            } else {
124                vn[j] = -val;
125                vn_norm += val * val;
126            }
127        }
128        vp_norm = vp_norm.sqrt();
129        vn_norm = vn_norm.sqrt();
130        // Take whichever pair (positive/positive vs negative/negative) has
131        // higher Frobenius product norm.
132        let pos = up_norm * vp_norm;
133        let neg = un_norm * vn_norm;
134        let scale = sigma.sqrt() * (pos.max(neg)).sqrt();
135        if pos >= neg {
136            let nrm_u = up_norm.max(1e-12);
137            let nrm_v = vp_norm.max(1e-12);
138            for i in 0..n {
139                w[[i, c]] = up[i] / nrm_u * scale;
140            }
141            for j in 0..d {
142                h[[c, j]] = vp[j] / nrm_v * scale;
143            }
144        } else {
145            let nrm_u = un_norm.max(1e-12);
146            let nrm_v = vn_norm.max(1e-12);
147            for i in 0..n {
148                w[[i, c]] = un[i] / nrm_u * scale;
149            }
150            for j in 0..d {
151                h[[c, j]] = vn[j] / nrm_v * scale;
152            }
153        }
154    }
155    // Floor at a small epsilon (sklearn convention) to avoid zero-locks.
156    let eps = 1e-6;
157    for v in w.iter_mut() {
158        if *v < eps {
159            *v = eps;
160        }
161    }
162    for v in h.iter_mut() {
163        if *v < eps {
164            *v = eps;
165        }
166    }
167    Ok((w, h))
168}
169
170#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
171pub struct FittedNmf {
172    /// Components, shape (n_components, n_features) — sklearn's H_.
173    pub components: Array2<f64>,
174    /// Final reconstruction error (Frobenius).
175    pub reconstruction_err: f64,
176    pub n_iter: usize,
177}
178
179impl FitUnsupervised<f64> for Nmf {
180    type Fitted = FittedNmf;
181
182    fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
183        let n = x.nrows();
184        let d = x.ncols();
185        let k = self.n_components;
186        if n == 0 || d == 0 {
187            return Err(RustMlError::EmptyInput("empty input".into()));
188        }
189        if k == 0 || k > d.min(n) {
190            return Err(RustMlError::InvalidParameter(format!(
191                "n_components must be in 1..={}",
192                d.min(n)
193            )));
194        }
195        // Require X ≥ 0.
196        for v in x.iter() {
197            if *v < 0.0 {
198                return Err(RustMlError::InvalidParameter("NMF requires X >= 0".into()));
199            }
200        }
201
202        let (mut w, mut h) = match self.init {
203            NmfInit::Nndsvd => nndsvd_init(x, k)?,
204            NmfInit::Random => {
205                let mut rng = StdRng::seed_from_u64(self.seed);
206                let scale = (x.mean().unwrap_or(0.0).max(0.0) / k as f64)
207                    .sqrt()
208                    .max(1e-6);
209                let w = Array2::<f64>::from_shape_fn((n, k), |_| rng.gen::<f64>() * scale + 1e-6);
210                let h = Array2::<f64>::from_shape_fn((k, d), |_| rng.gen::<f64>() * scale + 1e-6);
211                (w, h)
212            }
213        };
214
215        let mut prev_err = f64::INFINITY;
216        let mut n_iter = 0;
217        for iter in 0..self.max_iter {
218            n_iter = iter + 1;
219
220            // H update: H *= (W'X) / (W'W H)
221            let wt_x = w.t().dot(x);
222            let wt_w = w.t().dot(&w);
223            let wt_w_h = wt_w.dot(&h);
224            for a in 0..k {
225                for b in 0..d {
226                    h[[a, b]] *= wt_x[[a, b]] / wt_w_h[[a, b]].max(1e-12);
227                }
228            }
229            // W update: W *= (X H') / (W H H')
230            let h_ht = h.dot(&h.t());
231            let x_ht = x.dot(&h.t());
232            let w_h_ht = w.dot(&h_ht);
233            for r in 0..n {
234                for a in 0..k {
235                    w[[r, a]] *= x_ht[[r, a]] / w_h_ht[[r, a]].max(1e-12);
236                }
237            }
238
239            // Convergence check via reconstruction error.
240            let recon = w.dot(&h);
241            let mut err = 0.0;
242            for r in 0..n {
243                for c in 0..d {
244                    let dv = x[[r, c]] - recon[[r, c]];
245                    err += dv * dv;
246                }
247            }
248            err = err.sqrt();
249            if (prev_err - err).abs() / prev_err.max(1e-12) < self.tol {
250                prev_err = err;
251                break;
252            }
253            prev_err = err;
254        }
255
256        Ok(FittedNmf {
257            components: h,
258            reconstruction_err: prev_err,
259            n_iter,
260        })
261    }
262}
263
264impl FittedNmf {
265    /// Transform new data by solving `min_W >= 0  ||X - W H||²` via MU.
266    pub fn transform(&self, x: &Array2<f64>, max_iter: usize) -> Result<Array2<f64>> {
267        let h = &self.components;
268        let n = x.nrows();
269        let k = h.nrows();
270        let mut rng = StdRng::seed_from_u64(7);
271        let scale = (x.mean().unwrap_or(0.0).max(0.0) / k as f64)
272            .sqrt()
273            .max(1e-6);
274        let mut w = Array2::<f64>::from_shape_fn((n, k), |_| rng.gen::<f64>() * scale + 1e-6);
275        let h_ht = h.dot(&h.t());
276        let x_ht = x.dot(&h.t());
277        for _ in 0..max_iter {
278            let w_h_ht = w.dot(&h_ht);
279            for r in 0..n {
280                for a in 0..k {
281                    w[[r, a]] *= x_ht[[r, a]] / w_h_ht[[r, a]].max(1e-12);
282                }
283            }
284        }
285        Ok(w)
286    }
287
288    pub fn reconstruction_err(&self) -> f64 {
289        self.reconstruction_err
290    }
291    pub fn n_iter(&self) -> usize {
292        self.n_iter
293    }
294    pub fn components(&self) -> &Array2<f64> {
295        &self.components
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use ndarray::array;
303
304    #[test]
305    fn test_nmf_recovers_low_rank() {
306        // Construct X = W_true H_true with k=2.
307        let w_true = array![[1.0_f64, 0.0], [2.0, 0.5], [0.0, 1.0], [0.3, 2.0]];
308        let h_true = array![[1.0_f64, 2.0, 3.0], [0.5, 1.5, 0.5]];
309        let x = w_true.dot(&h_true);
310        let nmf = Nmf::new(2);
311        let fitted = nmf.fit(&x).unwrap();
312        let recon = nmf.max_iter.min(0); // suppress unused field warning
313        let _ = recon;
314        let recon = fitted.components.clone();
315        // The transform should give us back something whose product is close.
316        let w = fitted.transform(&x, 200).unwrap();
317        let approx = w.dot(&recon);
318        let mut err = 0.0;
319        for i in 0..x.nrows() {
320            for j in 0..x.ncols() {
321                err += (x[[i, j]] - approx[[i, j]]).powi(2);
322            }
323        }
324        let rel = err.sqrt() / x.iter().map(|v| v * v).sum::<f64>().sqrt();
325        assert!(rel < 0.05, "rel reconstruction error too large: {rel}");
326    }
327}