use crate::algebra::bridge::BridgeScratch;
#[allow(unused_imports)]
use crate::algebra::prelude::*;
use crate::core::block::BlockVec;
use crate::parallel::ReductionEngine;
use crate::reduction::ReproMode;
use crate::solver::common::givens::{apply_new_givens_and_update_g, apply_prev_givens_to_col};
use crate::solver::gmres::AugmentationPolicy;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct Workspace {
pub tmp1: Vec<S>,
pub tmp2: Vec<S>,
pub q_s: Vec<Vec<S>>,
pub z_s: Vec<Vec<S>>,
pub h_s: Vec<Vec<S>>,
pub q: Vec<Vec<S>>,
pub z: Vec<Vec<S>>,
pub h: Vec<Vec<S>>,
pub v_mem: Vec<S>,
pub z_mem: Vec<S>,
pub h_mem: Vec<S>,
pub givens_col_scratch: Vec<S>,
pub cs: Vec<R>,
pub sn: Vec<S>,
pub g: Vec<S>,
pub blk_scratch: Vec<S>,
pub blk_payload: Vec<R>,
pub bridge: BridgeScratch,
pub bridge_tmp: Vec<S>,
pub block_buf: Option<BlockVec>,
pub tsqr: Option<TsqrWorkspace>,
pub pipelined_w: Vec<S>,
pub pipelined_wtmp: Vec<S>,
pub pipelined_payload: Vec<R>,
pub gmres_sstep: Option<GmresSStepWorkspace>,
pub gmres_recycle: RecyclingSpace,
pub reduction: crate::utils::reduction::ReductOptions,
pub reduction_engine: Option<Arc<dyn ReductionEngine>>,
pub send_arena: crate::utils::buffer_pool::BufferPool<u8>,
pub recv_arena: crate::utils::buffer_pool::BufferPool<u8>,
pub packet_arena: crate::utils::buffer_pool::BufferPool<u8>,
n: usize,
m: usize,
need_z: bool,
}
#[derive(Debug, Clone)]
pub struct RecyclingSpace {
u: Vec<S>,
au: Vec<S>,
n: usize,
rmax: usize,
cols: usize,
policy: AugmentationPolicy,
}
#[derive(Debug)]
pub enum PipeReduct {
Sync {
reductions: usize,
},
Async {
handle: crate::parallel::ReduceHandle<Vec<R>>,
},
}
impl Default for RecyclingSpace {
fn default() -> Self {
Self {
u: Vec::new(),
au: Vec::new(),
n: 0,
rmax: 0,
cols: 0,
policy: AugmentationPolicy::None,
}
}
}
impl RecyclingSpace {
pub fn configure(&mut self, n: usize, rmax: usize, policy: AugmentationPolicy) {
if self.n != n || self.rmax != rmax {
self.u.resize(n.saturating_mul(rmax), S::zero());
self.au.resize(n.saturating_mul(rmax), S::zero());
self.n = n;
self.rmax = rmax;
self.cols = 0;
}
self.policy = policy;
}
#[inline]
pub fn policy(&self) -> AugmentationPolicy {
self.policy.clone()
}
#[inline]
pub fn capacity(&self) -> usize {
self.rmax
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
pub fn clear(&mut self) {
self.cols = 0;
}
pub fn col(&self, j: usize) -> &[S] {
let n = self.n;
&self.u[j * n..(j + 1) * n]
}
pub fn col_mut(&mut self, j: usize) -> &mut [S] {
let n = self.n;
&mut self.u[j * n..(j + 1) * n]
}
pub fn a_col(&self, j: usize) -> &[S] {
let n = self.n;
&self.au[j * n..(j + 1) * n]
}
pub fn a_col_mut(&mut self, j: usize) -> &mut [S] {
let n = self.n;
&mut self.au[j * n..(j + 1) * n]
}
pub fn push_from(&mut self, u: &[S], au: &[S]) {
if self.cols >= self.rmax {
return;
}
let n = self.n;
let dst_u = &mut self.u[self.cols * n..(self.cols + 1) * n];
let dst_au = &mut self.au[self.cols * n..(self.cols + 1) * n];
dst_u.copy_from_slice(u);
dst_au.copy_from_slice(au);
self.cols += 1;
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum ReorthPolicy {
Never,
#[default]
IfNeeded,
Always,
}
#[derive(Debug, Clone, Copy)]
pub struct GmresSpec {
pub n: usize,
pub m: usize,
pub need_z: bool,
pub block_s: usize,
}
#[derive(Debug, Clone)]
pub struct GmresSStepWorkspace {
pub w: BlockVec,
pub q: BlockVec,
pub aq: BlockVec,
pub gram: Vec<S>,
pub c_prev: Vec<R>,
pub payload: Vec<S>,
pub r: Vec<R>,
}
impl GmresSStepWorkspace {
pub fn new(n: usize, s: usize, m: usize) -> Self {
let mut ws = Self {
w: BlockVec::new(n, s),
q: BlockVec::new(n, s),
aq: BlockVec::new(n, s),
gram: vec![S::zero(); s.saturating_mul(s)],
c_prev: vec![R::default(); m.saturating_mul(s)],
payload: vec![S::zero(); s.saturating_mul(s + 1) / 2 + m.saturating_mul(s)],
r: vec![R::default(); s.saturating_mul(s)],
};
ws.ensure(n, s, m);
ws
}
pub fn ensure(&mut self, n: usize, s: usize, m: usize) {
self.w.resize(n, s);
self.q.resize(n, s);
self.aq.resize(n, s);
ensure_len(&mut self.gram, s.saturating_mul(s));
ensure_len(&mut self.c_prev, m.saturating_mul(s));
let payload_len = s.saturating_mul(s + 1) / 2 + m.saturating_mul(s);
ensure_len(&mut self.payload, payload_len);
ensure_len(&mut self.r, s.saturating_mul(s));
}
}
#[derive(Debug, Clone)]
pub struct TsqrWorkspace {
pub taus: Vec<S>,
pub rmat: Vec<S>,
pub w_max: usize,
}
impl TsqrWorkspace {
pub fn with_width(w_max: usize) -> Self {
Self {
taus: vec![S::zero(); w_max],
rmat: vec![S::zero(); w_max.saturating_mul(w_max)],
w_max,
}
}
}
impl Workspace {
pub fn new(n: usize) -> Self {
let mut ws = Self::default();
ws.tmp1.resize(n, S::zero());
ws.tmp2.resize(n, S::zero());
ws.n = n;
ws
}
pub fn ensure_comm_bytes(&mut self, max_send: usize, max_recv: usize) {
self.send_arena.ensure_len(max_send);
self.recv_arena.ensure_len(max_recv);
}
pub fn ensure_block(&mut self, n: usize, p: usize) {
if p == 0 {
self.block_buf = None;
return;
}
let replace = match self.block_buf {
Some(ref buf) if buf.nrows() == n && buf.ncols() >= p => false,
_ => true,
};
if replace {
self.block_buf = Some(BlockVec::new(n, p));
}
}
pub fn ensure_tsqr(&mut self, w_max: usize) {
if w_max == 0 {
self.tsqr = None;
return;
}
let replace = match self.tsqr {
Some(ref tsqr) if tsqr.w_max >= w_max => false,
_ => true,
};
if replace {
self.tsqr = Some(TsqrWorkspace::with_width(w_max));
}
}
pub fn ensure_sstep(&mut self, n: usize, s: usize, m: usize) {
if s == 0 {
self.gmres_sstep = None;
return;
}
let need_new = match self.gmres_sstep {
Some(ref buf) => {
buf.w.nrows() != n || buf.w.ncols() < s || buf.c_prev.len() < m.saturating_mul(s)
}
None => true,
};
if need_new {
self.gmres_sstep = Some(GmresSStepWorkspace::new(n, s, m));
} else if let Some(ref mut buf) = self.gmres_sstep {
buf.ensure(n, s, m);
}
}
#[inline]
pub fn sstep_mut(&mut self) -> Option<&mut GmresSStepWorkspace> {
self.gmres_sstep.as_mut()
}
#[inline]
pub fn n(&self) -> usize {
self.n
}
#[inline]
pub fn m(&self) -> usize {
self.m
}
#[inline]
pub fn has_z(&self) -> bool {
self.need_z
}
#[inline]
pub fn ld_h(&self) -> usize {
self.m + 1
}
pub fn acquire_gmres(&mut self, spec: GmresSpec) {
self.n = spec.n;
self.m = spec.m;
self.need_z = spec.need_z;
let n = spec.n;
let m = spec.m;
let v_len = (m + 1).checked_mul(n).expect("v_len overflow");
let z_len = if spec.need_z {
m.checked_mul(n).expect("z_len overflow")
} else {
0
};
let h_len = (m + 1).checked_mul(m).expect("h_len overflow");
let g_len = m + 1;
ensure_len(&mut self.tmp1, n);
ensure_len(&mut self.tmp2, n);
ensure_len(&mut self.v_mem, v_len);
if spec.need_z {
ensure_len(&mut self.z_mem, z_len);
} else {
self.z_mem.clear();
}
ensure_len(&mut self.h_mem, h_len);
ensure_len(&mut self.cs, m);
ensure_len(&mut self.sn, m);
ensure_len(&mut self.g, g_len);
ensure_len(&mut self.pipelined_w, n);
ensure_len(&mut self.pipelined_wtmp, n);
#[cfg(feature = "complex")]
let payload_len = 2 * (m + 1) + 1;
#[cfg(not(feature = "complex"))]
let payload_len = m + 2;
ensure_len(&mut self.pipelined_payload, payload_len);
if spec.block_s > 0 {
ensure_len(&mut self.blk_scratch, n * spec.block_s);
let payload_cap = block_payload_capacity(spec.m.saturating_add(1), spec.block_s);
ensure_capacity(&mut self.blk_payload, payload_cap);
} else {
self.blk_scratch.clear();
self.blk_payload.clear();
}
self.ensure_sstep(n, spec.block_s, m);
}
pub fn set_reduction_options(&mut self, opt: crate::utils::reduction::ReductOptions) {
self.reduction = opt;
self.reduction_engine = None;
}
pub fn set_reduction_engine(&mut self, engine: Arc<dyn ReductionEngine>) {
self.reduction_engine = Some(engine);
}
pub fn reduction_engine(&self) -> Option<&Arc<dyn ReductionEngine>> {
self.reduction_engine.as_ref()
}
pub fn set_reduction_mode(&mut self, mode: ReproMode) {
self.reduction.mode = mode;
self.reduction_engine = None;
}
pub fn reduction_options(&self) -> &crate::utils::reduction::ReductOptions {
&self.reduction
}
#[inline]
pub fn v_col(&mut self, j: usize) -> &mut [S] {
debug_assert!(j <= self.m);
let n = self.n;
let off = j.checked_mul(n).expect("v offset overflow");
&mut self.v_mem[off..off + n]
}
#[inline]
pub fn z_col(&mut self, j: usize) -> &mut [S] {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let off = j.checked_mul(n).expect("z offset overflow");
&mut self.z_mem[off..off + n]
}
#[inline]
pub fn h_at(&self, i: usize, j: usize) -> S {
debug_assert!(i <= self.m && j < self.m);
self.h_mem[j * (self.m + 1) + i]
}
#[inline]
pub fn h_at_mut(&mut self, i: usize, j: usize) -> &mut S {
debug_assert!(i <= self.m && j < self.m);
let idx = j * (self.m + 1) + i;
&mut self.h_mem[idx]
}
pub fn v_cols2(&mut self, a: usize, b: usize) -> (&mut [S], &mut [S]) {
debug_assert!(a <= self.m && b <= self.m && a != b);
let n = self.n;
let (lo, hi) = if a < b { (a, b) } else { (b, a) };
let lo_off = lo * n;
let hi_off = hi * n;
let (lo_part, rest) = self.v_mem.split_at_mut(hi_off);
let (_, lo_slice) = lo_part.split_at_mut(lo_off);
let (hi_slice, _) = rest.split_at_mut(n);
if a < b {
(&mut lo_slice[..n], hi_slice)
} else {
(hi_slice, &mut lo_slice[..n])
}
}
pub fn z_cols2(&mut self, a: usize, b: usize) -> (&mut [S], &mut [S]) {
debug_assert!(self.need_z && a < self.m && b < self.m && a != b);
let n = self.n;
let (lo, hi) = if a < b { (a, b) } else { (b, a) };
let lo_off = lo * n;
let hi_off = hi * n;
let (lo_part, rest) = self.z_mem.split_at_mut(hi_off);
let (_, lo_slice) = lo_part.split_at_mut(lo_off);
let (hi_slice, _) = rest.split_at_mut(n);
if a < b {
(&mut lo_slice[..n], hi_slice)
} else {
(hi_slice, &mut lo_slice[..n])
}
}
#[inline]
pub fn v_and_z_mut(&mut self, j: usize) -> (&[S], &mut [S]) {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let off = j * n;
let vj: &[S] = &self.v_mem[off..off + n];
let zj: &mut [S] = &mut self.z_mem[off..off + n];
(vj, zj)
}
#[inline]
pub fn tmp1_and_z_mut(&mut self, j: usize) -> (&[S], &mut [S]) {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let tmp: &[S] = &self.tmp1[..n];
let z: &mut [S] = &mut self.z_mem[j * n..(j + 1) * n];
(tmp, z)
}
#[inline]
pub fn tmp2_and_z_mut(&mut self, j: usize) -> (&[S], &mut [S]) {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let tmp: &[S] = &self.tmp2[..n];
let z: &mut [S] = &mut self.z_mem[j * n..(j + 1) * n];
(tmp, z)
}
#[inline]
pub fn z_and_tmp2_mut(&mut self, j: usize) -> (&[S], &mut [S]) {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let z: &[S] = &self.z_mem[j * n..(j + 1) * n];
let tmp: &mut [S] = &mut self.tmp2[..n];
(z, tmp)
}
#[inline]
pub fn copy_tmp2_into_vcol(&mut self, j: usize) {
let n = self.n;
let dst = &mut self.v_mem[j * n..(j + 1) * n];
let src = &self.tmp2[..n];
dst.copy_from_slice(src);
}
#[inline]
pub fn copy_tmp1_into_vcol(&mut self, j: usize) {
let n = self.n;
let dst = &mut self.v_mem[j * n..(j + 1) * n];
let src = &self.tmp1[..n];
dst.copy_from_slice(src);
}
#[inline]
pub fn copy_vcol_into_zcol(&mut self, j: usize) {
debug_assert!(self.need_z && j < self.m);
let n = self.n;
let src = &self.v_mem[j * n..(j + 1) * n];
let dst = &mut self.z_mem[j * n..(j + 1) * n];
dst.copy_from_slice(src);
}
#[inline]
pub fn copy_vcol_into_tmp1(&mut self, j: usize) {
let n = self.n;
let src = &self.v_mem[j * n..(j + 1) * n];
self.tmp1[..n].copy_from_slice(src);
}
#[inline]
pub fn apply_prev_givens_to_col(&mut self, j: usize, upto: usize) {
if upto == 0 {
return;
}
let ld = self.ld_h();
let base = j * ld;
let len = upto + 1;
ensure_len(&mut self.givens_col_scratch, len);
self.givens_col_scratch[..len].copy_from_slice(&self.h_mem[base..base + len]);
apply_prev_givens_to_col(
&mut self.givens_col_scratch[..len],
upto,
&self.cs,
&self.sn,
);
self.h_mem[base..base + len].copy_from_slice(&self.givens_col_scratch[..len]);
}
#[inline]
pub fn apply_final_givens_and_update_g(&mut self, j: usize) {
let ld = self.ld_h();
let base = j * ld;
let len = j + 2;
ensure_len(&mut self.givens_col_scratch, len);
self.givens_col_scratch[..len].copy_from_slice(&self.h_mem[base..base + len]);
apply_new_givens_and_update_g(
&mut self.givens_col_scratch[..len],
j,
&mut self.cs[..],
&mut self.sn[..],
&mut self.g[..],
);
self.h_mem[base..base + len].copy_from_slice(&self.givens_col_scratch[..len]);
}
#[cfg(not(feature = "complex"))]
pub fn finish_pipelined_arnoldi(
&mut self,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
mut glob: Vec<R>,
) -> Result<usize, crate::error::KError> {
let payload_len = k + 2;
if glob.len() != payload_len {
glob.resize(payload_len, R::zero());
}
let mut reductions = 1usize;
let mut sum_h2 = R::zero();
for i in 0..=k {
let hij = glob[i];
sum_h2 += hij * hij;
let vi = &self.v_mem[i * n..(i + 1) * n];
for idx in 0..n {
self.pipelined_wtmp[idx] -= hij * vi[idx];
}
*self.h_at_mut(i, k) = hij;
}
let total_norm_sq = glob[k + 1];
let mut hnext_sq = (total_norm_sq - sum_h2).max(R::zero());
if !hnext_sq.is_finite() {
hnext_sq = R::zero();
}
let tol = tol.max(R::zero());
let tol_sq = tol * tol;
let trigger_reorth = match policy {
ReorthPolicy::Never => false,
ReorthPolicy::Always => true,
ReorthPolicy::IfNeeded => {
total_norm_sq > R::zero() && hnext_sq < tol_sq * total_norm_sq
}
};
if trigger_reorth {
reductions += 1;
glob.resize(payload_len, R::zero());
for i in 0..=k {
let vi = &self.v_mem[i * n..(i + 1) * n];
glob[i] = vi
.iter()
.zip(&self.pipelined_wtmp[..n])
.map(|(a, b)| a * b)
.sum();
}
glob[k + 1] = self.pipelined_wtmp[..n].iter().map(|val| val * val).sum();
let corr = red.iallreduce_sum_vec_r(glob).wait();
let mut delta_norm_sq = R::zero();
for i in 0..=k {
let delta = corr[i];
delta_norm_sq += delta * delta;
let vi = &self.v_mem[i * n..(i + 1) * n];
for idx in 0..n {
self.pipelined_wtmp[idx] -= delta * vi[idx];
}
let hij = *self.h_at_mut(i, k) + delta;
*self.h_at_mut(i, k) = hij;
}
sum_h2 = R::zero();
for i in 0..=k {
let hij = *self.h_at_mut(i, k);
sum_h2 += hij * hij;
}
let wtmp_norm_sq = corr[k + 1];
hnext_sq = (wtmp_norm_sq - delta_norm_sq).max(R::zero());
if !hnext_sq.is_finite() {
hnext_sq = R::zero();
}
glob = corr;
}
let hnext = hnext_sq.sqrt();
*self.h_at_mut(k + 1, k) = hnext;
let base = (k + 1) * n;
if hnext > R::zero() {
let inv = S::from_real(hnext.recip());
for idx in 0..n {
self.v_mem[base + idx] = self.pipelined_wtmp[idx] * inv;
}
} else {
for idx in 0..n {
self.v_mem[base + idx] = S::zero();
}
}
self.pipelined_payload = glob;
Ok(reductions)
}
#[cfg(feature = "complex")]
pub fn finish_pipelined_arnoldi(
&mut self,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
mut glob: Vec<R>,
) -> Result<usize, crate::error::KError> {
let payload_len = 2 * (k + 1) + 1;
if glob.len() != payload_len {
glob.resize(payload_len, R::zero());
}
let mut reductions = 1usize;
let mut sum_h2 = R::zero();
for i in 0..=k {
let hij = S::from_parts(glob[2 * i], glob[2 * i + 1]);
sum_h2 += hij.abs2();
let vi = &self.v_mem[i * n..(i + 1) * n];
for idx in 0..n {
self.pipelined_wtmp[idx] -= hij * vi[idx];
}
*self.h_at_mut(i, k) = hij;
}
let total_norm_sq = glob[2 * (k + 1)];
let mut hnext_sq = (total_norm_sq - sum_h2).max(R::zero());
if !hnext_sq.is_finite() {
hnext_sq = R::zero();
}
let tol = tol.max(R::zero());
let tol_sq = tol * tol;
let trigger_reorth = match policy {
ReorthPolicy::Never => false,
ReorthPolicy::Always => true,
ReorthPolicy::IfNeeded => {
total_norm_sq > R::zero() && hnext_sq < tol_sq * total_norm_sq
}
};
if trigger_reorth {
reductions += 1;
glob.resize(payload_len, R::zero());
for i in 0..=k {
let vi = &self.v_mem[i * n..(i + 1) * n];
let mut acc = S::zero();
for (&a, &b) in vi.iter().zip(&self.pipelined_wtmp[..n]) {
acc = acc + a.conj() * b;
}
glob[2 * i] = acc.real();
glob[2 * i + 1] = acc.imag();
}
let mut norm_sq = R::zero();
for &value in &self.pipelined_wtmp[..n] {
norm_sq += value.abs2();
}
glob[2 * (k + 1)] = norm_sq;
let corr = red.iallreduce_sum_vec_r(glob).wait();
let mut delta_norm_sq = R::zero();
for i in 0..=k {
let delta = S::from_parts(corr[2 * i], corr[2 * i + 1]);
delta_norm_sq += delta.abs2();
let vi = &self.v_mem[i * n..(i + 1) * n];
for idx in 0..n {
self.pipelined_wtmp[idx] -= delta * vi[idx];
}
let hij = *self.h_at_mut(i, k) + delta;
*self.h_at_mut(i, k) = hij;
}
sum_h2 = R::zero();
for i in 0..=k {
let hij = *self.h_at_mut(i, k);
sum_h2 += hij.abs2();
}
let wtmp_norm_sq = corr[2 * (k + 1)];
hnext_sq = (wtmp_norm_sq - delta_norm_sq).max(R::zero());
if !hnext_sq.is_finite() {
hnext_sq = R::zero();
}
glob = corr;
}
let hnext = hnext_sq.sqrt();
*self.h_at_mut(k + 1, k) = S::from_real(hnext);
let base = (k + 1) * n;
if hnext > R::zero() {
let inv = S::from_real(hnext.recip());
for idx in 0..n {
self.v_mem[base + idx] = self.pipelined_wtmp[idx] * inv;
}
} else {
for idx in 0..n {
self.v_mem[base + idx] = S::zero();
}
}
self.pipelined_payload = glob;
Ok(reductions)
}
#[cfg(not(feature = "complex"))]
pub fn finish_pipe_reduction(
&mut self,
pipe: PipeReduct,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
) -> Result<usize, crate::error::KError> {
match pipe {
PipeReduct::Sync { reductions } => Ok(reductions),
PipeReduct::Async { handle } => {
let glob = handle.wait();
self.finish_pipelined_arnoldi(k, n, red, policy, tol, glob)
}
}
}
#[cfg(feature = "complex")]
pub fn finish_pipe_reduction(
&mut self,
pipe: PipeReduct,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
) -> Result<usize, crate::error::KError> {
match pipe {
PipeReduct::Sync { reductions } => Ok(reductions),
PipeReduct::Async { handle } => {
let glob = handle.wait();
self.finish_pipelined_arnoldi(k, n, red, policy, tol, glob)
}
}
}
#[cfg(not(feature = "complex"))]
pub fn pipelined_arnoldi_step(
&mut self,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
) -> Result<PipeReduct, crate::error::KError> {
debug_assert!(k < self.m);
let w = &self.pipelined_w[..n];
let payload_len = k + 2;
let mut payload = std::mem::take(&mut self.pipelined_payload);
payload.resize(payload_len, R::zero());
for i in 0..=k {
let vi = &self.v_mem[i * n..(i + 1) * n];
payload[i] = vi.iter().zip(w).map(|(a, b)| a * b).sum();
}
payload[k + 1] = w.iter().map(|val| val * val).sum();
let handle = red.iallreduce_sum_vec_r(payload);
self.pipelined_wtmp[..n].copy_from_slice(w);
if handle.is_ready() {
let payload = handle.wait();
let reductions = self.finish_pipelined_arnoldi(k, n, red, policy, tol, payload)?;
Ok(PipeReduct::Sync { reductions })
} else {
Ok(PipeReduct::Async { handle })
}
}
#[cfg(feature = "complex")]
pub fn pipelined_arnoldi_step(
&mut self,
k: usize,
n: usize,
red: &dyn crate::parallel::ReductionEngine,
policy: ReorthPolicy,
tol: R,
) -> Result<PipeReduct, crate::error::KError> {
debug_assert!(k < self.m);
let w = &self.pipelined_w[..n];
let payload_len = 2 * (k + 1) + 1;
let mut payload = std::mem::take(&mut self.pipelined_payload);
payload.resize(payload_len, R::zero());
for i in 0..=k {
let vi = &self.v_mem[i * n..(i + 1) * n];
let mut acc = S::zero();
for (&a, &b) in vi.iter().zip(w) {
acc = acc + a.conj() * b;
}
payload[2 * i] = acc.real();
payload[2 * i + 1] = acc.imag();
}
let mut norm_sq = R::zero();
for &value in w.iter() {
norm_sq += value.abs2();
}
payload[2 * (k + 1)] = norm_sq;
let handle = red.iallreduce_sum_vec_r(payload);
self.pipelined_wtmp[..n].copy_from_slice(w);
if handle.is_ready() {
let payload = handle.wait();
let reductions = self.finish_pipelined_arnoldi(k, n, red, policy, tol, payload)?;
Ok(PipeReduct::Sync { reductions })
} else {
Ok(PipeReduct::Async { handle })
}
}
}
#[inline]
fn ensure_len<T: Copy + Default>(v: &mut Vec<T>, need: usize) {
if v.len() < need {
v.resize(need, T::default());
} else if v.len() > need {
v.truncate(need);
}
}
#[inline]
fn ensure_capacity<T>(v: &mut Vec<T>, need: usize) {
if v.capacity() < need {
v.reserve(need - v.capacity());
}
}
#[cfg(test)]
mod ws_tests {
use super::*;
#[test]
fn acquire_gmres_does_not_shrink_z_capacity_when_need_z_toggles() {
let n = 128;
let m = 50;
let mut ws = Workspace::new(n);
ws.acquire_gmres(GmresSpec {
n,
m,
need_z: true,
block_s: 0,
});
let cap1 = ws.z_mem.capacity();
assert!(cap1 >= n * m);
ws.acquire_gmres(GmresSpec {
n,
m,
need_z: false,
block_s: 0,
});
let cap2 = ws.z_mem.capacity();
assert_eq!(cap2, cap1);
ws.acquire_gmres(GmresSpec {
n,
m,
need_z: true,
block_s: 0,
});
let cap3 = ws.z_mem.capacity();
assert_eq!(cap3, cap1);
}
}
#[inline]
fn block_payload_capacity(max_blocks: usize, block_size: usize) -> usize {
let scalars = max_blocks
.checked_mul(block_size)
.and_then(|v| v.checked_mul(block_size))
.unwrap_or(usize::MAX);
#[cfg(feature = "complex")]
{
scalars.checked_mul(2).unwrap_or(usize::MAX)
}
#[cfg(not(feature = "complex"))]
{
scalars
}
}