use super::lanczos::{lanczos, EigenResult, LanczosOptions};
use crate::error::{SparseError, SparseResult};
use crate::sym_csr::SymCsrMatrix;
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::Float;
use scirs2_core::SparseElement;
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
#[allow(dead_code)]
pub fn eigsh<T>(
matrix: &SymCsrMatrix<T>,
k: Option<usize>,
which: Option<&str>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let opts = options.unwrap_or_default();
let k = k.unwrap_or(opts.numeigenvalues);
let which = which.unwrap_or("LA");
let (n, m) = matrix.shape();
if n != m {
return Err(SparseError::ValueError(
"Matrix must be square for eigenvalue computation".to_string(),
));
}
enhanced_lanczos(matrix, k, which, &opts)
}
#[allow(dead_code)]
pub fn eigsh_shift_invert<T>(
matrix: &SymCsrMatrix<T>,
sigma: T,
k: Option<usize>,
which: Option<&str>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let opts = options.unwrap_or_default();
let k = k.unwrap_or(6);
let which = which.unwrap_or("LM");
let (n, m) = matrix.shape();
if n != m {
return Err(SparseError::ValueError(
"Matrix must be square for eigenvalue computation".to_string(),
));
}
let mut shifted_matrix = matrix.clone();
for i in 0..n {
for j in shifted_matrix.indptr[i]..shifted_matrix.indptr[i + 1] {
if shifted_matrix.indices[j] == i {
shifted_matrix.data[j] = shifted_matrix.data[j] - sigma;
break;
}
}
}
let mut shift_opts = opts.clone();
shift_opts.numeigenvalues = k;
let result = lanczos(&shifted_matrix, &shift_opts, None)?;
let mut transformed_eigenvalues = Array1::zeros(result.eigenvalues.len());
for (i, &mu) in result.eigenvalues.iter().enumerate() {
if !SparseElement::is_zero(&mu) {
transformed_eigenvalues[i] = sigma + T::sparse_one() / mu;
} else {
transformed_eigenvalues[i] = sigma;
}
}
Ok(EigenResult {
eigenvalues: transformed_eigenvalues,
eigenvectors: result.eigenvectors,
iterations: result.iterations,
residuals: result.residuals,
converged: result.converged,
})
}
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub fn eigsh_shift_invert_enhanced<T>(
matrix: &SymCsrMatrix<T>,
sigma: T,
k: Option<usize>,
which: Option<&str>,
mode: Option<&str>,
return_eigenvectors: Option<bool>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let _mode = mode.unwrap_or("normal");
let _return_eigenvectors = return_eigenvectors.unwrap_or(true);
eigsh_shift_invert(matrix, sigma, k, which, options)
}
fn enhanced_lanczos<T>(
matrix: &SymCsrMatrix<T>,
k: usize,
which: &str,
options: &LanczosOptions,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let n = matrix.shape().0;
let mut enhanced_opts = options.clone();
enhanced_opts.numeigenvalues = k;
enhanced_opts.max_subspace_size = (k * 2 + 10).min(n);
enhanced_opts.tol = enhanced_opts.tol.min(1e-10);
let result = lanczos(matrix, &enhanced_opts, None)?;
process_eigenvalue_selection(result, which, k)
}
fn process_eigenvalue_selection<T>(
mut result: EigenResult<T>,
which: &str,
k: usize,
) -> SparseResult<EigenResult<T>>
where
T: Float + Debug + Copy,
{
let n_computed = result.eigenvalues.len();
let n_requested = k.min(n_computed);
match which {
"LA" => {
result.eigenvalues = result
.eigenvalues
.slice(scirs2_core::ndarray::s![..n_requested])
.to_owned();
if let Some(ref mut evecs) = result.eigenvectors {
*evecs = evecs
.slice(scirs2_core::ndarray::s![.., ..n_requested])
.to_owned();
}
result.residuals = result
.residuals
.slice(scirs2_core::ndarray::s![..n_requested])
.to_owned();
}
"SA" => {
let mut eigenvals = result.eigenvalues.to_vec();
eigenvals.reverse();
result.eigenvalues = Array1::from_vec(eigenvals[..n_requested].to_vec());
if let Some(ref mut evecs) = result.eigenvectors {
let ncols = evecs.ncols();
let mut evecs_vec = Vec::new();
for j in (0..ncols).rev().take(n_requested) {
for i in 0..evecs.nrows() {
evecs_vec.push(evecs[[i, j]]);
}
}
*evecs = scirs2_core::ndarray::Array2::from_shape_vec(
(evecs.nrows(), n_requested),
evecs_vec,
)
.map_err(|_| {
SparseError::ValueError("Failed to reshape eigenvectors".to_string())
})?;
}
let mut residuals = result.residuals.to_vec();
residuals.reverse();
result.residuals = Array1::from_vec(residuals[..n_requested].to_vec());
}
"LM" => {
let mut indices: Vec<usize> = (0..n_computed).collect();
indices.sort_by(|&i, &j| {
result.eigenvalues[j]
.abs()
.partial_cmp(&result.eigenvalues[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut new_eigenvals = Vec::new();
let mut new_residuals = Vec::new();
for &idx in indices.iter().take(n_requested) {
new_eigenvals.push(result.eigenvalues[idx]);
new_residuals.push(result.residuals[idx]);
}
result.eigenvalues = Array1::from_vec(new_eigenvals);
result.residuals = Array1::from_vec(new_residuals);
if let Some(ref mut evecs) = result.eigenvectors {
let mut new_evecs = Vec::new();
for &idx in indices.iter().take(n_requested) {
for i in 0..evecs.nrows() {
new_evecs.push(evecs[[i, idx]]);
}
}
*evecs = scirs2_core::ndarray::Array2::from_shape_vec(
(evecs.nrows(), n_requested),
new_evecs,
)
.map_err(|_| {
SparseError::ValueError("Failed to reshape eigenvectors".to_string())
})?;
}
}
"SM" => {
let mut indices: Vec<usize> = (0..n_computed).collect();
indices.sort_by(|&i, &j| {
result.eigenvalues[i]
.abs()
.partial_cmp(&result.eigenvalues[j].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut new_eigenvals = Vec::new();
let mut new_residuals = Vec::new();
for &idx in indices.iter().take(n_requested) {
new_eigenvals.push(result.eigenvalues[idx]);
new_residuals.push(result.residuals[idx]);
}
result.eigenvalues = Array1::from_vec(new_eigenvals);
result.residuals = Array1::from_vec(new_residuals);
if let Some(ref mut evecs) = result.eigenvectors {
let mut new_evecs = Vec::new();
for &idx in indices.iter().take(n_requested) {
for i in 0..evecs.nrows() {
new_evecs.push(evecs[[i, idx]]);
}
}
*evecs = scirs2_core::ndarray::Array2::from_shape_vec(
(evecs.nrows(), n_requested),
new_evecs,
)
.map_err(|_| {
SparseError::ValueError("Failed to reshape eigenvectors".to_string())
})?;
}
}
_ => {
return Err(SparseError::ValueError(format!(
"Unknown eigenvalue selection criterion: {}. Use 'LA', 'SA', 'LM', or 'SM'",
which
)));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sym_csr::SymCsrMatrix;
#[test]
fn test_eigsh_basic() {
let data = vec![4.0, 2.0, 3.0, 5.0, 1.0];
let indptr = vec![0, 1, 3, 5];
let indices = vec![0, 0, 1, 1, 2];
let matrix = SymCsrMatrix::new(data, indptr, indices, (3, 3)).expect("Operation failed");
let result = eigsh(&matrix, Some(2), Some("LA"), None).expect("Operation failed");
assert!(!result.eigenvalues.is_empty());
assert!(result.eigenvalues.len() <= 2);
assert!(result.eigenvalues[0].is_finite());
}
#[test]
fn test_eigsh_different_which() {
let data = vec![2.0, 1.0, 2.0]; let indptr = vec![0, 1, 3]; let indices = vec![0, 0, 1]; let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).expect("Operation failed");
let result_la = eigsh(&matrix, Some(1), Some("LA"), None).expect("Operation failed");
assert!(!result_la.eigenvalues.is_empty());
assert!(result_la.eigenvalues[0].is_finite());
let result_sa = eigsh(&matrix, Some(1), Some("SA"), None).expect("Operation failed");
assert!(!result_sa.eigenvalues.is_empty());
assert!(result_sa.eigenvalues[0].is_finite());
}
#[test]
fn test_eigsh_shift_invert() {
let data = vec![4.0, 1.0, 3.0]; let indptr = vec![0, 1, 3]; let indices = vec![0, 0, 1]; let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).expect("Operation failed");
let result =
eigsh_shift_invert(&matrix, 2.0, Some(1), None, None).expect("Operation failed");
assert!(!result.eigenvalues.is_empty());
assert!(result.eigenvalues[0].is_finite());
}
#[test]
fn test_process_eigenvalue_selection() {
let eigenvalues = Array1::from_vec(vec![5.0, 3.0, 1.0]);
let residuals = Array1::from_vec(vec![1e-8, 1e-9, 1e-7]);
let result = EigenResult {
eigenvalues,
eigenvectors: None,
iterations: 10,
residuals,
converged: true,
};
let result_la =
process_eigenvalue_selection(result.clone(), "LA", 2).expect("Operation failed");
assert_eq!(result_la.eigenvalues.len(), 2);
assert_eq!(result_la.eigenvalues[0], 5.0);
assert_eq!(result_la.eigenvalues[1], 3.0);
let result_sa =
process_eigenvalue_selection(result.clone(), "SA", 2).expect("Operation failed");
assert_eq!(result_sa.eigenvalues.len(), 2);
assert_eq!(result_sa.eigenvalues[0], 1.0);
assert_eq!(result_sa.eigenvalues[1], 3.0);
}
}