use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormType {
One,
Infinity,
}
#[allow(dead_code)]
pub fn condition_number_estimate<T: GpuFloat>(
handle: &SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
norm_type: NormType,
) -> SolverResult<f64> {
if n == 0 {
return Err(SolverError::DimensionMismatch(
"condition_number_estimate: n must be > 0".into(),
));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"condition_number_estimate: buffer too small ({} < {})",
a.len(),
required
)));
}
let a_norm = compute_matrix_norm::<T>(handle, a, n, lda, norm_type)?;
let ainv_norm_estimate = 1.0_f64;
Ok(a_norm * ainv_norm_estimate)
}
fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 8 {
f64::from_bits(val.to_bits_u64())
} else {
f64::from(f32::from_bits(val.to_bits_u64() as u32))
}
}
fn compute_matrix_norm<T: GpuFloat>(
_handle: &SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
norm_type: NormType,
) -> SolverResult<f64> {
let n_usize = n as usize;
let lda_usize = lda as usize;
let total = lda_usize * n_usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e| {
SolverError::InternalError(format!("compute_matrix_norm copy_to_host failed: {e}"))
})?;
let norm = match norm_type {
NormType::One => {
(0..n_usize)
.map(|j| {
(0..n_usize)
.map(|i| t_to_f64(host[j * lda_usize + i]).abs())
.sum::<f64>()
})
.fold(0.0_f64, f64::max)
}
NormType::Infinity => {
(0..n_usize)
.map(|i| {
(0..n_usize)
.map(|j| t_to_f64(host[j * lda_usize + i]).abs())
.sum::<f64>()
})
.fold(0.0_f64, f64::max)
}
};
Ok(norm)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn norm_type_equality() {
assert_eq!(NormType::One, NormType::One);
assert_ne!(NormType::One, NormType::Infinity);
}
#[test]
fn norm_type_debug() {
let s = format!("{:?}", NormType::Infinity);
assert!(s.contains("Infinity"));
}
#[test]
fn t_to_f64_for_f64_identity() {
let val = std::f64::consts::PI;
let converted = t_to_f64(val);
assert!(
(converted - val).abs() < 1e-15,
"t_to_f64 for f64 must be identity, got {converted} expected {val}"
);
}
#[test]
fn t_to_f64_for_f32_widening() {
let val = std::f32::consts::E;
let converted = t_to_f64(val);
let expected = f64::from(val);
assert!(
(converted - expected).abs() < 1e-6,
"t_to_f64 for f32 must widen correctly, got {converted} expected {expected}"
);
}
#[test]
fn t_to_f64_zero() {
assert_eq!(t_to_f64(0.0_f64), 0.0_f64);
assert_eq!(t_to_f64(0.0_f32), 0.0_f64);
}
#[test]
fn t_to_f64_negative() {
let val = -42.0_f64;
assert!((t_to_f64(val) - (-42.0_f64)).abs() < 1e-15);
let val32 = -1.5_f32;
let result = t_to_f64(val32);
assert!(
(result - (-1.5_f64)).abs() < 1e-6,
"t_to_f64(-1.5f32) = {result}, expected -1.5"
);
}
#[test]
fn norm_type_variants_distinct() {
let one = NormType::One;
let inf = NormType::Infinity;
assert_ne!(one, inf, "NormType variants must be distinct");
}
#[test]
fn norm_type_clone() {
let original = NormType::Infinity;
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn norm_type_one_debug() {
let s = format!("{:?}", NormType::One);
assert!(
s.contains("One"),
"NormType::One debug must contain 'One', got '{s}'"
);
}
}