use std::fmt;
use echidna::sparse::{column_coloring, row_coloring, JacobianSparsityPattern};
use echidna::BytecodeTape;
use faer::linalg::solvers::Solve;
use faer::sparse::SparseColMat;
use faer::Col;
#[non_exhaustive]
#[derive(Debug)]
pub enum SparseImplicitError {
StructuralSingular {
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
FactorSingular {
source: Box<dyn std::error::Error + Send + Sync + 'static>,
},
NumericSingular,
ResidualExceeded {
relative_residual: f64,
tolerance: f64,
dimension: usize,
},
DimensionMismatch {
field: &'static str,
expected: usize,
actual: usize,
},
}
impl fmt::Display for SparseImplicitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SparseImplicitError::StructuralSingular { .. } => {
write!(f, "sparse_implicit: failed to build sparse F_z matrix")
}
SparseImplicitError::FactorSingular { .. } => {
write!(
f,
"sparse_implicit: sparse LU factorization failed (F_z singular)"
)
}
SparseImplicitError::NumericSingular => {
write!(
f,
"sparse_implicit: probe solve produced non-finite solution (F_z numerically singular)"
)
}
SparseImplicitError::ResidualExceeded {
relative_residual,
tolerance,
dimension,
} => {
write!(
f,
"sparse_implicit: probe solve residual {relative_residual:.3e} exceeds tolerance {tolerance:.3e} (F_z ill-conditioned, dim = {dimension})"
)
}
SparseImplicitError::DimensionMismatch {
field,
expected,
actual,
} => {
write!(
f,
"sparse_implicit: dimension mismatch for `{field}` (expected {expected}, got {actual})"
)
}
}
}
}
impl std::error::Error for SparseImplicitError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::StructuralSingular { source } | Self::FactorSingular { source } => {
Some(&**source)
}
Self::NumericSingular
| Self::ResidualExceeded { .. }
| Self::DimensionMismatch { .. } => None,
}
}
}
echidna::assert_send_sync!(SparseImplicitError);
pub struct SparseImplicitContext {
pattern: JacobianSparsityPattern,
colors: Vec<u32>,
num_colors: u32,
forward_mode: bool,
num_states: usize,
num_params: usize,
fz_indices: Vec<usize>,
fx_indices: Vec<usize>,
fx_by_col: Vec<Vec<usize>>,
}
impl SparseImplicitContext {
pub fn new(tape: &BytecodeTape<f64>, num_states: usize) -> Self {
let m = num_states;
assert_eq!(
tape.num_outputs(),
m,
"tape.num_outputs() ({}) must equal num_states ({})",
tape.num_outputs(),
m
);
assert!(
tape.num_inputs() > m,
"tape.num_inputs() ({}) must be greater than num_states ({})",
tape.num_inputs(),
m
);
let n = tape.num_inputs() - m;
let pattern = tape.detect_jacobian_sparsity();
let (col_colors, col_ncolors) = column_coloring(&pattern);
let (row_colors, row_ncolors) = row_coloring(&pattern);
let (colors, num_colors, forward_mode) = if col_ncolors <= row_ncolors {
(col_colors, col_ncolors, true)
} else {
(row_colors, row_ncolors, false)
};
let mut fz_indices = Vec::new();
let mut fx_indices = Vec::new();
let mut fx_by_col: Vec<Vec<usize>> = vec![Vec::new(); n];
for k in 0..pattern.nnz() {
let col = pattern.cols[k] as usize;
if col < m {
fz_indices.push(k);
} else {
fx_indices.push(k);
fx_by_col[col - m].push(k);
}
}
SparseImplicitContext {
pattern,
colors,
num_colors,
forward_mode,
num_states: m,
num_params: n,
fz_indices,
fx_indices,
fx_by_col,
}
}
pub fn num_states(&self) -> usize {
self.num_states
}
pub fn num_params(&self) -> usize {
self.num_params
}
pub fn nnz(&self) -> usize {
self.pattern.nnz()
}
pub fn fz_nnz(&self) -> usize {
self.fz_indices.len()
}
pub fn fx_nnz(&self) -> usize {
self.fx_indices.len()
}
}
fn extract_fz_triplets(
ctx: &SparseImplicitContext,
jac_values: &[f64],
) -> Vec<faer::sparse::Triplet<usize, usize, f64>> {
ctx.fz_indices
.iter()
.map(|&k| faer::sparse::Triplet {
row: ctx.pattern.rows[k] as usize,
col: ctx.pattern.cols[k] as usize,
val: jac_values[k],
})
.collect()
}
fn build_fz_and_factor(
ctx: &SparseImplicitContext,
jac_values: &[f64],
) -> Result<faer::sparse::linalg::solvers::Lu<usize, f64>, SparseImplicitError> {
let m = ctx.num_states;
let triplets = extract_fz_triplets(ctx, jac_values);
let mat = SparseColMat::<usize, f64>::try_new_from_triplets(m, m, &triplets).map_err(|e| {
SparseImplicitError::StructuralSingular {
source: Box::new(e),
}
})?;
let lu = mat
.sp_lu()
.map_err(|e| SparseImplicitError::FactorSingular {
source: Box::new(e),
})?;
let test_rhs_vec: Vec<f64> = (0..m)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let test_rhs = Col::<f64>::from_fn(m, |i| test_rhs_vec[i]);
let test_sol = lu.solve(&test_rhs);
if (0..m).any(|i| !test_sol[i].is_finite()) {
return Err(SparseImplicitError::NumericSingular);
}
let test_sol_slice: Vec<f64> = (0..m).map(|i| test_sol[i]).collect();
let mut applied = vec![0.0_f64; m];
for &k in &ctx.fz_indices {
let row = ctx.pattern.rows[k] as usize;
let col = ctx.pattern.cols[k] as usize;
applied[row] += jac_values[k] * test_sol_slice[col];
}
let mut resid_sq = 0.0_f64;
let mut rhs_sq = 0.0_f64;
for i in 0..m {
let r = applied[i] - test_rhs_vec[i];
resid_sq += r * r;
rhs_sq += test_rhs_vec[i] * test_rhs_vec[i];
}
let tol = (f64::EPSILON.sqrt()) * (m as f64).sqrt();
if !resid_sq.is_finite() {
return Err(SparseImplicitError::NumericSingular);
}
if resid_sq > tol * tol * rhs_sq {
let relative_residual = (resid_sq / rhs_sq).sqrt();
return Err(SparseImplicitError::ResidualExceeded {
relative_residual,
tolerance: tol,
dimension: m,
});
}
Ok(lu)
}
fn fx_matvec(ctx: &SparseImplicitContext, jac_values: &[f64], v: &[f64]) -> Vec<f64> {
let m = ctx.num_states;
let mut result = vec![0.0; m];
for &k in &ctx.fx_indices {
let row = ctx.pattern.rows[k] as usize;
let col = ctx.pattern.cols[k] as usize - m;
result[row] += jac_values[k] * v[col];
}
result
}
fn fx_transpose_matvec(ctx: &SparseImplicitContext, jac_values: &[f64], v: &[f64]) -> Vec<f64> {
let m = ctx.num_states;
let n = ctx.num_params;
let mut result = vec![0.0; n];
for &k in &ctx.fx_indices {
let row = ctx.pattern.rows[k] as usize;
let col = ctx.pattern.cols[k] as usize - m;
result[col] += jac_values[k] * v[row];
}
result
}
fn compute_sparse_jacobian(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
ctx: &SparseImplicitContext,
) -> (Vec<f64>, Vec<f64>) {
let mut inputs = Vec::with_capacity(ctx.num_states + ctx.num_params);
inputs.extend_from_slice(z_star);
inputs.extend_from_slice(x);
let (outputs, jac_values) = tape.sparse_jacobian_with_pattern(
&inputs,
&ctx.pattern,
&ctx.colors,
ctx.num_colors,
ctx.forward_mode,
);
#[cfg(debug_assertions)]
{
let norm_sq: f64 = outputs.iter().map(|v| v * v).sum();
let norm = norm_sq.sqrt();
if norm > 1e-6 {
eprintln!(
"WARNING: sparse implicit differentiation called with ||F(z*, x)|| = {:.6e} > 1e-6. \
Derivatives may be meaningless if z* is not a root.",
norm
);
}
}
(outputs, jac_values)
}
fn col_to_vec(col: &Col<f64>, len: usize) -> Vec<f64> {
(0..len).map(|i| col[i]).collect()
}
pub fn implicit_tangent_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
x_dot: &[f64],
ctx: &SparseImplicitContext,
) -> Result<Vec<f64>, SparseImplicitError> {
let m = ctx.num_states;
let n = ctx.num_params;
if z_star.len() != m {
return Err(SparseImplicitError::DimensionMismatch {
field: "z_star",
expected: m,
actual: z_star.len(),
});
}
if x.len() != n {
return Err(SparseImplicitError::DimensionMismatch {
field: "x",
expected: n,
actual: x.len(),
});
}
if x_dot.len() != n {
return Err(SparseImplicitError::DimensionMismatch {
field: "x_dot",
expected: n,
actual: x_dot.len(),
});
}
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let fx_xdot = fx_matvec(ctx, &jac_values, x_dot);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let rhs = Col::<f64>::from_fn(m, |i| -fx_xdot[i]);
let sol = lu.solve(&rhs);
Ok(col_to_vec(&sol, m))
}
pub fn implicit_adjoint_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
z_bar: &[f64],
ctx: &SparseImplicitContext,
) -> Result<Vec<f64>, SparseImplicitError> {
let m = ctx.num_states;
let n = ctx.num_params;
if z_star.len() != m {
return Err(SparseImplicitError::DimensionMismatch {
field: "z_star",
expected: m,
actual: z_star.len(),
});
}
if x.len() != n {
return Err(SparseImplicitError::DimensionMismatch {
field: "x",
expected: n,
actual: x.len(),
});
}
if z_bar.len() != m {
return Err(SparseImplicitError::DimensionMismatch {
field: "z_bar",
expected: m,
actual: z_bar.len(),
});
}
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let rhs = Col::<f64>::from_fn(m, |i| z_bar[i]);
let lambda = lu.solve_transpose(&rhs);
let lambda_vec = col_to_vec(&lambda, m);
let fx_t_lambda = fx_transpose_matvec(ctx, &jac_values, &lambda_vec);
let x_bar: Vec<f64> = fx_t_lambda.iter().map(|&v| -v).collect();
Ok(x_bar)
}
pub fn implicit_jacobian_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
ctx: &SparseImplicitContext,
) -> Result<Vec<Vec<f64>>, SparseImplicitError> {
let m = ctx.num_states;
let n = ctx.num_params;
if z_star.len() != m {
return Err(SparseImplicitError::DimensionMismatch {
field: "z_star",
expected: m,
actual: z_star.len(),
});
}
if x.len() != n {
return Err(SparseImplicitError::DimensionMismatch {
field: "x",
expected: n,
actual: x.len(),
});
}
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let mut result = vec![vec![0.0; n]; m];
for (j, fx_col_indices) in ctx.fx_by_col.iter().enumerate() {
let mut neg_col = vec![0.0; m];
for &k in fx_col_indices {
let row = ctx.pattern.rows[k] as usize;
neg_col[row] -= jac_values[k];
}
let rhs = Col::<f64>::from_fn(m, |i| neg_col[i]);
let sol = lu.solve(&rhs);
for i in 0..m {
result[i][j] = sol[i];
}
}
Ok(result)
}