use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::dense::lu;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormType {
One,
Infinity,
}
#[allow(dead_code)]
pub fn condition_number_estimate<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
norm_type: NormType,
) -> SolverResult<f64> {
if n == 0 {
return Err(SolverError::DimensionMismatch(
"condition_number_estimate: n must be > 0".into(),
));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"condition_number_estimate: buffer too small ({} < {})",
a.len(),
required
)));
}
let a_norm = compute_matrix_norm::<T>(handle, a, n, lda, norm_type)?;
let ainv_norm_estimate = estimate_inverse_norm_hager::<T>(handle, a, n, lda, norm_type)?;
Ok(a_norm * ainv_norm_estimate)
}
fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
if T::SIZE == 8 {
f64::from_bits(val.to_bits_u64())
} else {
f64::from(f32::from_bits(val.to_bits_u64() as u32))
}
}
fn compute_matrix_norm<T: GpuFloat>(
_handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
norm_type: NormType,
) -> SolverResult<f64> {
let n_usize = n as usize;
let lda_usize = lda as usize;
let total = lda_usize * n_usize;
let mut host = vec![T::gpu_zero(); total];
a.copy_to_host(&mut host).map_err(|e| {
SolverError::InternalError(format!("compute_matrix_norm copy_to_host failed: {e}"))
})?;
let norm = match norm_type {
NormType::One => {
(0..n_usize)
.map(|j| {
(0..n_usize)
.map(|i| t_to_f64(host[j * lda_usize + i]).abs())
.sum::<f64>()
})
.fold(0.0_f64, f64::max)
}
NormType::Infinity => {
(0..n_usize)
.map(|i| {
(0..n_usize)
.map(|j| t_to_f64(host[j * lda_usize + i]).abs())
.sum::<f64>()
})
.fold(0.0_f64, f64::max)
}
};
Ok(norm)
}
fn estimate_inverse_norm_hager<T: GpuFloat>(
handle: &mut SolverHandle,
a: &DeviceBuffer<T>,
n: u32,
lda: u32,
_norm_type: NormType,
) -> SolverResult<f64> {
let n_usize = n as usize;
let lda_usize = lda as usize;
const MAX_ITER: usize = 5;
const CONV_TOL: f64 = 0.95;
let mut lu_host = vec![T::gpu_zero(); lda_usize * n_usize];
a.copy_to_host(&mut lu_host).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy_from_device failed: {e}"
))
})?;
let mut lu_device = DeviceBuffer::<T>::alloc(n_usize * lda_usize).map_err(|e| {
SolverError::InternalError(format!("estimate_inverse_norm_hager: alloc LU buffer: {e}"))
})?;
lu_device.copy_from_host(&lu_host).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy to device failed: {e}"
))
})?;
let mut pivots = DeviceBuffer::<i32>::alloc(n_usize).map_err(|e| {
SolverError::InternalError(format!("estimate_inverse_norm_hager: alloc pivots: {e}"))
})?;
let lu_result = lu::lu_factorize(handle, &mut lu_device, n, lda, &mut pivots)?;
if lu_result.info != 0 {
return Err(SolverError::InternalError(format!(
"estimate_inverse_norm_hager: LU factorization failed (info={})",
lu_result.info
)));
}
let init_val = 1.0 / (n_usize as f64);
let mut x = vec![init_val; n_usize];
let mut best_estimate = 0.0_f64;
for _iter in 0..MAX_ITER {
let mut w_host = x
.iter()
.map(|&v| {
if T::SIZE == 8 {
T::from_bits_u64(v.to_bits())
} else {
T::from_bits_u64(u64::from((v as f32).to_bits()))
}
})
.collect::<Vec<_>>();
let mut w_device = DeviceBuffer::<T>::alloc(n_usize).map_err(|e| {
SolverError::InternalError(format!("estimate_inverse_norm_hager: alloc w: {e}"))
})?;
w_device.copy_from_host(&w_host).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy w to device: {e}"
))
})?;
lu::lu_solve(handle, &lu_device, &pivots, &mut w_device, n, 1)?;
w_device.copy_to_host(&mut w_host).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy w from device: {e}"
))
})?;
let w_norm_1 = w_host.iter().map(|&v| t_to_f64(v).abs()).sum::<f64>();
if w_norm_1 <= CONV_TOL * best_estimate {
best_estimate = w_norm_1;
break;
}
best_estimate = w_norm_1;
let zeta = w_host
.iter()
.map(|&v| {
let fv = t_to_f64(v);
if fv > 0.0 {
if T::SIZE == 8 {
T::from_bits_u64(1.0_f64.to_bits())
} else {
T::from_bits_u64(u64::from((1.0_f32).to_bits()))
}
} else if fv < 0.0 {
if T::SIZE == 8 {
T::from_bits_u64((-1.0_f64).to_bits())
} else {
T::from_bits_u64(u64::from((-1.0_f32).to_bits()))
}
} else {
T::gpu_zero()
}
})
.collect::<Vec<_>>();
let mut z = zeta.clone();
let mut z_device = DeviceBuffer::<T>::alloc(n_usize).map_err(|e| {
SolverError::InternalError(format!("estimate_inverse_norm_hager: alloc z: {e}"))
})?;
z_device.copy_from_host(&z).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy z to device: {e}"
))
})?;
lu::lu_solve(handle, &lu_device, &pivots, &mut z_device, n, 1)?;
z_device.copy_to_host(&mut z).map_err(|e| {
SolverError::InternalError(format!(
"estimate_inverse_norm_hager: copy z from device: {e}"
))
})?;
let (j_max, z_inf_norm) = z
.iter()
.enumerate()
.map(|(i, &v)| (i, t_to_f64(v).abs()))
.fold((0, 0.0_f64), |(i_max, max_so_far), (i, norm)| {
if norm > max_so_far {
(i, norm)
} else {
(i_max, max_so_far)
}
});
let z_dot_x = z
.iter()
.zip(x.iter())
.map(|(&zi, &xi)| t_to_f64(zi) * xi)
.sum::<f64>();
if z_inf_norm <= z_dot_x {
break;
}
x.iter_mut().for_each(|xi| *xi = 0.0);
x[j_max] = 1.0;
}
Ok(best_estimate)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn norm_type_equality() {
assert_eq!(NormType::One, NormType::One);
assert_ne!(NormType::One, NormType::Infinity);
}
#[test]
fn norm_type_debug() {
let s = format!("{:?}", NormType::Infinity);
assert!(s.contains("Infinity"));
}
#[test]
fn t_to_f64_for_f64_identity() {
let val = std::f64::consts::PI;
let converted = t_to_f64(val);
assert!(
(converted - val).abs() < 1e-15,
"t_to_f64 for f64 must be identity, got {converted} expected {val}"
);
}
#[test]
fn t_to_f64_for_f32_widening() {
let val = std::f32::consts::E;
let converted = t_to_f64(val);
let expected = f64::from(val);
assert!(
(converted - expected).abs() < 1e-6,
"t_to_f64 for f32 must widen correctly, got {converted} expected {expected}"
);
}
#[test]
fn t_to_f64_zero() {
assert_eq!(t_to_f64(0.0_f64), 0.0_f64);
assert_eq!(t_to_f64(0.0_f32), 0.0_f64);
}
#[test]
fn t_to_f64_negative() {
let val = -42.0_f64;
assert!((t_to_f64(val) - (-42.0_f64)).abs() < 1e-15);
let val32 = -1.5_f32;
let result = t_to_f64(val32);
assert!(
(result - (-1.5_f64)).abs() < 1e-6,
"t_to_f64(-1.5f32) = {result}, expected -1.5"
);
}
#[test]
fn norm_type_variants_distinct() {
let one = NormType::One;
let inf = NormType::Infinity;
assert_ne!(one, inf, "NormType variants must be distinct");
}
#[test]
fn norm_type_clone() {
let original = NormType::Infinity;
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn norm_type_one_debug() {
let s = format!("{:?}", NormType::One);
assert!(
s.contains("One"),
"NormType::One debug must contain 'One', got '{s}'"
);
}
}