use ndarray::*;
use num_traits::{Float, Zero};
use crate::convert::*;
use crate::error::*;
use crate::layout::*;
use crate::opnorm::OperationNorm;
use crate::types::*;
pub use lax::{Pivot, Transpose};
pub trait Solve<A: Scalar> {
fn solve<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_inplace(&mut b)?;
Ok(b)
}
fn solve_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solve_inplace(&mut b)?;
Ok(b)
}
fn solve_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;
fn solve_t<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_t_inplace(&mut b)?;
Ok(b)
}
fn solve_t_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solve_t_inplace(&mut b)?;
Ok(b)
}
fn solve_t_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;
fn solve_h<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut b = replicate(b);
self.solve_h_inplace(&mut b)?;
Ok(b)
}
fn solve_h_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solve_h_inplace(&mut b)?;
Ok(b)
}
fn solve_h_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, Ix1>,
) -> Result<&'a mut ArrayBase<S, Ix1>>;
}
#[derive(Clone)]
pub struct LUFactorized<S: Data + RawDataClone> {
a: ArrayBase<S, Ix2>,
ipiv: Pivot,
}
impl<A, S> Solve<A> for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
fn solve_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(1)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::No,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_t_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(0)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::Transpose,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_h_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
assert_eq!(
rhs.len(),
self.a.len_of(Axis(0)),
"The length of `rhs` must be compatible with the shape of the factored matrix.",
);
A::solve(
self.a.square_layout()?,
Transpose::Hermite,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
}
impl<A, S> Solve<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize()?;
f.solve_inplace(rhs)
}
fn solve_t_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize()?;
f.solve_t_inplace(rhs)
}
fn solve_h_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize()?;
f.solve_h_inplace(rhs)
}
}
pub trait Factorize<S: Data + RawDataClone> {
fn factorize(&self) -> Result<LUFactorized<S>>;
}
pub trait FactorizeInto<S: Data + RawDataClone> {
fn factorize_into(self) -> Result<LUFactorized<S>>;
}
impl<A, S> FactorizeInto<S> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A> + RawDataClone,
{
fn factorize_into(mut self) -> Result<LUFactorized<S>> {
let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
Ok(LUFactorized { a: self, ipiv })
}
}
impl<A, Si> Factorize<OwnedRepr<A>> for ArrayBase<Si, Ix2>
where
A: Scalar + Lapack,
Si: Data<Elem = A>,
{
fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
let mut a: Array2<A> = replicate(self);
let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
Ok(LUFactorized { a, ipiv })
}
}
pub trait Inverse {
type Output;
fn inv(&self) -> Result<Self::Output>;
}
pub trait InverseInto {
type Output;
fn inv_into(self) -> Result<Self::Output>;
}
impl<A, S> InverseInto for LUFactorized<S>
where
A: Scalar + Lapack,
S: DataMut<Elem = A> + RawDataClone,
{
type Output = ArrayBase<S, Ix2>;
fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
A::inv(
self.a.square_layout()?,
self.a.as_allocated_mut()?,
&self.ipiv,
)?;
Ok(self.a)
}
}
impl<A, S> Inverse for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
type Output = Array2<A>;
fn inv(&self) -> Result<Array2<A>> {
let a = if self.a.is_standard_layout() {
replicate(&self.a)
} else {
replicate(&self.a.t()).reversed_axes()
};
let f = LUFactorized {
a,
ipiv: self.ipiv.clone(),
};
f.inv_into()
}
}
impl<A, S> InverseInto for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A> + RawDataClone,
{
type Output = Self;
fn inv_into(self) -> Result<Self::Output> {
let f = self.factorize_into()?;
f.inv_into()
}
}
impl<A, Si> Inverse for ArrayBase<Si, Ix2>
where
A: Scalar + Lapack,
Si: Data<Elem = A>,
{
type Output = Array2<A>;
fn inv(&self) -> Result<Self::Output> {
let f = self.factorize()?;
f.inv_into()
}
}
pub trait Determinant<A: Scalar> {
fn det(&self) -> Result<A> {
let (sign, ln_det) = self.sln_det()?;
Ok(sign * A::from_real(Float::exp(ln_det)))
}
fn sln_det(&self) -> Result<(A, A::Real)>;
}
pub trait DeterminantInto<A: Scalar>: Sized {
fn det_into(self) -> Result<A> {
let (sign, ln_det) = self.sln_det_into()?;
Ok(sign * A::from_real(Float::exp(ln_det)))
}
fn sln_det_into(self) -> Result<(A, A::Real)>;
}
fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
where
A: Scalar + Lapack,
P: Iterator<Item = i32>,
U: Iterator<Item = &'a A>,
{
let pivot_sign = if ipiv_iter
.enumerate()
.filter(|&(i, pivot)| pivot != i as i32 + 1)
.count()
% 2
== 0
{
A::one()
} else {
-A::one()
};
let (upper_sign, ln_det) = u_diag_iter.fold(
(A::one(), A::Real::zero()),
|(upper_sign, ln_det), &elem| {
let abs_elem: A::Real = elem.abs();
(
upper_sign * elem / A::from_real(abs_elem),
ln_det + Float::ln(abs_elem),
)
},
);
(pivot_sign * upper_sign, ln_det)
}
impl<A, S> Determinant<A> for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
fn sln_det(&self) -> Result<(A, A::Real)> {
self.a.ensure_square()?;
Ok(lu_sln_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
}
}
impl<A, S> DeterminantInto<A> for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
fn sln_det_into(self) -> Result<(A, A::Real)> {
self.a.ensure_square()?;
Ok(lu_sln_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
}
}
impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn sln_det(&self) -> Result<(A, A::Real)> {
self.ensure_square()?;
match self.factorize() {
Ok(fac) => fac.sln_det(),
Err(LinalgError::Lapack(e))
if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
{
Ok((A::zero(), A::Real::neg_infinity()))
}
Err(err) => Err(err),
}
}
}
impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A> + RawDataClone,
{
fn sln_det_into(self) -> Result<(A, A::Real)> {
self.ensure_square()?;
match self.factorize_into() {
Ok(fac) => fac.sln_det_into(),
Err(LinalgError::Lapack(e))
if matches!(e, lax::error::Error::LapackComputationalFailure { .. }) =>
{
Ok((A::zero(), A::Real::neg_infinity()))
}
Err(err) => Err(err),
}
}
}
pub trait ReciprocalConditionNum<A: Scalar> {
fn rcond(&self) -> Result<A::Real>;
}
pub trait ReciprocalConditionNumInto<A: Scalar> {
fn rcond_into(self) -> Result<A::Real>;
}
impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
fn rcond(&self) -> Result<A::Real> {
Ok(A::rcond(
self.a.layout()?,
self.a.as_allocated()?,
self.a.opnorm_one()?,
)?)
}
}
impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
where
A: Scalar + Lapack,
S: Data<Elem = A> + RawDataClone,
{
fn rcond_into(self) -> Result<A::Real> {
self.rcond()
}
}
impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn rcond(&self) -> Result<A::Real> {
self.factorize()?.rcond_into()
}
}
impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A> + RawDataClone,
{
fn rcond_into(self) -> Result<A::Real> {
self.factorize_into()?.rcond_into()
}
}