Skip to main content

pounce_cli/
nl_hessian_program.rs

1//! Precompiled symbolic-Hessian program for one `Tape`.
2//!
3//! `Tape::hessian_accumulate` runs forward-over-reverse AD at every
4//! call: for each tape variable `j` it (a) match-dispatches every op
5//! in a forward-tangent sweep, (b) zeros adj/adj_dot, (c)
6//! match-dispatches every op again in the reverse-over-tangent
7//! sweep, and (d) HashMap-looks-up every `Var(k)` slot to find its
8//! Hessian output position. On evaluator-bound problems (dirichlet,
9//! lane_emden, henon) that match-dispatch + symbolic-AD overhead is
10//! ~80% of total CPU.
11//!
12//! This module compiles all of that ONCE at tape-build time into a
13//! flat `Vec<HOp>` of pre-resolved primitive ops:
14//!
15//!   * Forward pass — one `Fwd*` op per `TapeOp`. Mirrors
16//!     `Tape::forward`.
17//!   * Per-`j` forward tangent — only the ops touching slots that
18//!     statically depend on `j` are emitted (the rest stay zero
19//!     from the per-`j` `ZeroRange` reset).
20//!   * Per-`j` reverse-over-tangent — only ops on slots reachable
21//!     backward from output, with all slot indices and Hessian
22//!     output pointers pre-resolved.
23//!
24//! ## Scratch layout
25//!
26//! The program reads/writes a single `&mut [f64]` arena of
27//! `n_slots` cells. We allocate four contiguous regions of length
28//! `n` (`n` = `tape.ops.len()`):
29//!
30//!   * `v[i]`        in slot `i`
31//!   * `dot[i]`      in slot `n + i`
32//!   * `adj[i]`      in slot `2n + i`
33//!   * `adj_dot[i]`  in slot `3n + i`
34//!
35//! Per-`j` setup zeros the `[n, 4n)` range and seeds `adj[n-1]`.
36//! Allocation pattern is intentionally trivial — finer-grained
37//! slot recycling buys little once the dispatch loop is the
38//! bottleneck, and a contiguous layout makes the per-`j`
39//! `ZeroRange` reset a single `memset`-friendly loop.
40
41use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45/// One primitive operation in the compiled Hessian program.
46/// `dst`/`a`/`b`/etc. are `u32` offsets into the caller's scratch
47/// slice; see the module docs for the slot layout.
48#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50    // ===== Forward pass =====
51    FwdLoadVar {
52        dst: u32,
53        x_idx: u32,
54    },
55    FwdLoadConst {
56        dst: u32,
57        c_idx: u32,
58    },
59    FwdAdd {
60        dst: u32,
61        a: u32,
62        b: u32,
63    },
64    FwdSub {
65        dst: u32,
66        a: u32,
67        b: u32,
68    },
69    FwdMul {
70        dst: u32,
71        a: u32,
72        b: u32,
73    },
74    FwdDiv {
75        dst: u32,
76        a: u32,
77        b: u32,
78    },
79    FwdPow {
80        dst: u32,
81        a: u32,
82        b: u32,
83    },
84    FwdNeg {
85        dst: u32,
86        a: u32,
87    },
88    FwdAbs {
89        dst: u32,
90        a: u32,
91    },
92    FwdSqrt {
93        dst: u32,
94        a: u32,
95    },
96    FwdExp {
97        dst: u32,
98        a: u32,
99    },
100    FwdLog {
101        dst: u32,
102        a: u32,
103    },
104    FwdLog10 {
105        dst: u32,
106        a: u32,
107    },
108    FwdSin {
109        dst: u32,
110        a: u32,
111    },
112    FwdCos {
113        dst: u32,
114        a: u32,
115    },
116
117    // ===== Scalar slot init =====
118    SetZero {
119        dst: u32,
120    },
121    SetOne {
122        dst: u32,
123    },
124
125    // ===== Bulk reset (start of each j) =====
126    ZeroRange {
127        start: u32,
128        len: u32,
129    },
130
131    // ===== Forward tangent (per j) =====
132    DotAdd {
133        dst: u32,
134        a: u32,
135        b: u32,
136    },
137    DotSub {
138        dst: u32,
139        a: u32,
140        b: u32,
141    },
142    /// dot[d] = dot[a]*v[b] + v[a]*dot[b]
143    DotMul {
144        dst: u32,
145        dot_a: u32,
146        vb: u32,
147        va: u32,
148        dot_b: u32,
149    },
150    /// dot[d] = (dot[a]*v[b] - v[a]*dot[b]) / (v[b]*v[b])
151    DotDiv {
152        dst: u32,
153        dot_a: u32,
154        vb: u32,
155        va: u32,
156        dot_b: u32,
157    },
158    /// dot[d] = 0.5 / v[d] * dot[a]  (v[d] = sqrt(v[a]))
159    DotSqrt {
160        dst: u32,
161        dot_a: u32,
162        vd: u32,
163    },
164    /// dot[d] = v[d] * dot[a]  (v[d] = exp(v[a]))
165    DotExp {
166        dst: u32,
167        dot_a: u32,
168        vd: u32,
169    },
170    DotLog {
171        dst: u32,
172        dot_a: u32,
173        va: u32,
174    },
175    DotLog10 {
176        dst: u32,
177        dot_a: u32,
178        va: u32,
179    },
180    DotSin {
181        dst: u32,
182        dot_a: u32,
183        va: u32,
184    },
185    DotCos {
186        dst: u32,
187        dot_a: u32,
188        va: u32,
189    },
190    DotNeg {
191        dst: u32,
192        dot_a: u32,
193    },
194    DotAbs {
195        dst: u32,
196        dot_a: u32,
197        va: u32,
198    },
199    /// Compound: dot[d] for Pow(a, b). Carries the runtime
200    /// `u != 0` / `u > 0` branches.
201    DotPow {
202        dst: u32,
203        va: u32,
204        vb: u32,
205        vd: u32,
206        dot_a: u32,
207        dot_b: u32,
208    },
209
210    // ===== Reverse + adj_dot update (per j) =====
211    // Each op consumes adj[i] (= `w`) and adj_dot[i] (= `wd`) of
212    // the consumer slot, then `+=`-accumulates into the adj /
213    // adj_dot of the operand slots.
214    RevAdd {
215        adj_a: u32,
216        adj_b: u32,
217        adj_dot_a: u32,
218        adj_dot_b: u32,
219        w: u32,
220        wd: u32,
221    },
222    RevSub {
223        adj_a: u32,
224        adj_b: u32,
225        adj_dot_a: u32,
226        adj_dot_b: u32,
227        w: u32,
228        wd: u32,
229    },
230    RevMul {
231        adj_a: u32,
232        adj_b: u32,
233        adj_dot_a: u32,
234        adj_dot_b: u32,
235        w: u32,
236        wd: u32,
237        va: u32,
238        vb: u32,
239        dot_a: u32,
240        dot_b: u32,
241    },
242    RevDiv {
243        adj_a: u32,
244        adj_b: u32,
245        adj_dot_a: u32,
246        adj_dot_b: u32,
247        w: u32,
248        wd: u32,
249        va: u32,
250        vb: u32,
251        dot_a: u32,
252        dot_b: u32,
253    },
254    RevPow {
255        adj_a: u32,
256        adj_b: u32,
257        adj_dot_a: u32,
258        adj_dot_b: u32,
259        w: u32,
260        wd: u32,
261        va: u32,
262        vb: u32,
263        vd: u32,
264        dot_a: u32,
265        dot_b: u32,
266    },
267    RevNeg {
268        adj_a: u32,
269        adj_dot_a: u32,
270        w: u32,
271        wd: u32,
272    },
273    RevAbs {
274        adj_a: u32,
275        adj_dot_a: u32,
276        w: u32,
277        wd: u32,
278        va: u32,
279    },
280    RevSqrt {
281        adj_a: u32,
282        adj_dot_a: u32,
283        w: u32,
284        wd: u32,
285        va: u32,
286        vd: u32,
287        dot_a: u32,
288    },
289    RevExp {
290        adj_a: u32,
291        adj_dot_a: u32,
292        w: u32,
293        wd: u32,
294        vd: u32,
295        dot_a: u32,
296    },
297    RevLog {
298        adj_a: u32,
299        adj_dot_a: u32,
300        w: u32,
301        wd: u32,
302        va: u32,
303        dot_a: u32,
304    },
305    RevLog10 {
306        adj_a: u32,
307        adj_dot_a: u32,
308        w: u32,
309        wd: u32,
310        va: u32,
311        dot_a: u32,
312    },
313    RevSin {
314        adj_a: u32,
315        adj_dot_a: u32,
316        w: u32,
317        wd: u32,
318        va: u32,
319        dot_a: u32,
320    },
321    RevCos {
322        adj_a: u32,
323        adj_dot_a: u32,
324        w: u32,
325        wd: u32,
326        va: u32,
327        dot_a: u32,
328    },
329
330    // ===== Output =====
331    /// values[hess_ptr] += weight * scratch[adj_dot_slot].
332    HessEmit {
333        hess_ptr: u32,
334        adj_dot_slot: u32,
335    },
336}
337
338/// Precompiled Hessian-of-one-tape program. Built once via
339/// [`HessianProgram::compile`]; executed many times.
340#[derive(Debug, Clone)]
341pub struct HessianProgram {
342    ops: Vec<HOp>,
343    consts: Vec<f64>,
344    n_slots: u32,
345}
346
347impl HessianProgram {
348    /// Build the program. The `hess_map` is the same `(row, col)
349    /// -> values-index` map that [`Tape::hessian_accumulate`] uses;
350    /// the compiler inlines each lookup into a `HessEmit` op.
351    pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Self {
352        let n = tape.ops.len() as u32;
353        let v_base = 0u32;
354        let dot_base = n;
355        let adj_base = 2 * n;
356        let adj_dot_base = 3 * n;
357        let n_slots = 4 * n;
358
359        let v_slot = |i: u32| v_base + i;
360        let dot_slot = |i: u32| dot_base + i;
361        let adj_slot = |i: u32| adj_base + i;
362        let adj_dot_slot = |i: u32| adj_dot_base + i;
363
364        let reachable = reachable_to_output(tape);
365        let var_indices = tape.variables();
366        // depends_on[k_idx][i] — does slot i depend on var_indices[k_idx]?
367        let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
368            .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
369            .collect();
370
371        let mut consts: Vec<f64> = Vec::new();
372        let mut const_intern: HashMap<u64, u32> = HashMap::new();
373        let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
374            let bits = c.to_bits();
375            if let Some(&idx) = const_intern.get(&bits) {
376                return idx;
377            }
378            let idx = consts.len() as u32;
379            consts.push(c);
380            const_intern.insert(bits, idx);
381            idx
382        };
383
384        let mut ops: Vec<HOp> = Vec::new();
385
386        // ---- Forward pass ----
387        for (i, tape_op) in tape.ops.iter().enumerate() {
388            let i = i as u32;
389            let dst = v_slot(i);
390            let op = match *tape_op {
391                TapeOp::Const(c) => HOp::FwdLoadConst {
392                    dst,
393                    c_idx: intern_const(c, &mut consts),
394                },
395                TapeOp::Var(x_idx) => HOp::FwdLoadVar {
396                    dst,
397                    x_idx: x_idx as u32,
398                },
399                TapeOp::Add(a, b) => HOp::FwdAdd {
400                    dst,
401                    a: v_slot(a as u32),
402                    b: v_slot(b as u32),
403                },
404                TapeOp::Sub(a, b) => HOp::FwdSub {
405                    dst,
406                    a: v_slot(a as u32),
407                    b: v_slot(b as u32),
408                },
409                TapeOp::Mul(a, b) => HOp::FwdMul {
410                    dst,
411                    a: v_slot(a as u32),
412                    b: v_slot(b as u32),
413                },
414                TapeOp::Div(a, b) => HOp::FwdDiv {
415                    dst,
416                    a: v_slot(a as u32),
417                    b: v_slot(b as u32),
418                },
419                TapeOp::Pow(a, b) => HOp::FwdPow {
420                    dst,
421                    a: v_slot(a as u32),
422                    b: v_slot(b as u32),
423                },
424                TapeOp::Neg(a) => HOp::FwdNeg {
425                    dst,
426                    a: v_slot(a as u32),
427                },
428                TapeOp::Abs(a) => HOp::FwdAbs {
429                    dst,
430                    a: v_slot(a as u32),
431                },
432                TapeOp::Sqrt(a) => HOp::FwdSqrt {
433                    dst,
434                    a: v_slot(a as u32),
435                },
436                TapeOp::Exp(a) => HOp::FwdExp {
437                    dst,
438                    a: v_slot(a as u32),
439                },
440                TapeOp::Log(a) => HOp::FwdLog {
441                    dst,
442                    a: v_slot(a as u32),
443                },
444                TapeOp::Log10(a) => HOp::FwdLog10 {
445                    dst,
446                    a: v_slot(a as u32),
447                },
448                TapeOp::Sin(a) => HOp::FwdSin {
449                    dst,
450                    a: v_slot(a as u32),
451                },
452                TapeOp::Cos(a) => HOp::FwdCos {
453                    dst,
454                    a: v_slot(a as u32),
455                },
456                TapeOp::Funcall(_) => panic!(
457                    "HessianProgram path does not support AMPL external functions; \
458                     use the Tape (build_with_externals) path instead."
459                ),
460                TapeOp::Tan(_)
461                | TapeOp::Atan(_)
462                | TapeOp::Acos(_)
463                | TapeOp::Sinh(_)
464                | TapeOp::Cosh(_)
465                | TapeOp::Tanh(_)
466                | TapeOp::Asin(_)
467                | TapeOp::Acosh(_)
468                | TapeOp::Asinh(_)
469                | TapeOp::Atanh(_)
470                | TapeOp::Atan2(_, _)
471                | TapeOp::Cmp(_, _, _)
472                | TapeOp::And(_, _)
473                | TapeOp::Or(_, _)
474                | TapeOp::Not(_)
475                | TapeOp::Select(_, _, _)
476                | TapeOp::Min(_, _)
477                | TapeOp::Max(_, _) => panic!(
478                    "HessianProgram path does not yet support tan/atan/acos, the \
479                     other transcendental opcodes, atan2, min/max, or \
480                     conditional / logical opcodes; use the Tape \
481                     (build_with_externals) interpreter path instead."
482                ),
483            };
484            ops.push(op);
485        }
486
487        if n == 0 || var_indices.is_empty() {
488            return HessianProgram {
489                ops,
490                consts,
491                n_slots,
492            };
493        }
494
495        // ---- Per-j forward-tangent + reverse-over-tangent ----
496        for (k_idx, &j) in var_indices.iter().enumerate() {
497            // Reset dot, adj, adj_dot for this j. Seed adj[n-1] = 1.
498            ops.push(HOp::ZeroRange {
499                start: dot_base,
500                len: 3 * n,
501            });
502            ops.push(HOp::SetOne {
503                dst: adj_slot(n - 1),
504            });
505
506            // Forward tangent: only emit ops for slots that
507            // statically depend on j (the rest stay zero from the
508            // ZeroRange above).
509            for (i, tape_op) in tape.ops.iter().enumerate() {
510                let i_u = i as u32;
511                if !depends_on[k_idx][i] {
512                    continue;
513                }
514                let dst = dot_slot(i_u);
515                let dot_op = match *tape_op {
516                    // Const: dot stays 0 (filtered above by
517                    // depends_on, since Const has no var-deps).
518                    TapeOp::Const(_) => continue,
519                    // Var(k): dot = 1 iff k == j, else 0. We only
520                    // get here if depends_on[k_idx][i] is true,
521                    // which for Var(k) means k == j.
522                    TapeOp::Var(_) => HOp::SetOne { dst },
523                    TapeOp::Add(a, b) => HOp::DotAdd {
524                        dst,
525                        a: dot_slot(a as u32),
526                        b: dot_slot(b as u32),
527                    },
528                    TapeOp::Sub(a, b) => HOp::DotSub {
529                        dst,
530                        a: dot_slot(a as u32),
531                        b: dot_slot(b as u32),
532                    },
533                    TapeOp::Mul(a, b) => HOp::DotMul {
534                        dst,
535                        dot_a: dot_slot(a as u32),
536                        vb: v_slot(b as u32),
537                        va: v_slot(a as u32),
538                        dot_b: dot_slot(b as u32),
539                    },
540                    TapeOp::Div(a, b) => HOp::DotDiv {
541                        dst,
542                        dot_a: dot_slot(a as u32),
543                        vb: v_slot(b as u32),
544                        va: v_slot(a as u32),
545                        dot_b: dot_slot(b as u32),
546                    },
547                    TapeOp::Pow(a, b) => HOp::DotPow {
548                        dst,
549                        va: v_slot(a as u32),
550                        vb: v_slot(b as u32),
551                        vd: v_slot(i_u),
552                        dot_a: dot_slot(a as u32),
553                        dot_b: dot_slot(b as u32),
554                    },
555                    TapeOp::Neg(a) => HOp::DotNeg {
556                        dst,
557                        dot_a: dot_slot(a as u32),
558                    },
559                    TapeOp::Abs(a) => HOp::DotAbs {
560                        dst,
561                        dot_a: dot_slot(a as u32),
562                        va: v_slot(a as u32),
563                    },
564                    TapeOp::Sqrt(a) => HOp::DotSqrt {
565                        dst,
566                        dot_a: dot_slot(a as u32),
567                        vd: v_slot(i_u),
568                    },
569                    TapeOp::Exp(a) => HOp::DotExp {
570                        dst,
571                        dot_a: dot_slot(a as u32),
572                        vd: v_slot(i_u),
573                    },
574                    TapeOp::Log(a) => HOp::DotLog {
575                        dst,
576                        dot_a: dot_slot(a as u32),
577                        va: v_slot(a as u32),
578                    },
579                    TapeOp::Log10(a) => HOp::DotLog10 {
580                        dst,
581                        dot_a: dot_slot(a as u32),
582                        va: v_slot(a as u32),
583                    },
584                    TapeOp::Sin(a) => HOp::DotSin {
585                        dst,
586                        dot_a: dot_slot(a as u32),
587                        va: v_slot(a as u32),
588                    },
589                    TapeOp::Cos(a) => HOp::DotCos {
590                        dst,
591                        dot_a: dot_slot(a as u32),
592                        va: v_slot(a as u32),
593                    },
594                    TapeOp::Funcall(_) => panic!(
595                        "HessianProgram path does not support AMPL external functions; \
596                         use the Tape (build_with_externals) path instead."
597                    ),
598                    TapeOp::Tan(_)
599                    | TapeOp::Atan(_)
600                    | TapeOp::Acos(_)
601                    | TapeOp::Sinh(_)
602                    | TapeOp::Cosh(_)
603                    | TapeOp::Tanh(_)
604                    | TapeOp::Asin(_)
605                    | TapeOp::Acosh(_)
606                    | TapeOp::Asinh(_)
607                    | TapeOp::Atanh(_)
608                    | TapeOp::Atan2(_, _)
609                    | TapeOp::Cmp(_, _, _)
610                    | TapeOp::And(_, _)
611                    | TapeOp::Or(_, _)
612                    | TapeOp::Not(_)
613                    | TapeOp::Select(_, _, _)
614                    | TapeOp::Min(_, _)
615                    | TapeOp::Max(_, _) => panic!(
616                        "HessianProgram path does not yet support tan/atan/acos, the \
617                         other transcendental opcodes, atan2, min/max, or \
618                         conditional / logical opcodes; use the Tape \
619                         (build_with_externals) interpreter path instead."
620                    ),
621                };
622                ops.push(dot_op);
623            }
624
625            // Reverse-over-tangent: walk slots backward, emit only
626            // for reachable slots.
627            for i in (0..n as usize).rev() {
628                if !reachable[i] {
629                    continue;
630                }
631                let i_u = i as u32;
632                let w = adj_slot(i_u);
633                let wd = adj_dot_slot(i_u);
634                let tape_op = &tape.ops[i];
635                let rev_op = match *tape_op {
636                    TapeOp::Const(_) => continue,
637                    TapeOp::Var(k) => {
638                        // At a Var slot: if k >= j and hess_map has
639                        // an entry for (k, j), emit a HessEmit op.
640                        // No adj/adj_dot propagation (no operands).
641                        if k >= j {
642                            if let Some(&ptr) = hess_map.get(&(k, j)) {
643                                ops.push(HOp::HessEmit {
644                                    hess_ptr: ptr as u32,
645                                    adj_dot_slot: wd,
646                                });
647                            }
648                        }
649                        continue;
650                    }
651                    TapeOp::Add(a, b) => HOp::RevAdd {
652                        adj_a: adj_slot(a as u32),
653                        adj_b: adj_slot(b as u32),
654                        adj_dot_a: adj_dot_slot(a as u32),
655                        adj_dot_b: adj_dot_slot(b as u32),
656                        w,
657                        wd,
658                    },
659                    TapeOp::Sub(a, b) => HOp::RevSub {
660                        adj_a: adj_slot(a as u32),
661                        adj_b: adj_slot(b as u32),
662                        adj_dot_a: adj_dot_slot(a as u32),
663                        adj_dot_b: adj_dot_slot(b as u32),
664                        w,
665                        wd,
666                    },
667                    TapeOp::Mul(a, b) => HOp::RevMul {
668                        adj_a: adj_slot(a as u32),
669                        adj_b: adj_slot(b as u32),
670                        adj_dot_a: adj_dot_slot(a as u32),
671                        adj_dot_b: adj_dot_slot(b as u32),
672                        w,
673                        wd,
674                        va: v_slot(a as u32),
675                        vb: v_slot(b as u32),
676                        dot_a: dot_slot(a as u32),
677                        dot_b: dot_slot(b as u32),
678                    },
679                    TapeOp::Div(a, b) => HOp::RevDiv {
680                        adj_a: adj_slot(a as u32),
681                        adj_b: adj_slot(b as u32),
682                        adj_dot_a: adj_dot_slot(a as u32),
683                        adj_dot_b: adj_dot_slot(b as u32),
684                        w,
685                        wd,
686                        va: v_slot(a as u32),
687                        vb: v_slot(b as u32),
688                        dot_a: dot_slot(a as u32),
689                        dot_b: dot_slot(b as u32),
690                    },
691                    TapeOp::Pow(a, b) => HOp::RevPow {
692                        adj_a: adj_slot(a as u32),
693                        adj_b: adj_slot(b as u32),
694                        adj_dot_a: adj_dot_slot(a as u32),
695                        adj_dot_b: adj_dot_slot(b as u32),
696                        w,
697                        wd,
698                        va: v_slot(a as u32),
699                        vb: v_slot(b as u32),
700                        vd: v_slot(i_u),
701                        dot_a: dot_slot(a as u32),
702                        dot_b: dot_slot(b as u32),
703                    },
704                    TapeOp::Neg(a) => HOp::RevNeg {
705                        adj_a: adj_slot(a as u32),
706                        adj_dot_a: adj_dot_slot(a as u32),
707                        w,
708                        wd,
709                    },
710                    TapeOp::Abs(a) => HOp::RevAbs {
711                        adj_a: adj_slot(a as u32),
712                        adj_dot_a: adj_dot_slot(a as u32),
713                        w,
714                        wd,
715                        va: v_slot(a as u32),
716                    },
717                    TapeOp::Sqrt(a) => HOp::RevSqrt {
718                        adj_a: adj_slot(a as u32),
719                        adj_dot_a: adj_dot_slot(a as u32),
720                        w,
721                        wd,
722                        va: v_slot(a as u32),
723                        vd: v_slot(i_u),
724                        dot_a: dot_slot(a as u32),
725                    },
726                    TapeOp::Exp(a) => HOp::RevExp {
727                        adj_a: adj_slot(a as u32),
728                        adj_dot_a: adj_dot_slot(a as u32),
729                        w,
730                        wd,
731                        vd: v_slot(i_u),
732                        dot_a: dot_slot(a as u32),
733                    },
734                    TapeOp::Log(a) => HOp::RevLog {
735                        adj_a: adj_slot(a as u32),
736                        adj_dot_a: adj_dot_slot(a as u32),
737                        w,
738                        wd,
739                        va: v_slot(a as u32),
740                        dot_a: dot_slot(a as u32),
741                    },
742                    TapeOp::Log10(a) => HOp::RevLog10 {
743                        adj_a: adj_slot(a as u32),
744                        adj_dot_a: adj_dot_slot(a as u32),
745                        w,
746                        wd,
747                        va: v_slot(a as u32),
748                        dot_a: dot_slot(a as u32),
749                    },
750                    TapeOp::Sin(a) => HOp::RevSin {
751                        adj_a: adj_slot(a as u32),
752                        adj_dot_a: adj_dot_slot(a as u32),
753                        w,
754                        wd,
755                        va: v_slot(a as u32),
756                        dot_a: dot_slot(a as u32),
757                    },
758                    TapeOp::Cos(a) => HOp::RevCos {
759                        adj_a: adj_slot(a as u32),
760                        adj_dot_a: adj_dot_slot(a as u32),
761                        w,
762                        wd,
763                        va: v_slot(a as u32),
764                        dot_a: dot_slot(a as u32),
765                    },
766                    TapeOp::Funcall(_) => panic!(
767                        "HessianProgram path does not support AMPL external functions; \
768                         use the Tape (build_with_externals) path instead."
769                    ),
770                    TapeOp::Tan(_)
771                    | TapeOp::Atan(_)
772                    | TapeOp::Acos(_)
773                    | TapeOp::Sinh(_)
774                    | TapeOp::Cosh(_)
775                    | TapeOp::Tanh(_)
776                    | TapeOp::Asin(_)
777                    | TapeOp::Acosh(_)
778                    | TapeOp::Asinh(_)
779                    | TapeOp::Atanh(_)
780                    | TapeOp::Atan2(_, _)
781                    | TapeOp::Cmp(_, _, _)
782                    | TapeOp::And(_, _)
783                    | TapeOp::Or(_, _)
784                    | TapeOp::Not(_)
785                    | TapeOp::Select(_, _, _)
786                    | TapeOp::Min(_, _)
787                    | TapeOp::Max(_, _) => panic!(
788                        "HessianProgram path does not yet support tan/atan/acos, the \
789                         other transcendental opcodes, atan2, min/max, or \
790                         conditional / logical opcodes; use the Tape \
791                         (build_with_externals) interpreter path instead."
792                    ),
793                };
794                ops.push(rev_op);
795            }
796        }
797
798        HessianProgram {
799            ops,
800            consts,
801            n_slots,
802        }
803    }
804
805    pub fn n_slots(&self) -> usize {
806        self.n_slots as usize
807    }
808
809    pub fn n_ops(&self) -> usize {
810        self.ops.len()
811    }
812
813    /// Execute the program. `scratch` is overwritten throughout;
814    /// it must be at least [`n_slots`] long. `values` is the
815    /// shared Hessian-values buffer the caller is accumulating
816    /// into (same semantics as
817    /// [`Tape::hessian_accumulate`]'s `values`).
818    pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
819        debug_assert!(scratch.len() >= self.n_slots as usize);
820        if self.ops.is_empty() || weight == 0.0 {
821            return;
822        }
823        let consts = &self.consts[..];
824        for &op in &self.ops {
825            match op {
826                HOp::FwdLoadVar { dst, x_idx } => {
827                    scratch[dst as usize] = x[x_idx as usize];
828                }
829                HOp::FwdLoadConst { dst, c_idx } => {
830                    scratch[dst as usize] = consts[c_idx as usize];
831                }
832                HOp::FwdAdd { dst, a, b } => {
833                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
834                }
835                HOp::FwdSub { dst, a, b } => {
836                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
837                }
838                HOp::FwdMul { dst, a, b } => {
839                    scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
840                }
841                HOp::FwdDiv { dst, a, b } => {
842                    scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
843                }
844                HOp::FwdPow { dst, a, b } => {
845                    scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
846                }
847                HOp::FwdNeg { dst, a } => {
848                    scratch[dst as usize] = -scratch[a as usize];
849                }
850                HOp::FwdAbs { dst, a } => {
851                    scratch[dst as usize] = scratch[a as usize].abs();
852                }
853                HOp::FwdSqrt { dst, a } => {
854                    scratch[dst as usize] = scratch[a as usize].sqrt();
855                }
856                HOp::FwdExp { dst, a } => {
857                    scratch[dst as usize] = scratch[a as usize].exp();
858                }
859                HOp::FwdLog { dst, a } => {
860                    scratch[dst as usize] = scratch[a as usize].ln();
861                }
862                HOp::FwdLog10 { dst, a } => {
863                    scratch[dst as usize] = scratch[a as usize].log10();
864                }
865                HOp::FwdSin { dst, a } => {
866                    scratch[dst as usize] = scratch[a as usize].sin();
867                }
868                HOp::FwdCos { dst, a } => {
869                    scratch[dst as usize] = scratch[a as usize].cos();
870                }
871
872                HOp::SetZero { dst } => {
873                    scratch[dst as usize] = 0.0;
874                }
875                HOp::SetOne { dst } => {
876                    scratch[dst as usize] = 1.0;
877                }
878                HOp::ZeroRange { start, len } => {
879                    let s = start as usize;
880                    let e = s + len as usize;
881                    scratch[s..e].fill(0.0);
882                }
883
884                HOp::DotAdd { dst, a, b } => {
885                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
886                }
887                HOp::DotSub { dst, a, b } => {
888                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
889                }
890                HOp::DotMul {
891                    dst,
892                    dot_a,
893                    vb,
894                    va,
895                    dot_b,
896                } => {
897                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
898                        + scratch[va as usize] * scratch[dot_b as usize];
899                }
900                HOp::DotDiv {
901                    dst,
902                    dot_a,
903                    vb,
904                    va,
905                    dot_b,
906                } => {
907                    let v_b = scratch[vb as usize];
908                    scratch[dst as usize] = (scratch[dot_a as usize] * v_b
909                        - scratch[va as usize] * scratch[dot_b as usize])
910                        / (v_b * v_b);
911                }
912                HOp::DotSqrt { dst, dot_a, vd } => {
913                    let svd = scratch[vd as usize];
914                    scratch[dst as usize] = if svd > 0.0 {
915                        scratch[dot_a as usize] * 0.5 / svd
916                    } else {
917                        0.0
918                    };
919                }
920                HOp::DotExp { dst, dot_a, vd } => {
921                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
922                }
923                HOp::DotLog { dst, dot_a, va } => {
924                    scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
925                }
926                HOp::DotLog10 { dst, dot_a, va } => {
927                    scratch[dst as usize] =
928                        scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
929                }
930                HOp::DotSin { dst, dot_a, va } => {
931                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
932                }
933                HOp::DotCos { dst, dot_a, va } => {
934                    scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
935                }
936                HOp::DotNeg { dst, dot_a } => {
937                    scratch[dst as usize] = -scratch[dot_a as usize];
938                }
939                HOp::DotAbs { dst, dot_a, va } => {
940                    scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
941                        scratch[dot_a as usize]
942                    } else {
943                        -scratch[dot_a as usize]
944                    };
945                }
946                HOp::DotPow {
947                    dst,
948                    va,
949                    vb,
950                    vd,
951                    dot_a,
952                    dot_b,
953                } => {
954                    let u = scratch[va as usize];
955                    let r = scratch[vb as usize];
956                    let du = scratch[dot_a as usize];
957                    let dr = scratch[dot_b as usize];
958                    let mut result = 0.0;
959                    if r != 0.0 && u != 0.0 {
960                        result += r * u.powf(r - 1.0) * du;
961                    }
962                    if u > 0.0 {
963                        result += scratch[vd as usize] * u.ln() * dr;
964                    }
965                    scratch[dst as usize] = result;
966                }
967
968                HOp::RevAdd {
969                    adj_a,
970                    adj_b,
971                    adj_dot_a,
972                    adj_dot_b,
973                    w,
974                    wd,
975                } => {
976                    let w_v = scratch[w as usize];
977                    let wd_v = scratch[wd as usize];
978                    scratch[adj_a as usize] += w_v;
979                    scratch[adj_b as usize] += w_v;
980                    scratch[adj_dot_a as usize] += wd_v;
981                    scratch[adj_dot_b as usize] += wd_v;
982                }
983                HOp::RevSub {
984                    adj_a,
985                    adj_b,
986                    adj_dot_a,
987                    adj_dot_b,
988                    w,
989                    wd,
990                } => {
991                    let w_v = scratch[w as usize];
992                    let wd_v = scratch[wd as usize];
993                    scratch[adj_a as usize] += w_v;
994                    scratch[adj_b as usize] -= w_v;
995                    scratch[adj_dot_a as usize] += wd_v;
996                    scratch[adj_dot_b as usize] -= wd_v;
997                }
998                HOp::RevMul {
999                    adj_a,
1000                    adj_b,
1001                    adj_dot_a,
1002                    adj_dot_b,
1003                    w,
1004                    wd,
1005                    va,
1006                    vb,
1007                    dot_a,
1008                    dot_b,
1009                } => {
1010                    let w_v = scratch[w as usize];
1011                    let wd_v = scratch[wd as usize];
1012                    let va_v = scratch[va as usize];
1013                    let vb_v = scratch[vb as usize];
1014                    let da_v = scratch[dot_a as usize];
1015                    let db_v = scratch[dot_b as usize];
1016                    scratch[adj_a as usize] += w_v * vb_v;
1017                    scratch[adj_b as usize] += w_v * va_v;
1018                    scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
1019                    scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
1020                }
1021                HOp::RevDiv {
1022                    adj_a,
1023                    adj_b,
1024                    adj_dot_a,
1025                    adj_dot_b,
1026                    w,
1027                    wd,
1028                    va,
1029                    vb,
1030                    dot_a,
1031                    dot_b,
1032                } => {
1033                    let w_v = scratch[w as usize];
1034                    let wd_v = scratch[wd as usize];
1035                    let va_v = scratch[va as usize];
1036                    let vb_v = scratch[vb as usize];
1037                    let vb2 = vb_v * vb_v;
1038                    let vb3 = vb2 * vb_v;
1039                    let da_v = scratch[dot_a as usize];
1040                    let db_v = scratch[dot_b as usize];
1041                    scratch[adj_a as usize] += w_v / vb_v;
1042                    scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
1043                    scratch[adj_b as usize] += w_v * (-va_v / vb2);
1044                    scratch[adj_dot_b as usize] +=
1045                        wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
1046                }
1047                HOp::RevPow {
1048                    adj_a,
1049                    adj_b,
1050                    adj_dot_a,
1051                    adj_dot_b,
1052                    w,
1053                    wd,
1054                    va,
1055                    vb,
1056                    vd,
1057                    dot_a,
1058                    dot_b,
1059                } => {
1060                    let w_v = scratch[w as usize];
1061                    let wd_v = scratch[wd as usize];
1062                    let u = scratch[va as usize];
1063                    let r = scratch[vb as usize];
1064                    let du = scratch[dot_a as usize];
1065                    let dr = scratch[dot_b as usize];
1066                    if r != 0.0 {
1067                        if u != 0.0 {
1068                            let p_a = r * u.powf(r - 1.0);
1069                            scratch[adj_a as usize] += w_v * p_a;
1070                            let mut dp_a = dr * u.powf(r - 1.0);
1071                            if u > 0.0 {
1072                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1073                            } else {
1074                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1075                            }
1076                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1077                        } else if r >= 2.0 {
1078                            let p_a = 0.0;
1079                            scratch[adj_a as usize] += w_v * p_a;
1080                            let dp_a = if r == 2.0 {
1081                                2.0 * du
1082                            } else {
1083                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1084                            };
1085                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1086                        }
1087                    }
1088                    if u > 0.0 {
1089                        let ln_u = u.ln();
1090                        let p_b = scratch[vd as usize] * ln_u;
1091                        scratch[adj_b as usize] += w_v * p_b;
1092                        let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1093                        let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1094                        scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1095                    }
1096                }
1097                HOp::RevNeg {
1098                    adj_a,
1099                    adj_dot_a,
1100                    w,
1101                    wd,
1102                } => {
1103                    scratch[adj_a as usize] -= scratch[w as usize];
1104                    scratch[adj_dot_a as usize] -= scratch[wd as usize];
1105                }
1106                HOp::RevAbs {
1107                    adj_a,
1108                    adj_dot_a,
1109                    w,
1110                    wd,
1111                    va,
1112                } => {
1113                    let s = if scratch[va as usize] >= 0.0 {
1114                        1.0
1115                    } else {
1116                        -1.0
1117                    };
1118                    scratch[adj_a as usize] += scratch[w as usize] * s;
1119                    scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1120                }
1121                HOp::RevSqrt {
1122                    adj_a,
1123                    adj_dot_a,
1124                    w,
1125                    wd,
1126                    va: _,
1127                    vd,
1128                    dot_a,
1129                } => {
1130                    let sv = scratch[vd as usize];
1131                    if sv > 0.0 {
1132                        let fp = 0.5 / sv;
1133                        let fpp = -0.25 / (sv * sv * sv);
1134                        let w_v = scratch[w as usize];
1135                        let wd_v = scratch[wd as usize];
1136                        scratch[adj_a as usize] += w_v * fp;
1137                        scratch[adj_dot_a as usize] +=
1138                            wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1139                    }
1140                }
1141                HOp::RevExp {
1142                    adj_a,
1143                    adj_dot_a,
1144                    w,
1145                    wd,
1146                    vd,
1147                    dot_a,
1148                } => {
1149                    let ev = scratch[vd as usize];
1150                    let w_v = scratch[w as usize];
1151                    let wd_v = scratch[wd as usize];
1152                    scratch[adj_a as usize] += w_v * ev;
1153                    scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1154                }
1155                HOp::RevLog {
1156                    adj_a,
1157                    adj_dot_a,
1158                    w,
1159                    wd,
1160                    va,
1161                    dot_a,
1162                } => {
1163                    let u = scratch[va as usize];
1164                    let w_v = scratch[w as usize];
1165                    let wd_v = scratch[wd as usize];
1166                    scratch[adj_a as usize] += w_v / u;
1167                    scratch[adj_dot_a as usize] +=
1168                        wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1169                }
1170                HOp::RevLog10 {
1171                    adj_a,
1172                    adj_dot_a,
1173                    w,
1174                    wd,
1175                    va,
1176                    dot_a,
1177                } => {
1178                    let u = scratch[va as usize];
1179                    let c = std::f64::consts::LN_10;
1180                    let w_v = scratch[w as usize];
1181                    let wd_v = scratch[wd as usize];
1182                    scratch[adj_a as usize] += w_v / (u * c);
1183                    scratch[adj_dot_a as usize] +=
1184                        wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1185                }
1186                HOp::RevSin {
1187                    adj_a,
1188                    adj_dot_a,
1189                    w,
1190                    wd,
1191                    va,
1192                    dot_a,
1193                } => {
1194                    let u = scratch[va as usize];
1195                    let cu = u.cos();
1196                    let w_v = scratch[w as usize];
1197                    let wd_v = scratch[wd as usize];
1198                    scratch[adj_a as usize] += w_v * cu;
1199                    scratch[adj_dot_a as usize] +=
1200                        wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1201                }
1202                HOp::RevCos {
1203                    adj_a,
1204                    adj_dot_a,
1205                    w,
1206                    wd,
1207                    va,
1208                    dot_a,
1209                } => {
1210                    let u = scratch[va as usize];
1211                    let su = u.sin();
1212                    let w_v = scratch[w as usize];
1213                    let wd_v = scratch[wd as usize];
1214                    scratch[adj_a as usize] -= w_v * su;
1215                    scratch[adj_dot_a as usize] +=
1216                        wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1217                }
1218
1219                HOp::HessEmit {
1220                    hess_ptr,
1221                    adj_dot_slot,
1222                } => {
1223                    values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1224                }
1225            }
1226        }
1227    }
1228}
1229
1230/// `out[i]` = does tape slot `i` contribute (transitively) to the
1231/// output slot `n-1`. Used to skip emitting reverse-pass ops for
1232/// dead slots.
1233fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1234    let n = tape.ops.len();
1235    let mut r = vec![false; n];
1236    if n == 0 {
1237        return r;
1238    }
1239    r[n - 1] = true;
1240    for i in (0..n).rev() {
1241        if !r[i] {
1242            continue;
1243        }
1244        match tape.ops[i] {
1245            TapeOp::Const(_) | TapeOp::Var(_) => {}
1246            TapeOp::Add(a, b)
1247            | TapeOp::Sub(a, b)
1248            | TapeOp::Mul(a, b)
1249            | TapeOp::Div(a, b)
1250            | TapeOp::Pow(a, b)
1251            | TapeOp::Atan2(a, b) => {
1252                r[a] = true;
1253                r[b] = true;
1254            }
1255            TapeOp::Neg(a)
1256            | TapeOp::Abs(a)
1257            | TapeOp::Sqrt(a)
1258            | TapeOp::Exp(a)
1259            | TapeOp::Log(a)
1260            | TapeOp::Log10(a)
1261            | TapeOp::Sin(a)
1262            | TapeOp::Cos(a)
1263            | TapeOp::Tan(a)
1264            | TapeOp::Atan(a)
1265            | TapeOp::Acos(a)
1266            | TapeOp::Sinh(a)
1267            | TapeOp::Cosh(a)
1268            | TapeOp::Tanh(a)
1269            | TapeOp::Asin(a)
1270            | TapeOp::Acosh(a)
1271            | TapeOp::Asinh(a)
1272            | TapeOp::Atanh(a) => {
1273                r[a] = true;
1274            }
1275            TapeOp::Funcall(_) => panic!(
1276                "HessianProgram path does not support AMPL external functions; \
1277                 use the Tape (build_with_externals) path instead."
1278            ),
1279            TapeOp::Cmp(_, _, _)
1280            | TapeOp::And(_, _)
1281            | TapeOp::Or(_, _)
1282            | TapeOp::Not(_)
1283            | TapeOp::Select(_, _, _)
1284            | TapeOp::Min(_, _)
1285            | TapeOp::Max(_, _) => panic!(
1286                "HessianProgram path does not support conditional / logical / min-max \
1287                 opcodes; use the Tape (build_with_externals) path instead."
1288            ),
1289        }
1290    }
1291    r
1292}
1293
1294/// `out[i]` = does tape slot `i` transitively read from `Var(j)`.
1295/// Used to prune forward-tangent ops (slots with `out[i] = false`
1296/// have `dot[i] = 0` and the rest of the per-`j` pass can skip
1297/// them).
1298fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1299    let n = tape.ops.len();
1300    let mut d = vec![false; n];
1301    for (i, op) in tape.ops.iter().enumerate() {
1302        d[i] = match *op {
1303            TapeOp::Const(_) => false,
1304            TapeOp::Var(k) => k == j,
1305            TapeOp::Add(a, b)
1306            | TapeOp::Sub(a, b)
1307            | TapeOp::Mul(a, b)
1308            | TapeOp::Div(a, b)
1309            | TapeOp::Pow(a, b)
1310            | TapeOp::Atan2(a, b) => d[a] || d[b],
1311            TapeOp::Neg(a)
1312            | TapeOp::Abs(a)
1313            | TapeOp::Sqrt(a)
1314            | TapeOp::Exp(a)
1315            | TapeOp::Log(a)
1316            | TapeOp::Log10(a)
1317            | TapeOp::Sin(a)
1318            | TapeOp::Cos(a)
1319            | TapeOp::Tan(a)
1320            | TapeOp::Atan(a)
1321            | TapeOp::Acos(a)
1322            | TapeOp::Sinh(a)
1323            | TapeOp::Cosh(a)
1324            | TapeOp::Tanh(a)
1325            | TapeOp::Asin(a)
1326            | TapeOp::Acosh(a)
1327            | TapeOp::Asinh(a)
1328            | TapeOp::Atanh(a) => d[a],
1329            TapeOp::Funcall(_) => panic!(
1330                "HessianProgram path does not support AMPL external functions; \
1331                 use the Tape (build_with_externals) path instead."
1332            ),
1333            TapeOp::Cmp(_, _, _)
1334            | TapeOp::And(_, _)
1335            | TapeOp::Or(_, _)
1336            | TapeOp::Not(_)
1337            | TapeOp::Select(_, _, _)
1338            | TapeOp::Min(_, _)
1339            | TapeOp::Max(_, _) => panic!(
1340                "HessianProgram path does not support conditional / logical / min-max \
1341                 opcodes; use the Tape (build_with_externals) path instead."
1342            ),
1343        };
1344    }
1345    d
1346}
1347
1348#[cfg(test)]
1349mod tests {
1350    use super::*;
1351    use crate::nl_reader::{BinOp, Expr, UnaryOp};
1352    use std::collections::BTreeSet;
1353    use std::rc::Rc;
1354
1355    fn cnst(c: f64) -> Expr {
1356        Expr::Const(c)
1357    }
1358    fn var(i: usize) -> Expr {
1359        Expr::Var(i)
1360    }
1361    fn add(a: Expr, b: Expr) -> Expr {
1362        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1363    }
1364    fn mul(a: Expr, b: Expr) -> Expr {
1365        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1366    }
1367    fn pow(a: Expr, b: Expr) -> Expr {
1368        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1369    }
1370    fn div(a: Expr, b: Expr) -> Expr {
1371        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1372    }
1373    fn sub(a: Expr, b: Expr) -> Expr {
1374        Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1375    }
1376    fn unary(op: UnaryOp, a: Expr) -> Expr {
1377        Expr::Unary(op, Box::new(a))
1378    }
1379
1380    /// Build the same shared (row, col) -> pos map both AD paths
1381    /// scatter into. Lower-triangle pairs, in tape.variables() order.
1382    fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1383        let vars = tape.variables();
1384        let mut pairs: Vec<(usize, usize)> = Vec::new();
1385        let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1386        for (ai, &vi) in vars.iter().enumerate() {
1387            for &vj in &vars[..=ai] {
1388                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1389                map.entry((r, c)).or_insert_with(|| {
1390                    let p = pairs.len();
1391                    pairs.push((r, c));
1392                    p
1393                });
1394            }
1395        }
1396        (map, pairs)
1397    }
1398
1399    /// Run both implementations against the same input and assert
1400    /// values match to a tight ULP-aligned tolerance.
1401    fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1402        let (hess_map, pairs) = build_hess_map(tape);
1403        let nnz = pairs.len();
1404
1405        let mut tape_vals = vec![0.0; nnz];
1406        tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1407
1408        let program = HessianProgram::compile(tape, &hess_map);
1409        let mut scratch = vec![0.0; program.n_slots()];
1410        let mut prog_vals = vec![0.0; nnz];
1411        program.execute(x, weight, &mut scratch, &mut prog_vals);
1412
1413        for (k, &(r, c)) in pairs.iter().enumerate() {
1414            let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1415            assert!(
1416                (tape_vals[k] - prog_vals[k]).abs() < tol,
1417                "H[{},{}]: tape={:.6e} prog={:.6e}",
1418                r,
1419                c,
1420                tape_vals[k],
1421                prog_vals[k]
1422            );
1423        }
1424    }
1425
1426    #[test]
1427    fn matches_quadratic() {
1428        let e = add(
1429            add(
1430                mul(cnst(3.0), pow(var(0), cnst(2.0))),
1431                mul(cnst(2.0), mul(var(0), var(1))),
1432            ),
1433            pow(var(1), cnst(2.0)),
1434        );
1435        let tape = Tape::build(&e);
1436        assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1437        assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1438    }
1439
1440    #[test]
1441    fn matches_transcendental() {
1442        let e = Expr::Sum(vec![
1443            unary(UnaryOp::Exp, var(0)),
1444            unary(UnaryOp::Sin, var(1)),
1445            unary(UnaryOp::Log, var(0)),
1446            unary(UnaryOp::Sqrt, var(1)),
1447            mul(var(0), var(1)),
1448            unary(UnaryOp::Cos, add(var(0), var(1))),
1449        ]);
1450        let tape = Tape::build(&e);
1451        assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1452        assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1453    }
1454
1455    #[test]
1456    fn matches_division() {
1457        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1458        let tape = Tape::build(&e);
1459        assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1460    }
1461
1462    #[test]
1463    fn matches_through_cse() {
1464        let body = Rc::new(add(var(0), var(1)));
1465        let e = add(
1466            pow(Expr::Cse(body.clone()), cnst(2.0)),
1467            Expr::Cse(body.clone()),
1468        );
1469        let tape = Tape::build(&e);
1470        assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1471        assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1472    }
1473
1474    #[test]
1475    fn matches_pow_chain() {
1476        // After Tier 1 this lowers to a Mul chain; verify both
1477        // paths agree on the lowered form too.
1478        let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1479        let tape = Tape::build(&e);
1480        assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1481    }
1482
1483    #[test]
1484    fn matches_residual_pow_with_var_exponent() {
1485        // Pow where the exponent is variable (not constant), so
1486        // it survives Tier 1 and exercises the RevPow / DotPow
1487        // compound branches.
1488        let e = pow(var(0), var(1));
1489        let tape = Tape::build(&e);
1490        assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1491        assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1492    }
1493
1494    #[test]
1495    fn matches_sub_neg_abs() {
1496        let e = sub(
1497            unary(UnaryOp::Neg, var(0)),
1498            unary(UnaryOp::Abs, sub(var(1), var(0))),
1499        );
1500        let tape = Tape::build(&e);
1501        assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1502        assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1503    }
1504
1505    #[test]
1506    fn slots_layout_matches_design() {
1507        let e = mul(var(0), var(1));
1508        let tape = Tape::build(&e);
1509        let (hess_map, _) = build_hess_map(&tape);
1510        let prog = HessianProgram::compile(&tape, &hess_map);
1511        assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1512    }
1513
1514    /// Sanity: the pruning analyses are consistent with the slot
1515    /// structure exposed via `hessian_sparsity()`.
1516    #[test]
1517    fn dependence_matches_hessian_sparsity_for_simple_case() {
1518        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1519        let tape = Tape::build(&e);
1520        let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1521        // (0,0) from sin, (2,1) from x1*x2, (1,1)/(2,2) NOT there
1522        // because Mul(x1, x2) emits cross only.
1523        assert!(s.contains(&(0, 0)));
1524        assert!(s.contains(&(2, 1)));
1525        assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1526    }
1527}