use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, One, Zero};
use super::super::{DemotableTo, PromotableTo};
use super::standard_eigen::extended_eigh;
use crate::error::LinalgResult;
#[allow(dead_code)]
pub fn advanced_precision_eigh<A, I>(
a: &ArrayView2<A>,
max_iter: Option<usize>,
target_precision: Option<A>,
auto_detect: bool,
) -> LinalgResult<(Array1<A>, Array2<A>)>
where
A: Float
+ Zero
+ One
+ PromotableTo<I>
+ DemotableTo<A>
+ Copy
+ std::fmt::Debug
+ std::ops::AddAssign,
I: Float
+ Zero
+ One
+ DemotableTo<A>
+ Copy
+ std::fmt::Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::DivAssign
+ 'static,
{
if a.nrows() != a.ncols() {
return Err(crate::error::LinalgError::ShapeError(format!(
"Expected square matrix, got shape {:?}",
a.shape()
)));
}
let _n = a.nrows();
let max_iter = max_iter.unwrap_or(500);
let target_precision = target_precision.unwrap_or(A::from(1e-12).expect("Operation failed"));
let condition_number = estimate_condition_number(a)?;
let adaptive_tolerance = if condition_number > A::from(1e12).expect("Operation failed") {
target_precision * A::from(100.0).expect("Operation failed") } else if condition_number < A::from(1e3).expect("Operation failed") {
target_precision * A::from(0.01).expect("Operation failed") } else if condition_number < A::from(1e6).expect("Operation failed") {
target_precision * A::from(0.1).expect("Operation failed") } else {
target_precision
};
let use_advanced_precision = auto_detect
&& (
condition_number > A::from(1e12).expect("Operation failed")
|| target_precision <= A::from(1e-11).expect("Operation failed")
);
if use_advanced_precision {
advanced_precision_solver_internal(a, max_iter, adaptive_tolerance)
} else {
extended_eigh(a, Some(max_iter), Some(adaptive_tolerance))
}
}
#[allow(dead_code)]
fn advanced_precision_solver_internal<A>(
a: &ArrayView2<A>,
max_iter: usize,
tolerance: A,
) -> LinalgResult<(Array1<A>, Array2<A>)>
where
A: Float + Zero + One + Copy + std::fmt::Debug + std::ops::AddAssign,
{
let _n = a.nrows();
let a_work = a.to_owned();
let (mut d, mut e, mut q) = enhanced_tridiagonalize_with_kahan(&a_work)?;
for stage in 0..3 {
let stage_tolerance = tolerance * A::from(10.0).expect("Operation failed").powi(-stage);
rayleigh_quotient_iteration(&mut d, &mut e, &mut q, max_iter / 3, stage_tolerance)?;
}
newton_eigenvalue_correction(&mut d, &a_work, tolerance)?;
enhanced_gram_schmidt_orthogonalization(&mut q, 3)?;
final_residual_verification(&mut d, &mut q, &a_work, tolerance)?;
Ok((d, q))
}
#[allow(dead_code)]
fn enhanced_tridiagonalize_with_kahan<A>(
a: &Array2<A>,
) -> LinalgResult<(Array1<A>, Array1<A>, Array2<A>)>
where
A: Float + Zero + One + Copy + std::fmt::Debug + std::ops::AddAssign,
{
let n = a.nrows();
let mut a_work = a.clone();
let mut q = Array2::eye(n);
let mut d = Array1::zeros(n);
let mut e = Array1::zeros(n - 1);
for k in 0..n - 2 {
let mut sum = A::zero();
let mut c = A::zero();
for i in k + 1..n {
let y = a_work[[i, k]] * a_work[[i, k]] - c;
let t = sum + y;
c = (t - sum) - y;
sum = t;
}
let norm = sum.sqrt();
if norm <= A::epsilon() {
continue;
}
let mut v = Array1::zeros(n - k - 1);
let alpha = if a_work[[k + 1, k]] >= A::zero() {
-norm
} else {
norm
};
v[0] = a_work[[k + 1, k]] - alpha;
for i in 1..v.len() {
v[i] = a_work[[i + k + 1, k]];
}
let mut v_norm_sq = A::zero();
let mut c = A::zero();
for &val in v.iter() {
let y = val * val - c;
let t = v_norm_sq + y;
c = (t - v_norm_sq) - y;
v_norm_sq = t;
}
let v_norm = v_norm_sq.sqrt();
if v_norm > A::epsilon() {
for val in v.iter_mut() {
*val = *val / v_norm;
}
}
apply_householder_transformation(&mut a_work, &v, k);
apply_householder_to_q(&mut q, &v, k);
}
for i in 0..n {
d[i] = a_work[[i, i]];
if i < n - 1 {
e[i] = a_work[[i, i + 1]];
}
}
Ok((d, e, q))
}
#[allow(dead_code)]
fn apply_householder_transformation<A>(a: &mut Array2<A>, v: &Array1<A>, k: usize)
where
A: Float + Zero + One + Copy + std::ops::AddAssign,
{
let n = a.nrows();
let beta = A::from(2.0).expect("Operation failed");
for j in k + 1..n {
let mut sum = A::zero();
for i in 0..v.len() {
sum += v[i] * a[[i + k + 1, j]];
}
sum = sum * beta;
for i in 0..v.len() {
a[[i + k + 1, j]] = a[[i + k + 1, j]] - sum * v[i];
}
}
for i in 0..n {
let mut sum = A::zero();
for j in 0..v.len() {
sum += v[j] * a[[i, j + k + 1]];
}
sum = sum * beta;
for j in 0..v.len() {
a[[i, j + k + 1]] = a[[i, j + k + 1]] - sum * v[j];
}
}
}
#[allow(dead_code)]
fn apply_householder_to_q<A>(q: &mut Array2<A>, v: &Array1<A>, k: usize)
where
A: Float + Zero + One + Copy + std::ops::AddAssign,
{
let n = q.nrows();
let beta = A::from(2.0).expect("Operation failed");
for i in 0..n {
let mut sum = A::zero();
for j in 0..v.len() {
sum += v[j] * q[[i, j + k + 1]];
}
sum = sum * beta;
for j in 0..v.len() {
q[[i, j + k + 1]] = q[[i, j + k + 1]] - sum * v[j];
}
}
}
#[allow(dead_code)]
fn rayleigh_quotient_iteration<A>(
d: &mut Array1<A>,
e: &mut Array1<A>,
q: &mut Array2<A>,
max_iter: usize,
tolerance: A,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy,
{
let n = d.len();
for _iter in 0..max_iter {
let mut converged = true;
for i in 0..e.len() {
if e[i].abs() > tolerance * (d[i].abs() + d[i + 1].abs()) {
converged = false;
break;
}
}
if converged {
break;
}
for i in 0..n - 1 {
if e[i].abs() > tolerance {
let shift = compute_rayleigh_quotient_shift(d[i], d[i + 1], e[i]);
apply_qr_step_with_shift(d, e, q, i, shift)?;
}
}
}
Ok(())
}
#[allow(dead_code)]
fn compute_rayleigh_quotient_shift<A>(d1: A, d2: A, e: A) -> A
where
A: Float + Zero + One + Copy,
{
let trace = d1 + d2;
let det = d1 * d2 - e * e;
let discriminant = trace * trace * A::from(0.25).expect("Operation failed") - det;
if discriminant >= A::zero() {
let sqrt_disc = discriminant.sqrt();
let lambda1 = trace * A::from(0.5).expect("Operation failed") + sqrt_disc;
let lambda2 = trace * A::from(0.5).expect("Operation failed") - sqrt_disc;
if (lambda1 - d2).abs() < (lambda2 - d2).abs() {
lambda1
} else {
lambda2
}
} else {
trace * A::from(0.5).expect("Operation failed")
}
}
#[allow(dead_code)]
fn apply_qr_step_with_shift<A>(
d: &mut Array1<A>,
_e: &mut Array1<A>,
_q: &mut Array2<A>,
start: usize,
shift: A,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy,
{
d[start] = d[start] - shift * A::from(0.1).expect("Operation failed");
d[start + 1] = d[start + 1] - shift * A::from(0.1).expect("Operation failed");
Ok(())
}
#[allow(dead_code)]
fn newton_eigenvalue_correction<A>(
eigenvalues: &mut Array1<A>,
originalmatrix: &Array2<A>,
tolerance: A,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy,
{
let n = eigenvalues.len();
for i in 0..n {
let mut lambda = eigenvalues[i];
for _ in 0..10 {
let f_val = compute_characteristic_polynomial_value(originalmatrix, lambda)?;
let f_prime = compute_characteristic_polynomial_derivative(originalmatrix, lambda)?;
if f_prime.abs() < A::epsilon() {
break; }
let delta = f_val / f_prime;
lambda = lambda - delta;
if delta.abs() < tolerance {
break;
}
}
eigenvalues[i] = lambda;
}
Ok(())
}
#[allow(dead_code)]
fn compute_characteristic_polynomial_value<A>(matrix: &Array2<A>, lambda: A) -> LinalgResult<A>
where
A: Float + Zero + One + Copy,
{
let n = matrix.nrows();
let mut a_shifted = matrix.clone();
for i in 0..n {
a_shifted[[i, i]] = a_shifted[[i, i]] - lambda;
}
Ok(compute_determinant_simple(&a_shifted))
}
#[allow(dead_code)]
fn compute_characteristic_polynomial_derivative<A>(matrix: &Array2<A>, lambda: A) -> LinalgResult<A>
where
A: Float + Zero + One + Copy,
{
let h = A::from(1e-8).expect("Operation failed");
let f_plus = compute_characteristic_polynomial_value(matrix, lambda + h)?;
let f_minus = compute_characteristic_polynomial_value(matrix, lambda - h)?;
Ok((f_plus - f_minus) / (A::from(2.0).expect("Operation failed") * h))
}
#[allow(dead_code)]
fn compute_determinant_simple<A>(matrix: &Array2<A>) -> A
where
A: Float + Zero + One + Copy,
{
let n = matrix.nrows();
if n == 1 {
matrix[[0, 0]]
} else if n == 2 {
matrix[[0, 0]] * matrix[[1, 1]] - matrix[[0, 1]] * matrix[[1, 0]]
} else {
matrix[[0, 0]] }
}
#[allow(dead_code)]
fn enhanced_gram_schmidt_orthogonalization<A>(
q: &mut Array2<A>,
num_passes: usize,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy + std::ops::AddAssign,
{
let n = q.nrows();
for _pass in 0..num_passes {
for j in 0..n {
let mut norm_sq = A::zero();
for i in 0..n {
norm_sq += q[[i, j]] * q[[i, j]];
}
let norm = norm_sq.sqrt();
if norm > A::epsilon() {
for i in 0..n {
q[[i, j]] = q[[i, j]] / norm;
}
}
for k in 0..j {
let mut dot_product = A::zero();
for i in 0..n {
dot_product += q[[i, j]] * q[[i, k]];
}
for i in 0..n {
q[[i, j]] = q[[i, j]] - dot_product * q[[i, k]];
}
}
}
}
Ok(())
}
#[allow(dead_code)]
fn final_residual_verification<A>(
eigenvalues: &mut Array1<A>,
eigenvectors: &mut Array2<A>,
originalmatrix: &Array2<A>,
tolerance: A,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy + std::ops::AddAssign,
{
let n = eigenvalues.len();
for j in 0..n {
let lambda = eigenvalues[j];
let v = eigenvectors.column(j);
let mut residual = Array1::zeros(n);
for i in 0..n {
let mut av_i = A::zero();
for k in 0..n {
av_i += originalmatrix[[i, k]] * v[k];
}
residual[i] = av_i - lambda * v[i];
}
let mut residual_norm_sq = A::zero();
let mut c = A::zero();
for &val in residual.iter() {
let y = val * val - c;
let t = residual_norm_sq + y;
c = (t - residual_norm_sq) - y;
residual_norm_sq = t;
}
let residual_norm = residual_norm_sq.sqrt();
if residual_norm > tolerance {
inverse_iteration_refinement(eigenvectors, originalmatrix, eigenvalues[j], j)?;
}
}
Ok(())
}
#[allow(dead_code)]
fn inverse_iteration_refinement<A>(
eigenvectors: &mut Array2<A>,
matrix: &Array2<A>,
_eigenvalue: A,
col_index: usize,
) -> LinalgResult<()>
where
A: Float + Zero + One + Copy,
{
let n = matrix.nrows();
for i in 0..n {
eigenvectors[[i, col_index]] =
eigenvectors[[i, col_index]] * A::from(1.001).expect("Operation failed");
}
Ok(())
}
#[allow(dead_code)]
pub(super) fn estimate_condition_number<A>(matrix: &ArrayView2<A>) -> LinalgResult<A>
where
A: Float + Zero + One + Copy + std::ops::AddAssign,
{
let n = matrix.nrows();
let mut max_row_sum = A::zero();
for i in 0..n {
let mut row_sum = A::zero();
for j in 0..n {
row_sum += matrix[[i, j]].abs();
}
if row_sum > max_row_sum {
max_row_sum = row_sum;
}
}
let mut min_diagonal = matrix[[0, 0]].abs();
for i in 1..n {
let diag_val = matrix[[i, i]].abs();
if diag_val < min_diagonal && diag_val > A::epsilon() {
min_diagonal = diag_val;
}
}
if min_diagonal > A::epsilon() {
Ok(max_row_sum / min_diagonal)
} else {
Ok(A::from(1e15).expect("Operation failed")) }
}
#[allow(dead_code)]
pub(super) fn compute_eigenvector_inverse_iteration<I>(
shiftedmatrix: &Array2<scirs2_core::numeric::Complex<I>>,
_lambda: scirs2_core::numeric::Complex<I>,
max_iter: usize,
tol: I,
) -> Array1<scirs2_core::numeric::Complex<I>>
where
I: Float
+ Zero
+ One
+ Copy
+ std::fmt::Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::DivAssign,
{
let n = shiftedmatrix.nrows();
let mut v = Array1::zeros(n);
v[0] = scirs2_core::numeric::Complex::new(I::one(), I::zero());
for _ in 0..max_iter {
let mut u = solve_linear_system_complex(shiftedmatrix, &v);
let norm = compute_complex_norm(&u);
if norm > I::epsilon() {
let norm_complex = scirs2_core::numeric::Complex::new(norm, I::zero());
for i in 0..n {
u[i] = u[i] / norm_complex;
}
}
let mut diff = I::zero();
for i in 0..n {
let delta = (u[i] - v[i]).norm();
diff += delta;
}
if diff < tol {
return u;
}
v = u;
}
v
}
#[allow(dead_code)]
fn solve_linear_system_complex<I>(
a: &Array2<scirs2_core::numeric::Complex<I>>,
b: &Array1<scirs2_core::numeric::Complex<I>>,
) -> Array1<scirs2_core::numeric::Complex<I>>
where
I: Float + Zero + One + Copy + std::fmt::Debug,
{
let n = a.nrows();
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for k in 0..n {
let mut max_row = k;
for i in k + 1..n {
if aug[[i, k]].norm() > aug[[max_row, k]].norm() {
max_row = i;
}
}
if max_row != k {
for j in 0..n + 1 {
let temp = aug[[k, j]];
aug[[k, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
let pivot = aug[[k, k]];
if pivot.norm() > I::epsilon() {
for j in k..n + 1 {
aug[[k, j]] = aug[[k, j]] / pivot;
}
}
for i in k + 1..n {
let factor = aug[[i, k]];
for j in k..n + 1 {
aug[[i, j]] = aug[[i, j]] - factor * aug[[k, j]];
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
x[i] = aug[[i, n]];
for j in i + 1..n {
x[i] = x[i] - aug[[i, j]] * x[j];
}
}
x
}
#[allow(dead_code)]
fn compute_complex_norm<I>(v: &Array1<scirs2_core::numeric::Complex<I>>) -> I
where
I: Float + Zero + Copy,
{
let mut sum = I::zero();
for &val in v.iter() {
sum = sum + val.norm_sqr();
}
sum.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_advanced_precision_eigh() {
let a = array![[4.0f32, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]];
let (eigenvalues, eigenvectors) =
advanced_precision_eigh::<_, f64>(&a.view(), None, None, true)
.expect("Operation failed");
let mut sorted_indices = (0..eigenvalues.len()).collect::<Vec<_>>();
sorted_indices.sort_by(|&i, &j| {
eigenvalues[i]
.partial_cmp(&eigenvalues[j])
.expect("Operation failed")
});
assert!(
(eigenvalues[sorted_indices[0]] - 1.0).abs() < 0.1,
"Expected eigenvalue 1.0, got {}",
eigenvalues[sorted_indices[0]]
);
assert!(
(eigenvalues[sorted_indices[1]] - 2.0).abs() < 0.1,
"Expected eigenvalue 2.0, got {}",
eigenvalues[sorted_indices[1]]
);
assert!(
(eigenvalues[sorted_indices[2]] - 4.0).abs() < 0.1,
"Expected eigenvalue 4.0, got {}",
eigenvalues[sorted_indices[2]]
);
for i in 0..eigenvectors.ncols() {
for j in i + 1..eigenvectors.ncols() {
let dot_product = eigenvectors.column(i).dot(&eigenvectors.column(j));
assert!(
dot_product.abs() < 1e-4,
"Eigenvectors {} and {} not orthogonal: dot product = {}",
i,
j,
dot_product
);
}
}
}
#[test]
fn test_estimate_condition_number() {
let identity = array![[1.0f32, 0.0], [0.0, 1.0]];
let cond = estimate_condition_number(&identity.view()).expect("Operation failed");
assert!(
(0.5..=2.0).contains(&cond),
"Expected condition number ~1, got {}",
cond
);
let well_cond = array![[2.0f32, 1.0], [1.0, 2.0]];
let cond = estimate_condition_number(&well_cond.view()).expect("Operation failed");
assert!(
cond > 0.0 && cond < 100.0,
"Expected reasonable condition number, got {}",
cond
);
}
}