use crate::DType;
#[cfg(feature = "sparse")]
use numr::algorithm::iterative::{ConvergenceReason, GmresOptions, IterativeSolvers};
#[cfg(feature = "sparse")]
use numr::algorithm::sparse_linalg::{
IluDecomposition, IluOptions, SparseLinAlgAlgorithms, SymbolicIlu0,
};
#[cfg(feature = "sparse")]
use numr::error::Result;
#[cfg(feature = "sparse")]
use numr::runtime::Runtime;
#[cfg(feature = "sparse")]
use numr::sparse::{CsrData, SparseOps};
#[cfg(feature = "sparse")]
use numr::tensor::Tensor;
#[cfg(feature = "sparse")]
use super::direct_solver::DirectSparseSolver;
#[cfg(feature = "sparse")]
use crate::integrate::ode::SparseJacobianConfig;
#[cfg(feature = "sparse")]
pub struct SparseJacobianCache {
symbolic_ilu: Option<SymbolicIlu0>,
pub cache_hits: usize,
pub cache_misses: usize,
}
#[cfg(feature = "sparse")]
impl SparseJacobianCache {
pub fn new() -> Self {
Self {
symbolic_ilu: None,
cache_hits: 0,
cache_misses: 0,
}
}
pub fn get_or_compute_ilu<R, C>(
&mut self,
client: &C,
matrix: &CsrData<R>,
options: IluOptions,
) -> Result<IluDecomposition<R>>
where
R: Runtime<DType = DType>,
C: SparseLinAlgAlgorithms<R>,
{
if self.symbolic_ilu.is_none() {
self.cache_misses += 1;
let s = client.ilu0_symbolic(matrix)?;
self.symbolic_ilu = Some(s);
} else {
self.cache_hits += 1;
}
let symbolic = self
.symbolic_ilu
.as_ref()
.expect("symbolic_ilu guaranteed to be Some after initialization above");
client.ilu0_numeric(matrix, symbolic, options)
}
pub fn invalidate(&mut self) {
self.symbolic_ilu = None;
}
pub fn has_symbolic(&self) -> bool {
self.symbolic_ilu.is_some()
}
}
#[cfg(feature = "sparse")]
impl Default for SparseJacobianCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "sparse")]
pub fn dense_to_csr_full<R, C>(client: &C, dense: &Tensor<R>) -> Result<CsrData<R>>
where
R: Runtime<DType = DType>,
C: SparseOps<R>,
{
let shape = dense.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(numr::error::Error::ShapeMismatch {
expected: vec![shape[0], shape[0]],
got: shape.to_vec(),
});
}
client.dense_to_csr(dense, 1e-15)
}
#[cfg(feature = "sparse")]
pub fn dense_to_csr_with_pattern<R, C>(
client: &C,
dense: &Tensor<R>,
_pattern: &CsrData<R>,
) -> Result<CsrData<R>>
where
R: Runtime<DType = DType>,
C: SparseOps<R>,
{
dense_to_csr_full(client, dense)
}
#[cfg(feature = "sparse")]
pub fn solve_with_gmres<R, C>(
client: &C,
m_sparse: &CsrData<R>,
b: &Tensor<R>,
sparse_config: &SparseJacobianConfig<R>,
solver_name: &str,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: IterativeSolvers<R>,
{
let gmres_opts = GmresOptions {
max_iter: sparse_config.max_gmres_iter,
rtol: sparse_config.gmres_tol,
atol: 1e-14,
preconditioner: sparse_config.preconditioner,
..Default::default()
};
let result = client.gmres(m_sparse, b, None, gmres_opts).map_err(|e| {
numr::error::Error::Internal(format!(
"GMRES failed in sparse {} solve: {}",
solver_name, e
))
})?;
if !result.converged {
let diag = &result.diagnostics;
let hint = match result.reason {
ConvergenceReason::MaxIterationsReached => format!(
"Consider: (1) Increase max_gmres_iter (currently {}), \
(2) Enable ILU preconditioning if not already, \
(3) Loosen gmres_tol (currently {:.2e})",
diag.max_iter, diag.rtol
),
ConvergenceReason::Stagnation => format!(
"Residual stagnated at {:.2e} (initial: {:.2e}). \
System may be ill-conditioned. Consider: \
(1) Enable ILU preconditioning, \
(2) Increase restart parameter, \
(3) Use a smaller time step",
result.residual_norm, diag.initial_residual_norm
),
ConvergenceReason::NumericalBreakdown => format!(
"Numerical breakdown at iteration {}. Matrix may be singular. \
Check: (1) Jacobian computation is correct, \
(2) Problem is well-posed, \
(3) Consider adding a diagonal shift",
result.iterations
),
_ => format!(
"Unexpected non-convergence reason: {}. {}",
result.reason,
result.reason.hint()
),
};
return Err(numr::error::Error::Internal(format!(
"GMRES did not converge in {} {}: {} iterations, residual = {:.2e}. {}",
solver_name, "Newton step", result.iterations, result.residual_norm, hint
)));
}
Ok(result.solution)
}
#[cfg(feature = "sparse")]
pub fn solve_with_direct_lu<R, C>(
client: &C,
direct_solver: &mut DirectSparseSolver<R>,
m_dense: &Tensor<R>,
b: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: SparseOps<R> + numr::ops::IndexingOps<R> + numr::ops::TensorOps<R> + numr::ops::ScalarOps<R>,
{
direct_solver.solve(client, m_dense, b)
}
#[cfg(feature = "sparse")]
pub fn solve_sparse_system<R, C>(
client: &C,
m_dense: &Tensor<R>,
b: &Tensor<R>,
sparse_config: &crate::integrate::ode::SparseJacobianConfig<R>,
direct_solver: &mut Option<DirectSparseSolver<R>>,
pattern: Option<&CsrData<R>>,
solver_name: &str,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: SparseOps<R>
+ numr::ops::IndexingOps<R>
+ numr::ops::TensorOps<R>
+ numr::ops::ScalarOps<R>
+ IterativeSolvers<R>,
{
if let Some(ds) = direct_solver.as_mut() {
return solve_with_direct_lu(client, ds, m_dense, b);
}
let m_sparse = if let Some(pat) = pattern {
dense_to_csr_with_pattern(client, m_dense, pat)?
} else {
dense_to_csr_full(client, m_dense)?
};
let b_1d = if b.shape().len() == 2 && b.shape()[1] == 1 {
let n = b.shape()[0];
b.reshape(&[n]).map_err(|e| {
numr::error::Error::Internal(format!("Failed to reshape RHS to 1D: {}", e))
})?
} else {
b.clone()
};
solve_with_gmres(client, &m_sparse, &b_1d, sparse_config, solver_name)
}
#[cfg(feature = "sparse")]
pub fn create_direct_solver<R: Runtime<DType = DType>>(
sparse_config: &crate::integrate::ode::SparseJacobianConfig<R>,
n: usize,
) -> Option<DirectSparseSolver<R>> {
use super::direct_solver_config::SparseSolverStrategy;
if !sparse_config.enabled {
return None;
}
match sparse_config.solver_strategy {
SparseSolverStrategy::DirectLU => {
Some(DirectSparseSolver::new(&sparse_config.direct_solver_config))
}
SparseSolverStrategy::Auto if n < 5000 => {
Some(DirectSparseSolver::new(&sparse_config.direct_solver_config))
}
_ => None,
}
}