kryst 3.2.1

Krylov subspace and preconditioned iterative solvers for dense and sparse linear systems, with shared and distributed memory parallelism.
#![cfg(feature = "backend-nalgebra")]

use crate::algebra::prelude::*;
use crate::error::KError;
use crate::matrix::op::LinOp;
use crate::preconditioner::{PcSide, Preconditioner};
use nalgebra::{DMatrix, DVector, Dynamic};

fn extract_dense(op: &dyn LinOp<S = S>) -> Result<DMatrix<S>, KError> {
    let dense = op
        .as_any()
        .downcast_ref::<crate::matrix::op_nalgebra::NalgebraDenseOp>()
        .ok_or_else(|| {
            KError::InvalidInput("nalgebra direct PC requires NalgebraDenseOp".into())
        })?;
    Ok(dense.inner().clone())
}

fn ensure_trivial_comm(op: &dyn LinOp<S = S>) -> Result<(), KError> {
    if !op.comm().is_trivial() && op.comm().size() > 1 {
        return Err(KError::Unsupported(
            "nalgebra direct preconditioners require a trivial communicator",
        ));
    }
    Ok(())
}

pub struct NalgebraLuPc {
    lu: Option<nalgebra::linalg::LU<S, Dynamic, Dynamic>>,
    n: usize,
}

impl NalgebraLuPc {
    pub fn new() -> Self {
        Self { lu: None, n: 0 }
    }

    fn factorize(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
        ensure_trivial_comm(op)?;
        let a = extract_dense(op)?;
        if a.nrows() != a.ncols() {
            return Err(KError::InvalidInput(
                "nalgebra LU requires a square matrix".into(),
            ));
        }
        let n = a.nrows();
        let lu = a.lu();
        self.lu = Some(lu);
        self.n = n;
        Ok(())
    }
}

impl Preconditioner for NalgebraLuPc {
    fn dims(&self) -> (usize, usize) {
        (self.n, self.n)
    }

    fn setup(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
        self.factorize(op)
    }

    fn apply(&self, _side: PcSide, _r: &[S], _z: &mut [S]) -> Result<(), KError> {
        Err(KError::Unsupported(
            "NalgebraLuPc::apply is PREONLY-only; use SolverType::Preonly or call direct_solve",
        ))
    }

    fn direct_solve(
        &mut self,
        op: &dyn LinOp<S = S>,
        b: &[S],
        x: &mut [S],
    ) -> Result<(), KError> {
        if self.lu.is_none() {
            self.factorize(op)?;
        }
        let lu = self.lu.as_ref().ok_or_else(|| {
            KError::SolveError("nalgebra LU factorization missing".into())
        })?;
        if b.len() != self.n || x.len() != self.n {
            return Err(KError::InvalidInput(
                "nalgebra LU solve dimension mismatch".into(),
            ));
        }
        let b_vec = DVector::from_column_slice(b);
        let sol = lu
            .solve(&b_vec)
            .ok_or_else(|| KError::FactorError("nalgebra LU solve failed".into()))?;
        x.copy_from_slice(sol.as_slice());
        Ok(())
    }

    fn required_format(&self) -> crate::matrix::format::OpFormat {
        crate::matrix::format::OpFormat::Dense
    }
}

pub struct NalgebraQrPc {
    qr: Option<nalgebra::linalg::QR<S, Dynamic, Dynamic>>,
    n: usize,
}

impl NalgebraQrPc {
    pub fn new() -> Self {
        Self { qr: None, n: 0 }
    }

    fn factorize(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
        ensure_trivial_comm(op)?;
        let a = extract_dense(op)?;
        if a.nrows() != a.ncols() {
            return Err(KError::InvalidInput(
                "nalgebra QR requires a square matrix".into(),
            ));
        }
        let n = a.nrows();
        let qr = a.qr();
        self.qr = Some(qr);
        self.n = n;
        Ok(())
    }
}

impl Preconditioner for NalgebraQrPc {
    fn dims(&self) -> (usize, usize) {
        (self.n, self.n)
    }

    fn setup(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
        self.factorize(op)
    }

    fn apply(&self, _side: PcSide, _r: &[S], _z: &mut [S]) -> Result<(), KError> {
        Err(KError::Unsupported(
            "NalgebraQrPc::apply is PREONLY-only; use SolverType::Preonly or call direct_solve",
        ))
    }

    fn direct_solve(
        &mut self,
        op: &dyn LinOp<S = S>,
        b: &[S],
        x: &mut [S],
    ) -> Result<(), KError> {
        if self.qr.is_none() {
            self.factorize(op)?;
        }
        let qr = self.qr.as_ref().ok_or_else(|| {
            KError::SolveError("nalgebra QR factorization missing".into())
        })?;
        if b.len() != self.n || x.len() != self.n {
            return Err(KError::InvalidInput(
                "nalgebra QR solve dimension mismatch".into(),
            ));
        }
        let b_vec = DVector::from_column_slice(b);
        let sol = qr
            .solve(&b_vec)
            .ok_or_else(|| KError::FactorError("nalgebra QR solve failed".into()))?;
        x.copy_from_slice(sol.as_slice());
        Ok(())
    }

    fn required_format(&self) -> crate::matrix::format::OpFormat {
        crate::matrix::format::OpFormat::Dense
    }
}