lair 0.6.0

Linear algebra in Rust
Documentation
use crate::{lapack, Scalar};
use ndarray::{s, ArrayBase, Axis, Data, DataMut, Ix1, Ix2};

#[allow(dead_code)]
pub fn q_tall<A, SA, ST>(a: &mut ArrayBase<SA, Ix2>, tau: &ArrayBase<ST, Ix1>)
where
    A: Scalar,
    SA: DataMut<Elem = A>,
    ST: Data<Elem = A>,
{
    if a.is_empty() {
        return;
    }

    assert!(a.nrows() >= a.ncols());
    assert!(a.ncols() >= tau.len());
    lapack::ungqr(a, tau);
}

#[allow(dead_code)]
pub fn p_square<A, SA, ST>(a: &mut ArrayBase<SA, Ix2>, tau: &ArrayBase<ST, Ix1>)
where
    A: Scalar,
    SA: DataMut<Elem = A>,
    ST: Data<Elem = A>,
{
    if a.is_empty() {
        return;
    }

    assert_eq!(a.nrows(), a.ncols());
    assert_eq!(a.ncols(), tau.len());
    a.column_mut(0).fill(A::zero());
    a.row_mut(0)[0] = A::one();

    for (j, mut col) in a.lanes_mut(Axis(0)).into_iter().enumerate().skip(1) {
        for i in (1..j).rev() {
            col[i] = col[i - 1];
        }
        col[0] = A::zero();
    }
    if a.ncols() > 1 {
        lapack::unglq(&mut a.slice_mut(s![1.., 1..]), &tau.slice(s![..-1]));
    }
}

#[cfg(test)]
mod tests {
    use ndarray::{arr1, arr2};
    use num_complex::Complex32;

    #[test]
    fn q_tall() {
        let mut a = arr2(&[
            [Complex32::new(2., 1.), Complex32::new(-3., 1.)],
            [Complex32::new(-1., -2.), Complex32::new(1., 3.)],
            [Complex32::new(3., -2.), Complex32::new(-2., -1.)],
        ]);
        let tau = arr1(&[Complex32::new(1., -1.)]);
        super::q_tall(&mut a, &tau);
        assert_eq!(
            a,
            arr2(&[
                [Complex32::new(0.0, 1.0), Complex32::new(-1.0, -3.0)],
                [Complex32::new(3.0, 1.0), Complex32::new(-4.0, 5.0)],
                [Complex32::new(-1.0, 5.0), Complex32::new(-9.0, -7.0)],
            ])
        );
    }

    #[test]
    fn p_square() {
        let mut a = arr2(&[
            [Complex32::new(2., 1.), Complex32::new(1., 3.)],
            [Complex32::new(-1., -2.), Complex32::new(-2., -1.)],
        ]);
        let tau = arr1(&[Complex32::new(1., -1.), Complex32::new(-2., 3.)]);
        super::p_square(&mut a, &tau);
        assert_eq!(
            a,
            arr2(&[
                [Complex32::new(1.0, 0.0), Complex32::new(0.0, 0.0)],
                [Complex32::new(0.0, 0.0), Complex32::new(0.0, -1.0),]
            ])
        );
    }
}