1#![allow(non_snake_case)]
2
3use crate::algebra::*;
4use core::cmp::min;
5use std::iter::zip;
6
7#[allow(dead_code)]
8#[derive(PartialEq, Eq, Copy, Clone, Default)]
9pub(crate) enum SVDEngineAlgorithm {
10 #[default]
11 DivideAndConquer,
12 QRDecomposition,
13}
14
15pub(crate) struct SVDBlasWorkVectors<T> {
16 pub work: Vec<T>,
17 pub iwork: Vec<i32>,
18}
19
20impl<T: FloatT> Default for SVDBlasWorkVectors<T> {
21 fn default() -> Self {
22 let work = vec![T::one()];
26 let iwork = vec![1];
27 Self { work, iwork }
28 }
29}
30
31pub(crate) struct SVDEngine<T> {
32 pub s: Vec<T>,
34
35 pub U: Matrix<T>,
39 pub Vt: Matrix<T>,
40
41 pub blas: Option<SVDBlasWorkVectors<T>>,
43
44 pub algorithm: SVDEngineAlgorithm,
46}
47
48impl<T> SVDEngine<T>
49where
50 T: FloatT,
51{
52 pub fn new(size: (usize, usize)) -> Self {
53 let (m, n) = size;
54 let s = vec![T::zero(); min(m, n)];
55 let U = Matrix::<T>::zeros((m, min(m, n)));
56 let Vt = Matrix::<T>::zeros((min(m, n), n));
57 let blas = None;
58 let algorithm = SVDEngineAlgorithm::default();
59 Self {
60 s,
61 U,
62 Vt,
63 blas,
64 algorithm,
65 }
66 }
67
68 pub fn resize(&mut self, size: (usize, usize)) {
69 let (m, n) = size;
70 self.s.resize(min(m, n), T::zero());
71 self.U.resize((m, min(m, n)));
72 self.Vt.resize((min(m, n), n));
73 }
74
75 fn checkdim_factor<S>(
76 &mut self,
77 A: &mut DenseStorageMatrix<S, T>,
78 ) -> Result<(), DenseFactorizationError>
79 where
80 S: AsMut<[T]> + AsRef<[T]>,
81 {
82 let (m, n) = A.size();
83
84 if self.U.nrows() != m || self.Vt.ncols() != n {
85 Err(DenseFactorizationError::IncompatibleDimension)
86 } else {
87 Ok(())
88 }
89 }
90
91 fn checkdim_solve<S>(
92 &mut self,
93 B: &mut DenseStorageMatrix<S, T>,
94 ) -> Result<(), DenseFactorizationError>
95 where
96 S: AsMut<[T]> + AsRef<[T]>,
97 {
98 let m = self.U.nrows();
100 let n = self.Vt.ncols();
101
102 if m != n {
109 return Err(DenseFactorizationError::IncompatibleDimension);
110 }
111
112 if B.nrows() != m {
114 return Err(DenseFactorizationError::IncompatibleDimension);
115 }
116 Ok(())
117 }
118}
119
120impl<T> FactorSVD<T> for SVDEngine<T>
121where
122 T: FloatT,
123{
124 fn factor<S>(&mut self, A: &mut DenseStorageMatrix<S, T>) -> Result<(), DenseFactorizationError>
125 where
126 S: AsMut<[T]> + AsRef<[T]>,
127 {
128 self.checkdim_factor(A)?;
129
130 if A.is_square() {
132 match A.nrows() {
133 1 => self.factor1(A),
134 2 => self.factor2(A),
135 3 => self.factor3(A),
136 _ => self.factorblas(A),
137 }
138 } else {
139 self.factorblas(A)
141 }
142 }
143
144 fn solve<S>(&mut self, B: &mut DenseStorageMatrix<S, T>)
145 where
146 S: AsMut<[T]> + AsRef<[T]>,
147 {
148 self.checkdim_solve(B).unwrap();
153
154 self.solveblas(B);
163 }
164}
165
166impl<T> SVDEngine<T>
168where
169 T: FloatT,
170{
171 fn factor1<S>(
172 &mut self,
173 A: &mut DenseStorageMatrix<S, T>,
174 ) -> Result<(), DenseFactorizationError>
175 where
176 S: AsMut<[T]> + AsRef<[T]>,
177 {
178 self.U[(0, 0)] = T::one();
179 self.Vt[(0, 0)] = T::one();
180 self.s[0] = A[(0, 0)];
181
182 if self.s[0] < T::zero() {
183 self.s[0] = -self.s[0];
184 self.U[(0, 0)] = -T::one();
185 };
186 Ok(())
187 }
188}
189
190impl<T> SVDEngine<T>
193where
194 T: FloatT,
195{
196 fn factor2<S>(
197 &mut self,
198 A: &mut DenseStorageMatrix<S, T>,
199 ) -> Result<(), DenseFactorizationError>
200 where
201 S: AsMut<[T]> + AsRef<[T]>,
202 {
203 let mut As = DenseMatrix2::<T>::from(A);
204 let mut Vs = DenseMatrix2::<T>::zeros();
205 let mut Us = DenseMatrix2::<T>::zeros();
206
207 let s = As.svd(&mut Us, &mut Vs);
208 self.s.copy_from_slice(&s);
209 self.U.data.copy_from(&Us.data);
210
211 Vs.transpose_in_place();
213 self.Vt.copy_from_slice(&Vs.data);
214 Ok(())
215 }
216}
217
218impl<T> SVDEngine<T>
221where
222 T: FloatT,
223{
224 fn factor3<S>(
225 &mut self,
226 A: &mut DenseStorageMatrix<S, T>,
227 ) -> Result<(), DenseFactorizationError>
228 where
229 S: AsMut<[T]> + AsRef<[T]>,
230 {
231 let mut As = DenseMatrix3::<T>::from(A);
232 let mut Vs = DenseMatrix3::<T>::zeros();
233 let mut Us = DenseMatrix3::<T>::zeros();
234
235 let s = As.svd(&mut Us, &mut Vs);
236 self.s.copy_from_slice(&s);
237 self.U.data.copy_from(&Us.data);
238
239 Vs.transpose_in_place();
241 self.Vt.copy_from_slice(&Vs.data);
242 Ok(())
243 }
244}
245
246impl<T> SVDEngine<T>
249where
250 T: FloatT,
251{
252 fn factorblas<S>(
253 &mut self,
254 A: &mut DenseStorageMatrix<S, T>,
255 ) -> Result<(), DenseFactorizationError>
256 where
257 S: AsMut<[T]> + AsRef<[T]>,
258 {
259 let m = self.U.nrows();
262 let n = self.Vt.ncols();
263
264 let blaswork = self.blas.get_or_insert_with(SVDBlasWorkVectors::default);
266
267 let job = b'S'; let m = m.try_into().unwrap();
269 let n = n.try_into().unwrap();
270 let a = A.data_mut();
271 let lda = m;
272 let s = &mut self.s; let u = self.U.data_mut(); let ldu = m; let vt = self.Vt.data_mut(); let ldvt = min(m, n); let work = &mut blaswork.work;
278 let mut lwork = -1_i32; let iwork = &mut blaswork.iwork;
280 let info = &mut 0_i32; for i in 0..2 {
283 if self.algorithm == SVDEngineAlgorithm::DivideAndConquer {
288 iwork.resize(8 * min(m, n) as usize, 0);
289 }
290
291 match self.algorithm {
292 SVDEngineAlgorithm::DivideAndConquer => T::xgesdd(
293 job, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, iwork, info,
294 ),
295 SVDEngineAlgorithm::QRDecomposition => T::xgesvd(
296 job, job, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, info,
297 ),
298 }
299 if *info != 0 {
300 return Err(DenseFactorizationError::SVD(*info));
301 }
302
303 if i == 0 {
305 lwork = work[0].to_i32().unwrap();
306 work.resize(lwork as usize, T::zero());
307 }
308 }
309 Ok(())
310 }
311
312 fn solveblas<S>(&mut self, B: &mut DenseStorageMatrix<S, T>)
313 where
314 S: AsMut<[T]> + AsRef<[T]>,
315 {
316 let m = self.U.nrows();
318 let n = self.Vt.ncols();
319 let k = min(m, n); let nrhs = B.ncols();
323
324 let tol = T::epsilon() * self.s[0].abs() * T::from(k).unwrap();
327
328 let blaswork = self.blas.get_or_insert_with(SVDBlasWorkVectors::default);
330
331 blaswork.work.resize(k + k * nrhs, T::zero());
337 let (sinv, workC) = blaswork.work.split_at_mut(k);
338
339 let mut C = BorrowedMatrixMut::from_slice_mut(workC, k, nrhs);
341 C.mul(&self.U.t(), B, T::one(), T::zero());
342
343 zip(sinv.iter_mut(), self.s.iter()).for_each(|(sinv, s)| {
345 if s.abs() > tol {
346 *sinv = T::recip(s.abs());
347 } else {
348 *sinv = T::zero();
349 }
350 });
351
352 for col in 0..nrhs {
353 C.col_slice_mut(col).hadamard(sinv);
354 }
355
356 B.mul(&self.Vt.t(), &C, T::one(), T::zero());
358 }
359}
360
361#[cfg(test)]
364mod test {
365 use super::*;
366
367 fn test_solve_data_2x2<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
368 let A = Matrix::<T>::from(&[
370 [(4.0).as_T(), (1.0).as_T()],
371 [(1.0).as_T(), (3.0).as_T()],
372 ]);
373
374 let X = Matrix::<T>::from(&[
376 [(2.0).as_T(), (3.0).as_T()],
377 [(1.0).as_T(), (2.0).as_T()],
378 ]);
379
380 let B = Matrix::<T>::from(&[
382 [(9.0).as_T(), (14.0).as_T()],
383 [(5.0).as_T(), (9.0).as_T()],
384 ]);
385
386 (A, X, B)
387 }
388
389 #[rustfmt::skip]
390 fn test_solve_data_3x3<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
391 let A = Matrix::<T>::from(&[
392 [(8.0).as_T(), (-2.0).as_T(), (4.0).as_T()],
393 [(-2.0).as_T(), (12.0).as_T(), (2.0).as_T()],
394 [(4.0).as_T(), (2.0).as_T(), (6.0).as_T()]
395 ]);
396
397 let X = Matrix::<T>::from(&[
398 [(1.0).as_T(), (2.0).as_T()], [(3.0).as_T(), (4.0).as_T()], [(5.0).as_T(), (6.0).as_T()],
401 ]);
402
403 let B = Matrix::<T>::from(&[
404 [(22.0).as_T(), (32.0).as_T()], [(44.0).as_T(), (56.0).as_T()], [(40.0).as_T(), (52.0).as_T()],
407 ]);
408
409 (A, X, B)
410 }
411
412
413 #[rustfmt::skip]
414 fn test_solve_data_4x4<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
415 let A = Matrix::<T>::from(&[
417 [(10.0).as_T(), (2.0).as_T(), (3.0).as_T(), (1.0).as_T()],
418 [(2.0).as_T(), (8.0).as_T(), (0.0).as_T(), (3.0).as_T()],
419 [(3.0).as_T(), (0.0).as_T(), (6.0).as_T(), (2.0).as_T()],
420 [(1.0).as_T(), (3.0).as_T(), (2.0).as_T(), (9.0).as_T()],
421 ]);
422
423 let X = Matrix::<T>::from(&[
425 [(1.0).as_T(), (2.0).as_T()],
426 [(2.0).as_T(), (3.0).as_T()],
427 [(3.0).as_T(), (1.0).as_T()],
428 [(4.0).as_T(), (2.0).as_T()],
429 ]);
430
431 let B = Matrix::<T>::from(&[
433 [(27.0).as_T(), (31.0).as_T()],
434 [(30.0).as_T(), (34.0).as_T()],
435 [(29.0).as_T(), (16.0).as_T()],
436 [(49.0).as_T(), (31.0).as_T()],
437 ]);
438
439 (A, X, B)
440 }
441
442 fn run_svd_solve_test<T>(A: &Matrix<T>, X: &Matrix<T>, B: &Matrix<T>, tolfn: fn(T) -> T)
443 where
444 T: FloatT,
445 {
446 use crate::algebra::VectorMath;
447
448 let methods = [
449 SVDEngineAlgorithm::DivideAndConquer,
450 SVDEngineAlgorithm::QRDecomposition,
451 ];
452
453 for method in methods.iter() {
454
455 let mut thisA = A.clone();
457 let mut thisB = B.clone();
458
459 let mut eng = SVDEngine::<T>::new(thisA.size());
460 eng.algorithm = *method;
461
462 assert!(eng.factor(&mut thisA).is_ok());
463 eng.solve(&mut thisB);
464
465 assert!(thisB.data().norm_inf_diff(X.data()) < tolfn(1e-10.as_T()));
466 }
467 }
468
469 macro_rules! generate_test_svd_solve {
470 ($fxx:ty, $test_name:ident, $tolfn:ident) => {
471 #[test]
472 fn $test_name() {
473 let (mut A, mut X, mut B) = test_solve_data_2x2::<$fxx>();
474 run_svd_solve_test(&mut A, &mut X, &mut B, |x| x.$tolfn());
475
476 let (mut A, mut X, mut B) = test_solve_data_3x3::<$fxx>();
477 run_svd_solve_test(&mut A, &mut X, &mut B, |x| x.$tolfn());
478
479 let (mut A, mut X, mut B) = test_solve_data_4x4::<$fxx>();
480 run_svd_solve_test(&mut A, &mut X, &mut B, |x| x.$tolfn());
481 }
482 };
483 }
484
485 generate_test_svd_solve!(f32, test_svd_solve_f32, sqrt);
486 generate_test_svd_solve!(f64, test_svd_solve_f64, abs);
487
488
489 fn test_factor_data_2x2<T: FloatT>() ->Matrix<T> {
490 let (A,_,_) = test_solve_data_2x2::<T>();
491 A
492 }
493 fn test_factor_data_3x3<T: FloatT>() ->Matrix<T> {
494 let (A,_,_) = test_solve_data_3x3::<T>();
495 A
496 }
497 fn test_factor_data_4x4<T: FloatT>() ->Matrix<T> {
498 let (A,_,_) = test_solve_data_4x4::<T>();
499 A
500 }
501
502 #[rustfmt::skip]
503 fn test_factor_data_2x4<T: FloatT>() -> Matrix<T> {
504 Matrix::<T>::from(&[
505 [(10.0).as_T(), (2.0).as_T(), (3.0).as_T(), (1.0).as_T()],
506 [(2.0).as_T(), (8.0).as_T(), (0.0).as_T(), (3.0).as_T()],
507 ])
508 }
509
510 #[rustfmt::skip]
511 fn test_factor_data_4x2<T: FloatT>() -> Matrix<T> {
512 Matrix::<T>::from(&[
513 [(10.0).as_T(), (2.0).as_T()],
514 [(2.0).as_T(), (8.0).as_T()],
515 [(3.0).as_T(), (1.0).as_T()],
516 [(0.0).as_T(), (3.0).as_T()],
517 ])
518 }
519
520 fn is_descending_order<T: FloatT>(s: &[T]) -> bool {
521 s.windows(2).all(|w| w[0] >= w[1])
523 }
524
525
526 fn run_svd_factor_test<T>(A: &mut Matrix<T>, tolfn: fn(T) -> T)
527 where
528 T: FloatT,
529 {
530 use crate::algebra::{DenseMatrix, MultiplyGEMM, VectorMath};
531
532 let methods = [
533 SVDEngineAlgorithm::DivideAndConquer,
534 SVDEngineAlgorithm::QRDecomposition,
535 ];
536
537 for method in methods.iter() {
538
539 let Acopy = A.clone(); let mut eng = SVDEngine::<T>::new(A.size());
542 eng.algorithm = *method;
543
544 assert!(eng.factor(A).is_ok());
545
546 let mut M = Matrix::<T>::zeros((1, 1));
547 M.resize(A.size()); let U = &eng.U;
550 let s = &eng.s;
551 let Vt = &eng.Vt;
552
553 assert!(is_descending_order(s));
554
555 let mut Us = U.clone();
557 for c in 0..s.len() {
558 for r in 0..Us.nrows() {
559 Us[(r, c)] *= s[c];
560 }
561 }
562 M.mul(&Us, Vt, T::one(), T::zero());
563 assert!(M.data().norm_inf_diff(Acopy.data()) < tolfn((1e-10).as_T()));
564 }
565 }
566
567
568 macro_rules! generate_test_svd_factor {
569 ($fxx:ty, $test_name:ident, $tolfn:ident) => {
570 #[test]
571 fn $test_name() {
572 let mut A = test_factor_data_2x2::<$fxx>();
573 run_svd_factor_test(&mut A, |x| x.$tolfn());
574
575 let mut A = test_factor_data_3x3::<$fxx>();
576 run_svd_factor_test(&mut A, |x| x.$tolfn());
577
578 let mut A = test_factor_data_4x4::<$fxx>();
579 run_svd_factor_test(&mut A, |x| x.$tolfn());
580
581 let mut A = test_factor_data_2x4::<$fxx>();
582 run_svd_factor_test(&mut A, |x| x.$tolfn());
583
584 let mut A = test_factor_data_4x2::<$fxx>();
585 run_svd_factor_test(&mut A, |x| x.$tolfn());
586 }
587 };
588 }
589
590 generate_test_svd_factor!(f32, test_svd_factor_f32, sqrt);
591 generate_test_svd_factor!(f64, test_svd_factor_f64, abs);
592
593}
594
595
596
597#[cfg(all(test, feature = "bench"))]
598mod bench {
599
600 use super::*;
601
602 fn svd3_bench_iter() -> impl Iterator<Item = Matrix<f64>> {
603
604 use itertools::iproduct;
605
606 let v = [-4., -2., 0., 1., 5.];
607
608 iproduct!(v, v, v, v, v, v, v, v, v).map(move |(a, b, c, d, e, f, g, h, i)| {
609 let data = [a,b,c,d,e,f,g,h,i];
610 Matrix::new_from_slice((3,3), &data)
611 })
612 }
613
614 #[test]
615 fn bench_svd3_vs_blas() {
616
617 let mut eng = SVDEngine::<f64>::new((3,3));
618
619 for mut A in svd3_bench_iter() {
620 eng.factor3(&mut A).unwrap();
621 }
622
623 for mut A in svd3_bench_iter() {
624 eng.factorblas(&mut A).unwrap();
625 }
626 }
627
628}
629