rlst/dense/linalg/lapack/interface/gesvd.rs
1//! Implementation of ?gesvd - SVD factorization
2
3use lapack::{cgesvd, dgesvd, sgesvd, zgesvd};
4
5use crate::base_types::{LapackError, c32, c64};
6use crate::{base_types::LapackResult, traits::rlst_num::RlstScalar};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10use num::{Zero, complex::ComplexFloat};
11
12/// JobU specifies the computation of the left singular vectors.
13#[derive(Clone, Copy)]
14#[repr(u8)]
15pub enum JobU {
16 /// Return all columns.
17 A = b'A',
18 /// Return the first `min(m, n)` columns.
19 S = b'S',
20 /// Do not compute U.
21 N = b'N',
22}
23
24/// JobVt specifies the computation of the right singular vectors.
25#[derive(Clone, Copy)]
26#[repr(u8)]
27pub enum JobVt {
28 /// Return all rows.
29 A = b'A',
30 /// Return the first `min(m, n)` rows.
31 S = b'S',
32 /// Do not compute Vt.
33 N = b'N',
34}
35
36/// ?gesvd - SVD factorization
37pub trait Gesvd: RlstScalar {
38 /// Perform a singular value decomposition (SVD) of a matrix `a` with dimensions `m` x `n`.
39 /// If either `jobu` or `jobvt` is `JobU::N` or `JobVt::N`, the corresponding singular vectors
40 /// are not computed, and the array u or correspondingly vt is not referenced and can be
41 /// `None`.
42 #[allow(clippy::too_many_arguments)]
43 fn gesvd(
44 jobu: JobU,
45 jobvt: JobVt,
46 m: usize,
47 n: usize,
48 a: &mut [Self],
49 lda: usize,
50 s: &mut [Self::Real],
51 u: Option<&mut [Self]>,
52 ldu: usize,
53 vt: Option<&mut [Self]>,
54 ldvt: usize,
55 ) -> LapackResult<()>;
56}
57
58macro_rules! implement_gesvd {
59 ($scalar:ty, $gesvd:expr) => {
60 impl Gesvd for $scalar {
61 fn gesvd(
62 jobu: JobU,
63 jobvt: JobVt,
64 m: usize,
65 n: usize,
66 a: &mut [Self],
67 lda: usize,
68 s: &mut [Self::Real],
69 u: Option<&mut [Self]>,
70 ldu: usize,
71 vt: Option<&mut [Self]>,
72 ldvt: usize,
73 ) -> LapackResult<()> {
74 assert_eq!(
75 a.len(),
76 lda * n,
77 "Require `a.len()` {} == `lda * n` {}.",
78 a.len(),
79 lda * n
80 );
81
82 assert!(
83 lda >= std::cmp::max(1, m),
84 "Require `lda` {} >= `max(1, m)` {}.",
85 lda,
86 std::cmp::max(1, m)
87 );
88
89 let k = std::cmp::min(m, n);
90
91 assert_eq!(
92 s.len(),
93 k,
94 "Require `s.len()` {} == `min(m, n)` {}.",
95 s.len(),
96 k,
97 );
98
99 let mut info = 0;
100
101 let mut u_temp = Vec::<$scalar>::new();
102 let mut vt_temp = Vec::<$scalar>::new();
103
104 let u = match jobu {
105 JobU::A => {
106 let u = u.expect("JobU::A requires u to be Some");
107 assert_eq!(
108 u.len(),
109 ldu * m,
110 "Require `u.len()` {} == `ldu * m` {}.",
111 u.len(),
112 ldu * m
113 );
114
115 assert!(
116 ldu >= std::cmp::max(1, m),
117 "Require `ldu` {} >= `max(1, m)` {}.",
118 ldu,
119 std::cmp::max(1, m)
120 );
121 u
122 }
123 JobU::S => {
124 let u = u.expect("JobU::S requires u to be Some");
125 assert_eq!(
126 u.len(),
127 ldu * k,
128 "Require `u.len()` {} == `ldu * min(m, n)` {}.",
129 u.len(),
130 ldu * k
131 );
132 assert!(
133 ldu >= std::cmp::max(1, m),
134 "Require `ldu` {} >= `max(1, m)` {}.",
135 ldu,
136 std::cmp::max(1, m)
137 );
138 u
139 }
140 JobU::N => {
141 assert!(ldu >= 1, "Require `ldu` {} >= 1.", ldu);
142 u_temp.as_mut_slice()
143 }
144 };
145
146 let vt = match jobvt {
147 JobVt::A => {
148 let vt = vt.expect("JobVt::A requires vt to be Some");
149 assert_eq!(
150 vt.len(),
151 ldvt * n,
152 "Require `u.len()` {} == `ldvt * n` {}.",
153 vt.len(),
154 ldvt * n
155 );
156
157 assert!(
158 ldvt >= std::cmp::max(1, n),
159 "Require `ldvt` {} >= `max(1, n)` {}.",
160 ldvt,
161 std::cmp::max(1, k)
162 );
163 vt
164 }
165 JobVt::S => {
166 let vt = vt.expect("JobVt::S requires u to be Some");
167 assert_eq!(
168 vt.len(),
169 ldvt * n,
170 "Require `vt.len()` {} == `ldvt * n` {}.",
171 vt.len(),
172 ldvt * n
173 );
174 assert!(
175 ldvt >= std::cmp::max(1, k),
176 "Require `ldvt` {} >= `max(1, min(m, n))` {}.",
177 ldvt,
178 std::cmp::max(1, k)
179 );
180 vt
181 }
182 JobVt::N => {
183 assert!(ldvt >= 1, "Require `ldvt` {} >= 1.", ldvt);
184 vt_temp.as_mut_slice()
185 }
186 };
187
188 let mut work = vec![<$scalar>::zero(); 1];
189
190 unsafe {
191 $gesvd(
192 jobu as u8,
193 jobvt as u8,
194 m as i32,
195 n as i32,
196 a,
197 lda as i32,
198 s,
199 u,
200 ldu as i32,
201 vt,
202 ldvt as i32,
203 &mut work,
204 -1,
205 &mut info,
206 );
207 }
208
209 if info != 0 {
210 return Err(LapackError::LapackInfoCode(info));
211 }
212
213 let lwork = work[0].re() as i32;
214
215 let mut work = vec![<$scalar>::zero(); lwork as usize];
216
217 unsafe {
218 $gesvd(
219 jobu as u8,
220 jobvt as u8,
221 m as i32,
222 n as i32,
223 a,
224 lda as i32,
225 s,
226 u,
227 ldu as i32,
228 vt,
229 ldvt as i32,
230 &mut work,
231 lwork,
232 &mut info,
233 );
234 }
235
236 lapack_return(info, ())
237 }
238 }
239 };
240}
241
242macro_rules! implement_gesvd_complex {
243 ($scalar:ty, $gesvd:expr) => {
244 impl Gesvd for $scalar {
245 fn gesvd(
246 jobu: JobU,
247 jobvt: JobVt,
248 m: usize,
249 n: usize,
250 a: &mut [Self],
251 lda: usize,
252 s: &mut [Self::Real],
253 u: Option<&mut [Self]>,
254 ldu: usize,
255 vt: Option<&mut [Self]>,
256 ldvt: usize,
257 ) -> LapackResult<()> {
258 assert_eq!(
259 a.len(),
260 lda * n,
261 "Require `a.len()` {} == `lda * n` {}.",
262 a.len(),
263 lda * n
264 );
265
266 assert!(
267 lda >= std::cmp::max(1, m),
268 "Require `lda` {} >= `max(1, m)` {}.",
269 lda,
270 std::cmp::max(1, m)
271 );
272
273 let k = std::cmp::min(m, n);
274
275 assert_eq!(
276 s.len(),
277 k,
278 "Require `s.len()` {} == `min(m, n)` {}.",
279 s.len(),
280 k,
281 );
282
283 let mut info = 0;
284
285 let mut rwork = vec![<<$scalar as RlstScalar>::Real as Zero>::zero(); 5 * k];
286
287 let mut u_temp = Vec::<$scalar>::new();
288 let mut vt_temp = Vec::<$scalar>::new();
289
290 let u = match jobu {
291 JobU::A => {
292 let u = u.expect("JobU::A requires u to be Some");
293 assert_eq!(
294 u.len(),
295 ldu * m,
296 "Require `u.len()` {} == `ldu * m` {}.",
297 u.len(),
298 ldu * m
299 );
300
301 assert!(
302 ldu >= std::cmp::max(1, m),
303 "Require `ldu` {} >= `max(1, m)` {}.",
304 ldu,
305 std::cmp::max(1, m)
306 );
307 u
308 }
309 JobU::S => {
310 let u = u.expect("JobU::S requires u to be Some");
311 assert_eq!(
312 u.len(),
313 ldu * k,
314 "Require `u.len()` {} == `ldu * min(m, n)` {}.",
315 u.len(),
316 ldu * k
317 );
318 assert!(
319 ldu >= std::cmp::max(1, m),
320 "Require `ldu` {} >= `max(1, m)` {}.",
321 ldu,
322 std::cmp::max(1, m)
323 );
324 u
325 }
326 JobU::N => {
327 assert!(ldu >= 1, "Require `ldu` {} >= 1.", ldu);
328 u_temp.as_mut_slice()
329 }
330 };
331
332 let vt = match jobvt {
333 JobVt::A => {
334 let vt = vt.expect("JobVt::A requires vt to be Some");
335 assert_eq!(
336 vt.len(),
337 ldvt * n,
338 "Require `u.len()` {} == `ldvt * n` {}.",
339 vt.len(),
340 ldvt * n
341 );
342
343 assert!(
344 ldvt >= std::cmp::max(1, n),
345 "Require `ldvt` {} >= `max(1, n)` {}.",
346 ldvt,
347 std::cmp::max(1, k)
348 );
349 vt
350 }
351 JobVt::S => {
352 let vt = vt.expect("JobVt::S requires u to be Some");
353 assert_eq!(
354 vt.len(),
355 ldvt * n,
356 "Require `vt.len()` {} == `ldvt * n` {}.",
357 vt.len(),
358 ldvt * n
359 );
360 assert!(
361 ldvt >= std::cmp::max(1, k),
362 "Require `ldvt` {} >= `max(1, min(m, n))` {}.",
363 ldvt,
364 std::cmp::max(1, k)
365 );
366 vt
367 }
368 JobVt::N => {
369 assert!(ldvt >= 1, "Require `ldvt` {} >= 1.", ldvt);
370 vt_temp.as_mut_slice()
371 }
372 };
373
374 let mut work = vec![<$scalar>::zero(); 1];
375
376 unsafe {
377 $gesvd(
378 jobu as u8,
379 jobvt as u8,
380 m as i32,
381 n as i32,
382 a,
383 lda as i32,
384 s,
385 u,
386 ldu as i32,
387 vt,
388 ldvt as i32,
389 &mut work,
390 -1,
391 &mut rwork,
392 &mut info,
393 );
394 }
395
396 if info != 0 {
397 return Err(LapackError::LapackInfoCode(info));
398 }
399
400 let lwork = work[0].re() as i32;
401
402 let mut work = vec![<$scalar>::zero(); lwork as usize];
403
404 unsafe {
405 $gesvd(
406 jobu as u8,
407 jobvt as u8,
408 m as i32,
409 n as i32,
410 a,
411 lda as i32,
412 s,
413 u,
414 ldu as i32,
415 vt,
416 ldvt as i32,
417 &mut work,
418 lwork,
419 &mut rwork,
420 &mut info,
421 );
422 }
423
424 lapack_return(info, ())
425 }
426 }
427 };
428}
429
430implement_gesvd!(f32, sgesvd);
431implement_gesvd!(f64, dgesvd);
432implement_gesvd_complex!(c32, cgesvd);
433implement_gesvd_complex!(c64, zgesvd);