use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::ops::{AddAssign, SubAssign};
use super::types::{classify_stability, ConditionReport, StabilityDiagnostics, StabilityLevel};
use crate::error::{InterpolateError, InterpolateResult};
pub fn assess_matrix_condition<F>(matrix: &ArrayView2<F>) -> InterpolateResult<ConditionReport<F>>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign + 'static,
{
if matrix.nrows() != matrix.ncols() {
return Err(InterpolateError::ShapeMismatch {
expected: "square matrix".to_string(),
actual: format!("{}x{}", matrix.nrows(), matrix.ncols()),
object: "condition assessment".to_string(),
});
}
let n = matrix.nrows();
if n == 0 {
return Err(InterpolateError::InvalidInput {
message: "Cannot assess condition of empty matrix".to_string(),
});
}
let mut diagnostics = StabilityDiagnostics {
is_symmetric: check_symmetry(matrix),
..Default::default()
};
let condition_number = estimate_condition_number(matrix, &mut diagnostics)?;
let stability_level = classify_stability(condition_number);
let is_well_conditioned = matches!(
stability_level,
StabilityLevel::Excellent | StabilityLevel::Good
);
let recommended_regularization = if !is_well_conditioned {
Some(suggest_regularization(condition_number, &diagnostics))
} else {
None
};
Ok(ConditionReport {
condition_number,
is_well_conditioned,
recommended_regularization,
stability_level,
diagnostics,
})
}
pub fn check_symmetry<F>(matrix: &ArrayView2<F>) -> bool
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
if matrix.nrows() != matrix.ncols() {
return false;
}
let n = matrix.nrows();
let tolerance =
super::types::machine_epsilon::<F>() * F::from(100.0).unwrap_or_else(|| F::one());
for i in 0..n {
for j in 0..n {
if (matrix[(i, j)] - matrix[(j, i)]).abs() > tolerance {
return false;
}
}
}
true
}
pub fn estimate_condition_number<F>(
matrix: &ArrayView2<F>,
diagnostics: &mut StabilityDiagnostics<F>,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let n = matrix.nrows();
if n == 2 {
estimate_condition_2x2(matrix, diagnostics)
} else if n <= 100 {
estimate_condition_svd(matrix, diagnostics)
} else {
estimate_condition_norm_based(matrix)
}
}
pub fn estimate_condition_2x2<F>(
matrix: &ArrayView2<F>,
diagnostics: &mut StabilityDiagnostics<F>,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let a = matrix[(0, 0)];
let b = matrix[(0, 1)];
let c = matrix[(1, 0)];
let d = matrix[(1, 1)];
let det = a * d - b * c;
let eps = super::types::machine_epsilon::<F>()
* F::from(1000.0).expect("Failed to convert constant to float");
if det.abs() < eps {
return Ok(F::infinity());
}
let trace = a + d;
let discriminant =
trace * trace - F::from(4.0).expect("Failed to convert constant to float") * det;
if discriminant < F::zero() {
let frobenius_norm = (a * a + b * b + c * c + d * d).sqrt();
let frobenius_norm_inv = F::one() / det.abs() * (d * d + b * b + c * c + a * a).sqrt();
return Ok(frobenius_norm * frobenius_norm_inv);
}
let sqrt_discriminant = discriminant.sqrt();
let lambda1 =
(trace + sqrt_discriminant) / F::from(2.0).expect("Failed to convert constant to float");
let lambda2 =
(trace - sqrt_discriminant) / F::from(2.0).expect("Failed to convert constant to float");
let max_singular = lambda1.abs().max(lambda2.abs());
let min_singular = lambda1.abs().min(lambda2.abs());
if min_singular < eps {
return Ok(F::infinity());
}
diagnostics.max_singular_value = Some(max_singular);
diagnostics.min_singular_value = Some(min_singular);
diagnostics.estimated_rank = Some(if min_singular > eps { 2 } else { 1 });
Ok(max_singular / min_singular)
}
pub fn estimate_condition_svd<F>(
matrix: &ArrayView2<F>,
diagnostics: &mut StabilityDiagnostics<F>,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let n = matrix.nrows();
let max_singular = power_iteration_max_eigenvalue(matrix)?;
let min_singular = inverse_power_iteration_min_eigenvalue(matrix)?;
diagnostics.max_singular_value = Some(max_singular);
diagnostics.min_singular_value = Some(min_singular);
let rank_threshold = super::types::machine_epsilon::<F>()
* max_singular
* F::from(n).expect("Failed to convert to float");
diagnostics.estimated_rank = Some(if min_singular > rank_threshold {
n
} else {
n - 1
});
if diagnostics.is_symmetric {
diagnostics.is_positive_definite = Some(min_singular > F::zero());
}
if min_singular > F::zero() {
Ok(max_singular / min_singular)
} else {
Ok(F::infinity())
}
}
pub fn estimate_condition_norm_based<F>(matrix: &ArrayView2<F>) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let norm = matrix_1_norm(matrix);
let inv_norm = estimate_inverse_norm(matrix)?;
Ok(norm * inv_norm)
}
fn power_iteration_max_eigenvalue<F>(matrix: &ArrayView2<F>) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let n = matrix.nrows();
let max_iterations = 50;
let tolerance = F::from(1e-6).unwrap_or_else(|| super::types::machine_epsilon::<F>());
let mut v = Array1::from_elem(
n,
F::one() / F::from(n).expect("Failed to convert to float"),
);
let mut eigenvalue = F::zero();
for _ in 0..max_iterations {
let mut av = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
av[i] += matrix[(i, j)] * v[j];
}
}
let mut atav = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
atav[i] += matrix[(j, i)] * av[j];
}
}
let numerator = v
.iter()
.zip(atav.iter())
.map(|(x, y)| *x * *y)
.fold(F::zero(), |acc, x| acc + x);
let denominator = v.iter().map(|x| *x * *x).fold(F::zero(), |acc, x| acc + x);
if denominator < super::types::machine_epsilon::<F>() {
return Err(InterpolateError::NumericalInstability {
message: "Power iteration failed: zero denominator".to_string(),
});
}
let new_eigenvalue = numerator / denominator;
if (new_eigenvalue - eigenvalue).abs() < tolerance * eigenvalue.abs() {
return Ok(new_eigenvalue.sqrt());
}
eigenvalue = new_eigenvalue;
let norm = denominator.sqrt();
for x in v.iter_mut() {
*x = *x / norm;
}
}
Ok(eigenvalue.sqrt())
}
fn inverse_power_iteration_min_eigenvalue<F>(matrix: &ArrayView2<F>) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let n = matrix.nrows();
let mut min_diag = F::infinity();
for i in 0..n {
if matrix[(i, i)].abs() < min_diag {
min_diag = matrix[(i, i)].abs();
}
}
Ok(min_diag.max(super::types::machine_epsilon::<F>()))
}
fn matrix_1_norm<F>(matrix: &ArrayView2<F>) -> F
where
F: Float + FromPrimitive + AddAssign,
{
let mut max_col_sum = F::zero();
for j in 0..matrix.ncols() {
let mut col_sum = F::zero();
for i in 0..matrix.nrows() {
col_sum += matrix[(i, j)].abs();
}
if col_sum > max_col_sum {
max_col_sum = col_sum;
}
}
max_col_sum
}
fn estimate_inverse_norm<F>(matrix: &ArrayView2<F>) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let n = matrix.nrows();
let mut min_diag = F::infinity();
for i in 0..n {
let diag_val = matrix[(i, i)].abs();
if diag_val > F::zero() && diag_val < min_diag {
min_diag = diag_val;
}
}
if min_diag == F::infinity() || min_diag <= super::types::machine_epsilon::<F>() {
Ok(F::infinity())
} else {
Ok(F::one() / min_diag)
}
}
pub fn suggest_regularization<F>(condition_number: F, diagnostics: &StabilityDiagnostics<F>) -> F
where
F: Float + FromPrimitive + Debug + Display + AddAssign + SubAssign,
{
let machine_eps = diagnostics.machine_epsilon;
let base_regularization = if condition_number.is_infinite() {
machine_eps
* F::from(1e6)
.unwrap_or_else(|| F::from(1e6).expect("Failed to convert constant to float"))
} else {
machine_eps * condition_number.sqrt()
};
if let Some(min_sv) = diagnostics.min_singular_value {
if let Some(max_sv) = diagnostics.max_singular_value {
let geometric_mean = (min_sv * max_sv).sqrt();
return base_regularization.min(geometric_mean * machine_eps);
}
}
base_regularization
}
pub fn check_diagonal_dominance<F>(matrix: &ArrayView2<F>) -> bool
where
F: Float + FromPrimitive + AddAssign,
{
let n = matrix.nrows();
for i in 0..n {
let diag_val = matrix[(i, i)].abs();
let mut off_diag_sum = F::zero();
for j in 0..n {
if i != j {
off_diag_sum += matrix[(i, j)].abs();
}
}
if diag_val <= off_diag_sum {
return false;
}
}
true
}
pub fn count_zero_diagonal_elements<F>(matrix: &ArrayView2<F>) -> usize
where
F: Float + FromPrimitive,
{
let n = matrix.nrows();
let tolerance =
super::types::machine_epsilon::<F>() * F::from(100.0).unwrap_or_else(|| F::one());
let mut count = 0;
for i in 0..n {
if matrix[(i, i)].abs() < tolerance {
count += 1;
}
}
count
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_assess_identity_matrix() {
let matrix = Array2::<f64>::eye(3);
let report = assess_matrix_condition(&matrix.view()).expect("Operation failed");
assert!(report.is_well_conditioned);
assert_eq!(report.stability_level, StabilityLevel::Excellent);
assert!(report.condition_number > 0.0);
assert!(report.condition_number < 10.0); }
#[test]
fn test_symmetry_check() {
let symmetric =
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 2.0, 3.0]).expect("Operation failed");
let asymmetric =
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
assert!(check_symmetry(&symmetric.view()));
assert!(!check_symmetry(&asymmetric.view()));
}
#[test]
fn test_diagonal_dominance() {
let dominant =
Array2::from_shape_vec((2, 2), vec![3.0, 1.0, 1.0, 3.0]).expect("Operation failed");
let non_dominant =
Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 2.0, 1.0]).expect("Operation failed");
assert!(check_diagonal_dominance(&dominant.view()));
assert!(!check_diagonal_dominance(&non_dominant.view()));
}
#[test]
fn test_zero_diagonal_count() {
let matrix =
Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0, 0.0])
.expect("Operation failed");
assert_eq!(count_zero_diagonal_elements(&matrix.view()), 2);
}
#[test]
fn test_matrix_1_norm() {
let matrix =
Array2::from_shape_vec((2, 2), vec![1.0, -2.0, 3.0, -4.0]).expect("Operation failed");
let norm = matrix_1_norm(&matrix.view());
assert!((norm - 6.0).abs() < 1e-10);
}
#[test]
fn test_power_iteration() {
let matrix = Array2::<f64>::eye(3) * 2.0; let eigenvalue = power_iteration_max_eigenvalue(&matrix.view()).expect("Operation failed");
assert!((eigenvalue - 2.0).abs() < 1e-6);
}
}