use echidna::sparse::{column_coloring, row_coloring, JacobianSparsityPattern};
use echidna::BytecodeTape;
use faer::linalg::solvers::Solve;
use faer::sparse::SparseColMat;
use faer::Col;
pub struct SparseImplicitContext {
pattern: JacobianSparsityPattern,
colors: Vec<u32>,
num_colors: u32,
forward_mode: bool,
num_states: usize,
num_params: usize,
fz_indices: Vec<usize>,
fx_indices: Vec<usize>,
fx_by_col: Vec<Vec<usize>>,
}
impl SparseImplicitContext {
pub fn new(tape: &BytecodeTape<f64>, num_states: usize) -> Self {
let m = num_states;
assert_eq!(
tape.num_outputs(),
m,
"tape.num_outputs() ({}) must equal num_states ({})",
tape.num_outputs(),
m
);
assert!(
tape.num_inputs() > m,
"tape.num_inputs() ({}) must be greater than num_states ({})",
tape.num_inputs(),
m
);
let n = tape.num_inputs() - m;
let pattern = tape.detect_jacobian_sparsity();
let (col_colors, col_ncolors) = column_coloring(&pattern);
let (row_colors, row_ncolors) = row_coloring(&pattern);
let (colors, num_colors, forward_mode) = if col_ncolors <= row_ncolors {
(col_colors, col_ncolors, true)
} else {
(row_colors, row_ncolors, false)
};
let mut fz_indices = Vec::new();
let mut fx_indices = Vec::new();
let mut fx_by_col: Vec<Vec<usize>> = vec![Vec::new(); n];
for k in 0..pattern.nnz() {
let col = pattern.cols[k] as usize;
if col < m {
fz_indices.push(k);
} else {
fx_indices.push(k);
fx_by_col[col - m].push(k);
}
}
SparseImplicitContext {
pattern,
colors,
num_colors,
forward_mode,
num_states: m,
num_params: n,
fz_indices,
fx_indices,
fx_by_col,
}
}
pub fn num_states(&self) -> usize {
self.num_states
}
pub fn num_params(&self) -> usize {
self.num_params
}
pub fn nnz(&self) -> usize {
self.pattern.nnz()
}
pub fn fz_nnz(&self) -> usize {
self.fz_indices.len()
}
pub fn fx_nnz(&self) -> usize {
self.fx_indices.len()
}
}
fn extract_fz_triplets(
ctx: &SparseImplicitContext,
jac_values: &[f64],
) -> Vec<faer::sparse::Triplet<usize, usize, f64>> {
ctx.fz_indices
.iter()
.map(|&k| faer::sparse::Triplet {
row: ctx.pattern.rows[k] as usize,
col: ctx.pattern.cols[k] as usize,
val: jac_values[k],
})
.collect()
}
fn build_fz_and_factor(
ctx: &SparseImplicitContext,
jac_values: &[f64],
) -> Option<faer::sparse::linalg::solvers::Lu<usize, f64>> {
let m = ctx.num_states;
let triplets = extract_fz_triplets(ctx, jac_values);
let mat = SparseColMat::<usize, f64>::try_new_from_triplets(m, m, &triplets).ok()?;
let lu = mat.sp_lu().ok()?;
let test_rhs = Col::<f64>::from_fn(m, |_| 1.0);
let test_sol = lu.solve(&test_rhs);
if (0..m).any(|i| !test_sol[i].is_finite()) {
return None;
}
Some(lu)
}
fn fx_matvec(ctx: &SparseImplicitContext, jac_values: &[f64], v: &[f64]) -> Vec<f64> {
let m = ctx.num_states;
let mut result = vec![0.0; m];
for &k in &ctx.fx_indices {
let row = ctx.pattern.rows[k] as usize;
let col = ctx.pattern.cols[k] as usize - m;
result[row] += jac_values[k] * v[col];
}
result
}
fn fx_transpose_matvec(ctx: &SparseImplicitContext, jac_values: &[f64], v: &[f64]) -> Vec<f64> {
let m = ctx.num_states;
let n = ctx.num_params;
let mut result = vec![0.0; n];
for &k in &ctx.fx_indices {
let row = ctx.pattern.rows[k] as usize;
let col = ctx.pattern.cols[k] as usize - m;
result[col] += jac_values[k] * v[row];
}
result
}
fn compute_sparse_jacobian(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
ctx: &SparseImplicitContext,
) -> (Vec<f64>, Vec<f64>) {
let mut inputs = Vec::with_capacity(ctx.num_states + ctx.num_params);
inputs.extend_from_slice(z_star);
inputs.extend_from_slice(x);
let (outputs, jac_values) = tape.sparse_jacobian_with_pattern(
&inputs,
&ctx.pattern,
&ctx.colors,
ctx.num_colors,
ctx.forward_mode,
);
#[cfg(debug_assertions)]
{
let norm_sq: f64 = outputs.iter().map(|v| v * v).sum();
let norm = norm_sq.sqrt();
if norm > 1e-6 {
eprintln!(
"WARNING: sparse implicit differentiation called with ||F(z*, x)|| = {:.6e} > 1e-6. \
Derivatives may be meaningless if z* is not a root.",
norm
);
}
}
(outputs, jac_values)
}
fn col_to_vec(col: &Col<f64>, len: usize) -> Vec<f64> {
(0..len).map(|i| col[i]).collect()
}
pub fn implicit_tangent_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
x_dot: &[f64],
ctx: &SparseImplicitContext,
) -> Option<Vec<f64>> {
let m = ctx.num_states;
let n = ctx.num_params;
assert_eq!(
z_star.len(),
m,
"z_star length ({}) must equal num_states ({})",
z_star.len(),
m
);
assert_eq!(
x.len(),
n,
"x length ({}) must equal num_params ({})",
x.len(),
n
);
assert_eq!(
x_dot.len(),
n,
"x_dot length ({}) must equal num_params ({})",
x_dot.len(),
n
);
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let fx_xdot = fx_matvec(ctx, &jac_values, x_dot);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let rhs = Col::<f64>::from_fn(m, |i| -fx_xdot[i]);
let sol = lu.solve(&rhs);
Some(col_to_vec(&sol, m))
}
pub fn implicit_adjoint_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
z_bar: &[f64],
ctx: &SparseImplicitContext,
) -> Option<Vec<f64>> {
let m = ctx.num_states;
let n = ctx.num_params;
assert_eq!(
z_star.len(),
m,
"z_star length ({}) must equal num_states ({})",
z_star.len(),
m
);
assert_eq!(
x.len(),
n,
"x length ({}) must equal num_params ({})",
x.len(),
n
);
assert_eq!(
z_bar.len(),
m,
"z_bar length ({}) must equal num_states ({})",
z_bar.len(),
m
);
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let rhs = Col::<f64>::from_fn(m, |i| z_bar[i]);
let lambda = lu.solve_transpose(&rhs);
let lambda_vec = col_to_vec(&lambda, m);
let fx_t_lambda = fx_transpose_matvec(ctx, &jac_values, &lambda_vec);
let x_bar: Vec<f64> = fx_t_lambda.iter().map(|&v| -v).collect();
Some(x_bar)
}
pub fn implicit_jacobian_sparse(
tape: &mut BytecodeTape<f64>,
z_star: &[f64],
x: &[f64],
ctx: &SparseImplicitContext,
) -> Option<Vec<Vec<f64>>> {
let m = ctx.num_states;
let n = ctx.num_params;
assert_eq!(
z_star.len(),
m,
"z_star length ({}) must equal num_states ({})",
z_star.len(),
m
);
assert_eq!(
x.len(),
n,
"x length ({}) must equal num_params ({})",
x.len(),
n
);
let (_outputs, jac_values) = compute_sparse_jacobian(tape, z_star, x, ctx);
let lu = build_fz_and_factor(ctx, &jac_values)?;
let mut result = vec![vec![0.0; n]; m];
for (j, fx_col_indices) in ctx.fx_by_col.iter().enumerate() {
let mut neg_col = vec![0.0; m];
for &k in fx_col_indices {
let row = ctx.pattern.rows[k] as usize;
neg_col[row] -= jac_values[k];
}
let rhs = Col::<f64>::from_fn(m, |i| neg_col[i]);
let sol = lu.solve(&rhs);
for i in 0..m {
result[i][j] = sol[i];
}
}
Some(result)
}