Skip to main content

pounce_cli/
nl_hessian_program.rs

1//! Precompiled symbolic-Hessian program for one `Tape`.
2//!
3//! `Tape::hessian_accumulate` runs forward-over-reverse AD at every
4//! call: for each tape variable `j` it (a) match-dispatches every op
5//! in a forward-tangent sweep, (b) zeros adj/adj_dot, (c)
6//! match-dispatches every op again in the reverse-over-tangent
7//! sweep, and (d) HashMap-looks-up every `Var(k)` slot to find its
8//! Hessian output position. On evaluator-bound problems (dirichlet,
9//! lane_emden, henon) that match-dispatch + symbolic-AD overhead is
10//! ~80% of total CPU.
11//!
12//! This module compiles all of that ONCE at tape-build time into a
13//! flat `Vec<HOp>` of pre-resolved primitive ops:
14//!
15//!   * Forward pass — one `Fwd*` op per `TapeOp`. Mirrors
16//!     `Tape::forward`.
17//!   * Per-`j` forward tangent — only the ops touching slots that
18//!     statically depend on `j` are emitted (the rest stay zero
19//!     from the per-`j` `ZeroRange` reset).
20//!   * Per-`j` reverse-over-tangent — only ops on slots reachable
21//!     backward from output, with all slot indices and Hessian
22//!     output pointers pre-resolved.
23//!
24//! ## Scratch layout
25//!
26//! The program reads/writes a single `&mut [f64]` arena of
27//! `n_slots` cells. We allocate four contiguous regions of length
28//! `n` (`n` = `tape.ops.len()`):
29//!
30//!   * `v[i]`        in slot `i`
31//!   * `dot[i]`      in slot `n + i`
32//!   * `adj[i]`      in slot `2n + i`
33//!   * `adj_dot[i]`  in slot `3n + i`
34//!
35//! Per-`j` setup zeros the `[n, 4n)` range and seeds `adj[n-1]`.
36//! Allocation pattern is intentionally trivial — finer-grained
37//! slot recycling buys little once the dispatch loop is the
38//! bottleneck, and a contiguous layout makes the per-`j`
39//! `ZeroRange` reset a single `memset`-friendly loop.
40
41use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45/// One primitive operation in the compiled Hessian program.
46/// `dst`/`a`/`b`/etc. are `u32` offsets into the caller's scratch
47/// slice; see the module docs for the slot layout.
48#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50    // ===== Forward pass =====
51    FwdLoadVar {
52        dst: u32,
53        x_idx: u32,
54    },
55    FwdLoadConst {
56        dst: u32,
57        c_idx: u32,
58    },
59    FwdAdd {
60        dst: u32,
61        a: u32,
62        b: u32,
63    },
64    FwdSub {
65        dst: u32,
66        a: u32,
67        b: u32,
68    },
69    FwdMul {
70        dst: u32,
71        a: u32,
72        b: u32,
73    },
74    FwdDiv {
75        dst: u32,
76        a: u32,
77        b: u32,
78    },
79    FwdPow {
80        dst: u32,
81        a: u32,
82        b: u32,
83    },
84    FwdNeg {
85        dst: u32,
86        a: u32,
87    },
88    FwdAbs {
89        dst: u32,
90        a: u32,
91    },
92    FwdSqrt {
93        dst: u32,
94        a: u32,
95    },
96    FwdExp {
97        dst: u32,
98        a: u32,
99    },
100    FwdLog {
101        dst: u32,
102        a: u32,
103    },
104    FwdLog10 {
105        dst: u32,
106        a: u32,
107    },
108    FwdSin {
109        dst: u32,
110        a: u32,
111    },
112    FwdCos {
113        dst: u32,
114        a: u32,
115    },
116
117    // ===== Scalar slot init =====
118    SetZero {
119        dst: u32,
120    },
121    SetOne {
122        dst: u32,
123    },
124
125    // ===== Bulk reset (start of each j) =====
126    ZeroRange {
127        start: u32,
128        len: u32,
129    },
130
131    // ===== Forward tangent (per j) =====
132    DotAdd {
133        dst: u32,
134        a: u32,
135        b: u32,
136    },
137    DotSub {
138        dst: u32,
139        a: u32,
140        b: u32,
141    },
142    /// dot[d] = dot[a]*v[b] + v[a]*dot[b]
143    DotMul {
144        dst: u32,
145        dot_a: u32,
146        vb: u32,
147        va: u32,
148        dot_b: u32,
149    },
150    /// dot[d] = (dot[a]*v[b] - v[a]*dot[b]) / (v[b]*v[b])
151    DotDiv {
152        dst: u32,
153        dot_a: u32,
154        vb: u32,
155        va: u32,
156        dot_b: u32,
157    },
158    /// dot[d] = 0.5 / v[d] * dot[a]  (v[d] = sqrt(v[a]))
159    DotSqrt {
160        dst: u32,
161        dot_a: u32,
162        vd: u32,
163    },
164    /// dot[d] = v[d] * dot[a]  (v[d] = exp(v[a]))
165    DotExp {
166        dst: u32,
167        dot_a: u32,
168        vd: u32,
169    },
170    DotLog {
171        dst: u32,
172        dot_a: u32,
173        va: u32,
174    },
175    DotLog10 {
176        dst: u32,
177        dot_a: u32,
178        va: u32,
179    },
180    DotSin {
181        dst: u32,
182        dot_a: u32,
183        va: u32,
184    },
185    DotCos {
186        dst: u32,
187        dot_a: u32,
188        va: u32,
189    },
190    DotNeg {
191        dst: u32,
192        dot_a: u32,
193    },
194    DotAbs {
195        dst: u32,
196        dot_a: u32,
197        va: u32,
198    },
199    /// Compound: dot[d] for Pow(a, b). Carries the runtime
200    /// `u != 0` / `u > 0` branches.
201    DotPow {
202        dst: u32,
203        va: u32,
204        vb: u32,
205        vd: u32,
206        dot_a: u32,
207        dot_b: u32,
208    },
209
210    // ===== Reverse + adj_dot update (per j) =====
211    // Each op consumes adj[i] (= `w`) and adj_dot[i] (= `wd`) of
212    // the consumer slot, then `+=`-accumulates into the adj /
213    // adj_dot of the operand slots.
214    RevAdd {
215        adj_a: u32,
216        adj_b: u32,
217        adj_dot_a: u32,
218        adj_dot_b: u32,
219        w: u32,
220        wd: u32,
221    },
222    RevSub {
223        adj_a: u32,
224        adj_b: u32,
225        adj_dot_a: u32,
226        adj_dot_b: u32,
227        w: u32,
228        wd: u32,
229    },
230    RevMul {
231        adj_a: u32,
232        adj_b: u32,
233        adj_dot_a: u32,
234        adj_dot_b: u32,
235        w: u32,
236        wd: u32,
237        va: u32,
238        vb: u32,
239        dot_a: u32,
240        dot_b: u32,
241    },
242    RevDiv {
243        adj_a: u32,
244        adj_b: u32,
245        adj_dot_a: u32,
246        adj_dot_b: u32,
247        w: u32,
248        wd: u32,
249        va: u32,
250        vb: u32,
251        dot_a: u32,
252        dot_b: u32,
253    },
254    RevPow {
255        adj_a: u32,
256        adj_b: u32,
257        adj_dot_a: u32,
258        adj_dot_b: u32,
259        w: u32,
260        wd: u32,
261        va: u32,
262        vb: u32,
263        vd: u32,
264        dot_a: u32,
265        dot_b: u32,
266    },
267    RevNeg {
268        adj_a: u32,
269        adj_dot_a: u32,
270        w: u32,
271        wd: u32,
272    },
273    RevAbs {
274        adj_a: u32,
275        adj_dot_a: u32,
276        w: u32,
277        wd: u32,
278        va: u32,
279    },
280    RevSqrt {
281        adj_a: u32,
282        adj_dot_a: u32,
283        w: u32,
284        wd: u32,
285        va: u32,
286        vd: u32,
287        dot_a: u32,
288    },
289    RevExp {
290        adj_a: u32,
291        adj_dot_a: u32,
292        w: u32,
293        wd: u32,
294        vd: u32,
295        dot_a: u32,
296    },
297    RevLog {
298        adj_a: u32,
299        adj_dot_a: u32,
300        w: u32,
301        wd: u32,
302        va: u32,
303        dot_a: u32,
304    },
305    RevLog10 {
306        adj_a: u32,
307        adj_dot_a: u32,
308        w: u32,
309        wd: u32,
310        va: u32,
311        dot_a: u32,
312    },
313    RevSin {
314        adj_a: u32,
315        adj_dot_a: u32,
316        w: u32,
317        wd: u32,
318        va: u32,
319        dot_a: u32,
320    },
321    RevCos {
322        adj_a: u32,
323        adj_dot_a: u32,
324        w: u32,
325        wd: u32,
326        va: u32,
327        dot_a: u32,
328    },
329
330    // ===== Output =====
331    /// values[hess_ptr] += weight * scratch[adj_dot_slot].
332    HessEmit {
333        hess_ptr: u32,
334        adj_dot_slot: u32,
335    },
336}
337
338/// Precompiled Hessian-of-one-tape program. Built once via
339/// [`HessianProgram::compile`]; executed many times.
340#[derive(Debug, Clone)]
341pub struct HessianProgram {
342    ops: Vec<HOp>,
343    consts: Vec<f64>,
344    n_slots: u32,
345}
346
347impl HessianProgram {
348    /// Build the program. The `hess_map` is the same `(row, col)
349    /// -> values-index` map that [`Tape::hessian_accumulate`] uses;
350    /// the compiler inlines each lookup into a `HessEmit` op.
351    ///
352    /// Returns `None` when `tape` contains an opcode the program path
353    /// cannot lower (see [`program_supports_op`]). A caller must fall
354    /// back to the `Tape` (`build_with_externals`) interpreter path for
355    /// those tapes — this is a *graceful* signal, not a panic, so a
356    /// problem built from arbitrary user `.nl` input can never crash the
357    /// process here (code review L28).
358    pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Option<Self> {
359        // Gate up front: every downstream sweep (forward / tangent /
360        // reverse) and the dependence/reachability analyses only handle the
361        // supported opcode set. Reject unsupported tapes here so none of
362        // those match arms is ever reached with an opcode it can't lower.
363        if !tape.ops.iter().all(program_supports_op) {
364            return None;
365        }
366
367        let n = tape.ops.len() as u32;
368        let v_base = 0u32;
369        let dot_base = n;
370        let adj_base = 2 * n;
371        let adj_dot_base = 3 * n;
372        let n_slots = 4 * n;
373
374        let v_slot = |i: u32| v_base + i;
375        let dot_slot = |i: u32| dot_base + i;
376        let adj_slot = |i: u32| adj_base + i;
377        let adj_dot_slot = |i: u32| adj_dot_base + i;
378
379        let reachable = reachable_to_output(tape);
380        let var_indices = tape.variables();
381        // depends_on[k_idx][i] — does slot i depend on var_indices[k_idx]?
382        let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
383            .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
384            .collect();
385
386        let mut consts: Vec<f64> = Vec::new();
387        let mut const_intern: HashMap<u64, u32> = HashMap::new();
388        let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
389            let bits = c.to_bits();
390            if let Some(&idx) = const_intern.get(&bits) {
391                return idx;
392            }
393            let idx = consts.len() as u32;
394            consts.push(c);
395            const_intern.insert(bits, idx);
396            idx
397        };
398
399        let mut ops: Vec<HOp> = Vec::new();
400
401        // ---- Forward pass ----
402        for (i, tape_op) in tape.ops.iter().enumerate() {
403            let i = i as u32;
404            let dst = v_slot(i);
405            let op = match *tape_op {
406                TapeOp::Const(c) => HOp::FwdLoadConst {
407                    dst,
408                    c_idx: intern_const(c, &mut consts),
409                },
410                TapeOp::Var(x_idx) => HOp::FwdLoadVar {
411                    dst,
412                    x_idx: x_idx as u32,
413                },
414                TapeOp::Add(a, b) => HOp::FwdAdd {
415                    dst,
416                    a: v_slot(a as u32),
417                    b: v_slot(b as u32),
418                },
419                TapeOp::Sub(a, b) => HOp::FwdSub {
420                    dst,
421                    a: v_slot(a as u32),
422                    b: v_slot(b as u32),
423                },
424                TapeOp::Mul(a, b) => HOp::FwdMul {
425                    dst,
426                    a: v_slot(a as u32),
427                    b: v_slot(b as u32),
428                },
429                TapeOp::Div(a, b) => HOp::FwdDiv {
430                    dst,
431                    a: v_slot(a as u32),
432                    b: v_slot(b as u32),
433                },
434                TapeOp::Pow(a, b) => HOp::FwdPow {
435                    dst,
436                    a: v_slot(a as u32),
437                    b: v_slot(b as u32),
438                },
439                TapeOp::Neg(a) => HOp::FwdNeg {
440                    dst,
441                    a: v_slot(a as u32),
442                },
443                TapeOp::Abs(a) => HOp::FwdAbs {
444                    dst,
445                    a: v_slot(a as u32),
446                },
447                TapeOp::Sqrt(a) => HOp::FwdSqrt {
448                    dst,
449                    a: v_slot(a as u32),
450                },
451                TapeOp::Exp(a) => HOp::FwdExp {
452                    dst,
453                    a: v_slot(a as u32),
454                },
455                TapeOp::Log(a) => HOp::FwdLog {
456                    dst,
457                    a: v_slot(a as u32),
458                },
459                TapeOp::Log10(a) => HOp::FwdLog10 {
460                    dst,
461                    a: v_slot(a as u32),
462                },
463                TapeOp::Sin(a) => HOp::FwdSin {
464                    dst,
465                    a: v_slot(a as u32),
466                },
467                TapeOp::Cos(a) => HOp::FwdCos {
468                    dst,
469                    a: v_slot(a as u32),
470                },
471                TapeOp::Funcall(_) => unreachable!(
472                    "HessianProgram path does not support AMPL external functions; \
473                     use the Tape (build_with_externals) path instead."
474                ),
475                TapeOp::Tan(_)
476                | TapeOp::Atan(_)
477                | TapeOp::Acos(_)
478                | TapeOp::Sinh(_)
479                | TapeOp::Cosh(_)
480                | TapeOp::Tanh(_)
481                | TapeOp::Asin(_)
482                | TapeOp::Acosh(_)
483                | TapeOp::Asinh(_)
484                | TapeOp::Atanh(_)
485                | TapeOp::Atan2(_, _)
486                | TapeOp::Cmp(_, _, _)
487                | TapeOp::And(_, _)
488                | TapeOp::Or(_, _)
489                | TapeOp::Not(_)
490                | TapeOp::Select(_, _, _)
491                | TapeOp::Min(_, _)
492                | TapeOp::Max(_, _) => unreachable!(
493                    "HessianProgram path does not yet support tan/atan/acos, the \
494                     other transcendental opcodes, atan2, min/max, or \
495                     conditional / logical opcodes; use the Tape \
496                     (build_with_externals) interpreter path instead."
497                ),
498            };
499            ops.push(op);
500        }
501
502        if n == 0 || var_indices.is_empty() {
503            return Some(HessianProgram {
504                ops,
505                consts,
506                n_slots,
507            });
508        }
509
510        // ---- Per-j forward-tangent + reverse-over-tangent ----
511        for (k_idx, &j) in var_indices.iter().enumerate() {
512            // Reset dot, adj, adj_dot for this j. Seed adj[n-1] = 1.
513            ops.push(HOp::ZeroRange {
514                start: dot_base,
515                len: 3 * n,
516            });
517            ops.push(HOp::SetOne {
518                dst: adj_slot(n - 1),
519            });
520
521            // Forward tangent: only emit ops for slots that
522            // statically depend on j (the rest stay zero from the
523            // ZeroRange above).
524            for (i, tape_op) in tape.ops.iter().enumerate() {
525                let i_u = i as u32;
526                if !depends_on[k_idx][i] {
527                    continue;
528                }
529                let dst = dot_slot(i_u);
530                let dot_op = match *tape_op {
531                    // Const: dot stays 0 (filtered above by
532                    // depends_on, since Const has no var-deps).
533                    TapeOp::Const(_) => continue,
534                    // Var(k): dot = 1 iff k == j, else 0. We only
535                    // get here if depends_on[k_idx][i] is true,
536                    // which for Var(k) means k == j.
537                    TapeOp::Var(_) => HOp::SetOne { dst },
538                    TapeOp::Add(a, b) => HOp::DotAdd {
539                        dst,
540                        a: dot_slot(a as u32),
541                        b: dot_slot(b as u32),
542                    },
543                    TapeOp::Sub(a, b) => HOp::DotSub {
544                        dst,
545                        a: dot_slot(a as u32),
546                        b: dot_slot(b as u32),
547                    },
548                    TapeOp::Mul(a, b) => HOp::DotMul {
549                        dst,
550                        dot_a: dot_slot(a as u32),
551                        vb: v_slot(b as u32),
552                        va: v_slot(a as u32),
553                        dot_b: dot_slot(b as u32),
554                    },
555                    TapeOp::Div(a, b) => HOp::DotDiv {
556                        dst,
557                        dot_a: dot_slot(a as u32),
558                        vb: v_slot(b as u32),
559                        va: v_slot(a as u32),
560                        dot_b: dot_slot(b as u32),
561                    },
562                    TapeOp::Pow(a, b) => HOp::DotPow {
563                        dst,
564                        va: v_slot(a as u32),
565                        vb: v_slot(b as u32),
566                        vd: v_slot(i_u),
567                        dot_a: dot_slot(a as u32),
568                        dot_b: dot_slot(b as u32),
569                    },
570                    TapeOp::Neg(a) => HOp::DotNeg {
571                        dst,
572                        dot_a: dot_slot(a as u32),
573                    },
574                    TapeOp::Abs(a) => HOp::DotAbs {
575                        dst,
576                        dot_a: dot_slot(a as u32),
577                        va: v_slot(a as u32),
578                    },
579                    TapeOp::Sqrt(a) => HOp::DotSqrt {
580                        dst,
581                        dot_a: dot_slot(a as u32),
582                        vd: v_slot(i_u),
583                    },
584                    TapeOp::Exp(a) => HOp::DotExp {
585                        dst,
586                        dot_a: dot_slot(a as u32),
587                        vd: v_slot(i_u),
588                    },
589                    TapeOp::Log(a) => HOp::DotLog {
590                        dst,
591                        dot_a: dot_slot(a as u32),
592                        va: v_slot(a as u32),
593                    },
594                    TapeOp::Log10(a) => HOp::DotLog10 {
595                        dst,
596                        dot_a: dot_slot(a as u32),
597                        va: v_slot(a as u32),
598                    },
599                    TapeOp::Sin(a) => HOp::DotSin {
600                        dst,
601                        dot_a: dot_slot(a as u32),
602                        va: v_slot(a as u32),
603                    },
604                    TapeOp::Cos(a) => HOp::DotCos {
605                        dst,
606                        dot_a: dot_slot(a as u32),
607                        va: v_slot(a as u32),
608                    },
609                    TapeOp::Funcall(_) => unreachable!(
610                        "HessianProgram path does not support AMPL external functions; \
611                         use the Tape (build_with_externals) path instead."
612                    ),
613                    TapeOp::Tan(_)
614                    | TapeOp::Atan(_)
615                    | TapeOp::Acos(_)
616                    | TapeOp::Sinh(_)
617                    | TapeOp::Cosh(_)
618                    | TapeOp::Tanh(_)
619                    | TapeOp::Asin(_)
620                    | TapeOp::Acosh(_)
621                    | TapeOp::Asinh(_)
622                    | TapeOp::Atanh(_)
623                    | TapeOp::Atan2(_, _)
624                    | TapeOp::Cmp(_, _, _)
625                    | TapeOp::And(_, _)
626                    | TapeOp::Or(_, _)
627                    | TapeOp::Not(_)
628                    | TapeOp::Select(_, _, _)
629                    | TapeOp::Min(_, _)
630                    | TapeOp::Max(_, _) => unreachable!(
631                        "HessianProgram path does not yet support tan/atan/acos, the \
632                         other transcendental opcodes, atan2, min/max, or \
633                         conditional / logical opcodes; use the Tape \
634                         (build_with_externals) interpreter path instead."
635                    ),
636                };
637                ops.push(dot_op);
638            }
639
640            // Reverse-over-tangent: walk slots backward, emit only
641            // for reachable slots.
642            for i in (0..n as usize).rev() {
643                if !reachable[i] {
644                    continue;
645                }
646                let i_u = i as u32;
647                let w = adj_slot(i_u);
648                let wd = adj_dot_slot(i_u);
649                let tape_op = &tape.ops[i];
650                let rev_op = match *tape_op {
651                    TapeOp::Const(_) => continue,
652                    TapeOp::Var(k) => {
653                        // At a Var slot: if k >= j and hess_map has
654                        // an entry for (k, j), emit a HessEmit op.
655                        // No adj/adj_dot propagation (no operands).
656                        if k >= j {
657                            if let Some(&ptr) = hess_map.get(&(k, j)) {
658                                ops.push(HOp::HessEmit {
659                                    hess_ptr: ptr as u32,
660                                    adj_dot_slot: wd,
661                                });
662                            }
663                        }
664                        continue;
665                    }
666                    TapeOp::Add(a, b) => HOp::RevAdd {
667                        adj_a: adj_slot(a as u32),
668                        adj_b: adj_slot(b as u32),
669                        adj_dot_a: adj_dot_slot(a as u32),
670                        adj_dot_b: adj_dot_slot(b as u32),
671                        w,
672                        wd,
673                    },
674                    TapeOp::Sub(a, b) => HOp::RevSub {
675                        adj_a: adj_slot(a as u32),
676                        adj_b: adj_slot(b as u32),
677                        adj_dot_a: adj_dot_slot(a as u32),
678                        adj_dot_b: adj_dot_slot(b as u32),
679                        w,
680                        wd,
681                    },
682                    TapeOp::Mul(a, b) => HOp::RevMul {
683                        adj_a: adj_slot(a as u32),
684                        adj_b: adj_slot(b as u32),
685                        adj_dot_a: adj_dot_slot(a as u32),
686                        adj_dot_b: adj_dot_slot(b as u32),
687                        w,
688                        wd,
689                        va: v_slot(a as u32),
690                        vb: v_slot(b as u32),
691                        dot_a: dot_slot(a as u32),
692                        dot_b: dot_slot(b as u32),
693                    },
694                    TapeOp::Div(a, b) => HOp::RevDiv {
695                        adj_a: adj_slot(a as u32),
696                        adj_b: adj_slot(b as u32),
697                        adj_dot_a: adj_dot_slot(a as u32),
698                        adj_dot_b: adj_dot_slot(b as u32),
699                        w,
700                        wd,
701                        va: v_slot(a as u32),
702                        vb: v_slot(b as u32),
703                        dot_a: dot_slot(a as u32),
704                        dot_b: dot_slot(b as u32),
705                    },
706                    TapeOp::Pow(a, b) => HOp::RevPow {
707                        adj_a: adj_slot(a as u32),
708                        adj_b: adj_slot(b as u32),
709                        adj_dot_a: adj_dot_slot(a as u32),
710                        adj_dot_b: adj_dot_slot(b as u32),
711                        w,
712                        wd,
713                        va: v_slot(a as u32),
714                        vb: v_slot(b as u32),
715                        vd: v_slot(i_u),
716                        dot_a: dot_slot(a as u32),
717                        dot_b: dot_slot(b as u32),
718                    },
719                    TapeOp::Neg(a) => HOp::RevNeg {
720                        adj_a: adj_slot(a as u32),
721                        adj_dot_a: adj_dot_slot(a as u32),
722                        w,
723                        wd,
724                    },
725                    TapeOp::Abs(a) => HOp::RevAbs {
726                        adj_a: adj_slot(a as u32),
727                        adj_dot_a: adj_dot_slot(a as u32),
728                        w,
729                        wd,
730                        va: v_slot(a as u32),
731                    },
732                    TapeOp::Sqrt(a) => HOp::RevSqrt {
733                        adj_a: adj_slot(a as u32),
734                        adj_dot_a: adj_dot_slot(a as u32),
735                        w,
736                        wd,
737                        va: v_slot(a as u32),
738                        vd: v_slot(i_u),
739                        dot_a: dot_slot(a as u32),
740                    },
741                    TapeOp::Exp(a) => HOp::RevExp {
742                        adj_a: adj_slot(a as u32),
743                        adj_dot_a: adj_dot_slot(a as u32),
744                        w,
745                        wd,
746                        vd: v_slot(i_u),
747                        dot_a: dot_slot(a as u32),
748                    },
749                    TapeOp::Log(a) => HOp::RevLog {
750                        adj_a: adj_slot(a as u32),
751                        adj_dot_a: adj_dot_slot(a as u32),
752                        w,
753                        wd,
754                        va: v_slot(a as u32),
755                        dot_a: dot_slot(a as u32),
756                    },
757                    TapeOp::Log10(a) => HOp::RevLog10 {
758                        adj_a: adj_slot(a as u32),
759                        adj_dot_a: adj_dot_slot(a as u32),
760                        w,
761                        wd,
762                        va: v_slot(a as u32),
763                        dot_a: dot_slot(a as u32),
764                    },
765                    TapeOp::Sin(a) => HOp::RevSin {
766                        adj_a: adj_slot(a as u32),
767                        adj_dot_a: adj_dot_slot(a as u32),
768                        w,
769                        wd,
770                        va: v_slot(a as u32),
771                        dot_a: dot_slot(a as u32),
772                    },
773                    TapeOp::Cos(a) => HOp::RevCos {
774                        adj_a: adj_slot(a as u32),
775                        adj_dot_a: adj_dot_slot(a as u32),
776                        w,
777                        wd,
778                        va: v_slot(a as u32),
779                        dot_a: dot_slot(a as u32),
780                    },
781                    TapeOp::Funcall(_) => unreachable!(
782                        "HessianProgram path does not support AMPL external functions; \
783                         use the Tape (build_with_externals) path instead."
784                    ),
785                    TapeOp::Tan(_)
786                    | TapeOp::Atan(_)
787                    | TapeOp::Acos(_)
788                    | TapeOp::Sinh(_)
789                    | TapeOp::Cosh(_)
790                    | TapeOp::Tanh(_)
791                    | TapeOp::Asin(_)
792                    | TapeOp::Acosh(_)
793                    | TapeOp::Asinh(_)
794                    | TapeOp::Atanh(_)
795                    | TapeOp::Atan2(_, _)
796                    | TapeOp::Cmp(_, _, _)
797                    | TapeOp::And(_, _)
798                    | TapeOp::Or(_, _)
799                    | TapeOp::Not(_)
800                    | TapeOp::Select(_, _, _)
801                    | TapeOp::Min(_, _)
802                    | TapeOp::Max(_, _) => unreachable!(
803                        "HessianProgram path does not yet support tan/atan/acos, the \
804                         other transcendental opcodes, atan2, min/max, or \
805                         conditional / logical opcodes; use the Tape \
806                         (build_with_externals) interpreter path instead."
807                    ),
808                };
809                ops.push(rev_op);
810            }
811        }
812
813        Some(HessianProgram {
814            ops,
815            consts,
816            n_slots,
817        })
818    }
819
820    pub fn n_slots(&self) -> usize {
821        self.n_slots as usize
822    }
823
824    pub fn n_ops(&self) -> usize {
825        self.ops.len()
826    }
827
828    /// Execute the program. `scratch` is overwritten throughout;
829    /// it must be at least [`n_slots`] long. `values` is the
830    /// shared Hessian-values buffer the caller is accumulating
831    /// into (same semantics as
832    /// [`Tape::hessian_accumulate`]'s `values`).
833    pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
834        debug_assert!(scratch.len() >= self.n_slots as usize);
835        if self.ops.is_empty() || weight == 0.0 {
836            return;
837        }
838        let consts = &self.consts[..];
839        for &op in &self.ops {
840            match op {
841                HOp::FwdLoadVar { dst, x_idx } => {
842                    scratch[dst as usize] = x[x_idx as usize];
843                }
844                HOp::FwdLoadConst { dst, c_idx } => {
845                    scratch[dst as usize] = consts[c_idx as usize];
846                }
847                HOp::FwdAdd { dst, a, b } => {
848                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
849                }
850                HOp::FwdSub { dst, a, b } => {
851                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
852                }
853                HOp::FwdMul { dst, a, b } => {
854                    scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
855                }
856                HOp::FwdDiv { dst, a, b } => {
857                    scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
858                }
859                HOp::FwdPow { dst, a, b } => {
860                    scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
861                }
862                HOp::FwdNeg { dst, a } => {
863                    scratch[dst as usize] = -scratch[a as usize];
864                }
865                HOp::FwdAbs { dst, a } => {
866                    scratch[dst as usize] = scratch[a as usize].abs();
867                }
868                HOp::FwdSqrt { dst, a } => {
869                    scratch[dst as usize] = scratch[a as usize].sqrt();
870                }
871                HOp::FwdExp { dst, a } => {
872                    scratch[dst as usize] = scratch[a as usize].exp();
873                }
874                HOp::FwdLog { dst, a } => {
875                    scratch[dst as usize] = scratch[a as usize].ln();
876                }
877                HOp::FwdLog10 { dst, a } => {
878                    scratch[dst as usize] = scratch[a as usize].log10();
879                }
880                HOp::FwdSin { dst, a } => {
881                    scratch[dst as usize] = scratch[a as usize].sin();
882                }
883                HOp::FwdCos { dst, a } => {
884                    scratch[dst as usize] = scratch[a as usize].cos();
885                }
886
887                HOp::SetZero { dst } => {
888                    scratch[dst as usize] = 0.0;
889                }
890                HOp::SetOne { dst } => {
891                    scratch[dst as usize] = 1.0;
892                }
893                HOp::ZeroRange { start, len } => {
894                    let s = start as usize;
895                    let e = s + len as usize;
896                    scratch[s..e].fill(0.0);
897                }
898
899                HOp::DotAdd { dst, a, b } => {
900                    scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
901                }
902                HOp::DotSub { dst, a, b } => {
903                    scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
904                }
905                HOp::DotMul {
906                    dst,
907                    dot_a,
908                    vb,
909                    va,
910                    dot_b,
911                } => {
912                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
913                        + scratch[va as usize] * scratch[dot_b as usize];
914                }
915                HOp::DotDiv {
916                    dst,
917                    dot_a,
918                    vb,
919                    va,
920                    dot_b,
921                } => {
922                    let v_b = scratch[vb as usize];
923                    scratch[dst as usize] = (scratch[dot_a as usize] * v_b
924                        - scratch[va as usize] * scratch[dot_b as usize])
925                        / (v_b * v_b);
926                }
927                HOp::DotSqrt { dst, dot_a, vd } => {
928                    let svd = scratch[vd as usize];
929                    scratch[dst as usize] = if svd > 0.0 {
930                        scratch[dot_a as usize] * 0.5 / svd
931                    } else {
932                        0.0
933                    };
934                }
935                HOp::DotExp { dst, dot_a, vd } => {
936                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
937                }
938                HOp::DotLog { dst, dot_a, va } => {
939                    scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
940                }
941                HOp::DotLog10 { dst, dot_a, va } => {
942                    scratch[dst as usize] =
943                        scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
944                }
945                HOp::DotSin { dst, dot_a, va } => {
946                    scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
947                }
948                HOp::DotCos { dst, dot_a, va } => {
949                    scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
950                }
951                HOp::DotNeg { dst, dot_a } => {
952                    scratch[dst as usize] = -scratch[dot_a as usize];
953                }
954                HOp::DotAbs { dst, dot_a, va } => {
955                    scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
956                        scratch[dot_a as usize]
957                    } else {
958                        -scratch[dot_a as usize]
959                    };
960                }
961                HOp::DotPow {
962                    dst,
963                    va,
964                    vb,
965                    vd,
966                    dot_a,
967                    dot_b,
968                } => {
969                    let u = scratch[va as usize];
970                    let r = scratch[vb as usize];
971                    let du = scratch[dot_a as usize];
972                    let dr = scratch[dot_b as usize];
973                    let mut result = 0.0;
974                    if r != 0.0 && u != 0.0 {
975                        result += r * u.powf(r - 1.0) * du;
976                    }
977                    if u > 0.0 {
978                        result += scratch[vd as usize] * u.ln() * dr;
979                    }
980                    scratch[dst as usize] = result;
981                }
982
983                HOp::RevAdd {
984                    adj_a,
985                    adj_b,
986                    adj_dot_a,
987                    adj_dot_b,
988                    w,
989                    wd,
990                } => {
991                    let w_v = scratch[w as usize];
992                    let wd_v = scratch[wd as usize];
993                    scratch[adj_a as usize] += w_v;
994                    scratch[adj_b as usize] += w_v;
995                    scratch[adj_dot_a as usize] += wd_v;
996                    scratch[adj_dot_b as usize] += wd_v;
997                }
998                HOp::RevSub {
999                    adj_a,
1000                    adj_b,
1001                    adj_dot_a,
1002                    adj_dot_b,
1003                    w,
1004                    wd,
1005                } => {
1006                    let w_v = scratch[w as usize];
1007                    let wd_v = scratch[wd as usize];
1008                    scratch[adj_a as usize] += w_v;
1009                    scratch[adj_b as usize] -= w_v;
1010                    scratch[adj_dot_a as usize] += wd_v;
1011                    scratch[adj_dot_b as usize] -= wd_v;
1012                }
1013                HOp::RevMul {
1014                    adj_a,
1015                    adj_b,
1016                    adj_dot_a,
1017                    adj_dot_b,
1018                    w,
1019                    wd,
1020                    va,
1021                    vb,
1022                    dot_a,
1023                    dot_b,
1024                } => {
1025                    let w_v = scratch[w as usize];
1026                    let wd_v = scratch[wd as usize];
1027                    let va_v = scratch[va as usize];
1028                    let vb_v = scratch[vb as usize];
1029                    let da_v = scratch[dot_a as usize];
1030                    let db_v = scratch[dot_b as usize];
1031                    scratch[adj_a as usize] += w_v * vb_v;
1032                    scratch[adj_b as usize] += w_v * va_v;
1033                    scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
1034                    scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
1035                }
1036                HOp::RevDiv {
1037                    adj_a,
1038                    adj_b,
1039                    adj_dot_a,
1040                    adj_dot_b,
1041                    w,
1042                    wd,
1043                    va,
1044                    vb,
1045                    dot_a,
1046                    dot_b,
1047                } => {
1048                    let w_v = scratch[w as usize];
1049                    let wd_v = scratch[wd as usize];
1050                    let va_v = scratch[va as usize];
1051                    let vb_v = scratch[vb as usize];
1052                    let vb2 = vb_v * vb_v;
1053                    let vb3 = vb2 * vb_v;
1054                    let da_v = scratch[dot_a as usize];
1055                    let db_v = scratch[dot_b as usize];
1056                    scratch[adj_a as usize] += w_v / vb_v;
1057                    scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
1058                    scratch[adj_b as usize] += w_v * (-va_v / vb2);
1059                    scratch[adj_dot_b as usize] +=
1060                        wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
1061                }
1062                HOp::RevPow {
1063                    adj_a,
1064                    adj_b,
1065                    adj_dot_a,
1066                    adj_dot_b,
1067                    w,
1068                    wd,
1069                    va,
1070                    vb,
1071                    vd,
1072                    dot_a,
1073                    dot_b,
1074                } => {
1075                    let w_v = scratch[w as usize];
1076                    let wd_v = scratch[wd as usize];
1077                    let u = scratch[va as usize];
1078                    let r = scratch[vb as usize];
1079                    let du = scratch[dot_a as usize];
1080                    let dr = scratch[dot_b as usize];
1081                    if r != 0.0 {
1082                        if u != 0.0 {
1083                            let p_a = r * u.powf(r - 1.0);
1084                            scratch[adj_a as usize] += w_v * p_a;
1085                            let mut dp_a = dr * u.powf(r - 1.0);
1086                            if u > 0.0 {
1087                                dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1088                            } else {
1089                                dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1090                            }
1091                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1092                        } else if r >= 2.0 {
1093                            let p_a = 0.0;
1094                            scratch[adj_a as usize] += w_v * p_a;
1095                            let dp_a = if r == 2.0 {
1096                                2.0 * du
1097                            } else {
1098                                r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1099                            };
1100                            scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1101                        }
1102                    }
1103                    if u > 0.0 {
1104                        let ln_u = u.ln();
1105                        let p_b = scratch[vd as usize] * ln_u;
1106                        scratch[adj_b as usize] += w_v * p_b;
1107                        let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1108                        let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1109                        scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1110                    }
1111                }
1112                HOp::RevNeg {
1113                    adj_a,
1114                    adj_dot_a,
1115                    w,
1116                    wd,
1117                } => {
1118                    scratch[adj_a as usize] -= scratch[w as usize];
1119                    scratch[adj_dot_a as usize] -= scratch[wd as usize];
1120                }
1121                HOp::RevAbs {
1122                    adj_a,
1123                    adj_dot_a,
1124                    w,
1125                    wd,
1126                    va,
1127                } => {
1128                    let s = if scratch[va as usize] >= 0.0 {
1129                        1.0
1130                    } else {
1131                        -1.0
1132                    };
1133                    scratch[adj_a as usize] += scratch[w as usize] * s;
1134                    scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1135                }
1136                HOp::RevSqrt {
1137                    adj_a,
1138                    adj_dot_a,
1139                    w,
1140                    wd,
1141                    va: _,
1142                    vd,
1143                    dot_a,
1144                } => {
1145                    let sv = scratch[vd as usize];
1146                    if sv > 0.0 {
1147                        let fp = 0.5 / sv;
1148                        let fpp = -0.25 / (sv * sv * sv);
1149                        let w_v = scratch[w as usize];
1150                        let wd_v = scratch[wd as usize];
1151                        scratch[adj_a as usize] += w_v * fp;
1152                        scratch[adj_dot_a as usize] +=
1153                            wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1154                    }
1155                }
1156                HOp::RevExp {
1157                    adj_a,
1158                    adj_dot_a,
1159                    w,
1160                    wd,
1161                    vd,
1162                    dot_a,
1163                } => {
1164                    let ev = scratch[vd as usize];
1165                    let w_v = scratch[w as usize];
1166                    let wd_v = scratch[wd as usize];
1167                    scratch[adj_a as usize] += w_v * ev;
1168                    scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1169                }
1170                HOp::RevLog {
1171                    adj_a,
1172                    adj_dot_a,
1173                    w,
1174                    wd,
1175                    va,
1176                    dot_a,
1177                } => {
1178                    let u = scratch[va as usize];
1179                    let w_v = scratch[w as usize];
1180                    let wd_v = scratch[wd as usize];
1181                    scratch[adj_a as usize] += w_v / u;
1182                    scratch[adj_dot_a as usize] +=
1183                        wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1184                }
1185                HOp::RevLog10 {
1186                    adj_a,
1187                    adj_dot_a,
1188                    w,
1189                    wd,
1190                    va,
1191                    dot_a,
1192                } => {
1193                    let u = scratch[va as usize];
1194                    let c = std::f64::consts::LN_10;
1195                    let w_v = scratch[w as usize];
1196                    let wd_v = scratch[wd as usize];
1197                    scratch[adj_a as usize] += w_v / (u * c);
1198                    scratch[adj_dot_a as usize] +=
1199                        wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1200                }
1201                HOp::RevSin {
1202                    adj_a,
1203                    adj_dot_a,
1204                    w,
1205                    wd,
1206                    va,
1207                    dot_a,
1208                } => {
1209                    let u = scratch[va as usize];
1210                    let cu = u.cos();
1211                    let w_v = scratch[w as usize];
1212                    let wd_v = scratch[wd as usize];
1213                    scratch[adj_a as usize] += w_v * cu;
1214                    scratch[adj_dot_a as usize] +=
1215                        wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1216                }
1217                HOp::RevCos {
1218                    adj_a,
1219                    adj_dot_a,
1220                    w,
1221                    wd,
1222                    va,
1223                    dot_a,
1224                } => {
1225                    let u = scratch[va as usize];
1226                    let su = u.sin();
1227                    let w_v = scratch[w as usize];
1228                    let wd_v = scratch[wd as usize];
1229                    scratch[adj_a as usize] -= w_v * su;
1230                    scratch[adj_dot_a as usize] +=
1231                        wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1232                }
1233
1234                HOp::HessEmit {
1235                    hess_ptr,
1236                    adj_dot_slot,
1237                } => {
1238                    values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1239                }
1240            }
1241        }
1242    }
1243}
1244
1245/// Whether [`HessianProgram::compile`] can lower a single opcode. The
1246/// program path covers smooth arithmetic plus `sin`/`cos`; every other
1247/// opcode — AMPL external `Funcall`, the remaining transcendentals
1248/// (`tan`/`atan`/`acos`/the hyperbolics/`asin`…) and `atan2`, and the
1249/// `min`/`max`/conditional/logical family — is unsupported, so `compile`
1250/// returns `None` and the caller falls back to the `Tape`
1251/// (`build_with_externals`) interpreter path rather than panicking on user
1252/// input (code review L28). This is the single source of truth for the
1253/// supported set: the per-sweep match arms and the dependence/reachability
1254/// analyses are only ever reached with opcodes this predicate accepts, so
1255/// their unsupported branches are `unreachable!`.
1256fn program_supports_op(op: &TapeOp) -> bool {
1257    matches!(
1258        op,
1259        TapeOp::Const(_)
1260            | TapeOp::Var(_)
1261            | TapeOp::Add(_, _)
1262            | TapeOp::Sub(_, _)
1263            | TapeOp::Mul(_, _)
1264            | TapeOp::Div(_, _)
1265            | TapeOp::Pow(_, _)
1266            | TapeOp::Neg(_)
1267            | TapeOp::Abs(_)
1268            | TapeOp::Sqrt(_)
1269            | TapeOp::Exp(_)
1270            | TapeOp::Log(_)
1271            | TapeOp::Log10(_)
1272            | TapeOp::Sin(_)
1273            | TapeOp::Cos(_)
1274    )
1275}
1276
1277/// `out[i]` = does tape slot `i` contribute (transitively) to the
1278/// output slot `n-1`. Used to skip emitting reverse-pass ops for
1279/// dead slots.
1280fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1281    let n = tape.ops.len();
1282    let mut r = vec![false; n];
1283    if n == 0 {
1284        return r;
1285    }
1286    r[n - 1] = true;
1287    for i in (0..n).rev() {
1288        if !r[i] {
1289            continue;
1290        }
1291        match tape.ops[i] {
1292            TapeOp::Const(_) | TapeOp::Var(_) => {}
1293            TapeOp::Add(a, b)
1294            | TapeOp::Sub(a, b)
1295            | TapeOp::Mul(a, b)
1296            | TapeOp::Div(a, b)
1297            | TapeOp::Pow(a, b)
1298            | TapeOp::Atan2(a, b) => {
1299                r[a] = true;
1300                r[b] = true;
1301            }
1302            TapeOp::Neg(a)
1303            | TapeOp::Abs(a)
1304            | TapeOp::Sqrt(a)
1305            | TapeOp::Exp(a)
1306            | TapeOp::Log(a)
1307            | TapeOp::Log10(a)
1308            | TapeOp::Sin(a)
1309            | TapeOp::Cos(a)
1310            | TapeOp::Tan(a)
1311            | TapeOp::Atan(a)
1312            | TapeOp::Acos(a)
1313            | TapeOp::Sinh(a)
1314            | TapeOp::Cosh(a)
1315            | TapeOp::Tanh(a)
1316            | TapeOp::Asin(a)
1317            | TapeOp::Acosh(a)
1318            | TapeOp::Asinh(a)
1319            | TapeOp::Atanh(a) => {
1320                r[a] = true;
1321            }
1322            TapeOp::Funcall(_) => unreachable!(
1323                "HessianProgram path does not support AMPL external functions; \
1324                 use the Tape (build_with_externals) path instead."
1325            ),
1326            TapeOp::Cmp(_, _, _)
1327            | TapeOp::And(_, _)
1328            | TapeOp::Or(_, _)
1329            | TapeOp::Not(_)
1330            | TapeOp::Select(_, _, _)
1331            | TapeOp::Min(_, _)
1332            | TapeOp::Max(_, _) => unreachable!(
1333                "HessianProgram path does not support conditional / logical / min-max \
1334                 opcodes; use the Tape (build_with_externals) path instead."
1335            ),
1336        }
1337    }
1338    r
1339}
1340
1341/// `out[i]` = does tape slot `i` transitively read from `Var(j)`.
1342/// Used to prune forward-tangent ops (slots with `out[i] = false`
1343/// have `dot[i] = 0` and the rest of the per-`j` pass can skip
1344/// them).
1345fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1346    let n = tape.ops.len();
1347    let mut d = vec![false; n];
1348    for (i, op) in tape.ops.iter().enumerate() {
1349        d[i] = match *op {
1350            TapeOp::Const(_) => false,
1351            TapeOp::Var(k) => k == j,
1352            TapeOp::Add(a, b)
1353            | TapeOp::Sub(a, b)
1354            | TapeOp::Mul(a, b)
1355            | TapeOp::Div(a, b)
1356            | TapeOp::Pow(a, b)
1357            | TapeOp::Atan2(a, b) => d[a] || d[b],
1358            TapeOp::Neg(a)
1359            | TapeOp::Abs(a)
1360            | TapeOp::Sqrt(a)
1361            | TapeOp::Exp(a)
1362            | TapeOp::Log(a)
1363            | TapeOp::Log10(a)
1364            | TapeOp::Sin(a)
1365            | TapeOp::Cos(a)
1366            | TapeOp::Tan(a)
1367            | TapeOp::Atan(a)
1368            | TapeOp::Acos(a)
1369            | TapeOp::Sinh(a)
1370            | TapeOp::Cosh(a)
1371            | TapeOp::Tanh(a)
1372            | TapeOp::Asin(a)
1373            | TapeOp::Acosh(a)
1374            | TapeOp::Asinh(a)
1375            | TapeOp::Atanh(a) => d[a],
1376            TapeOp::Funcall(_) => unreachable!(
1377                "HessianProgram path does not support AMPL external functions; \
1378                 use the Tape (build_with_externals) path instead."
1379            ),
1380            TapeOp::Cmp(_, _, _)
1381            | TapeOp::And(_, _)
1382            | TapeOp::Or(_, _)
1383            | TapeOp::Not(_)
1384            | TapeOp::Select(_, _, _)
1385            | TapeOp::Min(_, _)
1386            | TapeOp::Max(_, _) => unreachable!(
1387                "HessianProgram path does not support conditional / logical / min-max \
1388                 opcodes; use the Tape (build_with_externals) path instead."
1389            ),
1390        };
1391    }
1392    d
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397    use super::*;
1398    use crate::nl_reader::{BinOp, Expr, UnaryOp};
1399    use std::collections::BTreeSet;
1400    use std::sync::Arc;
1401
1402    fn cnst(c: f64) -> Expr {
1403        Expr::Const(c)
1404    }
1405    fn var(i: usize) -> Expr {
1406        Expr::Var(i)
1407    }
1408    fn add(a: Expr, b: Expr) -> Expr {
1409        Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1410    }
1411    fn mul(a: Expr, b: Expr) -> Expr {
1412        Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1413    }
1414    fn pow(a: Expr, b: Expr) -> Expr {
1415        Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1416    }
1417    fn div(a: Expr, b: Expr) -> Expr {
1418        Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1419    }
1420    fn sub(a: Expr, b: Expr) -> Expr {
1421        Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1422    }
1423    fn unary(op: UnaryOp, a: Expr) -> Expr {
1424        Expr::Unary(op, Box::new(a))
1425    }
1426
1427    /// Build the same shared (row, col) -> pos map both AD paths
1428    /// scatter into. Lower-triangle pairs, in tape.variables() order.
1429    fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1430        let vars = tape.variables();
1431        let mut pairs: Vec<(usize, usize)> = Vec::new();
1432        let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1433        for (ai, &vi) in vars.iter().enumerate() {
1434            for &vj in &vars[..=ai] {
1435                let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1436                map.entry((r, c)).or_insert_with(|| {
1437                    let p = pairs.len();
1438                    pairs.push((r, c));
1439                    p
1440                });
1441            }
1442        }
1443        (map, pairs)
1444    }
1445
1446    /// Run both implementations against the same input and assert
1447    /// values match to a tight ULP-aligned tolerance.
1448    fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1449        let (hess_map, pairs) = build_hess_map(tape);
1450        let nnz = pairs.len();
1451
1452        let mut tape_vals = vec![0.0; nnz];
1453        tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1454
1455        let program =
1456            HessianProgram::compile(tape, &hess_map).expect("tape uses only supported opcodes");
1457        let mut scratch = vec![0.0; program.n_slots()];
1458        let mut prog_vals = vec![0.0; nnz];
1459        program.execute(x, weight, &mut scratch, &mut prog_vals);
1460
1461        for (k, &(r, c)) in pairs.iter().enumerate() {
1462            let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1463            assert!(
1464                (tape_vals[k] - prog_vals[k]).abs() < tol,
1465                "H[{},{}]: tape={:.6e} prog={:.6e}",
1466                r,
1467                c,
1468                tape_vals[k],
1469                prog_vals[k]
1470            );
1471        }
1472    }
1473
1474    #[test]
1475    fn matches_quadratic() {
1476        let e = add(
1477            add(
1478                mul(cnst(3.0), pow(var(0), cnst(2.0))),
1479                mul(cnst(2.0), mul(var(0), var(1))),
1480            ),
1481            pow(var(1), cnst(2.0)),
1482        );
1483        let tape = Tape::build(&e);
1484        assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1485        assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1486    }
1487
1488    #[test]
1489    fn matches_transcendental() {
1490        let e = Expr::Sum(vec![
1491            unary(UnaryOp::Exp, var(0)),
1492            unary(UnaryOp::Sin, var(1)),
1493            unary(UnaryOp::Log, var(0)),
1494            unary(UnaryOp::Sqrt, var(1)),
1495            mul(var(0), var(1)),
1496            unary(UnaryOp::Cos, add(var(0), var(1))),
1497        ]);
1498        let tape = Tape::build(&e);
1499        assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1500        assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1501    }
1502
1503    #[test]
1504    fn matches_division() {
1505        let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1506        let tape = Tape::build(&e);
1507        assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1508    }
1509
1510    #[test]
1511    fn matches_through_cse() {
1512        let body = Arc::new(add(var(0), var(1)));
1513        let e = add(
1514            pow(Expr::Cse(body.clone()), cnst(2.0)),
1515            Expr::Cse(body.clone()),
1516        );
1517        let tape = Tape::build(&e);
1518        assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1519        assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1520    }
1521
1522    #[test]
1523    fn matches_pow_chain() {
1524        // After Tier 1 this lowers to a Mul chain; verify both
1525        // paths agree on the lowered form too.
1526        let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1527        let tape = Tape::build(&e);
1528        assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1529    }
1530
1531    #[test]
1532    fn matches_residual_pow_with_var_exponent() {
1533        // Pow where the exponent is variable (not constant), so
1534        // it survives Tier 1 and exercises the RevPow / DotPow
1535        // compound branches.
1536        let e = pow(var(0), var(1));
1537        let tape = Tape::build(&e);
1538        assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1539        assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1540    }
1541
1542    #[test]
1543    fn matches_sub_neg_abs() {
1544        let e = sub(
1545            unary(UnaryOp::Neg, var(0)),
1546            unary(UnaryOp::Abs, sub(var(1), var(0))),
1547        );
1548        let tape = Tape::build(&e);
1549        assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1550        assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1551    }
1552
1553    #[test]
1554    fn slots_layout_matches_design() {
1555        let e = mul(var(0), var(1));
1556        let tape = Tape::build(&e);
1557        let (hess_map, _) = build_hess_map(&tape);
1558        let prog = HessianProgram::compile(&tape, &hess_map).expect("mul tape is supported");
1559        assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1560    }
1561
1562    /// Sanity: the pruning analyses are consistent with the slot
1563    /// structure exposed via `hessian_sparsity()`.
1564    #[test]
1565    fn dependence_matches_hessian_sparsity_for_simple_case() {
1566        let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1567        let tape = Tape::build(&e);
1568        let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1569        // (0,0) from sin, (2,1) from x1*x2, (1,1)/(2,2) NOT there
1570        // because Mul(x1, x2) emits cross only.
1571        assert!(s.contains(&(0, 0)));
1572        assert!(s.contains(&(2, 1)));
1573        assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1574    }
1575
1576    /// Code review L28: a tape using an opcode the program path cannot
1577    /// lower (here `tan`) must make `compile` return `None` — a graceful
1578    /// fall-back-to-the-`Tape`-path signal — rather than panic. Previously
1579    /// the per-sweep match arms `panic!`'d on such ops, which would crash
1580    /// the process on arbitrary user `.nl` input if this path were ever
1581    /// wired into dispatch.
1582    #[test]
1583    fn unsupported_opcode_returns_none_instead_of_panicking() {
1584        // `tan(x0)` lowers to `TapeOp::Tan`, which the HessianProgram
1585        // compiler does not support.
1586        let e = unary(UnaryOp::Tan, var(0));
1587        let tape = Tape::build(&e);
1588        let (hess_map, _) = build_hess_map(&tape);
1589        assert!(
1590            HessianProgram::compile(&tape, &hess_map).is_none(),
1591            "tan() tape must fall back (None), not compile"
1592        );
1593
1594        // A fully-supported tape still compiles to `Some`, so the guard
1595        // rejects only genuinely-unsupported ops.
1596        let ok = mul(var(0), var(1));
1597        let ok_tape = Tape::build(&ok);
1598        let (ok_map, _) = build_hess_map(&ok_tape);
1599        assert!(
1600            HessianProgram::compile(&ok_tape, &ok_map).is_some(),
1601            "a supported (mul) tape must still compile"
1602        );
1603    }
1604}