Skip to main content

pounce_cli/
nl_tape.rs

1//! Flat-tape reverse-mode AD for `.nl` expression trees.
2//!
3//! Replaces the FD-based Hessian path with a port of the tape AD used
4//! in `ripopt::nl::autodiff`. The tape is a `Vec<TapeOp>` where each op
5//! refers to its operands by tape-slot index; forward evaluation runs
6//! through the slice once filling a parallel `Vec<f64>` of values, and
7//! reverse-mode adjoints walk the same buffer backwards.
8//!
9//! Sparse Hessians are computed by forward-over-reverse: for each
10//! variable `j` that the tape depends on, run a forward tangent sweep
11//! seeded with `e_j`, then a second-order reverse sweep that produces
12//! column `j` of the Hessian. The caller supplies a `(row, col) -> nnz
13//! position` map (lower triangle, row >= col), and contributions are
14//! accumulated in place — the outer loop in `eval_h` calls the same
15//! map for the objective and every active constraint, so every
16//! Lagrangian term lands in the right slot.
17//!
18//! Common subexpressions are tape-emitted **once**: when the recursive
19//! builder hits `Expr::Cse(rc)` it keys on the `Rc` pointer identity,
20//! emitting the body the first time and returning the cached
21//! result-slot index on subsequent references. The forward pass then
22//! computes each CSE once and the reverse pass folds adjoints from
23//! every reference into a single slot — exact chain-rule behaviour.
24
25use std::collections::{BTreeSet, HashMap, HashSet};
26use std::rc::Rc;
27
28use super::nl_reader::{BinOp, Expr, UnaryOp};
29
30/// One operation in the flattened tape. Operand fields are tape-slot
31/// indices into the same tape; `Var(i)` references problem variable
32/// index `i` (read from the input `x` slice during forward).
33#[derive(Debug, Clone)]
34pub enum TapeOp {
35    Const(f64),
36    Var(usize),
37    Add(usize, usize),
38    Sub(usize, usize),
39    Mul(usize, usize),
40    Div(usize, usize),
41    Pow(usize, usize),
42    Neg(usize),
43    Abs(usize),
44    Sqrt(usize),
45    Exp(usize),
46    Log(usize),
47    Log10(usize),
48    Sin(usize),
49    Cos(usize),
50}
51
52/// A flattened expression tape. The result of evaluation is the value
53/// at slot `ops.len() - 1` (i.e. the last op).
54#[derive(Debug, Clone)]
55pub struct Tape {
56    pub ops: Vec<TapeOp>,
57}
58
59impl Tape {
60    /// Build a tape from an `Expr` tree. CSE bodies (`Expr::Cse(rc)`)
61    /// are cached by `Rc` pointer identity so each body is emitted
62    /// once even when referenced many times.
63    pub fn build(expr: &Expr) -> Self {
64        let mut ops = Vec::new();
65        let mut cache: HashMap<*const Expr, usize> = HashMap::new();
66        build_recursive(expr, &mut ops, &mut cache);
67        Tape { ops }
68    }
69
70    /// Forward sweep: returns `vals[i] = value of tape slot i`. The
71    /// scalar tape result is `vals[ops.len() - 1]`.
72    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
73        let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
74        for op in &self.ops {
75            let v = match op {
76                TapeOp::Const(c) => *c,
77                TapeOp::Var(i) => x[*i],
78                TapeOp::Add(a, b) => vals[*a] + vals[*b],
79                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
80                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
81                TapeOp::Div(a, b) => vals[*a] / vals[*b],
82                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
83                TapeOp::Neg(a) => -vals[*a],
84                TapeOp::Abs(a) => vals[*a].abs(),
85                TapeOp::Sqrt(a) => vals[*a].sqrt(),
86                TapeOp::Exp(a) => vals[*a].exp(),
87                TapeOp::Log(a) => vals[*a].ln(),
88                TapeOp::Log10(a) => vals[*a].log10(),
89                TapeOp::Sin(a) => vals[*a].sin(),
90                TapeOp::Cos(a) => vals[*a].cos(),
91            };
92            vals.push(v);
93        }
94        vals
95    }
96
97    pub fn eval(&self, x: &[f64]) -> f64 {
98        let vals = self.forward(x);
99        *vals.last().unwrap_or(&0.0)
100    }
101
102    /// Reverse-mode AD: accumulate `seed * df/dx_i` into `grad[i]` for
103    /// every problem variable `i` referenced by the tape. `grad` is
104    /// **not** zeroed by this routine — the caller can chain multiple
105    /// gradient accumulations into the same buffer.
106    pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
107        if seed == 0.0 || self.ops.is_empty() {
108            return;
109        }
110        let vals = self.forward(x);
111        self.reverse(&vals, seed, grad);
112    }
113
114    fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
115        let n = self.ops.len();
116        let mut adj = vec![0.0f64; n];
117        adj[n - 1] = seed;
118
119        for i in (0..n).rev() {
120            let a = adj[i];
121            if a == 0.0 {
122                continue;
123            }
124            match &self.ops[i] {
125                TapeOp::Const(_) => {}
126                TapeOp::Var(j) => {
127                    grad[*j] += a;
128                }
129                TapeOp::Add(l, r) => {
130                    adj[*l] += a;
131                    adj[*r] += a;
132                }
133                TapeOp::Sub(l, r) => {
134                    adj[*l] += a;
135                    adj[*r] -= a;
136                }
137                TapeOp::Mul(l, r) => {
138                    adj[*l] += a * vals[*r];
139                    adj[*r] += a * vals[*l];
140                }
141                TapeOp::Div(l, r) => {
142                    let rv = vals[*r];
143                    adj[*l] += a / rv;
144                    adj[*r] -= a * vals[*l] / (rv * rv);
145                }
146                TapeOp::Pow(l, r) => {
147                    let lv = vals[*l];
148                    let rv = vals[*r];
149                    if rv != 0.0 {
150                        adj[*l] += a * rv * lv.powf(rv - 1.0);
151                    }
152                    if lv > 0.0 {
153                        adj[*r] += a * vals[i] * lv.ln();
154                    }
155                }
156                TapeOp::Neg(j) => {
157                    adj[*j] -= a;
158                }
159                TapeOp::Abs(j) => {
160                    if vals[*j] >= 0.0 {
161                        adj[*j] += a;
162                    } else {
163                        adj[*j] -= a;
164                    }
165                }
166                TapeOp::Sqrt(j) => {
167                    let sv = vals[i];
168                    if sv > 0.0 {
169                        adj[*j] += a * 0.5 / sv;
170                    }
171                }
172                TapeOp::Exp(j) => {
173                    adj[*j] += a * vals[i];
174                }
175                TapeOp::Log(j) => {
176                    adj[*j] += a / vals[*j];
177                }
178                TapeOp::Log10(j) => {
179                    adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
180                }
181                TapeOp::Sin(j) => {
182                    adj[*j] += a * vals[*j].cos();
183                }
184                TapeOp::Cos(j) => {
185                    adj[*j] -= a * vals[*j].sin();
186                }
187            }
188        }
189    }
190
191    /// Sorted distinct problem-variable indices that the tape depends on.
192    pub fn variables(&self) -> Vec<usize> {
193        let mut s: BTreeSet<usize> = BTreeSet::new();
194        for op in &self.ops {
195            if let TapeOp::Var(j) = op {
196                s.insert(*j);
197            }
198        }
199        s.into_iter().collect()
200    }
201
202    /// Forward tangent sweep: `dot[i] = d(slot_i) / dx_{seed_var}`.
203    /// Caller-supplied `dot` buffer is overwritten in full; no zeroing
204    /// needed beforehand because every slot is written before it is
205    /// read (the loop walks forward and only reads earlier slots).
206    fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
207        let n = self.ops.len();
208        debug_assert_eq!(dot.len(), n);
209        for i in 0..n {
210            dot[i] = match &self.ops[i] {
211                TapeOp::Const(_) => 0.0,
212                TapeOp::Var(k) => {
213                    if *k == seed_var {
214                        1.0
215                    } else {
216                        0.0
217                    }
218                }
219                TapeOp::Add(a, b) => dot[*a] + dot[*b],
220                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
221                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
222                TapeOp::Div(a, b) => {
223                    let vb = vals[*b];
224                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
225                }
226                TapeOp::Pow(a, b) => {
227                    let u = vals[*a];
228                    let r = vals[*b];
229                    let du = dot[*a];
230                    let dr = dot[*b];
231                    let mut result = 0.0;
232                    if r != 0.0 && u != 0.0 {
233                        result += r * u.powf(r - 1.0) * du;
234                    }
235                    if u > 0.0 {
236                        result += vals[i] * u.ln() * dr;
237                    }
238                    result
239                }
240                TapeOp::Neg(a) => -dot[*a],
241                TapeOp::Abs(a) => {
242                    if vals[*a] >= 0.0 {
243                        dot[*a]
244                    } else {
245                        -dot[*a]
246                    }
247                }
248                TapeOp::Sqrt(a) => {
249                    let sv = vals[i];
250                    if sv > 0.0 {
251                        dot[*a] * 0.5 / sv
252                    } else {
253                        0.0
254                    }
255                }
256                TapeOp::Exp(a) => dot[*a] * vals[i],
257                TapeOp::Log(a) => dot[*a] / vals[*a],
258                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
259                TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
260                TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
261            };
262        }
263    }
264
265    /// Forward sweep into a caller-supplied buffer. Avoids the
266    /// per-call allocation of `forward()` so hot paths can reuse
267    /// one scratch arena across many tapes.
268    pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
269        let n = self.ops.len();
270        debug_assert!(vals.len() >= n);
271        for i in 0..n {
272            vals[i] = match &self.ops[i] {
273                TapeOp::Const(c) => *c,
274                TapeOp::Var(j) => x[*j],
275                TapeOp::Add(a, b) => vals[*a] + vals[*b],
276                TapeOp::Sub(a, b) => vals[*a] - vals[*b],
277                TapeOp::Mul(a, b) => vals[*a] * vals[*b],
278                TapeOp::Div(a, b) => vals[*a] / vals[*b],
279                TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
280                TapeOp::Neg(a) => -vals[*a],
281                TapeOp::Abs(a) => vals[*a].abs(),
282                TapeOp::Sqrt(a) => vals[*a].sqrt(),
283                TapeOp::Exp(a) => vals[*a].exp(),
284                TapeOp::Log(a) => vals[*a].ln(),
285                TapeOp::Log10(a) => vals[*a].log10(),
286                TapeOp::Sin(a) => vals[*a].sin(),
287                TapeOp::Cos(a) => vals[*a].cos(),
288            };
289        }
290    }
291
292    /// Directional Hessian-vector product: emits
293    /// `weight * (∇²f · seed)[k]` into `out[k]` for every problem
294    /// variable `k` the tape references. Caller supplies the
295    /// forward-pass result `vals` (use [`forward_into`]) plus three
296    /// scratch buffers (`dot`, `adj`, `adj_dot`), each at least
297    /// `self.ops.len()` long. `out` must be at least one past the
298    /// largest variable index in the tape; the routine reads
299    /// `seed[k]` for each `Var(k)` and writes `out[k] += weight *
300    /// (Hess · seed)[k]`.
301    ///
302    /// This is one forward-over-reverse AD pass — O(n_ops) work —
303    /// regardless of how many variables the tape depends on, which
304    /// is what makes Hessian coloring efficient: a single
305    /// directional pass recovers a whole color group of columns.
306    ///
307    /// [`forward_into`]: Tape::forward_into
308    pub fn hessian_directional(
309        &self,
310        vals: &[f64],
311        seed: &[f64],
312        weight: f64,
313        out: &mut [f64],
314        dot: &mut [f64],
315        adj: &mut [f64],
316        adj_dot: &mut [f64],
317    ) {
318        let n = self.ops.len();
319        if n == 0 || weight == 0.0 {
320            return;
321        }
322        debug_assert!(vals.len() >= n);
323        debug_assert!(dot.len() >= n);
324        debug_assert!(adj.len() >= n);
325        debug_assert!(adj_dot.len() >= n);
326
327        // Forward tangent: dot[i] = (∂vals[i] / ∂x · seed). At
328        // Var(k) the seed entry feeds in; the rest of the chain
329        // rule matches `forward_tangent` exactly.
330        for i in 0..n {
331            dot[i] = match &self.ops[i] {
332                TapeOp::Const(_) => 0.0,
333                TapeOp::Var(k) => seed[*k],
334                TapeOp::Add(a, b) => dot[*a] + dot[*b],
335                TapeOp::Sub(a, b) => dot[*a] - dot[*b],
336                TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
337                TapeOp::Div(a, b) => {
338                    let vb = vals[*b];
339                    (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
340                }
341                TapeOp::Pow(a, b) => {
342                    let u = vals[*a];
343                    let r = vals[*b];
344                    let du = dot[*a];
345                    let dr = dot[*b];
346                    let mut result = 0.0;
347                    if r != 0.0 && u != 0.0 {
348                        result += r * u.powf(r - 1.0) * du;
349                    }
350                    if u > 0.0 {
351                        result += vals[i] * u.ln() * dr;
352                    }
353                    result
354                }
355                TapeOp::Neg(a) => -dot[*a],
356                TapeOp::Abs(a) => {
357                    if vals[*a] >= 0.0 {
358                        dot[*a]
359                    } else {
360                        -dot[*a]
361                    }
362                }
363                TapeOp::Sqrt(a) => {
364                    let sv = vals[i];
365                    if sv > 0.0 {
366                        dot[*a] * 0.5 / sv
367                    } else {
368                        0.0
369                    }
370                }
371                TapeOp::Exp(a) => vals[i] * dot[*a],
372                TapeOp::Log(a) => dot[*a] / vals[*a],
373                TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
374                TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
375                TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
376            };
377        }
378
379        // Reverse over tangent. adj[i] = ∂f/∂vals[i],
380        // adj_dot[i] = derivative of adj[i] along `seed`
381        // direction = (Hess · seed) projected onto slot i.
382        for slot in adj.iter_mut().take(n) {
383            *slot = 0.0;
384        }
385        for slot in adj_dot.iter_mut().take(n) {
386            *slot = 0.0;
387        }
388        adj[n - 1] = 1.0;
389
390        for i in (0..n).rev() {
391            let w = adj[i];
392            let wd = adj_dot[i];
393            if w == 0.0 && wd == 0.0 {
394                continue;
395            }
396            match &self.ops[i] {
397                TapeOp::Const(_) => {}
398                TapeOp::Var(k) => {
399                    if wd != 0.0 {
400                        out[*k] += weight * wd;
401                    }
402                }
403                TapeOp::Add(a, b) => {
404                    adj[*a] += w;
405                    adj[*b] += w;
406                    adj_dot[*a] += wd;
407                    adj_dot[*b] += wd;
408                }
409                TapeOp::Sub(a, b) => {
410                    adj[*a] += w;
411                    adj[*b] -= w;
412                    adj_dot[*a] += wd;
413                    adj_dot[*b] -= wd;
414                }
415                TapeOp::Mul(a, b) => {
416                    adj[*a] += w * vals[*b];
417                    adj[*b] += w * vals[*a];
418                    adj_dot[*a] += wd * vals[*b] + w * dot[*b];
419                    adj_dot[*b] += wd * vals[*a] + w * dot[*a];
420                }
421                TapeOp::Div(a, b) => {
422                    let vb = vals[*b];
423                    let vb2 = vb * vb;
424                    let vb3 = vb2 * vb;
425                    adj[*a] += w / vb;
426                    adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
427                    adj[*b] += w * (-vals[*a] / vb2);
428                    adj_dot[*b] += wd * (-vals[*a] / vb2)
429                        + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
430                }
431                TapeOp::Pow(a, b) => {
432                    let u = vals[*a];
433                    let r = vals[*b];
434                    let du = dot[*a];
435                    let dr = dot[*b];
436                    if r != 0.0 {
437                        if u != 0.0 {
438                            let p_a = r * u.powf(r - 1.0);
439                            adj[*a] += w * p_a;
440                            let mut dp_a = dr * u.powf(r - 1.0);
441                            if u > 0.0 {
442                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
443                            } else {
444                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
445                            }
446                            adj_dot[*a] += wd * p_a + w * dp_a;
447                        } else if r >= 2.0 {
448                            let p_a = 0.0;
449                            adj[*a] += w * p_a;
450                            let dp_a = if r == 2.0 {
451                                2.0 * du
452                            } else {
453                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
454                            };
455                            adj_dot[*a] += wd * p_a + w * dp_a;
456                        }
457                    }
458                    if u > 0.0 {
459                        let ln_u = u.ln();
460                        let p_b = vals[i] * ln_u;
461                        adj[*b] += w * p_b;
462                        let dur = vals[i] * (r * du / u + dr * ln_u);
463                        let dp_b = dur * ln_u + vals[i] * du / u;
464                        adj_dot[*b] += wd * p_b + w * dp_b;
465                    }
466                }
467                TapeOp::Neg(a) => {
468                    adj[*a] -= w;
469                    adj_dot[*a] -= wd;
470                }
471                TapeOp::Abs(a) => {
472                    let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
473                    adj[*a] += w * s;
474                    adj_dot[*a] += wd * s;
475                }
476                TapeOp::Sqrt(a) => {
477                    let sv = vals[i];
478                    if sv > 0.0 {
479                        let fp = 0.5 / sv;
480                        let fpp = -0.25 / (vals[*a] * sv);
481                        adj[*a] += w * fp;
482                        adj_dot[*a] += wd * fp + w * fpp * dot[*a];
483                    }
484                }
485                TapeOp::Exp(a) => {
486                    let ev = vals[i];
487                    adj[*a] += w * ev;
488                    adj_dot[*a] += wd * ev + w * ev * dot[*a];
489                }
490                TapeOp::Log(a) => {
491                    let u = vals[*a];
492                    adj[*a] += w / u;
493                    adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
494                }
495                TapeOp::Log10(a) => {
496                    let u = vals[*a];
497                    let c = std::f64::consts::LN_10;
498                    adj[*a] += w / (u * c);
499                    adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
500                }
501                TapeOp::Sin(a) => {
502                    let u = vals[*a];
503                    let cu = u.cos();
504                    adj[*a] += w * cu;
505                    adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
506                }
507                TapeOp::Cos(a) => {
508                    let u = vals[*a];
509                    let su = u.sin();
510                    adj[*a] -= w * su;
511                    adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
512                }
513            }
514        }
515    }
516
517    /// Forward-over-reverse Hessian: for each variable `j` the tape
518    /// depends on, accumulate `weight * (d²f / dx_i dx_j)` into
519    /// `values[hess_map[(i, j)]]` for every `(i, j)` lower-triangle
520    /// pair in the map. The same routine is used for the objective
521    /// (with `weight = obj_factor`) and each active constraint (with
522    /// `weight = lambda[k]`); contributions sum into the shared map.
523    pub fn hessian_accumulate(
524        &self,
525        x: &[f64],
526        weight: f64,
527        hess_map: &HashMap<(usize, usize), usize>,
528        values: &mut [f64],
529    ) {
530        let n = self.ops.len();
531        if n == 0 || weight == 0.0 {
532            return;
533        }
534        let v = self.forward(x);
535        let var_indices = self.variables();
536
537        // Hoist scratch allocations out of the per-variable loop —
538        // each was costing O(n) per j on every hessian_accumulate
539        // call, which dominated runtime on large tapes (the dense-
540        // Hessian Mittelmann problems). `forward_tangent` fully
541        // overwrites `dot`, so no reset is needed there. `adj` and
542        // `adj_dot` are mutated additively, so we zero them per j.
543        let mut dot = vec![0.0f64; n];
544        let mut adj = vec![0.0f64; n];
545        let mut adj_dot = vec![0.0f64; n];
546        for &j in &var_indices {
547            self.forward_tangent(&v, j, &mut dot);
548
549            // adj[i] = standard adjoint (∂f/∂slot_i)
550            // adj_dot[i] = derivative of adj[i] w.r.t. x_j = ∂²f/(∂slot_i ∂x_j)
551            adj.fill(0.0);
552            adj_dot.fill(0.0);
553            adj[n - 1] = 1.0;
554
555            for i in (0..n).rev() {
556                let w = adj[i];
557                let wd = adj_dot[i];
558                if w == 0.0 && wd == 0.0 {
559                    continue;
560                }
561                match &self.ops[i] {
562                    TapeOp::Const(_) => {}
563                    TapeOp::Var(k) => {
564                        // Lower-triangle: only emit when row k >= col j
565                        // so an off-diagonal pair appears once.
566                        if wd != 0.0 && *k >= j {
567                            if let Some(&pos) = hess_map.get(&(*k, j)) {
568                                values[pos] += weight * wd;
569                            }
570                        }
571                    }
572                    TapeOp::Add(a, b) => {
573                        adj[*a] += w;
574                        adj[*b] += w;
575                        adj_dot[*a] += wd;
576                        adj_dot[*b] += wd;
577                    }
578                    TapeOp::Sub(a, b) => {
579                        adj[*a] += w;
580                        adj[*b] -= w;
581                        adj_dot[*a] += wd;
582                        adj_dot[*b] -= wd;
583                    }
584                    TapeOp::Mul(a, b) => {
585                        adj[*a] += w * v[*b];
586                        adj[*b] += w * v[*a];
587                        adj_dot[*a] += wd * v[*b] + w * dot[*b];
588                        adj_dot[*b] += wd * v[*a] + w * dot[*a];
589                    }
590                    TapeOp::Div(a, b) => {
591                        let vb = v[*b];
592                        let vb2 = vb * vb;
593                        let vb3 = vb2 * vb;
594                        adj[*a] += w / vb;
595                        adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
596                        adj[*b] += w * (-v[*a] / vb2);
597                        adj_dot[*b] += wd * (-v[*a] / vb2)
598                            + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
599                    }
600                    TapeOp::Pow(a, b) => {
601                        let u = v[*a];
602                        let r = v[*b];
603                        let du = dot[*a];
604                        let dr = dot[*b];
605                        if r != 0.0 {
606                            if u != 0.0 {
607                                let p_a = r * u.powf(r - 1.0);
608                                adj[*a] += w * p_a;
609                                let mut dp_a = dr * u.powf(r - 1.0);
610                                if u > 0.0 {
611                                    dp_a +=
612                                        r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
613                                } else {
614                                    dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
615                                }
616                                adj_dot[*a] += wd * p_a + w * dp_a;
617                            } else if r >= 2.0 {
618                                let p_a = 0.0;
619                                adj[*a] += w * p_a;
620                                let dp_a = if r == 2.0 {
621                                    2.0 * du
622                                } else {
623                                    r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
624                                };
625                                adj_dot[*a] += wd * p_a + w * dp_a;
626                            }
627                        }
628                        if u > 0.0 {
629                            let ln_u = u.ln();
630                            let p_b = v[i] * ln_u;
631                            adj[*b] += w * p_b;
632                            let dur = v[i] * (r * du / u + dr * ln_u);
633                            let dp_b = dur * ln_u + v[i] * du / u;
634                            adj_dot[*b] += wd * p_b + w * dp_b;
635                        }
636                    }
637                    TapeOp::Neg(a) => {
638                        adj[*a] -= w;
639                        adj_dot[*a] -= wd;
640                    }
641                    TapeOp::Abs(a) => {
642                        let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
643                        adj[*a] += w * s;
644                        adj_dot[*a] += wd * s;
645                    }
646                    TapeOp::Sqrt(a) => {
647                        let sv = v[i];
648                        if sv > 0.0 {
649                            let fp = 0.5 / sv;
650                            let fpp = -0.25 / (v[*a] * sv);
651                            adj[*a] += w * fp;
652                            adj_dot[*a] += wd * fp + w * fpp * dot[*a];
653                        }
654                    }
655                    TapeOp::Exp(a) => {
656                        let ev = v[i];
657                        adj[*a] += w * ev;
658                        adj_dot[*a] += wd * ev + w * ev * dot[*a];
659                    }
660                    TapeOp::Log(a) => {
661                        let u = v[*a];
662                        adj[*a] += w / u;
663                        adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
664                    }
665                    TapeOp::Log10(a) => {
666                        let u = v[*a];
667                        let c = std::f64::consts::LN_10;
668                        adj[*a] += w / (u * c);
669                        adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
670                    }
671                    TapeOp::Sin(a) => {
672                        let u = v[*a];
673                        let cu = u.cos();
674                        adj[*a] += w * cu;
675                        adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
676                    }
677                    TapeOp::Cos(a) => {
678                        let u = v[*a];
679                        let su = u.sin();
680                        adj[*a] -= w * su;
681                        adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
682                    }
683                }
684            }
685        }
686    }
687
688    /// Structural Hessian sparsity (lower triangle, row >= col).
689    /// Propagates per-slot variable-dependence sets forward; each
690    /// nonlinear op emits the cross/self products of its operand sets.
691    /// Linear ops contribute no second-derivative pairs.
692    pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
693        let n = self.ops.len();
694        let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
695        let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
696
697        let emit_cross =
698            |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
699                for &v1 in s1 {
700                    for &v2 in s2 {
701                        let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
702                        pairs.insert((r, c));
703                    }
704                }
705            };
706        let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
707            let vars: Vec<usize> = s.iter().copied().collect();
708            for (ai, &vi) in vars.iter().enumerate() {
709                for &vj in &vars[..=ai] {
710                    let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
711                    pairs.insert((r, c));
712                }
713            }
714        };
715
716        for op in &self.ops {
717            let vset = match op {
718                TapeOp::Const(_) => BTreeSet::new(),
719                TapeOp::Var(j) => {
720                    let mut s = BTreeSet::new();
721                    s.insert(*j);
722                    s
723                }
724                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
725                    var_sets[*a].union(&var_sets[*b]).copied().collect()
726                }
727                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
728                TapeOp::Mul(a, b) => {
729                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
730                    var_sets[*a].union(&var_sets[*b]).copied().collect()
731                }
732                TapeOp::Div(a, b) => {
733                    emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
734                    emit_self(&var_sets[*b], &mut pairs);
735                    var_sets[*a].union(&var_sets[*b]).copied().collect()
736                }
737                TapeOp::Pow(a, b) => {
738                    let combined: BTreeSet<usize> =
739                        var_sets[*a].union(&var_sets[*b]).copied().collect();
740                    emit_self(&combined, &mut pairs);
741                    combined
742                }
743                TapeOp::Sqrt(a)
744                | TapeOp::Exp(a)
745                | TapeOp::Log(a)
746                | TapeOp::Log10(a)
747                | TapeOp::Sin(a)
748                | TapeOp::Cos(a) => {
749                    emit_self(&var_sets[*a], &mut pairs);
750                    var_sets[*a].clone()
751                }
752            };
753            var_sets.push(vset);
754        }
755        pairs
756    }
757}
758
759fn build_recursive(
760    expr: &Expr,
761    ops: &mut Vec<TapeOp>,
762    cache: &mut HashMap<*const Expr, usize>,
763) -> usize {
764    match expr {
765        Expr::Const(c) => {
766            let idx = ops.len();
767            ops.push(TapeOp::Const(*c));
768            idx
769        }
770        Expr::Var(i) => {
771            let idx = ops.len();
772            ops.push(TapeOp::Var(*i));
773            idx
774        }
775        Expr::Binary(op, a, b) => {
776            // Pow(x, const) is the dominant libm/dispatch cost in
777            // transcendental-heavy AMPL tapes (henon, lane_emden, …):
778            // `powf` itself is ~30–50 cycles AND the reverse-mode arm
779            // for `Pow` carries an extra `ln(x)` branch. Rewriting
780            // small integer / half-integer exponents into mul/sqrt
781            // chains drops these calls entirely and reroutes the AD
782            // through the much cheaper `Mul`/`Sqrt` arms.
783            if let BinOp::Pow = op {
784                if let Some(c) = peek_const(b) {
785                    if let Some(idx) = try_emit_const_pow(a, c, ops, cache) {
786                        return idx;
787                    }
788                }
789            }
790            let l = build_recursive(a, ops, cache);
791            let r = build_recursive(b, ops, cache);
792            let idx = ops.len();
793            ops.push(match op {
794                BinOp::Add => TapeOp::Add(l, r),
795                BinOp::Sub => TapeOp::Sub(l, r),
796                BinOp::Mul => TapeOp::Mul(l, r),
797                BinOp::Div => TapeOp::Div(l, r),
798                BinOp::Pow => TapeOp::Pow(l, r),
799            });
800            idx
801        }
802        Expr::Unary(op, a) => {
803            let v = build_recursive(a, ops, cache);
804            let idx = ops.len();
805            ops.push(match op {
806                UnaryOp::Neg => TapeOp::Neg(v),
807                UnaryOp::Sqrt => TapeOp::Sqrt(v),
808                UnaryOp::Log => TapeOp::Log(v),
809                UnaryOp::Log10 => TapeOp::Log10(v),
810                UnaryOp::Exp => TapeOp::Exp(v),
811                UnaryOp::Abs => TapeOp::Abs(v),
812                UnaryOp::Sin => TapeOp::Sin(v),
813                UnaryOp::Cos => TapeOp::Cos(v),
814            });
815            idx
816        }
817        Expr::Sum(args) => {
818            if args.is_empty() {
819                let idx = ops.len();
820                ops.push(TapeOp::Const(0.0));
821                return idx;
822            }
823            let mut acc = build_recursive(&args[0], ops, cache);
824            for a in &args[1..] {
825                let next = build_recursive(a, ops, cache);
826                let idx = ops.len();
827                ops.push(TapeOp::Add(acc, next));
828                acc = idx;
829            }
830            acc
831        }
832        Expr::Cse(body) => {
833            // Cache by Rc identity so each shared body is emitted into
834            // the tape exactly once and every reference resolves to the
835            // same result-slot index. Forward computes the body once;
836            // reverse-mode adjoint sums contributions from every ref
837            // into that shared slot — exact chain rule for shared
838            // sub-expressions.
839            let key = Rc::as_ptr(body) as *const Expr;
840            if let Some(&idx) = cache.get(&key) {
841                idx
842            } else {
843                let idx = build_recursive(body, ops, cache);
844                cache.insert(key, idx);
845                idx
846            }
847        }
848    }
849}
850
851/// Resolve `e` to a literal constant if it is one (transparently
852/// peering through `Cse` wrappers, which AMPL emits around shared
853/// constants in CSE-heavy problems).
854fn peek_const(e: &Expr) -> Option<f64> {
855    match e {
856        Expr::Const(c) => Some(*c),
857        Expr::Cse(body) => peek_const(body),
858        _ => None,
859    }
860}
861
862/// Try to rewrite `base ^ exponent_const` into cheaper ops. Returns
863/// the result tape-slot on success; `None` means "fall through to
864/// generic Pow." Handles the cases that account for the bulk of
865/// AMPL-emitted Pow nodes: integer exponents up to ±8 and the
866/// `Sqrt`/passthrough/one specials. Half-integer exponents (e.g.
867/// `^1.5`) and larger integers are left to generic `Pow` since the
868/// resulting mul chain grows the tape faster than it saves work.
869fn try_emit_const_pow(
870    base_expr: &Expr,
871    c: f64,
872    ops: &mut Vec<TapeOp>,
873    cache: &mut HashMap<*const Expr, usize>,
874) -> Option<usize> {
875    if c == 0.0 {
876        let idx = ops.len();
877        ops.push(TapeOp::Const(1.0));
878        return Some(idx);
879    }
880    if c == 1.0 {
881        return Some(build_recursive(base_expr, ops, cache));
882    }
883    if c == 0.5 {
884        let b = build_recursive(base_expr, ops, cache);
885        let idx = ops.len();
886        ops.push(TapeOp::Sqrt(b));
887        return Some(idx);
888    }
889    // Integer exponents: bounded so a bad tape can't blow up the
890    // op count. 8 covers everything AMPL typically emits for
891    // polynomial models; beyond that the binary-expansion mul
892    // chain (≥4 ops) starts to lose to a single `powf`.
893    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
894        let n = c.abs() as u32;
895        if n == 0 {
896            // Already handled above, but guard.
897            let idx = ops.len();
898            ops.push(TapeOp::Const(1.0));
899            return Some(idx);
900        }
901        let b = build_recursive(base_expr, ops, cache);
902        let pos = emit_int_pow(b, n, ops);
903        if c < 0.0 {
904            // x^-n = 1 / x^n. Saves the powf and its ln branch in
905            // reverse mode; cost is one Div in their place.
906            let one_idx = ops.len();
907            ops.push(TapeOp::Const(1.0));
908            let idx = ops.len();
909            ops.push(TapeOp::Div(one_idx, pos));
910            return Some(idx);
911        }
912        return Some(pos);
913    }
914    None
915}
916
917/// Emit `base^n` for `n >= 1` as a binary-expansion mul chain.
918/// Worst-case op count is `2·floor(log2(n))` — i.e. 1 op for n=2, 2
919/// for n=3/4, 3 for n=5..8.
920fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
921    debug_assert!(n >= 1);
922    if n == 1 {
923        return base;
924    }
925    let half = emit_int_pow(base, n / 2, ops);
926    let squared = ops.len();
927    ops.push(TapeOp::Mul(half, half));
928    if n % 2 == 1 {
929        let idx = ops.len();
930        ops.push(TapeOp::Mul(squared, base));
931        idx
932    } else {
933        squared
934    }
935}
936
937// ============================================================
938// HybridTape: per-summand local tapes + shared CSE prelude.
939//
940// Partial separability — the .nl Sum/Add structure — gets each
941// summand its own local Vec<SummandOp>. CSE bodies (V-segments
942// in .nl) that appear in two or more summands are promoted into
943// a single shared `prelude: Vec<TapeOp>`; per-summand references
944// to a promoted CSE are SummandOp::Shared(prelude_slot).
945//
946// This is strictly better than either extreme:
947//   - per-summand Tape (no cross-summand sharing): re-inlines
948//     every shared CSE, blows up tape size when many constraints
949//     share a stencil derivative (Mittelmann *120 problems).
950//   - GlobalTape (single shared Vec<TapeOp> for everything):
951//     per-root reverse sweeps scatter across a many-MB buffer,
952//     thrashing cache when no CSE is actually shared (lane_emden
953//     120: each constraint owns its own ops → 50% regression
954//     vs per-summand tapes).
955//
956// Forward: prelude once, then each summand's local pass.
957// Reverse / forward-over-reverse: per-summand sweep over local
958// reach (which propagates adjoints into prelude_adj at Shared
959// boundaries), then a small reverse pass over the summand's
960// prelude_reach to fold those into grad / Hessian.
961// ============================================================
962
963/// One slot in a per-summand local tape.
964#[derive(Debug, Clone)]
965pub enum SummandOp {
966    /// Local op — operand indices reference other slots in the
967    /// same per-summand vector.
968    Local(TapeOp),
969    /// Pull a value from the shared prelude at slot `usize`. No
970    /// downstream cost beyond the lookup; adjoints flowing into
971    /// this slot accumulate into the prelude adjoint buffer.
972    Shared(usize),
973}
974
975#[derive(Debug, Clone)]
976pub struct Summand {
977    pub ops: Vec<SummandOp>,
978    /// Local slot holding the summand's final value.
979    pub root_slot: usize,
980    /// Local slots reachable from `root_slot`, ascending (topo).
981    pub local_reach: Vec<usize>,
982    /// Prelude slots reachable from the summand's Shared refs,
983    /// ascending (topo in prelude's operand DAG).
984    pub prelude_reach: Vec<usize>,
985    /// Variables touched by Var ops inside `local_reach`.
986    pub local_vars: Vec<usize>,
987    /// Variables touched by Var ops inside `prelude_reach`.
988    pub prelude_vars: Vec<usize>,
989    /// `local_vars ∪ prelude_vars`, sorted. Hessian j-loop set.
990    pub all_vars: Vec<usize>,
991}
992
993#[derive(Debug)]
994pub struct HybridTape {
995    /// Shared CSE bodies. Slot indices in `SummandOp::Shared`
996    /// point here; this Vec is built bottom-up by `build_recursive`,
997    /// so operand indices are always less than the consumer's
998    /// index (topo in ascending order).
999    pub prelude: Vec<TapeOp>,
1000    pub summands: Vec<Summand>,
1001}
1002
1003impl HybridTape {
1004    /// Build hybrid tape from a list of root expressions. CSE
1005    /// bodies referenced from ≥ 2 roots are promoted into the
1006    /// shared prelude; CSEs touched by only one root are inlined
1007    /// into that summand's local ops.
1008    pub fn build_multi(exprs: &[Expr]) -> Self {
1009        // Pass 1: per-Cse-pointer count of how many roots reference
1010        // it (each root contributes at most 1 to the count). The
1011        // ≥2 threshold means a CSE is shared across summands.
1012        let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1013        for e in exprs {
1014            let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1015            count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1016        }
1017
1018        // Pass 2: build prelude + each summand. The summand builder
1019        // hits the prelude path lazily — only when it encounters a
1020        // promoted Cse — so the prelude grows only with bodies that
1021        // are actually referenced from multiple summands.
1022        let mut prelude: Vec<TapeOp> = Vec::new();
1023        let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1024        let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1025        for e in exprs {
1026            let mut local: Vec<SummandOp> = Vec::new();
1027            let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1028            let root_slot = build_into_summand(
1029                e,
1030                &mut local,
1031                &mut local_cache,
1032                &mut prelude,
1033                &mut prelude_map,
1034                &cse_count,
1035            );
1036            summands.push(Summand {
1037                ops: local,
1038                root_slot,
1039                local_reach: Vec::new(),
1040                prelude_reach: Vec::new(),
1041                local_vars: Vec::new(),
1042                prelude_vars: Vec::new(),
1043                all_vars: Vec::new(),
1044            });
1045        }
1046
1047        // Pass 3: per-summand reach / vars. Prelude reach uses an
1048        // epoch-tagged shared visited buffer so total cost stays
1049        // O(Σ |prelude_reach_i|) rather than O(n_summands × |prelude|).
1050        let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1051        let mut p_epoch: u32 = 0;
1052        let mut p_stack: Vec<usize> = Vec::new();
1053        for s in &mut summands {
1054            let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1055            s.local_reach = local_reach;
1056
1057            let mut lv: BTreeSet<usize> = BTreeSet::new();
1058            for &i in &s.local_reach {
1059                if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1060                    lv.insert(*j);
1061                }
1062            }
1063            s.local_vars = lv.iter().copied().collect();
1064
1065            if !shared_refs.is_empty() {
1066                p_epoch += 1;
1067                let mut preach: Vec<usize> = Vec::new();
1068                for &start in &shared_refs {
1069                    bfs_prelude(
1070                        &prelude,
1071                        start,
1072                        &mut p_visited,
1073                        p_epoch,
1074                        &mut p_stack,
1075                        &mut preach,
1076                    );
1077                }
1078                preach.sort_unstable();
1079                s.prelude_vars = vars_in(&prelude, &preach);
1080                s.prelude_reach = preach;
1081            }
1082
1083            let mut av: BTreeSet<usize> = lv;
1084            for &v in &s.prelude_vars {
1085                av.insert(v);
1086            }
1087            s.all_vars = av.into_iter().collect();
1088        }
1089
1090        HybridTape { prelude, summands }
1091    }
1092
1093    pub fn n_prelude_ops(&self) -> usize {
1094        self.prelude.len()
1095    }
1096    pub fn n_summands(&self) -> usize {
1097        self.summands.len()
1098    }
1099    pub fn max_summand_ops(&self) -> usize {
1100        self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
1101    }
1102    pub fn total_local_ops(&self) -> usize {
1103        self.summands.iter().map(|s| s.ops.len()).sum()
1104    }
1105
1106    /// Forward sweep over the shared prelude. `prelude_vals` must
1107    /// have length `n_prelude_ops`.
1108    pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
1109        debug_assert_eq!(prelude_vals.len(), self.prelude.len());
1110        for i in 0..self.prelude.len() {
1111            prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
1112        }
1113    }
1114
1115    /// Forward sweep over one summand. `local_vals` must hold at
1116    /// least `s.ops.len()` entries.
1117    pub fn forward_summand(
1118        &self,
1119        s: &Summand,
1120        x: &[f64],
1121        prelude_vals: &[f64],
1122        local_vals: &mut [f64],
1123    ) {
1124        debug_assert!(local_vals.len() >= s.ops.len());
1125        for i in 0..s.ops.len() {
1126            local_vals[i] = match &s.ops[i] {
1127                SummandOp::Local(op) => fwd_step(op, x, local_vals),
1128                SummandOp::Shared(k) => prelude_vals[*k],
1129            };
1130        }
1131    }
1132
1133    /// Value at the summand root after `forward_summand`.
1134    #[inline]
1135    pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
1136        local_vals[s.root_slot]
1137    }
1138
1139    /// Reverse-mode gradient for one summand. Walks `local_reach`
1140    /// in reverse — propagating adjoints into `prelude_adj` at
1141    /// Shared boundaries — and then walks `prelude_reach` in
1142    /// reverse to land contributions in `grad`. Scratch arrays
1143    /// `local_adj` and `prelude_adj` are zeroed only at the slots
1144    /// actually touched.
1145    #[allow(clippy::too_many_arguments)]
1146    pub fn gradient_summand(
1147        &self,
1148        s: &Summand,
1149        prelude_vals: &[f64],
1150        local_vals: &[f64],
1151        seed: f64,
1152        grad: &mut [f64],
1153        local_adj: &mut [f64],
1154        prelude_adj: &mut [f64],
1155    ) {
1156        if seed == 0.0 || s.local_reach.is_empty() {
1157            return;
1158        }
1159        for &i in &s.local_reach {
1160            local_adj[i] = 0.0;
1161        }
1162        for &i in &s.prelude_reach {
1163            prelude_adj[i] = 0.0;
1164        }
1165        local_adj[s.root_slot] = seed;
1166        for &i in s.local_reach.iter().rev() {
1167            let a = local_adj[i];
1168            if a == 0.0 {
1169                continue;
1170            }
1171            match &s.ops[i] {
1172                SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
1173                SummandOp::Shared(k) => {
1174                    prelude_adj[*k] += a;
1175                }
1176            }
1177        }
1178        for &i in s.prelude_reach.iter().rev() {
1179            let a = prelude_adj[i];
1180            if a == 0.0 {
1181                continue;
1182            }
1183            rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
1184        }
1185    }
1186
1187    /// Forward-over-reverse Hessian for one summand with multiplier
1188    /// `weight`. Iterates over `s.all_vars`; for each seed variable
1189    /// j: (1) forward tangent through prelude_reach then local_reach,
1190    /// (2) reverse over local (folding adj/adj_dot into prelude at
1191    /// Shared boundaries), (3) reverse over prelude_reach. All
1192    /// scratch buffers are zeroed only at the touched slots inside
1193    /// the per-j loop.
1194    #[allow(clippy::too_many_arguments)]
1195    pub fn hessian_summand(
1196        &self,
1197        s: &Summand,
1198        prelude_vals: &[f64],
1199        local_vals: &[f64],
1200        weight: f64,
1201        hess_map: &HashMap<(usize, usize), usize>,
1202        values: &mut [f64],
1203        local_dot: &mut [f64],
1204        local_adj: &mut [f64],
1205        local_adj_dot: &mut [f64],
1206        prelude_dot: &mut [f64],
1207        prelude_adj: &mut [f64],
1208        prelude_adj_dot: &mut [f64],
1209    ) {
1210        if weight == 0.0 || s.local_reach.is_empty() {
1211            return;
1212        }
1213        for &j in &s.all_vars {
1214            for &i in &s.local_reach {
1215                local_dot[i] = 0.0;
1216                local_adj[i] = 0.0;
1217                local_adj_dot[i] = 0.0;
1218            }
1219            for &i in &s.prelude_reach {
1220                prelude_dot[i] = 0.0;
1221                prelude_adj[i] = 0.0;
1222                prelude_adj_dot[i] = 0.0;
1223            }
1224            for &i in &s.prelude_reach {
1225                prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
1226            }
1227            for &i in &s.local_reach {
1228                local_dot[i] = match &s.ops[i] {
1229                    SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
1230                    SummandOp::Shared(k) => prelude_dot[*k],
1231                };
1232            }
1233            local_adj[s.root_slot] = 1.0;
1234            for &i in s.local_reach.iter().rev() {
1235                let w = local_adj[i];
1236                let wd = local_adj_dot[i];
1237                if w == 0.0 && wd == 0.0 {
1238                    continue;
1239                }
1240                match &s.ops[i] {
1241                    SummandOp::Local(op) => {
1242                        ror_step(
1243                            op,
1244                            i,
1245                            j,
1246                            local_vals,
1247                            local_dot,
1248                            local_adj,
1249                            local_adj_dot,
1250                            w,
1251                            wd,
1252                            weight,
1253                            hess_map,
1254                            values,
1255                        );
1256                    }
1257                    SummandOp::Shared(k) => {
1258                        prelude_adj[*k] += w;
1259                        prelude_adj_dot[*k] += wd;
1260                    }
1261                }
1262            }
1263            for &i in s.prelude_reach.iter().rev() {
1264                let w = prelude_adj[i];
1265                let wd = prelude_adj_dot[i];
1266                if w == 0.0 && wd == 0.0 {
1267                    continue;
1268                }
1269                ror_step(
1270                    &self.prelude[i],
1271                    i,
1272                    j,
1273                    prelude_vals,
1274                    prelude_dot,
1275                    prelude_adj,
1276                    prelude_adj_dot,
1277                    w,
1278                    wd,
1279                    weight,
1280                    hess_map,
1281                    values,
1282                );
1283            }
1284        }
1285    }
1286
1287    /// Structural Hessian sparsity over the whole hybrid tape:
1288    /// every pair the prelude or any summand can produce.
1289    pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
1290        let mut pairs = hessian_sparsity_impl(&self.prelude);
1291
1292        // Per-prelude-slot var-set, reused across summands as the
1293        // var-set carrier for Shared refs.
1294        let prelude_var_sets = compute_var_sets(&self.prelude);
1295
1296        for s in &self.summands {
1297            summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
1298        }
1299        pairs
1300    }
1301}
1302
1303/// Pass-1 helper: per-root walk that increments `counts[ptr]` the
1304/// first time a Cse pointer is encountered in this root. Recursing
1305/// into the body is gated on the first visit to avoid quadratic
1306/// blowup on heavily shared CSE DAGs.
1307fn count_cse_appearances(
1308    e: &Expr,
1309    seen_in_root: &mut HashSet<*const Expr>,
1310    counts: &mut HashMap<*const Expr, usize>,
1311) {
1312    match e {
1313        Expr::Const(_) | Expr::Var(_) => {}
1314        Expr::Binary(_, a, b) => {
1315            count_cse_appearances(a, seen_in_root, counts);
1316            count_cse_appearances(b, seen_in_root, counts);
1317        }
1318        Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
1319        Expr::Sum(args) => {
1320            for a in args {
1321                count_cse_appearances(a, seen_in_root, counts);
1322            }
1323        }
1324        Expr::Cse(body) => {
1325            let key = Rc::as_ptr(body) as *const Expr;
1326            if seen_in_root.insert(key) {
1327                *counts.entry(key).or_insert(0) += 1;
1328                count_cse_appearances(body, seen_in_root, counts);
1329            }
1330        }
1331    }
1332}
1333
1334/// Recursive summand builder. CSEs that meet the promotion bar
1335/// (≥ 2 roots reference them per `cse_count`) get a single prelude
1336/// emission via `build_recursive`; the summand records a Shared op
1337/// pointing at the prelude slot. Non-promoted CSEs are inlined
1338/// into the summand with intra-summand Rc-pointer dedup.
1339fn build_into_summand(
1340    expr: &Expr,
1341    local: &mut Vec<SummandOp>,
1342    local_cache: &mut HashMap<*const Expr, usize>,
1343    prelude: &mut Vec<TapeOp>,
1344    prelude_map: &mut HashMap<*const Expr, usize>,
1345    cse_count: &HashMap<*const Expr, usize>,
1346) -> usize {
1347    match expr {
1348        Expr::Const(c) => {
1349            let i = local.len();
1350            local.push(SummandOp::Local(TapeOp::Const(*c)));
1351            i
1352        }
1353        Expr::Var(j) => {
1354            let i = local.len();
1355            local.push(SummandOp::Local(TapeOp::Var(*j)));
1356            i
1357        }
1358        Expr::Binary(op, a, b) => {
1359            if let BinOp::Pow = op {
1360                if let Some(c) = peek_const(b) {
1361                    if let Some(i) = try_emit_const_pow_summand(
1362                        a,
1363                        c,
1364                        local,
1365                        local_cache,
1366                        prelude,
1367                        prelude_map,
1368                        cse_count,
1369                    ) {
1370                        return i;
1371                    }
1372                }
1373            }
1374            let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1375            let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
1376            let i = local.len();
1377            local.push(SummandOp::Local(match op {
1378                BinOp::Add => TapeOp::Add(l, r),
1379                BinOp::Sub => TapeOp::Sub(l, r),
1380                BinOp::Mul => TapeOp::Mul(l, r),
1381                BinOp::Div => TapeOp::Div(l, r),
1382                BinOp::Pow => TapeOp::Pow(l, r),
1383            }));
1384            i
1385        }
1386        Expr::Unary(op, a) => {
1387            let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1388            let i = local.len();
1389            local.push(SummandOp::Local(match op {
1390                UnaryOp::Neg => TapeOp::Neg(v),
1391                UnaryOp::Sqrt => TapeOp::Sqrt(v),
1392                UnaryOp::Log => TapeOp::Log(v),
1393                UnaryOp::Log10 => TapeOp::Log10(v),
1394                UnaryOp::Exp => TapeOp::Exp(v),
1395                UnaryOp::Abs => TapeOp::Abs(v),
1396                UnaryOp::Sin => TapeOp::Sin(v),
1397                UnaryOp::Cos => TapeOp::Cos(v),
1398            }));
1399            i
1400        }
1401        Expr::Sum(args) => {
1402            if args.is_empty() {
1403                let i = local.len();
1404                local.push(SummandOp::Local(TapeOp::Const(0.0)));
1405                return i;
1406            }
1407            let mut acc = build_into_summand(
1408                &args[0],
1409                local,
1410                local_cache,
1411                prelude,
1412                prelude_map,
1413                cse_count,
1414            );
1415            for a in &args[1..] {
1416                let nxt =
1417                    build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1418                let i = local.len();
1419                local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
1420                acc = i;
1421            }
1422            acc
1423        }
1424        Expr::Cse(body) => {
1425            let key = Rc::as_ptr(body) as *const Expr;
1426            if let Some(&li) = local_cache.get(&key) {
1427                return li;
1428            }
1429            let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
1430            if promoted {
1431                // Build (or reuse) the prelude slot for this CSE.
1432                // `build_recursive(expr, ...)` hits the Cse arm,
1433                // emits the body once into prelude, and caches it
1434                // in `prelude_map` keyed by this Rc pointer.
1435                let pslot = build_recursive(expr, prelude, prelude_map);
1436                let li = local.len();
1437                local.push(SummandOp::Shared(pslot));
1438                local_cache.insert(key, li);
1439                li
1440            } else {
1441                let li =
1442                    build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
1443                local_cache.insert(key, li);
1444                li
1445            }
1446        }
1447    }
1448}
1449
1450/// Pow-lowering specialised for summand builds. Mirrors
1451/// `try_emit_const_pow` but with summand-flavoured emission.
1452fn try_emit_const_pow_summand(
1453    base_expr: &Expr,
1454    c: f64,
1455    local: &mut Vec<SummandOp>,
1456    local_cache: &mut HashMap<*const Expr, usize>,
1457    prelude: &mut Vec<TapeOp>,
1458    prelude_map: &mut HashMap<*const Expr, usize>,
1459    cse_count: &HashMap<*const Expr, usize>,
1460) -> Option<usize> {
1461    if c == 0.0 {
1462        let i = local.len();
1463        local.push(SummandOp::Local(TapeOp::Const(1.0)));
1464        return Some(i);
1465    }
1466    if c == 1.0 {
1467        return Some(build_into_summand(
1468            base_expr,
1469            local,
1470            local_cache,
1471            prelude,
1472            prelude_map,
1473            cse_count,
1474        ));
1475    }
1476    if c == 0.5 {
1477        let b = build_into_summand(
1478            base_expr,
1479            local,
1480            local_cache,
1481            prelude,
1482            prelude_map,
1483            cse_count,
1484        );
1485        let i = local.len();
1486        local.push(SummandOp::Local(TapeOp::Sqrt(b)));
1487        return Some(i);
1488    }
1489    if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1490        let n = c.abs() as u32;
1491        if n == 0 {
1492            let i = local.len();
1493            local.push(SummandOp::Local(TapeOp::Const(1.0)));
1494            return Some(i);
1495        }
1496        let b = build_into_summand(
1497            base_expr,
1498            local,
1499            local_cache,
1500            prelude,
1501            prelude_map,
1502            cse_count,
1503        );
1504        let pos = emit_int_pow_summand(b, n, local);
1505        if c < 0.0 {
1506            let one_idx = local.len();
1507            local.push(SummandOp::Local(TapeOp::Const(1.0)));
1508            let i = local.len();
1509            local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
1510            return Some(i);
1511        }
1512        return Some(pos);
1513    }
1514    None
1515}
1516
1517fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
1518    debug_assert!(n >= 1);
1519    if n == 1 {
1520        return base;
1521    }
1522    let half = emit_int_pow_summand(base, n / 2, local);
1523    let squared = local.len();
1524    local.push(SummandOp::Local(TapeOp::Mul(half, half)));
1525    if n % 2 == 1 {
1526        let i = local.len();
1527        local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
1528        i
1529    } else {
1530        squared
1531    }
1532}
1533
1534/// Walk a summand's local op DAG from `root`, returning the
1535/// reachable local slots (sorted ascending) plus the distinct
1536/// prelude slots referenced by any Shared op along the way.
1537fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
1538    let mut visited = vec![false; ops.len()];
1539    let mut reach: Vec<usize> = Vec::new();
1540    let mut shared: BTreeSet<usize> = BTreeSet::new();
1541    let mut stack: Vec<usize> = Vec::with_capacity(16);
1542    visited[root] = true;
1543    reach.push(root);
1544    stack.push(root);
1545    while let Some(s) = stack.pop() {
1546        match &ops[s] {
1547            SummandOp::Local(op) => {
1548                let (a, b) = op_operands(op);
1549                if let Some(a) = a {
1550                    if !visited[a] {
1551                        visited[a] = true;
1552                        reach.push(a);
1553                        stack.push(a);
1554                    }
1555                }
1556                if let Some(b) = b {
1557                    if !visited[b] {
1558                        visited[b] = true;
1559                        reach.push(b);
1560                        stack.push(b);
1561                    }
1562                }
1563            }
1564            SummandOp::Shared(k) => {
1565                shared.insert(*k);
1566            }
1567        }
1568    }
1569    reach.sort_unstable();
1570    (reach, shared.into_iter().collect())
1571}
1572
1573/// Epoch-tagged BFS over the prelude operand DAG, accumulating
1574/// reachable slots into `out`. Caller is responsible for sorting
1575/// `out` after a batch of starts has been processed.
1576fn bfs_prelude(
1577    prelude: &[TapeOp],
1578    start: usize,
1579    visited: &mut [u32],
1580    cur: u32,
1581    stack: &mut Vec<usize>,
1582    out: &mut Vec<usize>,
1583) {
1584    if visited[start] == cur {
1585        return;
1586    }
1587    visited[start] = cur;
1588    out.push(start);
1589    stack.push(start);
1590    while let Some(s) = stack.pop() {
1591        let (a, b) = op_operands(&prelude[s]);
1592        if let Some(a) = a {
1593            if visited[a] != cur {
1594                visited[a] = cur;
1595                out.push(a);
1596                stack.push(a);
1597            }
1598        }
1599        if let Some(b) = b {
1600            if visited[b] != cur {
1601                visited[b] = cur;
1602                out.push(b);
1603                stack.push(b);
1604            }
1605        }
1606    }
1607}
1608
1609/// Per-op var-set for the prelude — every slot's transitive
1610/// variable footprint. Used by `summand_sparsity` to expand
1611/// `SummandOp::Shared(k)` into its var-set carrier.
1612fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
1613    let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1614    for op in ops {
1615        let vs: BTreeSet<usize> = match op {
1616            TapeOp::Const(_) => BTreeSet::new(),
1617            TapeOp::Var(j) => {
1618                let mut s = BTreeSet::new();
1619                s.insert(*j);
1620                s
1621            }
1622            TapeOp::Add(a, b)
1623            | TapeOp::Sub(a, b)
1624            | TapeOp::Mul(a, b)
1625            | TapeOp::Div(a, b)
1626            | TapeOp::Pow(a, b) => out[*a].union(&out[*b]).copied().collect(),
1627            TapeOp::Neg(a)
1628            | TapeOp::Abs(a)
1629            | TapeOp::Sqrt(a)
1630            | TapeOp::Exp(a)
1631            | TapeOp::Log(a)
1632            | TapeOp::Log10(a)
1633            | TapeOp::Sin(a)
1634            | TapeOp::Cos(a) => out[*a].clone(),
1635        };
1636        out.push(vs);
1637    }
1638    out
1639}
1640
1641/// Per-op Hessian-sparsity propagation over a summand's mixed
1642/// SummandOp slice. Shared refs contribute their prelude var-set
1643/// but do not themselves emit pairs (those came from
1644/// `hessian_sparsity_impl(&prelude)`).
1645fn summand_sparsity(
1646    ops: &[SummandOp],
1647    prelude_var_sets: &[BTreeSet<usize>],
1648    pairs: &mut BTreeSet<(usize, usize)>,
1649) {
1650    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1651    let emit_cross =
1652        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1653            for &v1 in s1 {
1654                for &v2 in s2 {
1655                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1656                    pairs.insert((r, c));
1657                }
1658            }
1659        };
1660    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1661        let vars: Vec<usize> = s.iter().copied().collect();
1662        for (ai, &vi) in vars.iter().enumerate() {
1663            for &vj in &vars[..=ai] {
1664                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1665                pairs.insert((r, c));
1666            }
1667        }
1668    };
1669    for so in ops {
1670        let vset: BTreeSet<usize> = match so {
1671            SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
1672            SummandOp::Local(op) => match op {
1673                TapeOp::Const(_) => BTreeSet::new(),
1674                TapeOp::Var(j) => {
1675                    let mut s = BTreeSet::new();
1676                    s.insert(*j);
1677                    s
1678                }
1679                TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1680                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1681                }
1682                TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1683                TapeOp::Mul(a, b) => {
1684                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1685                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1686                }
1687                TapeOp::Div(a, b) => {
1688                    emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1689                    emit_self(&var_sets[*b], pairs);
1690                    var_sets[*a].union(&var_sets[*b]).copied().collect()
1691                }
1692                TapeOp::Pow(a, b) => {
1693                    let combined: BTreeSet<usize> =
1694                        var_sets[*a].union(&var_sets[*b]).copied().collect();
1695                    emit_self(&combined, pairs);
1696                    combined
1697                }
1698                TapeOp::Sqrt(a)
1699                | TapeOp::Exp(a)
1700                | TapeOp::Log(a)
1701                | TapeOp::Log10(a)
1702                | TapeOp::Sin(a)
1703                | TapeOp::Cos(a) => {
1704                    emit_self(&var_sets[*a], pairs);
1705                    var_sets[*a].clone()
1706                }
1707            },
1708        };
1709        var_sets.push(vset);
1710    }
1711}
1712
1713/// Operand indices of a `TapeOp`, normalized into a fixed-length
1714/// array so callers don't need to re-match every site.
1715#[inline]
1716fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
1717    match op {
1718        TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
1719        TapeOp::Add(a, b)
1720        | TapeOp::Sub(a, b)
1721        | TapeOp::Mul(a, b)
1722        | TapeOp::Div(a, b)
1723        | TapeOp::Pow(a, b) => (Some(*a), Some(*b)),
1724        TapeOp::Neg(a)
1725        | TapeOp::Abs(a)
1726        | TapeOp::Sqrt(a)
1727        | TapeOp::Exp(a)
1728        | TapeOp::Log(a)
1729        | TapeOp::Log10(a)
1730        | TapeOp::Sin(a)
1731        | TapeOp::Cos(a) => (Some(*a), None),
1732    }
1733}
1734
1735fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
1736    let mut s: BTreeSet<usize> = BTreeSet::new();
1737    for &i in reach {
1738        if let TapeOp::Var(j) = &ops[i] {
1739            s.insert(*j);
1740        }
1741    }
1742    s.into_iter().collect()
1743}
1744
1745// ----- Free-function AD step kernels used by GlobalTape -----
1746
1747#[inline]
1748fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
1749    match op {
1750        TapeOp::Const(c) => *c,
1751        TapeOp::Var(i) => x[*i],
1752        TapeOp::Add(a, b) => vals[*a] + vals[*b],
1753        TapeOp::Sub(a, b) => vals[*a] - vals[*b],
1754        TapeOp::Mul(a, b) => vals[*a] * vals[*b],
1755        TapeOp::Div(a, b) => vals[*a] / vals[*b],
1756        TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
1757        TapeOp::Neg(a) => -vals[*a],
1758        TapeOp::Abs(a) => vals[*a].abs(),
1759        TapeOp::Sqrt(a) => vals[*a].sqrt(),
1760        TapeOp::Exp(a) => vals[*a].exp(),
1761        TapeOp::Log(a) => vals[*a].ln(),
1762        TapeOp::Log10(a) => vals[*a].log10(),
1763        TapeOp::Sin(a) => vals[*a].sin(),
1764        TapeOp::Cos(a) => vals[*a].cos(),
1765    }
1766}
1767
1768#[inline]
1769fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
1770    match op {
1771        TapeOp::Const(_) => {}
1772        TapeOp::Var(j) => {
1773            grad[*j] += a;
1774        }
1775        TapeOp::Add(l, r) => {
1776            adj[*l] += a;
1777            adj[*r] += a;
1778        }
1779        TapeOp::Sub(l, r) => {
1780            adj[*l] += a;
1781            adj[*r] -= a;
1782        }
1783        TapeOp::Mul(l, r) => {
1784            adj[*l] += a * vals[*r];
1785            adj[*r] += a * vals[*l];
1786        }
1787        TapeOp::Div(l, r) => {
1788            let rv = vals[*r];
1789            adj[*l] += a / rv;
1790            adj[*r] -= a * vals[*l] / (rv * rv);
1791        }
1792        TapeOp::Pow(l, r) => {
1793            let lv = vals[*l];
1794            let rv = vals[*r];
1795            if rv != 0.0 {
1796                adj[*l] += a * rv * lv.powf(rv - 1.0);
1797            }
1798            if lv > 0.0 {
1799                adj[*r] += a * vals[i] * lv.ln();
1800            }
1801        }
1802        TapeOp::Neg(j) => {
1803            adj[*j] -= a;
1804        }
1805        TapeOp::Abs(j) => {
1806            if vals[*j] >= 0.0 {
1807                adj[*j] += a;
1808            } else {
1809                adj[*j] -= a;
1810            }
1811        }
1812        TapeOp::Sqrt(j) => {
1813            let sv = vals[i];
1814            if sv > 0.0 {
1815                adj[*j] += a * 0.5 / sv;
1816            }
1817        }
1818        TapeOp::Exp(j) => {
1819            adj[*j] += a * vals[i];
1820        }
1821        TapeOp::Log(j) => {
1822            adj[*j] += a / vals[*j];
1823        }
1824        TapeOp::Log10(j) => {
1825            adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
1826        }
1827        TapeOp::Sin(j) => {
1828            adj[*j] += a * vals[*j].cos();
1829        }
1830        TapeOp::Cos(j) => {
1831            adj[*j] -= a * vals[*j].sin();
1832        }
1833    }
1834}
1835
1836#[inline]
1837fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
1838    match op {
1839        TapeOp::Const(_) => 0.0,
1840        TapeOp::Var(k) => {
1841            if *k == seed_var {
1842                1.0
1843            } else {
1844                0.0
1845            }
1846        }
1847        TapeOp::Add(a, b) => dot[*a] + dot[*b],
1848        TapeOp::Sub(a, b) => dot[*a] - dot[*b],
1849        TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
1850        TapeOp::Div(a, b) => {
1851            let vb = vals[*b];
1852            (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
1853        }
1854        TapeOp::Pow(a, b) => {
1855            let u = vals[*a];
1856            let r = vals[*b];
1857            let du = dot[*a];
1858            let dr = dot[*b];
1859            let mut result = 0.0;
1860            if r != 0.0 && u != 0.0 {
1861                result += r * u.powf(r - 1.0) * du;
1862            }
1863            if u > 0.0 {
1864                result += vals[i] * u.ln() * dr;
1865            }
1866            result
1867        }
1868        TapeOp::Neg(a) => -dot[*a],
1869        TapeOp::Abs(a) => {
1870            if vals[*a] >= 0.0 {
1871                dot[*a]
1872            } else {
1873                -dot[*a]
1874            }
1875        }
1876        TapeOp::Sqrt(a) => {
1877            let sv = vals[i];
1878            if sv > 0.0 {
1879                dot[*a] * 0.5 / sv
1880            } else {
1881                0.0
1882            }
1883        }
1884        TapeOp::Exp(a) => dot[*a] * vals[i],
1885        TapeOp::Log(a) => dot[*a] / vals[*a],
1886        TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
1887        TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
1888        TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
1889    }
1890}
1891
1892#[allow(clippy::too_many_arguments)]
1893#[inline]
1894fn ror_step(
1895    op: &TapeOp,
1896    i: usize,
1897    seed_var: usize,
1898    vals: &[f64],
1899    dot: &[f64],
1900    adj: &mut [f64],
1901    adj_dot: &mut [f64],
1902    w: f64,
1903    wd: f64,
1904    weight: f64,
1905    hess_map: &HashMap<(usize, usize), usize>,
1906    values: &mut [f64],
1907) {
1908    match op {
1909        TapeOp::Const(_) => {}
1910        TapeOp::Var(k) => {
1911            if wd != 0.0 && *k >= seed_var {
1912                if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
1913                    values[pos] += weight * wd;
1914                }
1915            }
1916        }
1917        TapeOp::Add(a, b) => {
1918            adj[*a] += w;
1919            adj[*b] += w;
1920            adj_dot[*a] += wd;
1921            adj_dot[*b] += wd;
1922        }
1923        TapeOp::Sub(a, b) => {
1924            adj[*a] += w;
1925            adj[*b] -= w;
1926            adj_dot[*a] += wd;
1927            adj_dot[*b] -= wd;
1928        }
1929        TapeOp::Mul(a, b) => {
1930            adj[*a] += w * vals[*b];
1931            adj[*b] += w * vals[*a];
1932            adj_dot[*a] += wd * vals[*b] + w * dot[*b];
1933            adj_dot[*b] += wd * vals[*a] + w * dot[*a];
1934        }
1935        TapeOp::Div(a, b) => {
1936            let vb = vals[*b];
1937            let vb2 = vb * vb;
1938            let vb3 = vb2 * vb;
1939            adj[*a] += w / vb;
1940            adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
1941            adj[*b] += w * (-vals[*a] / vb2);
1942            adj_dot[*b] +=
1943                wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
1944        }
1945        TapeOp::Pow(a, b) => {
1946            let u = vals[*a];
1947            let r = vals[*b];
1948            let du = dot[*a];
1949            let dr = dot[*b];
1950            if r != 0.0 {
1951                if u != 0.0 {
1952                    let p_a = r * u.powf(r - 1.0);
1953                    adj[*a] += w * p_a;
1954                    let mut dp_a = dr * u.powf(r - 1.0);
1955                    if u > 0.0 {
1956                        dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1957                    } else {
1958                        dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1959                    }
1960                    adj_dot[*a] += wd * p_a + w * dp_a;
1961                } else if r >= 2.0 {
1962                    let p_a = 0.0;
1963                    adj[*a] += w * p_a;
1964                    let dp_a = if r == 2.0 {
1965                        2.0 * du
1966                    } else {
1967                        r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1968                    };
1969                    adj_dot[*a] += wd * p_a + w * dp_a;
1970                }
1971            }
1972            if u > 0.0 {
1973                let ln_u = u.ln();
1974                let p_b = vals[i] * ln_u;
1975                adj[*b] += w * p_b;
1976                let dur = vals[i] * (r * du / u + dr * ln_u);
1977                let dp_b = dur * ln_u + vals[i] * du / u;
1978                adj_dot[*b] += wd * p_b + w * dp_b;
1979            }
1980        }
1981        TapeOp::Neg(a) => {
1982            adj[*a] -= w;
1983            adj_dot[*a] -= wd;
1984        }
1985        TapeOp::Abs(a) => {
1986            let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
1987            adj[*a] += w * s;
1988            adj_dot[*a] += wd * s;
1989        }
1990        TapeOp::Sqrt(a) => {
1991            let sv = vals[i];
1992            if sv > 0.0 {
1993                let fp = 0.5 / sv;
1994                let fpp = -0.25 / (vals[*a] * sv);
1995                adj[*a] += w * fp;
1996                adj_dot[*a] += wd * fp + w * fpp * dot[*a];
1997            }
1998        }
1999        TapeOp::Exp(a) => {
2000            let ev = vals[i];
2001            adj[*a] += w * ev;
2002            adj_dot[*a] += wd * ev + w * ev * dot[*a];
2003        }
2004        TapeOp::Log(a) => {
2005            let u = vals[*a];
2006            adj[*a] += w / u;
2007            adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
2008        }
2009        TapeOp::Log10(a) => {
2010            let u = vals[*a];
2011            let c = std::f64::consts::LN_10;
2012            adj[*a] += w / (u * c);
2013            adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
2014        }
2015        TapeOp::Sin(a) => {
2016            let u = vals[*a];
2017            let cu = u.cos();
2018            adj[*a] += w * cu;
2019            adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
2020        }
2021        TapeOp::Cos(a) => {
2022            let u = vals[*a];
2023            let su = u.sin();
2024            adj[*a] -= w * su;
2025            adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
2026        }
2027    }
2028}
2029
2030/// Per-op Hessian-sparsity propagation. Same algorithm as
2031/// `Tape::hessian_sparsity` but as a free function so `GlobalTape`
2032/// can call it over its shared `ops` slice.
2033fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
2034    let n = ops.len();
2035    let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
2036    let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2037
2038    let emit_cross =
2039        |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2040            for &v1 in s1 {
2041                for &v2 in s2 {
2042                    let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2043                    pairs.insert((r, c));
2044                }
2045            }
2046        };
2047    let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2048        let vars: Vec<usize> = s.iter().copied().collect();
2049        for (ai, &vi) in vars.iter().enumerate() {
2050            for &vj in &vars[..=ai] {
2051                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2052                pairs.insert((r, c));
2053            }
2054        }
2055    };
2056
2057    for op in ops {
2058        let vset = match op {
2059            TapeOp::Const(_) => BTreeSet::new(),
2060            TapeOp::Var(j) => {
2061                let mut s = BTreeSet::new();
2062                s.insert(*j);
2063                s
2064            }
2065            TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2066                var_sets[*a].union(&var_sets[*b]).copied().collect()
2067            }
2068            TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2069            TapeOp::Mul(a, b) => {
2070                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2071                var_sets[*a].union(&var_sets[*b]).copied().collect()
2072            }
2073            TapeOp::Div(a, b) => {
2074                emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2075                emit_self(&var_sets[*b], &mut pairs);
2076                var_sets[*a].union(&var_sets[*b]).copied().collect()
2077            }
2078            TapeOp::Pow(a, b) => {
2079                let combined: BTreeSet<usize> =
2080                    var_sets[*a].union(&var_sets[*b]).copied().collect();
2081                emit_self(&combined, &mut pairs);
2082                combined
2083            }
2084            TapeOp::Sqrt(a)
2085            | TapeOp::Exp(a)
2086            | TapeOp::Log(a)
2087            | TapeOp::Log10(a)
2088            | TapeOp::Sin(a)
2089            | TapeOp::Cos(a) => {
2090                emit_self(&var_sets[*a], &mut pairs);
2091                var_sets[*a].clone()
2092            }
2093        };
2094        var_sets.push(vset);
2095    }
2096    pairs
2097}
2098
2099#[cfg(test)]
2100mod tests {
2101    use super::*;
2102
2103    fn cnst(c: f64) -> Expr {
2104        Expr::Const(c)
2105    }
2106    fn var(i: usize) -> Expr {
2107        Expr::Var(i)
2108    }
2109    fn add(a: Expr, b: Expr) -> Expr {
2110        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
2111    }
2112    fn mul(a: Expr, b: Expr) -> Expr {
2113        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
2114    }
2115    fn pow(a: Expr, b: Expr) -> Expr {
2116        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
2117    }
2118    fn div(a: Expr, b: Expr) -> Expr {
2119        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
2120    }
2121    fn unary(op: UnaryOp, a: Expr) -> Expr {
2122        Expr::Unary(op, Box::new(a))
2123    }
2124
2125    #[test]
2126    fn polynomial_eval_and_grad() {
2127        // f = 3*x0^2 + 2*x1
2128        let e = add(
2129            mul(cnst(3.0), pow(var(0), cnst(2.0))),
2130            mul(cnst(2.0), var(1)),
2131        );
2132        let t = Tape::build(&e);
2133        assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
2134        let mut g = vec![0.0; 2];
2135        t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
2136        // df/dx0 = 6*x0 = 12, df/dx1 = 2
2137        assert!((g[0] - 12.0).abs() < 1e-12);
2138        assert!((g[1] - 2.0).abs() < 1e-12);
2139    }
2140
2141    #[test]
2142    fn cse_shared_body_evaluated_once() {
2143        // body = x0 + x1, shared via Rc. f = body^2 + body.
2144        let body = Rc::new(add(var(0), var(1)));
2145        let e = add(
2146            pow(Expr::Cse(body.clone()), cnst(2.0)),
2147            Expr::Cse(body.clone()),
2148        );
2149        let t = Tape::build(&e);
2150        // body should appear once in the tape: count Add(Var(0),Var(1)) ops
2151        let n_body_adds = t
2152            .ops
2153            .iter()
2154            .filter(|op| {
2155                matches!(op, TapeOp::Add(a, b) if {
2156                    matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
2157                })
2158            })
2159            .count();
2160        assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
2161
2162        // f(1, 2) = 9 + 3 = 12
2163        assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
2164        let mut g = vec![0.0; 2];
2165        t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
2166        // df/dx0 = 2*(x0+x1) + 1 = 7, same for x1
2167        assert!((g[0] - 7.0).abs() < 1e-12);
2168        assert!((g[1] - 7.0).abs() < 1e-12);
2169    }
2170
2171    fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
2172        let vars = tape.variables();
2173        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2174        let mut pairs = Vec::new();
2175        for (ai, &vi) in vars.iter().enumerate() {
2176            for &vj in &vars[..=ai] {
2177                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2178                hess_map.entry((r, c)).or_insert_with(|| {
2179                    let p = pairs.len();
2180                    pairs.push((r, c));
2181                    p
2182                });
2183            }
2184        }
2185        let nnz = pairs.len();
2186        let mut ad = vec![0.0; nnz];
2187        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2188
2189        let mut fd = vec![0.0; nnz];
2190        let mut xp = x.to_vec();
2191        let mut gp = vec![0.0; n];
2192        let mut gm = vec![0.0; n];
2193        for &j in &vars {
2194            let h = (1e-7_f64).max(x[j].abs() * 1e-7);
2195            xp[j] = x[j] + h;
2196            gp.iter_mut().for_each(|v| *v = 0.0);
2197            tape.gradient_seed(&xp, 1.0, &mut gp);
2198            xp[j] = x[j] - h;
2199            gm.iter_mut().for_each(|v| *v = 0.0);
2200            tape.gradient_seed(&xp, 1.0, &mut gm);
2201            xp[j] = x[j];
2202            for &i in &vars {
2203                if i >= j {
2204                    if let Some(&pos) = hess_map.get(&(i, j)) {
2205                        fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
2206                    }
2207                }
2208            }
2209        }
2210        for (k, &(r, c)) in pairs.iter().enumerate() {
2211            let scale = fd[k].abs().max(1.0);
2212            assert!(
2213                (ad[k] - fd[k]).abs() / scale < tol,
2214                "H[{},{}]: AD={:.6e} FD={:.6e}",
2215                r,
2216                c,
2217                ad[k],
2218                fd[k]
2219            );
2220        }
2221    }
2222
2223    #[test]
2224    fn hessian_quadratic_matches_fd() {
2225        // f = 3 x0^2 + 2 x0 x1 + x1^2
2226        let e = add(
2227            add(
2228                mul(cnst(3.0), pow(var(0), cnst(2.0))),
2229                mul(cnst(2.0), mul(var(0), var(1))),
2230            ),
2231            pow(var(1), cnst(2.0)),
2232        );
2233        let t = Tape::build(&e);
2234        fd_check(&t, &[2.0, 3.0], 2, 1e-5);
2235    }
2236
2237    #[test]
2238    fn hessian_transcendental_matches_fd() {
2239        // f = exp(x0) + sin(x1) + log(x0) + sqrt(x1) + x0*x1
2240        let e = Expr::Sum(vec![
2241            unary(UnaryOp::Exp, var(0)),
2242            unary(UnaryOp::Sin, var(1)),
2243            unary(UnaryOp::Log, var(0)),
2244            unary(UnaryOp::Sqrt, var(1)),
2245            mul(var(0), var(1)),
2246        ]);
2247        let t = Tape::build(&e);
2248        fd_check(&t, &[1.5, 2.0], 2, 1e-5);
2249    }
2250
2251    #[test]
2252    fn hessian_division_matches_fd() {
2253        // f = x0/x1 + cos(x0)
2254        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2255        let t = Tape::build(&e);
2256        fd_check(&t, &[0.5, 1.2], 2, 1e-5);
2257    }
2258
2259    /// `hessian_directional` (one forward-over-reverse pass with
2260    /// a seed vector) recovers `H · e_j` for each unit-vector seed,
2261    /// matching column `j` of the dense Hessian computed by
2262    /// `hessian_accumulate`.
2263    fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
2264        let vars = tape.variables();
2265        let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2266        let mut pairs = Vec::new();
2267        for (ai, &vi) in vars.iter().enumerate() {
2268            for &vj in &vars[..=ai] {
2269                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2270                hess_map.entry((r, c)).or_insert_with(|| {
2271                    let p = pairs.len();
2272                    pairs.push((r, c));
2273                    p
2274                });
2275            }
2276        }
2277        let nnz = pairs.len();
2278        let mut ad = vec![0.0; nnz];
2279        tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2280
2281        let nops = tape.ops.len();
2282        let mut vals = vec![0.0; nops];
2283        tape.forward_into(x, &mut vals);
2284        let mut dot = vec![0.0; nops];
2285        let mut adj = vec![0.0; nops];
2286        let mut adj_dot = vec![0.0; nops];
2287
2288        for &j in &vars {
2289            let mut seed = vec![0.0; n];
2290            seed[j] = 1.0;
2291            let mut col = vec![0.0; n];
2292            tape.hessian_directional(
2293                &vals,
2294                &seed,
2295                1.0,
2296                &mut col,
2297                &mut dot,
2298                &mut adj,
2299                &mut adj_dot,
2300            );
2301            for &i in &vars {
2302                let (r, c) = if i >= j { (i, j) } else { (j, i) };
2303                let expect = ad[hess_map[&(r, c)]];
2304                assert!(
2305                    (col[i] - expect).abs() < 1e-10,
2306                    "directional H[{i},{j}] = {} vs accumulate {}",
2307                    col[i],
2308                    expect
2309                );
2310            }
2311        }
2312    }
2313
2314    #[test]
2315    fn directional_quadratic_matches_accumulate() {
2316        // f = 3 x0^2 + 2 x0 x1 + x1^2
2317        let e = add(
2318            add(
2319                mul(cnst(3.0), pow(var(0), cnst(2.0))),
2320                mul(mul(cnst(2.0), var(0)), var(1)),
2321            ),
2322            pow(var(1), cnst(2.0)),
2323        );
2324        let t = Tape::build(&e);
2325        directional_matches_accumulate(&t, &[0.5, -0.3], 2);
2326    }
2327
2328    #[test]
2329    fn directional_transcendental_matches_accumulate() {
2330        let e = Expr::Sum(vec![
2331            unary(UnaryOp::Exp, var(0)),
2332            unary(UnaryOp::Sin, var(1)),
2333            unary(UnaryOp::Log, var(0)),
2334            unary(UnaryOp::Sqrt, var(1)),
2335            mul(var(0), var(1)),
2336        ]);
2337        let t = Tape::build(&e);
2338        directional_matches_accumulate(&t, &[1.5, 2.0], 2);
2339    }
2340
2341    #[test]
2342    fn directional_with_division_matches_accumulate() {
2343        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2344        let t = Tape::build(&e);
2345        directional_matches_accumulate(&t, &[0.5, 1.2], 2);
2346    }
2347
2348    #[test]
2349    fn hessian_sparsity_separable() {
2350        // f = sin(x0) + x1*x2; couplings: (0,0) from sin, (2,1) from x1*x2
2351        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
2352        let t = Tape::build(&e);
2353        let s = t.hessian_sparsity();
2354        assert!(s.contains(&(0, 0)));
2355        assert!(s.contains(&(2, 1)));
2356        assert!(!s.contains(&(1, 0)));
2357        assert!(!s.contains(&(2, 0)));
2358    }
2359
2360    fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
2361        t.ops.iter().filter(|o| pred(o)).count()
2362    }
2363
2364    #[test]
2365    fn pow_zero_const_folds_to_one() {
2366        // x^0 → 1 (no Pow, no reference to x in the tape)
2367        let e = pow(var(0), cnst(0.0));
2368        let t = Tape::build(&e);
2369        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2370        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
2371        assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
2372    }
2373
2374    #[test]
2375    fn pow_one_passes_through() {
2376        // x^1 → x (no Pow, no Const introduced for the exponent)
2377        let e = pow(var(0), cnst(1.0));
2378        let t = Tape::build(&e);
2379        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2380        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
2381        assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
2382    }
2383
2384    #[test]
2385    fn pow_half_lowers_to_sqrt() {
2386        let e = pow(var(0), cnst(0.5));
2387        let t = Tape::build(&e);
2388        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2389        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
2390        assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
2391    }
2392
2393    #[test]
2394    fn pow_two_lowers_to_single_mul() {
2395        let e = pow(var(0), cnst(2.0));
2396        let t = Tape::build(&e);
2397        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2398        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2399        assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
2400    }
2401
2402    #[test]
2403    fn pow_three_lowers_to_two_muls() {
2404        let e = pow(var(0), cnst(3.0));
2405        let t = Tape::build(&e);
2406        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2407        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
2408        assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
2409    }
2410
2411    #[test]
2412    fn pow_eight_lowers_to_three_muls() {
2413        // Binary expansion: x → x² → x⁴ → x⁸ (3 squarings)
2414        let e = pow(var(0), cnst(8.0));
2415        let t = Tape::build(&e);
2416        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2417        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
2418        assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
2419    }
2420
2421    #[test]
2422    fn pow_negative_two_lowers_to_div() {
2423        // x^-2 → 1 / (x*x)
2424        let e = pow(var(0), cnst(-2.0));
2425        let t = Tape::build(&e);
2426        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2427        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
2428        assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
2429    }
2430
2431    #[test]
2432    fn pow_large_const_stays_generic() {
2433        // x^9 stays as Pow — beyond the cutoff, generic is cheaper.
2434        let e = pow(var(0), cnst(9.0));
2435        let t = Tape::build(&e);
2436        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2437    }
2438
2439    #[test]
2440    fn pow_non_integer_const_stays_generic() {
2441        // x^1.5 stays as Pow until half-integer handling is added.
2442        let e = pow(var(0), cnst(1.5));
2443        let t = Tape::build(&e);
2444        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2445    }
2446
2447    #[test]
2448    fn pow_const_through_cse_const() {
2449        // Exponent wrapped in Cse — peek_const should still see it.
2450        let two = Rc::new(cnst(2.0));
2451        let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
2452        let t = Tape::build(&e);
2453        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2454        assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2455    }
2456
2457    #[test]
2458    fn hessian_pow_three_matches_fd() {
2459        // f = 5 * x0^3 + x0 * x1 — exercises the lowered cubic + cross term.
2460        let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
2461        let t = Tape::build(&e);
2462        fd_check(&t, &[1.7, 0.8], 2, 1e-5);
2463    }
2464
2465    #[test]
2466    fn hessian_pow_negative_matches_fd() {
2467        // f = 1/x0^2 + x1^2 — exercises lowered x^-2 and x^2.
2468        let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
2469        let t = Tape::build(&e);
2470        fd_check(&t, &[1.3, 2.4], 2, 1e-5);
2471    }
2472
2473    #[test]
2474    fn hessian_pow_half_matches_fd() {
2475        // f = sqrt(x0) + x0*x1 (via Pow(_, 0.5) → Sqrt)
2476        let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
2477        let t = Tape::build(&e);
2478        fd_check(&t, &[2.5, 1.1], 2, 1e-5);
2479    }
2480
2481    #[test]
2482    fn hessian_sparsity_through_cse() {
2483        // body = x0+x1 (CSE). f = body^2 + body.
2484        // d²/dx² of body^2 couples (0,0), (1,0), (1,1).
2485        let body = Rc::new(add(var(0), var(1)));
2486        let e = add(
2487            pow(Expr::Cse(body.clone()), cnst(2.0)),
2488            Expr::Cse(body.clone()),
2489        );
2490        let t = Tape::build(&e);
2491        let s = t.hessian_sparsity();
2492        assert!(s.contains(&(0, 0)));
2493        assert!(s.contains(&(1, 0)));
2494        assert!(s.contains(&(1, 1)));
2495        assert_eq!(s.len(), 3);
2496    }
2497}