faer_svd/
lib.rs

1//! The SVD of a matrix $M$ of shape $(m, n)$ is a decomposition into three components $U$, $S$,
2//! and $V$, such that:
3//!
4//! - $U$ has shape $(m, m)$ and is a unitary matrix,
5//! - $V$ has shape $(n, n)$ and is a unitary matrix,
6//! - $S$ has shape $(m, n)$ and is zero everywhere except the main diagonal,
7//! - and finally:
8//!
9//! $$M = U S V^H.$$
10
11#![allow(clippy::type_complexity)]
12#![allow(clippy::too_many_arguments)]
13#![cfg_attr(not(feature = "std"), no_std)]
14
15use bidiag_real_svd::bidiag_real_svd_req;
16use coe::Coerce;
17use core::mem::swap;
18use dyn_stack::{PodStack, SizeOverflow, StackReq};
19use faer_core::{
20    assert,
21    householder::{
22        apply_block_householder_sequence_on_the_left_in_place_req,
23        apply_block_householder_sequence_on_the_left_in_place_with_conj,
24        upgrade_householder_factor,
25    },
26    temp_mat_req, temp_mat_uninit, unzipped,
27    zip::Diag,
28    zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField,
29};
30use num_complex::Complex;
31use reborrow::*;
32
33use crate::bidiag_real_svd::compute_bidiag_real_svd;
34
35#[doc(hidden)]
36pub mod bidiag;
37#[doc(hidden)]
38pub mod bidiag_real_svd;
39#[doc(hidden)]
40pub mod jacobi;
41
42const JACOBI_FALLBACK_THRESHOLD: usize = 4;
43const BIDIAG_QR_FALLBACK_THRESHOLD: usize = 128;
44
45/// Indicates whether the singular vectors are fully computed, partially computed, or skipped.
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum ComputeVectors {
48    No,
49    Thin,
50    Full,
51}
52
53fn compute_real_svd_small_req<E: Entity>(
54    m: usize,
55    n: usize,
56    compute_u: ComputeVectors,
57    compute_v: ComputeVectors,
58    parallelism: Parallelism,
59) -> Result<StackReq, SizeOverflow> {
60    assert!(m >= n);
61
62    if m == n {
63        return temp_mat_req::<E>(m, n);
64    }
65
66    let _ = compute_v;
67    let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(m, n);
68
69    let qr = temp_mat_req::<E>(m, n)?;
70    let householder = temp_mat_req::<E>(householder_blocksize, n)?;
71    let r = temp_mat_req::<E>(n, n)?;
72
73    let compute_qr = faer_qr::no_pivoting::compute::qr_in_place_req::<E>(
74        m,
75        n,
76        householder_blocksize,
77        parallelism,
78        Default::default(),
79    )?;
80
81    let apply_householder = apply_block_householder_sequence_on_the_left_in_place_req::<E>(
82        m,
83        householder_blocksize,
84        match compute_u {
85            ComputeVectors::No => 0,
86            ComputeVectors::Thin => n,
87            ComputeVectors::Full => m,
88        },
89    )?;
90
91    StackReq::try_all_of([
92        qr,
93        householder,
94        StackReq::try_any_of([StackReq::try_all_of([r, compute_qr])?, apply_householder])?,
95    ])
96}
97
98fn compute_svd_big_req<E: Entity>(
99    m: usize,
100    n: usize,
101    compute_u: ComputeVectors,
102    compute_v: ComputeVectors,
103    bidiag_svd_req: fn(
104        n: usize,
105        jacobi_fallback_threshold: usize,
106        compute_u: bool,
107        compute_v: bool,
108        parallelism: Parallelism,
109    ) -> Result<StackReq, SizeOverflow>,
110    parallelism: Parallelism,
111) -> Result<StackReq, SizeOverflow> {
112    assert!(m >= n);
113    let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(m, n);
114
115    let bid = temp_mat_req::<E>(m, n)?;
116    let householder_left = temp_mat_req::<E>(householder_blocksize, n)?;
117    let householder_right = temp_mat_req::<E>(householder_blocksize, n - 1)?;
118
119    let compute_bidiag = bidiag::bidiagonalize_in_place_req::<E>(m, n, parallelism)?;
120
121    let diag = StackReq::try_new::<E>(n)?;
122    let subdiag = diag;
123    let compute_ub = compute_v != ComputeVectors::No;
124    let compute_vb = compute_u != ComputeVectors::No;
125    let u_b = temp_mat_req::<E>(if compute_ub { n + 1 } else { 2 }, n + 1)?;
126    let v_b = temp_mat_req::<E>(n, if compute_vb { n } else { 0 })?;
127
128    let compute_bidiag_svd = bidiag_svd_req(
129        n,
130        JACOBI_FALLBACK_THRESHOLD,
131        compute_ub,
132        compute_vb,
133        parallelism,
134    )?;
135    let apply_householder_u = apply_block_householder_sequence_on_the_left_in_place_req::<E>(
136        m,
137        householder_blocksize,
138        match compute_u {
139            ComputeVectors::No => 0,
140            ComputeVectors::Thin => n,
141            ComputeVectors::Full => m,
142        },
143    )?;
144    let apply_householder_v = apply_block_householder_sequence_on_the_left_in_place_req::<E>(
145        n - 1,
146        householder_blocksize,
147        match compute_u {
148            ComputeVectors::No => 0,
149            _ => n,
150        },
151    )?;
152
153    StackReq::try_all_of([
154        bid,
155        householder_left,
156        householder_right,
157        StackReq::try_any_of([
158            compute_bidiag,
159            StackReq::try_all_of([
160                diag,
161                subdiag,
162                u_b,
163                v_b,
164                StackReq::try_any_of([
165                    compute_bidiag_svd,
166                    StackReq::try_all_of([apply_householder_u, apply_householder_v])?,
167                ])?,
168            ])?,
169        ])?,
170    ])
171}
172
173/// does qr -> jacobi svd
174fn compute_real_svd_small<E: RealField>(
175    matrix: MatRef<'_, E>,
176    s: MatMut<'_, E>,
177    u: Option<MatMut<'_, E>>,
178    v: Option<MatMut<'_, E>>,
179    epsilon: E,
180    zero_threshold: E,
181    parallelism: Parallelism,
182    stack: PodStack<'_>,
183) {
184    let mut u = u;
185    let mut v = v;
186
187    assert!(matrix.nrows() >= matrix.ncols());
188
189    let m = matrix.nrows();
190    let n = matrix.ncols();
191
192    // if the matrix is square, skip the QR
193    if m == n {
194        let (mut jacobi_mat, _) = temp_mat_uninit::<E>(m, n, stack);
195        let mut jacobi_mat = jacobi_mat.as_mut();
196        zipped!(jacobi_mat.rb_mut(), matrix)
197            .for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
198
199        jacobi::jacobi_svd(
200            jacobi_mat.rb_mut(),
201            u,
202            v,
203            jacobi::Skip::None,
204            epsilon,
205            zero_threshold,
206        );
207        zipped!(s, jacobi_mat.rb().diagonal().column_vector().as_2d())
208            .for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
209        return;
210    }
211
212    let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(m, n);
213
214    let (mut qr, stack) = temp_mat_uninit::<E>(m, n, stack);
215    let (mut householder, mut stack) = temp_mat_uninit::<E>(householder_blocksize, n, stack);
216    let mut qr = qr.as_mut();
217    let mut householder = householder.as_mut();
218
219    {
220        let (mut r, mut stack) = temp_mat_uninit::<E>(n, n, stack.rb_mut());
221        let mut r = r.as_mut();
222
223        zipped!(qr.rb_mut(), matrix).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
224
225        // matrix = q * r
226        faer_qr::no_pivoting::compute::qr_in_place(
227            qr.rb_mut(),
228            householder.rb_mut(),
229            parallelism,
230            stack.rb_mut(),
231            Default::default(),
232        );
233        zipped!(r.rb_mut())
234            .for_each_triangular_lower(Diag::Skip, |unzipped!(mut dst)| dst.write(E::faer_zero()));
235        zipped!(r.rb_mut(), qr.rb().submatrix(0, 0, n, n))
236            .for_each_triangular_upper(Diag::Include, |unzipped!(mut dst, src)| {
237                dst.write(src.read())
238            });
239
240        // r = u s v
241        jacobi::jacobi_svd(
242            r.rb_mut(),
243            u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)),
244            v.rb_mut(),
245            jacobi::Skip::None,
246            epsilon,
247            zero_threshold,
248        );
249        zipped!(s, r.rb().diagonal().column_vector().as_2d())
250            .for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
251    }
252
253    // matrix = q u s v
254    if let Some(mut u) = u.rb_mut() {
255        let ncols = u.ncols();
256        zipped!(u.rb_mut().submatrix_mut(n, 0, m - n, n))
257            .for_each(|unzipped!(mut dst)| dst.write(E::faer_zero()));
258        zipped!(u.rb_mut().submatrix_mut(0, n, m, ncols - n))
259            .for_each(|unzipped!(mut dst)| dst.write(E::faer_zero()));
260        if ncols == m {
261            zipped!(u
262                .rb_mut()
263                .submatrix_mut(n, n, m - n, m - n)
264                .diagonal_mut()
265                .column_vector_mut()
266                .as_2d_mut())
267            .for_each(|unzipped!(mut dst)| dst.write(E::faer_one()));
268        }
269
270        faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
271            qr.rb(),
272            householder.rb(),
273            Conj::No,
274            u,
275            parallelism,
276            stack.rb_mut(),
277        );
278    }
279}
280
281fn compute_bidiag_cplx_svd<E: ComplexField>(
282    diag: &mut [E],
283    subdiag: &mut [E],
284    mut u: Option<MatMut<'_, E>>,
285    mut v: Option<MatMut<'_, E>>,
286    jacobi_fallback_threshold: usize,
287    bidiag_qr_fallback_threshold: usize,
288    epsilon: E::Real,
289    consider_zero_threshold: E::Real,
290    parallelism: Parallelism,
291    stack: PodStack<'_>,
292) {
293    let n = diag.len();
294    let (mut u_real, stack) =
295        temp_mat_uninit::<E::Real>(n + 1, if u.is_some() { n + 1 } else { 0 }, stack);
296    let mut u_real = u_real.as_mut();
297    let (mut v_real, stack) = temp_mat_uninit::<E::Real>(n, if v.is_some() { n } else { 0 }, stack);
298    let mut v_real = v_real.as_mut();
299    let (diag_real, stack) = stack.collect(diag.iter().map(|x| x.faer_abs()));
300    let (subdiag_real, stack) = stack.collect(subdiag.iter().map(|x| x.faer_abs()));
301
302    let (col_mul, stack) = stack.make_with(n, |_| E::faer_zero());
303    let (row_mul, stack) = stack.make_with(n - 1, |_| E::faer_zero());
304
305    let normalized = |x: E| {
306        if x == E::faer_zero() {
307            E::faer_one()
308        } else {
309            let re = x.faer_real().faer_abs();
310            let im = x.faer_imag().faer_abs();
311            let max = if re > im { re } else { im };
312            let x = x.faer_scale_real(max.faer_inv());
313            x.faer_scale_real(x.faer_abs().faer_inv())
314        }
315    };
316
317    let mut col_normalized = normalized(diag[0]).faer_conj();
318    col_mul[0] = col_normalized;
319    for i in 1..n {
320        let row_normalized = normalized(subdiag[i - 1].faer_mul(col_normalized)).faer_conj();
321        row_mul[i - 1] = row_normalized.faer_conj();
322        col_normalized = normalized(diag[i].faer_mul(row_normalized)).faer_conj();
323        col_mul[i] = col_normalized;
324    }
325
326    compute_bidiag_real_svd::<E::Real>(
327        diag_real,
328        subdiag_real,
329        u.is_some().then_some(u_real.rb_mut()),
330        v.is_some().then_some(v_real.rb_mut()),
331        jacobi_fallback_threshold,
332        bidiag_qr_fallback_threshold,
333        epsilon,
334        consider_zero_threshold,
335        parallelism,
336        stack,
337    );
338
339    for i in 0..n {
340        diag[i] = E::faer_from_real(diag_real[i]);
341    }
342
343    let u_real = u_real.rb();
344    let v_real = v_real.rb();
345
346    if let Some(mut u) = u.rb_mut() {
347        zipped!(u.rb_mut().row_mut(0).as_2d_mut(), u_real.row(0).as_2d())
348            .for_each(|unzipped!(mut u, u_real)| u.write(E::faer_from_real(u_real.read())));
349        zipped!(u.rb_mut().row_mut(n).as_2d_mut(), u_real.row(n).as_2d())
350            .for_each(|unzipped!(mut u, u_real)| u.write(E::faer_from_real(u_real.read())));
351
352        for col_idx in 0..u.ncols() {
353            let mut u = u.rb_mut().col_mut(col_idx).subrows_mut(1, n - 1);
354            let u_real = u_real.col(col_idx).subrows(1, n - 1);
355
356            assert!(row_mul.len() == n - 1);
357            unsafe {
358                for (i, &row_mul) in row_mul.iter().enumerate() {
359                    u.write_unchecked(i, row_mul.faer_scale_real(u_real.read_unchecked(i)));
360                }
361            }
362        }
363    }
364    if let Some(mut v) = v.rb_mut() {
365        for col_idx in 0..v.ncols() {
366            let mut v = v.rb_mut().col_mut(col_idx);
367            let v_real = v_real.col(col_idx);
368
369            assert!(col_mul.len() == n);
370            unsafe {
371                for (i, &col_mul) in col_mul.iter().enumerate() {
372                    v.write_unchecked(i, col_mul.faer_scale_real(v_real.read_unchecked(i)));
373                }
374            }
375        }
376    }
377}
378
379fn bidiag_cplx_svd_req<E: Entity>(
380    n: usize,
381    jacobi_fallback_threshold: usize,
382    compute_u: bool,
383    compute_v: bool,
384    parallelism: Parallelism,
385) -> Result<StackReq, SizeOverflow> {
386    StackReq::try_all_of([
387        temp_mat_req::<E>(n + 1, if compute_u { n + 1 } else { 0 })?,
388        temp_mat_req::<E>(n, if compute_u { n } else { 0 })?,
389        StackReq::try_new::<E>(n)?,
390        StackReq::try_new::<E>(n)?,
391        StackReq::try_new::<Complex<E>>(n)?,
392        StackReq::try_new::<Complex<E>>(n - 1)?,
393        bidiag_real_svd_req::<E>(
394            n,
395            jacobi_fallback_threshold,
396            compute_u,
397            compute_v,
398            parallelism,
399        )?,
400    ])
401}
402
403/// does bidiagonilization -> divide conquer svd
404fn compute_svd_big<E: ComplexField>(
405    matrix: MatRef<'_, E>,
406    mut s: MatMut<'_, E>,
407    u: Option<MatMut<'_, E>>,
408    v: Option<MatMut<'_, E>>,
409    bidiag_svd: fn(
410        diag: &mut [E],
411        subdiag: &mut [E],
412        u: Option<MatMut<'_, E>>,
413        v: Option<MatMut<'_, E>>,
414        jacobi_fallback_threshold: usize,
415        bidiag_qr_fallback_threshold: usize,
416        epsilon: E::Real,
417        consider_zero_threshold: E::Real,
418        parallelism: Parallelism,
419        stack: PodStack<'_>,
420    ),
421    epsilon: E::Real,
422    zero_threshold: E::Real,
423    parallelism: Parallelism,
424    stack: PodStack<'_>,
425) {
426    let mut stack = stack;
427
428    assert!(matrix.nrows() >= matrix.ncols());
429
430    let m = matrix.nrows();
431    let n = matrix.ncols();
432    let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(m, n);
433
434    let (mut bid, stack) = temp_mat_uninit::<E>(m, n, stack.rb_mut());
435    let mut bid = bid.as_mut();
436    let (mut householder_left, stack) = temp_mat_uninit::<E>(householder_blocksize, n, stack);
437    let mut householder_left = householder_left.as_mut();
438    let (mut householder_right, mut stack) =
439        temp_mat_uninit::<E>(householder_blocksize, n - 1, stack);
440    let mut householder_right = householder_right.as_mut();
441
442    zipped!(bid.rb_mut(), matrix).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
443
444    bidiag::bidiagonalize_in_place(
445        bid.rb_mut(),
446        householder_left
447            .rb_mut()
448            .row_mut(0)
449            .transpose_mut()
450            .as_2d_mut(),
451        householder_right
452            .rb_mut()
453            .row_mut(0)
454            .transpose_mut()
455            .as_2d_mut(),
456        parallelism,
457        stack.rb_mut(),
458    );
459
460    let bid = bid.into_const();
461
462    let (diag, stack) = stack.make_with(n, |i| bid.read(i, i).faer_conj());
463    let (subdiag, stack) = stack.make_with(n, |i| {
464        if i < n - 1 {
465            bid.read(i, i + 1).faer_conj()
466        } else {
467            E::faer_zero()
468        }
469    });
470
471    let mut j_base = 0;
472    while j_base < n {
473        let bs = Ord::min(householder_blocksize, n - j_base);
474        let mut householder = householder_left.rb_mut().submatrix_mut(0, j_base, bs, bs);
475        let essentials = bid.submatrix(j_base, j_base, m - j_base, bs);
476        for j in 0..bs {
477            householder.write(j, j, householder.read(0, j));
478        }
479        upgrade_householder_factor(householder, essentials, bs, 1, parallelism);
480        j_base += bs;
481    }
482    let mut j_base = 0;
483    while j_base < n - 1 {
484        let bs = Ord::min(householder_blocksize, n - 1 - j_base);
485        let mut householder = householder_right.rb_mut().submatrix_mut(0, j_base, bs, bs);
486        let full_essentials = bid.submatrix(0, 1, m, n - 1).transpose();
487        let essentials = full_essentials.submatrix(j_base, j_base, n - 1 - j_base, bs);
488        for j in 0..bs {
489            householder.write(j, j, householder.read(0, j));
490        }
491        upgrade_householder_factor(householder, essentials, bs, 1, parallelism);
492        j_base += bs;
493    }
494
495    let (mut u_b, stack) = temp_mat_uninit::<E>(if v.is_some() { n + 1 } else { 0 }, n + 1, stack);
496    let mut u_b = u_b.as_mut();
497    let (mut v_b, mut stack) = temp_mat_uninit::<E>(n, if u.is_some() { n } else { 0 }, stack);
498    let mut v_b = v_b.as_mut();
499
500    bidiag_svd(
501        diag,
502        subdiag,
503        v.is_some().then_some(u_b.rb_mut()),
504        u.is_some().then_some(v_b.rb_mut()),
505        JACOBI_FALLBACK_THRESHOLD,
506        BIDIAG_QR_FALLBACK_THRESHOLD,
507        epsilon,
508        zero_threshold,
509        parallelism,
510        stack.rb_mut(),
511    );
512
513    for (idx, &diag) in diag.iter().enumerate() {
514        s.write(idx, 0, diag);
515    }
516
517    if let Some(mut u) = u {
518        let ncols = u.ncols();
519        zipped!(
520            u.rb_mut().submatrix_mut(0, 0, n, n),
521            v_b.rb().submatrix(0, 0, n, n),
522        )
523        .for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
524
525        zipped!(u.rb_mut().submatrix_mut(n, 0, m - n, ncols))
526            .for_each(|unzipped!(mut x)| x.write(E::faer_zero()));
527        zipped!(u.rb_mut().submatrix_mut(0, n, n, ncols - n))
528            .for_each(|unzipped!(mut x)| x.write(E::faer_zero()));
529        zipped!(u
530            .rb_mut()
531            .submatrix_mut(n, n, ncols - n, ncols - n)
532            .diagonal_mut()
533            .column_vector_mut()
534            .as_2d_mut())
535        .for_each(|unzipped!(mut x)| x.write(E::faer_one()));
536
537        apply_block_householder_sequence_on_the_left_in_place_with_conj(
538            bid,
539            householder_left.rb(),
540            Conj::No,
541            u,
542            parallelism,
543            stack.rb_mut(),
544        );
545    };
546    if let Some(mut v) = v {
547        zipped!(
548            v.rb_mut().submatrix_mut(0, 0, n, n),
549            u_b.rb().submatrix(0, 0, n, n),
550        )
551        .for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
552
553        let (mut bid_col_major, mut stack) =
554            faer_core::temp_mat_uninit::<E>(n - 1, m, stack.rb_mut());
555        let mut bid_col_major = bid_col_major.as_mut();
556        zipped!(
557            bid_col_major.rb_mut(),
558            bid.submatrix(0, 1, m, n - 1).transpose()
559        )
560        .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst, src)| {
561            dst.write(src.read())
562        });
563
564        apply_block_householder_sequence_on_the_left_in_place_with_conj(
565            bid_col_major.rb(),
566            householder_right.rb(),
567            Conj::No,
568            v.submatrix_mut(1, 0, n - 1, n),
569            parallelism,
570            stack.rb_mut(),
571        );
572    }
573}
574
575#[derive(Default, Copy, Clone)]
576#[non_exhaustive]
577pub struct SvdParams {}
578
579/// Computes the size and alignment of required workspace for performing a singular value
580/// decomposition. $U$ and $V$ may be computed fully, partially, or not computed at all.
581pub fn compute_svd_req<E: ComplexField>(
582    nrows: usize,
583    ncols: usize,
584    compute_u: ComputeVectors,
585    compute_v: ComputeVectors,
586    parallelism: Parallelism,
587    params: SvdParams,
588) -> Result<StackReq, SizeOverflow> {
589    let mut nrows = nrows;
590    let mut ncols = ncols;
591    let mut compute_u = compute_u;
592    let mut compute_v = compute_v;
593    let do_transpose = ncols > nrows;
594    if do_transpose {
595        swap(&mut nrows, &mut ncols);
596        swap(&mut compute_u, &mut compute_v);
597    }
598
599    if ncols == 0 {
600        return Ok(StackReq::default());
601    }
602
603    let size = Ord::min(nrows, ncols);
604    let skip_qr = nrows as f64 / ncols as f64 <= 11.0 / 6.0;
605    let (svd_nrows, svd_ncols) = if skip_qr {
606        (nrows, ncols)
607    } else {
608        (size, size)
609    };
610
611    let _ = params;
612    let squareish_svd = if coe::is_same::<E, E::Real>() {
613        if size <= JACOBI_FALLBACK_THRESHOLD {
614            compute_real_svd_small_req::<E>(svd_nrows, svd_ncols, compute_u, compute_v, parallelism)
615        } else {
616            compute_svd_big_req::<E::Real>(
617                svd_nrows,
618                svd_ncols,
619                compute_u,
620                compute_v,
621                bidiag_real_svd_req::<E::Real>,
622                parallelism,
623            )
624        }
625    } else {
626        compute_svd_big_req::<E>(
627            svd_nrows,
628            svd_ncols,
629            compute_u,
630            compute_v,
631            bidiag_cplx_svd_req::<E>,
632            parallelism,
633        )
634    }?;
635
636    if skip_qr {
637        Ok(squareish_svd)
638    } else {
639        let householder_blocksize =
640            faer_qr::no_pivoting::compute::recommended_blocksize::<E>(nrows, ncols);
641
642        StackReq::try_all_of([
643            temp_mat_req::<E>(nrows, ncols)?,
644            temp_mat_req::<E>(householder_blocksize, ncols)?,
645            StackReq::try_any_of([
646                StackReq::try_all_of([
647                    temp_mat_req::<E>(size, size)?,
648                    StackReq::try_any_of([
649                        faer_qr::no_pivoting::compute::qr_in_place_req::<E>(
650                            nrows,
651                            ncols,
652                            householder_blocksize,
653                            parallelism,
654                            Default::default(),
655                        )?,
656                        squareish_svd,
657                    ])?,
658                ])?,
659                apply_block_householder_sequence_on_the_left_in_place_req::<E>(
660                    nrows,
661                    householder_blocksize,
662                    nrows,
663                )?,
664            ])?,
665        ])
666    }
667}
668
669/// Computes the singular value decomposition of `matrix`.
670///
671/// `s` represents the main diagonal of the matrix $S$, and must have size equal to the minimum of
672/// `matrix.nrows()` and `matrix.ncols()`.
673///
674/// For each of `u` and `v`:
675/// - If the argument is `None`, then the corresponding singular vector matrix is not computed.
676/// - If it is `Some(..)`, then it must have a number of rows equal to `matrix.nrows()` for `u`,
677/// and `matrix.ncols()` for `v`.
678/// - The number of columns may be either equal to the number of rows, or it may be equal to the
679/// minimum of `matrix.nrows()` and `matrix.ncols()`, in which case only the singular vectors
680/// corresponding to the provided column storage are computed.
681///
682/// # Panics
683/// Panics if any of the conditions described above is violated, or if the type `E` does not have a
684/// fixed precision at compile time, e.g. a dynamic multiprecision floating point type.
685///
686/// This can also panic if the provided memory in `stack` is insufficient (see [`compute_svd_req`]).
687#[track_caller]
688pub fn compute_svd<E: ComplexField>(
689    matrix: MatRef<'_, E>,
690    s: MatMut<'_, E>,
691    u: Option<MatMut<'_, E>>,
692    v: Option<MatMut<'_, E>>,
693    parallelism: Parallelism,
694    stack: PodStack<'_>,
695    params: SvdParams,
696) {
697    compute_svd_custom_epsilon(
698        matrix,
699        s,
700        u,
701        v,
702        E::Real::faer_epsilon().unwrap(),
703        E::Real::faer_zero_threshold().unwrap(),
704        parallelism,
705        stack,
706        params,
707    );
708}
709
710/// See [`compute_svd`].
711///
712/// This function takes an additional `epsilon` and `zero_threshold` parameters. `epsilon`
713/// represents the precision of the values in the matrix, and `zero_threshold` is the value below
714/// which the precision starts to deteriorate, e.g. due to denormalized numbers.
715///
716/// These values need to be provided manually for types that do not have a known precision at
717/// compile time, e.g. a dynamic multiprecision floating point type.
718#[track_caller]
719pub fn compute_svd_custom_epsilon<E: ComplexField>(
720    matrix: MatRef<'_, E>,
721    s: MatMut<'_, E>,
722    u: Option<MatMut<'_, E>>,
723    v: Option<MatMut<'_, E>>,
724    epsilon: E::Real,
725    zero_threshold: E::Real,
726    parallelism: Parallelism,
727    stack: PodStack<'_>,
728    params: SvdParams,
729) {
730    let size = Ord::min(matrix.nrows(), matrix.ncols());
731    assert!(all(s.nrows() == size, s.ncols() == 1));
732    if let Some(u) = u.rb() {
733        assert!(u.nrows() == matrix.nrows());
734        assert!(u.ncols() == matrix.nrows() || u.ncols() == size);
735    }
736    if let Some(v) = v.rb() {
737        assert!(v.nrows() == matrix.ncols());
738        assert!(v.ncols() == matrix.ncols() || v.ncols() == size);
739    }
740
741    #[cfg(feature = "perf-warn")]
742    match (u.rb(), v.rb()) {
743        (Some(matrix), _) | (_, Some(matrix)) => {
744            if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) {
745                if matrix.col_stride().unsigned_abs() == 1 {
746                    log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found row-major matrix.");
747                } else {
748                    log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found matrix with generic strides.");
749                }
750            }
751        }
752        _ => {}
753    }
754
755    if !matrix.is_all_finite() {
756        { s }.fill(E::faer_nan());
757        if let Some(mut u) = u {
758            u.fill(E::faer_nan());
759        }
760        if let Some(mut v) = v {
761            v.fill(E::faer_nan());
762        }
763        return;
764    }
765
766    let mut u = u;
767    let mut v = v;
768    let mut matrix = matrix;
769    let do_transpose = matrix.ncols() > matrix.nrows();
770    if do_transpose {
771        matrix = matrix.transpose();
772        swap(&mut u, &mut v);
773    }
774
775    let m = matrix.nrows();
776    let n = matrix.ncols();
777
778    if n == 0 {
779        if let Some(mut u) = u {
780            zipped!(u.rb_mut()).for_each(|unzipped!(mut dst)| dst.write(E::faer_zero()));
781            zipped!(u
782                .submatrix_mut(0, 0, n, n)
783                .diagonal_mut()
784                .column_vector_mut()
785                .as_2d_mut())
786            .for_each(|unzipped!(mut dst)| dst.write(E::faer_one()));
787        }
788
789        return;
790    }
791
792    let _ = params;
793
794    if m as f64 / n as f64 <= 11.0 / 6.0 {
795        squareish_svd(
796            matrix,
797            s,
798            u.rb_mut(),
799            v.rb_mut(),
800            epsilon,
801            zero_threshold,
802            parallelism,
803            stack,
804        );
805    } else {
806        // do a qr first, then do the svd
807        let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::<E>(m, n);
808
809        let (mut qr, stack) = temp_mat_uninit::<E>(m, n, stack);
810        let mut qr = qr.as_mut();
811        let (mut householder, mut stack) = temp_mat_uninit::<E>(householder_blocksize, n, stack);
812        let mut householder = householder.as_mut();
813
814        {
815            let (mut r, mut stack) = temp_mat_uninit::<E>(n, n, stack.rb_mut());
816            let mut r = r.as_mut();
817
818            zipped!(qr.rb_mut(), matrix).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
819
820            // matrix = q * r
821            faer_qr::no_pivoting::compute::qr_in_place(
822                qr.rb_mut(),
823                householder.rb_mut(),
824                parallelism,
825                stack.rb_mut(),
826                Default::default(),
827            );
828            zipped!(r.rb_mut()).for_each_triangular_lower(Diag::Skip, |unzipped!(mut dst)| {
829                dst.write(E::faer_zero())
830            });
831            zipped!(r.rb_mut(), qr.rb().submatrix(0, 0, n, n))
832                .for_each_triangular_upper(Diag::Include, |unzipped!(mut dst, src)| {
833                    dst.write(src.read())
834                });
835
836            // r = u s v
837            squareish_svd(
838                r.rb(),
839                s,
840                u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)),
841                v.rb_mut(),
842                epsilon,
843                zero_threshold,
844                parallelism,
845                stack,
846            );
847        }
848
849        // matrix = q u s v
850        if let Some(mut u) = u.rb_mut() {
851            let ncols = u.ncols();
852            zipped!(u.rb_mut().submatrix_mut(n, 0, m - n, n))
853                .for_each(|unzipped!(mut dst)| dst.write(E::faer_zero()));
854            zipped!(u.rb_mut().submatrix_mut(0, n, m, ncols - n))
855                .for_each(|unzipped!(mut dst)| dst.write(E::faer_zero()));
856            if ncols == m {
857                zipped!(u
858                    .rb_mut()
859                    .submatrix_mut(n, n, m - n, m - n)
860                    .diagonal_mut()
861                    .column_vector_mut()
862                    .as_2d_mut())
863                .for_each(|unzipped!(mut dst)| dst.write(E::faer_one()));
864            }
865
866            faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
867                qr.rb(),
868                householder.rb(),
869                Conj::No,
870                u,
871                parallelism,
872                stack.rb_mut(),
873            );
874        }
875    }
876
877    if do_transpose {
878        // conjugate u and v
879        if let Some(u) = u {
880            zipped!(u).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()))
881        }
882        if let Some(v) = v {
883            zipped!(v).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()))
884        }
885    }
886}
887
888fn squareish_svd<E: ComplexField>(
889    matrix: MatRef<E>,
890    s: MatMut<E>,
891    mut u: Option<MatMut<E>>,
892    mut v: Option<MatMut<E>>,
893    epsilon: E::Real,
894    zero_threshold: E::Real,
895    parallelism: Parallelism,
896    stack: PodStack,
897) {
898    let size = matrix.ncols();
899    if coe::is_same::<E, E::Real>() {
900        if size <= JACOBI_FALLBACK_THRESHOLD {
901            compute_real_svd_small::<E::Real>(
902                matrix.coerce(),
903                s.coerce(),
904                u.rb_mut().map(coe::Coerce::coerce),
905                v.rb_mut().map(coe::Coerce::coerce),
906                coe::coerce_static(epsilon),
907                coe::coerce_static(zero_threshold),
908                parallelism,
909                stack,
910            );
911        } else {
912            compute_svd_big::<E::Real>(
913                matrix.coerce(),
914                s.coerce(),
915                u.rb_mut().map(coe::Coerce::coerce),
916                v.rb_mut().map(coe::Coerce::coerce),
917                compute_bidiag_real_svd::<E::Real>,
918                coe::coerce_static(epsilon),
919                coe::coerce_static(zero_threshold),
920                parallelism,
921                stack,
922            );
923        }
924    } else {
925        compute_svd_big::<E>(
926            matrix.coerce(),
927            s,
928            u,
929            v,
930            compute_bidiag_cplx_svd::<E>,
931            coe::coerce_static(epsilon),
932            coe::coerce_static(zero_threshold),
933            parallelism,
934            stack,
935        );
936    }
937}
938
939#[cfg(test)]
940mod tests {
941    use super::*;
942    use assert_approx_eq::assert_approx_eq;
943    use faer_core::{assert, c32, c64, Mat};
944
945    macro_rules! make_stack {
946        ($req: expr) => {
947            ::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
948        };
949    }
950
951    #[test]
952    fn test_real_big() {
953        for (m, n) in [(3, 2), (2, 2), (4, 4), (15, 10), (10, 10), (15, 15)] {
954            let mat = Mat::from_fn(m, n, |_, _| rand::random::<f64>());
955            let size = m.min(n);
956
957            let mut s = Mat::zeros(m, n);
958            let mut u = Mat::zeros(m, m);
959            let mut v = Mat::zeros(n, n);
960
961            compute_svd_big(
962                mat.as_ref(),
963                s.as_mut()
964                    .submatrix_mut(0, 0, size, size)
965                    .diagonal_mut()
966                    .column_vector_mut()
967                    .as_2d_mut(),
968                Some(u.as_mut()),
969                Some(v.as_mut()),
970                compute_bidiag_real_svd::<f64>,
971                f64::EPSILON,
972                f64::MIN_POSITIVE,
973                Parallelism::None,
974                make_stack!(compute_svd_big_req::<f64>(
975                    m,
976                    n,
977                    ComputeVectors::Full,
978                    ComputeVectors::Full,
979                    bidiag_real_svd_req::<f64>,
980                    Parallelism::None,
981                )),
982            );
983
984            let reconstructed = &u * &s * v.transpose();
985
986            for j in 0..n {
987                for i in 0..m {
988                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
989                }
990            }
991        }
992    }
993
994    #[test]
995    fn test_real_identity() {
996        for (m, n) in [(15, 10), (10, 10), (15, 15)] {
997            let mut mat = Mat::zeros(m, n);
998            let size = m.min(n);
999            for i in 0..size {
1000                mat.write(i, i, 1.0);
1001            }
1002
1003            let mut s = Mat::zeros(m, n);
1004            let mut u = Mat::zeros(m, m);
1005            let mut v = Mat::zeros(n, n);
1006
1007            compute_svd_big(
1008                mat.as_ref(),
1009                s.as_mut()
1010                    .submatrix_mut(0, 0, size, size)
1011                    .diagonal_mut()
1012                    .column_vector_mut()
1013                    .as_2d_mut(),
1014                Some(u.as_mut()),
1015                Some(v.as_mut()),
1016                compute_bidiag_real_svd::<f64>,
1017                f64::EPSILON,
1018                f64::MIN_POSITIVE,
1019                Parallelism::None,
1020                make_stack!(compute_svd_big_req::<f64>(
1021                    m,
1022                    n,
1023                    ComputeVectors::Full,
1024                    ComputeVectors::Full,
1025                    bidiag_real_svd_req::<f64>,
1026                    Parallelism::None,
1027                )),
1028            );
1029
1030            let reconstructed = &u * &s * v.transpose();
1031
1032            for j in 0..n {
1033                for i in 0..m {
1034                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1035                }
1036            }
1037        }
1038    }
1039
1040    #[test]
1041    fn test_real_zero() {
1042        for (m, n) in [(15, 10), (10, 10), (15, 15)] {
1043            let mat = Mat::zeros(m, n);
1044            let size = m.min(n);
1045
1046            let mut s = Mat::zeros(m, n);
1047            let mut u = Mat::zeros(m, m);
1048            let mut v = Mat::zeros(n, n);
1049
1050            compute_svd_big(
1051                mat.as_ref(),
1052                s.as_mut()
1053                    .submatrix_mut(0, 0, size, size)
1054                    .diagonal_mut()
1055                    .column_vector_mut()
1056                    .as_2d_mut(),
1057                Some(u.as_mut()),
1058                Some(v.as_mut()),
1059                compute_bidiag_real_svd::<f64>,
1060                f64::EPSILON,
1061                f64::MIN_POSITIVE,
1062                Parallelism::None,
1063                make_stack!(compute_svd_big_req::<f64>(
1064                    m,
1065                    n,
1066                    ComputeVectors::Full,
1067                    ComputeVectors::Full,
1068                    bidiag_real_svd_req::<f64>,
1069                    Parallelism::None,
1070                )),
1071            );
1072
1073            let reconstructed = &u * &s * v.transpose();
1074
1075            for j in 0..n {
1076                for i in 0..m {
1077                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1078                }
1079            }
1080        }
1081    }
1082
1083    #[test]
1084    fn test_real_small() {
1085        for (m, n) in [(4, 4), (5, 5), (15, 10), (10, 10), (15, 15)] {
1086            let mat = Mat::from_fn(m, n, |_, _| rand::random::<f64>());
1087            let size = m.min(n);
1088
1089            let mut s = Mat::zeros(m, n);
1090            let mut u = Mat::zeros(m, m);
1091            let mut v = Mat::zeros(n, n);
1092
1093            compute_real_svd_small(
1094                mat.as_ref(),
1095                s.as_mut()
1096                    .submatrix_mut(0, 0, size, size)
1097                    .diagonal_mut()
1098                    .column_vector_mut()
1099                    .as_2d_mut(),
1100                Some(u.as_mut()),
1101                Some(v.as_mut()),
1102                f64::EPSILON,
1103                f64::MIN_POSITIVE,
1104                Parallelism::None,
1105                make_stack!(compute_real_svd_small_req::<f64>(
1106                    m,
1107                    n,
1108                    ComputeVectors::Full,
1109                    ComputeVectors::Full,
1110                    Parallelism::None,
1111                )),
1112            );
1113
1114            let reconstructed = &u * &s * v.transpose();
1115
1116            for j in 0..n {
1117                for i in 0..m {
1118                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1119                }
1120            }
1121        }
1122    }
1123
1124    #[test]
1125    fn test_real() {
1126        for m in 0..20 {
1127            for n in 0..20 {
1128                let mat = Mat::from_fn(m, n, |_, _| rand::random::<f64>());
1129                let size = m.min(n);
1130
1131                let mut s = Mat::zeros(m, n);
1132                let mut u = Mat::zeros(m, m);
1133                let mut v = Mat::zeros(n, n);
1134
1135                compute_svd(
1136                    mat.as_ref(),
1137                    s.as_mut()
1138                        .submatrix_mut(0, 0, size, size)
1139                        .diagonal_mut()
1140                        .column_vector_mut()
1141                        .as_2d_mut(),
1142                    Some(u.as_mut()),
1143                    Some(v.as_mut()),
1144                    Parallelism::None,
1145                    make_stack!(compute_svd_req::<f64>(
1146                        m,
1147                        n,
1148                        ComputeVectors::Full,
1149                        ComputeVectors::Full,
1150                        Parallelism::None,
1151                        SvdParams::default(),
1152                    )),
1153                    SvdParams::default(),
1154                );
1155
1156                let reconstructed = &u * &s * v.transpose();
1157
1158                for j in 0..n {
1159                    for i in 0..m {
1160                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1161                    }
1162                }
1163            }
1164        }
1165    }
1166
1167    #[test]
1168    fn test_real_f32() {
1169        for m in 0..20 {
1170            for n in 0..20 {
1171                let mat = Mat::from_fn(m, n, |_, _| rand::random::<f32>());
1172                let size = m.min(n);
1173
1174                let mut s = Mat::zeros(m, n);
1175                let mut u = Mat::zeros(m, m);
1176                let mut v = Mat::zeros(n, n);
1177
1178                compute_svd(
1179                    mat.as_ref(),
1180                    s.as_mut()
1181                        .submatrix_mut(0, 0, size, size)
1182                        .diagonal_mut()
1183                        .column_vector_mut()
1184                        .as_2d_mut(),
1185                    Some(u.as_mut()),
1186                    Some(v.as_mut()),
1187                    Parallelism::None,
1188                    make_stack!(compute_svd_req::<f32>(
1189                        m,
1190                        n,
1191                        ComputeVectors::Full,
1192                        ComputeVectors::Full,
1193                        Parallelism::None,
1194                        SvdParams::default(),
1195                    )),
1196                    SvdParams::default(),
1197                );
1198
1199                let reconstructed = &u * &s * v.transpose();
1200
1201                for j in 0..n {
1202                    for i in 0..m {
1203                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-3);
1204                    }
1205                }
1206            }
1207        }
1208    }
1209
1210    #[test]
1211    fn test_real_thin() {
1212        for m in 0..20 {
1213            for n in 0..20 {
1214                use ComputeVectors::*;
1215                for compute_u in [No, Thin, Full] {
1216                    for compute_v in [No, Thin, Full] {
1217                        dbg!(m, n, compute_u, compute_v);
1218                        let mat = Mat::from_fn(m, n, |_, _| rand::random::<f64>());
1219                        let size = m.min(n);
1220
1221                        let mut s = Mat::zeros(m, n);
1222                        let mut u = Mat::zeros(
1223                            m,
1224                            match compute_u {
1225                                No => 0,
1226                                Thin => size,
1227                                Full => m,
1228                            },
1229                        );
1230                        let mut v = Mat::zeros(
1231                            n,
1232                            match compute_v {
1233                                No => 0,
1234                                Thin => size,
1235                                Full => n,
1236                            },
1237                        );
1238
1239                        compute_svd(
1240                            mat.as_ref(),
1241                            s.as_mut()
1242                                .submatrix_mut(0, 0, size, size)
1243                                .diagonal_mut()
1244                                .column_vector_mut()
1245                                .as_2d_mut(),
1246                            if compute_u == No {
1247                                None
1248                            } else {
1249                                Some(u.as_mut())
1250                            },
1251                            if compute_v == No {
1252                                None
1253                            } else {
1254                                Some(v.as_mut())
1255                            },
1256                            Parallelism::None,
1257                            make_stack!(compute_svd_req::<f64>(
1258                                m,
1259                                n,
1260                                compute_u,
1261                                compute_v,
1262                                Parallelism::None,
1263                                SvdParams::default(),
1264                            )),
1265                            SvdParams::default(),
1266                        );
1267
1268                        let mut s_target = Mat::zeros(m, n);
1269                        let mut u_target = Mat::zeros(m, m);
1270                        let mut v_target = Mat::zeros(n, n);
1271
1272                        compute_svd(
1273                            mat.as_ref(),
1274                            s_target
1275                                .as_mut()
1276                                .submatrix_mut(0, 0, size, size)
1277                                .diagonal_mut()
1278                                .column_vector_mut()
1279                                .as_2d_mut(),
1280                            Some(u_target.as_mut()),
1281                            Some(v_target.as_mut()),
1282                            Parallelism::None,
1283                            make_stack!(compute_svd_req::<f64>(
1284                                m,
1285                                n,
1286                                ComputeVectors::Full,
1287                                ComputeVectors::Full,
1288                                Parallelism::None,
1289                                SvdParams::default(),
1290                            )),
1291                            SvdParams::default(),
1292                        );
1293
1294                        for j in 0..u.ncols() {
1295                            for i in 0..u.nrows() {
1296                                assert_approx_eq!(u.read(i, j), u_target.read(i, j), 1e-10);
1297                            }
1298                        }
1299                        for j in 0..v.ncols() {
1300                            for i in 0..v.nrows() {
1301                                assert_approx_eq!(v.read(i, j), v_target.read(i, j), 1e-10);
1302                            }
1303                        }
1304                        for j in 0..s.ncols() {
1305                            for i in 0..s.nrows() {
1306                                assert_approx_eq!(s.read(i, j), s_target.read(i, j), 1e-10);
1307                            }
1308                        }
1309                    }
1310                }
1311            }
1312        }
1313    }
1314
1315    #[test]
1316    fn test_cplx() {
1317        for m in 0..20 {
1318            for n in 0..20 {
1319                let mat = Mat::from_fn(m, n, |_, _| c64::new(rand::random(), rand::random()));
1320                let size = m.min(n);
1321
1322                let mut s = Mat::zeros(m, n);
1323                let mut u = Mat::zeros(m, m);
1324                let mut v = Mat::zeros(n, n);
1325
1326                compute_svd(
1327                    mat.as_ref(),
1328                    s.as_mut()
1329                        .submatrix_mut(0, 0, size, size)
1330                        .diagonal_mut()
1331                        .column_vector_mut()
1332                        .as_2d_mut(),
1333                    Some(u.as_mut()),
1334                    Some(v.as_mut()),
1335                    Parallelism::None,
1336                    make_stack!(compute_svd_req::<c64>(
1337                        m,
1338                        n,
1339                        ComputeVectors::Full,
1340                        ComputeVectors::Full,
1341                        Parallelism::None,
1342                        SvdParams::default(),
1343                    )),
1344                    SvdParams::default(),
1345                );
1346
1347                let reconstructed = &u * &s * v.adjoint();
1348
1349                for j in 0..n {
1350                    for i in 0..m {
1351                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1352                    }
1353                }
1354            }
1355        }
1356    }
1357
1358    #[test]
1359    fn test_cplx_f32() {
1360        for m in 0..20 {
1361            for n in 0..20 {
1362                let mat = Mat::from_fn(m, n, |_, _| c32::new(rand::random(), rand::random()));
1363                let size = m.min(n);
1364
1365                let mut s = Mat::zeros(m, n);
1366                let mut u = Mat::zeros(m, m);
1367                let mut v = Mat::zeros(n, n);
1368
1369                compute_svd(
1370                    mat.as_ref(),
1371                    s.as_mut()
1372                        .submatrix_mut(0, 0, size, size)
1373                        .diagonal_mut()
1374                        .column_vector_mut()
1375                        .as_2d_mut(),
1376                    Some(u.as_mut()),
1377                    Some(v.as_mut()),
1378                    Parallelism::None,
1379                    make_stack!(compute_svd_req::<c32>(
1380                        m,
1381                        n,
1382                        ComputeVectors::Full,
1383                        ComputeVectors::Full,
1384                        Parallelism::None,
1385                        SvdParams::default(),
1386                    )),
1387                    SvdParams::default(),
1388                );
1389
1390                let reconstructed = &u * &s * v.adjoint();
1391
1392                for j in 0..n {
1393                    for i in 0..m {
1394                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-3);
1395                    }
1396                }
1397            }
1398        }
1399    }
1400
1401    #[test]
1402    fn test_cplx_thin() {
1403        for m in 0..20 {
1404            for n in 0..20 {
1405                use ComputeVectors::*;
1406                for compute_u in [No, Thin, Full] {
1407                    for compute_v in [No, Thin, Full] {
1408                        dbg!(m, n, compute_u, compute_v);
1409                        let mat =
1410                            Mat::from_fn(m, n, |_, _| c64::new(rand::random(), rand::random()));
1411                        let size = m.min(n);
1412
1413                        let mut s = Mat::zeros(m, n);
1414                        let mut u = Mat::zeros(
1415                            m,
1416                            match compute_u {
1417                                No => 0,
1418                                Thin => size,
1419                                Full => m,
1420                            },
1421                        );
1422                        let mut v = Mat::zeros(
1423                            n,
1424                            match compute_v {
1425                                No => 0,
1426                                Thin => size,
1427                                Full => n,
1428                            },
1429                        );
1430
1431                        compute_svd(
1432                            mat.as_ref(),
1433                            s.as_mut()
1434                                .submatrix_mut(0, 0, size, size)
1435                                .diagonal_mut()
1436                                .column_vector_mut()
1437                                .as_2d_mut(),
1438                            if compute_u == No {
1439                                None
1440                            } else {
1441                                Some(u.as_mut())
1442                            },
1443                            if compute_v == No {
1444                                None
1445                            } else {
1446                                Some(v.as_mut())
1447                            },
1448                            Parallelism::None,
1449                            make_stack!(compute_svd_req::<c64>(
1450                                m,
1451                                n,
1452                                compute_u,
1453                                compute_v,
1454                                Parallelism::None,
1455                                SvdParams::default(),
1456                            )),
1457                            SvdParams::default(),
1458                        );
1459
1460                        let mut s_target = Mat::zeros(m, n);
1461                        let mut u_target = Mat::zeros(m, m);
1462                        let mut v_target = Mat::zeros(n, n);
1463
1464                        compute_svd(
1465                            mat.as_ref(),
1466                            s_target
1467                                .as_mut()
1468                                .submatrix_mut(0, 0, size, size)
1469                                .diagonal_mut()
1470                                .column_vector_mut()
1471                                .as_2d_mut(),
1472                            Some(u_target.as_mut()),
1473                            Some(v_target.as_mut()),
1474                            Parallelism::None,
1475                            make_stack!(compute_svd_req::<c64>(
1476                                m,
1477                                n,
1478                                ComputeVectors::Full,
1479                                ComputeVectors::Full,
1480                                Parallelism::None,
1481                                SvdParams::default(),
1482                            )),
1483                            SvdParams::default(),
1484                        );
1485
1486                        for j in 0..u.ncols() {
1487                            for i in 0..u.nrows() {
1488                                assert_approx_eq!(u.read(i, j), u_target.read(i, j), 1e-10);
1489                            }
1490                        }
1491                        for j in 0..v.ncols() {
1492                            for i in 0..v.nrows() {
1493                                assert_approx_eq!(v.read(i, j), v_target.read(i, j), 1e-10);
1494                            }
1495                        }
1496                        for j in 0..s.ncols() {
1497                            for i in 0..s.nrows() {
1498                                assert_approx_eq!(s.read(i, j), s_target.read(i, j), 1e-10);
1499                            }
1500                        }
1501                    }
1502                }
1503            }
1504        }
1505    }
1506
1507    #[test]
1508    fn test_cplx_identity() {
1509        for (m, n) in [(15, 10), (10, 10), (15, 15)] {
1510            let mut mat = Mat::zeros(m, n);
1511            let size = m.min(n);
1512            for i in 0..size {
1513                mat.write(i, i, c64::faer_one());
1514            }
1515
1516            let mut s = Mat::zeros(m, n);
1517            let mut u = Mat::zeros(m, m);
1518            let mut v = Mat::zeros(n, n);
1519
1520            compute_svd_big(
1521                mat.as_ref(),
1522                s.as_mut()
1523                    .submatrix_mut(0, 0, size, size)
1524                    .diagonal_mut()
1525                    .column_vector_mut()
1526                    .as_2d_mut(),
1527                Some(u.as_mut()),
1528                Some(v.as_mut()),
1529                compute_bidiag_cplx_svd::<c64>,
1530                f64::EPSILON,
1531                f64::MIN_POSITIVE,
1532                Parallelism::None,
1533                make_stack!(compute_svd_big_req::<c64>(
1534                    m,
1535                    n,
1536                    ComputeVectors::Full,
1537                    ComputeVectors::Full,
1538                    bidiag_cplx_svd_req::<f64>,
1539                    Parallelism::None,
1540                )),
1541            );
1542
1543            let reconstructed = &u * &s * v.transpose();
1544
1545            for j in 0..n {
1546                for i in 0..m {
1547                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1548                }
1549            }
1550        }
1551    }
1552
1553    #[test]
1554    fn test_cplx_zero() {
1555        for (m, n) in [(15, 10), (10, 10), (15, 15)] {
1556            let mat = Mat::zeros(m, n);
1557            let size = m.min(n);
1558
1559            let mut s = Mat::zeros(m, n);
1560            let mut u = Mat::zeros(m, m);
1561            let mut v = Mat::zeros(n, n);
1562
1563            compute_svd_big(
1564                mat.as_ref(),
1565                s.as_mut()
1566                    .submatrix_mut(0, 0, size, size)
1567                    .diagonal_mut()
1568                    .column_vector_mut()
1569                    .as_2d_mut(),
1570                Some(u.as_mut()),
1571                Some(v.as_mut()),
1572                compute_bidiag_cplx_svd::<f64>,
1573                f64::EPSILON,
1574                f64::MIN_POSITIVE,
1575                Parallelism::None,
1576                make_stack!(compute_svd_big_req::<c64>(
1577                    m,
1578                    n,
1579                    ComputeVectors::Full,
1580                    ComputeVectors::Full,
1581                    bidiag_cplx_svd_req::<f64>,
1582                    Parallelism::None,
1583                )),
1584            );
1585
1586            let reconstructed = &u * &s * v.transpose();
1587
1588            for j in 0..n {
1589                for i in 0..m {
1590                    assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1591                }
1592            }
1593        }
1594    }
1595
1596    #[test]
1597    fn test_real_ones() {
1598        for n in [1, 2, 4, 8, 64, 512] {
1599            for m in [1, 2, 4, 8, 64, 512] {
1600                let f = |_, _| 1f64;
1601                let mat = Mat::from_fn(m, n, f);
1602                let mut s = Mat::zeros(m, n);
1603                let mut u = Mat::zeros(m, m);
1604                let mut v = Mat::zeros(n, n);
1605
1606                compute_svd(
1607                    mat.as_ref(),
1608                    s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1609                    Some(u.as_mut()),
1610                    Some(v.as_mut()),
1611                    faer_core::Parallelism::None,
1612                    make_stack!(compute_svd_req::<f64>(
1613                        m,
1614                        n,
1615                        ComputeVectors::Full,
1616                        ComputeVectors::Full,
1617                        faer_core::Parallelism::None,
1618                        Default::default(),
1619                    )),
1620                    Default::default(),
1621                );
1622
1623                let reconstructed = &u * &s * v.transpose();
1624
1625                for j in 0..n {
1626                    for i in 0..m {
1627                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1628                    }
1629                }
1630            }
1631        }
1632    }
1633
1634    #[test]
1635    fn test_cplx_ones() {
1636        for n in [1, 2, 4, 8, 32, 64, 512] {
1637            for m in [1, 2, 4, 8, 32, 64, 512] {
1638                let f = |_, _| c64::new(1.0, 0.0);
1639                let mat = Mat::from_fn(m, n, f);
1640                let mut s = Mat::zeros(m, n);
1641                let mut u = Mat::zeros(m, m);
1642                let mut v = Mat::zeros(n, n);
1643
1644                compute_svd(
1645                    mat.as_ref(),
1646                    s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
1647                    Some(u.as_mut()),
1648                    Some(v.as_mut()),
1649                    faer_core::Parallelism::None,
1650                    make_stack!(compute_svd_req::<c64>(
1651                        m,
1652                        n,
1653                        ComputeVectors::Full,
1654                        ComputeVectors::Full,
1655                        faer_core::Parallelism::None,
1656                        Default::default(),
1657                    )),
1658                    Default::default(),
1659                );
1660
1661                let reconstructed = &u * &s * v.transpose();
1662                for j in 0..n {
1663                    for i in 0..m {
1664                        assert_approx_eq!(reconstructed.read(i, j), mat.read(i, j), 1e-10);
1665                    }
1666                }
1667            }
1668        }
1669    }
1670}