use ndarray::{Array1, ArrayView1, ArrayViewMut1, Zip};
pub(crate) const PCG_REL_TOL_FLOOR: f64 = 1e-12;
pub(crate) const PCG_PRECONDITIONER_FLOOR: f64 = 1e-12;
#[derive(Debug, Clone)]
pub(crate) struct PcgDiagnostics {
pub(crate) residuals: Vec<f64>,
pub(crate) alpha: Vec<f64>,
pub(crate) beta: Vec<f64>,
}
impl PcgDiagnostics {
fn new(initial_residual_norm: f64) -> Self {
Self {
residuals: vec![initial_residual_norm],
alpha: Vec::new(),
beta: Vec::new(),
}
}
fn push_iteration(&mut self, alpha: f64, beta: Option<f64>, residual_norm: f64) {
self.alpha.push(alpha);
if let Some(beta) = beta {
self.beta.push(beta);
}
self.residuals.push(residual_norm);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum PcgStop {
Converged,
MaxIters,
Breakdown,
BadPreconditioner,
}
#[derive(Debug, Clone)]
pub(crate) struct PcgCoreResult {
pub(crate) stop: PcgStop,
pub(crate) iterations: usize,
pub(crate) rhs_norm: f64,
pub(crate) final_residual_norm: f64,
pub(crate) diagnostics: Option<PcgDiagnostics>,
}
#[inline]
fn serial_dot(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
let mut acc = 0.0_f64;
for (&x, &y) in a.iter().zip(b.iter()) {
acc += x * y;
}
acc
}
pub(crate) fn pcg_core<F>(
mut apply: F,
rhs: &ArrayView1<f64>,
precond_diag: &ArrayView1<f64>,
rel_tol: f64,
max_iters: usize,
refresh_period: usize,
record_diagnostics: bool,
solution: &mut ArrayViewMut1<f64>,
) -> PcgCoreResult
where
F: FnMut(&Array1<f64>, &mut Array1<f64>),
{
let p = rhs.len();
let rhs_norm = serial_dot(rhs, rhs).sqrt();
solution.fill(0.0);
let mut diagnostics = record_diagnostics.then(|| PcgDiagnostics::new(rhs_norm));
if precond_diag.len() != p || solution.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut x = Array1::<f64>::zeros(p);
if !rhs_norm.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
if rhs_norm == 0.0 {
return PcgCoreResult {
stop: PcgStop::Converged,
iterations: 0,
rhs_norm: 0.0,
final_residual_norm: 0.0,
diagnostics,
};
}
let tol = (rel_tol.max(PCG_REL_TOL_FLOOR) * rhs_norm).max(PCG_REL_TOL_FLOOR);
let mut inv_m = Array1::<f64>::zeros(p);
let mut bad_diag = false;
for (slot, &m) in inv_m.iter_mut().zip(precond_diag.iter()) {
if !m.is_finite() || m <= 0.0 {
bad_diag = true;
break;
}
*slot = 1.0 / m.max(PCG_PRECONDITIONER_FLOOR);
}
if bad_diag {
return PcgCoreResult {
stop: PcgStop::BadPreconditioner,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut r = rhs.to_owned();
let mut z = Array1::<f64>::zeros(p);
Zip::from(&mut z)
.and(&r)
.and(&inv_m)
.par_for_each(|zi, &ri, &im| {
*zi = ri * im;
});
let mut p_dir = z.clone();
let mut rz_old = serial_dot(&r.view(), &z.view());
if !rz_old.is_finite() || rz_old <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut ap = Array1::<f64>::zeros(p);
let mut last_r_norm = rhs_norm;
for iter in 0..max_iters {
apply(&p_dir, &mut ap);
if ap.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
let denom = serial_dot(&p_dir.view(), &ap.view());
if !denom.is_finite() || denom <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
let alpha = rz_old / denom;
if !alpha.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
x.scaled_add(alpha, &p_dir);
solution.assign(&x);
r.scaled_add(-alpha, &ap);
if refresh_period != 0 && (iter + 1) % refresh_period == 0 {
apply(&x, &mut ap);
if ap.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
r.assign(rhs);
r.scaled_add(-1.0, &ap);
}
let r_norm = serial_dot(&r.view(), &r.view()).sqrt();
last_r_norm = r_norm;
if r_norm.is_finite() && r_norm <= tol {
if let Some(d) = diagnostics.as_mut() {
d.push_iteration(alpha, None, r_norm);
}
return PcgCoreResult {
stop: PcgStop::Converged,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
Zip::from(&mut z)
.and(&r)
.and(&inv_m)
.par_for_each(|zi, &ri, &im| {
*zi = ri * im;
});
let rz_new = serial_dot(&r.view(), &z.view());
if !rz_new.is_finite() || rz_new <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
let beta = rz_new / rz_old;
if !beta.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
if let Some(d) = diagnostics.as_mut() {
d.push_iteration(alpha, Some(beta), r_norm);
}
Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
*pi = zi + beta * *pi;
});
rz_old = rz_new;
}
PcgCoreResult {
stop: PcgStop::MaxIters,
iterations: max_iters,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn pcg_core_matches_known_spd_solve() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![4.0, 3.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
let prod = a.dot(v);
out.assign(&prod);
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
true,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
assert!((x[0] - 0.0909090909).abs() < 1e-9, "x0={}", x[0]);
assert!((x[1] - 0.6363636363).abs() < 1e-9, "x1={}", x[1]);
let d = result.diagnostics.expect("diagnostics recorded");
assert!(!d.alpha.is_empty());
}
#[test]
fn pcg_core_unpreconditioned_diagonal_one_iteration() {
let p = 8;
let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
let b = Array1::from_vec(b);
let ones = Array1::<f64>::ones(p);
let diag_clone = diag.clone();
let mut w = Array1::<f64>::zeros(p);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
for i in 0..p {
out[i] = diag_clone[i] * v[i];
}
},
&b.view(),
&ones.view(),
1e-12,
p,
0,
false,
&mut w.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
assert!(result.diagnostics.is_none());
for i in 0..p {
let expected = b[i] / diag[i];
assert!((w[i] - expected).abs() < 1e-10, "w[{i}]={}", w[i]);
}
}
#[test]
fn pcg_core_rejects_bad_preconditioner() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![-4.0, 3.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
false,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::BadPreconditioner);
assert_eq!(x, Array1::<f64>::zeros(2));
}
#[test]
fn pcg_core_rejects_zero_preconditioner_entry() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![4.0, 0.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
false,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::BadPreconditioner);
assert_eq!(x, Array1::<f64>::zeros(2));
}
#[test]
fn pcg_core_relative_residual_holds_for_sub_unit_rhs() {
let a = array![
[4.0, 1.0, 0.0, 0.0],
[1.0, 3.0, 0.25, 0.0],
[0.0, 0.25, 6.0, 0.5],
[0.0, 0.0, 0.5, 5.0]
];
let b = array![0.03, -0.02, 0.04, 0.02];
let precond = array![4.0, 3.0, 6.0, 5.0];
let rel_tol = 0.1_f64;
let rhs_norm = (b.iter().map(|x| x * x).sum::<f64>()).sqrt();
assert!(
rhs_norm < 1.0,
"test premise: rhs must be sub-unit; got {rhs_norm}"
);
let mut x = Array1::<f64>::zeros(4);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
rel_tol,
64,
32,
false,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
let r: Array1<f64> = &b - &a.dot(&x);
let r_norm = (r.iter().map(|v| v * v).sum::<f64>()).sqrt();
assert!(
r_norm <= rel_tol * rhs_norm + 1e-12,
"expected ‖r‖={r_norm:.3e} ≤ rel_tol·‖rhs‖={:.3e}",
rel_tol * rhs_norm
);
}
}