ndarray-inverse 0.1.9

Pure Rust Inverse and Determinant trait for ndarray Array2
Documentation
#![allow(non_snake_case)]
//use num_traits::Float;
use ndarray::prelude::*;
use ndarray_inverse::*;
use ndarray::Zip;

fn main() {
    /*
    println!("Example 1:");
    let A: Array2<f64> = arr2(&[
        [1.0, 3.0, 5.0],
        [2.0, 4.0, 7.0],
        [1.0, 1.0, 0.0],
    ]);
    println!("A \n {}", A);
    let (L, U, P) = lu_decomp(&A);
    println!("L \n {}", L);
    println!("U \n {}", U);
    println!("P \n {}", P);
    */

    println!("\nExample 2:");
    /*
    let A: Array2<f64> = arr2(&[
        [11.0, 9.0, 24.0, 2.0],
        [1.0, 5.0, 2.0, 6.0],
        [3.0, 17.0, 18.0, 1.0],
        [2.0, 5.0, 7.0, 1.0],
    ]);
    let A: Array2<f64> = array![
            [4.3552 , 6.25851, 4.12662, 1.93708, 0.21272, 3.25683, 6.53326],
            [4.24746, 1.84137, 6.71904, 0.59754, 3.5806 , 3.63597, 5.347  ],
            [2.30479, 1.70591, 3.05354, 1.82188, 5.27839, 7.9166 , 2.04607],
            [2.40158, 6.38524, 7.90296, 4.69683, 6.63801, 7.32958, 1.45936],
            [0.42456, 6.47456, 1.55398, 8.28979, 4.20987, 0.90401, 4.94587],
            [5.78903, 1.92032, 6.20261, 5.78543, 1.94331, 8.25178, 7.47273],
            [1.44797, 7.41157, 7.69495, 8.90113, 3.05983, 0.41582, 6.42932]];
    let A: Array2<f64> = array![
            [-68.0, 68.0, -16.0, 4.0],
            [-36.0, 35.0, -9.0, 3.0],
            [48.0, -47.0, 11.0, -3.0],
            [64.0, -64.0, 16.0, -4.0]];
    */
    let A: Array2<f64> = array![
            [1.0, 1.0, 3.0, 4.0, 9.0, 3.0],
            [10.0, 10.0, 1.0, 2.0, 2.0, 5.0],
            [2.0, 9.0, 6.0, 10.0, 10.0, 9.0],
            [10.0, 9.0, 9.0, 7.0, 3.0, 6.0],
            [7.0, 6.0, 6.0, 2.0, 9.0, 5.0],
            [3.0, 8.0, 1.0, 4.0, 1.0, 5.0]
        ];

    //let A: Array2<f64> = array![[7.0, 3.0, -1.0, 2.0], [3.0, 8.0, 1.0, -4.0], [-1.0, 1.0, 4.0, -1.0], [2.0, -4.0, -1.0, 6.0]];

    if let Some((L, U, P)) = lu_decomp(&A) {
        /*
        //println!("A \n {}", A);
        */
        println!("L \n {}", L);
        println!("U \n {}", U);
        println!("P \n {}", P);

        //println!("Linv \n {:?}", linv(&L, 4));
        //println!("linv {:?}", L.inv());
        //println!("Uinv \n {:?}", uinv(&U, 4));
        println!("inverse \n{:?}", inverse(&A));
        println!("inv \n{:?}", A.inv());
        /*
        //println!("uinv {:?}", U.inv());
        println!("inv {:?}", inverse(&A));
        let inv2 = inverse(&inverse(&A).unwrap()).unwrap();
        println!("inv inv {:?}", inv2);
        //let inv4 = inverse(&inverse(&inv2).unwrap()).unwrap();
        //println!("inv inv inv inv {:?}", inv4);
        //println!("linv {:?}", linv(&L, 7));
        //println!("uinv {:?}", uinv(&U, 7));
        //assert!(inverse(&A) == A.inv());

        let mut _q = inverse(&A);
        for i in 0 .. 1000000 {
            _q = inverse(&A);
        }
        */
    }
}

fn lu_decomp<T: NdFloat>(A: &Array2<T>) -> Option<(Array2<T>, Array2<T>, Array2<T>)> {
    fn pivot<T: NdFloat>(A: &Array2<T>) -> Array2<T> {
        fn swap<T: NdFloat>(A: &mut Array2<T>, ir1: usize, ir2: usize) {
            /*
            let (.., mut rest) = A.view_mut().split_at(Axis(0), ir1);
            let (r0, mut rest) = rest.view_mut().split_at(Axis(0), 1);
            let (.., mut rest) = rest.view_mut().split_at(Axis(0), ir2 - ir1 - 1);
            let (r1, ..) = rest.view_mut().split_at(Axis(0), 1);
            */

            let (r0, r1) = A.multi_slice_mut((s![ir1, ..], s![ir2, ..]));
            Zip::from(r0).and(r1).for_each(std::mem::swap);
        }

        let n = A.raw_dim()[0];
        let mut P: Array2<T> = Array::eye(n);

        for (idx, col) in A.axis_iter(Axis(1)).enumerate() {
            // find index of maximum value in column i
            let mut mp = idx;
            for i in idx .. n {
                if col[mp].abs() < col[i].abs() {
                    mp = i;
                }
            }
            // swap rows when different
            if mp != idx {
                //println!("{idx}, {mp}");
                //println!("< {:?}", P);
                swap(&mut P, idx, mp);
                //println!("> {:?}", P);
            }
        }

        P
    }

    let d = A.raw_dim();
    let n = d[0];
    assert_eq!(n, d[1], "LU decomposition must take a square matrix.");

    let P = pivot(A);
    let pA = P.dot(A);

    let mut L: Array2<T> = Array::eye(n);
    let mut U: Array2<T> = Array::zeros((n, n));

    for c in 0 .. n {
        for r in 0 .. n {
            let pAs = pA[[r, c]] - U.slice(s![0..r, c]).dot(&L.slice(s![r, 0..r]));
            if pAs.is_nan() || pAs.is_infinite() {
                return None;
            }

            if r < c + 1 { // U
                U[[r, c]] = pAs;
            } else { // L
                L[[r, c]] = (pAs) / U[[c, c]];
            }
        }
    }

    Some((L, U, P))
}

fn inverse(s: &Array2<f64>) -> Option<Array2<f64>> {
    fn linv(l: &Array2<f64>, n: usize) -> Array2<f64> {
        let mut m: Array2<f64> = Array2::zeros((n, n));

        for i in 0 .. n {
            m[(i, i)] = 1.0 / l[(i, i)];

            for j in 0 .. i {
                 for k in j .. i {
                     m[(i, j)] += l[(i, k)] * m[(k, j)];
                 }

                 m[(i, j)] = -m[(i, j)] / l[(i, i)];
            }
        }

        m
    }

    fn uinv(u: &Array2<f64>, n: usize) -> Array2<f64> {
        linv(&u.t().to_owned(), n).t().to_owned()
    }

    let d = s.raw_dim();
    let n = d[0];

    assert!(d[0] == d[1]);

    if let Some((l, u, p)) = lu_decomp(s) {
        let lt = linv(&l, n);
        let ut = uinv(&u, n);

        Some(ut.dot(&lt).dot(&p))
    } else {
        None
    }
}