1use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
22use oxicuda_memory::DeviceBuffer;
23use oxicuda_rand::{RngEngine, RngGenerator};
24
25use crate::dense::qr;
26use crate::dense::svd;
27use crate::error::{SolverError, SolverResult};
28use crate::handle::SolverHandle;
29
30const DEFAULT_OVERSAMPLING: usize = 5;
36
37const DEFAULT_POWER_ITERATIONS: usize = 1;
39
40const DEFAULT_RANK: usize = 10;
42
43#[derive(Debug, Clone)]
49pub struct RandomizedSvdConfig {
50 pub rank: usize,
52 pub oversampling: usize,
54 pub power_iterations: usize,
56 pub rng_engine: RngEngine,
58 pub seed: u64,
60}
61
62impl Default for RandomizedSvdConfig {
63 fn default() -> Self {
64 Self {
65 rank: DEFAULT_RANK,
66 oversampling: DEFAULT_OVERSAMPLING,
67 power_iterations: DEFAULT_POWER_ITERATIONS,
68 rng_engine: RngEngine::Philox,
69 seed: 42,
70 }
71 }
72}
73
74impl RandomizedSvdConfig {
75 pub fn with_rank(rank: usize) -> Self {
77 Self {
78 rank,
79 ..Self::default()
80 }
81 }
82
83 pub fn oversampling(mut self, p: usize) -> Self {
85 self.oversampling = p;
86 self
87 }
88
89 pub fn power_iterations(mut self, q: usize) -> Self {
91 self.power_iterations = q;
92 self
93 }
94
95 pub fn seed(mut self, seed: u64) -> Self {
97 self.seed = seed;
98 self
99 }
100
101 pub fn sampling_dim(&self) -> usize {
103 self.rank + self.oversampling
104 }
105}
106
107pub struct RandomizedSvdResult<T: GpuFloat> {
109 pub u: DeviceBuffer<T>,
111 pub sigma: Vec<T>,
113 pub vt: DeviceBuffer<T>,
115 pub rank: usize,
117}
118
119impl<T: GpuFloat> std::fmt::Debug for RandomizedSvdResult<T> {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("RandomizedSvdResult")
122 .field("sigma", &self.sigma)
123 .field("rank", &self.rank)
124 .field("u_len", &self.u.len())
125 .field("vt_len", &self.vt.len())
126 .finish()
127 }
128}
129
130pub fn randomized_svd<T: GpuFloat>(
156 handle: &mut SolverHandle,
157 a: &DeviceBuffer<T>,
158 m: u32,
159 n: u32,
160 config: &RandomizedSvdConfig,
161) -> SolverResult<RandomizedSvdResult<T>> {
162 if m == 0 || n == 0 {
164 return Err(SolverError::DimensionMismatch(
165 "randomized_svd: matrix dimensions must be positive".into(),
166 ));
167 }
168 let required = m as usize * n as usize;
169 if a.len() < required {
170 return Err(SolverError::DimensionMismatch(format!(
171 "randomized_svd: buffer too small ({} < {required})",
172 a.len()
173 )));
174 }
175
176 let k = config.rank;
177 let p = config.oversampling;
178 let l = k + p; if l == 0 {
181 return Err(SolverError::DimensionMismatch(
182 "randomized_svd: rank + oversampling must be positive".into(),
183 ));
184 }
185
186 let min_mn = m.min(n) as usize;
188 let l = l.min(min_mn);
189 let effective_rank = k.min(l);
190
191 let omega = generate_gaussian_matrix::<T>(handle, n as usize, l, config)?;
193
194 let mut y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
196 gemm_multiply::<T>(
197 handle,
198 Transpose::NoTrans,
199 Transpose::NoTrans,
200 m,
201 l as u32,
202 n,
203 a,
204 m,
205 &omega,
206 n,
207 &mut y,
208 m,
209 )?;
210
211 for _q in 0..config.power_iterations {
213 let mut y_hat = DeviceBuffer::<T>::zeroed(n as usize * l)?;
215 gemm_multiply::<T>(
216 handle,
217 Transpose::Trans,
218 Transpose::NoTrans,
219 n,
220 l as u32,
221 m,
222 a,
223 m,
224 &y,
225 m,
226 &mut y_hat,
227 n,
228 )?;
229
230 let mut tau_hat = DeviceBuffer::<T>::zeroed(l)?;
232 qr::qr_factorize(handle, &mut y_hat, n, l as u32, n, &mut tau_hat)?;
233
234 y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
237 gemm_multiply::<T>(
238 handle,
239 Transpose::NoTrans,
240 Transpose::NoTrans,
241 m,
242 l as u32,
243 n,
244 a,
245 m,
246 &y_hat,
247 n,
248 &mut y,
249 m,
250 )?;
251 }
252
253 let mut tau = DeviceBuffer::<T>::zeroed(l)?;
255 qr::qr_factorize(handle, &mut y, m, l as u32, m, &mut tau)?;
256
257 let _q_matrix = DeviceBuffer::<T>::zeroed(m as usize * l)?;
262
263 let mut b_matrix = DeviceBuffer::<T>::zeroed(l * n as usize)?;
265 gemm_multiply::<T>(
266 handle,
267 Transpose::Trans,
268 Transpose::NoTrans,
269 l as u32,
270 n,
271 m,
272 &y, m,
274 a,
275 m,
276 &mut b_matrix,
277 l as u32,
278 )?;
279
280 let svd_result = svd::svd(
282 handle,
283 &mut b_matrix,
284 l as u32,
285 n,
286 l as u32,
287 svd::SvdJob::Thin,
288 )?;
289
290 let sigma = truncate_to_rank(&svd_result.singular_values, effective_rank);
293 let actual_rank = sigma.len();
294
295 let u_out = if let Some(ref u_hat) = svd_result.u {
297 let u_final = DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?;
299 let _ = u_hat;
302 u_final
303 } else {
304 DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?
305 };
306
307 let vt_out = if let Some(ref vt_hat) = svd_result.vt {
309 let vt_final = DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?;
311 let _ = vt_hat;
312 vt_final
313 } else {
314 DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?
315 };
316
317 Ok(RandomizedSvdResult {
318 u: u_out,
319 sigma,
320 vt: vt_out,
321 rank: actual_rank,
322 })
323}
324
325fn generate_gaussian_matrix<T: GpuFloat>(
331 handle: &SolverHandle,
332 rows: usize,
333 cols: usize,
334 config: &RandomizedSvdConfig,
335) -> SolverResult<DeviceBuffer<T>> {
336 let total = rows * cols;
337 let buffer = DeviceBuffer::<T>::zeroed(total)?;
338
339 let mut rng = RngGenerator::new(config.rng_engine, config.seed, handle.context())
341 .map_err(|e| SolverError::InternalError(format!("RNG creation failed: {e}")))?;
342
343 if T::SIZE == 4 {
345 let mut f32_buf = DeviceBuffer::<f32>::zeroed(total)?;
351 rng.generate_normal_f32(&mut f32_buf, 0.0, 1.0)
352 .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
353 } else if T::SIZE == 8 {
356 let mut f64_buf = DeviceBuffer::<f64>::zeroed(total)?;
358 rng.generate_normal_f64(&mut f64_buf, 0.0, 1.0)
359 .map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
360 }
361 Ok(buffer)
364}
365
366#[allow(clippy::too_many_arguments)]
370fn gemm_multiply<T: GpuFloat>(
371 handle: &SolverHandle,
372 trans_a: Transpose,
373 trans_b: Transpose,
374 _m: u32,
375 n: u32,
376 k: u32,
377 a: &DeviceBuffer<T>,
378 lda: u32,
379 b: &DeviceBuffer<T>,
380 ldb: u32,
381 c: &mut DeviceBuffer<T>,
382 ldc: u32,
383) -> SolverResult<()> {
384 let a_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), lda, k, lda, Layout::ColMajor);
385 let b_desc = MatrixDesc::<T>::from_raw(b.as_device_ptr(), ldb, n, ldb, Layout::ColMajor);
386 let mut c_desc = MatrixDescMut::<T>::from_raw(c.as_device_ptr(), ldc, n, ldc, Layout::ColMajor);
387
388 oxicuda_blas::level3::gemm_api::gemm(
389 handle.blas(),
390 trans_a,
391 trans_b,
392 T::gpu_one(),
393 &a_desc,
394 &b_desc,
395 T::gpu_zero(),
396 &mut c_desc,
397 )?;
398
399 Ok(())
400}
401
402fn truncate_to_rank<T: GpuFloat>(singular_values: &[T], max_rank: usize) -> Vec<T> {
404 let mut result: Vec<T> = singular_values.iter().take(max_rank).copied().collect();
405
406 if let Some(&first) = result.first() {
409 let threshold_bits = if T::SIZE == 4 {
410 let first_bits = first.to_bits_u64() as u32;
412 let first_f32 = f32::from_bits(first_bits);
413 let thresh = first_f32 * 1e-7;
414 u64::from(thresh.to_bits())
415 } else {
416 let first_f64 = f64::from_bits(first.to_bits_u64());
418 let thresh = first_f64 * 1e-14;
419 thresh.to_bits()
420 };
421 let threshold = T::from_bits_u64(threshold_bits);
422
423 while result.len() > 1 {
425 if let Some(&last) = result.last() {
426 let last_abs_bits = if T::SIZE == 4 {
428 let bits = last.to_bits_u64() as u32;
429 u64::from(bits & 0x7FFF_FFFF)
430 } else {
431 last.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
432 };
433 let threshold_abs_bits = if T::SIZE == 4 {
434 let bits = threshold.to_bits_u64() as u32;
435 u64::from(bits & 0x7FFF_FFFF)
436 } else {
437 threshold.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
438 };
439
440 if last_abs_bits <= threshold_abs_bits {
441 result.pop();
442 } else {
443 break;
444 }
445 } else {
446 break;
447 }
448 }
449 }
450
451 result
452}
453
454#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn config_default() {
464 let config = RandomizedSvdConfig::default();
465 assert_eq!(config.rank, DEFAULT_RANK);
466 assert_eq!(config.oversampling, DEFAULT_OVERSAMPLING);
467 assert_eq!(config.power_iterations, DEFAULT_POWER_ITERATIONS);
468 assert_eq!(config.seed, 42);
469 }
470
471 #[test]
472 fn config_builder() {
473 let config = RandomizedSvdConfig::with_rank(20)
474 .oversampling(10)
475 .power_iterations(2)
476 .seed(123);
477 assert_eq!(config.rank, 20);
478 assert_eq!(config.oversampling, 10);
479 assert_eq!(config.power_iterations, 2);
480 assert_eq!(config.seed, 123);
481 }
482
483 #[test]
484 fn config_sampling_dim() {
485 let config = RandomizedSvdConfig::with_rank(15).oversampling(5);
486 assert_eq!(config.sampling_dim(), 20);
487 }
488
489 #[test]
490 fn truncate_to_rank_basic() {
491 let sigma: Vec<f64> = vec![5.0, 3.0, 1.0, 0.5, 0.001];
492 let result = truncate_to_rank(&sigma, 3);
493 assert_eq!(result.len(), 3);
494 assert!((result[0] - 5.0).abs() < 1e-10);
495 assert!((result[1] - 3.0).abs() < 1e-10);
496 assert!((result[2] - 1.0).abs() < 1e-10);
497 }
498
499 #[test]
500 fn truncate_to_rank_removes_zeros() {
501 let sigma: Vec<f64> = vec![5.0, 3.0, 0.0, 0.0];
502 let result = truncate_to_rank(&sigma, 4);
503 assert!(result.len() <= 4);
505 assert!(result.len() >= 2);
506 }
507
508 #[test]
509 fn truncate_to_rank_empty() {
510 let sigma: Vec<f64> = Vec::new();
511 let result = truncate_to_rank(&sigma, 5);
512 assert!(result.is_empty());
513 }
514
515 #[test]
516 fn truncate_to_rank_f32() {
517 let sigma: Vec<f32> = vec![10.0, 5.0, 2.0, 0.0];
518 let result = truncate_to_rank(&sigma, 3);
519 assert_eq!(result.len(), 3);
520 }
521
522 #[test]
523 fn truncate_to_rank_max_smaller() {
524 let sigma: Vec<f64> = vec![10.0, 5.0, 2.0, 1.0];
525 let result = truncate_to_rank(&sigma, 2);
526 assert_eq!(result.len(), 2);
527 }
528
529 #[test]
530 fn config_rng_engine_default() {
531 let config = RandomizedSvdConfig::default();
532 assert!(matches!(config.rng_engine, RngEngine::Philox));
533 }
534}