1use std::fmt;
2
3use echidna::{BytecodeTape, Dual, Float};
4
5use crate::linalg::{lu_back_solve, lu_factor, lu_solve};
6
7#[non_exhaustive]
12#[derive(Debug, Clone)]
13pub enum ImplicitError {
14 Singular,
33 DimensionMismatch {
43 field: &'static str,
44 expected: usize,
45 actual: usize,
46 },
47}
48
49impl fmt::Display for ImplicitError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 ImplicitError::Singular => {
53 write!(
54 f,
55 "implicit: F_z is singular, ill-conditioned, or produced a non-finite solve"
56 )
57 }
58 ImplicitError::DimensionMismatch {
59 field,
60 expected,
61 actual,
62 } => {
63 write!(
64 f,
65 "implicit: dimension mismatch for `{field}` (expected {expected}, got {actual})"
66 )
67 }
68 }
69 }
70}
71
72impl std::error::Error for ImplicitError {}
73
74echidna::assert_send_sync!(ImplicitError);
75
76fn partition_jacobian<F: Float>(jac: &[Vec<F>], num_states: usize) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
80 let m = num_states;
81 let mut f_z = Vec::with_capacity(m);
82 let mut f_x = Vec::with_capacity(m);
83 for row in jac {
84 f_z.push(row[..m].to_vec());
85 f_x.push(row[m..].to_vec());
86 }
87 (f_z, f_x)
88}
89
90fn transpose<F: Float>(mat: &[Vec<F>]) -> Vec<Vec<F>> {
92 if mat.is_empty() {
93 return vec![];
94 }
95 let rows = mat.len();
96 let cols = mat[0].len();
97 let mut result = vec![vec![F::zero(); rows]; cols];
98 for i in 0..rows {
99 for j in 0..cols {
100 result[j][i] = mat[i][j];
101 }
102 }
103 result
104}
105
106fn validate_inputs<F: Float>(tape: &BytecodeTape<F>, z_star: &[F], x: &[F], num_states: usize) {
108 assert_eq!(
109 z_star.len(),
110 num_states,
111 "z_star length ({}) must equal num_states ({})",
112 z_star.len(),
113 num_states
114 );
115 assert_eq!(
116 tape.num_inputs(),
117 num_states + x.len(),
118 "tape.num_inputs() ({}) must equal num_states + x.len() ({})",
119 tape.num_inputs(),
120 num_states + x.len()
121 );
122 assert_eq!(
123 tape.num_outputs(),
124 num_states,
125 "tape.num_outputs() ({}) must equal num_states ({}) — IFT requires F: R^(m+n) → R^m to be square in the state block",
126 tape.num_outputs(),
127 num_states
128 );
129}
130
131fn compute_partitioned_jacobian<F: Float>(
134 tape: &mut BytecodeTape<F>,
135 z_star: &[F],
136 x: &[F],
137 num_states: usize,
138) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
139 let mut inputs = Vec::with_capacity(z_star.len() + x.len());
140 inputs.extend_from_slice(z_star);
141 inputs.extend_from_slice(x);
142
143 #[cfg(debug_assertions)]
145 {
146 tape.forward(&inputs);
147 let residual = tape.output_values();
148 let norm_sq: F = residual.iter().fold(F::zero(), |acc, &v| acc + v * v);
149 let norm = norm_sq.sqrt();
150 let threshold = F::from(1e-6).unwrap_or_else(|| F::epsilon());
151 if norm > threshold {
152 eprintln!(
153 "WARNING: implicit differentiation called with ||F(z*, x)|| = {:?} > 1e-6. \
154 Derivatives may be meaningless if z* is not a root.",
155 norm.to_f64()
156 );
157 }
158 }
159
160 let jac = tape.jacobian(&inputs);
161 partition_jacobian(&jac, num_states)
162}
163
164pub fn implicit_jacobian<F: Float>(
174 tape: &mut BytecodeTape<F>,
175 z_star: &[F],
176 x: &[F],
177 num_states: usize,
178) -> Result<Vec<Vec<F>>, ImplicitError> {
179 validate_inputs(tape, z_star, x, num_states);
180 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
181
182 let m = num_states;
183 let n = x.len();
184
185 let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
187
188 let mut result = vec![vec![F::zero(); n]; m];
190 for j in 0..n {
191 let neg_col: Vec<F> = (0..m).map(|i| F::zero() - f_x[i][j]).collect();
192 let col = lu_back_solve(&factors, &neg_col);
193
194 if col.iter().any(|v| !v.is_finite()) {
201 return Err(ImplicitError::Singular);
202 }
203
204 for i in 0..m {
205 result[i][j] = col[i];
206 }
207 }
208
209 Ok(result)
210}
211
212pub fn implicit_tangent<F: Float>(
222 tape: &mut BytecodeTape<F>,
223 z_star: &[F],
224 x: &[F],
225 x_dot: &[F],
226 num_states: usize,
227) -> Result<Vec<F>, ImplicitError> {
228 if x_dot.len() != x.len() {
229 return Err(ImplicitError::DimensionMismatch {
230 field: "x_dot",
231 expected: x.len(),
232 actual: x_dot.len(),
233 });
234 }
235 validate_inputs(tape, z_star, x, num_states);
236 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
237
238 let m = num_states;
239 let n = x.len();
240
241 let mut fx_xdot = vec![F::zero(); m];
243 for i in 0..m {
244 for j in 0..n {
245 fx_xdot[i] = fx_xdot[i] + f_x[i][j] * x_dot[j];
246 }
247 }
248
249 let neg_fx_xdot: Vec<F> = fx_xdot.iter().map(|&v| F::zero() - v).collect();
251
252 let sol = lu_solve(&f_z, &neg_fx_xdot).ok_or(ImplicitError::Singular)?;
254
255 if sol.iter().any(|v| !v.is_finite()) {
262 return Err(ImplicitError::Singular);
263 }
264
265 Ok(sol)
266}
267
268pub fn implicit_adjoint<F: Float>(
278 tape: &mut BytecodeTape<F>,
279 z_star: &[F],
280 x: &[F],
281 z_bar: &[F],
282 num_states: usize,
283) -> Result<Vec<F>, ImplicitError> {
284 if z_bar.len() != num_states {
285 return Err(ImplicitError::DimensionMismatch {
286 field: "z_bar",
287 expected: num_states,
288 actual: z_bar.len(),
289 });
290 }
291 validate_inputs(tape, z_star, x, num_states);
292 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
293
294 let m = num_states;
295 let n = x.len();
296
297 let f_z_t = transpose(&f_z);
299 let lambda = lu_solve(&f_z_t, z_bar).ok_or(ImplicitError::Singular)?;
300
301 let f_x_t = transpose(&f_x);
303 let mut x_bar = vec![F::zero(); n];
304 for j in 0..n {
305 for i in 0..m {
306 x_bar[j] = x_bar[j] - f_x_t[j][i] * lambda[i];
307 }
308 }
309
310 if x_bar.iter().any(|v| !v.is_finite()) {
315 return Err(ImplicitError::Singular);
316 }
317
318 Ok(x_bar)
319}
320
321pub fn implicit_hvp<F: Float>(
335 tape: &mut BytecodeTape<F>,
336 z_star: &[F],
337 x: &[F],
338 v: &[F],
339 w: &[F],
340 num_states: usize,
341) -> Result<Vec<F>, ImplicitError> {
342 let n = x.len();
343 let m = num_states;
344 if v.len() != n {
345 return Err(ImplicitError::DimensionMismatch {
346 field: "v",
347 expected: n,
348 actual: v.len(),
349 });
350 }
351 if w.len() != n {
352 return Err(ImplicitError::DimensionMismatch {
353 field: "w",
354 expected: n,
355 actual: w.len(),
356 });
357 }
358 validate_inputs(tape, z_star, x, num_states);
359
360 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
361 let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
362
363 let mut fx_v = vec![F::zero(); m];
365 let mut fx_w = vec![F::zero(); m];
366 for i in 0..m {
367 for j in 0..n {
368 fx_v[i] = fx_v[i] + f_x[i][j] * v[j];
369 fx_w[i] = fx_w[i] + f_x[i][j] * w[j];
370 }
371 }
372 let neg_fx_v: Vec<F> = fx_v.iter().map(|&val| F::zero() - val).collect();
373 let neg_fx_w: Vec<F> = fx_w.iter().map(|&val| F::zero() - val).collect();
374 let z_dot_v = lu_back_solve(&factors, &neg_fx_v);
375 let z_dot_w = lu_back_solve(&factors, &neg_fx_w);
376
377 let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
381 for i in 0..m {
382 dd_inputs.push(Dual::new(
383 Dual::new(z_star[i], z_dot_v[i]),
384 Dual::new(z_dot_w[i], F::zero()),
385 ));
386 }
387 for j in 0..n {
388 dd_inputs.push(Dual::new(Dual::new(x[j], v[j]), Dual::new(w[j], F::zero())));
389 }
390
391 let mut buf = Vec::new();
392 tape.forward_tangent(&dd_inputs, &mut buf);
393
394 let out_indices = tape.all_output_indices();
396 let mut rhs = Vec::with_capacity(m);
397 for &idx in out_indices {
398 rhs.push(buf[idx as usize].eps.eps);
399 }
400
401 let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
403 let h = lu_back_solve(&factors, &neg_rhs);
404
405 if h.iter().any(|v| !v.is_finite()) {
413 return Err(ImplicitError::Singular);
414 }
415
416 Ok(h)
417}
418
419pub fn implicit_hessian<F: Float>(
429 tape: &mut BytecodeTape<F>,
430 z_star: &[F],
431 x: &[F],
432 num_states: usize,
433) -> Result<Vec<Vec<Vec<F>>>, ImplicitError> {
434 let n = x.len();
435 let m = num_states;
436 validate_inputs(tape, z_star, x, num_states);
437
438 let (f_z, f_x) = compute_partitioned_jacobian(tape, z_star, x, num_states);
439 let factors = lu_factor(&f_z).ok_or(ImplicitError::Singular)?;
440
441 let mut sens_cols: Vec<Vec<F>> = Vec::with_capacity(n);
443 for j in 0..n {
444 let neg_col: Vec<F> = f_x.iter().map(|row| F::zero() - row[j]).collect();
445 sens_cols.push(lu_back_solve(&factors, &neg_col));
446 }
447
448 let out_indices = tape.all_output_indices();
449 let mut result = vec![vec![vec![F::zero(); n]; n]; m];
450 let mut buf: Vec<Dual<Dual<F>>> = Vec::new();
451
452 for j in 0..n {
453 for k in j..n {
454 let mut dd_inputs: Vec<Dual<Dual<F>>> = Vec::with_capacity(m + n);
456 for i in 0..m {
457 dd_inputs.push(Dual::new(
458 Dual::new(z_star[i], sens_cols[j][i]),
459 Dual::new(sens_cols[k][i], F::zero()),
460 ));
461 }
462 for (l, &x_l) in x.iter().enumerate() {
463 let p_l = if l == j { F::one() } else { F::zero() };
464 let w_l = if l == k { F::one() } else { F::zero() };
465 dd_inputs.push(Dual::new(Dual::new(x_l, p_l), Dual::new(w_l, F::zero())));
466 }
467
468 tape.forward_tangent(&dd_inputs, &mut buf);
469
470 let mut rhs = Vec::with_capacity(m);
472 for &idx in out_indices {
473 rhs.push(buf[idx as usize].eps.eps);
474 }
475 let neg_rhs: Vec<F> = rhs.iter().map(|&val| F::zero() - val).collect();
476 let h = lu_back_solve(&factors, &neg_rhs);
477
478 if h.iter().any(|v| !v.is_finite()) {
483 return Err(ImplicitError::Singular);
484 }
485
486 for i in 0..m {
487 result[i][j][k] = h[i];
488 result[i][k][j] = h[i]; }
490 }
491 }
492
493 Ok(result)
494}