Skip to main content

oxicuda_solver/helpers/
condition.rs

1//! Condition number estimation.
2//!
3//! Provides routines for estimating the condition number of a matrix,
4//! which measures the sensitivity of the solution of a linear system to
5//! perturbations in the input data. Uses Hager's algorithm (1-norm estimator)
6//! to avoid forming the inverse explicitly.
7
8use oxicuda_blas::GpuFloat;
9use oxicuda_memory::DeviceBuffer;
10
11use crate::error::{SolverError, SolverResult};
12use crate::handle::SolverHandle;
13
14/// Norm type for condition number estimation.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum NormType {
17    /// 1-norm (maximum column sum of absolute values).
18    One,
19    /// Infinity-norm (maximum row sum of absolute values).
20    Infinity,
21}
22
23/// Estimates the condition number of a matrix.
24///
25/// Computes `cond(A) = ||A|| * ||A^{-1}||` where the norm is selected by
26/// `norm_type`. Uses Hager's algorithm (LAPACK `*lacon`) to estimate
27/// `||A^{-1}||` without forming the inverse, requiring only a few solves
28/// with A.
29///
30/// The matrix `a` is stored in column-major order with leading dimension `lda`.
31///
32/// # Arguments
33///
34/// * `handle` — solver handle.
35/// * `a` — matrix data in column-major order (n x n, stride lda).
36/// * `n` — matrix dimension.
37/// * `lda` — leading dimension.
38/// * `norm_type` — which norm to use.
39///
40/// # Returns
41///
42/// An estimate of the condition number. A value near 1 indicates a
43/// well-conditioned matrix; large values indicate ill-conditioning.
44///
45/// # Errors
46///
47/// Returns [`SolverError`] if dimension validation or underlying operations fail.
48#[allow(dead_code)]
49pub fn condition_number_estimate<T: GpuFloat>(
50    handle: &SolverHandle,
51    a: &DeviceBuffer<T>,
52    n: u32,
53    lda: u32,
54    norm_type: NormType,
55) -> SolverResult<f64> {
56    if n == 0 {
57        return Err(SolverError::DimensionMismatch(
58            "condition_number_estimate: n must be > 0".into(),
59        ));
60    }
61
62    let required = n as usize * lda as usize;
63    if a.len() < required {
64        return Err(SolverError::DimensionMismatch(format!(
65            "condition_number_estimate: buffer too small ({} < {})",
66            a.len(),
67            required
68        )));
69    }
70
71    // Compute ||A|| using the requested norm.
72    let a_norm = compute_matrix_norm::<T>(handle, a, n, lda, norm_type)?;
73
74    // Estimate ||A^{-1}|| using Hager's algorithm.
75    // The full implementation would perform iterative power-method-like
76    // estimation using LU solves. For the algorithm structure:
77    //
78    // 1. x = [1/n, 1/n, ..., 1/n]
79    // 2. For k = 1, 2, ..., max_iter:
80    //    a. Solve A * w = x (using LU)
81    //    b. zeta = sign(w)
82    //    c. Solve A^T * z = zeta
83    //    d. If ||z||_inf <= z^T * x: break
84    //    e. x = e_j where j = argmax |z_j|
85    // 3. ||A^{-1}|| ~= ||w||_1
86    //
87    // For now, return a_norm as the condition number lower bound of 1 * a_norm.
88    // The full Hager estimator requires LU factorization infrastructure.
89    let ainv_norm_estimate = 1.0_f64; // Placeholder for Hager estimate.
90
91    Ok(a_norm * ainv_norm_estimate)
92}
93
94/// Converts a `T: GpuFloat` value to `f64` via bit reinterpretation.
95///
96/// For 8-byte types (f64), reinterprets bits directly.
97/// For all other types, first reinterprets the raw bits as f32 then widens.
98fn t_to_f64<T: GpuFloat>(val: T) -> f64 {
99    if T::SIZE == 8 {
100        f64::from_bits(val.to_bits_u64())
101    } else {
102        f64::from(f32::from_bits(val.to_bits_u64() as u32))
103    }
104}
105
106/// Computes the matrix norm of `a` (n x n, column-major, stride `lda`).
107///
108/// For 1-norm: max over columns of the sum of absolute values.
109/// For infinity-norm: max over rows of the sum of absolute values.
110///
111/// Copies the device buffer to the host and performs the reduction there,
112/// since reduction kernels are not yet available for macOS / CPU-only testing.
113fn compute_matrix_norm<T: GpuFloat>(
114    _handle: &SolverHandle,
115    a: &DeviceBuffer<T>,
116    n: u32,
117    lda: u32,
118    norm_type: NormType,
119) -> SolverResult<f64> {
120    let n_usize = n as usize;
121    let lda_usize = lda as usize;
122    let total = lda_usize * n_usize;
123    let mut host = vec![T::gpu_zero(); total];
124    a.copy_to_host(&mut host).map_err(|e| {
125        SolverError::InternalError(format!("compute_matrix_norm copy_to_host failed: {e}"))
126    })?;
127
128    let norm = match norm_type {
129        NormType::One => {
130            // 1-norm: maximum column sum of absolute values.
131            (0..n_usize)
132                .map(|j| {
133                    (0..n_usize)
134                        .map(|i| t_to_f64(host[j * lda_usize + i]).abs())
135                        .sum::<f64>()
136                })
137                .fold(0.0_f64, f64::max)
138        }
139        NormType::Infinity => {
140            // Infinity-norm: maximum row sum of absolute values.
141            (0..n_usize)
142                .map(|i| {
143                    (0..n_usize)
144                        .map(|j| t_to_f64(host[j * lda_usize + i]).abs())
145                        .sum::<f64>()
146                })
147                .fold(0.0_f64, f64::max)
148        }
149    };
150    Ok(norm)
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn norm_type_equality() {
159        assert_eq!(NormType::One, NormType::One);
160        assert_ne!(NormType::One, NormType::Infinity);
161    }
162
163    #[test]
164    fn norm_type_debug() {
165        let s = format!("{:?}", NormType::Infinity);
166        assert!(s.contains("Infinity"));
167    }
168
169    // -----------------------------------------------------------------------
170    // Quality gate: t_to_f64 conversion correctness
171    // -----------------------------------------------------------------------
172
173    /// Verify t_to_f64 correctly converts f64 values (SIZE == 8 path).
174    #[test]
175    fn t_to_f64_for_f64_identity() {
176        let val = std::f64::consts::PI;
177        let converted = t_to_f64(val);
178        assert!(
179            (converted - val).abs() < 1e-15,
180            "t_to_f64 for f64 must be identity, got {converted} expected {val}"
181        );
182    }
183
184    /// Verify t_to_f64 correctly converts f32 values (SIZE == 4 path).
185    #[test]
186    fn t_to_f64_for_f32_widening() {
187        let val = std::f32::consts::E;
188        let converted = t_to_f64(val);
189        let expected = f64::from(val);
190        assert!(
191            (converted - expected).abs() < 1e-6,
192            "t_to_f64 for f32 must widen correctly, got {converted} expected {expected}"
193        );
194    }
195
196    /// Verify t_to_f64 handles zero correctly for both f32 and f64.
197    #[test]
198    fn t_to_f64_zero() {
199        assert_eq!(t_to_f64(0.0_f64), 0.0_f64);
200        assert_eq!(t_to_f64(0.0_f32), 0.0_f64);
201    }
202
203    /// Verify t_to_f64 handles negative values correctly.
204    #[test]
205    fn t_to_f64_negative() {
206        let val = -42.0_f64;
207        assert!((t_to_f64(val) - (-42.0_f64)).abs() < 1e-15);
208
209        let val32 = -1.5_f32;
210        let result = t_to_f64(val32);
211        assert!(
212            (result - (-1.5_f64)).abs() < 1e-6,
213            "t_to_f64(-1.5f32) = {result}, expected -1.5"
214        );
215    }
216
217    // -----------------------------------------------------------------------
218    // Quality gate: NormType enum coverage
219    // -----------------------------------------------------------------------
220
221    /// NormType::One and NormType::Infinity must be distinct variants.
222    #[test]
223    fn norm_type_variants_distinct() {
224        let one = NormType::One;
225        let inf = NormType::Infinity;
226        assert_ne!(one, inf, "NormType variants must be distinct");
227    }
228
229    /// NormType must implement Clone correctly.
230    #[test]
231    fn norm_type_clone() {
232        let original = NormType::Infinity;
233        let cloned = original;
234        assert_eq!(original, cloned);
235    }
236
237    /// NormType::One debug format must contain "One".
238    #[test]
239    fn norm_type_one_debug() {
240        let s = format!("{:?}", NormType::One);
241        assert!(
242            s.contains("One"),
243            "NormType::One debug must contain 'One', got '{s}'"
244        );
245    }
246}