clarabel/algebra/dense/blas/
svd.rs

1#![allow(non_snake_case)]
2
3use crate::algebra::*;
4use core::cmp::min;
5use std::iter::zip;
6
7#[allow(dead_code)]
8#[derive(PartialEq, Eq, Copy, Clone, Default)]
9pub(crate) enum SVDEngineAlgorithm {
10    #[default]
11    DivideAndConquer,
12    QRDecomposition,
13}
14
15pub(crate) struct SVDBlasWorkVectors<T> {
16    pub work: Vec<T>,
17    pub iwork: Vec<i32>,
18}
19
20impl<T: FloatT> Default for SVDBlasWorkVectors<T> {
21    fn default() -> Self {
22        // must be at least 1 element because the
23        // requiring work size is written into the
24        // first element
25        let work = vec![T::one()];
26        let iwork = vec![1];
27        Self { work, iwork }
28    }
29}
30
31pub(crate) struct SVDEngine<T> {
32    /// Computed singular values
33    pub s: Vec<T>,
34
35    /// Left and right SVD matrices, each containing.
36    /// min(m,n) vectors.  Note right singular vectors
37    /// are stored in transposed form.
38    pub U: Matrix<T>,
39    pub Vt: Matrix<T>,
40
41    // BLAS workspace (allocated vecs only)
42    pub blas: Option<SVDBlasWorkVectors<T>>,
43
44    // BLAS factorization method
45    pub algorithm: SVDEngineAlgorithm,
46}
47
48impl<T> SVDEngine<T>
49where
50    T: FloatT,
51{
52    pub fn new(size: (usize, usize)) -> Self {
53        let (m, n) = size;
54        let s = vec![T::zero(); min(m, n)];
55        let U = Matrix::<T>::zeros((m, min(m, n)));
56        let Vt = Matrix::<T>::zeros((min(m, n), n));
57        let blas = None;
58        let algorithm = SVDEngineAlgorithm::default();
59        Self {
60            s,
61            U,
62            Vt,
63            blas,
64            algorithm,
65        }
66    }
67
68    pub fn resize(&mut self, size: (usize, usize)) {
69        let (m, n) = size;
70        self.s.resize(min(m, n), T::zero());
71        self.U.resize((m, min(m, n)));
72        self.Vt.resize((min(m, n), n));
73    }
74
75    fn checkdim_factor<S>(
76        &mut self,
77        A: &mut DenseStorageMatrix<S, T>,
78    ) -> Result<(), DenseFactorizationError>
79    where
80        S: AsMut<[T]> + AsRef<[T]>,
81    {
82        let (m, n) = A.size();
83
84        if self.U.nrows() != m || self.Vt.ncols() != n {
85            Err(DenseFactorizationError::IncompatibleDimension)
86        } else {
87            Ok(())
88        }
89    }
90
91    fn checkdim_solve<S>(
92        &mut self,
93        B: &mut DenseStorageMatrix<S, T>,
94    ) -> Result<(), DenseFactorizationError>
95    where
96        S: AsMut<[T]> + AsRef<[T]>,
97    {
98        // get the dimensions for the SVD factors
99        let m = self.U.nrows();
100        let n = self.Vt.ncols();
101
102        // this function only implemented for square matrices
103        // because otherwise writing the solution in place
104        // does not make sense.   This is not a good general
105        // implementation, but is only needed at present for a
106        // rank-deficient, symmetric square solves in PSD
107        // completion
108        if m != n {
109            return Err(DenseFactorizationError::IncompatibleDimension);
110        }
111
112        // the number of columns in B
113        if B.nrows() != m {
114            return Err(DenseFactorizationError::IncompatibleDimension);
115        }
116        Ok(())
117    }
118}
119
120impl<T> FactorSVD<T> for SVDEngine<T>
121where
122    T: FloatT,
123{
124    fn factor<S>(&mut self, A: &mut DenseStorageMatrix<S, T>) -> Result<(), DenseFactorizationError>
125    where
126        S: AsMut<[T]> + AsRef<[T]>,
127    {
128        self.checkdim_factor(A)?;
129
130        // all special cases are square
131        if A.is_square() {
132            match A.nrows() {
133                1 => self.factor1(A),
134                2 => self.factor2(A),
135                3 => self.factor3(A),
136                _ => self.factorblas(A),
137            }
138        } else {
139            // non-square matrices
140            self.factorblas(A)
141        }
142    }
143
144    fn solve<S>(&mut self, B: &mut DenseStorageMatrix<S, T>)
145    where
146        S: AsMut<[T]> + AsRef<[T]>,
147    {
148        // just unwrap here.   The only way
149        // to fail is to have a non-square matrix,
150        // which is not of interest in the crate
151        // and should panic if encountered.
152        self.checkdim_solve(B).unwrap();
153
154        // PJG: always use blas to solve, regardless of
155        // dimension.  SVD solve does not happen over cones,
156        // and is only I used (I think) during chordal
157        // decomposition.   Could come back to this for
158        // custom implementation if there is a bottleneck.
159        // NB: note this means we always carry the blas
160        // workspace, even at low dimensions.
161
162        self.solveblas(B);
163    }
164}
165
166// trivial implementation for 1x1 matrices
167impl<T> SVDEngine<T>
168where
169    T: FloatT,
170{
171    fn factor1<S>(
172        &mut self,
173        A: &mut DenseStorageMatrix<S, T>,
174    ) -> Result<(), DenseFactorizationError>
175    where
176        S: AsMut<[T]> + AsRef<[T]>,
177    {
178        self.U[(0, 0)] = T::one();
179        self.Vt[(0, 0)] = T::one();
180        self.s[0] = A[(0, 0)];
181
182        if self.s[0] < T::zero() {
183            self.s[0] = -self.s[0];
184            self.U[(0, 0)] = -T::one();
185        };
186        Ok(())
187    }
188}
189
190// implementation for 2x2 matrices
191
192impl<T> SVDEngine<T>
193where
194    T: FloatT,
195{
196    fn factor2<S>(
197        &mut self,
198        A: &mut DenseStorageMatrix<S, T>,
199    ) -> Result<(), DenseFactorizationError>
200    where
201        S: AsMut<[T]> + AsRef<[T]>,
202    {
203        let mut As = DenseMatrix2::<T>::from(A);
204        let mut Vs = DenseMatrix2::<T>::zeros();
205        let mut Us = DenseMatrix2::<T>::zeros();
206
207        let s = As.svd(&mut Us, &mut Vs);
208        self.s.copy_from_slice(&s);
209        self.U.data.copy_from(&Us.data);
210
211        // Vt is stored in transposed form
212        Vs.transpose_in_place();
213        self.Vt.copy_from_slice(&Vs.data);
214        Ok(())
215    }
216}
217
218// implementation for 3x3 matrices
219
220impl<T> SVDEngine<T>
221where
222    T: FloatT,
223{
224    fn factor3<S>(
225        &mut self,
226        A: &mut DenseStorageMatrix<S, T>,
227    ) -> Result<(), DenseFactorizationError>
228    where
229        S: AsMut<[T]> + AsRef<[T]>,
230    {
231        let mut As = DenseMatrix3::<T>::from(A);
232        let mut Vs = DenseMatrix3::<T>::zeros();
233        let mut Us = DenseMatrix3::<T>::zeros();
234
235        let s = As.svd(&mut Us, &mut Vs);
236        self.s.copy_from_slice(&s);
237        self.U.data.copy_from(&Us.data);
238
239        // Vt is stored in transposed form
240        Vs.transpose_in_place();
241        self.Vt.copy_from_slice(&Vs.data);
242        Ok(())
243    }
244}
245
246// implementation for arbitrary size (square) matrices
247
248impl<T> SVDEngine<T>
249where
250    T: FloatT,
251{
252    fn factorblas<S>(
253        &mut self,
254        A: &mut DenseStorageMatrix<S, T>,
255    ) -> Result<(), DenseFactorizationError>
256    where
257        S: AsMut<[T]> + AsRef<[T]>,
258    {
259        // standard BLAS ?gesdd and/or ?gesvd arguments for economy size SVD.
260
261        let m = self.U.nrows();
262        let n = self.Vt.ncols();
263
264        // unwrap or populate on the first call
265        let blaswork = self.blas.get_or_insert_with(SVDBlasWorkVectors::default);
266
267        let job = b'S'; // compact.
268        let m = m.try_into().unwrap();
269        let n = n.try_into().unwrap();
270        let a = A.data_mut();
271        let lda = m;
272        let s = &mut self.s; // singular values go here
273        let u = self.U.data_mut(); // U data goes here
274        let ldu = m; // leading dim of U
275        let vt = self.Vt.data_mut(); // Vt data goes here
276        let ldvt = min(m, n); // leading dim of Vt
277        let work = &mut blaswork.work;
278        let mut lwork = -1_i32; // -1 => config to request required work size
279        let iwork = &mut blaswork.iwork;
280        let info = &mut 0_i32; // output info
281
282        for i in 0..2 {
283            // iwork is only used for the DivideAndConquer BLAS call
284            // and should always be 8*min(m,n) elements in that case.
285            // This will *not* shrink iwork in the case that the engine's
286            // algorithm is switched back and forth
287            if self.algorithm == SVDEngineAlgorithm::DivideAndConquer {
288                iwork.resize(8 * min(m, n) as usize, 0);
289            }
290
291            match self.algorithm {
292                SVDEngineAlgorithm::DivideAndConquer => T::xgesdd(
293                    job, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, iwork, info,
294                ),
295                SVDEngineAlgorithm::QRDecomposition => T::xgesvd(
296                    job, job, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, info,
297                ),
298            }
299            if *info != 0 {
300                return Err(DenseFactorizationError::SVD(*info));
301            }
302
303            // resize work vector and reset length
304            if i == 0 {
305                lwork = work[0].to_i32().unwrap();
306                work.resize(lwork as usize, T::zero());
307            }
308        }
309        Ok(())
310    }
311
312    fn solveblas<S>(&mut self, B: &mut DenseStorageMatrix<S, T>)
313    where
314        S: AsMut<[T]> + AsRef<[T]>,
315    {
316        // get the dimensions for the SVD factors
317        let m = self.U.nrows();
318        let n = self.Vt.ncols();
319        let k = min(m, n); //number of singular values
320
321        // the number of columns in B
322        let nrhs = B.ncols();
323
324        // compute a tolerance for the singular values
325        // to be considered invertible
326        let tol = T::epsilon() * self.s[0].abs() * T::from(k).unwrap();
327
328        // unwrap or populate on the first call
329        let blaswork = self.blas.get_or_insert_with(SVDBlasWorkVectors::default);
330
331        // will compute B <- Vt * (Σ^-1 * (U^T * B))
332        // we need a workspace that is at least nrhs * k
333        // to hold the product C = U^T * B.  Will also
334        // allocate additional space to hold the inverted
335        // singular values
336        blaswork.work.resize(k + k * nrhs, T::zero());
337        let (sinv, workC) = blaswork.work.split_at_mut(k);
338
339        // C <- U^T * B
340        let mut C = BorrowedMatrixMut::from_slice_mut(workC, k, nrhs);
341        C.mul(&self.U.t(), B, T::one(), T::zero());
342
343        // C <- Σ^-1 * C
344        zip(sinv.iter_mut(), self.s.iter()).for_each(|(sinv, s)| {
345            if s.abs() > tol {
346                *sinv = T::recip(s.abs());
347            } else {
348                *sinv = T::zero();
349            }
350        });
351
352        for col in 0..nrhs {
353            C.col_slice_mut(col).hadamard(sinv);
354        }
355
356        // B <- V * C
357        B.mul(&self.Vt.t(), &C, T::one(), T::zero());
358    }
359}
360
361// ---- unit testing ----
362
363#[cfg(test)]
364mod test {
365    use super::*;
366
367    fn test_solve_data_2x2<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
368        // Create a symmetric matrix S
369        let A = Matrix::<T>::from(&[
370            [(4.0).as_T(), (1.0).as_T()],
371            [(1.0).as_T(), (3.0).as_T()],
372        ]);
373    
374        // Solution matrix X with 2 columns
375        let X = Matrix::<T>::from(&[
376            [(2.0).as_T(), (3.0).as_T()],
377            [(1.0).as_T(), (2.0).as_T()],
378        ]);
379    
380        // Right-hand side B = S*X
381        let B = Matrix::<T>::from(&[
382            [(9.0).as_T(), (14.0).as_T()],
383            [(5.0).as_T(), (9.0).as_T()],
384        ]);
385    
386        (A, X, B)
387    }
388 
389    #[rustfmt::skip]
390    fn test_solve_data_3x3<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
391        let A = Matrix::<T>::from(&[
392            [(8.0).as_T(), (-2.0).as_T(), (4.0).as_T()], 
393            [(-2.0).as_T(), (12.0).as_T(), (2.0).as_T()], 
394            [(4.0).as_T(), (2.0).as_T(), (6.0).as_T()]
395        ]);
396
397        let X = Matrix::<T>::from(&[
398            [(1.0).as_T(), (2.0).as_T()], //
399            [(3.0).as_T(), (4.0).as_T()], //
400            [(5.0).as_T(), (6.0).as_T()],
401        ]);
402
403        let B = Matrix::<T>::from(&[
404            [(22.0).as_T(), (32.0).as_T()], //
405            [(44.0).as_T(), (56.0).as_T()], //
406            [(40.0).as_T(), (52.0).as_T()],
407        ]);
408
409        (A, X, B)
410    }
411
412       
413    #[rustfmt::skip]
414    fn test_solve_data_4x4<T: FloatT>() -> (Matrix<T>, Matrix<T>, Matrix<T>) {
415        // Create a symmetric matrix S
416        let A = Matrix::<T>::from(&[
417            [(10.0).as_T(), (2.0).as_T(),  (3.0).as_T(),  (1.0).as_T()],
418            [(2.0).as_T(),  (8.0).as_T(),  (0.0).as_T(),  (3.0).as_T()],
419            [(3.0).as_T(),  (0.0).as_T(),  (6.0).as_T(),  (2.0).as_T()],
420            [(1.0).as_T(),  (3.0).as_T(),  (2.0).as_T(),  (9.0).as_T()],
421        ]);
422
423        // Solution matrix X with 2 columns
424        let X = Matrix::<T>::from(&[
425            [(1.0).as_T(), (2.0).as_T()],
426            [(2.0).as_T(), (3.0).as_T()],
427            [(3.0).as_T(), (1.0).as_T()],
428            [(4.0).as_T(), (2.0).as_T()],
429        ]);
430
431        // Right-hand side B = S*X
432        let B = Matrix::<T>::from(&[
433            [(27.0).as_T(), (31.0).as_T()],
434            [(30.0).as_T(), (34.0).as_T()],
435            [(29.0).as_T(), (16.0).as_T()],
436            [(49.0).as_T(), (31.0).as_T()],
437        ]);
438
439        (A, X, B)
440    }
441
442    fn run_svd_solve_test<T>(A: &Matrix<T>, X: &Matrix<T>, B: &Matrix<T>, tolfn: fn(T) -> T)
443    where
444        T: FloatT,
445    {
446        use crate::algebra::VectorMath;
447
448        let methods = [
449            SVDEngineAlgorithm::DivideAndConquer,
450            SVDEngineAlgorithm::QRDecomposition,
451        ];
452
453        for method in methods.iter() {
454
455            // A and B are modified inplace during factor/solve
456            let mut thisA = A.clone();
457            let mut thisB = B.clone();
458
459            let mut eng = SVDEngine::<T>::new(thisA.size());
460            eng.algorithm = *method;
461
462            assert!(eng.factor(&mut thisA).is_ok());
463            eng.solve(&mut thisB);
464
465            assert!(thisB.data().norm_inf_diff(X.data()) < tolfn(1e-10.as_T()));
466        }
467    }
468
469    macro_rules! generate_test_svd_solve {
470        ($fxx:ty, $test_name:ident, $tolfn:ident) => {
471            #[test]
472            fn $test_name() {
473                let (mut A, mut X, mut B) = test_solve_data_2x2::<$fxx>();
474                run_svd_solve_test(&mut A, &mut X, &mut B,  |x| x.$tolfn());
475
476                let (mut A, mut X, mut B) = test_solve_data_3x3::<$fxx>();  
477                run_svd_solve_test(&mut A, &mut X, &mut B,  |x| x.$tolfn());
478
479                let (mut A, mut X, mut B) = test_solve_data_4x4::<$fxx>();
480                run_svd_solve_test(&mut A, &mut X, &mut B,  |x| x.$tolfn());
481            }
482        };
483    }
484
485    generate_test_svd_solve!(f32, test_svd_solve_f32, sqrt);
486    generate_test_svd_solve!(f64, test_svd_solve_f64, abs);
487
488
489    fn test_factor_data_2x2<T: FloatT>() ->Matrix<T> {
490        let (A,_,_) = test_solve_data_2x2::<T>();
491        A
492    }
493    fn test_factor_data_3x3<T: FloatT>() ->Matrix<T> {
494        let (A,_,_) = test_solve_data_3x3::<T>();
495        A
496    }
497    fn test_factor_data_4x4<T: FloatT>() ->Matrix<T> {
498        let (A,_,_) = test_solve_data_4x4::<T>();
499        A
500    }
501
502    #[rustfmt::skip]
503    fn test_factor_data_2x4<T: FloatT>() -> Matrix<T> {
504        Matrix::<T>::from(&[
505            [(10.0).as_T(), (2.0).as_T(),  (3.0).as_T(),  (1.0).as_T()],
506            [(2.0).as_T(),  (8.0).as_T(),  (0.0).as_T(),  (3.0).as_T()],
507        ])
508    }
509
510    #[rustfmt::skip]
511    fn test_factor_data_4x2<T: FloatT>() -> Matrix<T> {
512        Matrix::<T>::from(&[
513            [(10.0).as_T(), (2.0).as_T()],
514            [(2.0).as_T(),  (8.0).as_T()],  
515            [(3.0).as_T(),  (1.0).as_T()],
516            [(0.0).as_T(),  (3.0).as_T()],
517        ])
518    }
519
520    fn is_descending_order<T: FloatT>(s: &[T]) -> bool {
521        // is_sorted is only available post v1.82
522        s.windows(2).all(|w| w[0] >= w[1])
523    }
524
525
526    fn run_svd_factor_test<T>(A: &mut Matrix<T>, tolfn: fn(T) -> T)
527    where
528        T: FloatT,
529    {
530        use crate::algebra::{DenseMatrix, MultiplyGEMM, VectorMath};
531
532        let methods = [
533            SVDEngineAlgorithm::DivideAndConquer,
534            SVDEngineAlgorithm::QRDecomposition,
535        ];
536
537        for method in methods.iter() {
538
539            let Acopy = A.clone(); //A is corrupted after factorization
540
541            let mut eng = SVDEngine::<T>::new(A.size());
542            eng.algorithm = *method;
543
544            assert!(eng.factor(A).is_ok());
545
546            let mut M = Matrix::<T>::zeros((1, 1));
547            M.resize(A.size()); //manual resize for test coverage
548
549            let U = &eng.U;
550            let s = &eng.s;
551            let Vt = &eng.Vt;
552
553            assert!(is_descending_order(s));
554
555            //reconstruct matrix from SVD
556            let mut Us = U.clone();
557            for c in 0..s.len() {
558                for r in 0..Us.nrows() {
559                    Us[(r, c)] *= s[c];
560                }
561            }
562            M.mul(&Us, Vt, T::one(), T::zero());
563            assert!(M.data().norm_inf_diff(Acopy.data()) < tolfn((1e-10).as_T()));
564        }
565    }
566
567
568    macro_rules! generate_test_svd_factor {
569        ($fxx:ty, $test_name:ident, $tolfn:ident) => {
570            #[test]
571            fn $test_name() {
572                let mut A = test_factor_data_2x2::<$fxx>();
573                run_svd_factor_test(&mut A,  |x| x.$tolfn());
574
575                let mut A = test_factor_data_3x3::<$fxx>();  
576                run_svd_factor_test(&mut A,  |x| x.$tolfn());
577
578                let mut A = test_factor_data_4x4::<$fxx>();
579                run_svd_factor_test(&mut A,  |x| x.$tolfn());
580
581                let mut A = test_factor_data_2x4::<$fxx>();
582                run_svd_factor_test(&mut A,  |x| x.$tolfn());
583
584                let mut A = test_factor_data_4x2::<$fxx>();
585                run_svd_factor_test(&mut A,  |x| x.$tolfn());
586            }
587        };
588    }
589
590    generate_test_svd_factor!(f32, test_svd_factor_f32, sqrt);
591    generate_test_svd_factor!(f64, test_svd_factor_f64, abs);
592
593}
594
595
596
597#[cfg(all(test, feature = "bench"))]
598mod bench {
599
600    use super::*;
601
602    fn svd3_bench_iter() -> impl Iterator<Item = Matrix<f64>> {
603
604        use itertools::iproduct;
605
606        let v = [-4., -2., 0., 1., 5.];
607
608        iproduct!(v, v, v, v, v, v, v, v, v).map(move |(a, b, c, d, e, f, g, h, i)| {
609            let data = [a,b,c,d,e,f,g,h,i];
610            Matrix::new_from_slice((3,3), &data)
611        })
612    }
613
614    #[test]
615    fn bench_svd3_vs_blas() {
616
617        let mut eng = SVDEngine::<f64>::new((3,3));
618
619        for mut A in svd3_bench_iter() {
620            eng.factor3(&mut A).unwrap();
621        }
622
623        for mut A in svd3_bench_iter() {
624            eng.factorblas(&mut A).unwrap();
625        }
626    }
627
628}
629