Skip to main content

pounce_nl/
nl_tape.rs

1//! Flat-tape reverse-mode AD for `.nl` expression trees.
2//!
3//! Replaces the FD-based Hessian path with a port of the tape AD used
4//! in `ripopt::nl::autodiff`. The tape is a `Vec<TapeOp>` where each op
5//! refers to its operands by tape-slot index; forward evaluation runs
6//! through the slice once filling a parallel `Vec<f64>` of values, and
7//! reverse-mode adjoints walk the same buffer backwards.
8//!
9//! Sparse Hessians are computed by forward-over-reverse: for each
10//! variable `j` that the tape depends on, run a forward tangent sweep
11//! seeded with `e_j`, then a second-order reverse sweep that produces
12//! column `j` of the Hessian. The caller supplies a `(row, col) -> nnz
13//! position` map (lower triangle, row >= col), and contributions are
14//! accumulated in place — the outer loop in `eval_h` calls the same
15//! map for the objective and every active constraint, so every
16//! Lagrangian term lands in the right slot.
17//!
18//! Common subexpressions are tape-emitted **once**: when the recursive
19//! builder hits `Expr::Cse(rc)` it keys on the `Arc` pointer identity,
20//! emitting the body the first time and returning the cached
21//! result-slot index on subsequent references. The forward pass then
22//! computes each CSE once and the reverse pass folds adjoints from
23//! every reference into a single slot — exact chain-rule behaviour.
24
25use std::collections::{BTreeSet, HashMap, HashSet};
26use std::sync::Arc;
27
28use super::nl_external::{EvalResult, ExternalArg, ExternalLibrary, ExternalResolver};
29use super::nl_reader::{BinOp, CmpOp, Expr, FuncallArg, UnaryOp};
30
31/// One operation in the flattened tape. Operand fields are tape-slot
32/// indices into the same tape; `Var(i)` references problem variable
33/// index `i` (read from the input `x` slice during forward).
34#[derive(Debug, Clone)]
35pub enum TapeOp {
36    Const(f64),
37    Var(usize),
38    Add(usize, usize),
39    Sub(usize, usize),
40    Mul(usize, usize),
41    Div(usize, usize),
42    Pow(usize, usize),
43    Neg(usize),
44    Abs(usize),
45    Sqrt(usize),
46    Exp(usize),
47    Log(usize),
48    Log10(usize),
49    Sin(usize),
50    Cos(usize),
51    Tan(usize),
52    Atan(usize),
53    Acos(usize),
54    Sinh(usize),
55    Cosh(usize),
56    Tanh(usize),
57    Asin(usize),
58    Acosh(usize),
59    Asinh(usize),
60    Atanh(usize),
61    /// Two-argument arctangent `atan2(vals[a], vals[b])` (operands are
62    /// `(y, x)`, matching AMPL's `atan2(y, x)` / `.nl` opcode o48).
63    Atan2(usize, usize),
64    /// Pairwise minimum `min(vals[a], vals[b])`. Piecewise linear: the
65    /// value/tangent/adjoint route through whichever operand is smaller
66    /// (ties pick the first), and the second derivative is identically
67    /// zero. n-ary AMPL `min` (opcode o11) folds to a chain of these.
68    Min(usize, usize),
69    /// Pairwise maximum `max(vals[a], vals[b])` — the `Min` mirror;
70    /// n-ary AMPL `max` (opcode o12) folds to a chain of these.
71    Max(usize, usize),
72    /// Relational comparison `vals[a] OP vals[b]` → `1.0`/`0.0`.
73    /// Piecewise constant, so its derivative is identically zero — the
74    /// AD passes treat it as a constant w.r.t. its operands.
75    Cmp(CmpOp, usize, usize),
76    /// Logical AND: `1.0` iff both operands are nonzero. Zero derivative.
77    And(usize, usize),
78    /// Logical OR: `1.0` iff either operand is nonzero. Zero derivative.
79    Or(usize, usize),
80    /// Logical NOT: `1.0` iff the operand is zero. Zero derivative.
81    Not(usize),
82    /// `if-then-else`: operands `(cond, then, else)`. The value is
83    /// `vals[then]` when `vals[cond] != 0` else `vals[else]`, and the
84    /// value/tangent/adjoint all route through the active branch only.
85    /// The condition contributes no derivative (the branch switch is a
86    /// non-smooth event the AD ignores).
87    Select(usize, usize, usize),
88    /// AMPL imported (external) function call. The payload (library
89    /// handle, name, and argument list) is boxed so this rare variant
90    /// does not inflate `size_of::<TapeOp>()`: without the box the
91    /// `Arc`+`String`+`Vec` make every op ~64 bytes, which on a
92    /// summand-split objective with millions of tiny tapes (e.g.
93    /// `sensors`) costs gigabytes. Boxing drops the common arithmetic
94    /// ops back to the size of the next-largest variant.
95    Funcall(Box<FuncallData>),
96}
97
98/// Boxed payload of [`TapeOp::Funcall`]. The library is kept alive by
99/// the `Arc`; `name` is the registered function name; `args` carries
100/// positional arguments where real-valued args reference earlier tape
101/// slots and string args are inline literals.
102#[derive(Debug, Clone)]
103pub struct FuncallData {
104    pub lib: Arc<ExternalLibrary>,
105    pub name: String,
106    pub args: Vec<TapeFuncallArg>,
107}
108
109/// One argument of a `TapeOp::Funcall`. Real arguments are tape-slot indices
110/// (their values come from the running `vals[]` during forward); string
111/// arguments are owned literals (AMPL `h<len>:<chars>` tokens).
112#[derive(Debug, Clone)]
113pub enum TapeFuncallArg {
114    Tape(usize),
115    Str(String),
116}
117
118/// Evaluate a relational opcode on two scalar values, returning the
119/// boolean truth (callers map it to `1.0`/`0.0`).
120#[inline]
121fn cmp_holds(op: CmpOp, a: f64, b: f64) -> bool {
122    match op {
123        CmpOp::Lt => a < b,
124        CmpOp::Le => a <= b,
125        CmpOp::Eq => a == b,
126        CmpOp::Ge => a >= b,
127        CmpOp::Gt => a > b,
128        CmpOp::Ne => a != b,
129    }
130}
131
132fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
133    args.iter()
134        .map(|a| match a {
135            TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
136            TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
137        })
138        .collect()
139}
140
141/// Evaluate an external (AMPL imported) function, poisoning the result with
142/// `NaN` instead of panicking when the library reports an error.
143///
144/// An external eval fails on user-controllable conditions — most commonly an
145/// out-of-domain property evaluation (e.g. an IDAES Helmholtz thermo call
146/// outside its valid pressure/temperature range). We mirror the tape's own
147/// arithmetic domain-error semantics (`log(-1) → NaN`): hand back NaN so the
148/// IPM sees a failed evaluation and the line search backs off, rather than
149/// raising an uncatchable panic across the pyo3 boundary on the `read_nl`
150/// surface. The NaN derivative/Hessian vectors are sized by the full argument
151/// count — an upper bound on the real-arg count a successful eval returns — so
152/// every downstream index into them stays in range.
153fn ext_eval_or_nan(
154    lib: &ExternalLibrary,
155    name: &str,
156    call_args: &[ExternalArg<'_>],
157    n_args: usize,
158    want_derivs: bool,
159    want_hes: bool,
160) -> EvalResult {
161    lib.eval(name, call_args, want_derivs, want_hes)
162        .unwrap_or_else(|_| EvalResult {
163            value: f64::NAN,
164            derivs: want_derivs.then(|| vec![f64::NAN; n_args]),
165            hessian: want_hes.then(|| vec![f64::NAN; n_args * (n_args + 1) / 2]),
166        })
167}
168
169/// A flattened expression tape. The result of evaluation is the value
170/// at slot `ops.len() - 1` (i.e. the last op).
171#[derive(Debug, Clone)]
172pub struct Tape {
173    pub ops: Vec<TapeOp>,
174}
175
176impl Tape {
177    /// Build a tape from an `Expr` tree (no AMPL external functions). CSE
178    /// bodies (`Expr::Cse(rc)`) are cached by `Arc` pointer identity so each
179    /// body is emitted once even when referenced many times.
180    pub fn build(expr: &Expr) -> Self {
181        Self::build_with_externals(expr, &ExternalResolver::default())
182    }
183
184    /// Build a tape from an `Expr` tree, resolving any `Expr::Funcall`
185    /// nodes through `resolver`. Panics if the expression references a
186    /// funcall id that is not in the resolver — `NlProblem::resolve_externals`
187    /// must populate the resolver before tape construction.
188    pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
189        let mut ops = Vec::new();
190        let mut cache: HashMap<*const Expr, usize> = HashMap::new();
191        build_recursive(expr, &mut ops, &mut cache, resolver);
192        Tape { ops }
193    }
194
195    /// Forward sweep: returns `vals[i] = value of tape slot i`. The
196    /// scalar tape result is `vals[ops.len() - 1]`.
197    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
198        let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
199        for op in &self.ops {
200            let v = match op {
201                TapeOp::Const(c) => *c,
202                TapeOp::Var(i) => x[*i],
203                TapeOp::Add(a, b) => vals[*a] + vals[*b],
204                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
205                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
206                TapeOp::Div(a, b) => vals[*a] / vals[*b],
207                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
208                TapeOp::Neg(a) => -vals[*a],
209                TapeOp::Abs(a) => vals[*a].abs(),
210                TapeOp::Sqrt(a) => vals[*a].sqrt(),
211                TapeOp::Exp(a) => vals[*a].exp(),
212                TapeOp::Log(a) => vals[*a].ln(),
213                TapeOp::Log10(a) => vals[*a].log10(),
214                TapeOp::Sin(a) => vals[*a].sin(),
215                TapeOp::Cos(a) => vals[*a].cos(),
216                TapeOp::Tan(a) => vals[*a].tan(),
217                TapeOp::Atan(a) => vals[*a].atan(),
218                TapeOp::Acos(a) => vals[*a].acos(),
219                TapeOp::Sinh(a) => vals[*a].sinh(),
220                TapeOp::Cosh(a) => vals[*a].cosh(),
221                TapeOp::Tanh(a) => vals[*a].tanh(),
222                TapeOp::Asin(a) => vals[*a].asin(),
223                TapeOp::Acosh(a) => vals[*a].acosh(),
224                TapeOp::Asinh(a) => vals[*a].asinh(),
225                TapeOp::Atanh(a) => vals[*a].atanh(),
226                TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
227                TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
228                TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
229                TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
230                TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
231                TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
232                TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
233                TapeOp::Select(c, t, e) => {
234                    if vals[*c] != 0.0 {
235                        vals[*t]
236                    } else {
237                        vals[*e]
238                    }
239                }
240                TapeOp::Funcall(fc) => {
241                    let FuncallData { lib, name, args } = fc.as_ref();
242                    let call_args = funcall_to_ext_args(args, &vals);
243                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
244                    res.value
245                }
246            };
247            vals.push(v);
248        }
249        vals
250    }
251
252    pub fn eval(&self, x: &[f64]) -> f64 {
253        let vals = self.forward(x);
254        *vals.last().unwrap_or(&0.0)
255    }
256
257    /// Reverse-mode AD: accumulate `seed * df/dx_i` into `grad[i]` for
258    /// every problem variable `i` referenced by the tape. `grad` is
259    /// **not** zeroed by this routine — the caller can chain multiple
260    /// gradient accumulations into the same buffer.
261    pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
262        if seed == 0.0 || self.ops.is_empty() {
263            return;
264        }
265        let vals = self.forward(x);
266        self.reverse(&vals, seed, grad);
267    }
268
269    /// Reverse-mode AD reusing two caller-supplied scratch buffers
270    /// (`vals` from [`forward_into`], and an `adj` arena ≥
271    /// `self.ops.len()`) instead of allocating a forward-value vector and
272    /// an adjoint vector per call like [`gradient_seed`]. The `.nl` design
273    /// emits one tiny tape per summand — ~10⁶ on large models — so a single
274    /// `eval_jac_g` / `eval_grad_f` drives this millions of times and the
275    /// per-call allocation dominated. `grad` is accumulated into (not
276    /// zeroed); `adj` may be passed dirty (it is zeroed at the touched
277    /// slots internally).
278    ///
279    /// [`forward_into`]: Tape::forward_into
280    pub fn gradient_seed_into(
281        &self,
282        x: &[f64],
283        seed: f64,
284        grad: &mut [f64],
285        vals: &mut [f64],
286        adj: &mut [f64],
287    ) {
288        if seed == 0.0 || self.ops.is_empty() {
289            return;
290        }
291        debug_assert!(vals.len() >= self.ops.len());
292        self.forward_into(x, vals);
293        self.reverse_into(vals, seed, grad, adj);
294    }
295
296    fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
297        let n = self.ops.len();
298        let mut adj = vec![0.0f64; n];
299        self.reverse_into(vals, seed, grad, &mut adj);
300    }
301
302    /// Reverse adjoint sweep into a caller-supplied `adj` scratch buffer
303    /// (length ≥ `self.ops.len()`), the allocation-free core of [`reverse`].
304    /// `adj` is zeroed over `[0, n)` internally, so a dirty arena is fine;
305    /// `grad` is accumulated into (not zeroed).
306    fn reverse_into(&self, vals: &[f64], seed: f64, grad: &mut [f64], adj: &mut [f64]) {
307        let n = self.ops.len();
308        debug_assert!(adj.len() >= n);
309        adj[..n].fill(0.0);
310        adj[n - 1] = seed;
311
312        for i in (0..n).rev() {
313            let a = adj[i];
314            if a == 0.0 {
315                continue;
316            }
317            match &self.ops[i] {
318                TapeOp::Const(_) => {}
319                TapeOp::Var(j) => {
320                    grad[*j] += a;
321                }
322                TapeOp::Add(l, r) => {
323                    adj[*l] += a;
324                    adj[*r] += a;
325                }
326                TapeOp::Sub(l, r) => {
327                    adj[*l] += a;
328                    adj[*r] -= a;
329                }
330                TapeOp::Mul(l, r) => {
331                    adj[*l] += a * vals[*r];
332                    adj[*r] += a * vals[*l];
333                }
334                TapeOp::Div(l, r) => {
335                    let rv = vals[*r];
336                    adj[*l] += a / rv;
337                    adj[*r] -= a * vals[*l] / (rv * rv);
338                }
339                TapeOp::Pow(l, r) => {
340                    let lv = vals[*l];
341                    let rv = vals[*r];
342                    if rv != 0.0 {
343                        adj[*l] += a * rv * lv.powf(rv - 1.0);
344                    }
345                    if lv > 0.0 {
346                        adj[*r] += a * vals[i] * lv.ln();
347                    }
348                }
349                TapeOp::Neg(j) => {
350                    adj[*j] -= a;
351                }
352                TapeOp::Abs(j) => {
353                    if vals[*j] >= 0.0 {
354                        adj[*j] += a;
355                    } else {
356                        adj[*j] -= a;
357                    }
358                }
359                TapeOp::Sqrt(j) => {
360                    let sv = vals[i];
361                    if sv > 0.0 {
362                        adj[*j] += a * 0.5 / sv;
363                    }
364                }
365                TapeOp::Exp(j) => {
366                    adj[*j] += a * vals[i];
367                }
368                TapeOp::Log(j) => {
369                    adj[*j] += a / vals[*j];
370                }
371                TapeOp::Log10(j) => {
372                    adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
373                }
374                TapeOp::Sin(j) => {
375                    adj[*j] += a * vals[*j].cos();
376                }
377                TapeOp::Cos(j) => {
378                    adj[*j] -= a * vals[*j].sin();
379                }
380                TapeOp::Tan(j) => {
381                    let t = vals[i];
382                    adj[*j] += a * (1.0 + t * t);
383                }
384                TapeOp::Atan(j) => {
385                    let u = vals[*j];
386                    adj[*j] += a / (1.0 + u * u);
387                }
388                TapeOp::Acos(j) => {
389                    let u = vals[*j];
390                    adj[*j] -= a / (1.0 - u * u).sqrt();
391                }
392                TapeOp::Sinh(j) => {
393                    adj[*j] += a * vals[*j].cosh();
394                }
395                TapeOp::Cosh(j) => {
396                    adj[*j] += a * vals[*j].sinh();
397                }
398                TapeOp::Tanh(j) => {
399                    let t = vals[i];
400                    adj[*j] += a * (1.0 - t * t);
401                }
402                TapeOp::Asin(j) => {
403                    let u = vals[*j];
404                    adj[*j] += a / (1.0 - u * u).sqrt();
405                }
406                TapeOp::Acosh(j) => {
407                    let u = vals[*j];
408                    adj[*j] += a / (u * u - 1.0).sqrt();
409                }
410                TapeOp::Asinh(j) => {
411                    let u = vals[*j];
412                    adj[*j] += a / (u * u + 1.0).sqrt();
413                }
414                TapeOp::Atanh(j) => {
415                    let u = vals[*j];
416                    adj[*j] += a / (1.0 - u * u);
417                }
418                TapeOp::Atan2(l, r) => {
419                    let y = vals[*l];
420                    let x = vals[*r];
421                    let d = y * y + x * x;
422                    adj[*l] += a * (x / d);
423                    adj[*r] += a * (-y / d);
424                }
425                // min/max are piecewise linear: the adjoint flows to the
426                // selected operand only (ties pick the first, a valid
427                // subgradient choice).
428                TapeOp::Min(l, r) => {
429                    if vals[*l] <= vals[*r] {
430                        adj[*l] += a;
431                    } else {
432                        adj[*r] += a;
433                    }
434                }
435                TapeOp::Max(l, r) => {
436                    if vals[*l] >= vals[*r] {
437                        adj[*l] += a;
438                    } else {
439                        adj[*r] += a;
440                    }
441                }
442                // Comparisons and logical connectives are piecewise
443                // constant: zero derivative, so no adjoint propagates.
444                TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
445                // if-then-else: the adjoint flows entirely into the
446                // active branch; the condition gets none.
447                TapeOp::Select(c, t, e) => {
448                    if vals[*c] != 0.0 {
449                        adj[*t] += a;
450                    } else {
451                        adj[*e] += a;
452                    }
453                }
454                TapeOp::Funcall(fc) => {
455                    let FuncallData { lib, name, args } = fc.as_ref();
456                    let call_args = funcall_to_ext_args(args, vals);
457                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
458                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
459                    let mut k = 0usize;
460                    for arg in args {
461                        if let TapeFuncallArg::Tape(idx) = arg {
462                            adj[*idx] += a * derivs[k];
463                            k += 1;
464                        }
465                    }
466                }
467            }
468        }
469    }
470
471    /// Sorted distinct problem-variable indices that the tape depends on.
472    pub fn variables(&self) -> Vec<usize> {
473        let mut s: BTreeSet<usize> = BTreeSet::new();
474        for op in &self.ops {
475            if let TapeOp::Var(j) = op {
476                s.insert(*j);
477            }
478        }
479        s.into_iter().collect()
480    }
481
482    /// Forward tangent sweep: `dot[i] = d(slot_i) / dx_{seed_var}`.
483    /// Caller-supplied `dot` buffer is overwritten in full; no zeroing
484    /// needed beforehand because every slot is written before it is
485    /// read (the loop walks forward and only reads earlier slots).
486    fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
487        let n = self.ops.len();
488        debug_assert_eq!(dot.len(), n);
489        for i in 0..n {
490            dot[i] = match &self.ops[i] {
491                TapeOp::Const(_) => 0.0,
492                TapeOp::Var(k) => {
493                    if *k == seed_var {
494                        1.0
495                    } else {
496                        0.0
497                    }
498                }
499                TapeOp::Add(a, b) => dot[*a] + dot[*b],
500                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
501                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
502                TapeOp::Div(a, b) => {
503                    let vb = vals[*b];
504                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
505                }
506                TapeOp::Pow(a, b) => {
507                    let u = vals[*a];
508                    let r = vals[*b];
509                    let du = dot[*a];
510                    let dr = dot[*b];
511                    let mut result = 0.0;
512                    // Match the reverse-mode gradient's guard (`rv != 0.0` only): at base
513                    // u == 0 the slope is still well defined for r >= 1 (and a
514                    // genuine ±inf for r < 1), so it must not be silently dropped,
515                    // or the forward tangent disagrees with the reverse gradient.
516                    if r != 0.0 {
517                        result += r * u.powf(r - 1.0) * du;
518                    }
519                    if u > 0.0 {
520                        result += vals[i] * u.ln() * dr;
521                    }
522                    result
523                }
524                TapeOp::Neg(a) => -dot[*a],
525                TapeOp::Abs(a) => {
526                    if vals[*a] >= 0.0 {
527                        dot[*a]
528                    } else {
529                        -dot[*a]
530                    }
531                }
532                TapeOp::Sqrt(a) => {
533                    let sv = vals[i];
534                    if sv > 0.0 {
535                        dot[*a] * 0.5 / sv
536                    } else {
537                        0.0
538                    }
539                }
540                TapeOp::Exp(a) => dot[*a] * vals[i],
541                TapeOp::Log(a) => dot[*a] / vals[*a],
542                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
543                TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
544                TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
545                TapeOp::Tan(a) => {
546                    let t = vals[i];
547                    dot[*a] * (1.0 + t * t)
548                }
549                TapeOp::Atan(a) => {
550                    let u = vals[*a];
551                    dot[*a] / (1.0 + u * u)
552                }
553                TapeOp::Acos(a) => {
554                    let u = vals[*a];
555                    -dot[*a] / (1.0 - u * u).sqrt()
556                }
557                TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
558                TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
559                TapeOp::Tanh(a) => {
560                    let t = vals[i];
561                    dot[*a] * (1.0 - t * t)
562                }
563                TapeOp::Asin(a) => {
564                    let u = vals[*a];
565                    dot[*a] / (1.0 - u * u).sqrt()
566                }
567                TapeOp::Acosh(a) => {
568                    let u = vals[*a];
569                    dot[*a] / (u * u - 1.0).sqrt()
570                }
571                TapeOp::Asinh(a) => {
572                    let u = vals[*a];
573                    dot[*a] / (u * u + 1.0).sqrt()
574                }
575                TapeOp::Atanh(a) => {
576                    let u = vals[*a];
577                    dot[*a] / (1.0 - u * u)
578                }
579                TapeOp::Atan2(a, b) => {
580                    let y = vals[*a];
581                    let x = vals[*b];
582                    let d = y * y + x * x;
583                    (x * dot[*a] - y * dot[*b]) / d
584                }
585                // min/max: the tangent follows the selected operand.
586                TapeOp::Min(a, b) => {
587                    if vals[*a] <= vals[*b] {
588                        dot[*a]
589                    } else {
590                        dot[*b]
591                    }
592                }
593                TapeOp::Max(a, b) => {
594                    if vals[*a] >= vals[*b] {
595                        dot[*a]
596                    } else {
597                        dot[*b]
598                    }
599                }
600                TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
601                TapeOp::Select(c, t, e) => {
602                    if vals[*c] != 0.0 {
603                        dot[*t]
604                    } else {
605                        dot[*e]
606                    }
607                }
608                TapeOp::Funcall(fc) => {
609                    let FuncallData { lib, name, args } = fc.as_ref();
610                    let call_args = funcall_to_ext_args(args, vals);
611                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
612                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
613                    let mut acc = 0.0;
614                    let mut k = 0usize;
615                    for arg in args {
616                        if let TapeFuncallArg::Tape(idx) = arg {
617                            acc += derivs[k] * dot[*idx];
618                            k += 1;
619                        }
620                    }
621                    acc
622                }
623            };
624        }
625    }
626
627    /// Forward sweep into a caller-supplied buffer. Avoids the
628    /// per-call allocation of `forward()` so hot paths can reuse
629    /// one scratch arena across many tapes.
630    pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
631        let n = self.ops.len();
632        debug_assert!(vals.len() >= n);
633        for i in 0..n {
634            vals[i] = match &self.ops[i] {
635                TapeOp::Const(c) => *c,
636                TapeOp::Var(j) => x[*j],
637                TapeOp::Add(a, b) => vals[*a] + vals[*b],
638                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
639                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
640                TapeOp::Div(a, b) => vals[*a] / vals[*b],
641                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
642                TapeOp::Neg(a) => -vals[*a],
643                TapeOp::Abs(a) => vals[*a].abs(),
644                TapeOp::Sqrt(a) => vals[*a].sqrt(),
645                TapeOp::Exp(a) => vals[*a].exp(),
646                TapeOp::Log(a) => vals[*a].ln(),
647                TapeOp::Log10(a) => vals[*a].log10(),
648                TapeOp::Sin(a) => vals[*a].sin(),
649                TapeOp::Cos(a) => vals[*a].cos(),
650                TapeOp::Tan(a) => vals[*a].tan(),
651                TapeOp::Atan(a) => vals[*a].atan(),
652                TapeOp::Acos(a) => vals[*a].acos(),
653                TapeOp::Sinh(a) => vals[*a].sinh(),
654                TapeOp::Cosh(a) => vals[*a].cosh(),
655                TapeOp::Tanh(a) => vals[*a].tanh(),
656                TapeOp::Asin(a) => vals[*a].asin(),
657                TapeOp::Acosh(a) => vals[*a].acosh(),
658                TapeOp::Asinh(a) => vals[*a].asinh(),
659                TapeOp::Atanh(a) => vals[*a].atanh(),
660                TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
661                TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
662                TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
663                TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
664                TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
665                TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
666                TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
667                TapeOp::Select(c, t, e) => {
668                    if vals[*c] != 0.0 {
669                        vals[*t]
670                    } else {
671                        vals[*e]
672                    }
673                }
674                TapeOp::Funcall(fc) => {
675                    let FuncallData { lib, name, args } = fc.as_ref();
676                    let call_args = funcall_to_ext_args(args, &*vals);
677                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
678                    res.value
679                }
680            };
681        }
682    }
683
684    /// Directional Hessian-vector product: emits
685    /// `weight * (∇²f · seed)[k]` into `out[k]` for every problem
686    /// variable `k` the tape references. Caller supplies the
687    /// forward-pass result `vals` (use [`forward_into`]) plus three
688    /// scratch buffers (`dot`, `adj`, `adj_dot`), each at least
689    /// `self.ops.len()` long. `out` must be at least one past the
690    /// largest variable index in the tape; the routine reads
691    /// `seed[k]` for each `Var(k)` and writes `out[k] += weight *
692    /// (Hess · seed)[k]`.
693    ///
694    /// This is one forward-over-reverse AD pass — O(n_ops) work —
695    /// regardless of how many variables the tape depends on, which
696    /// is what makes Hessian coloring efficient: a single
697    /// directional pass recovers a whole color group of columns.
698    ///
699    /// [`forward_into`]: Tape::forward_into
700    pub fn hessian_directional(
701        &self,
702        vals: &[f64],
703        seed: &[f64],
704        weight: f64,
705        out: &mut [f64],
706        dot: &mut [f64],
707        adj: &mut [f64],
708        adj_dot: &mut [f64],
709    ) {
710        let n = self.ops.len();
711        if n == 0 || weight == 0.0 {
712            return;
713        }
714        debug_assert!(vals.len() >= n);
715        debug_assert!(dot.len() >= n);
716        debug_assert!(adj.len() >= n);
717        debug_assert!(adj_dot.len() >= n);
718
719        // Forward tangent: dot[i] = (∂vals[i] / ∂x · seed). At
720        // Var(k) the seed entry feeds in; the rest of the chain
721        // rule matches `forward_tangent` exactly.
722        for i in 0..n {
723            dot[i] = match &self.ops[i] {
724                TapeOp::Const(_) => 0.0,
725                TapeOp::Var(k) => seed[*k],
726                TapeOp::Add(a, b) => dot[*a] + dot[*b],
727                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
728                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
729                TapeOp::Div(a, b) => {
730                    let vb = vals[*b];
731                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
732                }
733                TapeOp::Pow(a, b) => {
734                    let u = vals[*a];
735                    let r = vals[*b];
736                    let du = dot[*a];
737                    let dr = dot[*b];
738                    let mut result = 0.0;
739                    // Match the reverse-mode gradient's guard (`rv != 0.0` only): at base
740                    // u == 0 the slope is still well defined for r >= 1 (and a
741                    // genuine ±inf for r < 1), so it must not be silently dropped,
742                    // or the forward tangent disagrees with the reverse gradient.
743                    if r != 0.0 {
744                        result += r * u.powf(r - 1.0) * du;
745                    }
746                    if u > 0.0 {
747                        result += vals[i] * u.ln() * dr;
748                    }
749                    result
750                }
751                TapeOp::Neg(a) => -dot[*a],
752                TapeOp::Abs(a) => {
753                    if vals[*a] >= 0.0 {
754                        dot[*a]
755                    } else {
756                        -dot[*a]
757                    }
758                }
759                TapeOp::Sqrt(a) => {
760                    let sv = vals[i];
761                    if sv > 0.0 {
762                        dot[*a] * 0.5 / sv
763                    } else {
764                        0.0
765                    }
766                }
767                TapeOp::Exp(a) => vals[i] * dot[*a],
768                TapeOp::Log(a) => dot[*a] / vals[*a],
769                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
770                TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
771                TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
772                TapeOp::Tan(a) => {
773                    let t = vals[i];
774                    (1.0 + t * t) * dot[*a]
775                }
776                TapeOp::Atan(a) => {
777                    let u = vals[*a];
778                    dot[*a] / (1.0 + u * u)
779                }
780                TapeOp::Acos(a) => {
781                    let u = vals[*a];
782                    -dot[*a] / (1.0 - u * u).sqrt()
783                }
784                TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
785                TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
786                TapeOp::Tanh(a) => {
787                    let t = vals[i];
788                    (1.0 - t * t) * dot[*a]
789                }
790                TapeOp::Asin(a) => {
791                    let u = vals[*a];
792                    dot[*a] / (1.0 - u * u).sqrt()
793                }
794                TapeOp::Acosh(a) => {
795                    let u = vals[*a];
796                    dot[*a] / (u * u - 1.0).sqrt()
797                }
798                TapeOp::Asinh(a) => {
799                    let u = vals[*a];
800                    dot[*a] / (u * u + 1.0).sqrt()
801                }
802                TapeOp::Atanh(a) => {
803                    let u = vals[*a];
804                    dot[*a] / (1.0 - u * u)
805                }
806                TapeOp::Atan2(a, b) => {
807                    let y = vals[*a];
808                    let x = vals[*b];
809                    let d = y * y + x * x;
810                    (x * dot[*a] - y * dot[*b]) / d
811                }
812                // min/max: the tangent follows the selected operand.
813                TapeOp::Min(a, b) => {
814                    if vals[*a] <= vals[*b] {
815                        dot[*a]
816                    } else {
817                        dot[*b]
818                    }
819                }
820                TapeOp::Max(a, b) => {
821                    if vals[*a] >= vals[*b] {
822                        dot[*a]
823                    } else {
824                        dot[*b]
825                    }
826                }
827                TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
828                TapeOp::Select(c, t, e) => {
829                    if vals[*c] != 0.0 {
830                        dot[*t]
831                    } else {
832                        dot[*e]
833                    }
834                }
835                TapeOp::Funcall(fc) => {
836                    let FuncallData { lib, name, args } = fc.as_ref();
837                    let call_args = funcall_to_ext_args(args, vals);
838                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
839                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
840                    let mut acc = 0.0;
841                    let mut k = 0usize;
842                    for arg in args {
843                        if let TapeFuncallArg::Tape(idx) = arg {
844                            acc += derivs[k] * dot[*idx];
845                            k += 1;
846                        }
847                    }
848                    acc
849                }
850            };
851        }
852
853        // Reverse over tangent. adj[i] = ∂f/∂vals[i],
854        // adj_dot[i] = derivative of adj[i] along `seed`
855        // direction = (Hess · seed) projected onto slot i.
856        for slot in adj.iter_mut().take(n) {
857            *slot = 0.0;
858        }
859        for slot in adj_dot.iter_mut().take(n) {
860            *slot = 0.0;
861        }
862        adj[n - 1] = 1.0;
863
864        for i in (0..n).rev() {
865            let w = adj[i];
866            let wd = adj_dot[i];
867            if w == 0.0 && wd == 0.0 {
868                continue;
869            }
870            match &self.ops[i] {
871                TapeOp::Const(_) => {}
872                TapeOp::Var(k) => {
873                    if wd != 0.0 {
874                        out[*k] += weight * wd;
875                    }
876                }
877                TapeOp::Add(a, b) => {
878                    adj[*a] += w;
879                    adj[*b] += w;
880                    adj_dot[*a] += wd;
881                    adj_dot[*b] += wd;
882                }
883                TapeOp::Sub(a, b) => {
884                    adj[*a] += w;
885                    adj[*b] -= w;
886                    adj_dot[*a] += wd;
887                    adj_dot[*b] -= wd;
888                }
889                TapeOp::Mul(a, b) => {
890                    adj[*a] += w * vals[*b];
891                    adj[*b] += w * vals[*a];
892                    adj_dot[*a] += wd * vals[*b] + w * dot[*b];
893                    adj_dot[*b] += wd * vals[*a] + w * dot[*a];
894                }
895                TapeOp::Div(a, b) => {
896                    let vb = vals[*b];
897                    let vb2 = vb * vb;
898                    let vb3 = vb2 * vb;
899                    adj[*a] += w / vb;
900                    adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
901                    adj[*b] += w * (-vals[*a] / vb2);
902                    adj_dot[*b] += wd * (-vals[*a] / vb2)
903                        + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
904                }
905                TapeOp::Pow(a, b) => {
906                    let u = vals[*a];
907                    let r = vals[*b];
908                    let du = dot[*a];
909                    let dr = dot[*b];
910                    if r != 0.0 {
911                        if u != 0.0 {
912                            let p_a = r * u.powf(r - 1.0);
913                            adj[*a] += w * p_a;
914                            let mut dp_a = dr * u.powf(r - 1.0);
915                            if u > 0.0 {
916                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
917                            } else {
918                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
919                            }
920                            adj_dot[*a] += wd * p_a + w * dp_a;
921                        } else if r >= 2.0 {
922                            let p_a = 0.0;
923                            adj[*a] += w * p_a;
924                            let dp_a = if r == 2.0 {
925                                2.0 * du
926                            } else {
927                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
928                            };
929                            adj_dot[*a] += wd * p_a + w * dp_a;
930                        }
931                    }
932                    if u > 0.0 {
933                        let ln_u = u.ln();
934                        let p_b = vals[i] * ln_u;
935                        adj[*b] += w * p_b;
936                        let dur = vals[i] * (r * du / u + dr * ln_u);
937                        let dp_b = dur * ln_u + vals[i] * du / u;
938                        adj_dot[*b] += wd * p_b + w * dp_b;
939                    }
940                }
941                TapeOp::Neg(a) => {
942                    adj[*a] -= w;
943                    adj_dot[*a] -= wd;
944                }
945                TapeOp::Abs(a) => {
946                    let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
947                    adj[*a] += w * s;
948                    adj_dot[*a] += wd * s;
949                }
950                TapeOp::Sqrt(a) => {
951                    let sv = vals[i];
952                    if sv > 0.0 {
953                        let fp = 0.5 / sv;
954                        let fpp = -0.25 / (vals[*a] * sv);
955                        adj[*a] += w * fp;
956                        adj_dot[*a] += wd * fp + w * fpp * dot[*a];
957                    }
958                }
959                TapeOp::Exp(a) => {
960                    let ev = vals[i];
961                    adj[*a] += w * ev;
962                    adj_dot[*a] += wd * ev + w * ev * dot[*a];
963                }
964                TapeOp::Log(a) => {
965                    let u = vals[*a];
966                    adj[*a] += w / u;
967                    adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
968                }
969                TapeOp::Log10(a) => {
970                    let u = vals[*a];
971                    let c = std::f64::consts::LN_10;
972                    adj[*a] += w / (u * c);
973                    adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
974                }
975                TapeOp::Sin(a) => {
976                    let u = vals[*a];
977                    let cu = u.cos();
978                    adj[*a] += w * cu;
979                    adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
980                }
981                TapeOp::Cos(a) => {
982                    let u = vals[*a];
983                    let su = u.sin();
984                    adj[*a] -= w * su;
985                    adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
986                }
987                TapeOp::Tan(a) => {
988                    let t = vals[i];
989                    let gp = 1.0 + t * t;
990                    let gpp = 2.0 * t * gp;
991                    adj[*a] += w * gp;
992                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
993                }
994                TapeOp::Atan(a) => {
995                    let u = vals[*a];
996                    let d = 1.0 + u * u;
997                    let gp = 1.0 / d;
998                    let gpp = -2.0 * u / (d * d);
999                    adj[*a] += w * gp;
1000                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1001                }
1002                TapeOp::Acos(a) => {
1003                    let u = vals[*a];
1004                    let s = 1.0 - u * u;
1005                    let r = s.sqrt();
1006                    let gp = -1.0 / r;
1007                    let gpp = -u / (s * r);
1008                    adj[*a] += w * gp;
1009                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1010                }
1011                TapeOp::Sinh(a) => {
1012                    let u = vals[*a];
1013                    let gp = u.cosh();
1014                    let gpp = u.sinh();
1015                    adj[*a] += w * gp;
1016                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1017                }
1018                TapeOp::Cosh(a) => {
1019                    let u = vals[*a];
1020                    let gp = u.sinh();
1021                    let gpp = u.cosh();
1022                    adj[*a] += w * gp;
1023                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1024                }
1025                TapeOp::Tanh(a) => {
1026                    let t = vals[i];
1027                    let gp = 1.0 - t * t;
1028                    let gpp = -2.0 * t * gp;
1029                    adj[*a] += w * gp;
1030                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1031                }
1032                TapeOp::Asin(a) => {
1033                    let u = vals[*a];
1034                    let s = 1.0 - u * u;
1035                    let r = s.sqrt();
1036                    let gp = 1.0 / r;
1037                    let gpp = u / (s * r);
1038                    adj[*a] += w * gp;
1039                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1040                }
1041                TapeOp::Acosh(a) => {
1042                    let u = vals[*a];
1043                    let s = u * u - 1.0;
1044                    let r = s.sqrt();
1045                    let gp = 1.0 / r;
1046                    let gpp = -u / (s * r);
1047                    adj[*a] += w * gp;
1048                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1049                }
1050                TapeOp::Asinh(a) => {
1051                    let u = vals[*a];
1052                    let s = u * u + 1.0;
1053                    let r = s.sqrt();
1054                    let gp = 1.0 / r;
1055                    let gpp = -u / (s * r);
1056                    adj[*a] += w * gp;
1057                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1058                }
1059                TapeOp::Atanh(a) => {
1060                    let u = vals[*a];
1061                    let d = 1.0 - u * u;
1062                    let gp = 1.0 / d;
1063                    let gpp = 2.0 * u / (d * d);
1064                    adj[*a] += w * gp;
1065                    adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1066                }
1067                TapeOp::Atan2(a, b) => {
1068                    let y = vals[*a];
1069                    let x = vals[*b];
1070                    let d = y * y + x * x;
1071                    let d2 = d * d;
1072                    let fa = x / d;
1073                    let fb = -y / d;
1074                    let faa = -2.0 * y * x / d2;
1075                    let fab = (y * y - x * x) / d2;
1076                    let fbb = 2.0 * y * x / d2;
1077                    adj[*a] += w * fa;
1078                    adj[*b] += w * fb;
1079                    adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1080                    adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1081                }
1082                // min/max are piecewise linear (zero second derivative):
1083                // route the adjoint and its tangent into the selected
1084                // operand, exactly like the active branch of a Select.
1085                TapeOp::Min(a, b) => {
1086                    let br = if vals[*a] <= vals[*b] { *a } else { *b };
1087                    adj[br] += w;
1088                    adj_dot[br] += wd;
1089                }
1090                TapeOp::Max(a, b) => {
1091                    let br = if vals[*a] >= vals[*b] { *a } else { *b };
1092                    adj[br] += w;
1093                    adj_dot[br] += wd;
1094                }
1095                // Zero derivative: no first- or second-order adjoint.
1096                TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
1097                // Route both the adjoint and its tangent into the
1098                // active branch; the condition contributes nothing.
1099                TapeOp::Select(c, t, e) => {
1100                    let br = if vals[*c] != 0.0 { *t } else { *e };
1101                    adj[br] += w;
1102                    adj_dot[br] += wd;
1103                }
1104                TapeOp::Funcall(fc) => {
1105                    let FuncallData { lib, name, args } = fc.as_ref();
1106                    let call_args = funcall_to_ext_args(args, vals);
1107                    let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1108                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
1109                    let hes = res.hessian.expect("want_hes=true returns hessian");
1110                    let real_tape: Vec<usize> = args
1111                        .iter()
1112                        .filter_map(|a| match a {
1113                            TapeFuncallArg::Tape(t) => Some(*t),
1114                            TapeFuncallArg::Str(_) => None,
1115                        })
1116                        .collect();
1117                    for (k, &tk) in real_tape.iter().enumerate() {
1118                        adj[tk] += w * derivs[k];
1119                        let mut second_term = 0.0;
1120                        for (l, &tl) in real_tape.iter().enumerate() {
1121                            let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1122                            let h_kl = hes[lo + hi * (hi + 1) / 2];
1123                            second_term += h_kl * dot[tl];
1124                        }
1125                        adj_dot[tk] += wd * derivs[k] + w * second_term;
1126                    }
1127                }
1128            }
1129        }
1130    }
1131
1132    /// Forward-over-reverse Hessian: for each variable `j` the tape
1133    /// depends on, accumulate `weight * (d²f / dx_i dx_j)` into
1134    /// `values[hess_map[(i, j)]]` for every `(i, j)` lower-triangle
1135    /// pair in the map. The same routine is used for the objective
1136    /// (with `weight = obj_factor`) and each active constraint (with
1137    /// `weight = lambda[k]`); contributions sum into the shared map.
1138    pub fn hessian_accumulate(
1139        &self,
1140        x: &[f64],
1141        weight: f64,
1142        hess_map: &HashMap<(usize, usize), usize>,
1143        values: &mut [f64],
1144    ) {
1145        let n = self.ops.len();
1146        if n == 0 || weight == 0.0 {
1147            return;
1148        }
1149        let v = self.forward(x);
1150        let var_indices = self.variables();
1151
1152        // Hoist scratch allocations out of the per-variable loop —
1153        // each was costing O(n) per j on every hessian_accumulate
1154        // call, which dominated runtime on large tapes (the dense-
1155        // Hessian Mittelmann problems). `forward_tangent` fully
1156        // overwrites `dot`, so no reset is needed there. `adj` and
1157        // `adj_dot` are mutated additively, so we zero them per j.
1158        let mut dot = vec![0.0f64; n];
1159        let mut adj = vec![0.0f64; n];
1160        let mut adj_dot = vec![0.0f64; n];
1161        for &j in &var_indices {
1162            self.forward_tangent(&v, j, &mut dot);
1163
1164            // adj[i] = standard adjoint (∂f/∂slot_i)
1165            // adj_dot[i] = derivative of adj[i] w.r.t. x_j = ∂²f/(∂slot_i ∂x_j)
1166            adj.fill(0.0);
1167            adj_dot.fill(0.0);
1168            adj[n - 1] = 1.0;
1169
1170            for i in (0..n).rev() {
1171                let w = adj[i];
1172                let wd = adj_dot[i];
1173                if w == 0.0 && wd == 0.0 {
1174                    continue;
1175                }
1176                match &self.ops[i] {
1177                    TapeOp::Const(_) => {}
1178                    TapeOp::Var(k) => {
1179                        // Lower-triangle: only emit when row k >= col j
1180                        // so an off-diagonal pair appears once.
1181                        if wd != 0.0 && *k >= j {
1182                            if let Some(&pos) = hess_map.get(&(*k, j)) {
1183                                values[pos] += weight * wd;
1184                            }
1185                        }
1186                    }
1187                    TapeOp::Add(a, b) => {
1188                        adj[*a] += w;
1189                        adj[*b] += w;
1190                        adj_dot[*a] += wd;
1191                        adj_dot[*b] += wd;
1192                    }
1193                    TapeOp::Sub(a, b) => {
1194                        adj[*a] += w;
1195                        adj[*b] -= w;
1196                        adj_dot[*a] += wd;
1197                        adj_dot[*b] -= wd;
1198                    }
1199                    TapeOp::Mul(a, b) => {
1200                        adj[*a] += w * v[*b];
1201                        adj[*b] += w * v[*a];
1202                        adj_dot[*a] += wd * v[*b] + w * dot[*b];
1203                        adj_dot[*b] += wd * v[*a] + w * dot[*a];
1204                    }
1205                    TapeOp::Div(a, b) => {
1206                        let vb = v[*b];
1207                        let vb2 = vb * vb;
1208                        let vb3 = vb2 * vb;
1209                        adj[*a] += w / vb;
1210                        adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
1211                        adj[*b] += w * (-v[*a] / vb2);
1212                        adj_dot[*b] += wd * (-v[*a] / vb2)
1213                            + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
1214                    }
1215                    TapeOp::Pow(a, b) => {
1216                        let u = v[*a];
1217                        let r = v[*b];
1218                        let du = dot[*a];
1219                        let dr = dot[*b];
1220                        if r != 0.0 {
1221                            if u != 0.0 {
1222                                let p_a = r * u.powf(r - 1.0);
1223                                adj[*a] += w * p_a;
1224                                let mut dp_a = dr * u.powf(r - 1.0);
1225                                if u > 0.0 {
1226                                    dp_a +=
1227                                        r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1228                                } else {
1229                                    dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1230                                }
1231                                adj_dot[*a] += wd * p_a + w * dp_a;
1232                            } else if r >= 2.0 {
1233                                let p_a = 0.0;
1234                                adj[*a] += w * p_a;
1235                                let dp_a = if r == 2.0 {
1236                                    2.0 * du
1237                                } else {
1238                                    r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1239                                };
1240                                adj_dot[*a] += wd * p_a + w * dp_a;
1241                            }
1242                        }
1243                        if u > 0.0 {
1244                            let ln_u = u.ln();
1245                            let p_b = v[i] * ln_u;
1246                            adj[*b] += w * p_b;
1247                            let dur = v[i] * (r * du / u + dr * ln_u);
1248                            let dp_b = dur * ln_u + v[i] * du / u;
1249                            adj_dot[*b] += wd * p_b + w * dp_b;
1250                        }
1251                    }
1252                    TapeOp::Neg(a) => {
1253                        adj[*a] -= w;
1254                        adj_dot[*a] -= wd;
1255                    }
1256                    TapeOp::Abs(a) => {
1257                        let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
1258                        adj[*a] += w * s;
1259                        adj_dot[*a] += wd * s;
1260                    }
1261                    TapeOp::Sqrt(a) => {
1262                        let sv = v[i];
1263                        if sv > 0.0 {
1264                            let fp = 0.5 / sv;
1265                            let fpp = -0.25 / (v[*a] * sv);
1266                            adj[*a] += w * fp;
1267                            adj_dot[*a] += wd * fp + w * fpp * dot[*a];
1268                        }
1269                    }
1270                    TapeOp::Exp(a) => {
1271                        let ev = v[i];
1272                        adj[*a] += w * ev;
1273                        adj_dot[*a] += wd * ev + w * ev * dot[*a];
1274                    }
1275                    TapeOp::Log(a) => {
1276                        let u = v[*a];
1277                        adj[*a] += w / u;
1278                        adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
1279                    }
1280                    TapeOp::Log10(a) => {
1281                        let u = v[*a];
1282                        let c = std::f64::consts::LN_10;
1283                        adj[*a] += w / (u * c);
1284                        adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
1285                    }
1286                    TapeOp::Sin(a) => {
1287                        let u = v[*a];
1288                        let cu = u.cos();
1289                        adj[*a] += w * cu;
1290                        adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
1291                    }
1292                    TapeOp::Cos(a) => {
1293                        let u = v[*a];
1294                        let su = u.sin();
1295                        adj[*a] -= w * su;
1296                        adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
1297                    }
1298                    TapeOp::Tan(a) => {
1299                        let t = v[i];
1300                        let gp = 1.0 + t * t;
1301                        let gpp = 2.0 * t * gp;
1302                        adj[*a] += w * gp;
1303                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1304                    }
1305                    TapeOp::Atan(a) => {
1306                        let u = v[*a];
1307                        let d = 1.0 + u * u;
1308                        let gp = 1.0 / d;
1309                        let gpp = -2.0 * u / (d * d);
1310                        adj[*a] += w * gp;
1311                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1312                    }
1313                    TapeOp::Acos(a) => {
1314                        let u = v[*a];
1315                        let s = 1.0 - u * u;
1316                        let r = s.sqrt();
1317                        let gp = -1.0 / r;
1318                        let gpp = -u / (s * r);
1319                        adj[*a] += w * gp;
1320                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1321                    }
1322                    TapeOp::Sinh(a) => {
1323                        let u = v[*a];
1324                        let gp = u.cosh();
1325                        let gpp = u.sinh();
1326                        adj[*a] += w * gp;
1327                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1328                    }
1329                    TapeOp::Cosh(a) => {
1330                        let u = v[*a];
1331                        let gp = u.sinh();
1332                        let gpp = u.cosh();
1333                        adj[*a] += w * gp;
1334                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1335                    }
1336                    TapeOp::Tanh(a) => {
1337                        let t = v[i];
1338                        let gp = 1.0 - t * t;
1339                        let gpp = -2.0 * t * gp;
1340                        adj[*a] += w * gp;
1341                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1342                    }
1343                    TapeOp::Asin(a) => {
1344                        let u = v[*a];
1345                        let s = 1.0 - u * u;
1346                        let r = s.sqrt();
1347                        let gp = 1.0 / r;
1348                        let gpp = u / (s * r);
1349                        adj[*a] += w * gp;
1350                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1351                    }
1352                    TapeOp::Acosh(a) => {
1353                        let u = v[*a];
1354                        let s = u * u - 1.0;
1355                        let r = s.sqrt();
1356                        let gp = 1.0 / r;
1357                        let gpp = -u / (s * r);
1358                        adj[*a] += w * gp;
1359                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1360                    }
1361                    TapeOp::Asinh(a) => {
1362                        let u = v[*a];
1363                        let s = u * u + 1.0;
1364                        let r = s.sqrt();
1365                        let gp = 1.0 / r;
1366                        let gpp = -u / (s * r);
1367                        adj[*a] += w * gp;
1368                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1369                    }
1370                    TapeOp::Atanh(a) => {
1371                        let u = v[*a];
1372                        let d = 1.0 - u * u;
1373                        let gp = 1.0 / d;
1374                        let gpp = 2.0 * u / (d * d);
1375                        adj[*a] += w * gp;
1376                        adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1377                    }
1378                    TapeOp::Atan2(a, b) => {
1379                        let y = v[*a];
1380                        let x = v[*b];
1381                        let d = y * y + x * x;
1382                        let d2 = d * d;
1383                        let fa = x / d;
1384                        let fb = -y / d;
1385                        let faa = -2.0 * y * x / d2;
1386                        let fab = (y * y - x * x) / d2;
1387                        let fbb = 2.0 * y * x / d2;
1388                        adj[*a] += w * fa;
1389                        adj[*b] += w * fb;
1390                        adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1391                        adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1392                    }
1393                    // min/max are piecewise linear (zero second
1394                    // derivative): route adjoint and its tangent into
1395                    // the selected operand, like an active Select branch.
1396                    TapeOp::Min(a, b) => {
1397                        let br = if v[*a] <= v[*b] { *a } else { *b };
1398                        adj[br] += w;
1399                        adj_dot[br] += wd;
1400                    }
1401                    TapeOp::Max(a, b) => {
1402                        let br = if v[*a] >= v[*b] { *a } else { *b };
1403                        adj[br] += w;
1404                        adj_dot[br] += wd;
1405                    }
1406                    // Zero derivative: no first- or second-order adjoint.
1407                    TapeOp::Cmp(_, _, _)
1408                    | TapeOp::And(_, _)
1409                    | TapeOp::Or(_, _)
1410                    | TapeOp::Not(_) => {}
1411                    // Route adjoint and its tangent into the active
1412                    // branch only; the condition contributes nothing.
1413                    TapeOp::Select(c, t, e) => {
1414                        let br = if v[*c] != 0.0 { *t } else { *e };
1415                        adj[br] += w;
1416                        adj_dot[br] += wd;
1417                    }
1418                    TapeOp::Funcall(fc) => {
1419                        let FuncallData { lib, name, args } = fc.as_ref();
1420                        let call_args = funcall_to_ext_args(args, &v);
1421                        let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1422                        let derivs = res.derivs.expect("want_derivs=true returns derivs");
1423                        let hes = res.hessian.expect("want_hes=true returns hessian");
1424                        let real_tape: Vec<usize> = args
1425                            .iter()
1426                            .filter_map(|a| match a {
1427                                TapeFuncallArg::Tape(t) => Some(*t),
1428                                TapeFuncallArg::Str(_) => None,
1429                            })
1430                            .collect();
1431                        for (k, &tk) in real_tape.iter().enumerate() {
1432                            adj[tk] += w * derivs[k];
1433                            let mut second_term = 0.0;
1434                            for (l, &tl) in real_tape.iter().enumerate() {
1435                                let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1436                                let h_kl = hes[lo + hi * (hi + 1) / 2];
1437                                second_term += h_kl * dot[tl];
1438                            }
1439                            adj_dot[tk] += wd * derivs[k] + w * second_term;
1440                        }
1441                    }
1442                }
1443            }
1444        }
1445    }
1446
1447    /// Structural Hessian sparsity (lower triangle, row >= col).
1448    /// Propagates per-slot variable-dependence sets forward; each
1449    /// nonlinear op emits the cross/self products of its operand sets.
1450    /// Linear ops contribute no second-derivative pairs.
1451    pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
1452        let n = self.ops.len();
1453        let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
1454        let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
1455
1456        let emit_cross =
1457            |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1458                for &v1 in s1 {
1459                    for &v2 in s2 {
1460                        let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1461                        pairs.insert((r, c));
1462                    }
1463                }
1464            };
1465        let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1466            let vars: Vec<usize> = s.iter().copied().collect();
1467            for (ai, &vi) in vars.iter().enumerate() {
1468                for &vj in &vars[..=ai] {
1469                    let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1470                    pairs.insert((r, c));
1471                }
1472            }
1473        };
1474
1475        for op in &self.ops {
1476            let vset = match op {
1477                TapeOp::Const(_) => BTreeSet::new(),
1478                TapeOp::Var(j) => {
1479                    let mut s = BTreeSet::new();
1480                    s.insert(*j);
1481                    s
1482                }
1483                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1484                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1485                }
1486                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1487                TapeOp::Mul(a, b) => {
1488                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1489                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1490                }
1491                TapeOp::Div(a, b) => {
1492                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1493                    emit_self(&var_sets[*b], &mut pairs);
1494                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1495                }
1496                TapeOp::Pow(a, b) => {
1497                    let combined: BTreeSet<usize> =
1498                        var_sets[*a].union(&var_sets[*b]).copied().collect();
1499                    emit_self(&combined, &mut pairs);
1500                    combined
1501                }
1502                TapeOp::Sqrt(a)
1503                | TapeOp::Exp(a)
1504                | TapeOp::Log(a)
1505                | TapeOp::Log10(a)
1506                | TapeOp::Sin(a)
1507                | TapeOp::Cos(a)
1508                | TapeOp::Tan(a)
1509                | TapeOp::Atan(a)
1510                | TapeOp::Acos(a)
1511                | TapeOp::Sinh(a)
1512                | TapeOp::Cosh(a)
1513                | TapeOp::Tanh(a)
1514                | TapeOp::Asin(a)
1515                | TapeOp::Acosh(a)
1516                | TapeOp::Asinh(a)
1517                | TapeOp::Atanh(a) => {
1518                    emit_self(&var_sets[*a], &mut pairs);
1519                    var_sets[*a].clone()
1520                }
1521                // atan2(y, x) is nonlinear in both operands with a full
1522                // 2×2 second-derivative block; the structural superset is
1523                // every self/cross pair within the combined operand set.
1524                TapeOp::Atan2(a, b) => {
1525                    let combined: BTreeSet<usize> =
1526                        var_sets[*a].union(&var_sets[*b]).copied().collect();
1527                    emit_self(&combined, &mut pairs);
1528                    combined
1529                }
1530                // Comparisons / logical connectives are piecewise
1531                // constant: identically-zero derivative, so they
1532                // introduce no second-derivative pairs and carry no
1533                // variable dependence downstream (their result is a
1534                // constant as far as AD is concerned).
1535                TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
1536                    BTreeSet::new()
1537                }
1538                // Select passes through the active branch's value with
1539                // unit derivative, so it emits no pairs of its own; its
1540                // dependence set is the union of *both* branches
1541                // (either may become active as x varies — conservative
1542                // and correct for a structural superset). The condition
1543                // contributes no derivative and is excluded.
1544                TapeOp::Select(_c, t, e) => var_sets[*t].union(&var_sets[*e]).copied().collect(),
1545                // min/max are piecewise linear: the active operand passes
1546                // through with unit derivative, so the second derivative is
1547                // identically zero (no pairs). Their dependence set is the
1548                // union of both operands (either may become active as x
1549                // varies — conservative and correct for a structural
1550                // superset), mirroring Select.
1551                TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
1552                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1553                }
1554                TapeOp::Funcall(fc) => {
1555                    let args = &fc.args;
1556                    let mut combined: BTreeSet<usize> = BTreeSet::new();
1557                    for arg in args {
1558                        if let TapeFuncallArg::Tape(t) = arg {
1559                            for &vv in &var_sets[*t] {
1560                                combined.insert(vv);
1561                            }
1562                        }
1563                    }
1564                    emit_self(&combined, &mut pairs);
1565                    combined
1566                }
1567            };
1568            var_sets.push(vset);
1569        }
1570        pairs
1571    }
1572}
1573
1574fn build_recursive(
1575    expr: &Expr,
1576    ops: &mut Vec<TapeOp>,
1577    cache: &mut HashMap<*const Expr, usize>,
1578    resolver: &ExternalResolver,
1579) -> usize {
1580    match expr {
1581        Expr::Const(c) => {
1582            let idx = ops.len();
1583            ops.push(TapeOp::Const(*c));
1584            idx
1585        }
1586        Expr::Var(i) => {
1587            let idx = ops.len();
1588            ops.push(TapeOp::Var(*i));
1589            idx
1590        }
1591        Expr::Binary(op, a, b) => {
1592            // Pow(x, const) is the dominant libm/dispatch cost in
1593            // transcendental-heavy AMPL tapes (henon, lane_emden, …):
1594            // `powf` itself is ~30–50 cycles AND the reverse-mode arm
1595            // for `Pow` carries an extra `ln(x)` branch. Rewriting
1596            // small integer / half-integer exponents into mul/sqrt
1597            // chains drops these calls entirely and reroutes the AD
1598            // through the much cheaper `Mul`/`Sqrt` arms.
1599            if let BinOp::Pow = op {
1600                if let Some(c) = peek_const(b) {
1601                    if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
1602                        return idx;
1603                    }
1604                }
1605            }
1606            let l = build_recursive(a, ops, cache, resolver);
1607            let r = build_recursive(b, ops, cache, resolver);
1608            let idx = ops.len();
1609            ops.push(match op {
1610                BinOp::Add => TapeOp::Add(l, r),
1611                BinOp::Sub => TapeOp::Sub(l, r),
1612                BinOp::Mul => TapeOp::Mul(l, r),
1613                BinOp::Div => TapeOp::Div(l, r),
1614                BinOp::Pow => TapeOp::Pow(l, r),
1615                BinOp::Atan2 => TapeOp::Atan2(l, r),
1616            });
1617            idx
1618        }
1619        Expr::Unary(op, a) => {
1620            let v = build_recursive(a, ops, cache, resolver);
1621            let idx = ops.len();
1622            ops.push(match op {
1623                UnaryOp::Neg => TapeOp::Neg(v),
1624                UnaryOp::Sqrt => TapeOp::Sqrt(v),
1625                UnaryOp::Log => TapeOp::Log(v),
1626                UnaryOp::Log10 => TapeOp::Log10(v),
1627                UnaryOp::Exp => TapeOp::Exp(v),
1628                UnaryOp::Abs => TapeOp::Abs(v),
1629                UnaryOp::Sin => TapeOp::Sin(v),
1630                UnaryOp::Cos => TapeOp::Cos(v),
1631                UnaryOp::Tan => TapeOp::Tan(v),
1632                UnaryOp::Atan => TapeOp::Atan(v),
1633                UnaryOp::Acos => TapeOp::Acos(v),
1634                UnaryOp::Sinh => TapeOp::Sinh(v),
1635                UnaryOp::Cosh => TapeOp::Cosh(v),
1636                UnaryOp::Tanh => TapeOp::Tanh(v),
1637                UnaryOp::Asin => TapeOp::Asin(v),
1638                UnaryOp::Acosh => TapeOp::Acosh(v),
1639                UnaryOp::Asinh => TapeOp::Asinh(v),
1640                UnaryOp::Atanh => TapeOp::Atanh(v),
1641            });
1642            idx
1643        }
1644        Expr::Sum(args) => {
1645            if args.is_empty() {
1646                let idx = ops.len();
1647                ops.push(TapeOp::Const(0.0));
1648                return idx;
1649            }
1650            let mut acc = build_recursive(&args[0], ops, cache, resolver);
1651            for a in &args[1..] {
1652                let next = build_recursive(a, ops, cache, resolver);
1653                let idx = ops.len();
1654                ops.push(TapeOp::Add(acc, next));
1655                acc = idx;
1656            }
1657            acc
1658        }
1659        // n-ary min/max fold to a left-associative chain of binary
1660        // Min/Max TapeOps. The chain reproduces the list extremum, and
1661        // the binary Min/Max AD arms route the (sub)gradient to the
1662        // active operand at each step — equivalent to selecting the one
1663        // active operand of the whole list. An empty list cannot arise
1664        // from a well-formed `.nl` MINLIST/MAXLIST (count >= 1); guard
1665        // with a 0 constant for safety rather than panicking.
1666        Expr::MinList(args) | Expr::MaxList(args) => {
1667            let is_min = matches!(expr, Expr::MinList(_));
1668            if args.is_empty() {
1669                let idx = ops.len();
1670                ops.push(TapeOp::Const(0.0));
1671                return idx;
1672            }
1673            let mut acc = build_recursive(&args[0], ops, cache, resolver);
1674            for a in &args[1..] {
1675                let next = build_recursive(a, ops, cache, resolver);
1676                let idx = ops.len();
1677                ops.push(if is_min {
1678                    TapeOp::Min(acc, next)
1679                } else {
1680                    TapeOp::Max(acc, next)
1681                });
1682                acc = idx;
1683            }
1684            acc
1685        }
1686        Expr::Cse(body) => {
1687            // Cache by Arc identity so each shared body is emitted into
1688            // the tape exactly once and every reference resolves to the
1689            // same result-slot index. Forward computes the body once;
1690            // reverse-mode adjoint sums contributions from every ref
1691            // into that shared slot — exact chain rule for shared
1692            // sub-expressions.
1693            let key = Arc::as_ptr(body) as *const Expr;
1694            if let Some(&idx) = cache.get(&key) {
1695                idx
1696            } else {
1697                let idx = build_recursive(body, ops, cache, resolver);
1698                cache.insert(key, idx);
1699                idx
1700            }
1701        }
1702        Expr::Compare(op, a, b) => {
1703            let l = build_recursive(a, ops, cache, resolver);
1704            let r = build_recursive(b, ops, cache, resolver);
1705            let idx = ops.len();
1706            ops.push(TapeOp::Cmp(*op, l, r));
1707            idx
1708        }
1709        Expr::And(a, b) => {
1710            let l = build_recursive(a, ops, cache, resolver);
1711            let r = build_recursive(b, ops, cache, resolver);
1712            let idx = ops.len();
1713            ops.push(TapeOp::And(l, r));
1714            idx
1715        }
1716        Expr::Or(a, b) => {
1717            let l = build_recursive(a, ops, cache, resolver);
1718            let r = build_recursive(b, ops, cache, resolver);
1719            let idx = ops.len();
1720            ops.push(TapeOp::Or(l, r));
1721            idx
1722        }
1723        Expr::Not(a) => {
1724            let v = build_recursive(a, ops, cache, resolver);
1725            let idx = ops.len();
1726            ops.push(TapeOp::Not(v));
1727            idx
1728        }
1729        Expr::Cond { cond, then_, else_ } => {
1730            let c = build_recursive(cond, ops, cache, resolver);
1731            let t = build_recursive(then_, ops, cache, resolver);
1732            let e = build_recursive(else_, ops, cache, resolver);
1733            let idx = ops.len();
1734            ops.push(TapeOp::Select(c, t, e));
1735            idx
1736        }
1737        Expr::Funcall { id, args } => {
1738            let (lib, name) = resolver
1739                .funcs_by_id
1740                .get(id)
1741                .unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
1742            let tape_args: Vec<TapeFuncallArg> = args
1743                .iter()
1744                .map(|a| match a {
1745                    FuncallArg::Real(e) => {
1746                        TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
1747                    }
1748                    FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
1749                })
1750                .collect();
1751            let idx = ops.len();
1752            ops.push(TapeOp::Funcall(Box::new(FuncallData {
1753                lib: Arc::clone(lib),
1754                name: name.clone(),
1755                args: tape_args,
1756            })));
1757            idx
1758        }
1759    }
1760}
1761
1762/// Resolve `e` to a literal constant if it is one (transparently
1763/// peering through `Cse` wrappers, which AMPL emits around shared
1764/// constants in CSE-heavy problems).
1765fn peek_const(e: &Expr) -> Option<f64> {
1766    match e {
1767        Expr::Const(c) => Some(*c),
1768        Expr::Cse(body) => peek_const(body),
1769        _ => None,
1770    }
1771}
1772
1773/// Try to rewrite `base ^ exponent_const` into cheaper ops. Returns
1774/// the result tape-slot on success; `None` means "fall through to
1775/// generic Pow." Handles the cases that account for the bulk of
1776/// AMPL-emitted Pow nodes: integer exponents up to ±8 and the
1777/// `Sqrt`/passthrough/one specials. Half-integer exponents (e.g.
1778/// `^1.5`) and larger integers are left to generic `Pow` since the
1779/// resulting mul chain grows the tape faster than it saves work.
1780fn try_emit_const_pow(
1781    base_expr: &Expr,
1782    c: f64,
1783    ops: &mut Vec<TapeOp>,
1784    cache: &mut HashMap<*const Expr, usize>,
1785    resolver: &ExternalResolver,
1786) -> Option<usize> {
1787    if c == 0.0 {
1788        let idx = ops.len();
1789        ops.push(TapeOp::Const(1.0));
1790        return Some(idx);
1791    }
1792    if c == 1.0 {
1793        return Some(build_recursive(base_expr, ops, cache, resolver));
1794    }
1795    if c == 0.5 {
1796        let b = build_recursive(base_expr, ops, cache, resolver);
1797        let idx = ops.len();
1798        ops.push(TapeOp::Sqrt(b));
1799        return Some(idx);
1800    }
1801    // Integer exponents: bounded so a bad tape can't blow up the
1802    // op count. 8 covers everything AMPL typically emits for
1803    // polynomial models; beyond that the binary-expansion mul
1804    // chain (≥4 ops) starts to lose to a single `powf`.
1805    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1806        let n = c.abs() as u32;
1807        if n == 0 {
1808            // Already handled above, but guard.
1809            let idx = ops.len();
1810            ops.push(TapeOp::Const(1.0));
1811            return Some(idx);
1812        }
1813        let b = build_recursive(base_expr, ops, cache, resolver);
1814        let pos = emit_int_pow(b, n, ops);
1815        if c < 0.0 {
1816            // x^-n = 1 / x^n. Saves the powf and its ln branch in
1817            // reverse mode; cost is one Div in their place.
1818            let one_idx = ops.len();
1819            ops.push(TapeOp::Const(1.0));
1820            let idx = ops.len();
1821            ops.push(TapeOp::Div(one_idx, pos));
1822            return Some(idx);
1823        }
1824        return Some(pos);
1825    }
1826    None
1827}
1828
1829/// Emit `base^n` for `n >= 1` as a binary-expansion mul chain.
1830/// Worst-case op count is `2·floor(log2(n))` — i.e. 1 op for n=2, 2
1831/// for n=3/4, 3 for n=5..8.
1832fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
1833    debug_assert!(n >= 1);
1834    if n == 1 {
1835        return base;
1836    }
1837    let half = emit_int_pow(base, n / 2, ops);
1838    let squared = ops.len();
1839    ops.push(TapeOp::Mul(half, half));
1840    if n % 2 == 1 {
1841        let idx = ops.len();
1842        ops.push(TapeOp::Mul(squared, base));
1843        idx
1844    } else {
1845        squared
1846    }
1847}
1848
1849// ============================================================
1850// HybridTape: per-summand local tapes + shared CSE prelude.
1851//
1852// Partial separability — the .nl Sum/Add structure — gets each
1853// summand its own local Vec<SummandOp>. CSE bodies (V-segments
1854// in .nl) that appear in two or more summands are promoted into
1855// a single shared `prelude: Vec<TapeOp>`; per-summand references
1856// to a promoted CSE are SummandOp::Shared(prelude_slot).
1857//
1858// This is strictly better than either extreme:
1859//   - per-summand Tape (no cross-summand sharing): re-inlines
1860//     every shared CSE, blows up tape size when many constraints
1861//     share a stencil derivative (Mittelmann *120 problems).
1862//   - GlobalTape (single shared Vec<TapeOp> for everything):
1863//     per-root reverse sweeps scatter across a many-MB buffer,
1864//     thrashing cache when no CSE is actually shared (lane_emden
1865//     120: each constraint owns its own ops → 50% regression
1866//     vs per-summand tapes).
1867//
1868// Forward: prelude once, then each summand's local pass.
1869// Reverse / forward-over-reverse: per-summand sweep over local
1870// reach (which propagates adjoints into prelude_adj at Shared
1871// boundaries), then a small reverse pass over the summand's
1872// prelude_reach to fold those into grad / Hessian.
1873// ============================================================
1874
1875/// One slot in a per-summand local tape.
1876#[derive(Debug, Clone)]
1877pub enum SummandOp {
1878    /// Local op — operand indices reference other slots in the
1879    /// same per-summand vector.
1880    Local(TapeOp),
1881    /// Pull a value from the shared prelude at slot `usize`. No
1882    /// downstream cost beyond the lookup; adjoints flowing into
1883    /// this slot accumulate into the prelude adjoint buffer.
1884    Shared(usize),
1885}
1886
1887#[derive(Debug, Clone)]
1888pub struct Summand {
1889    pub ops: Vec<SummandOp>,
1890    /// Local slot holding the summand's final value.
1891    pub root_slot: usize,
1892    /// Local slots reachable from `root_slot`, ascending (topo).
1893    pub local_reach: Vec<usize>,
1894    /// Prelude slots reachable from the summand's Shared refs,
1895    /// ascending (topo in prelude's operand DAG).
1896    pub prelude_reach: Vec<usize>,
1897    /// Variables touched by Var ops inside `local_reach`.
1898    pub local_vars: Vec<usize>,
1899    /// Variables touched by Var ops inside `prelude_reach`.
1900    pub prelude_vars: Vec<usize>,
1901    /// `local_vars ∪ prelude_vars`, sorted. Hessian j-loop set.
1902    pub all_vars: Vec<usize>,
1903}
1904
1905#[derive(Debug)]
1906pub struct HybridTape {
1907    /// Shared CSE bodies. Slot indices in `SummandOp::Shared`
1908    /// point here; this Vec is built bottom-up by `build_recursive`,
1909    /// so operand indices are always less than the consumer's
1910    /// index (topo in ascending order).
1911    pub prelude: Vec<TapeOp>,
1912    pub summands: Vec<Summand>,
1913}
1914
1915impl HybridTape {
1916    /// Build hybrid tape from a list of root expressions. CSE
1917    /// bodies referenced from ≥ 2 roots are promoted into the
1918    /// shared prelude; CSEs touched by only one root are inlined
1919    /// into that summand's local ops.
1920    pub fn build_multi(exprs: &[Expr]) -> Self {
1921        // Pass 1: per-Cse-pointer count of how many roots reference
1922        // it (each root contributes at most 1 to the count). The
1923        // ≥2 threshold means a CSE is shared across summands.
1924        let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1925        for e in exprs {
1926            let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1927            count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1928        }
1929
1930        // Pass 2: build prelude + each summand. The summand builder
1931        // hits the prelude path lazily — only when it encounters a
1932        // promoted Cse — so the prelude grows only with bodies that
1933        // are actually referenced from multiple summands.
1934        let mut prelude: Vec<TapeOp> = Vec::new();
1935        let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1936        let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1937        for e in exprs {
1938            let mut local: Vec<SummandOp> = Vec::new();
1939            let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1940            let root_slot = build_into_summand(
1941                e,
1942                &mut local,
1943                &mut local_cache,
1944                &mut prelude,
1945                &mut prelude_map,
1946                &cse_count,
1947            );
1948            summands.push(Summand {
1949                ops: local,
1950                root_slot,
1951                local_reach: Vec::new(),
1952                prelude_reach: Vec::new(),
1953                local_vars: Vec::new(),
1954                prelude_vars: Vec::new(),
1955                all_vars: Vec::new(),
1956            });
1957        }
1958
1959        // Pass 3: per-summand reach / vars. Prelude reach uses an
1960        // epoch-tagged shared visited buffer so total cost stays
1961        // O(Σ |prelude_reach_i|) rather than O(n_summands × |prelude|).
1962        let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1963        let mut p_epoch: u32 = 0;
1964        let mut p_stack: Vec<usize> = Vec::new();
1965        for s in &mut summands {
1966            let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1967            s.local_reach = local_reach;
1968
1969            let mut lv: BTreeSet<usize> = BTreeSet::new();
1970            for &i in &s.local_reach {
1971                if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1972                    lv.insert(*j);
1973                }
1974            }
1975            s.local_vars = lv.iter().copied().collect();
1976
1977            if !shared_refs.is_empty() {
1978                p_epoch += 1;
1979                let mut preach: Vec<usize> = Vec::new();
1980                for &start in &shared_refs {
1981                    bfs_prelude(
1982                        &prelude,
1983                        start,
1984                        &mut p_visited,
1985                        p_epoch,
1986                        &mut p_stack,
1987                        &mut preach,
1988                    );
1989                }
1990                preach.sort_unstable();
1991                s.prelude_vars = vars_in(&prelude, &preach);
1992                s.prelude_reach = preach;
1993            }
1994
1995            let mut av: BTreeSet<usize> = lv;
1996            for &v in &s.prelude_vars {
1997                av.insert(v);
1998            }
1999            s.all_vars = av.into_iter().collect();
2000        }
2001
2002        HybridTape { prelude, summands }
2003    }
2004
2005    pub fn n_prelude_ops(&self) -> usize {
2006        self.prelude.len()
2007    }
2008    pub fn n_summands(&self) -> usize {
2009        self.summands.len()
2010    }
2011    pub fn max_summand_ops(&self) -> usize {
2012        self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
2013    }
2014    pub fn total_local_ops(&self) -> usize {
2015        self.summands.iter().map(|s| s.ops.len()).sum()
2016    }
2017
2018    /// Forward sweep over the shared prelude. `prelude_vals` must
2019    /// have length `n_prelude_ops`.
2020    pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
2021        debug_assert_eq!(prelude_vals.len(), self.prelude.len());
2022        for i in 0..self.prelude.len() {
2023            prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
2024        }
2025    }
2026
2027    /// Forward sweep over one summand. `local_vals` must hold at
2028    /// least `s.ops.len()` entries.
2029    pub fn forward_summand(
2030        &self,
2031        s: &Summand,
2032        x: &[f64],
2033        prelude_vals: &[f64],
2034        local_vals: &mut [f64],
2035    ) {
2036        debug_assert!(local_vals.len() >= s.ops.len());
2037        for i in 0..s.ops.len() {
2038            local_vals[i] = match &s.ops[i] {
2039                SummandOp::Local(op) => fwd_step(op, x, local_vals),
2040                SummandOp::Shared(k) => prelude_vals[*k],
2041            };
2042        }
2043    }
2044
2045    /// Value at the summand root after `forward_summand`.
2046    #[inline]
2047    pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
2048        local_vals[s.root_slot]
2049    }
2050
2051    /// Reverse-mode gradient for one summand. Walks `local_reach`
2052    /// in reverse — propagating adjoints into `prelude_adj` at
2053    /// Shared boundaries — and then walks `prelude_reach` in
2054    /// reverse to land contributions in `grad`. Scratch arrays
2055    /// `local_adj` and `prelude_adj` are zeroed only at the slots
2056    /// actually touched.
2057    #[allow(clippy::too_many_arguments)]
2058    pub fn gradient_summand(
2059        &self,
2060        s: &Summand,
2061        prelude_vals: &[f64],
2062        local_vals: &[f64],
2063        seed: f64,
2064        grad: &mut [f64],
2065        local_adj: &mut [f64],
2066        prelude_adj: &mut [f64],
2067    ) {
2068        if seed == 0.0 || s.local_reach.is_empty() {
2069            return;
2070        }
2071        for &i in &s.local_reach {
2072            local_adj[i] = 0.0;
2073        }
2074        for &i in &s.prelude_reach {
2075            prelude_adj[i] = 0.0;
2076        }
2077        local_adj[s.root_slot] = seed;
2078        for &i in s.local_reach.iter().rev() {
2079            let a = local_adj[i];
2080            if a == 0.0 {
2081                continue;
2082            }
2083            match &s.ops[i] {
2084                SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
2085                SummandOp::Shared(k) => {
2086                    prelude_adj[*k] += a;
2087                }
2088            }
2089        }
2090        for &i in s.prelude_reach.iter().rev() {
2091            let a = prelude_adj[i];
2092            if a == 0.0 {
2093                continue;
2094            }
2095            rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
2096        }
2097    }
2098
2099    /// Forward-over-reverse Hessian for one summand with multiplier
2100    /// `weight`. Iterates over `s.all_vars`; for each seed variable
2101    /// j: (1) forward tangent through prelude_reach then local_reach,
2102    /// (2) reverse over local (folding adj/adj_dot into prelude at
2103    /// Shared boundaries), (3) reverse over prelude_reach. All
2104    /// scratch buffers are zeroed only at the touched slots inside
2105    /// the per-j loop.
2106    #[allow(clippy::too_many_arguments)]
2107    pub fn hessian_summand(
2108        &self,
2109        s: &Summand,
2110        prelude_vals: &[f64],
2111        local_vals: &[f64],
2112        weight: f64,
2113        hess_map: &HashMap<(usize, usize), usize>,
2114        values: &mut [f64],
2115        local_dot: &mut [f64],
2116        local_adj: &mut [f64],
2117        local_adj_dot: &mut [f64],
2118        prelude_dot: &mut [f64],
2119        prelude_adj: &mut [f64],
2120        prelude_adj_dot: &mut [f64],
2121    ) {
2122        if weight == 0.0 || s.local_reach.is_empty() {
2123            return;
2124        }
2125        for &j in &s.all_vars {
2126            for &i in &s.local_reach {
2127                local_dot[i] = 0.0;
2128                local_adj[i] = 0.0;
2129                local_adj_dot[i] = 0.0;
2130            }
2131            for &i in &s.prelude_reach {
2132                prelude_dot[i] = 0.0;
2133                prelude_adj[i] = 0.0;
2134                prelude_adj_dot[i] = 0.0;
2135            }
2136            for &i in &s.prelude_reach {
2137                prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
2138            }
2139            for &i in &s.local_reach {
2140                local_dot[i] = match &s.ops[i] {
2141                    SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
2142                    SummandOp::Shared(k) => prelude_dot[*k],
2143                };
2144            }
2145            local_adj[s.root_slot] = 1.0;
2146            for &i in s.local_reach.iter().rev() {
2147                let w = local_adj[i];
2148                let wd = local_adj_dot[i];
2149                if w == 0.0 && wd == 0.0 {
2150                    continue;
2151                }
2152                match &s.ops[i] {
2153                    SummandOp::Local(op) => {
2154                        ror_step(
2155                            op,
2156                            i,
2157                            j,
2158                            local_vals,
2159                            local_dot,
2160                            local_adj,
2161                            local_adj_dot,
2162                            w,
2163                            wd,
2164                            weight,
2165                            hess_map,
2166                            values,
2167                        );
2168                    }
2169                    SummandOp::Shared(k) => {
2170                        prelude_adj[*k] += w;
2171                        prelude_adj_dot[*k] += wd;
2172                    }
2173                }
2174            }
2175            for &i in s.prelude_reach.iter().rev() {
2176                let w = prelude_adj[i];
2177                let wd = prelude_adj_dot[i];
2178                if w == 0.0 && wd == 0.0 {
2179                    continue;
2180                }
2181                ror_step(
2182                    &self.prelude[i],
2183                    i,
2184                    j,
2185                    prelude_vals,
2186                    prelude_dot,
2187                    prelude_adj,
2188                    prelude_adj_dot,
2189                    w,
2190                    wd,
2191                    weight,
2192                    hess_map,
2193                    values,
2194                );
2195            }
2196        }
2197    }
2198
2199    /// Structural Hessian sparsity over the whole hybrid tape:
2200    /// every pair the prelude or any summand can produce.
2201    pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
2202        let mut pairs = hessian_sparsity_impl(&self.prelude);
2203
2204        // Per-prelude-slot var-set, reused across summands as the
2205        // var-set carrier for Shared refs.
2206        let prelude_var_sets = compute_var_sets(&self.prelude);
2207
2208        for s in &self.summands {
2209            summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
2210        }
2211        pairs
2212    }
2213}
2214
2215/// Pass-1 helper: per-root walk that increments `counts[ptr]` the
2216/// first time a Cse pointer is encountered in this root. Recursing
2217/// into the body is gated on the first visit to avoid quadratic
2218/// blowup on heavily shared CSE DAGs.
2219/// True when `expr` (or any subexpression) is an AMPL external function
2220/// call. The hybrid summand path rejects funcalls outright, but the
2221/// *promoted*-CSE branch emits a shared CSE body via `build_recursive`
2222/// with an **empty** `ExternalResolver::default()` — it has no resolver
2223/// of its own. Without this pre-scan a funcall buried in a promoted CSE
2224/// would reach `build_recursive`'s `Expr::Funcall` arm and panic with the
2225/// misleading `unresolved AMPL funcall id <n>` message, instead of the
2226/// clear "not supported on the hybrid path" message the non-promoted
2227/// summand path raises. Pre-scanning makes both paths report the same
2228/// reason. (Funcalls are unsupported on the hybrid path regardless of
2229/// whether the id would resolve, so this never rejects a buildable tape.)
2230fn cse_contains_funcall(expr: &Expr) -> bool {
2231    match expr {
2232        Expr::Funcall { .. } => true,
2233        Expr::Const(_) | Expr::Var(_) => false,
2234        Expr::Binary(_, a, b) => cse_contains_funcall(a) || cse_contains_funcall(b),
2235        Expr::Unary(_, a) => cse_contains_funcall(a),
2236        Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
2237            args.iter().any(cse_contains_funcall)
2238        }
2239        Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
2240            cse_contains_funcall(a) || cse_contains_funcall(b)
2241        }
2242        Expr::Not(a) => cse_contains_funcall(a),
2243        Expr::Cond { cond, then_, else_ } => {
2244            cse_contains_funcall(cond) || cse_contains_funcall(then_) || cse_contains_funcall(else_)
2245        }
2246        Expr::Cse(body) => cse_contains_funcall(body),
2247    }
2248}
2249
2250fn count_cse_appearances(
2251    e: &Expr,
2252    seen_in_root: &mut HashSet<*const Expr>,
2253    counts: &mut HashMap<*const Expr, usize>,
2254) {
2255    match e {
2256        Expr::Const(_) | Expr::Var(_) => {}
2257        Expr::Binary(_, a, b) => {
2258            count_cse_appearances(a, seen_in_root, counts);
2259            count_cse_appearances(b, seen_in_root, counts);
2260        }
2261        Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
2262        Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
2263            for a in args {
2264                count_cse_appearances(a, seen_in_root, counts);
2265            }
2266        }
2267        Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
2268            count_cse_appearances(a, seen_in_root, counts);
2269            count_cse_appearances(b, seen_in_root, counts);
2270        }
2271        Expr::Not(a) => count_cse_appearances(a, seen_in_root, counts),
2272        Expr::Cond { cond, then_, else_ } => {
2273            count_cse_appearances(cond, seen_in_root, counts);
2274            count_cse_appearances(then_, seen_in_root, counts);
2275            count_cse_appearances(else_, seen_in_root, counts);
2276        }
2277        Expr::Cse(body) => {
2278            let key = Arc::as_ptr(body) as *const Expr;
2279            if seen_in_root.insert(key) {
2280                *counts.entry(key).or_insert(0) += 1;
2281                count_cse_appearances(body, seen_in_root, counts);
2282            }
2283        }
2284        Expr::Funcall { args, .. } => {
2285            for arg in args {
2286                if let FuncallArg::Real(e) = arg {
2287                    count_cse_appearances(e, seen_in_root, counts);
2288                }
2289            }
2290        }
2291    }
2292}
2293
2294/// Recursive summand builder. CSEs that meet the promotion bar
2295/// (≥ 2 roots reference them per `cse_count`) get a single prelude
2296/// emission via `build_recursive`; the summand records a Shared op
2297/// pointing at the prelude slot. Non-promoted CSEs are inlined
2298/// into the summand with intra-summand Arc-pointer dedup.
2299fn build_into_summand(
2300    expr: &Expr,
2301    local: &mut Vec<SummandOp>,
2302    local_cache: &mut HashMap<*const Expr, usize>,
2303    prelude: &mut Vec<TapeOp>,
2304    prelude_map: &mut HashMap<*const Expr, usize>,
2305    cse_count: &HashMap<*const Expr, usize>,
2306) -> usize {
2307    match expr {
2308        Expr::Const(c) => {
2309            let i = local.len();
2310            local.push(SummandOp::Local(TapeOp::Const(*c)));
2311            i
2312        }
2313        Expr::Var(j) => {
2314            let i = local.len();
2315            local.push(SummandOp::Local(TapeOp::Var(*j)));
2316            i
2317        }
2318        Expr::Binary(op, a, b) => {
2319            if let BinOp::Pow = op {
2320                if let Some(c) = peek_const(b) {
2321                    if let Some(i) = try_emit_const_pow_summand(
2322                        a,
2323                        c,
2324                        local,
2325                        local_cache,
2326                        prelude,
2327                        prelude_map,
2328                        cse_count,
2329                    ) {
2330                        return i;
2331                    }
2332                }
2333            }
2334            let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2335            let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
2336            let i = local.len();
2337            local.push(SummandOp::Local(match op {
2338                BinOp::Add => TapeOp::Add(l, r),
2339                BinOp::Sub => TapeOp::Sub(l, r),
2340                BinOp::Mul => TapeOp::Mul(l, r),
2341                BinOp::Div => TapeOp::Div(l, r),
2342                BinOp::Pow => TapeOp::Pow(l, r),
2343                BinOp::Atan2 => TapeOp::Atan2(l, r),
2344            }));
2345            i
2346        }
2347        Expr::Unary(op, a) => {
2348            let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2349            let i = local.len();
2350            local.push(SummandOp::Local(match op {
2351                UnaryOp::Neg => TapeOp::Neg(v),
2352                UnaryOp::Sqrt => TapeOp::Sqrt(v),
2353                UnaryOp::Log => TapeOp::Log(v),
2354                UnaryOp::Log10 => TapeOp::Log10(v),
2355                UnaryOp::Exp => TapeOp::Exp(v),
2356                UnaryOp::Abs => TapeOp::Abs(v),
2357                UnaryOp::Sin => TapeOp::Sin(v),
2358                UnaryOp::Cos => TapeOp::Cos(v),
2359                UnaryOp::Tan => TapeOp::Tan(v),
2360                UnaryOp::Atan => TapeOp::Atan(v),
2361                UnaryOp::Acos => TapeOp::Acos(v),
2362                UnaryOp::Sinh => TapeOp::Sinh(v),
2363                UnaryOp::Cosh => TapeOp::Cosh(v),
2364                UnaryOp::Tanh => TapeOp::Tanh(v),
2365                UnaryOp::Asin => TapeOp::Asin(v),
2366                UnaryOp::Acosh => TapeOp::Acosh(v),
2367                UnaryOp::Asinh => TapeOp::Asinh(v),
2368                UnaryOp::Atanh => TapeOp::Atanh(v),
2369            }));
2370            i
2371        }
2372        Expr::Sum(args) => {
2373            if args.is_empty() {
2374                let i = local.len();
2375                local.push(SummandOp::Local(TapeOp::Const(0.0)));
2376                return i;
2377            }
2378            let mut acc = build_into_summand(
2379                &args[0],
2380                local,
2381                local_cache,
2382                prelude,
2383                prelude_map,
2384                cse_count,
2385            );
2386            for a in &args[1..] {
2387                let nxt =
2388                    build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2389                let i = local.len();
2390                local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
2391                acc = i;
2392            }
2393            acc
2394        }
2395        Expr::Cse(body) => {
2396            let key = Arc::as_ptr(body) as *const Expr;
2397            if let Some(&li) = local_cache.get(&key) {
2398                return li;
2399            }
2400            let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
2401            if promoted {
2402                // `build_recursive` below runs with an empty resolver, so a
2403                // funcall hidden inside the promoted body would panic with the
2404                // misleading "unresolved AMPL funcall id" message rather than
2405                // the clear hybrid-unsupported message the non-promoted summand
2406                // path (and the `Expr::Funcall` arm at the bottom) raises.
2407                // Reject it up front so both CSE paths report the same reason.
2408                if cse_contains_funcall(body) {
2409                    panic!(
2410                        "HybridTape: AMPL external function calls are not supported on the \
2411                         hybrid (partial-separability) tape path. Build with \
2412                         Tape::build_with_externals instead."
2413                    );
2414                }
2415                // Build (or reuse) the prelude slot for this CSE.
2416                // `build_recursive(expr, ...)` hits the Cse arm,
2417                // emits the body once into prelude, and caches it
2418                // in `prelude_map` keyed by this Arc pointer.
2419                let pslot =
2420                    build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
2421                let li = local.len();
2422                local.push(SummandOp::Shared(pslot));
2423                local_cache.insert(key, li);
2424                li
2425            } else {
2426                let li =
2427                    build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
2428                local_cache.insert(key, li);
2429                li
2430            }
2431        }
2432        Expr::Compare(_, _, _)
2433        | Expr::And(_, _)
2434        | Expr::Or(_, _)
2435        | Expr::Not(_)
2436        | Expr::Cond { .. }
2437        | Expr::MinList(_)
2438        | Expr::MaxList(_) => {
2439            panic!(
2440                "HybridTape: conditional / logical / min-max opcodes (comparisons, \
2441                 AND/OR/NOT, if-then-else, min/max lists) are not supported on the \
2442                 hybrid (partial-separability) tape path. Build with \
2443                 Tape::build_with_externals instead."
2444            );
2445        }
2446        Expr::Funcall { .. } => {
2447            panic!(
2448                "HybridTape: AMPL external function calls are not supported on the \
2449                 hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
2450                 instead."
2451            );
2452        }
2453    }
2454}
2455
2456/// Pow-lowering specialised for summand builds. Mirrors
2457/// `try_emit_const_pow` but with summand-flavoured emission.
2458fn try_emit_const_pow_summand(
2459    base_expr: &Expr,
2460    c: f64,
2461    local: &mut Vec<SummandOp>,
2462    local_cache: &mut HashMap<*const Expr, usize>,
2463    prelude: &mut Vec<TapeOp>,
2464    prelude_map: &mut HashMap<*const Expr, usize>,
2465    cse_count: &HashMap<*const Expr, usize>,
2466) -> Option<usize> {
2467    if c == 0.0 {
2468        let i = local.len();
2469        local.push(SummandOp::Local(TapeOp::Const(1.0)));
2470        return Some(i);
2471    }
2472    if c == 1.0 {
2473        return Some(build_into_summand(
2474            base_expr,
2475            local,
2476            local_cache,
2477            prelude,
2478            prelude_map,
2479            cse_count,
2480        ));
2481    }
2482    if c == 0.5 {
2483        let b = build_into_summand(
2484            base_expr,
2485            local,
2486            local_cache,
2487            prelude,
2488            prelude_map,
2489            cse_count,
2490        );
2491        let i = local.len();
2492        local.push(SummandOp::Local(TapeOp::Sqrt(b)));
2493        return Some(i);
2494    }
2495    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
2496        let n = c.abs() as u32;
2497        if n == 0 {
2498            let i = local.len();
2499            local.push(SummandOp::Local(TapeOp::Const(1.0)));
2500            return Some(i);
2501        }
2502        let b = build_into_summand(
2503            base_expr,
2504            local,
2505            local_cache,
2506            prelude,
2507            prelude_map,
2508            cse_count,
2509        );
2510        let pos = emit_int_pow_summand(b, n, local);
2511        if c < 0.0 {
2512            let one_idx = local.len();
2513            local.push(SummandOp::Local(TapeOp::Const(1.0)));
2514            let i = local.len();
2515            local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
2516            return Some(i);
2517        }
2518        return Some(pos);
2519    }
2520    None
2521}
2522
2523fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
2524    debug_assert!(n >= 1);
2525    if n == 1 {
2526        return base;
2527    }
2528    let half = emit_int_pow_summand(base, n / 2, local);
2529    let squared = local.len();
2530    local.push(SummandOp::Local(TapeOp::Mul(half, half)));
2531    if n % 2 == 1 {
2532        let i = local.len();
2533        local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
2534        i
2535    } else {
2536        squared
2537    }
2538}
2539
2540/// Walk a summand's local op DAG from `root`, returning the
2541/// reachable local slots (sorted ascending) plus the distinct
2542/// prelude slots referenced by any Shared op along the way.
2543fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
2544    let mut visited = vec![false; ops.len()];
2545    let mut reach: Vec<usize> = Vec::new();
2546    let mut shared: BTreeSet<usize> = BTreeSet::new();
2547    let mut stack: Vec<usize> = Vec::with_capacity(16);
2548    visited[root] = true;
2549    reach.push(root);
2550    stack.push(root);
2551    while let Some(s) = stack.pop() {
2552        match &ops[s] {
2553            SummandOp::Local(op) => {
2554                let (a, b) = op_operands(op);
2555                if let Some(a) = a {
2556                    if !visited[a] {
2557                        visited[a] = true;
2558                        reach.push(a);
2559                        stack.push(a);
2560                    }
2561                }
2562                if let Some(b) = b {
2563                    if !visited[b] {
2564                        visited[b] = true;
2565                        reach.push(b);
2566                        stack.push(b);
2567                    }
2568                }
2569            }
2570            SummandOp::Shared(k) => {
2571                shared.insert(*k);
2572            }
2573        }
2574    }
2575    reach.sort_unstable();
2576    (reach, shared.into_iter().collect())
2577}
2578
2579/// Epoch-tagged BFS over the prelude operand DAG, accumulating
2580/// reachable slots into `out`. Caller is responsible for sorting
2581/// `out` after a batch of starts has been processed.
2582fn bfs_prelude(
2583    prelude: &[TapeOp],
2584    start: usize,
2585    visited: &mut [u32],
2586    cur: u32,
2587    stack: &mut Vec<usize>,
2588    out: &mut Vec<usize>,
2589) {
2590    if visited[start] == cur {
2591        return;
2592    }
2593    visited[start] = cur;
2594    out.push(start);
2595    stack.push(start);
2596    while let Some(s) = stack.pop() {
2597        let (a, b) = op_operands(&prelude[s]);
2598        if let Some(a) = a {
2599            if visited[a] != cur {
2600                visited[a] = cur;
2601                out.push(a);
2602                stack.push(a);
2603            }
2604        }
2605        if let Some(b) = b {
2606            if visited[b] != cur {
2607                visited[b] = cur;
2608                out.push(b);
2609                stack.push(b);
2610            }
2611        }
2612    }
2613}
2614
2615/// Per-op var-set for the prelude — every slot's transitive
2616/// variable footprint. Used by `summand_sparsity` to expand
2617/// `SummandOp::Shared(k)` into its var-set carrier.
2618fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
2619    let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2620    for op in ops {
2621        let vs: BTreeSet<usize> = match op {
2622            TapeOp::Const(_) => BTreeSet::new(),
2623            TapeOp::Var(j) => {
2624                let mut s = BTreeSet::new();
2625                s.insert(*j);
2626                s
2627            }
2628            TapeOp::Add(a, b)
2629            | TapeOp::Sub(a, b)
2630            | TapeOp::Mul(a, b)
2631            | TapeOp::Div(a, b)
2632            | TapeOp::Pow(a, b)
2633            | TapeOp::Atan2(a, b) => out[*a].union(&out[*b]).copied().collect(),
2634            TapeOp::Neg(a)
2635            | TapeOp::Abs(a)
2636            | TapeOp::Sqrt(a)
2637            | TapeOp::Exp(a)
2638            | TapeOp::Log(a)
2639            | TapeOp::Log10(a)
2640            | TapeOp::Sin(a)
2641            | TapeOp::Cos(a)
2642            | TapeOp::Tan(a)
2643            | TapeOp::Atan(a)
2644            | TapeOp::Acos(a)
2645            | TapeOp::Sinh(a)
2646            | TapeOp::Cosh(a)
2647            | TapeOp::Tanh(a)
2648            | TapeOp::Asin(a)
2649            | TapeOp::Acosh(a)
2650            | TapeOp::Asinh(a)
2651            | TapeOp::Atanh(a) => out[*a].clone(),
2652            TapeOp::Cmp(_, _, _)
2653            | TapeOp::And(_, _)
2654            | TapeOp::Or(_, _)
2655            | TapeOp::Not(_)
2656            | TapeOp::Select(_, _, _)
2657            | TapeOp::Min(_, _)
2658            | TapeOp::Max(_, _) => unreachable!(
2659                "HybridTape prelude cannot contain conditional / logical / min-max \
2660                 TapeOps; build_into_summand panics on those Expr variants."
2661            ),
2662            TapeOp::Funcall(_) => unreachable!(
2663                "HybridTape prelude cannot contain TapeOp::Funcall; \
2664                 build_into_summand panics on Expr::Funcall."
2665            ),
2666        };
2667        out.push(vs);
2668    }
2669    out
2670}
2671
2672/// Per-op Hessian-sparsity propagation over a summand's mixed
2673/// SummandOp slice. Shared refs contribute their prelude var-set
2674/// but do not themselves emit pairs (those came from
2675/// `hessian_sparsity_impl(&prelude)`).
2676fn summand_sparsity(
2677    ops: &[SummandOp],
2678    prelude_var_sets: &[BTreeSet<usize>],
2679    pairs: &mut BTreeSet<(usize, usize)>,
2680) {
2681    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2682    let emit_cross =
2683        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2684            for &v1 in s1 {
2685                for &v2 in s2 {
2686                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2687                    pairs.insert((r, c));
2688                }
2689            }
2690        };
2691    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2692        let vars: Vec<usize> = s.iter().copied().collect();
2693        for (ai, &vi) in vars.iter().enumerate() {
2694            for &vj in &vars[..=ai] {
2695                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2696                pairs.insert((r, c));
2697            }
2698        }
2699    };
2700    for so in ops {
2701        let vset: BTreeSet<usize> = match so {
2702            SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
2703            SummandOp::Local(op) => match op {
2704                TapeOp::Const(_) => BTreeSet::new(),
2705                TapeOp::Var(j) => {
2706                    let mut s = BTreeSet::new();
2707                    s.insert(*j);
2708                    s
2709                }
2710                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2711                    var_sets[*a].union(&var_sets[*b]).copied().collect()
2712                }
2713                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2714                TapeOp::Mul(a, b) => {
2715                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2716                    var_sets[*a].union(&var_sets[*b]).copied().collect()
2717                }
2718                TapeOp::Div(a, b) => {
2719                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2720                    emit_self(&var_sets[*b], pairs);
2721                    var_sets[*a].union(&var_sets[*b]).copied().collect()
2722                }
2723                TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
2724                    let combined: BTreeSet<usize> =
2725                        var_sets[*a].union(&var_sets[*b]).copied().collect();
2726                    emit_self(&combined, pairs);
2727                    combined
2728                }
2729                TapeOp::Sqrt(a)
2730                | TapeOp::Exp(a)
2731                | TapeOp::Log(a)
2732                | TapeOp::Log10(a)
2733                | TapeOp::Sin(a)
2734                | TapeOp::Cos(a)
2735                | TapeOp::Tan(a)
2736                | TapeOp::Atan(a)
2737                | TapeOp::Acos(a)
2738                | TapeOp::Sinh(a)
2739                | TapeOp::Cosh(a)
2740                | TapeOp::Tanh(a)
2741                | TapeOp::Asin(a)
2742                | TapeOp::Acosh(a)
2743                | TapeOp::Asinh(a)
2744                | TapeOp::Atanh(a) => {
2745                    emit_self(&var_sets[*a], pairs);
2746                    var_sets[*a].clone()
2747                }
2748                TapeOp::Cmp(_, _, _)
2749                | TapeOp::And(_, _)
2750                | TapeOp::Or(_, _)
2751                | TapeOp::Not(_)
2752                | TapeOp::Select(_, _, _)
2753                | TapeOp::Min(_, _)
2754                | TapeOp::Max(_, _) => unreachable!(
2755                    "HybridTape summand cannot contain conditional / logical / min-max \
2756                     TapeOps; build_into_summand panics on those Expr variants."
2757                ),
2758                TapeOp::Funcall(_) => unreachable!(
2759                    "HybridTape summand cannot contain TapeOp::Funcall; \
2760                     build_into_summand panics on Expr::Funcall."
2761                ),
2762            },
2763        };
2764        var_sets.push(vset);
2765    }
2766}
2767
2768/// Operand indices of a `TapeOp`, normalized into a fixed-length
2769/// array so callers don't need to re-match every site.
2770#[inline]
2771fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
2772    match op {
2773        TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
2774        TapeOp::Add(a, b)
2775        | TapeOp::Sub(a, b)
2776        | TapeOp::Mul(a, b)
2777        | TapeOp::Div(a, b)
2778        | TapeOp::Pow(a, b)
2779        | TapeOp::Atan2(a, b) => (Some(*a), Some(*b)),
2780        TapeOp::Neg(a)
2781        | TapeOp::Abs(a)
2782        | TapeOp::Sqrt(a)
2783        | TapeOp::Exp(a)
2784        | TapeOp::Log(a)
2785        | TapeOp::Log10(a)
2786        | TapeOp::Sin(a)
2787        | TapeOp::Cos(a)
2788        | TapeOp::Tan(a)
2789        | TapeOp::Atan(a)
2790        | TapeOp::Acos(a)
2791        | TapeOp::Sinh(a)
2792        | TapeOp::Cosh(a)
2793        | TapeOp::Tanh(a)
2794        | TapeOp::Asin(a)
2795        | TapeOp::Acosh(a)
2796        | TapeOp::Asinh(a)
2797        | TapeOp::Atanh(a) => (Some(*a), None),
2798        // Conditional / logical TapeOps never reach the HybridTape
2799        // operand-walk (build_into_summand rejects them). Cmp/And/Or
2800        // have two operands; Not has one; Select's three can't be
2801        // expressed in this two-slot shape, so it would be a bug to
2802        // see it here.
2803        TapeOp::Cmp(_, a, b) | TapeOp::And(a, b) | TapeOp::Or(a, b) => (Some(*a), Some(*b)),
2804        TapeOp::Not(a) => (Some(*a), None),
2805        TapeOp::Select(_, _, _) => unreachable!(
2806            "op_operands: TapeOp::Select has three operands and is unsupported on \
2807             the HybridTape path"
2808        ),
2809        TapeOp::Min(_, _) | TapeOp::Max(_, _) => unreachable!(
2810            "op_operands: TapeOp::Min/Max are unsupported on the HybridTape path \
2811             (build_into_summand rejects min/max lists)"
2812        ),
2813        TapeOp::Funcall(_) => (None, None),
2814    }
2815}
2816
2817fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
2818    let mut s: BTreeSet<usize> = BTreeSet::new();
2819    for &i in reach {
2820        if let TapeOp::Var(j) = &ops[i] {
2821            s.insert(*j);
2822        }
2823    }
2824    s.into_iter().collect()
2825}
2826
2827// ----- Free-function AD step kernels used by GlobalTape -----
2828
2829#[inline]
2830fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
2831    match op {
2832        TapeOp::Const(c) => *c,
2833        TapeOp::Var(i) => x[*i],
2834        TapeOp::Add(a, b) => vals[*a] + vals[*b],
2835        TapeOp::Sub(a, b) => vals[*a] - vals[*b],
2836        TapeOp::Mul(a, b) => vals[*a] * vals[*b],
2837        TapeOp::Div(a, b) => vals[*a] / vals[*b],
2838        TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
2839        TapeOp::Neg(a) => -vals[*a],
2840        TapeOp::Abs(a) => vals[*a].abs(),
2841        TapeOp::Sqrt(a) => vals[*a].sqrt(),
2842        TapeOp::Exp(a) => vals[*a].exp(),
2843        TapeOp::Log(a) => vals[*a].ln(),
2844        TapeOp::Log10(a) => vals[*a].log10(),
2845        TapeOp::Sin(a) => vals[*a].sin(),
2846        TapeOp::Cos(a) => vals[*a].cos(),
2847        TapeOp::Tan(a) => vals[*a].tan(),
2848        TapeOp::Atan(a) => vals[*a].atan(),
2849        TapeOp::Acos(a) => vals[*a].acos(),
2850        TapeOp::Sinh(a) => vals[*a].sinh(),
2851        TapeOp::Cosh(a) => vals[*a].cosh(),
2852        TapeOp::Tanh(a) => vals[*a].tanh(),
2853        TapeOp::Asin(a) => vals[*a].asin(),
2854        TapeOp::Acosh(a) => vals[*a].acosh(),
2855        TapeOp::Asinh(a) => vals[*a].asinh(),
2856        TapeOp::Atanh(a) => vals[*a].atanh(),
2857        TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
2858        TapeOp::Cmp(_, _, _)
2859        | TapeOp::And(_, _)
2860        | TapeOp::Or(_, _)
2861        | TapeOp::Not(_)
2862        | TapeOp::Select(_, _, _)
2863        | TapeOp::Min(_, _)
2864        | TapeOp::Max(_, _) => panic!(
2865            "GlobalTape free-function kernels do not implement conditional / logical \
2866             / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2867             instead."
2868        ),
2869        TapeOp::Funcall(fc) => {
2870            let FuncallData { lib, name, args } = fc.as_ref();
2871            let call_args = funcall_to_ext_args(args, vals);
2872            let res = lib
2873                .eval(name, &call_args, false, false)
2874                .unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
2875            res.value
2876        }
2877    }
2878}
2879
2880#[inline]
2881fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
2882    match op {
2883        TapeOp::Const(_) => {}
2884        TapeOp::Var(j) => {
2885            grad[*j] += a;
2886        }
2887        TapeOp::Add(l, r) => {
2888            adj[*l] += a;
2889            adj[*r] += a;
2890        }
2891        TapeOp::Sub(l, r) => {
2892            adj[*l] += a;
2893            adj[*r] -= a;
2894        }
2895        TapeOp::Mul(l, r) => {
2896            adj[*l] += a * vals[*r];
2897            adj[*r] += a * vals[*l];
2898        }
2899        TapeOp::Div(l, r) => {
2900            let rv = vals[*r];
2901            adj[*l] += a / rv;
2902            adj[*r] -= a * vals[*l] / (rv * rv);
2903        }
2904        TapeOp::Pow(l, r) => {
2905            let lv = vals[*l];
2906            let rv = vals[*r];
2907            if rv != 0.0 {
2908                adj[*l] += a * rv * lv.powf(rv - 1.0);
2909            }
2910            if lv > 0.0 {
2911                adj[*r] += a * vals[i] * lv.ln();
2912            }
2913        }
2914        TapeOp::Neg(j) => {
2915            adj[*j] -= a;
2916        }
2917        TapeOp::Abs(j) => {
2918            if vals[*j] >= 0.0 {
2919                adj[*j] += a;
2920            } else {
2921                adj[*j] -= a;
2922            }
2923        }
2924        TapeOp::Sqrt(j) => {
2925            let sv = vals[i];
2926            if sv > 0.0 {
2927                adj[*j] += a * 0.5 / sv;
2928            }
2929        }
2930        TapeOp::Exp(j) => {
2931            adj[*j] += a * vals[i];
2932        }
2933        TapeOp::Log(j) => {
2934            adj[*j] += a / vals[*j];
2935        }
2936        TapeOp::Log10(j) => {
2937            adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
2938        }
2939        TapeOp::Sin(j) => {
2940            adj[*j] += a * vals[*j].cos();
2941        }
2942        TapeOp::Cos(j) => {
2943            adj[*j] -= a * vals[*j].sin();
2944        }
2945        TapeOp::Tan(j) => {
2946            let t = vals[i];
2947            adj[*j] += a * (1.0 + t * t);
2948        }
2949        TapeOp::Atan(j) => {
2950            let u = vals[*j];
2951            adj[*j] += a / (1.0 + u * u);
2952        }
2953        TapeOp::Acos(j) => {
2954            let u = vals[*j];
2955            adj[*j] -= a / (1.0 - u * u).sqrt();
2956        }
2957        TapeOp::Sinh(j) => {
2958            adj[*j] += a * vals[*j].cosh();
2959        }
2960        TapeOp::Cosh(j) => {
2961            adj[*j] += a * vals[*j].sinh();
2962        }
2963        TapeOp::Tanh(j) => {
2964            let t = vals[i];
2965            adj[*j] += a * (1.0 - t * t);
2966        }
2967        TapeOp::Asin(j) => {
2968            let u = vals[*j];
2969            adj[*j] += a / (1.0 - u * u).sqrt();
2970        }
2971        TapeOp::Acosh(j) => {
2972            let u = vals[*j];
2973            adj[*j] += a / (u * u - 1.0).sqrt();
2974        }
2975        TapeOp::Asinh(j) => {
2976            let u = vals[*j];
2977            adj[*j] += a / (u * u + 1.0).sqrt();
2978        }
2979        TapeOp::Atanh(j) => {
2980            let u = vals[*j];
2981            adj[*j] += a / (1.0 - u * u);
2982        }
2983        TapeOp::Atan2(l, r) => {
2984            let y = vals[*l];
2985            let x = vals[*r];
2986            let d = y * y + x * x;
2987            adj[*l] += a * (x / d);
2988            adj[*r] += a * (-y / d);
2989        }
2990        TapeOp::Cmp(_, _, _)
2991        | TapeOp::And(_, _)
2992        | TapeOp::Or(_, _)
2993        | TapeOp::Not(_)
2994        | TapeOp::Select(_, _, _)
2995        | TapeOp::Min(_, _)
2996        | TapeOp::Max(_, _) => panic!(
2997            "GlobalTape free-function kernels do not implement conditional / logical \
2998             / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2999             instead."
3000        ),
3001        TapeOp::Funcall(fc) => {
3002            let FuncallData { lib, name, args } = fc.as_ref();
3003            let call_args = funcall_to_ext_args(args, vals);
3004            let res = lib
3005                .eval(name, &call_args, true, false)
3006                .unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
3007            let derivs = res.derivs.expect("want_derivs=true returns derivs");
3008            let mut k = 0usize;
3009            for arg in args {
3010                if let TapeFuncallArg::Tape(idx) = arg {
3011                    adj[*idx] += a * derivs[k];
3012                    k += 1;
3013                }
3014            }
3015            let _ = i;
3016            let _ = grad;
3017        }
3018    }
3019}
3020
3021#[inline]
3022fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
3023    match op {
3024        TapeOp::Const(_) => 0.0,
3025        TapeOp::Var(k) => {
3026            if *k == seed_var {
3027                1.0
3028            } else {
3029                0.0
3030            }
3031        }
3032        TapeOp::Add(a, b) => dot[*a] + dot[*b],
3033        TapeOp::Sub(a, b) => dot[*a] - dot[*b],
3034        TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
3035        TapeOp::Div(a, b) => {
3036            let vb = vals[*b];
3037            (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
3038        }
3039        TapeOp::Pow(a, b) => {
3040            let u = vals[*a];
3041            let r = vals[*b];
3042            let du = dot[*a];
3043            let dr = dot[*b];
3044            let mut result = 0.0;
3045            // Match the reverse-mode gradient's guard (`rv != 0.0` only): at base
3046            // u == 0 the slope is still well defined for r >= 1 (and a genuine
3047            // ±inf for r < 1), so it must not be silently dropped, or the forward
3048            // tangent disagrees with the reverse gradient.
3049            if r != 0.0 {
3050                result += r * u.powf(r - 1.0) * du;
3051            }
3052            if u > 0.0 {
3053                result += vals[i] * u.ln() * dr;
3054            }
3055            result
3056        }
3057        TapeOp::Neg(a) => -dot[*a],
3058        TapeOp::Abs(a) => {
3059            if vals[*a] >= 0.0 {
3060                dot[*a]
3061            } else {
3062                -dot[*a]
3063            }
3064        }
3065        TapeOp::Sqrt(a) => {
3066            let sv = vals[i];
3067            if sv > 0.0 {
3068                dot[*a] * 0.5 / sv
3069            } else {
3070                0.0
3071            }
3072        }
3073        TapeOp::Exp(a) => dot[*a] * vals[i],
3074        TapeOp::Log(a) => dot[*a] / vals[*a],
3075        TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
3076        TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
3077        TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
3078        TapeOp::Tan(a) => {
3079            let t = vals[i];
3080            dot[*a] * (1.0 + t * t)
3081        }
3082        TapeOp::Atan(a) => {
3083            let u = vals[*a];
3084            dot[*a] / (1.0 + u * u)
3085        }
3086        TapeOp::Acos(a) => {
3087            let u = vals[*a];
3088            -dot[*a] / (1.0 - u * u).sqrt()
3089        }
3090        TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
3091        TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
3092        TapeOp::Tanh(a) => {
3093            let t = vals[i];
3094            dot[*a] * (1.0 - t * t)
3095        }
3096        TapeOp::Asin(a) => {
3097            let u = vals[*a];
3098            dot[*a] / (1.0 - u * u).sqrt()
3099        }
3100        TapeOp::Acosh(a) => {
3101            let u = vals[*a];
3102            dot[*a] / (u * u - 1.0).sqrt()
3103        }
3104        TapeOp::Asinh(a) => {
3105            let u = vals[*a];
3106            dot[*a] / (u * u + 1.0).sqrt()
3107        }
3108        TapeOp::Atanh(a) => {
3109            let u = vals[*a];
3110            dot[*a] / (1.0 - u * u)
3111        }
3112        TapeOp::Atan2(a, b) => {
3113            let y = vals[*a];
3114            let x = vals[*b];
3115            let d = y * y + x * x;
3116            (x * dot[*a] - y * dot[*b]) / d
3117        }
3118        TapeOp::Cmp(_, _, _)
3119        | TapeOp::And(_, _)
3120        | TapeOp::Or(_, _)
3121        | TapeOp::Not(_)
3122        | TapeOp::Select(_, _, _)
3123        | TapeOp::Min(_, _)
3124        | TapeOp::Max(_, _) => panic!(
3125            "GlobalTape free-function kernels do not implement conditional / logical \
3126             / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3127             instead."
3128        ),
3129        TapeOp::Funcall(fc) => {
3130            let FuncallData { lib, name, args } = fc.as_ref();
3131            let call_args = funcall_to_ext_args(args, vals);
3132            let res = lib
3133                .eval(name, &call_args, true, false)
3134                .unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
3135            let derivs = res.derivs.expect("want_derivs=true returns derivs");
3136            let mut acc = 0.0;
3137            let mut k = 0usize;
3138            for arg in args {
3139                if let TapeFuncallArg::Tape(idx) = arg {
3140                    acc += derivs[k] * dot[*idx];
3141                    k += 1;
3142                }
3143            }
3144            let _ = seed_var;
3145            acc
3146        }
3147    }
3148}
3149
3150#[allow(clippy::too_many_arguments)]
3151#[inline]
3152fn ror_step(
3153    op: &TapeOp,
3154    i: usize,
3155    seed_var: usize,
3156    vals: &[f64],
3157    dot: &[f64],
3158    adj: &mut [f64],
3159    adj_dot: &mut [f64],
3160    w: f64,
3161    wd: f64,
3162    weight: f64,
3163    hess_map: &HashMap<(usize, usize), usize>,
3164    values: &mut [f64],
3165) {
3166    match op {
3167        TapeOp::Const(_) => {}
3168        TapeOp::Var(k) => {
3169            if wd != 0.0 && *k >= seed_var {
3170                if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
3171                    values[pos] += weight * wd;
3172                }
3173            }
3174        }
3175        TapeOp::Add(a, b) => {
3176            adj[*a] += w;
3177            adj[*b] += w;
3178            adj_dot[*a] += wd;
3179            adj_dot[*b] += wd;
3180        }
3181        TapeOp::Sub(a, b) => {
3182            adj[*a] += w;
3183            adj[*b] -= w;
3184            adj_dot[*a] += wd;
3185            adj_dot[*b] -= wd;
3186        }
3187        TapeOp::Mul(a, b) => {
3188            adj[*a] += w * vals[*b];
3189            adj[*b] += w * vals[*a];
3190            adj_dot[*a] += wd * vals[*b] + w * dot[*b];
3191            adj_dot[*b] += wd * vals[*a] + w * dot[*a];
3192        }
3193        TapeOp::Div(a, b) => {
3194            let vb = vals[*b];
3195            let vb2 = vb * vb;
3196            let vb3 = vb2 * vb;
3197            adj[*a] += w / vb;
3198            adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
3199            adj[*b] += w * (-vals[*a] / vb2);
3200            adj_dot[*b] +=
3201                wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
3202        }
3203        TapeOp::Pow(a, b) => {
3204            let u = vals[*a];
3205            let r = vals[*b];
3206            let du = dot[*a];
3207            let dr = dot[*b];
3208            if r != 0.0 {
3209                if u != 0.0 {
3210                    let p_a = r * u.powf(r - 1.0);
3211                    adj[*a] += w * p_a;
3212                    let mut dp_a = dr * u.powf(r - 1.0);
3213                    if u > 0.0 {
3214                        dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
3215                    } else {
3216                        dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
3217                    }
3218                    adj_dot[*a] += wd * p_a + w * dp_a;
3219                } else if r >= 2.0 {
3220                    let p_a = 0.0;
3221                    adj[*a] += w * p_a;
3222                    let dp_a = if r == 2.0 {
3223                        2.0 * du
3224                    } else {
3225                        r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
3226                    };
3227                    adj_dot[*a] += wd * p_a + w * dp_a;
3228                }
3229            }
3230            if u > 0.0 {
3231                let ln_u = u.ln();
3232                let p_b = vals[i] * ln_u;
3233                adj[*b] += w * p_b;
3234                let dur = vals[i] * (r * du / u + dr * ln_u);
3235                let dp_b = dur * ln_u + vals[i] * du / u;
3236                adj_dot[*b] += wd * p_b + w * dp_b;
3237            }
3238        }
3239        TapeOp::Neg(a) => {
3240            adj[*a] -= w;
3241            adj_dot[*a] -= wd;
3242        }
3243        TapeOp::Abs(a) => {
3244            let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
3245            adj[*a] += w * s;
3246            adj_dot[*a] += wd * s;
3247        }
3248        TapeOp::Sqrt(a) => {
3249            let sv = vals[i];
3250            if sv > 0.0 {
3251                let fp = 0.5 / sv;
3252                let fpp = -0.25 / (vals[*a] * sv);
3253                adj[*a] += w * fp;
3254                adj_dot[*a] += wd * fp + w * fpp * dot[*a];
3255            }
3256        }
3257        TapeOp::Exp(a) => {
3258            let ev = vals[i];
3259            adj[*a] += w * ev;
3260            adj_dot[*a] += wd * ev + w * ev * dot[*a];
3261        }
3262        TapeOp::Log(a) => {
3263            let u = vals[*a];
3264            adj[*a] += w / u;
3265            adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
3266        }
3267        TapeOp::Log10(a) => {
3268            let u = vals[*a];
3269            let c = std::f64::consts::LN_10;
3270            adj[*a] += w / (u * c);
3271            adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
3272        }
3273        TapeOp::Sin(a) => {
3274            let u = vals[*a];
3275            let cu = u.cos();
3276            adj[*a] += w * cu;
3277            adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
3278        }
3279        TapeOp::Cos(a) => {
3280            let u = vals[*a];
3281            let su = u.sin();
3282            adj[*a] -= w * su;
3283            adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
3284        }
3285        TapeOp::Tan(a) => {
3286            let t = vals[i];
3287            let gp = 1.0 + t * t;
3288            let gpp = 2.0 * t * gp;
3289            adj[*a] += w * gp;
3290            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3291        }
3292        TapeOp::Atan(a) => {
3293            let u = vals[*a];
3294            let d = 1.0 + u * u;
3295            let gp = 1.0 / d;
3296            let gpp = -2.0 * u / (d * d);
3297            adj[*a] += w * gp;
3298            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3299        }
3300        TapeOp::Acos(a) => {
3301            let u = vals[*a];
3302            let s = 1.0 - u * u;
3303            let r = s.sqrt();
3304            let gp = -1.0 / r;
3305            let gpp = -u / (s * r);
3306            adj[*a] += w * gp;
3307            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3308        }
3309        TapeOp::Sinh(a) => {
3310            let u = vals[*a];
3311            let gp = u.cosh();
3312            let gpp = vals[i]; // sinh(u)
3313            adj[*a] += w * gp;
3314            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3315        }
3316        TapeOp::Cosh(a) => {
3317            let u = vals[*a];
3318            let gp = u.sinh();
3319            let gpp = vals[i]; // cosh(u)
3320            adj[*a] += w * gp;
3321            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3322        }
3323        TapeOp::Tanh(a) => {
3324            let t = vals[i];
3325            let gp = 1.0 - t * t;
3326            let gpp = -2.0 * t * gp;
3327            adj[*a] += w * gp;
3328            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3329        }
3330        TapeOp::Asin(a) => {
3331            let u = vals[*a];
3332            let s = 1.0 - u * u;
3333            let r = s.sqrt();
3334            let gp = 1.0 / r;
3335            let gpp = u / (s * r);
3336            adj[*a] += w * gp;
3337            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3338        }
3339        TapeOp::Acosh(a) => {
3340            let u = vals[*a];
3341            let s = u * u - 1.0;
3342            let r = s.sqrt();
3343            let gp = 1.0 / r;
3344            let gpp = -u / (s * r);
3345            adj[*a] += w * gp;
3346            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3347        }
3348        TapeOp::Asinh(a) => {
3349            let u = vals[*a];
3350            let s = u * u + 1.0;
3351            let r = s.sqrt();
3352            let gp = 1.0 / r;
3353            let gpp = -u / (s * r);
3354            adj[*a] += w * gp;
3355            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3356        }
3357        TapeOp::Atanh(a) => {
3358            let u = vals[*a];
3359            let d = 1.0 - u * u;
3360            let gp = 1.0 / d;
3361            let gpp = 2.0 * u / (d * d);
3362            adj[*a] += w * gp;
3363            adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3364        }
3365        TapeOp::Atan2(a, b) => {
3366            let y = vals[*a];
3367            let x = vals[*b];
3368            let d = y * y + x * x;
3369            let d2 = d * d;
3370            let fa = x / d;
3371            let fb = -y / d;
3372            let faa = -2.0 * x * y / d2;
3373            let fab = (y * y - x * x) / d2;
3374            let fbb = 2.0 * x * y / d2;
3375            adj[*a] += w * fa;
3376            adj[*b] += w * fb;
3377            adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
3378            adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
3379        }
3380        TapeOp::Cmp(_, _, _)
3381        | TapeOp::And(_, _)
3382        | TapeOp::Or(_, _)
3383        | TapeOp::Not(_)
3384        | TapeOp::Select(_, _, _)
3385        | TapeOp::Min(_, _)
3386        | TapeOp::Max(_, _) => panic!(
3387            "GlobalTape free-function kernels do not implement conditional / logical \
3388             / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3389             instead."
3390        ),
3391        TapeOp::Funcall(fc) => {
3392            let FuncallData { lib, name, args } = fc.as_ref();
3393            let call_args = funcall_to_ext_args(args, vals);
3394            let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
3395                panic!("external function '{name}' 2nd-order eval failed: {e}")
3396            });
3397            let derivs = res.derivs.expect("want_derivs=true returns derivs");
3398            let hes = res.hessian.expect("want_hes=true returns hessian");
3399            let real_tape: Vec<usize> = args
3400                .iter()
3401                .filter_map(|a| match a {
3402                    TapeFuncallArg::Tape(t) => Some(*t),
3403                    TapeFuncallArg::Str(_) => None,
3404                })
3405                .collect();
3406            for (k, &tk) in real_tape.iter().enumerate() {
3407                adj[tk] += w * derivs[k];
3408                let mut second_term = 0.0;
3409                for (l, &tl) in real_tape.iter().enumerate() {
3410                    let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
3411                    let h_kl = hes[lo + hi * (hi + 1) / 2];
3412                    second_term += h_kl * dot[tl];
3413                }
3414                adj_dot[tk] += wd * derivs[k] + w * second_term;
3415            }
3416            let _ = seed_var;
3417            let _ = hess_map;
3418            let _ = values;
3419            let _ = weight;
3420            let _ = i;
3421        }
3422    }
3423}
3424
3425/// Per-op Hessian-sparsity propagation. Same algorithm as
3426/// `Tape::hessian_sparsity` but as a free function so `GlobalTape`
3427/// can call it over its shared `ops` slice.
3428fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
3429    let n = ops.len();
3430    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
3431    let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
3432
3433    let emit_cross =
3434        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3435            for &v1 in s1 {
3436                for &v2 in s2 {
3437                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
3438                    pairs.insert((r, c));
3439                }
3440            }
3441        };
3442    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3443        let vars: Vec<usize> = s.iter().copied().collect();
3444        for (ai, &vi) in vars.iter().enumerate() {
3445            for &vj in &vars[..=ai] {
3446                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3447                pairs.insert((r, c));
3448            }
3449        }
3450    };
3451
3452    for op in ops {
3453        let vset = match op {
3454            TapeOp::Const(_) => BTreeSet::new(),
3455            TapeOp::Var(j) => {
3456                let mut s = BTreeSet::new();
3457                s.insert(*j);
3458                s
3459            }
3460            TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
3461                var_sets[*a].union(&var_sets[*b]).copied().collect()
3462            }
3463            TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
3464            TapeOp::Mul(a, b) => {
3465                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3466                var_sets[*a].union(&var_sets[*b]).copied().collect()
3467            }
3468            TapeOp::Div(a, b) => {
3469                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3470                emit_self(&var_sets[*b], &mut pairs);
3471                var_sets[*a].union(&var_sets[*b]).copied().collect()
3472            }
3473            TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
3474                let combined: BTreeSet<usize> =
3475                    var_sets[*a].union(&var_sets[*b]).copied().collect();
3476                emit_self(&combined, &mut pairs);
3477                combined
3478            }
3479            TapeOp::Sqrt(a)
3480            | TapeOp::Exp(a)
3481            | TapeOp::Log(a)
3482            | TapeOp::Log10(a)
3483            | TapeOp::Sin(a)
3484            | TapeOp::Cos(a)
3485            | TapeOp::Tan(a)
3486            | TapeOp::Atan(a)
3487            | TapeOp::Acos(a)
3488            | TapeOp::Sinh(a)
3489            | TapeOp::Cosh(a)
3490            | TapeOp::Tanh(a)
3491            | TapeOp::Asin(a)
3492            | TapeOp::Acosh(a)
3493            | TapeOp::Asinh(a)
3494            | TapeOp::Atanh(a) => {
3495                emit_self(&var_sets[*a], &mut pairs);
3496                var_sets[*a].clone()
3497            }
3498            TapeOp::Funcall(fc) => {
3499                let args = &fc.args;
3500                let mut combined: BTreeSet<usize> = BTreeSet::new();
3501                for arg in args {
3502                    if let TapeFuncallArg::Tape(t) = arg {
3503                        for &vv in &var_sets[*t] {
3504                            combined.insert(vv);
3505                        }
3506                    }
3507                }
3508                emit_self(&combined, &mut pairs);
3509                combined
3510            }
3511            TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
3512                // Comparisons / logical ops have identically-zero derivative, so
3513                // they contribute no Hessian structure.
3514                BTreeSet::new()
3515            }
3516            TapeOp::Select(_, t, e) => {
3517                // Either branch may be active; the structural superset is the
3518                // union of both branches' variable sets.
3519                var_sets[*t].union(&var_sets[*e]).copied().collect()
3520            }
3521            TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
3522                // min/max are piecewise linear: zero second derivative (no
3523                // pairs); dependence set is the union of both operands.
3524                var_sets[*a].union(&var_sets[*b]).copied().collect()
3525            }
3526        };
3527        var_sets.push(vset);
3528    }
3529    pairs
3530}
3531
3532#[cfg(test)]
3533mod tests {
3534    use super::*;
3535
3536    fn cnst(c: f64) -> Expr {
3537        Expr::Const(c)
3538    }
3539    fn var(i: usize) -> Expr {
3540        Expr::Var(i)
3541    }
3542    fn add(a: Expr, b: Expr) -> Expr {
3543        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
3544    }
3545    fn mul(a: Expr, b: Expr) -> Expr {
3546        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
3547    }
3548    fn pow(a: Expr, b: Expr) -> Expr {
3549        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
3550    }
3551    fn div(a: Expr, b: Expr) -> Expr {
3552        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
3553    }
3554    fn unary(op: UnaryOp, a: Expr) -> Expr {
3555        Expr::Unary(op, Box::new(a))
3556    }
3557    fn cmp(op: CmpOp, a: Expr, b: Expr) -> Expr {
3558        Expr::Compare(op, Box::new(a), Box::new(b))
3559    }
3560    fn cond(c: Expr, t: Expr, e: Expr) -> Expr {
3561        Expr::Cond {
3562            cond: Box::new(c),
3563            then_: Box::new(t),
3564            else_: Box::new(e),
3565        }
3566    }
3567
3568    #[test]
3569    fn polynomial_eval_and_grad() {
3570        // f = 3*x0^2 + 2*x1
3571        let e = add(
3572            mul(cnst(3.0), pow(var(0), cnst(2.0))),
3573            mul(cnst(2.0), var(1)),
3574        );
3575        let t = Tape::build(&e);
3576        assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
3577        let mut g = vec![0.0; 2];
3578        t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
3579        // df/dx0 = 6*x0 = 12, df/dx1 = 2
3580        assert!((g[0] - 12.0).abs() < 1e-12);
3581        assert!((g[1] - 2.0).abs() < 1e-12);
3582    }
3583
3584    #[test]
3585    fn cse_shared_body_evaluated_once() {
3586        // body = x0 + x1, shared via Arc. f = body^2 + body.
3587        let body = Arc::new(add(var(0), var(1)));
3588        let e = add(
3589            pow(Expr::Cse(body.clone()), cnst(2.0)),
3590            Expr::Cse(body.clone()),
3591        );
3592        let t = Tape::build(&e);
3593        // body should appear once in the tape: count Add(Var(0),Var(1)) ops
3594        let n_body_adds = t
3595            .ops
3596            .iter()
3597            .filter(|op| {
3598                matches!(op, TapeOp::Add(a, b) if {
3599                    matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
3600                })
3601            })
3602            .count();
3603        assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
3604
3605        // f(1, 2) = 9 + 3 = 12
3606        assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
3607        let mut g = vec![0.0; 2];
3608        t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
3609        // df/dx0 = 2*(x0+x1) + 1 = 7, same for x1
3610        assert!((g[0] - 7.0).abs() < 1e-12);
3611        assert!((g[1] - 7.0).abs() < 1e-12);
3612    }
3613
3614    fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
3615        let vars = tape.variables();
3616        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3617        let mut pairs = Vec::new();
3618        for (ai, &vi) in vars.iter().enumerate() {
3619            for &vj in &vars[..=ai] {
3620                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3621                hess_map.entry((r, c)).or_insert_with(|| {
3622                    let p = pairs.len();
3623                    pairs.push((r, c));
3624                    p
3625                });
3626            }
3627        }
3628        let nnz = pairs.len();
3629        let mut ad = vec![0.0; nnz];
3630        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3631
3632        let mut fd = vec![0.0; nnz];
3633        let mut xp = x.to_vec();
3634        let mut gp = vec![0.0; n];
3635        let mut gm = vec![0.0; n];
3636        for &j in &vars {
3637            let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3638            xp[j] = x[j] + h;
3639            gp.iter_mut().for_each(|v| *v = 0.0);
3640            tape.gradient_seed(&xp, 1.0, &mut gp);
3641            xp[j] = x[j] - h;
3642            gm.iter_mut().for_each(|v| *v = 0.0);
3643            tape.gradient_seed(&xp, 1.0, &mut gm);
3644            xp[j] = x[j];
3645            for &i in &vars {
3646                if i >= j {
3647                    if let Some(&pos) = hess_map.get(&(i, j)) {
3648                        fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
3649                    }
3650                }
3651            }
3652        }
3653        for (k, &(r, c)) in pairs.iter().enumerate() {
3654            let scale = fd[k].abs().max(1.0);
3655            assert!(
3656                (ad[k] - fd[k]).abs() / scale < tol,
3657                "H[{},{}]: AD={:.6e} FD={:.6e}",
3658                r,
3659                c,
3660                ad[k],
3661                fd[k]
3662            );
3663        }
3664    }
3665
3666    #[test]
3667    fn hessian_quadratic_matches_fd() {
3668        // f = 3 x0^2 + 2 x0 x1 + x1^2
3669        let e = add(
3670            add(
3671                mul(cnst(3.0), pow(var(0), cnst(2.0))),
3672                mul(cnst(2.0), mul(var(0), var(1))),
3673            ),
3674            pow(var(1), cnst(2.0)),
3675        );
3676        let t = Tape::build(&e);
3677        fd_check(&t, &[2.0, 3.0], 2, 1e-5);
3678    }
3679
3680    #[test]
3681    fn hessian_transcendental_matches_fd() {
3682        // f = exp(x0) + sin(x1) + log(x0) + sqrt(x1) + x0*x1
3683        let e = Expr::Sum(vec![
3684            unary(UnaryOp::Exp, var(0)),
3685            unary(UnaryOp::Sin, var(1)),
3686            unary(UnaryOp::Log, var(0)),
3687            unary(UnaryOp::Sqrt, var(1)),
3688            mul(var(0), var(1)),
3689        ]);
3690        let t = Tape::build(&e);
3691        fd_check(&t, &[1.5, 2.0], 2, 1e-5);
3692    }
3693
3694    #[test]
3695    fn inverse_trig_grad_and_hessian_match_fd() {
3696        // f = tan(x0) + atan(x1) + acos(x2) + x0*x1
3697        // Point chosen so every op is in its smooth domain:
3698        // tan away from pi/2, acos arg in (-1, 1).
3699        let e = Expr::Sum(vec![
3700            unary(UnaryOp::Tan, var(0)),
3701            unary(UnaryOp::Atan, var(1)),
3702            unary(UnaryOp::Acos, var(2)),
3703            mul(var(0), var(1)),
3704        ]);
3705        let t = Tape::build(&e);
3706        let x = [0.5, 1.3, 0.3];
3707
3708        // Gradient vs central finite difference of the value. This
3709        // pins the first derivatives independently of the Hessian
3710        // (fd_check only ties the Hessian to the AD gradient).
3711        let mut g = vec![0.0; 3];
3712        t.gradient_seed(&x, 1.0, &mut g);
3713        for j in 0..3 {
3714            let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3715            let mut xp = x;
3716            let mut xm = x;
3717            xp[j] += h;
3718            xm[j] -= h;
3719            let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3720            let scale = fd.abs().max(1.0);
3721            assert!(
3722                (g[j] - fd).abs() / scale < 1e-5,
3723                "grad[{j}]: AD={:.6e} FD={:.6e}",
3724                g[j],
3725                fd
3726            );
3727        }
3728
3729        // Hessian (forward-over-reverse) vs FD of the gradient.
3730        fd_check(&t, &x, 3, 1e-5);
3731    }
3732
3733    /// Shared helper: check AD gradient vs central FD of the value at
3734    /// `x`, then the Hessian via `fd_check`.
3735    fn grad_and_hess_match_fd(e: &Expr, x: &[f64], tol: f64) {
3736        let n = x.len();
3737        let t = Tape::build(e);
3738        let mut g = vec![0.0; n];
3739        t.gradient_seed(x, 1.0, &mut g);
3740        for j in 0..n {
3741            let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3742            let mut xp = x.to_vec();
3743            let mut xm = x.to_vec();
3744            xp[j] += h;
3745            xm[j] -= h;
3746            let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3747            let scale = fd.abs().max(1.0);
3748            assert!(
3749                (g[j] - fd).abs() / scale < tol,
3750                "grad[{j}]: AD={:.6e} FD={:.6e}",
3751                g[j],
3752                fd
3753            );
3754        }
3755        fd_check(&t, x, n, tol);
3756    }
3757
3758    #[test]
3759    fn hyperbolic_grad_and_hessian_match_fd() {
3760        // f = sinh(x0) + cosh(x1) + tanh(x2) + asinh(x3) + x0*x1 + x2*x3
3761        // sinh/cosh/tanh/asinh are smooth on all of R.
3762        let e = Expr::Sum(vec![
3763            unary(UnaryOp::Sinh, var(0)),
3764            unary(UnaryOp::Cosh, var(1)),
3765            unary(UnaryOp::Tanh, var(2)),
3766            unary(UnaryOp::Asinh, var(3)),
3767            mul(var(0), var(1)),
3768            mul(var(2), var(3)),
3769        ]);
3770        grad_and_hess_match_fd(&e, &[0.5, 0.7, 0.3, 1.1], 1e-5);
3771    }
3772
3773    #[test]
3774    fn restricted_inverse_grad_and_hessian_match_fd() {
3775        // f = asin(x0) + acosh(x1) + atanh(x2) + x0*x2
3776        // Point chosen in each op's smooth domain:
3777        // asin/atanh need |arg| < 1; acosh needs arg > 1.
3778        let e = Expr::Sum(vec![
3779            unary(UnaryOp::Asin, var(0)),
3780            unary(UnaryOp::Acosh, var(1)),
3781            unary(UnaryOp::Atanh, var(2)),
3782            mul(var(0), var(2)),
3783        ]);
3784        grad_and_hess_match_fd(&e, &[0.4, 1.8, 0.3], 1e-5);
3785    }
3786
3787    #[test]
3788    fn atan2_grad_and_hessian_match_fd() {
3789        // f = atan2(x0, x1) + x0*x1, away from the origin.
3790        let atan2 = |a: Expr, b: Expr| Expr::Binary(BinOp::Atan2, Box::new(a), Box::new(b));
3791        let e = Expr::Sum(vec![atan2(var(0), var(1)), mul(var(0), var(1))]);
3792        grad_and_hess_match_fd(&e, &[1.2, 0.7], 1e-5);
3793    }
3794
3795    #[test]
3796    fn minmax_grad_and_hessian_match_fd() {
3797        // f = min(x0, x1, x2) + max(x1, x2) + x0*x2
3798        // Point chosen so each list has a UNIQUE strictly-active
3799        // operand, so the subgradient equals the FD slope (the ±h
3800        // probes never cross a kink):
3801        //   min(0.5, 3.0, 2.0) = 0.5  -> active x0
3802        //   max(3.0, 2.0)      = 3.0  -> active x1
3803        let e = Expr::Sum(vec![
3804            Expr::MinList(vec![var(0), var(1), var(2)]),
3805            Expr::MaxList(vec![var(1), var(2)]),
3806            mul(var(0), var(2)),
3807        ]);
3808        grad_and_hess_match_fd(&e, &[0.5, 3.0, 2.0], 1e-5);
3809    }
3810
3811    #[test]
3812    fn minmax_value_and_active_operand() {
3813        // Spot-check the value and that the gradient routes entirely
3814        // through the active operand (zero second derivative).
3815        let e = Expr::Sum(vec![
3816            Expr::MinList(vec![var(0), var(1)]),
3817            Expr::MaxList(vec![var(0), var(1)]),
3818        ]);
3819        let t = Tape::build(&e);
3820        // min(x0,x1) + max(x0,x1) == x0 + x1 for any inputs.
3821        let x = [1.3, -0.4];
3822        assert!((t.eval(&x) - (x[0] + x[1])).abs() < 1e-12);
3823        let mut g = vec![0.0; 2];
3824        t.gradient_seed(&x, 1.0, &mut g);
3825        // min active = x1 (smaller), max active = x0 (larger):
3826        // d/dx0 = 1 (from max), d/dx1 = 1 (from min).
3827        assert!((g[0] - 1.0).abs() < 1e-12, "g0={}", g[0]);
3828        assert!((g[1] - 1.0).abs() < 1e-12, "g1={}", g[1]);
3829    }
3830
3831    #[test]
3832    fn hessian_division_matches_fd() {
3833        // f = x0/x1 + cos(x0)
3834        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3835        let t = Tape::build(&e);
3836        fd_check(&t, &[0.5, 1.2], 2, 1e-5);
3837    }
3838
3839    #[test]
3840    fn conditional_value_grad_hessian_active_branch() {
3841        // f = if x0 >= 1 then x0*x1 else x1^2
3842        // The if-then-else differentiates only the active branch; the
3843        // condition (a comparison) contributes no derivative.
3844        let e = cond(
3845            cmp(CmpOp::Ge, var(0), cnst(1.0)),
3846            mul(var(0), var(1)),
3847            pow(var(1), cnst(2.0)),
3848        );
3849        let t = Tape::build(&e);
3850
3851        // x0 = 2 (>= 1) -> "then" branch x0*x1 is active.
3852        let x = [2.0, 5.0];
3853        assert!((t.eval(&x) - 10.0).abs() < 1e-12);
3854        let mut g = vec![0.0; 2];
3855        t.gradient_seed(&x, 1.0, &mut g);
3856        // d(x0*x1) = (x1, x0) = (5, 2)
3857        assert!((g[0] - 5.0).abs() < 1e-10);
3858        assert!((g[1] - 2.0).abs() < 1e-10);
3859        // H[0,1] = 1, diagonals 0. (Stay clear of the x0 = 1 kink.)
3860        fd_check(&t, &x, 2, 1e-5);
3861
3862        // x0 = 0 (< 1) -> "else" branch x1^2 is active; x0 drops out.
3863        let x2 = [0.0, 5.0];
3864        assert!((t.eval(&x2) - 25.0).abs() < 1e-12);
3865        let mut g2 = vec![0.0; 2];
3866        t.gradient_seed(&x2, 1.0, &mut g2);
3867        assert!(g2[0].abs() < 1e-10);
3868        assert!((g2[1] - 10.0).abs() < 1e-10);
3869        fd_check(&t, &x2, 2, 1e-5);
3870    }
3871
3872    #[test]
3873    fn comparison_and_logical_have_zero_derivative() {
3874        // f = (x0 < x1) + (x0 > 0 && x1 > 0) + !(x0 == x1)
3875        // Every term is piecewise-constant in the variables, so the
3876        // gradient must be identically zero away from the kinks.
3877        let lt = cmp(CmpOp::Lt, var(0), var(1));
3878        let and = Expr::And(
3879            Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3880            Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3881        );
3882        let notc = Expr::Not(Box::new(cmp(CmpOp::Eq, var(0), var(1))));
3883        let e = add(add(lt, and), notc);
3884        let t = Tape::build(&e);
3885
3886        let x = [1.0, 2.0];
3887        // 1 (1<2) + 1 (both > 0) + 1 (1 != 2) = 3
3888        assert!((t.eval(&x) - 3.0).abs() < 1e-12);
3889        let mut g = vec![0.0; 2];
3890        t.gradient_seed(&x, 1.0, &mut g);
3891        assert!(g[0].abs() < 1e-12, "d/dx0 should be 0, got {}", g[0]);
3892        assert!(g[1].abs() < 1e-12, "d/dx1 should be 0, got {}", g[1]);
3893    }
3894
3895    #[test]
3896    fn logical_or_value() {
3897        // f = (x0 > 0 || x1 > 0)
3898        let e = Expr::Or(
3899            Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3900            Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3901        );
3902        let t = Tape::build(&e);
3903        assert!((t.eval(&[-1.0, 3.0]) - 1.0).abs() < 1e-12);
3904        assert!((t.eval(&[-1.0, -3.0]) - 0.0).abs() < 1e-12);
3905    }
3906
3907    /// `hessian_directional` (one forward-over-reverse pass with
3908    /// a seed vector) recovers `H · e_j` for each unit-vector seed,
3909    /// matching column `j` of the dense Hessian computed by
3910    /// `hessian_accumulate`.
3911    fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
3912        let vars = tape.variables();
3913        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3914        let mut pairs = Vec::new();
3915        for (ai, &vi) in vars.iter().enumerate() {
3916            for &vj in &vars[..=ai] {
3917                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3918                hess_map.entry((r, c)).or_insert_with(|| {
3919                    let p = pairs.len();
3920                    pairs.push((r, c));
3921                    p
3922                });
3923            }
3924        }
3925        let nnz = pairs.len();
3926        let mut ad = vec![0.0; nnz];
3927        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3928
3929        let nops = tape.ops.len();
3930        let mut vals = vec![0.0; nops];
3931        tape.forward_into(x, &mut vals);
3932        let mut dot = vec![0.0; nops];
3933        let mut adj = vec![0.0; nops];
3934        let mut adj_dot = vec![0.0; nops];
3935
3936        for &j in &vars {
3937            let mut seed = vec![0.0; n];
3938            seed[j] = 1.0;
3939            let mut col = vec![0.0; n];
3940            tape.hessian_directional(
3941                &vals,
3942                &seed,
3943                1.0,
3944                &mut col,
3945                &mut dot,
3946                &mut adj,
3947                &mut adj_dot,
3948            );
3949            for &i in &vars {
3950                let (r, c) = if i >= j { (i, j) } else { (j, i) };
3951                let expect = ad[hess_map[&(r, c)]];
3952                assert!(
3953                    (col[i] - expect).abs() < 1e-10,
3954                    "directional H[{i},{j}] = {} vs accumulate {}",
3955                    col[i],
3956                    expect
3957                );
3958            }
3959        }
3960    }
3961
3962    #[test]
3963    fn directional_quadratic_matches_accumulate() {
3964        // f = 3 x0^2 + 2 x0 x1 + x1^2
3965        let e = add(
3966            add(
3967                mul(cnst(3.0), pow(var(0), cnst(2.0))),
3968                mul(mul(cnst(2.0), var(0)), var(1)),
3969            ),
3970            pow(var(1), cnst(2.0)),
3971        );
3972        let t = Tape::build(&e);
3973        directional_matches_accumulate(&t, &[0.5, -0.3], 2);
3974    }
3975
3976    #[test]
3977    fn directional_transcendental_matches_accumulate() {
3978        let e = Expr::Sum(vec![
3979            unary(UnaryOp::Exp, var(0)),
3980            unary(UnaryOp::Sin, var(1)),
3981            unary(UnaryOp::Log, var(0)),
3982            unary(UnaryOp::Sqrt, var(1)),
3983            mul(var(0), var(1)),
3984        ]);
3985        let t = Tape::build(&e);
3986        directional_matches_accumulate(&t, &[1.5, 2.0], 2);
3987    }
3988
3989    #[test]
3990    fn directional_with_division_matches_accumulate() {
3991        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3992        let t = Tape::build(&e);
3993        directional_matches_accumulate(&t, &[0.5, 1.2], 2);
3994    }
3995
3996    #[test]
3997    fn hessian_sparsity_separable() {
3998        // f = sin(x0) + x1*x2; couplings: (0,0) from sin, (2,1) from x1*x2
3999        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
4000        let t = Tape::build(&e);
4001        let s = t.hessian_sparsity();
4002        assert!(s.contains(&(0, 0)));
4003        assert!(s.contains(&(2, 1)));
4004        assert!(!s.contains(&(1, 0)));
4005        assert!(!s.contains(&(2, 0)));
4006    }
4007
4008    fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
4009        t.ops.iter().filter(|o| pred(o)).count()
4010    }
4011
4012    #[test]
4013    fn pow_zero_const_folds_to_one() {
4014        // x^0 → 1 (no Pow, no reference to x in the tape)
4015        let e = pow(var(0), cnst(0.0));
4016        let t = Tape::build(&e);
4017        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4018        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
4019        assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
4020    }
4021
4022    #[test]
4023    fn pow_one_passes_through() {
4024        // x^1 → x (no Pow, no Const introduced for the exponent)
4025        let e = pow(var(0), cnst(1.0));
4026        let t = Tape::build(&e);
4027        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4028        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
4029        assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
4030    }
4031
4032    #[test]
4033    fn pow_half_lowers_to_sqrt() {
4034        let e = pow(var(0), cnst(0.5));
4035        let t = Tape::build(&e);
4036        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4037        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
4038        assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
4039    }
4040
4041    #[test]
4042    fn pow_two_lowers_to_single_mul() {
4043        let e = pow(var(0), cnst(2.0));
4044        let t = Tape::build(&e);
4045        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4046        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
4047        assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
4048    }
4049
4050    #[test]
4051    fn pow_three_lowers_to_two_muls() {
4052        let e = pow(var(0), cnst(3.0));
4053        let t = Tape::build(&e);
4054        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4055        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
4056        assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
4057    }
4058
4059    #[test]
4060    fn pow_eight_lowers_to_three_muls() {
4061        // Binary expansion: x → x² → x⁴ → x⁸ (3 squarings)
4062        let e = pow(var(0), cnst(8.0));
4063        let t = Tape::build(&e);
4064        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4065        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
4066        assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
4067    }
4068
4069    #[test]
4070    fn pow_negative_two_lowers_to_div() {
4071        // x^-2 → 1 / (x*x)
4072        let e = pow(var(0), cnst(-2.0));
4073        let t = Tape::build(&e);
4074        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4075        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
4076        assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
4077    }
4078
4079    #[test]
4080    fn pow_large_const_stays_generic() {
4081        // x^9 stays as Pow — beyond the cutoff, generic is cheaper.
4082        let e = pow(var(0), cnst(9.0));
4083        let t = Tape::build(&e);
4084        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
4085    }
4086
4087    #[test]
4088    fn pow_non_integer_const_stays_generic() {
4089        // x^1.5 stays as Pow until half-integer handling is added.
4090        let e = pow(var(0), cnst(1.5));
4091        let t = Tape::build(&e);
4092        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
4093    }
4094
4095    #[test]
4096    fn pow_const_through_cse_const() {
4097        // Exponent wrapped in Cse — peek_const should still see it.
4098        let two = Arc::new(cnst(2.0));
4099        let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
4100        let t = Tape::build(&e);
4101        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4102        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
4103    }
4104
4105    #[test]
4106    fn hessian_pow_three_matches_fd() {
4107        // f = 5 * x0^3 + x0 * x1 — exercises the lowered cubic + cross term.
4108        let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
4109        let t = Tape::build(&e);
4110        fd_check(&t, &[1.7, 0.8], 2, 1e-5);
4111    }
4112
4113    #[test]
4114    fn hessian_pow_negative_matches_fd() {
4115        // f = 1/x0^2 + x1^2 — exercises lowered x^-2 and x^2.
4116        let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
4117        let t = Tape::build(&e);
4118        fd_check(&t, &[1.3, 2.4], 2, 1e-5);
4119    }
4120
4121    #[test]
4122    fn hessian_pow_half_matches_fd() {
4123        // f = sqrt(x0) + x0*x1 (via Pow(_, 0.5) → Sqrt)
4124        let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
4125        let t = Tape::build(&e);
4126        fd_check(&t, &[2.5, 1.1], 2, 1e-5);
4127    }
4128
4129    #[test]
4130    fn hessian_sparsity_through_cse() {
4131        // body = x0+x1 (CSE). f = body^2 + body.
4132        // d²/dx² of body^2 couples (0,0), (1,0), (1,1).
4133        let body = Arc::new(add(var(0), var(1)));
4134        let e = add(
4135            pow(Expr::Cse(body.clone()), cnst(2.0)),
4136            Expr::Cse(body.clone()),
4137        );
4138        let t = Tape::build(&e);
4139        let s = t.hessian_sparsity();
4140        assert!(s.contains(&(0, 0)));
4141        assert!(s.contains(&(1, 0)));
4142        assert!(s.contains(&(1, 1)));
4143        assert_eq!(s.len(), 3);
4144    }
4145
4146    #[test]
4147    fn pow_forward_tangent_matches_reverse_gradient_at_base_zero() {
4148        // Code review L29: `Pow` first-order tangent disagreed with the
4149        // reverse-mode gradient at base 0. f = x0 ^ x1 keeps a genuine `Pow`
4150        // op (variable exponent is not lowered to a Mul/Sqrt chain). At the
4151        // `.nl` default start x0 = 0, the base derivative d/dx0 (x0^1) = 1 is
4152        // well defined; reverse mode has always computed it, but the forward
4153        // tangent used to guard on `u != 0` and drop it, so Jacobian-vector
4154        // products silently disagreed with the gradient at x = 0. After the
4155        // fix both arms must agree.
4156        let e = pow(var(0), var(1));
4157        let t = Tape::build(&e);
4158        // Guard: the op must survive as a real Pow (not lowered away), else
4159        // this test would no longer exercise the fixed branch.
4160        assert!(
4161            t.ops.iter().any(|op| matches!(op, TapeOp::Pow(_, _))),
4162            "expected a Pow op in the tape; got {:?}",
4163            t.ops
4164        );
4165        let x = [0.0, 1.0];
4166        let n = t.ops.len();
4167
4168        // Reverse-mode gradient w.r.t. x0.
4169        let mut grad = vec![0.0; 2];
4170        t.gradient_seed(&x, 1.0, &mut grad);
4171
4172        // Forward tangent seeded on x0: dot[output] = df/dx0.
4173        let vals = t.forward(&x);
4174        let mut dot = vec![0.0; n];
4175        t.forward_tangent(&vals, 0, &mut dot);
4176        let fwd_dfx0 = dot[n - 1];
4177
4178        assert!(
4179            (grad[0] - 1.0).abs() < 1e-12,
4180            "reverse gradient df/dx0 at base 0 should be 1, got {}",
4181            grad[0]
4182        );
4183        assert!(
4184            (fwd_dfx0 - grad[0]).abs() < 1e-12,
4185            "forward tangent df/dx0 = {fwd_dfx0} must match reverse gradient {} at base 0",
4186            grad[0]
4187        );
4188    }
4189
4190    #[test]
4191    #[should_panic(expected = "external function calls are not supported on the")]
4192    fn hybrid_promoted_cse_with_funcall_reports_clear_message() {
4193        // Code review L34: `HybridTape::build_multi` builds a promoted CSE
4194        // (one shared across ≥2 summands) via `build_recursive` with an empty
4195        // resolver. A funcall inside that promoted body used to panic with the
4196        // misleading "unresolved AMPL funcall id 0" — implying a resolution
4197        // failure — instead of the real reason: funcalls are unsupported on
4198        // the hybrid path. Here the funcall body is shared across two roots so
4199        // it is promoted; assert the clear hybrid-unsupported message fires.
4200        let body = Arc::new(Expr::Funcall {
4201            id: 0,
4202            args: vec![FuncallArg::Real(var(0))],
4203        });
4204        let exprs = vec![
4205            add(Expr::Cse(body.clone()), cnst(1.0)),
4206            add(Expr::Cse(body.clone()), cnst(2.0)),
4207        ];
4208        HybridTape::build_multi(&exprs);
4209    }
4210}