use crate::algorithm::sparse_linalg::SparseLinAlgAlgorithms;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, LinalgOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::{
INVARIANT_SUBSPACE_TOL, REORTH_TOL, accumulate_basis_combination, vector_dot, vector_norm,
};
use super::super::types::{SparseEigComplexResult, SparseEigOptions};
use super::dense_eig::{hessenberg_eig, select_eigenvalues_complex};
pub fn arnoldi_eig_impl<R, C>(
client: &C,
a: &CsrData<R>,
k: usize,
options: SparseEigOptions,
) -> Result<SparseEigComplexResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ LinalgOps<R>,
{
let [nrows, ncols] = a.shape;
if nrows != ncols {
return Err(Error::ShapeMismatch {
expected: vec![nrows, nrows],
got: vec![nrows, ncols],
});
}
let n = nrows;
if k == 0 || k > n {
return Err(Error::Internal(format!(
"k must be in [1, n], got k={k}, n={n}"
)));
}
let device = a.values().device();
let dtype = a.values().dtype();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType {
dtype,
op: "arnoldi_eig",
});
}
let ncv = options.ncv.unwrap_or_else(|| (2 * k + 1).min(n));
let ncv = ncv.max(k + 1).min(n);
let ones = Tensor::<R>::ones(&[n], dtype, device);
let scale = 1.0 / (n as f64).sqrt();
let v0 = client.mul_scalar(&ones, scale)?;
let mut v_basis: Vec<Tensor<R>> = Vec::with_capacity(ncv + 1);
let mut h_columns: Vec<Vec<f64>> = Vec::with_capacity(ncv);
v_basis.push(v0);
for restart_iter in 0..options.max_iter {
let start_j = h_columns.len();
for j in start_j..ncv {
let vj = &v_basis[j];
let mut w = a.spmv(vj)?;
let mut h_col = vec![0.0f64; j + 2];
for i in 0..=j {
let h_ij = vector_dot(client, &v_basis[i], &w)?;
h_col[i] = h_ij;
if h_ij.abs() > REORTH_TOL {
let scaled = client.mul_scalar(&v_basis[i], h_ij)?;
w = client.sub(&w, &scaled)?;
}
}
for i in 0..=j {
let h_corr = vector_dot(client, &v_basis[i], &w)?;
if h_corr.abs() > REORTH_TOL {
h_col[i] += h_corr;
let scaled = client.mul_scalar(&v_basis[i], h_corr)?;
w = client.sub(&w, &scaled)?;
}
}
let h_next = vector_norm(client, &w)?;
h_col[j + 1] = h_next;
h_columns.push(h_col);
if h_next < INVARIANT_SUBSPACE_TOL {
break;
}
let v_next = client.mul_scalar(&w, 1.0 / h_next)?;
if v_basis.len() <= j + 1 {
v_basis.push(v_next);
} else {
v_basis[j + 1] = v_next;
}
}
let m = h_columns.len();
if m == 0 {
return Err(Error::Internal("Arnoldi produced no vectors".to_string()));
}
let mut h_dense = vec![vec![0.0f64; m]; m];
for j in 0..m {
for i in 0..h_columns[j].len().min(m) {
h_dense[i][j] = h_columns[j][i];
}
}
let (eig_real, eig_imag, schur_vecs) = hessenberg_eig(&h_dense, m);
let indices = select_eigenvalues_complex(&eig_real, &eig_imag, k.min(m), &options.which);
let last_h = if m > 0 && h_columns[m - 1].len() > m {
h_columns[m - 1][m]
} else {
0.0
};
let mut nconv = 0;
for &idx in &indices {
if idx < schur_vecs.len() && !schur_vecs[idx].is_empty() {
let last_comp = schur_vecs[idx].last().copied().unwrap_or(0.0);
let error = (last_h * last_comp).abs();
if error < options.tol {
nconv += 1;
}
}
}
if nconv >= k || m >= n {
return build_result(
client,
&v_basis,
&eig_real,
&eig_imag,
&schur_vecs,
&indices,
k,
m,
n,
dtype,
device,
restart_iter + 1,
nconv >= k,
nconv,
);
}
thick_restart(
client,
&mut v_basis,
&mut h_columns,
a,
&schur_vecs,
&indices,
k,
m,
n,
dtype,
device,
)?;
}
Err(Error::Internal(format!(
"Arnoldi eigensolver did not converge after {} restart cycles",
options.max_iter
)))
}
#[allow(clippy::too_many_arguments)]
fn build_result<R, C>(
client: &C,
v_basis: &[Tensor<R>],
eig_real: &[f64],
eig_imag: &[f64],
schur_vecs: &[Vec<f64>],
indices: &[usize],
k: usize,
m: usize,
n: usize,
dtype: DType,
device: &R::Device,
iterations: usize,
converged: bool,
nconv: usize,
) -> Result<SparseEigComplexResult<R>>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + ScalarOps<R>,
{
let k_actual = k.min(indices.len());
let mut eig_real_out = Vec::with_capacity(k_actual);
let mut eig_imag_out = Vec::with_capacity(k_actual);
let mut ritz_vectors: Vec<Tensor<R>> = Vec::with_capacity(k_actual);
for &idx in indices.iter().take(k_actual) {
eig_real_out.push(eig_real[idx]);
eig_imag_out.push(eig_imag[idx]);
if idx < schur_vecs.len() {
let y = &schur_vecs[idx];
let ritz = accumulate_basis_combination(
client,
&v_basis[..y.len().min(m)],
y,
n,
dtype,
device,
)?;
ritz_vectors.push(ritz);
} else {
ritz_vectors.push(Tensor::<R>::zeros(&[n], dtype, device));
}
}
let eigenvalues_real = Tensor::<R>::from_slice(&eig_real_out, &[k_actual], device);
let eigenvalues_imag = Tensor::<R>::from_slice(&eig_imag_out, &[k_actual], device);
let mut flat = vec![0.0f64; n * k_actual];
for (col, ritz) in ritz_vectors.iter().enumerate() {
let col_data: Vec<f64> = ritz.to_vec();
flat[col * n..(col + 1) * n].copy_from_slice(&col_data);
}
let eigenvectors = Tensor::<R>::from_slice(&flat, &[n, k_actual], device);
Ok(SparseEigComplexResult {
eigenvalues_real,
eigenvalues_imag,
eigenvectors,
iterations,
converged,
nconv,
})
}
#[allow(clippy::too_many_arguments)]
fn thick_restart<R, C>(
client: &C,
v_basis: &mut Vec<Tensor<R>>,
h_columns: &mut Vec<Vec<f64>>,
a: &CsrData<R>,
schur_vecs: &[Vec<f64>],
indices: &[usize],
k: usize,
m: usize,
n: usize,
dtype: DType,
device: &R::Device,
) -> Result<()>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let k_keep = k.min(indices.len());
let mut new_basis: Vec<Tensor<R>> = Vec::with_capacity(k_keep);
for &idx in indices.iter().take(k_keep) {
if idx < schur_vecs.len() {
let y = &schur_vecs[idx];
let mut ritz = accumulate_basis_combination(
client,
&v_basis[..y.len().min(m)],
y,
n,
dtype,
device,
)?;
let ritz_norm = vector_norm(client, &ritz)?;
if ritz_norm > REORTH_TOL {
ritz = client.mul_scalar(&ritz, 1.0 / ritz_norm)?;
}
new_basis.push(ritz);
}
}
for i in 1..new_basis.len() {
let mut vi = new_basis[i].clone();
for j in 0..i {
let h = vector_dot(client, &new_basis[j], &vi)?;
if h.abs() > REORTH_TOL {
let scaled = client.mul_scalar(&new_basis[j], h)?;
vi = client.sub(&vi, &scaled)?;
}
}
let vi_norm = vector_norm(client, &vi)?;
if vi_norm > REORTH_TOL {
vi = client.mul_scalar(&vi, 1.0 / vi_norm)?;
}
new_basis[i] = vi;
}
*v_basis = new_basis;
h_columns.clear();
let k_len = v_basis.len();
for j in 0..k_len {
let vj = &v_basis[j];
let av = a.spmv(vj)?;
let mut h_col = vec![0.0f64; j + 2];
for i in 0..=j {
h_col[i] = vector_dot(client, &v_basis[i], &av)?;
}
if j + 1 < k_len {
h_col[j + 1] = vector_dot(client, &v_basis[j + 1], &av)?;
} else {
let mut w = av;
for i in 0..=j {
if h_col[i].abs() > REORTH_TOL {
let scaled = client.mul_scalar(&v_basis[i], h_col[i])?;
w = client.sub(&w, &scaled)?;
}
}
let h_next = vector_norm(client, &w)?;
h_col[j + 1] = h_next;
if h_next > INVARIANT_SUBSPACE_TOL {
let v_next = client.mul_scalar(&w, 1.0 / h_next)?;
v_basis.push(v_next);
}
}
h_columns.push(h_col);
}
Ok(())
}