Skip to main content

rsvd_faer/
lib.rs

1use faer::linalg::matmul::matmul;
2use faer::Parallelism;
3use faer::{prelude::*, Mat};
4use rand::Rng;
5use rand_distr::{Distribution, Normal};
6
7/// Randomized SVD for `faer` matrices.
8///
9/// # Arguments
10///
11/// * `a` - Input matrix of shape `(m, n)`.
12/// * `k` - Desired target rank.
13/// * `p` - Oversampling parameter.
14/// * `q` - Number of power iterations.
15/// * `rng` - Random number generator used for Gaussian sampling.
16/// * `par` - `faer::Parallelism` mode for matrix multiplication.
17///
18/// # Example
19///
20/// ```rust
21/// use faer::{Mat, Parallelism};
22/// use rand::SeedableRng;
23/// use rand_chacha::ChaCha8Rng;
24/// use rsvd_faer::rsvd;
25///
26/// let mut rng = ChaCha8Rng::seed_from_u64(42);
27/// let a = Mat::<f64>::from_fn(3, 3, |i, j| {
28///     let data = [1.0, 2.0, 3.0, 8.0, 9.0, 4.0, 7.0, 6.0, 5.0];
29///     data[i * 3 + j]
30/// });
31///
32/// let (u, s, vt) = rsvd(a.as_ref(), 2, 5, 1, &mut rng, Parallelism::None);
33/// assert_eq!(u.nrows(), 3);
34/// assert_eq!(u.ncols(), 2);
35/// assert_eq!(s.nrows(), 2);
36/// assert_eq!(s.ncols(), 1);
37/// assert_eq!(vt.nrows(), 2);
38/// assert_eq!(vt.ncols(), 3);
39/// ```
40///
41/// # Returns
42///
43/// * A tuple `(u, s, vt)`
44/// * `u`: matrix shape `(m, k)`
45/// * `s`: matrix shape `(k, 1)`, containing the top `k` singular values as a column vector
46/// * `vt`: shape `(k, n)`
47
48pub fn rsvd(
49    a: MatRef<'_, f64>,
50    k: usize,
51    p: usize,
52    q: usize,
53    rng: &mut impl Rng,
54    par: Parallelism,
55) -> (Mat<f64>, Mat<f64>, Mat<f64>) {
56
57    let m = a.nrows();
58    let n = a.ncols();
59    let l = (k + p).min(m).min(n);
60
61    let normal = Normal::new(0.0, 1.0).unwrap();
62    let omega = Mat::<f64>::from_fn(n, l, |_, _| normal.sample(rng));
63
64    let mut y = Mat::<f64>::zeros(m, l);
65    let mut z = Mat::<f64>::zeros(n, l);
66
67    // Y = A * Omega
68    matmul(y.as_mut(), a, omega.as_ref(), None, 1.0, par);
69
70    for _ in 0..q {
71        // Z = A^T * Y
72        matmul(z.as_mut(), a.transpose(), y.as_ref(), None, 1.0, par);
73        // Y = A * Z
74        matmul(y.as_mut(), a, z.as_ref(), None, 1.0, par);
75    }
76
77    let q_mat = y.qr().compute_thin_q();
78
79    let mut b = Mat::<f64>::zeros(l, n);
80    matmul(b.as_mut(), q_mat.as_ref().transpose(), a, None, 1.0, par);
81
82    let svd = b.thin_svd();
83    let u_tilde = svd.u(); // l x l
84    let s_vec = svd.s_diagonal(); // l
85    let v_mat = svd.v(); // n x l
86
87    let mut u_full = Mat::<f64>::zeros(m, l);
88    matmul(u_full.as_mut(), q_mat.as_ref(), u_tilde, None, 1.0, par);
89
90    let u = u_full.get(.., ..k).to_owned();
91    let s = Mat::<f64>::from_fn(k, 1, |i, _| s_vec.read(i));
92    let vt = v_mat.get(.., ..k).transpose().to_owned();
93
94    (u, s, vt)
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use faer::linalg::matmul::matmul;
101    use faer::Parallelism;
102    use rand::SeedableRng;
103    use rand_chacha::ChaCha8Rng;
104    use rand_distr::StandardNormal;
105
106    fn generate_decaying_matrix(
107        m: usize,
108        n: usize,
109        matrix_rank: usize,
110        rng: &mut impl Rng,
111    ) -> Mat<f64> {
112        let actual_rank = matrix_rank.min(m).min(n);
113
114        let x = Mat::<f64>::from_fn(m, actual_rank, |_, _| rng.sample(StandardNormal));
115        let y = Mat::<f64>::from_fn(n, actual_rank, |_, _| rng.sample(StandardNormal));
116
117        let u = x.qr().compute_thin_q(); // m x actual_rank
118        let v = y.qr().compute_thin_q(); // n x actual_rank
119
120        // Diagonal sigma with exponentially decaying values: sigma_i = exp(-0.5 * i)
121        let mut sigma = Mat::<f64>::zeros(actual_rank, actual_rank);
122        for i in 0..actual_rank {
123            sigma.write(i, i, f64::exp(-(i as f64) * 0.5));
124        }
125
126        let mut tmp = Mat::<f64>::zeros(m, actual_rank);
127        matmul(
128            tmp.as_mut(),
129            u.as_ref(),
130            sigma.as_ref(),
131            None,
132            1.0,
133            Parallelism::Rayon(0),
134        );
135        // Step 2: A = tmp * V^T  (m x n)
136        let mut a_out = Mat::<f64>::zeros(m, n);
137        matmul(
138            a_out.as_mut(),
139            tmp.as_ref(),
140            v.as_ref().transpose(),
141            None,
142            1.0,
143            Parallelism::Rayon(0),
144        );
145        a_out
146    }
147
148    fn test_rsvd_vs_full_svd(
149        m: usize,
150        n: usize,
151        matrix_rank: usize,
152        rsvd_k: usize,
153        p: usize,
154        q: usize,
155        par: Parallelism,
156    ) {
157        let mut rng = ChaCha8Rng::seed_from_u64(42);
158        let a = generate_decaying_matrix(m, n, matrix_rank, &mut rng);
159
160        // --- RSVD ---
161        let start = std::time::Instant::now();
162        let (_, s_r, _) = rsvd(a.as_ref(), rsvd_k, p, q, &mut rng, par);
163        let rsvd_time = start.elapsed();
164
165        // --- Full SVD (reference) ---
166        let start = std::time::Instant::now();
167        let full_svd = a.thin_svd();
168        let full_time = start.elapsed();
169
170        let s_full = full_svd.s_diagonal();
171        let actual_k = rsvd_k.min(m).min(n);
172
173        let rsvd_sum: f64 = (0..actual_k).map(|i| s_r.read(i, 0)).sum();
174        let full_sum: f64 = (0..actual_k).map(|i| s_full.read(i)).sum();
175        let capture_ratio = rsvd_sum / full_sum;
176        let total_relative_error = 1.0 - capture_ratio;
177
178        println!(
179            "M={m:4} N={n:5} matrix_rank={matrix_rank:3} rsvd_k={rsvd_k:3} p={p} q={q} | \
180             energy_err={:.4}% | \
181             RSVD={:.2}ms  SVD={:.2}ms  speedup={:.2}x",
182            total_relative_error * 100.0,
183            rsvd_time.as_secs_f64() * 1000.0,
184            full_time.as_secs_f64() * 1000.0,
185            full_time.as_secs_f64() / rsvd_time.as_secs_f64().max(1e-9),
186        );
187    }
188
189    #[test]
190    fn test_rsvd_accuracy() {
191        let p = 5;
192        let q = 0;
193
194        let rsvd_k = 15;
195        let matrix_rank = rsvd_k + 10;
196
197        test_rsvd_vs_full_svd(50, 25, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
198        test_rsvd_vs_full_svd(50, 75, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
199        test_rsvd_vs_full_svd(50, 150, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
200        test_rsvd_vs_full_svd(100, 250, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
201        test_rsvd_vs_full_svd(100, 750, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
202
203        let rsvd_k = 25;
204        let matrix_rank = rsvd_k + 10;
205        test_rsvd_vs_full_svd(120, 1500, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
206    }
207
208    #[test]
209    fn test_rsvd_power_iteration_study() {
210        let p = 5;
211        let rsvd_k = 20;
212        let matrix_rank = rsvd_k + 15;
213        let m = 50;
214        let n = 150;
215
216        println!("\n{}", "─".repeat(80));
217        println!("Power iteration study: M={m} N={n} rsvd_k={rsvd_k} matrix_rank={matrix_rank}");
218        println!("{}", "─".repeat(80));
219
220        for q in 0..=4 {
221            test_rsvd_vs_full_svd(m, n, matrix_rank, rsvd_k, p, q, Parallelism::Rayon(0));
222        }
223    }
224
225    #[test]
226    fn test_rsvd_near_degenerate_subspace() {
227        fn generate_slow_decay(m: usize, n: usize, rank: usize, rng: &mut impl Rng) -> Mat<f64> {
228            let actual_rank = rank.min(m).min(n);
229            let x = Mat::<f64>::from_fn(m, actual_rank, |_, _| rng.sample(StandardNormal));
230            let y = Mat::<f64>::from_fn(n, actual_rank, |_, _| rng.sample(StandardNormal));
231            let u = x.qr().compute_thin_q();
232            let v = y.qr().compute_thin_q();
233            let mut sigma = Mat::<f64>::zeros(actual_rank, actual_rank);
234            for i in 0..actual_rank {
235                // Very slow decay: sigma_i = exp(-0.05 * i)  => nearly degenerate
236                sigma.write(i, i, f64::exp(-(i as f64) * 0.05));
237            }
238            let mut tmp = Mat::<f64>::zeros(m, actual_rank);
239            matmul(
240                tmp.as_mut(),
241                u.as_ref(),
242                sigma.as_ref(),
243                None,
244                1.0,
245                Parallelism::Rayon(0),
246            );
247            let mut a_out = Mat::<f64>::zeros(m, n);
248            matmul(
249                a_out.as_mut(),
250                tmp.as_ref(),
251                v.as_ref().transpose(),
252                None,
253                1.0,
254                Parallelism::Rayon(0),
255            );
256            a_out
257        }
258
259        let mut rng = ChaCha8Rng::seed_from_u64(42);
260        let (m, n, rsvd_k, p) = (50, 100, 10, 5);
261        let a = generate_slow_decay(m, n, rsvd_k + 5, &mut rng);
262
263        println!("\nNear-degenerate singular value study (slow decay):");
264        for q in [0, 1, 2, 3] {
265            let mut rng2 = ChaCha8Rng::seed_from_u64(42);
266            let (_, s_r, _) = rsvd(a.as_ref(), rsvd_k, p, q, &mut rng2, Parallelism::Rayon(0));
267
268            let full_svd = a.thin_svd();
269            let s_full = full_svd.s_diagonal();
270            let gap = if rsvd_k < s_full.nrows() {
271                s_full.read(rsvd_k - 1) / s_full.read(rsvd_k)
272            } else {
273                f64::INFINITY
274            };
275
276            println!(
277                "  q={q} | \
278                 sigma_k={:.4}  sigma_k+1={:.4}  gap_ratio={gap:.4}",
279                s_r.read(rsvd_k - 1, 0),
280                if rsvd_k < s_full.nrows() {
281                    s_full.read(rsvd_k)
282                } else {
283                    0.0
284                },
285            );
286        }
287    }
288}