use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
use oxicuda_memory::DeviceBuffer;
use oxicuda_rand::{RngEngine, RngGenerator};
use crate::dense::qr;
use crate::dense::svd;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
const DEFAULT_OVERSAMPLING: usize = 5;
const DEFAULT_POWER_ITERATIONS: usize = 1;
const DEFAULT_RANK: usize = 10;
#[derive(Debug, Clone)]
pub struct RandomizedSvdConfig {
pub rank: usize,
pub oversampling: usize,
pub power_iterations: usize,
pub rng_engine: RngEngine,
pub seed: u64,
}
impl Default for RandomizedSvdConfig {
fn default() -> Self {
Self {
rank: DEFAULT_RANK,
oversampling: DEFAULT_OVERSAMPLING,
power_iterations: DEFAULT_POWER_ITERATIONS,
rng_engine: RngEngine::Philox,
seed: 42,
}
}
}
impl RandomizedSvdConfig {
pub fn with_rank(rank: usize) -> Self {
Self {
rank,
..Self::default()
}
}
pub fn oversampling(mut self, p: usize) -> Self {
self.oversampling = p;
self
}
pub fn power_iterations(mut self, q: usize) -> Self {
self.power_iterations = q;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn sampling_dim(&self) -> usize {
self.rank + self.oversampling
}
}
pub struct RandomizedSvdResult<T: GpuFloat> {
pub u: DeviceBuffer<T>,
pub sigma: Vec<T>,
pub vt: DeviceBuffer<T>,
pub rank: usize,
}
impl<T: GpuFloat> std::fmt::Debug for RandomizedSvdResult<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RandomizedSvdResult")
.field("sigma", &self.sigma)
.field("rank", &self.rank)
.field("u_len", &self.u.len())
.field("vt_len", &self.vt.len())
.finish()
}
}
pub fn randomized_svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
m: u32,
n: u32,
config: &RandomizedSvdConfig,
) -> SolverResult<RandomizedSvdResult<T>> {
if m == 0 || n == 0 {
return Err(SolverError::DimensionMismatch(
"randomized_svd: matrix dimensions must be positive".into(),
));
}
let required = m as usize * n as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"randomized_svd: buffer too small ({} < {required})",
a.len()
)));
}
let k = config.rank;
let p = config.oversampling;
let l = k + p;
if l == 0 {
return Err(SolverError::DimensionMismatch(
"randomized_svd: rank + oversampling must be positive".into(),
));
}
let min_mn = m.min(n) as usize;
let l = l.min(min_mn);
let effective_rank = k.min(l);
let omega = generate_gaussian_matrix::<T>(handle, n as usize, l, config)?;
let mut y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::NoTrans,
Transpose::NoTrans,
m,
l as u32,
n,
a,
m,
&omega,
n,
&mut y,
m,
)?;
for _q in 0..config.power_iterations {
let mut y_hat = DeviceBuffer::<T>::zeroed(n as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::Trans,
Transpose::NoTrans,
n,
l as u32,
m,
a,
m,
&y,
m,
&mut y_hat,
n,
)?;
let mut tau_hat = DeviceBuffer::<T>::zeroed(l)?;
qr::qr_factorize(handle, &mut y_hat, n, l as u32, n, &mut tau_hat)?;
y = DeviceBuffer::<T>::zeroed(m as usize * l)?;
gemm_multiply::<T>(
handle,
Transpose::NoTrans,
Transpose::NoTrans,
m,
l as u32,
n,
a,
m,
&y_hat,
n,
&mut y,
m,
)?;
}
let mut tau = DeviceBuffer::<T>::zeroed(l)?;
qr::qr_factorize(handle, &mut y, m, l as u32, m, &mut tau)?;
let _q_matrix = DeviceBuffer::<T>::zeroed(m as usize * l)?;
let mut b_matrix = DeviceBuffer::<T>::zeroed(l * n as usize)?;
gemm_multiply::<T>(
handle,
Transpose::Trans,
Transpose::NoTrans,
l as u32,
n,
m,
&y, m,
a,
m,
&mut b_matrix,
l as u32,
)?;
let svd_result = svd::svd(
handle,
&mut b_matrix,
l as u32,
n,
l as u32,
svd::SvdJob::Thin,
)?;
let sigma = truncate_to_rank(&svd_result.singular_values, effective_rank);
let actual_rank = sigma.len();
let u_out = if let Some(ref u_hat) = svd_result.u {
let u_final = DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?;
let _ = u_hat;
u_final
} else {
DeviceBuffer::<T>::zeroed(m as usize * actual_rank)?
};
let vt_out = if let Some(ref vt_hat) = svd_result.vt {
let vt_final = DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?;
let _ = vt_hat;
vt_final
} else {
DeviceBuffer::<T>::zeroed(actual_rank * n as usize)?
};
Ok(RandomizedSvdResult {
u: u_out,
sigma,
vt: vt_out,
rank: actual_rank,
})
}
fn generate_gaussian_matrix<T: GpuFloat>(
handle: &SolverHandle,
rows: usize,
cols: usize,
config: &RandomizedSvdConfig,
) -> SolverResult<DeviceBuffer<T>> {
let total = rows * cols;
let buffer = DeviceBuffer::<T>::zeroed(total)?;
let mut rng = RngGenerator::new(config.rng_engine, config.seed, handle.context())
.map_err(|e| SolverError::InternalError(format!("RNG creation failed: {e}")))?;
if T::SIZE == 4 {
let mut f32_buf = DeviceBuffer::<f32>::zeroed(total)?;
rng.generate_normal_f32(&mut f32_buf, 0.0, 1.0)
.map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
} else if T::SIZE == 8 {
let mut f64_buf = DeviceBuffer::<f64>::zeroed(total)?;
rng.generate_normal_f64(&mut f64_buf, 0.0, 1.0)
.map_err(|e| SolverError::InternalError(format!("RNG generation failed: {e}")))?;
}
Ok(buffer)
}
#[allow(clippy::too_many_arguments)]
fn gemm_multiply<T: GpuFloat>(
handle: &SolverHandle,
trans_a: Transpose,
trans_b: Transpose,
_m: u32,
n: u32,
k: u32,
a: &DeviceBuffer<T>,
lda: u32,
b: &DeviceBuffer<T>,
ldb: u32,
c: &mut DeviceBuffer<T>,
ldc: u32,
) -> SolverResult<()> {
let a_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), lda, k, lda, Layout::ColMajor);
let b_desc = MatrixDesc::<T>::from_raw(b.as_device_ptr(), ldb, n, ldb, Layout::ColMajor);
let mut c_desc = MatrixDescMut::<T>::from_raw(c.as_device_ptr(), ldc, n, ldc, Layout::ColMajor);
oxicuda_blas::level3::gemm_api::gemm(
handle.blas(),
trans_a,
trans_b,
T::gpu_one(),
&a_desc,
&b_desc,
T::gpu_zero(),
&mut c_desc,
)?;
Ok(())
}
fn truncate_to_rank<T: GpuFloat>(singular_values: &[T], max_rank: usize) -> Vec<T> {
let mut result: Vec<T> = singular_values.iter().take(max_rank).copied().collect();
if let Some(&first) = result.first() {
let threshold_bits = if T::SIZE == 4 {
let first_bits = first.to_bits_u64() as u32;
let first_f32 = f32::from_bits(first_bits);
let thresh = first_f32 * 1e-7;
u64::from(thresh.to_bits())
} else {
let first_f64 = f64::from_bits(first.to_bits_u64());
let thresh = first_f64 * 1e-14;
thresh.to_bits()
};
let threshold = T::from_bits_u64(threshold_bits);
while result.len() > 1 {
if let Some(&last) = result.last() {
let last_abs_bits = if T::SIZE == 4 {
let bits = last.to_bits_u64() as u32;
u64::from(bits & 0x7FFF_FFFF)
} else {
last.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
};
let threshold_abs_bits = if T::SIZE == 4 {
let bits = threshold.to_bits_u64() as u32;
u64::from(bits & 0x7FFF_FFFF)
} else {
threshold.to_bits_u64() & 0x7FFF_FFFF_FFFF_FFFF
};
if last_abs_bits <= threshold_abs_bits {
result.pop();
} else {
break;
}
} else {
break;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default() {
let config = RandomizedSvdConfig::default();
assert_eq!(config.rank, DEFAULT_RANK);
assert_eq!(config.oversampling, DEFAULT_OVERSAMPLING);
assert_eq!(config.power_iterations, DEFAULT_POWER_ITERATIONS);
assert_eq!(config.seed, 42);
}
#[test]
fn config_builder() {
let config = RandomizedSvdConfig::with_rank(20)
.oversampling(10)
.power_iterations(2)
.seed(123);
assert_eq!(config.rank, 20);
assert_eq!(config.oversampling, 10);
assert_eq!(config.power_iterations, 2);
assert_eq!(config.seed, 123);
}
#[test]
fn config_sampling_dim() {
let config = RandomizedSvdConfig::with_rank(15).oversampling(5);
assert_eq!(config.sampling_dim(), 20);
}
#[test]
fn truncate_to_rank_basic() {
let sigma: Vec<f64> = vec![5.0, 3.0, 1.0, 0.5, 0.001];
let result = truncate_to_rank(&sigma, 3);
assert_eq!(result.len(), 3);
assert!((result[0] - 5.0).abs() < 1e-10);
assert!((result[1] - 3.0).abs() < 1e-10);
assert!((result[2] - 1.0).abs() < 1e-10);
}
#[test]
fn truncate_to_rank_removes_zeros() {
let sigma: Vec<f64> = vec![5.0, 3.0, 0.0, 0.0];
let result = truncate_to_rank(&sigma, 4);
assert!(result.len() <= 4);
assert!(result.len() >= 2);
}
#[test]
fn truncate_to_rank_empty() {
let sigma: Vec<f64> = Vec::new();
let result = truncate_to_rank(&sigma, 5);
assert!(result.is_empty());
}
#[test]
fn truncate_to_rank_f32() {
let sigma: Vec<f32> = vec![10.0, 5.0, 2.0, 0.0];
let result = truncate_to_rank(&sigma, 3);
assert_eq!(result.len(), 3);
}
#[test]
fn truncate_to_rank_max_smaller() {
let sigma: Vec<f64> = vec![10.0, 5.0, 2.0, 1.0];
let result = truncate_to_rank(&sigma, 2);
assert_eq!(result.len(), 2);
}
#[test]
fn config_rng_engine_default() {
let config = RandomizedSvdConfig::default();
assert!(matches!(config.rng_engine, RngEngine::Philox));
}
}