use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct ConditionEstimate {
pub estimate: f64,
pub reliable: bool,
}
impl ConditionEstimate {
fn reliable(value: f64) -> Self {
ConditionEstimate {
estimate: value.max(1.0),
reliable: true,
}
}
fn unreliable(value: f64) -> Self {
ConditionEstimate {
estimate: value.max(1.0),
reliable: false,
}
}
}
pub fn estimate_condition_1norm(a: &ArrayView2<f64>) -> LinalgResult<ConditionEstimate> {
let (m, n) = (a.nrows(), a.ncols());
if m != n {
return Err(LinalgError::ShapeError(format!(
"estimate_condition_1norm requires a square matrix, got ({m}×{n})"
)));
}
if n == 0 {
return Ok(ConditionEstimate::reliable(1.0));
}
let norm_a = matrix_1norm(a);
if norm_a == 0.0 {
return Err(LinalgError::SingularMatrixError(
"Matrix has zero 1-norm; condition number is infinite".to_string(),
));
}
let lu_result = lu_partial_pivot(a);
match lu_result {
Err(_) => {
return Ok(ConditionEstimate::unreliable(f64::INFINITY));
}
Ok((lu, piv)) => {
let inv_norm = estimate_inv_1norm(&lu, &piv, n);
let est = norm_a * inv_norm;
Ok(ConditionEstimate::reliable(est))
}
}
}
fn matrix_1norm(a: &ArrayView2<f64>) -> f64 {
let n = a.ncols();
let mut max_col = 0.0f64;
for j in 0..n {
let col_sum: f64 = (0..a.nrows()).map(|i| a[[i, j]].abs()).sum();
if col_sum > max_col {
max_col = col_sum;
}
}
max_col
}
fn lu_partial_pivot(a: &ArrayView2<f64>) -> LinalgResult<(Array2<f64>, Vec<usize>)> {
let n = a.nrows();
let mut lu = a.to_owned();
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = lu[[k, k]].abs();
let mut max_row = k;
for i in k + 1..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < f64::EPSILON * 1e3 {
return Err(LinalgError::SingularMatrixError(
"Near-singular matrix in LU decomposition".to_string(),
));
}
if max_row != k {
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
piv.swap(k, max_row);
}
let pivot = lu[[k, k]];
for i in k + 1..n {
lu[[i, k]] /= pivot;
for j in k + 1..n {
let lij = lu[[i, k]];
lu[[i, j]] -= lij * lu[[k, j]];
}
}
}
Ok((lu, piv))
}
fn lu_solve(lu: &Array2<f64>, piv: &[usize], b: &Array1<f64>) -> Array1<f64> {
let n = lu.nrows();
let mut x = b.to_owned();
for (k, &p) in piv.iter().enumerate().take(n) {
if p != k {
x.swap(k, p);
}
}
for i in 1..n {
let mut s = x[i];
for j in 0..i {
s -= lu[[i, j]] * x[j];
}
x[i] = s;
}
for i in (0..n).rev() {
let mut s = x[i];
for j in i + 1..n {
s -= lu[[i, j]] * x[j];
}
x[i] = s / lu[[i, i]];
}
x
}
fn lu_solve_transpose(lu: &Array2<f64>, piv: &[usize], b: &Array1<f64>) -> Array1<f64> {
let n = lu.nrows();
let mut x = b.to_owned();
for i in 0..n {
let mut s = x[i];
for j in 0..i {
s -= lu[[j, i]] * x[j];
}
x[i] = s / lu[[i, i]];
}
for i in (0..n).rev() {
let mut s = x[i];
for j in i + 1..n {
s -= lu[[j, i]] * x[j];
}
x[i] = s;
}
for k in (0..n).rev() {
let p = piv[k];
if p != k {
x.swap(k, p);
}
}
x
}
fn estimate_inv_1norm(lu: &Array2<f64>, piv: &[usize], n: usize) -> f64 {
const MAX_ITER: usize = 5;
let inv_n = 1.0 / n as f64;
let mut v = Array1::from_elem(n, inv_n);
let mut est = 0.0f64;
let mut est_old = 0.0f64;
for _iter in 0..MAX_ITER {
let y = lu_solve(lu, piv, &v);
est = y.iter().map(|x| x.abs()).sum::<f64>();
if est <= est_old * (1.0 + 1e-10) {
break; }
est_old = est;
let mut w = Array1::zeros(n);
for i in 0..n {
w[i] = if y[i] >= 0.0 { inv_n } else { -inv_n };
}
let z = lu_solve_transpose(lu, piv, &w);
let max_z = z.iter().map(|x| x.abs()).fold(0.0f64, f64::max);
if max_z
<= z.iter()
.zip(v.iter())
.map(|(zi, vi)| zi.abs() * vi.abs())
.sum::<f64>()
{
break;
}
let argmax = z
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.abs()
.partial_cmp(&b.abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
v = Array1::zeros(n);
v[argmax] = 1.0;
}
est
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PrecisionLevel {
F16,
F32,
F64,
F128Fallback,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum PrecisionPolicy {
Fixed(PrecisionLevel),
Adaptive {
low_threshold: f64,
high_threshold: f64,
},
}
impl Default for PrecisionPolicy {
fn default() -> Self {
PrecisionPolicy::Fixed(PrecisionLevel::F64)
}
}
pub fn select_precision(cond_est: f64, policy: &PrecisionPolicy) -> PrecisionLevel {
match policy {
PrecisionPolicy::Fixed(level) => level.clone(),
PrecisionPolicy::Adaptive {
low_threshold,
high_threshold,
} => {
if cond_est < *low_threshold {
PrecisionLevel::F16
} else if cond_est < *high_threshold {
PrecisionLevel::F32
} else if cond_est < 1e12 {
PrecisionLevel::F64
} else {
PrecisionLevel::F128Fallback
}
}
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveSolverConfig {
pub policy: PrecisionPolicy,
pub max_cond_iter: usize,
}
impl Default for AdaptiveSolverConfig {
fn default() -> Self {
AdaptiveSolverConfig {
policy: PrecisionPolicy::Adaptive {
low_threshold: 1e3,
high_threshold: 1e7,
},
max_cond_iter: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveSolver {
config: AdaptiveSolverConfig,
}
#[derive(Debug, Clone)]
pub struct SolveResult {
pub solution: Array1<f64>,
pub precision_used: PrecisionLevel,
pub cond_estimate: f64,
}
impl AdaptiveSolver {
pub fn new(config: AdaptiveSolverConfig) -> Self {
AdaptiveSolver { config }
}
pub fn default_adaptive() -> Self {
AdaptiveSolver::new(AdaptiveSolverConfig::default())
}
pub fn estimate_and_select(&self, a: &ArrayView2<f64>) -> LinalgResult<(f64, PrecisionLevel)> {
match &self.config.policy {
PrecisionPolicy::Fixed(level) => Ok((1.0, level.clone())),
PrecisionPolicy::Adaptive { .. } => {
let ce = estimate_condition_1norm(a)?;
let level = select_precision(ce.estimate, &self.config.policy);
Ok((ce.estimate, level))
}
}
}
pub fn solve(&self, a: &ArrayView2<f64>, b: &Array1<f64>) -> LinalgResult<SolveResult> {
let (m, n) = (a.nrows(), a.ncols());
if m != n {
return Err(LinalgError::ShapeError(format!(
"AdaptiveSolver::solve requires square matrix, got ({m}×{n})"
)));
}
if b.len() != n {
return Err(LinalgError::ShapeError(format!(
"RHS length {} does not match matrix size {n}",
b.len()
)));
}
let (cond_est, precision_used) = self.estimate_and_select(a)?;
let solution = match precision_used {
PrecisionLevel::F16 | PrecisionLevel::F32 => {
solve_f32_internal(a, b)?
}
PrecisionLevel::F64 => {
solve_f64_internal(a, b)?
}
PrecisionLevel::F128Fallback => {
solve_f64_refined(a, b)?
}
#[allow(unreachable_patterns)]
_ => solve_f64_internal(a, b)?,
};
Ok(SolveResult {
solution,
precision_used,
cond_estimate: cond_est,
})
}
}
fn solve_f32_internal(a: &ArrayView2<f64>, b: &Array1<f64>) -> LinalgResult<Array1<f64>> {
let n = a.nrows();
let mut lu_f32: Array2<f32> = Array2::from_shape_fn((n, n), |(i, j)| a[[i, j]] as f32);
let mut b_f32: Array1<f32> = Array1::from_shape_fn(n, |i| b[i] as f32);
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = lu_f32[[k, k]].abs();
let mut max_row = k;
for i in k + 1..n {
let v = lu_f32[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < f32::EPSILON * 1e3 {
return Err(LinalgError::SingularMatrixError(
"Near-singular matrix in f32 solve".to_string(),
));
}
if max_row != k {
for j in 0..n {
let tmp = lu_f32[[k, j]];
lu_f32[[k, j]] = lu_f32[[max_row, j]];
lu_f32[[max_row, j]] = tmp;
}
piv.swap(k, max_row);
}
let pv = lu_f32[[k, k]];
for i in k + 1..n {
lu_f32[[i, k]] /= pv;
let lij = lu_f32[[i, k]];
for j in k + 1..n {
lu_f32[[i, j]] -= lij * lu_f32[[k, j]];
}
}
}
for (k, &p) in piv.iter().enumerate().take(n) {
if p != k {
b_f32.swap(k, p);
}
}
for i in 1..n {
let mut s = b_f32[i];
for j in 0..i {
s -= lu_f32[[i, j]] * b_f32[j];
}
b_f32[i] = s;
}
for i in (0..n).rev() {
let mut s = b_f32[i];
for j in i + 1..n {
s -= lu_f32[[i, j]] * b_f32[j];
}
b_f32[i] = s / lu_f32[[i, i]];
}
Ok(Array1::from_shape_fn(n, |i| b_f32[i] as f64))
}
fn solve_f64_internal(a: &ArrayView2<f64>, b: &Array1<f64>) -> LinalgResult<Array1<f64>> {
let n = a.nrows();
let lu_result = lu_partial_pivot(a)?;
let (lu, piv) = lu_result;
Ok(lu_solve(&lu, &piv, b))
}
fn solve_f64_refined(a: &ArrayView2<f64>, b: &Array1<f64>) -> LinalgResult<Array1<f64>> {
let x0 = solve_f64_internal(a, b)?;
let n = a.nrows();
let mut r = b.clone();
for i in 0..n {
let mut ax_i = 0.0f64;
for j in 0..n {
ax_i += a[[i, j]] * x0[j];
}
r[i] -= ax_i;
}
let (lu, piv) = lu_partial_pivot(a)?;
let dx = lu_solve(&lu, &piv, &r);
Ok(Array1::from_shape_fn(n, |i| x0[i] + dx[i]))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_condition_identity_is_one() {
let n = 5;
let id: Array2<f64> = Array2::eye(n);
let ce = estimate_condition_1norm(&id.view()).expect("Cond estimation failed");
assert!(
ce.estimate < 2.0,
"Identity cond estimate too large: {}",
ce.estimate
);
assert!(ce.reliable, "Identity cond not marked reliable");
}
#[test]
fn test_condition_scaled_identity() {
let n = 4;
let a: Array2<f64> = Array2::eye(n) * 100.0;
let ce = estimate_condition_1norm(&a.view()).expect("Cond estimation failed");
assert!(
ce.estimate < 5.0,
"Scaled identity cond too large: {}",
ce.estimate
);
}
#[test]
fn test_condition_hilbert_matrix_large() {
let n = 8usize;
let h: Array2<f64> = Array2::from_shape_fn((n, n), |(i, j)| 1.0 / (i + j + 1) as f64);
let ce = estimate_condition_1norm(&h.view()).expect("Cond estimation failed");
assert!(
ce.estimate > 1e6,
"Hilbert(8) cond should be > 1e6, got {}",
ce.estimate
);
}
#[test]
fn test_condition_diagonal_well_conditioned() {
let d = array![
[1.0_f64, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 4.0, 0.0],
[0.0, 0.0, 0.0, 8.0],
];
let ce = estimate_condition_1norm(&d.view()).expect("Cond estimation failed");
assert!(ce.estimate >= 1.0, "Cond estimate must be >= 1");
assert!(
ce.estimate < 200.0,
"Diagonal cond estimate way off: {}",
ce.estimate
);
}
#[test]
fn test_condition_non_square_errors() {
let a: Array2<f64> = Array2::zeros((3, 4));
assert!(estimate_condition_1norm(&a.view()).is_err());
}
#[test]
fn test_fixed_f16_always_returns_f16() {
let policy = PrecisionPolicy::Fixed(PrecisionLevel::F16);
for cond in [1.0, 1e5, 1e15] {
assert_eq!(select_precision(cond, &policy), PrecisionLevel::F16);
}
}
#[test]
fn test_fixed_f32_always_returns_f32() {
let policy = PrecisionPolicy::Fixed(PrecisionLevel::F32);
for cond in [1.0, 1e5, 1e15] {
assert_eq!(select_precision(cond, &policy), PrecisionLevel::F32);
}
}
#[test]
fn test_fixed_f64_always_returns_f64() {
let policy = PrecisionPolicy::Fixed(PrecisionLevel::F64);
for cond in [1.0, 1e5, 1e15] {
assert_eq!(select_precision(cond, &policy), PrecisionLevel::F64);
}
}
#[test]
fn test_adaptive_threshold_dispatch() {
let policy = PrecisionPolicy::Adaptive {
low_threshold: 1e2,
high_threshold: 1e6,
};
assert_eq!(select_precision(50.0, &policy), PrecisionLevel::F16);
assert_eq!(select_precision(1e3, &policy), PrecisionLevel::F32);
assert_eq!(select_precision(1e7, &policy), PrecisionLevel::F64);
assert_eq!(
select_precision(1e13, &policy),
PrecisionLevel::F128Fallback
);
}
#[test]
fn test_adaptive_boundary_low_threshold() {
let policy = PrecisionPolicy::Adaptive {
low_threshold: 100.0,
high_threshold: 1e6,
};
assert_eq!(select_precision(99.9, &policy), PrecisionLevel::F16);
assert_eq!(select_precision(100.0, &policy), PrecisionLevel::F32);
}
#[test]
fn test_adaptive_boundary_high_threshold() {
let policy = PrecisionPolicy::Adaptive {
low_threshold: 1e2,
high_threshold: 1e6,
};
assert_eq!(select_precision(999_999.9, &policy), PrecisionLevel::F32);
assert_eq!(select_precision(1e6, &policy), PrecisionLevel::F64);
}
#[test]
fn test_adaptive_solver_well_conditioned_uses_f32() {
let a = array![[4.0_f64, 1.0], [2.0, 3.0]];
let b = array![1.0_f64, 2.0];
let config = AdaptiveSolverConfig {
policy: PrecisionPolicy::Adaptive {
low_threshold: 1e3,
high_threshold: 1e7,
},
max_cond_iter: 5,
};
let solver = AdaptiveSolver::new(config);
let result = solver.solve(&a.view(), &b).expect("Solve failed");
assert!(
(result.solution[0] - 0.1).abs() < 1e-3,
"x[0]={} expected≈0.1",
result.solution[0]
);
assert!(
(result.solution[1] - 0.6).abs() < 1e-3,
"x[1]={} expected≈0.6",
result.solution[1]
);
assert!(
matches!(
result.precision_used,
PrecisionLevel::F16 | PrecisionLevel::F32
),
"Expected F16/F32 but got {:?}",
result.precision_used
);
}
#[test]
fn test_adaptive_solver_ill_conditioned_uses_f64_or_above() {
let n = 8usize;
let h: Array2<f64> = Array2::from_shape_fn((n, n), |(i, j)| 1.0 / (i + j + 1) as f64);
let b = Array1::ones(n);
let config = AdaptiveSolverConfig {
policy: PrecisionPolicy::Adaptive {
low_threshold: 1e2,
high_threshold: 1e4,
},
max_cond_iter: 5,
};
let solver = AdaptiveSolver::new(config);
let result = solver.solve(&h.view(), &b).expect("Hilbert solve failed");
assert!(
matches!(
result.precision_used,
PrecisionLevel::F64 | PrecisionLevel::F128Fallback
),
"Expected F64/F128Fallback for Hilbert(8) but got {:?} (cond={})",
result.precision_used,
result.cond_estimate
);
let mut residual_norm = 0.0f64;
for i in 0..n {
let ax_i: f64 = (0..n).map(|j| h[[i, j]] * result.solution[j]).sum();
residual_norm += (ax_i - b[i]).powi(2);
}
assert!(
residual_norm.sqrt() < 1e-6,
"Hilbert solve residual too large: {}",
residual_norm.sqrt()
);
}
#[test]
fn test_adaptive_solver_fixed_policy() {
let a = array![[2.0_f64, 0.0], [0.0, 3.0]];
let b = array![4.0_f64, 9.0];
let solver = AdaptiveSolver::new(AdaptiveSolverConfig {
policy: PrecisionPolicy::Fixed(PrecisionLevel::F64),
max_cond_iter: 5,
});
let result = solver.solve(&a.view(), &b).expect("Solve failed");
assert_eq!(result.precision_used, PrecisionLevel::F64);
assert!((result.solution[0] - 2.0).abs() < 1e-10);
assert!((result.solution[1] - 3.0).abs() < 1e-10);
}
#[test]
fn test_adaptive_solver_non_square_error() {
let a: Array2<f64> = Array2::zeros((3, 4));
let b = Array1::zeros(3);
let solver = AdaptiveSolver::default_adaptive();
assert!(solver.solve(&a.view(), &b).is_err());
}
#[test]
fn test_adaptive_solver_dimension_mismatch_error() {
let a: Array2<f64> = Array2::eye(3);
let b: Array1<f64> = Array1::zeros(4); let solver = AdaptiveSolver::default_adaptive();
assert!(solver.solve(&a.view(), &b).is_err());
}
}