use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn tridiagonal_eigvalsh<F>(
diagonal: &ArrayView1<F>,
off_diagonal: &ArrayView1<F>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = diagonal.len();
if off_diagonal.len() != n - 1 {
return Err(LinalgError::ShapeError(format!(
"Off-_diagonal length ({}) must be one less than _diagonal length ({})",
off_diagonal.len(),
n
)));
}
let d = diagonal.to_owned();
let mut e = off_diagonal.to_owned();
let mut eigenvalues = d.clone();
let max_iter = 30 * n;
let compute_givens = |a: F, b: F| -> (F, F) {
let r = (a * a + b * b).sqrt();
if r < F::epsilon() {
(F::one(), F::zero())
} else {
(a / r, -b / r)
}
};
let tol = F::epsilon().sqrt()
* eigenvalues
.iter()
.fold(F::zero(), |max, &val| max.max(val.abs()));
for i in 0..n - 1 {
if e[i].abs() < tol {
e[i] = F::zero();
}
}
let mut m = n - 1;
let mut iter_count = 0;
while m > 0 && iter_count < max_iter {
let mut l = m;
while l > 0 {
if e[l - 1].abs() <= tol {
break;
}
l -= 1;
}
if l == m {
m -= 1;
continue;
}
let shift = eigenvalues[m];
let mut g = (eigenvalues[l] - shift) / (F::from(2.0).expect("Operation failed") * e[l]);
let mut r = (F::one() + g * g).sqrt();
if g < F::zero() {
r = -r;
}
g = eigenvalues[l] - shift + e[l] / (g + r);
let mut s = F::one();
let mut c = F::one();
let mut p = F::zero();
for i in l..m {
let f = s * e[i];
let b = c * e[i];
let (c_i, s_i) = compute_givens(g, f);
c = c_i;
s = s_i;
if i > l {
e[i - 1] = r;
}
g = eigenvalues[i + 1] - p;
r = (eigenvalues[i] - g) * s + F::from(2.0).expect("Operation failed") * c * b;
p = s * r;
eigenvalues[i] = g + p;
g = c * r - b;
}
eigenvalues[m] -= p;
e[m - 1] = g;
iter_count += 1;
}
if iter_count >= max_iter {
return Err(LinalgError::ConvergenceError(
"Maximum iterations reached in tridiagonal_eigvalsh".to_string(),
));
}
let mut sorted = eigenvalues.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
for i in 0..n {
eigenvalues[i] = sorted[i];
}
Ok(eigenvalues)
}
#[allow(dead_code)]
pub fn tridiagonal_eigh<F>(
diagonal: &ArrayView1<F>,
off_diagonal: &ArrayView1<F>,
) -> LinalgResult<(Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = diagonal.len();
if off_diagonal.len() != n - 1 {
return Err(LinalgError::ShapeError(format!(
"Off-_diagonal length ({}) must be one less than _diagonal length ({})",
off_diagonal.len(),
n
)));
}
let mut trimatrix = Array2::zeros((n, n));
for i in 0..n {
trimatrix[[i, i]] = diagonal[i];
if i < n - 1 {
trimatrix[[i, i + 1]] = off_diagonal[i];
trimatrix[[i + 1, i]] = off_diagonal[i];
}
}
let _result = Array2::<F>::zeros((n, n));
let identity = Array2::<F>::eye(n);
let max_iter = 100;
for _ in 0..max_iter {
let (q, r) = match crate::decomposition::qr(&trimatrix.view(), None) {
Ok((q, r)) => (q, r),
Err(e) => return Err(e),
};
let temp = r.dot(&q);
let mut is_converged = true;
for i in 0..n {
for j in 0..n {
if i != j && temp[[i, j]].abs() > F::epsilon() * F::from(100.0).expect("Operation failed") {
is_converged = false;
break;
}
}
if !is_converged {
break;
}
}
if is_converged {
let mut eigenvalues = Array1::zeros(n);
let eigenvectors = identity.dot(&q);
for i in 0..n {
eigenvalues[i] = temp[[i, i]];
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&i, &j| eigenvalues[i].partial_cmp(&eigenvalues[j]).expect("Operation failed"));
let mut sorted_eigenvalues = Array1::zeros(n);
let mut sorted_eigenvectors = Array2::zeros((n, n));
for (pos, &idx) in indices.iter().enumerate() {
sorted_eigenvalues[pos] = eigenvalues[idx];
for i in 0..n {
sorted_eigenvectors[[i, pos]] = eigenvectors[[i, idx]];
}
}
return Ok((sorted_eigenvalues, sorted_eigenvectors));
}
trimatrix = temp;
}
let mut eigenvalues = Array1::zeros(n);
for i in 0..n {
eigenvalues[i] = trimatrix[[i, i]];
}
Err(LinalgError::ConvergenceError(
"QR algorithm did not converge for tridiagonal matrix".to_string(),
))
}