pub(crate) mod eta;
pub(crate) mod lu;
pub(crate) mod refactor;
#[cfg(test)]
pub(crate) mod test_utils;
use crate::error::SolverError;
use crate::sparse::{CscMatrix, SparseVec};
use std::time::Instant;
pub(crate) trait BasisManager: Send {
fn ftran(&mut self, rhs: &mut SparseVec);
fn btran(&mut self, rhs: &mut SparseVec);
fn ftran_dense(&mut self, rhs: &mut [f64]);
fn btran_dense(&mut self, rhs: &mut [f64]);
fn update(&mut self, entering_col: usize, leaving_row: usize, pivot_col: &SparseVec);
}
pub(crate) struct LuBasis {
lu: lu::LuFactorization,
eta_file: eta::EtaFile,
basis_indices: Vec<usize>,
pub(crate) singular_basis: bool,
pub(crate) refactor_failed: bool,
}
impl LuBasis {
#[cfg(test)]
pub fn new(a: &CscMatrix, basis: &[usize], max_etas: usize) -> Result<Self, SolverError> {
Self::new_timed(a, basis, max_etas, None)
}
pub fn new_timed(
a: &CscMatrix,
basis: &[usize],
max_etas: usize,
deadline: Option<std::time::Instant>,
) -> Result<Self, SolverError> {
let lu = lu::LuFactorization::factorize_timed(a, basis, deadline)?;
let effective_max_etas = if max_etas == 0 {
crate::options::default_max_etas(basis.len())
} else {
max_etas
};
Ok(Self {
lu,
eta_file: eta::EtaFile::new(effective_max_etas),
basis_indices: basis.to_vec(),
singular_basis: false,
refactor_failed: false,
})
}
pub(crate) fn needs_refactor(&self) -> bool {
self.eta_file.needs_refactor()
}
pub(crate) fn eta_count(&self) -> usize {
self.eta_file.etas.len()
}
pub(crate) fn force_refactor_timed(
&mut self,
a: &CscMatrix,
basis: &[usize],
deadline: Option<Instant>,
) {
match refactor::refactor_timed(a, basis, deadline) {
Ok(new_lu) => {
self.lu = new_lu;
self.eta_file.etas.clear();
self.basis_indices = basis.to_vec();
}
Err(crate::error::SolverError::SingularBasis { .. }) => {
self.singular_basis = true;
self.refactor_failed = true;
}
Err(_) => {
self.refactor_failed = true;
}
}
}
pub(crate) fn refactor_if_needed_timed(
&mut self,
a: &CscMatrix,
basis: &[usize],
deadline: Option<Instant>,
) {
if self.eta_file.needs_refactor() {
match refactor::refactor_timed(a, basis, deadline) {
Ok(new_lu) => {
self.lu = new_lu;
self.eta_file.etas.clear();
self.basis_indices = basis.to_vec();
}
Err(crate::error::SolverError::SingularBasis { .. }) => {
self.singular_basis = true;
self.refactor_failed = true;
}
Err(_) => {
self.refactor_failed = true;
}
}
}
}
}
impl BasisManager for LuBasis {
fn ftran(&mut self, rhs: &mut SparseVec) {
let mut dense = rhs.to_dense();
lu::solve_ftran(&self.lu, &mut dense);
eta::apply_ftran(&self.eta_file.etas, &mut dense);
*rhs = SparseVec::from_dense(&dense);
}
fn btran(&mut self, rhs: &mut SparseVec) {
let mut dense = rhs.to_dense();
eta::apply_btran(&self.eta_file.etas, &mut dense);
lu::solve_btran(&self.lu, &mut dense);
*rhs = SparseVec::from_dense(&dense);
}
fn ftran_dense(&mut self, rhs: &mut [f64]) {
lu::solve_ftran(&self.lu, rhs);
eta::apply_ftran(&self.eta_file.etas, rhs);
}
fn btran_dense(&mut self, rhs: &mut [f64]) {
eta::apply_btran(&self.eta_file.etas, rhs);
lu::solve_btran(&self.lu, rhs);
}
fn update(&mut self, entering_col: usize, leaving_row: usize, pivot_col: &SparseVec) {
let eta = eta::add_eta_sparse(pivot_col, leaving_row);
self.eta_file.etas.push(eta);
self.basis_indices[leaving_row] = entering_col;
}
}
#[cfg(test)]
mod tests {
use super::test_utils::*;
use super::*;
#[test]
fn test_lu_basis_ftran_btran() {
let dense = vec![
vec![2.0, 1.0, 0.0],
vec![1.0, 3.0, 1.0],
vec![0.0, 1.0, 2.0],
];
let a = dense_to_csc(&dense, 3, 3);
let basis = vec![0, 1, 2];
let mut lb = LuBasis::new(&a, &basis, 50).unwrap();
let rhs_orig = vec![3.0, 5.0, 3.0];
let mut rhs_sv = SparseVec::from_dense(&rhs_orig);
lb.ftran(&mut rhs_sv);
let x = rhs_sv.to_dense();
let check = a.mat_vec_mul(&x).unwrap();
assert_vec_near(&check, &rhs_orig, 1e-10);
let mut rhs_sv2 = SparseVec::from_dense(&rhs_orig);
lb.btran(&mut rhs_sv2);
let y = rhs_sv2.to_dense();
let bt = a.transpose();
let check2 = bt.mat_vec_mul(&y).unwrap();
assert_vec_near(&check2, &rhs_orig, 1e-10);
}
#[test]
fn test_lu_basis_update() {
let dense = vec![
vec![2.0, 1.0, 0.0, 3.0],
vec![1.0, 3.0, 1.0, 1.0],
vec![0.0, 1.0, 2.0, 2.0],
];
let a = dense_to_csc(&dense, 3, 4);
let basis = vec![0, 1, 2];
let mut lb = LuBasis::new(&a, &basis, 50).unwrap();
let entering_col_dense = vec![3.0, 1.0, 2.0]; let mut pivot_sv = SparseVec::from_dense(&entering_col_dense);
lb.ftran(&mut pivot_sv);
lb.update(3, 1, &pivot_sv);
let rhs_orig = vec![5.0, 2.0, 4.0];
let mut rhs_sv = SparseVec::from_dense(&rhs_orig);
lb.ftran(&mut rhs_sv);
let x = rhs_sv.to_dense();
let b_new_dense = vec![
vec![2.0, 3.0, 0.0],
vec![1.0, 1.0, 1.0],
vec![0.0, 2.0, 2.0],
];
let b_new = dense_to_csc(&b_new_dense, 3, 3);
let check = b_new.mat_vec_mul(&x).unwrap();
assert_vec_near(&check, &rhs_orig, 1e-10);
}
#[test]
fn test_lu_basis_refactor_after_50_etas() {
let dense = vec![
vec![2.0, 1.0, 0.0],
vec![1.0, 3.0, 1.0],
vec![0.0, 1.0, 2.0],
];
let a = dense_to_csc(&dense, 3, 3);
let basis = vec![0, 1, 2];
let mut lb = LuBasis::new(&a, &basis, 50).unwrap();
assert!(
!lb.eta_file.needs_refactor(),
"Initially should not need refactor"
);
for i in 0..50 {
let r = i % 3;
let mut pivot = vec![0.0f64, 0.0, 0.0];
pivot[r] = 1.0;
lb.eta_file.etas.push(eta::add_eta(&pivot, r));
}
assert!(
lb.eta_file.needs_refactor(),
"50 etas with max_etas=50 should trigger refactor"
);
lb.refactor_if_needed_timed(&a, &basis, None);
assert!(
!lb.eta_file.needs_refactor(),
"After refactor, should not need refactor"
);
assert_eq!(
lb.eta_file.etas.len(),
0,
"Etas should be cleared after refactor"
);
let rhs_orig = vec![3.0, 5.0, 3.0];
let mut rhs_sv = SparseVec::from_dense(&rhs_orig);
lb.ftran(&mut rhs_sv);
let x = rhs_sv.to_dense();
let check = a.mat_vec_mul(&x).unwrap();
assert_vec_near(&check, &rhs_orig, 1e-10);
let bt = a.transpose();
let mut rhs_sv2 = SparseVec::from_dense(&rhs_orig);
lb.btran(&mut rhs_sv2);
let y = rhs_sv2.to_dense();
let check2 = bt.mat_vec_mul(&y).unwrap();
assert_vec_near(&check2, &rhs_orig, 1e-10);
}
#[test]
fn test_lu_basis_refactor() {
let dense = vec![
vec![2.0, 1.0, 0.0],
vec![1.0, 3.0, 1.0],
vec![0.0, 1.0, 2.0],
];
let a = dense_to_csc(&dense, 3, 3);
let basis = vec![0, 1, 2];
let mut lb = LuBasis::new(&a, &basis, 50).unwrap();
lb.eta_file.max_etas = 2;
lb.eta_file.etas.push(eta::add_eta(&[1.0, 0.0, 0.0], 0));
lb.eta_file.etas.push(eta::add_eta(&[0.0, 1.0, 0.0], 1));
assert!(lb.eta_file.needs_refactor());
lb.refactor_if_needed_timed(&a, &basis, None);
assert!(!lb.eta_file.needs_refactor());
assert_eq!(lb.eta_file.etas.len(), 0);
let rhs_orig = vec![3.0, 5.0, 3.0];
let mut rhs_sv = SparseVec::from_dense(&rhs_orig);
lb.ftran(&mut rhs_sv);
let x = rhs_sv.to_dense();
let check = a.mat_vec_mul(&x).unwrap();
assert_vec_near(&check, &rhs_orig, 1e-10);
}
}