use faer::Col;
use faer::Par;
use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
use faer::perm::Perm;
use faer::sparse::linalg::cholesky::SymmetricOrdering;
use faer::sparse::{SparseColMat, SymbolicSparseColMatRef};
use super::factor::{AptpFallback, AptpOptions};
use super::inertia::Inertia;
use super::numeric::{AptpNumeric, FactorizationStats};
use super::solve::{aptp_solve, aptp_solve_scratch};
use super::symbolic::AptpSymbolic;
use crate::error::SparseError;
use crate::ordering::{match_order_metis, metis_ordering};
#[derive(Debug, Clone)]
pub enum OrderingStrategy {
Amd,
Metis,
MatchOrderMetis,
UserSupplied(Perm<usize>),
}
#[derive(Debug, Clone)]
pub struct AnalyzeOptions {
pub ordering: OrderingStrategy,
}
impl Default for AnalyzeOptions {
fn default() -> Self {
Self {
ordering: OrderingStrategy::MatchOrderMetis,
}
}
}
#[derive(Debug, Clone)]
pub struct FactorOptions {
pub threshold: f64,
pub fallback: AptpFallback,
pub outer_block_size: usize,
pub inner_block_size: usize,
pub par: Par,
pub nemin: usize,
pub small_leaf_threshold: usize,
}
impl Default for FactorOptions {
fn default() -> Self {
Self {
threshold: 0.01,
fallback: AptpFallback::BunchKaufman,
outer_block_size: 256,
inner_block_size: 32,
par: Par::Seq,
nemin: 32,
small_leaf_threshold: 256,
}
}
}
#[derive(Debug, Clone)]
pub struct SolverOptions {
pub ordering: OrderingStrategy,
pub threshold: f64,
pub fallback: AptpFallback,
pub par: Par,
pub nemin: usize,
pub small_leaf_threshold: usize,
}
impl Default for SolverOptions {
fn default() -> Self {
Self {
ordering: OrderingStrategy::MatchOrderMetis,
threshold: 0.01,
fallback: AptpFallback::BunchKaufman,
par: Par::Seq,
nemin: 32,
small_leaf_threshold: 256,
}
}
}
pub struct SparseLDLT {
symbolic: AptpSymbolic,
numeric: Option<AptpNumeric>,
scaling: Option<Vec<f64>>,
perm_fwd: Vec<usize>,
}
impl SparseLDLT {
fn new_with_cached_perm(symbolic: AptpSymbolic, scaling: Option<Vec<f64>>) -> Self {
let (perm_fwd, _) = symbolic.perm_vecs();
SparseLDLT {
symbolic,
numeric: None,
scaling,
perm_fwd,
}
}
pub fn analyze(
matrix: SymbolicSparseColMatRef<'_, usize>,
options: &AnalyzeOptions,
) -> Result<Self, SparseError> {
let n = matrix.nrows();
match &options.ordering {
OrderingStrategy::Amd => {
let symbolic = AptpSymbolic::analyze(matrix, SymmetricOrdering::Amd)?;
Ok(Self::new_with_cached_perm(symbolic, None))
}
OrderingStrategy::Metis => {
let col_ptrs = matrix.col_ptr();
let row_indices = matrix.row_idx();
let mut triplets = Vec::new();
for j in 0..n {
for &i in &row_indices[col_ptrs[j]..col_ptrs[j + 1]] {
triplets.push(faer::sparse::Triplet::new(i, j, 1.0f64));
}
}
let dummy_matrix =
SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| {
SparseError::AnalysisFailure {
reason: format!("Failed to construct matrix for METIS ordering: {}", e),
}
})?;
let perm = metis_ordering(dummy_matrix.symbolic())?;
let symbolic =
AptpSymbolic::analyze(matrix, SymmetricOrdering::Custom(perm.as_ref()))?;
Ok(Self::new_with_cached_perm(symbolic, None))
}
OrderingStrategy::MatchOrderMetis => {
Err(SparseError::AnalysisFailure {
reason: "MatchOrderMetis requires numeric matrix values; use analyze_with_matrix() instead".to_string(),
})
}
OrderingStrategy::UserSupplied(perm) => {
let symbolic =
AptpSymbolic::analyze(matrix, SymmetricOrdering::Custom(perm.as_ref()))?;
Ok(Self::new_with_cached_perm(symbolic, None))
}
}
}
pub fn analyze_with_matrix(
matrix: &SparseColMat<usize, f64>,
options: &AnalyzeOptions,
) -> Result<Self, SparseError> {
let n = matrix.nrows();
match &options.ordering {
OrderingStrategy::MatchOrderMetis => {
let result = match_order_metis(matrix)?;
let ordering_perm = result.ordering;
let symbolic = AptpSymbolic::analyze(
matrix.symbolic(),
SymmetricOrdering::Custom(ordering_perm.as_ref()),
)?;
let (perm_fwd, _) = symbolic.perm_vecs();
let elim_scaling: Vec<f64> = (0..n).map(|i| result.scaling[perm_fwd[i]]).collect();
Ok(Self::new_with_cached_perm(symbolic, Some(elim_scaling)))
}
_ => Self::analyze(matrix.symbolic(), options),
}
}
pub fn factor(
&mut self,
matrix: &SparseColMat<usize, f64>,
options: &FactorOptions,
) -> Result<(), SparseError> {
let aptp_options = AptpOptions {
threshold: options.threshold,
fallback: options.fallback,
outer_block_size: options.outer_block_size,
inner_block_size: options.inner_block_size,
par: options.par,
nemin: options.nemin,
small_leaf_threshold: options.small_leaf_threshold,
..AptpOptions::default()
};
self.numeric = None;
let numeric = AptpNumeric::factor(
&self.symbolic,
matrix,
&aptp_options,
self.scaling.as_deref(),
)?;
self.numeric = Some(numeric);
Ok(())
}
pub fn refactor(
&mut self,
matrix: &SparseColMat<usize, f64>,
options: &FactorOptions,
) -> Result<(), SparseError> {
self.factor(matrix, options)
}
pub fn solve(
&self,
rhs: &Col<f64>,
stack: &mut MemStack,
par: Par,
) -> Result<Col<f64>, SparseError> {
let mut result = rhs.to_owned();
self.solve_in_place(&mut result, stack, par)?;
Ok(result)
}
pub fn solve_in_place(
&self,
rhs: &mut Col<f64>,
stack: &mut MemStack,
par: Par,
) -> Result<(), SparseError> {
let numeric = self
.numeric
.as_ref()
.ok_or_else(|| SparseError::SolveBeforeFactor {
context: "factor() must be called before solve()".to_string(),
})?;
let n = self.symbolic.nrows();
if rhs.nrows() != n {
return Err(SparseError::DimensionMismatch {
expected: (n, 1),
got: (rhs.nrows(), 1),
context: "RHS length must match matrix dimension".to_string(),
});
}
if n == 0 {
return Ok(());
}
let perm_fwd = &self.perm_fwd;
let mut rhs_perm = vec![0.0f64; n];
for new in 0..n {
rhs_perm[new] = rhs[perm_fwd[new]];
}
if let Some(ref scaling) = self.scaling {
for i in 0..n {
rhs_perm[i] *= scaling[i];
}
}
aptp_solve(&self.symbolic, numeric, &mut rhs_perm, stack, par)?;
if let Some(ref scaling) = self.scaling {
for i in 0..n {
rhs_perm[i] *= scaling[i];
}
}
for new in 0..n {
rhs[perm_fwd[new]] = rhs_perm[new];
}
Ok(())
}
pub fn solve_scratch(&self, rhs_ncols: usize) -> StackReq {
if let Some(ref numeric) = self.numeric {
aptp_solve_scratch(numeric, rhs_ncols)
} else {
StackReq::empty()
}
}
pub fn solve_full(
matrix: &SparseColMat<usize, f64>,
rhs: &Col<f64>,
options: &SolverOptions,
) -> Result<Col<f64>, SparseError> {
let analyze_opts = AnalyzeOptions {
ordering: options.ordering.clone(),
};
let factor_opts = FactorOptions {
threshold: options.threshold,
fallback: options.fallback,
par: options.par,
nemin: options.nemin,
small_leaf_threshold: options.small_leaf_threshold,
..FactorOptions::default()
};
let mut solver = Self::analyze_with_matrix(matrix, &analyze_opts)?;
solver.factor(matrix, &factor_opts)?;
let scratch = solver.solve_scratch(1);
let mut mem = MemBuffer::new(scratch);
let stack = MemStack::new(&mut mem);
solver.solve(rhs, stack, options.par)
}
pub fn inertia(&self) -> Option<Inertia> {
self.numeric.as_ref().map(|numeric| {
let mut inertia = Inertia {
positive: 0,
negative: 0,
zero: 0,
};
for ff in numeric.front_factors() {
let local_inertia = ff.d11().compute_inertia();
inertia.positive += local_inertia.positive;
inertia.negative += local_inertia.negative;
inertia.zero += local_inertia.zero;
}
inertia
})
}
pub fn stats(&self) -> Option<&FactorizationStats> {
self.numeric.as_ref().map(|n| n.stats())
}
pub fn per_supernode_stats(&self) -> Option<&[super::numeric::PerSupernodeStats]> {
self.numeric.as_ref().map(|n| n.per_supernode_stats())
}
pub fn n(&self) -> usize {
self.symbolic.nrows()
}
}