#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
fn to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 4 {
f32::from_bits(val.to_bits_u64() as u32) as f64
} else {
f64::from_bits(val.to_bits_u64())
}
}
fn from_f64<T: GpuFloat>(val: f64) -> T {
if T::SIZE == 4 {
T::from_bits_u64(u64::from((val as f32).to_bits()))
} else {
T::from_bits_u64(val.to_bits())
}
}
pub struct BandMatrix<T: GpuFloat> {
pub data: DeviceBuffer<T>,
pub n: usize,
pub kl: usize,
pub ku: usize,
}
impl<T: GpuFloat> BandMatrix<T> {
pub fn new(n: usize, kl: usize, ku: usize) -> SolverResult<Self> {
let ldab = 2 * kl + ku + 1;
let data = DeviceBuffer::<T>::zeroed(ldab * n)?;
Ok(Self { n, kl, ku, data })
}
pub fn ldab(&self) -> usize {
2 * self.kl + self.ku + 1
}
pub fn storage_len(&self) -> usize {
self.ldab() * self.n
}
pub fn band_index(&self, i: usize, j: usize) -> Option<usize> {
let row_in_band = self.kl + i;
if row_in_band < j {
return None; }
let band_row = row_in_band - j;
if band_row >= self.ldab() {
return None; }
Some(j * self.ldab() + band_row)
}
}
pub fn band_lu<T: GpuFloat>(
handle: &mut SolverHandle,
band: &mut BandMatrix<T>,
pivots: &mut DeviceBuffer<i32>,
) -> SolverResult<()> {
let n = band.n;
let kl = band.kl;
let ku = band.ku;
if n == 0 {
return Ok(());
}
if pivots.len() < n {
return Err(SolverError::DimensionMismatch(format!(
"band_lu: pivots buffer too small ({} < {n})",
pivots.len()
)));
}
if band.data.len() < band.storage_len() {
return Err(SolverError::DimensionMismatch(format!(
"band_lu: band data buffer too small ({} < {})",
band.data.len(),
band.storage_len()
)));
}
let ldab = band.ldab();
let ws = ldab * n * std::mem::size_of::<f64>();
handle.ensure_workspace(ws)?;
let mut ab = vec![0.0_f64; ldab * n];
read_band_to_host(&band.data, &mut ab, ldab * n)?;
let mut ipiv = vec![0_i32; n];
band_lu_host(&mut ab, n, kl, ku, ldab, &mut ipiv)?;
write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
write_pivots_to_device(pivots, &ipiv, n)?;
Ok(())
}
pub fn band_solve<T: GpuFloat>(
handle: &mut SolverHandle,
band: &BandMatrix<T>,
pivots: &DeviceBuffer<i32>,
b: &mut DeviceBuffer<T>,
n: usize,
nrhs: usize,
) -> SolverResult<()> {
if n == 0 || nrhs == 0 {
return Ok(());
}
if band.n != n {
return Err(SolverError::DimensionMismatch(format!(
"band_solve: band matrix dimension ({}) != n ({n})",
band.n
)));
}
if pivots.len() < n {
return Err(SolverError::DimensionMismatch(
"band_solve: pivots buffer too small".into(),
));
}
if b.len() < n * nrhs {
return Err(SolverError::DimensionMismatch(
"band_solve: B buffer too small".into(),
));
}
let ldab = band.ldab();
let kl = band.kl;
let ku = band.ku;
let ws = (ldab * n + n * nrhs) * std::mem::size_of::<f64>();
handle.ensure_workspace(ws)?;
let mut ab = vec![0.0_f64; ldab * n];
read_band_to_host(&band.data, &mut ab, ldab * n)?;
let mut ipiv = vec![0_i32; n];
read_pivots_from_device(pivots, &mut ipiv, n)?;
let mut b_host = vec![0.0_f64; n * nrhs];
read_band_to_host(b, &mut b_host, n * nrhs)?;
band_solve_host(&ab, &ipiv, &mut b_host, n, kl, ku, ldab, nrhs)?;
let b_device: Vec<T> = b_host.iter().map(|&v| from_f64(v)).collect();
write_host_to_band_t(b, &b_device, n * nrhs)?;
Ok(())
}
pub fn band_cholesky<T: GpuFloat>(
handle: &mut SolverHandle,
band: &mut BandMatrix<T>,
) -> SolverResult<()> {
let n = band.n;
let kl = band.kl;
let ku = band.ku;
if n == 0 {
return Ok(());
}
if kl != ku {
return Err(SolverError::DimensionMismatch(format!(
"band_cholesky: kl ({kl}) must equal ku ({ku}) for symmetric matrix"
)));
}
let ldab = band.ldab();
let ws = ldab * n * std::mem::size_of::<f64>();
handle.ensure_workspace(ws)?;
let mut ab = vec![0.0_f64; ldab * n];
read_band_to_host(&band.data, &mut ab, ldab * n)?;
band_cholesky_host(&mut ab, n, kl, ldab)?;
write_host_to_band_f64(&mut band.data, &ab, ldab * n)?;
Ok(())
}
fn band_lu_host(
ab: &mut [f64],
n: usize,
kl: usize,
ku: usize,
ldab: usize,
ipiv: &mut [i32],
) -> SolverResult<()> {
for k in 0..n {
let mut max_val = 0.0_f64;
let mut max_idx = k;
let end_row = n.min(k + kl + 1);
for i in k..end_row {
let band_row = kl + i - k;
if band_row < ldab {
let val = ab[k * ldab + band_row].abs();
if val > max_val {
max_val = val;
max_idx = i;
}
}
}
ipiv[k] = max_idx as i32;
if max_val < 1e-300 {
return Err(SolverError::SingularMatrix);
}
if max_idx != k {
let p = max_idx;
let col_start = k.saturating_sub(ku);
let col_end = n.min(k + kl + ku + 1);
for j in col_start..col_end {
let row_k = kl + k;
let row_p = kl + p;
if row_k >= j && row_k - j < ldab && row_p >= j && row_p - j < ldab {
ab.swap(j * ldab + (row_k - j), j * ldab + (row_p - j));
}
}
}
let pivot = ab[k * ldab + kl];
if pivot.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
for i in (k + 1)..end_row {
let band_row = kl + i - k;
if band_row < ldab {
let mult = ab[k * ldab + band_row] / pivot;
ab[k * ldab + band_row] = mult;
let update_end = n.min(k + ku + 1);
for j in (k + 1)..update_end {
let src_row = kl + k - j + (j - k); let dst_row = kl + i - j;
if src_row < ldab && dst_row < ldab && j < n {
ab[j * ldab + dst_row] -= mult * ab[j * ldab + src_row];
}
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn band_solve_host(
ab: &[f64],
ipiv: &[i32],
b: &mut [f64],
n: usize,
kl: usize,
_ku: usize,
ldab: usize,
nrhs: usize,
) -> SolverResult<()> {
for rhs in 0..nrhs {
let b_col = &mut b[rhs * n..(rhs + 1) * n];
for (k, &piv) in ipiv.iter().enumerate().take(n) {
let p = piv as usize;
if p != k {
b_col.swap(k, p);
}
}
for k in 0..n {
let end_row = n.min(k + kl + 1);
for i in (k + 1)..end_row {
let band_row = kl + i - k;
if band_row < ldab {
let mult = ab[k * ldab + band_row];
b_col[i] -= mult * b_col[k];
}
}
}
for k in (0..n).rev() {
let pivot = ab[k * ldab + kl];
if pivot.abs() < 1e-300 {
return Err(SolverError::SingularMatrix);
}
b_col[k] /= pivot;
let start_row = k.saturating_sub(kl);
for i in start_row..k {
let _band_row = kl + i - k;
let idx = kl + i;
if idx >= k {
let br = idx - k;
if br < ldab {
b_col[i] -= ab[k * ldab + br] * b_col[k];
}
}
}
}
}
Ok(())
}
fn band_cholesky_host(
ab: &mut [f64],
n: usize,
kd: usize, ldab: usize,
) -> SolverResult<()> {
for j in 0..n {
let diag_idx = kd; let mut sum = ab[j * ldab + diag_idx];
let k_start = j.saturating_sub(kd);
for k in k_start..j {
let band_row_jk = kd + j - k;
if band_row_jk < ldab {
let ljk = ab[k * ldab + band_row_jk];
sum -= ljk * ljk;
}
}
if sum <= 0.0 {
return Err(SolverError::NotPositiveDefinite);
}
let ljj = sum.sqrt();
ab[j * ldab + diag_idx] = ljj;
let end_row = n.min(j + kd + 1);
for i in (j + 1)..end_row {
let band_row_ij = kd + i - j;
if band_row_ij >= ldab {
continue;
}
let mut s = ab[j * ldab + band_row_ij];
for k in k_start..j {
let br_ik = kd + i - k;
let br_jk = kd + j - k;
if br_ik < ldab && br_jk < ldab {
s -= ab[k * ldab + br_ik] * ab[k * ldab + br_jk];
}
}
ab[j * ldab + band_row_ij] = s / ljj;
}
}
Ok(())
}
fn read_band_to_host<T: GpuFloat>(
_buf: &DeviceBuffer<T>,
host: &mut [f64],
count: usize,
) -> SolverResult<()> {
for val in host.iter_mut().take(count) {
*val = 0.0;
}
Ok(())
}
fn write_host_to_band_f64<T: GpuFloat>(
_buf: &mut DeviceBuffer<T>,
_data: &[f64],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
fn write_host_to_band_t<T: GpuFloat>(
_buf: &mut DeviceBuffer<T>,
_data: &[T],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
fn write_pivots_to_device(
_buf: &mut DeviceBuffer<i32>,
_data: &[i32],
_count: usize,
) -> SolverResult<()> {
Ok(())
}
fn read_pivots_from_device(
_buf: &DeviceBuffer<i32>,
host: &mut [i32],
count: usize,
) -> SolverResult<()> {
for (i, val) in host.iter_mut().enumerate().take(count) {
*val = i as i32;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn band_index_tridiagonal() {
let n = 5_usize;
let kl = 1_usize;
let ku = 1_usize;
let ldab = 2 * kl + ku + 1;
let row_in_band = kl + 2; assert!(row_in_band >= 2); let band_row = row_in_band - 2; assert!(band_row < ldab);
let idx = 2 * ldab + band_row; assert_eq!(idx, 9);
let _ = n;
}
#[test]
fn band_index_out_of_band() {
let kl = 1_usize;
let row_in_band = kl; let j = 3_usize;
assert!(row_in_band < j);
}
#[test]
fn band_matrix_ldab_formula() {
let kl = 2_usize;
let ku = 3_usize;
let ldab = 2 * kl + ku + 1;
assert_eq!(ldab, 8);
}
#[test]
fn band_lu_host_tridiagonal() {
let ldab = 4;
let n = 3;
let mut ab = vec![0.0_f64; ldab * n];
ab[1] = 2.0; ab[2] = -1.0;
ab[ldab] = -1.0; ab[ldab + 1] = 2.0; ab[ldab + 2] = -1.0;
ab[2 * ldab] = -1.0;
ab[2 * ldab + 1] = 2.0;
let mut ipiv = vec![0_i32; n];
let result = band_lu_host(&mut ab, n, 1, 1, ldab, &mut ipiv);
assert!(result.is_ok());
}
#[test]
fn band_cholesky_host_tridiagonal() {
let kd = 1;
let ldab = 2 * kd + kd + 1; let n = 3;
let mut ab = vec![0.0_f64; ldab * n];
ab[1] = 2.0; ab[2] = -1.0;
ab[ldab + 1] = 2.0; ab[ldab + 2] = -1.0;
ab[2 * ldab + 1] = 2.0;
let result = band_cholesky_host(&mut ab, n, kd, ldab);
assert!(result.is_ok());
assert!((ab[1] - 2.0_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn band_cholesky_host_not_spd() {
let kd = 1;
let ldab = 4;
let n = 2;
let mut ab = vec![0.0_f64; ldab * n];
ab[1] = -1.0; ab[ldab + 1] = 2.0;
let result = band_cholesky_host(&mut ab, n, kd, ldab);
assert!(result.is_err());
}
#[test]
fn f64_conversion_roundtrip() {
let val = std::f64::consts::E;
let converted: f64 = from_f64(to_f64(val));
assert!((converted - val).abs() < 1e-15);
}
#[test]
fn f32_conversion_roundtrip() {
let val = std::f32::consts::E;
let as_f64 = to_f64(val);
let back: f32 = from_f64(as_f64);
assert!((back - val).abs() < 1e-5);
}
}