use crate::error::FeralError;
use crate::inertia::Inertia;
use crate::numeric::condition::estimate_condition_1norm;
use crate::numeric::factorize::{
factorize_multifrontal_with_workspace, FactorWorkspace, NumericParams, SparseFactors,
};
use crate::numeric::solve::{solve_sparse, solve_sparse_many, solve_sparse_refined};
use crate::scaling::ScalingStrategy;
use crate::sparse::csc::CscMatrix;
use crate::symbolic::supernode::SupernodeParams;
use crate::symbolic::{symbolic_factorize, SymbolicFactorization};
#[derive(Debug)]
pub enum FactorStatus {
Success,
Singular,
WrongInertia { actual: Inertia, expected: Inertia },
FatalError(FeralError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QualityLevel {
Baseline,
ScalingEnabled,
PivotRaised,
Exhausted,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct PatternFingerprint {
n: usize,
nnz: usize,
structural_hash: u64,
}
impl PatternFingerprint {
fn of(matrix: &CscMatrix) -> Self {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut h = DefaultHasher::new();
matrix.col_ptr.hash(&mut h);
matrix.row_idx.hash(&mut h);
Self {
n: matrix.n,
nnz: matrix.row_idx.len(),
structural_hash: h.finish(),
}
}
}
pub struct Solver {
numeric_params: NumericParams,
snode_params: SupernodeParams,
pivtol_max: f64,
quality_level: QualityLevel,
last_symbolic: Option<SymbolicFactorization>,
last_factors: Option<SparseFactors>,
last_inertia: Option<Inertia>,
last_pattern_fingerprint: Option<PatternFingerprint>,
symbolic_call_count: usize,
workspace: FactorWorkspace,
}
impl Solver {
pub fn new() -> Self {
Self::with_params(NumericParams::default(), SupernodeParams::default())
}
pub fn with_params(np: NumericParams, sn: SupernodeParams) -> Self {
Self {
numeric_params: np,
snode_params: sn,
pivtol_max: 0.5,
quality_level: QualityLevel::Baseline,
last_symbolic: None,
last_factors: None,
last_inertia: None,
last_pattern_fingerprint: None,
symbolic_call_count: 0,
workspace: FactorWorkspace::new(),
}
}
pub fn factor(&mut self, matrix: &CscMatrix, check_inertia: Option<Inertia>) -> FactorStatus {
let fp = PatternFingerprint::of(matrix);
if self.last_pattern_fingerprint != Some(fp) {
self.last_symbolic = None;
self.last_factors = None;
self.last_inertia = None;
self.last_pattern_fingerprint = None;
}
if self.last_symbolic.is_none() {
match symbolic_factorize(matrix, &self.snode_params) {
Ok(sym) => {
self.symbolic_call_count += 1;
self.last_symbolic = Some(sym);
self.last_pattern_fingerprint = Some(fp);
}
Err(e) => return FactorStatus::FatalError(e),
}
}
let symbolic = match &self.last_symbolic {
Some(s) => s,
None => unreachable!("symbolic just populated"),
};
match factorize_multifrontal_with_workspace(
matrix,
symbolic,
&self.numeric_params,
&mut self.workspace,
) {
Ok((factors, inertia)) => {
let partial_singular = matches!(
factors.scaling_info,
crate::scaling::ScalingInfo::PartialSingular { .. }
);
self.last_factors = Some(factors);
self.last_inertia = Some(inertia.clone());
if partial_singular {
FactorStatus::Singular
} else if let Some(expected) = check_inertia {
if inertia == expected {
FactorStatus::Success
} else {
FactorStatus::WrongInertia {
actual: inertia,
expected,
}
}
} else {
FactorStatus::Success
}
}
Err(FeralError::NumericallyRankDeficient) => {
self.last_factors = None;
self.last_inertia = None;
FactorStatus::Singular
}
Err(e) => {
self.last_factors = None;
self.last_inertia = None;
FactorStatus::FatalError(e)
}
}
}
pub fn solve(&self, rhs: &[f64]) -> Result<Vec<f64>, FeralError> {
match &self.last_factors {
Some(f) => solve_sparse(f, rhs),
None => Err(FeralError::NoFactor),
}
}
pub fn solve_refined(&self, matrix: &CscMatrix, rhs: &[f64]) -> Result<Vec<f64>, FeralError> {
match &self.last_factors {
Some(f) => solve_sparse_refined(matrix, f, rhs),
None => Err(FeralError::NoFactor),
}
}
pub fn solve_many(&self, rhs: &[f64], nrhs: usize) -> Result<Vec<f64>, FeralError> {
match &self.last_factors {
Some(f) => solve_sparse_many(f, rhs, nrhs),
None => Err(FeralError::NoFactor),
}
}
pub fn solve_many_refined(
&self,
matrix: &CscMatrix,
rhs: &[f64],
nrhs: usize,
) -> Result<Vec<f64>, FeralError> {
let factors = match &self.last_factors {
Some(f) => f,
None => return Err(FeralError::NoFactor),
};
if nrhs == 0 {
return Ok(Vec::new());
}
let n = factors.n;
if rhs.len() != n * nrhs {
return Err(FeralError::DimensionMismatch {
expected: n * nrhs,
got: rhs.len(),
});
}
let mut out = vec![0.0; n * nrhs];
for c in 0..nrhs {
let src = &rhs[c * n..(c + 1) * n];
let xc = solve_sparse_refined(matrix, factors, src)?;
out[c * n..(c + 1) * n].copy_from_slice(&xc);
}
Ok(out)
}
pub fn estimate_condition_1norm(&self, matrix: &CscMatrix) -> Result<f64, FeralError> {
match &self.last_factors {
Some(f) => estimate_condition_1norm(matrix, f),
None => Err(FeralError::NoFactor),
}
}
pub fn increase_quality(&mut self) -> bool {
const FIRST_PIVOT_THRESHOLD: f64 = 0.01;
const PIVOT_EXPONENT: f64 = 0.75;
const EPS_CAP: f64 = 1e-12;
match self.quality_level {
QualityLevel::Exhausted => false,
QualityLevel::Baseline => {
if matches!(self.numeric_params.scaling, ScalingStrategy::Identity) {
self.numeric_params.scaling = ScalingStrategy::InfNorm;
self.quality_level = QualityLevel::ScalingEnabled;
true
} else {
self.bump_pivot_threshold(FIRST_PIVOT_THRESHOLD, PIVOT_EXPONENT, EPS_CAP);
true
}
}
QualityLevel::ScalingEnabled | QualityLevel::PivotRaised => {
self.bump_pivot_threshold(FIRST_PIVOT_THRESHOLD, PIVOT_EXPONENT, EPS_CAP);
true
}
}
}
fn bump_pivot_threshold(&mut self, first_jump: f64, exponent: f64, eps_cap: f64) {
let pivtol = &mut self.numeric_params.bk.pivot_threshold;
if *pivtol == 0.0 {
*pivtol = first_jump;
} else {
*pivtol = pivtol.powf(exponent).min(self.pivtol_max);
}
self.quality_level = if *pivtol >= self.pivtol_max - eps_cap {
QualityLevel::Exhausted
} else {
QualityLevel::PivotRaised
};
}
pub fn pivot_threshold(&self) -> f64 {
self.numeric_params.bk.pivot_threshold
}
pub fn scaling_strategy(&self) -> &ScalingStrategy {
&self.numeric_params.scaling
}
pub fn num_negative_eigenvalues(&self) -> usize {
match &self.last_inertia {
Some(i) => i.negative,
None => panic!("num_negative_eigenvalues called before factor()"),
}
}
pub fn provides_inertia(&self) -> bool {
true
}
pub fn min_diagonal(&self) -> Option<f64> {
self.last_factors.as_ref().and_then(|f| f.min_diagonal())
}
pub fn factors(&self) -> Option<&SparseFactors> {
self.last_factors.as_ref()
}
pub fn quality_level(&self) -> QualityLevel {
self.quality_level
}
pub fn symbolic_call_count(&self) -> usize {
self.symbolic_call_count
}
}
impl Default for Solver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dense::factor::BunchKaufmanParams;
fn solver_with_scaling(scaling: ScalingStrategy) -> Solver {
let np = NumericParams {
bk: BunchKaufmanParams::default(),
scaling,
small_leaf: Default::default(),
profiler: None,
};
Solver::with_params(np, SupernodeParams::default())
}
#[test]
fn u1_increase_quality_baseline_identity_to_scaling_enabled() {
let mut s = solver_with_scaling(ScalingStrategy::Identity);
assert_eq!(s.quality_level(), QualityLevel::Baseline);
assert_eq!(s.pivot_threshold(), 0.0);
assert!(s.increase_quality());
assert!(matches!(s.scaling_strategy(), ScalingStrategy::InfNorm));
assert_eq!(s.pivot_threshold(), 0.0, "stage 1 must not touch pivot");
assert_eq!(s.quality_level(), QualityLevel::ScalingEnabled);
}
#[test]
fn u2_increase_quality_baseline_nonidentity_skips_to_pivot_raised() {
let mut s = solver_with_scaling(ScalingStrategy::InfNorm);
assert_eq!(s.quality_level(), QualityLevel::Baseline);
assert!(s.increase_quality());
assert_eq!(s.pivot_threshold(), 0.01, "first jump rule");
assert_eq!(s.quality_level(), QualityLevel::PivotRaised);
}
#[test]
fn u3_increase_quality_pivot_geometric_rule() {
let mut s = solver_with_scaling(ScalingStrategy::InfNorm);
s.numeric_params.bk.pivot_threshold = 0.01;
s.quality_level = QualityLevel::PivotRaised;
assert!(s.increase_quality());
let want = 0.01_f64.powf(0.75);
assert!(
(s.pivot_threshold() - want).abs() < 1e-15,
"got {}",
s.pivot_threshold()
);
assert_eq!(s.quality_level(), QualityLevel::PivotRaised);
}
#[test]
fn u4_increase_quality_caps_at_pivtol_max_then_exhausts() {
let mut s = solver_with_scaling(ScalingStrategy::InfNorm);
s.numeric_params.bk.pivot_threshold = 0.49;
s.quality_level = QualityLevel::PivotRaised;
assert!(s.increase_quality());
assert_eq!(s.pivot_threshold(), 0.5);
assert_eq!(s.quality_level(), QualityLevel::Exhausted);
assert!(!s.increase_quality());
assert_eq!(s.pivot_threshold(), 0.5);
assert_eq!(s.quality_level(), QualityLevel::Exhausted);
}
#[test]
fn u5_increase_quality_exhausted_returns_false() {
let mut s = solver_with_scaling(ScalingStrategy::Identity);
let mut steps = 0;
while s.increase_quality() {
steps += 1;
assert!(steps < 20, "did not exhaust within 20 steps");
}
assert_eq!(s.quality_level(), QualityLevel::Exhausted);
}
#[test]
fn f1_fingerprint_same_pattern_equal() {
let a = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
let b = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[7.0, 11.0, 13.0]).unwrap();
let fa = PatternFingerprint::of(&a);
let fb = PatternFingerprint::of(&b);
assert_eq!(
fa, fb,
"byte-identical patterns must fingerprint identically"
);
}
#[test]
fn f2_fingerprint_distinguishes_same_n_nnz_different_pattern() {
let a = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
let b = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
assert_eq!(PatternFingerprint::of(&a), PatternFingerprint::of(&b));
let c = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 2, 1], &[2.0, 3.0, 5.0]).unwrap();
assert_eq!(c.n, a.n);
assert_eq!(c.row_idx.len(), a.row_idx.len());
assert_eq!(c.col_ptr.len(), a.col_ptr.len());
assert_ne!(
PatternFingerprint::of(&a),
PatternFingerprint::of(&c),
"same (n, nnz) but different row_idx must fingerprint differently"
);
}
#[test]
fn f3_fingerprint_distinguishes_different_col_ptr() {
let a = CscMatrix::from_triplets(4, &[0, 1, 2, 3], &[0, 1, 2, 3], &[1.0, 2.0, 3.0, 4.0])
.unwrap();
let b = CscMatrix::from_triplets(4, &[0, 0, 1, 2], &[0, 1, 1, 2], &[1.0, 0.5, 2.0, 3.0])
.unwrap();
assert_eq!(a.n, b.n);
assert_eq!(a.row_idx.len(), b.row_idx.len());
assert_ne!(
PatternFingerprint::of(&a),
PatternFingerprint::of(&b),
"different col_ptr distribution must fingerprint differently"
);
}
}