Skip to main content

echidna_optim/
implicit.rs

1use echidna::{BytecodeTape, Dual, Float};
2
3use crate::linalg::{lu_back_solve, lu_factor, lu_solve};
4
5/// Partition a full Jacobian `J_F` (m × (m+n)) into `F_z` (m × m) and `F_x` (m × n).
6///
7/// `num_states` is `m`, the number of state variables (first `m` columns → `F_z`).
8fn partition_jacobian<F: Float>(jac: &[Vec<F>], num_states: usize) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
9    let m = num_states;
10    let mut f_z = Vec::with_capacity(m);
11    let mut f_x = Vec::with_capacity(m);
12    for row in jac {
13        f_z.push(row[..m].to_vec());
14        f_x.push(row[m..].to_vec());
15    }
16    (f_z, f_x)
17}
18
19/// Transpose an m × n matrix stored as `Vec<Vec<F>>`.
20fn transpose<F: Float>(mat: &[Vec<F>]) -> Vec<Vec<F>> {
21    if mat.is_empty() {
22        return vec![];
23    }
24    let rows = mat.len();
25    let cols = mat[0].len();
26    let mut result = vec![vec![F::zero(); rows]; cols];
27    for i in 0..rows {
28        for j in 0..cols {
29            result[j][i] = mat[i][j];
30        }
31    }
32    result
33}
34
35/// Validate inputs shared by all implicit differentiation functions.
36fn validate_inputs<F: Float>(tape: &BytecodeTape<F>, z_star: &[F], x: &[F], num_states: usize) {
37    assert_eq!(
38        z_star.len(),
39        num_states,
40        "z_star length ({}) must equal num_states ({})",
41        z_star.len(),
42        num_states
43    );
44    assert_eq!(
45        tape.num_inputs(),
46        num_states + x.len(),
47        "tape.num_inputs() ({}) must equal num_states + x.len() ({})",
48        tape.num_inputs(),
49        num_states + x.len()
50    );
51    assert_eq!(
52        tape.num_outputs(),
53        num_states,
54        "tape.num_outputs() ({}) must equal num_states ({}) — IFT requires F: R^(m+n) → R^m to be square in the state block",
55        tape.num_outputs(),
56        num_states
57    );
58}
59
60/// Build concatenated input `[z_star..., x...]` and compute the full Jacobian,
61/// partitioned into `(F_z, F_x)`.
62fn compute_partitioned_jacobian<F: Float>(
63    tape: &mut BytecodeTape<F>,
64    z_star: &[F],
65    x: &[F],
66    num_states: usize,
67) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
68    let mut inputs = Vec::with_capacity(z_star.len() + x.len());
69    inputs.extend_from_slice(z_star);
70    inputs.extend_from_slice(x);
71
72    // Debug check: warn if residual is not near zero
73    #[cfg(debug_assertions)]
74    {
75        tape.forward(&inputs);
76        let residual = tape.output_values();
77        let norm_sq: F = residual.iter().fold(F::zero(), |acc, &v| acc + v * v);
78        let norm = norm_sq.sqrt();
79        let threshold = F::from(1e-6).unwrap_or_else(|| F::epsilon());
80        if norm > threshold {
81            eprintln!(
82                "WARNING: implicit differentiation called with ||F(z*, x)|| = {:?} > 1e-6. \
83                 Derivatives may be meaningless if z* is not a root.",
84                norm.to_f64()
85            );
86        }
87    }
88
89    let jac = tape.jacobian(&inputs);
90    partition_jacobian(&jac, num_states)
91}
92
93/// Compute the full implicit Jacobian `dz*/dx` (m × n matrix).
94///
95/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
96/// computes `dz*/dx = -F_z^{-1} · F_x` via the Implicit Function Theorem.
97///
98/// The first `num_states` tape inputs are state variables `z`, the remaining are
99/// parameters `x`.
100///
101/// Returns `None` if `F_z` is singular.
102pub fn implicit_jacobian<F: Float>(
103    tape: &mut BytecodeTape<F>,
104    z_star: &[F],
105    x: &[F],
106    num_states: usize,
107) -> Option<Vec<Vec<F>>> {
108    validate_inputs(tape, z_star, x, num_states);
109    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
110
111    let m = num_states;
112    let n = x.len();
113
114    // LU-factorize F_z once, then solve for each column of -F_x
115    let factors = lu_factor(&f_z)?;
116
117    // Build result column by column: solve F_z · col_j = -F_x[:, j]
118    let mut result = vec![vec![F::zero(); n]; m];
119    for j in 0..n {
120        let neg_col: Vec<F> = (0..m).map(|i| F::zero() - f_x[i][j]).collect();
121        let col = lu_back_solve(&factors, &neg_col);
122        for i in 0..m {
123            result[i][j] = col[i];
124        }
125    }
126
127    Some(result)
128}
129
130/// Compute the implicit tangent `dz*/dx · x_dot` (m-vector).
131///
132/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
133/// computes the directional derivative `dz*/dx · x_dot = -F_z^{-1} · (F_x · x_dot)`.
134///
135/// This solves a single linear system rather than computing the full Jacobian,
136/// which is more efficient when only one direction is needed.
137///
138/// Returns `None` if `F_z` is singular.
139pub fn implicit_tangent<F: Float>(
140    tape: &mut BytecodeTape<F>,
141    z_star: &[F],
142    x: &[F],
143    x_dot: &[F],
144    num_states: usize,
145) -> Option<Vec<F>> {
146    assert_eq!(
147        x_dot.len(),
148        x.len(),
149        "x_dot length ({}) must equal x length ({})",
150        x_dot.len(),
151        x.len()
152    );
153    validate_inputs(tape, z_star, x, num_states);
154    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
155
156    let m = num_states;
157    let n = x.len();
158
159    // Compute F_x · x_dot (matrix-vector product)
160    let mut fx_xdot = vec![F::zero(); m];
161    for i in 0..m {
162        for j in 0..n {
163            fx_xdot[i] = fx_xdot[i] + f_x[i][j] * x_dot[j];
164        }
165    }
166
167    // Negate: rhs = -(F_x · x_dot)
168    let neg_fx_xdot: Vec<F> = fx_xdot.iter().map(|&v| F::zero() - v).collect();
169
170    // Solve F_z · z_dot = -(F_x · x_dot)
171    lu_solve(&f_z, &neg_fx_xdot)
172}
173
174/// Compute the implicit adjoint `(dz*/dx)^T · z_bar` (n-vector).
175///
176/// Given a multi-output residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`,
177/// computes `x_bar = -F_x^T · (F_z^{-T} · z_bar)`.
178///
179/// This is the reverse-mode (adjoint) form, useful when `n > m` or when
180/// propagating gradients backward through an implicit layer.
181///
182/// Returns `None` if `F_z` is singular.
183pub fn implicit_adjoint<F: Float>(
184    tape: &mut BytecodeTape<F>,
185    z_star: &[F],
186    x: &[F],
187    z_bar: &[F],
188    num_states: usize,
189) -> Option<Vec<F>> {
190    assert_eq!(
191        z_bar.len(),
192        num_states,
193        "z_bar length ({}) must equal num_states ({})",
194        z_bar.len(),
195        num_states
196    );
197    validate_inputs(tape, z_star, x, num_states);
198    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
199
200    let m = num_states;
201    let n = x.len();
202
203    // Solve F_z^T · lambda = z_bar
204    let f_z_t = transpose(&f_z);
205    let lambda = lu_solve(&f_z_t, z_bar)?;
206
207    // Compute x_bar = -F_x^T · lambda
208    let f_x_t = transpose(&f_x);
209    let mut x_bar = vec![F::zero(); n];
210    for j in 0..n {
211        for i in 0..m {
212            x_bar[j] = x_bar[j] - f_x_t[j][i] * lambda[i];
213        }
214    }
215
216    Some(x_bar)
217}
218
219/// Compute the implicit Hessian-vector-vector product `d²z*/dx² · v · w` (m-vector).
220///
221/// Given a residual tape `F: R^(m+n) → R^m` with `F(z*, x) = 0`, computes the
222/// second-order sensitivity by differentiating the IFT identity twice:
223///
224///   `F_z · h + [ṗ^T · Hess(F_i) · ẇ]_i = 0`
225///
226/// where `ṗ = [dz*/dx · v; v]`, `ẇ = [dz*/dx · w; w]`, and `h = d²z*/dx² · v · w`.
227///
228/// Uses nested `Dual<Dual<F>>` forward passes to compute the second-order correction
229/// in a single O(tape_length) pass per direction pair.
230///
231/// Returns `None` if `F_z` is singular.
232pub fn implicit_hvp<F: Float>(
233    tape: &mut BytecodeTape<F>,
234    z_star: &[F],
235    x: &[F],
236    v: &[F],
237    w: &[F],
238    num_states: usize,
239) -> Option<Vec<F>> {
240    let n = x.len();
241    let m = num_states;
242    assert_eq!(
243        v.len(),
244        n,
245        "v length ({}) must equal x length ({})",
246        v.len(),
247        n
248    );
249    assert_eq!(
250        w.len(),
251        n,
252        "w length ({}) must equal x length ({})",
253        w.len(),
254        n
255    );
256    validate_inputs(tape, z_star, x, num_states);
257
258    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
259    let factors = lu_factor(&f_z)?;
260
261    // First-order sensitivities: ż_v = -F_z^{-1} · (F_x · v)
262    let mut fx_v = vec![F::zero(); m];
263    let mut fx_w = vec![F::zero(); m];
264    for i in 0..m {
265        for j in 0..n {
266            fx_v[i] = fx_v[i] + f_x[i][j] * v[j];
267            fx_w[i] = fx_w[i] + f_x[i][j] * w[j];
268        }
269    }
270    let neg_fx_v: Vec<F> = fx_v.iter().map(|&val| F::zero() - val).collect();
271    let neg_fx_w: Vec<F> = fx_w.iter().map(|&val| F::zero() - val).collect();
272    let z_dot_v = lu_back_solve(&factors, &neg_fx_v);
273    let z_dot_w = lu_back_solve(&factors, &neg_fx_w);
274
275    // Build Dual<Dual<F>> inputs for nested forward pass
276    // ṗ = [ż_v; v], ẇ = [ż_w; w]
277    // Input j: Dual { re: Dual(u_j, ṗ_j), eps: Dual(ẇ_j, 0) }
278    let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
279    for i in 0..m {
280        dd_inputs.push(Dual::new(
281            Dual::new(z_star[i], z_dot_v[i]),
282            Dual::new(z_dot_w[i], F::zero()),
283        ));
284    }
285    for j in 0..n {
286        dd_inputs.push(Dual::new(Dual::new(x[j], v[j]), Dual::new(w[j], F::zero())));
287    }
288
289    let mut buf = Vec::new();
290    tape.forward_tangent(&dd_inputs, &mut buf);
291
292    // Extract second-order correction: buf[out_idx].eps.eps for each output
293    let out_indices = tape.all_output_indices();
294    let mut rhs = Vec::with_capacity(m);
295    for &idx in out_indices {
296        rhs.push(buf[idx as usize].eps.eps);
297    }
298
299    // Solve F_z · h = -rhs
300    let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
301    let h = lu_back_solve(&factors, &neg_rhs);
302
303    Some(h)
304}
305
306/// Compute the full implicit Hessian tensor `d²z*/dx²` (m × n × n).
307///
308/// Returns `result[i][j][k]` = ∂²z*_i / (∂x_j ∂x_k). The tensor is symmetric
309/// in the last two indices (j, k).
310///
311/// Cost: `n(n+1)/2` nested `Dual<Dual<F>>` forward passes plus `n(n+1)/2` back-solves,
312/// all sharing a single LU factorization of `F_z`.
313///
314/// Returns `None` if `F_z` is singular.
315pub fn implicit_hessian<F: Float>(
316    tape: &mut BytecodeTape<F>,
317    z_star: &[F],
318    x: &[F],
319    num_states: usize,
320) -> Option<Vec<Vec<Vec<F>>>> {
321    let n = x.len();
322    let m = num_states;
323    validate_inputs(tape, z_star, x, num_states);
324
325    let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
326    let factors = lu_factor(&f_z)?;
327
328    // First-order sensitivity columns: S[:,j] = -F_z^{-1} · F_x[:,j]
329    let mut sens_cols: Vec<Vec<F>> = Vec::with_capacity(n);
330    for j in 0..n {
331        let neg_col: Vec<F> = f_x.iter().map(|row| F::zero() - row[j]).collect();
332        sens_cols.push(lu_back_solve(&factors, &neg_col));
333    }
334
335    let out_indices = tape.all_output_indices();
336    let mut result = vec![vec![vec![F::zero(); n]; n]; m];
337    let mut buf: Vec<Dual<Dual<F>>> = Vec::new();
338
339    for j in 0..n {
340        for k in j..n {
341            // ṗ = [S[:,j]; e_j], ẇ = [S[:,k]; e_k]
342            let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
343            for i in 0..m {
344                dd_inputs.push(Dual::new(
345                    Dual::new(z_star[i], sens_cols[j][i]),
346                    Dual::new(sens_cols[k][i], F::zero()),
347                ));
348            }
349            for (l, &x_l) in x.iter().enumerate() {
350                let p_l = if l == j { F::one() } else { F::zero() };
351                let w_l = if l == k { F::one() } else { F::zero() };
352                dd_inputs.push(Dual::new(Dual::new(x_l, p_l), Dual::new(w_l, F::zero())));
353            }
354
355            tape.forward_tangent(&dd_inputs, &mut buf);
356
357            // Extract RHS and solve
358            let mut rhs = Vec::with_capacity(m);
359            for &idx in out_indices {
360                rhs.push(buf[idx as usize].eps.eps);
361            }
362            let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
363            let h = lu_back_solve(&factors, &neg_rhs);
364
365            for i in 0..m {
366                result[i][j][k] = h[i];
367                result[i][k][j] = h[i]; // Symmetric
368            }
369        }
370    }
371
372    Some(result)
373}