use crate::error::FactorizationError;
use crate::sparse_sym_iface::SparseSymLinearSolverInterface;
use crate::t_sym_solver::TSymLinearSolver;
use pounce_common::types::{Index, Number};
pub struct Factorization {
inner: TSymLinearSolver,
dim: Index,
nnz: Index,
values: Vec<Number>,
inertia_known: bool,
}
impl std::fmt::Debug for Factorization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Factorization")
.field("dim", &self.dim)
.field("nnz", &self.nnz)
.field("inertia_known", &self.inertia_known)
.finish_non_exhaustive()
}
}
impl Factorization {
pub fn new(
dim: Index,
airn: Vec<Index>,
ajcn: Vec<Index>,
values: Vec<Number>,
backend: Box<dyn SparseSymLinearSolverInterface>,
) -> Result<Self, FactorizationError> {
assert_eq!(
airn.len(),
ajcn.len(),
"airn and ajcn must have same length"
);
assert_eq!(values.len(), airn.len(), "values must match nnz");
let nnz = airn.len() as Index;
let mut inner = TSymLinearSolver::new(backend, None, false);
FactorizationError::from_status(inner.initialize_structure(dim, &airn, &ajcn))?;
let mut me = Self {
inner,
dim,
nnz,
values,
inertia_known: false,
};
me.do_factor()?;
Ok(me)
}
pub fn solve(&mut self, rhs: &mut [Number], nrhs: usize) -> Result<(), FactorizationError> {
assert_eq!(
rhs.len(),
self.dim as usize * nrhs,
"rhs length must equal dim * nrhs"
);
let status = self.inner.multi_solve(
&self.values,
false, nrhs as Index,
rhs,
false,
0,
);
FactorizationError::from_status(status)
}
pub fn solve_one(&mut self, rhs: &mut [Number]) -> Result<(), FactorizationError> {
self.solve(rhs, 1)
}
pub fn refactor(&mut self, new_values: &[Number]) -> Result<(), FactorizationError> {
assert_eq!(
new_values.len(),
self.nnz as usize,
"new_values length must equal nnz",
);
self.values.copy_from_slice(new_values);
self.inertia_known = false;
self.do_factor()
}
pub fn number_of_neg_evals(&self) -> Option<Index> {
use crate::sym_solver::SymLinearSolver;
if self.inertia_known && self.inner.provides_inertia() {
Some(self.inner.number_of_neg_evals())
} else {
None
}
}
pub fn dim(&self) -> Index {
self.dim
}
pub fn nnz(&self) -> Index {
self.nnz
}
fn do_factor(&mut self) -> Result<(), FactorizationError> {
let mut dummy_rhs = vec![0.0; self.dim as usize];
let status = self.inner.multi_solve(
&self.values,
true, 1,
&mut dummy_rhs,
false,
0,
);
FactorizationError::from_status(status)?;
self.inertia_known = true;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse_sym_iface::EMatrixFormat;
use crate::status::ESymSolverStatus;
struct DenseLuBackend {
dim: usize,
nnz: usize,
rows: Vec<Index>, cols: Vec<Index>, values: Vec<Number>,
factor: Option<DenseLu>,
}
struct DenseLu {
a: Vec<Vec<f64>>, perm: Vec<usize>,
neg_evals: Index,
}
impl DenseLuBackend {
fn new() -> Self {
Self {
dim: 0,
nnz: 0,
rows: Vec::new(),
cols: Vec::new(),
values: Vec::new(),
factor: None,
}
}
fn assemble_dense(&self) -> Vec<Vec<f64>> {
let n = self.dim;
let mut a = vec![vec![0.0; n]; n];
for k in 0..self.nnz {
let i = (self.rows[k] - 1) as usize;
let j = (self.cols[k] - 1) as usize;
a[i][j] += self.values[k];
if i != j {
a[j][i] += self.values[k];
}
}
a
}
fn factor_dense(&mut self) -> ESymSolverStatus {
let n = self.dim;
let mut a = self.assemble_dense();
let mut perm: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut p = k;
let mut maxv = a[perm[k]][k].abs();
for i in (k + 1)..n {
let v = a[perm[i]][k].abs();
if v > maxv {
maxv = v;
p = i;
}
}
if maxv < 1e-300 {
return ESymSolverStatus::Singular;
}
perm.swap(k, p);
let pk = perm[k];
for &pi in &perm[(k + 1)..n] {
let factor = a[pi][k] / a[pk][k];
a[pi][k] = factor;
#[allow(clippy::needless_range_loop)]
for j in (k + 1)..n {
a[pi][j] -= factor * a[pk][j];
}
}
}
let mut neg = 0;
for k in 0..n {
if a[perm[k]][k] < 0.0 {
neg += 1;
}
}
self.factor = Some(DenseLu {
a,
perm,
neg_evals: neg as Index,
});
ESymSolverStatus::Success
}
fn solve_one(&self, b: &mut [f64]) {
let factor = self.factor.as_ref().unwrap();
let n = self.dim;
let mut x: Vec<f64> = factor.perm.iter().map(|&p| b[p]).collect();
for i in 0..n {
let pi = factor.perm[i];
for j in 0..i {
x[i] -= factor.a[pi][j] * x[j];
}
}
for i in (0..n).rev() {
let pi = factor.perm[i];
for j in (i + 1)..n {
x[i] -= factor.a[pi][j] * x[j];
}
x[i] /= factor.a[pi][i];
}
b.copy_from_slice(&x);
}
}
impl SparseSymLinearSolverInterface for DenseLuBackend {
fn initialize_structure(
&mut self,
dim: Index,
nonzeros: Index,
ia: &[Index],
ja: &[Index],
) -> ESymSolverStatus {
self.dim = dim as usize;
self.nnz = nonzeros as usize;
self.rows = ia.to_vec();
self.cols = ja.to_vec();
self.values = vec![0.0; self.nnz];
ESymSolverStatus::Success
}
fn values_array_mut(&mut self) -> &mut [Number] {
&mut self.values
}
fn multi_solve(
&mut self,
new_matrix: bool,
_ia: &[Index],
_ja: &[Index],
nrhs: Index,
rhs_vals: &mut [Number],
check_neg_evals: bool,
number_of_neg_evals: Index,
) -> ESymSolverStatus {
if new_matrix {
let s = self.factor_dense();
if s != ESymSolverStatus::Success {
return s;
}
if check_neg_evals {
let actual = self.factor.as_ref().unwrap().neg_evals;
if actual != number_of_neg_evals {
return ESymSolverStatus::WrongInertia;
}
}
}
let n = self.dim;
for k in 0..nrhs as usize {
let base = k * n;
self.solve_one(&mut rhs_vals[base..base + n]);
}
ESymSolverStatus::Success
}
fn number_of_neg_evals(&self) -> Index {
self.factor.as_ref().map(|f| f.neg_evals).unwrap_or(0)
}
fn increase_quality(&mut self) -> bool {
false
}
fn provides_inertia(&self) -> bool {
true
}
fn matrix_format(&self) -> EMatrixFormat {
EMatrixFormat::TripletFormat
}
}
#[test]
fn factors_spd_2x2_and_solves_one_rhs() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let values = vec![2.0, 1.0, 3.0];
let mut f =
Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
let mut rhs = vec![3.0, 4.0];
f.solve_one(&mut rhs).unwrap();
assert!((rhs[0] - 1.0).abs() < 1e-12);
assert!((rhs[1] - 1.0).abs() < 1e-12);
}
#[test]
fn packed_multi_rhs_matches_one_at_a_time() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let values = vec![2.0, 1.0, 3.0];
let backend1 = Box::new(DenseLuBackend::new());
let backend2 = Box::new(DenseLuBackend::new());
let mut f1 =
Factorization::new(2, airn.clone(), ajcn.clone(), values.clone(), backend1).unwrap();
let mut f2 = Factorization::new(2, airn, ajcn, values, backend2).unwrap();
let mut packed = vec![
3.0, 4.0, 5.0, 5.0, 2.0, 6.0, ];
f1.solve(&mut packed, 3).unwrap();
let mut col0 = vec![3.0, 4.0];
let mut col1 = vec![5.0, 5.0];
let mut col2 = vec![2.0, 6.0];
f2.solve_one(&mut col0).unwrap();
f2.solve_one(&mut col1).unwrap();
f2.solve_one(&mut col2).unwrap();
for (i, &v) in col0.iter().enumerate() {
assert!((packed[i] - v).abs() < 1e-12, "col0 mismatch at {i}");
}
for (i, &v) in col1.iter().enumerate() {
assert!((packed[2 + i] - v).abs() < 1e-12, "col1 mismatch at {i}");
}
for (i, &v) in col2.iter().enumerate() {
assert!((packed[4 + i] - v).abs() < 1e-12, "col2 mismatch at {i}");
}
}
#[test]
fn refactor_yields_correct_solution_for_new_values() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let mut f = Factorization::new(
2,
airn,
ajcn,
vec![2.0, 1.0, 3.0],
Box::new(DenseLuBackend::new()),
)
.unwrap();
f.refactor(&[4.0, 1.0, 5.0]).unwrap();
let mut rhs = vec![5.0, 6.0]; f.solve_one(&mut rhs).unwrap();
let r0 = 4.0 * rhs[0] + rhs[1] - 5.0;
let r1 = rhs[0] + 5.0 * rhs[1] - 6.0;
assert!(r0.abs() < 1e-10);
assert!(r1.abs() < 1e-10);
}
#[test]
fn singular_matrix_returns_singular_error() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let err = Factorization::new(
2,
airn,
ajcn,
vec![1.0, 1.0, 1.0],
Box::new(DenseLuBackend::new()),
)
.unwrap_err();
assert_eq!(err, FactorizationError::Singular);
}
#[test]
fn solve_one_matches_solve_with_nrhs_one() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let values = vec![2.0, 1.0, 3.0];
let mut f1 = Factorization::new(
2,
airn.clone(),
ajcn.clone(),
values.clone(),
Box::new(DenseLuBackend::new()),
)
.unwrap();
let mut f2 =
Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
let mut rhs1 = vec![3.0, 4.0];
let mut rhs2 = vec![3.0, 4.0];
f1.solve_one(&mut rhs1).unwrap();
f2.solve(&mut rhs2, 1).unwrap();
assert_eq!(rhs1, rhs2);
}
#[test]
fn inertia_is_reported_when_backend_provides_it() {
let airn = vec![1, 2, 2];
let ajcn = vec![1, 1, 2];
let f = Factorization::new(
2,
airn,
ajcn,
vec![2.0, 1.0, 3.0], Box::new(DenseLuBackend::new()),
)
.unwrap();
assert_eq!(f.number_of_neg_evals(), Some(0));
assert_eq!(f.dim(), 2);
assert_eq!(f.nnz(), 3);
}
}