use ariadnetor_core::Scalar;
#[cfg(feature = "arpack")]
use crate::krylov::ArpackParams;
use crate::krylov::LanczosParams;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum LocalEigensolverParams {
Lanczos(LanczosParams),
#[cfg(feature = "arpack")]
Arpack(ArpackParams),
}
impl Default for LocalEigensolverParams {
fn default() -> Self {
Self::Lanczos(LanczosParams::default())
}
}
impl From<LanczosParams> for LocalEigensolverParams {
fn from(p: LanczosParams) -> Self {
Self::Lanczos(p)
}
}
#[cfg(feature = "arpack")]
impl From<ArpackParams> for LocalEigensolverParams {
fn from(p: ArpackParams) -> Self {
Self::Arpack(p)
}
}
pub(crate) fn validate_eigensolver_params(
params: &LocalEigensolverParams,
) -> Result<(), &'static str> {
match params {
LocalEigensolverParams::Lanczos(p) => {
if p.max_iter == 0 {
return Err("lanczos.max_iter must be >= 1");
}
if !p.tol.is_finite() {
return Err("lanczos.tol must be finite");
}
if p.tol < 0.0 {
return Err("lanczos.tol must be non-negative");
}
Ok(())
}
#[cfg(feature = "arpack")]
LocalEigensolverParams::Arpack(p) => {
if p.max_iter == 0 {
return Err("arpack.max_iter must be >= 1");
}
if !p.tol.is_finite() {
return Err("arpack.tol must be finite");
}
if p.tol <= 0.0 {
return Err("arpack.tol must be strictly positive");
}
Ok(())
}
}
}
pub(crate) fn eigensolver_tol(params: &LocalEigensolverParams) -> f64 {
match params {
LocalEigensolverParams::Lanczos(p) => p.tol,
#[cfg(feature = "arpack")]
LocalEigensolverParams::Arpack(p) => p.tol,
}
}
#[cfg(not(feature = "arpack"))]
pub trait DmrgScalar: Scalar {}
#[cfg(not(feature = "arpack"))]
impl<T: Scalar> DmrgScalar for T {}
#[cfg(feature = "arpack")]
pub trait DmrgScalar: Scalar + crate::krylov::ArpackScalar {}
#[cfg(feature = "arpack")]
impl<T: Scalar + crate::krylov::ArpackScalar> DmrgScalar for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_lanczos_preserves_payload() {
let p = LocalEigensolverParams::from(LanczosParams {
max_iter: 7,
tol: 0.5,
seed: None,
});
match p {
LocalEigensolverParams::Lanczos(q) => {
assert_eq!(q.max_iter, 7);
assert_eq!(q.tol, 0.5);
}
#[cfg(feature = "arpack")]
other => panic!("expected Lanczos variant, got {other:?}"),
}
}
#[test]
fn eigensolver_tol_reads_lanczos_tol() {
let p = LocalEigensolverParams::Lanczos(LanczosParams {
tol: 0.5,
..LanczosParams::default()
});
assert_eq!(eigensolver_tol(&p), 0.5);
}
#[cfg(feature = "arpack")]
#[test]
fn from_arpack_preserves_payload_and_tol() {
let p = LocalEigensolverParams::from(ArpackParams {
tol: 0.5,
max_iter: 7,
ncv: None,
});
match &p {
LocalEigensolverParams::Arpack(q) => {
assert_eq!(q.max_iter, 7);
assert_eq!(q.tol, 0.5);
}
other => panic!("expected Arpack variant, got {other:?}"),
}
assert_eq!(eigensolver_tol(&p), 0.5);
}
}