mdarray-linalg-faer 0.1.2

Faer backend for mdarray-linalg
Documentation
use dyn_stack::{MemBuffer, MemStack};
use faer_traits::ComplexField;
use mdarray::{DSlice, Layout};
use num_complex::ComplexFloat;

use crate::into_faer_mut;

pub fn lu_faer<
    T: ComplexFloat + ComplexField + Default + 'static,
    La: Layout,
    Ll: Layout,
    Lu: Layout,
    Lp: Layout,
>(
    a: &mut DSlice<T, 2, La>,
    l_mda: &mut DSlice<T, 2, Ll>,
    u_mda: &mut DSlice<T, 2, Lu>,
    p_mda: &mut DSlice<T, 2, Lp>,
) {
    let (m, n) = *a.shape();
    let min_mn = m.min(n);
    let par = faer::get_global_parallelism();

    let mut lu_mat = into_faer_mut(a);

    let mut row_perm_fwd = vec![0usize; m];
    let mut row_perm_bwd = vec![0usize; m];

    // avoid a copy done in intern bu Faer
    faer::linalg::lu::partial_pivoting::factor::lu_in_place(
        lu_mat.as_mut(),
        &mut row_perm_fwd,
        &mut row_perm_bwd,
        par,
        MemStack::new(&mut MemBuffer::new(
            faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(
                m,
                n,
                par,
                faer::prelude::default(),
            ),
        )),
        faer::prelude::default(),
    );

    let mut l_faer = into_faer_mut(l_mda);
    for i in 0..m {
        for j in 0..min_mn {
            if i == j {
                l_faer[(i, j)] = T::one();
            } else if i > j {
                l_faer[(i, j)] = lu_mat[(i, j)];
            } else {
                l_faer[(i, j)] = T::zero();
            }
        }
    }

    let mut u_faer = into_faer_mut(u_mda);
    for i in 0..min_mn {
        for j in 0..n {
            if i <= j {
                u_faer[(i, j)] = lu_mat[(i, j)];
            } else {
                u_faer[(i, j)] = T::zero();
            }
        }
    }

    let mut p_faer = into_faer_mut(p_mda);
    for i in 0..m {
        for j in 0..m {
            p_faer[(i, j)] = T::zero();
        }
    }
    for i in 0..m {
        let perm_idx = row_perm_fwd[i];
        p_faer[(i, perm_idx)] = T::one();
    }
}