Skip to main content

echidna_optim/
implicit.rs

1use std::fmt;
2
3use echidna::{BytecodeTape, Dual, Float};
4
5use crate::linalg::{lu_back_solve, lu_factor, lu_solve};
6
7/// Reason a dense implicit-differentiation call failed.
8///
9/// Marked `#[non_exhaustive]` so future variants can be added without
10/// breaking exhaustive `match`es.
11#[non_exhaustive]
12#[derive(Debug, Clone)]
13pub enum ImplicitError {
14    /// `F_z` could not be used for a reliable solve. Fires in three
15    /// situations, currently collapsed under one name because the dense
16    /// LU does not expose which branch tripped:
17    ///
18    /// 1. **Structural singularity** — `linalg::lu_factor` encountered
19    ///    an exactly-zero pivot (e.g. a rank-deficient `F_z`).
20    /// 2. **Numeric singularity** — a pivot below `ε·n·‖F_z‖∞` (the
21    ///    relative threshold anchored on the matrix infinity norm).
22    /// 3. **Non-finite input or intermediate** — a NaN or ±Inf reached
23    ///    the LU pivot (rejected up-front by `lu_factor` to prevent
24    ///    silent NaN-tainted factors), or, for `implicit_hvp` /
25    ///    `implicit_hessian`, the nested-dual forward pass produced
26    ///    non-finite higher-order coefficients that would poison the
27    ///    back-solve output.
28    ///
29    /// `#[non_exhaustive]` leaves room to split these cases (or to
30    /// align with a future unified naming axis across dense and sparse
31    /// implicit modules) without a breaking change.
32    Singular,
33    /// A runtime-supplied vector argument to a public `implicit_*`
34    /// fn had an unexpected length. `field` names the argument
35    /// (e.g. `"x_dot"`, `"z_bar"`, `"v"`, `"w"`), `expected` is the
36    /// length the API requires (typically `num_states` or `x.len()`),
37    /// `actual` is the length the caller supplied.
38    ///
39    /// Tape-shape contract mismatches (checked in `validate_inputs`)
40    /// continue to panic — those are programmer-contract violations,
41    /// not recoverable runtime failures.
42    DimensionMismatch {
43        field: &'static str,
44        expected: usize,
45        actual: usize,
46    },
47}
48
49impl fmt::Display for ImplicitError {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            ImplicitError::Singular => {
53                write!(
54                    f,
55                    "implicit: F_z is singular, ill-conditioned, or produced a non-finite solve"
56                )
57            }
58            ImplicitError::DimensionMismatch {
59                field,
60                expected,
61                actual,
62            } => {
63                write!(
64                    f,
65                    "implicit: dimension mismatch for `{field}` (expected {expected}, got {actual})"
66                )
67            }
68        }
69    }
70}
71
72impl std::error::Error for ImplicitError {}
73
74echidna::assert_send_sync!(ImplicitError);
75
76/// Partition a full Jacobian `J_F` (m × (m+n)) into `F_z` (m × m) and `F_x` (m × n).
77///
78/// `num_states` is `m`, the number of state variables (first `m` columns → `F_z`).
79fn partition_jacobian<F: Float>(jac: &[Vec<F>], num_states: usize) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
80    let m = num_states;
81    let mut f_z = Vec::with_capacity(m);
82    let mut f_x = Vec::with_capacity(m);
83    for row in jac {
84        f_z.push(row[..m].to_vec());
85        f_x.push(row[m..].to_vec());
86    }
87    (f_z, f_x)
88}
89
90/// Transpose an m × n matrix stored as `Vec<Vec<F>>`.
91fn transpose<F: Float>(mat: &[Vec<F>]) -> Vec<Vec<F>> {
92    if mat.is_empty() {
93        return vec![];
94    }
95    let rows = mat.len();
96    let cols = mat[0].len();
97    let mut result = vec![vec![F::zero(); rows]; cols];
98    for i in 0..rows {
99        for j in 0..cols {
100            result[j][i] = mat[i][j];
101        }
102    }
103    result
104}
105
106/// Validate inputs shared by all implicit differentiation functions.
107fn validate_inputs<F: Float>(tape: &BytecodeTape<F>, z_star: &[F], x: &[F], num_states: usize) {
108    assert_eq!(
109        z_star.len(),
110        num_states,
111        "z_star length ({}) must equal num_states ({})",
112        z_star.len(),
113        num_states
114    );
115    assert_eq!(
116        tape.num_inputs(),
117        num_states + x.len(),
118        "tape.num_inputs() ({}) must equal num_states + x.len() ({})",
119        tape.num_inputs(),
120        num_states + x.len()
121    );
122    assert_eq!(
123        tape.num_outputs(),
124        num_states,
125        "tape.num_outputs() ({}) must equal num_states ({}) — IFT requires F: R^(m+n) → R^m to be square in the state block",
126        tape.num_outputs(),
127        num_states
128    );
129}
130
131/// Build concatenated input `[z_star..., x...]` and compute the full Jacobian,
132/// partitioned into `(F_z, F_x)`.
133fn compute_partitioned_jacobian<F: Float>(
134    tape: &mut BytecodeTape<F>,
135    z_star: &[F],
136    x: &[F],
137    num_states: usize,
138) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
139    let mut inputs = Vec::with_capacity(z_star.len() + x.len());
140    inputs.extend_from_slice(z_star);
141    inputs.extend_from_slice(x);
142
143    // Debug check: warn if residual is not near zero
144    #[cfg(debug_assertions)]
145    {
146        tape.forward(&inputs);
147        let residual = tape.output_values();
148        let norm_sq: F = residual.iter().fold(F::zero(), |acc, &v| acc + v * v);
149        let norm = norm_sq.sqrt();
150        let threshold = F::from(1e-6).unwrap_or_else(|| F::epsilon());
151        if norm > threshold {
152            eprintln!(
153                "WARNING: implicit differentiation called with ||F(z*, x)|| = {:?} > 1e-6. \
154                 Derivatives may be meaningless if z* is not a root.",
155                norm.to_f64()
156            );
157        }
158    }
159
160    let jac = tape.jacobian(&inputs);
161    partition_jacobian(&jac, num_states)
162}
163
164/// Compute the full implicit Jacobian `dz*/dx` (m × n matrix).
165///
166/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
167/// computes `dz*/dx = -F_z^{-1} · F_x` via the Implicit Function Theorem.
168///
169/// The first `num_states` tape inputs are state variables `z`, the remaining are
170/// parameters `x`.
171///
172/// Returns `Err(ImplicitError::Singular)` if `F_z` is singular.
173pub fn implicit_jacobian<F: Float>(
174    tape: &mut BytecodeTape<F>,
175    z_star: &[F],
176    x: &[F],
177    num_states: usize,
178) -> Result<Vec<Vec<F>>, ImplicitError> {
179    validate_inputs(tape, z_star, x, num_states);
180    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
181
182    let m = num_states;
183    let n = x.len();
184
185    // LU-factorize F_z once, then solve for each column of -F_x
186    let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
187
188    // Build result column by column: solve F_z · col_j = -F_x[:, j]
189    let mut result = vec![vec![F::zero(); n]; m];
190    for j in 0..n {
191        let neg_col: Vec<F> = (0..m).map(|i| F::zero() - f_x[i][j]).collect();
192        let col = lu_back_solve(&factors, &neg_col);
193
194        // Same non-finite guard as the other publics. When `F_z` is
195        // finite but one column of `F_x` is not (possible in principle
196        // for tapes where `∂F/∂x` carries NaN without `∂F/∂z` doing so),
197        // the back-solve propagates NaN into this column. Check per
198        // column for early-return; `result` is a local and dropped on
199        // `Err`.
200        if col.iter().any(|v| !v.is_finite()) {
201            return Err(ImplicitError::Singular);
202        }
203
204        for i in 0..m {
205            result[i][j] = col[i];
206        }
207    }
208
209    Ok(result)
210}
211
212/// Compute the implicit tangent `dz*/dx · x_dot` (m-vector).
213///
214/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
215/// computes the directional derivative `dz*/dx · x_dot = -F_z^{-1} · (F_x · x_dot)`.
216///
217/// This solves a single linear system rather than computing the full Jacobian,
218/// which is more efficient when only one direction is needed.
219///
220/// Returns `Err(ImplicitError::Singular)` if `F_z` is singular.
221pub fn implicit_tangent<F: Float>(
222    tape: &mut BytecodeTape<F>,
223    z_star: &[F],
224    x: &[F],
225    x_dot: &[F],
226    num_states: usize,
227) -> Result<Vec<F>, ImplicitError> {
228    if x_dot.len() != x.len() {
229        return Err(ImplicitError::DimensionMismatch {
230            field: "x_dot",
231            expected: x.len(),
232            actual: x_dot.len(),
233        });
234    }
235    validate_inputs(tape, z_star, x, num_states);
236    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
237
238    let m = num_states;
239    let n = x.len();
240
241    // Compute F_x · x_dot (matrix-vector product)
242    let mut fx_xdot = vec![F::zero(); m];
243    for i in 0..m {
244        for j in 0..n {
245            fx_xdot[i] = fx_xdot[i] + f_x[i][j] * x_dot[j];
246        }
247    }
248
249    // Negate: rhs = -(F_x · x_dot)
250    let neg_fx_xdot: Vec<F> = fx_xdot.iter().map(|&v| F::zero() - v).collect();
251
252    // Solve F_z · z_dot = -(F_x · x_dot)
253    let sol = lu_solve(&f_z, &neg_fx_xdot).ok_or(ImplicitError::Singular)?;
254
255    // Guard: when `F_z` is finite but the RHS `-(F_x · x_dot)` contains
256    // non-finite entries (e.g. NaN in `x_dot`, or a tape whose `F_x` went
257    // non-finite without `F_z` doing so), the back-solve propagates the
258    // NaN into the returned vector. Without this check it escapes as
259    // `Ok(vec![NaN, ...])`, violating the contract that `Ok` implies a
260    // finite result.
261    if sol.iter().any(|v| !v.is_finite()) {
262        return Err(ImplicitError::Singular);
263    }
264
265    Ok(sol)
266}
267
268/// Compute the implicit adjoint `(dz*/dx)^T · z_bar` (n-vector).
269///
270/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
271/// computes `x_bar = -F_x^T · (F_z^{-T} · z_bar)`.
272///
273/// This is the reverse-mode (adjoint) form, useful when `n > m` or when
274/// propagating gradients backward through an implicit layer.
275///
276/// Returns `Err(ImplicitError::Singular)` if `F_z` is singular.
277pub fn implicit_adjoint<F: Float>(
278    tape: &mut BytecodeTape<F>,
279    z_star: &[F],
280    x: &[F],
281    z_bar: &[F],
282    num_states: usize,
283) -> Result<Vec<F>, ImplicitError> {
284    if z_bar.len() != num_states {
285        return Err(ImplicitError::DimensionMismatch {
286            field: "z_bar",
287            expected: num_states,
288            actual: z_bar.len(),
289        });
290    }
291    validate_inputs(tape, z_star, x, num_states);
292    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
293
294    let m = num_states;
295    let n = x.len();
296
297    // Solve F_z^T · lambda = z_bar
298    let f_z_t = transpose(&f_z);
299    let lambda = lu_solve(&f_z_t, z_bar).ok_or(ImplicitError::Singular)?;
300
301    // Compute x_bar = -F_x^T · lambda
302    let f_x_t = transpose(&f_x);
303    let mut x_bar = vec![F::zero(); n];
304    for j in 0..n {
305        for i in 0..m {
306            x_bar[j] = x_bar[j] - f_x_t[j][i] * lambda[i];
307        }
308    }
309
310    // Same non-finite guard as `implicit_tangent`. A non-finite `z_bar`
311    // makes the transpose-solve RHS non-finite; `lu_back_solve`
312    // propagates NaN through the substitution and without this check it
313    // escapes as `Ok(vec![NaN, ...])`.
314    if x_bar.iter().any(|v| !v.is_finite()) {
315        return Err(ImplicitError::Singular);
316    }
317
318    Ok(x_bar)
319}
320
321/// Compute the implicit Hessian-vector-vector product `d²z*/dx² · v · w` (m-vector).
322///
323/// Given a residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`, computes the
324/// second-order sensitivity by differentiating the IFT identity twice:
325///
326///   `F_z · h + [ṗ^T · Hess(F_i) · ẇ]_i = 0`
327///
328/// where `ṗ = [dz*/dx · v; v]`, `ẇ = [dz*/dx · w; w]`, and `h = d²z*/dx² · v · w`.
329///
330/// Uses nested `Dual<Dual<F>>` forward passes to compute the second-order correction
331/// in a single O(tape_length) pass per direction pair.
332///
333/// Returns `Err(ImplicitError::Singular)` if `F_z` is singular.
334pub fn implicit_hvp<F: Float>(
335    tape: &mut BytecodeTape<F>,
336    z_star: &[F],
337    x: &[F],
338    v: &[F],
339    w: &[F],
340    num_states: usize,
341) -> Result<Vec<F>, ImplicitError> {
342    let n = x.len();
343    let m = num_states;
344    if v.len() != n {
345        return Err(ImplicitError::DimensionMismatch {
346            field: "v",
347            expected: n,
348            actual: v.len(),
349        });
350    }
351    if w.len() != n {
352        return Err(ImplicitError::DimensionMismatch {
353            field: "w",
354            expected: n,
355            actual: w.len(),
356        });
357    }
358    validate_inputs(tape, z_star, x, num_states);
359
360    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
361    let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
362
363    // First-order sensitivities: ż_v = -F_z^{-1} · (F_x · v)
364    let mut fx_v = vec![F::zero(); m];
365    let mut fx_w = vec![F::zero(); m];
366    for i in 0..m {
367        for j in 0..n {
368            fx_v[i] = fx_v[i] + f_x[i][j] * v[j];
369            fx_w[i] = fx_w[i] + f_x[i][j] * w[j];
370        }
371    }
372    let neg_fx_v: Vec<F> = fx_v.iter().map(|&val| F::zero() - val).collect();
373    let neg_fx_w: Vec<F> = fx_w.iter().map(|&val| F::zero() - val).collect();
374    let z_dot_v = lu_back_solve(&factors, &neg_fx_v);
375    let z_dot_w = lu_back_solve(&factors, &neg_fx_w);
376
377    // Build Dual<Dual<F>> inputs for nested forward pass
378    // ṗ = [ż_v; v], ẇ = [ż_w; w]
379    // Input j: Dual { re: Dual(u_j, ṗ_j), eps: Dual(ẇ_j, 0) }
380    let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
381    for i in 0..m {
382        dd_inputs.push(Dual::new(
383            Dual::new(z_star[i], z_dot_v[i]),
384            Dual::new(z_dot_w[i], F::zero()),
385        ));
386    }
387    for j in 0..n {
388        dd_inputs.push(Dual::new(Dual::new(x[j], v[j]), Dual::new(w[j], F::zero())));
389    }
390
391    let mut buf = Vec::new();
392    tape.forward_tangent(&dd_inputs, &mut buf);
393
394    // Extract second-order correction: buf[out_idx].eps.eps for each output
395    let out_indices = tape.all_output_indices();
396    let mut rhs = Vec::with_capacity(m);
397    for &idx in out_indices {
398        rhs.push(buf[idx as usize].eps.eps);
399    }
400
401    // Solve F_z · h = -rhs
402    let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
403    let h = lu_back_solve(&factors, &neg_rhs);
404
405    // Guard against non-finite output. `forward_tangent` is an infallible
406    // straight-line pass; if any tape op on the `Dual<Dual<F>>` inputs
407    // produced NaN or ±Inf (e.g. a pathological higher-order derivative
408    // at a function-domain boundary), it lands in `buf[idx].eps.eps`,
409    // flows into the back-solve, and without this check would escape as
410    // `Ok(vec![NaN, ...])` — violating the contract that `Ok` implies a
411    // finite result.
412    if h.iter().any(|v| !v.is_finite()) {
413        return Err(ImplicitError::Singular);
414    }
415
416    Ok(h)
417}
418
419/// Compute the full implicit Hessian tensor `d²z*/dx²` (m × n × n).
420///
421/// Returns `result[i][j][k]` = ∂²z*_i / (∂x_j ∂x_k). The tensor is symmetric
422/// in the last two indices (j, k).
423///
424/// Cost: `n(n+1)/2` nested `Dual<Dual<F>>` forward passes plus `n(n+1)/2` back-solves,
425/// all sharing a single LU factorization of `F_z`.
426///
427/// Returns `Err(ImplicitError::Singular)` if `F_z` is singular.
428pub fn implicit_hessian<F: Float>(
429    tape: &mut BytecodeTape<F>,
430    z_star: &[F],
431    x: &[F],
432    num_states: usize,
433) -> Result<Vec<Vec<Vec<F>>>, ImplicitError> {
434    let n = x.len();
435    let m = num_states;
436    validate_inputs(tape, z_star, x, num_states);
437
438    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
439    let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
440
441    // First-order sensitivity columns: S[:,j] = -F_z^{-1} · F_x[:,j]
442    let mut sens_cols: Vec<Vec<F>> = Vec::with_capacity(n);
443    for j in 0..n {
444        let neg_col: Vec<F> = f_x.iter().map(|row| F::zero() - row[j]).collect();
445        sens_cols.push(lu_back_solve(&factors, &neg_col));
446    }
447
448    let out_indices = tape.all_output_indices();
449    let mut result = vec![vec![vec![F::zero(); n]; n]; m];
450    let mut buf: Vec<Dual<Dual<F>>> = Vec::new();
451
452    for j in 0..n {
453        for k in j..n {
454            // ṗ = [S[:,j]; e_j], ẇ = [S[:,k]; e_k]
455            let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
456            for i in 0..m {
457                dd_inputs.push(Dual::new(
458                    Dual::new(z_star[i], sens_cols[j][i]),
459                    Dual::new(sens_cols[k][i], F::zero()),
460                ));
461            }
462            for (l, &x_l) in x.iter().enumerate() {
463                let p_l = if l == j { F::one() } else { F::zero() };
464                let w_l = if l == k { F::one() } else { F::zero() };
465                dd_inputs.push(Dual::new(Dual::new(x_l, p_l), Dual::new(w_l, F::zero())));
466            }
467
468            tape.forward_tangent(&dd_inputs, &mut buf);
469
470            // Extract RHS and solve
471            let mut rhs = Vec::with_capacity(m);
472            for &idx in out_indices {
473                rhs.push(buf[idx as usize].eps.eps);
474            }
475            let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
476            let h = lu_back_solve(&factors, &neg_rhs);
477
478            // Same non-finite guard as `implicit_hvp`. A single bad (j, k)
479            // pair from a pathological higher-order derivative would
480            // otherwise corrupt one symmetric plane of the returned tensor
481            // while leaving the rest apparently valid.
482            if h.iter().any(|v| !v.is_finite()) {
483                return Err(ImplicitError::Singular);
484            }
485
486            for i in 0..m {
487                result[i][j][k] = h[i];
488                result[i][k][j] = h[i]; // Symmetric
489            }
490        }
491    }
492
493    Ok(result)
494}