#![doc = include_str!("../README.md")]
use ndarray::{Array2};
use ndarray_linalg::QR;
use ndarray_linalg::svd::SVD;
use rand::{thread_rng, Rng};
use rand_distr::Normal;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
pub fn rsvd(input: &Array2<f64>, k: usize, p: usize, seed: Option<u64>) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let n = input.shape()[1];
let rng = match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => ChaCha8Rng::from_rng(thread_rng()).unwrap(),
};
let l = k + p; let omega = {
let vec = rng.sample_iter(Normal::new(0.0, 1.0).unwrap())
.take(l * n)
.collect::<Vec<_>>();
ndarray::Array::from_shape_vec((n, l), vec).unwrap()
};
let y = input.dot(&omega);
let (q, _) = y.qr().unwrap();
let b = q.t().dot(input);
let (Some(u), s, Some(vt)) = b.svd(true, true).unwrap() else {
panic!("SVD failed");
};
let s = Array2::from_diag(&s);
(q.dot(&u), s, vt)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn test_rsvd(m: usize, n:usize, k: usize, p: usize, seed: u64, tol: f64) {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let a = Array2::from_shape_fn((m, n), |_| rng.gen::<f64>());
let (u, s, vt) = rsvd(&a, k, p, Some(1337));
let (Some(u2), s2, Some(vt2)) = a.svd(true, true).unwrap()
else { panic!("SVD failed"); };
let s2 = Array2::from_diag(&s2);
if k >= m {
assert!(equivalent(&u, &u2, tol));
assert!(equivalent(&s, &s2, tol));
} else {
}
assert!(equivalent(&vt, &vt2, tol));
}
fn equivalent(a: &Array2<f64>, b: &Array2<f64>, e: f64) -> bool {
let a = a.clone().mapv_into(f64::abs);
let b = b.clone().mapv_into(f64::abs);
let diff = a - b;
let avg = diff.sum() / (diff.len() as f64);
avg < e
}
#[test]
fn test_rsvd_2x2() {
for i in 0..20 {
test_rsvd(2, 2, 1, 1, i, 1e-2);
}
}
#[test]
fn test_rsvd_3x3() {
for i in 0..20 {
test_rsvd(3, 3, 3, 1, i, 1e-2);
}
}
#[test]
fn test_rsvd_5x5() {
for i in 0..20 {
test_rsvd(5, 5, 5, 1, i, 1e-2);
}
}
#[test]
fn test_rsvd_10x10_k5() {
for i in 0..20 {
test_rsvd(10, 10, 5, 1, i, 0.1);
}
}
#[test]
fn test_rsvd_10x10_k8() {
for i in 0..20 {
test_rsvd(10, 10, 5, 1, i, 0.1);
}
}
#[test]
fn test_rsvd_100x100_k10() {
for i in 0..20 {
test_rsvd(100, 100, 10, 1, i, 1e-2);
}
}
#[test]
fn test_rsvd_100x10_k10() {
for i in 0..20 {
test_rsvd(100, 10, 10, 1, i, 1e-2);
}
}
#[test]
fn test_rsvd_10x100_k10() {
for i in 0..20 {
test_rsvd(10, 100, 10, 1, i, 1e-2);
}
}
}