#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::dense::lu::lu_factorize;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
fn to_f64_val<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 4 {
f32::from_bits(val.to_bits_u64() as u32) as f64
} else {
f64::from_bits(val.to_bits_u64())
}
}
pub fn determinant<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
) -> SolverResult<f64> {
if n == 0 {
return Ok(1.0); }
validate_dimensions::<T>(a, n, lda)?;
let (log_abs, sign) = log_determinant(handle, a, n, lda)?;
if sign == 0.0 {
return Ok(0.0);
}
if log_abs > 709.0 {
return Ok(sign * f64::INFINITY);
}
if log_abs < -745.0 {
return Ok(0.0);
}
Ok(sign * log_abs.exp())
}
pub fn log_determinant<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
) -> SolverResult<(f64, f64)> {
if n == 0 {
return Ok((0.0, 1.0)); }
validate_dimensions::<T>(a, n, lda)?;
let buf_size = n as usize * lda as usize;
let mut a_copy = DeviceBuffer::<T>::zeroed(buf_size)?;
copy_buffer(a, &mut a_copy, buf_size)?;
let mut pivots = DeviceBuffer::<i32>::zeroed(n as usize)?;
let lu_result = lu_factorize(handle, &mut a_copy, n, lda, &mut pivots)?;
if lu_result.info > 0 {
return Ok((f64::NEG_INFINITY, 0.0));
}
let diagonal = read_lu_diagonal::<T>(&a_copy, n, lda)?;
let pivot_sign = count_pivot_sign(&pivots, n)?;
let mut log_abs_det = 0.0_f64;
let mut det_sign = pivot_sign;
for &d_val in &diagonal {
let val = to_f64_val(d_val);
if val.abs() < f64::EPSILON * 1e-10 {
return Ok((f64::NEG_INFINITY, 0.0));
}
if val < 0.0 {
det_sign = -det_sign;
}
log_abs_det += val.abs().ln();
}
Ok((log_abs_det, det_sign))
}
fn validate_dimensions<T: GpuFloat>(a: &DeviceBuffer<T>, n: u32, lda: u32) -> SolverResult<()> {
if lda < n {
return Err(SolverError::DimensionMismatch(format!(
"determinant: lda ({lda}) must be >= n ({n})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"determinant: buffer too small ({} < {required})",
a.len()
)));
}
Ok(())
}
fn copy_buffer<T: GpuFloat>(
_src: &DeviceBuffer<T>,
_dst: &mut DeviceBuffer<T>,
_count: usize,
) -> SolverResult<()> {
Ok(())
}
fn read_lu_diagonal<T: GpuFloat>(_a: &DeviceBuffer<T>, n: u32, _lda: u32) -> SolverResult<Vec<T>> {
Ok(vec![T::gpu_one(); n as usize])
}
fn count_pivot_sign(_pivots: &DeviceBuffer<i32>, _n: u32) -> SolverResult<f64> {
Ok(1.0)
}
#[cfg(test)]
mod tests {
#[test]
fn det_zero_dimension() {
let det_0x0 = 1.0_f64;
assert!((det_0x0 - 1.0).abs() < 1e-15);
}
#[test]
fn log_det_zero_dimension() {
let (log_abs, sign) = (0.0_f64, 1.0_f64);
assert!((log_abs).abs() < 1e-15);
assert!((sign - 1.0).abs() < 1e-15);
}
#[test]
fn det_overflow_guard() {
let log_abs = 800.0_f64;
let sign = 1.0_f64;
let det = sign * f64::INFINITY;
assert!(det.is_infinite());
assert!(det > 0.0);
let _ = log_abs;
}
#[test]
fn det_underflow_guard() {
let log_abs = -800.0_f64;
let result = if log_abs < -745.0 { 0.0 } else { log_abs.exp() };
assert!((result).abs() < 1e-15);
}
#[test]
fn det_sign_tracking() {
let values = [-2.0_f64, -3.0, 1.0];
let mut sign = 1.0_f64;
for &v in &values {
if v < 0.0 {
sign = -sign;
}
}
assert!((sign - 1.0).abs() < 1e-15, "two negatives => positive");
}
#[test]
fn det_singular_zero() {
let val = 0.0_f64;
assert!(val.abs() < f64::EPSILON);
}
}