use crate::error::{LinalgError, LinalgResult};
type TridiagResult = LinalgResult<(Vec<f64>, Vec<f64>, Vec<Vec<f64>>)>;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DcConfig {
pub tol: f64,
pub max_iter: usize,
pub deflation_tol: f64,
}
impl Default for DcConfig {
fn default() -> Self {
Self {
tol: 1e-12,
max_iter: 30,
deflation_tol: 1e-10,
}
}
}
pub fn dc_eig_tridiag(diag: &[f64], off_diag: &[f64]) -> LinalgResult<(Vec<f64>, Vec<Vec<f64>>)> {
let n = diag.len();
if n == 0 {
return Ok((vec![], vec![]));
}
if off_diag.len() != n - 1 {
return Err(LinalgError::DimensionError(format!(
"off_diag length {} != diag length {} - 1",
off_diag.len(),
n
)));
}
let config = DcConfig::default();
dc_tridiag_impl(diag, off_diag, &config)
}
pub fn dc_eig_symmetric(a: &[Vec<f64>]) -> LinalgResult<(Vec<f64>, Vec<Vec<f64>>)> {
let n = a.len();
if n == 0 {
return Ok((vec![], vec![]));
}
for row in a {
if row.len() != n {
return Err(LinalgError::DimensionError(format!(
"Matrix is not square: row has {} elements, expected {n}",
row.len()
)));
}
}
let (diag, off_diag, q_house) = householder_tridiagonalize(a)?;
let config = DcConfig::default();
let (evals, evecs_tri) = dc_tridiag_impl(&diag, &off_diag, &config)?;
let evecs = back_transform_evecs(&q_house, &evecs_tri, n);
Ok((evals, evecs))
}
fn householder_tridiagonalize(a: &[Vec<f64>]) -> TridiagResult {
let n = a.len();
let mut mat: Vec<Vec<f64>> = a.to_vec();
let mut q = identity_matrix(n);
for k in 0..n.saturating_sub(2) {
let col_len = n - k - 1;
let mut x: Vec<f64> = (0..col_len).map(|i| mat[k + 1 + i][k]).collect();
let norm_x = vec_norm(&x);
if norm_x < 1e-15 {
continue;
}
let sign = if x[0] >= 0.0 { 1.0_f64 } else { -1.0_f64 };
x[0] += sign * norm_x;
let norm_v = vec_norm(&x);
if norm_v < 1e-15 {
continue;
}
for xi in &mut x {
*xi /= norm_v;
}
let sz = col_len;
let mut p = vec![0.0f64; sz];
for i in 0..sz {
for j in 0..sz {
p[i] += mat[k + 1 + i][k + 1 + j] * x[j];
}
}
for pi in &mut p {
*pi *= 2.0;
}
let vtp: f64 = x.iter().zip(p.iter()).map(|(vi, pi)| vi * pi).sum();
let mut w = p.clone();
for (wi, &xi) in w.iter_mut().zip(x.iter()) {
*wi -= vtp * xi;
}
for i in 0..sz {
for j in 0..sz {
mat[k + 1 + i][k + 1 + j] -= x[i] * w[j] + w[i] * x[j];
}
}
for i in 1..sz {
mat[k + 1 + i][k] = 0.0;
mat[k][k + 1 + i] = 0.0;
}
let off = -sign * norm_x;
mat[k + 1][k] = off;
mat[k][k + 1] = off;
let mut qv = vec![0.0f64; n];
for i in 0..n {
for j in 0..sz {
qv[i] += q[i][k + 1 + j] * x[j];
}
}
for i in 0..n {
for j in 0..sz {
q[i][k + 1 + j] -= 2.0 * qv[i] * x[j];
}
}
}
let diag: Vec<f64> = (0..n).map(|i| mat[i][i]).collect();
let off_diag: Vec<f64> = (0..n.saturating_sub(1)).map(|i| mat[i][i + 1]).collect();
Ok((diag, off_diag, q))
}
fn identity_matrix(n: usize) -> Vec<Vec<f64>> {
let mut q = vec![vec![0.0f64; n]; n];
for (i, qi) in q.iter_mut().enumerate().take(n) {
qi[i] = 1.0;
}
q
}
fn vec_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn back_transform_evecs(q: &[Vec<f64>], evecs_tri: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
evecs_tri
.iter()
.map(|v| {
let mut out = vec![0.0f64; n];
for i in 0..n {
for j in 0..n {
out[i] += q[i][j] * v[j];
}
}
out
})
.collect()
}
fn dc_tridiag_impl(
diag: &[f64],
off_diag: &[f64],
config: &DcConfig,
) -> LinalgResult<(Vec<f64>, Vec<Vec<f64>>)> {
let n = diag.len();
if n == 1 {
return Ok((vec![diag[0]], vec![vec![1.0]]));
}
if n == 2 {
return solve_2x2_tridiag(diag[0], diag[1], off_diag[0]);
}
if n <= 8 {
return qr_tridiag_eigen(diag, off_diag, config);
}
let mid = n / 2;
let beta = off_diag[mid - 1]; let abs_beta = beta.abs();
let mut diag1 = diag[..mid].to_vec();
let mut diag2 = diag[mid..].to_vec();
diag1[mid - 1] -= abs_beta;
diag2[0] -= abs_beta;
let off1 = &off_diag[..mid - 1];
let off2 = &off_diag[mid..];
let (evals1, evecs1) = dc_tridiag_impl(&diag1, off1, config)?;
let (evals2, evecs2) = dc_tridiag_impl(&diag2, off2, config)?;
let d: Vec<f64> = evals1.iter().chain(evals2.iter()).copied().collect();
let sign_beta = if beta >= 0.0 { 1.0_f64 } else { -1.0_f64 };
let mut u: Vec<f64> = Vec::with_capacity(n);
for ev1 in evecs1.iter().take(mid) {
u.push(sign_beta * ev1[mid - 1]);
}
for ev2 in evecs2.iter().take(n - mid) {
u.push(ev2[0]);
}
let (d_defl, u_defl, active_idx, trivial_evals) = deflate_secular(&d, &u, config.deflation_tol);
let secular_evals = if d_defl.is_empty() {
vec![]
} else {
solve_secular(&d_defl, &u_defl, abs_beta, config)?
};
let n_trivial = trivial_evals.len();
let n_secular = secular_evals.len();
let total = n_trivial + n_secular;
if total != n {
return qr_tridiag_eigen(diag, off_diag, config);
}
let mut all_pairs: Vec<(f64, Vec<f64>)> = Vec::with_capacity(n);
let trivial_positions: Vec<usize> = (0..n).filter(|i| !active_idx.contains(i)).collect();
for (&eval, &pos) in trivial_evals.iter().zip(trivial_positions.iter()) {
let mut ev = vec![0.0f64; n];
ev[pos] = 1.0;
all_pairs.push((eval, ev));
}
for &lam in &secular_evals {
let v_active: Vec<f64> = d_defl
.iter()
.zip(u_defl.iter())
.map(|(&di, &ui)| {
let denom = di - lam;
if denom.abs() < 1e-300 {
if ui >= 0.0 {
1e150_f64
} else {
-1e150_f64
}
} else {
ui / denom
}
})
.collect();
let norm = vec_norm(&v_active);
let mut ev_d = vec![0.0f64; n];
if norm > 1e-300 {
for (k, &ai) in active_idx.iter().enumerate() {
ev_d[ai] = v_active[k] / norm;
}
} else {
for &ai in &active_idx {
ev_d[ai] = 1.0 / (active_idx.len() as f64).sqrt();
}
}
all_pairs.push((lam, ev_d));
}
all_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let n2 = n - mid;
let evals_final: Vec<f64> = all_pairs.iter().map(|(e, _)| *e).collect();
let evecs_final: Vec<Vec<f64>> = all_pairs
.into_iter()
.map(|(_, v_d)| {
let mut out = vec![0.0f64; n];
for k in 0..mid {
let x_k = v_d[k];
if x_k.abs() > 1e-300 {
for i in 0..mid {
out[i] += x_k * evecs1[k][i];
}
}
}
for k in 0..n2 {
let x_k = v_d[mid + k];
if x_k.abs() > 1e-300 {
for i in 0..n2 {
out[mid + i] += x_k * evecs2[k][i];
}
}
}
out
})
.collect();
Ok((evals_final, evecs_final))
}
fn qr_tridiag_eigen(
diag: &[f64],
off_diag: &[f64],
_config: &DcConfig,
) -> LinalgResult<(Vec<f64>, Vec<Vec<f64>>)> {
let n = diag.len();
if n == 1 {
return Ok((vec![diag[0]], vec![vec![1.0]]));
}
if n == 2 {
return solve_2x2_tridiag(diag[0], diag[1], off_diag[0]);
}
let mut mat: Vec<Vec<f64>> = vec![vec![0.0f64; n]; n];
for i in 0..n {
mat[i][i] = diag[i];
}
for i in 0..n - 1 {
mat[i][i + 1] = off_diag[i];
mat[i + 1][i] = off_diag[i];
}
let mut q = identity_matrix(n);
let max_sweeps = 200;
let tol = 1e-13;
for _sweep in 0..max_sweeps {
let mut max_off = 0.0f64;
let mut p_idx = 0usize;
let mut q_idx = 1usize;
for (i, mat_i) in mat.iter().enumerate().take(n) {
for (j, &mat_ij) in mat_i.iter().enumerate().take(n).skip(i + 1) {
let v = mat_ij.abs();
if v > max_off {
max_off = v;
p_idx = i;
q_idx = j;
}
}
}
if max_off < tol {
break;
}
let theta = (mat[q_idx][q_idx] - mat[p_idx][p_idx]) / (2.0 * mat[p_idx][q_idx]);
let sign_t = if theta >= 0.0 { 1.0_f64 } else { -1.0_f64 };
let t = sign_t / (theta.abs() + (1.0 + theta * theta).sqrt());
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
let app = mat[p_idx][p_idx];
let aqq = mat[q_idx][q_idx];
let apq = mat[p_idx][q_idx];
mat[p_idx][p_idx] = app - t * apq;
mat[q_idx][q_idx] = aqq + t * apq;
mat[p_idx][q_idx] = 0.0;
mat[q_idx][p_idx] = 0.0;
for (r, mat_r) in mat.iter_mut().enumerate().take(n) {
if r != p_idx && r != q_idx {
let arp = mat_r[p_idx];
let arq = mat_r[q_idx];
mat_r[p_idx] = c * arp - s * arq;
mat_r[q_idx] = s * arp + c * arq;
}
}
{
let col_p: Vec<f64> = (0..n).map(|r| mat[r][p_idx]).collect();
let col_q: Vec<f64> = (0..n).map(|r| mat[r][q_idx]).collect();
for r in 0..n {
if r != p_idx && r != q_idx {
mat[p_idx][r] = col_p[r];
mat[q_idx][r] = col_q[r];
}
}
}
{
let (left, right) = q.split_at_mut(q_idx);
let qp = &mut left[p_idx];
let qq = &mut right[0];
for (vp, vq) in qp.iter_mut().zip(qq.iter_mut()) {
let old_p = *vp;
let old_q = *vq;
*vp = c * old_p - s * old_q;
*vq = s * old_p + c * old_q;
}
}
}
let mut pairs: Vec<(f64, Vec<f64>)> = (0..n).map(|i| (mat[i][i], q[i].clone())).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let evals: Vec<f64> = pairs.iter().map(|(e, _)| *e).collect();
let evecs: Vec<Vec<f64>> = pairs.into_iter().map(|(_, v)| v).collect();
Ok((evals, evecs))
}
fn solve_2x2_tridiag(d0: f64, d1: f64, e: f64) -> LinalgResult<(Vec<f64>, Vec<Vec<f64>>)> {
let tr = d0 + d1;
let det = d0 * d1 - e * e;
let disc = (tr * tr - 4.0 * det).max(0.0).sqrt();
let lam1 = (tr - disc) / 2.0;
let lam2 = (tr + disc) / 2.0;
let ev1 = eigvec_2x2(d0, d1, e, lam1);
let ev2 = eigvec_2x2(d0, d1, e, lam2);
Ok((vec![lam1, lam2], vec![ev1, ev2]))
}
fn eigvec_2x2(d0: f64, d1: f64, e: f64, lam: f64) -> Vec<f64> {
let a = d0 - lam;
let b = e;
let c = d1 - lam;
let n1 = (a * a + b * b).sqrt();
let n2 = (b * b + c * c).sqrt();
if n1 >= n2 {
if n1 < 1e-300 {
vec![1.0 / 2.0_f64.sqrt(), 1.0 / 2.0_f64.sqrt()]
} else {
vec![-b / n1, a / n1]
}
} else {
if n2 < 1e-300 {
vec![1.0 / 2.0_f64.sqrt(), 1.0 / 2.0_f64.sqrt()]
} else {
vec![-c / n2, b / n2]
}
}
}
pub fn solve_secular(d: &[f64], u: &[f64], beta: f64, config: &DcConfig) -> LinalgResult<Vec<f64>> {
let n = d.len();
if n == 0 {
return Ok(vec![]);
}
let mut pairs: Vec<(f64, f64)> = d.iter().copied().zip(u.iter().copied()).collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let d_sorted: Vec<f64> = pairs.iter().map(|p| p.0).collect();
let u_sorted: Vec<f64> = pairs.iter().map(|p| p.1).collect();
if beta.abs() < 1e-300 {
return Ok(d_sorted);
}
let mut roots = Vec::with_capacity(n);
for i in 0..n {
let (lo, hi) = if beta > 0.0 {
if i < n - 1 {
(d_sorted[i], d_sorted[i + 1])
} else {
let weight_sum: f64 = u_sorted.iter().map(|ui| ui * ui).sum();
(d_sorted[n - 1], d_sorted[n - 1] + beta * weight_sum + 1.0)
}
} else {
if i == 0 {
let weight_sum: f64 = u_sorted.iter().map(|ui| ui * ui).sum();
(d_sorted[0] + beta * weight_sum - 1.0, d_sorted[0])
} else {
(d_sorted[i - 1], d_sorted[i])
}
};
let root = find_secular_root(&d_sorted, &u_sorted, beta, lo, hi, config)?;
roots.push(root);
}
roots.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(roots)
}
#[inline]
fn secular_f(d: &[f64], u: &[f64], beta: f64, lam: f64) -> f64 {
let sum: f64 = d
.iter()
.zip(u.iter())
.map(|(&di, &ui)| ui * ui / (di - lam))
.sum();
1.0 + beta * sum
}
#[inline]
fn secular_df(d: &[f64], u: &[f64], beta: f64, lam: f64) -> f64 {
let sum: f64 = d
.iter()
.zip(u.iter())
.map(|(&di, &ui)| {
let t = di - lam;
ui * ui / (t * t)
})
.sum();
beta * sum
}
fn find_secular_root(
d: &[f64],
u: &[f64],
beta: f64,
lo: f64,
hi: f64,
config: &DcConfig,
) -> LinalgResult<f64> {
let tol = config.tol;
let max_iter = config.max_iter;
let eps = 1e-14 * (lo.abs() + hi.abs() + 1.0);
let lo_safe = lo + eps;
let hi_safe = hi - eps;
let mut lo_m = lo_safe;
let mut hi_m = hi_safe;
let f_lo = secular_f(d, u, beta, lo_m);
let f_hi = secular_f(d, u, beta, hi_m);
if f_lo.is_nan() || f_hi.is_nan() || !f_lo.is_finite() || !f_hi.is_finite() {
return Ok((lo + hi) / 2.0);
}
let mut x = (lo_m + hi_m) / 2.0;
for _ in 0..max_iter {
let fx = secular_f(d, u, beta, x);
if !fx.is_finite() {
x = (lo_m + hi_m) / 2.0;
continue;
}
if fx.abs() < tol {
return Ok(x);
}
let dfx = secular_df(d, u, beta, x);
let x_new = if dfx.abs() > 1e-300 {
x - fx / dfx
} else {
(lo_m + hi_m) / 2.0
};
if x_new > lo_m && x_new < hi_m {
let step = (x_new - x).abs();
x = x_new;
if step < tol * (x.abs() + 1.0) {
return Ok(x);
}
} else {
let f_lo_cur = secular_f(d, u, beta, lo_m);
let f_mid_cur = secular_f(d, u, beta, (lo_m + hi_m) / 2.0);
let mid = (lo_m + hi_m) / 2.0;
if f_lo_cur * f_mid_cur <= 0.0 {
hi_m = mid;
} else {
lo_m = mid;
}
x = (lo_m + hi_m) / 2.0;
if (hi_m - lo_m).abs() < tol {
return Ok(x);
}
}
}
Ok(x)
}
pub fn deflate_secular(
d: &[f64],
u: &[f64],
tol: f64,
) -> (Vec<f64>, Vec<f64>, Vec<usize>, Vec<f64>) {
let n = d.len();
let u_norm = u.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-300);
let mut defl_d = Vec::new();
let mut defl_u = Vec::new();
let mut active_idx = Vec::new();
let mut trivial_evals = Vec::new();
for i in 0..n {
if u[i].abs() < tol * u_norm {
trivial_evals.push(d[i]);
} else {
defl_d.push(d[i]);
defl_u.push(u[i]);
active_idx.push(i);
}
}
let mut i = 0;
while i < defl_d.len() {
let mut j = i + 1;
while j < defl_d.len() {
if (defl_d[j] - defl_d[i]).abs() < tol * (defl_d[i].abs() + 1.0) {
let new_u = (defl_u[i] * defl_u[i] + defl_u[j] * defl_u[j]).sqrt();
defl_u[i] = new_u;
trivial_evals.push(defl_d[j]);
defl_d.remove(j);
defl_u.remove(j);
active_idx.remove(j);
} else {
j += 1;
}
}
i += 1;
}
(defl_d, defl_u, active_idx, trivial_evals)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn check_orthonormal(evecs: &[Vec<f64>]) -> bool {
let m = evecs.len();
for i in 0..m {
let dot_ii: f64 = evecs[i].iter().map(|x| x * x).sum();
if (dot_ii - 1.0).abs() > 1e-7 {
return false;
}
for j in i + 1..m {
let dot_ij: f64 = evecs[i]
.iter()
.zip(evecs[j].iter())
.map(|(a, b)| a * b)
.sum();
if dot_ij.abs() > 1e-7 {
return false;
}
}
}
true
}
#[test]
fn test_3x3_tridiag_known_eigenvalues() {
let diag = vec![2.0, 2.0, 2.0];
let off = vec![-1.0, -1.0];
let (evals, evecs) = dc_eig_tridiag(&diag, &off).expect("DC tridiag failed");
assert_eq!(evals.len(), 3);
assert_eq!(evecs.len(), 3);
let expected = [2.0 - 2.0_f64.sqrt(), 2.0, 2.0 + 2.0_f64.sqrt()];
for (ev, ex) in evals.iter().zip(expected.iter()) {
assert!(
(ev - ex).abs() < 1e-8,
"Eigenvalue mismatch: got {ev}, expected {ex}"
);
}
assert!(
check_orthonormal(&evecs),
"Eigenvectors are not orthonormal"
);
}
#[test]
fn test_2x2_base_case() {
let (evals, evecs) = solve_2x2_tridiag(3.0, 3.0, 1.0).expect("2x2 solve failed");
assert_eq!(evals.len(), 2);
assert!(
approx_eq(evals[0], 2.0, 1e-10),
"Expected 2.0, got {}",
evals[0]
);
assert!(
approx_eq(evals[1], 4.0, 1e-10),
"Expected 4.0, got {}",
evals[1]
);
assert!(check_orthonormal(&evecs));
}
#[test]
fn test_dc_eig_symmetric_4x4() {
let a = vec![
vec![5.0, 1.0, 0.0, 0.0],
vec![1.0, 5.0, 1.0, 0.0],
vec![0.0, 1.0, 5.0, 1.0],
vec![0.0, 0.0, 1.0, 5.0],
];
let (evals, evecs) = dc_eig_symmetric(&a).expect("DC symmetric failed");
assert_eq!(evals.len(), 4);
for (k, &lam) in evals.iter().enumerate() {
let v = &evecs[k];
let mut av = [0.0f64; 4];
for i in 0..4 {
for j in 0..4 {
av[i] += a[i][j] * v[j];
}
}
for i in 0..4 {
assert!(
(av[i] - lam * v[i]).abs() < 1e-6,
"Eigenpair {k}: residual at component {i} too large: {} vs {}",
av[i],
lam * v[i]
);
}
}
assert!(check_orthonormal(&evecs), "Eigenvectors not orthonormal");
}
#[test]
fn test_deflation() {
let d = vec![1.0, 2.0, 3.0];
let u = vec![0.0, 1.0, 0.5]; let (defl_d, _defl_u, idx, trivials) = deflate_secular(&d, &u, 1e-10);
assert!(trivials.contains(&1.0), "d[0]=1 should be trivial");
assert_eq!(defl_d.len(), 2);
assert!(idx.contains(&1));
assert!(idx.contains(&2));
}
#[test]
fn test_larger_tridiag() {
let n = 10;
let diag: Vec<f64> = vec![2.0; n];
let off: Vec<f64> = vec![-1.0; n - 1];
let (evals, evecs) = dc_eig_tridiag(&diag, &off).expect("DC tridiag failed for n=10");
assert_eq!(evals.len(), n);
for &ev in &evals {
assert!(ev > 0.0, "Expected positive eigenvalue, got {ev}");
}
for i in 1..n {
assert!(evals[i] >= evals[i - 1], "Eigenvalues not sorted at {i}");
}
assert!(
check_orthonormal(&evecs),
"Eigenvectors not orthonormal for n=10"
);
}
}