Skip to main content

pounce_cli/
nl_tape.rs

1//! Flat-tape reverse-mode AD for `.nl` expression trees.
2//!
3//! Replaces the FD-based Hessian path with a port of the tape AD used
4//! in `ripopt::nl::autodiff`. The tape is a `Vec<TapeOp>` where each op
5//! refers to its operands by tape-slot index; forward evaluation runs
6//! through the slice once filling a parallel `Vec<f64>` of values, and
7//! reverse-mode adjoints walk the same buffer backwards.
8//!
9//! Sparse Hessians are computed by forward-over-reverse: for each
10//! variable `j` that the tape depends on, run a forward tangent sweep
11//! seeded with `e_j`, then a second-order reverse sweep that produces
12//! column `j` of the Hessian. The caller supplies a `(row, col) -> nnz
13//! position` map (lower triangle, row >= col), and contributions are
14//! accumulated in place — the outer loop in `eval_h` calls the same
15//! map for the objective and every active constraint, so every
16//! Lagrangian term lands in the right slot.
17//!
18//! Common subexpressions are tape-emitted **once**: when the recursive
19//! builder hits `Expr::Cse(rc)` it keys on the `Rc` pointer identity,
20//! emitting the body the first time and returning the cached
21//! result-slot index on subsequent references. The forward pass then
22//! computes each CSE once and the reverse pass folds adjoints from
23//! every reference into a single slot — exact chain-rule behaviour.
24
25use std::collections::{BTreeSet, HashMap, HashSet};
26use std::rc::Rc;
27use std::sync::Arc;
28
29use super::nl_external::{ExternalArg, ExternalLibrary, ExternalResolver};
30use super::nl_reader::{BinOp, Expr, FuncallArg, UnaryOp};
31
32/// One operation in the flattened tape. Operand fields are tape-slot
33/// indices into the same tape; `Var(i)` references problem variable
34/// index `i` (read from the input `x` slice during forward).
35#[derive(Debug, Clone)]
36pub enum TapeOp {
37    Const(f64),
38    Var(usize),
39    Add(usize, usize),
40    Sub(usize, usize),
41    Mul(usize, usize),
42    Div(usize, usize),
43    Pow(usize, usize),
44    Neg(usize),
45    Abs(usize),
46    Sqrt(usize),
47    Exp(usize),
48    Log(usize),
49    Log10(usize),
50    Sin(usize),
51    Cos(usize),
52    /// AMPL imported (external) function call. The library is kept alive by
53    /// the `Arc`; `name` is the registered function name; `args` carries
54    /// positional arguments where real-valued args reference earlier tape
55    /// slots and string args are inline literals.
56    Funcall {
57        lib: Arc<ExternalLibrary>,
58        name: String,
59        args: Vec<TapeFuncallArg>,
60    },
61}
62
63/// One argument of a `TapeOp::Funcall`. Real arguments are tape-slot indices
64/// (their values come from the running `vals[]` during forward); string
65/// arguments are owned literals (AMPL `h<len>:<chars>` tokens).
66#[derive(Debug, Clone)]
67pub enum TapeFuncallArg {
68    Tape(usize),
69    Str(String),
70}
71
72fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
73    args.iter()
74        .map(|a| match a {
75            TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
76            TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
77        })
78        .collect()
79}
80
81/// A flattened expression tape. The result of evaluation is the value
82/// at slot `ops.len() - 1` (i.e. the last op).
83#[derive(Debug, Clone)]
84pub struct Tape {
85    pub ops: Vec<TapeOp>,
86}
87
88impl Tape {
89    /// Build a tape from an `Expr` tree (no AMPL external functions). CSE
90    /// bodies (`Expr::Cse(rc)`) are cached by `Rc` pointer identity so each
91    /// body is emitted once even when referenced many times.
92    pub fn build(expr: &Expr) -> Self {
93        Self::build_with_externals(expr, &ExternalResolver::default())
94    }
95
96    /// Build a tape from an `Expr` tree, resolving any `Expr::Funcall`
97    /// nodes through `resolver`. Panics if the expression references a
98    /// funcall id that is not in the resolver — `NlProblem::resolve_externals`
99    /// must populate the resolver before tape construction.
100    pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
101        let mut ops = Vec::new();
102        let mut cache: HashMap<*const Expr, usize> = HashMap::new();
103        build_recursive(expr, &mut ops, &mut cache, resolver);
104        Tape { ops }
105    }
106
107    /// Forward sweep: returns `vals[i] = value of tape slot i`. The
108    /// scalar tape result is `vals[ops.len() - 1]`.
109    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
110        let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
111        for op in &self.ops {
112            let v = match op {
113                TapeOp::Const(c) => *c,
114                TapeOp::Var(i) => x[*i],
115                TapeOp::Add(a, b) => vals[*a] + vals[*b],
116                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
117                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
118                TapeOp::Div(a, b) => vals[*a] / vals[*b],
119                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
120                TapeOp::Neg(a) => -vals[*a],
121                TapeOp::Abs(a) => vals[*a].abs(),
122                TapeOp::Sqrt(a) => vals[*a].sqrt(),
123                TapeOp::Exp(a) => vals[*a].exp(),
124                TapeOp::Log(a) => vals[*a].ln(),
125                TapeOp::Log10(a) => vals[*a].log10(),
126                TapeOp::Sin(a) => vals[*a].sin(),
127                TapeOp::Cos(a) => vals[*a].cos(),
128                TapeOp::Funcall { lib, name, args } => {
129                    let call_args = funcall_to_ext_args(args, &vals);
130                    let res = lib
131                        .eval(name, &call_args, false, false)
132                        .unwrap_or_else(|e| {
133                            panic!("external function '{name}' forward eval failed: {e}")
134                        });
135                    res.value
136                }
137            };
138            vals.push(v);
139        }
140        vals
141    }
142
143    pub fn eval(&self, x: &[f64]) -> f64 {
144        let vals = self.forward(x);
145        *vals.last().unwrap_or(&0.0)
146    }
147
148    /// Reverse-mode AD: accumulate `seed * df/dx_i` into `grad[i]` for
149    /// every problem variable `i` referenced by the tape. `grad` is
150    /// **not** zeroed by this routine — the caller can chain multiple
151    /// gradient accumulations into the same buffer.
152    pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
153        if seed == 0.0 || self.ops.is_empty() {
154            return;
155        }
156        let vals = self.forward(x);
157        self.reverse(&vals, seed, grad);
158    }
159
160    fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
161        let n = self.ops.len();
162        let mut adj = vec![0.0f64; n];
163        adj[n - 1] = seed;
164
165        for i in (0..n).rev() {
166            let a = adj[i];
167            if a == 0.0 {
168                continue;
169            }
170            match &self.ops[i] {
171                TapeOp::Const(_) => {}
172                TapeOp::Var(j) => {
173                    grad[*j] += a;
174                }
175                TapeOp::Add(l, r) => {
176                    adj[*l] += a;
177                    adj[*r] += a;
178                }
179                TapeOp::Sub(l, r) => {
180                    adj[*l] += a;
181                    adj[*r] -= a;
182                }
183                TapeOp::Mul(l, r) => {
184                    adj[*l] += a * vals[*r];
185                    adj[*r] += a * vals[*l];
186                }
187                TapeOp::Div(l, r) => {
188                    let rv = vals[*r];
189                    adj[*l] += a / rv;
190                    adj[*r] -= a * vals[*l] / (rv * rv);
191                }
192                TapeOp::Pow(l, r) => {
193                    let lv = vals[*l];
194                    let rv = vals[*r];
195                    if rv != 0.0 {
196                        adj[*l] += a * rv * lv.powf(rv - 1.0);
197                    }
198                    if lv > 0.0 {
199                        adj[*r] += a * vals[i] * lv.ln();
200                    }
201                }
202                TapeOp::Neg(j) => {
203                    adj[*j] -= a;
204                }
205                TapeOp::Abs(j) => {
206                    if vals[*j] >= 0.0 {
207                        adj[*j] += a;
208                    } else {
209                        adj[*j] -= a;
210                    }
211                }
212                TapeOp::Sqrt(j) => {
213                    let sv = vals[i];
214                    if sv > 0.0 {
215                        adj[*j] += a * 0.5 / sv;
216                    }
217                }
218                TapeOp::Exp(j) => {
219                    adj[*j] += a * vals[i];
220                }
221                TapeOp::Log(j) => {
222                    adj[*j] += a / vals[*j];
223                }
224                TapeOp::Log10(j) => {
225                    adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
226                }
227                TapeOp::Sin(j) => {
228                    adj[*j] += a * vals[*j].cos();
229                }
230                TapeOp::Cos(j) => {
231                    adj[*j] -= a * vals[*j].sin();
232                }
233                TapeOp::Funcall { lib, name, args } => {
234                    let call_args = funcall_to_ext_args(args, vals);
235                    let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
236                        panic!("external function '{name}' reverse eval failed: {e}")
237                    });
238                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
239                    let mut k = 0usize;
240                    for arg in args {
241                        if let TapeFuncallArg::Tape(idx) = arg {
242                            adj[*idx] += a * derivs[k];
243                            k += 1;
244                        }
245                    }
246                }
247            }
248        }
249    }
250
251    /// Sorted distinct problem-variable indices that the tape depends on.
252    pub fn variables(&self) -> Vec<usize> {
253        let mut s: BTreeSet<usize> = BTreeSet::new();
254        for op in &self.ops {
255            if let TapeOp::Var(j) = op {
256                s.insert(*j);
257            }
258        }
259        s.into_iter().collect()
260    }
261
262    /// Forward tangent sweep: `dot[i] = d(slot_i) / dx_{seed_var}`.
263    /// Caller-supplied `dot` buffer is overwritten in full; no zeroing
264    /// needed beforehand because every slot is written before it is
265    /// read (the loop walks forward and only reads earlier slots).
266    fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
267        let n = self.ops.len();
268        debug_assert_eq!(dot.len(), n);
269        for i in 0..n {
270            dot[i] = match &self.ops[i] {
271                TapeOp::Const(_) => 0.0,
272                TapeOp::Var(k) => {
273                    if *k == seed_var {
274                        1.0
275                    } else {
276                        0.0
277                    }
278                }
279                TapeOp::Add(a, b) => dot[*a] + dot[*b],
280                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
281                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
282                TapeOp::Div(a, b) => {
283                    let vb = vals[*b];
284                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
285                }
286                TapeOp::Pow(a, b) => {
287                    let u = vals[*a];
288                    let r = vals[*b];
289                    let du = dot[*a];
290                    let dr = dot[*b];
291                    let mut result = 0.0;
292                    if r != 0.0 && u != 0.0 {
293                        result += r * u.powf(r - 1.0) * du;
294                    }
295                    if u > 0.0 {
296                        result += vals[i] * u.ln() * dr;
297                    }
298                    result
299                }
300                TapeOp::Neg(a) => -dot[*a],
301                TapeOp::Abs(a) => {
302                    if vals[*a] >= 0.0 {
303                        dot[*a]
304                    } else {
305                        -dot[*a]
306                    }
307                }
308                TapeOp::Sqrt(a) => {
309                    let sv = vals[i];
310                    if sv > 0.0 {
311                        dot[*a] * 0.5 / sv
312                    } else {
313                        0.0
314                    }
315                }
316                TapeOp::Exp(a) => dot[*a] * vals[i],
317                TapeOp::Log(a) => dot[*a] / vals[*a],
318                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
319                TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
320                TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
321                TapeOp::Funcall { lib, name, args } => {
322                    let call_args = funcall_to_ext_args(args, vals);
323                    let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
324                        panic!("external function '{name}' tangent eval failed: {e}")
325                    });
326                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
327                    let mut acc = 0.0;
328                    let mut k = 0usize;
329                    for arg in args {
330                        if let TapeFuncallArg::Tape(idx) = arg {
331                            acc += derivs[k] * dot[*idx];
332                            k += 1;
333                        }
334                    }
335                    acc
336                }
337            };
338        }
339    }
340
341    /// Forward sweep into a caller-supplied buffer. Avoids the
342    /// per-call allocation of `forward()` so hot paths can reuse
343    /// one scratch arena across many tapes.
344    pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
345        let n = self.ops.len();
346        debug_assert!(vals.len() >= n);
347        for i in 0..n {
348            vals[i] = match &self.ops[i] {
349                TapeOp::Const(c) => *c,
350                TapeOp::Var(j) => x[*j],
351                TapeOp::Add(a, b) => vals[*a] + vals[*b],
352                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
353                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
354                TapeOp::Div(a, b) => vals[*a] / vals[*b],
355                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
356                TapeOp::Neg(a) => -vals[*a],
357                TapeOp::Abs(a) => vals[*a].abs(),
358                TapeOp::Sqrt(a) => vals[*a].sqrt(),
359                TapeOp::Exp(a) => vals[*a].exp(),
360                TapeOp::Log(a) => vals[*a].ln(),
361                TapeOp::Log10(a) => vals[*a].log10(),
362                TapeOp::Sin(a) => vals[*a].sin(),
363                TapeOp::Cos(a) => vals[*a].cos(),
364                TapeOp::Funcall { lib, name, args } => {
365                    let call_args = funcall_to_ext_args(args, &*vals);
366                    let res = lib
367                        .eval(name, &call_args, false, false)
368                        .unwrap_or_else(|e| {
369                            panic!("external function '{name}' forward_into failed: {e}")
370                        });
371                    res.value
372                }
373            };
374        }
375    }
376
377    /// Directional Hessian-vector product: emits
378    /// `weight * (∇²f · seed)[k]` into `out[k]` for every problem
379    /// variable `k` the tape references. Caller supplies the
380    /// forward-pass result `vals` (use [`forward_into`]) plus three
381    /// scratch buffers (`dot`, `adj`, `adj_dot`), each at least
382    /// `self.ops.len()` long. `out` must be at least one past the
383    /// largest variable index in the tape; the routine reads
384    /// `seed[k]` for each `Var(k)` and writes `out[k] += weight *
385    /// (Hess · seed)[k]`.
386    ///
387    /// This is one forward-over-reverse AD pass — O(n_ops) work —
388    /// regardless of how many variables the tape depends on, which
389    /// is what makes Hessian coloring efficient: a single
390    /// directional pass recovers a whole color group of columns.
391    ///
392    /// [`forward_into`]: Tape::forward_into
393    pub fn hessian_directional(
394        &self,
395        vals: &[f64],
396        seed: &[f64],
397        weight: f64,
398        out: &mut [f64],
399        dot: &mut [f64],
400        adj: &mut [f64],
401        adj_dot: &mut [f64],
402    ) {
403        let n = self.ops.len();
404        if n == 0 || weight == 0.0 {
405            return;
406        }
407        debug_assert!(vals.len() >= n);
408        debug_assert!(dot.len() >= n);
409        debug_assert!(adj.len() >= n);
410        debug_assert!(adj_dot.len() >= n);
411
412        // Forward tangent: dot[i] = (∂vals[i] / ∂x · seed). At
413        // Var(k) the seed entry feeds in; the rest of the chain
414        // rule matches `forward_tangent` exactly.
415        for i in 0..n {
416            dot[i] = match &self.ops[i] {
417                TapeOp::Const(_) => 0.0,
418                TapeOp::Var(k) => seed[*k],
419                TapeOp::Add(a, b) => dot[*a] + dot[*b],
420                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
421                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
422                TapeOp::Div(a, b) => {
423                    let vb = vals[*b];
424                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
425                }
426                TapeOp::Pow(a, b) => {
427                    let u = vals[*a];
428                    let r = vals[*b];
429                    let du = dot[*a];
430                    let dr = dot[*b];
431                    let mut result = 0.0;
432                    if r != 0.0 && u != 0.0 {
433                        result += r * u.powf(r - 1.0) * du;
434                    }
435                    if u > 0.0 {
436                        result += vals[i] * u.ln() * dr;
437                    }
438                    result
439                }
440                TapeOp::Neg(a) => -dot[*a],
441                TapeOp::Abs(a) => {
442                    if vals[*a] >= 0.0 {
443                        dot[*a]
444                    } else {
445                        -dot[*a]
446                    }
447                }
448                TapeOp::Sqrt(a) => {
449                    let sv = vals[i];
450                    if sv > 0.0 {
451                        dot[*a] * 0.5 / sv
452                    } else {
453                        0.0
454                    }
455                }
456                TapeOp::Exp(a) => vals[i] * dot[*a],
457                TapeOp::Log(a) => dot[*a] / vals[*a],
458                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
459                TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
460                TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
461                TapeOp::Funcall { lib, name, args } => {
462                    let call_args = funcall_to_ext_args(args, vals);
463                    let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
464                        panic!("external function '{name}' tangent eval failed: {e}")
465                    });
466                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
467                    let mut acc = 0.0;
468                    let mut k = 0usize;
469                    for arg in args {
470                        if let TapeFuncallArg::Tape(idx) = arg {
471                            acc += derivs[k] * dot[*idx];
472                            k += 1;
473                        }
474                    }
475                    acc
476                }
477            };
478        }
479
480        // Reverse over tangent. adj[i] = ∂f/∂vals[i],
481        // adj_dot[i] = derivative of adj[i] along `seed`
482        // direction = (Hess · seed) projected onto slot i.
483        for slot in adj.iter_mut().take(n) {
484            *slot = 0.0;
485        }
486        for slot in adj_dot.iter_mut().take(n) {
487            *slot = 0.0;
488        }
489        adj[n - 1] = 1.0;
490
491        for i in (0..n).rev() {
492            let w = adj[i];
493            let wd = adj_dot[i];
494            if w == 0.0 && wd == 0.0 {
495                continue;
496            }
497            match &self.ops[i] {
498                TapeOp::Const(_) => {}
499                TapeOp::Var(k) => {
500                    if wd != 0.0 {
501                        out[*k] += weight * wd;
502                    }
503                }
504                TapeOp::Add(a, b) => {
505                    adj[*a] += w;
506                    adj[*b] += w;
507                    adj_dot[*a] += wd;
508                    adj_dot[*b] += wd;
509                }
510                TapeOp::Sub(a, b) => {
511                    adj[*a] += w;
512                    adj[*b] -= w;
513                    adj_dot[*a] += wd;
514                    adj_dot[*b] -= wd;
515                }
516                TapeOp::Mul(a, b) => {
517                    adj[*a] += w * vals[*b];
518                    adj[*b] += w * vals[*a];
519                    adj_dot[*a] += wd * vals[*b] + w * dot[*b];
520                    adj_dot[*b] += wd * vals[*a] + w * dot[*a];
521                }
522                TapeOp::Div(a, b) => {
523                    let vb = vals[*b];
524                    let vb2 = vb * vb;
525                    let vb3 = vb2 * vb;
526                    adj[*a] += w / vb;
527                    adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
528                    adj[*b] += w * (-vals[*a] / vb2);
529                    adj_dot[*b] += wd * (-vals[*a] / vb2)
530                        + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
531                }
532                TapeOp::Pow(a, b) => {
533                    let u = vals[*a];
534                    let r = vals[*b];
535                    let du = dot[*a];
536                    let dr = dot[*b];
537                    if r != 0.0 {
538                        if u != 0.0 {
539                            let p_a = r * u.powf(r - 1.0);
540                            adj[*a] += w * p_a;
541                            let mut dp_a = dr * u.powf(r - 1.0);
542                            if u > 0.0 {
543                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
544                            } else {
545                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
546                            }
547                            adj_dot[*a] += wd * p_a + w * dp_a;
548                        } else if r >= 2.0 {
549                            let p_a = 0.0;
550                            adj[*a] += w * p_a;
551                            let dp_a = if r == 2.0 {
552                                2.0 * du
553                            } else {
554                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
555                            };
556                            adj_dot[*a] += wd * p_a + w * dp_a;
557                        }
558                    }
559                    if u > 0.0 {
560                        let ln_u = u.ln();
561                        let p_b = vals[i] * ln_u;
562                        adj[*b] += w * p_b;
563                        let dur = vals[i] * (r * du / u + dr * ln_u);
564                        let dp_b = dur * ln_u + vals[i] * du / u;
565                        adj_dot[*b] += wd * p_b + w * dp_b;
566                    }
567                }
568                TapeOp::Neg(a) => {
569                    adj[*a] -= w;
570                    adj_dot[*a] -= wd;
571                }
572                TapeOp::Abs(a) => {
573                    let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
574                    adj[*a] += w * s;
575                    adj_dot[*a] += wd * s;
576                }
577                TapeOp::Sqrt(a) => {
578                    let sv = vals[i];
579                    if sv > 0.0 {
580                        let fp = 0.5 / sv;
581                        let fpp = -0.25 / (vals[*a] * sv);
582                        adj[*a] += w * fp;
583                        adj_dot[*a] += wd * fp + w * fpp * dot[*a];
584                    }
585                }
586                TapeOp::Exp(a) => {
587                    let ev = vals[i];
588                    adj[*a] += w * ev;
589                    adj_dot[*a] += wd * ev + w * ev * dot[*a];
590                }
591                TapeOp::Log(a) => {
592                    let u = vals[*a];
593                    adj[*a] += w / u;
594                    adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
595                }
596                TapeOp::Log10(a) => {
597                    let u = vals[*a];
598                    let c = std::f64::consts::LN_10;
599                    adj[*a] += w / (u * c);
600                    adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
601                }
602                TapeOp::Sin(a) => {
603                    let u = vals[*a];
604                    let cu = u.cos();
605                    adj[*a] += w * cu;
606                    adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
607                }
608                TapeOp::Cos(a) => {
609                    let u = vals[*a];
610                    let su = u.sin();
611                    adj[*a] -= w * su;
612                    adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
613                }
614                TapeOp::Funcall { lib, name, args } => {
615                    let call_args = funcall_to_ext_args(args, vals);
616                    let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
617                        panic!("external function '{name}' 2nd-order eval failed: {e}")
618                    });
619                    let derivs = res.derivs.expect("want_derivs=true returns derivs");
620                    let hes = res.hessian.expect("want_hes=true returns hessian");
621                    let real_tape: Vec<usize> = args
622                        .iter()
623                        .filter_map(|a| match a {
624                            TapeFuncallArg::Tape(t) => Some(*t),
625                            TapeFuncallArg::Str(_) => None,
626                        })
627                        .collect();
628                    for (k, &tk) in real_tape.iter().enumerate() {
629                        adj[tk] += w * derivs[k];
630                        let mut second_term = 0.0;
631                        for (l, &tl) in real_tape.iter().enumerate() {
632                            let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
633                            let h_kl = hes[lo + hi * (hi + 1) / 2];
634                            second_term += h_kl * dot[tl];
635                        }
636                        adj_dot[tk] += wd * derivs[k] + w * second_term;
637                    }
638                }
639            }
640        }
641    }
642
643    /// Forward-over-reverse Hessian: for each variable `j` the tape
644    /// depends on, accumulate `weight * (d²f / dx_i dx_j)` into
645    /// `values[hess_map[(i, j)]]` for every `(i, j)` lower-triangle
646    /// pair in the map. The same routine is used for the objective
647    /// (with `weight = obj_factor`) and each active constraint (with
648    /// `weight = lambda[k]`); contributions sum into the shared map.
649    pub fn hessian_accumulate(
650        &self,
651        x: &[f64],
652        weight: f64,
653        hess_map: &HashMap<(usize, usize), usize>,
654        values: &mut [f64],
655    ) {
656        let n = self.ops.len();
657        if n == 0 || weight == 0.0 {
658            return;
659        }
660        let v = self.forward(x);
661        let var_indices = self.variables();
662
663        // Hoist scratch allocations out of the per-variable loop —
664        // each was costing O(n) per j on every hessian_accumulate
665        // call, which dominated runtime on large tapes (the dense-
666        // Hessian Mittelmann problems). `forward_tangent` fully
667        // overwrites `dot`, so no reset is needed there. `adj` and
668        // `adj_dot` are mutated additively, so we zero them per j.
669        let mut dot = vec![0.0f64; n];
670        let mut adj = vec![0.0f64; n];
671        let mut adj_dot = vec![0.0f64; n];
672        for &j in &var_indices {
673            self.forward_tangent(&v, j, &mut dot);
674
675            // adj[i] = standard adjoint (∂f/∂slot_i)
676            // adj_dot[i] = derivative of adj[i] w.r.t. x_j = ∂²f/(∂slot_i ∂x_j)
677            adj.fill(0.0);
678            adj_dot.fill(0.0);
679            adj[n - 1] = 1.0;
680
681            for i in (0..n).rev() {
682                let w = adj[i];
683                let wd = adj_dot[i];
684                if w == 0.0 && wd == 0.0 {
685                    continue;
686                }
687                match &self.ops[i] {
688                    TapeOp::Const(_) => {}
689                    TapeOp::Var(k) => {
690                        // Lower-triangle: only emit when row k >= col j
691                        // so an off-diagonal pair appears once.
692                        if wd != 0.0 && *k >= j {
693                            if let Some(&pos) = hess_map.get(&(*k, j)) {
694                                values[pos] += weight * wd;
695                            }
696                        }
697                    }
698                    TapeOp::Add(a, b) => {
699                        adj[*a] += w;
700                        adj[*b] += w;
701                        adj_dot[*a] += wd;
702                        adj_dot[*b] += wd;
703                    }
704                    TapeOp::Sub(a, b) => {
705                        adj[*a] += w;
706                        adj[*b] -= w;
707                        adj_dot[*a] += wd;
708                        adj_dot[*b] -= wd;
709                    }
710                    TapeOp::Mul(a, b) => {
711                        adj[*a] += w * v[*b];
712                        adj[*b] += w * v[*a];
713                        adj_dot[*a] += wd * v[*b] + w * dot[*b];
714                        adj_dot[*b] += wd * v[*a] + w * dot[*a];
715                    }
716                    TapeOp::Div(a, b) => {
717                        let vb = v[*b];
718                        let vb2 = vb * vb;
719                        let vb3 = vb2 * vb;
720                        adj[*a] += w / vb;
721                        adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
722                        adj[*b] += w * (-v[*a] / vb2);
723                        adj_dot[*b] += wd * (-v[*a] / vb2)
724                            + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
725                    }
726                    TapeOp::Pow(a, b) => {
727                        let u = v[*a];
728                        let r = v[*b];
729                        let du = dot[*a];
730                        let dr = dot[*b];
731                        if r != 0.0 {
732                            if u != 0.0 {
733                                let p_a = r * u.powf(r - 1.0);
734                                adj[*a] += w * p_a;
735                                let mut dp_a = dr * u.powf(r - 1.0);
736                                if u > 0.0 {
737                                    dp_a +=
738                                        r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
739                                } else {
740                                    dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
741                                }
742                                adj_dot[*a] += wd * p_a + w * dp_a;
743                            } else if r >= 2.0 {
744                                let p_a = 0.0;
745                                adj[*a] += w * p_a;
746                                let dp_a = if r == 2.0 {
747                                    2.0 * du
748                                } else {
749                                    r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
750                                };
751                                adj_dot[*a] += wd * p_a + w * dp_a;
752                            }
753                        }
754                        if u > 0.0 {
755                            let ln_u = u.ln();
756                            let p_b = v[i] * ln_u;
757                            adj[*b] += w * p_b;
758                            let dur = v[i] * (r * du / u + dr * ln_u);
759                            let dp_b = dur * ln_u + v[i] * du / u;
760                            adj_dot[*b] += wd * p_b + w * dp_b;
761                        }
762                    }
763                    TapeOp::Neg(a) => {
764                        adj[*a] -= w;
765                        adj_dot[*a] -= wd;
766                    }
767                    TapeOp::Abs(a) => {
768                        let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
769                        adj[*a] += w * s;
770                        adj_dot[*a] += wd * s;
771                    }
772                    TapeOp::Sqrt(a) => {
773                        let sv = v[i];
774                        if sv > 0.0 {
775                            let fp = 0.5 / sv;
776                            let fpp = -0.25 / (v[*a] * sv);
777                            adj[*a] += w * fp;
778                            adj_dot[*a] += wd * fp + w * fpp * dot[*a];
779                        }
780                    }
781                    TapeOp::Exp(a) => {
782                        let ev = v[i];
783                        adj[*a] += w * ev;
784                        adj_dot[*a] += wd * ev + w * ev * dot[*a];
785                    }
786                    TapeOp::Log(a) => {
787                        let u = v[*a];
788                        adj[*a] += w / u;
789                        adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
790                    }
791                    TapeOp::Log10(a) => {
792                        let u = v[*a];
793                        let c = std::f64::consts::LN_10;
794                        adj[*a] += w / (u * c);
795                        adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
796                    }
797                    TapeOp::Sin(a) => {
798                        let u = v[*a];
799                        let cu = u.cos();
800                        adj[*a] += w * cu;
801                        adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
802                    }
803                    TapeOp::Cos(a) => {
804                        let u = v[*a];
805                        let su = u.sin();
806                        adj[*a] -= w * su;
807                        adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
808                    }
809                    TapeOp::Funcall { lib, name, args } => {
810                        let call_args = funcall_to_ext_args(args, &v);
811                        let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
812                            panic!("external function '{name}' 2nd-order eval failed: {e}")
813                        });
814                        let derivs = res.derivs.expect("want_derivs=true returns derivs");
815                        let hes = res.hessian.expect("want_hes=true returns hessian");
816                        let real_tape: Vec<usize> = args
817                            .iter()
818                            .filter_map(|a| match a {
819                                TapeFuncallArg::Tape(t) => Some(*t),
820                                TapeFuncallArg::Str(_) => None,
821                            })
822                            .collect();
823                        for (k, &tk) in real_tape.iter().enumerate() {
824                            adj[tk] += w * derivs[k];
825                            let mut second_term = 0.0;
826                            for (l, &tl) in real_tape.iter().enumerate() {
827                                let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
828                                let h_kl = hes[lo + hi * (hi + 1) / 2];
829                                second_term += h_kl * dot[tl];
830                            }
831                            adj_dot[tk] += wd * derivs[k] + w * second_term;
832                        }
833                    }
834                }
835            }
836        }
837    }
838
839    /// Structural Hessian sparsity (lower triangle, row >= col).
840    /// Propagates per-slot variable-dependence sets forward; each
841    /// nonlinear op emits the cross/self products of its operand sets.
842    /// Linear ops contribute no second-derivative pairs.
843    pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
844        let n = self.ops.len();
845        let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
846        let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
847
848        let emit_cross =
849            |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
850                for &v1 in s1 {
851                    for &v2 in s2 {
852                        let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
853                        pairs.insert((r, c));
854                    }
855                }
856            };
857        let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
858            let vars: Vec<usize> = s.iter().copied().collect();
859            for (ai, &vi) in vars.iter().enumerate() {
860                for &vj in &vars[..=ai] {
861                    let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
862                    pairs.insert((r, c));
863                }
864            }
865        };
866
867        for op in &self.ops {
868            let vset = match op {
869                TapeOp::Const(_) => BTreeSet::new(),
870                TapeOp::Var(j) => {
871                    let mut s = BTreeSet::new();
872                    s.insert(*j);
873                    s
874                }
875                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
876                    var_sets[*a].union(&var_sets[*b]).copied().collect()
877                }
878                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
879                TapeOp::Mul(a, b) => {
880                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
881                    var_sets[*a].union(&var_sets[*b]).copied().collect()
882                }
883                TapeOp::Div(a, b) => {
884                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
885                    emit_self(&var_sets[*b], &mut pairs);
886                    var_sets[*a].union(&var_sets[*b]).copied().collect()
887                }
888                TapeOp::Pow(a, b) => {
889                    let combined: BTreeSet<usize> =
890                        var_sets[*a].union(&var_sets[*b]).copied().collect();
891                    emit_self(&combined, &mut pairs);
892                    combined
893                }
894                TapeOp::Sqrt(a)
895                | TapeOp::Exp(a)
896                | TapeOp::Log(a)
897                | TapeOp::Log10(a)
898                | TapeOp::Sin(a)
899                | TapeOp::Cos(a) => {
900                    emit_self(&var_sets[*a], &mut pairs);
901                    var_sets[*a].clone()
902                }
903                TapeOp::Funcall { args, .. } => {
904                    let mut combined: BTreeSet<usize> = BTreeSet::new();
905                    for arg in args {
906                        if let TapeFuncallArg::Tape(t) = arg {
907                            for &vv in &var_sets[*t] {
908                                combined.insert(vv);
909                            }
910                        }
911                    }
912                    emit_self(&combined, &mut pairs);
913                    combined
914                }
915            };
916            var_sets.push(vset);
917        }
918        pairs
919    }
920}
921
922fn build_recursive(
923    expr: &Expr,
924    ops: &mut Vec<TapeOp>,
925    cache: &mut HashMap<*const Expr, usize>,
926    resolver: &ExternalResolver,
927) -> usize {
928    match expr {
929        Expr::Const(c) => {
930            let idx = ops.len();
931            ops.push(TapeOp::Const(*c));
932            idx
933        }
934        Expr::Var(i) => {
935            let idx = ops.len();
936            ops.push(TapeOp::Var(*i));
937            idx
938        }
939        Expr::Binary(op, a, b) => {
940            // Pow(x, const) is the dominant libm/dispatch cost in
941            // transcendental-heavy AMPL tapes (henon, lane_emden, …):
942            // `powf` itself is ~30–50 cycles AND the reverse-mode arm
943            // for `Pow` carries an extra `ln(x)` branch. Rewriting
944            // small integer / half-integer exponents into mul/sqrt
945            // chains drops these calls entirely and reroutes the AD
946            // through the much cheaper `Mul`/`Sqrt` arms.
947            if let BinOp::Pow = op {
948                if let Some(c) = peek_const(b) {
949                    if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
950                        return idx;
951                    }
952                }
953            }
954            let l = build_recursive(a, ops, cache, resolver);
955            let r = build_recursive(b, ops, cache, resolver);
956            let idx = ops.len();
957            ops.push(match op {
958                BinOp::Add => TapeOp::Add(l, r),
959                BinOp::Sub => TapeOp::Sub(l, r),
960                BinOp::Mul => TapeOp::Mul(l, r),
961                BinOp::Div => TapeOp::Div(l, r),
962                BinOp::Pow => TapeOp::Pow(l, r),
963            });
964            idx
965        }
966        Expr::Unary(op, a) => {
967            let v = build_recursive(a, ops, cache, resolver);
968            let idx = ops.len();
969            ops.push(match op {
970                UnaryOp::Neg => TapeOp::Neg(v),
971                UnaryOp::Sqrt => TapeOp::Sqrt(v),
972                UnaryOp::Log => TapeOp::Log(v),
973                UnaryOp::Log10 => TapeOp::Log10(v),
974                UnaryOp::Exp => TapeOp::Exp(v),
975                UnaryOp::Abs => TapeOp::Abs(v),
976                UnaryOp::Sin => TapeOp::Sin(v),
977                UnaryOp::Cos => TapeOp::Cos(v),
978            });
979            idx
980        }
981        Expr::Sum(args) => {
982            if args.is_empty() {
983                let idx = ops.len();
984                ops.push(TapeOp::Const(0.0));
985                return idx;
986            }
987            let mut acc = build_recursive(&args[0], ops, cache, resolver);
988            for a in &args[1..] {
989                let next = build_recursive(a, ops, cache, resolver);
990                let idx = ops.len();
991                ops.push(TapeOp::Add(acc, next));
992                acc = idx;
993            }
994            acc
995        }
996        Expr::Cse(body) => {
997            // Cache by Rc identity so each shared body is emitted into
998            // the tape exactly once and every reference resolves to the
999            // same result-slot index. Forward computes the body once;
1000            // reverse-mode adjoint sums contributions from every ref
1001            // into that shared slot — exact chain rule for shared
1002            // sub-expressions.
1003            let key = Rc::as_ptr(body) as *const Expr;
1004            if let Some(&idx) = cache.get(&key) {
1005                idx
1006            } else {
1007                let idx = build_recursive(body, ops, cache, resolver);
1008                cache.insert(key, idx);
1009                idx
1010            }
1011        }
1012        Expr::Funcall { id, args } => {
1013            let (lib, name) = resolver
1014                .funcs_by_id
1015                .get(id)
1016                .unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
1017            let tape_args: Vec<TapeFuncallArg> = args
1018                .iter()
1019                .map(|a| match a {
1020                    FuncallArg::Real(e) => {
1021                        TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
1022                    }
1023                    FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
1024                })
1025                .collect();
1026            let idx = ops.len();
1027            ops.push(TapeOp::Funcall {
1028                lib: Arc::clone(lib),
1029                name: name.clone(),
1030                args: tape_args,
1031            });
1032            idx
1033        }
1034    }
1035}
1036
1037/// Resolve `e` to a literal constant if it is one (transparently
1038/// peering through `Cse` wrappers, which AMPL emits around shared
1039/// constants in CSE-heavy problems).
1040fn peek_const(e: &Expr) -> Option<f64> {
1041    match e {
1042        Expr::Const(c) => Some(*c),
1043        Expr::Cse(body) => peek_const(body),
1044        _ => None,
1045    }
1046}
1047
1048/// Try to rewrite `base ^ exponent_const` into cheaper ops. Returns
1049/// the result tape-slot on success; `None` means "fall through to
1050/// generic Pow." Handles the cases that account for the bulk of
1051/// AMPL-emitted Pow nodes: integer exponents up to ±8 and the
1052/// `Sqrt`/passthrough/one specials. Half-integer exponents (e.g.
1053/// `^1.5`) and larger integers are left to generic `Pow` since the
1054/// resulting mul chain grows the tape faster than it saves work.
1055fn try_emit_const_pow(
1056    base_expr: &Expr,
1057    c: f64,
1058    ops: &mut Vec<TapeOp>,
1059    cache: &mut HashMap<*const Expr, usize>,
1060    resolver: &ExternalResolver,
1061) -> Option<usize> {
1062    if c == 0.0 {
1063        let idx = ops.len();
1064        ops.push(TapeOp::Const(1.0));
1065        return Some(idx);
1066    }
1067    if c == 1.0 {
1068        return Some(build_recursive(base_expr, ops, cache, resolver));
1069    }
1070    if c == 0.5 {
1071        let b = build_recursive(base_expr, ops, cache, resolver);
1072        let idx = ops.len();
1073        ops.push(TapeOp::Sqrt(b));
1074        return Some(idx);
1075    }
1076    // Integer exponents: bounded so a bad tape can't blow up the
1077    // op count. 8 covers everything AMPL typically emits for
1078    // polynomial models; beyond that the binary-expansion mul
1079    // chain (≥4 ops) starts to lose to a single `powf`.
1080    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1081        let n = c.abs() as u32;
1082        if n == 0 {
1083            // Already handled above, but guard.
1084            let idx = ops.len();
1085            ops.push(TapeOp::Const(1.0));
1086            return Some(idx);
1087        }
1088        let b = build_recursive(base_expr, ops, cache, resolver);
1089        let pos = emit_int_pow(b, n, ops);
1090        if c < 0.0 {
1091            // x^-n = 1 / x^n. Saves the powf and its ln branch in
1092            // reverse mode; cost is one Div in their place.
1093            let one_idx = ops.len();
1094            ops.push(TapeOp::Const(1.0));
1095            let idx = ops.len();
1096            ops.push(TapeOp::Div(one_idx, pos));
1097            return Some(idx);
1098        }
1099        return Some(pos);
1100    }
1101    None
1102}
1103
1104/// Emit `base^n` for `n >= 1` as a binary-expansion mul chain.
1105/// Worst-case op count is `2·floor(log2(n))` — i.e. 1 op for n=2, 2
1106/// for n=3/4, 3 for n=5..8.
1107fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
1108    debug_assert!(n >= 1);
1109    if n == 1 {
1110        return base;
1111    }
1112    let half = emit_int_pow(base, n / 2, ops);
1113    let squared = ops.len();
1114    ops.push(TapeOp::Mul(half, half));
1115    if n % 2 == 1 {
1116        let idx = ops.len();
1117        ops.push(TapeOp::Mul(squared, base));
1118        idx
1119    } else {
1120        squared
1121    }
1122}
1123
1124// ============================================================
1125// HybridTape: per-summand local tapes + shared CSE prelude.
1126//
1127// Partial separability — the .nl Sum/Add structure — gets each
1128// summand its own local Vec<SummandOp>. CSE bodies (V-segments
1129// in .nl) that appear in two or more summands are promoted into
1130// a single shared `prelude: Vec<TapeOp>`; per-summand references
1131// to a promoted CSE are SummandOp::Shared(prelude_slot).
1132//
1133// This is strictly better than either extreme:
1134//   - per-summand Tape (no cross-summand sharing): re-inlines
1135//     every shared CSE, blows up tape size when many constraints
1136//     share a stencil derivative (Mittelmann *120 problems).
1137//   - GlobalTape (single shared Vec<TapeOp> for everything):
1138//     per-root reverse sweeps scatter across a many-MB buffer,
1139//     thrashing cache when no CSE is actually shared (lane_emden
1140//     120: each constraint owns its own ops → 50% regression
1141//     vs per-summand tapes).
1142//
1143// Forward: prelude once, then each summand's local pass.
1144// Reverse / forward-over-reverse: per-summand sweep over local
1145// reach (which propagates adjoints into prelude_adj at Shared
1146// boundaries), then a small reverse pass over the summand's
1147// prelude_reach to fold those into grad / Hessian.
1148// ============================================================
1149
1150/// One slot in a per-summand local tape.
1151#[derive(Debug, Clone)]
1152pub enum SummandOp {
1153    /// Local op — operand indices reference other slots in the
1154    /// same per-summand vector.
1155    Local(TapeOp),
1156    /// Pull a value from the shared prelude at slot `usize`. No
1157    /// downstream cost beyond the lookup; adjoints flowing into
1158    /// this slot accumulate into the prelude adjoint buffer.
1159    Shared(usize),
1160}
1161
1162#[derive(Debug, Clone)]
1163pub struct Summand {
1164    pub ops: Vec<SummandOp>,
1165    /// Local slot holding the summand's final value.
1166    pub root_slot: usize,
1167    /// Local slots reachable from `root_slot`, ascending (topo).
1168    pub local_reach: Vec<usize>,
1169    /// Prelude slots reachable from the summand's Shared refs,
1170    /// ascending (topo in prelude's operand DAG).
1171    pub prelude_reach: Vec<usize>,
1172    /// Variables touched by Var ops inside `local_reach`.
1173    pub local_vars: Vec<usize>,
1174    /// Variables touched by Var ops inside `prelude_reach`.
1175    pub prelude_vars: Vec<usize>,
1176    /// `local_vars ∪ prelude_vars`, sorted. Hessian j-loop set.
1177    pub all_vars: Vec<usize>,
1178}
1179
1180#[derive(Debug)]
1181pub struct HybridTape {
1182    /// Shared CSE bodies. Slot indices in `SummandOp::Shared`
1183    /// point here; this Vec is built bottom-up by `build_recursive`,
1184    /// so operand indices are always less than the consumer's
1185    /// index (topo in ascending order).
1186    pub prelude: Vec<TapeOp>,
1187    pub summands: Vec<Summand>,
1188}
1189
1190impl HybridTape {
1191    /// Build hybrid tape from a list of root expressions. CSE
1192    /// bodies referenced from ≥ 2 roots are promoted into the
1193    /// shared prelude; CSEs touched by only one root are inlined
1194    /// into that summand's local ops.
1195    pub fn build_multi(exprs: &[Expr]) -> Self {
1196        // Pass 1: per-Cse-pointer count of how many roots reference
1197        // it (each root contributes at most 1 to the count). The
1198        // ≥2 threshold means a CSE is shared across summands.
1199        let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1200        for e in exprs {
1201            let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1202            count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1203        }
1204
1205        // Pass 2: build prelude + each summand. The summand builder
1206        // hits the prelude path lazily — only when it encounters a
1207        // promoted Cse — so the prelude grows only with bodies that
1208        // are actually referenced from multiple summands.
1209        let mut prelude: Vec<TapeOp> = Vec::new();
1210        let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1211        let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1212        for e in exprs {
1213            let mut local: Vec<SummandOp> = Vec::new();
1214            let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1215            let root_slot = build_into_summand(
1216                e,
1217                &mut local,
1218                &mut local_cache,
1219                &mut prelude,
1220                &mut prelude_map,
1221                &cse_count,
1222            );
1223            summands.push(Summand {
1224                ops: local,
1225                root_slot,
1226                local_reach: Vec::new(),
1227                prelude_reach: Vec::new(),
1228                local_vars: Vec::new(),
1229                prelude_vars: Vec::new(),
1230                all_vars: Vec::new(),
1231            });
1232        }
1233
1234        // Pass 3: per-summand reach / vars. Prelude reach uses an
1235        // epoch-tagged shared visited buffer so total cost stays
1236        // O(Σ |prelude_reach_i|) rather than O(n_summands × |prelude|).
1237        let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1238        let mut p_epoch: u32 = 0;
1239        let mut p_stack: Vec<usize> = Vec::new();
1240        for s in &mut summands {
1241            let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1242            s.local_reach = local_reach;
1243
1244            let mut lv: BTreeSet<usize> = BTreeSet::new();
1245            for &i in &s.local_reach {
1246                if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1247                    lv.insert(*j);
1248                }
1249            }
1250            s.local_vars = lv.iter().copied().collect();
1251
1252            if !shared_refs.is_empty() {
1253                p_epoch += 1;
1254                let mut preach: Vec<usize> = Vec::new();
1255                for &start in &shared_refs {
1256                    bfs_prelude(
1257                        &prelude,
1258                        start,
1259                        &mut p_visited,
1260                        p_epoch,
1261                        &mut p_stack,
1262                        &mut preach,
1263                    );
1264                }
1265                preach.sort_unstable();
1266                s.prelude_vars = vars_in(&prelude, &preach);
1267                s.prelude_reach = preach;
1268            }
1269
1270            let mut av: BTreeSet<usize> = lv;
1271            for &v in &s.prelude_vars {
1272                av.insert(v);
1273            }
1274            s.all_vars = av.into_iter().collect();
1275        }
1276
1277        HybridTape { prelude, summands }
1278    }
1279
1280    pub fn n_prelude_ops(&self) -> usize {
1281        self.prelude.len()
1282    }
1283    pub fn n_summands(&self) -> usize {
1284        self.summands.len()
1285    }
1286    pub fn max_summand_ops(&self) -> usize {
1287        self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
1288    }
1289    pub fn total_local_ops(&self) -> usize {
1290        self.summands.iter().map(|s| s.ops.len()).sum()
1291    }
1292
1293    /// Forward sweep over the shared prelude. `prelude_vals` must
1294    /// have length `n_prelude_ops`.
1295    pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
1296        debug_assert_eq!(prelude_vals.len(), self.prelude.len());
1297        for i in 0..self.prelude.len() {
1298            prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
1299        }
1300    }
1301
1302    /// Forward sweep over one summand. `local_vals` must hold at
1303    /// least `s.ops.len()` entries.
1304    pub fn forward_summand(
1305        &self,
1306        s: &Summand,
1307        x: &[f64],
1308        prelude_vals: &[f64],
1309        local_vals: &mut [f64],
1310    ) {
1311        debug_assert!(local_vals.len() >= s.ops.len());
1312        for i in 0..s.ops.len() {
1313            local_vals[i] = match &s.ops[i] {
1314                SummandOp::Local(op) => fwd_step(op, x, local_vals),
1315                SummandOp::Shared(k) => prelude_vals[*k],
1316            };
1317        }
1318    }
1319
1320    /// Value at the summand root after `forward_summand`.
1321    #[inline]
1322    pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
1323        local_vals[s.root_slot]
1324    }
1325
1326    /// Reverse-mode gradient for one summand. Walks `local_reach`
1327    /// in reverse — propagating adjoints into `prelude_adj` at
1328    /// Shared boundaries — and then walks `prelude_reach` in
1329    /// reverse to land contributions in `grad`. Scratch arrays
1330    /// `local_adj` and `prelude_adj` are zeroed only at the slots
1331    /// actually touched.
1332    #[allow(clippy::too_many_arguments)]
1333    pub fn gradient_summand(
1334        &self,
1335        s: &Summand,
1336        prelude_vals: &[f64],
1337        local_vals: &[f64],
1338        seed: f64,
1339        grad: &mut [f64],
1340        local_adj: &mut [f64],
1341        prelude_adj: &mut [f64],
1342    ) {
1343        if seed == 0.0 || s.local_reach.is_empty() {
1344            return;
1345        }
1346        for &i in &s.local_reach {
1347            local_adj[i] = 0.0;
1348        }
1349        for &i in &s.prelude_reach {
1350            prelude_adj[i] = 0.0;
1351        }
1352        local_adj[s.root_slot] = seed;
1353        for &i in s.local_reach.iter().rev() {
1354            let a = local_adj[i];
1355            if a == 0.0 {
1356                continue;
1357            }
1358            match &s.ops[i] {
1359                SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
1360                SummandOp::Shared(k) => {
1361                    prelude_adj[*k] += a;
1362                }
1363            }
1364        }
1365        for &i in s.prelude_reach.iter().rev() {
1366            let a = prelude_adj[i];
1367            if a == 0.0 {
1368                continue;
1369            }
1370            rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
1371        }
1372    }
1373
1374    /// Forward-over-reverse Hessian for one summand with multiplier
1375    /// `weight`. Iterates over `s.all_vars`; for each seed variable
1376    /// j: (1) forward tangent through prelude_reach then local_reach,
1377    /// (2) reverse over local (folding adj/adj_dot into prelude at
1378    /// Shared boundaries), (3) reverse over prelude_reach. All
1379    /// scratch buffers are zeroed only at the touched slots inside
1380    /// the per-j loop.
1381    #[allow(clippy::too_many_arguments)]
1382    pub fn hessian_summand(
1383        &self,
1384        s: &Summand,
1385        prelude_vals: &[f64],
1386        local_vals: &[f64],
1387        weight: f64,
1388        hess_map: &HashMap<(usize, usize), usize>,
1389        values: &mut [f64],
1390        local_dot: &mut [f64],
1391        local_adj: &mut [f64],
1392        local_adj_dot: &mut [f64],
1393        prelude_dot: &mut [f64],
1394        prelude_adj: &mut [f64],
1395        prelude_adj_dot: &mut [f64],
1396    ) {
1397        if weight == 0.0 || s.local_reach.is_empty() {
1398            return;
1399        }
1400        for &j in &s.all_vars {
1401            for &i in &s.local_reach {
1402                local_dot[i] = 0.0;
1403                local_adj[i] = 0.0;
1404                local_adj_dot[i] = 0.0;
1405            }
1406            for &i in &s.prelude_reach {
1407                prelude_dot[i] = 0.0;
1408                prelude_adj[i] = 0.0;
1409                prelude_adj_dot[i] = 0.0;
1410            }
1411            for &i in &s.prelude_reach {
1412                prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
1413            }
1414            for &i in &s.local_reach {
1415                local_dot[i] = match &s.ops[i] {
1416                    SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
1417                    SummandOp::Shared(k) => prelude_dot[*k],
1418                };
1419            }
1420            local_adj[s.root_slot] = 1.0;
1421            for &i in s.local_reach.iter().rev() {
1422                let w = local_adj[i];
1423                let wd = local_adj_dot[i];
1424                if w == 0.0 && wd == 0.0 {
1425                    continue;
1426                }
1427                match &s.ops[i] {
1428                    SummandOp::Local(op) => {
1429                        ror_step(
1430                            op,
1431                            i,
1432                            j,
1433                            local_vals,
1434                            local_dot,
1435                            local_adj,
1436                            local_adj_dot,
1437                            w,
1438                            wd,
1439                            weight,
1440                            hess_map,
1441                            values,
1442                        );
1443                    }
1444                    SummandOp::Shared(k) => {
1445                        prelude_adj[*k] += w;
1446                        prelude_adj_dot[*k] += wd;
1447                    }
1448                }
1449            }
1450            for &i in s.prelude_reach.iter().rev() {
1451                let w = prelude_adj[i];
1452                let wd = prelude_adj_dot[i];
1453                if w == 0.0 && wd == 0.0 {
1454                    continue;
1455                }
1456                ror_step(
1457                    &self.prelude[i],
1458                    i,
1459                    j,
1460                    prelude_vals,
1461                    prelude_dot,
1462                    prelude_adj,
1463                    prelude_adj_dot,
1464                    w,
1465                    wd,
1466                    weight,
1467                    hess_map,
1468                    values,
1469                );
1470            }
1471        }
1472    }
1473
1474    /// Structural Hessian sparsity over the whole hybrid tape:
1475    /// every pair the prelude or any summand can produce.
1476    pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
1477        let mut pairs = hessian_sparsity_impl(&self.prelude);
1478
1479        // Per-prelude-slot var-set, reused across summands as the
1480        // var-set carrier for Shared refs.
1481        let prelude_var_sets = compute_var_sets(&self.prelude);
1482
1483        for s in &self.summands {
1484            summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
1485        }
1486        pairs
1487    }
1488}
1489
1490/// Pass-1 helper: per-root walk that increments `counts[ptr]` the
1491/// first time a Cse pointer is encountered in this root. Recursing
1492/// into the body is gated on the first visit to avoid quadratic
1493/// blowup on heavily shared CSE DAGs.
1494fn count_cse_appearances(
1495    e: &Expr,
1496    seen_in_root: &mut HashSet<*const Expr>,
1497    counts: &mut HashMap<*const Expr, usize>,
1498) {
1499    match e {
1500        Expr::Const(_) | Expr::Var(_) => {}
1501        Expr::Binary(_, a, b) => {
1502            count_cse_appearances(a, seen_in_root, counts);
1503            count_cse_appearances(b, seen_in_root, counts);
1504        }
1505        Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
1506        Expr::Sum(args) => {
1507            for a in args {
1508                count_cse_appearances(a, seen_in_root, counts);
1509            }
1510        }
1511        Expr::Cse(body) => {
1512            let key = Rc::as_ptr(body) as *const Expr;
1513            if seen_in_root.insert(key) {
1514                *counts.entry(key).or_insert(0) += 1;
1515                count_cse_appearances(body, seen_in_root, counts);
1516            }
1517        }
1518        Expr::Funcall { args, .. } => {
1519            for arg in args {
1520                if let FuncallArg::Real(e) = arg {
1521                    count_cse_appearances(e, seen_in_root, counts);
1522                }
1523            }
1524        }
1525    }
1526}
1527
1528/// Recursive summand builder. CSEs that meet the promotion bar
1529/// (≥ 2 roots reference them per `cse_count`) get a single prelude
1530/// emission via `build_recursive`; the summand records a Shared op
1531/// pointing at the prelude slot. Non-promoted CSEs are inlined
1532/// into the summand with intra-summand Rc-pointer dedup.
1533fn build_into_summand(
1534    expr: &Expr,
1535    local: &mut Vec<SummandOp>,
1536    local_cache: &mut HashMap<*const Expr, usize>,
1537    prelude: &mut Vec<TapeOp>,
1538    prelude_map: &mut HashMap<*const Expr, usize>,
1539    cse_count: &HashMap<*const Expr, usize>,
1540) -> usize {
1541    match expr {
1542        Expr::Const(c) => {
1543            let i = local.len();
1544            local.push(SummandOp::Local(TapeOp::Const(*c)));
1545            i
1546        }
1547        Expr::Var(j) => {
1548            let i = local.len();
1549            local.push(SummandOp::Local(TapeOp::Var(*j)));
1550            i
1551        }
1552        Expr::Binary(op, a, b) => {
1553            if let BinOp::Pow = op {
1554                if let Some(c) = peek_const(b) {
1555                    if let Some(i) = try_emit_const_pow_summand(
1556                        a,
1557                        c,
1558                        local,
1559                        local_cache,
1560                        prelude,
1561                        prelude_map,
1562                        cse_count,
1563                    ) {
1564                        return i;
1565                    }
1566                }
1567            }
1568            let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1569            let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
1570            let i = local.len();
1571            local.push(SummandOp::Local(match op {
1572                BinOp::Add => TapeOp::Add(l, r),
1573                BinOp::Sub => TapeOp::Sub(l, r),
1574                BinOp::Mul => TapeOp::Mul(l, r),
1575                BinOp::Div => TapeOp::Div(l, r),
1576                BinOp::Pow => TapeOp::Pow(l, r),
1577            }));
1578            i
1579        }
1580        Expr::Unary(op, a) => {
1581            let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1582            let i = local.len();
1583            local.push(SummandOp::Local(match op {
1584                UnaryOp::Neg => TapeOp::Neg(v),
1585                UnaryOp::Sqrt => TapeOp::Sqrt(v),
1586                UnaryOp::Log => TapeOp::Log(v),
1587                UnaryOp::Log10 => TapeOp::Log10(v),
1588                UnaryOp::Exp => TapeOp::Exp(v),
1589                UnaryOp::Abs => TapeOp::Abs(v),
1590                UnaryOp::Sin => TapeOp::Sin(v),
1591                UnaryOp::Cos => TapeOp::Cos(v),
1592            }));
1593            i
1594        }
1595        Expr::Sum(args) => {
1596            if args.is_empty() {
1597                let i = local.len();
1598                local.push(SummandOp::Local(TapeOp::Const(0.0)));
1599                return i;
1600            }
1601            let mut acc = build_into_summand(
1602                &args[0],
1603                local,
1604                local_cache,
1605                prelude,
1606                prelude_map,
1607                cse_count,
1608            );
1609            for a in &args[1..] {
1610                let nxt =
1611                    build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1612                let i = local.len();
1613                local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
1614                acc = i;
1615            }
1616            acc
1617        }
1618        Expr::Cse(body) => {
1619            let key = Rc::as_ptr(body) as *const Expr;
1620            if let Some(&li) = local_cache.get(&key) {
1621                return li;
1622            }
1623            let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
1624            if promoted {
1625                // Build (or reuse) the prelude slot for this CSE.
1626                // `build_recursive(expr, ...)` hits the Cse arm,
1627                // emits the body once into prelude, and caches it
1628                // in `prelude_map` keyed by this Rc pointer.
1629                let pslot =
1630                    build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
1631                let li = local.len();
1632                local.push(SummandOp::Shared(pslot));
1633                local_cache.insert(key, li);
1634                li
1635            } else {
1636                let li =
1637                    build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
1638                local_cache.insert(key, li);
1639                li
1640            }
1641        }
1642        Expr::Funcall { .. } => {
1643            panic!(
1644                "HybridTape: AMPL external function calls are not supported on the \
1645                 hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
1646                 instead."
1647            );
1648        }
1649    }
1650}
1651
1652/// Pow-lowering specialised for summand builds. Mirrors
1653/// `try_emit_const_pow` but with summand-flavoured emission.
1654fn try_emit_const_pow_summand(
1655    base_expr: &Expr,
1656    c: f64,
1657    local: &mut Vec<SummandOp>,
1658    local_cache: &mut HashMap<*const Expr, usize>,
1659    prelude: &mut Vec<TapeOp>,
1660    prelude_map: &mut HashMap<*const Expr, usize>,
1661    cse_count: &HashMap<*const Expr, usize>,
1662) -> Option<usize> {
1663    if c == 0.0 {
1664        let i = local.len();
1665        local.push(SummandOp::Local(TapeOp::Const(1.0)));
1666        return Some(i);
1667    }
1668    if c == 1.0 {
1669        return Some(build_into_summand(
1670            base_expr,
1671            local,
1672            local_cache,
1673            prelude,
1674            prelude_map,
1675            cse_count,
1676        ));
1677    }
1678    if c == 0.5 {
1679        let b = build_into_summand(
1680            base_expr,
1681            local,
1682            local_cache,
1683            prelude,
1684            prelude_map,
1685            cse_count,
1686        );
1687        let i = local.len();
1688        local.push(SummandOp::Local(TapeOp::Sqrt(b)));
1689        return Some(i);
1690    }
1691    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1692        let n = c.abs() as u32;
1693        if n == 0 {
1694            let i = local.len();
1695            local.push(SummandOp::Local(TapeOp::Const(1.0)));
1696            return Some(i);
1697        }
1698        let b = build_into_summand(
1699            base_expr,
1700            local,
1701            local_cache,
1702            prelude,
1703            prelude_map,
1704            cse_count,
1705        );
1706        let pos = emit_int_pow_summand(b, n, local);
1707        if c < 0.0 {
1708            let one_idx = local.len();
1709            local.push(SummandOp::Local(TapeOp::Const(1.0)));
1710            let i = local.len();
1711            local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
1712            return Some(i);
1713        }
1714        return Some(pos);
1715    }
1716    None
1717}
1718
1719fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
1720    debug_assert!(n >= 1);
1721    if n == 1 {
1722        return base;
1723    }
1724    let half = emit_int_pow_summand(base, n / 2, local);
1725    let squared = local.len();
1726    local.push(SummandOp::Local(TapeOp::Mul(half, half)));
1727    if n % 2 == 1 {
1728        let i = local.len();
1729        local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
1730        i
1731    } else {
1732        squared
1733    }
1734}
1735
1736/// Walk a summand's local op DAG from `root`, returning the
1737/// reachable local slots (sorted ascending) plus the distinct
1738/// prelude slots referenced by any Shared op along the way.
1739fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
1740    let mut visited = vec![false; ops.len()];
1741    let mut reach: Vec<usize> = Vec::new();
1742    let mut shared: BTreeSet<usize> = BTreeSet::new();
1743    let mut stack: Vec<usize> = Vec::with_capacity(16);
1744    visited[root] = true;
1745    reach.push(root);
1746    stack.push(root);
1747    while let Some(s) = stack.pop() {
1748        match &ops[s] {
1749            SummandOp::Local(op) => {
1750                let (a, b) = op_operands(op);
1751                if let Some(a) = a {
1752                    if !visited[a] {
1753                        visited[a] = true;
1754                        reach.push(a);
1755                        stack.push(a);
1756                    }
1757                }
1758                if let Some(b) = b {
1759                    if !visited[b] {
1760                        visited[b] = true;
1761                        reach.push(b);
1762                        stack.push(b);
1763                    }
1764                }
1765            }
1766            SummandOp::Shared(k) => {
1767                shared.insert(*k);
1768            }
1769        }
1770    }
1771    reach.sort_unstable();
1772    (reach, shared.into_iter().collect())
1773}
1774
1775/// Epoch-tagged BFS over the prelude operand DAG, accumulating
1776/// reachable slots into `out`. Caller is responsible for sorting
1777/// `out` after a batch of starts has been processed.
1778fn bfs_prelude(
1779    prelude: &[TapeOp],
1780    start: usize,
1781    visited: &mut [u32],
1782    cur: u32,
1783    stack: &mut Vec<usize>,
1784    out: &mut Vec<usize>,
1785) {
1786    if visited[start] == cur {
1787        return;
1788    }
1789    visited[start] = cur;
1790    out.push(start);
1791    stack.push(start);
1792    while let Some(s) = stack.pop() {
1793        let (a, b) = op_operands(&prelude[s]);
1794        if let Some(a) = a {
1795            if visited[a] != cur {
1796                visited[a] = cur;
1797                out.push(a);
1798                stack.push(a);
1799            }
1800        }
1801        if let Some(b) = b {
1802            if visited[b] != cur {
1803                visited[b] = cur;
1804                out.push(b);
1805                stack.push(b);
1806            }
1807        }
1808    }
1809}
1810
1811/// Per-op var-set for the prelude — every slot's transitive
1812/// variable footprint. Used by `summand_sparsity` to expand
1813/// `SummandOp::Shared(k)` into its var-set carrier.
1814fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
1815    let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1816    for op in ops {
1817        let vs: BTreeSet<usize> = match op {
1818            TapeOp::Const(_) => BTreeSet::new(),
1819            TapeOp::Var(j) => {
1820                let mut s = BTreeSet::new();
1821                s.insert(*j);
1822                s
1823            }
1824            TapeOp::Add(a, b)
1825            | TapeOp::Sub(a, b)
1826            | TapeOp::Mul(a, b)
1827            | TapeOp::Div(a, b)
1828            | TapeOp::Pow(a, b) => out[*a].union(&out[*b]).copied().collect(),
1829            TapeOp::Neg(a)
1830            | TapeOp::Abs(a)
1831            | TapeOp::Sqrt(a)
1832            | TapeOp::Exp(a)
1833            | TapeOp::Log(a)
1834            | TapeOp::Log10(a)
1835            | TapeOp::Sin(a)
1836            | TapeOp::Cos(a) => out[*a].clone(),
1837            TapeOp::Funcall { .. } => unreachable!(
1838                "HybridTape prelude cannot contain TapeOp::Funcall; \
1839                 build_into_summand panics on Expr::Funcall."
1840            ),
1841        };
1842        out.push(vs);
1843    }
1844    out
1845}
1846
1847/// Per-op Hessian-sparsity propagation over a summand's mixed
1848/// SummandOp slice. Shared refs contribute their prelude var-set
1849/// but do not themselves emit pairs (those came from
1850/// `hessian_sparsity_impl(&prelude)`).
1851fn summand_sparsity(
1852    ops: &[SummandOp],
1853    prelude_var_sets: &[BTreeSet<usize>],
1854    pairs: &mut BTreeSet<(usize, usize)>,
1855) {
1856    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1857    let emit_cross =
1858        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1859            for &v1 in s1 {
1860                for &v2 in s2 {
1861                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1862                    pairs.insert((r, c));
1863                }
1864            }
1865        };
1866    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1867        let vars: Vec<usize> = s.iter().copied().collect();
1868        for (ai, &vi) in vars.iter().enumerate() {
1869            for &vj in &vars[..=ai] {
1870                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1871                pairs.insert((r, c));
1872            }
1873        }
1874    };
1875    for so in ops {
1876        let vset: BTreeSet<usize> = match so {
1877            SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
1878            SummandOp::Local(op) => match op {
1879                TapeOp::Const(_) => BTreeSet::new(),
1880                TapeOp::Var(j) => {
1881                    let mut s = BTreeSet::new();
1882                    s.insert(*j);
1883                    s
1884                }
1885                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1886                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1887                }
1888                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1889                TapeOp::Mul(a, b) => {
1890                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1891                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1892                }
1893                TapeOp::Div(a, b) => {
1894                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1895                    emit_self(&var_sets[*b], pairs);
1896                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1897                }
1898                TapeOp::Pow(a, b) => {
1899                    let combined: BTreeSet<usize> =
1900                        var_sets[*a].union(&var_sets[*b]).copied().collect();
1901                    emit_self(&combined, pairs);
1902                    combined
1903                }
1904                TapeOp::Sqrt(a)
1905                | TapeOp::Exp(a)
1906                | TapeOp::Log(a)
1907                | TapeOp::Log10(a)
1908                | TapeOp::Sin(a)
1909                | TapeOp::Cos(a) => {
1910                    emit_self(&var_sets[*a], pairs);
1911                    var_sets[*a].clone()
1912                }
1913                TapeOp::Funcall { .. } => unreachable!(
1914                    "HybridTape summand cannot contain TapeOp::Funcall; \
1915                     build_into_summand panics on Expr::Funcall."
1916                ),
1917            },
1918        };
1919        var_sets.push(vset);
1920    }
1921}
1922
1923/// Operand indices of a `TapeOp`, normalized into a fixed-length
1924/// array so callers don't need to re-match every site.
1925#[inline]
1926fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
1927    match op {
1928        TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
1929        TapeOp::Add(a, b)
1930        | TapeOp::Sub(a, b)
1931        | TapeOp::Mul(a, b)
1932        | TapeOp::Div(a, b)
1933        | TapeOp::Pow(a, b) => (Some(*a), Some(*b)),
1934        TapeOp::Neg(a)
1935        | TapeOp::Abs(a)
1936        | TapeOp::Sqrt(a)
1937        | TapeOp::Exp(a)
1938        | TapeOp::Log(a)
1939        | TapeOp::Log10(a)
1940        | TapeOp::Sin(a)
1941        | TapeOp::Cos(a) => (Some(*a), None),
1942        TapeOp::Funcall { .. } => (None, None),
1943    }
1944}
1945
1946fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
1947    let mut s: BTreeSet<usize> = BTreeSet::new();
1948    for &i in reach {
1949        if let TapeOp::Var(j) = &ops[i] {
1950            s.insert(*j);
1951        }
1952    }
1953    s.into_iter().collect()
1954}
1955
1956// ----- Free-function AD step kernels used by GlobalTape -----
1957
1958#[inline]
1959fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
1960    match op {
1961        TapeOp::Const(c) => *c,
1962        TapeOp::Var(i) => x[*i],
1963        TapeOp::Add(a, b) => vals[*a] + vals[*b],
1964        TapeOp::Sub(a, b) => vals[*a] - vals[*b],
1965        TapeOp::Mul(a, b) => vals[*a] * vals[*b],
1966        TapeOp::Div(a, b) => vals[*a] / vals[*b],
1967        TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
1968        TapeOp::Neg(a) => -vals[*a],
1969        TapeOp::Abs(a) => vals[*a].abs(),
1970        TapeOp::Sqrt(a) => vals[*a].sqrt(),
1971        TapeOp::Exp(a) => vals[*a].exp(),
1972        TapeOp::Log(a) => vals[*a].ln(),
1973        TapeOp::Log10(a) => vals[*a].log10(),
1974        TapeOp::Sin(a) => vals[*a].sin(),
1975        TapeOp::Cos(a) => vals[*a].cos(),
1976        TapeOp::Funcall { lib, name, args } => {
1977            let call_args = funcall_to_ext_args(args, vals);
1978            let res = lib
1979                .eval(name, &call_args, false, false)
1980                .unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
1981            res.value
1982        }
1983    }
1984}
1985
1986#[inline]
1987fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
1988    match op {
1989        TapeOp::Const(_) => {}
1990        TapeOp::Var(j) => {
1991            grad[*j] += a;
1992        }
1993        TapeOp::Add(l, r) => {
1994            adj[*l] += a;
1995            adj[*r] += a;
1996        }
1997        TapeOp::Sub(l, r) => {
1998            adj[*l] += a;
1999            adj[*r] -= a;
2000        }
2001        TapeOp::Mul(l, r) => {
2002            adj[*l] += a * vals[*r];
2003            adj[*r] += a * vals[*l];
2004        }
2005        TapeOp::Div(l, r) => {
2006            let rv = vals[*r];
2007            adj[*l] += a / rv;
2008            adj[*r] -= a * vals[*l] / (rv * rv);
2009        }
2010        TapeOp::Pow(l, r) => {
2011            let lv = vals[*l];
2012            let rv = vals[*r];
2013            if rv != 0.0 {
2014                adj[*l] += a * rv * lv.powf(rv - 1.0);
2015            }
2016            if lv > 0.0 {
2017                adj[*r] += a * vals[i] * lv.ln();
2018            }
2019        }
2020        TapeOp::Neg(j) => {
2021            adj[*j] -= a;
2022        }
2023        TapeOp::Abs(j) => {
2024            if vals[*j] >= 0.0 {
2025                adj[*j] += a;
2026            } else {
2027                adj[*j] -= a;
2028            }
2029        }
2030        TapeOp::Sqrt(j) => {
2031            let sv = vals[i];
2032            if sv > 0.0 {
2033                adj[*j] += a * 0.5 / sv;
2034            }
2035        }
2036        TapeOp::Exp(j) => {
2037            adj[*j] += a * vals[i];
2038        }
2039        TapeOp::Log(j) => {
2040            adj[*j] += a / vals[*j];
2041        }
2042        TapeOp::Log10(j) => {
2043            adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
2044        }
2045        TapeOp::Sin(j) => {
2046            adj[*j] += a * vals[*j].cos();
2047        }
2048        TapeOp::Cos(j) => {
2049            adj[*j] -= a * vals[*j].sin();
2050        }
2051        TapeOp::Funcall { lib, name, args } => {
2052            let call_args = funcall_to_ext_args(args, vals);
2053            let res = lib
2054                .eval(name, &call_args, true, false)
2055                .unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
2056            let derivs = res.derivs.expect("want_derivs=true returns derivs");
2057            let mut k = 0usize;
2058            for arg in args {
2059                if let TapeFuncallArg::Tape(idx) = arg {
2060                    adj[*idx] += a * derivs[k];
2061                    k += 1;
2062                }
2063            }
2064            let _ = i;
2065            let _ = grad;
2066        }
2067    }
2068}
2069
2070#[inline]
2071fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
2072    match op {
2073        TapeOp::Const(_) => 0.0,
2074        TapeOp::Var(k) => {
2075            if *k == seed_var {
2076                1.0
2077            } else {
2078                0.0
2079            }
2080        }
2081        TapeOp::Add(a, b) => dot[*a] + dot[*b],
2082        TapeOp::Sub(a, b) => dot[*a] - dot[*b],
2083        TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
2084        TapeOp::Div(a, b) => {
2085            let vb = vals[*b];
2086            (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
2087        }
2088        TapeOp::Pow(a, b) => {
2089            let u = vals[*a];
2090            let r = vals[*b];
2091            let du = dot[*a];
2092            let dr = dot[*b];
2093            let mut result = 0.0;
2094            if r != 0.0 && u != 0.0 {
2095                result += r * u.powf(r - 1.0) * du;
2096            }
2097            if u > 0.0 {
2098                result += vals[i] * u.ln() * dr;
2099            }
2100            result
2101        }
2102        TapeOp::Neg(a) => -dot[*a],
2103        TapeOp::Abs(a) => {
2104            if vals[*a] >= 0.0 {
2105                dot[*a]
2106            } else {
2107                -dot[*a]
2108            }
2109        }
2110        TapeOp::Sqrt(a) => {
2111            let sv = vals[i];
2112            if sv > 0.0 {
2113                dot[*a] * 0.5 / sv
2114            } else {
2115                0.0
2116            }
2117        }
2118        TapeOp::Exp(a) => dot[*a] * vals[i],
2119        TapeOp::Log(a) => dot[*a] / vals[*a],
2120        TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
2121        TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
2122        TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
2123        TapeOp::Funcall { lib, name, args } => {
2124            let call_args = funcall_to_ext_args(args, vals);
2125            let res = lib
2126                .eval(name, &call_args, true, false)
2127                .unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
2128            let derivs = res.derivs.expect("want_derivs=true returns derivs");
2129            let mut acc = 0.0;
2130            let mut k = 0usize;
2131            for arg in args {
2132                if let TapeFuncallArg::Tape(idx) = arg {
2133                    acc += derivs[k] * dot[*idx];
2134                    k += 1;
2135                }
2136            }
2137            let _ = seed_var;
2138            acc
2139        }
2140    }
2141}
2142
2143#[allow(clippy::too_many_arguments)]
2144#[inline]
2145fn ror_step(
2146    op: &TapeOp,
2147    i: usize,
2148    seed_var: usize,
2149    vals: &[f64],
2150    dot: &[f64],
2151    adj: &mut [f64],
2152    adj_dot: &mut [f64],
2153    w: f64,
2154    wd: f64,
2155    weight: f64,
2156    hess_map: &HashMap<(usize, usize), usize>,
2157    values: &mut [f64],
2158) {
2159    match op {
2160        TapeOp::Const(_) => {}
2161        TapeOp::Var(k) => {
2162            if wd != 0.0 && *k >= seed_var {
2163                if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
2164                    values[pos] += weight * wd;
2165                }
2166            }
2167        }
2168        TapeOp::Add(a, b) => {
2169            adj[*a] += w;
2170            adj[*b] += w;
2171            adj_dot[*a] += wd;
2172            adj_dot[*b] += wd;
2173        }
2174        TapeOp::Sub(a, b) => {
2175            adj[*a] += w;
2176            adj[*b] -= w;
2177            adj_dot[*a] += wd;
2178            adj_dot[*b] -= wd;
2179        }
2180        TapeOp::Mul(a, b) => {
2181            adj[*a] += w * vals[*b];
2182            adj[*b] += w * vals[*a];
2183            adj_dot[*a] += wd * vals[*b] + w * dot[*b];
2184            adj_dot[*b] += wd * vals[*a] + w * dot[*a];
2185        }
2186        TapeOp::Div(a, b) => {
2187            let vb = vals[*b];
2188            let vb2 = vb * vb;
2189            let vb3 = vb2 * vb;
2190            adj[*a] += w / vb;
2191            adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
2192            adj[*b] += w * (-vals[*a] / vb2);
2193            adj_dot[*b] +=
2194                wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
2195        }
2196        TapeOp::Pow(a, b) => {
2197            let u = vals[*a];
2198            let r = vals[*b];
2199            let du = dot[*a];
2200            let dr = dot[*b];
2201            if r != 0.0 {
2202                if u != 0.0 {
2203                    let p_a = r * u.powf(r - 1.0);
2204                    adj[*a] += w * p_a;
2205                    let mut dp_a = dr * u.powf(r - 1.0);
2206                    if u > 0.0 {
2207                        dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
2208                    } else {
2209                        dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
2210                    }
2211                    adj_dot[*a] += wd * p_a + w * dp_a;
2212                } else if r >= 2.0 {
2213                    let p_a = 0.0;
2214                    adj[*a] += w * p_a;
2215                    let dp_a = if r == 2.0 {
2216                        2.0 * du
2217                    } else {
2218                        r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
2219                    };
2220                    adj_dot[*a] += wd * p_a + w * dp_a;
2221                }
2222            }
2223            if u > 0.0 {
2224                let ln_u = u.ln();
2225                let p_b = vals[i] * ln_u;
2226                adj[*b] += w * p_b;
2227                let dur = vals[i] * (r * du / u + dr * ln_u);
2228                let dp_b = dur * ln_u + vals[i] * du / u;
2229                adj_dot[*b] += wd * p_b + w * dp_b;
2230            }
2231        }
2232        TapeOp::Neg(a) => {
2233            adj[*a] -= w;
2234            adj_dot[*a] -= wd;
2235        }
2236        TapeOp::Abs(a) => {
2237            let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
2238            adj[*a] += w * s;
2239            adj_dot[*a] += wd * s;
2240        }
2241        TapeOp::Sqrt(a) => {
2242            let sv = vals[i];
2243            if sv > 0.0 {
2244                let fp = 0.5 / sv;
2245                let fpp = -0.25 / (vals[*a] * sv);
2246                adj[*a] += w * fp;
2247                adj_dot[*a] += wd * fp + w * fpp * dot[*a];
2248            }
2249        }
2250        TapeOp::Exp(a) => {
2251            let ev = vals[i];
2252            adj[*a] += w * ev;
2253            adj_dot[*a] += wd * ev + w * ev * dot[*a];
2254        }
2255        TapeOp::Log(a) => {
2256            let u = vals[*a];
2257            adj[*a] += w / u;
2258            adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
2259        }
2260        TapeOp::Log10(a) => {
2261            let u = vals[*a];
2262            let c = std::f64::consts::LN_10;
2263            adj[*a] += w / (u * c);
2264            adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
2265        }
2266        TapeOp::Sin(a) => {
2267            let u = vals[*a];
2268            let cu = u.cos();
2269            adj[*a] += w * cu;
2270            adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
2271        }
2272        TapeOp::Cos(a) => {
2273            let u = vals[*a];
2274            let su = u.sin();
2275            adj[*a] -= w * su;
2276            adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
2277        }
2278        TapeOp::Funcall { lib, name, args } => {
2279            let call_args = funcall_to_ext_args(args, vals);
2280            let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
2281                panic!("external function '{name}' 2nd-order eval failed: {e}")
2282            });
2283            let derivs = res.derivs.expect("want_derivs=true returns derivs");
2284            let hes = res.hessian.expect("want_hes=true returns hessian");
2285            let real_tape: Vec<usize> = args
2286                .iter()
2287                .filter_map(|a| match a {
2288                    TapeFuncallArg::Tape(t) => Some(*t),
2289                    TapeFuncallArg::Str(_) => None,
2290                })
2291                .collect();
2292            for (k, &tk) in real_tape.iter().enumerate() {
2293                adj[tk] += w * derivs[k];
2294                let mut second_term = 0.0;
2295                for (l, &tl) in real_tape.iter().enumerate() {
2296                    let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
2297                    let h_kl = hes[lo + hi * (hi + 1) / 2];
2298                    second_term += h_kl * dot[tl];
2299                }
2300                adj_dot[tk] += wd * derivs[k] + w * second_term;
2301            }
2302            let _ = seed_var;
2303            let _ = hess_map;
2304            let _ = values;
2305            let _ = weight;
2306            let _ = i;
2307        }
2308    }
2309}
2310
2311/// Per-op Hessian-sparsity propagation. Same algorithm as
2312/// `Tape::hessian_sparsity` but as a free function so `GlobalTape`
2313/// can call it over its shared `ops` slice.
2314fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
2315    let n = ops.len();
2316    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
2317    let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2318
2319    let emit_cross =
2320        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2321            for &v1 in s1 {
2322                for &v2 in s2 {
2323                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2324                    pairs.insert((r, c));
2325                }
2326            }
2327        };
2328    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2329        let vars: Vec<usize> = s.iter().copied().collect();
2330        for (ai, &vi) in vars.iter().enumerate() {
2331            for &vj in &vars[..=ai] {
2332                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2333                pairs.insert((r, c));
2334            }
2335        }
2336    };
2337
2338    for op in ops {
2339        let vset = match op {
2340            TapeOp::Const(_) => BTreeSet::new(),
2341            TapeOp::Var(j) => {
2342                let mut s = BTreeSet::new();
2343                s.insert(*j);
2344                s
2345            }
2346            TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2347                var_sets[*a].union(&var_sets[*b]).copied().collect()
2348            }
2349            TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2350            TapeOp::Mul(a, b) => {
2351                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2352                var_sets[*a].union(&var_sets[*b]).copied().collect()
2353            }
2354            TapeOp::Div(a, b) => {
2355                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2356                emit_self(&var_sets[*b], &mut pairs);
2357                var_sets[*a].union(&var_sets[*b]).copied().collect()
2358            }
2359            TapeOp::Pow(a, b) => {
2360                let combined: BTreeSet<usize> =
2361                    var_sets[*a].union(&var_sets[*b]).copied().collect();
2362                emit_self(&combined, &mut pairs);
2363                combined
2364            }
2365            TapeOp::Sqrt(a)
2366            | TapeOp::Exp(a)
2367            | TapeOp::Log(a)
2368            | TapeOp::Log10(a)
2369            | TapeOp::Sin(a)
2370            | TapeOp::Cos(a) => {
2371                emit_self(&var_sets[*a], &mut pairs);
2372                var_sets[*a].clone()
2373            }
2374            TapeOp::Funcall { args, .. } => {
2375                let mut combined: BTreeSet<usize> = BTreeSet::new();
2376                for arg in args {
2377                    if let TapeFuncallArg::Tape(t) = arg {
2378                        for &vv in &var_sets[*t] {
2379                            combined.insert(vv);
2380                        }
2381                    }
2382                }
2383                emit_self(&combined, &mut pairs);
2384                combined
2385            }
2386        };
2387        var_sets.push(vset);
2388    }
2389    pairs
2390}
2391
2392#[cfg(test)]
2393mod tests {
2394    use super::*;
2395
2396    fn cnst(c: f64) -> Expr {
2397        Expr::Const(c)
2398    }
2399    fn var(i: usize) -> Expr {
2400        Expr::Var(i)
2401    }
2402    fn add(a: Expr, b: Expr) -> Expr {
2403        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
2404    }
2405    fn mul(a: Expr, b: Expr) -> Expr {
2406        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
2407    }
2408    fn pow(a: Expr, b: Expr) -> Expr {
2409        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
2410    }
2411    fn div(a: Expr, b: Expr) -> Expr {
2412        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
2413    }
2414    fn unary(op: UnaryOp, a: Expr) -> Expr {
2415        Expr::Unary(op, Box::new(a))
2416    }
2417
2418    #[test]
2419    fn polynomial_eval_and_grad() {
2420        // f = 3*x0^2 + 2*x1
2421        let e = add(
2422            mul(cnst(3.0), pow(var(0), cnst(2.0))),
2423            mul(cnst(2.0), var(1)),
2424        );
2425        let t = Tape::build(&e);
2426        assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
2427        let mut g = vec![0.0; 2];
2428        t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
2429        // df/dx0 = 6*x0 = 12, df/dx1 = 2
2430        assert!((g[0] - 12.0).abs() < 1e-12);
2431        assert!((g[1] - 2.0).abs() < 1e-12);
2432    }
2433
2434    #[test]
2435    fn cse_shared_body_evaluated_once() {
2436        // body = x0 + x1, shared via Rc. f = body^2 + body.
2437        let body = Rc::new(add(var(0), var(1)));
2438        let e = add(
2439            pow(Expr::Cse(body.clone()), cnst(2.0)),
2440            Expr::Cse(body.clone()),
2441        );
2442        let t = Tape::build(&e);
2443        // body should appear once in the tape: count Add(Var(0),Var(1)) ops
2444        let n_body_adds = t
2445            .ops
2446            .iter()
2447            .filter(|op| {
2448                matches!(op, TapeOp::Add(a, b) if {
2449                    matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
2450                })
2451            })
2452            .count();
2453        assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
2454
2455        // f(1, 2) = 9 + 3 = 12
2456        assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
2457        let mut g = vec![0.0; 2];
2458        t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
2459        // df/dx0 = 2*(x0+x1) + 1 = 7, same for x1
2460        assert!((g[0] - 7.0).abs() < 1e-12);
2461        assert!((g[1] - 7.0).abs() < 1e-12);
2462    }
2463
2464    fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
2465        let vars = tape.variables();
2466        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2467        let mut pairs = Vec::new();
2468        for (ai, &vi) in vars.iter().enumerate() {
2469            for &vj in &vars[..=ai] {
2470                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2471                hess_map.entry((r, c)).or_insert_with(|| {
2472                    let p = pairs.len();
2473                    pairs.push((r, c));
2474                    p
2475                });
2476            }
2477        }
2478        let nnz = pairs.len();
2479        let mut ad = vec![0.0; nnz];
2480        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2481
2482        let mut fd = vec![0.0; nnz];
2483        let mut xp = x.to_vec();
2484        let mut gp = vec![0.0; n];
2485        let mut gm = vec![0.0; n];
2486        for &j in &vars {
2487            let h = (1e-7_f64).max(x[j].abs() * 1e-7);
2488            xp[j] = x[j] + h;
2489            gp.iter_mut().for_each(|v| *v = 0.0);
2490            tape.gradient_seed(&xp, 1.0, &mut gp);
2491            xp[j] = x[j] - h;
2492            gm.iter_mut().for_each(|v| *v = 0.0);
2493            tape.gradient_seed(&xp, 1.0, &mut gm);
2494            xp[j] = x[j];
2495            for &i in &vars {
2496                if i >= j {
2497                    if let Some(&pos) = hess_map.get(&(i, j)) {
2498                        fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
2499                    }
2500                }
2501            }
2502        }
2503        for (k, &(r, c)) in pairs.iter().enumerate() {
2504            let scale = fd[k].abs().max(1.0);
2505            assert!(
2506                (ad[k] - fd[k]).abs() / scale < tol,
2507                "H[{},{}]: AD={:.6e} FD={:.6e}",
2508                r,
2509                c,
2510                ad[k],
2511                fd[k]
2512            );
2513        }
2514    }
2515
2516    #[test]
2517    fn hessian_quadratic_matches_fd() {
2518        // f = 3 x0^2 + 2 x0 x1 + x1^2
2519        let e = add(
2520            add(
2521                mul(cnst(3.0), pow(var(0), cnst(2.0))),
2522                mul(cnst(2.0), mul(var(0), var(1))),
2523            ),
2524            pow(var(1), cnst(2.0)),
2525        );
2526        let t = Tape::build(&e);
2527        fd_check(&t, &[2.0, 3.0], 2, 1e-5);
2528    }
2529
2530    #[test]
2531    fn hessian_transcendental_matches_fd() {
2532        // f = exp(x0) + sin(x1) + log(x0) + sqrt(x1) + x0*x1
2533        let e = Expr::Sum(vec![
2534            unary(UnaryOp::Exp, var(0)),
2535            unary(UnaryOp::Sin, var(1)),
2536            unary(UnaryOp::Log, var(0)),
2537            unary(UnaryOp::Sqrt, var(1)),
2538            mul(var(0), var(1)),
2539        ]);
2540        let t = Tape::build(&e);
2541        fd_check(&t, &[1.5, 2.0], 2, 1e-5);
2542    }
2543
2544    #[test]
2545    fn hessian_division_matches_fd() {
2546        // f = x0/x1 + cos(x0)
2547        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2548        let t = Tape::build(&e);
2549        fd_check(&t, &[0.5, 1.2], 2, 1e-5);
2550    }
2551
2552    /// `hessian_directional` (one forward-over-reverse pass with
2553    /// a seed vector) recovers `H · e_j` for each unit-vector seed,
2554    /// matching column `j` of the dense Hessian computed by
2555    /// `hessian_accumulate`.
2556    fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
2557        let vars = tape.variables();
2558        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2559        let mut pairs = Vec::new();
2560        for (ai, &vi) in vars.iter().enumerate() {
2561            for &vj in &vars[..=ai] {
2562                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2563                hess_map.entry((r, c)).or_insert_with(|| {
2564                    let p = pairs.len();
2565                    pairs.push((r, c));
2566                    p
2567                });
2568            }
2569        }
2570        let nnz = pairs.len();
2571        let mut ad = vec![0.0; nnz];
2572        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2573
2574        let nops = tape.ops.len();
2575        let mut vals = vec![0.0; nops];
2576        tape.forward_into(x, &mut vals);
2577        let mut dot = vec![0.0; nops];
2578        let mut adj = vec![0.0; nops];
2579        let mut adj_dot = vec![0.0; nops];
2580
2581        for &j in &vars {
2582            let mut seed = vec![0.0; n];
2583            seed[j] = 1.0;
2584            let mut col = vec![0.0; n];
2585            tape.hessian_directional(
2586                &vals,
2587                &seed,
2588                1.0,
2589                &mut col,
2590                &mut dot,
2591                &mut adj,
2592                &mut adj_dot,
2593            );
2594            for &i in &vars {
2595                let (r, c) = if i >= j { (i, j) } else { (j, i) };
2596                let expect = ad[hess_map[&(r, c)]];
2597                assert!(
2598                    (col[i] - expect).abs() < 1e-10,
2599                    "directional H[{i},{j}] = {} vs accumulate {}",
2600                    col[i],
2601                    expect
2602                );
2603            }
2604        }
2605    }
2606
2607    #[test]
2608    fn directional_quadratic_matches_accumulate() {
2609        // f = 3 x0^2 + 2 x0 x1 + x1^2
2610        let e = add(
2611            add(
2612                mul(cnst(3.0), pow(var(0), cnst(2.0))),
2613                mul(mul(cnst(2.0), var(0)), var(1)),
2614            ),
2615            pow(var(1), cnst(2.0)),
2616        );
2617        let t = Tape::build(&e);
2618        directional_matches_accumulate(&t, &[0.5, -0.3], 2);
2619    }
2620
2621    #[test]
2622    fn directional_transcendental_matches_accumulate() {
2623        let e = Expr::Sum(vec![
2624            unary(UnaryOp::Exp, var(0)),
2625            unary(UnaryOp::Sin, var(1)),
2626            unary(UnaryOp::Log, var(0)),
2627            unary(UnaryOp::Sqrt, var(1)),
2628            mul(var(0), var(1)),
2629        ]);
2630        let t = Tape::build(&e);
2631        directional_matches_accumulate(&t, &[1.5, 2.0], 2);
2632    }
2633
2634    #[test]
2635    fn directional_with_division_matches_accumulate() {
2636        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2637        let t = Tape::build(&e);
2638        directional_matches_accumulate(&t, &[0.5, 1.2], 2);
2639    }
2640
2641    #[test]
2642    fn hessian_sparsity_separable() {
2643        // f = sin(x0) + x1*x2; couplings: (0,0) from sin, (2,1) from x1*x2
2644        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
2645        let t = Tape::build(&e);
2646        let s = t.hessian_sparsity();
2647        assert!(s.contains(&(0, 0)));
2648        assert!(s.contains(&(2, 1)));
2649        assert!(!s.contains(&(1, 0)));
2650        assert!(!s.contains(&(2, 0)));
2651    }
2652
2653    fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
2654        t.ops.iter().filter(|o| pred(o)).count()
2655    }
2656
2657    #[test]
2658    fn pow_zero_const_folds_to_one() {
2659        // x^0 → 1 (no Pow, no reference to x in the tape)
2660        let e = pow(var(0), cnst(0.0));
2661        let t = Tape::build(&e);
2662        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2663        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
2664        assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
2665    }
2666
2667    #[test]
2668    fn pow_one_passes_through() {
2669        // x^1 → x (no Pow, no Const introduced for the exponent)
2670        let e = pow(var(0), cnst(1.0));
2671        let t = Tape::build(&e);
2672        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2673        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
2674        assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
2675    }
2676
2677    #[test]
2678    fn pow_half_lowers_to_sqrt() {
2679        let e = pow(var(0), cnst(0.5));
2680        let t = Tape::build(&e);
2681        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2682        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
2683        assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
2684    }
2685
2686    #[test]
2687    fn pow_two_lowers_to_single_mul() {
2688        let e = pow(var(0), cnst(2.0));
2689        let t = Tape::build(&e);
2690        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2691        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2692        assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
2693    }
2694
2695    #[test]
2696    fn pow_three_lowers_to_two_muls() {
2697        let e = pow(var(0), cnst(3.0));
2698        let t = Tape::build(&e);
2699        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2700        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
2701        assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
2702    }
2703
2704    #[test]
2705    fn pow_eight_lowers_to_three_muls() {
2706        // Binary expansion: x → x² → x⁴ → x⁸ (3 squarings)
2707        let e = pow(var(0), cnst(8.0));
2708        let t = Tape::build(&e);
2709        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2710        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
2711        assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
2712    }
2713
2714    #[test]
2715    fn pow_negative_two_lowers_to_div() {
2716        // x^-2 → 1 / (x*x)
2717        let e = pow(var(0), cnst(-2.0));
2718        let t = Tape::build(&e);
2719        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2720        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
2721        assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
2722    }
2723
2724    #[test]
2725    fn pow_large_const_stays_generic() {
2726        // x^9 stays as Pow — beyond the cutoff, generic is cheaper.
2727        let e = pow(var(0), cnst(9.0));
2728        let t = Tape::build(&e);
2729        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2730    }
2731
2732    #[test]
2733    fn pow_non_integer_const_stays_generic() {
2734        // x^1.5 stays as Pow until half-integer handling is added.
2735        let e = pow(var(0), cnst(1.5));
2736        let t = Tape::build(&e);
2737        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2738    }
2739
2740    #[test]
2741    fn pow_const_through_cse_const() {
2742        // Exponent wrapped in Cse — peek_const should still see it.
2743        let two = Rc::new(cnst(2.0));
2744        let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
2745        let t = Tape::build(&e);
2746        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2747        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2748    }
2749
2750    #[test]
2751    fn hessian_pow_three_matches_fd() {
2752        // f = 5 * x0^3 + x0 * x1 — exercises the lowered cubic + cross term.
2753        let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
2754        let t = Tape::build(&e);
2755        fd_check(&t, &[1.7, 0.8], 2, 1e-5);
2756    }
2757
2758    #[test]
2759    fn hessian_pow_negative_matches_fd() {
2760        // f = 1/x0^2 + x1^2 — exercises lowered x^-2 and x^2.
2761        let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
2762        let t = Tape::build(&e);
2763        fd_check(&t, &[1.3, 2.4], 2, 1e-5);
2764    }
2765
2766    #[test]
2767    fn hessian_pow_half_matches_fd() {
2768        // f = sqrt(x0) + x0*x1 (via Pow(_, 0.5) → Sqrt)
2769        let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
2770        let t = Tape::build(&e);
2771        fd_check(&t, &[2.5, 1.1], 2, 1e-5);
2772    }
2773
2774    #[test]
2775    fn hessian_sparsity_through_cse() {
2776        // body = x0+x1 (CSE). f = body^2 + body.
2777        // d²/dx² of body^2 couples (0,0), (1,0), (1,1).
2778        let body = Rc::new(add(var(0), var(1)));
2779        let e = add(
2780            pow(Expr::Cse(body.clone()), cnst(2.0)),
2781            Expr::Cse(body.clone()),
2782        );
2783        let t = Tape::build(&e);
2784        let s = t.hessian_sparsity();
2785        assert!(s.contains(&(0, 0)));
2786        assert!(s.contains(&(1, 0)));
2787        assert!(s.contains(&(1, 1)));
2788        assert_eq!(s.len(), 3);
2789    }
2790}