1use echidna::{BytecodeTape, Dual, Float};
2
3use crate::linalg::{lu_back_solve, lu_factor, lu_solve};
4
5fn partition_jacobian<F: Float>(jac: &[Vec<F>], num_states: usize) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
9 let m = num_states;
10 let mut f_z = Vec::with_capacity(m);
11 let mut f_x = Vec::with_capacity(m);
12 for row in jac {
13 f_z.push(row[..m].to_vec());
14 f_x.push(row[m..].to_vec());
15 }
16 (f_z, f_x)
17}
18
19fn transpose<F: Float>(mat: &[Vec<F>]) -> Vec<Vec<F>> {
21 if mat.is_empty() {
22 return vec![];
23 }
24 let rows = mat.len();
25 let cols = mat[0].len();
26 let mut result = vec![vec![F::zero(); rows]; cols];
27 for i in 0..rows {
28 for j in 0..cols {
29 result[j][i] = mat[i][j];
30 }
31 }
32 result
33}
34
35fn validate_inputs<F: Float>(tape: &BytecodeTape<F>, z_star: &[F], x: &[F], num_states: usize) {
37 assert_eq!(
38 z_star.len(),
39 num_states,
40 "z_star length ({}) must equal num_states ({})",
41 z_star.len(),
42 num_states
43 );
44 assert_eq!(
45 tape.num_inputs(),
46 num_states + x.len(),
47 "tape.num_inputs() ({}) must equal num_states + x.len() ({})",
48 tape.num_inputs(),
49 num_states + x.len()
50 );
51 assert_eq!(
52 tape.num_outputs(),
53 num_states,
54 "tape.num_outputs() ({}) must equal num_states ({}) — IFT requires F: R^(m+n) → R^m to be square in the state block",
55 tape.num_outputs(),
56 num_states
57 );
58}
59
60fn compute_partitioned_jacobian<F: Float>(
63 tape: &mut BytecodeTape<F>,
64 z_star: &[F],
65 x: &[F],
66 num_states: usize,
67) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
68 let mut inputs = Vec::with_capacity(z_star.len() + x.len());
69 inputs.extend_from_slice(z_star);
70 inputs.extend_from_slice(x);
71
72 #[cfg(debug_assertions)]
74 {
75 tape.forward(&inputs);
76 let residual = tape.output_values();
77 let norm_sq: F = residual.iter().fold(F::zero(), |acc, &v| acc + v * v);
78 let norm = norm_sq.sqrt();
79 let threshold = F::from(1e-6).unwrap_or_else(|| F::epsilon());
80 if norm > threshold {
81 eprintln!(
82 "WARNING: implicit differentiation called with ||F(z*, x)|| = {:?} > 1e-6. \
83 Derivatives may be meaningless if z* is not a root.",
84 norm.to_f64()
85 );
86 }
87 }
88
89 let jac = tape.jacobian(&inputs);
90 partition_jacobian(&jac, num_states)
91}
92
93pub fn implicit_jacobian<F: Float>(
103 tape: &mut BytecodeTape<F>,
104 z_star: &[F],
105 x: &[F],
106 num_states: usize,
107) -> Option<Vec<Vec<F>>> {
108 validate_inputs(tape, z_star, x, num_states);
109 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
110
111 let m = num_states;
112 let n = x.len();
113
114 let factors = lu_factor(&f_z)?;
116
117 let mut result = vec![vec![F::zero(); n]; m];
119 for j in 0..n {
120 let neg_col: Vec<F> = (0..m).map(|i| F::zero() - f_x[i][j]).collect();
121 let col = lu_back_solve(&factors, &neg_col);
122 for i in 0..m {
123 result[i][j] = col[i];
124 }
125 }
126
127 Some(result)
128}
129
130pub fn implicit_tangent<F: Float>(
140 tape: &mut BytecodeTape<F>,
141 z_star: &[F],
142 x: &[F],
143 x_dot: &[F],
144 num_states: usize,
145) -> Option<Vec<F>> {
146 assert_eq!(
147 x_dot.len(),
148 x.len(),
149 "x_dot length ({}) must equal x length ({})",
150 x_dot.len(),
151 x.len()
152 );
153 validate_inputs(tape, z_star, x, num_states);
154 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
155
156 let m = num_states;
157 let n = x.len();
158
159 let mut fx_xdot = vec![F::zero(); m];
161 for i in 0..m {
162 for j in 0..n {
163 fx_xdot[i] = fx_xdot[i] + f_x[i][j] * x_dot[j];
164 }
165 }
166
167 let neg_fx_xdot: Vec<F> = fx_xdot.iter().map(|&v| F::zero() - v).collect();
169
170 lu_solve(&f_z, &neg_fx_xdot)
172}
173
174pub fn implicit_adjoint<F: Float>(
184 tape: &mut BytecodeTape<F>,
185 z_star: &[F],
186 x: &[F],
187 z_bar: &[F],
188 num_states: usize,
189) -> Option<Vec<F>> {
190 assert_eq!(
191 z_bar.len(),
192 num_states,
193 "z_bar length ({}) must equal num_states ({})",
194 z_bar.len(),
195 num_states
196 );
197 validate_inputs(tape, z_star, x, num_states);
198 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
199
200 let m = num_states;
201 let n = x.len();
202
203 let f_z_t = transpose(&f_z);
205 let lambda = lu_solve(&f_z_t, z_bar)?;
206
207 let f_x_t = transpose(&f_x);
209 let mut x_bar = vec![F::zero(); n];
210 for j in 0..n {
211 for i in 0..m {
212 x_bar[j] = x_bar[j] - f_x_t[j][i] * lambda[i];
213 }
214 }
215
216 Some(x_bar)
217}
218
219pub fn implicit_hvp<F: Float>(
233 tape: &mut BytecodeTape<F>,
234 z_star: &[F],
235 x: &[F],
236 v: &[F],
237 w: &[F],
238 num_states: usize,
239) -> Option<Vec<F>> {
240 let n = x.len();
241 let m = num_states;
242 assert_eq!(
243 v.len(),
244 n,
245 "v length ({}) must equal x length ({})",
246 v.len(),
247 n
248 );
249 assert_eq!(
250 w.len(),
251 n,
252 "w length ({}) must equal x length ({})",
253 w.len(),
254 n
255 );
256 validate_inputs(tape, z_star, x, num_states);
257
258 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
259 let factors = lu_factor(&f_z)?;
260
261 let mut fx_v = vec![F::zero(); m];
263 let mut fx_w = vec![F::zero(); m];
264 for i in 0..m {
265 for j in 0..n {
266 fx_v[i] = fx_v[i] + f_x[i][j] * v[j];
267 fx_w[i] = fx_w[i] + f_x[i][j] * w[j];
268 }
269 }
270 let neg_fx_v: Vec<F> = fx_v.iter().map(|&val| F::zero() - val).collect();
271 let neg_fx_w: Vec<F> = fx_w.iter().map(|&val| F::zero() - val).collect();
272 let z_dot_v = lu_back_solve(&factors, &neg_fx_v);
273 let z_dot_w = lu_back_solve(&factors, &neg_fx_w);
274
275 let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
279 for i in 0..m {
280 dd_inputs.push(Dual::new(
281 Dual::new(z_star[i], z_dot_v[i]),
282 Dual::new(z_dot_w[i], F::zero()),
283 ));
284 }
285 for j in 0..n {
286 dd_inputs.push(Dual::new(Dual::new(x[j], v[j]), Dual::new(w[j], F::zero())));
287 }
288
289 let mut buf = Vec::new();
290 tape.forward_tangent(&dd_inputs, &mut buf);
291
292 let out_indices = tape.all_output_indices();
294 let mut rhs = Vec::with_capacity(m);
295 for &idx in out_indices {
296 rhs.push(buf[idx as usize].eps.eps);
297 }
298
299 let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
301 let h = lu_back_solve(&factors, &neg_rhs);
302
303 Some(h)
304}
305
306pub fn implicit_hessian<F: Float>(
316 tape: &mut BytecodeTape<F>,
317 z_star: &[F],
318 x: &[F],
319 num_states: usize,
320) -> Option<Vec<Vec<Vec<F>>>> {
321 let n = x.len();
322 let m = num_states;
323 validate_inputs(tape, z_star, x, num_states);
324
325 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
326 let factors = lu_factor(&f_z)?;
327
328 let mut sens_cols: Vec<Vec<F>> = Vec::with_capacity(n);
330 for j in 0..n {
331 let neg_col: Vec<F> = f_x.iter().map(|row| F::zero() - row[j]).collect();
332 sens_cols.push(lu_back_solve(&factors, &neg_col));
333 }
334
335 let out_indices = tape.all_output_indices();
336 let mut result = vec![vec![vec![F::zero(); n]; n]; m];
337 let mut buf: Vec<Dual<Dual<F>>> = Vec::new();
338
339 for j in 0..n {
340 for k in j..n {
341 let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
343 for i in 0..m {
344 dd_inputs.push(Dual::new(
345 Dual::new(z_star[i], sens_cols[j][i]),
346 Dual::new(sens_cols[k][i], F::zero()),
347 ));
348 }
349 for (l, &x_l) in x.iter().enumerate() {
350 let p_l = if l == j { F::one() } else { F::zero() };
351 let w_l = if l == k { F::one() } else { F::zero() };
352 dd_inputs.push(Dual::new(Dual::new(x_l, p_l), Dual::new(w_l, F::zero())));
353 }
354
355 tape.forward_tangent(&dd_inputs, &mut buf);
356
357 let mut rhs = Vec::with_capacity(m);
359 for &idx in out_indices {
360 rhs.push(buf[idx as usize].eps.eps);
361 }
362 let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
363 let h = lu_back_solve(&factors, &neg_rhs);
364
365 for i in 0..m {
366 result[i][j][k] = h[i];
367 result[i][k][j] = h[i]; }
369 }
370 }
371
372 Some(result)
373}