1use crate::matrix::DenseMatrix;
8use crate::Scalar;
9use faer::linalg::solvers::{Svd, ThinSvd};
10use faer::{ComplexField, Conjugate, Entity, Mat, RealField, SimpleEntity};
11use numra_core::LinalgError;
12
13pub struct SvdDecomposition<S: Scalar + Entity> {
17 svd: Svd<S>,
18 m: usize,
19 n: usize,
20}
21
22impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SvdDecomposition<S>
23where
24 S::Real: RealField,
25{
26 pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
28 let m = matrix.rows();
29 let n = matrix.cols();
30 let svd = Svd::new(matrix.as_faer());
31 Ok(Self { svd, m, n })
32 }
33
34 pub fn u(&self) -> DenseMatrix<S> {
36 DenseMatrix::from_faer(self.svd.u().to_owned())
37 }
38
39 pub fn v(&self) -> DenseMatrix<S> {
41 DenseMatrix::from_faer(self.svd.v().to_owned())
42 }
43
44 pub fn singular_values(&self) -> Vec<S> {
46 let s = self.svd.s_diagonal();
47 (0..s.nrows()).map(|i| s.read(i)).collect()
48 }
49
50 pub fn pseudoinverse(&self) -> DenseMatrix<S> {
52 DenseMatrix::from_faer(self.svd.pseudoinverse())
53 }
54
55 pub fn rank(&self, tol: S) -> usize {
57 let s = self.svd.s_diagonal();
58 (0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
59 }
60
61 pub fn cond(&self) -> S {
63 let s = self.svd.s_diagonal();
64 let k = s.nrows();
65 if k == 0 {
66 return S::ZERO;
67 }
68 let s_max = s.read(0).abs();
69 let s_min = s.read(k - 1).abs();
70 if s_min == S::ZERO {
71 return S::INFINITY;
72 }
73 s_max / s_min
74 }
75
76 pub fn nrows(&self) -> usize {
78 self.m
79 }
80
81 pub fn ncols(&self) -> usize {
83 self.n
84 }
85}
86
87pub struct ThinSvdDecomposition<S: Scalar + Entity> {
91 svd: ThinSvd<S>,
92 m: usize,
93 n: usize,
94}
95
96impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> ThinSvdDecomposition<S>
97where
98 S::Real: RealField,
99{
100 pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
102 let m = matrix.rows();
103 let n = matrix.cols();
104 let svd = ThinSvd::new(matrix.as_faer());
105 Ok(Self { svd, m, n })
106 }
107
108 pub fn u(&self) -> DenseMatrix<S> {
110 DenseMatrix::from_faer(self.svd.u().to_owned())
111 }
112
113 pub fn v(&self) -> DenseMatrix<S> {
115 DenseMatrix::from_faer(self.svd.v().to_owned())
116 }
117
118 pub fn singular_values(&self) -> Vec<S> {
120 let s = self.svd.s_diagonal();
121 (0..s.nrows()).map(|i| s.read(i)).collect()
122 }
123
124 pub fn pseudoinverse(&self) -> DenseMatrix<S> {
126 DenseMatrix::from_faer(self.svd.pseudoinverse())
127 }
128
129 pub fn rank(&self, tol: S) -> usize {
131 let s = self.svd.s_diagonal();
132 (0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
133 }
134
135 pub fn cond(&self) -> S {
137 let s = self.svd.s_diagonal();
138 let k = s.nrows();
139 if k == 0 {
140 return S::ZERO;
141 }
142 let s_max = s.read(0).abs();
143 let s_min = s.read(k - 1).abs();
144 if s_min == S::ZERO {
145 return S::INFINITY;
146 }
147 s_max / s_min
148 }
149
150 pub fn nrows(&self) -> usize {
152 self.m
153 }
154
155 pub fn ncols(&self) -> usize {
157 self.n
158 }
159}
160
161impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> DenseMatrix<S>
163where
164 S::Real: RealField,
165{
166 pub fn svd(&self) -> Result<SvdDecomposition<S>, LinalgError> {
168 SvdDecomposition::new(self)
169 }
170
171 pub fn thin_svd(&self) -> Result<ThinSvdDecomposition<S>, LinalgError> {
173 ThinSvdDecomposition::new(self)
174 }
175
176 pub fn singular_values(&self) -> Vec<S> {
178 let svd = ThinSvd::new(self.as_faer());
179 let s = svd.s_diagonal();
180 (0..s.nrows()).map(|i| s.read(i)).collect()
181 }
182
183 pub fn pinv(&self) -> Result<DenseMatrix<S>, LinalgError> {
185 let svd = ThinSvdDecomposition::new(self)?;
186 Ok(svd.pseudoinverse())
187 }
188
189 pub fn cond(&self) -> S {
191 let svd = ThinSvd::new(self.as_faer());
192 let s = svd.s_diagonal();
193 let k = s.nrows();
194 if k == 0 {
195 return S::ZERO;
196 }
197 let s_max = s.read(0).abs();
198 let s_min = s.read(k - 1).abs();
199 if s_min == S::ZERO {
200 return S::INFINITY;
201 }
202 s_max / s_min
203 }
204
205 pub fn rank(&self, tol: S) -> usize {
207 let svd = ThinSvd::new(self.as_faer());
208 let s = svd.s_diagonal();
209 (0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
210 }
211
212 pub fn lstsq(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
214 if b.len() != self.rows() {
215 return Err(LinalgError::DimensionMismatch {
216 expected: (self.rows(), 1),
217 actual: (b.len(), 1),
218 });
219 }
220 let pinv = self.pinv()?;
221 let mut b_mat = Mat::zeros(self.rows(), 1);
223 for (i, &val) in b.iter().enumerate() {
224 b_mat.write(i, 0, val);
225 }
226 let pinv_ref = pinv.as_faer();
227 let result = pinv_ref * b_mat.as_ref();
228 let x: Vec<S> = (0..self.cols()).map(|i| result.read(i, 0)).collect();
229 Ok(x)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::Matrix;
237
238 #[test]
239 fn test_svd_diagonal() {
240 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
242 a.set(0, 0, 3.0);
243 a.set(1, 1, 1.0);
244 a.set(2, 2, 2.0);
245
246 let svd = SvdDecomposition::new(&a).unwrap();
247 let s = svd.singular_values();
248
249 assert!((s[0] - 3.0).abs() < 1e-10);
250 assert!((s[1] - 2.0).abs() < 1e-10);
251 assert!((s[2] - 1.0).abs() < 1e-10);
252 }
253
254 #[test]
255 fn test_svd_rectangular_reconstruction() {
256 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
258 a.set(0, 0, 1.0);
259 a.set(0, 1, 2.0);
260 a.set(1, 0, 3.0);
261 a.set(1, 1, 4.0);
262 a.set(2, 0, 5.0);
263 a.set(2, 1, 6.0);
264
265 let svd = SvdDecomposition::new(&a).unwrap();
266 let u = svd.u();
267 let v = svd.v();
268 let s = svd.singular_values();
269
270 let m = a.rows();
272 let n = a.cols();
273 let k = s.len();
274 for i in 0..m {
275 for j in 0..n {
276 let mut val = 0.0;
277 for p in 0..k {
278 val += u.get(i, p) * s[p] * v.get(j, p);
279 }
280 assert!(
281 (val - a.get(i, j)).abs() < 1e-10,
282 "Reconstruction failed at ({}, {}): {} vs {}",
283 i,
284 j,
285 val,
286 a.get(i, j)
287 );
288 }
289 }
290 }
291
292 #[test]
293 fn test_pseudoinverse() {
294 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
296 a.set(0, 0, 1.0);
297 a.set(0, 1, 2.0);
298 a.set(1, 0, 3.0);
299 a.set(1, 1, 4.0);
300 a.set(2, 0, 5.0);
301 a.set(2, 1, 6.0);
302
303 let svd = SvdDecomposition::new(&a).unwrap();
304 let pinv = svd.pseudoinverse();
305
306 let m = a.rows();
310 let n = a.cols();
311 assert_eq!(pinv.rows(), n); assert_eq!(pinv.cols(), m); let mut a_pinv: DenseMatrix<f64> = DenseMatrix::zeros(m, m);
316 for i in 0..m {
317 for j in 0..m {
318 let mut val = 0.0;
319 for k in 0..n {
320 val += a.get(i, k) * pinv.get(k, j);
321 }
322 a_pinv.set(i, j, val);
323 }
324 }
325
326 for i in 0..m {
328 for j in 0..n {
329 let mut val = 0.0;
330 for k in 0..m {
331 val += a_pinv.get(i, k) * a.get(k, j);
332 }
333 assert!(
334 (val - a.get(i, j)).abs() < 1e-10,
335 "A * pinv(A) * A != A at ({}, {}): {} vs {}",
336 i,
337 j,
338 val,
339 a.get(i, j)
340 );
341 }
342 }
343 }
344
345 #[test]
346 fn test_condition_number_identity() {
347 let a: DenseMatrix<f64> = DenseMatrix::identity(4);
348 let svd = SvdDecomposition::new(&a).unwrap();
349 assert!((svd.cond() - 1.0).abs() < 1e-10);
350 }
351
352 #[test]
353 fn test_rank_deficient() {
354 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
356 a.set(0, 0, 1.0);
357 a.set(0, 1, 2.0);
358 a.set(1, 0, 2.0);
359 a.set(1, 1, 4.0);
360
361 let svd = SvdDecomposition::new(&a).unwrap();
362 assert_eq!(svd.rank(1e-10), 1);
363 }
364
365 #[test]
366 fn test_lstsq() {
367 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
369 a.set(0, 0, 1.0);
370 a.set(0, 1, 1.0);
371 a.set(1, 0, 1.0);
372 a.set(1, 1, 2.0);
373 a.set(2, 0, 1.0);
374 a.set(2, 1, 3.0);
375
376 let b = vec![1.0, 2.0, 2.0];
377 let x = a.lstsq(&b).unwrap();
378
379 assert!((x[0] - 2.0 / 3.0).abs() < 1e-10);
381 assert!((x[1] - 0.5).abs() < 1e-10);
382 }
383
384 #[test]
385 fn test_thin_svd() {
386 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(4, 2);
387 a.set(0, 0, 1.0);
388 a.set(0, 1, 0.0);
389 a.set(1, 0, 0.0);
390 a.set(1, 1, 2.0);
391 a.set(2, 0, 0.0);
392 a.set(2, 1, 0.0);
393 a.set(3, 0, 0.0);
394 a.set(3, 1, 0.0);
395
396 let svd = ThinSvdDecomposition::new(&a).unwrap();
397 let u = svd.u();
398 let v = svd.v();
399 let s = svd.singular_values();
400
401 assert_eq!(u.rows(), 4);
403 assert_eq!(u.cols(), 2);
404 assert_eq!(v.rows(), 2);
405 assert_eq!(v.cols(), 2);
406
407 assert!((s[0] - 2.0).abs() < 1e-10);
408 assert!((s[1] - 1.0).abs() < 1e-10);
409 }
410
411 #[test]
412 fn test_convenience_methods() {
413 let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
414 a.set(0, 0, 3.0);
415 a.set(1, 1, 1.0);
416
417 let s = a.singular_values();
418 assert!((s[0] - 3.0).abs() < 1e-10);
419 assert!((s[1] - 1.0).abs() < 1e-10);
420
421 assert!((a.cond() - 3.0).abs() < 1e-10);
422 assert_eq!(a.rank(1e-10), 2);
423 }
424
425 #[test]
426 fn test_svd_f32() {
427 let mut a: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
428 a.set(0, 0, 3.0);
429 a.set(0, 1, 0.0);
430 a.set(1, 0, 0.0);
431 a.set(1, 1, 2.0);
432
433 let svd = SvdDecomposition::new(&a).unwrap();
434 let s = svd.singular_values();
435
436 assert!((s[0] - 3.0).abs() < 1e-5);
437 assert!((s[1] - 2.0).abs() < 1e-5);
438 }
439}