#[allow(unused_imports)]
use crate::algebra::blas::{dot_conj, nrm2};
use crate::solver::MonitorCallback;
use crate::solver::common::call_monitors;
use crate::algebra::bridge::BridgeScratch;
#[allow(unused_imports)]
use crate::algebra::prelude::*;
use crate::context::ksp_context::Workspace;
use crate::error::KError;
use crate::matrix::op::LinOp;
use crate::matrix::op_bridge::matvec_s;
use crate::parallel::{Comm, UniverseComm};
use crate::preconditioner::bridge::apply_pc_s;
use crate::preconditioner::{PcSide, Preconditioner};
#[cfg(feature = "complex")]
use crate::reduction::Packet;
use crate::reduction::{CommDeterministic, DotEngine, ReductionOptions, ReproMode};
use crate::solver::LinearSolver;
#[cfg(feature = "complex")]
use crate::solver::common::dot_result_to_real;
use crate::solver::common::{dot1_async_s, nrm2_async_s};
use crate::utils::convergence::{ConvergedReason, Convergence, SolveStats, SolverCounters};
use crate::utils::reduction::{AllreduceHandle, AllreduceOps, ReductOptions};
use smallvec::SmallVec;
use std::any::Any;
#[derive(Debug, Clone, Copy)]
pub enum CgNormType {
Preconditioned,
Unpreconditioned,
Natural,
None,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PcgVariant {
Classic,
Pipelined {
replace_every: usize,
},
}
pub const PCG_PIPELINED_DEFAULT_REPLACE_EVERY: usize = 50;
pub struct PcgSolver {
pub(crate) conv: Convergence,
norm_type: CgNormType,
reduction: ReductionOptions,
true_residual_monitor: Option<Box<MonitorCallback<f64>>>,
initial_guess_nonzero: bool,
variant: PcgVariant,
async_reduction: ReductOptions,
}
struct ClassicWorkspace<'a> {
r: &'a mut [S],
z: &'a mut [S],
p: &'a mut [S],
ap: &'a mut [S],
scratch: &'a mut BridgeScratch,
}
impl<'a> ClassicWorkspace<'a> {
fn acquire(work: &'a mut Workspace, n: usize) -> Self {
if work.tmp1.len() != n {
work.tmp1.resize(n, S::zero());
}
if work.tmp2.len() != n {
work.tmp2.resize(n, S::zero());
}
while work.q_s.len() < 2 {
work.q_s.push(Vec::new());
}
for buf in &mut work.q_s[..2] {
if buf.len() != n {
buf.resize(n, S::zero());
}
}
let (pbuf, rest) = work.q_s.split_at_mut(1);
let (apbuf, _) = rest.split_at_mut(1);
Self {
r: &mut work.tmp1[..n],
z: &mut work.tmp2[..n],
p: &mut pbuf[0][..n],
ap: &mut apbuf[0][..n],
scratch: &mut work.bridge,
}
}
}
impl PcgSolver {
pub fn new(rtol: f64, maxits: usize) -> Self {
Self {
conv: Convergence {
rtol,
atol: 1e-50,
dtol: 1e5,
max_iters: maxits,
},
norm_type: CgNormType::Preconditioned,
reduction: ReductionOptions::default(),
true_residual_monitor: None,
initial_guess_nonzero: false,
variant: PcgVariant::Classic,
async_reduction: ReductOptions::default(),
}
}
pub fn set_tolerances(&mut self, rtol: f64, atol: f64, dtol: f64, maxits: usize) {
self.conv.rtol = rtol;
self.conv.atol = atol;
self.conv.dtol = dtol;
self.conv.max_iters = maxits;
}
pub fn with_norm(mut self, norm_type: CgNormType) -> Self {
self.norm_type = norm_type;
self
}
pub fn with_reproducible_dot(mut self, f: bool) -> Self {
self.reduction.mode = if f {
ReproMode::Deterministic
} else {
ReproMode::Fast
};
self
}
pub fn with_true_residual_monitor(mut self, m: Box<MonitorCallback<f64>>) -> Self {
self.true_residual_monitor = Some(m);
self
}
#[must_use = "with_variant returns an updated solver; assign it before continuing"]
pub fn with_variant(mut self, variant: PcgVariant) -> Self {
self.variant = variant;
self
}
pub fn set_variant(&mut self, variant: PcgVariant) {
self.variant = variant;
}
pub fn variant(&self) -> PcgVariant {
self.variant
}
pub fn set_async_reduction_options(&mut self, opt: ReductOptions) {
self.async_reduction = opt;
}
fn async_options(&self) -> ReductOptions {
let mut opt = self.async_reduction.clone();
opt.mode = self.reduction.mode;
opt
}
pub fn with_nonzero_guess(mut self, f: bool) -> Self {
self.initial_guess_nonzero = f;
self
}
pub fn set_nonzero_guess(&mut self, f: bool) {
self.initial_guess_nonzero = f;
}
pub fn set_reproducible_dot(&mut self, f: bool) {
self.reduction.mode = if f {
ReproMode::Deterministic
} else {
ReproMode::Fast
};
}
pub fn set_true_residual_monitor(&mut self, m: Option<Box<MonitorCallback<f64>>>) {
self.true_residual_monitor = m;
}
#[inline]
fn dot<C: Comm + CommDeterministic>(&self, u: &[f64], v: &[f64], comm: &C) -> f64 {
let engine = DotEngine {
opts: self.reduction,
};
engine.dot(u, v, comm)
}
#[inline]
fn dot_scalar<C: Comm + CommDeterministic>(&self, u: &[S], v: &[S], comm: &C) -> R {
#[cfg(not(feature = "complex"))]
{
let ur: &[f64] = unsafe { &*(u as *const [S] as *const [f64]) };
let vr: &[f64] = unsafe { &*(v as *const [S] as *const [f64]) };
self.dot(ur, vr, comm)
}
#[cfg(feature = "complex")]
{
let local = crate::algebra::blas::dot_conj(u, v);
if matches!(self.reduction.mode, ReproMode::Fast) {
return dot_result_to_real(comm.allreduce_sum_scalar(local));
}
let packet = Packet::<2> {
v: [local.real(), local.imag()],
};
let reduced = comm.allreduce_det(&packet, self.reduction.mode);
reduced.v[0]
}
}
#[inline]
fn dot_scalar_many<C: Comm + CommDeterministic>(
&self,
pairs: &[(&[S], &[S])],
comm: &C,
out: &mut [R],
) {
assert_eq!(pairs.len(), out.len());
if pairs.is_empty() {
return;
}
#[cfg(not(feature = "complex"))]
{
let engine = DotEngine {
opts: self.reduction,
};
let mut real_pairs: SmallVec<[(&[f64], &[f64]); 8]> =
SmallVec::with_capacity(pairs.len());
for &(u, v) in pairs {
let ur: &[f64] = unsafe { &*(u as *const [S] as *const [f64]) };
let vr: &[f64] = unsafe { &*(v as *const [S] as *const [f64]) };
real_pairs.push((ur, vr));
}
engine.dot_many_into(real_pairs.as_slice(), out, comm);
}
#[cfg(feature = "complex")]
{
for ((u, v), slot) in pairs.iter().zip(out.iter_mut()) {
let local = crate::algebra::blas::dot_conj(u, v);
if matches!(self.reduction.mode, ReproMode::Fast) {
*slot = dot_result_to_real(comm.allreduce_sum_scalar(local));
} else {
let packet = Packet::<2> {
v: [local.real(), local.imag()],
};
let reduced = comm.allreduce_det(&packet, self.reduction.mode);
*slot = reduced.v[0];
}
}
}
}
#[inline]
fn nrm2_scalar<C: Comm + CommDeterministic>(&self, u: &[S], comm: &C) -> R {
let val = self.dot_scalar(u, u, comm);
val.abs().sqrt()
}
#[inline]
fn ensure_norm<C: Comm + CommDeterministic>(
&self,
vec: &[S],
comm: &C,
cache: &mut Option<R>,
) -> R {
if let Some(val) = *cache {
val
} else {
let val = self.nrm2_scalar(vec, comm);
*cache = Some(val);
val
}
}
#[allow(clippy::too_many_arguments)]
fn solve_classic_scalar<C: Comm + CommDeterministic>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&dyn Preconditioner>,
b: &[S],
x: &mut [S],
comm: &C,
monitors: &[Box<MonitorCallback<f64>>],
work: &mut Workspace,
) -> Result<SolveStats<f64>, KError> {
let n = b.len();
let mut buffers = ClassicWorkspace::acquire(work, n);
let ClassicWorkspace {
r,
z,
p,
ap,
scratch,
} = &mut buffers;
let (r, z, p, ap, scratch) = (&mut **r, &mut **z, &mut **p, &mut **ap, &mut **scratch);
let zero_guess =
!self.initial_guess_nonzero && x.iter().all(|&xi| xi.abs() <= R::default());
if zero_guess {
r.copy_from_slice(b);
} else {
matvec_s(a, x, &mut ap[..], scratch);
for i in 0..n {
r[i] = b[i] - ap[i];
}
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
let need_r_norm_res = matches!(
self.norm_type,
CgNormType::Unpreconditioned | CgNormType::None
);
let need_z_norm_res = matches!(self.norm_type, CgNormType::Natural);
let need_r_norm_monitor = self.true_residual_monitor.is_some();
let mut initial_pairs: SmallVec<[(&[S], &[S]); 3]> = SmallVec::new();
initial_pairs.push((&r[..], &z[..]));
if need_r_norm_res || need_r_norm_monitor {
initial_pairs.push((&r[..], &r[..]));
}
if need_z_norm_res {
initial_pairs.push((&z[..], &z[..]));
}
let mut reductions: SmallVec<[R; 3]> = SmallVec::new();
reductions.resize(initial_pairs.len(), R::zero());
self.dot_scalar_many(initial_pairs.as_slice(), comm, reductions.as_mut_slice());
let mut idx = 0;
let mut rho = reductions[idx];
idx += 1;
if rho <= R::default() || !rho.is_finite() {
return Err(KError::IndefinitePreconditioner);
}
let mut cached_r_norm = if need_r_norm_res || need_r_norm_monitor {
let value = reductions[idx];
idx += 1;
Some(
if value < R::default() {
R::default()
} else {
value
}
.sqrt(),
)
} else {
None
};
let cached_z_norm = if need_z_norm_res {
let value = reductions[idx];
Some(
if value < R::default() {
R::default()
} else {
value
}
.sqrt(),
)
} else {
None
};
let mut rho_prev = rho;
drop(initial_pairs);
drop(reductions);
let mut res = match self.norm_type {
CgNormType::Preconditioned => rho.abs().sqrt(),
CgNormType::Unpreconditioned => cached_r_norm.unwrap(),
CgNormType::Natural => cached_z_norm.unwrap(),
CgNormType::None => cached_r_norm.unwrap(),
};
let res0 = res;
if call_monitors(monitors, 0, res, 0) {
return Ok(SolveStats::new(0, res, ConvergedReason::StoppedByMonitor));
}
if let Some(m) = &self.true_residual_monitor {
let value = self.ensure_norm(r, comm, &mut cached_r_norm);
m(0, value, 0);
}
p.copy_from_slice(z);
let (reason0, mut stats0) = self.conv.check(res, res0, 0);
if !matches!(reason0, ConvergedReason::Continued) {
stats0.final_residual = self.ensure_norm(r, comm, &mut cached_r_norm);
return Ok(stats0);
}
for k in 1..=self.conv.max_iters {
if k > 1 {
let beta = rho / rho_prev;
let beta_s = S::from_real(beta);
for i in 0..n {
p[i] = z[i] + beta_s * p[i];
}
}
matvec_s(a, p, &mut ap[..], scratch);
let p_ap = self.dot_scalar(p, ap, comm);
if !p_ap.is_finite() || p_ap <= R::default() {
return Err(KError::IndefiniteMatrix);
}
let alpha = rho / p_ap;
let alpha_s = S::from_real(alpha);
for i in 0..n {
x[i] += alpha_s * p[i];
r[i] -= alpha_s * ap[i];
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
let need_r_norm_res = matches!(
self.norm_type,
CgNormType::Unpreconditioned | CgNormType::None
);
let need_z_norm_res = matches!(self.norm_type, CgNormType::Natural);
let need_r_norm_monitor = self.true_residual_monitor.is_some();
let mut dot_pairs: SmallVec<[(&[S], &[S]); 3]> = SmallVec::new();
dot_pairs.push((&r[..], &z[..]));
if need_r_norm_res || need_r_norm_monitor {
dot_pairs.push((&r[..], &r[..]));
}
if need_z_norm_res {
dot_pairs.push((&z[..], &z[..]));
}
let mut dot_results: SmallVec<[R; 3]> = SmallVec::new();
dot_results.resize(dot_pairs.len(), R::default());
self.dot_scalar_many(dot_pairs.as_slice(), comm, dot_results.as_mut_slice());
let mut idx = 0;
let mut rho_new = dot_results[idx];
idx += 1;
if !rho_new.is_finite() || rho_new < R::default() {
return Err(KError::IndefinitePreconditioner);
}
if rho_new < 1e-300 {
rho_new = R::default();
}
let mut r_norm = if need_r_norm_res || need_r_norm_monitor {
let value = dot_results[idx];
idx += 1;
Some(
if value < R::default() {
R::default()
} else {
value
}
.sqrt(),
)
} else {
None
};
let mut z_norm = if need_z_norm_res {
let value = dot_results[idx];
Some(
if value < R::default() {
R::default()
} else {
value
}
.sqrt(),
)
} else {
None
};
drop(dot_pairs);
drop(dot_results);
match self.norm_type {
CgNormType::Preconditioned => {
res = rho_new.abs().sqrt();
}
CgNormType::Unpreconditioned => {
res = r_norm.unwrap();
}
CgNormType::Natural => {
res = z_norm.unwrap();
}
CgNormType::None => {}
}
if call_monitors(monitors, k, res, 0) {
return Ok(SolveStats::new(k, res, ConvergedReason::StoppedByMonitor));
}
if let Some(m) = &self.true_residual_monitor {
let value = self.ensure_norm(r, comm, &mut r_norm);
m(k, value, 0);
}
let res_check = match self.norm_type {
CgNormType::Preconditioned => rho_new.abs().sqrt(),
CgNormType::Unpreconditioned => self.ensure_norm(r, comm, &mut r_norm),
CgNormType::Natural => self.ensure_norm(z, comm, &mut z_norm),
CgNormType::None => self.ensure_norm(r, comm, &mut r_norm),
};
let (reason, mut stats) = self.conv.check(res_check, res0, k);
if !matches!(reason, ConvergedReason::Continued) {
stats.final_residual = self.ensure_norm(r, comm, &mut r_norm);
return Ok(stats);
}
rho_prev = rho;
rho = rho_new;
cached_r_norm = r_norm;
}
let final_res = self.ensure_norm(r, comm, &mut cached_r_norm);
Ok(SolveStats::new(
self.conv.max_iters,
final_res,
ConvergedReason::DivergedMaxIts,
))
}
#[allow(clippy::too_many_arguments)]
fn solve_pipelined_scalar<C: Comm + CommDeterministic + AllreduceOps>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&dyn Preconditioner>,
b: &[S],
x: &mut [S],
pc_side: PcSide,
comm: &C,
monitors: &[Box<MonitorCallback<f64>>],
work: &mut Workspace,
) -> Result<SolveStats<f64>, KError> {
if pc_side != PcSide::Left {
return Err(KError::InvalidInput(
"Pipelined PCG requires left preconditioning with HPD M; choose PcSide::Left or use MINRES (Hermitian) / GMRES (general) instead".into(),
));
}
let n = b.len();
if x.len() != n {
return Err(KError::InvalidInput("dimension mismatch: x,b".into()));
}
let mut counters = SolverCounters::default();
let mut buffers = ClassicWorkspace::acquire(work, n);
let ClassicWorkspace {
r,
z,
p,
ap,
scratch,
} = &mut buffers;
let zero_guess =
!self.initial_guess_nonzero && x.iter().all(|&xi| xi.abs() <= R::default());
if zero_guess {
r.copy_from_slice(b);
} else {
matvec_s(a, x, &mut ap[..], scratch);
for i in 0..n {
r[i] = b[i] - ap[i];
}
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
p.copy_from_slice(z);
let mut opt = self.async_options();
if opt.max_inflight == 0 {
opt.max_inflight = 1;
}
let (h_rho0, _) = dot1_async_s(comm, r, z, &opt)?;
let rho0 = {
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_rho0).0
};
if !rho0.is_finite() || rho0 < R::default() {
return Err(KError::IndefinitePreconditioner);
}
let (h_rnorm0, _) = nrm2_async_s(comm, r, &opt);
let rnorm0_sq = {
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_rnorm0).0
};
let rnorm0 = rnorm0_sq.sqrt();
if rnorm0 == R::default() {
return Ok(
SolveStats::new(0, R::default(), ConvergedReason::ConvergedRtol)
.with_counters(counters),
);
}
let _ = match self.norm_type {
CgNormType::Preconditioned => rho0.sqrt(),
CgNormType::Unpreconditioned | CgNormType::None => rnorm0,
CgNormType::Natural => {
let (h_z, _) = nrm2_async_s(comm, z, &opt);
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_z).0.sqrt()
}
};
let actual_res0 = self.nrm2_scalar(r, comm);
counters.num_global_reductions += 1;
if call_monitors(monitors, 0, actual_res0, counters.num_global_reductions) {
return Ok(
SolveStats::new(0, actual_res0, ConvergedReason::StoppedByMonitor)
.with_counters(counters),
);
}
if let Some(m) = &self.true_residual_monitor {
m(0, actual_res0, 0);
}
let (reason0, mut stats0) = self.conv.check(actual_res0, rnorm0, 0);
if !matches!(reason0, ConvergedReason::Continued) {
stats0.final_residual = actual_res0;
counters.residual_replacements = 0;
stats0.counters = counters;
return Ok(stats0);
}
let replace_every = match self.variant {
PcgVariant::Pipelined { replace_every } => replace_every,
PcgVariant::Classic => 0,
};
let mut rho_curr = rho0;
let mut rho_prev = rho0;
let mut pending_rho: Option<AllreduceHandle<(R, R)>> = None;
let mut iterations = 0usize;
let mut residual_replacements = 0usize;
let mut force_restart = false;
'solve: loop {
while iterations < self.conv.max_iters {
if iterations > 0 {
let handle = pending_rho
.take()
.expect("pipelined PCG pending rho handle");
rho_curr = {
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(handle).0
};
if !rho_curr.is_finite() || rho_curr < R::default() {
return Err(KError::IndefinitePreconditioner);
}
let _ = match self.norm_type {
CgNormType::Preconditioned => rho_curr.sqrt(),
CgNormType::Unpreconditioned | CgNormType::None => {
let (h_nr, _) = nrm2_async_s(comm, r, &opt);
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_nr).0.sqrt()
}
CgNormType::Natural => {
let (h_z, _) = nrm2_async_s(comm, z, &opt);
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_z).0.sqrt()
}
};
let actual_res = self.nrm2_scalar(r, comm);
counters.num_global_reductions += 1;
if call_monitors(
monitors,
iterations,
actual_res,
counters.num_global_reductions,
) {
return Ok(
SolveStats::new(
iterations,
actual_res,
ConvergedReason::StoppedByMonitor,
)
.with_counters(counters),
);
}
if let Some(m) = &self.true_residual_monitor {
m(iterations, actual_res, 0);
}
let (reason, mut stats) = self.conv.check(actual_res, rnorm0, iterations);
if !matches!(reason, ConvergedReason::Continued) {
stats.final_residual = actual_res;
let tol = self.conv.atol.max(self.conv.rtol * rnorm0);
if actual_res > tol {
matvec_s(a, x, &mut ap[..], scratch);
for i in 0..n {
r[i] = b[i] - ap[i];
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
residual_replacements += 1;
p.copy_from_slice(z);
rho_prev = R::default();
rho_curr = self.dot_scalar(r, z, comm);
counters.num_global_reductions += 1;
if !rho_curr.is_finite() || rho_curr < R::default() {
return Err(KError::IndefinitePreconditioner);
}
let _ = match self.norm_type {
CgNormType::Preconditioned => rho_curr.sqrt(),
CgNormType::Unpreconditioned | CgNormType::None => {
counters.num_global_reductions += 1;
self.nrm2_scalar(r, comm)
}
CgNormType::Natural => {
counters.num_global_reductions += 1;
self.nrm2_scalar(z, comm)
}
};
let (h_rho_next, _) = dot1_async_s(comm, r, z, &opt)?;
pending_rho = Some(h_rho_next);
counters.residual_replacements = residual_replacements;
continue;
}
counters.residual_replacements = residual_replacements;
stats.counters = counters;
return Ok(stats);
}
}
if iterations >= self.conv.max_iters {
break;
}
if iterations > 0 {
let beta = if rho_prev == R::default() {
R::default()
} else {
rho_curr / rho_prev
};
let beta_s = S::from_real(beta);
for i in 0..n {
p[i] = z[i] + beta_s * p[i];
}
}
matvec_s(a, p, &mut ap[..], scratch);
#[cfg(not(feature = "complex"))]
let pp_ap_local: R = p
.iter()
.zip(ap.iter())
.fold(R::default(), |acc, (&pi, &api)| acc + pi * api);
#[cfg(feature = "complex")]
let pp_ap_local: R = dot_result_to_real(dot_conj(p, ap));
let (h_ppap, _) = comm.allreduce2_async(pp_ap_local, R::default(), &opt)?;
let pp_ap = {
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_ppap).0
};
if !pp_ap.is_finite() {
return Err(KError::IndefiniteMatrix);
}
if pp_ap.abs() <= f64::EPSILON {
return Ok(SolveStats::new(
iterations,
self.nrm2_scalar(r, comm),
ConvergedReason::ConvergedHappyBreakdown,
)
.with_counters({
counters.residual_replacements = residual_replacements;
counters
}));
}
let alpha = rho_curr / pp_ap;
let alpha_s = S::from_real(alpha);
for i in 0..n {
x[i] += alpha_s * p[i];
r[i] -= alpha_s * ap[i];
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
if replace_every > 0 && ((iterations + 1) % replace_every == 0) {
matvec_s(a, x, &mut ap[..], scratch);
for i in 0..n {
r[i] = b[i] - ap[i];
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
residual_replacements += 1;
p.copy_from_slice(z);
force_restart = true;
}
let (h_rho_next, _) = dot1_async_s(comm, r, z, &opt)?;
pending_rho = Some(h_rho_next);
if force_restart {
rho_prev = R::default();
force_restart = false;
} else {
rho_prev = rho_curr;
}
iterations += 1;
}
if let Some(handle) = pending_rho.take() {
counters.num_global_reductions += 1;
rho_curr = <C as AllreduceOps>::wait_pair(handle).0;
if !rho_curr.is_finite() || rho_curr < R::default() {
return Err(KError::IndefinitePreconditioner);
}
let _ = match self.norm_type {
CgNormType::Preconditioned => rho_curr.sqrt(),
CgNormType::Unpreconditioned | CgNormType::None => {
let (h_nr, _) = nrm2_async_s(comm, r, &opt);
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_nr).0.sqrt()
}
CgNormType::Natural => {
let (h_z, _) = nrm2_async_s(comm, z, &opt);
counters.num_global_reductions += 1;
<C as AllreduceOps>::wait_pair(h_z).0.sqrt()
}
};
let actual_res = self.nrm2_scalar(r, comm);
counters.num_global_reductions += 1;
if call_monitors(
monitors,
iterations,
actual_res,
counters.num_global_reductions,
) {
return Ok(
SolveStats::new(
iterations,
actual_res,
ConvergedReason::StoppedByMonitor,
)
.with_counters(counters),
);
}
if let Some(m) = &self.true_residual_monitor {
m(iterations, actual_res, 0);
}
let (reason, mut stats) = self.conv.check(actual_res, rnorm0, iterations);
if !matches!(reason, ConvergedReason::Continued) {
stats.final_residual = actual_res;
let tol = self.conv.atol.max(self.conv.rtol * rnorm0);
if actual_res > tol {
if iterations >= self.conv.max_iters {
counters.residual_replacements = residual_replacements;
break 'solve;
}
matvec_s(a, x, &mut ap[..], scratch);
for i in 0..n {
r[i] = b[i] - ap[i];
}
if let Some(pc) = pc {
apply_pc_s(pc, PcSide::Left, r, &mut z[..], scratch)?;
} else {
z.copy_from_slice(r);
}
residual_replacements += 1;
p.copy_from_slice(z);
rho_prev = R::default();
rho_curr = self.dot_scalar(r, z, comm);
counters.num_global_reductions += 1;
if !rho_curr.is_finite() || rho_curr < R::default() {
return Err(KError::IndefinitePreconditioner);
}
let _ = match self.norm_type {
CgNormType::Preconditioned => rho_curr.sqrt(),
CgNormType::Unpreconditioned | CgNormType::None => {
counters.num_global_reductions += 1;
self.nrm2_scalar(r, comm)
}
CgNormType::Natural => {
counters.num_global_reductions += 1;
self.nrm2_scalar(z, comm)
}
};
let (h_rho_next, _) = dot1_async_s(comm, r, z, &opt)?;
pending_rho = Some(h_rho_next);
if iterations < self.conv.max_iters {
force_restart = true;
}
counters.residual_replacements = residual_replacements;
continue 'solve;
}
counters.residual_replacements = residual_replacements;
stats.counters = counters;
return Ok(stats);
}
}
break 'solve;
}
counters.residual_replacements = residual_replacements;
let final_res = self.nrm2_scalar(r, comm);
Ok(
SolveStats::new(iterations, final_res, ConvergedReason::DivergedMaxIts)
.with_counters(counters),
)
}
#[allow(clippy::too_many_arguments)]
pub fn solve_with_comm<C: Comm + CommDeterministic + AllreduceOps>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&mut dyn Preconditioner>,
b: &[f64],
x: &mut [f64],
pc_side: PcSide,
comm: &C,
monitors: Option<&[Box<MonitorCallback<f64>>]>,
work: Option<&mut Workspace>,
) -> Result<SolveStats<f64>, KError> {
self.solve_impl(a, pc, b, x, pc_side, comm, monitors, work)
}
#[allow(clippy::too_many_arguments)]
fn solve_impl<C: Comm + CommDeterministic + AllreduceOps>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&mut dyn Preconditioner>,
b: &[f64],
x: &mut [f64],
pc_side: PcSide,
comm: &C,
monitors: Option<&[Box<MonitorCallback<f64>>]>,
work: Option<&mut Workspace>,
) -> Result<SolveStats<f64>, KError> {
match self.variant {
PcgVariant::Classic => self.solve_classic(a, pc, b, x, pc_side, comm, monitors, work),
PcgVariant::Pipelined { .. } => {
self.solve_pipelined(a, pc, b, x, pc_side, comm, monitors, work)
}
}
}
#[allow(clippy::too_many_arguments)]
fn solve_classic<C: Comm + CommDeterministic>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&mut dyn Preconditioner>,
b: &[f64],
x: &mut [f64],
pc_side: PcSide,
comm: &C,
monitors: Option<&[Box<MonitorCallback<f64>>]>,
work: Option<&mut Workspace>,
) -> Result<SolveStats<f64>, KError> {
let pc_ref = pc.as_deref();
if pc_side != PcSide::Left {
return Err(KError::InvalidInput(
"CG/PCG requires left preconditioning with HPD M; choose PcSide::Left or use MINRES (Hermitian) / GMRES (general) instead".into(),
));
}
let n = b.len();
if x.len() != n {
return Err(KError::InvalidInput("dimension mismatch: x,b".into()));
}
let monitors = monitors.unwrap_or(&[]);
let mut owned_workspace;
let work = match work {
Some(ws) => ws,
None => {
owned_workspace = Workspace::new(n);
&mut owned_workspace
}
};
#[cfg(not(feature = "complex"))]
let b_slice: &[S] = unsafe { &*(b as *const [f64] as *const [S]) };
#[cfg(not(feature = "complex"))]
let x_slice: &mut [S] = unsafe { &mut *(x as *mut [f64] as *mut [S]) };
#[cfg(feature = "complex")]
let b_owned: Vec<S> = b.iter().copied().map(S::from_real).collect();
#[cfg(feature = "complex")]
let mut x_owned: Vec<S> = x.iter().copied().map(S::from_real).collect();
#[cfg(feature = "complex")]
let b_slice: &[S] = &b_owned;
#[cfg(feature = "complex")]
let x_slice: &mut [S] = &mut x_owned;
let stats = self.solve_classic_scalar(a, pc_ref, b_slice, x_slice, comm, monitors, work)?;
#[cfg(feature = "complex")]
{
for (dst, src) in x.iter_mut().zip(x_slice.iter()) {
*dst = src.real();
}
}
Ok(stats)
}
#[allow(clippy::too_many_arguments)]
fn solve_pipelined<C: Comm + CommDeterministic + AllreduceOps>(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&mut dyn Preconditioner>,
b: &[f64],
x: &mut [f64],
pc_side: PcSide,
comm: &C,
monitors: Option<&[Box<MonitorCallback<f64>>]>,
work: Option<&mut Workspace>,
) -> Result<SolveStats<f64>, KError> {
let pc_ref = pc.as_deref();
if pc_side != PcSide::Left {
return Err(KError::InvalidInput(
"Pipelined PCG requires left preconditioning with HPD M; choose PcSide::Left or use MINRES (Hermitian) / GMRES (general) instead".into(),
));
}
let n = b.len();
if x.len() != n {
return Err(KError::InvalidInput("dimension mismatch: x,b".into()));
}
let monitors = monitors.unwrap_or(&[]);
let mut owned_workspace;
let work = match work {
Some(ws) => ws,
None => {
owned_workspace = Workspace::new(n);
&mut owned_workspace
}
};
#[cfg(not(feature = "complex"))]
let b_slice: &[S] = unsafe { &*(b as *const [f64] as *const [S]) };
#[cfg(not(feature = "complex"))]
let x_slice: &mut [S] = unsafe { &mut *(x as *mut [f64] as *mut [S]) };
#[cfg(feature = "complex")]
let b_owned: Vec<S> = b.iter().copied().map(S::from_real).collect();
#[cfg(feature = "complex")]
let mut x_owned: Vec<S> = x.iter().copied().map(S::from_real).collect();
#[cfg(feature = "complex")]
let b_slice: &[S] = &b_owned;
#[cfg(feature = "complex")]
let x_slice: &mut [S] = &mut x_owned;
let stats = self
.solve_pipelined_scalar(a, pc_ref, b_slice, x_slice, pc_side, comm, monitors, work)?;
#[cfg(feature = "complex")]
{
for (dst, src) in x.iter_mut().zip(x_slice.iter()) {
*dst = src.real();
}
}
Ok(stats)
}
}
impl LinearSolver for PcgSolver {
type Error = KError;
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn setup_workspace(&mut self, work: &mut Workspace) {
if work.q.len() < 2 {
work.q.resize(2, Vec::new());
}
}
fn solve(
&mut self,
a: &dyn LinOp<S = f64>,
pc: Option<&mut dyn Preconditioner>,
b: &[f64],
x: &mut [f64],
pc_side: PcSide,
comm: &UniverseComm,
monitors: Option<&[Box<MonitorCallback<f64>>]>,
work: Option<&mut Workspace>,
) -> Result<SolveStats<f64>, Self::Error> {
self.solve_impl(a, pc, b, x, pc_side, comm, monitors, work)
}
}