use ndarray::*;
use num_traits::{Float, One, Zero};
use crate::convert::*;
use crate::error::*;
use crate::layout::*;
use crate::types::*;
pub use lax::{Pivot, UPLO};
pub trait SolveH<A: Scalar> {
fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solveh_inplace(&mut b)?;
Ok(b)
}
fn solveh_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solveh_inplace(&mut b)?;
Ok(b)
}
fn solveh_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;
}
pub struct BKFactorized<S: Data> {
pub a: ArrayBase<S, Ix2>,
pub ipiv: Pivot,
}
impl<A, S> SolveH<A> for BKFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solveh_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(1)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solveh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
}
impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solveh_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorizeh()?;
f.solveh_inplace(rhs)
}
}
pub trait FactorizeH<S: Data> {
fn factorizeh(&self) -> Result<BKFactorized<S>>;
}
pub trait FactorizeHInto<S: Data> {
fn factorizeh_into(self) -> Result<BKFactorized<S>>;
}
impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
let ipiv = A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)?;
Ok(BKFactorized { a: self, ipiv })
}
}
impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
where
A: Scalar + Lapack,
Si: Data<Elem = A>,
{
fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
let mut a: Array2<A> = replicate(self);
let ipiv = A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)?;
Ok(BKFactorized { a, ipiv })
}
}
pub trait InverseH {
type Output;
fn invh(&self) -> Result<Self::Output>;
}
pub trait InverseHInto {
type Output;
fn invh_into(self) -> Result<Self::Output>;
}
impl<A, S> InverseHInto for BKFactorized<S>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type Output = ArrayBase<S, Ix2>;
fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
A::invh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated_mut()?,
&self.ipiv,
)?;
triangular_fill_hermitian(&mut self.a, UPLO::Upper);
Ok(self.a)
}
}
impl<A, S> InverseH for BKFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type Output = Array2<A>;
fn invh(&self) -> Result<Self::Output> {
let f = BKFactorized {
a: replicate(&self.a),
ipiv: self.ipiv.clone(),
};
f.invh_into()
}
}
impl<A, S> InverseHInto for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type Output = Self;
fn invh_into(self) -> Result<Self::Output> {
let f = self.factorizeh_into()?;
f.invh_into()
}
}
impl<A, Si> InverseH for ArrayBase<Si, Ix2>
where
A: Scalar + Lapack,
Si: Data<Elem = A>,
{
type Output = Array2<A>;
fn invh(&self) -> Result<Self::Output> {
let f = self.factorizeh()?;
f.invh_into()
}
}
pub trait DeterminantH {
type Elem: Scalar;
fn deth(&self) -> Result<<Self::Elem as Scalar>::Real>;
fn sln_deth(&self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
}
pub trait DeterminantHInto {
type Elem: Scalar;
fn deth_into(self) -> Result<<Self::Elem as Scalar>::Real>;
fn sln_deth_into(self) -> Result<(<Self::Elem as Scalar>::Real, <Self::Elem as Scalar>::Real)>;
}
fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
where
P: Iterator<Item = i32>,
S: Data<Elem = A>,
A: Scalar + Lapack,
{
let layout = a.layout().unwrap();
let mut sign = A::Real::one();
let mut ln_det = A::Real::zero();
let mut ipiv_enum = ipiv_iter.enumerate();
while let Some((k, ipiv_k)) = ipiv_enum.next() {
debug_assert!(k < a.nrows() && k < a.ncols());
if ipiv_k > 0 {
let elem = unsafe { a.uget((k, k)) }.re();
debug_assert_eq!(elem.im(), Zero::zero());
sign *= elem.signum();
ln_det += Float::ln(Float::abs(elem));
} else {
let upper_diag = unsafe { a.uget((k, k)) }.re();
debug_assert_eq!(upper_diag.im(), Zero::zero());
let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re();
debug_assert_eq!(lower_diag.im(), Zero::zero());
let off_diag = match layout {
MatrixLayout::C { .. } => match uplo {
UPLO::Upper => unsafe { a.uget((k + 1, k)) },
UPLO::Lower => unsafe { a.uget((k, k + 1)) },
},
MatrixLayout::F { .. } => match uplo {
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
},
};
let block_det = upper_diag * lower_diag - off_diag.square();
sign *= block_det.signum();
ln_det += Float::ln(Float::abs(block_det));
ipiv_enum.next();
}
}
(sign, ln_det)
}
impl<A, S> BKFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
pub fn deth(&self) -> A::Real {
let (sign, ln_det) = self.sln_deth();
sign * Float::exp(ln_det)
}
pub fn sln_deth(&self) -> (A::Real, A::Real) {
bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
}
pub fn deth_into(self) -> A::Real {
let (sign, ln_det) = self.sln_deth_into();
sign * Float::exp(ln_det)
}
pub fn sln_deth_into(self) -> (A::Real, A::Real) {
bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
}
}
impl<A, S> DeterminantH for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
type Elem = A;
fn deth(&self) -> Result<A::Real> {
let (sign, ln_det) = self.sln_deth()?;
Ok(sign * Float::exp(ln_det))
}
fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
match self.factorizeh() {
Ok(fac) => Ok(fac.sln_deth()),
Err(LinalgError::Lapack(e))
if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
{
Ok((A::Real::zero(), A::Real::neg_infinity()))
}
Err(err) => Err(err),
}
}
}
impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type Elem = A;
fn deth_into(self) -> Result<A::Real> {
let (sign, ln_det) = self.sln_deth_into()?;
Ok(sign * Float::exp(ln_det))
}
fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
match self.factorizeh_into() {
Ok(fac) => Ok(fac.sln_deth_into()),
Err(LinalgError::Lapack(e))
if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
{
Ok((A::Real::zero(), A::Real::neg_infinity()))
}
Err(err) => Err(err),
}
}
}