kryst 4.0.4

Krylov subspace and preconditioned iterative solvers for dense and sparse linear systems, with shared and distributed memory parallelism.
use crate::algebra::bridge::BridgeScratch;
#[allow(unused_imports)]
use crate::algebra::prelude::*;
use crate::context::ksp_context::Workspace;
use crate::error::KError;
use crate::matrix::op::{LinOp, LinOpF64};
use crate::ops::klinop::KLinOp;
use crate::ops::kpc::KPreconditioner;
use crate::ops::wrap::{as_s_op, as_s_pc};
use crate::parallel::{Comm, UniverseComm};
use crate::preconditioner::{PcSide, Preconditioner};
use crate::solver::{LinearSolver, MonitorCallback};
use crate::utils::convergence::{ConvergedReason, Convergence, SolveStats};
use std::any::Any;

pub struct RichardsonSolver {
    conv: Convergence,
    omega: f64,
}

impl RichardsonSolver {
    pub fn new(rtol: f64, maxits: usize) -> Self {
        Self {
            conv: Convergence::new(rtol, 1e-12, 1e3, maxits),
            omega: 1.0,
        }
    }

    pub fn set_omega(&mut self, omega: f64) {
        self.omega = omega;
    }

    #[allow(clippy::too_many_arguments)]
    pub fn solve<A>(
        &mut self,
        a: &A,
        pc: Option<&dyn KPreconditioner<Scalar = S>>,
        b: &[S],
        x: &mut [S],
        pc_side: PcSide,
        comm: &UniverseComm,
        monitors: Option<&[Box<MonitorCallback<R>>]>,
        work: Option<&mut Workspace>,
    ) -> Result<SolveStats<R>, KError>
    where
        A: KLinOp<Scalar = S> + ?Sized,
    {
        let n = b.len();
        if x.len() != n {
            return Err(KError::InvalidInput(
                "Richardson: vector length mismatch".into(),
            ));
        }

        let mut local_scratch = BridgeScratch::default();
        let scratch = work.map(|w| &mut w.bridge).unwrap_or(&mut local_scratch);

        let mut ax = vec![S::zero(); n];
        let mut r = vec![S::zero(); n];
        let mut z = vec![S::zero(); n];
        a.matvec_s(x, &mut ax, scratch);
        for i in 0..n {
            r[i] = b[i] - ax[i];
        }

        let bnorm = comm.norm2(b);
        let mut rnorm = comm.norm2(&r);
        let monitors = monitors.unwrap_or(&[]);
        for m in monitors {
            let _ = m(0, rnorm, 0);
        }
        let (reason0, s0) = self.conv.check(rnorm, bnorm, 0);
        if !matches!(reason0, ConvergedReason::Continued) {
            return Ok(SolveStats::new(0, rnorm, s0.reason));
        }

        let omega = S::from_real(self.omega);
        for k in 1..=self.conv.max_iters {
            if let Some(pc_ref) = pc {
                pc_ref.apply_s(pc_side, &r, &mut z, scratch)?;
            } else {
                z.copy_from_slice(&r);
            }
            for i in 0..n {
                x[i] += omega * z[i];
            }
            a.matvec_s(x, &mut ax, scratch);
            for i in 0..n {
                r[i] = b[i] - ax[i];
            }
            rnorm = comm.norm2(&r);
            for m in monitors {
                let _ = m(k, rnorm, 0);
            }
            let (reason, s) = self.conv.check(rnorm, bnorm, k);
            if !matches!(reason, ConvergedReason::Continued) {
                return Ok(SolveStats::new(k, rnorm, s.reason));
            }
        }
        Ok(SolveStats::new(
            self.conv.max_iters,
            rnorm,
            ConvergedReason::DivergedMaxIts,
        ))
    }

    #[allow(clippy::too_many_arguments)]
    pub fn solve_f64<A>(
        &mut self,
        a: &A,
        pc: Option<&dyn Preconditioner>,
        b: &[f64],
        x: &mut [f64],
        pc_side: PcSide,
        comm: &UniverseComm,
        monitors: Option<&[Box<MonitorCallback<f64>>]>,
        work: Option<&mut Workspace>,
    ) -> Result<SolveStats<f64>, KError>
    where
        A: LinOpF64 + LinOp<S = f64> + Send + Sync + ?Sized,
    {
        let op = as_s_op(a);
        let pc_wrapper = pc.map(as_s_pc);
        let pc_ref = pc_wrapper
            .as_ref()
            .map(|w| w as &dyn KPreconditioner<Scalar = S>);

        #[cfg(not(feature = "complex"))]
        {
            let b_s: &[S] = unsafe { &*(b as *const [f64] as *const [S]) };
            let x_s: &mut [S] = unsafe { &mut *(x as *mut [f64] as *mut [S]) };
            self.solve(&op, pc_ref, b_s, x_s, pc_side, comm, monitors, work)
        }
        #[cfg(feature = "complex")]
        {
            let b_s: Vec<S> = b.iter().copied().map(S::from_real).collect();
            let mut x_s: Vec<S> = x.iter().copied().map(S::from_real).collect();
            let result = self.solve(&op, pc_ref, &b_s, &mut x_s, pc_side, comm, monitors, work);
            if result.is_ok() {
                for (dst, src) in x.iter_mut().zip(x_s.iter()) {
                    *dst = src.real();
                }
            }
            result
        }
    }
}

impl LinearSolver for RichardsonSolver {
    type Error = KError;

    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    fn solve(
        &mut self,
        a: &dyn LinOp<S = f64>,
        mut pc: Option<&mut dyn Preconditioner>,
        b: &[f64],
        x: &mut [f64],
        pc_side: PcSide,
        comm: &UniverseComm,
        monitors: Option<&[Box<MonitorCallback<f64>>]>,
        _work: Option<&mut Workspace>,
    ) -> Result<SolveStats<f64>, Self::Error> {
        self.solve_f64(
            a,
            pc.as_deref_mut().map(|p| &*p),
            b,
            x,
            pc_side,
            comm,
            monitors,
            _work,
        )
    }
}