mdarray_linalg_lapack/svd/
context.rs1use super::simple::gsvd;
11use mdarray_linalg::{get_dims, into_i32};
12
13use mdarray::{DSlice, DTensor, Dense, Layout, tensor};
14
15use super::scalar::{LapackScalar, NeedsRwork};
16use mdarray_linalg::svd::{SVD, SVDDecomp, SVDError};
17use num_complex::ComplexFloat;
18
19use crate::Lapack;
20
21impl<T> SVD<T> for Lapack
22where
23 T: ComplexFloat + Default + LapackScalar + NeedsRwork,
24 T::Real: Into<T>,
25{
26 fn svd<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<SVDDecomp<T>, SVDError> {
28 let (m, n) = get_dims!(a);
29 let min_mn = m.min(n);
30
31 let mut s = tensor![[T::default(); min_mn as usize]; min_mn as usize];
32 let mut u = tensor![[T::default(); m as usize]; m as usize];
33 let mut vt = tensor![[T::default(); n as usize]; n as usize];
34
35 match gsvd(a, &mut s, Some(&mut u), Some(&mut vt), self.svd_config) {
36 Ok(_) => Ok(SVDDecomp { s, u, vt }),
37 Err(e) => Err(e),
38 }
39 }
40
41 fn svd_s<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<DTensor<T, 2>, SVDError> {
43 let (m, n) = get_dims!(a);
44 let min_mn = m.min(n);
45
46 let mut s = tensor![[T::default(); min_mn as usize]; min_mn as usize];
48
49 match gsvd::<L, Dense, Dense, Dense, T>(a, &mut s, None, None, self.svd_config) {
50 Ok(_) => Ok(s),
51 Err(err) => Err(err),
52 }
53 }
54
55 fn svd_overwrite<L: Layout, Ls: Layout, Lu: Layout, Lvt: Layout>(
57 &self,
58 a: &mut DSlice<T, 2, L>,
59 s: &mut DSlice<T, 2, Ls>,
60 u: &mut DSlice<T, 2, Lu>,
61 vt: &mut DSlice<T, 2, Lvt>,
62 ) -> Result<(), SVDError> {
63 gsvd(a, s, Some(u), Some(vt), self.svd_config)
64 }
65
66 fn svd_overwrite_s<L: Layout, Ls: Layout>(
68 &self,
69 a: &mut DSlice<T, 2, L>,
70 s: &mut DSlice<T, 2, Ls>,
71 ) -> Result<(), SVDError> {
72 gsvd::<L, Ls, Dense, Dense, T>(a, s, None, None, self.svd_config)
73 }
74}