use std::fmt;
use echidna::{BytecodeTape, Dual, Float};
use crate::linalg::{lu_back_solve, lu_factor, lu_solve};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ImplicitError {
Singular,
DimensionMismatch {
field: &'static str,
expected: usize,
actual: usize,
},
}
impl fmt::Display for ImplicitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ImplicitError::Singular => {
write!(
f,
"implicit: F_z is singular, ill-conditioned, or produced a non-finite solve"
)
}
ImplicitError::DimensionMismatch {
field,
expected,
actual,
} => {
write!(
f,
"implicit: dimension mismatch for `{field}` (expected {expected}, got {actual})"
)
}
}
}
}
impl std::error::Error for ImplicitError {}
echidna::assert_send_sync!(ImplicitError);
fn partition_jacobian<F: Float>(jac: &[Vec<F>], num_states: usize) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
let m = num_states;
let mut f_z = Vec::with_capacity(m);
let mut f_x = Vec::with_capacity(m);
for row in jac {
f_z.push(row[..m].to_vec());
f_x.push(row[m..].to_vec());
}
(f_z, f_x)
}
fn transpose<F: Float>(mat: &[Vec<F>]) -> Vec<Vec<F>> {
if mat.is_empty() {
return vec![];
}
let rows = mat.len();
let cols = mat[0].len();
let mut result = vec![vec![F::zero(); rows]; cols];
for i in 0..rows {
for j in 0..cols {
result[j][i] = mat[i][j];
}
}
result
}
fn validate_inputs<F: Float>(tape: &BytecodeTape<F>, z_star: &[F], x: &[F], num_states: usize) {
assert_eq!(
z_star.len(),
num_states,
"z_star length ({}) must equal num_states ({})",
z_star.len(),
num_states
);
assert_eq!(
tape.num_inputs(),
num_states + x.len(),
"tape.num_inputs() ({}) must equal num_states + x.len() ({})",
tape.num_inputs(),
num_states + x.len()
);
assert_eq!(
tape.num_outputs(),
num_states,
"tape.num_outputs() ({}) must equal num_states ({}) — IFT requires F: R^(m+n) → R^m to be square in the state block",
tape.num_outputs(),
num_states
);
}
fn compute_partitioned_jacobian<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
num_states: usize,
) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
let mut inputs = Vec::with_capacity(z_star.len() + x.len());
inputs.extend_from_slice(z_star);
inputs.extend_from_slice(x);
#[cfg(debug_assertions)]
{
tape.forward(&inputs);
let residual = tape.output_values();
let norm_sq: F = residual.iter().fold(F::zero(), |acc, &v| acc + v * v);
let norm = norm_sq.sqrt();
let threshold = F::from(1e-6).unwrap_or_else(|| F::epsilon());
if norm > threshold {
eprintln!(
"WARNING: implicit differentiation called with ||F(z*, x)|| = {:?} > 1e-6. \
Derivatives may be meaningless if z* is not a root.",
norm.to_f64()
);
}
}
let jac = tape.jacobian(&inputs);
partition_jacobian(&jac, num_states)
}
pub fn implicit_jacobian<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
num_states: usize,
) -> Result<Vec<Vec<F>>, ImplicitError> {
validate_inputs(tape, z_star, x, num_states);
let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
let m = num_states;
let n = x.len();
let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
let mut result = vec![vec![F::zero(); n]; m];
for j in 0..n {
let neg_col: Vec<F> = (0..m).map(|i| F::zero() - f_x[i][j]).collect();
let col = lu_back_solve(&factors, &neg_col);
if col.iter().any(|v| !v.is_finite()) {
return Err(ImplicitError::Singular);
}
for i in 0..m {
result[i][j] = col[i];
}
}
Ok(result)
}
pub fn implicit_tangent<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
x_dot: &[F],
num_states: usize,
) -> Result<Vec<F>, ImplicitError> {
if x_dot.len() != x.len() {
return Err(ImplicitError::DimensionMismatch {
field: "x_dot",
expected: x.len(),
actual: x_dot.len(),
});
}
validate_inputs(tape, z_star, x, num_states);
let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
let m = num_states;
let n = x.len();
let mut fx_xdot = vec![F::zero(); m];
for i in 0..m {
for j in 0..n {
fx_xdot[i] = fx_xdot[i] + f_x[i][j] * x_dot[j];
}
}
let neg_fx_xdot: Vec<F> = fx_xdot.iter().map(|&v| F::zero() - v).collect();
let sol = lu_solve(&f_z, &neg_fx_xdot).ok_or(ImplicitError::Singular)?;
if sol.iter().any(|v| !v.is_finite()) {
return Err(ImplicitError::Singular);
}
Ok(sol)
}
pub fn implicit_adjoint<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
z_bar: &[F],
num_states: usize,
) -> Result<Vec<F>, ImplicitError> {
if z_bar.len() != num_states {
return Err(ImplicitError::DimensionMismatch {
field: "z_bar",
expected: num_states,
actual: z_bar.len(),
});
}
validate_inputs(tape, z_star, x, num_states);
let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
let m = num_states;
let n = x.len();
let f_z_t = transpose(&f_z);
let lambda = lu_solve(&f_z_t, z_bar).ok_or(ImplicitError::Singular)?;
let f_x_t = transpose(&f_x);
let mut x_bar = vec![F::zero(); n];
for j in 0..n {
for i in 0..m {
x_bar[j] = x_bar[j] - f_x_t[j][i] * lambda[i];
}
}
if x_bar.iter().any(|v| !v.is_finite()) {
return Err(ImplicitError::Singular);
}
Ok(x_bar)
}
pub fn implicit_hvp<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
v: &[F],
w: &[F],
num_states: usize,
) -> Result<Vec<F>, ImplicitError> {
let n = x.len();
let m = num_states;
if v.len() != n {
return Err(ImplicitError::DimensionMismatch {
field: "v",
expected: n,
actual: v.len(),
});
}
if w.len() != n {
return Err(ImplicitError::DimensionMismatch {
field: "w",
expected: n,
actual: w.len(),
});
}
validate_inputs(tape, z_star, x, num_states);
let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
let mut fx_v = vec![F::zero(); m];
let mut fx_w = vec![F::zero(); m];
for i in 0..m {
for j in 0..n {
fx_v[i] = fx_v[i] + f_x[i][j] * v[j];
fx_w[i] = fx_w[i] + f_x[i][j] * w[j];
}
}
let neg_fx_v: Vec<F> = fx_v.iter().map(|&val| F::zero() - val).collect();
let neg_fx_w: Vec<F> = fx_w.iter().map(|&val| F::zero() - val).collect();
let z_dot_v = lu_back_solve(&factors, &neg_fx_v);
let z_dot_w = lu_back_solve(&factors, &neg_fx_w);
let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
for i in 0..m {
dd_inputs.push(Dual::new(
Dual::new(z_star[i], z_dot_v[i]),
Dual::new(z_dot_w[i], F::zero()),
));
}
for j in 0..n {
dd_inputs.push(Dual::new(Dual::new(x[j], v[j]), Dual::new(w[j], F::zero())));
}
let mut buf = Vec::new();
tape.forward_tangent(&dd_inputs, &mut buf);
let out_indices = tape.all_output_indices();
let mut rhs = Vec::with_capacity(m);
for &idx in out_indices {
rhs.push(buf[idx as usize].eps.eps);
}
let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
let h = lu_back_solve(&factors, &neg_rhs);
if h.iter().any(|v| !v.is_finite()) {
return Err(ImplicitError::Singular);
}
Ok(h)
}
pub fn implicit_hessian<F: Float>(
tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
num_states: usize,
) -> Result<Vec<Vec<Vec<F>>>, ImplicitError> {
let n = x.len();
let m = num_states;
validate_inputs(tape, z_star, x, num_states);
let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
let mut sens_cols: Vec<Vec<F>> = Vec::with_capacity(n);
for j in 0..n {
let neg_col: Vec<F> = f_x.iter().map(|row| F::zero() - row[j]).collect();
sens_cols.push(lu_back_solve(&factors, &neg_col));
}
let out_indices = tape.all_output_indices();
let mut result = vec![vec![vec![F::zero(); n]; n]; m];
let mut buf: Vec<Dual<Dual<F>>> = Vec::new();
for j in 0..n {
for k in j..n {
let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
for i in 0..m {
dd_inputs.push(Dual::new(
Dual::new(z_star[i], sens_cols[j][i]),
Dual::new(sens_cols[k][i], F::zero()),
));
}
for (l, &x_l) in x.iter().enumerate() {
let p_l = if l == j { F::one() } else { F::zero() };
let w_l = if l == k { F::one() } else { F::zero() };
dd_inputs.push(Dual::new(Dual::new(x_l, p_l), Dual::new(w_l, F::zero())));
}
tape.forward_tangent(&dd_inputs, &mut buf);
let mut rhs = Vec::with_capacity(m);
for &idx in out_indices {
rhs.push(buf[idx as usize].eps.eps);
}
let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
let h = lu_back_solve(&factors, &neg_rhs);
if h.iter().any(|v| !v.is_finite()) {
return Err(ImplicitError::Singular);
}
for i in 0..m {
result[i][j][k] = h[i];
result[i][k][j] = h[i]; }
}
}
Ok(result)
}