faer/linalg/svd/
mod.rs

1//! low level implementation of the svd of a matrix
2//!
3//! the svd of a matrix $A$ of shape $(m, n)$ is a decomposition into three components $U$, $S$,
4//! and $V$, such that:
5//!
6//! - $U$ has shape $(m, m)$ and is a unitary matrix
7//! - $V$ has shape $(n, n)$ and is a unitary matrix
8//! - $S$ has shape $(m, n)$ and is zero everywhere except the main diagonal
9//! - and finally:
10//!
11//! $$A = U S V^H$$
12
13use bidiag::BidiagParams;
14use linalg::qr::no_pivoting::factor::QrParams;
15
16use crate::assert;
17use crate::internal_prelude::*;
18
19/// bidiagonalization
20pub mod bidiag;
21pub(crate) mod bidiag_svd;
22
23/// whether the singular vectors should be computed
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum ComputeSvdVectors {
26	/// do not compute singular vectors
27	No,
28	/// compute the first $\min(\text{nrows}, \text{ncols})$ singular vectors
29	Thin,
30	/// compute singular vectors
31	Full,
32}
33
34/// svd error
35#[derive(Copy, Clone, Debug, PartialEq, Eq)]
36pub enum SvdError {
37	/// reached max iterations
38	NoConvergence,
39}
40
41/// svd tuning parameters
42#[derive(Debug, Copy, Clone)]
43pub struct SvdParams {
44	/// bidiagonalization parameters
45	pub bidiag: BidiagParams,
46	/// $QR$ parameters
47	pub qr: QrParams,
48	/// threshold at which the implementation should stop recursing
49	pub recursion_threshold: usize,
50	/// threshold at which parallelism should be disabled
51	pub qr_ratio_threshold: f64,
52
53	#[doc(hidden)]
54	pub non_exhaustive: NonExhaustive,
55}
56
57impl<T: ComplexField> Auto<T> for SvdParams {
58	fn auto() -> Self {
59		Self {
60			recursion_threshold: 128,
61			qr_ratio_threshold: 11.0 / 6.0,
62
63			bidiag: auto!(T),
64			qr: auto!(T),
65			non_exhaustive: NonExhaustive(()),
66		}
67	}
68}
69
70fn svd_imp_scratch<T: ComplexField>(
71	m: usize,
72	n: usize,
73	compute_u: ComputeSvdVectors,
74	compute_v: ComputeSvdVectors,
75
76	bidiag_svd_scratch: fn(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq,
77
78	params: SvdParams,
79
80	par: Par,
81) -> StackReq {
82	assert!(m >= n);
83
84	let householder_blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
85	let bid = temp_mat_scratch::<T>(m, n);
86	let householder_left = temp_mat_scratch::<T>(householder_blocksize, n);
87	let householder_right = temp_mat_scratch::<T>(householder_blocksize, n);
88
89	let compute_bidiag = bidiag::bidiag_in_place_scratch::<T>(m, n, par, params.bidiag.into());
90	let diag = temp_mat_scratch::<T>(n, 1);
91	let subdiag = diag;
92	let compute_ub = compute_v != ComputeSvdVectors::No;
93	let compute_vb = compute_u != ComputeSvdVectors::No;
94	let u_b = temp_mat_scratch::<T>(if compute_ub { n + 1 } else { 2 }, n + 1);
95	let v_b = temp_mat_scratch::<T>(n, if compute_vb { n } else { 0 });
96
97	let compute_bidiag_svd = bidiag_svd_scratch(n, compute_ub, compute_vb, par, params);
98
99	let apply_householder_u = linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
100		m,
101		householder_blocksize,
102		match compute_u {
103			ComputeSvdVectors::No => 0,
104			ComputeSvdVectors::Thin => n,
105			ComputeSvdVectors::Full => m,
106		},
107	);
108	let apply_householder_v = linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
109		n - 1,
110		householder_blocksize,
111		match compute_v {
112			ComputeSvdVectors::No => 0,
113			_ => n,
114		},
115	);
116
117	StackReq::all_of(&[
118		bid,
119		householder_left,
120		householder_right,
121		StackReq::any_of(&[
122			compute_bidiag,
123			StackReq::all_of(&[
124				diag,
125				subdiag,
126				u_b,
127				v_b,
128				StackReq::any_of(&[compute_bidiag_svd, StackReq::all_of(&[apply_householder_u, apply_householder_v])]),
129			]),
130		]),
131	])
132}
133
134fn bidiag_cplx_svd_scratch<T: ComplexField>(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq {
135	StackReq::all_of(&[
136		temp_mat_scratch::<T>(n, 1).array(4),
137		temp_mat_scratch::<T::Real>(n + 1, if compute_u { n + 1 } else { 0 }),
138		temp_mat_scratch::<T::Real>(n, if compute_v { n } else { 0 }),
139		bidiag_real_svd_scratch::<T::Real>(n, compute_u, compute_v, par, params),
140	])
141}
142
143fn bidiag_real_svd_scratch<T: RealField>(n: usize, compute_u: bool, compute_v: bool, par: Par, params: SvdParams) -> StackReq {
144	if n < params.recursion_threshold {
145		StackReq::EMPTY
146	} else {
147		StackReq::all_of(&[
148			temp_mat_scratch::<T>(2, if compute_u { 0 } else { n + 1 }),
149			bidiag_svd::divide_and_conquer_scratch::<T>(n, params.recursion_threshold, compute_u, compute_v, par),
150		])
151	}
152}
153
154#[math]
155fn compute_bidiag_cplx_svd<T: ComplexField>(
156	mut diag: ColMut<'_, T, usize, ContiguousFwd>,
157	subdiag: ColMut<'_, T, usize, ContiguousFwd>,
158	mut u: Option<MatMut<'_, T>>,
159	mut v: Option<MatMut<'_, T>>,
160	params: SvdParams,
161	par: Par,
162	stack: &mut MemStack,
163) -> Result<(), SvdError> {
164	let n = diag.nrows();
165
166	let (mut diag_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
167	let (mut subdiag_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
168	let (mut u_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n + 1, if u.is_some() { n + 1 } else { 0 }, stack) };
169	let (mut v_real, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, if v.is_some() { n } else { 0 }, stack) };
170
171	let (mut col_mul, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
172	let (mut row_mul, stack) = unsafe { temp_mat_uninit::<T, _, _>(n - 1, 1, stack) };
173
174	let mut u_real = u.rb().map(|_| u_real.as_mat_mut());
175	let mut v_real = v.rb().map(|_| v_real.as_mat_mut());
176
177	let mut diag_real = diag_real.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
178	let mut subdiag_real = subdiag_real.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
179
180	let mut col_mul = col_mul.as_mat_mut().col_mut(0);
181	let mut row_mul = row_mul.as_mat_mut().col_mut(0);
182
183	let normalized = |x: T| {
184		if x == zero() {
185			one()
186		} else {
187			let norm1 = max(abs(real(x)), abs(imag(x)));
188			let y = x * from_real(recip(norm1));
189			y * from_real(recip(abs(y)))
190		}
191	};
192
193	let mut col_normalized = normalized(conj(diag[0]));
194	col_mul[0] = copy(col_normalized);
195	diag_real[0] = abs(diag[0]);
196	subdiag_real[n - 1] = zero();
197	for i in 1..n {
198		let row_normalized = normalized(conj(subdiag[i - 1] * col_normalized));
199		subdiag_real[i - 1] = abs(subdiag[i - 1]);
200		row_mul[i - 1] = conj(row_normalized);
201
202		col_normalized = normalized(conj(diag[i] * row_normalized));
203		diag_real[i] = abs(diag[i]);
204		col_mul[i] = copy(col_normalized);
205	}
206
207	compute_bidiag_real_svd(
208		diag_real.rb_mut(),
209		subdiag_real.rb_mut(),
210		u_real.rb_mut(),
211		v_real.rb_mut(),
212		params,
213		par,
214		stack,
215	)?;
216
217	for i in 0..n {
218		diag[i] = from_real(diag_real[i]);
219	}
220
221	let u_real = u_real.rb();
222	let v_real = v_real.rb();
223
224	if let (Some(mut u), Some(u_real)) = (u.rb_mut(), u_real) {
225		z!(u.rb_mut().row_mut(0), u_real.row(0)).for_each(|uz!(u, r)| *u = from_real(*r));
226		z!(u.rb_mut().row_mut(n), u_real.row(n)).for_each(|uz!(u, r)| *u = from_real(*r));
227
228		for j in 0..u.ncols() {
229			let mut u = u.rb_mut().col_mut(j).subrows_mut(1, n - 1);
230			let u_real = u_real.rb().col(j).subrows(1, n - 1);
231			z!(u.rb_mut(), u_real, row_mul.rb()).for_each(|uz!(u, re, f)| *u = mul_real(*f, *re));
232		}
233	}
234	if let (Some(mut v), Some(v_real)) = (v.rb_mut(), v_real) {
235		for j in 0..v.ncols() {
236			let mut v = v.rb_mut().col_mut(j);
237			let v_real = v_real.rb().col(j);
238			z!(v.rb_mut(), v_real, col_mul.rb()).for_each(|uz!(v, re, f)| *v = mul_real(*f, *re));
239		}
240	}
241
242	Ok(())
243}
244
245#[math]
246fn compute_bidiag_real_svd<T: RealField>(
247	mut diag: ColMut<'_, T, usize, ContiguousFwd>,
248	mut subdiag: ColMut<'_, T, usize, ContiguousFwd>,
249	mut u: Option<MatMut<'_, T, usize, usize>>,
250	mut v: Option<MatMut<'_, T, usize, usize>>,
251	params: SvdParams,
252	par: Par,
253	stack: &mut MemStack,
254) -> Result<(), SvdError> {
255	let n = diag.nrows();
256	for i in 0..n {
257		if !(is_finite(diag[i]) && is_finite(subdiag[i])) {
258			return Err(SvdError::NoConvergence);
259		}
260	}
261
262	if n < params.recursion_threshold {
263		if let Some(mut u) = u.rb_mut() {
264			u.fill(zero());
265			u.diagonal_mut().fill(one());
266		}
267		if let Some(mut v) = v.rb_mut() {
268			v.fill(zero());
269			v.diagonal_mut().fill(one());
270		}
271
272		bidiag_svd::qr_algorithm(
273			diag.rb_mut(),
274			subdiag.rb_mut(),
275			u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)),
276			v.rb_mut(),
277		)?;
278
279		return Ok(());
280	} else {
281		let (mut u2, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(2, if u.is_some() { 0 } else { n + 1 }, stack) };
282
283		bidiag_svd::divide_and_conquer(
284			diag.as_row_shape_mut(n),
285			subdiag.as_row_shape_mut(n),
286			match u {
287				Some(u) => bidiag_svd::MatU::Full(u),
288				None => bidiag_svd::MatU::TwoRowsStorage(u2.as_mat_mut()),
289			},
290			v.map(|m| m.as_shape_mut(n, n)),
291			par,
292			stack,
293			params.recursion_threshold,
294		)
295	}
296}
297
298/// bidiag -> divide conquer svd / qr algo
299#[math]
300fn svd_imp<T: ComplexField>(
301	matrix: MatRef<'_, T>,
302	s: ColMut<'_, T>,
303	u: Option<MatMut<'_, T>>,
304	v: Option<MatMut<'_, T>>,
305	bidiag_svd: fn(
306		diag: ColMut<'_, T, usize, ContiguousFwd>,
307		subdiag: ColMut<'_, T, usize, ContiguousFwd>,
308		u: Option<MatMut<'_, T, usize, usize>>,
309		v: Option<MatMut<'_, T, usize, usize>>,
310		params: SvdParams,
311		par: Par,
312		stack: &mut MemStack,
313	) -> Result<(), SvdError>,
314	par: Par,
315	stack: &mut MemStack,
316	params: SvdParams,
317) -> Result<(), SvdError> {
318	assert!(matrix.nrows() >= matrix.ncols());
319	let m = matrix.nrows();
320	let n = matrix.ncols();
321
322	let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
323
324	let (mut bid, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
325	let mut bid = bid.as_mat_mut();
326
327	let (mut Hl, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n, stack) };
328	let (mut Hr, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n - 1, stack) };
329
330	let mut Hl = Hl.as_mat_mut();
331	let mut Hr = Hr.as_mat_mut();
332
333	bid.copy_from(matrix);
334	bidiag::bidiag_in_place(bid.rb_mut(), Hl.rb_mut(), Hr.rb_mut(), par, stack, params.bidiag.into());
335
336	let (mut diag, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
337	let (mut subdiag, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, 1, stack) };
338	let mut diag = diag.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
339	let mut subdiag = subdiag.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap();
340
341	let (mut ub, stack) = unsafe { temp_mat_uninit::<T, _, _>(n + 1, if v.is_some() { n + 1 } else { 0 }, stack) };
342	let (mut vb, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, if u.is_some() { n } else { 0 }, stack) };
343
344	let mut ub = ub.as_mat_mut();
345	let mut vb = vb.as_mat_mut();
346
347	for i in 0..n {
348		diag[i] = conj(bid[(i, i)]);
349		if i + 1 < n {
350			subdiag[i] = conj(bid[(i, i + 1)]);
351		} else {
352			subdiag[i] = zero();
353		}
354	}
355
356	bidiag_svd(
357		diag.rb_mut(),
358		subdiag.rb_mut(),
359		v.rb().map(|_| ub.rb_mut()),
360		u.rb().map(|_| vb.rb_mut()),
361		params,
362		par,
363		stack,
364	)?;
365
366	{ s }.copy_from(diag);
367
368	if let Some(mut u) = u {
369		let ncols = u.ncols();
370		u.rb_mut().submatrix_mut(0, 0, n, n).copy_from(vb.rb());
371		u.rb_mut().submatrix_mut(n, 0, m - n, ncols).fill(zero());
372		u.rb_mut().submatrix_mut(0, n, n, ncols - n).fill(zero());
373		u.rb_mut().submatrix_mut(n, n, ncols - n, ncols - n).diagonal_mut().fill(one());
374
375		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(bid.rb(), Hl.rb(), Conj::No, u, par, stack);
376	}
377	if let Some(mut v) = v {
378		v.copy_from(ub.rb().submatrix(0, 0, n, n));
379
380		for j in 1..n {
381			for i in 0..j {
382				bid[(j, i)] = copy(bid[(i, j)]);
383			}
384		}
385
386		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
387			bid.rb().submatrix(1, 0, n - 1, n - 1),
388			Hr.rb(),
389			Conj::Yes,
390			v.subrows_mut(1, n - 1),
391			par,
392			stack,
393		);
394	}
395
396	Ok(())
397}
398
399fn compute_squareish_svd<T: ComplexField>(
400	matrix: MatRef<'_, T>,
401	s: ColMut<'_, T>,
402	u: Option<MatMut<'_, T>>,
403	v: Option<MatMut<'_, T>>,
404	par: Par,
405	stack: &mut MemStack,
406	params: SvdParams,
407) -> Result<(), SvdError> {
408	if try_const! { T::IS_REAL } {
409		svd_imp::<T::Real>(
410			unsafe { core::mem::transmute(matrix) },
411			unsafe { core::mem::transmute(s) },
412			unsafe { core::mem::transmute(u) },
413			unsafe { core::mem::transmute(v) },
414			compute_bidiag_real_svd::<T::Real>,
415			par,
416			stack,
417			params,
418		)
419	} else {
420		svd_imp::<T>(matrix, s, u, v, compute_bidiag_cplx_svd::<T>, par, stack, params)
421	}
422}
423
424/// computes the size and alignment of the workspace required to compute a matrix's svd
425pub fn svd_scratch<T: ComplexField>(
426	nrows: usize,
427	ncols: usize,
428	compute_u: ComputeSvdVectors,
429	compute_v: ComputeSvdVectors,
430	par: Par,
431	params: Spec<SvdParams, T>,
432) -> StackReq {
433	let params = params.config;
434	let mut m = nrows;
435	let mut n = ncols;
436	let mut compute_u = compute_u;
437	let mut compute_v = compute_v;
438
439	if n > m {
440		core::mem::swap(&mut m, &mut n);
441		core::mem::swap(&mut compute_u, &mut compute_v);
442	}
443
444	if n == 0 {
445		return StackReq::EMPTY;
446	}
447
448	let bidiag_svd_scratch = if try_const! { T::IS_REAL } {
449		bidiag_real_svd_scratch::<T::Real>
450	} else {
451		bidiag_cplx_svd_scratch::<T>
452	};
453
454	if m as f64 / n as f64 <= params.qr_ratio_threshold {
455		svd_imp_scratch::<T>(m, n, compute_u, compute_v, bidiag_svd_scratch, params, par)
456	} else {
457		let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
458		StackReq::all_of(&[
459			temp_mat_scratch::<T>(m, n),
460			temp_mat_scratch::<T>(bs, n),
461			StackReq::any_of(&[
462				StackReq::all_of(&[
463					temp_mat_scratch::<T>(n, n),
464					svd_imp_scratch::<T>(n, n, compute_u, compute_v, bidiag_svd_scratch, params, par),
465				]),
466				linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
467					m,
468					bs,
469					match compute_u {
470						ComputeSvdVectors::No => 0,
471						ComputeSvdVectors::Thin => n,
472						ComputeSvdVectors::Full => m,
473					},
474				),
475			]),
476		])
477	}
478}
479
480/// computes the svd of $A$, with the singular vectors being omitted, thin or full
481///
482/// the singular are stored in $S$, and the singular vectors in $U$ and $V$ such that the singular
483/// values are sorted in nonincreasing order
484#[math]
485pub fn svd<T: ComplexField>(
486	A: MatRef<'_, T>,
487	s: DiagMut<'_, T>,
488	u: Option<MatMut<'_, T>>,
489	v: Option<MatMut<'_, T>>,
490	par: Par,
491	stack: &mut MemStack,
492	params: Spec<SvdParams, T>,
493) -> Result<(), SvdError> {
494	let params = params.config;
495
496	let (m, n) = A.shape();
497	let size = Ord::min(m, n);
498	assert!(s.dim() == size);
499	let s = s.column_vector_mut();
500
501	if let Some(u) = u.rb() {
502		assert!(all(u.nrows() == A.nrows(), any(u.ncols() == A.nrows(), u.ncols() == size),));
503	}
504	if let Some(v) = v.rb() {
505		assert!(all(v.nrows() == A.ncols(), any(v.ncols() == A.ncols(), v.ncols() == size),));
506	}
507
508	#[cfg(feature = "perf-warn")]
509	match (u.rb(), v.rb()) {
510		(Some(matrix), _) | (_, Some(matrix)) => {
511			if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) {
512				if matrix.col_stride().unsigned_abs() == 1 {
513					log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found row-major matrix.");
514				} else {
515					log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found matrix with generic strides.");
516				}
517			}
518		},
519		_ => {},
520	}
521
522	let mut u = u;
523	let mut v = v;
524	let mut matrix = A;
525	let do_transpose = n > m;
526	if do_transpose {
527		matrix = matrix.transpose();
528		core::mem::swap(&mut u, &mut v)
529	}
530
531	let (m, n) = matrix.shape();
532	if n == 0 {
533		if let Some(mut u) = u {
534			u.fill(zero());
535			u.rb_mut().diagonal_mut().fill(one());
536		}
537		return Ok(());
538	}
539
540	if m as f64 / n as f64 <= params.qr_ratio_threshold {
541		compute_squareish_svd(matrix, s, u.rb_mut(), v.rb_mut(), par, stack, params)?;
542	} else {
543		let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
544		let (mut qr, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
545		let mut qr = qr.as_mat_mut();
546		let (mut householder, stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, n, stack) };
547		let mut householder = householder.as_mat_mut();
548
549		{
550			qr.copy_from(matrix.rb());
551			linalg::qr::no_pivoting::factor::qr_in_place(qr.rb_mut(), householder.rb_mut(), par, stack, params.qr.into());
552		}
553
554		{
555			let (mut r, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
556			let mut r = r.as_mat_mut();
557			z!(r.rb_mut()).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(dst)| *dst = zero());
558			z!(r.rb_mut(), qr.rb().submatrix(0, 0, n, n)).for_each_triangular_upper(linalg::zip::Diag::Include, |uz!(dst, src)| *dst = copy(*src));
559
560			// r = u s v
561			compute_squareish_svd(r.rb(), s, u.rb_mut().map(|u| u.submatrix_mut(0, 0, n, n)), v.rb_mut(), par, stack, params)?;
562		}
563
564		// matrix = q u s v
565		if let Some(mut u) = u.rb_mut() {
566			u.rb_mut().subrows_mut(n, m - n).fill(zero());
567			if u.ncols() == m {
568				u.rb_mut().submatrix_mut(n, n, m - n, m - n).diagonal_mut().fill(one());
569			}
570
571			linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
572				qr.rb(),
573				householder.rb(),
574				Conj::No,
575				u.rb_mut(),
576				par,
577				stack,
578			);
579		}
580	}
581
582	if do_transpose {
583		// conjugate u and v
584		if let Some(u) = u.rb_mut() {
585			z!(u).for_each(|uz!(u)| *u = conj(*u))
586		}
587		if let Some(v) = v.rb_mut() {
588			z!(v).for_each(|uz!(v)| *v = conj(*v))
589		}
590	}
591
592	Ok(())
593}
594
595/// computes the size and alignment of the workspace required to compute a matrix's
596/// pseudoinverse, given the svd
597pub fn pseudoinverse_from_svd_scratch<T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
598	_ = par;
599	let size = Ord::min(nrows, ncols);
600	StackReq::all_of(&[temp_mat_scratch::<T>(nrows, size), temp_mat_scratch::<T>(ncols, size)])
601}
602
603/// computes a self-adjoint matrix's pseudoinverse, given the svd factors $S$, $U$ and $V$
604#[math]
605pub fn pseudoinverse_from_svd<T: ComplexField>(
606	pinv: MatMut<'_, T>,
607	s: DiagRef<'_, T>,
608	u: MatRef<'_, T>,
609	v: MatRef<'_, T>,
610	par: Par,
611	stack: &mut MemStack,
612) {
613	pseudoinverse_from_svd_with_tolerance(
614		pinv,
615		s,
616		u,
617		v,
618		zero(),
619		eps::<T::Real>() * from_f64::<T::Real>(Ord::max(u.nrows(), v.nrows()) as f64),
620		par,
621		stack,
622	);
623}
624
625/// computes a self-adjoint matrix's pseudoinverse, given the svd factors $S$, $U$ and $V$, and
626/// tolerance parameters for determining zero singular values
627#[math]
628pub fn pseudoinverse_from_svd_with_tolerance<T: ComplexField>(
629	pinv: MatMut<'_, T>,
630	s: DiagRef<'_, T>,
631	u: MatRef<'_, T>,
632	v: MatRef<'_, T>,
633	abs_tol: T::Real,
634	rel_tol: T::Real,
635	par: Par,
636	stack: &mut MemStack,
637) {
638	let mut pinv = pinv;
639	let m = u.nrows();
640	let n = v.nrows();
641	let size = Ord::min(m, n);
642
643	assert!(all(u.nrows() == m, v.nrows() == n, u.ncols() >= size, v.ncols() >= size, s.dim() >= size));
644	let s = s.column_vector();
645	let u = u.get(.., ..size);
646	let v = v.get(.., ..size);
647
648	let smax = s.norm_max();
649	let tol = max(abs_tol, rel_tol * smax);
650
651	let (mut u_trunc, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, size, stack) };
652	let (mut vp_trunc, _) = unsafe { temp_mat_uninit::<T, _, _>(n, size, stack) };
653
654	let mut u_trunc = u_trunc.as_mat_mut();
655	let mut vp_trunc = vp_trunc.as_mat_mut();
656	let mut len = 0;
657
658	for j in 0..n {
659		let x = absmax(s[j]);
660		if x > tol {
661			let p = recip(real(s[j]));
662			u_trunc.rb_mut().col_mut(len).copy_from(u.col(j));
663			z!(vp_trunc.rb_mut().col_mut(len), v.col(j)).for_each(|uz!(dst, src)| *dst = mul_real(*src, p));
664
665			len += 1;
666		}
667	}
668
669	linalg::matmul::matmul(pinv.rb_mut(), Accum::Replace, vp_trunc.rb(), u_trunc.rb().adjoint(), one(), par);
670}
671
672#[cfg(test)]
673mod tests {
674	use super::*;
675	use crate::assert;
676	use crate::stats::prelude::*;
677	use crate::utils::approx::*;
678	use dyn_stack::MemBuffer;
679
680	#[track_caller]
681	fn test_svd<T: ComplexField>(mat: MatRef<'_, T>) {
682		let (m, n) = mat.shape();
683		let params = Spec::new(SvdParams {
684			recursion_threshold: 8,
685			qr_ratio_threshold: 1.0,
686			..auto!(T)
687		});
688		use faer_traits::math_utils::*;
689		let approx_eq = CwiseMat(ApproxEq::<T::Real>::eps() * sqrt(&from_f64(8.0 * Ord::max(m, n) as f64)));
690
691		{
692			let mut s = Mat::zeros(m, n);
693			let mut u = Mat::zeros(m, m);
694			let mut v = Mat::zeros(n, n);
695
696			svd(
697				mat.as_ref(),
698				s.as_mut().diagonal_mut(),
699				Some(u.as_mut()),
700				Some(v.as_mut()),
701				Par::Seq,
702				MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
703					m,
704					n,
705					ComputeSvdVectors::Full,
706					ComputeSvdVectors::Full,
707					Par::Seq,
708					params,
709				))),
710				params,
711			)
712			.unwrap();
713
714			let reconstructed = &u * &s * v.adjoint();
715			assert!(reconstructed ~ mat);
716		}
717
718		let size = Ord::min(m, n);
719		let mut s = Mat::zeros(size, size);
720		let mut u = Mat::zeros(m, size);
721		let mut v = Mat::zeros(n, size);
722
723		{
724			svd(
725				mat.as_ref(),
726				s.as_mut().diagonal_mut(),
727				Some(u.as_mut()),
728				Some(v.as_mut()),
729				Par::Seq,
730				MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
731					m,
732					n,
733					ComputeSvdVectors::Thin,
734					ComputeSvdVectors::Thin,
735					Par::Seq,
736					params,
737				))),
738				params,
739			)
740			.unwrap();
741
742			let reconstructed = &u * &s * v.adjoint();
743			assert!(reconstructed ~ mat);
744		}
745		{
746			let mut s2 = Mat::zeros(size, size);
747			let mut u2 = Mat::zeros(m, size);
748
749			svd(
750				mat.as_ref(),
751				s2.as_mut().diagonal_mut(),
752				Some(u2.as_mut()),
753				None,
754				Par::Seq,
755				MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
756					m,
757					n,
758					ComputeSvdVectors::Thin,
759					ComputeSvdVectors::No,
760					Par::Seq,
761					params,
762				))),
763				params,
764			)
765			.unwrap();
766
767			assert!(s2 ~ s);
768			assert!(u2 ~ u);
769		}
770
771		{
772			let mut s2 = Mat::zeros(size, size);
773			let mut v2 = Mat::zeros(n, size);
774
775			svd(
776				mat.as_ref(),
777				s2.as_mut().diagonal_mut(),
778				None,
779				Some(v2.as_mut()),
780				Par::Seq,
781				MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
782					m,
783					n,
784					ComputeSvdVectors::No,
785					ComputeSvdVectors::Thin,
786					Par::Seq,
787					params,
788				))),
789				params,
790			)
791			.unwrap();
792
793			assert!(s2 ~ s);
794			assert!(v2 ~ v);
795		}
796		{
797			let mut s2 = Mat::zeros(size, size);
798
799			svd(
800				mat.as_ref(),
801				s2.as_mut().diagonal_mut(),
802				None,
803				None,
804				Par::Seq,
805				MemStack::new(&mut MemBuffer::new(svd_scratch::<T>(
806					m,
807					n,
808					ComputeSvdVectors::No,
809					ComputeSvdVectors::No,
810					Par::Seq,
811					params,
812				))),
813				params,
814			)
815			.unwrap();
816
817			assert!(s2 ~ s);
818		}
819	}
820
821	#[test]
822	fn test_real() {
823		let rng = &mut StdRng::seed_from_u64(1);
824
825		for (m, n) in [
826			(3, 2),
827			(2, 2),
828			(4, 4),
829			(15, 10),
830			(10, 10),
831			(15, 15),
832			(50, 50),
833			(100, 100),
834			(150, 150),
835			(150, 20),
836			(20, 150),
837		] {
838			let mat = CwiseMatDistribution {
839				nrows: m,
840				ncols: n,
841				dist: StandardNormal,
842			}
843			.rand::<Mat<f64>>(rng);
844
845			test_svd(mat.as_ref());
846		}
847	}
848
849	#[test]
850	fn test_cplx() {
851		let rng = &mut StdRng::seed_from_u64(1);
852
853		for (m, n) in [
854			(1, 1),
855			(2, 2),
856			(3, 2),
857			(2, 2),
858			(3, 3),
859			(4, 4),
860			(15, 10),
861			(10, 10),
862			(15, 15),
863			(16, 16),
864			(17, 17),
865			(18, 18),
866			(19, 19),
867			(20, 20),
868			(30, 30),
869			(50, 50),
870			(100, 100),
871			(150, 150),
872			(150, 20),
873			(20, 150),
874		] {
875			let mat = CwiseMatDistribution {
876				nrows: m,
877				ncols: n,
878				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
879			}
880			.rand::<Mat<c64>>(rng);
881
882			test_svd(mat.as_ref());
883		}
884	}
885
886	#[test]
887	fn test_special() {
888		for (m, n) in [
889			(3, 2),
890			(2, 2),
891			(4, 4),
892			(15, 10),
893			(10, 10),
894			(15, 15),
895			(50, 50),
896			(100, 100),
897			(150, 150),
898			(150, 20),
899			(20, 150),
900		] {
901			test_svd(Mat::<f64>::zeros(m, n).as_ref());
902			test_svd(Mat::<c64>::zeros(m, n).as_ref());
903			test_svd(Mat::<f64>::full(m, n, 1.0).as_ref());
904			test_svd(Mat::<c64>::full(m, n, c64::ONE).as_ref());
905			test_svd(Mat::<f64>::identity(m, n).as_ref());
906			test_svd(Mat::<c64>::identity(m, n).as_ref());
907		}
908	}
909
910	#[test]
911	fn test_zink() {
912		let diag = [
913			-9.931833701529301,
914			-10.920807536026027,
915			-52.33647796311243,
916			2.3685025127736967,
917			2.421701994236093,
918			-0.5051763005624579,
919			-0.04808263896606017,
920			-0.003875251886338955,
921			-0.0006413264967716465,
922			-0.003381944152463707,
923			2.981152313236375e-5,
924			5.4290648208388795e-6,
925			-6.329275972084404e-7,
926			-6.879142344209158e-7,
927			-5.265228263479126e-9,
928			-2.941999902335516e-9,
929			-1.3060984997930294e-10,
930			7.07516117218088e-12,
931			1.8657003929029376e-12,
932			-6.216080089659131e-14,
933		];
934		let subdiag = [
935			-57.8029649868477,
936			17.67263066467847,
937			8.884153814270894,
938			-9.01998231080713,
939			-1.028638150814966,
940			0.22247719217200435,
941			0.016389886745811315,
942			-0.004090989452162578,
943			0.00036818904090536926,
944			-0.0031394146217732367,
945			-7.571300829706796e-6,
946			3.0045718718618155e-6,
947			2.1329796886727743e-6,
948			9.259701025627789e-8,
949			2.2291214755992877e-9,
950			-2.3017207713252894e-9,
951			6.807967994979358e-11,
952			2.1677299575405587e-12,
953			-3.07282771050034e-13,
954			0.0,
955		];
956
957		let n = diag.len();
958		let params = SvdParams {
959			recursion_threshold: 8,
960			qr_ratio_threshold: 1.0,
961			..auto!(f64)
962		};
963
964		let mut d = ColRef::from_slice(&diag).to_owned();
965		let mut s = ColRef::from_slice(&subdiag).to_owned();
966		compute_bidiag_real_svd(
967			d.as_mut().try_as_col_major_mut().unwrap(),
968			s.as_mut().try_as_col_major_mut().unwrap(),
969			None,
970			None,
971			params,
972			Par::Seq,
973			MemStack::new(&mut MemBuffer::new(bidiag_real_svd_scratch::<f64>(n, false, false, Par::Seq, params))),
974		)
975		.unwrap();
976
977		assert!(d[n - 1] != 0.0);
978	}
979}