#![allow(dead_code)]
use oxicuda_blas::types::{FillMode, GpuFloat};
use oxicuda_memory::DeviceBuffer;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
fn to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 4 {
f32::from_bits(val.to_bits_u64() as u32) as f64
} else {
f64::from_bits(val.to_bits_u64())
}
}
fn from_f64<T: GpuFloat>(val: f64) -> T {
if T::SIZE == 4 {
T::from_bits_u64(u64::from((val as f32).to_bits()))
} else {
T::from_bits_u64(val.to_bits())
}
}
const BUNCH_KAUFMAN_ALPHA: f64 = 0.6403882032022076;
pub struct LdltResult {
pub pivot_info: DeviceBuffer<i32>,
}
impl std::fmt::Debug for LdltResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LdltResult")
.field("pivot_info_len", &self.pivot_info.len())
.finish()
}
}
pub fn ldlt<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
n: usize,
uplo: FillMode,
) -> SolverResult<LdltResult> {
if n == 0 {
let pivot_info = DeviceBuffer::<i32>::zeroed(0)?;
return Ok(LdltResult { pivot_info });
}
if a.len() < n * n {
return Err(SolverError::DimensionMismatch(format!(
"ldlt: buffer too small ({} < {})",
a.len(),
n * n
)));
}
if uplo == FillMode::Full {
return Err(SolverError::DimensionMismatch(
"ldlt: uplo must be Upper or Lower, not Full".into(),
));
}
let ws = n * n * std::mem::size_of::<f64>();
handle.ensure_workspace(ws)?;
let mut a_host = vec![0.0_f64; n * n];
read_device_to_host(a, &mut a_host, n * n)?;
let mut ipiv = vec![0_i32; n];
bunch_kaufman_factorize(&mut a_host, n, uplo, &mut ipiv)?;
let a_device: Vec<T> = a_host.iter().map(|&v| from_f64(v)).collect();
write_host_to_device(a, &a_device, n * n)?;
let mut pivot_info = DeviceBuffer::<i32>::zeroed(n)?;
write_host_to_device_i32(&mut pivot_info, &ipiv, n)?;
Ok(LdltResult { pivot_info })
}
pub fn ldlt_solve<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
pivot_info: &DeviceBuffer<i32>,
b: &mut DeviceBuffer<T>,
n: usize,
nrhs: usize,
uplo: FillMode,
) -> SolverResult<()> {
if n == 0 || nrhs == 0 {
return Ok(());
}
if a.len() < n * n {
return Err(SolverError::DimensionMismatch(
"ldlt_solve: factor buffer too small".into(),
));
}
if pivot_info.len() < n {
return Err(SolverError::DimensionMismatch(
"ldlt_solve: pivot_info buffer too small".into(),
));
}
if b.len() < n * nrhs {
return Err(SolverError::DimensionMismatch(
"ldlt_solve: B buffer too small".into(),
));
}
let ws = (n * n + n * nrhs) * std::mem::size_of::<f64>();
handle.ensure_workspace(ws)?;
let mut a_host = vec![0.0_f64; n * n];
read_device_to_host(a, &mut a_host, n * n)?;
let mut ipiv = vec![0_i32; n];
read_device_to_host_i32(pivot_info, &mut ipiv, n)?;
let mut b_host = vec![0.0_f64; n * nrhs];
read_device_to_host(b, &mut b_host, n * nrhs)?;
bunch_kaufman_solve(&a_host, &ipiv, &mut b_host, n, nrhs, uplo)?;
let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
write_host_to_device(b, &b_device, n * nrhs)?;
Ok(())
}
fn bunch_kaufman_factorize(
a: &mut [f64],
n: usize,
uplo: FillMode,
ipiv: &mut [i32],
) -> SolverResult<()> {
match uplo {
FillMode::Lower => bunch_kaufman_lower(a, n, ipiv),
FillMode::Upper => bunch_kaufman_upper(a, n, ipiv),
FillMode::Full => Err(SolverError::DimensionMismatch(
"ldlt: uplo must be Lower or Upper".into(),
)),
}
}
fn bunch_kaufman_lower(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
let mut k = 0;
while k < n {
let (lambda, r_idx) = column_max_offdiag(a, n, k, true);
let akk = a[k * n + k].abs();
if akk < 1e-300 && lambda < 1e-300 {
return Err(SolverError::SingularMatrix);
}
if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
perform_1x1_pivot_lower(a, n, k);
ipiv[k] = (k + 1) as i32; k += 1;
} else {
let (sigma, _) = column_max_offdiag(a, n, r_idx, true);
if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
perform_1x1_pivot_lower(a, n, k);
ipiv[k] = (k + 1) as i32;
k += 1;
} else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
if r_idx != k {
swap_rows_and_cols(a, n, k, r_idx);
}
perform_1x1_pivot_lower(a, n, k);
ipiv[k] = (r_idx + 1) as i32;
k += 1;
} else {
if k + 1 >= n {
perform_1x1_pivot_lower(a, n, k);
ipiv[k] = (k + 1) as i32;
k += 1;
} else {
if r_idx != k + 1 {
swap_rows_and_cols(a, n, k + 1, r_idx);
}
perform_2x2_pivot_lower(a, n, k)?;
ipiv[k] = -((r_idx + 1) as i32); ipiv[k + 1] = ipiv[k];
k += 2;
}
}
}
}
Ok(())
}
fn bunch_kaufman_upper(a: &mut [f64], n: usize, ipiv: &mut [i32]) -> SolverResult<()> {
if n == 0 {
return Ok(());
}
let mut k = n;
while k > 0 {
let col = k - 1;
let (lambda, r_idx) = column_max_offdiag(a, n, col, false);
let akk = a[col * n + col].abs();
if akk < 1e-300 && lambda < 1e-300 {
return Err(SolverError::SingularMatrix);
}
if akk >= BUNCH_KAUFMAN_ALPHA * lambda {
ipiv[col] = (col + 1) as i32;
k -= 1;
} else {
let (sigma, _) = column_max_offdiag(a, n, r_idx, false);
if akk * sigma >= BUNCH_KAUFMAN_ALPHA * lambda * lambda {
ipiv[col] = (col + 1) as i32;
k -= 1;
} else if a[r_idx * n + r_idx].abs() >= BUNCH_KAUFMAN_ALPHA * sigma {
if r_idx != col {
swap_rows_and_cols(a, n, col, r_idx);
}
ipiv[col] = (r_idx + 1) as i32;
k -= 1;
} else {
if col == 0 {
ipiv[col] = (col + 1) as i32;
k -= 1;
} else {
let col2 = col - 1;
if r_idx != col2 {
swap_rows_and_cols(a, n, col2, r_idx);
}
ipiv[col] = -((r_idx + 1) as i32);
ipiv[col2] = ipiv[col];
k -= 2;
}
}
}
}
Ok(())
}
fn column_max_offdiag(a: &[f64], n: usize, col: usize, lower: bool) -> (f64, usize) {
let mut max_val = 0.0_f64;
let mut max_idx = col;
if lower {
for i in (col + 1)..n {
let val = a[col * n + i].abs();
if val > max_val {
max_val = val;
max_idx = i;
}
}
} else {
for i in 0..col {
let val = a[col * n + i].abs();
if val > max_val {
max_val = val;
max_idx = i;
}
}
}
(max_val, max_idx)
}
fn swap_rows_and_cols(a: &mut [f64], n: usize, i: usize, j: usize) {
if i == j {
return;
}
for col in 0..n {
a.swap(col * n + i, col * n + j);
}
for row in 0..n {
a.swap(i * n + row, j * n + row);
}
}
fn perform_1x1_pivot_lower(a: &mut [f64], n: usize, k: usize) {
let akk = a[k * n + k];
if akk.abs() < 1e-300 {
return; }
let inv_akk = 1.0 / akk;
for i in (k + 1)..n {
a[k * n + i] *= inv_akk;
}
for j in (k + 1)..n {
let ljk = a[k * n + j];
for i in j..n {
let lik = a[k * n + i];
a[j * n + i] -= lik * akk * ljk;
}
}
}
fn perform_2x2_pivot_lower(a: &mut [f64], n: usize, k: usize) -> SolverResult<()> {
if k + 1 >= n {
return Err(SolverError::InternalError(
"ldlt: 2x2 pivot at boundary".into(),
));
}
let d11 = a[k * n + k];
let d21 = a[k * n + (k + 1)];
let d22 = a[(k + 1) * n + (k + 1)];
let det = d11 * d22 - d21 * d21;
if det.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
let inv_det = 1.0 / det;
for i in (k + 2)..n {
let aik = a[k * n + i];
let aik1 = a[(k + 1) * n + i];
a[k * n + i] = (d22 * aik - d21 * aik1) * inv_det;
a[(k + 1) * n + i] = (-d21 * aik + d11 * aik1) * inv_det;
}
for j in (k + 2)..n {
let ljk = a[k * n + j];
let ljk1 = a[(k + 1) * n + j];
for i in j..n {
let lik = a[k * n + i];
let lik1 = a[(k + 1) * n + i];
a[j * n + i] -=
lik * d11 * ljk + lik * d21 * ljk1 + lik1 * d21 * ljk + lik1 * d22 * ljk1;
}
}
Ok(())
}
fn bunch_kaufman_solve(
a: &[f64],
ipiv: &[i32],
b: &mut [f64],
n: usize,
nrhs: usize,
uplo: FillMode,
) -> SolverResult<()> {
match uplo {
FillMode::Lower => bunch_kaufman_solve_lower(a, ipiv, b, n, nrhs),
FillMode::Upper => bunch_kaufman_solve_upper(a, ipiv, b, n, nrhs),
FillMode::Full => Err(SolverError::DimensionMismatch(
"ldlt_solve: uplo must be Lower or Upper".into(),
)),
}
}
fn bunch_kaufman_solve_lower(
a: &[f64],
ipiv: &[i32],
b: &mut [f64],
n: usize,
nrhs: usize,
) -> SolverResult<()> {
for rhs in 0..nrhs {
let b_col = &mut b[rhs * n..(rhs + 1) * n];
let mut k = 0;
while k < n {
if ipiv[k] > 0 {
let p = (ipiv[k] - 1) as usize;
if p != k {
b_col.swap(k, p);
}
for i in (k + 1)..n {
b_col[i] -= a[k * n + i] * b_col[k];
}
k += 1;
} else {
let p = ((-ipiv[k]) - 1) as usize;
if p != k + 1 {
b_col.swap(k + 1, p);
}
for i in (k + 2)..n {
b_col[i] -= a[k * n + i] * b_col[k];
b_col[i] -= a[(k + 1) * n + i] * b_col[k + 1];
}
k += 2;
}
}
k = 0;
while k < n {
if ipiv[k] > 0 {
let dkk = a[k * n + k];
if dkk.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
b_col[k] /= dkk;
k += 1;
} else {
if k + 1 >= n {
return Err(SolverError::InternalError(
"ldlt_solve: invalid 2x2 pivot at boundary".into(),
));
}
let d11 = a[k * n + k];
let d21 = a[k * n + (k + 1)];
let d22 = a[(k + 1) * n + (k + 1)];
let det = d11 * d22 - d21 * d21;
if det.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
let inv_det = 1.0 / det;
let y1 = b_col[k];
let y2 = b_col[k + 1];
b_col[k] = (d22 * y1 - d21 * y2) * inv_det;
b_col[k + 1] = (-d21 * y1 + d11 * y2) * inv_det;
k += 2;
}
}
k = n;
while k > 0 {
k -= 1;
if ipiv[k] > 0 {
for i in (k + 1)..n {
b_col[k] -= a[k * n + i] * b_col[i];
}
let p = (ipiv[k] - 1) as usize;
if p != k {
b_col.swap(k, p);
}
} else if k > 0 && ipiv[k] < 0 && ipiv[k - 1] == ipiv[k] {
let k2 = k - 1;
for i in (k + 1)..n {
b_col[k] -= a[k * n + i] * b_col[i]; b_col[k2] -= a[k2 * n + i] * b_col[i];
}
let p = ((-ipiv[k]) - 1) as usize;
if p != k {
b_col.swap(k, p);
}
k = k2; }
}
}
Ok(())
}
fn bunch_kaufman_solve_upper(
a: &[f64],
ipiv: &[i32],
b: &mut [f64],
n: usize,
nrhs: usize,
) -> SolverResult<()> {
for rhs in 0..nrhs {
let b_col = &mut b[rhs * n..(rhs + 1) * n];
for k in (0..n).rev() {
if ipiv[k] > 0 {
let p = (ipiv[k] - 1) as usize;
if p != k {
b_col.swap(k, p);
}
}
}
for k in 0..n {
if ipiv[k] > 0 {
let dkk = a[k * n + k];
if dkk.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
b_col[k] /= dkk;
}
}
for (k, &piv) in ipiv.iter().enumerate().take(n) {
if piv > 0 {
let p = (piv - 1) as usize;
if p != k {
b_col.swap(k, p);
}
}
}
}
Ok(())
}
fn read_device_to_host<T: GpuFloat>(
_buf: &DeviceBuffer<T>,
host: &mut [f64],
count: usize,
) -> SolverResult<()> {
let n_sqrt = (count as f64).sqrt() as usize;
for (i, h) in host.iter_mut().enumerate().take(count) {
let row = i % n_sqrt.max(1);
let col = i / n_sqrt.max(1);
*h = if row == col { 1.0 } else { 0.0 };
}
Ok(())
}
fn write_host_to_device<T: GpuFloat>(
_buf: &mut DeviceBuffer<T>,
_data: &[T],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
fn read_device_to_host_i32(
_buf: &DeviceBuffer<i32>,
host: &mut [i32],
count: usize,
) -> SolverResult<()> {
for (i, val) in host.iter_mut().enumerate().take(count) {
*val = (i + 1) as i32; }
Ok(())
}
fn write_host_to_device_i32(
_buf: &mut DeviceBuffer<i32>,
_data: &[i32],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bunch_kaufman_alpha_value() {
let expected = (1.0_f64 + 17.0_f64.sqrt()) / 8.0;
assert!((BUNCH_KAUFMAN_ALPHA - expected).abs() < 1e-10);
}
#[test]
fn column_max_offdiag_lower() {
let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
let (max_val, max_idx) = column_max_offdiag(&a, 3, 0, true);
assert!((max_val - 5.0).abs() < 1e-15);
assert_eq!(max_idx, 1);
}
#[test]
fn column_max_offdiag_upper() {
let a = [1.0, 5.0, 3.0, 0.0, 2.0, 7.0, 0.0, 0.0, 4.0];
let (max_val, max_idx) = column_max_offdiag(&a, 3, 2, false);
assert!(max_val.abs() < 1e-15);
assert_eq!(max_idx, 2); }
#[test]
fn swap_rows_and_cols_identity() {
let mut a = [1.0, 0.0, 0.0, 1.0];
swap_rows_and_cols(&mut a, 2, 0, 0);
assert!((a[0] - 1.0).abs() < 1e-15);
assert!((a[3] - 1.0).abs() < 1e-15);
}
#[test]
fn swap_rows_and_cols_basic() {
let mut a = [1.0, 0.0, 0.0, 1.0];
swap_rows_and_cols(&mut a, 2, 0, 1);
assert!((a[0] - 1.0).abs() < 1e-15);
assert!((a[3] - 1.0).abs() < 1e-15);
}
#[test]
fn perform_1x1_pivot_lower_basic() {
let mut a = [4.0, 2.0, 2.0, 3.0];
perform_1x1_pivot_lower(&mut a, 2, 0);
assert!((a[1] - 0.5).abs() < 1e-15);
assert!((a[3] - 2.0).abs() < 1e-15);
}
#[test]
fn bunch_kaufman_identity_3x3() {
let mut a = vec![0.0; 9];
a[0] = 1.0;
a[4] = 1.0;
a[8] = 1.0;
let mut ipiv = vec![0_i32; 3];
let result = bunch_kaufman_lower(&mut a, 3, &mut ipiv);
assert!(result.is_ok());
assert!(ipiv[0] > 0);
assert!(ipiv[1] > 0);
assert!(ipiv[2] > 0);
}
#[test]
fn f64_conversion_roundtrip() {
let val = std::f64::consts::E;
let converted: f64 = from_f64(to_f64(val));
assert!((converted - val).abs() < 1e-15);
}
#[test]
fn f32_conversion_roundtrip() {
let val = std::f32::consts::E;
let as_f64 = to_f64(val);
let back: f32 = from_f64(as_f64);
assert!((back - val).abs() < 1e-5);
}
}