#![allow(non_snake_case)]
use libnum::{Num, zero, one, Zero, One};
use libnum::Float;
use libnum::Complex;
use std::ops::{Add, Sub, Mul, Div};
use super::{Array, Ix};
pub type Col<A> = Array<A, Ix>;
pub type Mat<A> = Array<A, (Ix, Ix)>;
pub trait Ring : Clone + Zero + Add<Output=Self> + Sub<Output=Self>
+ One + Mul<Output=Self> { }
impl<A: Clone + Zero + Add<Output=A> + Sub<Output=A> + One + Mul<Output=A>> Ring for A { }
pub trait Field : Ring + Div<Output=Self> { }
impl<A: Ring + Div<Output=A>> Field for A { }
pub trait ComplexField : Copy + Field
{
#[inline]
fn conjugate(self) -> Self { self }
fn sqrt_real(self) -> Self;
#[inline]
fn is_complex() -> bool { false }
}
impl ComplexField for f32
{
#[inline]
fn sqrt_real(self) -> f32 { self.sqrt() }
}
impl ComplexField for f64
{
#[inline]
fn sqrt_real(self) -> f64 { self.sqrt() }
}
#[cfg(not(nocomplex))]
impl<A: Num + Float> ComplexField for Complex<A>
{
#[inline]
fn conjugate(self) -> Complex<A> { self.conj() }
fn sqrt_real(self) -> Complex<A> { Complex::new(self.re.sqrt(), zero()) }
#[inline]
fn is_complex() -> bool { true }
}
pub fn eye<A: Clone + Zero + One>(n: Ix) -> Mat<A>
{
let mut eye = Array::zeros((n, n));
for a_ii in eye.diag_iter_mut() {
*a_ii = one::<A>();
}
eye
}
pub fn least_squares<A: ComplexField>(a: &Mat<A>, b: &Col<A>) -> Col<A>
{
let mut aT = a.clone();
aT.swap_axes(0, 1);
if <A as ComplexField>::is_complex() {
for elt in aT.iter_mut() {
*elt = elt.conjugate();
}
}
let aT_a = aT.mat_mul(a);
let mut L = cholesky(aT_a);
let rhs = aT.mat_mul_col(b);
let z = subst_fw(&L, &rhs);
if <A as ComplexField>::is_complex() {
let (m, _) = L.dim();
for i in 1..m {
for j in 0..i {
let elt = &mut L[(i, j)];
*elt = elt.conjugate();
}
}
}
L.swap_axes(0, 1);
subst_bw(&L, &z)
}
pub fn cholesky<A: ComplexField>(a: Mat<A>) -> Mat<A>
{
let z = zero::<A>();
let (m, n) = a.dim();
assert!(m == n);
let mut L = a;
for i in 0..m {
for j in 0..i {
let mut lik_ljk_sum = z;
{
let Lik = L.row_iter(i);
let Ljk = L.row_iter(j);
for (&lik, &ljk) in Lik.zip(Ljk).take(j as usize) {
lik_ljk_sum = lik_ljk_sum + lik * ljk.conjugate();
}
}
L[(i, j)] = (L[(i, j)] - lik_ljk_sum) / L[(j, j)];
}
let j = i;
let mut ljk_sum = z;
for &ljk in L.row_iter(j).take(j as usize) {
ljk_sum = ljk_sum + ljk * ljk.conjugate();
}
L[(j, j)] = (L[(j, j)] - ljk_sum).sqrt_real();
for j in i + 1..n {
L[(i, j)] = z;
}
}
L
}
fn vec_elem<A: Copy>(elt: A, n: usize) -> Vec<A>
{
let mut v = Vec::with_capacity(n);
for _ in 0..n {
v.push(elt);
}
v
}
pub fn subst_fw<A: Copy + Field>(l: &Mat<A>, b: &Col<A>) -> Col<A>
{
let (m, n) = l.dim();
assert!(m == n);
assert!(m == b.dim());
let mut x = vec_elem(zero::<A>(), m as usize);
for (i, bi) in b.indexed_iter() {
let mut b_lx_sum = *bi;
for (lij, xj) in l.row_iter(i).zip(x.iter()).take(i as usize) {
b_lx_sum = b_lx_sum - (*lij) * (*xj)
}
x[i as usize] = b_lx_sum / l[(i, i)];
}
Array::from_vec(x)
}
pub fn subst_bw<A: Copy + Field>(u: &Mat<A>, b: &Col<A>) -> Col<A>
{
let (m, n) = u.dim();
assert!(m == n);
assert!(m == b.dim());
let mut x = vec_elem(zero::<A>(), m as usize);
for i in (0..m).rev() {
let mut b_ux_sum = b[i];
for (uij, xj) in u.row_iter(i).rev().zip(x.iter().rev()).take((m - i - 1) as usize) {
b_ux_sum = b_ux_sum - (*uij) * (*xj);
}
x[i as usize] = b_ux_sum / u[(i, i)];
}
Array::from_vec(x)
}