use super::{handle_klu_error_code, klu_ordering, klu_scaling};
use super::{ComplexCooMatrix, ComplexCscMatrix, ComplexLinSolTrait, LinSolParams, StatsLinSol, Sym};
use super::{KLU_ORDERING_AMD, KLU_ORDERING_COLAMD, KLU_SCALE_MAX, KLU_SCALE_NONE, KLU_SCALE_SUM};
use crate::constants::*;
use crate::StrError;
use russell_lab::{complex_vec_copy, Complex64, ComplexVector, Stopwatch};
#[repr(C)]
struct InterfaceComplexKLU {
_data: [u8; 0],
_marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}
unsafe impl Send for InterfaceComplexKLU {}
unsafe impl Send for ComplexSolverKLU {}
extern "C" {
fn complex_solver_klu_new() -> *mut InterfaceComplexKLU;
fn complex_solver_klu_drop(solver: *mut InterfaceComplexKLU);
fn complex_solver_klu_initialize(
solver: *mut InterfaceComplexKLU,
ordering: i32,
scaling: i32,
ndim: i32,
col_pointers: *const i32,
row_indices: *const i32,
) -> i32;
fn complex_solver_klu_factorize(
solver: *mut InterfaceComplexKLU,
effective_ordering: *mut i32,
effective_scaling: *mut i32,
cond_estimate: *mut f64,
compute_cond: CcBool,
col_pointers: *const i32,
row_indices: *const i32,
values: *const Complex64,
) -> i32;
fn complex_solver_klu_solve(solver: *mut InterfaceComplexKLU, ndim: i32, in_rhs_out_x: *mut Complex64) -> i32;
}
pub struct ComplexSolverKLU {
solver: *mut InterfaceComplexKLU,
csc: Option<ComplexCscMatrix>,
initialized: bool,
factorized: bool,
initialized_sym: Sym,
initialized_ndim: usize,
initialized_nnz: usize,
effective_ordering: i32,
effective_scaling: i32,
cond_estimate: f64,
stopwatch: Stopwatch,
time_initialize_ns: u128,
time_factorize_ns: u128,
time_solve_ns: u128,
}
impl Drop for ComplexSolverKLU {
fn drop(&mut self) {
unsafe {
complex_solver_klu_drop(self.solver);
}
}
}
impl ComplexSolverKLU {
pub fn new() -> Result<Self, StrError> {
unsafe {
let solver = complex_solver_klu_new();
if solver.is_null() {
return Err("c-code failed to allocate the KLU solver");
}
Ok(ComplexSolverKLU {
solver,
csc: None,
initialized: false,
factorized: false,
initialized_sym: Sym::No,
initialized_ndim: 0,
initialized_nnz: 0,
effective_ordering: -1,
effective_scaling: -1,
cond_estimate: 0.0,
stopwatch: Stopwatch::new(),
time_initialize_ns: 0,
time_factorize_ns: 0,
time_solve_ns: 0,
})
}
}
}
impl ComplexLinSolTrait for ComplexSolverKLU {
fn factorize(&mut self, mat: &ComplexCooMatrix, params: Option<LinSolParams>) -> Result<(), StrError> {
if self.initialized {
if mat.symmetric != self.initialized_sym {
return Err("subsequent factorizations must use the same matrix (symmetric differs)");
}
if mat.nrow != self.initialized_ndim {
return Err("subsequent factorizations must use the same matrix (ndim differs)");
}
if mat.nnz != self.initialized_nnz {
return Err("subsequent factorizations must use the same matrix (nnz differs)");
}
self.csc.as_mut().unwrap().update_from_coo(mat)?;
} else {
if mat.nrow != mat.ncol {
return Err("the matrix must be square");
}
if mat.nnz < 1 {
return Err("the COO matrix must have at least one non-zero value");
}
if mat.symmetric == Sym::YesLower || mat.symmetric == Sym::YesUpper {
return Err("KLU requires Sym::YesFull for symmetric matrices");
}
self.initialized_sym = mat.symmetric;
self.initialized_ndim = mat.nrow;
self.initialized_nnz = mat.nnz;
self.csc = Some(ComplexCscMatrix::from_coo(mat)?);
}
let csc = self.csc.as_ref().unwrap();
let par = if let Some(p) = params { p } else { LinSolParams::new() };
let ordering = klu_ordering(par.ordering);
let scaling = klu_scaling(par.scaling);
let compute_cond = if par.compute_condition_numbers { 1 } else { 0 };
let ndim = to_i32(csc.nrow);
if !self.initialized {
self.stopwatch.reset();
unsafe {
let status = complex_solver_klu_initialize(
self.solver,
ordering,
scaling,
ndim,
csc.col_pointers.as_ptr(),
csc.row_indices.as_ptr(),
);
if status != SUCCESSFUL_EXIT {
return Err(handle_klu_error_code(status));
}
}
self.time_initialize_ns = self.stopwatch.stop();
self.initialized = true;
}
self.stopwatch.reset();
unsafe {
let status = complex_solver_klu_factorize(
self.solver,
&mut self.effective_ordering,
&mut self.effective_scaling,
&mut self.cond_estimate,
compute_cond,
csc.col_pointers.as_ptr(),
csc.row_indices.as_ptr(),
csc.values.as_ptr(),
);
if status != SUCCESSFUL_EXIT {
return Err(handle_klu_error_code(status));
}
}
self.time_factorize_ns = self.stopwatch.stop();
self.factorized = true;
Ok(())
}
fn solve(&mut self, x: &mut ComplexVector, rhs: &ComplexVector, _verbose: bool) -> Result<(), StrError> {
if !self.factorized {
return Err("the function factorize must be called before solve");
}
if x.dim() != self.initialized_ndim {
return Err("the dimension of the vector of unknown values x is incorrect");
}
if rhs.dim() != self.initialized_ndim {
return Err("the dimension of the right-hand side vector is incorrect");
}
let ndim = to_i32(self.initialized_ndim);
complex_vec_copy(x, rhs).unwrap();
self.stopwatch.reset();
unsafe {
let status = complex_solver_klu_solve(self.solver, ndim, x.as_mut_data().as_mut_ptr());
if status != SUCCESSFUL_EXIT {
return Err(handle_klu_error_code(status));
}
}
self.time_solve_ns = self.stopwatch.stop();
Ok(())
}
fn update_stats(&self, stats: &mut StatsLinSol) {
stats.main.solver = if cfg!(feature = "local_suitesparse") {
"KLU-local".to_string()
} else {
"KLU".to_string()
};
stats.output.umfpack_rcond_estimate = self.cond_estimate;
stats.output.effective_ordering = match self.effective_ordering {
KLU_ORDERING_AMD => "Amd".to_string(),
KLU_ORDERING_COLAMD => "Colamd".to_string(),
_ => "Unknown".to_string(),
};
stats.output.effective_scaling = match self.effective_scaling {
KLU_SCALE_NONE => "No".to_string(),
KLU_SCALE_SUM => "Sum".to_string(),
KLU_SCALE_MAX => "Max".to_string(),
_ => "Unknown".to_string(),
};
stats.time_nanoseconds.initialize = self.time_initialize_ns;
stats.time_nanoseconds.factorize = self.time_factorize_ns;
stats.time_nanoseconds.solve = self.time_solve_ns;
}
fn get_ns_init(&self) -> u128 {
self.time_initialize_ns
}
fn get_ns_fact(&self) -> u128 {
self.time_factorize_ns
}
fn get_ns_solve(&self) -> u128 {
self.time_solve_ns
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ComplexCooMatrix, Ordering, Samples, Scaling};
use russell_lab::{complex_vec_approx_eq, cpx};
#[test]
fn new_and_drop_work() {
let solver = ComplexSolverKLU::new().unwrap();
assert!(!solver.factorized);
}
#[test]
fn factorize_handles_errors() {
let mut solver = ComplexSolverKLU::new().unwrap();
assert!(!solver.factorized);
let (coo, _, _, _) = Samples::complex_rectangular_4x3();
assert_eq!(solver.factorize(&coo, None).err(), Some("the matrix must be square"));
let coo = ComplexCooMatrix::new(1, 1, 1, Sym::No).unwrap();
assert_eq!(
solver.factorize(&coo, None).err(),
Some("the COO matrix must have at least one non-zero value")
);
let (coo, _, _, _) = Samples::complex_symmetric_3x3_lower();
assert_eq!(
solver.factorize(&coo, None).err(),
Some("KLU requires Sym::YesFull for symmetric matrices")
);
let mut coo = ComplexCooMatrix::new(2, 2, 2, Sym::No).unwrap();
coo.put(0, 0, cpx!(1.0, 0.0)).unwrap();
coo.put(1, 1, cpx!(2.0, 0.0)).unwrap();
solver.factorize(&coo, None).unwrap();
let mut coo = ComplexCooMatrix::new(2, 2, 2, Sym::YesFull).unwrap();
coo.put(0, 0, cpx!(1.0, 0.0)).unwrap();
coo.put(1, 1, cpx!(2.0, 0.0)).unwrap();
assert_eq!(
solver.factorize(&coo, None).err(),
Some("subsequent factorizations must use the same matrix (symmetric differs)")
);
let mut coo = ComplexCooMatrix::new(1, 1, 1, Sym::No).unwrap();
coo.put(0, 0, cpx!(1.0, 0.0)).unwrap();
assert_eq!(
solver.factorize(&coo, None).err(),
Some("subsequent factorizations must use the same matrix (ndim differs)")
);
let mut coo = ComplexCooMatrix::new(2, 2, 1, Sym::No).unwrap();
coo.put(0, 0, cpx!(1.0, 0.0)).unwrap();
assert_eq!(
solver.factorize(&coo, None).err(),
Some("subsequent factorizations must use the same matrix (nnz differs)")
);
}
#[test]
fn factorize_works() {
let mut solver = ComplexSolverKLU::new().unwrap();
assert!(!solver.factorized);
let (coo, _, _, _) = Samples::complex_symmetric_3x3_full();
let mut params = LinSolParams::new();
params.ordering = Ordering::Metis;
params.scaling = Scaling::Sum;
solver.factorize(&coo, Some(params)).unwrap();
assert!(solver.factorized);
assert_eq!(solver.effective_ordering, KLU_ORDERING_AMD);
assert_eq!(solver.effective_scaling, KLU_SCALE_SUM);
solver.factorize(&coo, Some(params)).unwrap();
}
#[test]
fn factorize_fails_on_singular_matrix() {
let mut solver = ComplexSolverKLU::new().unwrap();
let mut coo = ComplexCooMatrix::new(2, 2, 2, Sym::No).unwrap();
coo.put(0, 0, cpx!(1.0, 0.0)).unwrap();
coo.put(1, 1, cpx!(0.0, 0.0)).unwrap();
assert_eq!(solver.factorize(&coo, None), Err("klu_factor failed"));
}
#[test]
fn solve_handles_errors() {
let mut coo = ComplexCooMatrix::new(2, 2, 2, Sym::No).unwrap();
coo.put(0, 0, cpx!(123.0, 1.0)).unwrap();
coo.put(1, 1, cpx!(456.0, 2.0)).unwrap();
let mut solver = ComplexSolverKLU::new().unwrap();
assert!(!solver.factorized);
let mut x = ComplexVector::new(2);
let rhs = ComplexVector::new(2);
assert_eq!(
solver.solve(&mut x, &rhs, false),
Err("the function factorize must be called before solve")
);
let mut x = ComplexVector::new(1);
solver.factorize(&coo, None).unwrap();
assert_eq!(
solver.solve(&mut x, &rhs, false),
Err("the dimension of the vector of unknown values x is incorrect")
);
let mut x = ComplexVector::new(2);
let rhs = ComplexVector::new(1);
assert_eq!(
solver.solve(&mut x, &rhs, false),
Err("the dimension of the right-hand side vector is incorrect")
);
}
#[test]
fn solve_works() {
let mut solver = ComplexSolverKLU::new().unwrap();
let (coo, _, _, _) = Samples::complex_symmetric_3x3_full();
let mut x = ComplexVector::new(3);
let rhs = ComplexVector::from(&[cpx!(-3.0, 3.0), cpx!(2.0, -2.0), cpx!(9.0, 7.0)]);
let x_correct = &[cpx!(1.0, 1.0), cpx!(2.0, -2.0), cpx!(3.0, 3.0)];
let mut params = LinSolParams::new();
params.ordering = Ordering::Cholmod;
params.scaling = Scaling::Max;
solver.factorize(&coo, Some(params)).unwrap();
solver.solve(&mut x, &rhs, false).unwrap();
complex_vec_approx_eq(&x, x_correct, 1e-14);
let mut x_again = ComplexVector::new(3);
solver.solve(&mut x_again, &rhs, false).unwrap();
complex_vec_approx_eq(&x_again, x_correct, 1e-14);
let mut stats = StatsLinSol::new();
solver.update_stats(&mut stats);
assert_eq!(stats.output.effective_ordering, "Amd");
assert_eq!(stats.output.effective_scaling, "Max");
}
}