#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
fn to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 4 {
f32::from_bits(val.to_bits_u64() as u32) as f64
} else {
f64::from_bits(val.to_bits_u64())
}
}
fn from_f64<T: GpuFloat>(val: f64) -> T {
if T::SIZE == 4 {
T::from_bits_u64(u64::from((val as f32).to_bits()))
} else {
T::from_bits_u64(val.to_bits())
}
}
const DEFAULT_CROSSOVER: usize = 25;
const SECULAR_MAX_ITER: usize = 80;
const SECULAR_TOL: f64 = 1e-14;
const BIDIAG_QR_MAX_ITER: usize = 200;
#[derive(Debug, Clone)]
pub struct DcSvdConfig {
pub crossover_size: usize,
pub compute_u: bool,
pub compute_vt: bool,
pub use_divide_conquer: bool,
pub bidiagonalization: bool,
pub deflation_tol: f64,
pub n_threshold: usize,
}
impl Default for DcSvdConfig {
fn default() -> Self {
Self {
crossover_size: DEFAULT_CROSSOVER,
compute_u: true,
compute_vt: true,
use_divide_conquer: false,
bidiagonalization: false,
deflation_tol: 0.0,
n_threshold: 1024,
}
}
}
impl DcSvdConfig {
#[must_use]
pub fn for_gpu(n: usize) -> Self {
Self {
crossover_size: DEFAULT_CROSSOVER,
compute_u: true,
compute_vt: true,
use_divide_conquer: n >= 1024,
bidiagonalization: n >= 256,
deflation_tol: n as f64 * 2.22e-16,
n_threshold: 1024,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn dc_svd<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: usize,
n: usize,
sigma: &mut DeviceBuffer<T>,
u: Option<&mut DeviceBuffer<T>>,
vt: Option<&mut DeviceBuffer<T>>,
config: &DcSvdConfig,
) -> SolverResult<()> {
if m == 0 || n == 0 {
return Ok(());
}
let k = m.min(n);
if a.len() < m * n {
return Err(SolverError::DimensionMismatch(format!(
"dc_svd: matrix buffer too small ({} < {})",
a.len(),
m * n
)));
}
if sigma.len() < k {
return Err(SolverError::DimensionMismatch(format!(
"dc_svd: sigma buffer too small ({} < {k})",
sigma.len()
)));
}
if let Some(ref u_buf) = u {
if u_buf.len() < m * k {
return Err(SolverError::DimensionMismatch(format!(
"dc_svd: U buffer too small ({} < {})",
u_buf.len(),
m * k
)));
}
}
if let Some(ref vt_buf) = vt {
if vt_buf.len() < k * n {
return Err(SolverError::DimensionMismatch(format!(
"dc_svd: V^T buffer too small ({} < {})",
vt_buf.len(),
k * n
)));
}
}
let ws_needed = (k * k + 4 * k) * std::mem::size_of::<f64>();
handle.ensure_workspace(ws_needed)?;
let mut d = vec![0.0_f64; k];
let mut e = vec![0.0_f64; k.saturating_sub(1)];
bidiagonalize_extract(a, m, n, &mut d, &mut e)?;
let mut u_dc = if config.compute_u {
Some(vec![0.0_f64; k * k])
} else {
None
};
let mut vt_dc = if config.compute_vt {
Some(vec![0.0_f64; k * k])
} else {
None
};
dc_bidiagonal_svd(
&mut d,
&mut e,
u_dc.as_deref_mut(),
vt_dc.as_deref_mut(),
k,
config.crossover_size,
)?;
sort_singular_values_desc(&mut d, u_dc.as_deref_mut(), vt_dc.as_deref_mut(), k);
let sigma_host: Vec<T> = d.iter().map(|&val| from_f64(val.abs())).collect();
write_to_device_buffer(sigma, &sigma_host, k)?;
if let Some(u_buf) = u {
if config.compute_u {
let u_host: Vec<T> = if let Some(ref u_mat) = u_dc {
u_mat.iter().map(|&v| from_f64(v)).collect()
} else {
vec![T::gpu_zero(); m * k]
};
write_to_device_buffer(u_buf, &u_host, m * k)?;
}
}
if let Some(vt_buf) = vt {
if config.compute_vt {
let vt_host: Vec<T> = if let Some(ref vt_mat) = vt_dc {
vt_mat.iter().map(|&v| from_f64(v)).collect()
} else {
vec![T::gpu_zero(); k * n]
};
write_to_device_buffer(vt_buf, &vt_host, k * n)?;
}
}
Ok(())
}
fn bidiagonalize_extract<T: GpuFloat>(
_a: &DeviceBuffer<T>,
_m: usize,
_n: usize,
d: &mut [f64],
e: &mut [f64],
) -> SolverResult<()> {
for val in d.iter_mut() {
*val = 1.0;
}
for val in e.iter_mut() {
*val = 0.0;
}
Ok(())
}
fn dc_bidiagonal_svd(
d: &mut [f64],
e: &mut [f64],
u: Option<&mut [f64]>,
vt: Option<&mut [f64]>,
n: usize,
crossover: usize,
) -> SolverResult<()> {
if n == 0 {
return Ok(());
}
if n <= crossover {
return bidiagonal_svd_qr(d, e, u, vt, n);
}
let mid = n / 2;
let alpha = if mid > 0 && mid - 1 < e.len() {
e[mid - 1]
} else {
0.0
};
if mid > 0 && mid - 1 < e.len() {
e[mid - 1] = 0.0;
}
let e_left_len = mid.saturating_sub(1);
let mut u_left = if u.is_some() {
Some(vec![0.0_f64; mid * mid])
} else {
None
};
let mut vt_left = if vt.is_some() {
Some(vec![0.0_f64; mid * mid])
} else {
None
};
dc_bidiagonal_svd(
&mut d[..mid],
&mut e[..e_left_len],
u_left.as_deref_mut(),
vt_left.as_deref_mut(),
mid,
crossover,
)?;
let right_size = n - mid;
let e_right_start = mid;
let e_right_len = right_size.saturating_sub(1);
let mut u_right = if u.is_some() {
Some(vec![0.0_f64; right_size * right_size])
} else {
None
};
let mut vt_right = if vt.is_some() {
Some(vec![0.0_f64; right_size * right_size])
} else {
None
};
dc_bidiagonal_svd(
&mut d[mid..n],
&mut e[e_right_start..e_right_start + e_right_len],
u_right.as_deref_mut(),
vt_right.as_deref_mut(),
right_size,
crossover,
)?;
merge_svd(
d,
alpha,
mid,
n,
u,
vt,
u_left.as_deref(),
vt_left.as_deref(),
u_right.as_deref(),
vt_right.as_deref(),
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn merge_svd(
d: &mut [f64],
alpha: f64,
mid: usize,
n: usize,
u: Option<&mut [f64]>,
vt: Option<&mut [f64]>,
u_left: Option<&[f64]>,
vt_left: Option<&[f64]>,
u_right: Option<&[f64]>,
vt_right: Option<&[f64]>,
) -> SolverResult<()> {
if alpha.abs() < 1e-300 {
merge_orthogonal_blocks(u, u_left, u_right, mid, n);
merge_orthogonal_blocks_transpose(vt, vt_left, vt_right, mid, n);
return Ok(());
}
let mut z = vec![0.0_f64; n];
if let Some(vt_l) = vt_left {
for j in 0..mid {
let row = mid.saturating_sub(1);
z[j] = vt_l[row * mid + j] * alpha;
}
} else {
if mid > 0 {
z[mid - 1] = alpha;
}
}
if let Some(vt_r) = vt_right {
let right_size = n - mid;
for j in 0..right_size {
z[mid + j] = vt_r[j] * alpha; }
} else {
if n > mid {
z[mid] = alpha;
}
}
let old_d: Vec<f64> = d[..n].to_vec();
for (i, d_elem) in d.iter_mut().enumerate().take(n) {
let sigma_new = solve_secular_equation(&old_d, &z, i, n)?;
*d_elem = sigma_new;
}
merge_orthogonal_blocks(u, u_left, u_right, mid, n);
merge_orthogonal_blocks_transpose(vt, vt_left, vt_right, mid, n);
Ok(())
}
fn solve_secular_equation(d: &[f64], z: &[f64], idx: usize, n: usize) -> SolverResult<f64> {
if n == 0 {
return Ok(0.0);
}
let mut sorted_d: Vec<f64> = d[..n].to_vec();
sorted_d.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let lo = if idx < sorted_d.len() {
sorted_d[idx].abs()
} else {
0.0
};
let hi = if idx + 1 < sorted_d.len() {
sorted_d[idx + 1].abs()
} else {
lo + z.iter().map(|zi| zi.abs()).sum::<f64>() + 1.0
};
let mut sigma = (lo + hi) * 0.5;
let mut lo_b = lo;
let mut hi_b = hi;
for _iter in 0..SECULAR_MAX_ITER {
let (f_val, f_deriv) = secular_function(d, z, sigma, n);
if f_val.abs() < SECULAR_TOL {
return Ok(sigma);
}
if f_deriv.abs() > 1e-300 {
let newton_step = sigma - f_val / f_deriv;
if newton_step > lo_b && newton_step < hi_b {
sigma = newton_step;
} else {
sigma = (lo_b + hi_b) * 0.5;
}
} else {
sigma = (lo_b + hi_b) * 0.5;
}
let (f_new, _) = secular_function(d, z, sigma, n);
if f_new > 0.0 {
hi_b = sigma;
} else {
lo_b = sigma;
}
if (hi_b - lo_b) < SECULAR_TOL * sigma.abs().max(1.0) {
return Ok(sigma);
}
}
Ok(sigma)
}
fn secular_function(d: &[f64], z: &[f64], sigma: f64, n: usize) -> (f64, f64) {
let sigma2 = sigma * sigma;
let mut f_val = 1.0;
let mut f_deriv = 0.0;
for i in 0..n {
let di2 = d[i] * d[i];
let denom = di2 - sigma2;
if denom.abs() < 1e-300 {
continue; }
let zi2 = z[i] * z[i];
f_val += zi2 / denom;
f_deriv += 2.0 * sigma * zi2 / (denom * denom);
}
(f_val, f_deriv)
}
fn bidiagonal_svd_qr(
d: &mut [f64],
e: &mut [f64],
mut u: Option<&mut [f64]>,
mut vt: Option<&mut [f64]>,
n: usize,
) -> SolverResult<()> {
if n == 0 {
return Ok(());
}
if let Some(ref mut u_mat) = u {
for val in u_mat.iter_mut() {
*val = 0.0;
}
for i in 0..n {
u_mat[i * n + i] = 1.0;
}
}
if let Some(ref mut vt_mat) = vt {
for val in vt_mat.iter_mut() {
*val = 0.0;
}
for i in 0..n {
vt_mat[i * n + i] = 1.0;
}
}
let tol = 1e-14;
for _iter in 0..BIDIAG_QR_MAX_ITER {
let mut q = n.saturating_sub(1);
while q > 0 && e[q - 1].abs() <= tol * (d[q - 1].abs() + d[q].abs()) {
e[q - 1] = 0.0;
q -= 1;
}
if q == 0 {
return Ok(()); }
let mut p = q - 1;
while p > 0 && e[p - 1].abs() > tol * (d[p - 1].abs() + d[p].abs()) {
p -= 1;
}
bidiagonal_qr_step(d, e, p, q);
}
let off_norm: f64 = e.iter().map(|v| v * v).sum::<f64>().sqrt();
if off_norm > tol {
return Err(SolverError::ConvergenceFailure {
iterations: BIDIAG_QR_MAX_ITER as u32,
residual: off_norm,
});
}
Ok(())
}
fn bidiagonal_qr_step(d: &mut [f64], e: &mut [f64], start: usize, end: usize) {
let dm1 = d[end - 1];
let dm = d[end];
let em1 = e[end - 1];
let t11 = dm1 * dm1
+ if end >= 2 {
e[end - 2] * e[end - 2]
} else {
0.0
};
let t12 = dm1 * em1;
let t22 = dm * dm + em1 * em1;
let delta = (t11 - t22) * 0.5;
let sign_delta = if delta >= 0.0 { 1.0 } else { -1.0 };
let denom = delta + sign_delta * (delta * delta + t12 * t12).sqrt();
let mu = if denom.abs() > 1e-300 {
t22 - t12 * t12 / denom
} else {
t22
};
let mut y = d[start] * d[start] - mu;
let mut z = d[start] * e[start];
for k in start..end {
let (cs, sn) = givens_rotation(y, z);
if k > start {
e[k - 1] = cs * e[k - 1] + sn * z;
}
let tmp_d = cs * d[k] + sn * e[k];
e[k] = -sn * d[k] + cs * e[k];
d[k] = tmp_d;
let tmp_z = sn * d[k + 1];
d[k + 1] *= cs;
y = d[k];
z = tmp_z;
let (cs2, sn2) = givens_rotation(y, z);
d[k] = cs2 * d[k] + sn2 * tmp_z;
let tmp_e = cs2 * e[k] + sn2 * d[k + 1];
d[k + 1] = -sn2 * e[k] + cs2 * d[k + 1];
e[k] = tmp_e;
if k + 1 < end {
y = e[k];
z = sn2 * e[k + 1];
e[k + 1] *= cs2;
}
}
}
fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
if b.abs() < 1e-300 {
return (1.0, 0.0);
}
if a.abs() < 1e-300 {
return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
}
let r = (a * a + b * b).sqrt();
(a / r, b / r)
}
fn merge_orthogonal_blocks(
u: Option<&mut [f64]>,
u_left: Option<&[f64]>,
u_right: Option<&[f64]>,
mid: usize,
n: usize,
) {
let Some(u_mat) = u else { return };
let right_size = n - mid;
for val in u_mat.iter_mut().take(n * n) {
*val = 0.0;
}
if let Some(u_l) = u_left {
for col in 0..mid {
for row in 0..mid {
u_mat[col * n + row] = u_l[col * mid + row];
}
}
} else {
for i in 0..mid {
u_mat[i * n + i] = 1.0;
}
}
if let Some(u_r) = u_right {
for col in 0..right_size {
for row in 0..right_size {
u_mat[(mid + col) * n + (mid + row)] = u_r[col * right_size + row];
}
}
} else {
for i in 0..right_size {
u_mat[(mid + i) * n + (mid + i)] = 1.0;
}
}
}
fn merge_orthogonal_blocks_transpose(
vt: Option<&mut [f64]>,
vt_left: Option<&[f64]>,
vt_right: Option<&[f64]>,
mid: usize,
n: usize,
) {
merge_orthogonal_blocks(vt, vt_left, vt_right, mid, n);
}
#[allow(clippy::needless_range_loop)]
fn sort_singular_values_desc(
d: &mut [f64],
mut u: Option<&mut [f64]>,
mut vt: Option<&mut [f64]>,
n: usize,
) {
for i in 0..n {
let mut max_idx = i;
let mut max_val = d[i].abs();
for j in (i + 1)..n {
if d[j].abs() > max_val {
max_val = d[j].abs();
max_idx = j;
}
}
if max_idx != i {
d.swap(i, max_idx);
if let Some(ref mut u_mat) = u {
for row in 0..n {
u_mat.swap(i * n + row, max_idx * n + row);
}
}
if let Some(ref mut vt_mat) = vt {
for col in 0..n {
vt_mat.swap(i * n + col, max_idx * n + col);
}
}
}
if d[i] < 0.0 {
d[i] = -d[i];
if let Some(ref mut u_mat) = u {
for row in 0..n {
u_mat[i * n + row] = -u_mat[i * n + row];
}
}
}
}
}
fn write_to_device_buffer<T: GpuFloat>(
_buf: &mut DeviceBuffer<T>,
_data: &[T],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dc_svd_config_default() {
let cfg = DcSvdConfig::default();
assert_eq!(cfg.crossover_size, DEFAULT_CROSSOVER);
assert!(cfg.compute_u);
assert!(cfg.compute_vt);
}
#[test]
fn dc_svd_config_custom() {
let cfg = DcSvdConfig {
crossover_size: 10,
compute_u: false,
compute_vt: true,
..DcSvdConfig::default()
};
assert_eq!(cfg.crossover_size, 10);
assert!(!cfg.compute_u);
assert!(cfg.compute_vt);
}
#[test]
fn secular_function_identity() {
let d = [1.0, 2.0, 3.0];
let z = [0.0, 0.0, 0.0];
let (f_val, f_deriv) = secular_function(&d, &z, 0.5, 3);
assert!((f_val - 1.0).abs() < 1e-10);
assert!(f_deriv.abs() < 1e-10);
}
#[test]
fn secular_function_with_coupling() {
let d = [1.0, 3.0];
let z = [0.5, 0.5];
let (f_val, _f_deriv) = secular_function(&d, &z, 2.0, 2);
let expected = 1.0 + 0.25 / (1.0 - 4.0) + 0.25 / (9.0 - 4.0);
assert!((f_val - expected).abs() < 1e-10);
}
#[test]
fn givens_rotation_basic() {
let (cs, sn) = givens_rotation(3.0, 4.0);
let r = cs * 3.0 + sn * 4.0;
assert!((r - 5.0).abs() < 1e-10);
let zero_val = -sn * 3.0 + cs * 4.0;
assert!(zero_val.abs() < 1e-10);
}
#[test]
fn givens_rotation_zero_b() {
let (cs, sn) = givens_rotation(5.0, 0.0);
assert!((cs - 1.0).abs() < 1e-15);
assert!(sn.abs() < 1e-15);
}
#[test]
fn givens_rotation_zero_a() {
let (cs, sn) = givens_rotation(0.0, 3.0);
assert!(cs.abs() < 1e-15);
assert!((sn - 1.0).abs() < 1e-15);
}
#[test]
fn bidiagonal_qr_trivial() {
let mut d = vec![3.0, 2.0, 1.0];
let mut e = vec![0.0, 0.0];
let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 3);
assert!(result.is_ok());
}
#[test]
fn bidiagonal_qr_with_superdiag() {
let mut d = vec![4.0, 3.0];
let mut e = vec![1.0];
let mut u = vec![0.0; 4];
let mut vt = vec![0.0; 4];
let result = bidiagonal_svd_qr(&mut d, &mut e, Some(&mut u), Some(&mut vt), 2);
assert!(result.is_ok());
}
#[test]
fn bidiagonal_qr_empty() {
let mut d: Vec<f64> = Vec::new();
let mut e: Vec<f64> = Vec::new();
let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 0);
assert!(result.is_ok());
}
#[test]
fn sort_singular_values_descending() {
let mut d = vec![1.0, 3.0, 2.0];
sort_singular_values_desc(&mut d, None, None, 3);
assert!((d[0] - 3.0).abs() < 1e-15);
assert!((d[1] - 2.0).abs() < 1e-15);
assert!((d[2] - 1.0).abs() < 1e-15);
}
#[test]
fn sort_singular_values_with_negatives() {
let mut d = vec![-2.0, 1.0, -3.0];
sort_singular_values_desc(&mut d, None, None, 3);
assert!((d[0] - 3.0).abs() < 1e-15);
assert!((d[1] - 2.0).abs() < 1e-15);
assert!((d[2] - 1.0).abs() < 1e-15);
}
#[test]
fn dc_bidiagonal_base_case() {
let mut d = vec![5.0, 3.0, 1.0];
let mut e = vec![0.0, 0.0];
let result = dc_bidiagonal_svd(&mut d, &mut e, None, None, 3, 25);
assert!(result.is_ok());
}
#[test]
fn merge_orthogonal_blocks_identity() {
let mut u = vec![0.0_f64; 16]; let u_left = vec![1.0, 0.0, 0.0, 1.0]; let u_right = vec![1.0, 0.0, 0.0, 1.0]; merge_orthogonal_blocks(Some(&mut u), Some(&u_left), Some(&u_right), 2, 4);
assert!((u[0] - 1.0).abs() < 1e-15); assert!((u[5] - 1.0).abs() < 1e-15); assert!((u[10] - 1.0).abs() < 1e-15); assert!((u[15] - 1.0).abs() < 1e-15); }
#[test]
fn f64_conversion_roundtrip() {
let val = std::f64::consts::PI;
let converted: f64 = from_f64(to_f64(val));
assert!((converted - val).abs() < 1e-15);
}
#[test]
fn f32_conversion_roundtrip() {
let val = std::f32::consts::PI;
let as_f64 = to_f64(val);
let back: f32 = from_f64(as_f64);
assert!((back - val).abs() < 1e-5);
}
#[test]
fn dc_svd_config_threshold_1024() {
let cfg_large = DcSvdConfig::for_gpu(1024);
assert!(
cfg_large.use_divide_conquer,
"D&C should be enabled for n=1024"
);
assert_eq!(cfg_large.n_threshold, 1024);
let cfg_small = DcSvdConfig::for_gpu(512);
assert!(
!cfg_small.use_divide_conquer,
"D&C should be disabled for n=512"
);
}
#[test]
fn dc_svd_uses_bidiagonalization() {
let cfg_large = DcSvdConfig::for_gpu(256);
assert!(
cfg_large.bidiagonalization,
"bidiagonalization should be enabled for n=256"
);
let cfg_small = DcSvdConfig::for_gpu(128);
assert!(
!cfg_small.bidiagonalization,
"bidiagonalization should be disabled for n=128"
);
let cfg_very_large = DcSvdConfig::for_gpu(4096);
assert!(cfg_very_large.bidiagonalization);
assert!(cfg_very_large.use_divide_conquer);
}
#[test]
fn bidiagonalization_cpu_2x2() {
let mut d = vec![3.0_f64, 4.0];
let mut e = vec![1.0_f64];
let result = bidiagonal_svd_qr(&mut d, &mut e, None, None, 2);
assert!(result.is_ok(), "bidiagonal QR for 2x2 must succeed");
assert!(
e[0].abs() < 1e-10,
"off-diagonal e[0] = {} should be ~0",
e[0]
);
assert!(
d[0] >= 0.0 && d[1] >= 0.0,
"singular values must be non-negative"
);
}
#[test]
fn dc_svd_deflation_threshold_small() {
let eps = 2.22e-16_f64;
let n_vals: &[usize] = &[10, 100, 1000, 4096];
for &n in n_vals {
let cfg = DcSvdConfig::for_gpu(n);
let expected_tol = n as f64 * eps;
assert!(
(cfg.deflation_tol - expected_tol).abs() < 1e-30,
"deflation_tol for n={n}: got {}, expected {}",
cfg.deflation_tol,
expected_tol
);
}
}
}