1use crate::UnsafeRandom1DAccessByValue;
4use crate::base_types::RlstResult;
5use crate::dense::array::{Array, DynArray};
6use crate::dense::linalg::lapack::interface::gesdd::JobZ;
7use crate::traits::base_operations::Shape;
8use crate::traits::linalg::base::Gemm;
9use crate::traits::linalg::decompositions::SingularValueDecomposition;
10use crate::traits::linalg::lapack::Lapack;
11use crate::traits::rlst_num::RlstScalar;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SvdMode {
16 Full,
18 Compact,
20}
21
22impl<Item, ArrayImpl> SingularValueDecomposition for Array<ArrayImpl, 2>
23where
24 Item: Lapack + Gemm,
25 ArrayImpl: UnsafeRandom1DAccessByValue<Item = Item> + Shape<2>,
26{
27 type Item = Item;
28
29 fn singular_values(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>> {
30 let mut a = DynArray::new_from(self);
31 let [m, n] = a.shape();
32 let k = std::cmp::min(m, n);
33
34 let mut s = DynArray::<<Self::Item as RlstScalar>::Real, 1>::from_shape([k]);
35
36 Item::gesdd(
37 JobZ::N,
38 m,
39 n,
40 a.data_mut().unwrap(),
41 m,
42 s.data_mut().unwrap(),
43 None,
44 1,
45 None,
46 1,
47 )?;
48
49 Ok(s)
50 }
51
52 fn svd(
53 &self,
54 mode: SvdMode,
55 ) -> RlstResult<(
56 DynArray<<Self::Item as RlstScalar>::Real, 1>,
57 DynArray<Self::Item, 2>,
58 DynArray<Self::Item, 2>,
59 )> {
60 let mut a = DynArray::new_from(self);
61 let [m, n] = a.shape();
62 let k = std::cmp::min(m, n);
63 let mut s = DynArray::<<Self::Item as RlstScalar>::Real, 1>::from_shape([k]);
64 let (mut u, mut vt, ldvt) = match mode {
65 SvdMode::Full => (
66 DynArray::<Self::Item, 2>::from_shape([m, m]),
67 DynArray::<Self::Item, 2>::from_shape([n, n]),
68 n,
69 ),
70 SvdMode::Compact => (
71 DynArray::<Self::Item, 2>::from_shape([m, k]),
72 DynArray::<Self::Item, 2>::from_shape([k, n]),
73 k,
74 ),
75 };
76
77 let jobz = match mode {
78 SvdMode::Full => JobZ::A,
79 SvdMode::Compact => JobZ::S,
80 };
81
82 Item::gesdd(
83 jobz,
84 m,
85 n,
86 a.data_mut().unwrap(),
87 m,
88 s.data_mut().unwrap(),
89 u.data_mut(),
90 m,
91 vt.data_mut(),
92 ldvt,
93 )?;
94
95 Ok((s, u, vt))
96 }
97}
98
99#[cfg(test)]
100mod test {
101
102 use super::*;
103 use crate::base_types::{c32, c64};
104 use crate::dense::array::DynArray;
105 use crate::dot;
106 use crate::traits::base_operations::*;
107 use crate::traits::linalg::SymmEig;
108 use itertools::izip;
109
110 use paste::paste;
111
112 macro_rules! implement_svd_tests {
113 ($scalar:ty, $tol:expr) => {
114 paste! {
115
116
117 #[test]
118 fn [<test_singular_values_$scalar>]() {
119 let m = 10;
120 let n = 5;
121 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
122 a.fill_from_seed_equally_distributed(0);
123
124 let ata = dot!(a.r().conj().transpose().eval(), a.r());
125
126 let s = a.singular_values().unwrap();
127
128 let actual = ata
129 .eigenvaluesh()
130 .unwrap()
131 .unary_op(|v| <<$scalar as RlstScalar>::Real>::sqrt(v))
132 .reverse_axis(0);
133
134 crate::assert_array_relative_eq!(s, actual, $tol);
135 }
136
137 #[test]
138 fn [<test_svd_thin_compact_$scalar>]() {
139 let m = 10;
140 let n = 5;
141 let k = std::cmp::min(m, n);
142 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
143 a.fill_from_seed_equally_distributed(0);
144
145 let (s, u, vt) = a.svd(SvdMode::Compact).unwrap();
146
147 let s = {
148 let mut s_mat = DynArray::<$scalar, 2>::from_shape([k, k]);
149 izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
150 *v_elem = RlstScalar::from_real(w_elem);
151 });
152 s_mat
153 };
154
155 let actual = dot!(u.r(), dot!(s.r(), vt.r()));
156 crate::assert_array_relative_eq!(actual, a, $tol);
157 }
158
159 #[test]
160 fn [<test_svd_thin_full_$scalar>]() {
161 let m = 10;
162 let n = 5;
163 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
164 a.fill_from_seed_equally_distributed(0);
165
166 let (s, u, vt) = a.svd(SvdMode::Full).unwrap();
167
168 let s = {
169 let mut s_mat = DynArray::<$scalar, 2>::from_shape([m, n]);
170 izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
171 *v_elem = RlstScalar::from_real(w_elem);
172 });
173 s_mat
174 };
175
176 let actual = dot!(u.r(), dot!(s.r(), vt.r()));
177 crate::assert_array_relative_eq!(actual, a, $tol);
178 }
179
180 #[test]
181 fn [<test_svd_thick_compact_$scalar>]() {
182 let m = 5;
183 let n = 10;
184 let k = std::cmp::min(m, n);
185 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
186 a.fill_from_seed_equally_distributed(0);
187
188 let (s, u, vt) = a.svd(SvdMode::Compact).unwrap();
189
190 let s = {
191 let mut s_mat = DynArray::<$scalar, 2>::from_shape([k, k]);
192 izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
193 *v_elem = RlstScalar::from_real(w_elem);
194 });
195 s_mat
196 };
197
198 let actual = dot!(u.r(), dot!(s.r(), vt.r()));
199 crate::assert_array_relative_eq!(actual, a, $tol);
200 }
201
202 #[test]
203 fn [<test_svd_thick_full_$scalar>]() {
204 let m = 5;
205 let n = 10;
206 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
207 a.fill_from_seed_equally_distributed(0);
208
209 let (s, u, vt) = a.svd(SvdMode::Full).unwrap();
210
211 let s = {
212 let mut s_mat = DynArray::<$scalar, 2>::from_shape([m, n]);
213 izip!(s_mat.diag_iter_mut(), s.iter_value()).for_each(|(v_elem, w_elem)| {
214 *v_elem = RlstScalar::from_real(w_elem);
215 });
216 s_mat
217 };
218
219 let actual = dot!(u.r(), dot!(s.r(), vt.r()));
220 crate::assert_array_relative_eq!(actual, a, $tol);
221 }
222
223
224 }
225 };
226 }
227
228 implement_svd_tests!(f32, 1E-4);
229 implement_svd_tests!(f64, 1E-10);
230 implement_svd_tests!(c32, 1E-4);
231 implement_svd_tests!(c64, 1E-10);
232
233 macro_rules! implement_pinv_tests {
234 ($scalar:ty, $tol:expr) => {
235 paste! {
236
237 #[test]
238 fn [<test_pseudo_inverse_thin_$scalar>]() {
239 let m = 20;
240 let n = 10;
241 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
242 a.fill_from_seed_normally_distributed(0);
243
244 let pinv = a.pseudo_inverse(None, None).unwrap();
245
246 let pinv_mat = pinv.as_matrix();
247
248 assert_eq!(pinv_mat.shape(), [n, m]);
249
250 let mut ident = DynArray::<$scalar, 2>::from_shape([n, n]);
251 ident.set_identity();
252
253 let actual = dot!(pinv_mat.r(), a.r());
254
255 crate::assert_array_abs_diff_eq!(actual, ident, $tol);
256
257 let mut x = DynArray::<$scalar, 2>::from_shape([m, 2]);
258
259 x.fill_from_seed_equally_distributed(1);
260
261 crate::assert_array_relative_eq!(dot!(pinv_mat.r(), x.r()), pinv.apply(&x), $tol);
262 }
263
264 #[test]
265 fn [<test_pseudo_inverse_thick_$scalar>]() {
266 let m = 10;
267 let n = 20;
268 let mut a = DynArray::<$scalar, 2>::from_shape([m, n]);
269 a.fill_from_seed_normally_distributed(0);
270
271 let pinv = a.pseudo_inverse(None, None).unwrap();
272
273 let pinv_mat = pinv.as_matrix();
274
275 assert_eq!(pinv_mat.shape(), [n, m]);
276
277 let mut ident = DynArray::<$scalar, 2>::from_shape([m, m]);
278 ident.set_identity();
279
280 let actual = dot!(a.r(), pinv_mat.r());
281
282 crate::assert_array_abs_diff_eq!(actual, ident, $tol);
283
284 let mut x = DynArray::<$scalar, 2>::from_shape([m, 2]);
285
286 x.fill_from_seed_equally_distributed(1);
287
288 crate::assert_array_relative_eq!(dot!(pinv_mat.r(), x.r()), pinv.apply(&x), $tol);
289 }
290
291 }
292 };
293 }
294
295 implement_pinv_tests!(f32, 1E-4);
296 implement_pinv_tests!(f64, 1E-10);
297 implement_pinv_tests!(c32, 1E-4);
298 implement_pinv_tests!(c64, 1E-10);
299}