use crate::internal_prelude::*;
use crate::{assert, get_global_parallelism};
use alloc::vec;
use alloc::vec::Vec;
use dyn_stack::MemBuffer;
use faer_traits::ComplexConj;
pub use linalg::cholesky::ldlt::factor::LdltError;
pub use linalg::cholesky::llt::factor::LltError;
pub use linalg::evd::EvdError;
pub use linalg::gevd::{GevdError, SelfAdjointGevdError};
use linalg::svd::ComputeSvdVectors;
pub use linalg::svd::SvdError;
pub trait ShapeCore {
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
}
pub trait SolveCore<T: ComplexField>: ShapeCore {
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
);
}
pub trait SolveLstsqCore<T: ComplexField>: ShapeCore {
fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
}
pub trait DenseSolveCore<T: ComplexField>: SolveCore<T> {
fn reconstruct(&self) -> Mat<T>;
fn inverse(&self) -> Mat<T>;
}
impl<S: ?Sized + ShapeCore> ShapeCore for &S {
#[inline]
fn nrows(&self) -> usize {
(**self).nrows()
}
#[inline]
fn ncols(&self) -> usize {
(**self).ncols()
}
}
impl<T: ComplexField, S: ?Sized + SolveCore<T>> SolveCore<T> for &S {
#[inline]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
(**self).solve_in_place_with_conj(conj, rhs)
}
#[inline]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
(**self).solve_transpose_in_place_with_conj(conj, rhs)
}
}
impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsqCore<T> for &S {
#[inline]
fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
(**self).solve_lstsq_in_place_with_conj(conj, rhs)
}
}
impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolveCore<T> for &S {
#[inline]
fn reconstruct(&self) -> Mat<T> {
(**self).reconstruct()
}
#[inline]
fn inverse(&self) -> Mat<T> {
(**self).inverse()
}
}
pub trait Solve<T: ComplexField>: SolveCore<T> {
#[track_caller]
#[inline]
fn solve_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
self.solve_in_place_with_conj(
Conj::No,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn solve_conjugate_in_place(
&self,
rhs: impl AsMatMut<T = T, Rows = usize>,
) {
self.solve_in_place_with_conj(
Conj::Yes,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn solve_transpose_in_place(
&self,
rhs: impl AsMatMut<T = T, Rows = usize>,
) {
self.solve_transpose_in_place_with_conj(
Conj::No,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn solve_adjoint_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
self.solve_transpose_in_place_with_conj(
Conj::Yes,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn rsolve_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
self.solve_transpose_in_place_with_conj(
Conj::No,
{ lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut(),
);
}
#[track_caller]
#[inline]
fn rsolve_conjugate_in_place(
&self,
lhs: impl AsMatMut<T = T, Cols = usize>,
) {
self.solve_transpose_in_place_with_conj(
Conj::Yes,
{ lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut(),
);
}
#[track_caller]
#[inline]
fn rsolve_transpose_in_place(
&self,
lhs: impl AsMatMut<T = T, Cols = usize>,
) {
self.solve_in_place_with_conj(
Conj::No,
{ lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut(),
);
}
#[track_caller]
#[inline]
fn rsolve_adjoint_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
self.solve_in_place_with_conj(
Conj::Yes,
{ lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut(),
);
}
#[track_caller]
#[inline]
fn solve<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn solve_conjugate<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_conjugate_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn solve_transpose<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_transpose_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn solve_adjoint<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_adjoint_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn rsolve<Lhs: AsMatRef<T = T, Cols = usize>>(
&self,
lhs: Lhs,
) -> Lhs::Owned {
let lhs = lhs.as_mat_ref();
let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
out.as_mat_mut().copy_from(lhs);
self.rsolve_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn rsolve_conjugate<Lhs: AsMatRef<T = T, Cols = usize>>(
&self,
lhs: Lhs,
) -> Lhs::Owned {
let lhs = lhs.as_mat_ref();
let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
out.as_mat_mut().copy_from(lhs);
self.rsolve_conjugate_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn rsolve_transpose<Lhs: AsMatRef<T = T, Cols = usize>>(
&self,
lhs: Lhs,
) -> Lhs::Owned {
let lhs = lhs.as_mat_ref();
let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
out.as_mat_mut().copy_from(lhs);
self.rsolve_transpose_in_place(&mut out);
out
}
#[track_caller]
#[inline]
fn rsolve_adjoint<Lhs: AsMatRef<T = T, Cols = usize>>(
&self,
lhs: Lhs,
) -> Lhs::Owned {
let lhs = lhs.as_mat_ref();
let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
out.as_mat_mut().copy_from(lhs);
self.rsolve_adjoint_in_place(&mut out);
out
}
}
impl<
C: Conjugate,
Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, C>>,
> mat::generic::Mat<Inner>
{
#[track_caller]
pub fn solve_lower_triangular_in_place(
&self,
mut rhs: impl AsMatMut<T = C::Canonical, Rows = usize>,
) {
linalg::triangular_solve::solve_lower_triangular_in_place(
self.rb(),
rhs.as_mat_mut().as_dyn_cols_mut(),
get_global_parallelism(),
);
}
#[track_caller]
pub fn solve_upper_triangular_in_place(
&self,
mut rhs: impl AsMatMut<T = C::Canonical, Rows = usize>,
) {
linalg::triangular_solve::solve_upper_triangular_in_place(
self.rb(),
rhs.as_mat_mut().as_dyn_cols_mut(),
get_global_parallelism(),
);
}
#[track_caller]
pub fn solve_unit_lower_triangular_in_place(
&self,
mut rhs: impl AsMatMut<T = C::Canonical, Rows = usize>,
) {
linalg::triangular_solve::solve_unit_lower_triangular_in_place(
self.rb(),
rhs.as_mat_mut().as_dyn_cols_mut(),
get_global_parallelism(),
);
}
#[track_caller]
pub fn solve_unit_upper_triangular_in_place(
&self,
mut rhs: impl AsMatMut<T = C::Canonical, Rows = usize>,
) {
linalg::triangular_solve::solve_unit_upper_triangular_in_place(
self.rb(),
rhs.as_mat_mut().as_dyn_cols_mut(),
get_global_parallelism(),
);
}
#[track_caller]
pub fn partial_piv_lu(&self) -> PartialPivLu<C::Canonical> {
PartialPivLu::new(self.rb())
}
#[track_caller]
pub fn full_piv_lu(&self) -> FullPivLu<C::Canonical> {
FullPivLu::new(self.rb())
}
#[track_caller]
pub fn qr(&self) -> Qr<C::Canonical> {
Qr::new(self.rb())
}
#[track_caller]
pub fn col_piv_qr(&self) -> ColPivQr<C::Canonical> {
ColPivQr::new(self.rb())
}
#[track_caller]
pub fn svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
Svd::new(self.rb())
}
#[track_caller]
pub fn thin_svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
Svd::new_thin(self.rb())
}
#[track_caller]
pub fn llt(&self, side: Side) -> Result<Llt<C::Canonical>, LltError> {
Llt::new(self.rb(), side)
}
#[track_caller]
pub fn ldlt(&self, side: Side) -> Result<Ldlt<C::Canonical>, LdltError> {
Ldlt::new(self.rb(), side)
}
#[track_caller]
pub fn lblt(&self, side: Side) -> Lblt<C::Canonical> {
Lblt::new(self.rb(), side)
}
#[track_caller]
pub fn self_adjoint_eigen(
&self,
side: Side,
) -> Result<SelfAdjointEigen<C::Canonical>, EvdError> {
SelfAdjointEigen::new(self.rb(), side)
}
#[track_caller]
pub fn self_adjoint_eigenvalues(
&self,
side: Side,
) -> Result<Vec<Real<C>>, EvdError> {
#[track_caller]
pub fn imp<T: ComplexField>(
mut A: MatRef<'_, T>,
side: Side,
) -> Result<Vec<T::Real>, EvdError> {
assert!(A.nrows() == A.ncols());
if side == Side::Upper {
A = A.transpose();
}
let par = get_global_parallelism();
let n = A.nrows();
let mut s = Diag::<T>::zeros(n);
linalg::evd::self_adjoint_evd(
A,
s.as_mut(),
None,
par,
MemStack::new(&mut MemBuffer::new(
linalg::evd::self_adjoint_evd_scratch::<T>(
n,
linalg::evd::ComputeEigenvectors::No,
par,
default(),
),
)),
default(),
)?;
Ok(s.column_vector().iter().map(|x| x.real()).collect())
}
imp(self.rb().canonical(), side)
}
#[track_caller]
pub fn singular_values(&self) -> Result<Vec<Real<C>>, SvdError> {
pub fn imp<T: ComplexField>(
A: MatRef<'_, T>,
) -> Result<Vec<T::Real>, SvdError> {
let par = get_global_parallelism();
let m = A.nrows();
let n = A.ncols();
let mut s = Diag::<T>::zeros(Ord::min(m, n));
linalg::svd::svd(
A,
s.as_mut(),
None,
None,
par,
MemStack::new(&mut MemBuffer::new(
linalg::svd::svd_scratch::<T>(
m,
n,
linalg::svd::ComputeSvdVectors::No,
linalg::svd::ComputeSvdVectors::No,
par,
default(),
),
)),
default(),
)?;
Ok(s.column_vector().iter().map(|x| x.real()).collect())
}
imp(self.rb().canonical())
}
}
impl<C: Conjugate> MatRef<'_, C> {
#[track_caller]
fn eigen_imp(&self) -> Result<Eigen<Real<C>>, EvdError> {
if const { C::Canonical::IS_REAL } {
Eigen::new_from_real(unsafe { crate::hacks::coerce(*self) })
} else if const { C::IS_CANONICAL } {
Eigen::new(unsafe {
crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(*self)
})
} else {
Eigen::new(unsafe {
crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(
*self,
)
})
}
}
#[track_caller]
fn gen_eigen_imp(
&self,
B: MatRef<'_, C>,
) -> Result<GeneralizedEigen<Real<C>>, GevdError> {
if const { C::Canonical::IS_REAL } {
GeneralizedEigen::new_from_real(
unsafe { crate::hacks::coerce(*self) },
unsafe { crate::hacks::coerce(B) },
)
} else if const { C::IS_CANONICAL } {
GeneralizedEigen::new(
unsafe {
crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(
*self,
)
},
unsafe {
crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(B)
},
)
} else {
GeneralizedEigen::new(
unsafe {
crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(
*self,
)
},
unsafe {
crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(
B,
)
},
)
}
}
#[track_caller]
fn eigenvalues_imp(&self) -> Result<Vec<Complex<Real<C>>>, EvdError> {
let par = get_global_parallelism();
if const { C::Canonical::IS_REAL } {
let A = unsafe {
crate::hacks::coerce::<_, MatRef<'_, Real<C>>>(*self)
};
assert!(A.nrows() == A.ncols());
let n = A.nrows();
let mut s_re = Diag::<Real<C>>::zeros(n);
let mut s_im = Diag::<Real<C>>::zeros(n);
linalg::evd::evd_real(
A,
s_re.as_mut(),
s_im.as_mut(),
None,
None,
par,
MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<
Real<C>,
>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::No,
par,
default(),
))),
default(),
)?;
Ok(s_re
.column_vector()
.iter()
.zip(s_im.column_vector().iter())
.map(|(re, im)| Complex::new(re.clone(), im.clone()))
.collect())
} else {
let A = unsafe {
crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(
self.canonical(),
)
};
assert!(A.nrows() == A.ncols());
let n = A.nrows();
let mut s = Diag::<Complex<Real<C>>>::zeros(n);
linalg::evd::evd_cplx(
A,
s.as_mut(),
None,
None,
par,
MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<
Complex<Real<C>>,
>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::No,
par,
default(),
))),
default(),
)?;
if const { C::IS_CANONICAL } {
Ok(s.column_vector().iter().cloned().collect())
} else {
Ok(s.column_vector().iter().map(conj).collect())
}
}
}
}
impl<
T: Conjugate,
Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, T>>,
> mat::generic::Mat<Inner>
{
#[track_caller]
pub fn generalized_eigen(
&self,
B: impl AsMatRef<T = T, Rows = usize, Cols = usize>,
) -> Result<GeneralizedEigen<Real<T>>, GevdError> {
self.rb().gen_eigen_imp(B.as_mat_ref())
}
#[track_caller]
pub fn eigen(&self) -> Result<Eigen<Real<T>>, EvdError> {
self.rb().eigen_imp()
}
#[track_caller]
pub fn eigenvalues(&self) -> Result<Vec<Complex<Real<T>>>, EvdError> {
self.rb().eigenvalues_imp()
}
}
pub trait SolveLstsq<T: ComplexField>: SolveLstsqCore<T> {
#[track_caller]
#[inline]
fn solve_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
self.solve_lstsq_in_place_with_conj(
Conj::No,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn solve_conjugate_lstsq_in_place(
&self,
rhs: impl AsMatMut<T = T, Rows = usize>,
) {
self.solve_lstsq_in_place_with_conj(
Conj::Yes,
{ rhs }.as_mat_mut().as_dyn_cols_mut(),
);
}
#[track_caller]
#[inline]
fn solve_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_lstsq_in_place(&mut out);
out.truncate(self.ncols(), rhs.ncols());
out
}
#[track_caller]
#[inline]
fn solve_conjugate_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(
&self,
rhs: Rhs,
) -> Rhs::Owned {
let rhs = rhs.as_mat_ref();
let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
out.as_mat_mut().copy_from(rhs);
self.solve_conjugate_lstsq_in_place(&mut out);
out.truncate(self.ncols(), rhs.ncols());
out
}
}
pub trait DenseSolve<T: ComplexField>: DenseSolveCore<T> {}
impl<T: ComplexField, S: ?Sized + SolveCore<T>> Solve<T> for S {}
impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsq<T> for S {}
impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolve<T> for S {}
#[derive(Clone, Debug)]
pub struct Llt<T> {
L: Mat<T>,
}
#[derive(Clone, Debug)]
pub struct Ldlt<T> {
L: Mat<T>,
D: Diag<T>,
}
#[derive(Clone, Debug)]
pub struct Lblt<T> {
L: Mat<T>,
B_diag: Diag<T>,
B_subdiag: Diag<T>,
P: Perm<usize>,
}
#[derive(Clone, Debug)]
pub struct PartialPivLu<T> {
L: Mat<T>,
U: Mat<T>,
P: Perm<usize>,
}
#[derive(Clone, Debug)]
pub struct FullPivLu<T> {
L: Mat<T>,
U: Mat<T>,
P: Perm<usize>,
Q: Perm<usize>,
}
#[derive(Clone, Debug)]
pub struct Qr<T> {
Q_basis: Mat<T>,
Q_coeff: Mat<T>,
R: Mat<T>,
}
#[derive(Clone, Debug)]
pub struct ColPivQr<T> {
Q_basis: Mat<T>,
Q_coeff: Mat<T>,
R: Mat<T>,
P: Perm<usize>,
}
#[derive(Clone, Debug)]
pub struct Svd<T> {
U: Mat<T>,
V: Mat<T>,
S: Diag<T>,
}
#[derive(Clone, Debug)]
pub struct SelfAdjointEigen<T> {
U: Mat<T>,
S: Diag<T>,
}
#[derive(Clone, Debug)]
pub struct Eigen<T> {
U: Mat<Complex<T>>,
S: Diag<Complex<T>>,
}
#[derive(Clone, Debug)]
pub struct GeneralizedEigen<T> {
U: Mat<Complex<T>>,
S_a: Diag<Complex<T>>,
S_b: Diag<Complex<T>>,
}
impl<T: ComplexField> Llt<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
side: Side,
) -> Result<Self, LltError> {
assert!(all(A.nrows() == A.ncols()));
let n = A.nrows();
let mut L = Mat::zeros(n, n);
match side {
Side::Lower => L.copy_from_triangular_lower(A),
Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
}
Self::new_imp(L)
}
#[track_caller]
fn new_imp(mut L: Mat<T>) -> Result<Self, LltError> {
let par = get_global_parallelism();
let n = L.nrows();
let mut mem = MemBuffer::new(
linalg::cholesky::llt::factor::cholesky_in_place_scratch::<T>(
n,
par,
default(),
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::llt::factor::cholesky_in_place(
L.as_mut(),
Default::default(),
par,
stack,
default(),
)?;
z!(&mut L)
.for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| {
*x = zero()
});
Ok(Self { L })
}
pub fn L(&self) -> MatRef<'_, T> {
self.L.as_ref()
}
}
impl<T: ComplexField> Ldlt<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
side: Side,
) -> Result<Self, LdltError> {
assert!(all(A.nrows() == A.ncols()));
let n = A.nrows();
let mut L = Mat::zeros(n, n);
match side {
Side::Lower => L.copy_from_triangular_lower(A),
Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
}
Self::new_imp(L)
}
#[track_caller]
fn new_imp(mut L: Mat<T>) -> Result<Self, LdltError> {
let par = get_global_parallelism();
let n = L.nrows();
let mut D = Diag::zeros(n);
let mut mem = MemBuffer::new(
linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<T>(
n,
par,
default(),
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::ldlt::factor::cholesky_in_place(
L.as_mut(),
Default::default(),
par,
stack,
default(),
)?;
D.copy_from(L.diagonal());
L.diagonal_mut().fill(one());
z!(&mut L)
.for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| {
*x = zero()
});
Ok(Self { L, D })
}
pub fn L(&self) -> MatRef<'_, T> {
self.L.as_ref()
}
pub fn D(&self) -> DiagRef<'_, T> {
self.D.as_ref()
}
}
impl<T: ComplexField> Lblt<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
side: Side,
) -> Self {
assert!(all(A.nrows() == A.ncols()));
let n = A.nrows();
let mut L = Mat::zeros(n, n);
match side {
Side::Lower => L.copy_from_triangular_lower(A),
Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
}
Self::new_imp(L)
}
#[track_caller]
fn new_imp(mut L: Mat<T>) -> Self {
let par = get_global_parallelism();
let n = L.nrows();
let mut diag = Diag::zeros(n);
let mut subdiag = Diag::zeros(n);
let mut perm_fwd = vec![0usize; n];
let mut perm_bwd = vec![0usize; n];
let mut mem = MemBuffer::new(
linalg::cholesky::lblt::factor::cholesky_in_place_scratch::<usize, T>(
n,
par,
default(),
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::lblt::factor::cholesky_in_place(
L.as_mut(),
subdiag.as_mut(),
&mut perm_fwd,
&mut perm_bwd,
par,
stack,
default(),
);
diag.copy_from(L.diagonal());
L.diagonal_mut().fill(one());
z!(&mut L)
.for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| {
*x = zero()
});
Self {
L,
B_diag: diag,
B_subdiag: subdiag,
P: unsafe {
Perm::new_unchecked(
perm_fwd.into_boxed_slice(),
perm_bwd.into_boxed_slice(),
)
},
}
}
pub fn L(&self) -> MatRef<'_, T> {
self.L.as_ref()
}
pub fn B_diag(&self) -> DiagRef<'_, T> {
self.B_diag.as_ref()
}
pub fn B_subdiag(&self) -> DiagRef<'_, T> {
self.B_subdiag.as_ref()
}
pub fn P(&self) -> PermRef<'_, usize> {
self.P.as_ref()
}
}
fn split_LU<T: ComplexField>(LU: Mat<T>) -> (Mat<T>, Mat<T>) {
let (m, n) = LU.shape();
let size = Ord::min(m, n);
let (L, U) = if m >= n {
let mut L = LU;
let mut U = Mat::zeros(size, size);
U.copy_from_triangular_upper(L.get(..size, ..size));
z!(&mut L)
.for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| {
*x = zero()
});
L.diagonal_mut().fill(one());
(L, U)
} else {
let mut U = LU;
let mut L = Mat::zeros(size, size);
L.copy_from_strict_triangular_lower(U.get(..size, ..size));
z!(&mut U)
.for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(x)| {
*x = zero()
});
L.diagonal_mut().fill(one());
(L, U)
};
(L, U)
}
impl<T: ComplexField> PartialPivLu<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
let LU = A.to_owned();
Self::new_imp(LU)
}
#[track_caller]
fn new_imp(mut LU: Mat<T>) -> Self {
let par = get_global_parallelism();
let (m, n) = LU.shape();
let mut row_perm_fwd = vec![0usize; m];
let mut row_perm_bwd = vec![0usize; m];
linalg::lu::partial_pivoting::factor::lu_in_place(
LU.as_mut(),
&mut row_perm_fwd,
&mut row_perm_bwd,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<
usize,
T,
>(m, n, par, default()),
)),
default(),
);
let (L, U) = split_LU(LU);
Self {
L,
U,
P: unsafe {
Perm::new_unchecked(
row_perm_fwd.into_boxed_slice(),
row_perm_bwd.into_boxed_slice(),
)
},
}
}
pub fn L(&self) -> MatRef<'_, T> {
self.L.as_ref()
}
pub fn U(&self) -> MatRef<'_, T> {
self.U.as_ref()
}
pub fn P(&self) -> PermRef<'_, usize> {
self.P.as_ref()
}
}
impl<T: ComplexField> FullPivLu<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
let LU = A.to_owned();
Self::new_imp(LU)
}
#[track_caller]
fn new_imp(mut LU: Mat<T>) -> Self {
let par = get_global_parallelism();
let (m, n) = LU.shape();
let mut row_perm_fwd = vec![0usize; m];
let mut row_perm_bwd = vec![0usize; m];
let mut col_perm_fwd = vec![0usize; n];
let mut col_perm_bwd = vec![0usize; n];
linalg::lu::full_pivoting::factor::lu_in_place(
LU.as_mut(),
&mut row_perm_fwd,
&mut row_perm_bwd,
&mut col_perm_fwd,
&mut col_perm_bwd,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::full_pivoting::factor::lu_in_place_scratch::<
usize,
T,
>(m, n, par, default()),
)),
default(),
);
let (L, U) = split_LU(LU);
Self {
L,
U,
P: unsafe {
Perm::new_unchecked(
row_perm_fwd.into_boxed_slice(),
row_perm_bwd.into_boxed_slice(),
)
},
Q: unsafe {
Perm::new_unchecked(
col_perm_fwd.into_boxed_slice(),
col_perm_bwd.into_boxed_slice(),
)
},
}
}
pub fn L(&self) -> MatRef<'_, T> {
self.L.as_ref()
}
pub fn U(&self) -> MatRef<'_, T> {
self.U.as_ref()
}
pub fn P(&self) -> PermRef<'_, usize> {
self.P.as_ref()
}
pub fn Q(&self) -> PermRef<'_, usize> {
self.Q.as_ref()
}
}
impl<T: ComplexField> Qr<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
let QR = A.to_owned();
Self::new_imp(QR)
}
#[track_caller]
fn new_imp(mut QR: Mat<T>) -> Self {
let par = get_global_parallelism();
let (m, n) = QR.shape();
let size = Ord::min(m, n);
let block_size =
linalg::qr::no_pivoting::factor::recommended_block_size::<T>(m, n);
let mut Q_coeff = Mat::zeros(block_size, size);
linalg::qr::no_pivoting::factor::qr_in_place(
QR.as_mut(),
Q_coeff.as_mut(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::factor::qr_in_place_scratch::<T>(
m,
n,
block_size,
par,
default(),
),
)),
default(),
);
let (Q_basis, R) = split_LU(QR);
Self {
Q_basis,
Q_coeff,
R,
}
}
pub fn Q_basis(&self) -> MatRef<'_, T> {
self.Q_basis.as_ref()
}
pub fn Q_coeff(&self) -> MatRef<'_, T> {
self.Q_coeff.as_ref()
}
pub fn R(&self) -> MatRef<'_, T> {
self.R.as_ref()
}
pub fn thin_R(&self) -> MatRef<'_, T> {
let size = Ord::min(self.nrows(), self.ncols());
self.R.get(..size, ..)
}
pub fn compute_Q(&self) -> Mat<T> {
let mut Q = Mat::identity(self.nrows(), self.nrows());
let par = get_global_parallelism();
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
Conj::No,
Q.rb_mut(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
self.nrows(),
self.Q_coeff.nrows(),
self.nrows(),
),
)),
);
Q
}
pub fn compute_thin_Q(&self) -> Mat<T> {
let size = Ord::min(self.nrows(), self.ncols());
let mut Q = Mat::identity(self.nrows(), size);
let par = get_global_parallelism();
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
Conj::No,
Q.rb_mut(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
)),
);
Q
}
}
impl<T: ComplexField> ColPivQr<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
let QR = A.to_owned();
Self::new_imp(QR)
}
#[track_caller]
fn new_imp(mut QR: Mat<T>) -> Self {
let par = get_global_parallelism();
let (m, n) = QR.shape();
let size = Ord::min(m, n);
let mut col_perm_fwd = vec![0usize; n];
let mut col_perm_bwd = vec![0usize; n];
let block_size =
linalg::qr::no_pivoting::factor::recommended_block_size::<T>(m, n);
let mut Q_coeff = Mat::zeros(block_size, size);
linalg::qr::col_pivoting::factor::qr_in_place(
QR.as_mut(),
Q_coeff.as_mut(),
&mut col_perm_fwd,
&mut col_perm_bwd,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::col_pivoting::factor::qr_in_place_scratch::<usize, T>(
m,
n,
block_size,
par,
default(),
),
)),
default(),
);
let (Q_basis, R) = split_LU(QR);
Self {
Q_basis,
Q_coeff,
R,
P: unsafe {
Perm::new_unchecked(
col_perm_fwd.into_boxed_slice(),
col_perm_bwd.into_boxed_slice(),
)
},
}
}
pub fn Q_basis(&self) -> MatRef<'_, T> {
self.Q_basis.as_ref()
}
pub fn Q_coeff(&self) -> MatRef<'_, T> {
self.Q_coeff.as_ref()
}
pub fn R(&self) -> MatRef<'_, T> {
self.R.as_ref()
}
pub fn thin_R(&self) -> MatRef<'_, T> {
let size = Ord::min(self.nrows(), self.ncols());
self.R.get(..size, ..)
}
pub fn compute_Q(&self) -> Mat<T> {
let mut Q = Mat::identity(self.nrows(), self.nrows());
let par = get_global_parallelism();
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
Conj::No,
Q.rb_mut(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
self.nrows(),
self.Q_coeff.nrows(),
self.nrows(),
),
)),
);
Q
}
pub fn compute_thin_Q(&self) -> Mat<T> {
let size = Ord::min(self.nrows(), self.ncols());
let mut Q = Mat::identity(self.nrows(), size);
let par = get_global_parallelism();
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
Conj::No,
Q.rb_mut(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
)),
);
Q
}
pub fn P(&self) -> PermRef<'_, usize> {
self.P.as_ref()
}
}
impl<T: ComplexField> Svd<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
) -> Result<Self, SvdError> {
Self::new_imp(A.canonical(), Conj::get::<C>(), false)
}
#[track_caller]
pub fn new_thin<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
) -> Result<Self, SvdError> {
Self::new_imp(A.canonical(), Conj::get::<C>(), true)
}
#[track_caller]
fn new_imp(
A: MatRef<'_, T>,
conj: Conj,
thin: bool,
) -> Result<Self, SvdError> {
let par = get_global_parallelism();
let (m, n) = A.shape();
let size = Ord::min(m, n);
let mut U = Mat::zeros(m, if thin { size } else { m });
let mut V = Mat::zeros(n, if thin { size } else { n });
let mut S = Diag::zeros(size);
let compute = if thin {
ComputeSvdVectors::Thin
} else {
ComputeSvdVectors::Full
};
linalg::svd::svd(
A,
S.as_mut(),
Some(U.as_mut()),
Some(V.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(
m,
n,
compute,
compute,
par,
default(),
))),
default(),
)?;
if conj == Conj::Yes {
for c in U.col_iter_mut() {
for x in c.iter_mut() {
*x = x.conj();
}
}
for c in V.col_iter_mut() {
for x in c.iter_mut() {
*x = x.conj();
}
}
}
Ok(Self { U, V, S })
}
pub fn U(&self) -> MatRef<'_, T> {
self.U.as_ref()
}
pub fn V(&self) -> MatRef<'_, T> {
self.V.as_ref()
}
pub fn S(&self) -> DiagRef<'_, T> {
self.S.as_ref()
}
pub fn pseudoinverse(&self) -> Mat<T> {
let U = self.U();
let V = self.V();
let S = self.S();
let par = get_global_parallelism();
let stack = &mut MemBuffer::new(
linalg::svd::pseudoinverse_from_svd_scratch::<T>(
self.nrows(),
self.ncols(),
par,
),
);
let mut pinv = Mat::zeros(self.ncols(), self.nrows());
linalg::svd::pseudoinverse_from_svd(
pinv.rb_mut(),
S,
U,
V,
par,
MemStack::new(stack),
);
pinv
}
}
impl<T: ComplexField> SelfAdjointEigen<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = T>>(
A: MatRef<'_, C>,
side: Side,
) -> Result<Self, EvdError> {
assert!(A.nrows() == A.ncols());
match side {
Side::Lower => Self::new_imp(A.canonical(), Conj::get::<C>()),
Side::Upper => {
Self::new_imp(A.adjoint().canonical(), Conj::get::<C::Conj>())
},
}
}
#[track_caller]
fn new_imp(A: MatRef<'_, T>, conj: Conj) -> Result<Self, EvdError> {
let par = get_global_parallelism();
let n = A.nrows();
let mut U = Mat::zeros(n, n);
let mut S = Diag::zeros(n);
linalg::evd::self_adjoint_evd(
A,
S.as_mut(),
Some(U.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(
linalg::evd::self_adjoint_evd_scratch::<T>(
n,
linalg::evd::ComputeEigenvectors::Yes,
par,
default(),
),
)),
default(),
)?;
if conj == Conj::Yes {
for c in U.col_iter_mut() {
for x in c.iter_mut() {
*x = x.conj();
}
}
}
Ok(Self { U, S })
}
pub fn U(&self) -> MatRef<'_, T> {
self.U.as_ref()
}
pub fn S(&self) -> DiagRef<'_, T> {
self.S.as_ref()
}
pub fn pseudoinverse(&self) -> Mat<T> {
let U = self.U();
let S = self.S();
let par = get_global_parallelism();
let stack = &mut MemBuffer::new(
linalg::evd::pseudoinverse_from_self_adjoint_evd_scratch::<T>(
self.nrows(),
par,
),
);
let mut pinv = Mat::zeros(self.ncols(), self.nrows());
linalg::evd::pseudoinverse_from_self_adjoint_evd(
pinv.rb_mut(),
S,
U,
par,
MemStack::new(stack),
);
pinv
}
}
fn real_to_cplx<T: RealField>(
mut U: MatMut<'_, Complex<T>>,
mut S: DiagMut<'_, Complex<T>>,
U_real: MatRef<'_, T>,
S_re: DiagRef<'_, T>,
S_im: DiagRef<'_, T>,
) {
let n = U.ncols();
let mut j = 0;
while j < n {
if S_im[j] == zero() {
S[j] = Complex::new(S_re[j].clone(), zero());
for i in 0..n {
U[(i, j)] = Complex::new(U_real[(i, j)].clone(), zero());
}
j += 1;
} else {
S[j] = Complex::new(S_re[j].clone(), S_im[j].clone());
S[j + 1] = Complex::new(S_re[j].clone(), -(&S_im[j]));
for i in 0..n {
U[(i, j)] = Complex::new(
U_real[(i, j)].clone(),
U_real[(i, j + 1)].clone(),
);
U[(i, j + 1)] = Complex::new(
U_real[(i, j)].clone(),
-(&U_real[(i, j + 1)]),
);
}
j += 2;
}
}
}
impl<T: RealField> Eigen<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = Complex<T>>>(
A: MatRef<'_, C>,
) -> Result<Self, EvdError> {
assert!(A.nrows() == A.ncols());
Self::new_imp(A.canonical(), Conj::get::<C>())
}
#[track_caller]
pub fn new_from_real(A: MatRef<'_, T>) -> Result<Self, EvdError> {
assert!(A.nrows() == A.ncols());
let par = get_global_parallelism();
let n = A.nrows();
let mut U_real = Mat::zeros(n, n);
let mut S_re = Diag::zeros(n);
let mut S_im = Diag::zeros(n);
linalg::evd::evd_real(
A,
S_re.as_mut(),
S_im.as_mut(),
None,
Some(U_real.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<T>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::Yes,
par,
default(),
))),
default(),
)?;
let mut U = Mat::zeros(n, n);
let mut S = Diag::zeros(n);
real_to_cplx(
U.as_mut(),
S.as_mut(),
U_real.as_ref(),
S_re.as_ref(),
S_im.as_ref(),
);
Ok(Self { U, S })
}
fn new_imp(
A: MatRef<'_, Complex<T>>,
conj: Conj,
) -> Result<Self, EvdError> {
let par = get_global_parallelism();
let n = A.nrows();
let mut U = Mat::zeros(n, n);
let mut S = Diag::zeros(n);
linalg::evd::evd_cplx(
A,
S.as_mut(),
None,
Some(U.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<
Complex<T>,
>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::Yes,
par,
default(),
))),
default(),
)?;
if conj == Conj::Yes {
zip!(&mut U).for_each(|unzip!(c)| *c = c.conj());
zip!(&mut S).for_each(|unzip!(c)| *c = c.conj());
}
Ok(Self { U, S })
}
pub fn U(&self) -> MatRef<'_, Complex<T>> {
self.U.as_ref()
}
pub fn S(&self) -> DiagRef<'_, Complex<T>> {
self.S.as_ref()
}
}
impl<T: RealField> GeneralizedEigen<T> {
#[track_caller]
pub fn new<C: Conjugate<Canonical = Complex<T>>>(
A: MatRef<'_, C>,
B: MatRef<'_, C>,
) -> Result<Self, GevdError> {
let n = A.nrows();
assert!(all(
A.nrows() == n,
A.ncols() == n,
B.nrows() == n,
B.ncols() == n
));
Self::new_imp(A.canonical(), B.canonical(), Conj::get::<C>())
}
#[track_caller]
pub fn new_from_real(
A: MatRef<'_, T>,
B: MatRef<'_, T>,
) -> Result<Self, GevdError> {
let n = A.nrows();
assert!(all(
A.nrows() == n,
A.ncols() == n,
B.nrows() == n,
B.ncols() == n
));
let par = get_global_parallelism();
let mut U_real = Mat::zeros(n, n);
let mut S_re = Diag::zeros(n);
let mut S_im = Diag::zeros(n);
let mut S_b = Diag::zeros(n);
let A = &mut A.cloned();
let B = &mut B.cloned();
linalg::gevd::gevd_real(
A.as_mut(),
B.as_mut(),
S_re.as_mut(),
S_im.as_mut(),
S_b.as_mut(),
None,
Some(U_real.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(
linalg::gevd::gevd_scratch::<T>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::Yes,
par,
default(),
),
)),
default(),
)?;
let mut U = Mat::zeros(n, n);
let mut S_a = Diag::zeros(n);
let S_b = zip!(&S_b).map(|unzip!(x)| Complex::new(x.clone(), zero()));
real_to_cplx(
U.as_mut(),
S_a.as_mut(),
U_real.as_ref(),
S_re.as_ref(),
S_im.as_ref(),
);
Ok(Self { U, S_a, S_b })
}
fn new_imp(
A: MatRef<'_, Complex<T>>,
B: MatRef<'_, Complex<T>>,
conj: Conj,
) -> Result<Self, GevdError> {
let par = get_global_parallelism();
let n = A.nrows();
let mut U = Mat::zeros(n, n);
let mut S_a = Diag::zeros(n);
let mut S_b = Diag::zeros(n);
let A = &mut A.cloned();
let B = &mut B.cloned();
linalg::gevd::gevd_cplx(
A.as_mut(),
B.as_mut(),
S_a.as_mut(),
S_b.as_mut(),
None,
Some(U.as_mut()),
par,
MemStack::new(&mut MemBuffer::new(linalg::gevd::gevd_scratch::<
Complex<T>,
>(
n,
linalg::evd::ComputeEigenvectors::No,
linalg::evd::ComputeEigenvectors::Yes,
par,
default(),
))),
default(),
)?;
if conj == Conj::Yes {
zip!(&mut U).for_each(|unzip!(c)| *c = c.conj());
zip!(&mut S_a).for_each(|unzip!(c)| *c = c.conj());
zip!(&mut S_b).for_each(|unzip!(c)| *c = c.conj());
}
Ok(Self { U, S_a, S_b })
}
pub fn U(&self) -> MatRef<'_, Complex<T>> {
self.U.as_ref()
}
pub fn S_a(&self) -> DiagRef<'_, Complex<T>> {
self.S_a.as_ref()
}
pub fn S_b(&self) -> DiagRef<'_, Complex<T>> {
self.S_b.as_ref()
}
}
impl<T: ComplexField> ShapeCore for Llt<T> {
#[inline]
fn nrows(&self) -> usize {
self.L().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.L().ncols()
}
}
impl<T: ComplexField> ShapeCore for Ldlt<T> {
#[inline]
fn nrows(&self) -> usize {
self.L().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.L().ncols()
}
}
impl<T: ComplexField> ShapeCore for Lblt<T> {
#[inline]
fn nrows(&self) -> usize {
self.L().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.L().ncols()
}
}
impl<T: ComplexField> ShapeCore for PartialPivLu<T> {
#[inline]
fn nrows(&self) -> usize {
self.L().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.U().ncols()
}
}
impl<T: ComplexField> ShapeCore for FullPivLu<T> {
#[inline]
fn nrows(&self) -> usize {
self.L().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.U().ncols()
}
}
impl<T: ComplexField> ShapeCore for Qr<T> {
#[inline]
fn nrows(&self) -> usize {
self.Q_basis().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.R().ncols()
}
}
impl<T: ComplexField> ShapeCore for ColPivQr<T> {
#[inline]
fn nrows(&self) -> usize {
self.Q_basis().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.R().ncols()
}
}
impl<T: ComplexField> ShapeCore for Svd<T> {
#[inline]
fn nrows(&self) -> usize {
self.U().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.V().nrows()
}
}
impl<T: ComplexField> ShapeCore for SelfAdjointEigen<T> {
#[inline]
fn nrows(&self) -> usize {
self.U().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.U().nrows()
}
}
impl<T: RealField> ShapeCore for Eigen<T> {
#[inline]
fn nrows(&self) -> usize {
self.U().nrows()
}
#[inline]
fn ncols(&self) -> usize {
self.U().nrows()
}
}
impl<T: ComplexField> SolveCore<T> for Llt<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::llt::solve::solve_in_place_with_conj(
self.L.as_ref(),
conj,
rhs,
par,
stack,
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::llt::solve::solve_in_place_with_conj(
self.L.as_ref(),
conj.compose(Conj::Yes),
rhs,
par,
stack,
);
}
}
fn make_self_adjoint<T: ComplexField>(mut A: MatMut<'_, T>) {
assert!(A.nrows() == A.ncols());
let n = A.nrows();
for j in 0..n {
A[(j, j)] = A[(j, j)].as_real();
for i in 0..j {
A[(i, j)] = A[(j, i)].conj();
}
}
}
impl<T: ComplexField> DenseSolveCore<T> for Llt<T> {
#[track_caller]
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem = MemBuffer::new(
linalg::cholesky::llt::reconstruct::reconstruct_scratch::<T>(
n, par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::llt::reconstruct::reconstruct(
out.as_mut(),
self.L(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem = MemBuffer::new(
linalg::cholesky::llt::inverse::inverse_scratch::<T>(n, par),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::llt::inverse::inverse(
out.as_mut(),
self.L(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
}
impl<T: ComplexField> SolveCore<T> for Ldlt<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::ldlt::solve::solve_in_place_with_conj(
self.L.as_ref(),
self.D.as_ref(),
conj,
rhs,
par,
stack,
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::ldlt::solve::solve_in_place_with_conj(
self.L(),
self.D(),
conj.compose(Conj::Yes),
rhs,
par,
stack,
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for Ldlt<T> {
#[track_caller]
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem = MemBuffer::new(
linalg::cholesky::ldlt::reconstruct::reconstruct_scratch::<T>(
n, par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::ldlt::reconstruct::reconstruct(
out.as_mut(),
self.L(),
self.D(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem = MemBuffer::new(
linalg::cholesky::ldlt::inverse::inverse_scratch::<T>(n, par),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::ldlt::inverse::inverse(
out.as_mut(),
self.L(),
self.D(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
}
impl<T: ComplexField> SolveCore<T> for Lblt<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::lblt::solve::solve_in_place_with_conj(
self.L.as_ref(),
self.B_diag(),
self.B_subdiag(),
conj,
self.P(),
rhs,
par,
stack,
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
let mut mem = MemBuffer::new(
linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
self.L.nrows(),
rhs.ncols(),
par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::lblt::solve::solve_in_place_with_conj(
self.L(),
self.B_diag(),
self.B_subdiag(),
conj.compose(Conj::Yes),
self.P(),
rhs,
par,
stack,
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for Lblt<T> {
#[track_caller]
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem = MemBuffer::new(
linalg::cholesky::lblt::reconstruct::reconstruct_scratch::<usize, T>(
n, par,
),
);
let stack = MemStack::new(&mut mem);
linalg::cholesky::lblt::reconstruct::reconstruct(
out.as_mut(),
self.L(),
self.B_diag(),
self.B_subdiag(),
self.P(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
let n = self.L.nrows();
let mut out = Mat::zeros(n, n);
let mut mem =
MemBuffer::new(linalg::cholesky::lblt::inverse::inverse_scratch::<
usize,
T,
>(n, par));
let stack = MemStack::new(&mut mem);
linalg::cholesky::lblt::inverse::inverse(
out.as_mut(),
self.L(),
self.B_diag(),
self.B_subdiag(),
self.P(),
par,
stack,
);
make_self_adjoint(out.as_mut());
out
}
}
impl<T: ComplexField> SolveCore<T> for PartialPivLu<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let k = rhs.ncols();
linalg::lu::partial_pivoting::solve::solve_in_place_with_conj(
self.L(),
self.U(),
self.P(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::solve::solve_in_place_scratch::<
usize,
T,
>(self.nrows(), k, par),
)),
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let k = rhs.ncols();
linalg::lu::partial_pivoting::solve::solve_transpose_in_place_with_conj(
self.L(),
self.U(),
self.P(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::solve::solve_transpose_in_place_scratch::<usize, T>(self.nrows(), k, par),
)),
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for PartialPivLu<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let mut out = Mat::zeros(m, n);
linalg::lu::partial_pivoting::reconstruct::reconstruct(
out.as_mut(),
self.L(),
self.U(),
self.P(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::reconstruct::reconstruct_scratch::<
usize,
T,
>(m, n, par),
)),
);
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.ncols();
let mut out = Mat::zeros(n, n);
linalg::lu::partial_pivoting::inverse::inverse(
out.as_mut(),
self.L(),
self.U(),
self.P(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::inverse::inverse_scratch::<
usize,
T,
>(n, par),
)),
);
out
}
}
impl<T: ComplexField> SolveCore<T> for FullPivLu<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let k = rhs.ncols();
linalg::lu::full_pivoting::solve::solve_in_place_with_conj(
self.L(),
self.U(),
self.P(),
self.Q(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::full_pivoting::solve::solve_in_place_scratch::<
usize,
T,
>(self.nrows(), k, par),
)),
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let k = rhs.ncols();
linalg::lu::full_pivoting::solve::solve_transpose_in_place_with_conj(
self.L(),
self.U(),
self.P(),
self.Q(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_transpose_in_place_scratch::<
usize,
T,
>(self.nrows(), k, par))),
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for FullPivLu<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let mut out = Mat::zeros(m, n);
linalg::lu::full_pivoting::reconstruct::reconstruct(
out.as_mut(),
self.L(),
self.U(),
self.P(),
self.Q(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::full_pivoting::reconstruct::reconstruct_scratch::<
usize,
T,
>(m, n, par),
)),
);
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.ncols();
let mut out = Mat::zeros(n, n);
linalg::lu::full_pivoting::inverse::inverse(
out.as_mut(),
self.L(),
self.U(),
self.P(),
self.Q(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::full_pivoting::inverse::inverse_scratch::<usize, T>(
n, par,
),
)),
);
out
}
}
impl<T: ComplexField> SolveCore<T> for Qr<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let n = self.nrows();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::no_pivoting::solve::solve_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::solve::solve_in_place_scratch::<T>(
n, block_size, k, par,
),
)),
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let n = self.nrows();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(n, block_size, k, par),
)),
);
}
}
impl<T: ComplexField> SolveLstsqCore<T> for Qr<T> {
#[track_caller]
fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == rhs.nrows(),
self.nrows() >= self.ncols(),
));
let m = self.nrows();
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(
m, n, block_size, k, par,
),
)),
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for Qr<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let mut out = Mat::zeros(m, n);
linalg::qr::no_pivoting::reconstruct::reconstruct(
out.as_mut(),
self.Q_basis(),
self.Q_coeff(),
self.R(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::reconstruct::reconstruct_scratch::<T>(
m, n, block_size, par,
),
)),
);
out
}
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let mut out = Mat::zeros(n, n);
linalg::qr::no_pivoting::inverse::inverse(
out.as_mut(),
self.Q_basis(),
self.Q_coeff(),
self.R(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::no_pivoting::inverse::inverse_scratch::<T>(
n, block_size, par,
),
)),
);
out
}
}
impl<T: ComplexField> SolveCore<T> for ColPivQr<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let n = self.nrows();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::col_pivoting::solve::solve_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
self.P(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::col_pivoting::solve::solve_in_place_scratch::<
usize,
T,
>(n, block_size, k, par),
)),
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let n = self.nrows();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::col_pivoting::solve::solve_transpose_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
self.P(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_transpose_in_place_scratch::<
usize,
T,
>(n, block_size, k, par))),
);
}
}
impl<T: ComplexField> SolveLstsqCore<T> for ColPivQr<T> {
#[track_caller]
fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == rhs.nrows(),
self.nrows() >= self.ncols(),
));
let m = self.nrows();
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let k = rhs.ncols();
linalg::qr::col_pivoting::solve::solve_lstsq_in_place_with_conj(
self.Q_basis(),
self.Q_coeff(),
self.R(),
self.P(),
conj,
rhs,
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::col_pivoting::solve::solve_lstsq_in_place_scratch::<
usize,
T,
>(m, n, block_size, k, par),
)),
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for ColPivQr<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let mut out = Mat::zeros(m, n);
linalg::qr::col_pivoting::reconstruct::reconstruct(
out.as_mut(),
self.Q_basis(),
self.Q_coeff(),
self.R(),
self.P(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::col_pivoting::reconstruct::reconstruct_scratch::<
usize,
T,
>(m, n, block_size, par),
)),
);
out
}
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.ncols();
let block_size = self.Q_coeff().nrows();
let mut out = Mat::zeros(n, n);
linalg::qr::col_pivoting::inverse::inverse(
out.as_mut(),
self.Q_basis(),
self.Q_coeff(),
self.R(),
self.P(),
par,
MemStack::new(&mut MemBuffer::new(
linalg::qr::col_pivoting::inverse::inverse_scratch::<usize, T>(
n, block_size, par,
),
)),
);
out
}
}
impl<T: ComplexField> SolveCore<T> for Svd<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let mut rhs = rhs;
let n = self.nrows();
let k = rhs.ncols();
let mut tmp = Mat::zeros(n, k);
linalg::matmul::matmul_with_conj(
tmp.as_mut(),
Accum::Replace,
self.U().transpose(),
conj.compose(Conj::Yes),
rhs.as_ref(),
Conj::No,
one(),
par,
);
for j in 0..k {
for i in 0..n {
let s = recip(&real(&self.S()[i]));
tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
}
}
linalg::matmul::matmul_with_conj(
rhs.as_mut(),
Accum::Replace,
self.V(),
conj,
tmp.as_ref(),
Conj::No,
one(),
par,
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let mut rhs = rhs;
let n = self.nrows();
let k = rhs.ncols();
let mut tmp = Mat::zeros(n, k);
linalg::matmul::matmul_with_conj(
tmp.as_mut(),
Accum::Replace,
self.V().transpose(),
conj,
rhs.as_ref(),
Conj::No,
one(),
par,
);
for j in 0..k {
for i in 0..n {
let s = recip(&real(&self.S()[i]));
tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
}
}
linalg::matmul::matmul_with_conj(
rhs.as_mut(),
Accum::Replace,
self.U(),
conj.compose(Conj::Yes),
tmp.as_ref(),
Conj::No,
one(),
par,
);
}
}
impl<T: ComplexField> SolveLstsqCore<T> for Svd<T> {
#[track_caller]
fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == rhs.nrows(),
self.nrows() >= self.ncols(),
));
let m = self.nrows();
let n = self.ncols();
let size = Ord::min(m, n);
let U = self.U().get(.., ..size);
let V = self.V().get(.., ..size);
let k = rhs.ncols();
let mut tmp = Mat::zeros(size, k);
linalg::matmul::matmul_with_conj(
tmp.as_mut(),
Accum::Replace,
U.transpose(),
conj.compose(Conj::Yes),
rhs.as_ref(),
Conj::No,
one(),
par,
);
for j in 0..k {
for i in 0..size {
let s = recip(&real(&self.S()[i]));
tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
}
}
linalg::matmul::matmul_with_conj(
rhs.get_mut(..size, ..),
Accum::Replace,
V,
conj,
tmp.as_ref(),
Conj::No,
one(),
par,
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for Svd<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let size = Ord::min(m, n);
let U = self.U().get(.., ..size);
let V = self.V().get(.., ..size);
let S = self.S();
let mut UxS = Mat::zeros(m, size);
for j in 0..size {
let s = real(&S[j]);
for i in 0..m {
UxS[(i, j)] = mul_real(&U[(i, j)], &s);
}
}
let mut out = Mat::zeros(m, n);
linalg::matmul::matmul(
out.as_mut(),
Accum::Replace,
UxS.as_ref(),
V.adjoint(),
one(),
par,
);
out
}
#[track_caller]
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.nrows();
let U = self.U();
let V = self.V();
let S = self.S();
let mut VxS = Mat::zeros(n, n);
for j in 0..n {
let s = recip(&real(&S[j]));
for i in 0..n {
VxS[(i, j)] = mul_real(&V[(i, j)], &s);
}
}
let mut out = Mat::zeros(n, n);
linalg::matmul::matmul(
out.as_mut(),
Accum::Replace,
VxS.as_ref(),
U.adjoint(),
one(),
par,
);
out
}
}
impl<T: ComplexField> SolveCore<T> for SelfAdjointEigen<T> {
#[track_caller]
fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.nrows() == rhs.nrows(),
));
let mut rhs = rhs;
let n = self.nrows();
let k = rhs.ncols();
let mut tmp = Mat::zeros(n, k);
linalg::matmul::matmul_with_conj(
tmp.as_mut(),
Accum::Replace,
self.U().transpose(),
conj.compose(Conj::Yes),
rhs.as_ref(),
Conj::No,
one(),
par,
);
for j in 0..k {
for i in 0..n {
let s = recip(&real(&self.S()[i]));
tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
}
}
linalg::matmul::matmul_with_conj(
rhs.as_mut(),
Accum::Replace,
self.U(),
conj,
tmp.as_ref(),
Conj::No,
one(),
par,
);
}
#[track_caller]
fn solve_transpose_in_place_with_conj(
&self,
conj: Conj,
rhs: MatMut<'_, T>,
) {
let par = get_global_parallelism();
assert!(all(
self.nrows() == self.ncols(),
self.ncols() == rhs.nrows(),
));
let mut rhs = rhs;
let n = self.nrows();
let k = rhs.ncols();
let mut tmp = Mat::zeros(n, k);
linalg::matmul::matmul_with_conj(
tmp.as_mut(),
Accum::Replace,
self.U().transpose(),
conj,
rhs.as_ref(),
Conj::No,
one(),
par,
);
for j in 0..k {
for i in 0..n {
let s = recip(&real(&self.S()[i]));
tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
}
}
linalg::matmul::matmul_with_conj(
rhs.as_mut(),
Accum::Replace,
self.U(),
conj.compose(Conj::Yes),
tmp.as_ref(),
Conj::No,
one(),
par,
);
}
}
impl<T: ComplexField> DenseSolveCore<T> for SelfAdjointEigen<T> {
fn reconstruct(&self) -> Mat<T> {
let par = get_global_parallelism();
let m = self.nrows();
let n = self.ncols();
let size = Ord::min(m, n);
let U = self.U().get(.., ..size);
let V = self.U().get(.., ..size);
let S = self.S();
let mut UxS = Mat::zeros(m, size);
for j in 0..size {
let s = real(&S[j]);
for i in 0..m {
UxS[(i, j)] = mul_real(&U[(i, j)], &s);
}
}
let mut out = Mat::zeros(m, n);
linalg::matmul::matmul(
out.as_mut(),
Accum::Replace,
UxS.as_ref(),
V.adjoint(),
one(),
par,
);
out
}
fn inverse(&self) -> Mat<T> {
let par = get_global_parallelism();
assert!(self.nrows() == self.ncols());
let n = self.nrows();
let U = self.U();
let V = self.U();
let S = self.S();
let mut VxS = Mat::zeros(n, n);
for j in 0..n {
let s = recip(&real(&S[j]));
for i in 0..n {
VxS[(i, j)] = mul_real(&V[(i, j)], &s);
}
}
let mut out = Mat::zeros(n, n);
linalg::matmul::matmul(
out.as_mut(),
Accum::Replace,
VxS.as_ref(),
U.adjoint(),
one(),
par,
);
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assert;
use crate::stats::prelude::*;
use crate::utils::approx::*;
#[track_caller]
fn test_solver(A: MatRef<'_, c64>, A_dec: impl SolveCore<c64>) {
#[track_caller]
fn test_solver_imp(A: MatRef<'_, c64>, A_dec: &dyn SolveCore<c64>) {
let rng = &mut StdRng::seed_from_u64(0xC0FFEE);
let n = A.nrows();
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
let k = 3;
let ref R = CwiseMatDistribution {
nrows: n,
ncols: k,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
let ref L = CwiseMatDistribution {
nrows: k,
ncols: n,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
assert!(A * A_dec.solve(R) ~ R);
assert!(A.conjugate() * A_dec.solve_conjugate(R) ~ R);
assert!(A.transpose() * A_dec.solve_transpose(R) ~ R);
assert!(A.adjoint() * A_dec.solve_adjoint(R) ~ R);
assert!(A_dec.rsolve(L) * A ~ L);
assert!(A_dec.rsolve_conjugate(L) * A.conjugate() ~ L);
assert!(A_dec.rsolve_transpose(L) * A.transpose() ~ L);
assert!(A_dec.rsolve_adjoint(L) * A.adjoint() ~ L);
}
test_solver_imp(A, &A_dec)
}
#[test]
fn test_all_solvers() {
let rng = &mut StdRng::seed_from_u64(0);
let n = 50;
let ref A = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
let A = A.rb();
test_solver(A, A.partial_piv_lu());
test_solver(A, A.full_piv_lu());
test_solver(A, A.qr());
test_solver(A, A.col_piv_qr());
test_solver(A, A.svd().unwrap());
{
let ref A = A * A.adjoint();
let A = A.rb();
test_solver(A, A.llt(Side::Lower).unwrap());
test_solver(A, A.ldlt(Side::Lower).unwrap());
}
{
let ref A = A + A.adjoint();
let A = A.rb();
test_solver(A, A.lblt(Side::Lower));
test_solver(A, A.self_adjoint_eigen(Side::Lower).unwrap());
}
}
#[test]
fn test_eigen_cplx() {
let rng = &mut StdRng::seed_from_u64(0);
let n = 50;
let A = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
let n = A.nrows();
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
{
let evd = A.eigen().unwrap();
let e = A.eigenvalues().unwrap();
assert!(& A * evd.U() ~ evd.U() * evd.S());
assert!(evd.S().column_vector() ~ ColRef::from_slice(& e));
}
{
let evd = A.conjugate().eigen().unwrap();
let e = A.conjugate().eigenvalues().unwrap();
assert!(A.conjugate() * evd.U() ~ evd.U() * evd.S());
assert!(evd.S().column_vector() ~ ColRef::from_slice(& e));
}
}
#[test]
fn test_geigen_cplx() {
let rng = &mut StdRng::seed_from_u64(0);
let n = 50;
let A = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
let B = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: ComplexDistribution::new(StandardNormal, StandardNormal),
}
.rand::<Mat<c64>>(rng);
let n = A.nrows();
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
{
let evd = A.generalized_eigen(&B).unwrap();
let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
assert!(& A * evd.U() ~ & B * evd.U() * e);
}
{
let evd = A.conjugate().generalized_eigen(B.conjugate()).unwrap();
let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
assert!(A.conjugate() * evd.U() ~ B.conjugate() * evd.U() * e);
}
}
#[test]
fn test_eigen_real() {
let rng = &mut StdRng::seed_from_u64(0);
let n = 50;
let A = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: StandardNormal,
}
.rand::<Mat<f64>>(rng);
let n = A.nrows();
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
let evd = A.eigen().unwrap();
let e = A.eigenvalues().unwrap();
let A = Mat::from_fn(A.nrows(), A.ncols(), |i, j| c64::from(A[(i, j)]));
assert!(& A * evd.U() ~ evd.U() * evd.S());
assert!(evd.S().column_vector() ~ ColRef::from_slice(& e));
}
#[test]
fn test_geigen_real() {
let rng = &mut StdRng::seed_from_u64(0);
let n = 50;
let A = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: StandardNormal,
}
.rand::<Mat<f64>>(rng);
let B = CwiseMatDistribution {
nrows: n,
ncols: n,
dist: StandardNormal,
}
.rand::<Mat<f64>>(rng);
let n = A.nrows();
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
let Ac = zip!(&A).map(|unzip!(x)| c64::new(*x, 0.0));
let Bc = zip!(&B).map(|unzip!(x)| c64::new(*x, 0.0));
{
let evd = A.generalized_eigen(&B).unwrap();
let e = zip!(evd.S_a(), evd.S_b()).map(|unzip!(a, b)| a / b);
assert!(& Ac * evd.U() ~ & Bc * evd.U() * e);
}
}
#[test]
fn test_svd_solver_for_rectangular_matrix() {
#[rustfmt::skip]
let A = crate::mat![
[4., 5., 7.], [8., 8., 2.], [4., 0., 9.], [2., 6., 2.], [0., 6., 0.],
];
#[rustfmt::skip]
let B = crate::mat![
[105., 49.], [98., 54.], [113., 35.], [46., 34.], [12., 24.],
];
#[rustfmt::skip]
let X_true = crate::mat![[8., 2.], [2., 4.], [9., 3.],];
let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (A.nrows() as f64));
let svd = A.svd().unwrap();
let mut X = B.cloned();
svd.solve_lstsq_in_place_with_conj(crate::Conj::No, X.as_mat_mut());
assert!(X.get(..X_true.nrows(),..) ~ X_true);
}
}