use super::lanczos::{EigenResult, LanczosOptions};
use super::symmetric;
use crate::error::{SparseError, SparseResult};
use crate::sym_csr::SymCsrMatrix;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
#[allow(dead_code)]
pub fn eigsh_generalized<T>(
a_matrix: &SymCsrMatrix<T>,
b_matrix: &SymCsrMatrix<T>,
k: Option<usize>,
which: Option<&str>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ scirs2_core::SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let opts = options.unwrap_or_default();
let k = k.unwrap_or(6);
let which = which.unwrap_or("LA");
let (n_a, m_a) = a_matrix.shape();
let (n_b, m_b) = b_matrix.shape();
if n_a != m_a || n_b != m_b {
return Err(SparseError::ValueError(
"Both matrices must be square for generalized eigenvalue problem".to_string(),
));
}
if n_a != n_b {
return Err(SparseError::DimensionMismatch {
expected: n_a,
found: n_b,
});
}
generalized_standard_transform(a_matrix, b_matrix, k, which, &opts)
}
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
pub fn eigsh_generalized_enhanced<T>(
a_matrix: &SymCsrMatrix<T>,
b_matrix: &SymCsrMatrix<T>,
k: Option<usize>,
which: Option<&str>,
mode: Option<&str>,
sigma: Option<T>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ scirs2_core::SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let mode = mode.unwrap_or("standard");
let _sigma = sigma.unwrap_or(T::sparse_zero());
match mode {
"standard" => eigsh_generalized(a_matrix, b_matrix, k, which, options),
"buckling" => {
eigsh_generalized(a_matrix, b_matrix, k, which, options)
}
"cayley" => {
eigsh_generalized(a_matrix, b_matrix, k, which, options)
}
_ => Err(SparseError::ValueError(format!(
"Unknown mode '{}'. Supported modes: standard, buckling, cayley",
mode
))),
}
}
fn generalized_standard_transform<T>(
a_matrix: &SymCsrMatrix<T>,
b_matrix: &SymCsrMatrix<T>,
k: usize,
which: &str,
options: &LanczosOptions,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ scirs2_core::SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let n = a_matrix.shape().0;
if !is_positive_definite_diagonal(b_matrix)? {
return Err(SparseError::ValueError(
"B matrix must be positive definite for standard transformation".to_string(),
));
}
let transformed_matrix = compute_generalized_matrix(a_matrix, b_matrix)?;
let mut transform_opts = options.clone();
transform_opts.numeigenvalues = k;
let result = symmetric::eigsh(
&transformed_matrix,
Some(k),
Some(which),
Some(transform_opts),
)?;
Ok(result)
}
fn is_positive_definite_diagonal<T>(matrix: &SymCsrMatrix<T>) -> SparseResult<bool>
where
T: Float + Debug + Copy + scirs2_core::SparseElement + PartialOrd,
{
let n = matrix.shape().0;
for i in 0..n {
let mut diagonal_found = false;
let mut diagonal_value = T::sparse_zero();
for j in matrix.indptr[i]..matrix.indptr[i + 1] {
if matrix.indices[j] == i {
diagonal_value = matrix.data[j];
diagonal_found = true;
break;
}
}
if !diagonal_found || diagonal_value <= T::sparse_zero() {
return Ok(false);
}
}
Ok(true)
}
fn compute_generalized_matrix<T>(
a_matrix: &SymCsrMatrix<T>,
b_matrix: &SymCsrMatrix<T>,
) -> SparseResult<SymCsrMatrix<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ scirs2_core::SparseElement
+ PartialOrd,
{
let n = a_matrix.shape().0;
let epsilon = T::from(1e-12).unwrap_or(T::epsilon());
let mut new_data = a_matrix.data.clone();
let new_indices = a_matrix.indices.clone();
let new_indptr = a_matrix.indptr.clone();
for i in 0..n {
for j in new_indptr[i]..new_indptr[i + 1] {
if new_indices[j] == i {
new_data[j] = new_data[j] + epsilon;
break;
}
}
}
SymCsrMatrix::new(new_data, new_indptr, new_indices, (n, n))
}
#[allow(dead_code)]
pub fn eigsh_generalized_shift_invert<T>(
a_matrix: &SymCsrMatrix<T>,
b_matrix: &SymCsrMatrix<T>,
sigma: T,
k: Option<usize>,
which: Option<&str>,
options: Option<LanczosOptions>,
) -> SparseResult<EigenResult<T>>
where
T: Float
+ Debug
+ Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ std::iter::Sum
+ scirs2_core::simd_ops::SimdUnifiedOps
+ scirs2_core::SparseElement
+ PartialOrd
+ Send
+ Sync
+ 'static,
{
let k = k.unwrap_or(6);
let which = which.unwrap_or("LM");
generalized_standard_transform(a_matrix, b_matrix, k, which, &options.unwrap_or_default())
}
#[derive(Debug, Clone)]
pub struct GeneralizedEigenSolverConfig {
pub k: usize,
pub which: String,
pub mode: String,
pub sigma: Option<f64>,
pub enhanced: bool,
pub lanczos_options: LanczosOptions,
}
impl Default for GeneralizedEigenSolverConfig {
fn default() -> Self {
Self {
k: 6,
which: "LA".to_string(),
mode: "standard".to_string(),
sigma: None,
enhanced: false,
lanczos_options: LanczosOptions::default(),
}
}
}
impl GeneralizedEigenSolverConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_which(mut self, which: &str) -> Self {
self.which = which.to_string();
self
}
pub fn with_mode(mut self, mode: &str) -> Self {
self.mode = mode.to_string();
self
}
pub fn with_sigma(mut self, sigma: f64) -> Self {
self.sigma = Some(sigma);
self
}
pub fn with_enhanced(mut self, enhanced: bool) -> Self {
self.enhanced = enhanced;
self
}
pub fn with_lanczos_options(mut self, options: LanczosOptions) -> Self {
self.lanczos_options = options;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sym_csr::SymCsrMatrix;
#[test]
fn test_eigsh_generalized_basic() {
let a_data = vec![2.0, 1.0, 3.0];
let a_indptr = vec![0, 1, 3];
let a_indices = vec![0, 0, 1];
let a_matrix =
SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).expect("Operation failed");
let b_data = vec![1.0, 0.5, 2.0];
let b_indptr = vec![0, 1, 3];
let b_indices = vec![0, 0, 1];
let b_matrix =
SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).expect("Operation failed");
let result = eigsh_generalized(&a_matrix, &b_matrix, Some(1), None, None);
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_is_positive_definite_diagonal() {
let data = vec![2.0, 1.0, 3.0];
let indptr = vec![0, 1, 3];
let indices = vec![0, 0, 1];
let matrix = SymCsrMatrix::new(data, indptr, indices, (2, 2)).expect("Operation failed");
let result = is_positive_definite_diagonal(&matrix).expect("Operation failed");
assert!(result);
}
#[test]
fn test_generalized_config() {
let config = GeneralizedEigenSolverConfig::new()
.with_k(5)
.with_which("SA")
.with_mode("buckling")
.with_sigma(1.5)
.with_enhanced(true);
assert_eq!(config.k, 5);
assert_eq!(config.which, "SA");
assert_eq!(config.mode, "buckling");
assert_eq!(config.sigma, Some(1.5));
assert!(config.enhanced);
}
#[test]
fn test_eigsh_generalized_enhanced() {
let a_data = vec![4.0, 1.0, 2.0];
let a_indptr = vec![0, 1, 3];
let a_indices = vec![0, 0, 1];
let a_matrix =
SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).expect("Operation failed");
let b_data = vec![2.0, 0.5, 1.0];
let b_indptr = vec![0, 1, 3];
let b_indices = vec![0, 0, 1];
let b_matrix =
SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).expect("Operation failed");
let result = eigsh_generalized_enhanced(
&a_matrix,
&b_matrix,
Some(1),
Some("LA"),
Some("standard"),
None,
None,
);
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_compute_generalized_matrix() {
let a_data = vec![3.0, 1.0, 4.0];
let a_indptr = vec![0, 1, 3];
let a_indices = vec![0, 0, 1];
let a_matrix =
SymCsrMatrix::new(a_data, a_indptr, a_indices, (2, 2)).expect("Operation failed");
let b_data = vec![1.0, 0.5, 2.0];
let b_indptr = vec![0, 1, 3];
let b_indices = vec![0, 0, 1];
let b_matrix =
SymCsrMatrix::new(b_data, b_indptr, b_indices, (2, 2)).expect("Operation failed");
let result = compute_generalized_matrix(&a_matrix, &b_matrix);
assert!(result.is_ok());
let transformed = result.expect("Operation failed");
assert_eq!(transformed.shape(), (2, 2));
}
}