use impl_prelude::*;
use lapack::c::{sgels, sgelsd, dgels, dgelsd, cgels, cgelsd, zgels, zgelsd};
pub struct LeastSquaresSolution<T, D: Dimension> {
pub solution: Array<T, D>,
pub rank: usize,
}
pub enum LeastSquaresType {
Degenerate,
Full,
}
#[derive(Debug)]
pub enum LeastSquaresError {
BadLayout,
InconsistentLayout,
InconsistentDimensions(usize, usize),
Degenerate,
IllegalParameter(i32),
}
pub trait LeastSquares: Sized + Clone {
fn compute_multi_full_into<D1, D2>
(a: ArrayBase<D1, Ix2>,
b: ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>;
fn compute_multi_full<D1, D2>(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
Self::compute_multi_full_into(a.to_owned(), b.to_owned())
}
fn compute_multi_degenerate<D1, D2>
(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
Self::compute_multi_degenerate_into(a.to_owned(), b.to_owned())
}
fn compute_multi_degenerate_into<D1, D2>
(a: ArrayBase<D1, Ix2>,
b: ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: DataMut<Elem = Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem = Self> + DataOwned<Elem = Self>;
fn compute_multi<D1, D2>(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
let r = Self::compute_multi_full(a, b);
match r {
Err(LeastSquaresError::Degenerate) => Self::compute_multi_degenerate(a, b),
x => x,
}
}
fn compute<D1, D2>(a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix1>)
-> Result<LeastSquaresSolution<Self, Ix1>, LeastSquaresError>
where D1: Data<Elem = Self>,
D2: Data<Elem = Self>
{
let n = b.dim();
let b_mat = match b.to_owned().into_shape((n, 1)) {
Ok(x) => x,
Err(_) => return Err(LeastSquaresError::BadLayout),
};
let res = try!(Self::compute_multi(a, &b_mat));
Ok(LeastSquaresSolution {
solution: res.solution.into_shape(n).unwrap(),
rank: res.rank,
})
}
}
fn resize_solution<T: Clone + Default, D>(mut b_sol: ArrayBase<D, Ix2>, n: usize) -> Array<T, Ix2>
where D: DataMut<Elem = T>
{
let b_dim = b_sol.dim();
if b_dim.0 > n {
b_sol.slice_mut(s![0..n as isize, ..]).to_owned()
} else {
let mut extended_sol = Array::default((n, b_dim.1));
extended_sol.slice_mut(s![0..b_dim.0 as isize, ..]).assign(&b_sol);
extended_sol
}
}
macro_rules! impl_least_squares {
($impl_type:ty, $sv_type:ty, $full_func:ident, $degen_func:ident) => (
impl LeastSquares for $impl_type {
fn compute_multi_full_into<D1, D2>(
mut a: ArrayBase<D1, Ix2>,
mut b: ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
let a_dim = a.dim();
let b_dim = b.dim();
if a_dim.0 != b_dim.0 {
return Err(LeastSquaresError::InconsistentDimensions(a_dim.0, b_dim.0));
}
let (a_slice, layout, lda) = match slice_and_layout_mut(&mut a) {
Some(x) => x,
None => return Err(LeastSquaresError::BadLayout)
};
let info = {
let (b_slice, ldb) = match slice_and_layout_matching_mut(&mut b, layout) {
Some(x) => x,
None => return Err(LeastSquaresError::InconsistentLayout)
};
$full_func(layout, b'N', a_dim.0 as i32 , a_dim.1 as i32, b_dim.1 as i32,
a_slice, lda as i32,
b_slice, ldb as i32)
};
if info == 0 {
Ok(LeastSquaresSolution {
solution: resize_solution(b, a_dim.1),
rank: cmp::min(a_dim.0, a_dim.1)
})
} else if info < 0 {
Err(LeastSquaresError::IllegalParameter(-info))
} else {
Err(LeastSquaresError::Degenerate)
}
}
fn compute_multi_degenerate_into<D1, D2>(
mut a: ArrayBase<D1, Ix2>,
mut b: ArrayBase<D2, Ix2>)
-> Result<LeastSquaresSolution<Self, Ix2>, LeastSquaresError>
where D1: DataMut<Elem=Self> + DataOwned<Elem = Self>,
D2: DataMut<Elem=Self> + DataOwned<Elem = Self> {
let a_dim = a.dim();
let b_dim = b.dim();
if a_dim.0 != b_dim.0 {
return Err(LeastSquaresError::InconsistentDimensions(a_dim.0, b_dim.0));
}
let (a_slice, layout, lda) = match slice_and_layout_mut(&mut a) {
Some(x) => x,
None => return Err(LeastSquaresError::BadLayout)
};
let mut svs: Array<$sv_type, Ix1> = Array::default(cmp::min(a_dim.0, a_dim.1));
let mut rank: i32 = 0;
let info = {
let (b_slice, ldb) = match slice_and_layout_matching_mut(&mut b, layout) {
Some(x) => x,
None => return Err(LeastSquaresError::InconsistentLayout)
};
$degen_func(layout, a_dim.0 as i32 , a_dim.1 as i32, b_dim.1 as i32,
a_slice, lda as i32,
b_slice, ldb as i32,
svs.as_slice_mut().unwrap(), 0.0,
&mut rank)
};
if info == 0 {
Ok(LeastSquaresSolution {
solution: resize_solution(b, a_dim.1),
rank: rank as usize })
} else if info < 0 {
Err(LeastSquaresError::IllegalParameter(-info))
} else {
unreachable!();
}
}
}
)
}
impl_least_squares!(f32, f32, sgels, sgelsd);
impl_least_squares!(f64, f64, dgels, dgelsd);
impl_least_squares!(c32, f32, cgels, cgelsd);
impl_least_squares!(c64, f64, zgels, zgelsd);