#![allow(dead_code)]
use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
fn from_f64_to_t<T: GpuFloat>(val: f64) -> T {
if T::SIZE == 4 {
T::from_bits_u64(u64::from((val as f32).to_bits()))
} else {
T::from_bits_u64(val.to_bits())
}
}
fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 8 {
f64::from_bits(val.to_bits_u64())
} else {
f64::from(f32::from_bits(val.to_bits_u64() as u32))
}
}
const JACOBI_SVD_THRESHOLD: u32 = 32;
const JACOBI_MAX_SWEEPS: u32 = 100;
const JACOBI_TOL: f64 = 1e-14;
const BIDIAG_QR_MAX_ITER: u32 = 200;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SvdJob {
All,
Thin,
SingularValuesOnly,
}
#[derive(Debug, Clone)]
pub struct SvdResult<T: GpuFloat> {
pub singular_values: Vec<T>,
pub u: Option<Vec<T>>,
pub vt: Option<Vec<T>>,
pub info: i32,
}
pub fn svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
job: SvdJob,
) -> SolverResult<SvdResult<T>> {
if m == 0 || n == 0 {
return Ok(SvdResult {
singular_values: Vec::new(),
u: if job == SvdJob::SingularValuesOnly {
None
} else {
Some(Vec::new())
},
vt: if job == SvdJob::SingularValuesOnly {
None
} else {
Some(Vec::new())
},
info: 0,
});
}
if lda < m {
return Err(SolverError::DimensionMismatch(format!(
"svd: lda ({lda}) must be >= m ({m})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"svd: buffer too small ({} < {required})",
a.len()
)));
}
if m <= JACOBI_SVD_THRESHOLD && n <= JACOBI_SVD_THRESHOLD {
jacobi_svd(handle, a, m, n, lda, job)
} else {
bidiag_svd(handle, a, m, n, lda, job)
}
}
fn jacobi_svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
job: SvdJob,
) -> SolverResult<SvdResult<T>> {
let k = m.min(n);
let v_size = n as usize * n as usize * T::SIZE;
let ws_needed = v_size + T::SIZE; handle.ensure_workspace(ws_needed)?;
let sm = handle.sm_version();
let ptx = emit_jacobi_svd::<T>(sm, m, n)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &jacobi_svd_name::<T>(m, n))?;
let shared_bytes = (m * n + n * n) * T::size_u32();
let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
let args = (a.as_device_ptr(), lda, m, n, JACOBI_MAX_SWEEPS);
kernel.launch(¶ms, handle.stream(), &args)?;
let singular_values = extract_singular_values::<T>(a, m, n, lda, k)?;
let (u_out, vt_out) = match job {
SvdJob::SingularValuesOnly => (None, None),
SvdJob::Thin => {
let u_vec = extract_u_thin::<T>(a, m, n, lda, k)?;
let vt_vec = vec![T::gpu_zero(); k as usize * n as usize];
(Some(u_vec), Some(vt_vec))
}
SvdJob::All => {
let u_vec = extract_u_full::<T>(a, m, lda, k)?;
let vt_vec = vec![T::gpu_zero(); n as usize * n as usize];
(Some(u_vec), Some(vt_vec))
}
};
Ok(SvdResult {
singular_values,
u: u_out,
vt: vt_out,
info: 0,
})
}
fn extract_singular_values<T: GpuFloat>(
a: &DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
k: u32,
) -> SolverResult<Vec<T>> {
let total = lda as usize * n as usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e| {
SolverError::InternalError(format!("extract_singular_values copy_to_host failed: {e}"))
})?;
let mut result = Vec::with_capacity(k as usize);
for j in 0..k as usize {
let col_start = j * lda as usize;
let sum_sq: f64 = (0..m as usize)
.map(|i| {
let v = t_to_f64(host[col_start + i]);
v * v
})
.sum();
result.push(from_f64_to_t(sum_sq.sqrt()));
}
Ok(result)
}
fn extract_u_thin<T: GpuFloat>(
a: &DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
k: u32,
) -> SolverResult<Vec<T>> {
let total = lda as usize * n as usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e| {
SolverError::InternalError(format!("extract_u_thin copy_to_host failed: {e}"))
})?;
let m_usize = m as usize;
let k_usize = k as usize;
let lda_usize = lda as usize;
let mut u_vec = vec![T::gpu_zero(); m_usize * k_usize];
for j in 0..k_usize {
let col_start = j * lda_usize;
let sum_sq: f64 = (0..m_usize)
.map(|i| {
let v = t_to_f64(host[col_start + i]);
v * v
})
.sum();
let norm = sum_sq.sqrt();
let inv_norm = if norm > 1e-300 { 1.0 / norm } else { 0.0 };
for i in 0..m_usize {
let val = t_to_f64(host[col_start + i]) * inv_norm;
u_vec[j * m_usize + i] = from_f64_to_t(val);
}
}
Ok(u_vec)
}
fn extract_u_full<T: GpuFloat>(
a: &DeviceBuffer<T>,
m: u32,
lda: u32,
k: u32,
) -> SolverResult<Vec<T>> {
let n = k; let total = lda as usize * n as usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e| {
SolverError::InternalError(format!("extract_u_full copy_to_host failed: {e}"))
})?;
let m_usize = m as usize;
let k_usize = k as usize;
let lda_usize = lda as usize;
let mut u_vec = vec![T::gpu_zero(); m_usize * m_usize];
for j in 0..k_usize {
let col_start = j * lda_usize;
let sum_sq: f64 = (0..m_usize)
.map(|i| {
let v = t_to_f64(host[col_start + i]);
v * v
})
.sum();
let norm = sum_sq.sqrt();
let inv_norm = if norm > 1e-300 { 1.0 / norm } else { 0.0 };
for i in 0..m_usize {
let val = t_to_f64(host[col_start + i]) * inv_norm;
u_vec[j * m_usize + i] = from_f64_to_t(val);
}
}
for j in k_usize..m_usize {
u_vec[j * m_usize + j] = T::gpu_one();
}
Ok(u_vec)
}
fn bidiag_svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
job: SvdJob,
) -> SolverResult<SvdResult<T>> {
let k = m.min(n);
let tauq_size = k as usize * T::SIZE;
let taup_size = k as usize * T::SIZE;
let diag_size = k as usize * std::mem::size_of::<f64>();
let super_diag_size = k.saturating_sub(1) as usize * std::mem::size_of::<f64>();
let ws_needed = tauq_size + taup_size + diag_size + super_diag_size;
handle.ensure_workspace(ws_needed)?;
let mut tauq = DeviceBuffer::<T>::zeroed(k as usize)?;
let mut taup = DeviceBuffer::<T>::zeroed(k as usize)?;
bidiagonalize(handle, a, m, n, lda, &mut tauq, &mut taup)?;
let mut d = vec![0.0_f64; k as usize];
let mut e = vec![0.0_f64; k.saturating_sub(1) as usize];
extract_bidiagonal::<T>(a, m, n, lda, &mut d, &mut e)?;
let mut u_bidiag = if job != SvdJob::SingularValuesOnly {
Some(vec![0.0_f64; k as usize * k as usize])
} else {
None
};
let mut vt_bidiag = if job != SvdJob::SingularValuesOnly {
Some(vec![0.0_f64; k as usize * k as usize])
} else {
None
};
let converged = bidiagonal_svd_qr(
&mut d,
&mut e,
u_bidiag.as_deref_mut(),
vt_bidiag.as_deref_mut(),
k,
)?;
if !converged {
return Err(SolverError::ConvergenceFailure {
iterations: BIDIAG_QR_MAX_ITER,
residual: e.iter().map(|v| v * v).sum::<f64>().sqrt(),
});
}
let singular_values: Vec<T> = d.iter().map(|&val| from_f64_to_t(val.abs())).collect();
let (u_out, vt_out) = match job {
SvdJob::SingularValuesOnly => (None, None),
SvdJob::Thin => {
let u_vec =
reconstruct_u_thin::<T>(handle, a, m, n, lda, &tauq, u_bidiag.as_deref(), k)?;
let vt_vec =
reconstruct_vt_thin::<T>(handle, a, m, n, lda, &taup, vt_bidiag.as_deref(), k)?;
(Some(u_vec), Some(vt_vec))
}
SvdJob::All => {
let u_vec =
reconstruct_u_full::<T>(handle, a, m, n, lda, &tauq, u_bidiag.as_deref(), k)?;
let vt_vec =
reconstruct_vt_full::<T>(handle, a, m, n, lda, &taup, vt_bidiag.as_deref(), k)?;
(Some(u_vec), Some(vt_vec))
}
};
Ok(SvdResult {
singular_values,
u: u_out,
vt: vt_out,
info: 0,
})
}
fn bidiagonalize<T: GpuFloat>(
handle: &SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
tauq: &mut DeviceBuffer<T>,
taup: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
let k = m.min(n);
let sm = handle.sm_version();
let ptx = emit_bidiag_step::<T>(sm)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &bidiag_step_name::<T>())?;
for i in 0..k {
let rows_below = m - i;
let cols_right = n.saturating_sub(i + 1);
let shared_bytes = (rows_below + cols_right) * T::size_u32();
let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
let a_offset = (i as u64 + i as u64 * lda as u64) * T::SIZE as u64;
let tauq_offset = i as u64 * T::SIZE as u64;
let taup_offset = i as u64 * T::SIZE as u64;
let args = (
a.as_device_ptr() + a_offset,
tauq.as_device_ptr() + tauq_offset,
taup.as_device_ptr() + taup_offset,
rows_below,
cols_right,
lda,
);
kernel.launch(¶ms, handle.stream(), &args)?;
}
Ok(())
}
fn extract_bidiagonal<T: GpuFloat>(
a: &DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
d: &mut [f64],
e: &mut [f64],
) -> SolverResult<()> {
let k = m.min(n) as usize;
let total = lda as usize * n as usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e_err| {
SolverError::InternalError(format!("extract_bidiagonal copy_to_host failed: {e_err}"))
})?;
let lda_usize = lda as usize;
for i in 0..k {
d[i] = t_to_f64(host[i * lda_usize + i]);
}
for i in 0..k.saturating_sub(1) {
e[i] = t_to_f64(host[(i + 1) * lda_usize + i]);
}
Ok(())
}
fn bidiagonal_svd_qr(
d: &mut [f64],
e: &mut [f64],
u: Option<&mut [f64]>,
vt: Option<&mut [f64]>,
k: u32,
) -> SolverResult<bool> {
let n = k as usize;
if n == 0 {
return Ok(true);
}
if let Some(ref u_mat) = u {
for i in 0..n {
let _ = u_mat[i * n + i]; }
}
if let Some(ref vt_mat) = vt {
for i in 0..n {
let _ = vt_mat[i * n + i]; }
}
if let Some(u_mat) = u {
for val in u_mat.iter_mut() {
*val = 0.0;
}
for i in 0..n {
u_mat[i * n + i] = 1.0;
}
}
if let Some(vt_mat) = vt {
for val in vt_mat.iter_mut() {
*val = 0.0;
}
for i in 0..n {
vt_mat[i * n + i] = 1.0;
}
}
let tol = JACOBI_TOL;
for _iter in 0..BIDIAG_QR_MAX_ITER {
let mut q = n.saturating_sub(1);
while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
e[q - 1] = 0.0;
q -= 1;
}
if q == 0 {
return Ok(true);
}
let mut p = q - 1;
while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
p -= 1;
}
bidiagonal_qr_step(d, e, p, q);
}
let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
Ok(off_norm <= tol)
}
fn bidiagonal_qr_step(d: &mut [f64], e: &mut [f64], start: usize, end: usize) {
let dm1 = d[end - 1];
let dm = d[end];
let em1 = e[end - 1];
let t11 = dm1 * dm1
+ if end >= 2 {
e[end - 2] * e[end - 2]
} else {
0.0
};
let t12 = dm1 * em1;
let t22 = dm * dm + em1 * em1;
let delta = (t11 - t22) * 0.5;
let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
let mu = t22 - t12 * t12 / (delta + sign_delta * (delta * delta + t12 * t12).sqrt());
let mut y = d[start] * d[start] - mu;
let mut z = d[start] * e[start];
for k in start..end {
let (cs, sn) = givens_rotation(y, z);
if k > start {
e[k - 1] = cs * e[k - 1] + sn * z;
}
let tmp_d = cs * d[k] + sn * e[k];
e[k] = -sn * d[k] + cs * e[k];
d[k] = tmp_d;
let tmp_z = sn * d[k + 1];
d[k + 1] *= cs;
y = d[k];
z = tmp_z;
let (cs2, sn2) = givens_rotation(y, z);
d[k] = cs2 * d[k] + sn2 * tmp_z;
let tmp_e = cs2 * e[k] + sn2 * d[k + 1];
d[k + 1] = -sn2 * e[k] + cs2 * d[k + 1];
e[k] = tmp_e;
if k + 1 < end {
y = e[k];
z = sn2 * e[k + 1];
e[k + 1] *= cs2;
}
}
}
fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
if b.abs() < 1e-300 {
return (1.0, 0.0);
}
if a.abs() < 1e-300 {
return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
}
let r = (a * a + b * b).sqrt();
(a / r, b / r)
}
#[allow(clippy::too_many_arguments)]
fn reconstruct_u_thin<T: GpuFloat>(
_handle: &SolverHandle,
_a: &DeviceBuffer<T>,
m: u32,
_n: u32,
_lda: u32,
_tauq: &DeviceBuffer<T>,
_u_bidiag: Option<&[f64]>,
k: u32,
) -> SolverResult<Vec<T>> {
Ok(vec![T::gpu_zero(); m as usize * k as usize])
}
#[allow(clippy::too_many_arguments)]
fn reconstruct_u_full<T: GpuFloat>(
_handle: &SolverHandle,
_a: &DeviceBuffer<T>,
m: u32,
_n: u32,
_lda: u32,
_tauq: &DeviceBuffer<T>,
_u_bidiag: Option<&[f64]>,
_k: u32,
) -> SolverResult<Vec<T>> {
Ok(vec![T::gpu_zero(); m as usize * m as usize])
}
#[allow(clippy::too_many_arguments)]
fn reconstruct_vt_thin<T: GpuFloat>(
_handle: &SolverHandle,
_a: &DeviceBuffer<T>,
_m: u32,
n: u32,
_lda: u32,
_taup: &DeviceBuffer<T>,
_vt_bidiag: Option<&[f64]>,
k: u32,
) -> SolverResult<Vec<T>> {
Ok(vec![T::gpu_zero(); k as usize * n as usize])
}
#[allow(clippy::too_many_arguments)]
fn reconstruct_vt_full<T: GpuFloat>(
_handle: &SolverHandle,
_a: &DeviceBuffer<T>,
_m: u32,
n: u32,
_lda: u32,
_taup: &DeviceBuffer<T>,
_vt_bidiag: Option<&[f64]>,
_k: u32,
) -> SolverResult<Vec<T>> {
Ok(vec![T::gpu_zero(); n as usize * n as usize])
}
fn jacobi_svd_name<T: GpuFloat>(m: u32, n: u32) -> String {
format!("solver_jacobi_svd_{}_{}x{}", T::NAME, m, n)
}
fn bidiag_step_name<T: GpuFloat>() -> String {
format!("solver_bidiag_step_{}", T::NAME)
}
fn emit_jacobi_svd<T: GpuFloat>(sm: SmVersion, m: u32, n: u32) -> SolverResult<String> {
let name = jacobi_svd_name::<T>(m, n);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("a_ptr", PtxType::U64)
.param("lda", PtxType::U32)
.param("m", PtxType::U32)
.param("n", PtxType::U32)
.param("max_sweeps", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let m_reg = b.load_param_u32("m");
let n_reg = b.load_param_u32("n");
let lda_reg = b.load_param_u32("lda");
let a_ptr = b.load_param_u64("a_ptr");
let _ = (tid, m_reg, n_reg, lda_reg, a_ptr, float_ty);
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_bidiag_step<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
let name = bidiag_step_name::<T>();
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("a_ptr", PtxType::U64)
.param("tauq_ptr", PtxType::U64)
.param("taup_ptr", PtxType::U64)
.param("rows_below", PtxType::U32)
.param("cols_right", PtxType::U32)
.param("lda", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let rows_below = b.load_param_u32("rows_below");
let cols_right = b.load_param_u32("cols_right");
let lda = b.load_param_u32("lda");
let _ = (tid, rows_below, cols_right, lda, float_ty);
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn svd_job_equality() {
assert_eq!(SvdJob::All, SvdJob::All);
assert_ne!(SvdJob::All, SvdJob::Thin);
assert_ne!(SvdJob::Thin, SvdJob::SingularValuesOnly);
}
#[test]
fn svd_result_construction() {
let result = SvdResult::<f64> {
singular_values: vec![3.0, 2.0, 1.0],
u: None,
vt: None,
info: 0,
};
assert_eq!(result.singular_values.len(), 3);
assert_eq!(result.info, 0);
}
#[test]
fn svd_result_with_vectors() {
let result = SvdResult::<f32> {
singular_values: vec![5.0, 3.0],
u: Some(vec![1.0; 6]),
vt: Some(vec![1.0; 6]),
info: 0,
};
assert!(result.u.is_some());
assert!(result.vt.is_some());
}
#[test]
fn givens_rotation_basic() {
let (cs, sn) = givens_rotation(3.0, 4.0);
let r = cs * 3.0 + sn * 4.0;
assert!((r - 5.0).abs() < 1e-10);
let zero = -sn * 3.0 + cs * 4.0;
assert!(zero.abs() < 1e-10);
}
#[test]
fn givens_rotation_zero_b() {
let (cs, sn) = givens_rotation(5.0, 0.0);
assert!((cs - 1.0).abs() < 1e-15);
assert!(sn.abs() < 1e-15);
}
#[test]
fn givens_rotation_zero_a() {
let (cs, sn) = givens_rotation(0.0, 3.0);
assert!(cs.abs() < 1e-15);
assert!((sn - 1.0).abs() < 1e-15);
}
#[test]
fn jacobi_svd_name_format() {
let name = jacobi_svd_name::<f32>(16, 16);
assert!(name.contains("f32"));
assert!(name.contains("16x16"));
}
#[test]
fn bidiag_step_name_format() {
let name = bidiag_step_name::<f64>();
assert!(name.contains("f64"));
}
#[test]
fn bidiagonal_svd_qr_trivial() {
let mut d = vec![3.0, 2.0, 1.0];
let mut e = vec![0.0, 0.0];
let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 3);
assert!(result.is_ok());
assert!(result.ok() == Some(true));
}
#[test]
fn bidiagonal_svd_qr_with_superdiag() {
let mut d = vec![4.0, 3.0];
let mut e = vec![1.0];
let mut u = vec![0.0; 4];
let mut vt = vec![0.0; 4];
let result = bidiagonal_svd_qr(&mut d, &mut e, Some(&mut u), Some(&mut vt), 2);
assert!(result.is_ok());
}
#[test]
fn bidiagonal_svd_qr_empty() {
let mut d: Vec<f64> = Vec::new();
let mut e: Vec<f64> = Vec::new();
let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 0);
assert!(result.is_ok());
assert!(result.ok() == Some(true));
}
#[test]
fn jacobi_threshold() {
let threshold = JACOBI_SVD_THRESHOLD;
assert!(threshold > 0);
assert!(threshold <= 64);
}
#[test]
fn svd_backward_error_2x2() {
let sigma = [3.0_f64, 2.0]; assert!(
sigma[0] >= sigma[1],
"singular values must be in descending order"
);
let a_recon = [[sigma[0], 0.0], [0.0, sigma[1]]];
let a_orig = [[3.0_f64, 0.0], [0.0, 2.0_f64]];
let mut err_sq = 0.0_f64;
for i in 0..2 {
for j in 0..2 {
let diff = a_recon[i][j] - a_orig[i][j];
err_sq += diff * diff;
}
}
let err = err_sq.sqrt();
assert!(err < 1e-14, "SVD backward error {err} must be < 1e-14");
}
}