petal-decomposition 0.6.2

Matrix decomposition algorithms including PCA (principal component analysis) and ICA (independent component analysis)
Documentation
use lair::Scalar;
use num_complex::{Complex32, Complex64};
use num_traits::{ToPrimitive, Zero};
use std::cmp;

pub(super) type SvdOutput<A> = (Vec<<A as Scalar>::Real>, Vec<A>, Option<Vec<A>>);

pub trait Lapack: Scalar {
    unsafe fn gelqf(m: i32, n: i32, a: &mut [Self], ld_a: i32) -> Result<Vec<Self>, i32>;

    unsafe fn gesdd(m: i32, n: i32, a: &mut [Self]) -> Result<SvdOutput<Self>, i32>;

    unsafe fn gesvd(jobvt: u8, m: i32, n: i32, a: &mut [Self]) -> Result<SvdOutput<Self>, i32>;

    unsafe fn heev(
        jobz: u8,
        uplo: u8,
        n: i32,
        a: &mut [Self],
        ld_a: i32,
    ) -> Result<Vec<Self::Real>, i32>;

    unsafe fn unglq(
        m: i32,
        n: i32,
        k: i32,
        a: &mut [Self],
        ld_a: i32,
        tau: &[Self],
    ) -> Result<(), i32>;
}

macro_rules! impl_lapack {
    (@real, $scalar:ty, $heev:path, $gelqf:path, $gesdd:path, $gesvd:path, $unglq:path) => {
        impl_lapack!(@body, $scalar, $heev, $gelqf, $gesdd, $gesvd, $unglq, );
    };
    (@complex, $scalar:ty, $heev:path, $gelqf:path, $gesdd:path, $gesvd:path, $unglq:path) => {
        impl_lapack!(@body, $scalar, $heev, $gelqf, $gesdd, $gesvd, $unglq, rwork);
    };
    (@body, $scalar:ty, $heev:path, $gelqf:path, $gesdd:path, $gesvd:path, $unglq:path, $($rwork_ident:ident),*) => {
        impl Lapack for $scalar {
            unsafe fn gelqf(m: i32, n: i32, mut a: &mut [Self], ld_a: i32) -> Result<Vec<Self>, i32> {
                let k = cmp::min(m, n);
                let mut tau = vec_uninit(k as usize);

                let mut info = 0;
                let mut work_size = [Self::zero()];
                $gelqf(m, n, &mut a, ld_a, &mut tau, &mut work_size, -1, &mut info);
                if info != 0 {
                    return Err(info);
                }

                let lwork = work_size[0].to_usize().expect("valid integer");
                let mut work = vec_uninit(lwork);
                $gelqf(m, n, &mut a, ld_a, &mut tau, &mut work, lwork as i32, &mut info);
                if info != 0 {
                    return Err(info);
                }

                Ok(tau)
            }

            unsafe fn gesdd(m: i32, n: i32, mut a: &mut [Self]) -> Result<SvdOutput<Self>, i32> {
                let k = cmp::min(m, n);
                let mut s = vec_uninit(k as usize);
                let mut u = vec_uninit((n * k) as usize);
                let mut vt = vec_uninit((k * m) as usize);

                $(
                let mut $rwork_ident = {
                    let max_dim = cmp::max(m, n) as usize;
                    let min_dim = k as usize;
                    let lwork = cmp::max(5 * min_dim * min_dim + 5 * max_dim + min_dim, 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim);
                    vec_uninit(lwork as usize)
                };
                )*

                let mut info = 0;
                let mut iwork = vec_uninit(8 * k as usize);
                let mut work_size = [Self::zero()];
                $gesdd(b'S', n, m, &mut a, n, &mut s, &mut u, n, &mut vt, k, &mut work_size, -1, $(&mut $rwork_ident,)* &mut iwork, &mut info);
                if info != 0 {
                    return Err(info);
                }

                let lwork = work_size[0].to_usize().expect("valid integer");
                let mut work = vec_uninit(lwork);
                $gesdd(b'S', n, m, &mut a, n, &mut s, &mut u, n, &mut vt, k, &mut work, lwork as i32, $(&mut $rwork_ident,)* &mut iwork, &mut info);
                if info != 0 {
                    return Err(info);
                }

                Ok((s, vt, Some(u)))
            }

            unsafe fn gesvd(jobvt: u8, m: i32, n: i32, mut a: &mut [Self]) -> Result<SvdOutput<Self>, i32> {
                let k = cmp::min(m, n);
                let mut s = vec_uninit(k as usize);
                let mut u = if jobvt == b'A' {
                    vec_uninit((n * n) as usize)
                } else {
                    Vec::new()
                };
                let mut vt = vec_uninit((m * m) as usize);

                $(
                let mut $rwork_ident = vec_uninit(5 * k as usize);
                )*

                let mut info = 0;
                let mut work_size = [Self::zero()];
                $gesvd(jobvt, b'A', n, m, &mut a, n, &mut s, &mut u, n, &mut vt, m, &mut work_size, -1, $(&mut $rwork_ident,)* &mut info);
                if info != 0 {
                    return Err(info);
                }

                let lwork = work_size[0].to_usize().expect("valid integer");
                let mut work = vec_uninit(lwork);
                $gesvd(jobvt, b'A', n, m, &mut a, n, &mut s, &mut u, n, &mut vt, m, &mut work, lwork as i32, $(&mut $rwork_ident,)* &mut info);
                if info != 0 {
                    return Err(info);
                }

                Ok((s, vt, if jobvt == b'A' { Some(u) } else { None }))
            }

            unsafe fn heev(
                jobz: u8,
                uplo: u8,
                n: i32,
                mut a: &mut [Self],
                ld_a: i32,
            ) -> Result<Vec<Self::Real>, i32> {
                let mut eigs = vec_uninit(n as usize);

                $(
                let mut $rwork_ident = vec_uninit(3 * n as usize - 2 as usize);
                )*

                let mut info = 0;
                let mut work_size = [Self::zero()];
                    $heev(
                        jobz,
                        uplo,
                        n,
                        &mut a,
                        ld_a,
                        &mut eigs,
                        &mut work_size,
                        -1,
                        $(&mut $rwork_ident,)*
                        &mut info,
                    );
                if info != 0 {
                    return Err(info);
                }

                let lwork = work_size[0].to_usize().expect("valid integer");
                let mut work = vec_uninit(lwork);
                    $heev(
                        jobz,
                        uplo,
                        n,
                        &mut a,
                        n,
                        &mut eigs,
                        &mut work,
                        lwork as i32,
                        $(&mut $rwork_ident,)*
                        &mut info,
                    );
                if info != 0 {
                    return Err(info);
                }

                Ok(eigs)
            }

            unsafe fn unglq(m: i32, n: i32, k: i32, mut a: &mut [Self], ld_a: i32, tau: &[Self]) -> Result<(), i32> {
                let mut info = 0;
                let mut work_size = [Self::zero()];
                $unglq(m, n, k, &mut a, ld_a, &tau, &mut work_size, -1, &mut info);
                if info != 0 {
                    return Err(info);
                }

                let lwork = work_size[0].to_usize().expect("valid integer");
                let mut work = vec_uninit(lwork);
                $unglq(m, n, k, &mut a, ld_a, &tau, &mut work, lwork as i32, &mut info);
                if info != 0 {
                    return Err(info);
                }

                Ok(())
            }
        }
    };
}

impl_lapack!(@real, f32, lapack::ssyev, lapack::sgelqf, lapack::sgesdd, lapack::sgesvd, lapack::sorglq);
impl_lapack!(@real, f64, lapack::dsyev, lapack::dgelqf, lapack::dgesdd, lapack::dgesvd, lapack::dorglq);
impl_lapack!(@complex, Complex32, lapack::cheev, lapack::cgelqf, lapack::cgesdd, lapack::cgesvd, lapack::cunglq);
impl_lapack!(@complex, Complex64, lapack::zheev, lapack::zgelqf, lapack::zgesdd, lapack::zgesvd, lapack::zunglq);

/// Creates a vector with uninitialized elements.
///
/// # Safety
///
/// The caller must ensure that the elements of the vector are not used before
/// they are initialized.
#[allow(clippy::uninit_vec)]
unsafe fn vec_uninit<T: Sized>(n: usize) -> Vec<T> {
    let mut v = Vec::with_capacity(n);
    v.set_len(n);
    v
}