echidna_optim/piggyback.rs
1use std::fmt;
2
3use echidna::{BytecodeTape, Dual, Float};
4
5/// Reason a piggyback solve failed to converge.
6///
7/// Marked `#[non_exhaustive]` so future variants can be added without
8/// breaking exhaustive `match`es. Numeric fields use `f64` (cast via
9/// `Float::to_f64`) for uniform diagnostic output regardless of the
10/// solver's `F` type.
11#[non_exhaustive]
12#[derive(Debug, Clone)]
13pub enum PiggybackError {
14 /// The primal `z_{k+1} = G(z_k, x)` produced a non-finite norm
15 /// (relative-norm `||z_new - z||/(1 + ||z||)` is NaN/Inf), or
16 /// the primal vector itself contained non-finite components in
17 /// the forward-adjoint loop. `last_norm` is the primal-delta
18 /// relative norm at the detecting iteration: non-finite when
19 /// detection came from the norm check (the usual case);
20 /// finite — and itself diagnostic — when detection came from
21 /// the componentwise finite check (primal vector overflowed
22 /// mid-iteration while the step-to-step delta stayed bounded).
23 PrimalDivergence { iteration: usize, last_norm: f64 },
24 /// Primal stayed finite but the tangent
25 /// `ż_{k+1} = G_z · ż_k + G_x · ẋ` produced non-finite values.
26 /// Catches the ratio-converging case where the primal norm
27 /// remains bounded while individual tangent components overflow.
28 /// `last_norm` is the primal-delta relative norm at the
29 /// detecting iteration — **finite** by construction here (the
30 /// tangent-only divergence path takes the norm-finite branch
31 /// before the componentwise check fires); surfacing it tells
32 /// the caller the primal iteration was bounded while the JVP
33 /// overflowed.
34 TangentDivergence { iteration: usize, last_norm: f64 },
35 /// Adjoint `λ_{k+1} = G_z^T · λ_k + z̄` produced non-finite
36 /// values (norm or individual components). `last_norm` is the
37 /// adjoint-delta relative norm at the detecting iteration:
38 /// non-finite when detection came from the norm check; finite
39 /// when it came from the componentwise `lambda_new` check.
40 AdjointDivergence { iteration: usize, last_norm: f64 },
41 /// `piggyback_tangent_solve` reached `max_iter` without meeting
42 /// `tol`. `z_norm` is the final iteration's relative primal-delta
43 /// norm (`||z_new - z|| / (1 + ||z||)`) — a value just over `tol`
44 /// signals proximity to convergence; many orders over signals
45 /// stagnation. `iteration` equals `max_iter`.
46 IterationsExhaustedTangent { iteration: usize, z_norm: f64 },
47 /// `piggyback_adjoint_solve` reached `max_iter` without meeting
48 /// `tol`. `lam_norm` is the final iteration's relative adjoint-
49 /// delta norm (`||λ_new - λ|| / (1 + ||λ||)`).
50 IterationsExhaustedAdjoint { iteration: usize, lam_norm: f64 },
51 /// `piggyback_forward_adjoint_solve` reached `max_iter` without
52 /// meeting `tol` on both norms simultaneously. Each field is the
53 /// final iteration's relative norm for the corresponding stream.
54 IterationsExhaustedForwardAdjoint {
55 iteration: usize,
56 z_norm: f64,
57 lam_norm: f64,
58 },
59 /// A runtime-supplied vector argument to a public `*_solve` fn
60 /// had an unexpected length. `field` names the argument (e.g.
61 /// `"z_dot"`, `"z_bar"`), `expected` is the length the API
62 /// requires (typically `num_states` or `x.len()`), `actual` is
63 /// the length the caller supplied.
64 ///
65 /// Note: tape-shape contract mismatches
66 /// (`validate_step_tape`) and step-fn argument mismatches
67 /// (`piggyback_tangent_step[_with_buf]`) continue to panic —
68 /// those are programmer-contract violations, not recoverable
69 /// runtime failures.
70 DimensionMismatch {
71 field: &'static str,
72 expected: usize,
73 actual: usize,
74 },
75}
76
77impl fmt::Display for PiggybackError {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 match self {
80 PiggybackError::PrimalDivergence {
81 iteration,
82 last_norm,
83 } => {
84 write!(
85 f,
86 "piggyback: primal diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
87 )
88 }
89 PiggybackError::TangentDivergence {
90 iteration,
91 last_norm,
92 } => {
93 write!(
94 f,
95 "piggyback: tangent diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
96 )
97 }
98 PiggybackError::AdjointDivergence {
99 iteration,
100 last_norm,
101 } => {
102 write!(
103 f,
104 "piggyback: adjoint diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
105 )
106 }
107 PiggybackError::IterationsExhaustedTangent { iteration, z_norm } => {
108 write!(
109 f,
110 "piggyback: tangent solve reached max_iter = {iteration} (z_norm = {z_norm:.3e})"
111 )
112 }
113 PiggybackError::IterationsExhaustedAdjoint {
114 iteration,
115 lam_norm,
116 } => {
117 write!(
118 f,
119 "piggyback: adjoint solve reached max_iter = {iteration} (lam_norm = {lam_norm:.3e})"
120 )
121 }
122 PiggybackError::IterationsExhaustedForwardAdjoint {
123 iteration,
124 z_norm,
125 lam_norm,
126 } => {
127 write!(
128 f,
129 "piggyback: forward-adjoint solve reached max_iter = {iteration} (z_norm = {z_norm:.3e}, lam_norm = {lam_norm:.3e})"
130 )
131 }
132 PiggybackError::DimensionMismatch {
133 field,
134 expected,
135 actual,
136 } => {
137 write!(
138 f,
139 "piggyback: dimension mismatch for `{field}` (expected {expected}, got {actual})"
140 )
141 }
142 }
143 }
144}
145
146impl std::error::Error for PiggybackError {}
147
148echidna::assert_send_sync!(PiggybackError);
149
150/// Validate that a step tape G: R^(m+n) -> R^m has the expected shape.
151///
152/// Uses `assert_eq!` (panic) rather than `Result` because shape
153/// mismatches are programmer errors — calling `piggyback_*_solve` with
154/// an inconsistent tape is a contract violation, not a runtime
155/// numerical failure that callers should recover from.
156fn validate_step_tape<F: Float>(tape: &BytecodeTape<F>, z: &[F], x: &[F], num_states: usize) {
157 assert_eq!(z.len(), num_states);
158 assert_eq!(tape.num_inputs(), num_states + x.len());
159 assert_eq!(
160 tape.num_outputs(),
161 num_states,
162 "step tape must have num_outputs == num_states (G: R^(m+n) -> R^m)"
163 );
164}
165
166/// One tangent piggyback step through a fixed-point map G.
167///
168/// Given the iteration `z_{k+1} = G(z_k, x)`, computes both the primal step
169/// and the tangent propagation `ż_{k+1} = G_z · ż_k + G_x · ẋ` in a single
170/// forward pass using dual numbers.
171///
172/// Returns `(z_new, z_dot_new)`.
173pub fn piggyback_tangent_step<F: Float>(
174 step_tape: &BytecodeTape<F>,
175 z: &[F],
176 x: &[F],
177 z_dot: &[F],
178 x_dot: &[F],
179 num_states: usize,
180) -> (Vec<F>, Vec<F>) {
181 let mut buf = Vec::new();
182 piggyback_tangent_step_with_buf(step_tape, z, x, z_dot, x_dot, num_states, &mut buf)
183}
184
185/// One tangent piggyback step, reusing `buf` across calls.
186///
187/// Same as [`piggyback_tangent_step`] but avoids reallocating the internal
188/// dual-number buffer on each call.
189pub fn piggyback_tangent_step_with_buf<F: Float>(
190 step_tape: &BytecodeTape<F>,
191 z: &[F],
192 x: &[F],
193 z_dot: &[F],
194 x_dot: &[F],
195 num_states: usize,
196 buf: &mut Vec<Dual<F>>,
197) -> (Vec<F>, Vec<F>) {
198 validate_step_tape(step_tape, z, x, num_states);
199 let m = num_states;
200 let n = x.len();
201 assert_eq!(z_dot.len(), m, "z_dot length must equal num_states");
202 assert_eq!(x_dot.len(), n, "x_dot length must equal x length");
203
204 // Build dual inputs: [Dual(z_i, ż_i), ..., Dual(x_j, ẋ_j), ...]
205 let mut dual_inputs = Vec::with_capacity(m + n);
206 for i in 0..m {
207 dual_inputs.push(Dual::new(z[i], z_dot[i]));
208 }
209 for j in 0..n {
210 dual_inputs.push(Dual::new(x[j], x_dot[j]));
211 }
212
213 step_tape.forward_tangent(&dual_inputs, buf);
214
215 // Extract outputs: .re -> z_new, .eps -> z_dot_new
216 let out_indices = step_tape.all_output_indices();
217 let mut z_new = Vec::with_capacity(m);
218 let mut z_dot_new = Vec::with_capacity(m);
219 for &idx in out_indices {
220 let d = buf[idx as usize];
221 z_new.push(d.re);
222 z_dot_new.push(d.eps);
223 }
224
225 (z_new, z_dot_new)
226}
227
228/// Tangent piggyback solve: find fixed point z* = G(z*, x) and its tangent ż*.
229///
230/// Iterates the fixed-point map `z_{k+1} = G(z_k, x)` while simultaneously
231/// propagating tangents `ż_{k+1} = G_z · ż_k + G_x · ẋ`.
232///
233/// Returns `Ok((z_star, z_dot_star, iterations))` on convergence. Returns
234/// `Err(PiggybackError::PrimalDivergence)` when the primal norm becomes
235/// non-finite, `Err(PiggybackError::TangentDivergence)` when the primal
236/// stays finite but the tangent overflows (ratio-converging case), or
237/// `Err(PiggybackError::IterationsExhaustedTangent { iteration, z_norm })` when
238/// `max_iter` is reached without satisfying `tol`.
239pub fn piggyback_tangent_solve<F: Float>(
240 step_tape: &BytecodeTape<F>,
241 z0: &[F],
242 x: &[F],
243 x_dot: &[F],
244 num_states: usize,
245 max_iter: usize,
246 tol: F,
247) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
248 // Runtime vector-length check at solve-level: surfaces dimension
249 // mismatches as `Err` before the first iteration dispatches to the
250 // step fn (which still panics on bad input as a contract-level
251 // guarantee).
252 if x_dot.len() != x.len() {
253 return Err(PiggybackError::DimensionMismatch {
254 field: "x_dot",
255 expected: x.len(),
256 actual: x_dot.len(),
257 });
258 }
259 let m = num_states;
260 let mut z = z0.to_vec();
261 let mut z_dot = vec![F::zero(); m];
262 let mut buf = Vec::new();
263 let mut last_norm: f64 = f64::NAN;
264
265 for k in 0..max_iter {
266 let (z_new, z_dot_new) =
267 piggyback_tangent_step_with_buf(step_tape, &z, x, &z_dot, x_dot, num_states, &mut buf);
268
269 // Relative convergence: ||z_new - z|| / (1 + ||z||)
270 let mut delta_sq = F::zero();
271 let mut z_sq = F::zero();
272 for i in 0..m {
273 let d = z_new[i] - z[i];
274 delta_sq = delta_sq + d * d;
275 z_sq = z_sq + z[i] * z[i];
276 }
277 let norm = delta_sq.sqrt() / (F::one() + z_sq.sqrt());
278 // Variant-mapping order: norm-check first → PrimalDivergence;
279 // tangent-finite check second → TangentDivergence. A non-finite
280 // primal naturally produces a non-finite norm, so it falls into
281 // PrimalDivergence by detection priority.
282 if !norm.is_finite() {
283 return Err(PiggybackError::PrimalDivergence {
284 iteration: k,
285 last_norm: norm.to_f64().unwrap_or(f64::NAN),
286 });
287 }
288 // Detect tangent divergence even when the primal `z_new` itself is
289 // finite: the JVP iteration `z_dot_{k+1} = G_z·z_dot_k + G_x·x_dot`
290 // can produce Inf/NaN tangents that a primal-only norm check misses.
291 if !z_dot_new.iter().all(|v| v.is_finite()) {
292 // `norm` is guaranteed finite here — the `!norm.is_finite()`
293 // branch above would have returned `PrimalDivergence` first.
294 // The debug_assert guards against a future refactor that
295 // reorders these checks and silently invalidates the
296 // `TangentDivergence::last_norm` docstring's "finite by
297 // construction" promise.
298 debug_assert!(
299 norm.is_finite(),
300 "TangentDivergence path must see a finite primal norm"
301 );
302 return Err(PiggybackError::TangentDivergence {
303 iteration: k,
304 last_norm: norm.to_f64().unwrap_or(f64::NAN),
305 });
306 }
307 last_norm = norm.to_f64().unwrap_or(f64::NAN);
308 if norm < tol {
309 return Ok((z_new, z_dot_new, k + 1));
310 }
311
312 z = z_new;
313 z_dot = z_dot_new;
314 }
315
316 Err(PiggybackError::IterationsExhaustedTangent {
317 iteration: max_iter,
318 z_norm: last_norm,
319 })
320}
321
322/// Adjoint piggyback solve at a converged fixed point z* = G(z*, x).
323///
324/// Iterates the adjoint fixed-point equation `λ_{k+1} = G_z^T · λ_k + z̄`
325/// using reverse-mode sweeps through the step tape. At convergence, returns
326/// `x̄ = G_x^T · λ*`.
327///
328/// Requires z* to already be computed (e.g. by the primal solver).
329/// The iteration converges when G is a contraction (‖G_z‖ < 1).
330///
331/// Returns `Ok((x_bar, iterations))` on convergence. Returns
332/// `Err(PiggybackError::AdjointDivergence)` when the adjoint norm is
333/// non-finite or `lambda_new` overflows (ratio-converging case), or
334/// `Err(PiggybackError::IterationsExhaustedAdjoint { iteration, lam_norm })`
335/// when `max_iter` is reached without satisfying `tol`.
336pub fn piggyback_adjoint_solve<F: Float>(
337 step_tape: &mut BytecodeTape<F>,
338 z_star: &[F],
339 x: &[F],
340 z_bar: &[F],
341 num_states: usize,
342 max_iter: usize,
343 tol: F,
344) -> Result<(Vec<F>, usize), PiggybackError> {
345 // Runtime arg-length check first so solve-fn users see
346 // `Err(DimensionMismatch)` in preference to a `validate_step_tape`
347 // panic when both would fire. Matches the check ordering in
348 // `piggyback_tangent_solve`.
349 let m = num_states;
350 if z_bar.len() != m {
351 return Err(PiggybackError::DimensionMismatch {
352 field: "z_bar",
353 expected: m,
354 actual: z_bar.len(),
355 });
356 }
357 validate_step_tape(step_tape, z_star, x, num_states);
358
359 // Set primal values: forward([z*, x])
360 let mut input = Vec::with_capacity(m + x.len());
361 input.extend_from_slice(z_star);
362 input.extend_from_slice(x);
363 step_tape.forward(&input);
364
365 let mut lambda = z_bar.to_vec();
366 let mut last_norm: f64 = f64::NAN;
367
368 for k in 0..max_iter {
369 // reverse_seeded(λ) returns [G_z^T · λ; G_x^T · λ] (length m+n)
370 let adj = step_tape.reverse_seeded(&lambda);
371
372 // λ_new[i] = adj[i] + z_bar[i] for i = 0..m
373 let mut lambda_new = Vec::with_capacity(m);
374 let mut delta_sq = F::zero();
375 let mut lam_sq = F::zero();
376 for i in 0..m {
377 let l_new = adj[i] + z_bar[i];
378 let d = l_new - lambda[i];
379 delta_sq = delta_sq + d * d;
380 lam_sq = lam_sq + lambda[i] * lambda[i];
381 lambda_new.push(l_new);
382 }
383
384 let norm = delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
385 if !norm.is_finite() {
386 return Err(PiggybackError::AdjointDivergence {
387 iteration: k,
388 last_norm: norm.to_f64().unwrap_or(f64::NAN),
389 });
390 }
391 // A ratio-converging iteration with exponentially-growing `lambda`
392 // magnitudes (spectral radius of `G_z^T` ≥ 1) can produce finite
393 // `norm` while `lambda_new` is Inf/NaN. Explicit finite check
394 // catches the divergence regardless of ratio behaviour.
395 if !lambda_new.iter().all(|v| v.is_finite()) {
396 debug_assert!(
397 norm.is_finite(),
398 "AdjointDivergence componentwise path must see a finite norm"
399 );
400 return Err(PiggybackError::AdjointDivergence {
401 iteration: k,
402 last_norm: norm.to_f64().unwrap_or(f64::NAN),
403 });
404 }
405 last_norm = norm.to_f64().unwrap_or(f64::NAN);
406 if norm < tol {
407 // One extra reverse pass with converged lambda to get consistent x_bar.
408 // Without this, adj[m..] uses the pre-convergence lambda, introducing
409 // O(tol * ||G_x||) error.
410 let adj_final = step_tape.reverse_seeded(&lambda_new);
411 return Ok((adj_final[m..].to_vec(), k + 1));
412 }
413
414 lambda = lambda_new;
415 }
416
417 Err(PiggybackError::IterationsExhaustedAdjoint {
418 iteration: max_iter,
419 lam_norm: last_norm,
420 })
421}
422
423/// Interleaved forward-adjoint piggyback solve.
424///
425/// Simultaneously iterates the primal fixed-point `z_{k+1} = G(z_k, x)` and
426/// the adjoint equation `λ_{k+1} = G_z^T · λ_k + z̄`. This cuts the total
427/// iteration count from `K_primal + K_adjoint` to `max(K_primal, K_adjoint)`.
428///
429/// Returns `Ok((z_star, x_bar, iterations))` when both `z` and `λ` converge.
430/// Returns `Err(PiggybackError::PrimalDivergence)` when `z_norm` becomes
431/// non-finite or `z_new` itself contains non-finite components,
432/// `Err(PiggybackError::AdjointDivergence)` when the adjoint norm or
433/// `lambda_new` overflows, or
434/// `Err(PiggybackError::IterationsExhaustedForwardAdjoint { iteration, z_norm, lam_norm })`
435/// when `max_iter` is reached without satisfying `tol`.
436pub fn piggyback_forward_adjoint_solve<F: Float>(
437 step_tape: &mut BytecodeTape<F>,
438 z0: &[F],
439 x: &[F],
440 z_bar: &[F],
441 num_states: usize,
442 max_iter: usize,
443 tol: F,
444) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
445 // Runtime arg-length check first — see note on
446 // `piggyback_adjoint_solve`.
447 let m = num_states;
448 if z_bar.len() != m {
449 return Err(PiggybackError::DimensionMismatch {
450 field: "z_bar",
451 expected: m,
452 actual: z_bar.len(),
453 });
454 }
455 validate_step_tape(step_tape, z0, x, num_states);
456
457 // Pre-allocate input buffer [z, x]
458 let mut input = Vec::with_capacity(m + x.len());
459 input.extend_from_slice(z0);
460 input.extend_from_slice(x);
461
462 let mut lambda = z_bar.to_vec();
463 let mut last_z_norm: f64 = f64::NAN;
464 let mut last_lam_norm: f64 = f64::NAN;
465
466 for k in 0..max_iter {
467 // Forward pass at current z
468 step_tape.forward(&input);
469 let z_new = step_tape.output_values();
470
471 // Reverse pass with current λ
472 let adj = step_tape.reverse_seeded(&lambda);
473
474 // Primal convergence: ||z_new - z|| / (1 + ||z||)
475 let mut z_delta_sq = F::zero();
476 let mut z_sq = F::zero();
477 for i in 0..m {
478 let d = z_new[i] - input[i];
479 z_delta_sq = z_delta_sq + d * d;
480 z_sq = z_sq + input[i] * input[i];
481 }
482 let z_norm = z_delta_sq.sqrt() / (F::one() + z_sq.sqrt());
483 if !z_norm.is_finite() {
484 return Err(PiggybackError::PrimalDivergence {
485 iteration: k,
486 last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
487 });
488 }
489
490 // Adjoint update and convergence: λ_new = G_z^T · λ + z̄
491 let mut lam_delta_sq = F::zero();
492 let mut lam_sq = F::zero();
493 let mut lambda_new = Vec::with_capacity(m);
494 for i in 0..m {
495 let l_new = adj[i] + z_bar[i];
496 let d = l_new - lambda[i];
497 lam_delta_sq = lam_delta_sq + d * d;
498 lam_sq = lam_sq + lambda[i] * lambda[i];
499 lambda_new.push(l_new);
500 }
501 let lam_norm = lam_delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
502 if !lam_norm.is_finite() {
503 return Err(PiggybackError::AdjointDivergence {
504 iteration: k,
505 last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
506 });
507 }
508 // Same divergence case as the standalone solvers: a ratio-converging
509 // iteration with exponentially-growing lambda magnitudes can produce
510 // finite `lam_norm` while `lambda_new` itself is Inf/NaN.
511 if !lambda_new.iter().all(|v| v.is_finite()) {
512 debug_assert!(
513 lam_norm.is_finite(),
514 "AdjointDivergence componentwise path must see a finite lam_norm"
515 );
516 return Err(PiggybackError::AdjointDivergence {
517 iteration: k,
518 last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
519 });
520 }
521 // Defense-in-depth: a non-finite `z_new[i]` would typically have
522 // already shown up as `!z_norm.is_finite()` above (the delta/sq
523 // loops touch every index), but guard the componentwise case
524 // explicitly so a future refactor of the norm computation can't
525 // silently lose primal-divergence detection.
526 if !z_new.iter().all(|v| v.is_finite()) {
527 debug_assert!(
528 z_norm.is_finite(),
529 "PrimalDivergence componentwise path must see a finite z_norm"
530 );
531 return Err(PiggybackError::PrimalDivergence {
532 iteration: k,
533 last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
534 });
535 }
536
537 last_z_norm = z_norm.to_f64().unwrap_or(f64::NAN);
538 last_lam_norm = lam_norm.to_f64().unwrap_or(f64::NAN);
539
540 if z_norm < tol && lam_norm < tol {
541 // One extra reverse pass with converged lambda_new to get consistent x_bar,
542 // matching the pattern in piggyback_adjoint_solve.
543 input[..m].copy_from_slice(&z_new[..m]);
544 step_tape.forward(&input);
545 let adj_final = step_tape.reverse_seeded(&lambda_new);
546 return Ok((z_new, adj_final[m..].to_vec(), k + 1));
547 }
548
549 // Update z in the input buffer
550 input[..m].copy_from_slice(&z_new[..m]);
551 lambda = lambda_new;
552 }
553
554 Err(PiggybackError::IterationsExhaustedForwardAdjoint {
555 iteration: max_iter,
556 z_norm: last_z_norm,
557 lam_norm: last_lam_norm,
558 })
559}