use ariadnetor_core::Scalar;
use ariadnetor_tensor::{ComputeBackendTensorExt, DenseTensor, Host, linear_combine};
use num_complex::{Complex32, Complex64};
use num_traits::{NumCast, One, Zero};
use super::lanczos::LinearOp;
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for num_complex::Complex<f32> {}
impl Sealed for num_complex::Complex<f64> {}
}
#[derive(Debug, Clone)]
pub struct ArpackParams {
pub tol: f64,
pub max_iter: usize,
pub ncv: Option<usize>,
}
impl Default for ArpackParams {
fn default() -> Self {
Self {
tol: 1e-10,
max_iter: 300,
ncv: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ArpackResult<T: Scalar> {
pub eigenvalue: T::Real,
pub eigenvector: DenseTensor<T>,
pub iters: usize,
pub n_matvec: usize,
pub residual: T::Real,
pub converged: bool,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ArpackError {
#[error("invalid parameter: {0}")]
InvalidParam(&'static str),
#[error("ARPACK *aupd returned info = {0}")]
AupdFailed(i32),
#[error("ARPACK *eupd returned info = {0}")]
EupdFailed(i32),
#[error("ARPACK requested unsupported ido = {0}")]
UnexpectedIdo(i32),
#[error(
"ARPACK hit max_iter without convergence: iters = {iters}, \
nconv = {nconv}, n_matvec = {n_matvec}"
)]
MaxIterReached {
iters: usize,
nconv: usize,
n_matvec: usize,
},
#[error(
"ARPACK could not apply shifts during a restart cycle (info = 3); \
increase ncv relative to nev: iters = {iters}, nconv = {nconv}, \
n_matvec = {n_matvec}"
)]
NoShiftsApplied {
iters: usize,
nconv: usize,
n_matvec: usize,
},
#[error(
"ARPACK could not build an Arnoldi factorization (info = -9999); \
built size {factorization_size}; try increasing max_iter or ncv: \
iters = {iters}, n_matvec = {n_matvec}"
)]
ArnoldiFactorizationFailed {
iters: usize,
factorization_size: usize,
n_matvec: usize,
},
}
impl From<arpack::Error> for ArpackError {
fn from(e: arpack::Error) -> Self {
match e {
arpack::Error::InvalidParam(m) => ArpackError::InvalidParam(m),
arpack::Error::AupdFailed { info, .. } => ArpackError::AupdFailed(info),
arpack::Error::EupdFailed { info, .. } => ArpackError::EupdFailed(info),
arpack::Error::UnexpectedIdo(i) => ArpackError::UnexpectedIdo(i),
arpack::Error::MaxIterReached {
iters,
nconv,
n_matvec,
} => ArpackError::MaxIterReached {
iters,
nconv,
n_matvec,
},
arpack::Error::NoShiftsApplied {
iters,
nconv,
n_matvec,
} => ArpackError::NoShiftsApplied {
iters,
nconv,
n_matvec,
},
arpack::Error::ArnoldiFactorizationFailed {
iters,
factorization_size,
n_matvec,
} => ArpackError::ArnoldiFactorizationFailed {
iters,
factorization_size,
n_matvec,
},
_ => ArpackError::InvalidParam("unrecognized arpack error variant"),
}
}
}
pub trait ArpackScalar: Scalar + sealed::Sealed {
fn solve(
n: usize,
matvec: &mut dyn FnMut(&[Self], &mut [Self]),
params: &ArpackParams,
) -> Result<arpack::EigSolution<Self>, arpack::Error>;
}
impl ArpackScalar for f32 {
fn solve(
n: usize,
matvec: &mut dyn FnMut(&[Self], &mut [Self]),
params: &ArpackParams,
) -> Result<arpack::EigSolution<Self>, arpack::Error> {
let opts = arpack::symmetric::Options {
tol: params.tol,
max_iter: params.max_iter,
ncv: params.ncv,
};
arpack::symmetric::smallest_eigenpair_f32(n, matvec, &opts)
}
}
impl ArpackScalar for f64 {
fn solve(
n: usize,
matvec: &mut dyn FnMut(&[Self], &mut [Self]),
params: &ArpackParams,
) -> Result<arpack::EigSolution<Self>, arpack::Error> {
let opts = arpack::symmetric::Options {
tol: params.tol,
max_iter: params.max_iter,
ncv: params.ncv,
};
arpack::symmetric::smallest_eigenpair_f64(n, matvec, &opts)
}
}
impl ArpackScalar for Complex32 {
fn solve(
n: usize,
matvec: &mut dyn FnMut(&[Self], &mut [Self]),
params: &ArpackParams,
) -> Result<arpack::EigSolution<Self>, arpack::Error> {
let opts = arpack::arnoldi::Options {
tol: params.tol,
max_iter: params.max_iter,
ncv: params.ncv,
};
arpack::arnoldi::smallest_eigenpair_c32(n, matvec, &opts)
}
}
impl ArpackScalar for Complex64 {
fn solve(
n: usize,
matvec: &mut dyn FnMut(&[Self], &mut [Self]),
params: &ArpackParams,
) -> Result<arpack::EigSolution<Self>, arpack::Error> {
let opts = arpack::arnoldi::Options {
tol: params.tol,
max_iter: params.max_iter,
ncv: params.ncv,
};
arpack::arnoldi::smallest_eigenpair_c64(n, matvec, &opts)
}
}
pub fn arpack_smallest<T, Op>(
op: &Op,
dim: usize,
params: &ArpackParams,
) -> Result<ArpackResult<T>, ArpackError>
where
T: ArpackScalar,
T::Real: Scalar<Real = T::Real>,
Op: LinearOp<T>,
{
assert!(dim >= 1, "arpack_smallest: dim must be >= 1");
if !params.tol.is_finite() || params.tol <= 0.0 {
return Err(ArpackError::InvalidParam(
"params.tol must be finite and strictly positive",
));
}
let solution = T::solve(
dim,
&mut |x_slice, y_slice| {
let x_dense = Host::shared().dense(x_slice.to_vec(), vec![dim]);
let y_dense = op.apply(&x_dense);
assert_eq!(
y_dense.shape(),
&[dim],
"LinearOp::apply must return a rank-1 tensor of shape [dim]",
);
y_slice.copy_from_slice(y_dense.data_slice());
},
params,
)?;
let eigenvalue = solution.eigenvalue.re();
let mut eigenvector = Host::shared().dense(solution.eigenvector, vec![dim]);
eigenvector.normalize();
let h_psi_raw = op.apply(&eigenvector);
assert_eq!(
h_psi_raw.shape(),
&[dim],
"LinearOp::apply must return a rank-1 tensor of shape [dim]",
);
let h_psi = if h_psi_raw.order() == eigenvector.order() {
h_psi_raw
} else {
h_psi_raw.reordered(eigenvector.order())
};
let lambda_t = T::from_real_imag(eigenvalue, T::Real::zero());
let neg_lambda = lambda_t.scale_real(-T::Real::one());
let residual_vec = linear_combine(&[&h_psi, &eigenvector], &[T::one(), neg_lambda])
.expect("linear_combine on rank-1 tensors of matching shape");
let residual = residual_vec.norm();
let tol_real: T::Real = <T::Real as NumCast>::from(params.tol)
.unwrap_or_else(|| panic!("tol {} not representable in T::Real", params.tol));
let converged = residual <= tol_real;
Ok(ArpackResult {
eigenvalue,
eigenvector,
iters: solution.iters,
n_matvec: solution.n_matvec,
residual,
converged,
})
}