use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
use oxicuda_memory::DeviceBuffer;
use oxicuda_rand::{RngEngine, RngGenerator};
use crate::dense::qr;
use crate::dense::svd;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
const DEFAULT_OVERSAMPLING: usize = 5;
const DEFAULT_POWER_ITERATIONS: usize = 1;
const DEFAULT_RANK: usize = 10;
#[derive(Debug, Clone)]
pub struct RandomizedSvdConfig {
pub rank: usize,
pub oversampling: usize,
pub power_iterations: usize,
pub rng_engine: RngEngine,
pub seed: u64,
}
impl Default for RandomizedSvdConfig {
fn default() -> Self {
Self {
rank: DEFAULT_RANK,
oversampling: DEFAULT_OVERSAMPLING,
power_iterations: DEFAULT_POWER_ITERATIONS,
rng_engine: RngEngine::Philox,
seed: 42,
}
}
}
impl RandomizedSvdConfig {
pub fn with_rank(rank: usize) -> Self {
Self {
rank,
..Self::default()
}
}
pub fn oversampling(mut self, p: usize) -> Self {
self.oversampling = p;
self
}
pub fn power_iterations(mut self, q: usize) -> Self {
self.power_iterations = q;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn sampling_dim(&self) -> usize {
self.rank + self.oversampling
}
}
pub struct RandomizedSvdResult<T: GpuFloat> {
pub u: DeviceBuffer<T>,
pub sigma: Vec<T>,
pub vt: DeviceBuffer<T>,
pub rank: usize,
}
impl<T: GpuFloat> std::fmt::Debug for RandomizedSvdResult<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RandomizedSvdResult")
.field("sigma", &self.sigma)
.field("rank", &self.rank)
.field("u_len", &self.u.len())
.field("vt_len", &self.vt.len())
.finish()
}
}
pub fn randomized_svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
m: u32,
n: u32,
config: &RandomizedSvdConfig,
) -> SolverResult<RandomizedSvdResult<T>> {
if m == 0 || n == 0 {
return Err(SolverError::DimensionMismatch(
"randomized_svd: matrix dimensions must be positive".into(),
));
}
let required = m as usize * n as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"randomized_svd: buffer too small ({} < {required})",
a.len()
)));
}
let k = config.rank;
let p = config.oversampling;
let l = k + p;
if l == 0 {
return Err(SolverError::DimensionMismatch(
"randomized_svd: rank + oversampling must be positive".into(),
));
}
let min_mn = m.min(n) as usize;
let l = l.min(min_mn);
let effective_rank = k.min(l);
let omega = generate_gaussian_matrix::<T>(handle, n as usize, l, config)?;
let mut y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::NoTrans,
Transpose::NoTrans,
m,
l as u32,
n,
a,
m,
&omega,
n,
&mut y,
m,
)?;
for _q in 0..config.power_iterations {
let mut y_hat = DeviceBuffer::<T>::zeroed(n as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::Trans,
Transpose::NoTrans,
n,
l as u32,
m,
a,
m,
&y,
m,
&mut y_hat,
n,
)?;
let mut tau_hat = DeviceBuffer::<T>::zeroed(l)?;
qr::qr_factorize(handle, &mut y_hat, n, l as u32, n, &mut tau_hat)?;
y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::NoTrans,
Transpose::NoTrans,
m,
l as u32,
n,
a,
m,
&y_hat,
n,
&mut y,
m,
)?;
}
let mut tau = DeviceBuffer::<T>::zeroed(l)?;
qr::qr_factorize(handle, &mut y, m, l as u32, m, &mut tau)?;
let mut q_explicit = DeviceBuffer::<T>::zeroed(m as usize * m as usize)?;
qr::qr_generate_q(handle, &y, &tau, &mut q_explicit, m, l as u32)?;
let mut b_matrix = DeviceBuffer::<T>::zeroed(l * n as usize)?;
gemm_multiply::<T>(
handle,
Transpose::Trans,
Transpose::NoTrans,
l as u32,
n,
m,
&q_explicit,
m,
a,
m,
&mut b_matrix,
l as u32,
)?;
let svd_result = svd::svd(
handle,
&mut b_matrix,
l as u32,
n,
l as u32,
svd::SvdJob::Thin,
)?;
let sigma = truncate_to_rank(&svd_result.singular_values, effective_rank);
let actual_rank = sigma.len();
let u_out = if let Some(ref u_hat) = svd_result.u {
let k_hat = svd_result.singular_values.len();
let rank_used = actual_rank.min(k_hat);
let mut u_hat_rank_host = vec![T::gpu_zero(); l * actual_rank];
for col in 0..rank_used {
for row in 0..l {
u_hat_rank_host[col * l + row] = u_hat[col * l + row];
}
}
let mut u_hat_rank = DeviceBuffer::<T>::zeroed(l * actual_rank)?;
u_hat_rank.copy_from_host(&u_hat_rank_host)?;
let mut u_final = DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?;
gemm_multiply::<T>(
handle,
Transpose::NoTrans,
Transpose::NoTrans,
m,
actual_rank as u32,
l as u32,
&q_explicit,
m,
&u_hat_rank,
l as u32,
&mut u_final,
m,
)?;
u_final
} else {
DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?
};
let vt_out = if let Some(ref vt_hat) = svd_result.vt {
let n_usize = n as usize;
let k_hat = svd_result.singular_values.len();
let rank_used = actual_rank.min(k_hat);
let mut vt_host = vec![T::gpu_zero(); actual_rank * n_usize];
for col in 0..n_usize {
for row in 0..rank_used {
vt_host[col * actual_rank + row] = vt_hat[col * k_hat + row];
}
}
let mut vt_final = DeviceBuffer::<T>::zeroed(actual_rank * n_usize)?;
vt_final.copy_from_host(&vt_host)?;
vt_final
} else {
DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?
};
Ok(RandomizedSvdResult {
u: u_out,
sigma,
vt: vt_out,
rank: actual_rank,
})
}
fn generate_gaussian_matrix<T: GpuFloat>(
handle: &SolverHandle,
rows: usize,
cols: usize,
config: &RandomizedSvdConfig,
) -> SolverResult<DeviceBuffer<T>> {
let total = rows * cols;
let mut buffer = DeviceBuffer::<T>::zeroed(total)?;
let mut rng = RngGenerator::new(config.rng_engine, config.seed, handle.context())
.map_err(|e| SolverError::InternalError(format!("RNG creation failed: {e}")))?;
if T::SIZE == 4 {
let mut f32_buf = DeviceBuffer::<f32>::zeroed(total)?;
rng.generate_normal_f32(&mut f32_buf, 0.0, 1.0)
.map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
let mut host_f32 = vec![0.0_f32; total];
f32_buf.copy_to_host(&mut host_f32)?;
let host_t: Vec<T> = host_f32
.into_iter()
.map(|x| T::from_bits_u64(u64::from(x.to_bits())))
.collect();
buffer.copy_from_host(&host_t)?;
} else if T::SIZE == 8 {
let mut f64_buf = DeviceBuffer::<f64>::zeroed(total)?;
rng.generate_normal_f64(&mut f64_buf, 0.0, 1.0)
.map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
let mut host_f64 = vec![0.0_f64; total];
f64_buf.copy_to_host(&mut host_f64)?;
let host_t: Vec<T> = host_f64
.into_iter()
.map(|x| T::from_bits_u64(x.to_bits()))
.collect();
buffer.copy_from_host(&host_t)?;
} else {
return Err(SolverError::InternalError(format!(
"generate_gaussian_matrix: unsupported precision size {}",
T::SIZE
)));
}
Ok(buffer)
}
#[allow(clippy::too_many_arguments)]
fn gemm_multiply<T: GpuFloat>(
handle: &SolverHandle,
trans_a: Transpose,
trans_b: Transpose,
_m: u32,
n: u32,
k: u32,
a: &DeviceBuffer<T>,
lda: u32,
b: &DeviceBuffer<T>,
ldb: u32,
c: &mut DeviceBuffer<T>,
ldc: u32,
) -> SolverResult<()> {
let a_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), lda, k, lda, Layout::ColMajor);
let b_desc = MatrixDesc::<T>::from_raw(b.as_device_ptr(), ldb, n, ldb, Layout::ColMajor);
let mut c_desc = MatrixDescMut::<T>::from_raw(c.as_device_ptr(), ldc, n, ldc, Layout::ColMajor);
oxicuda_blas::level3::gemm_api::gemm(
handle.blas(),
trans_a,
trans_b,
T::gpu_one(),
&a_desc,
&b_desc,
T::gpu_zero(),
&mut c_desc,
)?;
Ok(())
}
fn truncate_to_rank<T: GpuFloat>(singular_values: &[T], max_rank: usize) -> Vec<T> {
let mut result: Vec<T> = singular_values.iter().take(max_rank).copied().collect();
if let Some(&first) = result.first() {
let threshold_bits = if T::SIZE == 4 {
let first_bits = first.to_bits_u64() as u32;
let first_f32 = f32::from_bits(first_bits);
let thresh = first_f32 * 1e-7;
u64::from(thresh.to_bits())
} else {
let first_f64 = f64::from_bits(first.to_bits_u64());
let thresh = first_f64 * 1e-14;
thresh.to_bits()
};
let threshold = T::from_bits_u64(threshold_bits);
while result.len() > 1 {
if let Some(&last) = result.last() {
let last_abs_bits = if T::SIZE == 4 {
let bits = last.to_bits_u64() as u32;
u64::from(bits & 0x7FFF_FFFF)
} else {
last.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
};
let threshold_abs_bits = if T::SIZE == 4 {
let bits = threshold.to_bits_u64() as u32;
u64::from(bits & 0x7FFF_FFFF)
} else {
threshold.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
};
if last_abs_bits <= threshold_abs_bits {
result.pop();
} else {
break;
}
} else {
break;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default() {
let config = RandomizedSvdConfig::default();
assert_eq!(config.rank, DEFAULT_RANK);
assert_eq!(config.oversampling, DEFAULT_OVERSAMPLING);
assert_eq!(config.power_iterations, DEFAULT_POWER_ITERATIONS);
assert_eq!(config.seed, 42);
}
#[test]
fn config_builder() {
let config = RandomizedSvdConfig::with_rank(20)
.oversampling(10)
.power_iterations(2)
.seed(123);
assert_eq!(config.rank, 20);
assert_eq!(config.oversampling, 10);
assert_eq!(config.power_iterations, 2);
assert_eq!(config.seed, 123);
}
#[test]
fn config_sampling_dim() {
let config = RandomizedSvdConfig::with_rank(15).oversampling(5);
assert_eq!(config.sampling_dim(), 20);
}
#[test]
fn truncate_to_rank_basic() {
let sigma: Vec<f64> = vec![5.0, 3.0, 1.0, 0.5, 0.001];
let result = truncate_to_rank(&sigma, 3);
assert_eq!(result.len(), 3);
assert!((result[0] - 5.0).abs() < 1e-10);
assert!((result[1] - 3.0).abs() < 1e-10);
assert!((result[2] - 1.0).abs() < 1e-10);
}
#[test]
fn truncate_to_rank_removes_zeros() {
let sigma: Vec<f64> = vec![5.0, 3.0, 0.0, 0.0];
let result = truncate_to_rank(&sigma, 4);
assert!(result.len() <= 4);
assert!(result.len() >= 2);
}
#[test]
fn truncate_to_rank_empty() {
let sigma: Vec<f64> = Vec::new();
let result = truncate_to_rank(&sigma, 5);
assert!(result.is_empty());
}
#[test]
fn truncate_to_rank_f32() {
let sigma: Vec<f32> = vec![10.0, 5.0, 2.0, 0.0];
let result = truncate_to_rank(&sigma, 3);
assert_eq!(result.len(), 3);
}
#[test]
fn truncate_to_rank_max_smaller() {
let sigma: Vec<f64> = vec![10.0, 5.0, 2.0, 1.0];
let result = truncate_to_rank(&sigma, 2);
assert_eq!(result.len(), 2);
}
#[test]
fn config_rng_engine_default() {
let config = RandomizedSvdConfig::default();
assert!(matches!(config.rng_engine, RngEngine::Philox));
}
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut c = vec![0.0_f32; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0.0_f32;
for ki in 0..k {
acc = f32::mul_add(a[row * k + ki], b[ki * n + col], acc);
}
c[row * n + col] = acc;
}
}
c
}
#[test]
#[allow(clippy::type_complexity)]
fn rsvd_gemm_multiply_signature_exists() {
let _fn_ref: fn(usize, usize, usize, f32, &[f32], &[f32], f32, &[f32]) -> Vec<f32> =
|m, k, n, alpha, a, b, beta, c| {
let raw = cpu_matmul_f32(a, b, m, k, n);
raw.iter()
.zip(c.iter())
.map(|(&r, &c_val)| alpha * r + beta * c_val)
.collect()
};
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0]; let c_init = vec![0.0_f32; 4];
let result = _fn_ref(2, 3, 2, 1.0, &a, &b, 0.0, &c_init);
assert!(
(result[0] - 58.0).abs() < 1e-4,
"GEMM C[0,0] expected 58, got {}",
result[0]
);
assert!(
(result[1] - 64.0).abs() < 1e-4,
"GEMM C[0,1] expected 64, got {}",
result[1]
);
assert!(
(result[2] - 139.0).abs() < 1e-4,
"GEMM C[1,0] expected 139, got {}",
result[2]
);
assert!(
(result[3] - 154.0).abs() < 1e-4,
"GEMM C[1,1] expected 154, got {}",
result[3]
);
}
#[test]
fn rsvd_gemm_sketch_throughput_proxy_256x128_rank16() {
let m = 256_usize;
let k = 128_usize;
let r = 16_usize;
let a: Vec<f32> = (0..m * k)
.map(|i| ((i as f32 * 1.618_034_f32).fract() - 0.5) * 2.0)
.collect();
let omega: Vec<f32> = (0..k * r)
.map(|i| ((i as f32 * std::f32::consts::E).fract() - 0.5) * 0.5)
.collect();
let c_zero = vec![0.0_f32; m * r];
let _ = cpu_matmul_f32(&a, &omega, m, k, r);
const ITERS: usize = 100;
let start = std::time::Instant::now();
let mut sketch = vec![0.0_f32; m * r];
for _ in 0..ITERS {
let raw = cpu_matmul_f32(&a, &omega, m, k, r);
sketch = raw
.into_iter()
.zip(c_zero.iter())
.map(|(r_val, &c_val)| r_val + c_val)
.collect();
}
let elapsed_ns = start.elapsed().as_nanos() as f64;
let flops_per_gemm = 2.0 * m as f64 * k as f64 * r as f64;
let gflops = (flops_per_gemm * ITERS as f64) / elapsed_ns;
println!(
"rSVD GEMM sketch proxy ({}×{} × {}×{}, {} iters): {:.3} GFLOPS (CPU reference)",
m, k, k, r, ITERS, gflops
);
let sketch_norm: f32 = sketch.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
sketch_norm > 0.01,
"Sketch must be non-zero, got norm={}",
sketch_norm
);
assert!(
gflops > 0.0001,
"GEMM sketch throughput unrealistically low: {:.6} GFLOPS",
gflops
);
}
}