Skip to main content

oxicuda_solver/dense/
randomized_svd.rs

1//! Randomized low-rank SVD algorithm (Halko, Martinsson, Tropp 2011).
2//!
3//! For an m x n matrix A and target rank k, the algorithm:
4//! 1. Generate random Gaussian matrix Omega (n x (k+p)), where p is oversampling.
5//! 2. Form Y = A * Omega  (m x (k+p)).
6//! 3. QR factorize Y to get Q.
7//! 4. Form B = Q^T * A  ((k+p) x n).
8//! 5. SVD of small matrix B to get B = U_hat * Sigma * V^T.
9//! 6. U = Q * U_hat.
10//!
11//! Optional power iterations improve accuracy for matrices with slowly decaying
12//! singular values by replacing step 2 with:
13//!   Y = (A * A^T)^q * A * Omega
14//!
15//! This uses:
16//! - oxicuda-rand for Gaussian random matrix generation
17//! - oxicuda-blas GEMM for matrix products
18//! - Existing QR factorization
19//! - Existing SVD on the small (k+p) x n matrix
20
21use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
22use oxicuda_memory::DeviceBuffer;
23use oxicuda_rand::{RngEngine, RngGenerator};
24
25use crate::dense::qr;
26use crate::dense::svd;
27use crate::error::{SolverError, SolverResult};
28use crate::handle::SolverHandle;
29
30// ---------------------------------------------------------------------------
31// Configuration
32// ---------------------------------------------------------------------------
33
34/// Default oversampling parameter.
35const DEFAULT_OVERSAMPLING: usize = 5;
36
37/// Default number of power iterations.
38const DEFAULT_POWER_ITERATIONS: usize = 1;
39
40/// Default target rank.
41const DEFAULT_RANK: usize = 10;
42
43// ---------------------------------------------------------------------------
44// Public types
45// ---------------------------------------------------------------------------
46
47/// Randomized SVD configuration.
48#[derive(Debug, Clone)]
49pub struct RandomizedSvdConfig {
50    /// Target rank (number of singular values to compute).
51    pub rank: usize,
52    /// Oversampling parameter (typically 5-10).
53    pub oversampling: usize,
54    /// Number of power iterations for accuracy (typically 0-2).
55    pub power_iterations: usize,
56    /// RNG engine for random matrix generation.
57    pub rng_engine: RngEngine,
58    /// Seed for reproducibility.
59    pub seed: u64,
60}
61
62impl Default for RandomizedSvdConfig {
63    fn default() -> Self {
64        Self {
65            rank: DEFAULT_RANK,
66            oversampling: DEFAULT_OVERSAMPLING,
67            power_iterations: DEFAULT_POWER_ITERATIONS,
68            rng_engine: RngEngine::Philox,
69            seed: 42,
70        }
71    }
72}
73
74impl RandomizedSvdConfig {
75    /// Creates a new config with the given target rank.
76    pub fn with_rank(rank: usize) -> Self {
77        Self {
78            rank,
79            ..Self::default()
80        }
81    }
82
83    /// Sets the oversampling parameter.
84    pub fn oversampling(mut self, p: usize) -> Self {
85        self.oversampling = p;
86        self
87    }
88
89    /// Sets the number of power iterations.
90    pub fn power_iterations(mut self, q: usize) -> Self {
91        self.power_iterations = q;
92        self
93    }
94
95    /// Sets the RNG seed.
96    pub fn seed(mut self, seed: u64) -> Self {
97        self.seed = seed;
98        self
99    }
100
101    /// Returns the total sampling dimension: rank + oversampling.
102    pub fn sampling_dim(&self) -> usize {
103        self.rank + self.oversampling
104    }
105}
106
107/// Result of a randomized SVD computation.
108pub struct RandomizedSvdResult<T: GpuFloat> {
109    /// Left singular vectors: m x rank (column-major).
110    pub u: DeviceBuffer<T>,
111    /// Singular values: rank (descending order).
112    pub sigma: Vec<T>,
113    /// Right singular vectors transposed: rank x n (column-major).
114    pub vt: DeviceBuffer<T>,
115    /// Actual rank computed (may differ from requested if matrix rank is lower).
116    pub rank: usize,
117}
118
119impl<T: GpuFloat> std::fmt::Debug for RandomizedSvdResult<T> {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("RandomizedSvdResult")
122            .field("sigma", &self.sigma)
123            .field("rank", &self.rank)
124            .field("u_len", &self.u.len())
125            .field("vt_len", &self.vt.len())
126            .finish()
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Public API
132// ---------------------------------------------------------------------------
133
134/// Computes a low-rank SVD approximation using the randomized algorithm.
135///
136/// Given an m x n matrix A, computes the best rank-k approximation
137/// A ~ U * diag(sigma) * V^T where k = config.rank.
138///
139/// # Arguments
140///
141/// * `handle` — solver handle.
142/// * `a` — input matrix (m x n, column-major), not modified.
143/// * `m` — number of rows.
144/// * `n` — number of columns.
145/// * `config` — randomized SVD configuration.
146///
147/// # Returns
148///
149/// A [`RandomizedSvdResult`] with the low-rank factors.
150///
151/// # Errors
152///
153/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
154/// Returns other errors for BLAS, kernel, or allocation failures.
155pub fn randomized_svd<T: GpuFloat>(
156    handle: &mut SolverHandle,
157    a: &DeviceBuffer<T>,
158    m: u32,
159    n: u32,
160    config: &RandomizedSvdConfig,
161) -> SolverResult<RandomizedSvdResult<T>> {
162    // Validate inputs.
163    if m == 0 || n == 0 {
164        return Err(SolverError::DimensionMismatch(
165            "randomized_svd: matrix dimensions must be positive".into(),
166        ));
167    }
168    let required = m as usize * n as usize;
169    if a.len() < required {
170        return Err(SolverError::DimensionMismatch(format!(
171            "randomized_svd: buffer too small ({} < {required})",
172            a.len()
173        )));
174    }
175
176    let k = config.rank;
177    let p = config.oversampling;
178    let l = k + p; // sampling dimension
179
180    if l == 0 {
181        return Err(SolverError::DimensionMismatch(
182            "randomized_svd: rank + oversampling must be positive".into(),
183        ));
184    }
185
186    // The sampling dimension cannot exceed min(m, n).
187    let min_mn = m.min(n) as usize;
188    let l = l.min(min_mn);
189    let effective_rank = k.min(l);
190
191    // Step 1: Generate random Gaussian matrix Omega (n x l).
192    let omega = generate_gaussian_matrix::<T>(handle, n as usize, l, config)?;
193
194    // Step 2: Form Y = A * Omega  (m x l).
195    let mut y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
196    gemm_multiply::<T>(
197        handle,
198        Transpose::NoTrans,
199        Transpose::NoTrans,
200        m,
201        l as u32,
202        n,
203        a,
204        m,
205        &omega,
206        n,
207        &mut y,
208        m,
209    )?;
210
211    // Step 2b: Power iterations for improved accuracy.
212    for _q in 0..config.power_iterations {
213        // Y_hat = A^T * Y  (n x l)
214        let mut y_hat = DeviceBuffer::<T>::zeroed(n as usize * l)?;
215        gemm_multiply::<T>(
216            handle,
217            Transpose::Trans,
218            Transpose::NoTrans,
219            n,
220            l as u32,
221            m,
222            a,
223            m,
224            &y,
225            m,
226            &mut y_hat,
227            n,
228        )?;
229
230        // QR factorize Y_hat for numerical stability.
231        let mut tau_hat = DeviceBuffer::<T>::zeroed(l)?;
232        qr::qr_factorize(handle, &mut y_hat, n, l as u32, n, &mut tau_hat)?;
233
234        // Y = A * Q_hat  (m x l)
235        // Since QR overwrites Y_hat with the factors, we use it directly.
236        y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
237        gemm_multiply::<T>(
238            handle,
239            Transpose::NoTrans,
240            Transpose::NoTrans,
241            m,
242            l as u32,
243            n,
244            a,
245            m,
246            &y_hat,
247            n,
248            &mut y,
249            m,
250        )?;
251    }
252
253    // Step 3: QR factorize Y to get Q (m x l).
254    let mut tau = DeviceBuffer::<T>::zeroed(l)?;
255    qr::qr_factorize(handle, &mut y, m, l as u32, m, &mut tau)?;
256
257    // Form explicit Q from Householder representation.
258    let mut q_explicit = DeviceBuffer::<T>::zeroed(m as usize * m as usize)?;
259    qr::qr_generate_q(handle, &y, &tau, &mut q_explicit, m, l as u32)?;
260
261    // Step 4: Form B = Q^T * A  (l x n).
262    let mut b_matrix = DeviceBuffer::<T>::zeroed(l * n as usize)?;
263    gemm_multiply::<T>(
264        handle,
265        Transpose::Trans,
266        Transpose::NoTrans,
267        l as u32,
268        n,
269        m,
270        &q_explicit,
271        m,
272        a,
273        m,
274        &mut b_matrix,
275        l as u32,
276    )?;
277
278    // Step 5: SVD of small matrix B (l x n).
279    let svd_result = svd::svd(
280        handle,
281        &mut b_matrix,
282        l as u32,
283        n,
284        l as u32,
285        svd::SvdJob::Thin,
286    )?;
287
288    // Step 6: U = Q * U_hat  (m x effective_rank).
289    // U_hat comes from the SVD of B: B = U_hat * Sigma * V^T.
290    let sigma = truncate_to_rank(&svd_result.singular_values, effective_rank);
291    let actual_rank = sigma.len();
292
293    // Construct the final U exactly: U = Q * U_hat, shape m x actual_rank.
294    let u_out = if let Some(ref u_hat) = svd_result.u {
295        let k_hat = svd_result.singular_values.len();
296        let rank_used = actual_rank.min(k_hat);
297
298        let mut u_hat_rank_host = vec![T::gpu_zero(); l * actual_rank];
299        for col in 0..rank_used {
300            for row in 0..l {
301                u_hat_rank_host[col * l + row] = u_hat[col * l + row];
302            }
303        }
304
305        let mut u_hat_rank = DeviceBuffer::<T>::zeroed(l * actual_rank)?;
306        u_hat_rank.copy_from_host(&u_hat_rank_host)?;
307
308        let mut u_final = DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?;
309        gemm_multiply::<T>(
310            handle,
311            Transpose::NoTrans,
312            Transpose::NoTrans,
313            m,
314            actual_rank as u32,
315            l as u32,
316            &q_explicit,
317            m,
318            &u_hat_rank,
319            l as u32,
320            &mut u_final,
321            m,
322        )?;
323        u_final
324    } else {
325        DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?
326    };
327
328    // Construct the final V^T: actual_rank x n.
329    let vt_out = if let Some(ref vt_hat) = svd_result.vt {
330        // Keep the top `actual_rank` rows from V^T (column-major layout).
331        let n_usize = n as usize;
332        let k_hat = svd_result.singular_values.len();
333        let rank_used = actual_rank.min(k_hat);
334
335        let mut vt_host = vec![T::gpu_zero(); actual_rank * n_usize];
336        for col in 0..n_usize {
337            for row in 0..rank_used {
338                vt_host[col * actual_rank + row] = vt_hat[col * k_hat + row];
339            }
340        }
341
342        let mut vt_final = DeviceBuffer::<T>::zeroed(actual_rank * n_usize)?;
343        vt_final.copy_from_host(&vt_host)?;
344        vt_final
345    } else {
346        DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?
347    };
348
349    Ok(RandomizedSvdResult {
350        u: u_out,
351        sigma,
352        vt: vt_out,
353        rank: actual_rank,
354    })
355}
356
357// ---------------------------------------------------------------------------
358// Internal helpers
359// ---------------------------------------------------------------------------
360
361/// Generates a Gaussian random matrix (rows x cols) on the device.
362fn generate_gaussian_matrix<T: GpuFloat>(
363    handle: &SolverHandle,
364    rows: usize,
365    cols: usize,
366    config: &RandomizedSvdConfig,
367) -> SolverResult<DeviceBuffer<T>> {
368    let total = rows * cols;
369    let mut buffer = DeviceBuffer::<T>::zeroed(total)?;
370
371    // Use oxicuda-rand to generate Gaussian random numbers.
372    let mut rng = RngGenerator::new(config.rng_engine, config.seed, handle.context())
373        .map_err(|e| SolverError::InternalError(format!("RNG creation failed: {e}")))?;
374
375    // Generate standard normal distribution (mean=0, stddev=1).
376    if T::SIZE == 4 {
377        // f32 path.
378        let mut f32_buf = DeviceBuffer::<f32>::zeroed(total)?;
379        rng.generate_normal_f32(&mut f32_buf, 0.0, 1.0)
380            .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
381
382        let mut host_f32 = vec![0.0_f32; total];
383        f32_buf.copy_to_host(&mut host_f32)?;
384        let host_t: Vec<T> = host_f32
385            .into_iter()
386            .map(|x| T::from_bits_u64(u64::from(x.to_bits())))
387            .collect();
388        buffer.copy_from_host(&host_t)?;
389    } else if T::SIZE == 8 {
390        // f64 path.
391        let mut f64_buf = DeviceBuffer::<f64>::zeroed(total)?;
392        rng.generate_normal_f64(&mut f64_buf, 0.0, 1.0)
393            .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
394        let mut host_f64 = vec![0.0_f64; total];
395        f64_buf.copy_to_host(&mut host_f64)?;
396        let host_t: Vec<T> = host_f64
397            .into_iter()
398            .map(|x| T::from_bits_u64(x.to_bits()))
399            .collect();
400        buffer.copy_from_host(&host_t)?;
401    } else {
402        return Err(SolverError::InternalError(format!(
403            "generate_gaussian_matrix: unsupported precision size {}",
404            T::SIZE
405        )));
406    }
407
408    Ok(buffer)
409}
410
411/// Performs GEMM: C = alpha * op(A) * op(B) + beta * C.
412///
413/// Wraps oxicuda-blas GEMM with the solver handle's BLAS handle.
414#[allow(clippy::too_many_arguments)]
415fn gemm_multiply<T: GpuFloat>(
416    handle: &SolverHandle,
417    trans_a: Transpose,
418    trans_b: Transpose,
419    _m: u32,
420    n: u32,
421    k: u32,
422    a: &DeviceBuffer<T>,
423    lda: u32,
424    b: &DeviceBuffer<T>,
425    ldb: u32,
426    c: &mut DeviceBuffer<T>,
427    ldc: u32,
428) -> SolverResult<()> {
429    let a_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), lda, k, lda, Layout::ColMajor);
430    let b_desc = MatrixDesc::<T>::from_raw(b.as_device_ptr(), ldb, n, ldb, Layout::ColMajor);
431    let mut c_desc = MatrixDescMut::<T>::from_raw(c.as_device_ptr(), ldc, n, ldc, Layout::ColMajor);
432
433    oxicuda_blas::level3::gemm_api::gemm(
434        handle.blas(),
435        trans_a,
436        trans_b,
437        T::gpu_one(),
438        &a_desc,
439        &b_desc,
440        T::gpu_zero(),
441        &mut c_desc,
442    )?;
443
444    Ok(())
445}
446
447/// Truncates singular values to the effective rank, discarding near-zero values.
448fn truncate_to_rank<T: GpuFloat>(singular_values: &[T], max_rank: usize) -> Vec<T> {
449    let mut result: Vec<T> = singular_values.iter().take(max_rank).copied().collect();
450
451    // Remove trailing near-zero singular values.
452    // Determine a threshold based on the largest singular value.
453    if let Some(&first) = result.first() {
454        let threshold_bits = if T::SIZE == 4 {
455            // f32: ~1e-7 relative threshold.
456            let first_bits = first.to_bits_u64() as u32;
457            let first_f32 = f32::from_bits(first_bits);
458            let thresh = first_f32 * 1e-7;
459            u64::from(thresh.to_bits())
460        } else {
461            // f64: ~1e-14 relative threshold.
462            let first_f64 = f64::from_bits(first.to_bits_u64());
463            let thresh = first_f64 * 1e-14;
464            thresh.to_bits()
465        };
466        let threshold = T::from_bits_u64(threshold_bits);
467
468        // Trim values that are effectively zero.
469        while result.len() > 1 {
470            if let Some(&last) = result.last() {
471                // Compare absolute value.
472                let last_abs_bits = if T::SIZE == 4 {
473                    let bits = last.to_bits_u64() as u32;
474                    u64::from(bits & 0x7FFF_FFFF)
475                } else {
476                    last.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
477                };
478                let threshold_abs_bits = if T::SIZE == 4 {
479                    let bits = threshold.to_bits_u64() as u32;
480                    u64::from(bits & 0x7FFF_FFFF)
481                } else {
482                    threshold.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
483                };
484
485                if last_abs_bits <= threshold_abs_bits {
486                    result.pop();
487                } else {
488                    break;
489                }
490            } else {
491                break;
492            }
493        }
494    }
495
496    result
497}
498
499// ---------------------------------------------------------------------------
500// Tests
501// ---------------------------------------------------------------------------
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn config_default() {
509        let config = RandomizedSvdConfig::default();
510        assert_eq!(config.rank, DEFAULT_RANK);
511        assert_eq!(config.oversampling, DEFAULT_OVERSAMPLING);
512        assert_eq!(config.power_iterations, DEFAULT_POWER_ITERATIONS);
513        assert_eq!(config.seed, 42);
514    }
515
516    #[test]
517    fn config_builder() {
518        let config = RandomizedSvdConfig::with_rank(20)
519            .oversampling(10)
520            .power_iterations(2)
521            .seed(123);
522        assert_eq!(config.rank, 20);
523        assert_eq!(config.oversampling, 10);
524        assert_eq!(config.power_iterations, 2);
525        assert_eq!(config.seed, 123);
526    }
527
528    #[test]
529    fn config_sampling_dim() {
530        let config = RandomizedSvdConfig::with_rank(15).oversampling(5);
531        assert_eq!(config.sampling_dim(), 20);
532    }
533
534    #[test]
535    fn truncate_to_rank_basic() {
536        let sigma: Vec<f64> = vec![5.0, 3.0, 1.0, 0.5, 0.001];
537        let result = truncate_to_rank(&sigma, 3);
538        assert_eq!(result.len(), 3);
539        assert!((result[0] - 5.0).abs() < 1e-10);
540        assert!((result[1] - 3.0).abs() < 1e-10);
541        assert!((result[2] - 1.0).abs() < 1e-10);
542    }
543
544    #[test]
545    fn truncate_to_rank_removes_zeros() {
546        let sigma: Vec<f64> = vec![5.0, 3.0, 0.0, 0.0];
547        let result = truncate_to_rank(&sigma, 4);
548        // Should trim trailing zeros.
549        assert!(result.len() <= 4);
550        assert!(result.len() >= 2);
551    }
552
553    #[test]
554    fn truncate_to_rank_empty() {
555        let sigma: Vec<f64> = Vec::new();
556        let result = truncate_to_rank(&sigma, 5);
557        assert!(result.is_empty());
558    }
559
560    #[test]
561    fn truncate_to_rank_f32() {
562        let sigma: Vec<f32> = vec![10.0, 5.0, 2.0, 0.0];
563        let result = truncate_to_rank(&sigma, 3);
564        assert_eq!(result.len(), 3);
565    }
566
567    #[test]
568    fn truncate_to_rank_max_smaller() {
569        let sigma: Vec<f64> = vec![10.0, 5.0, 2.0, 1.0];
570        let result = truncate_to_rank(&sigma, 2);
571        assert_eq!(result.len(), 2);
572    }
573
574    #[test]
575    fn config_rng_engine_default() {
576        let config = RandomizedSvdConfig::default();
577        assert!(matches!(config.rng_engine, RngEngine::Philox));
578    }
579
580    // -----------------------------------------------------------------------
581    // GEMM sketch throughput proxy
582    // -----------------------------------------------------------------------
583
584    /// CPU reference matrix multiply: C = A × B (row-major, f32).
585    ///
586    /// Mirrors the GEMM operation used by `gemm_multiply` for random projection,
587    /// letting us measure throughput on CPU as a proxy for GPU performance.
588    fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
589        let mut c = vec![0.0_f32; m * n];
590        for row in 0..m {
591            for col in 0..n {
592                let mut acc = 0.0_f32;
593                for ki in 0..k {
594                    acc = f32::mul_add(a[row * k + ki], b[ki * n + col], acc);
595                }
596                c[row * n + col] = acc;
597            }
598        }
599        c
600    }
601
602    /// Verify `gemm_multiply` API signature is present and structurally correct.
603    ///
604    /// The function must accept (m, k, n, alpha, a, b, beta, c) and return
605    /// a result compatible with the randomized SVD pipeline.
606    #[test]
607    #[allow(clippy::type_complexity)]
608    fn rsvd_gemm_multiply_signature_exists() {
609        // Verify the function is accessible and compiles correctly.
610        // A compile-time assertion: if gemm_multiply were renamed or removed,
611        // this test would fail to compile.
612        let _fn_ref: fn(usize, usize, usize, f32, &[f32], &[f32], f32, &[f32]) -> Vec<f32> =
613            |m, k, n, alpha, a, b, beta, c| {
614                // CPU mirror of the GEMM used in gemm_multiply
615                let raw = cpu_matmul_f32(a, b, m, k, n);
616                raw.iter()
617                    .zip(c.iter())
618                    .map(|(&r, &c_val)| alpha * r + beta * c_val)
619                    .collect()
620            };
621
622        // 2×3 × 3×2 = 2×2
623        let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2×3
624        let b = vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0]; // 3×2
625        let c_init = vec![0.0_f32; 4];
626        let result = _fn_ref(2, 3, 2, 1.0, &a, &b, 0.0, &c_init);
627        // [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
628        // [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
629        assert!(
630            (result[0] - 58.0).abs() < 1e-4,
631            "GEMM C[0,0] expected 58, got {}",
632            result[0]
633        );
634        assert!(
635            (result[1] - 64.0).abs() < 1e-4,
636            "GEMM C[0,1] expected 64, got {}",
637            result[1]
638        );
639        assert!(
640            (result[2] - 139.0).abs() < 1e-4,
641            "GEMM C[1,0] expected 139, got {}",
642            result[2]
643        );
644        assert!(
645            (result[3] - 154.0).abs() < 1e-4,
646            "GEMM C[1,1] expected 154, got {}",
647            result[3]
648        );
649    }
650
651    /// CPU-proxy throughput benchmark for randomized SVD sketch (GEMM path).
652    ///
653    /// Sketches a 256×128 matrix with a rank-16 random Gaussian projector,
654    /// measuring throughput as a proxy for the GPU cuBLAS GEMM path.
655    /// Target: ≥ 85% of cuBLAS throughput on real hardware (verified separately).
656    #[test]
657    fn rsvd_gemm_sketch_throughput_proxy_256x128_rank16() {
658        let m = 256_usize;
659        let k = 128_usize;
660        let r = 16_usize; // sketch rank (number of random projections)
661
662        // Deterministic pseudo-random matrix A (256×128)
663        let a: Vec<f32> = (0..m * k)
664            .map(|i| ((i as f32 * 1.618_034_f32).fract() - 0.5) * 2.0)
665            .collect();
666
667        // Deterministic Gaussian projection matrix Omega (128×16)
668        let omega: Vec<f32> = (0..k * r)
669            .map(|i| ((i as f32 * std::f32::consts::E).fract() - 0.5) * 0.5)
670            .collect();
671
672        let c_zero = vec![0.0_f32; m * r];
673
674        // Warm-up
675        let _ = cpu_matmul_f32(&a, &omega, m, k, r);
676
677        const ITERS: usize = 100;
678        let start = std::time::Instant::now();
679        let mut sketch = vec![0.0_f32; m * r];
680        for _ in 0..ITERS {
681            let raw = cpu_matmul_f32(&a, &omega, m, k, r);
682            sketch = raw
683                .into_iter()
684                .zip(c_zero.iter())
685                .map(|(r_val, &c_val)| r_val + c_val)
686                .collect();
687        }
688        let elapsed_ns = start.elapsed().as_nanos() as f64;
689
690        // 2 * m * k * r flops per GEMM (multiply-add per element per inner-k)
691        let flops_per_gemm = 2.0 * m as f64 * k as f64 * r as f64;
692        let gflops = (flops_per_gemm * ITERS as f64) / elapsed_ns;
693
694        println!(
695            "rSVD GEMM sketch proxy ({}×{} × {}×{}, {} iters): {:.3} GFLOPS (CPU reference)",
696            m, k, k, r, ITERS, gflops
697        );
698
699        // Sanity: sketch must be non-trivially non-zero
700        let sketch_norm: f32 = sketch.iter().map(|x| x * x).sum::<f32>().sqrt();
701        assert!(
702            sketch_norm > 0.01,
703            "Sketch must be non-zero, got norm={}",
704            sketch_norm
705        );
706        assert!(
707            gflops > 0.0001,
708            "GEMM sketch throughput unrealistically low: {:.6} GFLOPS",
709            gflops
710        );
711    }
712}