1use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
22use oxicuda_memory::DeviceBuffer;
23use oxicuda_rand::{RngEngine, RngGenerator};
24
25use crate::dense::qr;
26use crate::dense::svd;
27use crate::error::{SolverError, SolverResult};
28use crate::handle::SolverHandle;
29
30const DEFAULT_OVERSAMPLING: usize = 5;
36
37const DEFAULT_POWER_ITERATIONS: usize = 1;
39
40const DEFAULT_RANK: usize = 10;
42
43#[derive(Debug, Clone)]
49pub struct RandomizedSvdConfig {
50 pub rank: usize,
52 pub oversampling: usize,
54 pub power_iterations: usize,
56 pub rng_engine: RngEngine,
58 pub seed: u64,
60}
61
62impl Default for RandomizedSvdConfig {
63 fn default() -> Self {
64 Self {
65 rank: DEFAULT_RANK,
66 oversampling: DEFAULT_OVERSAMPLING,
67 power_iterations: DEFAULT_POWER_ITERATIONS,
68 rng_engine: RngEngine::Philox,
69 seed: 42,
70 }
71 }
72}
73
74impl RandomizedSvdConfig {
75 pub fn with_rank(rank: usize) -> Self {
77 Self {
78 rank,
79 ..Self::default()
80 }
81 }
82
83 pub fn oversampling(mut self, p: usize) -> Self {
85 self.oversampling = p;
86 self
87 }
88
89 pub fn power_iterations(mut self, q: usize) -> Self {
91 self.power_iterations = q;
92 self
93 }
94
95 pub fn seed(mut self, seed: u64) -> Self {
97 self.seed = seed;
98 self
99 }
100
101 pub fn sampling_dim(&self) -> usize {
103 self.rank + self.oversampling
104 }
105}
106
107pub struct RandomizedSvdResult<T: GpuFloat> {
109 pub u: DeviceBuffer<T>,
111 pub sigma: Vec<T>,
113 pub vt: DeviceBuffer<T>,
115 pub rank: usize,
117}
118
119impl<T: GpuFloat> std::fmt::Debug for RandomizedSvdResult<T> {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("RandomizedSvdResult")
122 .field("sigma", &self.sigma)
123 .field("rank", &self.rank)
124 .field("u_len", &self.u.len())
125 .field("vt_len", &self.vt.len())
126 .finish()
127 }
128}
129
130pub fn randomized_svd<T: GpuFloat>(
156 handle: &mut SolverHandle,
157 a: &DeviceBuffer<T>,
158 m: u32,
159 n: u32,
160 config: &RandomizedSvdConfig,
161) -> SolverResult<RandomizedSvdResult<T>> {
162 if m == 0 || n == 0 {
164 return Err(SolverError::DimensionMismatch(
165 "randomized_svd: matrix dimensions must be positive".into(),
166 ));
167 }
168 let required = m as usize * n as usize;
169 if a.len() < required {
170 return Err(SolverError::DimensionMismatch(format!(
171 "randomized_svd: buffer too small ({} < {required})",
172 a.len()
173 )));
174 }
175
176 let k = config.rank;
177 let p = config.oversampling;
178 let l = k + p; if l == 0 {
181 return Err(SolverError::DimensionMismatch(
182 "randomized_svd: rank + oversampling must be positive".into(),
183 ));
184 }
185
186 let min_mn = m.min(n) as usize;
188 let l = l.min(min_mn);
189 let effective_rank = k.min(l);
190
191 let omega = generate_gaussian_matrix::<T>(handle, n as usize, l, config)?;
193
194 let mut y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
196 gemm_multiply::<T>(
197 handle,
198 Transpose::NoTrans,
199 Transpose::NoTrans,
200 m,
201 l as u32,
202 n,
203 a,
204 m,
205 &omega,
206 n,
207 &mut y,
208 m,
209 )?;
210
211 for _q in 0..config.power_iterations {
213 let mut y_hat = DeviceBuffer::<T>::zeroed(n as usize * l)?;
215 gemm_multiply::<T>(
216 handle,
217 Transpose::Trans,
218 Transpose::NoTrans,
219 n,
220 l as u32,
221 m,
222 a,
223 m,
224 &y,
225 m,
226 &mut y_hat,
227 n,
228 )?;
229
230 let mut tau_hat = DeviceBuffer::<T>::zeroed(l)?;
232 qr::qr_factorize(handle, &mut y_hat, n, l as u32, n, &mut tau_hat)?;
233
234 y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
237 gemm_multiply::<T>(
238 handle,
239 Transpose::NoTrans,
240 Transpose::NoTrans,
241 m,
242 l as u32,
243 n,
244 a,
245 m,
246 &y_hat,
247 n,
248 &mut y,
249 m,
250 )?;
251 }
252
253 let mut tau = DeviceBuffer::<T>::zeroed(l)?;
255 qr::qr_factorize(handle, &mut y, m, l as u32, m, &mut tau)?;
256
257 let mut q_explicit = DeviceBuffer::<T>::zeroed(m as usize * m as usize)?;
259 qr::qr_generate_q(handle, &y, &tau, &mut q_explicit, m, l as u32)?;
260
261 let mut b_matrix = DeviceBuffer::<T>::zeroed(l * n as usize)?;
263 gemm_multiply::<T>(
264 handle,
265 Transpose::Trans,
266 Transpose::NoTrans,
267 l as u32,
268 n,
269 m,
270 &q_explicit,
271 m,
272 a,
273 m,
274 &mut b_matrix,
275 l as u32,
276 )?;
277
278 let svd_result = svd::svd(
280 handle,
281 &mut b_matrix,
282 l as u32,
283 n,
284 l as u32,
285 svd::SvdJob::Thin,
286 )?;
287
288 let sigma = truncate_to_rank(&svd_result.singular_values, effective_rank);
291 let actual_rank = sigma.len();
292
293 let u_out = if let Some(ref u_hat) = svd_result.u {
295 let k_hat = svd_result.singular_values.len();
296 let rank_used = actual_rank.min(k_hat);
297
298 let mut u_hat_rank_host = vec![T::gpu_zero(); l * actual_rank];
299 for col in 0..rank_used {
300 for row in 0..l {
301 u_hat_rank_host[col * l + row] = u_hat[col * l + row];
302 }
303 }
304
305 let mut u_hat_rank = DeviceBuffer::<T>::zeroed(l * actual_rank)?;
306 u_hat_rank.copy_from_host(&u_hat_rank_host)?;
307
308 let mut u_final = DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?;
309 gemm_multiply::<T>(
310 handle,
311 Transpose::NoTrans,
312 Transpose::NoTrans,
313 m,
314 actual_rank as u32,
315 l as u32,
316 &q_explicit,
317 m,
318 &u_hat_rank,
319 l as u32,
320 &mut u_final,
321 m,
322 )?;
323 u_final
324 } else {
325 DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?
326 };
327
328 let vt_out = if let Some(ref vt_hat) = svd_result.vt {
330 let n_usize = n as usize;
332 let k_hat = svd_result.singular_values.len();
333 let rank_used = actual_rank.min(k_hat);
334
335 let mut vt_host = vec![T::gpu_zero(); actual_rank * n_usize];
336 for col in 0..n_usize {
337 for row in 0..rank_used {
338 vt_host[col * actual_rank + row] = vt_hat[col * k_hat + row];
339 }
340 }
341
342 let mut vt_final = DeviceBuffer::<T>::zeroed(actual_rank * n_usize)?;
343 vt_final.copy_from_host(&vt_host)?;
344 vt_final
345 } else {
346 DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?
347 };
348
349 Ok(RandomizedSvdResult {
350 u: u_out,
351 sigma,
352 vt: vt_out,
353 rank: actual_rank,
354 })
355}
356
357fn generate_gaussian_matrix<T: GpuFloat>(
363 handle: &SolverHandle,
364 rows: usize,
365 cols: usize,
366 config: &RandomizedSvdConfig,
367) -> SolverResult<DeviceBuffer<T>> {
368 let total = rows * cols;
369 let mut buffer = DeviceBuffer::<T>::zeroed(total)?;
370
371 let mut rng = RngGenerator::new(config.rng_engine, config.seed, handle.context())
373 .map_err(|e| SolverError::InternalError(format!("RNG creation failed: {e}")))?;
374
375 if T::SIZE == 4 {
377 let mut f32_buf = DeviceBuffer::<f32>::zeroed(total)?;
379 rng.generate_normal_f32(&mut f32_buf, 0.0, 1.0)
380 .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
381
382 let mut host_f32 = vec![0.0_f32; total];
383 f32_buf.copy_to_host(&mut host_f32)?;
384 let host_t: Vec<T> = host_f32
385 .into_iter()
386 .map(|x| T::from_bits_u64(u64::from(x.to_bits())))
387 .collect();
388 buffer.copy_from_host(&host_t)?;
389 } else if T::SIZE == 8 {
390 let mut f64_buf = DeviceBuffer::<f64>::zeroed(total)?;
392 rng.generate_normal_f64(&mut f64_buf, 0.0, 1.0)
393 .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
394 let mut host_f64 = vec![0.0_f64; total];
395 f64_buf.copy_to_host(&mut host_f64)?;
396 let host_t: Vec<T> = host_f64
397 .into_iter()
398 .map(|x| T::from_bits_u64(x.to_bits()))
399 .collect();
400 buffer.copy_from_host(&host_t)?;
401 } else {
402 return Err(SolverError::InternalError(format!(
403 "generate_gaussian_matrix: unsupported precision size {}",
404 T::SIZE
405 )));
406 }
407
408 Ok(buffer)
409}
410
411#[allow(clippy::too_many_arguments)]
415fn gemm_multiply<T: GpuFloat>(
416 handle: &SolverHandle,
417 trans_a: Transpose,
418 trans_b: Transpose,
419 _m: u32,
420 n: u32,
421 k: u32,
422 a: &DeviceBuffer<T>,
423 lda: u32,
424 b: &DeviceBuffer<T>,
425 ldb: u32,
426 c: &mut DeviceBuffer<T>,
427 ldc: u32,
428) -> SolverResult<()> {
429 let a_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), lda, k, lda, Layout::ColMajor);
430 let b_desc = MatrixDesc::<T>::from_raw(b.as_device_ptr(), ldb, n, ldb, Layout::ColMajor);
431 let mut c_desc = MatrixDescMut::<T>::from_raw(c.as_device_ptr(), ldc, n, ldc, Layout::ColMajor);
432
433 oxicuda_blas::level3::gemm_api::gemm(
434 handle.blas(),
435 trans_a,
436 trans_b,
437 T::gpu_one(),
438 &a_desc,
439 &b_desc,
440 T::gpu_zero(),
441 &mut c_desc,
442 )?;
443
444 Ok(())
445}
446
447fn truncate_to_rank<T: GpuFloat>(singular_values: &[T], max_rank: usize) -> Vec<T> {
449 let mut result: Vec<T> = singular_values.iter().take(max_rank).copied().collect();
450
451 if let Some(&first) = result.first() {
454 let threshold_bits = if T::SIZE == 4 {
455 let first_bits = first.to_bits_u64() as u32;
457 let first_f32 = f32::from_bits(first_bits);
458 let thresh = first_f32 * 1e-7;
459 u64::from(thresh.to_bits())
460 } else {
461 let first_f64 = f64::from_bits(first.to_bits_u64());
463 let thresh = first_f64 * 1e-14;
464 thresh.to_bits()
465 };
466 let threshold = T::from_bits_u64(threshold_bits);
467
468 while result.len() > 1 {
470 if let Some(&last) = result.last() {
471 let last_abs_bits = if T::SIZE == 4 {
473 let bits = last.to_bits_u64() as u32;
474 u64::from(bits & 0x7FFF_FFFF)
475 } else {
476 last.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
477 };
478 let threshold_abs_bits = if T::SIZE == 4 {
479 let bits = threshold.to_bits_u64() as u32;
480 u64::from(bits & 0x7FFF_FFFF)
481 } else {
482 threshold.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
483 };
484
485 if last_abs_bits <= threshold_abs_bits {
486 result.pop();
487 } else {
488 break;
489 }
490 } else {
491 break;
492 }
493 }
494 }
495
496 result
497}
498
499#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn config_default() {
509 let config = RandomizedSvdConfig::default();
510 assert_eq!(config.rank, DEFAULT_RANK);
511 assert_eq!(config.oversampling, DEFAULT_OVERSAMPLING);
512 assert_eq!(config.power_iterations, DEFAULT_POWER_ITERATIONS);
513 assert_eq!(config.seed, 42);
514 }
515
516 #[test]
517 fn config_builder() {
518 let config = RandomizedSvdConfig::with_rank(20)
519 .oversampling(10)
520 .power_iterations(2)
521 .seed(123);
522 assert_eq!(config.rank, 20);
523 assert_eq!(config.oversampling, 10);
524 assert_eq!(config.power_iterations, 2);
525 assert_eq!(config.seed, 123);
526 }
527
528 #[test]
529 fn config_sampling_dim() {
530 let config = RandomizedSvdConfig::with_rank(15).oversampling(5);
531 assert_eq!(config.sampling_dim(), 20);
532 }
533
534 #[test]
535 fn truncate_to_rank_basic() {
536 let sigma: Vec<f64> = vec![5.0, 3.0, 1.0, 0.5, 0.001];
537 let result = truncate_to_rank(&sigma, 3);
538 assert_eq!(result.len(), 3);
539 assert!((result[0] - 5.0).abs() < 1e-10);
540 assert!((result[1] - 3.0).abs() < 1e-10);
541 assert!((result[2] - 1.0).abs() < 1e-10);
542 }
543
544 #[test]
545 fn truncate_to_rank_removes_zeros() {
546 let sigma: Vec<f64> = vec![5.0, 3.0, 0.0, 0.0];
547 let result = truncate_to_rank(&sigma, 4);
548 assert!(result.len() <= 4);
550 assert!(result.len() >= 2);
551 }
552
553 #[test]
554 fn truncate_to_rank_empty() {
555 let sigma: Vec<f64> = Vec::new();
556 let result = truncate_to_rank(&sigma, 5);
557 assert!(result.is_empty());
558 }
559
560 #[test]
561 fn truncate_to_rank_f32() {
562 let sigma: Vec<f32> = vec![10.0, 5.0, 2.0, 0.0];
563 let result = truncate_to_rank(&sigma, 3);
564 assert_eq!(result.len(), 3);
565 }
566
567 #[test]
568 fn truncate_to_rank_max_smaller() {
569 let sigma: Vec<f64> = vec![10.0, 5.0, 2.0, 1.0];
570 let result = truncate_to_rank(&sigma, 2);
571 assert_eq!(result.len(), 2);
572 }
573
574 #[test]
575 fn config_rng_engine_default() {
576 let config = RandomizedSvdConfig::default();
577 assert!(matches!(config.rng_engine, RngEngine::Philox));
578 }
579
580 fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
589 let mut c = vec![0.0_f32; m * n];
590 for row in 0..m {
591 for col in 0..n {
592 let mut acc = 0.0_f32;
593 for ki in 0..k {
594 acc = f32::mul_add(a[row * k + ki], b[ki * n + col], acc);
595 }
596 c[row * n + col] = acc;
597 }
598 }
599 c
600 }
601
602 #[test]
607 #[allow(clippy::type_complexity)]
608 fn rsvd_gemm_multiply_signature_exists() {
609 let _fn_ref: fn(usize, usize, usize, f32, &[f32], &[f32], f32, &[f32]) -> Vec<f32> =
613 |m, k, n, alpha, a, b, beta, c| {
614 let raw = cpu_matmul_f32(a, b, m, k, n);
616 raw.iter()
617 .zip(c.iter())
618 .map(|(&r, &c_val)| alpha * r + beta * c_val)
619 .collect()
620 };
621
622 let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0]; let c_init = vec![0.0_f32; 4];
626 let result = _fn_ref(2, 3, 2, 1.0, &a, &b, 0.0, &c_init);
627 assert!(
630 (result[0] - 58.0).abs() < 1e-4,
631 "GEMM C[0,0] expected 58, got {}",
632 result[0]
633 );
634 assert!(
635 (result[1] - 64.0).abs() < 1e-4,
636 "GEMM C[0,1] expected 64, got {}",
637 result[1]
638 );
639 assert!(
640 (result[2] - 139.0).abs() < 1e-4,
641 "GEMM C[1,0] expected 139, got {}",
642 result[2]
643 );
644 assert!(
645 (result[3] - 154.0).abs() < 1e-4,
646 "GEMM C[1,1] expected 154, got {}",
647 result[3]
648 );
649 }
650
651 #[test]
657 fn rsvd_gemm_sketch_throughput_proxy_256x128_rank16() {
658 let m = 256_usize;
659 let k = 128_usize;
660 let r = 16_usize; let a: Vec<f32> = (0..m * k)
664 .map(|i| ((i as f32 * 1.618_034_f32).fract() - 0.5) * 2.0)
665 .collect();
666
667 let omega: Vec<f32> = (0..k * r)
669 .map(|i| ((i as f32 * std::f32::consts::E).fract() - 0.5) * 0.5)
670 .collect();
671
672 let c_zero = vec![0.0_f32; m * r];
673
674 let _ = cpu_matmul_f32(&a, &omega, m, k, r);
676
677 const ITERS: usize = 100;
678 let start = std::time::Instant::now();
679 let mut sketch = vec![0.0_f32; m * r];
680 for _ in 0..ITERS {
681 let raw = cpu_matmul_f32(&a, &omega, m, k, r);
682 sketch = raw
683 .into_iter()
684 .zip(c_zero.iter())
685 .map(|(r_val, &c_val)| r_val + c_val)
686 .collect();
687 }
688 let elapsed_ns = start.elapsed().as_nanos() as f64;
689
690 let flops_per_gemm = 2.0 * m as f64 * k as f64 * r as f64;
692 let gflops = (flops_per_gemm * ITERS as f64) / elapsed_ns;
693
694 println!(
695 "rSVD GEMM sketch proxy ({}×{} × {}×{}, {} iters): {:.3} GFLOPS (CPU reference)",
696 m, k, k, r, ITERS, gflops
697 );
698
699 let sketch_norm: f32 = sketch.iter().map(|x| x * x).sum::<f32>().sqrt();
701 assert!(
702 sketch_norm > 0.01,
703 "Sketch must be non-zero, got norm={}",
704 sketch_norm
705 );
706 assert!(
707 gflops > 0.0001,
708 "GEMM sketch throughput unrealistically low: {:.6} GFLOPS",
709 gflops
710 );
711 }
712}