nnmf_nalgebra/
lib.rs

1use std::{
2    iter::Sum,
3    ops::{Div, Sub},
4};
5
6use nalgebra::{
7    allocator::Allocator, ComplexField, Const, DMatrix, DefaultAllocator, Dim, Dyn, Matrix,
8    SMatrix, Scalar,
9};
10use rand::{distributions::Standard, prelude::Distribution};
11
12/// Does non-negative matrix factorization using multiplicative static update rule as defined on the [wikipedia](https://en.wikipedia.org/wiki/Non-negative_matrix_factorization#Algorithms) page. Is generic over the allocation strategy of the Matrix.
13pub fn non_negative_matrix_factorization_generic<T, R: Dim, C: Dim, K: Dim>(
14    matrix: &Matrix<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer>,
15    max_iter: usize,
16    tolerance: T,
17    nrows: R,
18    ncols: C,
19    k: K,
20) -> (
21    Matrix<T, R, K, <DefaultAllocator as Allocator<T, R, K>>::Buffer>,
22    Matrix<T, K, C, <DefaultAllocator as Allocator<T, K, C>>::Buffer>,
23)
24where
25    T: Scalar
26        + ComplexField<RealField = T>
27        + Sub<T>
28        + Clone
29        + Copy
30        + Sum<T>
31        + PartialOrd
32        + Div<T, Output = T>,
33    Standard: Distribution<T>,
34    DefaultAllocator: Allocator<T, R, C>,
35    DefaultAllocator: Allocator<T, R, K>,
36    DefaultAllocator: Allocator<T, K, C>,
37    DefaultAllocator: Allocator<T, K, R>,
38    DefaultAllocator: Allocator<T, C, K>,
39{
40    // Two reduced-dimension vectors we are trying to calculate, each field is initialized to [0, 1)
41    let mut w: Matrix<T, R, K, _> = Matrix::new_random_generic(nrows, k);
42    let mut h: Matrix<T, K, C, _> = Matrix::new_random_generic(k, ncols);
43
44    let mut w_transpose: Matrix<T, K, R, _> = Matrix::zeros_generic(k, nrows);
45    let mut h_transpose: Matrix<T, C, K, _> = Matrix::zeros_generic(ncols, k);
46
47    let mut wh: Matrix<T, R, C, _> = &w * &h;
48
49    let mut wt_v: Matrix<T, K, C, _> = Matrix::zeros_generic(k, ncols);
50    let mut wt_w_h: Matrix<T, K, C, _> = Matrix::zeros_generic(k, ncols);
51
52    let mut v_ht: Matrix<T, R, K, _> = Matrix::zeros_generic(nrows, k);
53    let mut w_h_ht: Matrix<T, R, K, _> = Matrix::zeros_generic(nrows, k);
54
55    // Repeat until convergence
56    for _ in 0..max_iter {
57        // Return if cost is less than tolerance
58        let cost = matrix
59            .iter()
60            .zip(wh.iter())
61            .map(|(a, b)| (*a - *b).powi(2))
62            .sum::<T>();
63
64        if cost < tolerance {
65            break;
66        }
67
68        // Calculate W^T
69        w.transpose_to(&mut w_transpose);
70
71        // let wt_v: Matrix<T, K, C, _> = &w_transpose * matrix;
72        // Numerator = W^T * V
73        w_transpose.mul_to(matrix, &mut wt_v);
74        // Denominator = W^T * W * H
75        w_transpose.mul_to(&wh, &mut wt_w_h);
76
77        // Component-wise update of h
78        h.iter_mut()
79            .zip(wt_v.iter().zip(wt_w_h.iter()))
80            .map(|(h_old, (num, den))| *h_old *= *num / *den)
81            .last()
82            .unwrap();
83
84        // Calculate H^T
85        h.transpose_to(&mut h_transpose);
86
87        // WH = W * H
88        w.mul_to(&h, &mut wh);
89
90        // Numerator = V * H^T
91        matrix.mul_to(&h_transpose, &mut v_ht);
92        // Denominator = W * H * H^T
93        wh.mul_to(&h_transpose, &mut w_h_ht);
94
95        // Component-wise update of w
96        w.iter_mut()
97            .zip(v_ht.iter().zip(w_h_ht.iter()))
98            .map(|(w_old, (num, den))| *w_old *= *num / *den)
99            .last()
100            .unwrap();
101
102        w.mul_to(&h, &mut wh);
103    }
104
105    (w, h)
106}
107
108/// Does non-negative matrix factorization on a statically-sized matrix (SMatrix)
109pub fn non_negative_matrix_factorization_static<T, const R: usize, const C: usize, const K: usize>(
110    matrix: &SMatrix<T, R, C>,
111    max_iter: usize,
112    tolerance: T,
113) -> (SMatrix<T, R, K>, SMatrix<T, K, C>)
114where
115    T: Scalar
116        + ComplexField<RealField = T>
117        + Sub<T>
118        + Clone
119        + Copy
120        + Sum<T>
121        + PartialOrd
122        + Div<T, Output = T>,
123    Standard: Distribution<T>,
124{
125    non_negative_matrix_factorization_generic(
126        matrix, max_iter, tolerance, Const::<R>, Const::<C>, Const::<K>,
127    )
128}
129
130/// Does non-negative matrix factorization on a dynamically-sized matrix (DMatrix)
131pub fn non_negative_matrix_factorization_dyn<T>(
132    matrix: &DMatrix<T>,
133    max_iter: usize,
134    tolerance: T,
135    k: usize,
136) -> (DMatrix<T>, DMatrix<T>)
137where
138    T: Scalar
139        + ComplexField<RealField = T>
140        + Sub<T>
141        + Clone
142        + Copy
143        + Sum<T>
144        + PartialOrd
145        + Div<T, Output = T>,
146    Standard: Distribution<T>,
147{
148    let (nrows, ncols) = matrix.shape();
149    non_negative_matrix_factorization_generic(
150        matrix,
151        max_iter,
152        tolerance,
153        Dyn(nrows),
154        Dyn(ncols),
155        Dyn(k),
156    )
157}
158
159#[cfg(test)]
160mod tests {
161    use nalgebra::{dmatrix, matrix, SMatrix};
162
163    use crate::*;
164
165    #[test]
166    fn test_static_and_dyn() {
167        let matrix: SMatrix<f64, 4, 5> = matrix![
168            1.0, 2.0, 0.0, 30.0, 1.5;
169            0.0, 3.0, 1.0, 30.0, 1.6;
170            5.0, 1.0, 10.0, 30.0, 1.7;
171            3.0, 1.0, 10.0, 30.0, 1.7
172        ];
173
174        let (w, h) = non_negative_matrix_factorization_static::<_, 4, 5, 3>(&matrix, 1000, 0.01);
175        println!("{w}\n{h}");
176        let prediction = w * h;
177
178        println!("Matrix: {}\n Predic: {}", matrix, prediction);
179        assert!(matrix.relative_eq(&(w * h), 0.5, 0.5));
180
181        let d_matrix: DMatrix<f64> = dmatrix![
182            1.0, 2.0, 0.0, 30.0, 1.5;
183            0.0, 3.0, 1.0, 30.0, 1.6;
184            5.0, 1.0, 10.0, 30.0, 1.7;
185            3.0, 1.0, 10.0, 30.0, 1.7
186        ];
187        let (w_dyn, h_dyn) = non_negative_matrix_factorization_dyn::<f64>(&d_matrix, 1000, 0.01, 3);
188        println!("{w}\n{h}");
189        let prediction = w * h;
190
191        println!("Matrix: {}\n Predic: {}", matrix, prediction);
192        assert!(d_matrix.relative_eq(&(w_dyn.clone() * h_dyn.clone()), 0.5, 0.5));
193    }
194}