Skip to main content

echidna_optim/
piggyback.rs

1use std::fmt;
2
3use echidna::{BytecodeTape, Dual, Float};
4
5/// Reason a piggyback solve failed to converge.
6///
7/// Marked `#[non_exhaustive]` so future variants can be added without
8/// breaking exhaustive `match`es. Numeric fields use `f64` (cast via
9/// `Float::to_f64`) for uniform diagnostic output regardless of the
10/// solver's `F` type.
11#[non_exhaustive]
12#[derive(Debug, Clone)]
13pub enum PiggybackError {
14    /// The primal `z_{k+1} = G(z_k, x)` produced a non-finite norm
15    /// (relative-norm `||z_new - z||/(1 + ||z||)` is NaN/Inf), or
16    /// the primal vector itself contained non-finite components in
17    /// the forward-adjoint loop. `last_norm` is the primal-delta
18    /// relative norm at the detecting iteration: non-finite when
19    /// detection came from the norm check (the usual case);
20    /// finite — and itself diagnostic — when detection came from
21    /// the componentwise finite check (primal vector overflowed
22    /// mid-iteration while the step-to-step delta stayed bounded).
23    PrimalDivergence { iteration: usize, last_norm: f64 },
24    /// Primal stayed finite but the tangent
25    /// `ż_{k+1} = G_z · ż_k + G_x · ẋ` produced non-finite values.
26    /// Catches the ratio-converging case where the primal norm
27    /// remains bounded while individual tangent components overflow.
28    /// `last_norm` is the primal-delta relative norm at the
29    /// detecting iteration — **finite** by construction here (the
30    /// tangent-only divergence path takes the norm-finite branch
31    /// before the componentwise check fires); surfacing it tells
32    /// the caller the primal iteration was bounded while the JVP
33    /// overflowed.
34    TangentDivergence { iteration: usize, last_norm: f64 },
35    /// Adjoint `λ_{k+1} = G_z^T · λ_k + z̄` produced non-finite
36    /// values (norm or individual components). `last_norm` is the
37    /// adjoint-delta relative norm at the detecting iteration:
38    /// non-finite when detection came from the norm check; finite
39    /// when it came from the componentwise `lambda_new` check.
40    AdjointDivergence { iteration: usize, last_norm: f64 },
41    /// `piggyback_tangent_solve` reached `max_iter` without meeting
42    /// `tol`. `z_norm` is the final iteration's relative primal-delta
43    /// norm (`||z_new - z|| / (1 + ||z||)`) — a value just over `tol`
44    /// signals proximity to convergence; many orders over signals
45    /// stagnation. `iteration` equals `max_iter`.
46    IterationsExhaustedTangent { iteration: usize, z_norm: f64 },
47    /// `piggyback_adjoint_solve` reached `max_iter` without meeting
48    /// `tol`. `lam_norm` is the final iteration's relative adjoint-
49    /// delta norm (`||λ_new - λ|| / (1 + ||λ||)`).
50    IterationsExhaustedAdjoint { iteration: usize, lam_norm: f64 },
51    /// `piggyback_forward_adjoint_solve` reached `max_iter` without
52    /// meeting `tol` on both norms simultaneously. Each field is the
53    /// final iteration's relative norm for the corresponding stream.
54    IterationsExhaustedForwardAdjoint {
55        iteration: usize,
56        z_norm: f64,
57        lam_norm: f64,
58    },
59    /// A runtime-supplied vector argument to a public `*_solve` fn
60    /// had an unexpected length. `field` names the argument (e.g.
61    /// `"z_dot"`, `"z_bar"`), `expected` is the length the API
62    /// requires (typically `num_states` or `x.len()`), `actual` is
63    /// the length the caller supplied.
64    ///
65    /// Note: tape-shape contract mismatches
66    /// (`validate_step_tape`) and step-fn argument mismatches
67    /// (`piggyback_tangent_step[_with_buf]`) continue to panic —
68    /// those are programmer-contract violations, not recoverable
69    /// runtime failures.
70    DimensionMismatch {
71        field: &'static str,
72        expected: usize,
73        actual: usize,
74    },
75}
76
77impl fmt::Display for PiggybackError {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        match self {
80            PiggybackError::PrimalDivergence {
81                iteration,
82                last_norm,
83            } => {
84                write!(
85                    f,
86                    "piggyback: primal diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
87                )
88            }
89            PiggybackError::TangentDivergence {
90                iteration,
91                last_norm,
92            } => {
93                write!(
94                    f,
95                    "piggyback: tangent diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
96                )
97            }
98            PiggybackError::AdjointDivergence {
99                iteration,
100                last_norm,
101            } => {
102                write!(
103                    f,
104                    "piggyback: adjoint diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
105                )
106            }
107            PiggybackError::IterationsExhaustedTangent { iteration, z_norm } => {
108                write!(
109                    f,
110                    "piggyback: tangent solve reached max_iter = {iteration} (z_norm = {z_norm:.3e})"
111                )
112            }
113            PiggybackError::IterationsExhaustedAdjoint {
114                iteration,
115                lam_norm,
116            } => {
117                write!(
118                    f,
119                    "piggyback: adjoint solve reached max_iter = {iteration} (lam_norm = {lam_norm:.3e})"
120                )
121            }
122            PiggybackError::IterationsExhaustedForwardAdjoint {
123                iteration,
124                z_norm,
125                lam_norm,
126            } => {
127                write!(
128                    f,
129                    "piggyback: forward-adjoint solve reached max_iter = {iteration} (z_norm = {z_norm:.3e}, lam_norm = {lam_norm:.3e})"
130                )
131            }
132            PiggybackError::DimensionMismatch {
133                field,
134                expected,
135                actual,
136            } => {
137                write!(
138                    f,
139                    "piggyback: dimension mismatch for `{field}` (expected {expected}, got {actual})"
140                )
141            }
142        }
143    }
144}
145
146impl std::error::Error for PiggybackError {}
147
148echidna::assert_send_sync!(PiggybackError);
149
150/// Validate that a step tape G: R^(m+n) -> R^m has the expected shape.
151///
152/// Uses `assert_eq!` (panic) rather than `Result` because shape
153/// mismatches are programmer errors — calling `piggyback_*_solve` with
154/// an inconsistent tape is a contract violation, not a runtime
155/// numerical failure that callers should recover from.
156fn validate_step_tape<F: Float>(tape: &BytecodeTape<F>, z: &[F], x: &[F], num_states: usize) {
157    assert_eq!(z.len(), num_states);
158    assert_eq!(tape.num_inputs(), num_states + x.len());
159    assert_eq!(
160        tape.num_outputs(),
161        num_states,
162        "step tape must have num_outputs == num_states (G: R^(m+n) -> R^m)"
163    );
164}
165
166/// One tangent piggyback step through a fixed-point map G.
167///
168/// Given the iteration `z_{k+1} = G(z_k, x)`, computes both the primal step
169/// and the tangent propagation `ż_{k+1} = G_z · ż_k + G_x · ẋ` in a single
170/// forward pass using dual numbers.
171///
172/// Returns `(z_new, z_dot_new)`.
173pub fn piggyback_tangent_step<F: Float>(
174    step_tape: &BytecodeTape<F>,
175    z: &[F],
176    x: &[F],
177    z_dot: &[F],
178    x_dot: &[F],
179    num_states: usize,
180) -> (Vec<F>, Vec<F>) {
181    let mut buf = Vec::new();
182    piggyback_tangent_step_with_buf(step_tape, z, x, z_dot, x_dot, num_states, &mut buf)
183}
184
185/// One tangent piggyback step, reusing `buf` across calls.
186///
187/// Same as [`piggyback_tangent_step`] but avoids reallocating the internal
188/// dual-number buffer on each call.
189pub fn piggyback_tangent_step_with_buf<F: Float>(
190    step_tape: &BytecodeTape<F>,
191    z: &[F],
192    x: &[F],
193    z_dot: &[F],
194    x_dot: &[F],
195    num_states: usize,
196    buf: &mut Vec<Dual<F>>,
197) -> (Vec<F>, Vec<F>) {
198    validate_step_tape(step_tape, z, x, num_states);
199    let m = num_states;
200    let n = x.len();
201    assert_eq!(z_dot.len(), m, "z_dot length must equal num_states");
202    assert_eq!(x_dot.len(), n, "x_dot length must equal x length");
203
204    // Build dual inputs: [Dual(z_i, ż_i), ..., Dual(x_j, ẋ_j), ...]
205    let mut dual_inputs = Vec::with_capacity(m + n);
206    for i in 0..m {
207        dual_inputs.push(Dual::new(z[i], z_dot[i]));
208    }
209    for j in 0..n {
210        dual_inputs.push(Dual::new(x[j], x_dot[j]));
211    }
212
213    step_tape.forward_tangent(&dual_inputs, buf);
214
215    // Extract outputs: .re -> z_new, .eps -> z_dot_new
216    let out_indices = step_tape.all_output_indices();
217    let mut z_new = Vec::with_capacity(m);
218    let mut z_dot_new = Vec::with_capacity(m);
219    for &idx in out_indices {
220        let d = buf[idx as usize];
221        z_new.push(d.re);
222        z_dot_new.push(d.eps);
223    }
224
225    (z_new, z_dot_new)
226}
227
228/// Tangent piggyback solve: find fixed point z* = G(z*, x) and its tangent ż*.
229///
230/// Iterates the fixed-point map `z_{k+1} = G(z_k, x)` while simultaneously
231/// propagating tangents `ż_{k+1} = G_z · ż_k + G_x · ẋ`.
232///
233/// Returns `Ok((z_star, z_dot_star, iterations))` on convergence. Returns
234/// `Err(PiggybackError::PrimalDivergence)` when the primal norm becomes
235/// non-finite, `Err(PiggybackError::TangentDivergence)` when the primal
236/// stays finite but the tangent overflows (ratio-converging case), or
237/// `Err(PiggybackError::IterationsExhaustedTangent { iteration, z_norm })` when
238/// `max_iter` is reached without satisfying `tol`.
239pub fn piggyback_tangent_solve<F: Float>(
240    step_tape: &BytecodeTape<F>,
241    z0: &[F],
242    x: &[F],
243    x_dot: &[F],
244    num_states: usize,
245    max_iter: usize,
246    tol: F,
247) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
248    // Runtime vector-length check at solve-level: surfaces dimension
249    // mismatches as `Err` before the first iteration dispatches to the
250    // step fn (which still panics on bad input as a contract-level
251    // guarantee).
252    if x_dot.len() != x.len() {
253        return Err(PiggybackError::DimensionMismatch {
254            field: "x_dot",
255            expected: x.len(),
256            actual: x_dot.len(),
257        });
258    }
259    let m = num_states;
260    let mut z = z0.to_vec();
261    let mut z_dot = vec![F::zero(); m];
262    let mut buf = Vec::new();
263    let mut last_norm: f64 = f64::NAN;
264
265    for k in 0..max_iter {
266        let (z_new, z_dot_new) =
267            piggyback_tangent_step_with_buf(step_tape, &z, x, &z_dot, x_dot, num_states, &mut buf);
268
269        // Relative convergence: ||z_new - z|| / (1 + ||z||)
270        let mut delta_sq = F::zero();
271        let mut z_sq = F::zero();
272        for i in 0..m {
273            let d = z_new[i] - z[i];
274            delta_sq = delta_sq + d * d;
275            z_sq = z_sq + z[i] * z[i];
276        }
277        let norm = delta_sq.sqrt() / (F::one() + z_sq.sqrt());
278        // Variant-mapping order: norm-check first → PrimalDivergence;
279        // tangent-finite check second → TangentDivergence. A non-finite
280        // primal naturally produces a non-finite norm, so it falls into
281        // PrimalDivergence by detection priority.
282        if !norm.is_finite() {
283            return Err(PiggybackError::PrimalDivergence {
284                iteration: k,
285                last_norm: norm.to_f64().unwrap_or(f64::NAN),
286            });
287        }
288        // Detect tangent divergence even when the primal `z_new` itself is
289        // finite: the JVP iteration `z_dot_{k+1} = G_z·z_dot_k + G_x·x_dot`
290        // can produce Inf/NaN tangents that a primal-only norm check misses.
291        if !z_dot_new.iter().all(|v| v.is_finite()) {
292            // `norm` is guaranteed finite here — the `!norm.is_finite()`
293            // branch above would have returned `PrimalDivergence` first.
294            // The debug_assert guards against a future refactor that
295            // reorders these checks and silently invalidates the
296            // `TangentDivergence::last_norm` docstring's "finite by
297            // construction" promise.
298            debug_assert!(
299                norm.is_finite(),
300                "TangentDivergence path must see a finite primal norm"
301            );
302            return Err(PiggybackError::TangentDivergence {
303                iteration: k,
304                last_norm: norm.to_f64().unwrap_or(f64::NAN),
305            });
306        }
307        last_norm = norm.to_f64().unwrap_or(f64::NAN);
308        if norm < tol {
309            return Ok((z_new, z_dot_new, k + 1));
310        }
311
312        z = z_new;
313        z_dot = z_dot_new;
314    }
315
316    Err(PiggybackError::IterationsExhaustedTangent {
317        iteration: max_iter,
318        z_norm: last_norm,
319    })
320}
321
322/// Adjoint piggyback solve at a converged fixed point z* = G(z*, x).
323///
324/// Iterates the adjoint fixed-point equation `λ_{k+1} = G_z^T · λ_k + z̄`
325/// using reverse-mode sweeps through the step tape. At convergence, returns
326/// `x̄ = G_x^T · λ*`.
327///
328/// Requires z* to already be computed (e.g. by the primal solver).
329/// The iteration converges when G is a contraction (‖G_z‖ < 1).
330///
331/// Returns `Ok((x_bar, iterations))` on convergence. Returns
332/// `Err(PiggybackError::AdjointDivergence)` when the adjoint norm is
333/// non-finite or `lambda_new` overflows (ratio-converging case), or
334/// `Err(PiggybackError::IterationsExhaustedAdjoint { iteration, lam_norm })`
335/// when `max_iter` is reached without satisfying `tol`.
336pub fn piggyback_adjoint_solve<F: Float>(
337    step_tape: &mut BytecodeTape<F>,
338    z_star: &[F],
339    x: &[F],
340    z_bar: &[F],
341    num_states: usize,
342    max_iter: usize,
343    tol: F,
344) -> Result<(Vec<F>, usize), PiggybackError> {
345    // Runtime arg-length check first so solve-fn users see
346    // `Err(DimensionMismatch)` in preference to a `validate_step_tape`
347    // panic when both would fire. Matches the check ordering in
348    // `piggyback_tangent_solve`.
349    let m = num_states;
350    if z_bar.len() != m {
351        return Err(PiggybackError::DimensionMismatch {
352            field: "z_bar",
353            expected: m,
354            actual: z_bar.len(),
355        });
356    }
357    validate_step_tape(step_tape, z_star, x, num_states);
358
359    // Set primal values: forward([z*, x])
360    let mut input = Vec::with_capacity(m + x.len());
361    input.extend_from_slice(z_star);
362    input.extend_from_slice(x);
363    step_tape.forward(&input);
364
365    let mut lambda = z_bar.to_vec();
366    let mut last_norm: f64 = f64::NAN;
367
368    for k in 0..max_iter {
369        // reverse_seeded(λ) returns [G_z^T · λ; G_x^T · λ] (length m+n)
370        let adj = step_tape.reverse_seeded(&lambda);
371
372        // λ_new[i] = adj[i] + z_bar[i] for i = 0..m
373        let mut lambda_new = Vec::with_capacity(m);
374        let mut delta_sq = F::zero();
375        let mut lam_sq = F::zero();
376        for i in 0..m {
377            let l_new = adj[i] + z_bar[i];
378            let d = l_new - lambda[i];
379            delta_sq = delta_sq + d * d;
380            lam_sq = lam_sq + lambda[i] * lambda[i];
381            lambda_new.push(l_new);
382        }
383
384        let norm = delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
385        if !norm.is_finite() {
386            return Err(PiggybackError::AdjointDivergence {
387                iteration: k,
388                last_norm: norm.to_f64().unwrap_or(f64::NAN),
389            });
390        }
391        // A ratio-converging iteration with exponentially-growing `lambda`
392        // magnitudes (spectral radius of `G_z^T` ≥ 1) can produce finite
393        // `norm` while `lambda_new` is Inf/NaN. Explicit finite check
394        // catches the divergence regardless of ratio behaviour.
395        if !lambda_new.iter().all(|v| v.is_finite()) {
396            debug_assert!(
397                norm.is_finite(),
398                "AdjointDivergence componentwise path must see a finite norm"
399            );
400            return Err(PiggybackError::AdjointDivergence {
401                iteration: k,
402                last_norm: norm.to_f64().unwrap_or(f64::NAN),
403            });
404        }
405        last_norm = norm.to_f64().unwrap_or(f64::NAN);
406        if norm < tol {
407            // One extra reverse pass with converged lambda to get consistent x_bar.
408            // Without this, adj[m..] uses the pre-convergence lambda, introducing
409            // O(tol * ||G_x||) error.
410            let adj_final = step_tape.reverse_seeded(&lambda_new);
411            return Ok((adj_final[m..].to_vec(), k + 1));
412        }
413
414        lambda = lambda_new;
415    }
416
417    Err(PiggybackError::IterationsExhaustedAdjoint {
418        iteration: max_iter,
419        lam_norm: last_norm,
420    })
421}
422
423/// Interleaved forward-adjoint piggyback solve.
424///
425/// Simultaneously iterates the primal fixed-point `z_{k+1} = G(z_k, x)` and
426/// the adjoint equation `λ_{k+1} = G_z^T · λ_k + z̄`. This cuts the total
427/// iteration count from `K_primal + K_adjoint` to `max(K_primal, K_adjoint)`.
428///
429/// Returns `Ok((z_star, x_bar, iterations))` when both `z` and `λ` converge.
430/// Returns `Err(PiggybackError::PrimalDivergence)` when `z_norm` becomes
431/// non-finite or `z_new` itself contains non-finite components,
432/// `Err(PiggybackError::AdjointDivergence)` when the adjoint norm or
433/// `lambda_new` overflows, or
434/// `Err(PiggybackError::IterationsExhaustedForwardAdjoint { iteration, z_norm, lam_norm })`
435/// when `max_iter` is reached without satisfying `tol`.
436pub fn piggyback_forward_adjoint_solve<F: Float>(
437    step_tape: &mut BytecodeTape<F>,
438    z0: &[F],
439    x: &[F],
440    z_bar: &[F],
441    num_states: usize,
442    max_iter: usize,
443    tol: F,
444) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
445    // Runtime arg-length check first — see note on
446    // `piggyback_adjoint_solve`.
447    let m = num_states;
448    if z_bar.len() != m {
449        return Err(PiggybackError::DimensionMismatch {
450            field: "z_bar",
451            expected: m,
452            actual: z_bar.len(),
453        });
454    }
455    validate_step_tape(step_tape, z0, x, num_states);
456
457    // Pre-allocate input buffer [z, x]
458    let mut input = Vec::with_capacity(m + x.len());
459    input.extend_from_slice(z0);
460    input.extend_from_slice(x);
461
462    let mut lambda = z_bar.to_vec();
463    let mut last_z_norm: f64 = f64::NAN;
464    let mut last_lam_norm: f64 = f64::NAN;
465
466    for k in 0..max_iter {
467        // Forward pass at current z
468        step_tape.forward(&input);
469        let z_new = step_tape.output_values();
470
471        // Reverse pass with current λ
472        let adj = step_tape.reverse_seeded(&lambda);
473
474        // Primal convergence: ||z_new - z|| / (1 + ||z||)
475        let mut z_delta_sq = F::zero();
476        let mut z_sq = F::zero();
477        for i in 0..m {
478            let d = z_new[i] - input[i];
479            z_delta_sq = z_delta_sq + d * d;
480            z_sq = z_sq + input[i] * input[i];
481        }
482        let z_norm = z_delta_sq.sqrt() / (F::one() + z_sq.sqrt());
483        if !z_norm.is_finite() {
484            return Err(PiggybackError::PrimalDivergence {
485                iteration: k,
486                last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
487            });
488        }
489
490        // Adjoint update and convergence: λ_new = G_z^T · λ + z̄
491        let mut lam_delta_sq = F::zero();
492        let mut lam_sq = F::zero();
493        let mut lambda_new = Vec::with_capacity(m);
494        for i in 0..m {
495            let l_new = adj[i] + z_bar[i];
496            let d = l_new - lambda[i];
497            lam_delta_sq = lam_delta_sq + d * d;
498            lam_sq = lam_sq + lambda[i] * lambda[i];
499            lambda_new.push(l_new);
500        }
501        let lam_norm = lam_delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
502        if !lam_norm.is_finite() {
503            return Err(PiggybackError::AdjointDivergence {
504                iteration: k,
505                last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
506            });
507        }
508        // Same divergence case as the standalone solvers: a ratio-converging
509        // iteration with exponentially-growing lambda magnitudes can produce
510        // finite `lam_norm` while `lambda_new` itself is Inf/NaN.
511        if !lambda_new.iter().all(|v| v.is_finite()) {
512            debug_assert!(
513                lam_norm.is_finite(),
514                "AdjointDivergence componentwise path must see a finite lam_norm"
515            );
516            return Err(PiggybackError::AdjointDivergence {
517                iteration: k,
518                last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
519            });
520        }
521        // Defense-in-depth: a non-finite `z_new[i]` would typically have
522        // already shown up as `!z_norm.is_finite()` above (the delta/sq
523        // loops touch every index), but guard the componentwise case
524        // explicitly so a future refactor of the norm computation can't
525        // silently lose primal-divergence detection.
526        if !z_new.iter().all(|v| v.is_finite()) {
527            debug_assert!(
528                z_norm.is_finite(),
529                "PrimalDivergence componentwise path must see a finite z_norm"
530            );
531            return Err(PiggybackError::PrimalDivergence {
532                iteration: k,
533                last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
534            });
535        }
536
537        last_z_norm = z_norm.to_f64().unwrap_or(f64::NAN);
538        last_lam_norm = lam_norm.to_f64().unwrap_or(f64::NAN);
539
540        if z_norm < tol && lam_norm < tol {
541            // One extra reverse pass with converged lambda_new to get consistent x_bar,
542            // matching the pattern in piggyback_adjoint_solve.
543            input[..m].copy_from_slice(&z_new[..m]);
544            step_tape.forward(&input);
545            let adj_final = step_tape.reverse_seeded(&lambda_new);
546            return Ok((z_new, adj_final[m..].to_vec(), k + 1));
547        }
548
549        // Update z in the input buffer
550        input[..m].copy_from_slice(&z_new[..m]);
551        lambda = lambda_new;
552    }
553
554    Err(PiggybackError::IterationsExhaustedForwardAdjoint {
555        iteration: max_iter,
556        z_norm: last_z_norm,
557        lam_norm: last_lam_norm,
558    })
559}