rlst/dense/linalg/lapack/
singular_value_decomposition.rs

1//! Implementation of the singular value decomposition using LAPACK.
2
3use crate::UnsafeRandom1DAccessByValue;
4use crate::base_types::RlstResult;
5use crate::dense::array::{Array, DynArray};
6use crate::dense::linalg::lapack::interface::gesdd::JobZ;
7use crate::traits::base_operations::Shape;
8use crate::traits::linalg::base::Gemm;
9use crate::traits::linalg::decompositions::SingularValueDecomposition;
10use crate::traits::linalg::lapack::Lapack;
11use crate::traits::rlst_num::RlstScalar;
12
13/// Symmetric eigenvalue decomposition mode.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SvdMode {
16    /// Compute full matrices U and V.
17    Full,
18    /// Compute compact matrices U and V.
19    Compact,
20}
21
22impl<Item, ArrayImpl> SingularValueDecomposition for Array<ArrayImpl, 2>
23where
24    Item: Lapack + Gemm,
25    ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item> + Shape<2>,
26{
27    type Item = Item;
28
29    fn singular_values(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>> {
30        let mut a = DynArray::new_from(self);
31        let [m, n] = a.shape();
32        let k = std::cmp::min(m, n);
33
34        let mut s = DynArray::<<Self::Item as RlstScalar>::Real, 1>::from_shape([k]);
35
36        Item::gesdd(
37            JobZ::N,
38            m,
39            n,
40            a.data_mut().unwrap(),
41            m,
42            s.data_mut().unwrap(),
43            None,
44            1,
45            None,
46            1,
47        )?;
48
49        Ok(s)
50    }
51
52    fn svd(
53        &self,
54        mode: SvdMode,
55    ) -> RlstResult<(
56        DynArray<<Self::Item as RlstScalar>::Real, 1>,
57        DynArray<Self::Item, 2>,
58        DynArray<Self::Item, 2>,
59    )> {
60        let mut a = DynArray::new_from(self);
61        let [m, n] = a.shape();
62        let k = std::cmp::min(m, n);
63        let mut s = DynArray::<<Self::Item as RlstScalar>::Real, 1>::from_shape([k]);
64        let (mut u, mut vt, ldvt) = match mode {
65            SvdMode::Full => (
66                DynArray::<Self::Item, 2>::from_shape([m, m]),
67                DynArray::<Self::Item, 2>::from_shape([n, n]),
68                n,
69            ),
70            SvdMode::Compact => (
71                DynArray::<Self::Item, 2>::from_shape([m, k]),
72                DynArray::<Self::Item, 2>::from_shape([k, n]),
73                k,
74            ),
75        };
76
77        let jobz = match mode {
78            SvdMode::Full => JobZ::A,
79            SvdMode::Compact => JobZ::S,
80        };
81
82        Item::gesdd(
83            jobz,
84            m,
85            n,
86            a.data_mut().unwrap(),
87            m,
88            s.data_mut().unwrap(),
89            u.data_mut(),
90            m,
91            vt.data_mut(),
92            ldvt,
93        )?;
94
95        Ok((s, u, vt))
96    }
97}
98
99#[cfg(test)]
100mod test {
101
102    use super::*;
103    use crate::base_types::{c32, c64};
104    use crate::dense::array::DynArray;
105    use crate::dot;
106    use crate::traits::base_operations::*;
107    use crate::traits::linalg::SymmEig;
108    use itertools::izip;
109
110    use paste::paste;
111
112    macro_rules! implement_svd_tests {
113        ($scalar:ty, $tol:expr) => {
114            paste! {
115
116
117            #[test]
118            fn [<test_singular_values_$scalar>]() {
119                let m = 10;
120                let n = 5;
121                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
122                a.fill_from_seed_equally_distributed(0);
123
124                let ata = dot!(a.r().conj().transpose().eval(), a.r());
125
126                let s = a.singular_values().unwrap();
127
128                let actual = ata
129                    .eigenvaluesh()
130                    .unwrap()
131                    .unary_op(|v| <<$scalar as RlstScalar>::Real>::sqrt(v))
132                    .reverse_axis(0);
133
134                crate::assert_array_relative_eq!(s, actual, $tol);
135            }
136
137            #[test]
138            fn [<test_svd_thin_compact_$scalar>]() {
139                let m = 10;
140                let n = 5;
141                let k = std::cmp::min(m, n);
142                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
143                a.fill_from_seed_equally_distributed(0);
144
145                let (s, u, vt) = a.svd(SvdMode::Compact).unwrap();
146
147                let s = {
148                    let mut s_mat = DynArray::<$scalar, 2>::from_shape([k, k]);
149                    izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
150                        *v_elem = RlstScalar::from_real(w_elem);
151                    });
152                    s_mat
153                };
154
155                let actual = dot!(u.r(), dot!(s.r(), vt.r()));
156                crate::assert_array_relative_eq!(actual, a, $tol);
157            }
158
159            #[test]
160            fn [<test_svd_thin_full_$scalar>]() {
161                let m = 10;
162                let n = 5;
163                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
164                a.fill_from_seed_equally_distributed(0);
165
166                let (s, u, vt) = a.svd(SvdMode::Full).unwrap();
167
168                let s = {
169                    let mut s_mat = DynArray::<$scalar, 2>::from_shape([m, n]);
170                    izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
171                        *v_elem = RlstScalar::from_real(w_elem);
172                    });
173                    s_mat
174                };
175
176                let actual = dot!(u.r(), dot!(s.r(), vt.r()));
177                crate::assert_array_relative_eq!(actual, a, $tol);
178            }
179
180            #[test]
181            fn [<test_svd_thick_compact_$scalar>]() {
182                let m = 5;
183                let n = 10;
184                let k = std::cmp::min(m, n);
185                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
186                a.fill_from_seed_equally_distributed(0);
187
188                let (s, u, vt) = a.svd(SvdMode::Compact).unwrap();
189
190                let s = {
191                    let mut s_mat = DynArray::<$scalar, 2>::from_shape([k, k]);
192                    izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
193                        *v_elem = RlstScalar::from_real(w_elem);
194                    });
195                    s_mat
196                };
197
198                let actual = dot!(u.r(), dot!(s.r(), vt.r()));
199                crate::assert_array_relative_eq!(actual, a, $tol);
200            }
201
202            #[test]
203            fn [<test_svd_thick_full_$scalar>]() {
204                let m = 5;
205                let n = 10;
206                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
207                a.fill_from_seed_equally_distributed(0);
208
209                let (s, u, vt) = a.svd(SvdMode::Full).unwrap();
210
211                let s = {
212                    let mut s_mat = DynArray::<$scalar, 2>::from_shape([m, n]);
213                    izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
214                        *v_elem = RlstScalar::from_real(w_elem);
215                    });
216                    s_mat
217                };
218
219                let actual = dot!(u.r(), dot!(s.r(), vt.r()));
220                crate::assert_array_relative_eq!(actual, a, $tol);
221            }
222
223
224                    }
225        };
226    }
227
228    implement_svd_tests!(f32, 1E-4);
229    implement_svd_tests!(f64, 1E-10);
230    implement_svd_tests!(c32, 1E-4);
231    implement_svd_tests!(c64, 1E-10);
232
233    macro_rules! implement_pinv_tests {
234        ($scalar:ty, $tol:expr) => {
235            paste! {
236
237            #[test]
238            fn [<test_pseudo_inverse_thin_$scalar>]() {
239                let m = 20;
240                let n = 10;
241                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
242                a.fill_from_seed_normally_distributed(0);
243
244                let pinv = a.pseudo_inverse(None, None).unwrap();
245
246                let pinv_mat = pinv.as_matrix();
247
248                assert_eq!(pinv_mat.shape(), [n, m]);
249
250                let mut ident = DynArray::<$scalar, 2>::from_shape([n, n]);
251                ident.set_identity();
252
253                let actual = dot!(pinv_mat.r(), a.r());
254
255                crate::assert_array_abs_diff_eq!(actual, ident, $tol);
256
257                let mut x = DynArray::<$scalar, 2>::from_shape([m, 2]);
258
259                x.fill_from_seed_equally_distributed(1);
260
261                crate::assert_array_relative_eq!(dot!(pinv_mat.r(), x.r()), pinv.apply(&x), $tol);
262            }
263
264            #[test]
265            fn [<test_pseudo_inverse_thick_$scalar>]() {
266                let m = 10;
267                let n = 20;
268                let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
269                a.fill_from_seed_normally_distributed(0);
270
271                let pinv = a.pseudo_inverse(None, None).unwrap();
272
273                let pinv_mat = pinv.as_matrix();
274
275                assert_eq!(pinv_mat.shape(), [n, m]);
276
277                let mut ident = DynArray::<$scalar, 2>::from_shape([m, m]);
278                ident.set_identity();
279
280                let actual = dot!(a.r(), pinv_mat.r());
281
282                crate::assert_array_abs_diff_eq!(actual, ident, $tol);
283
284                let mut x = DynArray::<$scalar, 2>::from_shape([m, 2]);
285
286                x.fill_from_seed_equally_distributed(1);
287
288                crate::assert_array_relative_eq!(dot!(pinv_mat.r(), x.r()), pinv.apply(&x), $tol);
289            }
290
291                    }
292        };
293    }
294
295    implement_pinv_tests!(f32, 1E-4);
296    implement_pinv_tests!(f64, 1E-10);
297    implement_pinv_tests!(c32, 1E-4);
298    implement_pinv_tests!(c64, 1E-10);
299}