use impl_prelude::*;
use permute::{MatrixPermutation, Permutes};
use ndarray as nd;
use lapack::c::{sgetrf, dgetrf, cgetrf, zgetrf, sgetri, dgetri, cgetri, zgetri};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LUError {
BadLayout,
NotSquare,
Singular,
InconsistentDimensions,
IllegalParameter(i32),
}
#[derive(Debug)]
pub struct LUFactors<T: LU> {
mat: Array<T, Ix2>,
perm: MatrixPermutation,
}
impl<T: LU> LUFactors<T> {
fn from_raw<Matrix>(mat: Matrix, perm: Vec<i32>) -> Result<LUFactors<T>, LUError>
where Matrix: Into<Array<T, Ix2>>
{
let mut ma = mat.into();
if slice_and_layout_mut(&mut ma).is_none() {
return Err(LUError::BadLayout);
}
Ok(LUFactors {
mat: ma,
perm: MatrixPermutation::from_ipiv(perm),
})
}
pub fn inverse(&self) -> Result<Array<T, Ix2>, LUError> {
LU::compute_inverse(&self.mat, self.perm.ipiv())
}
pub fn inverse_into(self) -> Result<Array<T, Ix2>, LUError> {
LU::compute_inverse_into(self.mat, self.perm.ipiv())
}
pub fn rows(&self) -> usize {
self.mat.rows()
}
pub fn cols(&self) -> usize {
self.mat.cols()
}
pub fn perm(&self) -> &MatrixPermutation {
&self.perm
}
fn k(&self) -> usize {
cmp::min(self.rows(), self.cols())
}
pub fn l(&self) -> Array<T, Ix2> {
let (m, _) = self.mat.dim();
let k = self.k();
let mut l = Array::zeros((m, k));
let top_left_slice = s![.., ..k as isize];
l.slice_mut(top_left_slice).assign(&self.mat.slice(top_left_slice));
let zero = T::zero();
let one = T::one();
for (r, mut row) in l.outer_iter_mut().enumerate().take(k) {
row[r] = one;
row.slice_mut(s![r as isize + 1..]).fill(zero);
}
l
}
#[inline]
pub fn u(&self) -> Array<T, Ix2> {
let (_, n) = self.mat.dim();
let k = self.k();
let mut u = Array::zeros((k, n));
let top_left_slice = s![..k as isize, ..];
u.slice_mut(top_left_slice).assign(&self.mat.slice(top_left_slice));
let zero = T::zero();
for (r, mut row) in u.outer_iter_mut().enumerate().take(k) {
row.slice_mut(s![..r as isize]).fill(zero);
}
u
}
pub fn reconstruct(&self) -> Array<T, Ix2> {
let lu = self.l().dot(&self.u());
self.perm.permute_into(lu).expect("guarantee that lu is the right size")
}
}
pub trait LU: nd::LinalgScalar + Permutes {
fn compute_into(a: Array<Self, Ix2>) -> Result<LUFactors<Self>, LUError>;
fn compute<D1: Data>(a: &ArrayBase<D1, Ix2>) -> Result<LUFactors<Self>, LUError>
where D1: Data<Elem = Self>
{
Self::compute_into(a.to_owned())
}
fn compute_inverse_into<D1>(mat: ArrayBase<D1, Ix2>, perm: &[i32])
-> Result<ArrayBase<D1, Ix2>, LUError>
where D1: DataOwned<Elem = Self> + DataMut<Elem=Self>;
fn compute_inverse<D1>(mat: &ArrayBase<D1, Ix2>, perm: &[i32])
-> Result<Array<Self, Ix2>, LUError>
where D1: Data<Elem = Self> {
if mat.rows() != mat.cols() {
return Err(LUError::NotSquare);
}
let copy_mat = mat.to_owned();
Self::compute_inverse_into(copy_mat, perm)
}
}
macro_rules! impl_lu {
($lu_type:ty, $lu_func:ident, $lu_invert:ident) => (
impl LU for $lu_type {
fn compute_into(mut a: Array<Self, Ix2>) -> Result<LUFactors<Self>, LUError> {
let dim = a.dim();
let (info, perm_i) = {
let (mut slice, layout, lda) = match slice_and_layout_mut(&mut a) {
None => return Err(LUError::BadLayout),
Some(x) => x,
};
let mut perm_i = Vec::new();
perm_i.resize(cmp::min(dim.0, dim.1), -1);
($lu_func(layout,
dim.0 as i32,
dim.1 as i32,
&mut slice,
lda as i32,
&mut perm_i),
perm_i)
};
if info == 0 {
LUFactors::from_raw(a, perm_i)
} else if info < 0 {
Err(LUError::IllegalParameter(-info))
} else {
Err(LUError::Singular)
}
}
fn compute_inverse_into<D1>(mut mat: ArrayBase<D1, Ix2>, perm: &[i32])
-> Result<ArrayBase<D1, Ix2>, LUError>
where D1: DataOwned<Elem = Self> + DataMut<Elem=Self> {
let dim = mat.dim();
if dim.0 != dim.1 {
return Err(LUError::NotSquare);
}
let info = {
let (mut slice, layout, lda) = match slice_and_layout_mut(&mut mat) {
None => return Err(LUError::BadLayout),
Some(x) => x,
};
$lu_invert(layout, dim.0 as i32, &mut slice, lda as i32, &perm)
};
if info == 0 {
Ok(mat)
} else if info < 0 {
Err(LUError::IllegalParameter(-info))
} else {
Err(LUError::Singular)
}
}
}
)
}
impl_lu!(f32, sgetrf, sgetri);
impl_lu!(f64, dgetrf, dgetri);
impl_lu!(c32, cgetrf, cgetri);
impl_lu!(c64, zgetrf, zgetri);