use crate::error::SparseError;
use crate::krylov::gmres_dr::{dot, gram_schmidt_mgs, norm2};
#[derive(Debug, Clone)]
pub struct HarmonicRitzDeflation {
pub n_deflate: usize,
pub vectors: Vec<Vec<f64>>,
pub values: Vec<f64>,
}
impl HarmonicRitzDeflation {
pub fn new(n_deflate: usize) -> Self {
Self {
n_deflate,
vectors: Vec::new(),
values: Vec::new(),
}
}
pub fn extract_from_krylov(&mut self, h: &[Vec<f64>], v: &[Vec<f64>], m: usize) {
if m == 0 {
return;
}
let n = if !v.is_empty() { v[0].len() } else { return };
let n_take = self.n_deflate.min(m);
let h_extra = if m < h.len() && m.checked_sub(1).is_some() {
let col = m - 1;
if col < h[m].len() {
h[m][col].abs()
} else {
0.0
}
} else {
0.0
};
let hm: Vec<Vec<f64>> = (0..m)
.map(|i| {
if i < h.len() {
let row_len = m.min(h[i].len());
let mut row = h[i][..row_len].to_vec();
row.resize(m, 0.0);
row
} else {
vec![0.0; m]
}
})
.collect();
let schur_vecs = harmonic_ritz_schur_vecs(&hm, h_extra, n_take);
let mut new_vecs: Vec<Vec<f64>> = Vec::with_capacity(n_take);
let mut new_vals: Vec<f64> = Vec::with_capacity(n_take);
for (eig_val, s) in &schur_vecs {
if s.len() != m {
continue;
}
let mut y = vec![0.0f64; n];
for (j, &sj) in s.iter().enumerate() {
if j < v.len() {
for l in 0..n {
y[l] += sj * v[j][l];
}
}
}
let nrm = norm2(&y);
if nrm > 1e-15 {
for yi in &mut y {
*yi /= nrm;
}
new_vecs.push(y);
new_vals.push(*eig_val);
}
}
gram_schmidt_mgs(&mut new_vecs);
self.vectors = new_vecs;
self.values = new_vals;
}
pub fn deflate(&self, r: &[f64]) -> Vec<f64> {
let mut result = r.to_vec();
for v in &self.vectors {
let coeff = dot(&result, v);
let norm2_v: f64 = dot(v, v);
if norm2_v > 1e-300 {
let scale = coeff / norm2_v;
for (ri, vi) in result.iter_mut().zip(v.iter()) {
*ri -= scale * vi;
}
}
}
result
}
pub fn correction_from_residual(&self, r: &[f64], av: &[Vec<f64>]) -> Vec<f64> {
let n = r.len();
let k = self.vectors.len().min(av.len());
let mut delta = vec![0.0f64; n];
for j in 0..k {
let av_j = &av[j];
let av_norm2 = dot(av_j, av_j);
if av_norm2 < 1e-300 {
continue;
}
let coeff = dot(av_j, r) / av_norm2;
for i in 0..n {
delta[i] += coeff * self.vectors[j][i];
}
}
delta
}
pub fn dim(&self) -> usize {
self.vectors.len()
}
pub fn clear(&mut self) {
self.vectors.clear();
self.values.clear();
}
pub fn projection_quality(&self, r: &[f64]) -> f64 {
let r_norm = norm2(r);
if r_norm < 1e-300 {
return 0.0;
}
let mut proj = 0.0f64;
for v in &self.vectors {
let c = dot(r, v);
let vn2 = dot(v, v);
if vn2 > 1e-300 {
proj += c * c / vn2;
}
}
proj.sqrt() / r_norm
}
}
fn harmonic_ritz_schur_vecs(hm: &[Vec<f64>], h_extra: f64, n_take: usize) -> Vec<(f64, Vec<f64>)> {
let m = hm.len();
if m == 0 || n_take == 0 {
return Vec::new();
}
let mut a = hm.to_vec();
if h_extra > 1e-15 && m >= 1 {
a[m - 1][m - 1] += h_extra * h_extra;
}
let n_qr = 30 * m;
let mut q_total: Vec<Vec<f64>> = (0..m)
.map(|i| {
let mut row = vec![0.0f64; m];
row[i] = 1.0;
row
})
.collect();
for _ in 0..n_qr {
let shift = if m >= 2 { a[m - 1][m - 1] } else { 0.0 };
let (q, r) = hessenberg_qr(&a, shift, m);
a = dense_mat_mul(&r, &q, m);
for i in 0..m {
a[i][i] += shift;
}
q_total = dense_mat_mul(&q_total, &q, m);
}
let mut eig_pairs: Vec<(f64, usize)> = (0..m).map(|i| (a[i][i].abs(), i)).collect();
eig_pairs.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(std::cmp::Ordering::Equal));
let take = n_take.min(m);
eig_pairs[..take]
.iter()
.map(|(_, col)| {
let eig = a[*col][*col];
let s: Vec<f64> = (0..m).map(|row| q_total[row][*col]).collect();
(eig, s)
})
.collect()
}
fn hessenberg_qr(a: &[Vec<f64>], shift: f64, m: usize) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let mut r: Vec<Vec<f64>> = (0..m)
.map(|i| {
let mut row = a[i].clone();
row.resize(m, 0.0);
if i < row.len() {
row[i] -= shift;
}
row
})
.collect();
let mut q: Vec<Vec<f64>> = (0..m)
.map(|i| {
let mut row = vec![0.0f64; m];
row[i] = 1.0;
row
})
.collect();
for j in 0..m.saturating_sub(1) {
let a_jj = if j < r.len() && j < r[j].len() {
r[j][j]
} else {
0.0
};
let a_j1j = if j + 1 < r.len() && j < r[j + 1].len() {
r[j + 1][j]
} else {
0.0
};
let denom = (a_jj * a_jj + a_j1j * a_j1j).sqrt();
let (c, s) = if denom < 1e-300 {
(1.0f64, 0.0f64)
} else {
(a_jj / denom, a_j1j / denom)
};
for col in 0..m {
let r_jc = if j < r.len() && col < r[j].len() {
r[j][col]
} else {
0.0
};
let r_j1c = if j + 1 < r.len() && col < r[j + 1].len() {
r[j + 1][col]
} else {
0.0
};
if j < r.len() && col < r[j].len() {
r[j][col] = c * r_jc + s * r_j1c;
}
if j + 1 < r.len() && col < r[j + 1].len() {
r[j + 1][col] = -s * r_jc + c * r_j1c;
}
}
for row in 0..m {
let q_rj = if row < q.len() && j < q[row].len() {
q[row][j]
} else {
0.0
};
let q_rj1 = if row < q.len() && j + 1 < q[row].len() {
q[row][j + 1]
} else {
0.0
};
if row < q.len() && j < q[row].len() {
q[row][j] = c * q_rj + s * q_rj1;
}
if row < q.len() && j + 1 < q[row].len() {
q[row][j + 1] = -s * q_rj + c * q_rj1;
}
}
}
(q, r)
}
fn dense_mat_mul(a: &[Vec<f64>], b: &[Vec<f64>], m: usize) -> Vec<Vec<f64>> {
let mut c = vec![vec![0.0f64; m]; m];
for i in 0..m {
for k in 0..m {
let a_ik = if i < a.len() && k < a[i].len() {
a[i][k]
} else {
0.0
};
if a_ik.abs() < 1e-300 {
continue;
}
for j in 0..m {
let b_kj = if k < b.len() && j < b[k].len() {
b[k][j]
} else {
0.0
};
c[i][j] += a_ik * b_kj;
}
}
}
c
}
pub fn deflation_reduces_residual<F>(
matvec: F,
deflation: &HarmonicRitzDeflation,
b: &[f64],
x: &[f64],
) -> bool
where
F: Fn(&[f64]) -> Vec<f64>,
{
let ax = matvec(x);
let r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, axi)| bi - axi).collect();
let r_deflated = deflation.deflate(&r);
norm2(&r_deflated) <= norm2(&r) + 1e-12
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_harmonic_ritz_deflation_construction() {
let defl = HarmonicRitzDeflation::new(4);
assert_eq!(defl.n_deflate, 4);
assert!(defl.vectors.is_empty());
assert!(defl.values.is_empty());
assert_eq!(defl.dim(), 0);
}
#[test]
fn test_deflate_zero_vector() {
let defl = HarmonicRitzDeflation::new(2);
let r = vec![0.0f64; 5];
let deflated = defl.deflate(&r);
assert_eq!(deflated, r);
}
#[test]
fn test_deflate_reduces_norm() {
let mut defl = HarmonicRitzDeflation::new(1);
defl.vectors = vec![vec![1.0, 0.0, 0.0, 0.0]];
defl.values = vec![0.1];
let r = vec![1.0, 1.0, 1.0, 1.0];
let r_d = defl.deflate(&r);
assert!(r_d[0].abs() < 1e-14, "r_d[0] = {}", r_d[0]);
assert!((r_d[1] - 1.0).abs() < 1e-14);
assert!((r_d[2] - 1.0).abs() < 1e-14);
assert!((r_d[3] - 1.0).abs() < 1e-14);
assert!(norm2(&r_d) < norm2(&r));
}
#[test]
fn test_extract_from_krylov_trivial() {
let mut defl = HarmonicRitzDeflation::new(1);
let h = vec![vec![2.0, 1.0], vec![0.5, 1.5], vec![0.1, 0.0]];
let v = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
defl.extract_from_krylov(&h, &v, 2);
assert!(defl.vectors.len() <= 1);
for vi in &defl.vectors {
let nrm = norm2(vi);
assert!(
(nrm - 1.0).abs() < 1e-12,
"deflation vector not normalised: {}",
nrm
);
}
}
#[test]
fn test_deflation_projection_quality_orthogonal() {
let mut defl = HarmonicRitzDeflation::new(2);
defl.vectors = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
defl.values = vec![0.1, 0.2];
let r = vec![3.0, 4.0, 0.0];
let q = defl.projection_quality(&r);
let expected = (3.0_f64.powi(2) + 4.0_f64.powi(2)).sqrt() / norm2(&r);
assert!(
(q - expected).abs() < 1e-12,
"q = {:.6}, expected = {:.6}",
q,
expected
);
}
#[test]
fn test_deflation_clear() {
let mut defl = HarmonicRitzDeflation::new(3);
defl.vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
defl.values = vec![0.1, 0.2];
defl.clear();
assert_eq!(defl.dim(), 0);
assert!(defl.values.is_empty());
}
#[test]
fn test_correction_from_residual() {
let n = 4;
let u = vec![1.0 / 2.0f64.sqrt(), 1.0 / 2.0f64.sqrt(), 0.0, 0.0];
let lambda = 0.5f64;
let au: Vec<f64> = u.iter().map(|x| lambda * x).collect();
let mut defl = HarmonicRitzDeflation::new(1);
defl.vectors = vec![u.clone()];
defl.values = vec![lambda];
let av_list = vec![au.clone()];
let r = vec![2.0, 2.0, 0.0, 0.0];
let correction = defl.correction_from_residual(&r, &av_list);
assert_eq!(correction.len(), n);
let proj_c_on_u = dot(&correction, &u);
assert!(
proj_c_on_u.abs() > 1e-10,
"correction should have non-zero projection on u"
);
}
}