#![allow(non_snake_case)]
extern crate ndarray;
extern crate num_traits;
extern crate num_complex;
use num_traits::{Num, Zero, One};
use num_traits::Float;
use num_complex::Complex;
use std::ops::{Add, Sub, Mul, Div};
use ndarray::{RcArray, Ix1, Ix2};
use ndarray::{rcarr1, rcarr2};
use ndarray::LinalgScalar;
pub type Col<A> = RcArray<A, Ix1>;
pub type Mat<A> = RcArray<A, Ix2>;
pub trait Ring : Clone + Zero + Add<Output=Self> + Sub<Output=Self>
+ One + Mul<Output=Self> { }
impl<A: Clone + Zero + Add<Output=A> + Sub<Output=A> + One + Mul<Output=A>> Ring for A { }
pub trait Field : Ring + Div<Output=Self> { }
impl<A: Ring + Div<Output=A>> Field for A { }
pub trait ComplexField : LinalgScalar
{
#[inline]
fn conjugate(self) -> Self { self }
fn sqrt_real(self) -> Self;
#[inline]
fn is_complex() -> bool { false }
}
impl ComplexField for f32
{
#[inline]
fn sqrt_real(self) -> f32 { self.sqrt() }
}
impl ComplexField for f64
{
#[inline]
fn sqrt_real(self) -> f64 { self.sqrt() }
}
impl<A: LinalgScalar + Float + Num> ComplexField for Complex<A>
{
#[inline]
fn conjugate(self) -> Complex<A> { self.conj() }
fn sqrt_real(self) -> Complex<A> { Complex::new(self.re.sqrt(), A::zero()) }
#[inline]
fn is_complex() -> bool { true }
}
fn main() {
chol();
subst();
lst_squares();
}
fn chol()
{
let _ = rcarr2(&[[1., 2.], [3., 4.]]); let a = rcarr2(&[[10., 14.], [14., 20.]]);
let chol = cholesky(a);
let ans =
rcarr2(&[[3.16227770, 0.00000000],
[4.42718887, 0.63245525]]);
assert!(ans.all_close(&chol, 0.001));
let b = RcArray::linspace(0f32, 8., 9).reshape((3, 3));
let mut bt = b.clone();
bt.swap_axes(0, 1);
let bpd = bt.dot(&b).into_shared();
println!("bpd=\n{:?}", bpd);
let chol = cholesky(bpd);
println!("chol=\n{:.8?}", chol);
let ans =
rcarr2(&[[6.70820379, 0.00000000, 0.00000000],
[8.04984474, 1.09544373, 0.00000000],
[9.39148617, 2.19088745, 0.00000000]]);
assert!(ans.all_close(&chol, 0.001));
let a =
rcarr2(&[[ 0.05201001, 0.22982409, 0.1014132 ],
[ 0.22982409, 1.105822 , 0.37946544],
[ 0.1014132 , 0.37946544, 1.16199134]]);
let chol = cholesky(a);
let ans =
rcarr2(&[[ 0.22805704, 0. , 0. ],
[ 1.00774829, 0.30044197, 0. ],
[ 0.44468348, -0.2285419 , 0.95499557]]);
assert!(ans.all_close(&chol, 0.001));
}
fn subst()
{
let lll =
rcarr2(&[[ 0.22805704, 0. , 0. ],
[ 1.00774829, 0.30044197, 0. ],
[ 0.44468348, -0.2285419 , 0.95499557]]);
let ans = rcarr1::<f32>(&[4.384868, -8.050947, -0.827078]);
assert!(ans.all_close(&subst_fw(&lll, &rcarr1(&[1., 2., 3.])),
0.001));
}
fn lst_squares()
{
let xs =
rcarr2(&[[ 2., 3.],
[-2., -1.],
[ 1., 5.],
[-1., 2.]]);
let b = rcarr1(&[1., -1., 2., 1.]);
let x_lstsq = least_squares(&xs, &b);
let ans = rcarr1(&[0.070632, 0.390335]);
assert!(x_lstsq.all_close(&ans, 0.001));
}
pub fn least_squares<A: ComplexField>(a: &Mat<A>, b: &Col<A>) -> Col<A>
{
let mut aT = a.clone();
aT.swap_axes(0, 1);
if <A as ComplexField>::is_complex() {
for elt in aT.iter_mut() {
*elt = elt.conjugate();
}
}
let aT_a = aT.dot(a).into_shared();
let mut L = cholesky(aT_a);
let rhs = aT.dot(b).into_shared();
let z = subst_fw(&L, &rhs);
if <A as ComplexField>::is_complex() {
let (m, _) = L.dim();
for i in 1..m {
for j in 0..i {
let elt = &mut L[[i, j]];
*elt = elt.conjugate();
}
}
}
L.swap_axes(0, 1);
subst_bw(&L, &z)
}
pub fn cholesky<A: ComplexField>(a: Mat<A>) -> Mat<A>
{
let z = A::zero();
let (m, n) = a.dim();
assert!(m == n);
let mut L = a;
for i in 0..m {
for j in 0..i {
let mut lik_ljk_sum = z;
{
let Lik = L.row(i).into_iter();
let Ljk = L.row(j).into_iter();
for (&lik, &ljk) in Lik.zip(Ljk).take(j as usize) {
lik_ljk_sum = lik_ljk_sum + lik * ljk.conjugate();
}
}
L[[i, j]] = (L[[i, j]] - lik_ljk_sum) / L[[j, j]];
}
let j = i;
let mut ljk_sum = z;
for &ljk in L.row(j).into_iter().take(j as usize) {
ljk_sum = ljk_sum + ljk * ljk.conjugate();
}
L[[j, j]] = (L[[j, j]] - ljk_sum).sqrt_real();
for j in i + 1..n {
L[[i, j]] = z;
}
}
L
}
pub fn subst_fw<A: Copy + Field>(l: &Mat<A>, b: &Col<A>) -> Col<A>
{
let (m, n) = l.dim();
assert!(m == n);
assert!(m == b.len());
let mut x = Col::zeros(m);
for i in 0..m {
let mut b_lx_sum = b[i];
for j in 0..i {
b_lx_sum = b_lx_sum - l[[i, j]] * x[j];
}
x[i] = b_lx_sum / l[[i, i]];
}
x
}
pub fn subst_bw<A: Copy + Field>(u: &Mat<A>, b: &Col<A>) -> Col<A>
{
let (m, n) = u.dim();
assert!(m == n);
assert!(m == b.len());
let mut x = Col::zeros(m);
for i in (0..m).rev() {
let mut b_ux_sum = b[i];
for j in i..m {
b_ux_sum = b_ux_sum - u[[i, j]] * x[j];
}
x[i] = b_ux_sum / u[[i, i]];
}
x
}