#[cfg(feature = "cudss")]
use crate::SolverCUDSS;
#[cfg(feature = "local_sparse")]
use super::SolverMUMPS;
use super::{CooMatrix, Genie, LinSolParams, SolverUMFPACK, StatsLinSol};
use crate::StrError;
use russell_lab::Vector;
pub trait LinSolTrait: Send {
fn factorize(&mut self, mat: &CooMatrix, params: Option<LinSolParams>) -> Result<(), StrError>;
fn solve(&mut self, x: &mut Vector, rhs: &Vector, verbose: bool) -> Result<(), StrError>;
fn update_stats(&self, stats: &mut StatsLinSol);
fn get_ns_init(&self) -> u128;
fn get_ns_fact(&self) -> u128;
fn get_ns_solve(&self) -> u128;
}
pub struct LinSolver<'a> {
pub actual: Box<dyn Send + LinSolTrait + 'a>,
}
impl<'a> LinSolver<'a> {
pub fn new(genie: Genie) -> Result<Self, StrError> {
#[cfg(all(feature = "cudss", feature = "local_sparse"))]
let actual: Box<dyn Send + LinSolTrait> = match genie {
Genie::Cudss => Box::new(SolverCUDSS::new()?),
Genie::Mumps => Box::new(SolverMUMPS::new()?),
Genie::Umfpack => Box::new(SolverUMFPACK::new()?),
};
#[cfg(all(not(feature = "cudss"), feature = "local_sparse"))]
let actual: Box<dyn Send + LinSolTrait> = match genie {
Genie::Cudss => return Err("cuDSS solver is not available"),
Genie::Mumps => Box::new(SolverMUMPS::new()?),
Genie::Umfpack => Box::new(SolverUMFPACK::new()?),
};
#[cfg(all(feature = "cudss", not(feature = "local_sparse")))]
let actual: Box<dyn Send + LinSolTrait> = match genie {
Genie::Cudss => Box::new(SolverCUDSS::new()?),
Genie::Mumps => return Err("MUMPS solver is not available"),
Genie::Umfpack => Box::new(SolverUMFPACK::new()?),
};
#[cfg(all(not(feature = "cudss"), not(feature = "local_sparse")))]
let actual: Box<dyn Send + LinSolTrait> = match genie {
Genie::Cudss => return Err("cuDSS solver is not available"),
Genie::Mumps => return Err("MUMPS solver is not available"),
Genie::Umfpack => Box::new(SolverUMFPACK::new()?),
};
Ok(LinSolver { actual })
}
pub fn compute(
genie: Genie,
x: &mut Vector,
mat: &CooMatrix,
rhs: &Vector,
params: Option<LinSolParams>,
) -> Result<Self, StrError> {
let mut solver = LinSolver::new(genie)?;
solver.actual.factorize(mat, params)?;
let verbose = if let Some(p) = params { p.verbose } else { false };
solver.actual.solve(x, rhs, verbose)?;
Ok(solver)
}
}
#[cfg(test)]
mod tests {
use super::LinSolver;
use crate::{Genie, Samples};
use russell_lab::{Vector, vec_approx_eq};
#[cfg(feature = "local_sparse")]
use serial_test::serial;
#[test]
#[serial]
#[cfg(feature = "cudss")]
fn lin_solver_compute_works_cudss() {
let (coo, _, _, _) = Samples::mkl_symmetric_5x5_lower(true, false);
let mut x = Vector::new(5);
let rhs = Vector::from(&[1.0, 2.0, 3.0, 4.0, 5.0]);
LinSolver::compute(Genie::Cudss, &mut x, &coo, &rhs, None).unwrap();
let x_correct = vec![-979.0 / 3.0, 983.0, 1961.0 / 12.0, 398.0, 123.0 / 2.0];
vec_approx_eq(&x, &x_correct, 1e-10);
}
#[test]
#[serial]
#[cfg(feature = "local_sparse")]
fn lin_solver_compute_works_mumps() {
let (coo, _, _, _) = Samples::mkl_symmetric_5x5_lower(true, false);
let mut x = Vector::new(5);
let rhs = Vector::from(&[1.0, 2.0, 3.0, 4.0, 5.0]);
LinSolver::compute(Genie::Mumps, &mut x, &coo, &rhs, None).unwrap();
let x_correct = vec![-979.0 / 3.0, 983.0, 1961.0 / 12.0, 398.0, 123.0 / 2.0];
vec_approx_eq(&x, &x_correct, 1e-10);
}
#[test]
fn lin_solver_compute_works_umfpack() {
let (coo, _, _, _) = Samples::mkl_symmetric_5x5_full();
let mut x = Vector::new(5);
let rhs = Vector::from(&[1.0, 2.0, 3.0, 4.0, 5.0]);
LinSolver::compute(Genie::Umfpack, &mut x, &coo, &rhs, None).unwrap();
let x_correct = vec![-979.0 / 3.0, 983.0, 1961.0 / 12.0, 398.0, 123.0 / 2.0];
vec_approx_eq(&x, &x_correct, 1e-10);
}
}