1use std::{
2 iter::Sum,
3 ops::{Div, Sub},
4};
5
6use nalgebra::{
7 allocator::Allocator, ComplexField, Const, DMatrix, DefaultAllocator, Dim, Dyn, Matrix,
8 SMatrix, Scalar,
9};
10use rand::{distributions::Standard, prelude::Distribution};
11
12pub fn non_negative_matrix_factorization_generic<T, R: Dim, C: Dim, K: Dim>(
14 matrix: &Matrix<T, R, C, <DefaultAllocator as Allocator<T, R, C>>::Buffer>,
15 max_iter: usize,
16 tolerance: T,
17 nrows: R,
18 ncols: C,
19 k: K,
20) -> (
21 Matrix<T, R, K, <DefaultAllocator as Allocator<T, R, K>>::Buffer>,
22 Matrix<T, K, C, <DefaultAllocator as Allocator<T, K, C>>::Buffer>,
23)
24where
25 T: Scalar
26 + ComplexField<RealField = T>
27 + Sub<T>
28 + Clone
29 + Copy
30 + Sum<T>
31 + PartialOrd
32 + Div<T, Output = T>,
33 Standard: Distribution<T>,
34 DefaultAllocator: Allocator<T, R, C>,
35 DefaultAllocator: Allocator<T, R, K>,
36 DefaultAllocator: Allocator<T, K, C>,
37 DefaultAllocator: Allocator<T, K, R>,
38 DefaultAllocator: Allocator<T, C, K>,
39{
40 let mut w: Matrix<T, R, K, _> = Matrix::new_random_generic(nrows, k);
42 let mut h: Matrix<T, K, C, _> = Matrix::new_random_generic(k, ncols);
43
44 let mut w_transpose: Matrix<T, K, R, _> = Matrix::zeros_generic(k, nrows);
45 let mut h_transpose: Matrix<T, C, K, _> = Matrix::zeros_generic(ncols, k);
46
47 let mut wh: Matrix<T, R, C, _> = &w * &h;
48
49 let mut wt_v: Matrix<T, K, C, _> = Matrix::zeros_generic(k, ncols);
50 let mut wt_w_h: Matrix<T, K, C, _> = Matrix::zeros_generic(k, ncols);
51
52 let mut v_ht: Matrix<T, R, K, _> = Matrix::zeros_generic(nrows, k);
53 let mut w_h_ht: Matrix<T, R, K, _> = Matrix::zeros_generic(nrows, k);
54
55 for _ in 0..max_iter {
57 let cost = matrix
59 .iter()
60 .zip(wh.iter())
61 .map(|(a, b)| (*a - *b).powi(2))
62 .sum::<T>();
63
64 if cost < tolerance {
65 break;
66 }
67
68 w.transpose_to(&mut w_transpose);
70
71 w_transpose.mul_to(matrix, &mut wt_v);
74 w_transpose.mul_to(&wh, &mut wt_w_h);
76
77 h.iter_mut()
79 .zip(wt_v.iter().zip(wt_w_h.iter()))
80 .map(|(h_old, (num, den))| *h_old *= *num / *den)
81 .last()
82 .unwrap();
83
84 h.transpose_to(&mut h_transpose);
86
87 w.mul_to(&h, &mut wh);
89
90 matrix.mul_to(&h_transpose, &mut v_ht);
92 wh.mul_to(&h_transpose, &mut w_h_ht);
94
95 w.iter_mut()
97 .zip(v_ht.iter().zip(w_h_ht.iter()))
98 .map(|(w_old, (num, den))| *w_old *= *num / *den)
99 .last()
100 .unwrap();
101
102 w.mul_to(&h, &mut wh);
103 }
104
105 (w, h)
106}
107
108pub fn non_negative_matrix_factorization_static<T, const R: usize, const C: usize, const K: usize>(
110 matrix: &SMatrix<T, R, C>,
111 max_iter: usize,
112 tolerance: T,
113) -> (SMatrix<T, R, K>, SMatrix<T, K, C>)
114where
115 T: Scalar
116 + ComplexField<RealField = T>
117 + Sub<T>
118 + Clone
119 + Copy
120 + Sum<T>
121 + PartialOrd
122 + Div<T, Output = T>,
123 Standard: Distribution<T>,
124{
125 non_negative_matrix_factorization_generic(
126 matrix, max_iter, tolerance, Const::<R>, Const::<C>, Const::<K>,
127 )
128}
129
130pub fn non_negative_matrix_factorization_dyn<T>(
132 matrix: &DMatrix<T>,
133 max_iter: usize,
134 tolerance: T,
135 k: usize,
136) -> (DMatrix<T>, DMatrix<T>)
137where
138 T: Scalar
139 + ComplexField<RealField = T>
140 + Sub<T>
141 + Clone
142 + Copy
143 + Sum<T>
144 + PartialOrd
145 + Div<T, Output = T>,
146 Standard: Distribution<T>,
147{
148 let (nrows, ncols) = matrix.shape();
149 non_negative_matrix_factorization_generic(
150 matrix,
151 max_iter,
152 tolerance,
153 Dyn(nrows),
154 Dyn(ncols),
155 Dyn(k),
156 )
157}
158
159#[cfg(test)]
160mod tests {
161 use nalgebra::{dmatrix, matrix, SMatrix};
162
163 use crate::*;
164
165 #[test]
166 fn test_static_and_dyn() {
167 let matrix: SMatrix<f64, 4, 5> = matrix![
168 1.0, 2.0, 0.0, 30.0, 1.5;
169 0.0, 3.0, 1.0, 30.0, 1.6;
170 5.0, 1.0, 10.0, 30.0, 1.7;
171 3.0, 1.0, 10.0, 30.0, 1.7
172 ];
173
174 let (w, h) = non_negative_matrix_factorization_static::<_, 4, 5, 3>(&matrix, 1000, 0.01);
175 println!("{w}\n{h}");
176 let prediction = w * h;
177
178 println!("Matrix: {}\n Predic: {}", matrix, prediction);
179 assert!(matrix.relative_eq(&(w * h), 0.5, 0.5));
180
181 let d_matrix: DMatrix<f64> = dmatrix![
182 1.0, 2.0, 0.0, 30.0, 1.5;
183 0.0, 3.0, 1.0, 30.0, 1.6;
184 5.0, 1.0, 10.0, 30.0, 1.7;
185 3.0, 1.0, 10.0, 30.0, 1.7
186 ];
187 let (w_dyn, h_dyn) = non_negative_matrix_factorization_dyn::<f64>(&d_matrix, 1000, 0.01, 3);
188 println!("{w}\n{h}");
189 let prediction = w * h;
190
191 println!("Matrix: {}\n Predic: {}", matrix, prediction);
192 assert!(d_matrix.relative_eq(&(w_dyn.clone() * h_dyn.clone()), 0.5, 0.5));
193 }
194}