use ndarray::{Array1, ArrayView1, ArrayViewMut1, Zip};
pub const PCG_REL_TOL_FLOOR: f64 = 1e-12;
pub const PCG_PRECONDITIONER_FLOOR: f64 = 1e-12;
#[derive(Debug, Clone)]
pub struct PcgDiagnostics {
pub residuals: Vec<f64>,
pub alpha: Vec<f64>,
pub beta: Vec<f64>,
}
impl PcgDiagnostics {
fn new(initial_residual_norm: f64) -> Self {
Self {
residuals: vec![initial_residual_norm],
alpha: Vec::new(),
beta: Vec::new(),
}
}
fn push_iteration(&mut self, alpha: f64, beta: Option<f64>, residual_norm: f64) {
self.alpha.push(alpha);
if let Some(beta) = beta {
self.beta.push(beta);
}
self.residuals.push(residual_norm);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PcgStop {
Converged,
MaxIters,
Breakdown,
BadPreconditioner,
}
#[derive(Debug, Clone)]
pub struct PcgCoreResult {
pub stop: PcgStop,
pub iterations: usize,
pub rhs_norm: f64,
pub final_residual_norm: f64,
pub diagnostics: Option<PcgDiagnostics>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotReduction {
Serial,
Reordered,
}
#[inline]
fn serial_dot(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
let mut acc = 0.0_f64;
for (&x, &y) in a.iter().zip(b.iter()) {
acc += x * y;
}
acc
}
#[inline]
fn reordered_dot(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
match (a.as_slice(), b.as_slice()) {
(Some(av), Some(bv)) => {
const LANES: usize = 8;
let n = av.len().min(bv.len());
let mut acc = [0.0_f64; LANES];
let chunks = n / LANES;
for c in 0..chunks {
let base = c * LANES;
for l in 0..LANES {
acc[l] += av[base + l] * bv[base + l];
}
}
let mut s =
((acc[0] + acc[1]) + (acc[2] + acc[3])) + ((acc[4] + acc[5]) + (acc[6] + acc[7]));
for i in (chunks * LANES)..n {
s += av[i] * bv[i];
}
s
}
_ => serial_dot(a, b),
}
}
#[inline]
fn dot(a: &ArrayView1<f64>, b: &ArrayView1<f64>, reduction: DotReduction) -> f64 {
match reduction {
DotReduction::Serial => serial_dot(a, b),
DotReduction::Reordered => reordered_dot(a, b),
}
}
pub fn pcg_core<F>(
mut apply: F,
rhs: &ArrayView1<f64>,
precond_diag: &ArrayView1<f64>,
rel_tol: f64,
max_iters: usize,
refresh_period: usize,
record_diagnostics: bool,
reduction: DotReduction,
solution: &mut ArrayViewMut1<f64>,
) -> PcgCoreResult
where
F: FnMut(&Array1<f64>, &mut Array1<f64>),
{
let p = rhs.len();
let rhs_norm = dot(rhs, rhs, reduction).sqrt();
solution.fill(0.0);
let mut diagnostics = record_diagnostics.then(|| PcgDiagnostics::new(rhs_norm));
if precond_diag.len() != p || solution.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut x = Array1::<f64>::zeros(p);
if !rhs_norm.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
if rhs_norm == 0.0 {
return PcgCoreResult {
stop: PcgStop::Converged,
iterations: 0,
rhs_norm: 0.0,
final_residual_norm: 0.0,
diagnostics,
};
}
let tol = (rel_tol.max(PCG_REL_TOL_FLOOR) * rhs_norm).max(PCG_REL_TOL_FLOOR);
let mut inv_m = Array1::<f64>::zeros(p);
let mut bad_diag = false;
for (slot, &m) in inv_m.iter_mut().zip(precond_diag.iter()) {
if !m.is_finite() || m <= 0.0 {
bad_diag = true;
break;
}
*slot = 1.0 / m.max(PCG_PRECONDITIONER_FLOOR);
}
if bad_diag {
return PcgCoreResult {
stop: PcgStop::BadPreconditioner,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut r = rhs.to_owned();
let mut z = Array1::<f64>::zeros(p);
Zip::from(&mut z)
.and(&r)
.and(&inv_m)
.par_for_each(|zi, &ri, &im| {
*zi = ri * im;
});
let mut p_dir = z.clone();
let mut rz_old = dot(&r.view(), &z.view(), reduction);
if !rz_old.is_finite() || rz_old <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: 0,
rhs_norm,
final_residual_norm: rhs_norm,
diagnostics,
};
}
let mut ap = Array1::<f64>::zeros(p);
let mut last_r_norm = rhs_norm;
for iter in 0..max_iters {
apply(&p_dir, &mut ap);
if ap.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
let denom = dot(&p_dir.view(), &ap.view(), reduction);
if !denom.is_finite() || denom <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
let alpha = rz_old / denom;
if !alpha.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
x.scaled_add(alpha, &p_dir);
solution.assign(&x);
r.scaled_add(-alpha, &ap);
if refresh_period != 0 && (iter + 1) % refresh_period == 0 {
apply(&x, &mut ap);
if ap.len() != p {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
};
}
r.assign(rhs);
r.scaled_add(-1.0, &ap);
}
let r_norm = dot(&r.view(), &r.view(), reduction).sqrt();
last_r_norm = r_norm;
if r_norm.is_finite() && r_norm <= tol {
if let Some(d) = diagnostics.as_mut() {
d.push_iteration(alpha, None, r_norm);
}
return PcgCoreResult {
stop: PcgStop::Converged,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
Zip::from(&mut z)
.and(&r)
.and(&inv_m)
.par_for_each(|zi, &ri, &im| {
*zi = ri * im;
});
let rz_new = dot(&r.view(), &z.view(), reduction);
if !rz_new.is_finite() || rz_new <= 0.0 {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
let beta = rz_new / rz_old;
if !beta.is_finite() {
return PcgCoreResult {
stop: PcgStop::Breakdown,
iterations: iter + 1,
rhs_norm,
final_residual_norm: r_norm,
diagnostics,
};
}
if let Some(d) = diagnostics.as_mut() {
d.push_iteration(alpha, Some(beta), r_norm);
}
Zip::from(&mut p_dir).and(&z).par_for_each(|pi, &zi| {
*pi = zi + beta * *pi;
});
rz_old = rz_new;
}
PcgCoreResult {
stop: PcgStop::MaxIters,
iterations: max_iters,
rhs_norm,
final_residual_norm: last_r_norm,
diagnostics,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn dot_serial_is_bit_identical_to_plain_left_fold() {
let mut av = vec![1e16, 1.0];
let mut bv = vec![1.0, 1.0];
for k in 0..4096 {
av.push(1.0);
bv.push(((k as f64).sin()).abs() + 1e-3);
}
let a = Array1::from(av);
let b = Array1::from(bv);
let mut reference = 0.0_f64;
for (x, y) in a.iter().zip(b.iter()) {
reference += x * y;
}
let got = dot(&a.view(), &b.view(), DotReduction::Serial);
assert_eq!(
got.to_bits(),
reference.to_bits(),
"Serial reduction must be bit-identical to the plain left fold"
);
}
#[test]
fn dot_reordered_matches_serial_to_loose_tol() {
for &n in &[7usize, 8, 9, 16, 100, 513, 1024, 4096] {
let a: Array1<f64> = Array1::from_shape_fn(n, |i| ((i * 7 + 1) as f64).sin() * 3.0);
let b: Array1<f64> = Array1::from_shape_fn(n, |i| ((i * 13 + 3) as f64).cos() * 2.0);
let s = dot(&a.view(), &b.view(), DotReduction::Serial);
let r = dot(&a.view(), &b.view(), DotReduction::Reordered);
let rel = (s - r).abs() / s.abs().max(1e-300);
assert!(
rel < 1e-12,
"n={n}: reordered rel diff {rel:.3e} should be far below trace SE"
);
}
}
#[test]
fn dot_reordered_handles_tail_and_short_lengths() {
for &n in &[0usize, 1, 3, 5, 7] {
let a: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64) + 0.25);
let b: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64) * 0.5 + 1.0);
let s = dot(&a.view(), &b.view(), DotReduction::Serial);
let r = dot(&a.view(), &b.view(), DotReduction::Reordered);
assert_eq!(s.to_bits(), r.to_bits(), "n={n}");
}
}
#[test]
fn pcg_core_matches_known_spd_solve() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![4.0, 3.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
let prod = a.dot(v);
out.assign(&prod);
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
true,
DotReduction::Serial,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
assert!((x[0] - 0.0909090909).abs() < 1e-9, "x0={}", x[0]);
assert!((x[1] - 0.6363636363).abs() < 1e-9, "x1={}", x[1]);
let d = result.diagnostics.expect("diagnostics recorded");
assert!(!d.alpha.is_empty());
}
#[test]
fn pcg_core_unpreconditioned_diagonal_one_iteration() {
let p = 8;
let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
let b = Array1::from_vec(b);
let ones = Array1::<f64>::ones(p);
let diag_clone = diag.clone();
let mut w = Array1::<f64>::zeros(p);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
for i in 0..p {
out[i] = diag_clone[i] * v[i];
}
},
&b.view(),
&ones.view(),
1e-12,
p,
0,
false,
DotReduction::Serial,
&mut w.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
assert!(result.diagnostics.is_none());
for i in 0..p {
let expected = b[i] / diag[i];
assert!((w[i] - expected).abs() < 1e-10, "w[{i}]={}", w[i]);
}
}
#[test]
fn pcg_core_rejects_bad_preconditioner() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![-4.0, 3.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
false,
DotReduction::Serial,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::BadPreconditioner);
assert_eq!(x, Array1::<f64>::zeros(2));
}
#[test]
fn pcg_core_rejects_zero_preconditioner_entry() {
let a = array![[4.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let precond = array![4.0, 0.0];
let mut x = Array1::<f64>::zeros(2);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
1e-12,
20,
32,
false,
DotReduction::Serial,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::BadPreconditioner);
assert_eq!(x, Array1::<f64>::zeros(2));
}
#[test]
fn pcg_core_relative_residual_holds_for_sub_unit_rhs() {
let a = array![
[4.0, 1.0, 0.0, 0.0],
[1.0, 3.0, 0.25, 0.0],
[0.0, 0.25, 6.0, 0.5],
[0.0, 0.0, 0.5, 5.0]
];
let b = array![0.03, -0.02, 0.04, 0.02];
let precond = array![4.0, 3.0, 6.0, 5.0];
let rel_tol = 0.1_f64;
let rhs_norm = (b.iter().map(|x| x * x).sum::<f64>()).sqrt();
assert!(
rhs_norm < 1.0,
"test premise: rhs must be sub-unit; got {rhs_norm}"
);
let mut x = Array1::<f64>::zeros(4);
let result = pcg_core(
|v: &Array1<f64>, out: &mut Array1<f64>| {
out.assign(&a.dot(v));
},
&b.view(),
&precond.view(),
rel_tol,
64,
32,
false,
DotReduction::Serial,
&mut x.view_mut(),
);
assert_eq!(result.stop, PcgStop::Converged);
let r: Array1<f64> = &b - &a.dot(&x);
let r_norm = (r.iter().map(|v| v * v).sum::<f64>()).sqrt();
assert!(
r_norm <= rel_tol * rhs_norm + 1e-12,
"expected ‖r‖={r_norm:.3e} ≤ rel_tol·‖rhs‖={:.3e}",
rel_tol * rhs_norm
);
}
}