Skip to main content

pounce_cli/
nl_hessian_program.rs

1//! Precompiled symbolic-Hessian program for one `Tape`.
2//!
3//! `Tape::hessian_accumulate` runs forward-over-reverse AD at every
4//! call: for each tape variable `j` it (a) match-dispatches every op
5//! in a forward-tangent sweep, (b) zeros adj/adj_dot, (c)
6//! match-dispatches every op again in the reverse-over-tangent
7//! sweep, and (d) HashMap-looks-up every `Var(k)` slot to find its
8//! Hessian output position. On evaluator-bound problems (dirichlet,
9//! lane_emden, henon) that match-dispatch + symbolic-AD overhead is
10//! ~80% of total CPU.
11//!
12//! This module compiles all of that ONCE at tape-build time into a
13//! flat `Vec<HOp>` of pre-resolved primitive ops:
14//!
15//!   * Forward pass — one `Fwd*` op per `TapeOp`. Mirrors
16//!     `Tape::forward`.
17//!   * Per-`j` forward tangent — only the ops touching slots that
18//!     statically depend on `j` are emitted (the rest stay zero
19//!     from the per-`j` `ZeroRange` reset).
20//!   * Per-`j` reverse-over-tangent — only ops on slots reachable
21//!     backward from output, with all slot indices and Hessian
22//!     output pointers pre-resolved.
23//!
24//! ## Scratch layout
25//!
26//! The program reads/writes a single `&mut [f64]` arena of
27//! `n_slots` cells. We allocate four contiguous regions of length
28//! `n` (`n` = `tape.ops.len()`):
29//!
30//!   * `v[i]`        in slot `i`
31//!   * `dot[i]`      in slot `n + i`
32//!   * `adj[i]`      in slot `2n + i`
33//!   * `adj_dot[i]`  in slot `3n + i`
34//!
35//! Per-`j` setup zeros the `[n, 4n)` range and seeds `adj[n-1]`.
36//! Allocation pattern is intentionally trivial — finer-grained
37//! slot recycling buys little once the dispatch loop is the
38//! bottleneck, and a contiguous layout makes the per-`j`
39//! `ZeroRange` reset a single `memset`-friendly loop.
40
41use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45/// One primitive operation in the compiled Hessian program.
46/// `dst`/`a`/`b`/etc. are `u32` offsets into the caller's scratch
47/// slice; see the module docs for the slot layout.
48#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50    // ===== Forward pass =====
51    FwdLoadVar {
52        dst: u32,
53        x_idx: u32,
54    },
55    FwdLoadConst {
56        dst: u32,
57        c_idx: u32,
58    },
59    FwdAdd {
60        dst: u32,
61        a: u32,
62        b: u32,
63    },
64    FwdSub {
65        dst: u32,
66        a: u32,
67        b: u32,
68    },
69    FwdMul {
70        dst: u32,
71        a: u32,
72        b: u32,
73    },
74    FwdDiv {
75        dst: u32,
76        a: u32,
77        b: u32,
78    },
79    FwdPow {
80        dst: u32,
81        a: u32,
82        b: u32,
83    },
84    FwdNeg {
85        dst: u32,
86        a: u32,
87    },
88    FwdAbs {
89        dst: u32,
90        a: u32,
91    },
92    FwdSqrt {
93        dst: u32,
94        a: u32,
95    },
96    FwdExp {
97        dst: u32,
98        a: u32,
99    },
100    FwdLog {
101        dst: u32,
102        a: u32,
103    },
104    FwdLog10 {
105        dst: u32,
106        a: u32,
107    },
108    FwdSin {
109        dst: u32,
110        a: u32,
111    },
112    FwdCos {
113        dst: u32,
114        a: u32,
115    },
116
117    // ===== Scalar slot init =====
118    SetZero {
119        dst: u32,
120    },
121    SetOne {
122        dst: u32,
123    },
124
125    // ===== Bulk reset (start of each j) =====
126    ZeroRange {
127        start: u32,
128        len: u32,
129    },
130
131    // ===== Forward tangent (per j) =====
132    DotAdd {
133        dst: u32,
134        a: u32,
135        b: u32,
136    },
137    DotSub {
138        dst: u32,
139        a: u32,
140        b: u32,
141    },
142    /// dot[d] = dot[a]*v[b] + v[a]*dot[b]
143    DotMul {
144        dst: u32,
145        dot_a: u32,
146        vb: u32,
147        va: u32,
148        dot_b: u32,
149    },
150    /// dot[d] = (dot[a]*v[b] - v[a]*dot[b]) / (v[b]*v[b])
151    DotDiv {
152        dst: u32,
153        dot_a: u32,
154        vb: u32,
155        va: u32,
156        dot_b: u32,
157    },
158    /// dot[d] = 0.5 / v[d] * dot[a]  (v[d] = sqrt(v[a]))
159    DotSqrt {
160        dst: u32,
161        dot_a: u32,
162        vd: u32,
163    },
164    /// dot[d] = v[d] * dot[a]  (v[d] = exp(v[a]))
165    DotExp {
166        dst: u32,
167        dot_a: u32,
168        vd: u32,
169    },
170    DotLog {
171        dst: u32,
172        dot_a: u32,
173        va: u32,
174    },
175    DotLog10 {
176        dst: u32,
177        dot_a: u32,
178        va: u32,
179    },
180    DotSin {
181        dst: u32,
182        dot_a: u32,
183        va: u32,
184    },
185    DotCos {
186        dst: u32,
187        dot_a: u32,
188        va: u32,
189    },
190    DotNeg {
191        dst: u32,
192        dot_a: u32,
193    },
194    DotAbs {
195        dst: u32,
196        dot_a: u32,
197        va: u32,
198    },
199    /// Compound: dot[d] for Pow(a, b). Carries the runtime
200    /// `u != 0` / `u > 0` branches.
201    DotPow {
202        dst: u32,
203        va: u32,
204        vb: u32,
205        vd: u32,
206        dot_a: u32,
207        dot_b: u32,
208    },
209
210    // ===== Reverse + adj_dot update (per j) =====
211    // Each op consumes adj[i] (= `w`) and adj_dot[i] (= `wd`) of
212    // the consumer slot, then `+=`-accumulates into the adj /
213    // adj_dot of the operand slots.
214    RevAdd {
215        adj_a: u32,
216        adj_b: u32,
217        adj_dot_a: u32,
218        adj_dot_b: u32,
219        w: u32,
220        wd: u32,
221    },
222    RevSub {
223        adj_a: u32,
224        adj_b: u32,
225        adj_dot_a: u32,
226        adj_dot_b: u32,
227        w: u32,
228        wd: u32,
229    },
230    RevMul {
231        adj_a: u32,
232        adj_b: u32,
233        adj_dot_a: u32,
234        adj_dot_b: u32,
235        w: u32,
236        wd: u32,
237        va: u32,
238        vb: u32,
239        dot_a: u32,
240        dot_b: u32,
241    },
242    RevDiv {
243        adj_a: u32,
244        adj_b: u32,
245        adj_dot_a: u32,
246        adj_dot_b: u32,
247        w: u32,
248        wd: u32,
249        va: u32,
250        vb: u32,
251        dot_a: u32,
252        dot_b: u32,
253    },
254    RevPow {
255        adj_a: u32,
256        adj_b: u32,
257        adj_dot_a: u32,
258        adj_dot_b: u32,
259        w: u32,
260        wd: u32,
261        va: u32,
262        vb: u32,
263        vd: u32,
264        dot_a: u32,
265        dot_b: u32,
266    },
267    RevNeg {
268        adj_a: u32,
269        adj_dot_a: u32,
270        w: u32,
271        wd: u32,
272    },
273    RevAbs {
274        adj_a: u32,
275        adj_dot_a: u32,
276        w: u32,
277        wd: u32,
278        va: u32,
279    },
280    RevSqrt {
281        adj_a: u32,
282        adj_dot_a: u32,
283        w: u32,
284        wd: u32,
285        va: u32,
286        vd: u32,
287        dot_a: u32,
288    },
289    RevExp {
290        adj_a: u32,
291        adj_dot_a: u32,
292        w: u32,
293        wd: u32,
294        vd: u32,
295        dot_a: u32,
296    },
297    RevLog {
298        adj_a: u32,
299        adj_dot_a: u32,
300        w: u32,
301        wd: u32,
302        va: u32,
303        dot_a: u32,
304    },
305    RevLog10 {
306        adj_a: u32,
307        adj_dot_a: u32,
308        w: u32,
309        wd: u32,
310        va: u32,
311        dot_a: u32,
312    },
313    RevSin {
314        adj_a: u32,
315        adj_dot_a: u32,
316        w: u32,
317        wd: u32,
318        va: u32,
319        dot_a: u32,
320    },
321    RevCos {
322        adj_a: u32,
323        adj_dot_a: u32,
324        w: u32,
325        wd: u32,
326        va: u32,
327        dot_a: u32,
328    },
329
330    // ===== Output =====
331    /// values[hess_ptr] += weight * scratch[adj_dot_slot].
332    HessEmit {
333        hess_ptr: u32,
334        adj_dot_slot: u32,
335    },
336}
337
338/// Precompiled Hessian-of-one-tape program. Built once via
339/// [`HessianProgram::compile`]; executed many times.
340#[derive(Debug, Clone)]
341pub struct HessianProgram {
342    ops: Vec<HOp>,
343    consts: Vec<f64>,
344    n_slots: u32,
345}
346
347impl HessianProgram {
348    /// Build the program. The `hess_map` is the same `(row, col)
349    /// -> values-index` map that [`Tape::hessian_accumulate`] uses;
350    /// the compiler inlines each lookup into a `HessEmit` op.
351    pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Self {
352        let n = tape.ops.len() as u32;
353        let v_base = 0u32;
354        let dot_base = n;
355        let adj_base = 2 * n;
356        let adj_dot_base = 3 * n;
357        let n_slots = 4 * n;
358
359        let v_slot = |i: u32| v_base + i;
360        let dot_slot = |i: u32| dot_base + i;
361        let adj_slot = |i: u32| adj_base + i;
362        let adj_dot_slot = |i: u32| adj_dot_base + i;
363
364        let reachable = reachable_to_output(tape);
365        let var_indices = tape.variables();
366        // depends_on[k_idx][i] — does slot i depend on var_indices[k_idx]?
367        let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
368            .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
369            .collect();
370
371        let mut consts: Vec<f64> = Vec::new();
372        let mut const_intern: HashMap<u64, u32> = HashMap::new();
373        let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
374            let bits = c.to_bits();
375            if let Some(&idx) = const_intern.get(&bits) {
376                return idx;
377            }
378            let idx = consts.len() as u32;
379            consts.push(c);
380            const_intern.insert(bits, idx);
381            idx
382        };
383
384        let mut ops: Vec<HOp> = Vec::new();
385
386        // ---- Forward pass ----
387        for (i, tape_op) in tape.ops.iter().enumerate() {
388            let i = i as u32;
389            let dst = v_slot(i);
390            let op = match *tape_op {
391                TapeOp::Const(c) => HOp::FwdLoadConst {
392                    dst,
393                    c_idx: intern_const(c, &mut consts),
394                },
395                TapeOp::Var(x_idx) => HOp::FwdLoadVar {
396                    dst,
397                    x_idx: x_idx as u32,
398                },
399                TapeOp::Add(a, b) => HOp::FwdAdd {
400                    dst,
401                    a: v_slot(a as u32),
402                    b: v_slot(b as u32),
403                },
404                TapeOp::Sub(a, b) => HOp::FwdSub {
405                    dst,
406                    a: v_slot(a as u32),
407                    b: v_slot(b as u32),
408                },
409                TapeOp::Mul(a, b) => HOp::FwdMul {
410                    dst,
411                    a: v_slot(a as u32),
412                    b: v_slot(b as u32),
413                },
414                TapeOp::Div(a, b) => HOp::FwdDiv {
415                    dst,
416                    a: v_slot(a as u32),
417                    b: v_slot(b as u32),
418                },
419                TapeOp::Pow(a, b) => HOp::FwdPow {
420                    dst,
421                    a: v_slot(a as u32),
422                    b: v_slot(b as u32),
423                },
424                TapeOp::Neg(a) => HOp::FwdNeg {
425                    dst,
426                    a: v_slot(a as u32),
427                },
428                TapeOp::Abs(a) => HOp::FwdAbs {
429                    dst,
430                    a: v_slot(a as u32),
431                },
432                TapeOp::Sqrt(a) => HOp::FwdSqrt {
433                    dst,
434                    a: v_slot(a as u32),
435                },
436                TapeOp::Exp(a) => HOp::FwdExp {
437                    dst,
438                    a: v_slot(a as u32),
439                },
440                TapeOp::Log(a) => HOp::FwdLog {
441                    dst,
442                    a: v_slot(a as u32),
443                },
444                TapeOp::Log10(a) => HOp::FwdLog10 {
445                    dst,
446                    a: v_slot(a as u32),
447                },
448                TapeOp::Sin(a) => HOp::FwdSin {
449                    dst,
450                    a: v_slot(a as u32),
451                },
452                TapeOp::Cos(a) => HOp::FwdCos {
453                    dst,
454                    a: v_slot(a as u32),
455                },
456                TapeOp::Funcall { .. } => panic!(
457                    "HessianProgram path does not support AMPL external functions; \
458                     use the Tape (build_with_externals) path instead."
459                ),
460            };
461            ops.push(op);
462        }
463
464        if n == 0 || var_indices.is_empty() {
465            return HessianProgram {
466                ops,
467                consts,
468                n_slots,
469            };
470        }
471
472        // ---- Per-j forward-tangent + reverse-over-tangent ----
473        for (k_idx, &j) in var_indices.iter().enumerate() {
474            // Reset dot, adj, adj_dot for this j. Seed adj[n-1] = 1.
475            ops.push(HOp::ZeroRange {
476                start: dot_base,
477                len: 3 * n,
478            });
479            ops.push(HOp::SetOne {
480                dst: adj_slot(n - 1),
481            });
482
483            // Forward tangent: only emit ops for slots that
484            // statically depend on j (the rest stay zero from the
485            // ZeroRange above).
486            for (i, tape_op) in tape.ops.iter().enumerate() {
487                let i_u = i as u32;
488                if !depends_on[k_idx][i] {
489                    continue;
490                }
491                let dst = dot_slot(i_u);
492                let dot_op = match *tape_op {
493                    // Const: dot stays 0 (filtered above by
494                    // depends_on, since Const has no var-deps).
495                    TapeOp::Const(_) => continue,
496                    // Var(k): dot = 1 iff k == j, else 0. We only
497                    // get here if depends_on[k_idx][i] is true,
498                    // which for Var(k) means k == j.
499                    TapeOp::Var(_) => HOp::SetOne { dst },
500                    TapeOp::Add(a, b) => HOp::DotAdd {
501                        dst,
502                        a: dot_slot(a as u32),
503                        b: dot_slot(b as u32),
504                    },
505                    TapeOp::Sub(a, b) => HOp::DotSub {
506                        dst,
507                        a: dot_slot(a as u32),
508                        b: dot_slot(b as u32),
509                    },
510                    TapeOp::Mul(a, b) => HOp::DotMul {
511                        dst,
512                        dot_a: dot_slot(a as u32),
513                        vb: v_slot(b as u32),
514                        va: v_slot(a as u32),
515                        dot_b: dot_slot(b as u32),
516                    },
517                    TapeOp::Div(a, b) => HOp::DotDiv {
518                        dst,
519                        dot_a: dot_slot(a as u32),
520                        vb: v_slot(b as u32),
521                        va: v_slot(a as u32),
522                        dot_b: dot_slot(b as u32),
523                    },
524                    TapeOp::Pow(a, b) => HOp::DotPow {
525                        dst,
526                        va: v_slot(a as u32),
527                        vb: v_slot(b as u32),
528                        vd: v_slot(i_u),
529                        dot_a: dot_slot(a as u32),
530                        dot_b: dot_slot(b as u32),
531                    },
532                    TapeOp::Neg(a) => HOp::DotNeg {
533                        dst,
534                        dot_a: dot_slot(a as u32),
535                    },
536                    TapeOp::Abs(a) => HOp::DotAbs {
537                        dst,
538                        dot_a: dot_slot(a as u32),
539                        va: v_slot(a as u32),
540                    },
541                    TapeOp::Sqrt(a) => HOp::DotSqrt {
542                        dst,
543                        dot_a: dot_slot(a as u32),
544                        vd: v_slot(i_u),
545                    },
546                    TapeOp::Exp(a) => HOp::DotExp {
547                        dst,
548                        dot_a: dot_slot(a as u32),
549                        vd: v_slot(i_u),
550                    },
551                    TapeOp::Log(a) => HOp::DotLog {
552                        dst,
553                        dot_a: dot_slot(a as u32),
554                        va: v_slot(a as u32),
555                    },
556                    TapeOp::Log10(a) => HOp::DotLog10 {
557                        dst,
558                        dot_a: dot_slot(a as u32),
559                        va: v_slot(a as u32),
560                    },
561                    TapeOp::Sin(a) => HOp::DotSin {
562                        dst,
563                        dot_a: dot_slot(a as u32),
564                        va: v_slot(a as u32),
565                    },
566                    TapeOp::Cos(a) => HOp::DotCos {
567                        dst,
568                        dot_a: dot_slot(a as u32),
569                        va: v_slot(a as u32),
570                    },
571                    TapeOp::Funcall { .. } => panic!(
572                        "HessianProgram path does not support AMPL external functions; \
573                         use the Tape (build_with_externals) path instead."
574                    ),
575                };
576                ops.push(dot_op);
577            }
578
579            // Reverse-over-tangent: walk slots backward, emit only
580            // for reachable slots.
581            for i in (0..n as usize).rev() {
582                if !reachable[i] {
583                    continue;
584                }
585                let i_u = i as u32;
586                let w = adj_slot(i_u);
587                let wd = adj_dot_slot(i_u);
588                let tape_op = &tape.ops[i];
589                let rev_op = match *tape_op {
590                    TapeOp::Const(_) => continue,
591                    TapeOp::Var(k) => {
592                        // At a Var slot: if k >= j and hess_map has
593                        // an entry for (k, j), emit a HessEmit op.
594                        // No adj/adj_dot propagation (no operands).
595                        if k >= j {
596                            if let Some(&ptr) = hess_map.get(&(k, j)) {
597                                ops.push(HOp::HessEmit {
598                                    hess_ptr: ptr as u32,
599                                    adj_dot_slot: wd,
600                                });
601                            }
602                        }
603                        continue;
604                    }
605                    TapeOp::Add(a, b) => HOp::RevAdd {
606                        adj_a: adj_slot(a as u32),
607                        adj_b: adj_slot(b as u32),
608                        adj_dot_a: adj_dot_slot(a as u32),
609                        adj_dot_b: adj_dot_slot(b as u32),
610                        w,
611                        wd,
612                    },
613                    TapeOp::Sub(a, b) => HOp::RevSub {
614                        adj_a: adj_slot(a as u32),
615                        adj_b: adj_slot(b as u32),
616                        adj_dot_a: adj_dot_slot(a as u32),
617                        adj_dot_b: adj_dot_slot(b as u32),
618                        w,
619                        wd,
620                    },
621                    TapeOp::Mul(a, b) => HOp::RevMul {
622                        adj_a: adj_slot(a as u32),
623                        adj_b: adj_slot(b as u32),
624                        adj_dot_a: adj_dot_slot(a as u32),
625                        adj_dot_b: adj_dot_slot(b as u32),
626                        w,
627                        wd,
628                        va: v_slot(a as u32),
629                        vb: v_slot(b as u32),
630                        dot_a: dot_slot(a as u32),
631                        dot_b: dot_slot(b as u32),
632                    },
633                    TapeOp::Div(a, b) => HOp::RevDiv {
634                        adj_a: adj_slot(a as u32),
635                        adj_b: adj_slot(b as u32),
636                        adj_dot_a: adj_dot_slot(a as u32),
637                        adj_dot_b: adj_dot_slot(b as u32),
638                        w,
639                        wd,
640                        va: v_slot(a as u32),
641                        vb: v_slot(b as u32),
642                        dot_a: dot_slot(a as u32),
643                        dot_b: dot_slot(b as u32),
644                    },
645                    TapeOp::Pow(a, b) => HOp::RevPow {
646                        adj_a: adj_slot(a as u32),
647                        adj_b: adj_slot(b as u32),
648                        adj_dot_a: adj_dot_slot(a as u32),
649                        adj_dot_b: adj_dot_slot(b as u32),
650                        w,
651                        wd,
652                        va: v_slot(a as u32),
653                        vb: v_slot(b as u32),
654                        vd: v_slot(i_u),
655                        dot_a: dot_slot(a as u32),
656                        dot_b: dot_slot(b as u32),
657                    },
658                    TapeOp::Neg(a) => HOp::RevNeg {
659                        adj_a: adj_slot(a as u32),
660                        adj_dot_a: adj_dot_slot(a as u32),
661                        w,
662                        wd,
663                    },
664                    TapeOp::Abs(a) => HOp::RevAbs {
665                        adj_a: adj_slot(a as u32),
666                        adj_dot_a: adj_dot_slot(a as u32),
667                        w,
668                        wd,
669                        va: v_slot(a as u32),
670                    },
671                    TapeOp::Sqrt(a) => HOp::RevSqrt {
672                        adj_a: adj_slot(a as u32),
673                        adj_dot_a: adj_dot_slot(a as u32),
674                        w,
675                        wd,
676                        va: v_slot(a as u32),
677                        vd: v_slot(i_u),
678                        dot_a: dot_slot(a as u32),
679                    },
680                    TapeOp::Exp(a) => HOp::RevExp {
681                        adj_a: adj_slot(a as u32),
682                        adj_dot_a: adj_dot_slot(a as u32),
683                        w,
684                        wd,
685                        vd: v_slot(i_u),
686                        dot_a: dot_slot(a as u32),
687                    },
688                    TapeOp::Log(a) => HOp::RevLog {
689                        adj_a: adj_slot(a as u32),
690                        adj_dot_a: adj_dot_slot(a as u32),
691                        w,
692                        wd,
693                        va: v_slot(a as u32),
694                        dot_a: dot_slot(a as u32),
695                    },
696                    TapeOp::Log10(a) => HOp::RevLog10 {
697                        adj_a: adj_slot(a as u32),
698                        adj_dot_a: adj_dot_slot(a as u32),
699                        w,
700                        wd,
701                        va: v_slot(a as u32),
702                        dot_a: dot_slot(a as u32),
703                    },
704                    TapeOp::Sin(a) => HOp::RevSin {
705                        adj_a: adj_slot(a as u32),
706                        adj_dot_a: adj_dot_slot(a as u32),
707                        w,
708                        wd,
709                        va: v_slot(a as u32),
710                        dot_a: dot_slot(a as u32),
711                    },
712                    TapeOp::Cos(a) => HOp::RevCos {
713                        adj_a: adj_slot(a as u32),
714                        adj_dot_a: adj_dot_slot(a as u32),
715                        w,
716                        wd,
717                        va: v_slot(a as u32),
718                        dot_a: dot_slot(a as u32),
719                    },
720                    TapeOp::Funcall { .. } => panic!(
721                        "HessianProgram path does not support AMPL external functions; \
722                         use the Tape (build_with_externals) path instead."
723                    ),
724                };
725                ops.push(rev_op);
726            }
727        }
728
729        HessianProgram {
730            ops,
731            consts,
732            n_slots,
733        }
734    }
735
736    pub fn n_slots(&self) -> usize {
737        self.n_slots as usize
738    }
739
740    pub fn n_ops(&self) -> usize {
741        self.ops.len()
742    }
743
744    /// Execute the program. `scratch` is overwritten throughout;
745    /// it must be at least [`n_slots`] long. `values` is the
746    /// shared Hessian-values buffer the caller is accumulating
747    /// into (same semantics as
748    /// [`Tape::hessian_accumulate`]'s `values`).
749    pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
750        debug_assert!(scratch.len() >= self.n_slots as usize);
751        if self.ops.is_empty() || weight == 0.0 {
752            return;
753        }
754        let consts = &self.consts[..];
755        for &op in &self.ops {
756            match op {
757                HOp::FwdLoadVar { dst, x_idx } => {
758                    scratch[dst as usize] = x[x_idx as usize];
759                }
760                HOp::FwdLoadConst { dst, c_idx } => {
761                    scratch[dst as usize] = consts[c_idx as usize];
762                }
763                HOp::FwdAdd { dst, a, b } => {
764                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
765                }
766                HOp::FwdSub { dst, a, b } => {
767                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
768                }
769                HOp::FwdMul { dst, a, b } => {
770                    scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
771                }
772                HOp::FwdDiv { dst, a, b } => {
773                    scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
774                }
775                HOp::FwdPow { dst, a, b } => {
776                    scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
777                }
778                HOp::FwdNeg { dst, a } => {
779                    scratch[dst as usize] = -scratch[a as usize];
780                }
781                HOp::FwdAbs { dst, a } => {
782                    scratch[dst as usize] = scratch[a as usize].abs();
783                }
784                HOp::FwdSqrt { dst, a } => {
785                    scratch[dst as usize] = scratch[a as usize].sqrt();
786                }
787                HOp::FwdExp { dst, a } => {
788                    scratch[dst as usize] = scratch[a as usize].exp();
789                }
790                HOp::FwdLog { dst, a } => {
791                    scratch[dst as usize] = scratch[a as usize].ln();
792                }
793                HOp::FwdLog10 { dst, a } => {
794                    scratch[dst as usize] = scratch[a as usize].log10();
795                }
796                HOp::FwdSin { dst, a } => {
797                    scratch[dst as usize] = scratch[a as usize].sin();
798                }
799                HOp::FwdCos { dst, a } => {
800                    scratch[dst as usize] = scratch[a as usize].cos();
801                }
802
803                HOp::SetZero { dst } => {
804                    scratch[dst as usize] = 0.0;
805                }
806                HOp::SetOne { dst } => {
807                    scratch[dst as usize] = 1.0;
808                }
809                HOp::ZeroRange { start, len } => {
810                    let s = start as usize;
811                    let e = s + len as usize;
812                    scratch[s..e].fill(0.0);
813                }
814
815                HOp::DotAdd { dst, a, b } => {
816                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
817                }
818                HOp::DotSub { dst, a, b } => {
819                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
820                }
821                HOp::DotMul {
822                    dst,
823                    dot_a,
824                    vb,
825                    va,
826                    dot_b,
827                } => {
828                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
829                        + scratch[va as usize] * scratch[dot_b as usize];
830                }
831                HOp::DotDiv {
832                    dst,
833                    dot_a,
834                    vb,
835                    va,
836                    dot_b,
837                } => {
838                    let v_b = scratch[vb as usize];
839                    scratch[dst as usize] = (scratch[dot_a as usize] * v_b
840                        - scratch[va as usize] * scratch[dot_b as usize])
841                        / (v_b * v_b);
842                }
843                HOp::DotSqrt { dst, dot_a, vd } => {
844                    let svd = scratch[vd as usize];
845                    scratch[dst as usize] = if svd > 0.0 {
846                        scratch[dot_a as usize] * 0.5 / svd
847                    } else {
848                        0.0
849                    };
850                }
851                HOp::DotExp { dst, dot_a, vd } => {
852                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
853                }
854                HOp::DotLog { dst, dot_a, va } => {
855                    scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
856                }
857                HOp::DotLog10 { dst, dot_a, va } => {
858                    scratch[dst as usize] =
859                        scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
860                }
861                HOp::DotSin { dst, dot_a, va } => {
862                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
863                }
864                HOp::DotCos { dst, dot_a, va } => {
865                    scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
866                }
867                HOp::DotNeg { dst, dot_a } => {
868                    scratch[dst as usize] = -scratch[dot_a as usize];
869                }
870                HOp::DotAbs { dst, dot_a, va } => {
871                    scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
872                        scratch[dot_a as usize]
873                    } else {
874                        -scratch[dot_a as usize]
875                    };
876                }
877                HOp::DotPow {
878                    dst,
879                    va,
880                    vb,
881                    vd,
882                    dot_a,
883                    dot_b,
884                } => {
885                    let u = scratch[va as usize];
886                    let r = scratch[vb as usize];
887                    let du = scratch[dot_a as usize];
888                    let dr = scratch[dot_b as usize];
889                    let mut result = 0.0;
890                    if r != 0.0 && u != 0.0 {
891                        result += r * u.powf(r - 1.0) * du;
892                    }
893                    if u > 0.0 {
894                        result += scratch[vd as usize] * u.ln() * dr;
895                    }
896                    scratch[dst as usize] = result;
897                }
898
899                HOp::RevAdd {
900                    adj_a,
901                    adj_b,
902                    adj_dot_a,
903                    adj_dot_b,
904                    w,
905                    wd,
906                } => {
907                    let w_v = scratch[w as usize];
908                    let wd_v = scratch[wd as usize];
909                    scratch[adj_a as usize] += w_v;
910                    scratch[adj_b as usize] += w_v;
911                    scratch[adj_dot_a as usize] += wd_v;
912                    scratch[adj_dot_b as usize] += wd_v;
913                }
914                HOp::RevSub {
915                    adj_a,
916                    adj_b,
917                    adj_dot_a,
918                    adj_dot_b,
919                    w,
920                    wd,
921                } => {
922                    let w_v = scratch[w as usize];
923                    let wd_v = scratch[wd as usize];
924                    scratch[adj_a as usize] += w_v;
925                    scratch[adj_b as usize] -= w_v;
926                    scratch[adj_dot_a as usize] += wd_v;
927                    scratch[adj_dot_b as usize] -= wd_v;
928                }
929                HOp::RevMul {
930                    adj_a,
931                    adj_b,
932                    adj_dot_a,
933                    adj_dot_b,
934                    w,
935                    wd,
936                    va,
937                    vb,
938                    dot_a,
939                    dot_b,
940                } => {
941                    let w_v = scratch[w as usize];
942                    let wd_v = scratch[wd as usize];
943                    let va_v = scratch[va as usize];
944                    let vb_v = scratch[vb as usize];
945                    let da_v = scratch[dot_a as usize];
946                    let db_v = scratch[dot_b as usize];
947                    scratch[adj_a as usize] += w_v * vb_v;
948                    scratch[adj_b as usize] += w_v * va_v;
949                    scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
950                    scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
951                }
952                HOp::RevDiv {
953                    adj_a,
954                    adj_b,
955                    adj_dot_a,
956                    adj_dot_b,
957                    w,
958                    wd,
959                    va,
960                    vb,
961                    dot_a,
962                    dot_b,
963                } => {
964                    let w_v = scratch[w as usize];
965                    let wd_v = scratch[wd as usize];
966                    let va_v = scratch[va as usize];
967                    let vb_v = scratch[vb as usize];
968                    let vb2 = vb_v * vb_v;
969                    let vb3 = vb2 * vb_v;
970                    let da_v = scratch[dot_a as usize];
971                    let db_v = scratch[dot_b as usize];
972                    scratch[adj_a as usize] += w_v / vb_v;
973                    scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
974                    scratch[adj_b as usize] += w_v * (-va_v / vb2);
975                    scratch[adj_dot_b as usize] +=
976                        wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
977                }
978                HOp::RevPow {
979                    adj_a,
980                    adj_b,
981                    adj_dot_a,
982                    adj_dot_b,
983                    w,
984                    wd,
985                    va,
986                    vb,
987                    vd,
988                    dot_a,
989                    dot_b,
990                } => {
991                    let w_v = scratch[w as usize];
992                    let wd_v = scratch[wd as usize];
993                    let u = scratch[va as usize];
994                    let r = scratch[vb as usize];
995                    let du = scratch[dot_a as usize];
996                    let dr = scratch[dot_b as usize];
997                    if r != 0.0 {
998                        if u != 0.0 {
999                            let p_a = r * u.powf(r - 1.0);
1000                            scratch[adj_a as usize] += w_v * p_a;
1001                            let mut dp_a = dr * u.powf(r - 1.0);
1002                            if u > 0.0 {
1003                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1004                            } else {
1005                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1006                            }
1007                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1008                        } else if r >= 2.0 {
1009                            let p_a = 0.0;
1010                            scratch[adj_a as usize] += w_v * p_a;
1011                            let dp_a = if r == 2.0 {
1012                                2.0 * du
1013                            } else {
1014                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1015                            };
1016                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1017                        }
1018                    }
1019                    if u > 0.0 {
1020                        let ln_u = u.ln();
1021                        let p_b = scratch[vd as usize] * ln_u;
1022                        scratch[adj_b as usize] += w_v * p_b;
1023                        let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1024                        let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1025                        scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1026                    }
1027                }
1028                HOp::RevNeg {
1029                    adj_a,
1030                    adj_dot_a,
1031                    w,
1032                    wd,
1033                } => {
1034                    scratch[adj_a as usize] -= scratch[w as usize];
1035                    scratch[adj_dot_a as usize] -= scratch[wd as usize];
1036                }
1037                HOp::RevAbs {
1038                    adj_a,
1039                    adj_dot_a,
1040                    w,
1041                    wd,
1042                    va,
1043                } => {
1044                    let s = if scratch[va as usize] >= 0.0 {
1045                        1.0
1046                    } else {
1047                        -1.0
1048                    };
1049                    scratch[adj_a as usize] += scratch[w as usize] * s;
1050                    scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1051                }
1052                HOp::RevSqrt {
1053                    adj_a,
1054                    adj_dot_a,
1055                    w,
1056                    wd,
1057                    va: _,
1058                    vd,
1059                    dot_a,
1060                } => {
1061                    let sv = scratch[vd as usize];
1062                    if sv > 0.0 {
1063                        let fp = 0.5 / sv;
1064                        let fpp = -0.25 / (sv * sv * sv);
1065                        let w_v = scratch[w as usize];
1066                        let wd_v = scratch[wd as usize];
1067                        scratch[adj_a as usize] += w_v * fp;
1068                        scratch[adj_dot_a as usize] +=
1069                            wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1070                    }
1071                }
1072                HOp::RevExp {
1073                    adj_a,
1074                    adj_dot_a,
1075                    w,
1076                    wd,
1077                    vd,
1078                    dot_a,
1079                } => {
1080                    let ev = scratch[vd as usize];
1081                    let w_v = scratch[w as usize];
1082                    let wd_v = scratch[wd as usize];
1083                    scratch[adj_a as usize] += w_v * ev;
1084                    scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1085                }
1086                HOp::RevLog {
1087                    adj_a,
1088                    adj_dot_a,
1089                    w,
1090                    wd,
1091                    va,
1092                    dot_a,
1093                } => {
1094                    let u = scratch[va as usize];
1095                    let w_v = scratch[w as usize];
1096                    let wd_v = scratch[wd as usize];
1097                    scratch[adj_a as usize] += w_v / u;
1098                    scratch[adj_dot_a as usize] +=
1099                        wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1100                }
1101                HOp::RevLog10 {
1102                    adj_a,
1103                    adj_dot_a,
1104                    w,
1105                    wd,
1106                    va,
1107                    dot_a,
1108                } => {
1109                    let u = scratch[va as usize];
1110                    let c = std::f64::consts::LN_10;
1111                    let w_v = scratch[w as usize];
1112                    let wd_v = scratch[wd as usize];
1113                    scratch[adj_a as usize] += w_v / (u * c);
1114                    scratch[adj_dot_a as usize] +=
1115                        wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1116                }
1117                HOp::RevSin {
1118                    adj_a,
1119                    adj_dot_a,
1120                    w,
1121                    wd,
1122                    va,
1123                    dot_a,
1124                } => {
1125                    let u = scratch[va as usize];
1126                    let cu = u.cos();
1127                    let w_v = scratch[w as usize];
1128                    let wd_v = scratch[wd as usize];
1129                    scratch[adj_a as usize] += w_v * cu;
1130                    scratch[adj_dot_a as usize] +=
1131                        wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1132                }
1133                HOp::RevCos {
1134                    adj_a,
1135                    adj_dot_a,
1136                    w,
1137                    wd,
1138                    va,
1139                    dot_a,
1140                } => {
1141                    let u = scratch[va as usize];
1142                    let su = u.sin();
1143                    let w_v = scratch[w as usize];
1144                    let wd_v = scratch[wd as usize];
1145                    scratch[adj_a as usize] -= w_v * su;
1146                    scratch[adj_dot_a as usize] +=
1147                        wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1148                }
1149
1150                HOp::HessEmit {
1151                    hess_ptr,
1152                    adj_dot_slot,
1153                } => {
1154                    values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1155                }
1156            }
1157        }
1158    }
1159}
1160
1161/// `out[i]` = does tape slot `i` contribute (transitively) to the
1162/// output slot `n-1`. Used to skip emitting reverse-pass ops for
1163/// dead slots.
1164fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1165    let n = tape.ops.len();
1166    let mut r = vec![false; n];
1167    if n == 0 {
1168        return r;
1169    }
1170    r[n - 1] = true;
1171    for i in (0..n).rev() {
1172        if !r[i] {
1173            continue;
1174        }
1175        match tape.ops[i] {
1176            TapeOp::Const(_) | TapeOp::Var(_) => {}
1177            TapeOp::Add(a, b)
1178            | TapeOp::Sub(a, b)
1179            | TapeOp::Mul(a, b)
1180            | TapeOp::Div(a, b)
1181            | TapeOp::Pow(a, b) => {
1182                r[a] = true;
1183                r[b] = true;
1184            }
1185            TapeOp::Neg(a)
1186            | TapeOp::Abs(a)
1187            | TapeOp::Sqrt(a)
1188            | TapeOp::Exp(a)
1189            | TapeOp::Log(a)
1190            | TapeOp::Log10(a)
1191            | TapeOp::Sin(a)
1192            | TapeOp::Cos(a) => {
1193                r[a] = true;
1194            }
1195            TapeOp::Funcall { .. } => panic!(
1196                "HessianProgram path does not support AMPL external functions; \
1197                 use the Tape (build_with_externals) path instead."
1198            ),
1199        }
1200    }
1201    r
1202}
1203
1204/// `out[i]` = does tape slot `i` transitively read from `Var(j)`.
1205/// Used to prune forward-tangent ops (slots with `out[i] = false`
1206/// have `dot[i] = 0` and the rest of the per-`j` pass can skip
1207/// them).
1208fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1209    let n = tape.ops.len();
1210    let mut d = vec![false; n];
1211    for (i, op) in tape.ops.iter().enumerate() {
1212        d[i] = match *op {
1213            TapeOp::Const(_) => false,
1214            TapeOp::Var(k) => k == j,
1215            TapeOp::Add(a, b)
1216            | TapeOp::Sub(a, b)
1217            | TapeOp::Mul(a, b)
1218            | TapeOp::Div(a, b)
1219            | TapeOp::Pow(a, b) => d[a] || d[b],
1220            TapeOp::Neg(a)
1221            | TapeOp::Abs(a)
1222            | TapeOp::Sqrt(a)
1223            | TapeOp::Exp(a)
1224            | TapeOp::Log(a)
1225            | TapeOp::Log10(a)
1226            | TapeOp::Sin(a)
1227            | TapeOp::Cos(a) => d[a],
1228            TapeOp::Funcall { .. } => panic!(
1229                "HessianProgram path does not support AMPL external functions; \
1230                 use the Tape (build_with_externals) path instead."
1231            ),
1232        };
1233    }
1234    d
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239    use super::*;
1240    use crate::nl_reader::{BinOp, Expr, UnaryOp};
1241    use std::collections::BTreeSet;
1242    use std::rc::Rc;
1243
1244    fn cnst(c: f64) -> Expr {
1245        Expr::Const(c)
1246    }
1247    fn var(i: usize) -> Expr {
1248        Expr::Var(i)
1249    }
1250    fn add(a: Expr, b: Expr) -> Expr {
1251        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1252    }
1253    fn mul(a: Expr, b: Expr) -> Expr {
1254        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1255    }
1256    fn pow(a: Expr, b: Expr) -> Expr {
1257        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1258    }
1259    fn div(a: Expr, b: Expr) -> Expr {
1260        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1261    }
1262    fn sub(a: Expr, b: Expr) -> Expr {
1263        Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1264    }
1265    fn unary(op: UnaryOp, a: Expr) -> Expr {
1266        Expr::Unary(op, Box::new(a))
1267    }
1268
1269    /// Build the same shared (row, col) -> pos map both AD paths
1270    /// scatter into. Lower-triangle pairs, in tape.variables() order.
1271    fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1272        let vars = tape.variables();
1273        let mut pairs: Vec<(usize, usize)> = Vec::new();
1274        let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1275        for (ai, &vi) in vars.iter().enumerate() {
1276            for &vj in &vars[..=ai] {
1277                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1278                map.entry((r, c)).or_insert_with(|| {
1279                    let p = pairs.len();
1280                    pairs.push((r, c));
1281                    p
1282                });
1283            }
1284        }
1285        (map, pairs)
1286    }
1287
1288    /// Run both implementations against the same input and assert
1289    /// values match to a tight ULP-aligned tolerance.
1290    fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1291        let (hess_map, pairs) = build_hess_map(tape);
1292        let nnz = pairs.len();
1293
1294        let mut tape_vals = vec![0.0; nnz];
1295        tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1296
1297        let program = HessianProgram::compile(tape, &hess_map);
1298        let mut scratch = vec![0.0; program.n_slots()];
1299        let mut prog_vals = vec![0.0; nnz];
1300        program.execute(x, weight, &mut scratch, &mut prog_vals);
1301
1302        for (k, &(r, c)) in pairs.iter().enumerate() {
1303            let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1304            assert!(
1305                (tape_vals[k] - prog_vals[k]).abs() < tol,
1306                "H[{},{}]: tape={:.6e} prog={:.6e}",
1307                r,
1308                c,
1309                tape_vals[k],
1310                prog_vals[k]
1311            );
1312        }
1313    }
1314
1315    #[test]
1316    fn matches_quadratic() {
1317        let e = add(
1318            add(
1319                mul(cnst(3.0), pow(var(0), cnst(2.0))),
1320                mul(cnst(2.0), mul(var(0), var(1))),
1321            ),
1322            pow(var(1), cnst(2.0)),
1323        );
1324        let tape = Tape::build(&e);
1325        assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1326        assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1327    }
1328
1329    #[test]
1330    fn matches_transcendental() {
1331        let e = Expr::Sum(vec![
1332            unary(UnaryOp::Exp, var(0)),
1333            unary(UnaryOp::Sin, var(1)),
1334            unary(UnaryOp::Log, var(0)),
1335            unary(UnaryOp::Sqrt, var(1)),
1336            mul(var(0), var(1)),
1337            unary(UnaryOp::Cos, add(var(0), var(1))),
1338        ]);
1339        let tape = Tape::build(&e);
1340        assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1341        assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1342    }
1343
1344    #[test]
1345    fn matches_division() {
1346        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1347        let tape = Tape::build(&e);
1348        assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1349    }
1350
1351    #[test]
1352    fn matches_through_cse() {
1353        let body = Rc::new(add(var(0), var(1)));
1354        let e = add(
1355            pow(Expr::Cse(body.clone()), cnst(2.0)),
1356            Expr::Cse(body.clone()),
1357        );
1358        let tape = Tape::build(&e);
1359        assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1360        assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1361    }
1362
1363    #[test]
1364    fn matches_pow_chain() {
1365        // After Tier 1 this lowers to a Mul chain; verify both
1366        // paths agree on the lowered form too.
1367        let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1368        let tape = Tape::build(&e);
1369        assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1370    }
1371
1372    #[test]
1373    fn matches_residual_pow_with_var_exponent() {
1374        // Pow where the exponent is variable (not constant), so
1375        // it survives Tier 1 and exercises the RevPow / DotPow
1376        // compound branches.
1377        let e = pow(var(0), var(1));
1378        let tape = Tape::build(&e);
1379        assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1380        assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1381    }
1382
1383    #[test]
1384    fn matches_sub_neg_abs() {
1385        let e = sub(
1386            unary(UnaryOp::Neg, var(0)),
1387            unary(UnaryOp::Abs, sub(var(1), var(0))),
1388        );
1389        let tape = Tape::build(&e);
1390        assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1391        assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1392    }
1393
1394    #[test]
1395    fn slots_layout_matches_design() {
1396        let e = mul(var(0), var(1));
1397        let tape = Tape::build(&e);
1398        let (hess_map, _) = build_hess_map(&tape);
1399        let prog = HessianProgram::compile(&tape, &hess_map);
1400        assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1401    }
1402
1403    /// Sanity: the pruning analyses are consistent with the slot
1404    /// structure exposed via `hessian_sparsity()`.
1405    #[test]
1406    fn dependence_matches_hessian_sparsity_for_simple_case() {
1407        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1408        let tape = Tape::build(&e);
1409        let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1410        // (0,0) from sin, (2,1) from x1*x2, (1,1)/(2,2) NOT there
1411        // because Mul(x1, x2) emits cross only.
1412        assert!(s.contains(&(0, 0)));
1413        assert!(s.contains(&(2, 1)));
1414        assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1415    }
1416}