1use bidiag::BidiagParams;
14use linalg::qr::no_pivoting::factor::QrParams;
15
16use crate::assert;
17use crate::internal_prelude::*;
18
19pub mod bidiag;
21pub(crate) mod bidiag_svd;
22
23#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum ComputeSvdVectors {
26 No,
28 Thin,
30 Full,
32}
33
34#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum SvdError {
37 NoConvergence,
39}
40
41#[derive(Debug, Copy, Clone)]
43pub struct SvdParams {
44 pub bidiag: BidiagParams,
46 pub qr: QrParams,
48 pub recursion_threshold: usize,
50 pub qr_ratio_threshold: f64,
52
53 #[doc(hidden)]
54 pub non_exhaustive: NonExhaustive,
55}
56
57impl<T: ComplexField> Auto<T> for SvdParams {
58 fn auto() -> Self {
59 Self {
60 recursion_threshold: 128,
61 qr_ratio_threshold: 11.0 / 6.0,
62
63 bidiag: auto!(T),
64 qr: auto!(T),
65 non_exhaustive: NonExhaustive(()),
66 }
67 }
68}
69
70fn svd_imp_scratch<T: ComplexField>(
71 m: usize,
72 n: usize,
73 compute_u: ComputeSvdVectors,
74 compute_v: ComputeSvdVectors,
75
76 bidiag_svd_scratch: fn(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq,
77
78 params: SvdParams,
79
80 par: Par,
81) -> StackReq {
82 assert!(m >= n);
83
84 let householder_blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
85 let bid = temp_mat_scratch::<T>(m, n);
86 let householder_left = temp_mat_scratch::<T>(householder_blocksize, n);
87 let householder_right = temp_mat_scratch::<T>(householder_blocksize, n);
88
89 let compute_bidiag = bidiag::bidiag_in_place_scratch::<T>(m, n, par, params.bidiag.into());
90 let diag = temp_mat_scratch::<T>(n, 1);
91 let subdiag = diag;
92 let compute_ub = compute_v != ComputeSvdVectors::No;
93 let compute_vb = compute_u != ComputeSvdVectors::No;
94 let u_b = temp_mat_scratch::<T>(if compute_ub { n + 1 } else { 2 }, n + 1);
95 let v_b = temp_mat_scratch::<T>(n, if compute_vb { n } else { 0 });
96
97 let compute_bidiag_svd = bidiag_svd_scratch(n, compute_ub, compute_vb, par, params);
98
99 let apply_householder_u = linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
100 m,
101 householder_blocksize,
102 match compute_u {
103 ComputeSvdVectors::No => 0,
104 ComputeSvdVectors::Thin => n,
105 ComputeSvdVectors::Full => m,
106 },
107 );
108 let apply_householder_v = linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
109 n - 1,
110 householder_blocksize,
111 match compute_v {
112 ComputeSvdVectors::No => 0,
113 _ => n,
114 },
115 );
116
117 StackReq::all_of(&[
118 bid,
119 householder_left,
120 householder_right,
121 StackReq::any_of(&[
122 compute_bidiag,
123 StackReq::all_of(&[
124 diag,
125 subdiag,
126 u_b,
127 v_b,
128 StackReq::any_of(&[compute_bidiag_svd, StackReq::all_of(&[apply_householder_u, apply_householder_v])]),
129 ]),
130 ]),
131 ])
132}
133
134fn bidiag_cplx_svd_scratch<T: ComplexField>(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq {
135 StackReq::all_of(&[
136 temp_mat_scratch::<T>(n, 1).array(4),
137 temp_mat_scratch::<T::Real>(n + 1, if compute_u { n + 1 } else { 0 }),
138 temp_mat_scratch::<T::Real>(n, if compute_v { n } else { 0 }),
139 bidiag_real_svd_scratch::<T::Real>(n, compute_u, compute_v, par, params),
140 ])
141}
142
143fn bidiag_real_svd_scratch<T: RealField>(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq {
144 if n < params.recursion_threshold {
145 StackReq::EMPTY
146 } else {
147 StackReq::all_of(&[
148 temp_mat_scratch::<T>(2, if compute_u { 0 } else { n + 1 }),
149 bidiag_svd::divide_and_conquer_scratch::<T>(n, params.recursion_threshold, compute_u, compute_v, par),
150 ])
151 }
152}
153
154#[math]
155fn compute_bidiag_cplx_svd<T: ComplexField>(
156 mut diag: ColMut<'_, T, usize, ContiguousFwd>,
157 subdiag: ColMut<'_, T, usize, ContiguousFwd>,
158 mut u: Option<MatMut<'_, T>>,
159 mut v: Option<MatMut<'_, T>>,
160 params: SvdParams,
161 par: Par,
162 stack: &mut MemStack,
163) -> Result<(), SvdError> {
164 let n = diag.nrows();
165
166 let (mut diag_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
167 let (mut subdiag_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
168 let (mut u_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n + 1, if u.is_some() { n + 1 } else { 0 }, stack) };
169 let (mut v_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, if v.is_some() { n } else { 0 }, stack) };
170
171 let (mut col_mul, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
172 let (mut row_mul, stack) = unsafe { temp_mat_uninit::<T, _, _>(n - 1, 1, stack) };
173
174 let mut u_real = u.rb().map(|_| u_real.as_mat_mut());
175 let mut v_real = v.rb().map(|_| v_real.as_mat_mut());
176
177 let mut diag_real = diag_real.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
178 let mut subdiag_real = subdiag_real.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
179
180 let mut col_mul = col_mul.as_mat_mut().col_mut(0);
181 let mut row_mul = row_mul.as_mat_mut().col_mut(0);
182
183 let normalized = |x: T| {
184 if x == zero() {
185 one()
186 } else {
187 let norm1 = max(abs(real(x)), abs(imag(x)));
188 let y = x * from_real(recip(norm1));
189 y * from_real(recip(abs(y)))
190 }
191 };
192
193 let mut col_normalized = normalized(conj(diag[0]));
194 col_mul[0] = copy(col_normalized);
195 diag_real[0] = abs(diag[0]);
196 subdiag_real[n - 1] = zero();
197 for i in 1..n {
198 let row_normalized = normalized(conj(subdiag[i - 1] * col_normalized));
199 subdiag_real[i - 1] = abs(subdiag[i - 1]);
200 row_mul[i - 1] = conj(row_normalized);
201
202 col_normalized = normalized(conj(diag[i] * row_normalized));
203 diag_real[i] = abs(diag[i]);
204 col_mul[i] = copy(col_normalized);
205 }
206
207 compute_bidiag_real_svd(
208 diag_real.rb_mut(),
209 subdiag_real.rb_mut(),
210 u_real.rb_mut(),
211 v_real.rb_mut(),
212 params,
213 par,
214 stack,
215 )?;
216
217 for i in 0..n {
218 diag[i] = from_real(diag_real[i]);
219 }
220
221 let u_real = u_real.rb();
222 let v_real = v_real.rb();
223
224 if let (Some(mut u), Some(u_real)) = (u.rb_mut(), u_real) {
225 z!(u.rb_mut().row_mut(0), u_real.row(0)).for_each(|uz!(u, r)| *u = from_real(*r));
226 z!(u.rb_mut().row_mut(n), u_real.row(n)).for_each(|uz!(u, r)| *u = from_real(*r));
227
228 for j in 0..u.ncols() {
229 let mut u = u.rb_mut().col_mut(j).subrows_mut(1, n - 1);
230 let u_real = u_real.rb().col(j).subrows(1, n - 1);
231 z!(u.rb_mut(), u_real, row_mul.rb()).for_each(|uz!(u, re, f)| *u = mul_real(*f, *re));
232 }
233 }
234 if let (Some(mut v), Some(v_real)) = (v.rb_mut(), v_real) {
235 for j in 0..v.ncols() {
236 let mut v = v.rb_mut().col_mut(j);
237 let v_real = v_real.rb().col(j);
238 z!(v.rb_mut(), v_real, col_mul.rb()).for_each(|uz!(v, re, f)| *v = mul_real(*f, *re));
239 }
240 }
241
242 Ok(())
243}
244
245#[math]
246fn compute_bidiag_real_svd<T: RealField>(
247 mut diag: ColMut<'_, T, usize, ContiguousFwd>,
248 mut subdiag: ColMut<'_, T, usize, ContiguousFwd>,
249 mut u: Option<MatMut<'_, T, usize, usize>>,
250 mut v: Option<MatMut<'_, T, usize, usize>>,
251 params: SvdParams,
252 par: Par,
253 stack: &mut MemStack,
254) -> Result<(), SvdError> {
255 let n = diag.nrows();
256 for i in 0..n {
257 if !(is_finite(diag[i]) && is_finite(subdiag[i])) {
258 return Err(SvdError::NoConvergence);
259 }
260 }
261
262 if n < params.recursion_threshold {
263 if let Some(mut u) = u.rb_mut() {
264 u.fill(zero());
265 u.diagonal_mut().fill(one());
266 }
267 if let Some(mut v) = v.rb_mut() {
268 v.fill(zero());
269 v.diagonal_mut().fill(one());
270 }
271
272 bidiag_svd::qr_algorithm(
273 diag.rb_mut(),
274 subdiag.rb_mut(),
275 u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)),
276 v.rb_mut(),
277 )?;
278
279 return Ok(());
280 } else {
281 let (mut u2, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(2, if u.is_some() { 0 } else { n + 1 }, stack) };
282
283 bidiag_svd::divide_and_conquer(
284 diag.as_row_shape_mut(n),
285 subdiag.as_row_shape_mut(n),
286 match u {
287 Some(u) => bidiag_svd::MatU::Full(u),
288 None => bidiag_svd::MatU::TwoRowsStorage(u2.as_mat_mut()),
289 },
290 v.map(|m| m.as_shape_mut(n, n)),
291 par,
292 stack,
293 params.recursion_threshold,
294 )
295 }
296}
297
298#[math]
300fn svd_imp<T: ComplexField>(
301 matrix: MatRef<'_, T>,
302 s: ColMut<'_, T>,
303 u: Option<MatMut<'_, T>>,
304 v: Option<MatMut<'_, T>>,
305 bidiag_svd: fn(
306 diag: ColMut<'_, T, usize, ContiguousFwd>,
307 subdiag: ColMut<'_, T, usize, ContiguousFwd>,
308 u: Option<MatMut<'_, T, usize, usize>>,
309 v: Option<MatMut<'_, T, usize, usize>>,
310 params: SvdParams,
311 par: Par,
312 stack: &mut MemStack,
313 ) -> Result<(), SvdError>,
314 par: Par,
315 stack: &mut MemStack,
316 params: SvdParams,
317) -> Result<(), SvdError> {
318 assert!(matrix.nrows() >= matrix.ncols());
319 let m = matrix.nrows();
320 let n = matrix.ncols();
321
322 let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
323
324 let (mut bid, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
325 let mut bid = bid.as_mat_mut();
326
327 let (mut Hl, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n, stack) };
328 let (mut Hr, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n - 1, stack) };
329
330 let mut Hl = Hl.as_mat_mut();
331 let mut Hr = Hr.as_mat_mut();
332
333 bid.copy_from(matrix);
334 bidiag::bidiag_in_place(bid.rb_mut(), Hl.rb_mut(), Hr.rb_mut(), par, stack, params.bidiag.into());
335
336 let (mut diag, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
337 let (mut subdiag, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
338 let mut diag = diag.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
339 let mut subdiag = subdiag.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
340
341 let (mut ub, stack) = unsafe { temp_mat_uninit::<T, _, _>(n + 1, if v.is_some() { n + 1 } else { 0 }, stack) };
342 let (mut vb, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, if u.is_some() { n } else { 0 }, stack) };
343
344 let mut ub = ub.as_mat_mut();
345 let mut vb = vb.as_mat_mut();
346
347 for i in 0..n {
348 diag[i] = conj(bid[(i, i)]);
349 if i + 1 < n {
350 subdiag[i] = conj(bid[(i, i + 1)]);
351 } else {
352 subdiag[i] = zero();
353 }
354 }
355
356 bidiag_svd(
357 diag.rb_mut(),
358 subdiag.rb_mut(),
359 v.rb().map(|_| ub.rb_mut()),
360 u.rb().map(|_| vb.rb_mut()),
361 params,
362 par,
363 stack,
364 )?;
365
366 { s }.copy_from(diag);
367
368 if let Some(mut u) = u {
369 let ncols = u.ncols();
370 u.rb_mut().submatrix_mut(0, 0, n, n).copy_from(vb.rb());
371 u.rb_mut().submatrix_mut(n, 0, m - n, ncols).fill(zero());
372 u.rb_mut().submatrix_mut(0, n, n, ncols - n).fill(zero());
373 u.rb_mut().submatrix_mut(n, n, ncols - n, ncols - n).diagonal_mut().fill(one());
374
375 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(bid.rb(), Hl.rb(), Conj::No, u, par, stack);
376 }
377 if let Some(mut v) = v {
378 v.copy_from(ub.rb().submatrix(0, 0, n, n));
379
380 for j in 1..n {
381 for i in 0..j {
382 bid[(j, i)] = copy(bid[(i, j)]);
383 }
384 }
385
386 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
387 bid.rb().submatrix(1, 0, n - 1, n - 1),
388 Hr.rb(),
389 Conj::Yes,
390 v.subrows_mut(1, n - 1),
391 par,
392 stack,
393 );
394 }
395
396 Ok(())
397}
398
399fn compute_squareish_svd<T: ComplexField>(
400 matrix: MatRef<'_, T>,
401 s: ColMut<'_, T>,
402 u: Option<MatMut<'_, T>>,
403 v: Option<MatMut<'_, T>>,
404 par: Par,
405 stack: &mut MemStack,
406 params: SvdParams,
407) -> Result<(), SvdError> {
408 if try_const! { T::IS_REAL } {
409 svd_imp::<T::Real>(
410 unsafe { core::mem::transmute(matrix) },
411 unsafe { core::mem::transmute(s) },
412 unsafe { core::mem::transmute(u) },
413 unsafe { core::mem::transmute(v) },
414 compute_bidiag_real_svd::<T::Real>,
415 par,
416 stack,
417 params,
418 )
419 } else {
420 svd_imp::<T>(matrix, s, u, v, compute_bidiag_cplx_svd::<T>, par, stack, params)
421 }
422}
423
424pub fn svd_scratch<T: ComplexField>(
426 nrows: usize,
427 ncols: usize,
428 compute_u: ComputeSvdVectors,
429 compute_v: ComputeSvdVectors,
430 par: Par,
431 params: Spec<SvdParams, T>,
432) -> StackReq {
433 let params = params.config;
434 let mut m = nrows;
435 let mut n = ncols;
436 let mut compute_u = compute_u;
437 let mut compute_v = compute_v;
438
439 if n > m {
440 core::mem::swap(&mut m, &mut n);
441 core::mem::swap(&mut compute_u, &mut compute_v);
442 }
443
444 if n == 0 {
445 return StackReq::EMPTY;
446 }
447
448 let bidiag_svd_scratch = if try_const! { T::IS_REAL } {
449 bidiag_real_svd_scratch::<T::Real>
450 } else {
451 bidiag_cplx_svd_scratch::<T>
452 };
453
454 if m as f64 / n as f64 <= params.qr_ratio_threshold {
455 svd_imp_scratch::<T>(m, n, compute_u, compute_v, bidiag_svd_scratch, params, par)
456 } else {
457 let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
458 StackReq::all_of(&[
459 temp_mat_scratch::<T>(m, n),
460 temp_mat_scratch::<T>(bs, n),
461 StackReq::any_of(&[
462 StackReq::all_of(&[
463 temp_mat_scratch::<T>(n, n),
464 svd_imp_scratch::<T>(n, n, compute_u, compute_v, bidiag_svd_scratch, params, par),
465 ]),
466 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
467 m,
468 bs,
469 match compute_u {
470 ComputeSvdVectors::No => 0,
471 ComputeSvdVectors::Thin => n,
472 ComputeSvdVectors::Full => m,
473 },
474 ),
475 ]),
476 ])
477 }
478}
479
480#[math]
485pub fn svd<T: ComplexField>(
486 A: MatRef<'_, T>,
487 s: DiagMut<'_, T>,
488 u: Option<MatMut<'_, T>>,
489 v: Option<MatMut<'_, T>>,
490 par: Par,
491 stack: &mut MemStack,
492 params: Spec<SvdParams, T>,
493) -> Result<(), SvdError> {
494 let params = params.config;
495
496 let (m, n) = A.shape();
497 let size = Ord::min(m, n);
498 assert!(s.dim() == size);
499 let s = s.column_vector_mut();
500
501 if let Some(u) = u.rb() {
502 assert!(all(u.nrows() == A.nrows(), any(u.ncols() == A.nrows(), u.ncols() == size),));
503 }
504 if let Some(v) = v.rb() {
505 assert!(all(v.nrows() == A.ncols(), any(v.ncols() == A.ncols(), v.ncols() == size),));
506 }
507
508 #[cfg(feature = "perf-warn")]
509 match (u.rb(), v.rb()) {
510 (Some(matrix), _) | (_, Some(matrix)) => {
511 if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) {
512 if matrix.col_stride().unsigned_abs() == 1 {
513 log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found row-major matrix.");
514 } else {
515 log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found matrix with generic strides.");
516 }
517 }
518 },
519 _ => {},
520 }
521
522 let mut u = u;
523 let mut v = v;
524 let mut matrix = A;
525 let do_transpose = n > m;
526 if do_transpose {
527 matrix = matrix.transpose();
528 core::mem::swap(&mut u, &mut v)
529 }
530
531 let (m, n) = matrix.shape();
532 if n == 0 {
533 if let Some(mut u) = u {
534 u.fill(zero());
535 u.rb_mut().diagonal_mut().fill(one());
536 }
537 return Ok(());
538 }
539
540 if m as f64 / n as f64 <= params.qr_ratio_threshold {
541 compute_squareish_svd(matrix, s, u.rb_mut(), v.rb_mut(), par, stack, params)?;
542 } else {
543 let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
544 let (mut qr, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
545 let mut qr = qr.as_mat_mut();
546 let (mut householder, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n, stack) };
547 let mut householder = householder.as_mat_mut();
548
549 {
550 qr.copy_from(matrix.rb());
551 linalg::qr::no_pivoting::factor::qr_in_place(qr.rb_mut(), householder.rb_mut(), par, stack, params.qr.into());
552 }
553
554 {
555 let (mut r, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
556 let mut r = r.as_mat_mut();
557 z!(r.rb_mut()).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(dst)| *dst = zero());
558 z!(r.rb_mut(), qr.rb().submatrix(0, 0, n, n)).for_each_triangular_upper(linalg::zip::Diag::Include, |uz!(dst, src)| *dst = copy(*src));
559
560 compute_squareish_svd(r.rb(), s, u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)), v.rb_mut(), par, stack, params)?;
562 }
563
564 if let Some(mut u) = u.rb_mut() {
566 u.rb_mut().subrows_mut(n, m - n).fill(zero());
567 if u.ncols() == m {
568 u.rb_mut().submatrix_mut(n, n, m - n, m - n).diagonal_mut().fill(one());
569 }
570
571 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
572 qr.rb(),
573 householder.rb(),
574 Conj::No,
575 u.rb_mut(),
576 par,
577 stack,
578 );
579 }
580 }
581
582 if do_transpose {
583 if let Some(u) = u.rb_mut() {
585 z!(u).for_each(|uz!(u)| *u = conj(*u))
586 }
587 if let Some(v) = v.rb_mut() {
588 z!(v).for_each(|uz!(v)| *v = conj(*v))
589 }
590 }
591
592 Ok(())
593}
594
595pub fn pseudoinverse_from_svd_scratch<T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
598 _ = par;
599 let size = Ord::min(nrows, ncols);
600 StackReq::all_of(&[temp_mat_scratch::<T>(nrows, size), temp_mat_scratch::<T>(ncols, size)])
601}
602
603#[math]
605pub fn pseudoinverse_from_svd<T: ComplexField>(
606 pinv: MatMut<'_, T>,
607 s: DiagRef<'_, T>,
608 u: MatRef<'_, T>,
609 v: MatRef<'_, T>,
610 par: Par,
611 stack: &mut MemStack,
612) {
613 pseudoinverse_from_svd_with_tolerance(
614 pinv,
615 s,
616 u,
617 v,
618 zero(),
619 eps::<T::Real>() * from_f64::<T::Real>(Ord::max(u.nrows(), v.nrows()) as f64),
620 par,
621 stack,
622 );
623}
624
625#[math]
628pub fn pseudoinverse_from_svd_with_tolerance<T: ComplexField>(
629 pinv: MatMut<'_, T>,
630 s: DiagRef<'_, T>,
631 u: MatRef<'_, T>,
632 v: MatRef<'_, T>,
633 abs_tol: T::Real,
634 rel_tol: T::Real,
635 par: Par,
636 stack: &mut MemStack,
637) {
638 let mut pinv = pinv;
639 let m = u.nrows();
640 let n = v.nrows();
641 let size = Ord::min(m, n);
642
643 assert!(all(u.nrows() == m, v.nrows() == n, u.ncols() >= size, v.ncols() >= size, s.dim() >= size));
644 let s = s.column_vector();
645 let u = u.get(.., ..size);
646 let v = v.get(.., ..size);
647
648 let smax = s.norm_max();
649 let tol = max(abs_tol, rel_tol * smax);
650
651 let (mut u_trunc, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, size, stack) };
652 let (mut vp_trunc, _) = unsafe { temp_mat_uninit::<T, _, _>(n, size, stack) };
653
654 let mut u_trunc = u_trunc.as_mat_mut();
655 let mut vp_trunc = vp_trunc.as_mat_mut();
656 let mut len = 0;
657
658 for j in 0..n {
659 let x = absmax(s[j]);
660 if x > tol {
661 let p = recip(real(s[j]));
662 u_trunc.rb_mut().col_mut(len).copy_from(u.col(j));
663 z!(vp_trunc.rb_mut().col_mut(len), v.col(j)).for_each(|uz!(dst, src)| *dst = mul_real(*src, p));
664
665 len += 1;
666 }
667 }
668
669 linalg::matmul::matmul(pinv.rb_mut(), Accum::Replace, vp_trunc.rb(), u_trunc.rb().adjoint(), one(), par);
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use crate::assert;
676 use crate::stats::prelude::*;
677 use crate::utils::approx::*;
678 use dyn_stack::MemBuffer;
679
680 #[track_caller]
681 fn test_svd<T: ComplexField>(mat: MatRef<'_, T>) {
682 let (m, n) = mat.shape();
683 let params = Spec::new(SvdParams {
684 recursion_threshold: 8,
685 qr_ratio_threshold: 1.0,
686 ..auto!(T)
687 });
688 use faer_traits::math_utils::*;
689 let approx_eq = CwiseMat(ApproxEq::<T::Real>::eps() * sqrt(&from_f64(8.0 * Ord::max(m, n) as f64)));
690
691 {
692 let mut s = Mat::zeros(m, n);
693 let mut u = Mat::zeros(m, m);
694 let mut v = Mat::zeros(n, n);
695
696 svd(
697 mat.as_ref(),
698 s.as_mut().diagonal_mut(),
699 Some(u.as_mut()),
700 Some(v.as_mut()),
701 Par::Seq,
702 MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
703 m,
704 n,
705 ComputeSvdVectors::Full,
706 ComputeSvdVectors::Full,
707 Par::Seq,
708 params,
709 ))),
710 params,
711 )
712 .unwrap();
713
714 let reconstructed = &u * &s * v.adjoint();
715 assert!(reconstructed ~ mat);
716 }
717
718 let size = Ord::min(m, n);
719 let mut s = Mat::zeros(size, size);
720 let mut u = Mat::zeros(m, size);
721 let mut v = Mat::zeros(n, size);
722
723 {
724 svd(
725 mat.as_ref(),
726 s.as_mut().diagonal_mut(),
727 Some(u.as_mut()),
728 Some(v.as_mut()),
729 Par::Seq,
730 MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
731 m,
732 n,
733 ComputeSvdVectors::Thin,
734 ComputeSvdVectors::Thin,
735 Par::Seq,
736 params,
737 ))),
738 params,
739 )
740 .unwrap();
741
742 let reconstructed = &u * &s * v.adjoint();
743 assert!(reconstructed ~ mat);
744 }
745 {
746 let mut s2 = Mat::zeros(size, size);
747 let mut u2 = Mat::zeros(m, size);
748
749 svd(
750 mat.as_ref(),
751 s2.as_mut().diagonal_mut(),
752 Some(u2.as_mut()),
753 None,
754 Par::Seq,
755 MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
756 m,
757 n,
758 ComputeSvdVectors::Thin,
759 ComputeSvdVectors::No,
760 Par::Seq,
761 params,
762 ))),
763 params,
764 )
765 .unwrap();
766
767 assert!(s2 ~ s);
768 assert!(u2 ~ u);
769 }
770
771 {
772 let mut s2 = Mat::zeros(size, size);
773 let mut v2 = Mat::zeros(n, size);
774
775 svd(
776 mat.as_ref(),
777 s2.as_mut().diagonal_mut(),
778 None,
779 Some(v2.as_mut()),
780 Par::Seq,
781 MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
782 m,
783 n,
784 ComputeSvdVectors::No,
785 ComputeSvdVectors::Thin,
786 Par::Seq,
787 params,
788 ))),
789 params,
790 )
791 .unwrap();
792
793 assert!(s2 ~ s);
794 assert!(v2 ~ v);
795 }
796 {
797 let mut s2 = Mat::zeros(size, size);
798
799 svd(
800 mat.as_ref(),
801 s2.as_mut().diagonal_mut(),
802 None,
803 None,
804 Par::Seq,
805 MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
806 m,
807 n,
808 ComputeSvdVectors::No,
809 ComputeSvdVectors::No,
810 Par::Seq,
811 params,
812 ))),
813 params,
814 )
815 .unwrap();
816
817 assert!(s2 ~ s);
818 }
819 }
820
821 #[test]
822 fn test_real() {
823 let rng = &mut StdRng::seed_from_u64(1);
824
825 for (m, n) in [
826 (3, 2),
827 (2, 2),
828 (4, 4),
829 (15, 10),
830 (10, 10),
831 (15, 15),
832 (50, 50),
833 (100, 100),
834 (150, 150),
835 (150, 20),
836 (20, 150),
837 ] {
838 let mat = CwiseMatDistribution {
839 nrows: m,
840 ncols: n,
841 dist: StandardNormal,
842 }
843 .rand::<Mat<f64>>(rng);
844
845 test_svd(mat.as_ref());
846 }
847 }
848
849 #[test]
850 fn test_cplx() {
851 let rng = &mut StdRng::seed_from_u64(1);
852
853 for (m, n) in [
854 (1, 1),
855 (2, 2),
856 (3, 2),
857 (2, 2),
858 (3, 3),
859 (4, 4),
860 (15, 10),
861 (10, 10),
862 (15, 15),
863 (16, 16),
864 (17, 17),
865 (18, 18),
866 (19, 19),
867 (20, 20),
868 (30, 30),
869 (50, 50),
870 (100, 100),
871 (150, 150),
872 (150, 20),
873 (20, 150),
874 ] {
875 let mat = CwiseMatDistribution {
876 nrows: m,
877 ncols: n,
878 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
879 }
880 .rand::<Mat<c64>>(rng);
881
882 test_svd(mat.as_ref());
883 }
884 }
885
886 #[test]
887 fn test_special() {
888 for (m, n) in [
889 (3, 2),
890 (2, 2),
891 (4, 4),
892 (15, 10),
893 (10, 10),
894 (15, 15),
895 (50, 50),
896 (100, 100),
897 (150, 150),
898 (150, 20),
899 (20, 150),
900 ] {
901 test_svd(Mat::<f64>::zeros(m, n).as_ref());
902 test_svd(Mat::<c64>::zeros(m, n).as_ref());
903 test_svd(Mat::<f64>::full(m, n, 1.0).as_ref());
904 test_svd(Mat::<c64>::full(m, n, c64::ONE).as_ref());
905 test_svd(Mat::<f64>::identity(m, n).as_ref());
906 test_svd(Mat::<c64>::identity(m, n).as_ref());
907 }
908 }
909
910 #[test]
911 fn test_zink() {
912 let diag = [
913 -9.931833701529301,
914 -10.920807536026027,
915 -52.33647796311243,
916 2.3685025127736967,
917 2.421701994236093,
918 -0.5051763005624579,
919 -0.04808263896606017,
920 -0.003875251886338955,
921 -0.0006413264967716465,
922 -0.003381944152463707,
923 2.981152313236375e-5,
924 5.4290648208388795e-6,
925 -6.329275972084404e-7,
926 -6.879142344209158e-7,
927 -5.265228263479126e-9,
928 -2.941999902335516e-9,
929 -1.3060984997930294e-10,
930 7.07516117218088e-12,
931 1.8657003929029376e-12,
932 -6.216080089659131e-14,
933 ];
934 let subdiag = [
935 -57.8029649868477,
936 17.67263066467847,
937 8.884153814270894,
938 -9.01998231080713,
939 -1.028638150814966,
940 0.22247719217200435,
941 0.016389886745811315,
942 -0.004090989452162578,
943 0.00036818904090536926,
944 -0.0031394146217732367,
945 -7.571300829706796e-6,
946 3.0045718718618155e-6,
947 2.1329796886727743e-6,
948 9.259701025627789e-8,
949 2.2291214755992877e-9,
950 -2.3017207713252894e-9,
951 6.807967994979358e-11,
952 2.1677299575405587e-12,
953 -3.07282771050034e-13,
954 0.0,
955 ];
956
957 let n = diag.len();
958 let params = SvdParams {
959 recursion_threshold: 8,
960 qr_ratio_threshold: 1.0,
961 ..auto!(f64)
962 };
963
964 let mut d = ColRef::from_slice(&diag).to_owned();
965 let mut s = ColRef::from_slice(&subdiag).to_owned();
966 compute_bidiag_real_svd(
967 d.as_mut().try_as_col_major_mut().unwrap(),
968 s.as_mut().try_as_col_major_mut().unwrap(),
969 None,
970 None,
971 params,
972 Par::Seq,
973 MemStack::new(&mut MemBuffer::new(bidiag_real_svd_scratch::<f64>(n, false, false, Par::Seq, params))),
974 )
975 .unwrap();
976
977 assert!(d[n - 1] != 0.0);
978 }
979}