use ndarray::*;
use num_traits::{Float, One, Zero};
use super::convert::*;
use super::error::*;
use super::layout::*;
use super::types::*;
pub use lapack_traits::{Pivot, UPLO};
pub trait SolveH<A: Scalar> {
fn solveh<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solveh_inplace(&mut b)?;
Ok(b)
}
fn solveh_into<S: DataMut<Elem = A>>(&self, mut b: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solveh_inplace(&mut b)?;
Ok(b)
}
fn solveh_inplace<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
}
pub struct BKFactorized<S: Data> {
pub a: ArrayBase<S, Ix2>,
pub ipiv: Pivot,
}
impl<A, S> SolveH<A> for BKFactorized<S>
where
A: Scalar,
S: Data<Elem = A>,
{
fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solveh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?
};
Ok(rhs)
}
}
impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
where
A: Scalar,
S: Data<Elem = A>,
{
fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorizeh()?;
f.solveh_inplace(rhs)
}
}
pub trait FactorizeH<S: Data> {
fn factorizeh(&self) -> Result<BKFactorized<S>>;
}
pub trait FactorizeHInto<S: Data> {
fn factorizeh_into(self) -> Result<BKFactorized<S>>;
}
impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
where
A: Scalar,
S: DataMut<Elem = A>,
{
fn factorizeh_into(mut self) -> Result<BKFactorized<S>> {
let ipiv = unsafe { A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)? };
Ok(BKFactorized {
a: self,
ipiv: ipiv,
})
}
}
impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
where
A: Scalar,
Si: Data<Elem = A>,
{
fn factorizeh(&self) -> Result<BKFactorized<OwnedRepr<A>>> {
let mut a: Array2<A> = replicate(self);
let ipiv = unsafe { A::bk(a.square_layout()?, UPLO::Upper, a.as_allocated_mut()?)? };
Ok(BKFactorized { a: a, ipiv: ipiv })
}
}
pub trait InverseH {
type Output;
fn invh(&self) -> Result<Self::Output>;
}
pub trait InverseHInto {
type Output;
fn invh_into(self) -> Result<Self::Output>;
}
impl<A, S> InverseHInto for BKFactorized<S>
where
A: Scalar,
S: DataMut<Elem = A>,
{
type Output = ArrayBase<S, Ix2>;
fn invh_into(mut self) -> Result<ArrayBase<S, Ix2>> {
unsafe {
A::invh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated_mut()?,
&self.ipiv,
)?
};
triangular_fill_hermitian(&mut self.a, UPLO::Upper);
Ok(self.a)
}
}
impl<A, S> InverseH for BKFactorized<S>
where
A: Scalar,
S: Data<Elem = A>,
{
type Output = Array2<A>;
fn invh(&self) -> Result<Self::Output> {
let f = BKFactorized {
a: replicate(&self.a),
ipiv: self.ipiv.clone(),
};
f.invh_into()
}
}
impl<A, S> InverseHInto for ArrayBase<S, Ix2>
where
A: Scalar,
S: DataMut<Elem = A>,
{
type Output = Self;
fn invh_into(self) -> Result<Self::Output> {
let f = self.factorizeh_into()?;
f.invh_into()
}
}
impl<A, Si> InverseH for ArrayBase<Si, Ix2>
where
A: Scalar,
Si: Data<Elem = A>,
{
type Output = Array2<A>;
fn invh(&self) -> Result<Self::Output> {
let f = self.factorizeh()?;
f.invh_into()
}
}
pub trait DeterminantH {
type Output;
fn deth(&self) -> Self::Output;
}
pub trait DeterminantHInto {
type Output;
fn deth_into(self) -> Self::Output;
}
fn bk_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> A::Real
where
P: Iterator<Item = i32>,
S: Data<Elem = A>,
A: Scalar,
{
let mut sign = A::Real::one();
let mut ln_det = A::Real::zero();
let mut ipiv_enum = ipiv_iter.enumerate();
while let Some((k, ipiv_k)) = ipiv_enum.next() {
debug_assert!(k < a.rows() && k < a.cols());
if ipiv_k > 0 {
let elem = unsafe { a.uget((k, k)) }.real();
debug_assert_eq!(elem.imag(), Zero::zero());
sign = sign * elem.signum();
ln_det = ln_det + elem.abs().ln();
} else {
let upper_diag = unsafe { a.uget((k, k)) }.real();
debug_assert_eq!(upper_diag.imag(), Zero::zero());
let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.real();
debug_assert_eq!(lower_diag.imag(), Zero::zero());
let off_diag = match uplo {
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
};
let block_det = upper_diag * lower_diag - off_diag.abs_sqr();
sign = sign * block_det.signum();
ln_det = ln_det + block_det.abs().ln();
ipiv_enum.next();
}
}
sign * ln_det.exp()
}
impl<A, S> DeterminantH for BKFactorized<S>
where
A: Scalar,
S: Data<Elem = A>,
{
type Output = A::Real;
fn deth(&self) -> A::Real {
bk_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
}
}
impl<A, S> DeterminantHInto for BKFactorized<S>
where
A: Scalar,
S: Data<Elem = A>,
{
type Output = A::Real;
fn deth_into(self) -> A::Real {
bk_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
}
}
impl<A, S> DeterminantH for ArrayBase<S, Ix2>
where
A: Scalar,
S: Data<Elem = A>,
{
type Output = Result<A::Real>;
fn deth(&self) -> Result<A::Real> {
match self.factorizeh() {
Ok(fac) => Ok(fac.deth()),
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
Err(err) => Err(err),
}
}
}
impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
where
A: Scalar,
S: DataMut<Elem = A>,
{
type Output = Result<A::Real>;
fn deth_into(self) -> Result<A::Real> {
match self.factorizeh_into() {
Ok(fac) => Ok(fac.deth_into()),
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
Err(err) => Err(err),
}
}
}