rsvd-faer 0.1.0

Randomized SVD for faer matrices in Rust
Documentation
use faer::linalg::matmul::matmul;
use faer::Parallelism;
use faer::{prelude::*, Mat};
use rand::Rng;
use rand_distr::{Distribution, Normal};

/// Randomized SVD for `faer` matrices.
///
/// # Arguments
///
/// * `a` - Input matrix of shape `(m, n)`.
/// * `k` - Desired target rank.
/// * `p` - Oversampling parameter.
/// * `q` - Number of power iterations.
/// * `rng` - Random number generator used for Gaussian sampling.
/// * `par` - `faer::Parallelism` mode for matrix multiplication.
///
/// # Example
///
/// ```rust
/// use faer::{Mat, Parallelism};
/// use rand::SeedableRng;
/// use rand_chacha::ChaCha8Rng;
/// use rsvd_faer::rsvd;
///
/// let mut rng = ChaCha8Rng::seed_from_u64(42);
/// let a = Mat::<f64>::from_fn(3, 3, |i, j| {
///     let data = [1.0, 2.0, 3.0, 8.0, 9.0, 4.0, 7.0, 6.0, 5.0];
///     data[i * 3 + j]
/// });
///
/// let (u, s, vt) = rsvd(a.as_ref(), 2, 5, 1, &mut rng, Parallelism::None);
/// assert_eq!(u.nrows(), 3);
/// assert_eq!(u.ncols(), 2);
/// assert_eq!(s.nrows(), 2);
/// assert_eq!(s.ncols(), 1);
/// assert_eq!(vt.nrows(), 2);
/// assert_eq!(vt.ncols(), 3);
/// ```
///
/// # Returns
///
/// * A tuple `(u, s, vt)`
/// * `u`: matrix shape `(m, k)`
/// * `s`: matrix shape `(k, 1)`, containing the top `k` singular values as a column vector
/// * `vt`: shape `(k, n)`

pub fn rsvd(
    a: MatRef<'_, f64>,
    k: usize,
    p: usize,
    q: usize,
    rng: &mut impl Rng,
    par: Parallelism,
) -> (Mat<f64>, Mat<f64>, Mat<f64>) {

    let m = a.nrows();
    let n = a.ncols();
    let l = (k + p).min(m).min(n);

    let normal = Normal::new(0.0, 1.0).unwrap();
    let omega = Mat::<f64>::from_fn(n, l, |_, _| normal.sample(rng));

    let mut y = Mat::<f64>::zeros(m, l);
    let mut z = Mat::<f64>::zeros(n, l);

    // Y = A * Omega
    matmul(y.as_mut(), a, omega.as_ref(), None, 1.0, par);

    for _ in 0..q {
        // Z = A^T * Y
        matmul(z.as_mut(), a.transpose(), y.as_ref(), None, 1.0, par);
        // Y = A * Z
        matmul(y.as_mut(), a, z.as_ref(), None, 1.0, par);
    }

    let q_mat = y.qr().compute_thin_q();

    let mut b = Mat::<f64>::zeros(l, n);
    matmul(b.as_mut(), q_mat.as_ref().transpose(), a, None, 1.0, par);

    let svd = b.thin_svd();
    let u_tilde = svd.u(); // l x l
    let s_vec = svd.s_diagonal(); // l
    let v_mat = svd.v(); // n x l

    let mut u_full = Mat::<f64>::zeros(m, l);
    matmul(u_full.as_mut(), q_mat.as_ref(), u_tilde, None, 1.0, par);

    let u = u_full.get(.., ..k).to_owned();
    let s = Mat::<f64>::from_fn(k, 1, |i, _| s_vec.read(i));
    let vt = v_mat.get(.., ..k).transpose().to_owned();

    (u, s, vt)
}

#[cfg(test)]
mod tests {
    use super::*;
    use faer::linalg::matmul::matmul;
    use faer::Parallelism;
    use rand::SeedableRng;
    use rand_chacha::ChaCha8Rng;
    use rand_distr::StandardNormal;

    fn generate_decaying_matrix(
        m: usize,
        n: usize,
        matrix_rank: usize,
        rng: &mut impl Rng,
    ) -> Mat<f64> {
        let actual_rank = matrix_rank.min(m).min(n);

        let x = Mat::<f64>::from_fn(m, actual_rank, |_, _| rng.sample(StandardNormal));
        let y = Mat::<f64>::from_fn(n, actual_rank, |_, _| rng.sample(StandardNormal));

        let u = x.qr().compute_thin_q(); // m x actual_rank
        let v = y.qr().compute_thin_q(); // n x actual_rank

        // Diagonal sigma with exponentially decaying values: sigma_i = exp(-0.5 * i)
        let mut sigma = Mat::<f64>::zeros(actual_rank, actual_rank);
        for i in 0..actual_rank {
            sigma.write(i, i, f64::exp(-(i as f64) * 0.5));
        }

        let mut tmp = Mat::<f64>::zeros(m, actual_rank);
        matmul(
            tmp.as_mut(),
            u.as_ref(),
            sigma.as_ref(),
            None,
            1.0,
            Parallelism::Rayon(0),
        );
        // Step 2: A = tmp * V^T  (m x n)
        let mut a_out = Mat::<f64>::zeros(m, n);
        matmul(
            a_out.as_mut(),
            tmp.as_ref(),
            v.as_ref().transpose(),
            None,
            1.0,
            Parallelism::Rayon(0),
        );
        a_out
    }

    fn test_rsvd_vs_full_svd(
        m: usize,
        n: usize,
        matrix_rank: usize,
        rsvd_k: usize,
        p: usize,
        q: usize,
        par: Parallelism,
    ) {
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let a = generate_decaying_matrix(m, n, matrix_rank, &mut rng);

        // --- RSVD ---
        let start = std::time::Instant::now();
        let (_, s_r, _) = rsvd(a.as_ref(), rsvd_k, p, q, &mut rng, par);
        let rsvd_time = start.elapsed();

        // --- Full SVD (reference) ---
        let start = std::time::Instant::now();
        let full_svd = a.thin_svd();
        let full_time = start.elapsed();

        let s_full = full_svd.s_diagonal();
        let actual_k = rsvd_k.min(m).min(n);

        let rsvd_sum: f64 = (0..actual_k).map(|i| s_r.read(i, 0)).sum();
        let full_sum: f64 = (0..actual_k).map(|i| s_full.read(i)).sum();
        let capture_ratio = rsvd_sum / full_sum;
        let total_relative_error = 1.0 - capture_ratio;

        println!(
            "M={m:4} N={n:5} matrix_rank={matrix_rank:3} rsvd_k={rsvd_k:3} p={p} q={q} | \
             energy_err={:.4}% | \
             RSVD={:.2}ms  SVD={:.2}ms  speedup={:.2}x",
            total_relative_error * 100.0,
            rsvd_time.as_secs_f64() * 1000.0,
            full_time.as_secs_f64() * 1000.0,
            full_time.as_secs_f64() / rsvd_time.as_secs_f64().max(1e-9),
        );
    }

    #[test]
    fn test_rsvd_accuracy() {
        let p = 5;
        let q = 0;

        let rsvd_k = 15;
        let matrix_rank = rsvd_k + 10;

        test_rsvd_vs_full_svd(50, 25, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
        test_rsvd_vs_full_svd(50, 75, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
        test_rsvd_vs_full_svd(50, 150, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
        test_rsvd_vs_full_svd(100, 250, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
        test_rsvd_vs_full_svd(100, 750, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));

        let rsvd_k = 25;
        let matrix_rank = rsvd_k + 10;
        test_rsvd_vs_full_svd(120, 1500, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
    }

    #[test]
    fn test_rsvd_power_iteration_study() {
        let p = 5;
        let rsvd_k = 20;
        let matrix_rank = rsvd_k + 15;
        let m = 50;
        let n = 150;

        println!("\n{}", "".repeat(80));
        println!("Power iteration study: M={m} N={n} rsvd_k={rsvd_k} matrix_rank={matrix_rank}");
        println!("{}", "".repeat(80));

        for q in 0..=4 {
            test_rsvd_vs_full_svd(m, n, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
        }
    }

    #[test]
    fn test_rsvd_near_degenerate_subspace() {
        fn generate_slow_decay(m: usize, n: usize, rank: usize, rng: &mut impl Rng) -> Mat<f64> {
            let actual_rank = rank.min(m).min(n);
            let x = Mat::<f64>::from_fn(m, actual_rank, |_, _| rng.sample(StandardNormal));
            let y = Mat::<f64>::from_fn(n, actual_rank, |_, _| rng.sample(StandardNormal));
            let u = x.qr().compute_thin_q();
            let v = y.qr().compute_thin_q();
            let mut sigma = Mat::<f64>::zeros(actual_rank, actual_rank);
            for i in 0..actual_rank {
                // Very slow decay: sigma_i = exp(-0.05 * i)  => nearly degenerate
                sigma.write(i, i, f64::exp(-(i as f64) * 0.05));
            }
            let mut tmp = Mat::<f64>::zeros(m, actual_rank);
            matmul(
                tmp.as_mut(),
                u.as_ref(),
                sigma.as_ref(),
                None,
                1.0,
                Parallelism::Rayon(0),
            );
            let mut a_out = Mat::<f64>::zeros(m, n);
            matmul(
                a_out.as_mut(),
                tmp.as_ref(),
                v.as_ref().transpose(),
                None,
                1.0,
                Parallelism::Rayon(0),
            );
            a_out
        }

        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let (m, n, rsvd_k, p) = (50, 100, 10, 5);
        let a = generate_slow_decay(m, n, rsvd_k + 5, &mut rng);

        println!("\nNear-degenerate singular value study (slow decay):");
        for q in [0, 1, 2, 3] {
            let mut rng2 = ChaCha8Rng::seed_from_u64(42);
            let (_, s_r, _) = rsvd(a.as_ref(), rsvd_k, p, q, &mut rng2, Parallelism::Rayon(0));

            let full_svd = a.thin_svd();
            let s_full = full_svd.s_diagonal();
            let gap = if rsvd_k < s_full.nrows() {
                s_full.read(rsvd_k - 1) / s_full.read(rsvd_k)
            } else {
                f64::INFINITY
            };

            println!(
                "  q={q} | \
                 sigma_k={:.4}  sigma_k+1={:.4}  gap_ratio={gap:.4}",
                s_r.read(rsvd_k - 1, 0),
                if rsvd_k < s_full.nrows() {
                    s_full.read(rsvd_k)
                } else {
                    0.0
                },
            );
        }
    }
}