use crate::error::SparseError;
use crate::krylov::gmres_dr::{dot, gram_schmidt_mgs, norm2, solve_least_squares_hessenberg};
#[derive(Debug, Clone)]
pub struct AugmentedKrylovConfig {
pub krylov_dim: usize,
pub tol: f64,
pub max_iter: usize,
pub max_cycles: usize,
}
impl Default for AugmentedKrylovConfig {
fn default() -> Self {
Self {
krylov_dim: 20,
tol: 1e-10,
max_iter: 1000,
max_cycles: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct AugmentedKrylovResult {
pub x: Vec<f64>,
pub residual_norm: f64,
pub iterations: usize,
pub converged: bool,
pub residual_history: Vec<f64>,
pub new_augmentation: Vec<Vec<f64>>,
}
pub struct AugmentedKrylov {
config: AugmentedKrylovConfig,
}
impl AugmentedKrylov {
pub fn new(config: AugmentedKrylovConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self {
config: AugmentedKrylovConfig::default(),
}
}
pub fn solve<F>(
&self,
matvec: F,
b: &[f64],
x0: Option<&[f64]>,
augmentation: &[Vec<f64>],
) -> Result<AugmentedKrylovResult, SparseError>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = b.len();
let mut x = match x0 {
Some(v) => v.to_vec(),
None => vec![0.0f64; n],
};
let b_norm = norm2(b);
let abs_tol = if b_norm > 1e-300 {
self.config.tol * b_norm
} else {
self.config.tol
};
let mut total_mv = 0usize;
let mut residual_history = Vec::new();
let mut last_krylov: Vec<Vec<f64>> = Vec::new();
let mut aug_orth: Vec<Vec<f64>> = augmentation.to_vec();
gram_schmidt_mgs(&mut aug_orth);
aug_orth.retain(|vi| norm2(vi) > 0.5);
let k_aug = aug_orth.len();
let mut aw: Vec<Vec<f64>> = Vec::with_capacity(k_aug);
for j in 0..k_aug {
aw.push(matvec(&aug_orth[j]));
total_mv += 1;
}
let mut aw_orth = aw.clone();
gram_schmidt_mgs(&mut aw_orth);
aw_orth.retain(|vi| norm2(vi) > 0.5);
for _cycle in 0..self.config.max_cycles {
let ax = matvec(&x);
total_mv += 1;
let r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, axi)| bi - axi).collect();
let r_norm = norm2(&r);
residual_history.push(r_norm);
if r_norm <= abs_tol {
let new_aug = extract_augmentation(&last_krylov, k_aug, n);
return Ok(AugmentedKrylovResult {
x,
residual_norm: r_norm,
iterations: total_mv,
converged: true,
residual_history,
new_augmentation: new_aug,
});
}
if total_mv >= self.config.max_iter {
break;
}
if k_aug > 0 {
let mut ata = vec![vec![0.0f64; k_aug]; k_aug];
let mut atr = vec![0.0f64; k_aug];
for i in 0..k_aug {
atr[i] = dot(&aw[i], &r);
for j in 0..k_aug {
ata[i][j] = dot(&aw[i], &aw[j]);
}
}
let alpha = solve_small_spd(&ata, &atr, k_aug);
for j in 0..k_aug {
for i in 0..n {
x[i] += alpha[j] * aug_orth[j][i];
}
}
}
let ax2 = matvec(&x);
total_mv += 1;
let r2: Vec<f64> = b.iter().zip(ax2.iter()).map(|(bi, axi)| bi - axi).collect();
let r2_norm = norm2(&r2);
if r2_norm <= abs_tol {
let new_aug = extract_augmentation(&last_krylov, k_aug, n);
residual_history.push(r2_norm);
return Ok(AugmentedKrylovResult {
x,
residual_norm: r2_norm,
iterations: total_mv,
converged: true,
residual_history,
new_augmentation: new_aug,
});
}
let m = self.config.krylov_dim;
let mut v: Vec<Vec<f64>> = vec![vec![0.0f64; n]; m + 1];
let mut h: Vec<Vec<f64>> = vec![vec![0.0f64; m]; m + 1];
let inv_r2 = 1.0 / r2_norm;
for l in 0..n {
v[0][l] = r2[l] * inv_r2;
}
let mut j_end = 1;
for j in 1..=m {
if j == m {
j_end = m;
break;
}
let w_raw = matvec(&v[j - 1]);
total_mv += 1;
let mut w = w_raw;
for i in 0..j {
h[i][j - 1] = dot(&w, &v[i]);
for l in 0..n {
w[l] -= h[i][j - 1] * v[i][l];
}
}
h[j][j - 1] = norm2(&w);
if h[j][j - 1] > 1e-15 {
let inv = 1.0 / h[j][j - 1];
for l in 0..n {
v[j][l] = w[l] * inv;
}
j_end = j + 1;
} else {
j_end = j + 1;
break;
}
if total_mv >= self.config.max_iter {
j_end = j + 1;
break;
}
}
let krylov_size = (j_end - 1).max(1).min(h[0].len());
let mut g = vec![0.0f64; j_end];
g[0] = r2_norm;
let cols = krylov_size.min(h[0].len());
let y = solve_least_squares_hessenberg(&h, &g, cols)?;
for j in 0..y.len().min(v.len()) {
for i in 0..n {
x[i] += y[j] * v[j][i];
}
}
last_krylov = v[..j_end].to_vec();
if total_mv >= self.config.max_iter {
break;
}
}
let ax_fin = matvec(&x);
total_mv += 1;
let r_fin: Vec<f64> = b
.iter()
.zip(ax_fin.iter())
.map(|(bi, axi)| bi - axi)
.collect();
let r_fin_norm = norm2(&r_fin);
residual_history.push(r_fin_norm);
let new_aug = extract_augmentation(&last_krylov, k_aug, n);
Ok(AugmentedKrylovResult {
x,
residual_norm: r_fin_norm,
iterations: total_mv,
converged: r_fin_norm <= abs_tol,
residual_history,
new_augmentation: new_aug,
})
}
}
pub(crate) fn solve_small_spd(a: &[Vec<f64>], b: &[f64], k: usize) -> Vec<f64> {
if k == 0 {
return Vec::new();
}
if k == 1 {
let diag = a[0][0];
return vec![if diag.abs() > 1e-300 {
b[0] / diag
} else {
0.0
}];
}
let mut l = vec![vec![0.0f64; k]; k];
let mut ok = true;
'chol: for i in 0..k {
for j in 0..=i {
let mut sum = a[i][j];
for p in 0..j {
sum -= l[i][p] * l[j][p];
}
if i == j {
if sum < 1e-300 {
ok = false;
break 'chol;
}
l[i][j] = sum.sqrt();
} else if l[j][j].abs() > 1e-300 {
l[i][j] = sum / l[j][j];
} else {
ok = false;
break 'chol;
}
}
}
if ok {
let mut y = vec![0.0f64; k];
for i in 0..k {
let mut s = b[i];
for j in 0..i {
s -= l[i][j] * y[j];
}
y[i] = if l[i][i].abs() > 1e-300 {
s / l[i][i]
} else {
0.0
};
}
let mut x = vec![0.0f64; k];
for i in (0..k).rev() {
let mut s = y[i];
for j in (i + 1)..k {
s -= l[j][i] * x[j];
}
x[i] = if l[i][i].abs() > 1e-300 {
s / l[i][i]
} else {
0.0
};
}
x
} else {
(0..k)
.map(|i| {
if a[i][i].abs() > 1e-300 {
b[i] / a[i][i]
} else {
0.0
}
})
.collect()
}
}
fn extract_augmentation(krylov: &[Vec<f64>], k_aug: usize, _n: usize) -> Vec<Vec<f64>> {
if krylov.is_empty() || k_aug == 0 {
return Vec::new();
}
let m = krylov.len();
let take = k_aug.min(m);
let mut new_vecs: Vec<Vec<f64>> = krylov[..take].to_vec();
gram_schmidt_mgs(&mut new_vecs);
new_vecs.retain(|vi| norm2(vi) > 0.5);
new_vecs
}
#[cfg(test)]
mod tests {
use super::*;
fn diag_mv(diag: Vec<f64>) -> impl Fn(&[f64]) -> Vec<f64> {
move |x: &[f64]| x.iter().zip(diag.iter()).map(|(xi, di)| xi * di).collect()
}
#[test]
fn test_augmented_krylov_no_augmentation() {
let n = 8;
let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let b = vec![1.0f64; n];
let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
krylov_dim: 6,
tol: 1e-12,
max_iter: 300,
max_cycles: 20,
});
let result = solver
.solve(diag_mv(diag.clone()), &b, None, &[])
.expect("augmented krylov solve failed");
assert!(
result.converged,
"should converge without augmentation: residual = {:.3e}",
result.residual_norm
);
}
#[test]
fn test_augmented_krylov_with_augmentation() {
let n = 10;
let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let b = vec![1.0f64; n];
let aug = vec![
{
let mut v = vec![0.0f64; n];
v[0] = 1.0;
v
},
{
let mut v = vec![0.0f64; n];
v[1] = 1.0;
v
},
];
let solver = AugmentedKrylov::new(AugmentedKrylovConfig {
krylov_dim: 8,
tol: 1e-12,
max_iter: 300,
max_cycles: 30,
});
let result = solver
.solve(diag_mv(diag), &b, None, &aug)
.expect("augmented krylov with augmentation failed");
assert!(
result.converged,
"should converge with augmentation: residual = {:.3e}",
result.residual_norm
);
}
#[test]
fn test_augmented_result_new_augmentation_populated() {
let n = 6;
let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let b = vec![1.0f64; n];
let aug = vec![{
let mut v = vec![0.0f64; n];
v[0] = 1.0;
v
}];
let solver = AugmentedKrylov::with_defaults();
let result = solver
.solve(diag_mv(diag), &b, None, &aug)
.expect("solve failed");
assert!(result.converged || result.residual_norm < 1e-8);
}
#[test]
fn test_augmented_config_default() {
let cfg = AugmentedKrylovConfig::default();
assert_eq!(cfg.krylov_dim, 20);
assert!(cfg.tol > 0.0);
assert!(cfg.max_iter > 0);
}
}