pub mod feast;
pub mod generalized;
pub mod iterative;
pub mod sparse;
pub mod standard;
pub mod zolotarev;
use crate::error::{LinalgError, LinalgResult};
pub use standard::EigenResult;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub use generalized::{eig_gen, eigh_gen, eigvals_gen, eigvalsh_gen};
pub use standard::{eig, eigh, eigvals, power_iteration};
pub use sparse::{arnoldi, eigs_gen, lanczos, svds};
pub use iterative::{
arnoldi as arnoldi_dense, inverse_power_iteration, jacobi_eigenvalue, lanczos as lanczos_dense,
power_iteration_dense, rayleigh_quotient_iteration, ArnoldiResult, LanczosResult,
};
#[allow(dead_code)]
pub fn eigvalsh<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (eigenvalues, _) = eigh(a, workers)?;
Ok(eigenvalues)
}
#[allow(dead_code)]
pub fn advanced_precision_eig<F>(
a: &ArrayView2<F>,
tolerance: F,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
if n != a.ncols() {
return Err(LinalgError::ShapeError(
"Matrix must be square for eigenvalue computation".to_string(),
));
}
if n == 1 {
let eigenvalue = a[[0, 0]];
let eigenvector = Array2::from_elem((1, 1), F::one());
return Ok((Array1::from_elem(1, eigenvalue), eigenvector));
}
if n == 2 {
let a11 = a[[0, 0]];
let a12 = a[[0, 1]];
let a21 = a[[1, 0]];
let a22 = a[[1, 1]];
let trace = a11 + a22;
let det = a11 * a22 - a12 * a21;
let four = F::from(4.0).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert 4.0 to target type".to_string())
})?;
let discriminant = trace * trace - four * det;
if discriminant >= F::zero() {
let sqrt_disc = discriminant.sqrt();
let two = F::from(2.0).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert 2.0 to target type".to_string())
})?;
let lambda1 = (trace + sqrt_disc) / two;
let lambda2 = (trace - sqrt_disc) / two;
let mut eigenvectors = Array2::zeros((2, 2));
if (a11 - lambda1).abs() > tolerance || a12.abs() > tolerance {
let v1_1 = a12;
let v1_2 = lambda1 - a11;
let norm1 = (v1_1 * v1_1 + v1_2 * v1_2).sqrt();
eigenvectors[[0, 0]] = v1_1 / norm1;
eigenvectors[[1, 0]] = v1_2 / norm1;
} else {
eigenvectors[[0, 0]] = F::one();
eigenvectors[[1, 0]] = F::zero();
}
if (a11 - lambda2).abs() > tolerance || a12.abs() > tolerance {
let v2_1 = a12;
let v2_2 = lambda2 - a11;
let norm2 = (v2_1 * v2_1 + v2_2 * v2_2).sqrt();
eigenvectors[[0, 1]] = v2_1 / norm2;
eigenvectors[[1, 1]] = v2_2 / norm2;
} else {
eigenvectors[[0, 1]] = F::zero();
eigenvectors[[1, 1]] = F::one();
}
return Ok((Array1::from_vec(vec![lambda1, lambda2]), eigenvectors));
}
}
let mut is_symmetric = true;
for i in 0..n {
for j in i + 1..n {
if (a[[i, j]] - a[[j, i]]).abs() > tolerance {
is_symmetric = false;
break;
}
}
if !is_symmetric {
break;
}
}
if is_symmetric {
advanced_precision_symmetric_eigensolver(a, tolerance)
} else {
advanced_precision_general_eigensolver(a, tolerance)
}
}
#[allow(dead_code)]
fn advanced_precision_symmetric_eigensolver<F>(
a: &ArrayView2<F>,
tolerance: F,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = a.nrows();
let (mut eigenvalues, mut eigenvectors) = eigh(a, None)?;
let condition_estimate = estimate_condition_number(a);
let adaptive_tolerance = if condition_estimate > F::from(1e12).expect("Operation failed") {
tolerance * F::from(0.1).expect("Operation failed") } else {
tolerance
};
let max_iterations = 50; let mut converged = false;
for iter in 0..max_iterations {
let mut max_residual = F::zero();
let mut improvement_made = false;
for i in 0..n {
let v = eigenvectors.column(i);
let lambda = eigenvalues[i];
let av = kahanmatrix_vector_product(a, &v);
let lambda_v = v.mapv(|x| x * lambda);
let residual = kahan_vector_subtraction(&av, &lambda_v);
let residual_norm = kahan_dot_product(&residual, &residual).sqrt();
if residual_norm > max_residual {
max_residual = residual_norm;
}
let vt_av = kahan_dot_product(&v.to_owned(), &av);
let vt_v = kahan_dot_product(&v.to_owned(), &v.to_owned());
if vt_v > F::epsilon() {
let new_eigenvalue = vt_av / vt_v;
let correction = newton_eigenvalue_correction(
a,
&v.to_owned(),
new_eigenvalue,
adaptive_tolerance,
);
let corrected_eigenvalue = new_eigenvalue + correction;
if (corrected_eigenvalue - eigenvalues[i]).abs() > F::epsilon() {
eigenvalues[i] = corrected_eigenvalue;
improvement_made = true;
}
}
if residual_norm > adaptive_tolerance {
let refined_vector =
enhanced_inverse_iteration(a, lambda, &v.to_owned(), adaptive_tolerance)?;
eigenvectors.column_mut(i).assign(&refined_vector);
improvement_made = true;
}
}
if improvement_made {
enhanced_gram_schmidt_orthogonalization(&mut eigenvectors, adaptive_tolerance);
}
if max_residual < adaptive_tolerance && !improvement_made {
converged = true;
break;
}
if iter % 5 == 4 {
let verification_passed =
verify_eigenvalue_accuracy(a, &eigenvalues, &eigenvectors, adaptive_tolerance);
if verification_passed {
converged = true;
break;
}
}
}
if !converged {
eprintln!("Warning: Advanced-precision eigenvalue solver did not fully converge to desired tolerance");
}
let (sorted_eigenvalues, sorted_eigenvectors) =
sort_eigenvalues_and_vectors(eigenvalues, eigenvectors);
Ok((sorted_eigenvalues, sorted_eigenvectors))
}
#[allow(dead_code)]
fn advanced_precision_general_eigensolver<F>(
a: &ArrayView2<F>,
tolerance: F,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
let mut nearly_symmetric = true;
let symmetry_tolerance = tolerance * F::from(10.0).expect("Operation failed");
for i in 0..n {
for j in i + 1..n {
if (a[[i, j]] - a[[j, i]]).abs() > symmetry_tolerance {
nearly_symmetric = false;
break;
}
}
if !nearly_symmetric {
break;
}
}
if nearly_symmetric {
advanced_precision_symmetric_eigensolver(a, tolerance)
} else {
eprintln!("Warning: Advanced-precision solver for general non-symmetric matrices is limited. Using standard solver.");
eigh(a, None)
}
}
#[allow(dead_code)]
fn kahanmatrix_vector_product<F>(a: &ArrayView2<F>, v: &ArrayView1<F>) -> Array1<F>
where
F: Float + Sum,
{
let n = a.nrows();
let mut result = Array1::zeros(n);
for i in 0..n {
let mut sum = F::zero();
let mut compensation = F::zero();
for j in 0..a.ncols() {
let term = a[[i, j]] * v[j] - compensation;
let new_sum = sum + term;
compensation = (new_sum - sum) - term;
sum = new_sum;
}
result[i] = sum;
}
result
}
#[allow(dead_code)]
fn kahan_vector_subtraction<F>(a: &Array1<F>, b: &Array1<F>) -> Array1<F>
where
F: Float,
{
let mut result = Array1::zeros(a.len());
let mut compensation = Array1::zeros(a.len());
for i in 0..a.len() {
let term = a[i] - b[i] - compensation[i];
let new_result = result[i] + term;
compensation[i] = (new_result - result[i]) - term;
result[i] = new_result;
}
result
}
#[allow(dead_code)]
fn kahan_dot_product<F>(a: &Array1<F>, b: &Array1<F>) -> F
where
F: Float + Sum,
{
let mut sum = F::zero();
let mut compensation = F::zero();
for (ai, bi) in a.iter().zip(b.iter()) {
let term = (*ai) * (*bi) - compensation;
let new_sum = sum + term;
compensation = (new_sum - sum) - term;
sum = new_sum;
}
sum
}
#[allow(dead_code)]
fn newton_eigenvalue_correction<F>(a: &ArrayView2<F>, v: &Array1<F>, lambda: F, tolerance: F) -> F
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let max_newton_iterations = 5;
let mut correction = F::zero();
for _ in 0..max_newton_iterations {
let av = a.dot(v);
let lambda_v = v.mapv(|x| x * lambda);
let residual = &av - &lambda_v;
let f_lambda = v.dot(&residual);
if f_lambda.abs() < tolerance * F::from(0.01).expect("Operation failed") {
break;
}
let f_prime = -v.dot(v);
if f_prime.abs() > F::epsilon() {
let delta = f_lambda / f_prime;
correction += delta;
if delta.abs() < tolerance * F::from(0.1).expect("Operation failed") {
break;
}
}
}
correction
}
#[allow(dead_code)]
fn enhanced_inverse_iteration<F>(
a: &ArrayView2<F>,
lambda: F,
v: &Array1<F>,
tolerance: F,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
let mut refined_v = v.clone();
let mut shiftedmatrix = a.to_owned();
for i in 0..n {
shiftedmatrix[[i, i]] -= lambda;
}
let regularization = tolerance * F::from(1e-6).expect("Operation failed");
for i in 0..n {
shiftedmatrix[[i, i]] += regularization;
}
if let Ok(y) = crate::solve::solve(&shiftedmatrix.view(), &refined_v.view(), None) {
let norm = y.dot(&y).sqrt();
if norm > F::epsilon() {
refined_v = y / norm;
}
}
Ok(refined_v)
}
#[allow(dead_code)]
fn enhanced_gram_schmidt_orthogonalization<F>(vectors: &mut Array2<F>, tolerance: F)
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = vectors.ncols();
let num_passes = 3;
for _pass in 0..num_passes {
for i in 0..n {
for j in 0..i {
let vi = vectors.column(i).to_owned();
let vj = vectors.column(j).to_owned();
let proj = kahan_dot_product(&vi, &vj);
for k in 0..vectors.nrows() {
vectors[[k, i]] -= proj * vj[k];
}
}
let vi = vectors.column(i).to_owned();
let norm = kahan_dot_product(&vi, &vi).sqrt();
if norm > tolerance {
for k in 0..vectors.nrows() {
vectors[[k, i]] /= norm;
}
}
}
}
}
#[allow(dead_code)]
fn verify_eigenvalue_accuracy<F>(
a: &ArrayView2<F>,
eigenvalues: &Array1<F>,
eigenvectors: &Array2<F>,
tolerance: F,
) -> bool
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = eigenvalues.len();
for i in 0..n {
let v = eigenvectors.column(i);
let lambda = eigenvalues[i];
let av = kahanmatrix_vector_product(a, &v);
let lambda_v = v.mapv(|x| x * lambda);
let residual = kahan_vector_subtraction(&av, &lambda_v);
let residual_norm = kahan_dot_product(&residual, &residual).sqrt();
if residual_norm > tolerance {
return false;
}
}
true
}
#[allow(dead_code)]
fn sort_eigenvalues_and_vectors<F>(
eigenvalues: Array1<F>,
eigenvectors: Array2<F>,
) -> (Array1<F>, Array2<F>)
where
F: Float,
{
let n = eigenvalues.len();
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&i, &j| {
eigenvalues[j]
.abs()
.partial_cmp(&eigenvalues[i].abs())
.expect("Operation failed")
});
let sorted_eigenvalues = indices.iter().map(|&i| eigenvalues[i]).collect();
let mut sorted_eigenvectors = Array2::zeros(eigenvectors.raw_dim());
for (new_idx, &old_idx) in indices.iter().enumerate() {
sorted_eigenvectors
.column_mut(new_idx)
.assign(&eigenvectors.column(old_idx));
}
(sorted_eigenvalues, sorted_eigenvectors)
}
#[allow(dead_code)]
pub fn estimate_condition_number<F>(a: &ArrayView2<F>) -> F
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let n = a.nrows();
if n == 0 {
return F::one();
}
let mut is_diagonal = true;
for i in 0..n {
for j in 0..a.ncols() {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut max_diag = F::zero();
let mut min_diag = F::infinity();
for i in 0..n.min(a.ncols()) {
let val = a[[i, i]].abs();
if val > max_diag {
max_diag = val;
}
if val > F::zero() && val < min_diag {
min_diag = val;
}
}
if min_diag <= F::zero() || min_diag == F::infinity() {
F::from(1e12).unwrap_or_else(|| F::max_value() / F::from(1000.0).unwrap_or(F::one()))
} else {
max_diag / min_diag
}
} else {
if let Ok((_, s, _)) = crate::decomposition::svd(a, false, Some(1)) {
let mut max_sv = F::zero();
let mut min_sv = F::infinity();
for &sv in s.iter() {
if sv > max_sv {
max_sv = sv;
}
if sv > F::zero() && sv < min_sv {
min_sv = sv;
}
}
if min_sv <= F::zero() || min_sv == F::infinity() {
F::from(1e12)
.unwrap_or_else(|| F::max_value() / F::from(1000.0).unwrap_or(F::one()))
} else {
max_sv / min_sv
}
} else {
if let (Ok(norm_2), Ok(norm_1)) = (
crate::norm::matrix_norm(a, "2", Some(1)),
crate::norm::matrix_norm(a, "1", Some(1)),
) {
let n_f = F::from(n).unwrap_or_else(|| F::one());
(norm_2 * norm_1) / n_f
} else {
F::from(1e6).unwrap_or_else(|| F::one())
}
}
}
}
#[allow(dead_code)]
pub fn adaptive_tolerance_selection<F>(condition_number: F) -> F
where
F: Float + NumAssign,
{
let hundred = F::from(100.0).unwrap_or_else(|| {
let ten = F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one();
ten * ten
});
let base_tol = F::epsilon() * hundred;
let threshold_1e12 =
F::from(1e12).unwrap_or_else(|| F::max_value() / F::from(1000.0).unwrap_or(F::one()));
let threshold_1e8 =
F::from(1e8).unwrap_or_else(|| F::max_value() / F::from(10000.0).unwrap_or(F::one()));
let threshold_1e4 = F::from(1e4).unwrap_or_else(|| F::from(10000.0).unwrap_or(F::one()));
if condition_number > threshold_1e12 {
base_tol
* F::from(1000.0).unwrap_or_else(|| {
let ten = F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one();
ten * ten * ten
})
} else if condition_number > threshold_1e8 {
base_tol * hundred
} else if condition_number > threshold_1e4 {
base_tol
* F::from(10.0).unwrap_or_else(|| {
F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
+ F::one()
})
} else {
base_tol
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_backward_compatibility() {
let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
let (w1, v1) = eig(&a.view(), None).expect("Operation failed");
let (w2, v2) = standard::eig(&a.view(), None).expect("Operation failed");
assert_eq!(w1.len(), w2.len());
assert_eq!(v1.dim(), v2.dim());
let (w1, v1) = eigh(&a.view(), None).expect("Operation failed");
let (w2, v2) = standard::eigh(&a.view(), None).expect("Operation failed");
assert_eq!(w1.len(), w2.len());
assert_eq!(v1.dim(), v2.dim());
let w1 = eigvals(&a.view(), None).expect("Operation failed");
let w2 = standard::eigvals(&a.view(), None).expect("Operation failed");
assert_eq!(w1.len(), w2.len());
let w1 = eigvalsh(&a.view(), None).expect("Operation failed");
let (w2_, _) = eigh(&a.view(), None).expect("Operation failed");
for i in 0..w1.len() {
assert_relative_eq!(w1[i], w2_[i], epsilon = 1e-10);
}
}
#[test]
fn test_generalized_eigenvalue_re_exports() {
let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
let b = array![[1.0_f64, 0.0], [0.0, 2.0]];
let (w1, v1) = eig_gen(&a.view(), &b.view(), None).expect("Operation failed");
let (w2, v2) = generalized::eig_gen(&a.view(), &b.view(), None).expect("Operation failed");
assert_eq!(w1.len(), w2.len());
assert_eq!(v1.dim(), v2.dim());
let (w1, v1) = eigh_gen(&a.view(), &b.view(), None).expect("Operation failed");
let (w2, v2) = generalized::eigh_gen(&a.view(), &b.view(), None).expect("Operation failed");
assert_eq!(w1.len(), w2.len());
assert_eq!(v1.dim(), v2.dim());
}
#[test]
fn test_advanced_precision_fallback() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let result = advanced_precision_eig(&a.view(), 1e-12);
assert!(result.is_ok());
let (w, v) = result.expect("Operation failed");
assert_eq!(w.len(), 2);
assert_eq!(v.dim(), (2, 2));
}
#[test]
fn test_condition_number_estimation() {
let well_conditioned = array![[1.0_f64, 0.0], [0.0, 1.0]];
let cond1 = estimate_condition_number(&well_conditioned.view());
assert!(cond1 <= 2.0);
let ill_conditioned = array![[1.0_f64, 0.0], [0.0, 1e-12]];
let direct_cond = 1.0_f64 / 1e-12_f64;
let cond2 = estimate_condition_number(&ill_conditioned.view());
assert!(
cond2 > 1e10,
"Condition number {:.2e} should be > 1e10 (expected ~{:.2e})",
cond2,
direct_cond
);
}
#[test]
fn test_adaptive_tolerance() {
let tol1 = adaptive_tolerance_selection(1.0_f64);
let base_tol = f64::EPSILON * 100.0;
assert_relative_eq!(tol1, base_tol, epsilon = 1e-15);
let tol2 = adaptive_tolerance_selection(1e15_f64);
assert!(tol2 > base_tol * 100.0);
}
#[test]
fn test_module_organization() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let _ = standard::eig(&a.view(), None).expect("Operation failed");
let b = Array2::eye(2);
let _ = generalized::eig_gen(&a.view(), &b.view(), None).expect("Operation failed");
let csr = sparse::CsrMatrix::new(2, 2, vec![], vec![], vec![]);
let result = sparse::lanczos(&csr, 1, "largest", 0.0_f64, 10, 1e-6);
let _ = result;
}
}