use crate::error::KError;
use crate::solver::MonitorCallback;
use crate::solver::legacy::LinearSolver;
use crate::utils::convergence::{ConvergedReason, SolveStats};
use crate::{parallel::UniverseComm, preconditioner::PcSide};
use faer::linalg::solvers::{FullPivLu, Qr, SolveCore};
use faer::{Conj, Mat, MatMut};
type Scalar = f64;
#[cfg(feature = "logging")]
use crate::utils::profiling::StageGuard;
pub struct LuSolver {
factor: Option<FullPivLu<Scalar>>,
}
impl LuSolver {
pub fn new() -> Self {
LuSolver { factor: None }
}
pub fn solve_cached(&self, b: &[Scalar], x: &mut [Scalar]) {
if let Some(factor) = &self.factor {
let n = b.len();
x.clone_from_slice(b);
let x_mat = MatMut::from_column_major_slice_mut(x, n, 1);
factor.solve_in_place_with_conj(Conj::No, x_mat);
} else {
panic!("LuSolver: solve_cached called before factorization");
}
}
}
impl LinearSolver<Mat<Scalar>, Vec<Scalar>> for LuSolver {
type Error = KError;
type Scalar = Scalar;
fn solve(
&mut self,
a: &Mat<Scalar>,
pc: Option<
&(dyn crate::preconditioner::legacy::Preconditioner<Mat<Scalar>, Vec<Scalar>> + '_),
>,
b: &Vec<Scalar>,
x: &mut Vec<Scalar>,
_pc_side: PcSide,
_comm: &UniverseComm,
monitors: Option<&[Box<MonitorCallback<Self::Scalar>>]>,
_work: Option<&mut crate::context::ksp_context::Workspace>,
) -> Result<SolveStats<Scalar>, KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("LuSolve");
let _ = pc; let _ = _pc_side;
if let Some(monitors) = monitors {
for monitor in monitors {
monitor(0, 0.0, 0);
}
}
#[cfg(feature = "logging")]
let _fact_guard = StageGuard::new("LuFactor");
let factor = FullPivLu::new(a.as_ref());
self.factor = Some(factor);
x.clone_from(b);
let n = x.len();
let x_mat = MatMut::from_column_major_slice_mut(x, n, 1);
self.factor
.as_ref()
.unwrap()
.solve_in_place_with_conj(Conj::No, x_mat);
if let Some(monitors) = monitors {
for monitor in monitors {
monitor(1, 0.0, 0);
}
}
Ok(SolveStats::new(1, 0.0, ConvergedReason::ConvergedAtol))
}
}
impl Default for LuSolver {
fn default() -> Self {
Self::new()
}
}
pub struct QrSolver;
impl QrSolver {
pub fn new() -> Self {
QrSolver
}
}
impl LinearSolver<Mat<Scalar>, Vec<Scalar>> for QrSolver {
type Error = KError;
type Scalar = Scalar;
fn solve(
&mut self,
a: &Mat<Scalar>,
pc: Option<
&(dyn crate::preconditioner::legacy::Preconditioner<Mat<Scalar>, Vec<Scalar>> + '_),
>,
b: &Vec<Scalar>,
x: &mut Vec<Scalar>,
_pc_side: PcSide,
_comm: &UniverseComm,
monitors: Option<&[Box<MonitorCallback<Self::Scalar>>]>,
_work: Option<&mut crate::context::ksp_context::Workspace>,
) -> Result<SolveStats<Scalar>, KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("QrSolve");
let _ = pc; let _ = _pc_side;
if let Some(monitors) = monitors {
for monitor in monitors {
monitor(0, 0.0, 0);
}
}
let factor = Qr::new(a.as_ref());
x.clone_from(b);
let n = x.len();
let x_mat = MatMut::from_column_major_slice_mut(x, n, 1);
factor.solve_in_place_with_conj(Conj::No, x_mat);
if let Some(monitors) = monitors {
for monitor in monitors {
monitor(1, 0.0, 0);
}
}
Ok(SolveStats::new(1, 0.0, ConvergedReason::ConvergedAtol))
}
}
impl Default for QrSolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::legacy::LinearSolver;
use faer::Mat;
#[test]
fn lu_solver_solves_dense_system() {
let a = Mat::from_fn(3, 3, |i, j| match (i, j) {
(0, 0) => 2.0,
(0, 1) => 1.0,
(0, 2) => 1.0,
(1, 0) => 1.0,
(1, 1) => 3.0,
(1, 2) => 2.0,
(2, 0) => 1.0,
(2, 1) => 0.0,
(2, 2) => 0.0,
_ => 0.0,
});
let b = vec![4.0, 5.0, 6.0];
let mut x = vec![0.0; 3];
let mut solver = LuSolver::new();
let stats = solver
.solve(
&a,
None,
&b,
&mut x,
PcSide::Left,
&UniverseComm::NoComm(crate::parallel::NoComm),
None,
None,
)
.unwrap();
let expected = vec![6.0, 15.0, -23.0];
let tol = 1e-10;
for (xi, ei) in x.iter().zip(expected.iter()) {
assert!((xi - ei).abs() < tol, "xi = {}, expected = {}", xi, ei);
}
assert!(
matches!(
stats.reason,
ConvergedReason::ConvergedAtol | ConvergedReason::ConvergedRtol
),
"LU did not report Converged reason"
);
}
#[test]
fn qr_solver_solves_dense_system() {
let a = Mat::from_fn(3, 3, |i, j| match (i, j) {
(0, 0) => 2.0,
(0, 1) => 1.0,
(0, 2) => 1.0,
(1, 0) => 1.0,
(1, 1) => 3.0,
(1, 2) => 2.0,
(2, 0) => 1.0,
(2, 1) => 0.0,
(2, 2) => 0.0,
_ => 0.0,
});
let b = vec![4.0, 5.0, 6.0];
let mut x = vec![0.0; 3];
let mut solver = QrSolver::new();
let stats = solver
.solve(
&a,
None,
&b,
&mut x,
PcSide::Left,
&UniverseComm::NoComm(crate::parallel::NoComm),
None,
None,
)
.unwrap();
let expected = vec![6.0, 15.0, -23.0];
let tol = 1e-10;
for (xi, ei) in x.iter().zip(expected.iter()) {
assert!((xi - ei).abs() < tol, "xi = {}, expected = {}", xi, ei);
}
assert!(
matches!(
stats.reason,
ConvergedReason::ConvergedAtol | ConvergedReason::ConvergedRtol
),
"QR did not report Converged reason"
);
}
}