1use crate::internal_prelude::*;
2use crate::{assert, get_global_parallelism};
3use alloc::vec;
4use alloc::vec::Vec;
5use dyn_stack::MemBuffer;
6use faer_traits::{ComplexConj, math_utils};
7use linalg::svd::ComputeSvdVectors;
8
9pub use linalg::cholesky::ldlt::factor::LdltError;
10pub use linalg::cholesky::llt::factor::LltError;
11pub use linalg::evd::EvdError;
12pub use linalg::gevd::{GevdError, SelfAdjointGevdError};
13pub use linalg::svd::SvdError;
14
15pub trait ShapeCore {
17 fn nrows(&self) -> usize;
19 fn ncols(&self) -> usize;
21}
22
23pub trait SolveCore<T: ComplexField>: ShapeCore {
25 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
28 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
31}
32pub trait SolveLstsqCore<T: ComplexField>: ShapeCore {
34 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
37}
38pub trait DenseSolveCore<T: ComplexField>: SolveCore<T> {
40 fn reconstruct(&self) -> Mat<T>;
42 fn inverse(&self) -> Mat<T>;
45}
46
47impl<S: ?Sized + ShapeCore> ShapeCore for &S {
48 #[inline]
49 fn nrows(&self) -> usize {
50 (**self).nrows()
51 }
52
53 #[inline]
54 fn ncols(&self) -> usize {
55 (**self).ncols()
56 }
57}
58
59impl<T: ComplexField, S: ?Sized + SolveCore<T>> SolveCore<T> for &S {
60 #[inline]
61 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
62 (**self).solve_in_place_with_conj(conj, rhs)
63 }
64
65 #[inline]
66 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
67 (**self).solve_transpose_in_place_with_conj(conj, rhs)
68 }
69}
70
71impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsqCore<T> for &S {
72 #[inline]
73 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
74 (**self).solve_lstsq_in_place_with_conj(conj, rhs)
75 }
76}
77
78impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolveCore<T> for &S {
79 #[inline]
80 fn reconstruct(&self) -> Mat<T> {
81 (**self).reconstruct()
82 }
83
84 #[inline]
85 fn inverse(&self) -> Mat<T> {
86 (**self).inverse()
87 }
88}
89
90pub trait Solve<T: ComplexField>: SolveCore<T> {
92 #[track_caller]
93 #[inline]
94 fn solve_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
96 self.solve_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
97 }
98 #[track_caller]
99 #[inline]
100 fn solve_conjugate_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
102 self.solve_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
103 }
104
105 #[track_caller]
106 #[inline]
107 fn solve_transpose_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
109 self.solve_transpose_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
110 }
111 #[track_caller]
112 #[inline]
113 fn solve_adjoint_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
115 self.solve_transpose_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
116 }
117
118 #[track_caller]
119 #[inline]
120 fn rsolve_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
122 self.solve_transpose_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
123 }
124 #[track_caller]
125 #[inline]
126 fn rsolve_conjugate_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
128 self.solve_transpose_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
129 }
130
131 #[track_caller]
132 #[inline]
133 fn rsolve_transpose_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
135 self.solve_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
136 }
137 #[track_caller]
138 #[inline]
139 fn rsolve_adjoint_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
141 self.solve_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
142 }
143
144 #[track_caller]
145 #[inline]
146 fn solve<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
148 let rhs = rhs.as_mat_ref();
149 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
150 out.as_mat_mut().copy_from(rhs);
151 self.solve_in_place(&mut out);
152 out
153 }
154 #[track_caller]
155 #[inline]
156 fn solve_conjugate<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
158 let rhs = rhs.as_mat_ref();
159 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
160 out.as_mat_mut().copy_from(rhs);
161 self.solve_conjugate_in_place(&mut out);
162 out
163 }
164
165 #[track_caller]
166 #[inline]
167 fn solve_transpose<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
169 let rhs = rhs.as_mat_ref();
170 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
171 out.as_mat_mut().copy_from(rhs);
172 self.solve_transpose_in_place(&mut out);
173 out
174 }
175 #[track_caller]
176 #[inline]
177 fn solve_adjoint<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
179 let rhs = rhs.as_mat_ref();
180 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
181 out.as_mat_mut().copy_from(rhs);
182 self.solve_adjoint_in_place(&mut out);
183 out
184 }
185
186 #[track_caller]
187 #[inline]
188 fn rsolve<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
190 let lhs = lhs.as_mat_ref();
191 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
192 out.as_mat_mut().copy_from(lhs);
193 self.rsolve_in_place(&mut out);
194 out
195 }
196 #[track_caller]
197 #[inline]
198 fn rsolve_conjugate<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
200 let lhs = lhs.as_mat_ref();
201 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
202 out.as_mat_mut().copy_from(lhs);
203 self.rsolve_conjugate_in_place(&mut out);
204 out
205 }
206
207 #[track_caller]
208 #[inline]
209 fn rsolve_transpose<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
211 let lhs = lhs.as_mat_ref();
212 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
213 out.as_mat_mut().copy_from(lhs);
214 self.rsolve_transpose_in_place(&mut out);
215 out
216 }
217 #[track_caller]
218 #[inline]
219 fn rsolve_adjoint<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
221 let lhs = lhs.as_mat_ref();
222 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
223 out.as_mat_mut().copy_from(lhs);
224 self.rsolve_adjoint_in_place(&mut out);
225 out
226 }
227}
228
229impl<C: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, C>>> mat::generic::Mat<Inner> {
230 #[track_caller]
231 pub fn partial_piv_lu(&self) -> PartialPivLu<C::Canonical> {
233 PartialPivLu::new(self.rb())
234 }
235
236 #[track_caller]
237 pub fn full_piv_lu(&self) -> FullPivLu<C::Canonical> {
239 FullPivLu::new(self.rb())
240 }
241
242 #[track_caller]
243 pub fn qr(&self) -> Qr<C::Canonical> {
245 Qr::new(self.rb())
246 }
247
248 #[track_caller]
249 pub fn col_piv_qr(&self) -> ColPivQr<C::Canonical> {
251 ColPivQr::new(self.rb())
252 }
253
254 #[track_caller]
255 pub fn svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
259 Svd::new(self.rb())
260 }
261
262 #[track_caller]
263 pub fn thin_svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
267 Svd::new_thin(self.rb())
268 }
269
270 #[track_caller]
271 pub fn llt(&self, side: Side) -> Result<Llt<C::Canonical>, LltError> {
273 Llt::new(self.rb(), side)
274 }
275
276 #[track_caller]
277 pub fn ldlt(&self, side: Side) -> Result<Ldlt<C::Canonical>, LdltError> {
279 Ldlt::new(self.rb(), side)
280 }
281
282 #[track_caller]
283 pub fn lblt(&self, side: Side) -> Lblt<C::Canonical> {
285 Lblt::new(self.rb(), side)
286 }
287
288 #[track_caller]
289 pub fn self_adjoint_eigen(&self, side: Side) -> Result<SelfAdjointEigen<C::Canonical>, EvdError> {
293 SelfAdjointEigen::new(self.rb(), side)
294 }
295
296 #[track_caller]
297 pub fn self_adjoint_eigenvalues(&self, side: Side) -> Result<Vec<Real<C>>, EvdError> {
301 #[track_caller]
302 pub fn imp<T: ComplexField>(mut A: MatRef<'_, T>, side: Side) -> Result<Vec<T::Real>, EvdError> {
303 assert!(A.nrows() == A.ncols());
304 if side == Side::Upper {
305 A = A.transpose();
306 }
307 let par = get_global_parallelism();
308 let n = A.nrows();
309
310 let mut s = Diag::<T>::zeros(n);
311
312 linalg::evd::self_adjoint_evd(
313 A,
314 s.as_mut(),
315 None,
316 par,
317 MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
318 n,
319 linalg::evd::ComputeEigenvectors::No,
320 par,
321 default(),
322 ))),
323 default(),
324 )?;
325
326 Ok(s.column_vector().iter().map(|x| real(x)).collect())
327 }
328
329 imp(self.rb().canonical(), side)
330 }
331
332 #[track_caller]
333 pub fn singular_values(&self) -> Result<Vec<Real<C>>, SvdError> {
337 pub fn imp<T: ComplexField>(A: MatRef<'_, T>) -> Result<Vec<T::Real>, SvdError> {
338 let par = get_global_parallelism();
339 let m = A.nrows();
340 let n = A.ncols();
341
342 let mut s = Diag::<T>::zeros(Ord::min(m, n));
343
344 linalg::svd::svd(
345 A,
346 s.as_mut(),
347 None,
348 None,
349 par,
350 MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(
351 m,
352 n,
353 linalg::svd::ComputeSvdVectors::No,
354 linalg::svd::ComputeSvdVectors::No,
355 par,
356 default(),
357 ))),
358 default(),
359 )?;
360
361 Ok(s.column_vector().iter().map(|x| real(x)).collect())
362 }
363
364 imp(self.rb().canonical())
365 }
366}
367
368impl<C: Conjugate> MatRef<'_, C> {
369 #[track_caller]
370 fn eigen_imp(&self) -> Result<Eigen<Real<C>>, EvdError> {
371 if const { C::Canonical::IS_REAL } {
372 Eigen::new_from_real(unsafe { crate::hacks::coerce(*self) })
373 } else if const { C::IS_CANONICAL } {
374 Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(*self) })
375 } else {
376 Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(*self) })
377 }
378 }
379
380 #[track_caller]
381 fn gen_eigen_imp(&self, B: MatRef<'_, C>) -> Result<GeneralizedEigen<Real<C>>, GevdError> {
382 if const { C::Canonical::IS_REAL } {
383 GeneralizedEigen::new_from_real(unsafe { crate::hacks::coerce(*self) }, unsafe { crate::hacks::coerce(B) })
384 } else if const { C::IS_CANONICAL } {
385 GeneralizedEigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(*self) }, unsafe {
386 crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(B)
387 })
388 } else {
389 GeneralizedEigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(*self) }, unsafe {
390 crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(B)
391 })
392 }
393 }
394
395 #[track_caller]
396 fn eigenvalues_imp(&self) -> Result<Vec<Complex<Real<C>>>, EvdError> {
397 let par = get_global_parallelism();
398
399 if const { C::Canonical::IS_REAL } {
400 let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Real<C>>>(*self) };
401 assert!(A.nrows() == A.ncols());
402 let n = A.nrows();
403
404 let mut s_re = Diag::<Real<C>>::zeros(n);
405 let mut s_im = Diag::<Real<C>>::zeros(n);
406
407 linalg::evd::evd_real(
408 A,
409 s_re.as_mut(),
410 s_im.as_mut(),
411 None,
412 None,
413 par,
414 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Real<C>>(
415 n,
416 linalg::evd::ComputeEigenvectors::No,
417 linalg::evd::ComputeEigenvectors::No,
418 par,
419 default(),
420 ))),
421 default(),
422 )?;
423
424 Ok(s_re
425 .column_vector()
426 .iter()
427 .zip(s_im.column_vector().iter())
428 .map(|(re, im)| Complex::new(re.clone(), im.clone()))
429 .collect())
430 } else {
431 let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(self.canonical()) };
432 assert!(A.nrows() == A.ncols());
433 let n = A.nrows();
434
435 let mut s = Diag::<Complex<Real<C>>>::zeros(n);
436
437 linalg::evd::evd_cplx(
438 A,
439 s.as_mut(),
440 None,
441 None,
442 par,
443 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<Real<C>>>(
444 n,
445 linalg::evd::ComputeEigenvectors::No,
446 linalg::evd::ComputeEigenvectors::No,
447 par,
448 default(),
449 ))),
450 default(),
451 )?;
452
453 if const { C::IS_CANONICAL } {
454 Ok(s.column_vector().iter().cloned().collect())
455 } else {
456 Ok(s.column_vector().iter().map(conj).collect())
457 }
458 }
459 }
460}
461
462impl<T: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, T>>> mat::generic::Mat<Inner> {
463 #[track_caller]
465 pub fn generalized_eigen(&self, B: impl AsMatRef<T = T, Rows = usize, Cols = usize>) -> Result<GeneralizedEigen<Real<T>>, GevdError> {
466 self.rb().gen_eigen_imp(B.as_mat_ref())
467 }
468
469 #[track_caller]
471 pub fn eigen(&self) -> Result<Eigen<Real<T>>, EvdError> {
472 self.rb().eigen_imp()
473 }
474
475 #[track_caller]
477 pub fn eigenvalues(&self) -> Result<Vec<Complex<Real<T>>>, EvdError> {
478 self.rb().eigenvalues_imp()
479 }
480}
481
482pub trait SolveLstsq<T: ComplexField>: SolveLstsqCore<T> {
484 #[track_caller]
485 #[inline]
486 fn solve_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
488 self.solve_lstsq_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
489 }
490
491 #[track_caller]
492 #[inline]
493 fn solve_conjugate_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
495 self.solve_lstsq_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
496 }
497
498 #[track_caller]
499 #[inline]
500 fn solve_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
502 let rhs = rhs.as_mat_ref();
503 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
504 out.as_mat_mut().copy_from(rhs);
505 self.solve_lstsq_in_place(&mut out);
506 out.truncate(self.ncols(), rhs.ncols());
507 out
508 }
509 #[track_caller]
510 #[inline]
511 fn solve_conjugate_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
513 let rhs = rhs.as_mat_ref();
514 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
515 out.as_mat_mut().copy_from(rhs);
516 self.solve_conjugate_lstsq_in_place(&mut out);
517 out.truncate(self.ncols(), rhs.ncols());
518 out
519 }
520}
521pub trait DenseSolve<T: ComplexField>: DenseSolveCore<T> {}
523
524impl<T: ComplexField, S: ?Sized + SolveCore<T>> Solve<T> for S {}
525impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsq<T> for S {}
526impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolve<T> for S {}
527
528#[derive(Clone, Debug)]
530pub struct Llt<T> {
531 L: Mat<T>,
532}
533
534#[derive(Clone, Debug)]
536pub struct Ldlt<T> {
537 L: Mat<T>,
538 D: Diag<T>,
539}
540
541#[derive(Clone, Debug)]
543pub struct Lblt<T> {
544 L: Mat<T>,
545 B_diag: Diag<T>,
546 B_subdiag: Diag<T>,
547 P: Perm<usize>,
548}
549
550#[derive(Clone, Debug)]
552pub struct PartialPivLu<T> {
553 L: Mat<T>,
554 U: Mat<T>,
555 P: Perm<usize>,
556}
557
558#[derive(Clone, Debug)]
560pub struct FullPivLu<T> {
561 L: Mat<T>,
562 U: Mat<T>,
563 P: Perm<usize>,
564 Q: Perm<usize>,
565}
566
567#[derive(Clone, Debug)]
569pub struct Qr<T> {
570 Q_basis: Mat<T>,
571 Q_coeff: Mat<T>,
572 R: Mat<T>,
573}
574
575#[derive(Clone, Debug)]
577pub struct ColPivQr<T> {
578 Q_basis: Mat<T>,
579 Q_coeff: Mat<T>,
580 R: Mat<T>,
581 P: Perm<usize>,
582}
583
584#[derive(Clone, Debug)]
586pub struct Svd<T> {
587 U: Mat<T>,
588 V: Mat<T>,
589 S: Diag<T>,
590}
591
592#[derive(Clone, Debug)]
594pub struct SelfAdjointEigen<T> {
595 U: Mat<T>,
596 S: Diag<T>,
597}
598
599#[derive(Clone, Debug)]
601pub struct Eigen<T> {
602 U: Mat<Complex<T>>,
603 S: Diag<Complex<T>>,
604}
605
606#[derive(Clone, Debug)]
608pub struct GeneralizedEigen<T> {
609 U: Mat<Complex<T>>,
610 S_a: Diag<Complex<T>>,
611 S_b: Diag<Complex<T>>,
612}
613
614impl<T: ComplexField> Llt<T> {
615 #[track_caller]
617 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LltError> {
618 assert!(all(A.nrows() == A.ncols()));
619 let n = A.nrows();
620
621 let mut L = Mat::zeros(n, n);
622 match side {
623 Side::Lower => L.copy_from_triangular_lower(A),
624 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
625 }
626
627 Self::new_imp(L)
628 }
629
630 #[track_caller]
631 fn new_imp(mut L: Mat<T>) -> Result<Self, LltError> {
632 let par = get_global_parallelism();
633
634 let n = L.nrows();
635
636 let mut mem = MemBuffer::new(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
637 let stack = MemStack::new(&mut mem);
638
639 linalg::cholesky::llt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
640 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
641
642 Ok(Self { L })
643 }
644
645 pub fn L(&self) -> MatRef<'_, T> {
647 self.L.as_ref()
648 }
649}
650
651impl<T: ComplexField> Ldlt<T> {
652 #[track_caller]
654 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LdltError> {
655 assert!(all(A.nrows() == A.ncols()));
656 let n = A.nrows();
657
658 let mut L = Mat::zeros(n, n);
659 match side {
660 Side::Lower => L.copy_from_triangular_lower(A),
661 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
662 }
663
664 Self::new_imp(L)
665 }
666
667 #[track_caller]
668 fn new_imp(mut L: Mat<T>) -> Result<Self, LdltError> {
669 let par = get_global_parallelism();
670
671 let n = L.nrows();
672 let mut D = Diag::zeros(n);
673
674 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
675 let stack = MemStack::new(&mut mem);
676
677 linalg::cholesky::ldlt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
678
679 D.copy_from(L.diagonal());
680 L.diagonal_mut().fill(one());
681 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
682
683 Ok(Self { L, D })
684 }
685
686 pub fn L(&self) -> MatRef<'_, T> {
688 self.L.as_ref()
689 }
690
691 pub fn D(&self) -> DiagRef<'_, T> {
693 self.D.as_ref()
694 }
695}
696
697impl<T: ComplexField> Lblt<T> {
698 #[track_caller]
700 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Self {
701 assert!(all(A.nrows() == A.ncols()));
702 let n = A.nrows();
703
704 let mut L = Mat::zeros(n, n);
705 match side {
706 Side::Lower => L.copy_from_triangular_lower(A),
707 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
708 }
709 Self::new_imp(L)
710 }
711
712 #[track_caller]
713 fn new_imp(mut L: Mat<T>) -> Self {
714 let par = get_global_parallelism();
715
716 let n = L.nrows();
717
718 let mut diag = Diag::zeros(n);
719 let mut subdiag = Diag::zeros(n);
720 let mut perm_fwd = vec![0usize; n];
721 let mut perm_bwd = vec![0usize; n];
722
723 let mut mem = MemBuffer::new(linalg::cholesky::lblt::factor::cholesky_in_place_scratch::<usize, T>(n, par, default()));
724 let stack = MemStack::new(&mut mem);
725
726 linalg::cholesky::lblt::factor::cholesky_in_place(L.as_mut(), subdiag.as_mut(), &mut perm_fwd, &mut perm_bwd, par, stack, default());
727
728 diag.copy_from(L.diagonal());
729 L.diagonal_mut().fill(one());
730 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
731
732 Self {
733 L,
734 B_diag: diag,
735 B_subdiag: subdiag,
736 P: unsafe { Perm::new_unchecked(perm_fwd.into_boxed_slice(), perm_bwd.into_boxed_slice()) },
737 }
738 }
739
740 pub fn L(&self) -> MatRef<'_, T> {
742 self.L.as_ref()
743 }
744
745 pub fn B_diag(&self) -> DiagRef<'_, T> {
747 self.B_diag.as_ref()
748 }
749
750 pub fn B_subdiag(&self) -> DiagRef<'_, T> {
752 self.B_subdiag.as_ref()
753 }
754
755 pub fn P(&self) -> PermRef<'_, usize> {
757 self.P.as_ref()
758 }
759}
760
761fn split_LU<T: ComplexField>(LU: Mat<T>) -> (Mat<T>, Mat<T>) {
762 let (m, n) = LU.shape();
763 let size = Ord::min(m, n);
764
765 let (L, U) = if m >= n {
766 let mut L = LU;
767 let mut U = Mat::zeros(size, size);
768
769 U.copy_from_triangular_upper(L.get(..size, ..size));
770
771 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
772 L.diagonal_mut().fill(one());
773
774 (L, U)
775 } else {
776 let mut U = LU;
777 let mut L = Mat::zeros(size, size);
778
779 L.copy_from_strict_triangular_lower(U.get(..size, ..size));
780
781 z!(&mut U).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
782 L.diagonal_mut().fill(one());
783
784 (L, U)
785 };
786 (L, U)
787}
788
789impl<T: ComplexField> PartialPivLu<T> {
790 #[track_caller]
792 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
793 let LU = A.to_owned();
794 Self::new_imp(LU)
795 }
796
797 #[track_caller]
798 fn new_imp(mut LU: Mat<T>) -> Self {
799 let par = get_global_parallelism();
800
801 let (m, n) = LU.shape();
802 let mut row_perm_fwd = vec![0usize; m];
803 let mut row_perm_bwd = vec![0usize; m];
804
805 linalg::lu::partial_pivoting::factor::lu_in_place(
806 LU.as_mut(),
807 &mut row_perm_fwd,
808 &mut row_perm_bwd,
809 par,
810 MemStack::new(&mut MemBuffer::new(
811 linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(m, n, par, default()),
812 )),
813 default(),
814 );
815
816 let (L, U) = split_LU(LU);
817
818 Self {
819 L,
820 U,
821 P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
822 }
823 }
824
825 pub fn L(&self) -> MatRef<'_, T> {
827 self.L.as_ref()
828 }
829
830 pub fn U(&self) -> MatRef<'_, T> {
832 self.U.as_ref()
833 }
834
835 pub fn P(&self) -> PermRef<'_, usize> {
837 self.P.as_ref()
838 }
839}
840
841impl<T: ComplexField> FullPivLu<T> {
842 #[track_caller]
844 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
845 let LU = A.to_owned();
846 Self::new_imp(LU)
847 }
848
849 #[track_caller]
850 fn new_imp(mut LU: Mat<T>) -> Self {
851 let par = get_global_parallelism();
852
853 let (m, n) = LU.shape();
854 let mut row_perm_fwd = vec![0usize; m];
855 let mut row_perm_bwd = vec![0usize; m];
856 let mut col_perm_fwd = vec![0usize; n];
857 let mut col_perm_bwd = vec![0usize; n];
858
859 linalg::lu::full_pivoting::factor::lu_in_place(
860 LU.as_mut(),
861 &mut row_perm_fwd,
862 &mut row_perm_bwd,
863 &mut col_perm_fwd,
864 &mut col_perm_bwd,
865 par,
866 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::factor::lu_in_place_scratch::<usize, T>(
867 m,
868 n,
869 par,
870 default(),
871 ))),
872 default(),
873 );
874
875 let (L, U) = split_LU(LU);
876
877 Self {
878 L,
879 U,
880 P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
881 Q: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
882 }
883 }
884
885 pub fn L(&self) -> MatRef<'_, T> {
887 self.L.as_ref()
888 }
889
890 pub fn U(&self) -> MatRef<'_, T> {
892 self.U.as_ref()
893 }
894
895 pub fn P(&self) -> PermRef<'_, usize> {
897 self.P.as_ref()
898 }
899
900 pub fn Q(&self) -> PermRef<'_, usize> {
902 self.Q.as_ref()
903 }
904}
905
906impl<T: ComplexField> Qr<T> {
907 #[track_caller]
909 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
910 let QR = A.to_owned();
911 Self::new_imp(QR)
912 }
913
914 #[track_caller]
915 fn new_imp(mut QR: Mat<T>) -> Self {
916 let par = get_global_parallelism();
917
918 let (m, n) = QR.shape();
919 let size = Ord::min(m, n);
920
921 let block_size = linalg::qr::no_pivoting::factor::recommended_block_size::<T>(m, n);
922 let mut Q_coeff = Mat::zeros(block_size, size);
923
924 linalg::qr::no_pivoting::factor::qr_in_place(
925 QR.as_mut(),
926 Q_coeff.as_mut(),
927 par,
928 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::factor::qr_in_place_scratch::<T>(
929 m,
930 n,
931 block_size,
932 par,
933 default(),
934 ))),
935 default(),
936 );
937
938 let (Q_basis, R) = split_LU(QR);
939
940 Self { Q_basis, Q_coeff, R }
941 }
942
943 pub fn Q_basis(&self) -> MatRef<'_, T> {
945 self.Q_basis.as_ref()
946 }
947
948 pub fn Q_coeff(&self) -> MatRef<'_, T> {
950 self.Q_coeff.as_ref()
951 }
952
953 pub fn R(&self) -> MatRef<'_, T> {
955 self.R.as_ref()
956 }
957
958 pub fn thin_R(&self) -> MatRef<'_, T> {
960 let size = Ord::min(self.nrows(), self.ncols());
961 self.R.get(..size, ..)
962 }
963
964 pub fn compute_Q(&self) -> Mat<T> {
966 let mut Q = Mat::identity(self.nrows(), self.nrows());
967 let par = get_global_parallelism();
968 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
969 self.Q_basis(),
970 self.Q_coeff(),
971 Conj::No,
972 Q.rb_mut(),
973 par,
974 MemStack::new(&mut MemBuffer::new(
975 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
976 self.nrows(),
977 self.Q_coeff.nrows(),
978 self.nrows(),
979 ),
980 )),
981 );
982 Q
983 }
984
985 pub fn compute_thin_Q(&self) -> Mat<T> {
987 let size = Ord::min(self.nrows(), self.ncols());
988 let mut Q = Mat::identity(self.nrows(), size);
989 let par = get_global_parallelism();
990 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
991 self.Q_basis(),
992 self.Q_coeff(),
993 Conj::No,
994 Q.rb_mut(),
995 par,
996 MemStack::new(&mut MemBuffer::new(
997 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
998 )),
999 );
1000 Q
1001 }
1002}
1003
1004impl<T: ComplexField> ColPivQr<T> {
1005 #[track_caller]
1007 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
1008 let QR = A.to_owned();
1009 Self::new_imp(QR)
1010 }
1011
1012 #[track_caller]
1013 fn new_imp(mut QR: Mat<T>) -> Self {
1014 let par = get_global_parallelism();
1015
1016 let (m, n) = QR.shape();
1017 let size = Ord::min(m, n);
1018
1019 let mut col_perm_fwd = vec![0usize; n];
1020 let mut col_perm_bwd = vec![0usize; n];
1021
1022 let block_size = linalg::qr::no_pivoting::factor::recommended_block_size::<T>(m, n);
1023 let mut Q_coeff = Mat::zeros(block_size, size);
1024
1025 linalg::qr::col_pivoting::factor::qr_in_place(
1026 QR.as_mut(),
1027 Q_coeff.as_mut(),
1028 &mut col_perm_fwd,
1029 &mut col_perm_bwd,
1030 par,
1031 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::factor::qr_in_place_scratch::<usize, T>(
1032 m,
1033 n,
1034 block_size,
1035 par,
1036 default(),
1037 ))),
1038 default(),
1039 );
1040
1041 let (Q_basis, R) = split_LU(QR);
1042
1043 Self {
1044 Q_basis,
1045 Q_coeff,
1046 R,
1047 P: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
1048 }
1049 }
1050
1051 pub fn Q_basis(&self) -> MatRef<'_, T> {
1053 self.Q_basis.as_ref()
1054 }
1055
1056 pub fn Q_coeff(&self) -> MatRef<'_, T> {
1058 self.Q_coeff.as_ref()
1059 }
1060
1061 pub fn R(&self) -> MatRef<'_, T> {
1063 self.R.as_ref()
1064 }
1065
1066 pub fn thin_R(&self) -> MatRef<'_, T> {
1068 let size = Ord::min(self.nrows(), self.ncols());
1069 self.R.get(..size, ..)
1070 }
1071
1072 pub fn compute_Q(&self) -> Mat<T> {
1074 let mut Q = Mat::identity(self.nrows(), self.nrows());
1075 let par = get_global_parallelism();
1076 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1077 self.Q_basis(),
1078 self.Q_coeff(),
1079 Conj::No,
1080 Q.rb_mut(),
1081 par,
1082 MemStack::new(&mut MemBuffer::new(
1083 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
1084 self.nrows(),
1085 self.Q_coeff.nrows(),
1086 self.nrows(),
1087 ),
1088 )),
1089 );
1090 Q
1091 }
1092
1093 pub fn compute_thin_Q(&self) -> Mat<T> {
1095 let size = Ord::min(self.nrows(), self.ncols());
1096 let mut Q = Mat::identity(self.nrows(), size);
1097 let par = get_global_parallelism();
1098 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1099 self.Q_basis(),
1100 self.Q_coeff(),
1101 Conj::No,
1102 Q.rb_mut(),
1103 par,
1104 MemStack::new(&mut MemBuffer::new(
1105 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
1106 )),
1107 );
1108 Q
1109 }
1110
1111 pub fn P(&self) -> PermRef<'_, usize> {
1113 self.P.as_ref()
1114 }
1115}
1116
1117impl<T: ComplexField> Svd<T> {
1118 #[track_caller]
1120 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1121 Self::new_imp(A.canonical(), Conj::get::<C>(), false)
1122 }
1123
1124 #[track_caller]
1126 pub fn new_thin<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1127 Self::new_imp(A.canonical(), Conj::get::<C>(), true)
1128 }
1129
1130 #[track_caller]
1131 fn new_imp(A: MatRef<'_, T>, conj: Conj, thin: bool) -> Result<Self, SvdError> {
1132 let par = get_global_parallelism();
1133
1134 let (m, n) = A.shape();
1135 let size = Ord::min(m, n);
1136
1137 let mut U = Mat::zeros(m, if thin { size } else { m });
1138 let mut V = Mat::zeros(n, if thin { size } else { n });
1139 let mut S = Diag::zeros(size);
1140
1141 let compute = if thin { ComputeSvdVectors::Thin } else { ComputeSvdVectors::Full };
1142
1143 linalg::svd::svd(
1144 A,
1145 S.as_mut(),
1146 Some(U.as_mut()),
1147 Some(V.as_mut()),
1148 par,
1149 MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(m, n, compute, compute, par, default()))),
1150 default(),
1151 )?;
1152
1153 if conj == Conj::Yes {
1154 for c in U.col_iter_mut() {
1155 for x in c.iter_mut() {
1156 *x = math_utils::conj(x);
1157 }
1158 }
1159 for c in V.col_iter_mut() {
1160 for x in c.iter_mut() {
1161 *x = math_utils::conj(x);
1162 }
1163 }
1164 }
1165
1166 Ok(Self { U, V, S })
1167 }
1168
1169 pub fn U(&self) -> MatRef<'_, T> {
1171 self.U.as_ref()
1172 }
1173
1174 pub fn V(&self) -> MatRef<'_, T> {
1176 self.V.as_ref()
1177 }
1178
1179 pub fn S(&self) -> DiagRef<'_, T> {
1181 self.S.as_ref()
1182 }
1183
1184 pub fn pseudoinverse(&self) -> Mat<T> {
1186 let U = self.U();
1187 let V = self.V();
1188 let S = self.S();
1189 let par = get_global_parallelism();
1190 let stack = &mut MemBuffer::new(linalg::svd::pseudoinverse_from_svd_scratch::<T>(self.nrows(), self.ncols(), par));
1191 let mut pinv = Mat::zeros(self.ncols(), self.nrows());
1192 linalg::svd::pseudoinverse_from_svd(pinv.rb_mut(), S, U, V, par, MemStack::new(stack));
1193 pinv
1194 }
1195}
1196
1197impl<T: ComplexField> SelfAdjointEigen<T> {
1198 #[track_caller]
1200 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, EvdError> {
1201 assert!(A.nrows() == A.ncols());
1202
1203 match side {
1204 Side::Lower => Self::new_imp(A.canonical(), Conj::get::<C>()),
1205 Side::Upper => Self::new_imp(A.adjoint().canonical(), Conj::get::<C::Conj>()),
1206 }
1207 }
1208
1209 #[track_caller]
1210 fn new_imp(A: MatRef<'_, T>, conj: Conj) -> Result<Self, EvdError> {
1211 let par = get_global_parallelism();
1212
1213 let n = A.nrows();
1214
1215 let mut U = Mat::zeros(n, n);
1216 let mut S = Diag::zeros(n);
1217
1218 linalg::evd::self_adjoint_evd(
1219 A,
1220 S.as_mut(),
1221 Some(U.as_mut()),
1222 par,
1223 MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
1224 n,
1225 linalg::evd::ComputeEigenvectors::Yes,
1226 par,
1227 default(),
1228 ))),
1229 default(),
1230 )?;
1231
1232 if conj == Conj::Yes {
1233 for c in U.col_iter_mut() {
1234 for x in c.iter_mut() {
1235 *x = math_utils::conj(x);
1236 }
1237 }
1238 }
1239
1240 Ok(Self { U, S })
1241 }
1242
1243 pub fn U(&self) -> MatRef<'_, T> {
1245 self.U.as_ref()
1246 }
1247
1248 pub fn S(&self) -> DiagRef<'_, T> {
1250 self.S.as_ref()
1251 }
1252
1253 pub fn pseudoinverse(&self) -> Mat<T> {
1255 let U = self.U();
1256 let S = self.S();
1257 let par = get_global_parallelism();
1258 let stack = &mut MemBuffer::new(linalg::evd::pseudoinverse_from_self_adjoint_evd_scratch::<T>(self.nrows(), par));
1259 let mut pinv = Mat::zeros(self.ncols(), self.nrows());
1260 linalg::evd::pseudoinverse_from_self_adjoint_evd(pinv.rb_mut(), S, U, par, MemStack::new(stack));
1261 pinv
1262 }
1263}
1264
1265fn real_to_cplx<T: RealField>(
1266 mut U: MatMut<'_, Complex<T>>,
1267 mut S: DiagMut<'_, Complex<T>>,
1268 U_real: MatRef<'_, T>,
1269 S_re: DiagRef<'_, T>,
1270 S_im: DiagRef<'_, T>,
1271) {
1272 let n = U.ncols();
1273
1274 let mut j = 0;
1275 while j < n {
1276 if S_im[j] == zero() {
1277 S[j] = Complex::new(S_re[j].clone(), zero());
1278
1279 for i in 0..n {
1280 U[(i, j)] = Complex::new(U_real[(i, j)].clone(), zero());
1281 }
1282
1283 j += 1;
1284 } else {
1285 S[j] = Complex::new(S_re[j].clone(), S_im[j].clone());
1286 S[j + 1] = Complex::new(S_re[j].clone(), neg(&S_im[j]));
1287
1288 for i in 0..n {
1289 U[(i, j)] = Complex::new(U_real[(i, j)].clone(), U_real[(i, j + 1)].clone());
1290 U[(i, j + 1)] = Complex::new(U_real[(i, j)].clone(), neg(&U_real[(i, j + 1)]));
1291 }
1292
1293 j += 2;
1294 }
1295 }
1296}
1297
1298impl<T: RealField> Eigen<T> {
1299 #[track_caller]
1301 pub fn new<C: Conjugate<Canonical = Complex<T>>>(A: MatRef<'_, C>) -> Result<Self, EvdError> {
1302 assert!(A.nrows() == A.ncols());
1303 Self::new_imp(A.canonical(), Conj::get::<C>())
1304 }
1305
1306 #[track_caller]
1308 pub fn new_from_real(A: MatRef<'_, T>) -> Result<Self, EvdError> {
1309 assert!(A.nrows() == A.ncols());
1310
1311 let par = get_global_parallelism();
1312
1313 let n = A.nrows();
1314
1315 let mut U_real = Mat::zeros(n, n);
1316 let mut S_re = Diag::zeros(n);
1317 let mut S_im = Diag::zeros(n);
1318
1319 linalg::evd::evd_real(
1320 A,
1321 S_re.as_mut(),
1322 S_im.as_mut(),
1323 None,
1324 Some(U_real.as_mut()),
1325 par,
1326 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<T>(
1327 n,
1328 linalg::evd::ComputeEigenvectors::No,
1329 linalg::evd::ComputeEigenvectors::Yes,
1330 par,
1331 default(),
1332 ))),
1333 default(),
1334 )?;
1335
1336 let mut U = Mat::zeros(n, n);
1337 let mut S = Diag::zeros(n);
1338
1339 real_to_cplx(U.as_mut(), S.as_mut(), U_real.as_ref(), S_re.as_ref(), S_im.as_ref());
1340
1341 Ok(Self { U, S })
1342 }
1343
1344 fn new_imp(A: MatRef<'_, Complex<T>>, conj: Conj) -> Result<Self, EvdError> {
1345 let par = get_global_parallelism();
1346
1347 let n = A.nrows();
1348
1349 let mut U = Mat::zeros(n, n);
1350 let mut S = Diag::zeros(n);
1351
1352 linalg::evd::evd_cplx(
1353 A,
1354 S.as_mut(),
1355 None,
1356 Some(U.as_mut()),
1357 par,
1358 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<T>>(
1359 n,
1360 linalg::evd::ComputeEigenvectors::No,
1361 linalg::evd::ComputeEigenvectors::Yes,
1362 par,
1363 default(),
1364 ))),
1365 default(),
1366 )?;
1367
1368 if conj == Conj::Yes {
1369 zip!(&mut U).for_each(|unzip!(c)| *c = math_utils::conj(c));
1370 zip!(&mut S).for_each(|unzip!(c)| *c = math_utils::conj(c));
1371 }
1372
1373 Ok(Self { U, S })
1374 }
1375
1376 pub fn U(&self) -> MatRef<'_, Complex<T>> {
1378 self.U.as_ref()
1379 }
1380
1381 pub fn S(&self) -> DiagRef<'_, Complex<T>> {
1383 self.S.as_ref()
1384 }
1385}
1386
1387impl<T: RealField> GeneralizedEigen<T> {
1388 #[track_caller]
1390 pub fn new<C: Conjugate<Canonical = Complex<T>>>(A: MatRef<'_, C>, B: MatRef<'_, C>) -> Result<Self, GevdError> {
1391 let n = A.nrows();
1392 assert!(all(A.nrows() == n, A.ncols() == n, B.nrows() == n, B.ncols() == n));
1393 Self::new_imp(A.canonical(), B.canonical(), Conj::get::<C>())
1394 }
1395
1396 #[track_caller]
1398 pub fn new_from_real(A: MatRef<'_, T>, B: MatRef<'_, T>) -> Result<Self, GevdError> {
1399 let n = A.nrows();
1400 assert!(all(A.nrows() == n, A.ncols() == n, B.nrows() == n, B.ncols() == n));
1401
1402 let par = get_global_parallelism();
1403
1404 let mut U_real = Mat::zeros(n, n);
1405 let mut S_re = Diag::zeros(n);
1406 let mut S_im = Diag::zeros(n);
1407 let mut S_b = Diag::zeros(n);
1408 let A = &mut A.cloned();
1409 let B = &mut B.cloned();
1410
1411 linalg::gevd::gevd_real(
1412 A.as_mut(),
1413 B.as_mut(),
1414 S_re.as_mut(),
1415 S_im.as_mut(),
1416 S_b.as_mut(),
1417 None,
1418 Some(U_real.as_mut()),
1419 par,
1420 MemStack::new(&mut MemBuffer::new(linalg::gevd::gevd_scratch::<T>(
1421 n,
1422 linalg::evd::ComputeEigenvectors::No,
1423 linalg::evd::ComputeEigenvectors::Yes,
1424 par,
1425 default(),
1426 ))),
1427 default(),
1428 )?;
1429
1430 let mut U = Mat::zeros(n, n);
1431 let mut S_a = Diag::zeros(n);
1432 let S_b = zip!(&S_b).map(|unzip!(x)| Complex::new(x.clone(), zero()));
1433
1434 real_to_cplx(U.as_mut(), S_a.as_mut(), U_real.as_ref(), S_re.as_ref(), S_im.as_ref());
1435
1436 Ok(Self { U, S_a, S_b })
1437 }
1438
1439 fn new_imp(A: MatRef<'_, Complex<T>>, B: MatRef<'_, Complex<T>>, conj: Conj) -> Result<Self, GevdError> {
1440 let par = get_global_parallelism();
1441
1442 let n = A.nrows();
1443
1444 let mut U = Mat::zeros(n, n);
1445 let mut S_a = Diag::zeros(n);
1446 let mut S_b = Diag::zeros(n);
1447 let A = &mut A.cloned();
1448 let B = &mut B.cloned();
1449
1450 linalg::gevd::gevd_cplx(
1451 A.as_mut(),
1452 B.as_mut(),
1453 S_a.as_mut(),
1454 S_b.as_mut(),
1455 None,
1456 Some(U.as_mut()),
1457 par,
1458 MemStack::new(&mut MemBuffer::new(linalg::gevd::gevd_scratch::<Complex<T>>(
1459 n,
1460 linalg::evd::ComputeEigenvectors::No,
1461 linalg::evd::ComputeEigenvectors::Yes,
1462 par,
1463 default(),
1464 ))),
1465 default(),
1466 )?;
1467
1468 if conj == Conj::Yes {
1469 zip!(&mut U).for_each(|unzip!(c)| *c = math_utils::conj(c));
1470 zip!(&mut S_a).for_each(|unzip!(c)| *c = math_utils::conj(c));
1471 zip!(&mut S_b).for_each(|unzip!(c)| *c = math_utils::conj(c));
1472 }
1473
1474 Ok(Self { U, S_a, S_b })
1475 }
1476
1477 pub fn U(&self) -> MatRef<'_, Complex<T>> {
1479 self.U.as_ref()
1480 }
1481
1482 pub fn S_a(&self) -> DiagRef<'_, Complex<T>> {
1484 self.S_a.as_ref()
1485 }
1486
1487 pub fn S_b(&self) -> DiagRef<'_, Complex<T>> {
1489 self.S_b.as_ref()
1490 }
1491}
1492
1493impl<T: ComplexField> ShapeCore for Llt<T> {
1494 #[inline]
1495 fn nrows(&self) -> usize {
1496 self.L().nrows()
1497 }
1498
1499 #[inline]
1500 fn ncols(&self) -> usize {
1501 self.L().ncols()
1502 }
1503}
1504impl<T: ComplexField> ShapeCore for Ldlt<T> {
1505 #[inline]
1506 fn nrows(&self) -> usize {
1507 self.L().nrows()
1508 }
1509
1510 #[inline]
1511 fn ncols(&self) -> usize {
1512 self.L().ncols()
1513 }
1514}
1515impl<T: ComplexField> ShapeCore for Lblt<T> {
1516 #[inline]
1517 fn nrows(&self) -> usize {
1518 self.L().nrows()
1519 }
1520
1521 #[inline]
1522 fn ncols(&self) -> usize {
1523 self.L().ncols()
1524 }
1525}
1526impl<T: ComplexField> ShapeCore for PartialPivLu<T> {
1527 #[inline]
1528 fn nrows(&self) -> usize {
1529 self.L().nrows()
1530 }
1531
1532 #[inline]
1533 fn ncols(&self) -> usize {
1534 self.U().ncols()
1535 }
1536}
1537impl<T: ComplexField> ShapeCore for FullPivLu<T> {
1538 #[inline]
1539 fn nrows(&self) -> usize {
1540 self.L().nrows()
1541 }
1542
1543 #[inline]
1544 fn ncols(&self) -> usize {
1545 self.U().ncols()
1546 }
1547}
1548impl<T: ComplexField> ShapeCore for Qr<T> {
1549 #[inline]
1550 fn nrows(&self) -> usize {
1551 self.Q_basis().nrows()
1552 }
1553
1554 #[inline]
1555 fn ncols(&self) -> usize {
1556 self.R().ncols()
1557 }
1558}
1559impl<T: ComplexField> ShapeCore for ColPivQr<T> {
1560 #[inline]
1561 fn nrows(&self) -> usize {
1562 self.Q_basis().nrows()
1563 }
1564
1565 #[inline]
1566 fn ncols(&self) -> usize {
1567 self.R().ncols()
1568 }
1569}
1570impl<T: ComplexField> ShapeCore for Svd<T> {
1571 #[inline]
1572 fn nrows(&self) -> usize {
1573 self.U().nrows()
1574 }
1575
1576 #[inline]
1577 fn ncols(&self) -> usize {
1578 self.V().nrows()
1579 }
1580}
1581impl<T: ComplexField> ShapeCore for SelfAdjointEigen<T> {
1582 #[inline]
1583 fn nrows(&self) -> usize {
1584 self.U().nrows()
1585 }
1586
1587 #[inline]
1588 fn ncols(&self) -> usize {
1589 self.U().nrows()
1590 }
1591}
1592impl<T: RealField> ShapeCore for Eigen<T> {
1593 #[inline]
1594 fn nrows(&self) -> usize {
1595 self.U().nrows()
1596 }
1597
1598 #[inline]
1599 fn ncols(&self) -> usize {
1600 self.U().nrows()
1601 }
1602}
1603
1604impl<T: ComplexField> SolveCore<T> for Llt<T> {
1605 #[track_caller]
1606 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1607 let par = get_global_parallelism();
1608
1609 let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1610 self.L.nrows(),
1611 rhs.ncols(),
1612 par,
1613 ));
1614 let stack = MemStack::new(&mut mem);
1615
1616 linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj, rhs, par, stack);
1617 }
1618
1619 #[track_caller]
1620 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1621 let par = get_global_parallelism();
1622
1623 let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1624 self.L.nrows(),
1625 rhs.ncols(),
1626 par,
1627 ));
1628 let stack = MemStack::new(&mut mem);
1629
1630 linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj.compose(Conj::Yes), rhs, par, stack);
1631 }
1632}
1633
1634#[math]
1635fn make_self_adjoint<T: ComplexField>(mut A: MatMut<'_, T>) {
1636 assert!(A.nrows() == A.ncols());
1637 let n = A.nrows();
1638 for j in 0..n {
1639 A[(j, j)] = from_real(real(A[(j, j)]));
1640 for i in 0..j {
1641 A[(i, j)] = conj(A[(j, i)]);
1642 }
1643 }
1644}
1645
1646impl<T: ComplexField> DenseSolveCore<T> for Llt<T> {
1647 #[track_caller]
1648 fn reconstruct(&self) -> Mat<T> {
1649 let par = get_global_parallelism();
1650
1651 let n = self.L.nrows();
1652 let mut out = Mat::zeros(n, n);
1653
1654 let mut mem = MemBuffer::new(linalg::cholesky::llt::reconstruct::reconstruct_scratch::<T>(n, par));
1655 let stack = MemStack::new(&mut mem);
1656
1657 linalg::cholesky::llt::reconstruct::reconstruct(out.as_mut(), self.L(), par, stack);
1658
1659 make_self_adjoint(out.as_mut());
1660 out
1661 }
1662
1663 #[track_caller]
1664 fn inverse(&self) -> Mat<T> {
1665 let par = get_global_parallelism();
1666
1667 let n = self.L.nrows();
1668 let mut out = Mat::zeros(n, n);
1669
1670 let mut mem = MemBuffer::new(linalg::cholesky::llt::inverse::inverse_scratch::<T>(n, par));
1671 let stack = MemStack::new(&mut mem);
1672
1673 linalg::cholesky::llt::inverse::inverse(out.as_mut(), self.L(), par, stack);
1674
1675 make_self_adjoint(out.as_mut());
1676 out
1677 }
1678}
1679
1680impl<T: ComplexField> SolveCore<T> for Ldlt<T> {
1681 #[track_caller]
1682 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1683 let par = get_global_parallelism();
1684
1685 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1686 self.L.nrows(),
1687 rhs.ncols(),
1688 par,
1689 ));
1690 let stack = MemStack::new(&mut mem);
1691
1692 linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L.as_ref(), self.D.as_ref(), conj, rhs, par, stack);
1693 }
1694
1695 #[track_caller]
1696 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1697 let par = get_global_parallelism();
1698
1699 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1700 self.L.nrows(),
1701 rhs.ncols(),
1702 par,
1703 ));
1704 let stack = MemStack::new(&mut mem);
1705
1706 linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L(), self.D(), conj.compose(Conj::Yes), rhs, par, stack);
1707 }
1708}
1709
1710impl<T: ComplexField> DenseSolveCore<T> for Ldlt<T> {
1711 #[track_caller]
1712 fn reconstruct(&self) -> Mat<T> {
1713 let par = get_global_parallelism();
1714
1715 let n = self.L.nrows();
1716 let mut out = Mat::zeros(n, n);
1717
1718 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::reconstruct::reconstruct_scratch::<T>(n, par));
1719 let stack = MemStack::new(&mut mem);
1720
1721 linalg::cholesky::ldlt::reconstruct::reconstruct(out.as_mut(), self.L(), self.D(), par, stack);
1722
1723 make_self_adjoint(out.as_mut());
1724 out
1725 }
1726
1727 #[track_caller]
1728 fn inverse(&self) -> Mat<T> {
1729 let par = get_global_parallelism();
1730
1731 let n = self.L.nrows();
1732 let mut out = Mat::zeros(n, n);
1733
1734 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::inverse::inverse_scratch::<T>(n, par));
1735 let stack = MemStack::new(&mut mem);
1736
1737 linalg::cholesky::ldlt::inverse::inverse(out.as_mut(), self.L(), self.D(), par, stack);
1738
1739 make_self_adjoint(out.as_mut());
1740 out
1741 }
1742}
1743
1744impl<T: ComplexField> SolveCore<T> for Lblt<T> {
1745 #[track_caller]
1746 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1747 let par = get_global_parallelism();
1748
1749 let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1750 self.L.nrows(),
1751 rhs.ncols(),
1752 par,
1753 ));
1754 let stack = MemStack::new(&mut mem);
1755
1756 linalg::cholesky::lblt::solve::solve_in_place_with_conj(self.L.as_ref(), self.B_diag(), self.B_subdiag(), conj, self.P(), rhs, par, stack);
1757 }
1758
1759 #[track_caller]
1760 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1761 let par = get_global_parallelism();
1762
1763 let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1764 self.L.nrows(),
1765 rhs.ncols(),
1766 par,
1767 ));
1768 let stack = MemStack::new(&mut mem);
1769
1770 linalg::cholesky::lblt::solve::solve_in_place_with_conj(
1771 self.L(),
1772 self.B_diag(),
1773 self.B_subdiag(),
1774 conj.compose(Conj::Yes),
1775 self.P(),
1776 rhs,
1777 par,
1778 stack,
1779 );
1780 }
1781}
1782
1783impl<T: ComplexField> DenseSolveCore<T> for Lblt<T> {
1784 #[track_caller]
1785 fn reconstruct(&self) -> Mat<T> {
1786 let par = get_global_parallelism();
1787
1788 let n = self.L.nrows();
1789 let mut out = Mat::zeros(n, n);
1790
1791 let mut mem = MemBuffer::new(linalg::cholesky::lblt::reconstruct::reconstruct_scratch::<usize, T>(n, par));
1792 let stack = MemStack::new(&mut mem);
1793
1794 linalg::cholesky::lblt::reconstruct::reconstruct(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1795
1796 make_self_adjoint(out.as_mut());
1797 out
1798 }
1799
1800 #[track_caller]
1801 fn inverse(&self) -> Mat<T> {
1802 let par = get_global_parallelism();
1803
1804 let n = self.L.nrows();
1805 let mut out = Mat::zeros(n, n);
1806
1807 let mut mem = MemBuffer::new(linalg::cholesky::lblt::inverse::inverse_scratch::<usize, T>(n, par));
1808 let stack = MemStack::new(&mut mem);
1809
1810 linalg::cholesky::lblt::inverse::inverse(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1811
1812 make_self_adjoint(out.as_mut());
1813 out
1814 }
1815}
1816
1817impl<T: ComplexField> SolveCore<T> for PartialPivLu<T> {
1818 #[track_caller]
1819 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1820 let par = get_global_parallelism();
1821
1822 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1823
1824 let k = rhs.ncols();
1825
1826 linalg::lu::partial_pivoting::solve::solve_in_place_with_conj(
1827 self.L(),
1828 self.U(),
1829 self.P(),
1830 conj,
1831 rhs,
1832 par,
1833 MemStack::new(&mut MemBuffer::new(
1834 linalg::lu::partial_pivoting::solve::solve_in_place_scratch::<usize, T>(self.nrows(), k, par),
1835 )),
1836 );
1837 }
1838
1839 #[track_caller]
1840 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1841 let par = get_global_parallelism();
1842
1843 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1844
1845 let k = rhs.ncols();
1846
1847 linalg::lu::partial_pivoting::solve::solve_transpose_in_place_with_conj(
1848 self.L(),
1849 self.U(),
1850 self.P(),
1851 conj,
1852 rhs,
1853 par,
1854 MemStack::new(&mut MemBuffer::new(
1855 linalg::lu::partial_pivoting::solve::solve_transpose_in_place_scratch::<usize, T>(self.nrows(), k, par),
1856 )),
1857 );
1858 }
1859}
1860
1861impl<T: ComplexField> DenseSolveCore<T> for PartialPivLu<T> {
1862 fn reconstruct(&self) -> Mat<T> {
1863 let par = get_global_parallelism();
1864 let m = self.nrows();
1865 let n = self.ncols();
1866
1867 let mut out = Mat::zeros(m, n);
1868
1869 linalg::lu::partial_pivoting::reconstruct::reconstruct(
1870 out.as_mut(),
1871 self.L(),
1872 self.U(),
1873 self.P(),
1874 par,
1875 MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::reconstruct::reconstruct_scratch::<
1876 usize,
1877 T,
1878 >(m, n, par))),
1879 );
1880
1881 out
1882 }
1883
1884 #[track_caller]
1885 fn inverse(&self) -> Mat<T> {
1886 let par = get_global_parallelism();
1887
1888 assert!(self.nrows() == self.ncols());
1889
1890 let n = self.ncols();
1891
1892 let mut out = Mat::zeros(n, n);
1893
1894 linalg::lu::partial_pivoting::inverse::inverse(
1895 out.as_mut(),
1896 self.L(),
1897 self.U(),
1898 self.P(),
1899 par,
1900 MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(
1901 n, par,
1902 ))),
1903 );
1904
1905 out
1906 }
1907}
1908
1909impl<T: ComplexField> SolveCore<T> for FullPivLu<T> {
1910 #[track_caller]
1911 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1912 let par = get_global_parallelism();
1913
1914 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1915
1916 let k = rhs.ncols();
1917
1918 linalg::lu::full_pivoting::solve::solve_in_place_with_conj(
1919 self.L(),
1920 self.U(),
1921 self.P(),
1922 self.Q(),
1923 conj,
1924 rhs,
1925 par,
1926 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_in_place_scratch::<usize, T>(
1927 self.nrows(),
1928 k,
1929 par,
1930 ))),
1931 );
1932 }
1933
1934 #[track_caller]
1935 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1936 let par = get_global_parallelism();
1937
1938 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1939
1940 let k = rhs.ncols();
1941
1942 linalg::lu::full_pivoting::solve::solve_transpose_in_place_with_conj(
1943 self.L(),
1944 self.U(),
1945 self.P(),
1946 self.Q(),
1947 conj,
1948 rhs,
1949 par,
1950 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_transpose_in_place_scratch::<
1951 usize,
1952 T,
1953 >(self.nrows(), k, par))),
1954 );
1955 }
1956}
1957
1958impl<T: ComplexField> DenseSolveCore<T> for FullPivLu<T> {
1959 fn reconstruct(&self) -> Mat<T> {
1960 let par = get_global_parallelism();
1961 let m = self.nrows();
1962 let n = self.ncols();
1963
1964 let mut out = Mat::zeros(m, n);
1965
1966 linalg::lu::full_pivoting::reconstruct::reconstruct(
1967 out.as_mut(),
1968 self.L(),
1969 self.U(),
1970 self.P(),
1971 self.Q(),
1972 par,
1973 MemStack::new(&mut MemBuffer::new(
1974 linalg::lu::full_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, par),
1975 )),
1976 );
1977
1978 out
1979 }
1980
1981 #[track_caller]
1982 fn inverse(&self) -> Mat<T> {
1983 let par = get_global_parallelism();
1984
1985 assert!(self.nrows() == self.ncols());
1986
1987 let n = self.ncols();
1988
1989 let mut out = Mat::zeros(n, n);
1990
1991 linalg::lu::full_pivoting::inverse::inverse(
1992 out.as_mut(),
1993 self.L(),
1994 self.U(),
1995 self.P(),
1996 self.Q(),
1997 par,
1998 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::inverse::inverse_scratch::<usize, T>(
1999 n, par,
2000 ))),
2001 );
2002
2003 out
2004 }
2005}
2006
2007impl<T: ComplexField> SolveCore<T> for Qr<T> {
2008 #[track_caller]
2009 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2010 let par = get_global_parallelism();
2011
2012 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2013
2014 let n = self.nrows();
2015 let block_size = self.Q_coeff().nrows();
2016 let k = rhs.ncols();
2017
2018 linalg::qr::no_pivoting::solve::solve_in_place_with_conj(
2019 self.Q_basis(),
2020 self.Q_coeff(),
2021 self.R(),
2022 conj,
2023 rhs,
2024 par,
2025 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_in_place_scratch::<T>(
2026 n, block_size, k, par,
2027 ))),
2028 );
2029 }
2030
2031 #[track_caller]
2032 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2033 let par = get_global_parallelism();
2034
2035 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2036
2037 let n = self.nrows();
2038 let block_size = self.Q_coeff().nrows();
2039 let k = rhs.ncols();
2040
2041 linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(
2042 self.Q_basis(),
2043 self.Q_coeff(),
2044 self.R(),
2045 conj,
2046 rhs,
2047 par,
2048 MemStack::new(&mut MemBuffer::new(
2049 linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(n, block_size, k, par),
2050 )),
2051 );
2052 }
2053}
2054
2055impl<T: ComplexField> SolveLstsqCore<T> for Qr<T> {
2056 #[track_caller]
2057 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2058 let par = get_global_parallelism();
2059
2060 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2061
2062 let m = self.nrows();
2063 let n = self.ncols();
2064 let block_size = self.Q_coeff().nrows();
2065 let k = rhs.ncols();
2066
2067 linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(
2068 self.Q_basis(),
2069 self.Q_coeff(),
2070 self.R(),
2071 conj,
2072 rhs,
2073 par,
2074 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(
2075 m, n, block_size, k, par,
2076 ))),
2077 );
2078 }
2079}
2080
2081impl<T: ComplexField> DenseSolveCore<T> for Qr<T> {
2082 fn reconstruct(&self) -> Mat<T> {
2083 let par = get_global_parallelism();
2084 let m = self.nrows();
2085 let n = self.ncols();
2086 let block_size = self.Q_coeff().nrows();
2087
2088 let mut out = Mat::zeros(m, n);
2089
2090 linalg::qr::no_pivoting::reconstruct::reconstruct(
2091 out.as_mut(),
2092 self.Q_basis(),
2093 self.Q_coeff(),
2094 self.R(),
2095 par,
2096 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::reconstruct::reconstruct_scratch::<T>(
2097 m, n, block_size, par,
2098 ))),
2099 );
2100
2101 out
2102 }
2103
2104 fn inverse(&self) -> Mat<T> {
2105 let par = get_global_parallelism();
2106 assert!(self.nrows() == self.ncols());
2107
2108 let n = self.ncols();
2109 let block_size = self.Q_coeff().nrows();
2110
2111 let mut out = Mat::zeros(n, n);
2112
2113 linalg::qr::no_pivoting::inverse::inverse(
2114 out.as_mut(),
2115 self.Q_basis(),
2116 self.Q_coeff(),
2117 self.R(),
2118 par,
2119 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::inverse::inverse_scratch::<T>(
2120 n, block_size, par,
2121 ))),
2122 );
2123
2124 out
2125 }
2126}
2127
2128impl<T: ComplexField> SolveCore<T> for ColPivQr<T> {
2129 #[track_caller]
2130 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2131 let par = get_global_parallelism();
2132
2133 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2134
2135 let n = self.nrows();
2136 let block_size = self.Q_coeff().nrows();
2137 let k = rhs.ncols();
2138
2139 linalg::qr::col_pivoting::solve::solve_in_place_with_conj(
2140 self.Q_basis(),
2141 self.Q_coeff(),
2142 self.R(),
2143 self.P(),
2144 conj,
2145 rhs,
2146 par,
2147 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_in_place_scratch::<usize, T>(
2148 n, block_size, k, par,
2149 ))),
2150 );
2151 }
2152
2153 #[track_caller]
2154 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2155 let par = get_global_parallelism();
2156
2157 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2158
2159 let n = self.nrows();
2160 let block_size = self.Q_coeff().nrows();
2161 let k = rhs.ncols();
2162
2163 linalg::qr::col_pivoting::solve::solve_transpose_in_place_with_conj(
2164 self.Q_basis(),
2165 self.Q_coeff(),
2166 self.R(),
2167 self.P(),
2168 conj,
2169 rhs,
2170 par,
2171 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_transpose_in_place_scratch::<
2172 usize,
2173 T,
2174 >(n, block_size, k, par))),
2175 );
2176 }
2177}
2178
2179impl<T: ComplexField> SolveLstsqCore<T> for ColPivQr<T> {
2180 #[track_caller]
2181 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2182 let par = get_global_parallelism();
2183
2184 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2185
2186 let m = self.nrows();
2187 let n = self.ncols();
2188 let block_size = self.Q_coeff().nrows();
2189 let k = rhs.ncols();
2190
2191 linalg::qr::col_pivoting::solve::solve_lstsq_in_place_with_conj(
2192 self.Q_basis(),
2193 self.Q_coeff(),
2194 self.R(),
2195 self.P(),
2196 conj,
2197 rhs,
2198 par,
2199 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_lstsq_in_place_scratch::<
2200 usize,
2201 T,
2202 >(m, n, block_size, k, par))),
2203 );
2204 }
2205}
2206
2207impl<T: ComplexField> DenseSolveCore<T> for ColPivQr<T> {
2208 fn reconstruct(&self) -> Mat<T> {
2209 let par = get_global_parallelism();
2210 let m = self.nrows();
2211 let n = self.ncols();
2212 let block_size = self.Q_coeff().nrows();
2213
2214 let mut out = Mat::zeros(m, n);
2215
2216 linalg::qr::col_pivoting::reconstruct::reconstruct(
2217 out.as_mut(),
2218 self.Q_basis(),
2219 self.Q_coeff(),
2220 self.R(),
2221 self.P(),
2222 par,
2223 MemStack::new(&mut MemBuffer::new(
2224 linalg::qr::col_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, block_size, par),
2225 )),
2226 );
2227
2228 out
2229 }
2230
2231 fn inverse(&self) -> Mat<T> {
2232 let par = get_global_parallelism();
2233 assert!(self.nrows() == self.ncols());
2234
2235 let n = self.ncols();
2236 let block_size = self.Q_coeff().nrows();
2237
2238 let mut out = Mat::zeros(n, n);
2239
2240 linalg::qr::col_pivoting::inverse::inverse(
2241 out.as_mut(),
2242 self.Q_basis(),
2243 self.Q_coeff(),
2244 self.R(),
2245 self.P(),
2246 par,
2247 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::inverse::inverse_scratch::<usize, T>(
2248 n, block_size, par,
2249 ))),
2250 );
2251
2252 out
2253 }
2254}
2255
2256impl<T: ComplexField> SolveCore<T> for Svd<T> {
2257 #[track_caller]
2258 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2259 let par = get_global_parallelism();
2260
2261 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2262
2263 let mut rhs = rhs;
2264 let n = self.nrows();
2265 let k = rhs.ncols();
2266 let mut tmp = Mat::zeros(n, k);
2267
2268 linalg::matmul::matmul_with_conj(
2269 tmp.as_mut(),
2270 Accum::Replace,
2271 self.U().transpose(),
2272 conj.compose(Conj::Yes),
2273 rhs.as_ref(),
2274 Conj::No,
2275 one(),
2276 par,
2277 );
2278
2279 for j in 0..k {
2280 for i in 0..n {
2281 let s = recip(&real(&self.S()[i]));
2282 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2283 }
2284 }
2285
2286 linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.V(), conj, tmp.as_ref(), Conj::No, one(), par);
2287 }
2288
2289 #[track_caller]
2290 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2291 let par = get_global_parallelism();
2292
2293 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2294
2295 let mut rhs = rhs;
2296 let n = self.nrows();
2297 let k = rhs.ncols();
2298 let mut tmp = Mat::zeros(n, k);
2299
2300 linalg::matmul::matmul_with_conj(
2301 tmp.as_mut(),
2302 Accum::Replace,
2303 self.V().transpose(),
2304 conj,
2305 rhs.as_ref(),
2306 Conj::No,
2307 one(),
2308 par,
2309 );
2310
2311 for j in 0..k {
2312 for i in 0..n {
2313 let s = recip(&real(&self.S()[i]));
2314 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2315 }
2316 }
2317
2318 linalg::matmul::matmul_with_conj(
2319 rhs.as_mut(),
2320 Accum::Replace,
2321 self.U(),
2322 conj.compose(Conj::Yes),
2323 tmp.as_ref(),
2324 Conj::No,
2325 one(),
2326 par,
2327 );
2328 }
2329}
2330
2331impl<T: ComplexField> SolveLstsqCore<T> for Svd<T> {
2332 #[track_caller]
2333 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2334 let par = get_global_parallelism();
2335
2336 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2337
2338 let m = self.nrows();
2339 let n = self.ncols();
2340
2341 let size = Ord::min(m, n);
2342
2343 let U = self.U().get(.., ..size);
2344 let V = self.V().get(.., ..size);
2345
2346 let k = rhs.ncols();
2347
2348 let mut tmp = Mat::zeros(size, k);
2349
2350 linalg::matmul::matmul_with_conj(
2351 tmp.as_mut(),
2352 Accum::Replace,
2353 U.transpose(),
2354 conj.compose(Conj::Yes),
2355 rhs.as_ref(),
2356 Conj::No,
2357 one(),
2358 par,
2359 );
2360
2361 for j in 0..k {
2362 for i in 0..size {
2363 let s = recip(&real(&self.S()[i]));
2364 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2365 }
2366 }
2367
2368 linalg::matmul::matmul_with_conj(rhs.get_mut(..size, ..), Accum::Replace, V, conj, tmp.as_ref(), Conj::No, one(), par);
2369 }
2370}
2371
2372impl<T: ComplexField> DenseSolveCore<T> for Svd<T> {
2373 fn reconstruct(&self) -> Mat<T> {
2374 let par = get_global_parallelism();
2375 let m = self.nrows();
2376 let n = self.ncols();
2377
2378 let size = Ord::min(m, n);
2379
2380 let U = self.U().get(.., ..size);
2381 let V = self.V().get(.., ..size);
2382 let S = self.S();
2383
2384 let mut UxS = Mat::zeros(m, size);
2385 for j in 0..size {
2386 let s = real(&S[j]);
2387 for i in 0..m {
2388 UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2389 }
2390 }
2391
2392 let mut out = Mat::zeros(m, n);
2393
2394 linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2395
2396 out
2397 }
2398
2399 #[track_caller]
2400 fn inverse(&self) -> Mat<T> {
2401 let par = get_global_parallelism();
2402
2403 assert!(self.nrows() == self.ncols());
2404 let n = self.nrows();
2405
2406 let U = self.U();
2407 let V = self.V();
2408 let S = self.S();
2409
2410 let mut VxS = Mat::zeros(n, n);
2411 for j in 0..n {
2412 let s = recip(&real(&S[j]));
2413
2414 for i in 0..n {
2415 VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2416 }
2417 }
2418
2419 let mut out = Mat::zeros(n, n);
2420
2421 linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2422
2423 out
2424 }
2425}
2426
2427impl<T: ComplexField> SolveCore<T> for SelfAdjointEigen<T> {
2428 #[track_caller]
2429 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2430 let par = get_global_parallelism();
2431
2432 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2433
2434 let mut rhs = rhs;
2435 let n = self.nrows();
2436 let k = rhs.ncols();
2437 let mut tmp = Mat::zeros(n, k);
2438
2439 linalg::matmul::matmul_with_conj(
2440 tmp.as_mut(),
2441 Accum::Replace,
2442 self.U().transpose(),
2443 conj.compose(Conj::Yes),
2444 rhs.as_ref(),
2445 Conj::No,
2446 one(),
2447 par,
2448 );
2449
2450 for j in 0..k {
2451 for i in 0..n {
2452 let s = recip(&real(&self.S()[i]));
2453 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2454 }
2455 }
2456
2457 linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.U(), conj, tmp.as_ref(), Conj::No, one(), par);
2458 }
2459
2460 #[track_caller]
2461 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2462 let par = get_global_parallelism();
2463
2464 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2465
2466 let mut rhs = rhs;
2467 let n = self.nrows();
2468 let k = rhs.ncols();
2469 let mut tmp = Mat::zeros(n, k);
2470
2471 linalg::matmul::matmul_with_conj(
2472 tmp.as_mut(),
2473 Accum::Replace,
2474 self.U().transpose(),
2475 conj,
2476 rhs.as_ref(),
2477 Conj::No,
2478 one(),
2479 par,
2480 );
2481
2482 for j in 0..k {
2483 for i in 0..n {
2484 let s = recip(&real(&self.S()[i]));
2485 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2486 }
2487 }
2488
2489 linalg::matmul::matmul_with_conj(
2490 rhs.as_mut(),
2491 Accum::Replace,
2492 self.U(),
2493 conj.compose(Conj::Yes),
2494 tmp.as_ref(),
2495 Conj::No,
2496 one(),
2497 par,
2498 );
2499 }
2500}
2501
2502impl<T: ComplexField> DenseSolveCore<T> for SelfAdjointEigen<T> {
2503 fn reconstruct(&self) -> Mat<T> {
2504 let par = get_global_parallelism();
2505 let m = self.nrows();
2506 let n = self.ncols();
2507
2508 let size = Ord::min(m, n);
2509
2510 let U = self.U().get(.., ..size);
2511 let V = self.U().get(.., ..size);
2512 let S = self.S();
2513
2514 let mut UxS = Mat::zeros(m, size);
2515 for j in 0..size {
2516 let s = real(&S[j]);
2517 for i in 0..m {
2518 UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2519 }
2520 }
2521
2522 let mut out = Mat::zeros(m, n);
2523
2524 linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2525
2526 out
2527 }
2528
2529 fn inverse(&self) -> Mat<T> {
2530 let par = get_global_parallelism();
2531
2532 assert!(self.nrows() == self.ncols());
2533 let n = self.nrows();
2534
2535 let U = self.U();
2536 let V = self.U();
2537 let S = self.S();
2538
2539 let mut VxS = Mat::zeros(n, n);
2540 for j in 0..n {
2541 let s = recip(&real(&S[j]));
2542
2543 for i in 0..n {
2544 VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2545 }
2546 }
2547
2548 let mut out = Mat::zeros(n, n);
2549
2550 linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2551
2552 out
2553 }
2554}
2555
2556#[cfg(test)]
2557mod tests {
2558 use super::*;
2559 use crate::assert;
2560 use crate::stats::prelude::*;
2561 use crate::utils::approx::*;
2562
2563 #[track_caller]
2564 fn test_solver(A: MatRef<'_, c64>, A_dec: impl SolveCore<c64>) {
2565 #[track_caller]
2566 fn test_solver_imp(A: MatRef<'_, c64>, A_dec: &dyn SolveCore<c64>) {
2567 let rng = &mut StdRng::seed_from_u64(0xC0FFEE);
2568
2569 let n = A.nrows();
2570 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2571
2572 let k = 3;
2573
2574 let ref R = CwiseMatDistribution {
2575 nrows: n,
2576 ncols: k,
2577 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2578 }
2579 .rand::<Mat<c64>>(rng);
2580
2581 let ref L = CwiseMatDistribution {
2582 nrows: k,
2583 ncols: n,
2584 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2585 }
2586 .rand::<Mat<c64>>(rng);
2587
2588 assert!(A * A_dec.solve(R) ~ R);
2589 assert!(A.conjugate() * A_dec.solve_conjugate(R) ~ R);
2590 assert!(A.transpose() * A_dec.solve_transpose(R) ~ R);
2591 assert!(A.adjoint() * A_dec.solve_adjoint(R) ~ R);
2592
2593 assert!(A_dec.rsolve(L) * A ~ L);
2594 assert!(A_dec.rsolve_conjugate(L) * A.conjugate() ~ L);
2595 assert!(A_dec.rsolve_transpose(L) * A.transpose() ~ L);
2596 assert!(A_dec.rsolve_adjoint(L) * A.adjoint() ~ L);
2597 }
2598
2599 test_solver_imp(A, &A_dec)
2600 }
2601
2602 #[test]
2603 fn test_all_solvers() {
2604 let rng = &mut StdRng::seed_from_u64(0);
2605 let n = 50;
2606
2607 let ref A = CwiseMatDistribution {
2608 nrows: n,
2609 ncols: n,
2610 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2611 }
2612 .rand::<Mat<c64>>(rng);
2613 let A = A.rb();
2614
2615 test_solver(A, A.partial_piv_lu());
2616 test_solver(A, A.full_piv_lu());
2617 test_solver(A, A.qr());
2618 test_solver(A, A.col_piv_qr());
2619 test_solver(A, A.svd().unwrap());
2620
2621 {
2622 let ref A = A * A.adjoint();
2623 let A = A.rb();
2624 test_solver(A, A.llt(Side::Lower).unwrap());
2625 test_solver(A, A.ldlt(Side::Lower).unwrap());
2626 }
2627
2628 {
2629 let ref A = A + A.adjoint();
2630 let A = A.rb();
2631 test_solver(A, A.lblt(Side::Lower));
2632 test_solver(A, A.self_adjoint_eigen(Side::Lower).unwrap());
2633 }
2634 }
2635
2636 #[test]
2637 fn test_eigen_cplx() {
2638 let rng = &mut StdRng::seed_from_u64(0);
2639 let n = 50;
2640
2641 let A = CwiseMatDistribution {
2642 nrows: n,
2643 ncols: n,
2644 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2645 }
2646 .rand::<Mat<c64>>(rng);
2647
2648 let n = A.nrows();
2649 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2650
2651 {
2652 let evd = A.eigen().unwrap();
2653 let e = A.eigenvalues().unwrap();
2654 assert!(&A * evd.U() ~ evd.U() * evd.S());
2655 assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2656 }
2657 {
2658 let evd = A.conjugate().eigen().unwrap();
2659 let e = A.conjugate().eigenvalues().unwrap();
2660 assert!(A.conjugate() * evd.U() ~ evd.U() * evd.S());
2661 assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2662 }
2663 }
2664
2665 #[test]
2666 fn test_geigen_cplx() {
2667 let rng = &mut StdRng::seed_from_u64(0);
2668 let n = 50;
2669
2670 let A = CwiseMatDistribution {
2671 nrows: n,
2672 ncols: n,
2673 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2674 }
2675 .rand::<Mat<c64>>(rng);
2676
2677 let B = CwiseMatDistribution {
2678 nrows: n,
2679 ncols: n,
2680 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2681 }
2682 .rand::<Mat<c64>>(rng);
2683
2684 let n = A.nrows();
2685 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2686
2687 {
2688 let evd = A.generalized_eigen(&B).unwrap();
2689 let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
2690 assert!(&A * evd.U() ~ &B * evd.U() * e);
2691 }
2692
2693 {
2694 let evd = A.conjugate().generalized_eigen(B.conjugate()).unwrap();
2695 let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
2696 assert!(A.conjugate() * evd.U() ~ B.conjugate() * evd.U() * e);
2697 }
2698 }
2699
2700 #[test]
2701 fn test_eigen_real() {
2702 let rng = &mut StdRng::seed_from_u64(0);
2703 let n = 50;
2704
2705 let A = CwiseMatDistribution {
2706 nrows: n,
2707 ncols: n,
2708 dist: StandardNormal,
2709 }
2710 .rand::<Mat<f64>>(rng);
2711
2712 let n = A.nrows();
2713 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2714
2715 let evd = A.eigen().unwrap();
2716 let e = A.eigenvalues().unwrap();
2717
2718 let A = Mat::from_fn(A.nrows(), A.ncols(), |i, j| c64::from(A[(i, j)]));
2719
2720 assert!(&A * evd.U() ~ evd.U() * evd.S());
2721 assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2722 }
2723
2724 #[test]
2725 fn test_geigen_real() {
2726 let rng = &mut StdRng::seed_from_u64(0);
2727 let n = 50;
2728
2729 let A = CwiseMatDistribution {
2730 nrows: n,
2731 ncols: n,
2732 dist: StandardNormal,
2733 }
2734 .rand::<Mat<f64>>(rng);
2735
2736 let B = CwiseMatDistribution {
2737 nrows: n,
2738 ncols: n,
2739 dist: StandardNormal,
2740 }
2741 .rand::<Mat<f64>>(rng);
2742
2743 let n = A.nrows();
2744 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2745
2746 let Ac = zip!(&A).map(|unzip!(x)| c64::new(*x, 0.0));
2747 let Bc = zip!(&B).map(|unzip!(x)| c64::new(*x, 0.0));
2748
2749 {
2750 let evd = A.generalized_eigen(&B).unwrap();
2751 let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
2752 assert!(&Ac * evd.U() ~ &Bc * evd.U() * e);
2753 }
2754 }
2755
2756 #[test]
2757 fn test_svd_solver_for_rectangular_matrix() {
2758 #[rustfmt::skip]
2759 let A = crate::mat![
2760 [4., 5., 7.],
2761 [8., 8., 2.],
2762 [4., 0., 9.],
2763 [2., 6., 2.],
2764 [0., 6., 0.],
2765 ];
2766 #[rustfmt::skip]
2767 let B = crate::mat![
2768 [105., 49.],
2769 [ 98., 54.],
2770 [113., 35.],
2771 [ 46., 34.],
2772 [ 12., 24.],
2773 ];
2774
2775 #[rustfmt::skip]
2776 let X_true= crate::mat![
2777 [8., 2.],
2778 [2., 4.],
2779 [9., 3.],
2780 ];
2781
2782 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (A.nrows() as f64));
2783 let svd = A.svd().unwrap();
2784 let mut X = B.cloned();
2785 svd.solve_lstsq_in_place_with_conj(crate::Conj::No, X.as_mat_mut());
2786 assert!(X.get(..X_true.nrows(),..) ~ X_true);
2787 }
2788}