Skip to main content

gam_math/
jet_tower.rs

1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//!   v        ℓ
12//!   g[a]     ∂ℓ/∂p_a
13//!   h[a][b]  ∂²ℓ/∂p_a∂p_b
14//!   t3[abc]  ∂³ℓ/∂p_a∂p_b∂p_c
15//!   t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//!    asserting channel-by-channel agreement with a `RowNllProgram` written
94//!    once — see [`verify_kernel_channels`]. This alone would have caught
95//!    #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//!    keep hand-tuned hot paths, now verified against the single-expression
98//!    truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//!    scale port) implement ONLY `RowNllProgram` and get an exact
101//!    fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113    /// Value ℓ.
114    pub v: f64,
115    /// Gradient ∂ℓ/∂p_a.
116    pub g: [f64; K],
117    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118    pub h: [[f64; K]; K],
119    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120    pub t3: [[[f64; K]; K]; K],
121    /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122    pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126    /// The additive identity.
127    pub fn zero() -> Self {
128        Self {
129            v: 0.0,
130            g: [0.0; K],
131            h: [[0.0; K]; K],
132            t3: [[[0.0; K]; K]; K],
133            t4: [[[[0.0; K]; K]; K]; K],
134        }
135    }
136
137    /// A constant: value `c`, all derivatives zero.
138    pub fn constant(c: f64) -> Self {
139        let mut out = Self::zero();
140        out.v = c;
141        out
142    }
143
144    /// The seeded variable `p_idx` with current value `value`:
145    /// unit first derivative in slot `idx`, zero elsewhere and above.
146    pub fn variable(value: f64, idx: usize) -> Self {
147        let mut out = Self::constant(value);
148        out.g[idx] = 1.0;
149        out
150    }
151
152    /// Read the (fully symmetric) derivative tensor entry whose differentiation
153    /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154    #[inline]
155    fn deriv(&self, labels: &[usize]) -> f64 {
156        assert!(
157            labels.len() <= 4,
158            "Tower4 carries at most fourth-order derivatives"
159        );
160        match labels.len() {
161            0 => self.v,
162            1 => self.g[labels[0]],
163            2 => self.h[labels[0]][labels[1]],
164            3 => self.t3[labels[0]][labels[1]][labels[2]],
165            _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166        }
167    }
168
169    /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170    ///
171    /// # Codegen
172    ///
173    /// Each output entry's `2^m` subset sum is written as a compact straight-line
174    /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175    /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176    /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177    /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178    /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179    /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180    /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181    ///
182    /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183    /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184    /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185    /// start so a signed-zero leading product collapses to `+0.0` identically —
186    /// which matters because real jets carry exact-`0.0` channels
187    /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188    /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189    /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190    /// stress — the accumulator start is load-bearing).
191    pub fn mul(&self, o: &Self) -> Self {
192        let a = self;
193        let b = o;
194        let mut out = Self::zero();
195        out.v = a.v * b.v;
196        for i in 0..K {
197            // subsets of {i}: {} {i}
198            let mut acc = 0.0;
199            acc += a.v * b.g[i];
200            acc += a.g[i] * b.v;
201            out.g[i] = acc;
202        }
203        for i in 0..K {
204            for j in 0..K {
205                // subsets of {i,j}: {} {i} {j} {ij}
206                let mut acc = 0.0;
207                acc += a.v * b.h[i][j];
208                acc += a.g[i] * b.g[j];
209                acc += a.g[j] * b.g[i];
210                acc += a.h[i][j] * b.v;
211                out.h[i][j] = acc;
212            }
213        }
214        for i in 0..K {
215            for j in 0..K {
216                for k in 0..K {
217                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
218                    let mut acc = 0.0;
219                    acc += a.v * b.t3[i][j][k];
220                    acc += a.g[i] * b.h[j][k];
221                    acc += a.g[j] * b.h[i][k];
222                    acc += a.h[i][j] * b.g[k];
223                    acc += a.g[k] * b.h[i][j];
224                    acc += a.h[i][k] * b.g[j];
225                    acc += a.h[j][k] * b.g[i];
226                    acc += a.t3[i][j][k] * b.v;
227                    out.t3[i][j][k] = acc;
228                }
229            }
230        }
231        for i in 0..K {
232            for j in 0..K {
233                for k in 0..K {
234                    for l in 0..K {
235                        // subsets of {i,j,k,l} in bit order sub = 0..16
236                        let mut acc = 0.0;
237                        acc += a.v * b.t4[i][j][k][l];
238                        acc += a.g[i] * b.t3[j][k][l];
239                        acc += a.g[j] * b.t3[i][k][l];
240                        acc += a.h[i][j] * b.h[k][l];
241                        acc += a.g[k] * b.t3[i][j][l];
242                        acc += a.h[i][k] * b.h[j][l];
243                        acc += a.h[j][k] * b.h[i][l];
244                        acc += a.t3[i][j][k] * b.g[l];
245                        acc += a.g[l] * b.t3[i][j][k];
246                        acc += a.h[i][l] * b.h[j][k];
247                        acc += a.h[j][l] * b.h[i][k];
248                        acc += a.t3[i][j][l] * b.g[k];
249                        acc += a.h[k][l] * b.h[i][j];
250                        acc += a.t3[i][k][l] * b.g[j];
251                        acc += a.t3[j][k][l] * b.g[i];
252                        acc += a.t4[i][j][k][l] * b.v;
253                        out.t4[i][j][k][l] = acc;
254                    }
255                }
256            }
257        }
258        out
259    }
260
261    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
262    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
263    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
264    /// out of the borrowed operands.
265    pub fn add(&self, o: &Self) -> Self {
266        *self + *o
267    }
268
269    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
270    pub fn sub(&self, o: &Self) -> Self {
271        *self + o.scale(-1.0)
272    }
273
274    /// Exact multivariate Faà di Bruno composition `f ∘ self`.
275    ///
276    /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
277    /// the SAME `[f64; 5]` stack shape the families' existing
278    /// `unary_derivatives_*` helpers produce, so those special-function
279    /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
280    ///
281    /// The order-m output sums over the set partitions of the m indices
282    /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
283    /// block count: each partition into r blocks contributes
284    /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
285    ///
286    /// # Codegen
287    ///
288    /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
289    /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
290    /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
291    /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
292    /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
293    /// sum is straight-line, so this does NOT unroll over `K` and does NOT
294    /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
295    /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
296    /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
297    ///
298    /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
299    /// walker's exact partition-enumeration order, each term's block products
300    /// are left-associated exactly as the walker's `prod *= block`, and the
301    /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
302    /// (so signed-zero products collapse to `+0.0` identically). The order-4
303    /// term sequence was generated from the walker's own enumeration. Proven
304    /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
305    /// 5000 random inputs each (zeroed / sign-varied stacks included).
306    pub fn compose_unary(&self, d: [f64; 5]) -> Self {
307        let mut out = Self::zero();
308        out.v = d[0];
309        for i in 0..K {
310            let mut acc = 0.0;
311            acc += d[1] * self.g[i];
312            out.g[i] = acc;
313        }
314        for i in 0..K {
315            for j in 0..K {
316                let mut acc = 0.0;
317                acc += d[1] * self.h[i][j];
318                acc += d[2] * self.g[i] * self.g[j];
319                out.h[i][j] = acc;
320            }
321        }
322        for i in 0..K {
323            for j in 0..K {
324                for k in 0..K {
325                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
326                    let mut acc = 0.0;
327                    acc += d[1] * self.t3[i][j][k];
328                    acc += d[2] * self.h[i][j] * self.g[k];
329                    acc += d[2] * self.h[i][k] * self.g[j];
330                    acc += d[2] * self.g[i] * self.h[j][k];
331                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
332                    out.t3[i][j][k] = acc;
333                }
334            }
335        }
336        for i in 0..K {
337            for j in 0..K {
338                for k in 0..K {
339                    for l in 0..K {
340                        // Bell(4)=15 partitions, walker enumeration order.
341                        let mut acc = 0.0;
342                        acc += d[1] * self.t4[i][j][k][l];
343                        acc += d[2] * self.t3[i][j][k] * self.g[l];
344                        acc += d[2] * self.t3[i][j][l] * self.g[k];
345                        acc += d[2] * self.h[i][j] * self.h[k][l];
346                        acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
347                        acc += d[2] * self.t3[i][k][l] * self.g[j];
348                        acc += d[2] * self.h[i][k] * self.h[j][l];
349                        acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
350                        acc += d[2] * self.h[i][l] * self.h[j][k];
351                        acc += d[2] * self.g[i] * self.t3[j][k][l];
352                        acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
353                        acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
354                        acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
355                        acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
356                        acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
357                        out.t4[i][j][k][l] = acc;
358                    }
359                }
360            }
361        }
362        out
363    }
364
365    /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
366    /// built from the base value through `stack_fn` — the scalar arm of the
367    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
368    /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
369    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
370    /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
371    /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
372    /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
373    #[inline]
374    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
375        self.compose_unary(stack_fn(self.v))
376    }
377
378    /// Single-active-slot fast path for [`Self::compose_unary`].
379    ///
380    /// When the inner jet `self` has derivative support ONLY on the all-`slot`
381    /// diagonal channels — i.e. it is a univariate jet in primary `slot`
382    /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
383    /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
384    /// Bruno walk collapses. Every output channel whose axis tuple contains an
385    /// axis `≠ slot` is structurally `0`: each set-partition has a block
386    /// covering that axis, that block reads an off-`slot` derivative of `self`
387    /// (which is `0`), so the block product and the whole partition vanish, and
388    /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
389    /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
390    /// `t3[slot]³`, `t4[slot]⁴`) survive.
391    ///
392    /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
393    /// the EXACT term order of [`Self::compose_unary`]'s diagonal
394    /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
395    /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
396    /// at the zero-init `+0.0`, which the full walk also produces (the
397    /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
398    /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
399    /// `K ≥ 2` this is far fewer floating-point operations than materialising
400    /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
401    /// are all zero, and far cheaper than the recursive set-partition walker the
402    /// diagonal channels previously routed through (a measured ~9.5× speedup vs
403    /// the full `compose_unary`, recovering a 5.9× walker regression at the
404    /// `K ∈ {2,3}` BMS tower widths).
405    ///
406    /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
407    /// five-channel build does not amortise the call/spill overhead).
408    ///
409    /// # Precondition
410    ///
411    /// The caller guarantees the single-active-slot structure. If it does not
412    /// hold, the off-`slot` channels would be wrongly zeroed; use the full
413    /// [`Self::compose_unary`] in that case.
414    #[inline]
415    pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
416        let mut out = Self::zero();
417        let s = slot;
418        let g = self.g[s];
419        let h = self.h[s][s];
420        let t3 = self.t3[s][s][s];
421        let t4 = self.t4[s][s][s][s];
422        out.v = d[0];
423        // g (i=s): d1*g
424        out.g[s] = {
425            let mut acc = 0.0;
426            acc += d[1] * g;
427            acc
428        };
429        // h (i=j=s): d1*h + d2*g*g
430        out.h[s][s] = {
431            let mut acc = 0.0;
432            acc += d[1] * h;
433            acc += d[2] * g * g;
434            acc
435        };
436        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
437        out.t3[s][s][s] = {
438            let mut acc = 0.0;
439            acc += d[1] * t3;
440            acc += d[2] * h * g;
441            acc += d[2] * h * g;
442            acc += d[2] * g * h;
443            acc += d[3] * g * g * g;
444            acc
445        };
446        // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
447        out.t4[s][s][s][s] = {
448            let mut acc = 0.0;
449            acc += d[1] * t4;
450            acc += d[2] * t3 * g;
451            acc += d[2] * t3 * g;
452            acc += d[2] * h * h;
453            acc += d[3] * h * g * g;
454            acc += d[2] * t3 * g;
455            acc += d[2] * h * h;
456            acc += d[3] * h * g * g;
457            acc += d[2] * h * h;
458            acc += d[2] * g * t3;
459            acc += d[3] * g * h * g;
460            acc += d[3] * h * g * g;
461            acc += d[3] * g * h * g;
462            acc += d[3] * g * g * h;
463            acc += d[4] * g * g * g * g;
464            acc
465        };
466        out
467    }
468
469    /// Multiply every channel by a plain scalar.
470    pub fn scale(&self, s: f64) -> Self {
471        let mut out = *self;
472        out.v *= s;
473        for i in 0..K {
474            out.g[i] *= s;
475            for j in 0..K {
476                out.h[i][j] *= s;
477                for k in 0..K {
478                    out.t3[i][j][k] *= s;
479                    for l in 0..K {
480                        out.t4[i][j][k][l] *= s;
481                    }
482                }
483            }
484        }
485        out
486    }
487
488    /// e^self.
489    pub fn exp(&self) -> Self {
490        let e = self.v.exp();
491        self.compose_unary([e, e, e, e, e])
492    }
493
494    /// ln(self). Caller guarantees positivity (likelihood programs do).
495    pub fn ln(&self) -> Self {
496        let u = self.v;
497        let r = 1.0 / u;
498        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
499    }
500
501    /// 1/self.
502    pub fn recip(&self) -> Self {
503        let r = 1.0 / self.v;
504        let r2 = r * r;
505        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
506    }
507
508    /// √self. Caller guarantees positivity.
509    pub fn sqrt(&self) -> Self {
510        let u = self.v;
511        let s = u.sqrt();
512        self.compose_unary([
513            s,
514            0.5 / s,
515            -0.25 / (u * s),
516            0.375 / (u * u * s),
517            -0.9375 / (u * u * u * s),
518        ])
519    }
520
521    /// self^a for real exponent `a`. Caller guarantees a positive base.
522    pub fn powf(&self, a: f64) -> Self {
523        let u = self.v;
524        let f0 = u.powf(a);
525        let f1 = a * u.powf(a - 1.0);
526        let f2 = a * (a - 1.0) * u.powf(a - 2.0);
527        let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
528        let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
529        self.compose_unary([f0, f1, f2, f3, f4])
530    }
531
532    /// ln Γ(self). Caller guarantees positivity.
533    pub fn ln_gamma(&self) -> Self {
534        self.compose_unary(ln_gamma_derivative_stack(self.v))
535    }
536
537    /// ψ(self), the digamma function. Caller guarantees positivity.
538    pub fn digamma(&self) -> Self {
539        self.compose_unary(digamma_derivative_stack(self.v))
540    }
541
542    /// ψ′(self), the trigamma function. Caller guarantees positivity.
543    pub fn trigamma(&self) -> Self {
544        self.compose_unary(trigamma_derivative_stack(self.v))
545    }
546
547    /// Contract `t3` with one primary-space direction:
548    /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
549    /// `row_third_contracted` shape.
550    ///
551    /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
552    /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
553    /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
554    /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
555    /// stays contiguous/vectorisable) and mirror into the lower triangle — this
556    /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
557    /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
558    pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
559        let mut out = [[0.0; K]; K];
560        for a in 0..K {
561            for b in a..K {
562                let mut acc = 0.0;
563                for c in 0..K {
564                    acc += self.t3[a][b][c] * dir[c];
565                }
566                out[a][b] = acc;
567                out[b][a] = acc;
568            }
569        }
570        out
571    }
572
573    /// Contract `t4` with two primary-space directions:
574    /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
575    /// `row_fourth_contracted` shape.
576    ///
577    /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
578    /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
579    /// order), so the upper triangle `i ≤ j` is computed and mirrored —
580    /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
581    /// and the inner double loop stays the original contiguous/vectorisable form.
582    pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
583        let mut out = [[0.0; K]; K];
584        for i in 0..K {
585            for j in i..K {
586                let mut acc = 0.0;
587                for k in 0..K {
588                    for l in 0..K {
589                        acc += self.t4[i][j][k][l] * u[k] * w[l];
590                    }
591                }
592                out[i][j] = acc;
593                out[j][i] = acc;
594            }
595        }
596        out
597    }
598}
599
600impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
601    #[inline]
602    fn derivative(&self, labels: &[usize]) -> f64 {
603        self.deriv(labels)
604    }
605
606    fn map_derivatives<F>(&self, mut f: F) -> Self
607    where
608        F: FnMut(&[usize]) -> f64,
609    {
610        let mut out = Self::zero();
611        out.v = f(&[]);
612        for i in 0..K {
613            let labels = [i];
614            out.g[i] = f(&labels);
615        }
616        for i in 0..K {
617            for j in 0..K {
618                let labels = [i, j];
619                out.h[i][j] = f(&labels);
620            }
621        }
622        for i in 0..K {
623            for j in 0..K {
624                for k in 0..K {
625                    let labels = [i, j, k];
626                    out.t3[i][j][k] = f(&labels);
627                }
628            }
629        }
630        for i in 0..K {
631            for j in 0..K {
632                for k in 0..K {
633                    for l in 0..K {
634                        let labels = [i, j, k, l];
635                        out.t4[i][j][k][l] = f(&labels);
636                    }
637                }
638            }
639        }
640        out
641    }
642}
643
644/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
645///
646/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
647/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
648/// [`Tower4`] uses for those orders, so for any program written over both
649/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
650/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
651/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
652/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
653/// gradient, or Hessian.
654///
655/// It exists purely for performance: an inner Newton step (and the
656/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
657/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
658/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
659/// dominates the cold marginal-slope fit, while returning the exact same
660/// `(v, g, h)`.
661#[derive(Clone, Copy, Debug)]
662pub struct Tower2<const K: usize> {
663    /// Value ℓ.
664    pub v: f64,
665    /// Gradient ∂ℓ/∂p_a.
666    pub g: [f64; K],
667    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
668    pub h: [[f64; K]; K],
669}
670
671impl<const K: usize> Tower2<K> {
672    /// The additive identity.
673    pub fn zero() -> Self {
674        Self {
675            v: 0.0,
676            g: [0.0; K],
677            h: [[0.0; K]; K],
678        }
679    }
680
681    /// A constant: value `c`, all derivatives zero.
682    pub fn constant(c: f64) -> Self {
683        let mut out = Self::zero();
684        out.v = c;
685        out
686    }
687
688    /// The seeded variable `p_idx` with current value `value`:
689    /// unit first derivative in slot `idx`, zero elsewhere and above.
690    pub fn variable(value: f64, idx: usize) -> Self {
691        let mut out = Self::constant(value);
692        out.g[idx] = 1.0;
693        out
694    }
695
696    /// Read the derivative tensor entry whose differentiation axes are
697    /// `labels` (length 0..=2): value, `g`, `h`.
698    #[inline]
699    fn deriv(&self, labels: &[usize]) -> f64 {
700        assert!(
701            labels.len() <= 2,
702            "Tower2 carries at most second-order derivatives"
703        );
704        match labels.len() {
705            0 => self.v,
706            1 => self.g[labels[0]],
707            _ => self.h[labels[0]][labels[1]],
708        }
709    }
710
711    /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` channels
712    /// match [`Tower4::mul`] term-for-term.
713    pub fn mul(&self, o: &Self) -> Self {
714        let a = self;
715        let b = o;
716        let mut out = Self::zero();
717        out.v = a.v * b.v;
718        for i in 0..K {
719            out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
720        }
721        for i in 0..K {
722            for j in 0..K {
723                out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
724            }
725        }
726        out
727    }
728
729    /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
730    ///
731    /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
732    /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
733    /// `d[0..=2]` for those orders), so this is a strict truncation, not an
734    /// approximation. The full-order `[f64; 5]` derivative stacks the families
735    /// already produce can be passed by slicing their first three entries.
736    ///
737    /// # Codegen
738    ///
739    /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
740    /// directly instead of routing through the generic
741    /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
742    /// closure dispatch). That matters because this is the kernel under EVERY
743    /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
744    /// composition all bottom out here — so the straight-line form (whose inner
745    /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
746    /// walker calls) lifts all of them at once.
747    ///
748    /// The term and accumulation order is BIT-IDENTICAL to the walker it
749    /// replaces: each output channel mirrors the walker's `total = 0.0` start
750    /// (the explicit `acc` accumulator), so a signed-zero product collapses to
751    /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
752    /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
753    /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
754    ///   `g[i]`   = `f′·u_i`                   (single block `{i}`)
755    ///   `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j`  (blocks `{ij}` then `{i}{j}`),
756    /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
757    pub fn compose_unary(&self, d: [f64; 3]) -> Self {
758        let mut out = Self::zero();
759        out.v = d[0];
760        for i in 0..K {
761            let mut acc = 0.0;
762            acc += d[1] * self.g[i];
763            out.g[i] = acc;
764        }
765        for i in 0..K {
766            for j in 0..K {
767                let mut acc = 0.0;
768                acc += d[1] * self.h[i][j];
769                acc += d[2] * self.g[i] * self.g[j];
770                out.h[i][j] = acc;
771            }
772        }
773        out
774    }
775
776    /// Multiply every channel by a plain scalar.
777    pub fn scale(&self, s: f64) -> Self {
778        let mut out = *self;
779        out.v *= s;
780        for i in 0..K {
781            out.g[i] *= s;
782            for j in 0..K {
783                out.h[i][j] *= s;
784            }
785        }
786        out
787    }
788
789    /// e^self.
790    pub fn exp(&self) -> Self {
791        let e = self.v.exp();
792        self.compose_unary([e, e, e])
793    }
794
795    /// √self. Caller guarantees positivity.
796    pub fn sqrt(&self) -> Self {
797        let u = self.v;
798        let s = u.sqrt();
799        self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
800    }
801}
802
803impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
804    #[inline]
805    fn derivative(&self, labels: &[usize]) -> f64 {
806        self.deriv(labels)
807    }
808
809    fn map_derivatives<F>(&self, mut f: F) -> Self
810    where
811        F: FnMut(&[usize]) -> f64,
812    {
813        let mut out = Self::zero();
814        out.v = f(&[]);
815        for i in 0..K {
816            let labels = [i];
817            out.g[i] = f(&labels);
818        }
819        for i in 0..K {
820            for j in 0..K {
821                let labels = [i, j];
822                out.h[i][j] = f(&labels);
823            }
824        }
825        out
826    }
827}
828
829impl<const K: usize> std::ops::Add for Tower2<K> {
830    type Output = Self;
831    fn add(self, o: Self) -> Self {
832        let mut out = self;
833        out.v += o.v;
834        for i in 0..K {
835            out.g[i] += o.g[i];
836            for j in 0..K {
837                out.h[i][j] += o.h[i][j];
838            }
839        }
840        out
841    }
842}
843
844impl<const K: usize> std::ops::Mul for Tower2<K> {
845    type Output = Self;
846    fn mul(self, o: Self) -> Self {
847        Tower2::mul(&self, &o)
848    }
849}
850
851impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
852    type Output = Self;
853    fn add(self, c: f64) -> Self {
854        let mut out = self;
855        out.v += c;
856        out
857    }
858}
859
860impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
861    type Output = Self;
862    fn mul(self, c: f64) -> Self {
863        self.scale(c)
864    }
865}
866
867/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
868///
869/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
870/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
871/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
872/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
873/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
874/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
875/// [`Tower4::compose_unary`]). So for any program written over both towers the
876/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
877/// perturb the value, gradient, Hessian, or third derivatives.
878///
879/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
880/// needs up to third derivatives (the survival location-scale row kernel reads
881/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
882/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
883/// product/composition that otherwise dominates the per-row cost.
884#[derive(Clone, Copy, Debug)]
885pub struct Tower3<const K: usize> {
886    /// Value ℓ.
887    pub v: f64,
888    /// Gradient ∂ℓ/∂p_a.
889    pub g: [f64; K],
890    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
891    pub h: [[f64; K]; K],
892    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
893    pub t3: [[[f64; K]; K]; K],
894}
895
896impl<const K: usize> Tower3<K> {
897    /// The additive identity.
898    pub fn zero() -> Self {
899        Self {
900            v: 0.0,
901            g: [0.0; K],
902            h: [[0.0; K]; K],
903            t3: [[[0.0; K]; K]; K],
904        }
905    }
906
907    /// A constant: value `c`, all derivatives zero.
908    pub fn constant(c: f64) -> Self {
909        let mut out = Self::zero();
910        out.v = c;
911        out
912    }
913
914    /// The seeded variable `p_idx` with current value `value`:
915    /// unit first derivative in slot `idx`, zero elsewhere and above.
916    pub fn variable(value: f64, idx: usize) -> Self {
917        let mut out = Self::constant(value);
918        out.g[idx] = 1.0;
919        out
920    }
921
922    /// Read the (fully symmetric) derivative tensor entry whose differentiation
923    /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
924    #[inline]
925    fn deriv(&self, labels: &[usize]) -> f64 {
926        assert!(
927            labels.len() <= 3,
928            "Tower3 carries at most third-order derivatives"
929        );
930        match labels.len() {
931            0 => self.v,
932            1 => self.g[labels[0]],
933            2 => self.h[labels[0]][labels[1]],
934            _ => self.t3[labels[0]][labels[1]][labels[2]],
935        }
936    }
937
938    /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
939    /// channels match [`Tower4::mul`] term-for-term.
940    ///
941    /// # Codegen
942    ///
943    /// Straight-line per-entry subset sums instead of the
944    /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
945    /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
946    /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
947    /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
948    /// signed-zero leading product on exact-`0.0` jet channels). Proven
949    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
950    /// zero/sign-stressed inputs each (these channel formulas are exactly the
951    /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
952    pub fn mul(&self, o: &Self) -> Self {
953        let a = self;
954        let b = o;
955        let mut out = Self::zero();
956        out.v = a.v * b.v;
957        for i in 0..K {
958            let mut acc = 0.0;
959            acc += a.v * b.g[i];
960            acc += a.g[i] * b.v;
961            out.g[i] = acc;
962        }
963        for i in 0..K {
964            for j in 0..K {
965                let mut acc = 0.0;
966                acc += a.v * b.h[i][j];
967                acc += a.g[i] * b.g[j];
968                acc += a.g[j] * b.g[i];
969                acc += a.h[i][j] * b.v;
970                out.h[i][j] = acc;
971            }
972        }
973        for i in 0..K {
974            for j in 0..K {
975                for k in 0..K {
976                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
977                    let mut acc = 0.0;
978                    acc += a.v * b.t3[i][j][k];
979                    acc += a.g[i] * b.h[j][k];
980                    acc += a.g[j] * b.h[i][k];
981                    acc += a.h[i][j] * b.g[k];
982                    acc += a.g[k] * b.h[i][j];
983                    acc += a.h[i][k] * b.g[j];
984                    acc += a.h[j][k] * b.g[i];
985                    acc += a.t3[i][j][k] * b.v;
986                    out.t3[i][j][k] = acc;
987                }
988            }
989        }
990        out
991    }
992
993    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
994    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
995    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
996    /// out of the borrowed operands.
997    pub fn add(&self, o: &Self) -> Self {
998        *self + *o
999    }
1000
1001    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1002    pub fn sub(&self, o: &Self) -> Self {
1003        *self + o.scale(-1.0)
1004    }
1005
1006    /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1007    ///
1008    /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1009    /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1010    /// (which uses only `d[0..=3]` for those orders), so this is a strict
1011    /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1012    /// stacks the families already produce can be passed by slicing their first
1013    /// four entries.
1014    ///
1015    /// # Codegen
1016    ///
1017    /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1018    /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1019    /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1020    /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1021    /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1022    /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1023    /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1024    /// associated block products, `acc = 0.0` accumulator start. Proven
1025    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1026    /// random inputs each.
1027    pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1028        let mut out = Self::zero();
1029        out.v = d[0];
1030        for i in 0..K {
1031            let mut acc = 0.0;
1032            acc += d[1] * self.g[i];
1033            out.g[i] = acc;
1034        }
1035        for i in 0..K {
1036            for j in 0..K {
1037                let mut acc = 0.0;
1038                acc += d[1] * self.h[i][j];
1039                acc += d[2] * self.g[i] * self.g[j];
1040                out.h[i][j] = acc;
1041            }
1042        }
1043        for i in 0..K {
1044            for j in 0..K {
1045                for k in 0..K {
1046                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1047                    let mut acc = 0.0;
1048                    acc += d[1] * self.t3[i][j][k];
1049                    acc += d[2] * self.h[i][j] * self.g[k];
1050                    acc += d[2] * self.h[i][k] * self.g[j];
1051                    acc += d[2] * self.g[i] * self.h[j][k];
1052                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1053                    out.t3[i][j][k] = acc;
1054                }
1055            }
1056        }
1057        out
1058    }
1059
1060    /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1061    /// built from the base value through `stack_fn` — the scalar arm of the
1062    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1063    /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1064    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1065    /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1066    /// [`Tower4::compose_unary_with`].
1067    #[inline]
1068    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1069        self.compose_unary(stack_fn(self.v))
1070    }
1071
1072    /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1073    /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1074    /// derivative support only on the all-`slot` diagonal, every output channel
1075    /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1076    /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1077    /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1078    /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1079    /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1080    /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1081    /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1082    /// set-partition walker the diagonal channels previously routed through,
1083    /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1084    /// widths. Caller guarantees the single-slot precondition; otherwise use
1085    /// [`Self::compose_unary`].
1086    #[inline]
1087    pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1088        let mut out = Self::zero();
1089        let s = slot;
1090        let g = self.g[s];
1091        let h = self.h[s][s];
1092        let t3 = self.t3[s][s][s];
1093        out.v = d[0];
1094        // g (i=s): d1*g
1095        out.g[s] = {
1096            let mut acc = 0.0;
1097            acc += d[1] * g;
1098            acc
1099        };
1100        // h (i=j=s): d1*h + d2*g*g
1101        out.h[s][s] = {
1102            let mut acc = 0.0;
1103            acc += d[1] * h;
1104            acc += d[2] * g * g;
1105            acc
1106        };
1107        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1108        out.t3[s][s][s] = {
1109            let mut acc = 0.0;
1110            acc += d[1] * t3;
1111            acc += d[2] * h * g;
1112            acc += d[2] * h * g;
1113            acc += d[2] * g * h;
1114            acc += d[3] * g * g * g;
1115            acc
1116        };
1117        out
1118    }
1119
1120    /// Multiply every channel by a plain scalar.
1121    pub fn scale(&self, s: f64) -> Self {
1122        let mut out = *self;
1123        out.v *= s;
1124        for i in 0..K {
1125            out.g[i] *= s;
1126            for j in 0..K {
1127                out.h[i][j] *= s;
1128                for k in 0..K {
1129                    out.t3[i][j][k] *= s;
1130                }
1131            }
1132        }
1133        out
1134    }
1135}
1136
1137impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1138    #[inline]
1139    fn derivative(&self, labels: &[usize]) -> f64 {
1140        self.deriv(labels)
1141    }
1142
1143    fn map_derivatives<F>(&self, mut f: F) -> Self
1144    where
1145        F: FnMut(&[usize]) -> f64,
1146    {
1147        let mut out = Self::zero();
1148        out.v = f(&[]);
1149        for i in 0..K {
1150            let labels = [i];
1151            out.g[i] = f(&labels);
1152        }
1153        for i in 0..K {
1154            for j in 0..K {
1155                let labels = [i, j];
1156                out.h[i][j] = f(&labels);
1157            }
1158        }
1159        for i in 0..K {
1160            for j in 0..K {
1161                for k in 0..K {
1162                    let labels = [i, j, k];
1163                    out.t3[i][j][k] = f(&labels);
1164                }
1165            }
1166        }
1167        out
1168    }
1169}
1170
1171impl<const K: usize> std::ops::Add for Tower3<K> {
1172    type Output = Self;
1173    fn add(self, o: Self) -> Self {
1174        let mut out = self;
1175        out.v += o.v;
1176        for i in 0..K {
1177            out.g[i] += o.g[i];
1178            for j in 0..K {
1179                out.h[i][j] += o.h[i][j];
1180                for k in 0..K {
1181                    out.t3[i][j][k] += o.t3[i][j][k];
1182                }
1183            }
1184        }
1185        out
1186    }
1187}
1188
1189pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1190    [
1191        statrs::function::gamma::ln_gamma(x),
1192        digamma_positive(x),
1193        polygamma_positive(1, x),
1194        polygamma_positive(2, x),
1195        polygamma_positive(3, x),
1196    ]
1197}
1198
1199pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1200    [
1201        statrs::function::gamma::ln_gamma(x),
1202        digamma_positive(x),
1203        polygamma_positive(1, x),
1204    ]
1205}
1206
1207pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1208    [
1209        digamma_positive(x),
1210        polygamma_positive(1, x),
1211        polygamma_positive(2, x),
1212        polygamma_positive(3, x),
1213        polygamma_positive(4, x),
1214    ]
1215}
1216
1217pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1218    [
1219        polygamma_positive(1, x),
1220        polygamma_positive(2, x),
1221        polygamma_positive(3, x),
1222        polygamma_positive(4, x),
1223        polygamma_positive(5, x),
1224    ]
1225}
1226
1227fn digamma_positive(mut x: f64) -> f64 {
1228    if !(x.is_finite() && x > 0.0) {
1229        return f64::NAN;
1230    }
1231    let mut acc = 0.0;
1232    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1233        acc -= 1.0 / x;
1234        x += 1.0;
1235    }
1236    acc + digamma_asymptotic(x)
1237}
1238
1239fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1240    if !(x.is_finite() && x > 0.0) {
1241        return f64::NAN;
1242    }
1243    let mut acc = 0.0;
1244    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1245        acc += polygamma_recurrence_term(order, x);
1246        x += 1.0;
1247    }
1248    acc + polygamma_asymptotic(order, x)
1249}
1250
1251const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1252const BERNOULLI_EVEN: [(usize, f64); 10] = [
1253    (2, 1.0 / 6.0),
1254    (4, -1.0 / 30.0),
1255    (6, 1.0 / 42.0),
1256    (8, -1.0 / 30.0),
1257    (10, 5.0 / 66.0),
1258    (12, -691.0 / 2730.0),
1259    (14, 7.0 / 6.0),
1260    (16, -3617.0 / 510.0),
1261    (18, 43867.0 / 798.0),
1262    (20, -174611.0 / 330.0),
1263];
1264
1265fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1266    let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1267    sign * factorial(order) / x.powi((order + 1) as i32)
1268}
1269
1270fn digamma_asymptotic(x: f64) -> f64 {
1271    let mut out = x.ln() - 0.5 / x;
1272    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1273        out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1274    }
1275    out
1276}
1277
1278fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1279    if !(1..=5).contains(&order) {
1280        return f64::NAN;
1281    }
1282
1283    let order_factorial = factorial(order);
1284    let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1285    let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1286        + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1287
1288    let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1289    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1290        let rising = rising_factorial(bernoulli_order, order);
1291        out += bernoulli_sign * bernoulli * rising
1292            / bernoulli_order as f64
1293            / x.powi((bernoulli_order + order) as i32);
1294    }
1295    out
1296}
1297
1298fn factorial(n: usize) -> f64 {
1299    (1..=n).fold(1.0, |acc, k| acc * k as f64)
1300}
1301
1302fn rising_factorial(start: usize, len: usize) -> f64 {
1303    (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1304}
1305
1306impl<const K: usize> std::ops::Add for Tower4<K> {
1307    type Output = Self;
1308    fn add(self, o: Self) -> Self {
1309        let mut out = self;
1310        out.v += o.v;
1311        for i in 0..K {
1312            out.g[i] += o.g[i];
1313            for j in 0..K {
1314                out.h[i][j] += o.h[i][j];
1315                for k in 0..K {
1316                    out.t3[i][j][k] += o.t3[i][j][k];
1317                    for l in 0..K {
1318                        out.t4[i][j][k][l] += o.t4[i][j][k][l];
1319                    }
1320                }
1321            }
1322        }
1323        out
1324    }
1325}
1326
1327impl<const K: usize> std::ops::Sub for Tower4<K> {
1328    type Output = Self;
1329    fn sub(self, o: Self) -> Self {
1330        self + o.scale(-1.0)
1331    }
1332}
1333
1334impl<const K: usize> std::ops::Neg for Tower4<K> {
1335    type Output = Self;
1336    fn neg(self) -> Self {
1337        self.scale(-1.0)
1338    }
1339}
1340
1341impl<const K: usize> std::ops::Mul for Tower4<K> {
1342    type Output = Self;
1343    fn mul(self, o: Self) -> Self {
1344        Tower4::mul(&self, &o)
1345    }
1346}
1347
1348impl<const K: usize> std::ops::Div for Tower4<K> {
1349    type Output = Self;
1350    fn div(self, o: Self) -> Self {
1351        Tower4::mul(&self, &o.recip())
1352    }
1353}
1354
1355impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1356    type Output = Self;
1357    fn add(self, c: f64) -> Self {
1358        let mut out = self;
1359        out.v += c;
1360        out
1361    }
1362}
1363
1364impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1365    type Output = Self;
1366    fn sub(self, c: f64) -> Self {
1367        self + (-c)
1368    }
1369}
1370
1371impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1372    type Output = Self;
1373    fn mul(self, c: f64) -> Self {
1374        self.scale(c)
1375    }
1376}
1377
1378// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1379//
1380// The flexible survival marginal-slope row loss is NOT a free composition
1381// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1382// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1383// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1384// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1385// last hand-written one in the codebase, and the genus of #736-class
1386// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1387// this reason). These two combinators close that gap: once the constraint
1388// `F` and the integrand/boundaries are themselves towers, the intercept's
1389// derivative tower and the integral's derivative tower come out EXACTLY at
1390// every order — there is no order left to hand-code and forget.
1391
1392/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1393/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1394/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1395/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1396/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1397/// and `a0` is that solved intercept value.
1398///
1399/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1400/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1401/// derivative. This is the mechanical replacement for the hand-coded
1402/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1403/// recursion (first_full.rs) and its third/fourth-order continuations.
1404///
1405/// Method: order-by-order substitution. We build `a` incrementally; at each
1406/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1407/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1408/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1409/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1410/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1411/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1412/// auditable and exact, not a hand-expanded formula per order.
1413///
1414/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1415/// solve's strict monotonicity guard.
1416///
1417/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1418/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1419/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1420/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1421/// is guarded explicitly and re-verified by a composed-residual self-check.
1422pub fn implicit_solve<const K1: usize, const K: usize>(
1423    f: &Tower4<K1>,
1424    a0: f64,
1425) -> Result<Tower4<K>, String> {
1426    assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1427    let f_a = f.g[0];
1428    if f_a == 0.0 || !f_a.is_finite() {
1429        return Err(format!(
1430            "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1431        ));
1432    }
1433    // The expansion point must be a genuine root of F. The single Newton
1434    // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1435    // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1436    // Newton step (rather than f.v directly) makes the criterion invariant to
1437    // the magnitude of f_a / the units of F.
1438    let root_tol = 1e-9;
1439    if !f.v.is_finite() {
1440        return Err(format!(
1441            "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1442            f.v
1443        ));
1444    }
1445    let newton_step = f.v.abs() / f_a.abs();
1446    if newton_step > root_tol * (1.0 + a0.abs()) {
1447        return Err(format!(
1448            "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1449             F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1450             root_tol {root_tol:.1e} · (1 + |a0|)",
1451            f.v
1452        ));
1453    }
1454    // Start with a = constant a0 (correct through order 0). Then lift each
1455    // order in turn. Because substitute_intercept reads `a`'s order-≤m
1456    // tensors when forming G's order-m coefficient, and the order-m
1457    // coefficient of G depends on a's order-m tensor ONLY through the linear
1458    // F_a·a_m term, a single corrective pass per order is exact.
1459    let mut a = Tower4::<K>::constant(a0);
1460    for order in 1..=4 {
1461        let g = substitute_intercept(f, &a);
1462        // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1463        // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1464        // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1465        // as a plain variable in the substitution's first-order part).
1466        match order {
1467            1 => {
1468                for i in 0..K {
1469                    a.g[i] -= g.g[i] / f_a;
1470                }
1471            }
1472            2 => {
1473                for i in 0..K {
1474                    for j in 0..K {
1475                        a.h[i][j] -= g.h[i][j] / f_a;
1476                    }
1477                }
1478            }
1479            3 => {
1480                for i in 0..K {
1481                    for j in 0..K {
1482                        for k in 0..K {
1483                            a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1484                        }
1485                    }
1486                }
1487            }
1488            _ => {
1489                for i in 0..K {
1490                    for j in 0..K {
1491                        for k in 0..K {
1492                            for l in 0..K {
1493                                a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1494                            }
1495                        }
1496                    }
1497                }
1498            }
1499        }
1500    }
1501    // Self-check: the composed residual G = F∘a must vanish through order 4.
1502    // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1503    // is exactly the root requirement guarded above. Re-verify all channels
1504    // against a scale-aware floor so any arithmetic regression in the
1505    // substitution recursion is loud rather than silently shipping a
1506    // level-set expansion.
1507    let g = substitute_intercept(f, &a);
1508    let resid_tol = 1e-7 * (1.0 + f_a.abs());
1509    let mut worst = g.v.abs();
1510    for i in 0..K {
1511        worst = worst.max(g.g[i].abs());
1512        for j in 0..K {
1513            worst = worst.max(g.h[i][j].abs());
1514            for k in 0..K {
1515                worst = worst.max(g.t3[i][j][k].abs());
1516                for l in 0..K {
1517                    worst = worst.max(g.t4[i][j][k][l].abs());
1518                }
1519            }
1520        }
1521    }
1522    if !worst.is_finite() || worst > resid_tol {
1523        return Err(format!(
1524            "implicit_solve: composed residual G = F∘a does not vanish: \
1525             worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1526        ));
1527    }
1528    Ok(a)
1529}
1530
1531/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1532/// written over `K + 1` variables, returning the composite tower over the
1533/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1534///
1535/// This is the exact multivariate chain rule specialised to "slot 0 is a
1536/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1537/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1538/// point, with the slot-0 increment being the non-constant part of `a` and
1539/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1540/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1541/// so it carries derivatives exactly through order four.
1542pub fn substitute_intercept<const K1: usize, const K: usize>(
1543    f: &Tower4<K1>,
1544    a: &Tower4<K>,
1545) -> Tower4<K> {
1546    assert_eq!(K1, K + 1);
1547    // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1548    // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1549    // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1550    // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1551    // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1552    // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1553    // using tower products for the increment monomials so all θ-derivatives
1554    // propagate exactly.
1555    let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1556        if slot == 0 {
1557            // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1558            let mut d = *a;
1559            d.v = 0.0;
1560            d
1561        } else {
1562            // slot i: the increment δθ_{i-1} = seeded variable minus value.
1563            // θ centred at its expansion value has zero constant term and unit
1564            // first derivative in its own slot.
1565            let mut d = Tower4::<K>::zero();
1566            d.g[slot - 1] = 1.0;
1567            d
1568        }
1569    });
1570    // Accumulate the Taylor sum. order-0 term:
1571    let mut out = Tower4::<K>::constant(f.v);
1572    // order 1: Σ_a f.g[a] · inp[a]
1573    for a_idx in 0..K1 {
1574        out = out + inp[a_idx].scale(f.g[a_idx]);
1575    }
1576    // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1577    for a_idx in 0..K1 {
1578        for b_idx in 0..K1 {
1579            let prod = inp[a_idx].mul(&inp[b_idx]);
1580            out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1581        }
1582    }
1583    // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1584    for a_idx in 0..K1 {
1585        for b_idx in 0..K1 {
1586            for c_idx in 0..K1 {
1587                let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1588                out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1589            }
1590        }
1591    }
1592    // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1593    for a_idx in 0..K1 {
1594        for b_idx in 0..K1 {
1595            for c_idx in 0..K1 {
1596                for d_idx in 0..K1 {
1597                    let prod = inp[a_idx]
1598                        .mul(&inp[b_idx])
1599                        .mul(&inp[c_idx])
1600                        .mul(&inp[d_idx]);
1601                    out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1602                }
1603            }
1604        }
1605    }
1606    out
1607}
1608
1609/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1610/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1611/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1612/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1613/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1614///
1615/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1616/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1617/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1618/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1619/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1620/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1621/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1622/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1623/// to a constant); callers read only the derivative channels and pair this
1624/// with the interior moment-integral value separately.
1625///
1626/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1627/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1628/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1629pub fn moving_limit_boundary_tower<const K: usize>(
1630    z_edge: &Tower4<K>,
1631    b_stack: [f64; 4],
1632) -> Tower4<K> {
1633    z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1634}
1635
1636/// The boundary-flux derivative tower of a single moving cell integral
1637/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1638/// two edge towers and the integrand stacks at each edge. The returned
1639/// tower's derivative channels are the EXACT moving-boundary contribution to
1640/// every θ-derivative of the cell integral, to fourth order, with no term
1641/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1642/// derivative channels are all zero, contributing nothing — matching the
1643/// production `edge_vel = 0` short-circuit.
1644pub fn cell_moving_boundary_flux_tower<const K: usize>(
1645    z_right: &Tower4<K>,
1646    b_stack_right: [f64; 4],
1647    z_left: &Tower4<K>,
1648    b_stack_left: [f64; 4],
1649) -> Tower4<K> {
1650    moving_limit_boundary_tower(z_right, b_stack_right)
1651        - moving_limit_boundary_tower(z_left, b_stack_left)
1652}
1653
1654/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1655///
1656/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1657/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1658/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1659/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1660/// integrand coefficients move with η, hence with the primaries), so the
1661/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1662/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1663/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1664/// edge-motion-carrying terms the hand path assembles one by one, including the
1665/// `G·z_uv` term the directional path drops).
1666///
1667/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1668/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1669/// part — everything carrying edge motion — is exactly
1670///   `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1671/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1672/// interior moment integral already supplies. Both are one
1673/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1674/// θ in slots 1..K): substituting the edge tower gives the full composite,
1675/// substituting a frozen constant edge isolates the pure-θ part, and their
1676/// difference is the exact boundary flux — every Leibniz term derived by the
1677/// substitution algebra, none hand-omitted.
1678///
1679/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1680/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1681/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1682/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1683/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1684/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1685/// the derivative channels are meaningful, matching the value-less convention of
1686/// [`moving_limit_boundary_tower`].
1687pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1688    phi_jet: &Tower4<K1>,
1689    z_edge: &Tower4<K>,
1690) -> Tower4<K> {
1691    assert_eq!(
1692        K1,
1693        K + 1,
1694        "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1695    );
1696    let frozen_edge = Tower4::<K>::constant(z_edge.v);
1697    let full = substitute_intercept(phi_jet, z_edge);
1698    let interior = substitute_intercept(phi_jet, &frozen_edge);
1699    full - interior
1700}
1701
1702/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1703/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1704/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1705/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1706/// so its `full` and `interior` substitutions coincide and it contributes
1707/// nothing — matching the production `edge_vel = 0` short-circuit.
1708pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1709    phi_jet_right: &Tower4<K1>,
1710    z_right: &Tower4<K>,
1711    phi_jet_left: &Tower4<K1>,
1712    z_left: &Tower4<K>,
1713) -> Tower4<K> {
1714    moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1715        - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1716}
1717
1718// ── The program seam ─────────────────────────────────────────────────
1719
1720/// A family's row negative log-likelihood written ONCE over tower scalars.
1721///
1722/// This is the single source of truth #932 asks for: the value channel of
1723/// the returned tower must BE the production row NLL (same branches, same
1724/// guards, same numerics), and every derivative channel is then exact by
1725/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1726/// NOT part of this trait — it is family data, not calculus, and stays on
1727/// the `RowKernel` implementor.
1728pub trait RowNllProgram<const K: usize>: Send + Sync {
1729    /// Number of observations the program covers.
1730    fn n_rows(&self) -> usize;
1731
1732    /// Current primary-scalar values for `row` (where to seed the tower).
1733    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1734
1735    /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1736    /// variable `a` at the current primary value; implementations combine
1737    /// them with `Tower4` arithmetic and per-row data (response, censoring
1738    /// indicators, offsets) entering as constants.
1739    fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1740}
1741
1742/// Evaluate a program's full tower at the current primaries for one row.
1743///
1744/// One call yields every `RowKernel` calculus channel; callers that need
1745/// several contractions of the same row should hold the returned tower and
1746/// contract repeatedly rather than re-evaluating.
1747pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1748    prog: &P,
1749    row: usize,
1750) -> Result<Tower4<K>, String> {
1751    let p = prog.primaries(row)?;
1752    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1753    prog.row_nll(row, &vars)
1754}
1755
1756/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1757pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1758    prog: &P,
1759    row: usize,
1760) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1761    let t = evaluate_program(prog, row)?;
1762    Ok((t.v, t.g, t.h))
1763}
1764
1765/// Mechanically derived `row_third_contracted` channel.
1766pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1767    prog: &P,
1768    row: usize,
1769    dir: &[f64; K],
1770) -> Result<[[f64; K]; K], String> {
1771    Ok(evaluate_program(prog, row)?.third_contracted(dir))
1772}
1773
1774/// Mechanically derived `row_fourth_contracted` channel.
1775pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1776    prog: &P,
1777    row: usize,
1778    dir_u: &[f64; K],
1779    dir_v: &[f64; K],
1780) -> Result<[[f64; K]; K], String> {
1781    Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1782}
1783
1784// ── The generic program seam (#932 scalar cutover) ───────────────────
1785
1786/// A family's row negative log-likelihood written ONCE over the generic
1787/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1788/// re-instantiated at whatever order / representation a consumer needs
1789/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1790/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1791/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1792/// [`Tower4`] for every channel at once).
1793///
1794/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1795/// program implementing this generic trait gets the small contracted scalars for
1796/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1797/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1798/// work unchanged; new families should prefer this generic trait.
1799///
1800/// Because a `Tower4`-specialised `row_nll` body uses only
1801/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1802/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1803/// A program written that way needs no `Tower4`-specialised method and routes
1804/// the directional and joint-Hessian gates through the contracted scalars from
1805/// a single definition.
1806pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1807    /// Number of observations the program covers.
1808    fn n_rows(&self) -> usize;
1809
1810    /// Current primary-scalar values for `row` (where to seed the scalar).
1811    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1812
1813    /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1814    /// (base value + per-scalar nilpotent directions) by the caller; the body
1815    /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1816    /// (response, censoring, offsets) entering as constants.
1817    fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1818        &self,
1819        row: usize,
1820        p: &[S; K],
1821    ) -> Result<S, String>;
1822}
1823
1824/// Evaluate a generic program at the value/gradient/Hessian scalar
1825/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1826/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1827///
1828/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1829/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1830/// cheap order-2 scalar. The single source of truth means the gradient and
1831/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1832pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1833    prog: &P,
1834    row: usize,
1835) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1836    let base = prog.primaries(row)?;
1837    let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1838        <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1839    });
1840    let s = prog.row_nll_generic(row, &vars)?;
1841    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1842}
1843
1844/// Evaluate a generic program at the one-seed scalar
1845/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1846/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1847/// materialising the dense `t3` tensor. The contraction direction is folded
1848/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1849pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1850    prog: &P,
1851    row: usize,
1852    dir: &[f64; K],
1853) -> Result<[[f64; K]; K], String> {
1854    let base = prog.primaries(row)?;
1855    let vars: [crate::jet_scalar::OneSeed<K>; K] =
1856        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1857    let s = prog.row_nll_generic(row, &vars)?;
1858    Ok(s.contracted_third())
1859}
1860
1861/// Evaluate a generic program at the two-seed scalar
1862/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1863/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1864/// WITHOUT materialising the dense `t4` tensor.
1865pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1866    prog: &P,
1867    row: usize,
1868    dir_u: &[f64; K],
1869    dir_v: &[f64; K],
1870) -> Result<[[f64; K]; K], String> {
1871    let base = prog.primaries(row)?;
1872    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1873        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1874    let s = prog.row_nll_generic(row, &vars)?;
1875    Ok(s.contracted_fourth())
1876}
1877
1878/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1879/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1880/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1881/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1882/// the order-2 / contracted scalars consume, so there is a single source of
1883/// truth across every channel.
1884pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1885    prog: &P,
1886    row: usize,
1887) -> Result<Tower4<K>, String> {
1888    let base = prog.primaries(row)?;
1889    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1890    prog.row_nll_generic(row, &vars)
1891}
1892
1893// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1894//
1895// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1896// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1897// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1898// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1899// as separate inherent methods, not a shared trait bound — so a row-NLL body
1900// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1901// 4-rows-per-pass SIMD batch path could not reuse the single source.
1902//
1903// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1904// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1905// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1906// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1907// (so the scalar call sites compile unchanged and bit-identically), and explicit
1908// impls route the `f64x4` lane towers through their existing per-lane methods. A
1909// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1910// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1911
1912/// The verdict of a per-lane [`RowJet::guard`] domain check.
1913///
1914/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1915/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1916/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1917/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1918/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1919/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1920#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1921pub struct GuardVerdict {
1922    lanes: u8,
1923    failed_mask: u8,
1924}
1925
1926impl GuardVerdict {
1927    /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1928    #[inline]
1929    pub fn scalar(pass: bool) -> Self {
1930        Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1931    }
1932    /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1933    #[inline]
1934    pub fn lanes4(failed_mask: u8) -> Self {
1935        Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1936    }
1937    /// Number of active lanes inspected (1 scalar, 4 batch).
1938    #[inline]
1939    pub fn lanes(self) -> usize {
1940        self.lanes as usize
1941    }
1942    /// True iff every inspected lane satisfied the predicate.
1943    #[inline]
1944    pub fn all_pass(self) -> bool {
1945        self.failed_mask == 0
1946    }
1947    /// True iff at least one inspected lane failed the predicate.
1948    #[inline]
1949    pub fn any_failed(self) -> bool {
1950        self.failed_mask != 0
1951    }
1952    /// True iff lane `i` failed the predicate.
1953    #[inline]
1954    pub fn lane_failed(self, i: usize) -> bool {
1955        (self.failed_mask >> i) & 1 == 1
1956    }
1957    /// The raw failure mask (bit `i` ⇒ lane `i` failed).
1958    #[inline]
1959    pub fn failed_mask(self) -> u8 {
1960        self.failed_mask
1961    }
1962}
1963
1964/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
1965/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
1966/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
1967/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
1968/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
1969/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
1970/// common `N == 5` case is bit-identical to passing the stack straight through).
1971#[inline]
1972fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
1973    let mut out = [0.0_f64; M];
1974    let m = N.min(M);
1975    out[..m].copy_from_slice(&s[..m]);
1976    out
1977}
1978
1979/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
1980/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
1981/// rows/pass without a dual-source copy (module §"The RowJet bridge").
1982///
1983/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
1984/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
1985/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
1986/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
1987/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
1988pub trait RowJet<const K: usize>: Copy {
1989    /// The value channel(s) seen by [`guard`](Self::guard) and
1990    /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
1991    /// `f64x4` lane tower.
1992    type Value: Copy;
1993
1994    /// A constant (value `c`, all derivatives zero), broadcast to every lane.
1995    fn constant(c: f64) -> Self;
1996    /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
1997    /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
1998    /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
1999    /// [`generic_batched_third_tower`]), which build the tower variables directly
2000    /// from each row's primaries; this method is for any row-invariant auxiliary
2001    /// variable a body introduces.
2002    fn variable(x: f64, slot: usize) -> Self;
2003    /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2004    fn values(&self) -> Self::Value;
2005
2006    /// Truncated Leibniz `self + o`.
2007    fn add(&self, o: &Self) -> Self;
2008    /// Truncated Leibniz `self − o`.
2009    fn sub(&self, o: &Self) -> Self;
2010    /// Truncated Leibniz `self · o`.
2011    fn mul(&self, o: &Self) -> Self;
2012    /// Multiply every channel by the plain scalar `s`.
2013    fn scale(&self, s: f64) -> Self;
2014    /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2015    /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2016    fn neg(&self) -> Self {
2017        self.scale(-1.0)
2018    }
2019
2020    /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2021    /// derivative stack is built from the running base value PER LANE through
2022    /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2023    /// inherent method that already exists on both the scalar towers and the lane
2024    /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2025    /// lane tower it is re-run per lane (the four rows carry four distinct base
2026    /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2027    /// Making it a trait method is precisely what lets a body written once over
2028    /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2029    /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2030    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2031
2032    /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2033    /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2034    /// its one value; a lane tower checks all four. Lets a batched program detect
2035    /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2036    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2037
2038    /// Per-lane scale: multiply every channel by the per-lane factor `s`
2039    /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2040    /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2041    /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2042    /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2043    /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2044    /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2045    /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2046    /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2047    /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2048    fn scale_rows(&self, s: Self::Value) -> Self;
2049
2050    /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2051    /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2052    /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2053    /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2054    /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2055    /// knowing the concrete representation: the program holds the per-row data and
2056    /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2057    /// [`RowNllProgramRowJet::row_nll`], so the body writes
2058    /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2059    /// buried in a `compose_unary_with` stack is pulled out the same way:
2060    /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2061    /// (Binary per-row branches such as the event indicator `di` are kept
2062    /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2063    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2064
2065    // ── value-derived transcendental conveniences ───────────────────────
2066    // Each routes through `compose_unary_with` with the SAME derivative stack the
2067    // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2068    // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2069    // `i` is bit-identical to the scalar result on row `i`.
2070
2071    /// `e^self`.
2072    fn exp(&self) -> Self {
2073        self.compose_unary_with(|u| {
2074            let e = u.exp();
2075            [e, e, e, e, e]
2076        })
2077    }
2078    /// `ln(self)`. Caller guarantees positivity.
2079    fn ln(&self) -> Self {
2080        self.compose_unary_with(|u| {
2081            let r = 1.0 / u;
2082            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2083        })
2084    }
2085    /// `√self`. Caller guarantees positivity.
2086    fn sqrt(&self) -> Self {
2087        self.compose_unary_with(|u| {
2088            let s = u.sqrt();
2089            [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2090        })
2091    }
2092    /// `1/self`.
2093    fn recip(&self) -> Self {
2094        self.compose_unary_with(|u| {
2095            let r = 1.0 / u;
2096            let r2 = r * r;
2097            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2098        })
2099    }
2100    /// `self^a` for real `a`. Caller guarantees a positive base.
2101    fn powf(&self, a: f64) -> Self {
2102        self.compose_unary_with(move |u| {
2103            [
2104                u.powf(a),
2105                a * u.powf(a - 1.0),
2106                a * (a - 1.0) * u.powf(a - 2.0),
2107                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2108                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2109            ]
2110        })
2111    }
2112    /// `ln Γ(self)`. Caller guarantees a positive argument.
2113    fn ln_gamma(&self) -> Self {
2114        self.compose_unary_with(ln_gamma_derivative_stack)
2115    }
2116    /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2117    fn digamma(&self) -> Self {
2118        self.compose_unary_with(digamma_derivative_stack)
2119    }
2120}
2121
2122/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2123/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2124/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2125/// bridge adds the lane representation without churning the scalar path. (The
2126/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2127/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2128/// any downstream impl, so the coherence checker proves the impls disjoint.)
2129impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2130    type Value = f64;
2131    #[inline]
2132    fn constant(c: f64) -> Self {
2133        <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2134    }
2135    #[inline]
2136    fn variable(x: f64, slot: usize) -> Self {
2137        <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2138    }
2139    #[inline]
2140    fn values(&self) -> f64 {
2141        crate::jet_scalar::JetScalar::value(self)
2142    }
2143    #[inline]
2144    fn add(&self, o: &Self) -> Self {
2145        crate::jet_scalar::JetScalar::add(self, o)
2146    }
2147    #[inline]
2148    fn sub(&self, o: &Self) -> Self {
2149        crate::jet_scalar::JetScalar::sub(self, o)
2150    }
2151    #[inline]
2152    fn mul(&self, o: &Self) -> Self {
2153        crate::jet_scalar::JetScalar::mul(self, o)
2154    }
2155    #[inline]
2156    fn scale(&self, s: f64) -> Self {
2157        crate::jet_scalar::JetScalar::scale(self, s)
2158    }
2159    #[inline]
2160    fn neg(&self) -> Self {
2161        crate::jet_scalar::JetScalar::neg(self)
2162    }
2163    #[inline]
2164    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2165        crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2166    }
2167    #[inline]
2168    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2169        GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2170    }
2171    #[inline]
2172    fn scale_rows(&self, s: f64) -> Self {
2173        // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2174        // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2175        crate::jet_scalar::JetScalar::scale(self, s)
2176    }
2177    #[inline]
2178    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2179        value_of(rows[0])
2180    }
2181}
2182
2183/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2184/// routing each op through its existing per-lane method. Lane `i` of a batched
2185/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2186/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2187impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2188    type Value = [f64; 4];
2189    #[inline]
2190    fn constant(c: f64) -> Self {
2191        Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2192    }
2193    #[inline]
2194    fn variable(x: f64, slot: usize) -> Self {
2195        Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2196    }
2197    #[inline]
2198    fn values(&self) -> [f64; 4] {
2199        self.v.to_array()
2200    }
2201    #[inline]
2202    fn add(&self, o: &Self) -> Self {
2203        Tower4Lane::add(self, o)
2204    }
2205    #[inline]
2206    fn sub(&self, o: &Self) -> Self {
2207        Tower4Lane::sub(self, o)
2208    }
2209    #[inline]
2210    fn mul(&self, o: &Self) -> Self {
2211        Tower4Lane::mul(self, o)
2212    }
2213    #[inline]
2214    fn scale(&self, s: f64) -> Self {
2215        Tower4Lane::scale(self, s)
2216    }
2217    #[inline]
2218    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2219        Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2220    }
2221    #[inline]
2222    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2223        let vals = self.v.to_array();
2224        let mut mask = 0u8;
2225        for (i, &v) in vals.iter().enumerate() {
2226            if !pred(v) {
2227                mask |= 1 << i;
2228            }
2229        }
2230        GuardVerdict::lanes4(mask)
2231    }
2232    #[inline]
2233    fn scale_rows(&self, s: [f64; 4]) -> Self {
2234        // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2235        // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2236        let sl = wide::f64x4::new(s);
2237        let mut out = *self;
2238        out.v = self.v * sl;
2239        for i in 0..K {
2240            out.g[i] = self.g[i] * sl;
2241            for j in 0..K {
2242                out.h[i][j] = self.h[i][j] * sl;
2243                for k in 0..K {
2244                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2245                    for l in 0..K {
2246                        out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2247                    }
2248                }
2249            }
2250        }
2251        out
2252    }
2253    #[inline]
2254    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2255        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2256    }
2257}
2258
2259/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2260/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2261/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2262/// [`Tower3`] which also carries only up to the third tensor.
2263impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2264    type Value = [f64; 4];
2265    #[inline]
2266    fn constant(c: f64) -> Self {
2267        Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2268    }
2269    #[inline]
2270    fn variable(x: f64, slot: usize) -> Self {
2271        Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2272    }
2273    #[inline]
2274    fn values(&self) -> [f64; 4] {
2275        self.v.to_array()
2276    }
2277    #[inline]
2278    fn add(&self, o: &Self) -> Self {
2279        Tower3Lane::add(self, o)
2280    }
2281    #[inline]
2282    fn sub(&self, o: &Self) -> Self {
2283        Tower3Lane::sub(self, o)
2284    }
2285    #[inline]
2286    fn mul(&self, o: &Self) -> Self {
2287        Tower3Lane::mul(self, o)
2288    }
2289    #[inline]
2290    fn scale(&self, s: f64) -> Self {
2291        Tower3Lane::scale(self, s)
2292    }
2293    #[inline]
2294    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2295        Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2296    }
2297    #[inline]
2298    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2299        let vals = self.v.to_array();
2300        let mut mask = 0u8;
2301        for (i, &v) in vals.iter().enumerate() {
2302            if !pred(v) {
2303                mask |= 1 << i;
2304            }
2305        }
2306        GuardVerdict::lanes4(mask)
2307    }
2308    #[inline]
2309    fn scale_rows(&self, s: [f64; 4]) -> Self {
2310        let sl = wide::f64x4::new(s);
2311        let mut out = *self;
2312        out.v = self.v * sl;
2313        for i in 0..K {
2314            out.g[i] = self.g[i] * sl;
2315            for j in 0..K {
2316                out.h[i][j] = self.h[i][j] * sl;
2317                for k in 0..K {
2318                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2319                }
2320            }
2321        }
2322        out
2323    }
2324    #[inline]
2325    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2326        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2327    }
2328}
2329
2330/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2331/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2332/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2333/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2334/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2335/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2336/// and the batched channels through [`generic_batched_fourth_tower`] /
2337/// [`generic_batched_third_tower`], all from a single source.
2338pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2339    /// Number of observations the program covers.
2340    fn n_rows(&self) -> usize;
2341
2342    /// Current primary-scalar values for `row` (where to seed each lane).
2343    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2344
2345    /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2346    /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2347    /// pre-seeded by the caller (base value plus, for the directional scalars, the
2348    /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2349    /// per-row data entering through `rows`/`self` as constants.
2350    fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2351}
2352
2353/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2354/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2355/// `RowJet` twin of [`generic_row_kernel`].
2356pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2357    prog: &P,
2358    row: usize,
2359) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2360    let base = prog.primaries(row)?;
2361    let vars: [crate::jet_scalar::Order2<K>; K] =
2362        std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2363    let s = prog.row_nll(&[row], &vars)?;
2364    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2365}
2366
2367/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2368/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2369/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2370pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2371    prog: &P,
2372    row: usize,
2373    dir: &[f64; K],
2374) -> Result<[[f64; K]; K], String> {
2375    let base = prog.primaries(row)?;
2376    let vars: [crate::jet_scalar::OneSeed<K>; K] =
2377        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2378    let s = prog.row_nll(&[row], &vars)?;
2379    Ok(s.contracted_third())
2380}
2381
2382/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2383/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2384/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2385pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2386    prog: &P,
2387    row: usize,
2388    dir_u: &[f64; K],
2389    dir_v: &[f64; K],
2390) -> Result<[[f64; K]; K], String> {
2391    let base = prog.primaries(row)?;
2392    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2393        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2394    let s = prog.row_nll(&[row], &vars)?;
2395    Ok(s.contracted_fourth())
2396}
2397
2398/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2399/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2400/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2401/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2402/// the scalar [`generic_full_tower`] on `rows[i]`.
2403pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2404    prog: &P,
2405    rows: [usize; 4],
2406) -> Result<Tower4Batch<K>, String> {
2407    let bases: [[f64; K]; 4] = [
2408        prog.primaries(rows[0])?,
2409        prog.primaries(rows[1])?,
2410        prog.primaries(rows[2])?,
2411        prog.primaries(rows[3])?,
2412    ];
2413    let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2414        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2415        Tower4Batch::variable(lane_vals, a)
2416    });
2417    prog.row_nll(&rows, &vars)
2418}
2419
2420/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2421/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2422/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2423/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2424pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2425    prog: &P,
2426    rows: [usize; 4],
2427) -> Result<Tower3Batch<K>, String> {
2428    let bases: [[f64; K]; 4] = [
2429        prog.primaries(rows[0])?,
2430        prog.primaries(rows[1])?,
2431        prog.primaries(rows[2])?,
2432        prog.primaries(rows[3])?,
2433    ];
2434    let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2435        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2436        Tower3Batch::variable(lane_vals, a)
2437    });
2438    prog.row_nll(&rows, &vars)
2439}
2440
2441// ── The oracle ───────────────────────────────────────────────────────
2442
2443/// One row's worth of hand-written kernel outputs, as claimed by a
2444/// `RowKernel` implementation, packaged for verification against the
2445/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2446/// its visibility — can be audited from its own test module.
2447pub struct KernelChannels<const K: usize> {
2448    /// Claimed `(nll, ∇, H)` from `row_kernel`.
2449    pub value: f64,
2450    /// Claimed gradient.
2451    pub gradient: [f64; K],
2452    /// Claimed Hessian.
2453    pub hessian: [[f64; K]; K],
2454    /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2455    pub third: Vec<([f64; K], [[f64; K]; K])>,
2456    /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2457    pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2458}
2459
2460/// Channel-by-channel audit of a hand-written kernel against the
2461/// single-expression tower truth. Returns `Err` naming the first channel,
2462/// index, claimed and true values on disagreement — designed as the body
2463/// of the per-family CI oracle tests (#932 deployment step 2).
2464///
2465/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2466/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2467/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2468/// towers pass without demanding bit-equality, while a tiny cross-block
2469/// entry dropped next to a huge one is still caught (it is NOT measured
2470/// against the largest entry of the whole channel — there is no per-channel
2471/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2472///
2473/// Non-finite handling is strict: a NaN on either side always fails; an
2474/// infinity passes only when both sides are the SAME signed infinity.
2475pub fn verify_kernel_channels<const K: usize>(
2476    tower: &Tower4<K>,
2477    claims: &KernelChannels<K>,
2478    rel_tol: f64,
2479) -> Result<(), String> {
2480    // Absolute floor: reuse rel_tol so a single knob controls both the
2481    // relative band and the absolute floor for entries near zero.
2482    let atol = rel_tol;
2483    let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2484        // Non-finite values never silently pass the algebraic comparison
2485        // below (any comparison with NaN is false). Handle them explicitly:
2486        // NaN on either side always errs; an infinity passes only if both
2487        // sides are the identical signed infinity.
2488        if !claim.is_finite() || !truth.is_finite() {
2489            let agree = claim.is_infinite()
2490                && truth.is_infinite()
2491                && claim.is_sign_positive() == truth.is_sign_positive();
2492            if agree {
2493                return Ok(());
2494            }
2495            return Err(format!(
2496                "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2497            ));
2498        }
2499        let band = atol + rel_tol * claim.abs().max(truth.abs());
2500        if (claim - truth).abs() > band {
2501            return Err(format!(
2502                "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2503            ));
2504        }
2505        Ok(())
2506    };
2507
2508    check("value", claims.value, tower.v)?;
2509
2510    for a in 0..K {
2511        check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2512    }
2513
2514    for a in 0..K {
2515        for b in 0..K {
2516            check(
2517                &format!("hessian[{a}][{b}]"),
2518                claims.hessian[a][b],
2519                tower.h[a][b],
2520            )?;
2521        }
2522    }
2523
2524    for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2525        let truth = tower.third_contracted(dir);
2526        for a in 0..K {
2527            for b in 0..K {
2528                check(
2529                    &format!("third[{t_idx}][{a}][{b}]"),
2530                    claim[a][b],
2531                    truth[a][b],
2532                )?;
2533            }
2534        }
2535    }
2536
2537    for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2538        let truth = tower.fourth_contracted(u, w);
2539        for a in 0..K {
2540            for b in 0..K {
2541                check(
2542                    &format!("fourth[{f_idx}][{a}][{b}]"),
2543                    claim[a][b],
2544                    truth[a][b],
2545                )?;
2546            }
2547        }
2548    }
2549
2550    Ok(())
2551}
2552
2553// ===========================================================================
2554// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2555// ===========================================================================
2556//
2557// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2558// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2559// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2560// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2561// vector pass instead of one per scalar pass.
2562//
2563// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2564// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2565// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2566// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2567// performs no fp-contraction), so lane `i` of any channel of a
2568// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2569// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2570// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2571// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2572// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2573//
2574// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2575// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2576// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2577// exactly as the scalar path already does).
2578
2579use crate::jet_scalar::Lane;
2580
2581/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2582/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2583/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2584#[derive(Clone, Copy)]
2585pub struct Tower4Lane<L: Lane, const K: usize> {
2586    /// Value channel (one entry per lane/row).
2587    pub v: L,
2588    /// Gradient `∂/∂p_a`.
2589    pub g: [L; K],
2590    /// Hessian `∂²/∂p_a∂p_b`.
2591    pub h: [[L; K]; K],
2592    /// Third tensor `∂³`.
2593    pub t3: [[[L; K]; K]; K],
2594    /// Fourth tensor `∂⁴`.
2595    pub t4: [[[[L; K]; K]; K]; K],
2596}
2597
2598/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2599pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2600
2601impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2602    /// All-zero tower (every channel `+0.0` in every lane).
2603    #[inline]
2604    pub fn zero() -> Self {
2605        let z = L::splat(0.0);
2606        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2607    }
2608    /// Constant `c` (per lane): value channel only.
2609    #[inline]
2610    pub fn constant(c: L) -> Self {
2611        let mut o = Self::zero();
2612        o.v = c;
2613        o
2614    }
2615    /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2616    /// slot `idx` (mirrors [`Tower4::variable`]).
2617    #[inline]
2618    pub fn variable(value: L, idx: usize) -> Self {
2619        let mut o = Self::constant(value);
2620        o.g[idx] = L::splat(1.0);
2621        o
2622    }
2623    /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2624    #[inline]
2625    pub fn lane(&self, i: usize) -> Tower4<K> {
2626        let mut out = Tower4::<K>::zero();
2627        out.v = self.v.lane(i);
2628        for a in 0..K {
2629            out.g[a] = self.g[a].lane(i);
2630            for b in 0..K {
2631                out.h[a][b] = self.h[a][b].lane(i);
2632                for c in 0..K {
2633                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2634                    for d in 0..K {
2635                        out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2636                    }
2637                }
2638            }
2639        }
2640        out
2641    }
2642    /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2643    #[inline]
2644    pub fn add(&self, o: &Self) -> Self {
2645        let mut out = *self;
2646        out.v = self.v.add(o.v);
2647        for i in 0..K {
2648            out.g[i] = self.g[i].add(o.g[i]);
2649            for j in 0..K {
2650                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2651                for k in 0..K {
2652                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2653                    for l in 0..K {
2654                        out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2655                    }
2656                }
2657            }
2658        }
2659        out
2660    }
2661    /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2662    #[inline]
2663    pub fn sub(&self, o: &Self) -> Self {
2664        let mut out = *self;
2665        out.v = self.v.sub(o.v);
2666        for i in 0..K {
2667            out.g[i] = self.g[i].sub(o.g[i]);
2668            for j in 0..K {
2669                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2670                for k in 0..K {
2671                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2672                    for l in 0..K {
2673                        out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2674                    }
2675                }
2676            }
2677        }
2678        out
2679    }
2680    /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2681    #[inline]
2682    pub fn scale(&self, s: f64) -> Self {
2683        let sl = L::splat(s);
2684        let mut out = *self;
2685        out.v = self.v.mul(sl);
2686        for i in 0..K {
2687            out.g[i] = self.g[i].mul(sl);
2688            for j in 0..K {
2689                out.h[i][j] = self.h[i][j].mul(sl);
2690                for k in 0..K {
2691                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2692                    for l in 0..K {
2693                        out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2694                    }
2695                }
2696            }
2697        }
2698        out
2699    }
2700    /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2701    #[inline]
2702    pub fn mul(&self, o: &Self) -> Self {
2703        let a = self;
2704        let b = o;
2705        let mut out = Self::zero();
2706        out.v = a.v.mul(b.v);
2707        for i in 0..K {
2708            let mut acc = L::splat(0.0);
2709            acc = acc.add(a.v.mul(b.g[i]));
2710            acc = acc.add(a.g[i].mul(b.v));
2711            out.g[i] = acc;
2712        }
2713        for i in 0..K {
2714            for j in 0..K {
2715                let mut acc = L::splat(0.0);
2716                acc = acc.add(a.v.mul(b.h[i][j]));
2717                acc = acc.add(a.g[i].mul(b.g[j]));
2718                acc = acc.add(a.g[j].mul(b.g[i]));
2719                acc = acc.add(a.h[i][j].mul(b.v));
2720                out.h[i][j] = acc;
2721            }
2722        }
2723        for i in 0..K {
2724            for j in 0..K {
2725                for k in 0..K {
2726                    let mut acc = L::splat(0.0);
2727                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
2728                    acc = acc.add(a.g[i].mul(b.h[j][k]));
2729                    acc = acc.add(a.g[j].mul(b.h[i][k]));
2730                    acc = acc.add(a.h[i][j].mul(b.g[k]));
2731                    acc = acc.add(a.g[k].mul(b.h[i][j]));
2732                    acc = acc.add(a.h[i][k].mul(b.g[j]));
2733                    acc = acc.add(a.h[j][k].mul(b.g[i]));
2734                    acc = acc.add(a.t3[i][j][k].mul(b.v));
2735                    out.t3[i][j][k] = acc;
2736                }
2737            }
2738        }
2739        for i in 0..K {
2740            for j in 0..K {
2741                for k in 0..K {
2742                    for l in 0..K {
2743                        let mut acc = L::splat(0.0);
2744                        acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2745                        acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2746                        acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2747                        acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2748                        acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2749                        acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2750                        acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2751                        acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2752                        acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2753                        acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2754                        acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2755                        acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2756                        acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2757                        acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2758                        acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2759                        acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2760                        out.t4[i][j][k][l] = acc;
2761                    }
2762                }
2763            }
2764        }
2765        out
2766    }
2767    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2768    /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2769    /// (build via [`Lane::unary5`] from the scalar special-function stack).
2770    #[inline]
2771    pub fn compose_unary(&self, d: [L; 5]) -> Self {
2772        let mut out = Self::zero();
2773        out.v = d[0];
2774        for i in 0..K {
2775            let mut acc = L::splat(0.0);
2776            acc = acc.add(d[1].mul(self.g[i]));
2777            out.g[i] = acc;
2778        }
2779        for i in 0..K {
2780            for j in 0..K {
2781                let mut acc = L::splat(0.0);
2782                acc = acc.add(d[1].mul(self.h[i][j]));
2783                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2784                out.h[i][j] = acc;
2785            }
2786        }
2787        for i in 0..K {
2788            for j in 0..K {
2789                for k in 0..K {
2790                    let mut acc = L::splat(0.0);
2791                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
2792                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2793                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2794                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2795                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2796                    out.t3[i][j][k] = acc;
2797                }
2798            }
2799        }
2800        for i in 0..K {
2801            for j in 0..K {
2802                for k in 0..K {
2803                    for l in 0..K {
2804                        let mut acc = L::splat(0.0);
2805                        acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2806                        acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2807                        acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2808                        acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2809                        acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2810                        acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2811                        acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2812                        acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2813                        acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2814                        acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2815                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2816                        acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2817                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2818                        acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2819                        acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2820                        out.t4[i][j][k][l] = acc;
2821                    }
2822                }
2823            }
2824        }
2825        out
2826    }
2827    /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2828    /// built from the base value through `stack_fn`, evaluated PER LANE — the
2829    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2830    /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2831    ///
2832    /// Each of the four lanes carries a DISTINCT base value, so the scalar
2833    /// `stack_fn` is run once per lane at that lane's own value (via
2834    /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2835    /// the composition is then the existing per-lane [`Self::compose_unary`].
2836    /// Because `unary_with` runs the identical scalar closure per lane and
2837    /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2838    /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2839    /// — which is exactly what lets a row program written against the scalar
2840    /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2841    #[inline]
2842    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2843        self.compose_unary(self.v.unary_with(stack_fn))
2844    }
2845
2846    /// Single-active-slot fast path, term-for-term lift of
2847    /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2848    #[inline]
2849    pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2850        let mut out = Self::zero();
2851        let s = slot;
2852        let g = self.g[s];
2853        let h = self.h[s][s];
2854        let t3 = self.t3[s][s][s];
2855        let t4 = self.t4[s][s][s][s];
2856        out.v = d[0];
2857        out.g[s] = {
2858            let mut acc = L::splat(0.0);
2859            acc = acc.add(d[1].mul(g));
2860            acc
2861        };
2862        out.h[s][s] = {
2863            let mut acc = L::splat(0.0);
2864            acc = acc.add(d[1].mul(h));
2865            acc = acc.add(d[2].mul(g).mul(g));
2866            acc
2867        };
2868        out.t3[s][s][s] = {
2869            let mut acc = L::splat(0.0);
2870            acc = acc.add(d[1].mul(t3));
2871            acc = acc.add(d[2].mul(h).mul(g));
2872            acc = acc.add(d[2].mul(h).mul(g));
2873            acc = acc.add(d[2].mul(g).mul(h));
2874            acc = acc.add(d[3].mul(g).mul(g).mul(g));
2875            acc
2876        };
2877        out.t4[s][s][s][s] = {
2878            let mut acc = L::splat(0.0);
2879            acc = acc.add(d[1].mul(t4));
2880            acc = acc.add(d[2].mul(t3).mul(g));
2881            acc = acc.add(d[2].mul(t3).mul(g));
2882            acc = acc.add(d[2].mul(h).mul(h));
2883            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2884            acc = acc.add(d[2].mul(t3).mul(g));
2885            acc = acc.add(d[2].mul(h).mul(h));
2886            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2887            acc = acc.add(d[2].mul(h).mul(h));
2888            acc = acc.add(d[2].mul(g).mul(t3));
2889            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2890            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2891            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2892            acc = acc.add(d[3].mul(g).mul(g).mul(h));
2893            acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2894            acc
2895        };
2896        out
2897    }
2898    /// Contract `t3` with a primary-space direction (lift of
2899    /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2900    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2901    #[inline]
2902    pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2903        let mut out = [[L::splat(0.0); K]; K];
2904        for a in 0..K {
2905            for b in a..K {
2906                let mut acc = L::splat(0.0);
2907                for c in 0..K {
2908                    acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2909                }
2910                out[a][b] = acc;
2911                out[b][a] = acc;
2912            }
2913        }
2914        out
2915    }
2916    /// Contract `t4` with two primary-space directions (lift of
2917    /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
2918    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2919    #[inline]
2920    pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2921        let mut out = [[L::splat(0.0); K]; K];
2922        for i in 0..K {
2923            for j in i..K {
2924                let mut acc = L::splat(0.0);
2925                for k in 0..K {
2926                    for l in 0..K {
2927                        acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2928                    }
2929                }
2930                out[i][j] = acc;
2931                out[j][i] = acc;
2932            }
2933        }
2934        out
2935    }
2936}
2937
2938/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2939#[derive(Clone, Copy)]
2940pub struct Tower3Lane<L: Lane, const K: usize> {
2941    /// Value channel.
2942    pub v: L,
2943    /// Gradient.
2944    pub g: [L; K],
2945    /// Hessian.
2946    pub h: [[L; K]; K],
2947    /// Third tensor.
2948    pub t3: [[[L; K]; K]; K],
2949}
2950
2951/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
2952pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
2953
2954impl<L: Lane, const K: usize> Tower3Lane<L, K> {
2955    /// All-zero tower.
2956    #[inline]
2957    pub fn zero() -> Self {
2958        let z = L::splat(0.0);
2959        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
2960    }
2961    /// Constant `c` (per lane).
2962    #[inline]
2963    pub fn constant(c: L) -> Self {
2964        let mut o = Self::zero();
2965        o.v = c;
2966        o
2967    }
2968    /// Seeded variable `p_idx` at per-lane `value`.
2969    #[inline]
2970    pub fn variable(value: L, idx: usize) -> Self {
2971        let mut o = Self::constant(value);
2972        o.g[idx] = L::splat(1.0);
2973        o
2974    }
2975    /// Extract lane `i` as a scalar [`Tower3<K>`].
2976    #[inline]
2977    pub fn lane(&self, i: usize) -> Tower3<K> {
2978        let mut out = Tower3::<K>::zero();
2979        out.v = self.v.lane(i);
2980        for a in 0..K {
2981            out.g[a] = self.g[a].lane(i);
2982            for b in 0..K {
2983                out.h[a][b] = self.h[a][b].lane(i);
2984                for c in 0..K {
2985                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2986                }
2987            }
2988        }
2989        out
2990    }
2991    /// Per-channel lane-wise `self + o`.
2992    #[inline]
2993    pub fn add(&self, o: &Self) -> Self {
2994        let mut out = *self;
2995        out.v = self.v.add(o.v);
2996        for i in 0..K {
2997            out.g[i] = self.g[i].add(o.g[i]);
2998            for j in 0..K {
2999                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3000                for k in 0..K {
3001                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3002                }
3003            }
3004        }
3005        out
3006    }
3007    /// Per-channel lane-wise `self - o`.
3008    #[inline]
3009    pub fn sub(&self, o: &Self) -> Self {
3010        let mut out = *self;
3011        out.v = self.v.sub(o.v);
3012        for i in 0..K {
3013            out.g[i] = self.g[i].sub(o.g[i]);
3014            for j in 0..K {
3015                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3016                for k in 0..K {
3017                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3018                }
3019            }
3020        }
3021        out
3022    }
3023    /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3024    #[inline]
3025    pub fn scale(&self, s: f64) -> Self {
3026        let sl = L::splat(s);
3027        let mut out = *self;
3028        out.v = self.v.mul(sl);
3029        for i in 0..K {
3030            out.g[i] = self.g[i].mul(sl);
3031            for j in 0..K {
3032                out.h[i][j] = self.h[i][j].mul(sl);
3033                for k in 0..K {
3034                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3035                }
3036            }
3037        }
3038        out
3039    }
3040    /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3041    #[inline]
3042    pub fn mul(&self, o: &Self) -> Self {
3043        let a = self;
3044        let b = o;
3045        let mut out = Self::zero();
3046        out.v = a.v.mul(b.v);
3047        for i in 0..K {
3048            let mut acc = L::splat(0.0);
3049            acc = acc.add(a.v.mul(b.g[i]));
3050            acc = acc.add(a.g[i].mul(b.v));
3051            out.g[i] = acc;
3052        }
3053        for i in 0..K {
3054            for j in 0..K {
3055                let mut acc = L::splat(0.0);
3056                acc = acc.add(a.v.mul(b.h[i][j]));
3057                acc = acc.add(a.g[i].mul(b.g[j]));
3058                acc = acc.add(a.g[j].mul(b.g[i]));
3059                acc = acc.add(a.h[i][j].mul(b.v));
3060                out.h[i][j] = acc;
3061            }
3062        }
3063        for i in 0..K {
3064            for j in 0..K {
3065                for k in 0..K {
3066                    let mut acc = L::splat(0.0);
3067                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
3068                    acc = acc.add(a.g[i].mul(b.h[j][k]));
3069                    acc = acc.add(a.g[j].mul(b.h[i][k]));
3070                    acc = acc.add(a.h[i][j].mul(b.g[k]));
3071                    acc = acc.add(a.g[k].mul(b.h[i][j]));
3072                    acc = acc.add(a.h[i][k].mul(b.g[j]));
3073                    acc = acc.add(a.h[j][k].mul(b.g[i]));
3074                    acc = acc.add(a.t3[i][j][k].mul(b.v));
3075                    out.t3[i][j][k] = acc;
3076                }
3077            }
3078        }
3079        out
3080    }
3081    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3082    /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3083    #[inline]
3084    pub fn compose_unary(&self, d: [L; 4]) -> Self {
3085        let mut out = Self::zero();
3086        out.v = d[0];
3087        for i in 0..K {
3088            let mut acc = L::splat(0.0);
3089            acc = acc.add(d[1].mul(self.g[i]));
3090            out.g[i] = acc;
3091        }
3092        for i in 0..K {
3093            for j in 0..K {
3094                let mut acc = L::splat(0.0);
3095                acc = acc.add(d[1].mul(self.h[i][j]));
3096                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3097                out.h[i][j] = acc;
3098            }
3099        }
3100        for i in 0..K {
3101            for j in 0..K {
3102                for k in 0..K {
3103                    let mut acc = L::splat(0.0);
3104                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
3105                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3106                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3107                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3108                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3109                    out.t3[i][j][k] = acc;
3110                }
3111            }
3112        }
3113        out
3114    }
3115    /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3116    /// built from the base value through `stack_fn`, evaluated PER LANE — the
3117    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3118    /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3119    /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3120    /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3121    /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3122    /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3123    #[inline]
3124    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3125        self.compose_unary(self.v.unary_with(stack_fn))
3126    }
3127
3128    /// Single-active-slot fast path, term-for-term lift of
3129    /// [`Tower3::compose_unary_single_slot`].
3130    #[inline]
3131    pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3132        let mut out = Self::zero();
3133        let s = slot;
3134        let g = self.g[s];
3135        let h = self.h[s][s];
3136        let t3 = self.t3[s][s][s];
3137        out.v = d[0];
3138        out.g[s] = {
3139            let mut acc = L::splat(0.0);
3140            acc = acc.add(d[1].mul(g));
3141            acc
3142        };
3143        out.h[s][s] = {
3144            let mut acc = L::splat(0.0);
3145            acc = acc.add(d[1].mul(h));
3146            acc = acc.add(d[2].mul(g).mul(g));
3147            acc
3148        };
3149        out.t3[s][s][s] = {
3150            let mut acc = L::splat(0.0);
3151            acc = acc.add(d[1].mul(t3));
3152            acc = acc.add(d[2].mul(h).mul(g));
3153            acc = acc.add(d[2].mul(h).mul(g));
3154            acc = acc.add(d[2].mul(g).mul(h));
3155            acc = acc.add(d[3].mul(g).mul(g).mul(g));
3156            acc
3157        };
3158        out
3159    }
3160}
3161
3162#[cfg(test)]
3163mod batch_tests {
3164    use super::*;
3165
3166    struct Rng(u64);
3167    impl Rng {
3168        fn f(&mut self) -> f64 {
3169            self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3170            ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3171        }
3172    }
3173
3174    // Fill every channel of a scalar Tower4<K> with random data.
3175    fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3176        let mut t = Tower4::<K>::zero();
3177        t.v = r.f();
3178        for i in 0..K {
3179            t.g[i] = r.f();
3180            for j in 0..K {
3181                t.h[i][j] = r.f();
3182                for k in 0..K {
3183                    t.t3[i][j][k] = r.f();
3184                    for l in 0..K {
3185                        t.t4[i][j][k][l] = r.f();
3186                    }
3187                }
3188            }
3189        }
3190        t
3191    }
3192    fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3193        let mut t = Tower3::<K>::zero();
3194        t.v = r.f();
3195        for i in 0..K {
3196            t.g[i] = r.f();
3197            for j in 0..K {
3198                t.h[i][j] = r.f();
3199                for k in 0..K {
3200                    t.t3[i][j][k] = r.f();
3201                }
3202            }
3203        }
3204        t
3205    }
3206    fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3207        let mut b = Tower4Batch::<K>::zero();
3208        let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3209            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3210        };
3211        b.v = lane(&|t| t.v);
3212        for i in 0..K {
3213            b.g[i] = lane(&|t| t.g[i]);
3214            for j in 0..K {
3215                b.h[i][j] = lane(&|t| t.h[i][j]);
3216                for k in 0..K {
3217                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3218                    for l in 0..K {
3219                        b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3220                    }
3221                }
3222            }
3223        }
3224        b
3225    }
3226    fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3227        let mut b = Tower3Batch::<K>::zero();
3228        let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3229            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3230        };
3231        b.v = lane(&|t| t.v);
3232        for i in 0..K {
3233            b.g[i] = lane(&|t| t.g[i]);
3234            for j in 0..K {
3235                b.h[i][j] = lane(&|t| t.h[i][j]);
3236                for k in 0..K {
3237                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3238                }
3239            }
3240        }
3241        b
3242    }
3243    fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3244        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3245        for i in 0..K {
3246            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3247            for j in 0..K {
3248                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3249                for k in 0..K {
3250                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3251                    for l in 0..K {
3252                        assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3253                    }
3254                }
3255            }
3256        }
3257    }
3258    fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3259        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3260        for i in 0..K {
3261            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3262            for j in 0..K {
3263                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3264                for k in 0..K {
3265                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3266                }
3267            }
3268        }
3269    }
3270
3271    // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3272    // then assert every channel of every lane is to_bits-identical.
3273    fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3274        let mut r = Rng(seed);
3275        let mut rows_checked = 0;
3276        for _ in 0..batches {
3277            let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3278            let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3279            let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3280            let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3281            let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3282            let s = r.f();
3283
3284            // scalar per-row reference
3285            let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3286                let prod = a[rw].mul(&b[rw]);
3287                let comp = prod.compose_unary(d[rw]);
3288                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3289                summed.compose_unary_single_slot(d[rw], 0)
3290            });
3291            let third: [[[f64; K]; K]; 4] =
3292                std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3293            let fourth: [[[f64; K]; K]; 4] =
3294                std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3295
3296            // batched f64x4
3297            let ab = pack4_t4(&a);
3298            let bb = pack4_t4(&b);
3299            let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3300                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3301            });
3302            let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3303                wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3304            });
3305            let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3306                wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3307            });
3308            let prodb = ab.mul(&bb);
3309            let compb = prodb.compose_unary(db);
3310            let summedb = compb.add(&ab).sub(&bb).scale(s);
3311            let finalb = summedb.compose_unary_single_slot(db, 0);
3312            let thirdb = ab.third_contracted(&dirb);
3313            let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3314
3315            for rw in 0..4 {
3316                assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3317                for i in 0..K {
3318                    for j in 0..K {
3319                        assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3320                        assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3321                    }
3322                }
3323                rows_checked += 1;
3324            }
3325        }
3326        rows_checked
3327    }
3328    fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3329        let mut r = Rng(seed);
3330        let mut rows_checked = 0;
3331        for _ in 0..batches {
3332            let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3333            let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3334            let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3335            let s = r.f();
3336            let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3337                let prod = a[rw].mul(&b[rw]);
3338                let comp = prod.compose_unary(d[rw]);
3339                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3340                summed.compose_unary_single_slot(d[rw], 0)
3341            });
3342            let ab = pack4_t3(&a);
3343            let bb = pack4_t3(&b);
3344            let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3345                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3346            });
3347            let prodb = ab.mul(&bb);
3348            let compb = prodb.compose_unary(db);
3349            let summedb = compb.add(&ab).sub(&bb).scale(s);
3350            let finalb = summedb.compose_unary_single_slot(db, 0);
3351            for rw in 0..4 {
3352                assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3353                rows_checked += 1;
3354            }
3355        }
3356        rows_checked
3357    }
3358
3359    // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3360    // lanes (≈210 KiB by value); the op chain keeps several live, which can
3361    // exceed a test thread's default stack. Run each width on a large-stack
3362    // thread so K=9 is exercised without a stack overflow.
3363    fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3364        std::thread::Builder::new()
3365            .stack_size(512 << 20)
3366            .spawn(f)
3367            .unwrap()
3368            .join()
3369            .unwrap()
3370    }
3371
3372    #[test]
3373    fn tower4_batch_lane_bit_identical() {
3374        let batches = 2000;
3375        let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3376            + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3377            + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3378            + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3379        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3380        // worker threads against silently running zero comparisons.
3381        assert_eq!(rows_checked, 4 * batches * 4);
3382    }
3383
3384    #[test]
3385    fn tower3_batch_lane_bit_identical() {
3386        let batches = 2000;
3387        let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3388            + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3389            + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3390            + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3391        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3392        // worker threads against silently running zero comparisons.
3393        assert_eq!(rows_checked, 4 * batches * 4);
3394    }
3395
3396    // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3397    //
3398    // The seam lets a single-sourced row program build its special-function
3399    // STACK from the base value through a closure, so the SAME expression
3400    // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3401    // per-lane bases). These oracles pin both arms `to_bits`.
3402
3403    /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3404    /// standing in for a family's hand-certified special-function stack. `stack4`
3405    /// is its order-≤3 truncation.
3406    fn seam_stack5(u: f64) -> [f64; 5] {
3407        [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3408    }
3409    fn seam_stack4(u: f64) -> [f64; 4] {
3410        let s = seam_stack5(u);
3411        [s[0], s[1], s[2], s[3]]
3412    }
3413
3414    /// Force a distinct / edge per-lane base value (signed zeros included).
3415    fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3416        match which {
3417            0 => -0.0,
3418            1 => 0.0,
3419            2 => r.f(),
3420            _ => r.f() + 3.0,
3421        }
3422    }
3423
3424    /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3425    /// the explicit `compose_unary(f(value))` on every channel.
3426    fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3427        let mut r = Rng(seed);
3428        for _ in 0..n {
3429            let mut t = rand_t4::<K>(&mut r);
3430            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3431            assert_t4_eq(
3432                &t.compose_unary_with(seam_stack5),
3433                &t.compose_unary(seam_stack5(t.v)),
3434                "scalar t4 seam",
3435            );
3436        }
3437        n
3438    }
3439    fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3440        let mut r = Rng(seed);
3441        for _ in 0..n {
3442            let mut t = rand_t3::<K>(&mut r);
3443            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3444            assert_t3_eq(
3445                &t.compose_unary_with(seam_stack4),
3446                &t.compose_unary(seam_stack4(t.v)),
3447                "scalar t3 seam",
3448            );
3449        }
3450        n
3451    }
3452
3453    /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3454    /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3455    /// with the four lanes carrying DISTINCT base values (signed zeros included),
3456    /// so a buggy impl reusing one lane's base would fail.
3457    fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3458        let mut r = Rng(seed);
3459        let mut verified = 0usize;
3460        for _ in 0..batches {
3461            let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3462            for (rw, row) in rows.iter_mut().enumerate() {
3463                row.v = seam_edge_base(&mut r, rw);
3464            }
3465            let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3466            for (rw, row) in rows.iter().enumerate() {
3467                assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3468                verified += 1;
3469            }
3470        }
3471        verified
3472    }
3473    fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3474        let mut r = Rng(seed);
3475        let mut verified = 0usize;
3476        for _ in 0..batches {
3477            let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3478            for (rw, row) in rows.iter_mut().enumerate() {
3479                row.v = seam_edge_base(&mut r, rw);
3480            }
3481            let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3482            for (rw, row) in rows.iter().enumerate() {
3483                assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3484                verified += 1;
3485            }
3486        }
3487        verified
3488    }
3489
3490    #[test]
3491    fn compose_unary_with_scalar_bit_identical() {
3492        let n = 1100;
3493        let total = scalar_seam_t4::<2>(0x2200_0001, n)
3494            + scalar_seam_t4::<3>(0x2200_0002, n)
3495            + scalar_seam_t4::<4>(0x2200_0003, n)
3496            + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3497            + scalar_seam_t3::<2>(0x3300_0001, n)
3498            + scalar_seam_t3::<3>(0x3300_0002, n)
3499            + scalar_seam_t3::<4>(0x3300_0003, n)
3500            + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3501        // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3502        assert_eq!(total, 8 * n);
3503    }
3504
3505    #[test]
3506    fn compose_unary_with_lane_matches_scalar() {
3507        let b = 600;
3508        let total = lane_seam_t4::<2>(0x4400_0001, b)
3509            + lane_seam_t4::<3>(0x4400_0002, b)
3510            + lane_seam_t4::<4>(0x4400_0003, b)
3511            + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3512            + lane_seam_t3::<2>(0x5500_0001, b)
3513            + lane_seam_t3::<3>(0x5500_0002, b)
3514            + lane_seam_t3::<4>(0x5500_0003, b)
3515            + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3516        // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3517        assert_eq!(total, 8 * b * 4);
3518    }
3519}
3520
3521#[cfg(test)]
3522mod tests {
3523    use super::*;
3524
3525    /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3526    /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3527    /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3528    /// dropping the fourth tensor cannot perturb them. Exercises products
3529    /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3530    /// same operations the survival location-scale `nll_index_tower` composes —
3531    /// across all mixed partials, not just the diagonal entries that kernel reads.
3532    #[test]
3533    fn tower3_matches_tower4_through_third_order() {
3534        let s_a: [f64; 5] = [
3535            0.3_f64.sin(),
3536            0.3_f64.cos(),
3537            -0.3_f64.sin(),
3538            -0.3_f64.cos(),
3539            0.3_f64.sin(),
3540        ];
3541        let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3542        let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3543
3544        let a4 = Tower4::<3>::variable(0.4, 0);
3545        let b4 = Tower4::<3>::variable(-0.7, 1);
3546        let c4 = Tower4::<3>::variable(0.9, 2);
3547        let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3548            + a4.mul(&c4).scale(-0.7)
3549            + b4.compose_unary(s_b).scale(0.25);
3550
3551        let a3 = Tower3::<3>::variable(0.4, 0);
3552        let b3 = Tower3::<3>::variable(-0.7, 1);
3553        let c3 = Tower3::<3>::variable(0.9, 2);
3554        let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3555            + a3.mul(&c3).scale(-0.7)
3556            + b3.compose_unary(s4(s_b)).scale(0.25);
3557
3558        assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3559        for i in 0..3 {
3560            assert_eq!(
3561                prog3.g[i].to_bits(),
3562                prog4.g[i].to_bits(),
3563                "g[{i}] mismatch"
3564            );
3565            for j in 0..3 {
3566                assert_eq!(
3567                    prog3.h[i][j].to_bits(),
3568                    prog4.h[i][j].to_bits(),
3569                    "h[{i}][{j}] mismatch"
3570                );
3571                for k in 0..3 {
3572                    assert_eq!(
3573                        prog3.t3[i][j][k].to_bits(),
3574                        prog4.t3[i][j][k].to_bits(),
3575                        "t3[{i}][{j}][{k}] mismatch"
3576                    );
3577                }
3578            }
3579        }
3580    }
3581
3582    /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3583    /// The entire tower has textbook closed forms in μ = σ(η); this test
3584    /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3585    /// analytic truth at near-machine precision.
3586    struct LogitProgram {
3587        eta: Vec<f64>,
3588        y: Vec<f64>,
3589    }
3590
3591    impl RowNllProgram<1> for LogitProgram {
3592        fn n_rows(&self) -> usize {
3593            self.eta.len()
3594        }
3595        fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3596            Ok([self.eta[row]])
3597        }
3598        fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3599            let eta = p[0];
3600            Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3601        }
3602    }
3603
3604    #[test]
3605    fn logit_tower_matches_closed_forms() {
3606        let prog = LogitProgram {
3607            eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3608            y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3609        };
3610        for row in 0..prog.n_rows() {
3611            let t = evaluate_program(&prog, row).expect("logit program");
3612            let eta = prog.eta[row];
3613            let y = prog.y[row];
3614            let mu = 1.0 / (1.0 + (-eta).exp());
3615            let w = mu * (1.0 - mu);
3616            let expect = [
3617                (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3618                (t.g[0], mu - y, "grad"),
3619                (t.h[0][0], w, "hess"),
3620                (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3621                (
3622                    t.t4[0][0][0][0],
3623                    w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3624                    "fourth",
3625                ),
3626            ];
3627            for (got, want, label) in expect {
3628                assert!(
3629                    (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3630                    "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3631                );
3632            }
3633        }
3634    }
3635
3636    fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3637        let diff = (got - want).abs();
3638        assert!(
3639            diff <= rel_tol * want.abs().max(1.0),
3640            "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3641        );
3642    }
3643
3644    #[test]
3645    fn gamma_special_function_stacks_match_reference_values() {
3646        const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3647        let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3648        let cases = [
3649            (
3650                "x=0.1",
3651                0.1,
3652                -10.423_754_940_411_076,
3653                101.433_299_150_792_75,
3654            ),
3655            (
3656                "x=0.5",
3657                0.5,
3658                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3659                pi_sq / 2.0,
3660            ),
3661            ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3662            (
3663                "x=2.5",
3664                2.5,
3665                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3666                pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3667            ),
3668            (
3669                "x=50",
3670                50.0,
3671                3.901_989_673_427_892,
3672                0.020_201_333_226_697_128,
3673            ),
3674        ];
3675
3676        for (label, x, digamma_ref, trigamma_ref) in cases {
3677            let ln_gamma_stack = ln_gamma_derivative_stack(x);
3678            let digamma_stack = digamma_derivative_stack(x);
3679            let trigamma_stack = trigamma_derivative_stack(x);
3680            assert_close(
3681                &format!("{label} ln_gamma_stack digamma"),
3682                ln_gamma_stack[1],
3683                digamma_ref,
3684                1e-13,
3685            );
3686            assert_close(
3687                &format!("{label} digamma value"),
3688                digamma_stack[0],
3689                digamma_ref,
3690                1e-13,
3691            );
3692            assert_close(
3693                &format!("{label} ln_gamma_stack trigamma"),
3694                ln_gamma_stack[2],
3695                trigamma_ref,
3696                1e-13,
3697            );
3698            assert_close(
3699                &format!("{label} digamma_stack trigamma"),
3700                digamma_stack[1],
3701                trigamma_ref,
3702                1e-13,
3703            );
3704            assert_close(
3705                &format!("{label} trigamma value"),
3706                trigamma_stack[0],
3707                trigamma_ref,
3708                1e-13,
3709            );
3710        }
3711    }
3712
3713    #[test]
3714    fn gamma_special_function_stacks_obey_recurrences() {
3715        for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3716            let digamma_x = digamma_derivative_stack(x)[0];
3717            let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3718            let trigamma_x = trigamma_derivative_stack(x)[0];
3719            let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3720            assert_close(
3721                &format!("digamma recurrence x={x}"),
3722                digamma_next,
3723                digamma_x + 1.0 / x,
3724                1e-13,
3725            );
3726            assert_close(
3727                &format!("trigamma recurrence x={x}"),
3728                trigamma_next,
3729                trigamma_x - 1.0 / (x * x),
3730                1e-13,
3731            );
3732        }
3733    }
3734
3735    /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3736    /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3737    /// shape — all have one-line closed forms here.
3738    struct LocScaleProgram {
3739        eta: Vec<f64>,
3740        s: Vec<f64>,
3741        y: Vec<f64>,
3742    }
3743
3744    impl RowNllProgram<2> for LocScaleProgram {
3745        fn n_rows(&self) -> usize {
3746            self.eta.len()
3747        }
3748        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3749            Ok([self.eta[row], self.s[row]])
3750        }
3751        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3752            let r = -(p[0] - self.y[row]);
3753            Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3754        }
3755    }
3756
3757    #[test]
3758    fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3759        let prog = LocScaleProgram {
3760            eta: vec![0.3, -1.1, 2.0],
3761            s: vec![-0.5, 0.2, 0.8],
3762            y: vec![1.0, -2.0, 2.5],
3763        };
3764        let tol = 1e-12;
3765        for row in 0..prog.n_rows() {
3766            let t = evaluate_program(&prog, row).expect("locscale program");
3767            let r = prog.y[row] - prog.eta[row];
3768            let w = (-2.0 * prog.s[row]).exp();
3769            // (η, s) = indices (0, 1).
3770            let truth_g = [-w * r, 1.0 - w * r * r];
3771            let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3772            // Third tensor: distinct-entry closed forms.
3773            // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3774            let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3775                match a + b + c {
3776                    0 => 0.0,
3777                    1 => -2.0 * w,
3778                    2 => -4.0 * w * r,
3779                    _ => -4.0 * w * r * r,
3780                }
3781            };
3782            // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3783            // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3784            let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3785                match a + b + c + d {
3786                    0 | 1 => 0.0,
3787                    2 => 4.0 * w,
3788                    3 => 8.0 * w * r,
3789                    _ => 8.0 * w * r * r,
3790                }
3791            };
3792            for a in 0..2 {
3793                assert!(
3794                    (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3795                    "row {row} grad[{a}]"
3796                );
3797                for b in 0..2 {
3798                    assert!(
3799                        (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3800                        "row {row} hess[{a}][{b}]: got {} want {}",
3801                        t.h[a][b],
3802                        truth_h[a][b]
3803                    );
3804                    for c in 0..2 {
3805                        assert!(
3806                            (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3807                                <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3808                            "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3809                            t.t3[a][b][c],
3810                            t3_truth(a, b, c)
3811                        );
3812                        for d in 0..2 {
3813                            assert!(
3814                                (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3815                                    <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3816                                "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3817                                t.t4[a][b][c][d],
3818                                t4_truth(a, b, c, d)
3819                            );
3820                        }
3821                    }
3822                }
3823            }
3824            // The derived trait-surface helpers agree with direct contraction.
3825            let dir = [0.7, -1.3];
3826            let third = derived_third_contracted(&prog, row, &dir).expect("third");
3827            for a in 0..2 {
3828                for b in 0..2 {
3829                    let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3830                    assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3831                }
3832            }
3833        }
3834    }
3835
3836    /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3837    /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3838    /// every tower channel is checked against central finite differences
3839    /// of the channel one order below — value→grad, grad→hess, hess→t3,
3840    /// t3→t4 — so each order is independently anchored.
3841    ///
3842    /// The program carries a per-row primary fixture plus a per-row offset
3843    /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3844    /// drives both the seed point and the evaluated expression.
3845    struct GnarlyProgram {
3846        primaries: Vec<[f64; 3]>,
3847        tau: Vec<f64>,
3848    }
3849
3850    impl GnarlyProgram {
3851        fn fixture() -> Self {
3852            Self {
3853                primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3854                tau: vec![0.15, -0.35, 0.5],
3855            }
3856        }
3857    }
3858
3859    impl RowNllProgram<3> for GnarlyProgram {
3860        fn n_rows(&self) -> usize {
3861            self.primaries.len()
3862        }
3863        fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3864            self.primaries
3865                .get(row)
3866                .copied()
3867                .ok_or_else(|| format!("gnarly: row {row} out of range"))
3868        }
3869        fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3870            let tau = *self
3871                .tau
3872                .get(row)
3873                .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3874            let a = (p[0] * p[1]).exp();
3875            let b = (p[2] * p[2] + 1.0).sqrt();
3876            let c = (a + b + tau).ln();
3877            let d = (p[1] * 0.5 + 2.0).powf(1.7);
3878            Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3879        }
3880    }
3881
3882    /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3883    /// `row` (used to drive central differences off the fixture grid),
3884    /// while keeping `row`'s per-row data (`tau`) in the loss.
3885    fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3886        struct At<'a> {
3887            base: &'a GnarlyProgram,
3888            row: usize,
3889            p: [f64; 3],
3890        }
3891        impl RowNllProgram<3> for At<'_> {
3892            fn n_rows(&self) -> usize {
3893                1
3894            }
3895            fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3896                if row != 0 {
3897                    return Err(format!("gnarly-at: row {row} out of range"));
3898                }
3899                Ok(self.p)
3900            }
3901            fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3902                if eval_row != 0 {
3903                    return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3904                }
3905                self.base.row_nll(self.row, vars)
3906            }
3907        }
3908        evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3909    }
3910
3911    #[test]
3912    fn gnarly_tower_is_fd_consistent_order_by_order() {
3913        let prog = GnarlyProgram::fixture();
3914        for row in 0..prog.n_rows() {
3915            let base = prog.primaries(row).expect("primaries");
3916            let t = gnarly_tower_at(&prog, row, base);
3917            let h_step = 1e-5;
3918            let tol = 1e-6;
3919            for c in 0..3 {
3920                let mut up = base;
3921                let mut dn = base;
3922                up[c] += h_step;
3923                dn[c] -= h_step;
3924                let t_up = gnarly_tower_at(&prog, row, up);
3925                let t_dn = gnarly_tower_at(&prog, row, dn);
3926                // value → gradient.
3927                let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3928                assert!(
3929                    (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3930                    "grad[{c}]: analytic {} fd {}",
3931                    t.g[c],
3932                    fd_g
3933                );
3934                for a in 0..3 {
3935                    // gradient → Hessian.
3936                    let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3937                    assert!(
3938                        (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3939                        "hess[{a}][{c}]: analytic {} fd {}",
3940                        t.h[a][c],
3941                        fd_h
3942                    );
3943                    for b in 0..3 {
3944                        // Hessian → third.
3945                        let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
3946                        assert!(
3947                            (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
3948                            "t3[{a}][{b}][{c}]: analytic {} fd {}",
3949                            t.t3[a][b][c],
3950                            fd_t3
3951                        );
3952                        for d in 0..3 {
3953                            // third → fourth.
3954                            let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
3955                            assert!(
3956                                (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
3957                                "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
3958                                t.t4[a][b][d][c],
3959                                fd_t4
3960                            );
3961                        }
3962                    }
3963                }
3964            }
3965        }
3966    }
3967
3968    /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
3969    /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
3970    /// smooth, strictly-`a`-monotone
3971    ///   F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
3972    /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
3973    /// independent finite-difference oracle. Mirrors the survival flex
3974    /// calibration solve (one implicit intercept over the primaries) without
3975    /// any survival machinery, so a failure localises to the combinator.
3976    #[test]
3977    fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
3978        const C: f64 = 1.7;
3979        // The scalar constraint as a plain f64 closure (the production root
3980        // finder analogue) and its tower form in (a, θ₀, θ₁).
3981        let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
3982        let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
3983        let solve = |th: [f64; 2]| -> f64 {
3984            let mut a = 0.0_f64;
3985            for _ in 0..100 {
3986                let r = f_scalar(a, th);
3987                if r.abs() < 1e-14 {
3988                    break;
3989                }
3990                a -= r / f_da(a, th);
3991            }
3992            a
3993        };
3994        // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
3995        let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
3996            let a = Tower4::<3>::variable(a0, 0);
3997            let t0 = Tower4::<3>::variable(th[0], 1);
3998            let t1 = Tower4::<3>::variable(th[1], 2);
3999            a + t0 * a.mul(&a) + t1 * a.exp() - C
4000        };
4001
4002        let th0 = [0.35, 0.2];
4003        let a0 = solve(th0);
4004        let f = f_tower(a0, th0);
4005        // Residual at the solved point is ~0 (the combinator tolerates the
4006        // production Newton residual; here it is machine-zero).
4007        assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4008        let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4009
4010        // FD oracle: central differences of the scalar re-solve. Each order is
4011        // built from the previous via one more central difference, exactly the
4012        // gnarly order-by-order ladder.
4013        let h = 1e-4;
4014        let tol = 1e-5;
4015        let re = |th: [f64; 2]| solve(th);
4016        for i in 0..2 {
4017            let mut up = th0;
4018            let mut dn = th0;
4019            up[i] += h;
4020            dn[i] -= h;
4021            let fd_g = (re(up) - re(dn)) / (2.0 * h);
4022            assert!(
4023                (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4024                "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4025                a_tower.g[i],
4026                fd_g
4027            );
4028            // second order: FD of the analytic gradient component would re-use
4029            // the combinator; instead difference a SCALAR gradient computed by
4030            // a nested re-solve so the oracle stays production-independent.
4031            let grad_at = |th: [f64; 2], j: usize| -> f64 {
4032                let mut up = th;
4033                let mut dn = th;
4034                up[j] += h;
4035                dn[j] -= h;
4036                (re(up) - re(dn)) / (2.0 * h)
4037            };
4038            for j in 0..2 {
4039                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4040                assert!(
4041                    (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4042                    "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4043                    a_tower.h[i][j],
4044                    fd_h
4045                );
4046            }
4047        }
4048    }
4049
4050    /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4051    /// linear-in-a constraint, and the second-order tensor matches the
4052    /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4053    /// This pins the recursion against the hand-coded first_full.rs formula it
4054    /// replaces, independent of any FD step.
4055    #[test]
4056    fn implicit_solve_matches_textbook_ift_recursion() {
4057        // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4058        let a0 = 0.4_f64;
4059        let th = [0.25_f64, -0.15_f64];
4060        let f = {
4061            let a = Tower4::<3>::variable(a0, 0);
4062            let t0 = Tower4::<3>::variable(th[0], 1);
4063            let t1 = Tower4::<3>::variable(th[1], 2);
4064            // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4065            // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4066            //   0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4067            // implicit_solve requires a genuine root; at the root the level-set
4068            // and root-curve derivatives coincide, so the textbook-IFT
4069            // assertions below are unaffected.
4070            a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4071        };
4072        let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4073        let f_a = f.g[0];
4074        // First order: a_u = −F_u / F_a.
4075        for u in 0..2 {
4076            let want = -f.g[u + 1] / f_a;
4077            assert!(
4078                (a_t.g[u] - want).abs() < 1e-12,
4079                "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4080                a_t.g[u],
4081                want
4082            );
4083        }
4084        // Second order textbook IFT (indices shifted by 1 for the a-slot).
4085        for u in 0..2 {
4086            for v in 0..2 {
4087                let f_uv = f.h[u + 1][v + 1];
4088                let f_au = f.h[0][u + 1];
4089                let f_av = f.h[0][v + 1];
4090                let f_aa = f.h[0][0];
4091                let want =
4092                    -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4093                assert!(
4094                    (a_t.h[u][v] - want).abs() < 1e-12,
4095                    "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4096                    a_t.h[u][v],
4097                    want
4098                );
4099            }
4100        }
4101    }
4102
4103    /// The moving-boundary flux tower reproduces every θ-derivative of a
4104    /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4105    /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4106    /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4107    /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4108    /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4109    #[test]
4110    fn moving_boundary_flux_carries_b_zuv_term() {
4111        use std::f64::consts::PI;
4112        let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4113        // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4114        let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4115        let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4116        let th0 = [0.7_f64, 0.5_f64];
4117
4118        // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4119        let mut z_edge = Tower4::<2>::constant(z_r(th0));
4120        z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4121        z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4122        z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2  (the z_uv the old flux dropped)
4123
4124        // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4125        // B‴=(3z−z³)·B.
4126        let z0 = z_edge.v;
4127        let b0 = b(z0);
4128        let stack = [
4129            b0,
4130            -z0 * b0,
4131            (z0 * z0 - 1.0) * b0,
4132            (3.0 * z0 - z0 * z0 * z0) * b0,
4133        ];
4134        let flux = moving_limit_boundary_tower(&z_edge, stack);
4135
4136        // FD truth of the integral's derivatives.
4137        let h = 1e-4;
4138        let tol = 1e-6;
4139        for i in 0..2 {
4140            let mut up = th0;
4141            let mut dn = th0;
4142            up[i] += h;
4143            dn[i] -= h;
4144            let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4145            assert!(
4146                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4147                "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4148                flux.g[i],
4149                fd_g
4150            );
4151        }
4152        // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4153        // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4154        // leave the [1][1] entry short by exactly 2·B(z₀).
4155        let grad1_at = |th: [f64; 2]| -> f64 {
4156            let mut up = th;
4157            let mut dn = th;
4158            up[1] += h;
4159            dn[1] -= h;
4160            (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4161        };
4162        let mut up = th0;
4163        let mut dn = th0;
4164        up[1] += h;
4165        dn[1] -= h;
4166        let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4167        assert!(
4168            (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4169            "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4170            flux.h[1][1],
4171            fd_h11
4172        );
4173        // Explicit witness that the B·z_uv term is present and material:
4174        // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4175        let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4176        let b_zuv = flux.h[1][1] - pure_zu2;
4177        assert!(
4178            (b_zuv - b0 * 2.0).abs() < 1e-10,
4179            "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4180            b_zuv,
4181            b0 * 2.0
4182        );
4183    }
4184
4185    /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4186    /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4187    /// plain `moving_limit_boundary_tower` cannot express, and the case the
4188    /// survival directional/bidirectional paths hand-assemble term-by-term
4189    /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4190    /// dropping `G·z_uv`). Two independent oracles:
4191    ///   (1) closed-form: the boundary flux of `∫ G dz` is exactly
4192    ///       `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4193    ///       derivatives we take by central FD of the closed form — no jet code.
4194    ///   (2) the explicit second-order hand closure, including the `G·z_uv` term,
4195    ///       built from the integrand's own (z,θ) partials.
4196    /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4197    /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4198    /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4199    /// miss a Hessian entry.
4200    #[test]
4201    fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4202        // G(z;θ) = exp(z·θ₀);  Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4203        let g = |z: f64, t0: f64| (z * t0).exp();
4204        let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4205        let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4206        let th0 = [0.4_f64, 0.5_f64];
4207        let z0 = z_r(th0);
4208
4209        // Edge tower z_edge(θ) over K=2 primaries.
4210        let mut z_edge = Tower4::<2>::constant(z0);
4211        z_edge.g[0] = 1.0; // ∂z/∂θ₀
4212        z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4213        z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4214
4215        // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4216        // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4217        let z_var = Tower4::<3>::variable(z0, 0);
4218        let t0_var = Tower4::<3>::variable(th0[0], 1);
4219        // θ₁ does not enter G/Φ here, but seed it so the jet carries the full
4220        // K1 frame (its Φ-derivatives are zero; the z_edge chain supplies all θ₁
4221        // motion through slot 0).
4222        let _t1_var = Tower4::<3>::variable(th0[1], 2);
4223        let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4224        // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4225        assert!(
4226            (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4227            "Φ_z {:+.8e} != G {:+.8e}",
4228            phi_jet.g[0],
4229            g(z0, th0[0])
4230        );
4231
4232        let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4233
4234        // Value channel is 0 by construction (boundary, not the integral itself).
4235        assert!(
4236            flux.v.abs() < 1e-12,
4237            "boundary value channel {:+.3e}",
4238            flux.v
4239        );
4240
4241        // Oracle (1): central FD of the closed-form boundary flux
4242        //   Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ)   (z₀ FROZEN at the base edge).
4243        let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4244        let h = 1e-4;
4245        let tol = 1e-6;
4246        for i in 0..2 {
4247            let mut up = th0;
4248            let mut dn = th0;
4249            up[i] += h;
4250            dn[i] -= h;
4251            let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4252            assert!(
4253                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4254                "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4255                flux.g[i],
4256                fd_g
4257            );
4258        }
4259        let grad_at = |th: [f64; 2], j: usize| -> f64 {
4260            let mut up = th;
4261            let mut dn = th;
4262            up[j] += h;
4263            dn[j] -= h;
4264            (bnd(up) - bnd(dn)) / (2.0 * h)
4265        };
4266        for i in 0..2 {
4267            for j in 0..2 {
4268                let mut up = th0;
4269                let mut dn = th0;
4270                up[i] += h;
4271                dn[i] -= h;
4272                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4273                assert!(
4274                    (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4275                    "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4276                    flux.h[i][j],
4277                    fd_h
4278                );
4279            }
4280        }
4281
4282        // Oracle (2): the explicit second-order hand closure, term by term —
4283        // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4284        // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4285        // G_θ₁ = 0.
4286        let gg = g(z0, th0[0]);
4287        let g_z = th0[0] * gg;
4288        let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4289        for i in 0..2 {
4290            for j in 0..2 {
4291                let z_u = z_edge.g[i];
4292                let z_v = z_edge.g[j];
4293                let z_uv = z_edge.h[i][j];
4294                let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4295                assert!(
4296                    (flux.h[i][j] - hand).abs() < 1e-9,
4297                    "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4298                    flux.h[i][j],
4299                    hand
4300                );
4301            }
4302        }
4303
4304        // Decisive: the `G·z_uv` term the directional path DROPS is present and
4305        // material in the [1][1] entry (z_uv = 2 there).
4306        let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4307        let g_zuv = flux.h[1][1] - pure_no_zuv;
4308        assert!(
4309            (g_zuv - gg * 2.0).abs() < 1e-9,
4310            "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4311            g_zuv,
4312            gg * 2.0
4313        );
4314    }
4315
4316    /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4317    /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4318    /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4319    /// boundary-velocity formulas:
4320    ///   z_u   = −(a_u + [u==g]·z) / b
4321    ///   z_uv  = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4322    /// This pins the bridge between `implicit_solve` and
4323    /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4324    /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4325    /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4326    /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4327    #[test]
4328    fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4329        const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4330        let g_idx = 1usize;
4331        let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4332        // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4333        // two live axes so a_u and a_uv are both exercised. (In production this
4334        // comes from implicit_solve; here we plant known derivatives.)
4335        let mut a = Tower4::<3>::constant(0.45);
4336        a.g[0] = 0.7;
4337        a.g[1] = -0.3;
4338        a.h[0][0] = 0.25;
4339        a.h[0][1] = 0.11;
4340        a.h[1][0] = 0.11;
4341        a.h[1][1] = -0.08;
4342
4343        // In the survival flex frame the slope `b` IS the g-primary directly
4344        // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4345        // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4346        let b = Tower4::<3>::variable(g0, g_idx);
4347        let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4348
4349        let bv = g0;
4350        let z0 = z_edge.v;
4351        assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4352
4353        // z_u = −(a_u + [u==g]·z) / b.
4354        for u in 0..2 {
4355            let direct = if u == g_idx { z0 } else { 0.0 };
4356            let want = -(a.g[u] + direct) / bv;
4357            assert!(
4358                (z_edge.g[u] - want).abs() < 1e-10,
4359                "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4360                z_edge.g[u],
4361                want
4362            );
4363        }
4364        // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4365        // first-order z_v/z_u (already verified above).
4366        for u in 0..2 {
4367            for v in 0..2 {
4368                let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4369                    + if v == g_idx { z_edge.g[u] } else { 0.0 };
4370                let want = -(a.h[u][v] + cross) / bv;
4371                assert!(
4372                    (z_edge.h[u][v] - want).abs() < 1e-10,
4373                    "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4374                    z_edge.h[u][v],
4375                    want
4376                );
4377            }
4378        }
4379    }
4380
4381    /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4382    /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4383    /// FD-certified bare boundary-velocity constants exactly:
4384    ///   z_a  = ∂z/∂a   = −1/b
4385    ///   z_ab = ∂²z/∂a∂b = +1/b²
4386    ///   z_aa = ∂²z/∂a²  = 0
4387    ///   z_bb = ∂²z/∂b²  = +2(τ−a)/b³
4388    /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4389    /// production base path drops (and only adds in the dir twins, causing the
4390    /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4391    /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4392    /// that later. Pins the constant before the constraint-tower wiring.
4393    #[test]
4394    fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4395        const TAU: f64 = 1.3;
4396        let a0 = 0.45_f64;
4397        let b0 = 0.85_f64;
4398        // Slot 0 = a, slot 1 = b, both seeded independent.
4399        let a = Tower4::<2>::variable(a0, 0);
4400        let b = Tower4::<2>::variable(b0, 1);
4401        let z = (Tower4::<2>::constant(TAU) - a) / b;
4402
4403        assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4404        assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4405        assert!(
4406            (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4407            "z_ab {:+.10e} vs +1/b² {:+.10e}",
4408            z.h[0][1],
4409            1.0 / (b0 * b0)
4410        );
4411        assert!(
4412            z.h[0][0].abs() < 1e-12,
4413            "z_aa must vanish, got {:+.10e}",
4414            z.h[0][0]
4415        );
4416        let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4417        assert!(
4418            (z.h[1][1] - want_zbb).abs() < 1e-12,
4419            "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4420            z.h[1][1],
4421            want_zbb
4422        );
4423    }
4424
4425    /// The oracle harness catches a planted #736-style sign flip in a
4426    /// cross block and reports the channel by name.
4427    #[test]
4428    fn oracle_catches_planted_cross_block_sign_flip() {
4429        let prog = LocScaleProgram {
4430            eta: vec![0.3],
4431            s: vec![-0.5],
4432            y: vec![1.0],
4433        };
4434        let t = evaluate_program(&prog, 0).expect("tower");
4435        let dir = [0.6, -0.2];
4436        let mut third = t.third_contracted(&dir);
4437        let honest = KernelChannels {
4438            value: t.v,
4439            gradient: t.g,
4440            hessian: t.h,
4441            third: vec![(dir, third)],
4442            fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4443        };
4444        verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4445
4446        // Plant the #736 flip: negate one mixed cross entry.
4447        third[0][1] = -third[0][1];
4448        let flipped = KernelChannels {
4449            value: t.v,
4450            gradient: t.g,
4451            hessian: t.h,
4452            third: vec![(dir, third)],
4453            fourth: vec![],
4454        };
4455        let err = verify_kernel_channels(&t, &flipped, 1e-10)
4456            .expect_err("planted sign flip must be caught");
4457        assert!(
4458            err.contains("third[0][0][1]"),
4459            "oracle must name the flipped channel, got: {err}"
4460        );
4461    }
4462
4463    /// The third- and fourth-order tensors must be FULLY symmetric under
4464    /// index permutation (mixed partials commute). The tower stores them
4465    /// unsymmetrized, so equal-by-construction is a real invariant of the
4466    /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4467    /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4468    /// every composition path contributes. Lives in a test (not the hot
4469    /// per-op path) on purpose.
4470    #[test]
4471    fn t3_t4_are_fully_index_symmetric() {
4472        let prog = GnarlyProgram::fixture();
4473        // 3! = 6 permutations of three indices.
4474        let perms3: [[usize; 3]; 6] = [
4475            [0, 1, 2],
4476            [0, 2, 1],
4477            [1, 0, 2],
4478            [1, 2, 0],
4479            [2, 0, 1],
4480            [2, 1, 0],
4481        ];
4482        // 4! = 24 permutations of four indices.
4483        let perms4: [[usize; 4]; 24] = [
4484            [0, 1, 2, 3],
4485            [0, 1, 3, 2],
4486            [0, 2, 1, 3],
4487            [0, 2, 3, 1],
4488            [0, 3, 1, 2],
4489            [0, 3, 2, 1],
4490            [1, 0, 2, 3],
4491            [1, 0, 3, 2],
4492            [1, 2, 0, 3],
4493            [1, 2, 3, 0],
4494            [1, 3, 0, 2],
4495            [1, 3, 2, 0],
4496            [2, 0, 1, 3],
4497            [2, 0, 3, 1],
4498            [2, 1, 0, 3],
4499            [2, 1, 3, 0],
4500            [2, 3, 0, 1],
4501            [2, 3, 1, 0],
4502            [3, 0, 1, 2],
4503            [3, 0, 2, 1],
4504            [3, 1, 0, 2],
4505            [3, 1, 2, 0],
4506            [3, 2, 0, 1],
4507            [3, 2, 1, 0],
4508        ];
4509        for row in 0..prog.n_rows() {
4510            let t = evaluate_program(&prog, row).expect("gnarly tower");
4511            let scale_t3 =
4512                t.t3.iter()
4513                    .flatten()
4514                    .flatten()
4515                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4516                    .max(1.0);
4517            let scale_t4 =
4518                t.t4.iter()
4519                    .flatten()
4520                    .flatten()
4521                    .flatten()
4522                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4523                    .max(1.0);
4524            for i in 0..3 {
4525                for j in 0..3 {
4526                    for k in 0..3 {
4527                        let base = t.t3[i][j][k];
4528                        let idx = [i, j, k];
4529                        for p in &perms3 {
4530                            let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4531                            assert!(
4532                                (base - permed).abs() <= 1e-12 * scale_t3,
4533                                "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4534                                 permuted {permed:+.15e} under {p:?}"
4535                            );
4536                        }
4537                        for l in 0..3 {
4538                            let base4 = t.t4[i][j][k][l];
4539                            let idx4 = [i, j, k, l];
4540                            for p in &perms4 {
4541                                let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4542                                assert!(
4543                                    (base4 - permed).abs() <= 1e-12 * scale_t4,
4544                                    "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4545                                     permuted {permed:+.15e} under {p:?}"
4546                                );
4547                            }
4548                        }
4549                    }
4550                }
4551            }
4552        }
4553    }
4554}
4555
4556#[inline]
4557fn erfcx_nonnegative(x: f64) -> f64 {
4558    if !x.is_finite() {
4559        return if x.is_sign_positive() {
4560            0.0
4561        } else {
4562            f64::INFINITY
4563        };
4564    }
4565    if x <= 0.0 {
4566        return 1.0;
4567    }
4568    if x < 26.0 {
4569        ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4570    } else {
4571        let inv = 1.0 / x;
4572        let inv2 = inv * inv;
4573        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4574            + 6.5625 * inv2 * inv2 * inv2 * inv2;
4575        inv * poly / std::f64::consts::PI.sqrt()
4576    }
4577}
4578
4579#[inline]
4580fn log1mexp_positive(a: f64) -> f64 {
4581    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4582    if a > core::f64::consts::LN_2 {
4583        (-(-a).exp()).ln_1p()
4584    } else if a > 0.0 {
4585        (-(-a).exp_m1()).ln()
4586    } else {
4587        f64::NEG_INFINITY
4588    }
4589}
4590
4591#[inline]
4592fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4593    if x == f64::INFINITY {
4594        return (0.0, 0.0);
4595    }
4596    if x == f64::NEG_INFINITY {
4597        return (f64::NEG_INFINITY, f64::INFINITY);
4598    }
4599    if x.is_nan() {
4600        return (f64::NAN, f64::NAN);
4601    }
4602    if x < 0.0 {
4603        let u = -x / std::f64::consts::SQRT_2;
4604        let ex = erfcx_nonnegative(u).max(1e-300);
4605        let log_cdf = -u * u + (0.5 * ex).ln();
4606        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4607        (log_cdf, lambda)
4608    } else {
4609        let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4610        let lambda = crate::probability::normal_pdf(x) / cdf;
4611        (cdf.ln(), lambda)
4612    }
4613}
4614
4615/// Stable derivative stack for `log Phi(x)` through fourth order.
4616#[inline]
4617pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4618    let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4619    let lambda2 = lambda * lambda;
4620    let lambda3 = lambda2 * lambda;
4621    let x2 = x * x;
4622    [
4623        log_cdf,
4624        lambda,
4625        -lambda * (x + lambda),
4626        lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4627        -lambda
4628            * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4629    ]
4630}
4631
4632/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4633#[inline]
4634pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4635    let r = 1.0 / x.exp_m1();
4636    [
4637        log1mexp_positive(x),
4638        r,
4639        -r * (1.0 + r),
4640        r * (1.0 + r) * (1.0 + 2.0 * r),
4641        -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4642    ]
4643}
4644// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4645#[cfg(test)]
4646mod rowjet_bridge_tests {
4647    use super::*;
4648    use crate::jet_scalar::{JetScalar, Order2};
4649
4650    /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4651    /// subtraction, a scale/neg, a constant, and two value-distinct
4652    /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4653    /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4654    /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4655    struct ToyProgram {
4656        primaries: Vec<[f64; 2]>,
4657        /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4658        /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4659        /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4660        aux: Vec<[f64; 3]>,
4661    }
4662
4663    impl ToyProgram {
4664        /// The body uses `pack_rows` to gather the per-lane continuous data from
4665        /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4666        /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4667        fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4668            let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4669            let z = R::pack_rows(rows, |r| self.aux[r][1]);
4670            let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4671
4672            let a = p[0].mul(&p[1]).scale_rows(cov);
4673            let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4674            let c = b
4675                .compose_unary_with(|u| {
4676                    let e = u.exp();
4677                    [e, e, e, e, e]
4678                })
4679                .scale_rows(z);
4680            let d = c.neg().add(&p[0]);
4681            let e = d
4682                .compose_unary_with(|u| {
4683                    let s = (1.0 + u * u).sqrt();
4684                    let s3 = s * s * s;
4685                    let s5 = s3 * s * s;
4686                    let s7 = s5 * s * s;
4687                    [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4688                })
4689                .scale_rows(wi);
4690            e.mul(&p[1]).add(&e)
4691        }
4692    }
4693
4694    impl RowNllProgramRowJet<2> for ToyProgram {
4695        fn n_rows(&self) -> usize {
4696            self.primaries.len()
4697        }
4698        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4699            Ok(self.primaries[row])
4700        }
4701        fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4702            assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4703            Ok(self.body(rows, p))
4704        }
4705    }
4706
4707    fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4708        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4709        for i in 0..2 {
4710            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4711            for j in 0..2 {
4712                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4713                for k in 0..2 {
4714                    assert_eq!(
4715                        a.t3[i][j][k].to_bits(),
4716                        b.t3[i][j][k].to_bits(),
4717                        "{ctx}: t3[{i}][{j}][{k}]"
4718                    );
4719                    for l in 0..2 {
4720                        assert_eq!(
4721                            a.t4[i][j][k][l].to_bits(),
4722                            b.t4[i][j][k][l].to_bits(),
4723                            "{ctx}: t4[{i}][{j}][{k}][{l}]"
4724                        );
4725                    }
4726                }
4727            }
4728        }
4729    }
4730
4731    fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4732        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4733        for i in 0..2 {
4734            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4735            for j in 0..2 {
4736                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4737                for k in 0..2 {
4738                    assert_eq!(
4739                        a.t3[i][j][k].to_bits(),
4740                        b.t3[i][j][k].to_bits(),
4741                        "{ctx}: t3[{i}][{j}][{k}]"
4742                    );
4743                }
4744            }
4745        }
4746    }
4747
4748    // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4749    struct Lcg(u64);
4750    impl Lcg {
4751        fn next(&mut self) -> f64 {
4752            self.0 = self
4753                .0
4754                .wrapping_mul(6364136223846793005)
4755                .wrapping_add(1442695040888963407);
4756            (self.0 >> 11) as f64 / (1u64 << 53) as f64
4757        }
4758        fn val(&mut self) -> f64 {
4759            let u = self.next();
4760            if u < 0.04 {
4761                return 0.0;
4762            }
4763            if u < 0.08 {
4764                return -0.0;
4765            }
4766            (self.next() - 0.5) * 5.0
4767        }
4768    }
4769
4770    /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4771    /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4772    /// signed-zero and per-lane-distinct primaries.
4773    #[test]
4774    fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4775        let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4776        let mut batches = 0usize;
4777        for _ in 0..2500 {
4778            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4779            // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4780            let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4781            let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4782            let rows = [0usize, 1, 2, 3];
4783
4784            // order-4 batch vs scalar Tower4 (instantiated through the same body).
4785            let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4786            for (row, base) in bases.iter().enumerate() {
4787                let vars: [Tower4<2>; 2] =
4788                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4789                let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4790                assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4791            }
4792
4793            // order-3 batch vs scalar Tower3.
4794            let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4795            for (row, base) in bases.iter().enumerate() {
4796                let vars: [Tower3<2>; 2] =
4797                    std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4798                let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4799                assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4800            }
4801            batches += 1;
4802        }
4803        assert_eq!(batches, 2500);
4804    }
4805
4806    /// The blanket impl does not churn the scalar path: the body driven through
4807    /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4808    /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4809    /// `Tower4` lower channels.
4810    #[test]
4811    fn blanket_scalar_path_is_unchanged_and_consistent() {
4812        let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4813        for _ in 0..3000 {
4814            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4815            let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4816            let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4817
4818            // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4819            // reference body uses `scale(f64)` where the RowJet body uses
4820            // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4821            // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4822            let via_rowjet: Tower4<2> = {
4823                let vars: [Tower4<2>; 2] =
4824                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4825                prog.row_nll(&[0], &vars).expect("rowjet")
4826            };
4827            let via_jetscalar: Tower4<2> = {
4828                let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4829                    <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4830                });
4831                let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4832                // The body using JetScalar's own ops + scale(f64) directly.
4833                let a = vars[0].mul(&vars[1]).scale(cov);
4834                let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4835                let c = b
4836                    .compose_unary_with(|u| {
4837                        let e = u.exp();
4838                        [e, e, e, e, e]
4839                    })
4840                    .scale(z);
4841                let d = JetScalar::neg(&c).add(&vars[0]);
4842                let e = d
4843                    .compose_unary_with(|u| {
4844                        let s = (1.0 + u * u).sqrt();
4845                        let s3 = s * s * s;
4846                        let s5 = s3 * s * s;
4847                        let s7 = s5 * s * s;
4848                        [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4849                    })
4850                    .scale(wi);
4851                e.mul(&vars[1]).add(&e)
4852            };
4853            assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4854
4855            // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4856            // Order2 and Tower4 use different internal representations so
4857            // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4858            // Hessian channels that evaluate to exactly zero; IEEE equality
4859            // treats these as equal, so `==` is the right comparison here.
4860            let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4861            assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4862            for i in 0..2 {
4863                assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4864                for j in 0..2 {
4865                    assert!(
4866                        h[i][j] == via_rowjet.h[i][j],
4867                        "kernel h[{i}][{j}]: {} vs {}",
4868                        h[i][j],
4869                        via_rowjet.h[i][j]
4870                    );
4871                }
4872            }
4873
4874            // (c) the Order2 scalar IS a RowJet via the blanket.
4875            let o2: [Order2<2>; 2] =
4876                std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4877            let _ = prog.body(&[0], &o2);
4878        }
4879    }
4880
4881    /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4882    /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4883    /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4884    #[test]
4885    fn scale_rows_scalar_is_bit_identical_to_scale() {
4886        let mut rng = Lcg(0xFEED_FACE_0042_1001);
4887        for _ in 0..3000 {
4888            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4889            let s = rng.val();
4890            // Build a dense tower with populated channels (exp of a product).
4891            let vars: [Tower4<2>; 2] =
4892                std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4893            let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4894                let e = u.exp();
4895                [e, e, e, e, e]
4896            });
4897            let via_scale = RowJet::scale(&jet, s);
4898            let via_scale_rows = RowJet::scale_rows(&jet, s);
4899            assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4900        }
4901    }
4902
4903    /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4904    /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4905    /// continuous per-row data path the single-`f64` `scale` could not carry.
4906    #[test]
4907    fn batched_scale_rows_matches_per_row_scalar_scale() {
4908        let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4909        for _ in 0..2500 {
4910            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4911            let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4912            let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4913                Tower4Batch::variable(
4914                    wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4915                    a,
4916                )
4917            });
4918            let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4919                let e = u.exp();
4920                [e, e, e, e, e]
4921            });
4922            let scaled = prod.scale_rows(s);
4923            for (row, base) in bases.iter().enumerate() {
4924                let v: [Tower4<2>; 2] =
4925                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4926                let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4927                    let e = u.exp();
4928                    [e, e, e, e, e]
4929                });
4930                let ref_s = RowJet::scale(&prod_s, s[row]);
4931                assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4932            }
4933        }
4934    }
4935
4936    /// The per-lane guard reports exactly the failing lanes on a batch and the
4937    /// single lane on a scalar jet.
4938    #[test]
4939    fn guard_reports_per_lane_failures() {
4940        let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
4941        let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4942            Tower4Batch::variable(
4943                wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
4944                a,
4945            )
4946        });
4947        let verdict = vars[0].guard(|v| v > 0.0);
4948        assert_eq!(verdict.lanes(), 4);
4949        assert!(verdict.any_failed());
4950        assert!(!verdict.all_pass());
4951        assert!(!verdict.lane_failed(0));
4952        assert!(verdict.lane_failed(1));
4953        assert!(!verdict.lane_failed(2));
4954        assert!(verdict.lane_failed(3));
4955        assert_eq!(verdict.failed_mask(), 0b1010);
4956
4957        let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
4958        let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
4959        assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
4960        assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
4961        assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
4962    }
4963
4964    // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
4965
4966    #[test]
4967    fn ln_gamma_derivative_stack_known_values_at_1() {
4968        let s = ln_gamma_derivative_stack(1.0);
4969        // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
4970        assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
4971        // ψ₀(1) = -γ  (Euler–Mascheroni)
4972        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
4973        assert!(
4974            (s[1] + euler_mascheroni).abs() < 1e-10,
4975            "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
4976            s[1]
4977        );
4978        // ψ₁(1) = π²/6
4979        let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
4980        assert!(
4981            (s[2] - pi2_6).abs() < 1e-10,
4982            "trigamma(1) ≈ {pi2_6:.6}, got {}",
4983            s[2]
4984        );
4985    }
4986
4987    #[test]
4988    fn ln_gamma_derivative_stack_known_values_at_2() {
4989        let s = ln_gamma_derivative_stack(2.0);
4990        // ln Γ(2) = ln(1) = 0 exactly
4991        assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
4992        // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
4993        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
4994        let digamma_2 = 1.0 - euler_mascheroni;
4995        assert!(
4996            (s[1] - digamma_2).abs() < 1e-10,
4997            "digamma(2) ≈ {digamma_2:.6}, got {}",
4998            s[1]
4999        );
5000    }
5001
5002    #[test]
5003    fn ln_gamma_derivative_stack_order2_is_prefix() {
5004        for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5005            let full = ln_gamma_derivative_stack(x);
5006            let ord2 = ln_gamma_derivative_stack_order2(x);
5007            assert_eq!(
5008                ord2[0], full[0],
5009                "order2[0] != full[0] at x={x}"
5010            );
5011            assert_eq!(
5012                ord2[1], full[1],
5013                "order2[1] != full[1] at x={x}"
5014            );
5015            assert_eq!(
5016                ord2[2], full[2],
5017                "order2[2] != full[2] at x={x}"
5018            );
5019        }
5020    }
5021
5022    #[test]
5023    fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5024        // The two stacks share a run of four polygamma values:
5025        // ln_gamma_stack[1..5] == digamma_stack[0..4]
5026        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5027            let lg = ln_gamma_derivative_stack(x);
5028            let dg = digamma_derivative_stack(x);
5029            for i in 0..4 {
5030                assert_eq!(
5031                    lg[i + 1], dg[i],
5032                    "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5033                    i + 1,
5034                    i
5035                );
5036            }
5037        }
5038    }
5039
5040    #[test]
5041    fn trigamma_derivative_stack_overlaps_digamma_stack() {
5042        // digamma_stack[1..5] == trigamma_stack[0..4]
5043        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5044            let dg = digamma_derivative_stack(x);
5045            let tg = trigamma_derivative_stack(x);
5046            for i in 0..4 {
5047                assert_eq!(
5048                    dg[i + 1], tg[i],
5049                    "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5050                    i + 1,
5051                    i
5052                );
5053            }
5054        }
5055    }
5056
5057    #[test]
5058    fn derivative_stacks_all_finite_at_positive_inputs() {
5059        for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5060            for v in ln_gamma_derivative_stack(x) {
5061                assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5062            }
5063            for v in digamma_derivative_stack(x) {
5064                assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5065            }
5066            for v in trigamma_derivative_stack(x) {
5067                assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5068            }
5069        }
5070    }
5071}
5072
5073// ── Contraction-symmetry optimization gate ────────────────────────────────────
5074//
5075// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5076// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5077// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5078// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5079// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5080// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5081// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5082// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5083// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5084// the full reference and (2) a measured wall-clock that is not slower.
5085#[cfg(test)]
5086mod contraction_symmetry_tests {
5087    use super::*;
5088
5089    struct Rng(u64);
5090    impl Rng {
5091        fn u(&mut self) -> f64 {
5092            self.0 = self
5093                .0
5094                .wrapping_mul(6364136223846793005)
5095                .wrapping_add(1442695040888963407);
5096            (self.0 >> 11) as f64 / (1u64 << 53) as f64
5097        }
5098        fn s(&mut self) -> f64 {
5099            (self.u() - 0.5) * 4.0
5100        }
5101    }
5102
5103    /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5104    fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5105        let mut t = Tower4::<K>::zero();
5106        t.v = r.s();
5107        for i in 0..K {
5108            t.g[i] = r.s();
5109        }
5110        for a in 0..K {
5111            for b in a..K {
5112                let v2 = r.s();
5113                t.h[a][b] = v2;
5114                t.h[b][a] = v2;
5115                for c in b..K {
5116                    let v3 = r.s();
5117                    for p in perms3([a, b, c]) {
5118                        t.t3[p[0]][p[1]][p[2]] = v3;
5119                    }
5120                    for d in c..K {
5121                        let v4 = r.s();
5122                        for p in perms4([a, b, c, d]) {
5123                            t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5124                        }
5125                    }
5126                }
5127            }
5128        }
5129        t
5130    }
5131
5132    fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5133        let [a, b, c] = idx;
5134        [[a, b, c], [a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]]
5135    }
5136    fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5137        let [a, b, c, d] = idx;
5138        [
5139            [a, b, c, d], [a, b, d, c], [a, c, b, d], [a, c, d, b], [a, d, b, c], [a, d, c, b],
5140            [b, a, c, d], [b, a, d, c], [b, c, a, d], [b, c, d, a], [b, d, a, c], [b, d, c, a],
5141            [c, a, b, d], [c, a, d, b], [c, b, a, d], [c, b, d, a], [c, d, a, b], [c, d, b, a],
5142            [d, a, b, c], [d, a, c, b], [d, b, a, c], [d, b, c, a], [d, c, a, b], [d, c, b, a],
5143        ]
5144    }
5145
5146    /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5147    fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5148        let mut out = [[0.0; K]; K];
5149        for a in 0..K {
5150            for b in 0..K {
5151                let mut acc = 0.0;
5152                for c in 0..K {
5153                    acc += t.t3[a][b][c] * dir[c];
5154                }
5155                out[a][b] = acc;
5156            }
5157        }
5158        out
5159    }
5160    fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5161        let mut out = [[0.0; K]; K];
5162        for i in 0..K {
5163            for j in 0..K {
5164                let mut acc = 0.0;
5165                for k in 0..K {
5166                    for l in 0..K {
5167                        acc += t.t4[i][j][k][l] * u[k] * w[l];
5168                    }
5169                }
5170                out[i][j] = acc;
5171            }
5172        }
5173        out
5174    }
5175
5176    fn check_bit_identical<const K: usize>(seed: u64, n: usize) {
5177        let mut r = Rng(seed);
5178        for _ in 0..n {
5179            let t = rand_sym4::<K>(&mut r);
5180            let dir: [f64; K] = std::array::from_fn(|_| r.s());
5181            let u: [f64; K] = std::array::from_fn(|_| r.s());
5182            let w: [f64; K] = std::array::from_fn(|_| r.s());
5183            let t3_sym = t.third_contracted(&dir);
5184            let t3_full = third_full(&t, &dir);
5185            let t4_sym = t.fourth_contracted(&u, &w);
5186            let t4_full = fourth_full(&t, &u, &w);
5187            for a in 0..K {
5188                for b in 0..K {
5189                    assert_eq!(
5190                        t3_sym[a][b].to_bits(),
5191                        t3_full[a][b].to_bits(),
5192                        "third K={K} [{a}][{b}]"
5193                    );
5194                    assert_eq!(
5195                        t4_sym[a][b].to_bits(),
5196                        t4_full[a][b].to_bits(),
5197                        "fourth K={K} [{a}][{b}]"
5198                    );
5199                }
5200            }
5201        }
5202    }
5203
5204    /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5205    /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5206    /// are unchanged; this is a pure speed-only optimization).
5207    #[test]
5208    fn contraction_symmetry_is_bit_identical_to_full_nest() {
5209        check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000);
5210        check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800);
5211        check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600);
5212        check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5213    }
5214
5215    /// Measure the wall-clock of the output-symmetric contraction vs the full
5216    /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5217    /// test is the correctness gate). Informational — wall-clock is noisy — with
5218    /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5219    /// fewer inner contractions, so it must not be materially slower).
5220    #[test]
5221    fn contraction_symmetry_speedup_is_reported() {
5222        const K: usize = 9;
5223        let mut r = Rng(0xC0FF_EE99_1234_5678);
5224        let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5225        let dir: [f64; K] = std::array::from_fn(|_| r.s());
5226        let u: [f64; K] = std::array::from_fn(|_| r.s());
5227        let w: [f64; K] = std::array::from_fn(|_| r.s());
5228
5229        let reps = 400usize;
5230        let t_sym = {
5231            let start = std::time::Instant::now();
5232            let mut sink = 0.0f64;
5233            for _ in 0..reps {
5234                for t in &towers {
5235                    let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5236                    let o4 = std::hint::black_box(t)
5237                        .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5238                    sink += o3[0][K - 1] + o4[0][K - 1];
5239                }
5240            }
5241            std::hint::black_box(sink);
5242            start.elapsed().as_secs_f64()
5243        };
5244        let t_full = {
5245            let start = std::time::Instant::now();
5246            let mut sink = 0.0f64;
5247            for _ in 0..reps {
5248                for t in &towers {
5249                    let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5250                    let o4 = fourth_full(
5251                        std::hint::black_box(t),
5252                        std::hint::black_box(&u),
5253                        std::hint::black_box(&w),
5254                    );
5255                    sink += o3[0][K - 1] + o4[0][K - 1];
5256                }
5257            }
5258            std::hint::black_box(sink);
5259            start.elapsed().as_secs_f64()
5260        };
5261        let calls = (reps * towers.len()) as f64;
5262        eprintln!(
5263            "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5264             wall_speedup={:.2}x",
5265            t_sym / calls * 1e9,
5266            t_full / calls * 1e9,
5267            t_full / t_sym
5268        );
5269        assert!(
5270            t_sym <= t_full * 1.5,
5271            "output-symmetric contraction pathologically slower: \
5272             sym={t_sym:.4}s full={t_full:.4}s"
5273        );
5274    }
5275}