#![allow(non_snake_case)]
//use num_traits::Float;
use ndarray::prelude::*;
use ndarray::Zip;
fn main() {
/*
println!("Example 1:");
let A: Array2<f64> = arr2(&[
[1.0, 3.0, 5.0],
[2.0, 4.0, 7.0],
[1.0, 1.0, 0.0],
]);
println!("A \n {}", A);
let (L, U, P) = lu_decomp(&A);
println!("L \n {}", L);
println!("U \n {}", U);
println!("P \n {}", P);
*/
println!("\nExample 2:");
let A: Array2<f64> = arr2(&[
[11.0, 9.0, 24.0, 2.0],
[1.0, 5.0, 2.0, 6.0],
[3.0, 17.0, 18.0, 1.0],
[2.0, 5.0, 7.0, 1.0],
]);
/*
let A: Array2<f64> = array![
[4.3552 , 6.25851, 4.12662, 1.93708, 0.21272, 3.25683, 6.53326],
[4.24746, 1.84137, 6.71904, 0.59754, 3.5806 , 3.63597, 5.347 ],
[2.30479, 1.70591, 3.05354, 1.82188, 5.27839, 7.9166 , 2.04607],
[2.40158, 6.38524, 7.90296, 4.69683, 6.63801, 7.32958, 1.45936],
[0.42456, 6.47456, 1.55398, 8.28979, 4.20987, 0.90401, 4.94587],
[5.78903, 1.92032, 6.20261, 5.78543, 1.94331, 8.25178, 7.47273],
[1.44797, 7.41157, 7.69495, 8.90113, 3.05983, 0.41582, 6.42932]];
*/
//let A: Array2<f64> = array![[7.0, 3.0, -1.0, 2.0], [3.0, 8.0, 1.0, -4.0], [-1.0, 1.0, 4.0, -1.0], [2.0, -4.0, -1.0, 6.0]];
let (mut L, mut U, mut P) = lu_decomp(&A);
//println!("A \n {}", A);
println!("L \n {}", L);
println!("U \n {}", U);
println!("P \n {}", P);
/*
*/
println!("linv {:?}", linv(&L, 4));
println!("uinv {:?}", uinv(&U, 4));
println!("inv {:?}", inverse(&A));
println!("inv inv {:?}", inverse(&inverse(&A).unwrap()));
//println!("linv {:?}", linv(&L, 7));
//println!("uinv {:?}", uinv(&U, 7));
/*
let mut _q = inverse(&A);
for i in 0 .. 1000000 {
_q = inverse(&A);
}
*/
}
fn lu_decomp<T: NdFloat>(A: &Array2<T>) -> (Array2<T>, Array2<T>, Array2<T>) {
fn pivot<T: NdFloat>(A: &Array2<T>) -> Array2<T> {
fn swap<T: NdFloat>(A: &mut Array2<T>, ir1: usize, ir2: usize) {
let (.., mut rest) = A.view_mut().split_at(Axis(0), ir1);
let (r0, mut rest) = rest.view_mut().split_at(Axis(0), 1);
let (.., mut rest) = rest.view_mut().split_at(Axis(0), ir2 - ir1 - 1);
let (r1, ..) = rest.view_mut().split_at(Axis(0), 1);
Zip::from(r0).and(r1).for_each(std::mem::swap);
}
let n = A.raw_dim()[0];
let mut P: Array2<T> = Array::eye(n);
for (idx, col) in A.axis_iter(Axis(1)).enumerate() {
// find idx of maximum value in column i
let mut mp = idx;
for i in idx .. n {
if col[mp].abs() < col[i].abs() {
mp = i;
}
}
// swap rows if necessary
if mp != idx {
swap(&mut P, idx, mp);
}
}
P
}
let d = A.raw_dim();
let n = d[0];
assert_eq!(n, d[1], "LU decomposition must take a square matrix.");
let P = pivot(&A);
let pA = P.dot(A);
let mut L: Array2<T> = Array::eye(n);
let mut U: Array2<T> = Array::zeros((n, n));
for c in 0 .. n {
for r in 0 .. n {
let pAs = pA[[r, c]] - U.slice(s![0..r, c]).dot(&L.slice(s![r, 0..r]));
if r < c + 1 { // U
U[[r, c]] = pAs;
} else { // L
L[[r, c]] = (pAs) / U[[c, c]];
}
}
}
(L, U, P)
}
fn uinv(l: &Array2<f64>, n: usize) -> Array2<f64> {
let mut m: Array2<f64> = Array2::zeros((n, n));
/*
for i in 0 .. n {
if l[(i, i)] == 0.0 {
panic!(); // return m;
}
m[(i, i)] = 1.0 / l[(i, i)];
for j in 0 .. i {
for k in j .. i {
m[(j, i)] += l[(k, i)] * m[(j, k)];
}
m[(j, i)] = -m[(j, i)] / l[(i, i)];
}
}
*/
for i in n -1 ..= 0 {
if l[(i, i)] == 0.0 {
panic!(); // return m;
}
m[(i, i)] = 1.0 / l[(i, i)];
for j in 0 .. i {
for k in j .. i {
m[(j, i)] += l[(k, i)] * m[(j, k)];
}
m[(j, i)] = -m[(j, i)] / l[(i, i)];
}
}
m
}
fn linv(u: &Array2<f64>, n: usize) -> Array2<f64> {
let ut = u.t().to_owned();
uinv(&ut, n).t().to_owned()
}
fn inverse(s: &Array2<f64>) -> Option<Array2<f64>> {
let d = s.raw_dim();
let n = d[0];
assert!(d[0] == d[1]);
let (l, u, _) = lu_decomp(s);
let lt = linv(&l, n);
let ut = uinv(&u, n);
Some(ut.dot(<))
}
/*
fn inverse(s: &Array2<f64>) -> Option<Array2<f64>> {
let d = s.raw_dim();
let n = d[0];
assert!(d[0] == d[1]);
let mut inv: Array2<f64> = Array2::zeros((n, n));
let mut e: Array1<f64> = Array1::zeros(n);
let (mut l, mut u, mut p) = lu_decomp(s);
for i in 0 .. n {
e[i] = 1.0;
let col = match solve(&s, e) {
Some(col) => {
for j in 0 .. n {
inv[[j, i]] = col[j];
}
//e = col.apply(&|_| 0.0);
e = Array1::zeros(n);
},
None => return None,
};
}
Some(inv)
}
fn solve(lu: &Array2<f64>, p: &Array2<f64>, b: Array1<f64>) -> Option<Array1<f64>> {
let d = lu.raw_dim();
assert!(d[0] == d[1]);
let b = lu_forward_substitution(lu, p * b);
back_substitution(lu, b)
}
fn lu_forward_substitution<T: Float>(l: &Array2<T>, b: Array1<T>) -> Array1<T> {
let mut x = b.clone();
for i in 0 .. b.len() {
//for (i, row) in lu.row_iter().enumerate().skip(1) {
// Note that at time of writing we need raw_slice here for
// auto-vectorization to kick in
/*
let adjustment = row.raw_slice()
.iter()
.take(i)
.cloned()
.zip(x.iter().cloned())
.fold(T::zero(), |sum, (l, x)| sum + l * x);
*/
x[i] = x[i] - adjustment;
}
x
}
fn back_substitution(u: &Array2<f64>, y: Array1<f64>) -> Array1<f64> {
let n = u.raw_dim()[0];
let mut x = y;
for i in (0 .. n).rev() {
let row = u.row(i);
let divisor = unsafe { u.get_unchecked([i, i]).clone() };
let dot = {
let row_part = &row.raw_slice()[(i + 1) .. n];
let x_part = &x.data()[(i + 1) .. n];
row_part.dot(x_part)
};
x[i] = (x[i] - dot) / divisor;
}
x
}
*/