Skip to main content

gam_math/
jet_tower.rs

1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//!   v        ℓ
12//!   g[a]     ∂ℓ/∂p_a
13//!   h[a][b]  ∂²ℓ/∂p_a∂p_b
14//!   t3[abc]  ∂³ℓ/∂p_a∂p_b∂p_c
15//!   t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//!    asserting channel-by-channel agreement with a `RowNllProgram` written
94//!    once — see [`verify_kernel_channels`]. This alone would have caught
95//!    #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//!    keep hand-tuned hot paths, now verified against the single-expression
98//!    truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//!    scale port) implement ONLY `RowNllProgram` and get an exact
101//!    fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113    /// Value ℓ.
114    pub v: f64,
115    /// Gradient ∂ℓ/∂p_a.
116    pub g: [f64; K],
117    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118    pub h: [[f64; K]; K],
119    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120    pub t3: [[[f64; K]; K]; K],
121    /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122    pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126    /// The additive identity.
127    pub fn zero() -> Self {
128        Self {
129            v: 0.0,
130            g: [0.0; K],
131            h: [[0.0; K]; K],
132            t3: [[[0.0; K]; K]; K],
133            t4: [[[[0.0; K]; K]; K]; K],
134        }
135    }
136
137    /// A constant: value `c`, all derivatives zero.
138    pub fn constant(c: f64) -> Self {
139        let mut out = Self::zero();
140        out.v = c;
141        out
142    }
143
144    /// The seeded variable `p_idx` with current value `value`:
145    /// unit first derivative in slot `idx`, zero elsewhere and above.
146    pub fn variable(value: f64, idx: usize) -> Self {
147        let mut out = Self::constant(value);
148        out.g[idx] = 1.0;
149        out
150    }
151
152    /// Read the (fully symmetric) derivative tensor entry whose differentiation
153    /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154    #[inline]
155    fn deriv(&self, labels: &[usize]) -> f64 {
156        assert!(
157            labels.len() <= 4,
158            "Tower4 carries at most fourth-order derivatives"
159        );
160        match labels.len() {
161            0 => self.v,
162            1 => self.g[labels[0]],
163            2 => self.h[labels[0]][labels[1]],
164            3 => self.t3[labels[0]][labels[1]][labels[2]],
165            _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166        }
167    }
168
169    /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170    ///
171    /// # Codegen
172    ///
173    /// Each output entry's `2^m` subset sum is written as a compact straight-line
174    /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175    /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176    /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177    /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178    /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179    /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180    /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181    ///
182    /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183    /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184    /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185    /// start so a signed-zero leading product collapses to `+0.0` identically —
186    /// which matters because real jets carry exact-`0.0` channels
187    /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188    /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189    /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190    /// stress — the accumulator start is load-bearing).
191    pub fn mul(&self, o: &Self) -> Self {
192        let a = self;
193        let b = o;
194        let mut out = Self::zero();
195        out.v = a.v * b.v;
196        for i in 0..K {
197            // subsets of {i}: {} {i}
198            let mut acc = 0.0;
199            acc += a.v * b.g[i];
200            acc += a.g[i] * b.v;
201            out.g[i] = acc;
202        }
203        // Hessian is symmetric under i↔j; compute the upper triangle and mirror
204        // (see [`Tower2::mul`] — same term order, enforces exact symmetry).
205        for i in 0..K {
206            for j in i..K {
207                // subsets of {i,j}: {} {i} {j} {ij}
208                let mut acc = 0.0;
209                acc += a.v * b.h[i][j];
210                acc += a.g[i] * b.g[j];
211                acc += a.g[j] * b.g[i];
212                acc += a.h[i][j] * b.v;
213                out.h[i][j] = acc;
214                out.h[j][i] = acc;
215            }
216        }
217        for i in 0..K {
218            for j in 0..K {
219                for k in 0..K {
220                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
221                    let mut acc = 0.0;
222                    acc += a.v * b.t3[i][j][k];
223                    acc += a.g[i] * b.h[j][k];
224                    acc += a.g[j] * b.h[i][k];
225                    acc += a.h[i][j] * b.g[k];
226                    acc += a.g[k] * b.h[i][j];
227                    acc += a.h[i][k] * b.g[j];
228                    acc += a.h[j][k] * b.g[i];
229                    acc += a.t3[i][j][k] * b.v;
230                    out.t3[i][j][k] = acc;
231                }
232            }
233        }
234        for i in 0..K {
235            for j in 0..K {
236                for k in 0..K {
237                    for l in 0..K {
238                        // subsets of {i,j,k,l} in bit order sub = 0..16
239                        let mut acc = 0.0;
240                        acc += a.v * b.t4[i][j][k][l];
241                        acc += a.g[i] * b.t3[j][k][l];
242                        acc += a.g[j] * b.t3[i][k][l];
243                        acc += a.h[i][j] * b.h[k][l];
244                        acc += a.g[k] * b.t3[i][j][l];
245                        acc += a.h[i][k] * b.h[j][l];
246                        acc += a.h[j][k] * b.h[i][l];
247                        acc += a.t3[i][j][k] * b.g[l];
248                        acc += a.g[l] * b.t3[i][j][k];
249                        acc += a.h[i][l] * b.h[j][k];
250                        acc += a.h[j][l] * b.h[i][k];
251                        acc += a.t3[i][j][l] * b.g[k];
252                        acc += a.h[k][l] * b.h[i][j];
253                        acc += a.t3[i][k][l] * b.g[j];
254                        acc += a.t3[j][k][l] * b.g[i];
255                        acc += a.t4[i][j][k][l] * b.v;
256                        out.t4[i][j][k][l] = acc;
257                    }
258                }
259            }
260        }
261        out
262    }
263
264    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
265    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
266    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
267    /// out of the borrowed operands.
268    pub fn add(&self, o: &Self) -> Self {
269        *self + *o
270    }
271
272    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
273    pub fn sub(&self, o: &Self) -> Self {
274        *self + o.scale(-1.0)
275    }
276
277    /// Exact multivariate Faà di Bruno composition `f ∘ self`.
278    ///
279    /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
280    /// the SAME `[f64; 5]` stack shape the families' existing
281    /// `unary_derivatives_*` helpers produce, so those special-function
282    /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
283    ///
284    /// The order-m output sums over the set partitions of the m indices
285    /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
286    /// block count: each partition into r blocks contributes
287    /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
288    ///
289    /// # Codegen
290    ///
291    /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
292    /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
293    /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
294    /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
295    /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
296    /// sum is straight-line, so this does NOT unroll over `K` and does NOT
297    /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
298    /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
299    /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
300    ///
301    /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
302    /// walker's exact partition-enumeration order, each term's block products
303    /// are left-associated exactly as the walker's `prod *= block`, and the
304    /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
305    /// (so signed-zero products collapse to `+0.0` identically). The order-4
306    /// term sequence was generated from the walker's own enumeration. Proven
307    /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
308    /// 5000 random inputs each (zeroed / sign-varied stacks included).
309    pub fn compose_unary(&self, d: [f64; 5]) -> Self {
310        let mut out = Self::zero();
311        out.v = d[0];
312        for i in 0..K {
313            let mut acc = 0.0;
314            acc += d[1] * self.g[i];
315            out.g[i] = acc;
316        }
317        for i in 0..K {
318            for j in 0..K {
319                let mut acc = 0.0;
320                acc += d[1] * self.h[i][j];
321                acc += d[2] * self.g[i] * self.g[j];
322                out.h[i][j] = acc;
323            }
324        }
325        for i in 0..K {
326            for j in 0..K {
327                for k in 0..K {
328                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
329                    let mut acc = 0.0;
330                    acc += d[1] * self.t3[i][j][k];
331                    acc += d[2] * self.h[i][j] * self.g[k];
332                    acc += d[2] * self.h[i][k] * self.g[j];
333                    acc += d[2] * self.g[i] * self.h[j][k];
334                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
335                    out.t3[i][j][k] = acc;
336                }
337            }
338        }
339        for i in 0..K {
340            for j in 0..K {
341                for k in 0..K {
342                    for l in 0..K {
343                        // Bell(4)=15 partitions, walker enumeration order.
344                        let mut acc = 0.0;
345                        acc += d[1] * self.t4[i][j][k][l];
346                        acc += d[2] * self.t3[i][j][k] * self.g[l];
347                        acc += d[2] * self.t3[i][j][l] * self.g[k];
348                        acc += d[2] * self.h[i][j] * self.h[k][l];
349                        acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
350                        acc += d[2] * self.t3[i][k][l] * self.g[j];
351                        acc += d[2] * self.h[i][k] * self.h[j][l];
352                        acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
353                        acc += d[2] * self.h[i][l] * self.h[j][k];
354                        acc += d[2] * self.g[i] * self.t3[j][k][l];
355                        acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
356                        acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
357                        acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
358                        acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
359                        acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
360                        out.t4[i][j][k][l] = acc;
361                    }
362                }
363            }
364        }
365        out
366    }
367
368    /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
369    /// built from the base value through `stack_fn` — the scalar arm of the
370    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
371    /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
372    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
373    /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
374    /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
375    /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
376    #[inline]
377    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
378        self.compose_unary(stack_fn(self.v))
379    }
380
381    /// Single-active-slot fast path for [`Self::compose_unary`].
382    ///
383    /// When the inner jet `self` has derivative support ONLY on the all-`slot`
384    /// diagonal channels — i.e. it is a univariate jet in primary `slot`
385    /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
386    /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
387    /// Bruno walk collapses. Every output channel whose axis tuple contains an
388    /// axis `≠ slot` is structurally `0`: each set-partition has a block
389    /// covering that axis, that block reads an off-`slot` derivative of `self`
390    /// (which is `0`), so the block product and the whole partition vanish, and
391    /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
392    /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
393    /// `t3[slot]³`, `t4[slot]⁴`) survive.
394    ///
395    /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
396    /// the EXACT term order of [`Self::compose_unary`]'s diagonal
397    /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
398    /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
399    /// at the zero-init `+0.0`, which the full walk also produces (the
400    /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
401    /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
402    /// `K ≥ 2` this is far fewer floating-point operations than materialising
403    /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
404    /// are all zero, and far cheaper than the recursive set-partition walker the
405    /// diagonal channels previously routed through (a measured ~9.5× speedup vs
406    /// the full `compose_unary`, recovering a 5.9× walker regression at the
407    /// `K ∈ {2,3}` BMS tower widths).
408    ///
409    /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
410    /// five-channel build does not amortise the call/spill overhead).
411    ///
412    /// # Precondition
413    ///
414    /// The caller guarantees the single-active-slot structure. If it does not
415    /// hold, the off-`slot` channels would be wrongly zeroed; use the full
416    /// [`Self::compose_unary`] in that case.
417    #[inline]
418    pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
419        let mut out = Self::zero();
420        let s = slot;
421        let g = self.g[s];
422        let h = self.h[s][s];
423        let t3 = self.t3[s][s][s];
424        let t4 = self.t4[s][s][s][s];
425        out.v = d[0];
426        // g (i=s): d1*g
427        out.g[s] = {
428            let mut acc = 0.0;
429            acc += d[1] * g;
430            acc
431        };
432        // h (i=j=s): d1*h + d2*g*g
433        out.h[s][s] = {
434            let mut acc = 0.0;
435            acc += d[1] * h;
436            acc += d[2] * g * g;
437            acc
438        };
439        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
440        out.t3[s][s][s] = {
441            let mut acc = 0.0;
442            acc += d[1] * t3;
443            acc += d[2] * h * g;
444            acc += d[2] * h * g;
445            acc += d[2] * g * h;
446            acc += d[3] * g * g * g;
447            acc
448        };
449        // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
450        out.t4[s][s][s][s] = {
451            let mut acc = 0.0;
452            acc += d[1] * t4;
453            acc += d[2] * t3 * g;
454            acc += d[2] * t3 * g;
455            acc += d[2] * h * h;
456            acc += d[3] * h * g * g;
457            acc += d[2] * t3 * g;
458            acc += d[2] * h * h;
459            acc += d[3] * h * g * g;
460            acc += d[2] * h * h;
461            acc += d[2] * g * t3;
462            acc += d[3] * g * h * g;
463            acc += d[3] * h * g * g;
464            acc += d[3] * g * h * g;
465            acc += d[3] * g * g * h;
466            acc += d[4] * g * g * g * g;
467            acc
468        };
469        out
470    }
471
472    /// Multiply every channel by a plain scalar.
473    pub fn scale(&self, s: f64) -> Self {
474        let mut out = *self;
475        out.v *= s;
476        for i in 0..K {
477            out.g[i] *= s;
478            for j in 0..K {
479                out.h[i][j] *= s;
480                for k in 0..K {
481                    out.t3[i][j][k] *= s;
482                    for l in 0..K {
483                        out.t4[i][j][k][l] *= s;
484                    }
485                }
486            }
487        }
488        out
489    }
490
491    /// e^self.
492    pub fn exp(&self) -> Self {
493        let e = self.v.exp();
494        self.compose_unary([e, e, e, e, e])
495    }
496
497    /// ln(self). Caller guarantees positivity (likelihood programs do).
498    pub fn ln(&self) -> Self {
499        let u = self.v;
500        let r = 1.0 / u;
501        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
502    }
503
504    /// 1/self.
505    pub fn recip(&self) -> Self {
506        let r = 1.0 / self.v;
507        let r2 = r * r;
508        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
509    }
510
511    /// √self. Caller guarantees positivity.
512    pub fn sqrt(&self) -> Self {
513        let u = self.v;
514        let s = u.sqrt();
515        self.compose_unary([
516            s,
517            0.5 / s,
518            -0.25 / (u * s),
519            0.375 / (u * u * s),
520            -0.9375 / (u * u * u * s),
521        ])
522    }
523
524    /// self^a for real exponent `a`. Caller guarantees a positive base.
525    pub fn powf(&self, a: f64) -> Self {
526        let u = self.v;
527        let f0 = u.powf(a);
528        let f1 = a * u.powf(a - 1.0);
529        let f2 = a * (a - 1.0) * u.powf(a - 2.0);
530        let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
531        let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
532        self.compose_unary([f0, f1, f2, f3, f4])
533    }
534
535    /// ln Γ(self). Caller guarantees positivity.
536    pub fn ln_gamma(&self) -> Self {
537        self.compose_unary(ln_gamma_derivative_stack(self.v))
538    }
539
540    /// ψ(self), the digamma function. Caller guarantees positivity.
541    pub fn digamma(&self) -> Self {
542        self.compose_unary(digamma_derivative_stack(self.v))
543    }
544
545    /// ψ′(self), the trigamma function. Caller guarantees positivity.
546    pub fn trigamma(&self) -> Self {
547        self.compose_unary(trigamma_derivative_stack(self.v))
548    }
549
550    /// Contract `t3` with one primary-space direction:
551    /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
552    /// `row_third_contracted` shape.
553    ///
554    /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
555    /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
556    /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
557    /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
558    /// stays contiguous/vectorisable) and mirror into the lower triangle — this
559    /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
560    /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
561    pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
562        let mut out = [[0.0; K]; K];
563        for a in 0..K {
564            for b in a..K {
565                let mut acc = 0.0;
566                for c in 0..K {
567                    acc += self.t3[a][b][c] * dir[c];
568                }
569                out[a][b] = acc;
570                out[b][a] = acc;
571            }
572        }
573        out
574    }
575
576    /// Contract `t4` with two primary-space directions:
577    /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
578    /// `row_fourth_contracted` shape.
579    ///
580    /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
581    /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
582    /// order), so the upper triangle `i ≤ j` is computed and mirrored —
583    /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
584    /// and the inner double loop stays the original contiguous/vectorisable form.
585    pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
586        let mut out = [[0.0; K]; K];
587        for i in 0..K {
588            for j in i..K {
589                let mut acc = 0.0;
590                for k in 0..K {
591                    for l in 0..K {
592                        acc += self.t4[i][j][k][l] * u[k] * w[l];
593                    }
594                }
595                out[i][j] = acc;
596                out[j][i] = acc;
597            }
598        }
599        out
600    }
601}
602
603impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
604    #[inline]
605    fn derivative(&self, labels: &[usize]) -> f64 {
606        self.deriv(labels)
607    }
608
609    fn map_derivatives<F>(&self, mut f: F) -> Self
610    where
611        F: FnMut(&[usize]) -> f64,
612    {
613        let mut out = Self::zero();
614        out.v = f(&[]);
615        for i in 0..K {
616            let labels = [i];
617            out.g[i] = f(&labels);
618        }
619        for i in 0..K {
620            for j in 0..K {
621                let labels = [i, j];
622                out.h[i][j] = f(&labels);
623            }
624        }
625        for i in 0..K {
626            for j in 0..K {
627                for k in 0..K {
628                    let labels = [i, j, k];
629                    out.t3[i][j][k] = f(&labels);
630                }
631            }
632        }
633        for i in 0..K {
634            for j in 0..K {
635                for k in 0..K {
636                    for l in 0..K {
637                        let labels = [i, j, k, l];
638                        out.t4[i][j][k][l] = f(&labels);
639                    }
640                }
641            }
642        }
643        out
644    }
645}
646
647/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
648///
649/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
650/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
651/// [`Tower4`] uses for those orders, so for any program written over both
652/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
653/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
654/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
655/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
656/// gradient, or Hessian.
657///
658/// It exists purely for performance: an inner Newton step (and the
659/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
660/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
661/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
662/// dominates the cold marginal-slope fit, while returning the exact same
663/// `(v, g, h)`.
664#[derive(Clone, Copy, Debug)]
665pub struct Tower2<const K: usize> {
666    /// Value ℓ.
667    pub v: f64,
668    /// Gradient ∂ℓ/∂p_a.
669    pub g: [f64; K],
670    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
671    pub h: [[f64; K]; K],
672}
673
674impl<const K: usize> Tower2<K> {
675    /// The additive identity.
676    pub fn zero() -> Self {
677        Self {
678            v: 0.0,
679            g: [0.0; K],
680            h: [[0.0; K]; K],
681        }
682    }
683
684    /// A constant: value `c`, all derivatives zero.
685    pub fn constant(c: f64) -> Self {
686        let mut out = Self::zero();
687        out.v = c;
688        out
689    }
690
691    /// The seeded variable `p_idx` with current value `value`:
692    /// unit first derivative in slot `idx`, zero elsewhere and above.
693    pub fn variable(value: f64, idx: usize) -> Self {
694        let mut out = Self::constant(value);
695        out.g[idx] = 1.0;
696        out
697    }
698
699    /// Read the derivative tensor entry whose differentiation axes are
700    /// `labels` (length 0..=2): value, `g`, `h`.
701    #[inline]
702    fn deriv(&self, labels: &[usize]) -> f64 {
703        assert!(
704            labels.len() <= 2,
705            "Tower2 carries at most second-order derivatives"
706        );
707        match labels.len() {
708            0 => self.v,
709            1 => self.g[labels[0]],
710            _ => self.h[labels[0]][labels[1]],
711        }
712    }
713
714    /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` upper
715    /// triangle matches [`Tower4::mul`] term-for-term.
716    ///
717    /// # Symmetry fast path
718    ///
719    /// The order-≤2 Leibniz Hessian
720    /// `h[i][j] = a.v·b.h[i][j] + a.g[i]·b.g[j] + a.g[j]·b.g[i] + a.h[i][j]·b.v`
721    /// is symmetric under `i ↔ j` whenever the operand Hessians are — which they
722    /// always are: `constant`/`variable` seed a symmetric (zero) `h`, and
723    /// `mul`/`compose_unary`/`add`/`scale` each preserve symmetry, so the
724    /// invariant holds for every tower a row program can build. We therefore
725    /// compute only the upper triangle `j ≥ i` and mirror it into the lower
726    /// triangle. At the `K = 9` survival width that is `K(K+1)/2 = 45` four-product
727    /// entry evaluations instead of `K² = 81`, and the win is larger in wall-clock
728    /// because the `648`-entry `h` spills at `K = 9` — halving the expensive
729    /// stores/reloads roughly halves the kernel (measured ≈2× on a `Tower2<9>`
730    /// mul-and-read throughput microbench; the dominant `mul` under every packed
731    /// scalar bottoms out here).
732    ///
733    /// The upper-triangle entries are BIT-IDENTICAL to the old rectangular form
734    /// (same term/accumulation order). The lower triangle now equals its mirror
735    /// exactly, where the rectangular form rounded `h[i][j]` and `h[j][i]`
736    /// independently (the two cross products accumulate in opposite order) and
737    /// left a ≤1-ulp asymmetry; mirroring removes it, so the result is exactly
738    /// symmetric — strictly closer to the true symmetric Hessian, not merely a
739    /// reordering. Dense-`h` consumers are all tolerance-gated (rel-tol ≥ 1e-11 ≫
740    /// 1e-16); the `f64`/`f64x4` lane oracle stays exact because
741    /// [`crate::jet_scalar::Order2Lane::mul`] mirrors term-for-term.
742    pub fn mul(&self, o: &Self) -> Self {
743        let a = self;
744        let b = o;
745        let mut out = Self::zero();
746        out.v = a.v * b.v;
747        for i in 0..K {
748            out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
749        }
750        for i in 0..K {
751            for j in i..K {
752                let hij = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
753                out.h[i][j] = hij;
754                out.h[j][i] = hij;
755            }
756        }
757        out
758    }
759
760    /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
761    ///
762    /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
763    /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
764    /// `d[0..=2]` for those orders), so this is a strict truncation, not an
765    /// approximation. The full-order `[f64; 5]` derivative stacks the families
766    /// already produce can be passed by slicing their first three entries.
767    ///
768    /// # Codegen
769    ///
770    /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
771    /// directly instead of routing through the generic
772    /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
773    /// closure dispatch). That matters because this is the kernel under EVERY
774    /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
775    /// composition all bottom out here — so the straight-line form (whose inner
776    /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
777    /// walker calls) lifts all of them at once.
778    ///
779    /// The term and accumulation order is BIT-IDENTICAL to the walker it
780    /// replaces: each output channel mirrors the walker's `total = 0.0` start
781    /// (the explicit `acc` accumulator), so a signed-zero product collapses to
782    /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
783    /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
784    /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
785    ///   `g[i]`   = `f′·u_i`                   (single block `{i}`)
786    ///   `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j`  (blocks `{ij}` then `{i}{j}`),
787    /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
788    pub fn compose_unary(&self, d: [f64; 3]) -> Self {
789        let mut out = Self::zero();
790        out.v = d[0];
791        for i in 0..K {
792            let mut acc = 0.0;
793            acc += d[1] * self.g[i];
794            out.g[i] = acc;
795        }
796        for i in 0..K {
797            for j in 0..K {
798                let mut acc = 0.0;
799                acc += d[1] * self.h[i][j];
800                acc += d[2] * self.g[i] * self.g[j];
801                out.h[i][j] = acc;
802            }
803        }
804        out
805    }
806
807    /// Multiply every channel by a plain scalar.
808    pub fn scale(&self, s: f64) -> Self {
809        let mut out = *self;
810        out.v *= s;
811        for i in 0..K {
812            out.g[i] *= s;
813            for j in 0..K {
814                out.h[i][j] *= s;
815            }
816        }
817        out
818    }
819
820    /// e^self.
821    pub fn exp(&self) -> Self {
822        let e = self.v.exp();
823        self.compose_unary([e, e, e])
824    }
825
826    /// √self. Caller guarantees positivity.
827    pub fn sqrt(&self) -> Self {
828        let u = self.v;
829        let s = u.sqrt();
830        self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
831    }
832}
833
834impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
835    #[inline]
836    fn derivative(&self, labels: &[usize]) -> f64 {
837        self.deriv(labels)
838    }
839
840    fn map_derivatives<F>(&self, mut f: F) -> Self
841    where
842        F: FnMut(&[usize]) -> f64,
843    {
844        let mut out = Self::zero();
845        out.v = f(&[]);
846        for i in 0..K {
847            let labels = [i];
848            out.g[i] = f(&labels);
849        }
850        for i in 0..K {
851            for j in 0..K {
852                let labels = [i, j];
853                out.h[i][j] = f(&labels);
854            }
855        }
856        out
857    }
858}
859
860impl<const K: usize> std::ops::Add for Tower2<K> {
861    type Output = Self;
862    fn add(self, o: Self) -> Self {
863        let mut out = self;
864        out.v += o.v;
865        for i in 0..K {
866            out.g[i] += o.g[i];
867            for j in 0..K {
868                out.h[i][j] += o.h[i][j];
869            }
870        }
871        out
872    }
873}
874
875impl<const K: usize> std::ops::Mul for Tower2<K> {
876    type Output = Self;
877    fn mul(self, o: Self) -> Self {
878        Tower2::mul(&self, &o)
879    }
880}
881
882impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
883    type Output = Self;
884    fn add(self, c: f64) -> Self {
885        let mut out = self;
886        out.v += c;
887        out
888    }
889}
890
891impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
892    type Output = Self;
893    fn mul(self, c: f64) -> Self {
894        self.scale(c)
895    }
896}
897
898/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
899///
900/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
901/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
902/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
903/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
904/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
905/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
906/// [`Tower4::compose_unary`]). So for any program written over both towers the
907/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
908/// perturb the value, gradient, Hessian, or third derivatives.
909///
910/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
911/// needs up to third derivatives (the survival location-scale row kernel reads
912/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
913/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
914/// product/composition that otherwise dominates the per-row cost.
915#[derive(Clone, Copy, Debug)]
916pub struct Tower3<const K: usize> {
917    /// Value ℓ.
918    pub v: f64,
919    /// Gradient ∂ℓ/∂p_a.
920    pub g: [f64; K],
921    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
922    pub h: [[f64; K]; K],
923    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
924    pub t3: [[[f64; K]; K]; K],
925}
926
927impl<const K: usize> Tower3<K> {
928    /// The additive identity.
929    pub fn zero() -> Self {
930        Self {
931            v: 0.0,
932            g: [0.0; K],
933            h: [[0.0; K]; K],
934            t3: [[[0.0; K]; K]; K],
935        }
936    }
937
938    /// A constant: value `c`, all derivatives zero.
939    pub fn constant(c: f64) -> Self {
940        let mut out = Self::zero();
941        out.v = c;
942        out
943    }
944
945    /// The seeded variable `p_idx` with current value `value`:
946    /// unit first derivative in slot `idx`, zero elsewhere and above.
947    pub fn variable(value: f64, idx: usize) -> Self {
948        let mut out = Self::constant(value);
949        out.g[idx] = 1.0;
950        out
951    }
952
953    /// Read the (fully symmetric) derivative tensor entry whose differentiation
954    /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
955    #[inline]
956    fn deriv(&self, labels: &[usize]) -> f64 {
957        assert!(
958            labels.len() <= 3,
959            "Tower3 carries at most third-order derivatives"
960        );
961        match labels.len() {
962            0 => self.v,
963            1 => self.g[labels[0]],
964            2 => self.h[labels[0]][labels[1]],
965            _ => self.t3[labels[0]][labels[1]][labels[2]],
966        }
967    }
968
969    /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
970    /// channels match [`Tower4::mul`] term-for-term.
971    ///
972    /// # Codegen
973    ///
974    /// Straight-line per-entry subset sums instead of the
975    /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
976    /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
977    /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
978    /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
979    /// signed-zero leading product on exact-`0.0` jet channels). Proven
980    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
981    /// zero/sign-stressed inputs each (these channel formulas are exactly the
982    /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
983    pub fn mul(&self, o: &Self) -> Self {
984        let a = self;
985        let b = o;
986        let mut out = Self::zero();
987        out.v = a.v * b.v;
988        for i in 0..K {
989            let mut acc = 0.0;
990            acc += a.v * b.g[i];
991            acc += a.g[i] * b.v;
992            out.g[i] = acc;
993        }
994        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
995        for i in 0..K {
996            for j in i..K {
997                let mut acc = 0.0;
998                acc += a.v * b.h[i][j];
999                acc += a.g[i] * b.g[j];
1000                acc += a.g[j] * b.g[i];
1001                acc += a.h[i][j] * b.v;
1002                out.h[i][j] = acc;
1003                out.h[j][i] = acc;
1004            }
1005        }
1006        for i in 0..K {
1007            for j in 0..K {
1008                for k in 0..K {
1009                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
1010                    let mut acc = 0.0;
1011                    acc += a.v * b.t3[i][j][k];
1012                    acc += a.g[i] * b.h[j][k];
1013                    acc += a.g[j] * b.h[i][k];
1014                    acc += a.h[i][j] * b.g[k];
1015                    acc += a.g[k] * b.h[i][j];
1016                    acc += a.h[i][k] * b.g[j];
1017                    acc += a.h[j][k] * b.g[i];
1018                    acc += a.t3[i][j][k] * b.v;
1019                    out.t3[i][j][k] = acc;
1020                }
1021            }
1022        }
1023        out
1024    }
1025
1026    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
1027    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
1028    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
1029    /// out of the borrowed operands.
1030    pub fn add(&self, o: &Self) -> Self {
1031        *self + *o
1032    }
1033
1034    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1035    pub fn sub(&self, o: &Self) -> Self {
1036        *self + o.scale(-1.0)
1037    }
1038
1039    /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1040    ///
1041    /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1042    /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1043    /// (which uses only `d[0..=3]` for those orders), so this is a strict
1044    /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1045    /// stacks the families already produce can be passed by slicing their first
1046    /// four entries.
1047    ///
1048    /// # Codegen
1049    ///
1050    /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1051    /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1052    /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1053    /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1054    /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1055    /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1056    /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1057    /// associated block products, `acc = 0.0` accumulator start. Proven
1058    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1059    /// random inputs each.
1060    pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1061        let mut out = Self::zero();
1062        out.v = d[0];
1063        for i in 0..K {
1064            let mut acc = 0.0;
1065            acc += d[1] * self.g[i];
1066            out.g[i] = acc;
1067        }
1068        for i in 0..K {
1069            for j in 0..K {
1070                let mut acc = 0.0;
1071                acc += d[1] * self.h[i][j];
1072                acc += d[2] * self.g[i] * self.g[j];
1073                out.h[i][j] = acc;
1074            }
1075        }
1076        for i in 0..K {
1077            for j in 0..K {
1078                for k in 0..K {
1079                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1080                    let mut acc = 0.0;
1081                    acc += d[1] * self.t3[i][j][k];
1082                    acc += d[2] * self.h[i][j] * self.g[k];
1083                    acc += d[2] * self.h[i][k] * self.g[j];
1084                    acc += d[2] * self.g[i] * self.h[j][k];
1085                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1086                    out.t3[i][j][k] = acc;
1087                }
1088            }
1089        }
1090        out
1091    }
1092
1093    /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1094    /// built from the base value through `stack_fn` — the scalar arm of the
1095    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1096    /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1097    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1098    /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1099    /// [`Tower4::compose_unary_with`].
1100    #[inline]
1101    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1102        self.compose_unary(stack_fn(self.v))
1103    }
1104
1105    /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1106    /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1107    /// derivative support only on the all-`slot` diagonal, every output channel
1108    /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1109    /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1110    /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1111    /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1112    /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1113    /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1114    /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1115    /// set-partition walker the diagonal channels previously routed through,
1116    /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1117    /// widths. Caller guarantees the single-slot precondition; otherwise use
1118    /// [`Self::compose_unary`].
1119    #[inline]
1120    pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1121        let mut out = Self::zero();
1122        let s = slot;
1123        let g = self.g[s];
1124        let h = self.h[s][s];
1125        let t3 = self.t3[s][s][s];
1126        out.v = d[0];
1127        // g (i=s): d1*g
1128        out.g[s] = {
1129            let mut acc = 0.0;
1130            acc += d[1] * g;
1131            acc
1132        };
1133        // h (i=j=s): d1*h + d2*g*g
1134        out.h[s][s] = {
1135            let mut acc = 0.0;
1136            acc += d[1] * h;
1137            acc += d[2] * g * g;
1138            acc
1139        };
1140        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1141        out.t3[s][s][s] = {
1142            let mut acc = 0.0;
1143            acc += d[1] * t3;
1144            acc += d[2] * h * g;
1145            acc += d[2] * h * g;
1146            acc += d[2] * g * h;
1147            acc += d[3] * g * g * g;
1148            acc
1149        };
1150        out
1151    }
1152
1153    /// Multiply every channel by a plain scalar.
1154    pub fn scale(&self, s: f64) -> Self {
1155        let mut out = *self;
1156        out.v *= s;
1157        for i in 0..K {
1158            out.g[i] *= s;
1159            for j in 0..K {
1160                out.h[i][j] *= s;
1161                for k in 0..K {
1162                    out.t3[i][j][k] *= s;
1163                }
1164            }
1165        }
1166        out
1167    }
1168}
1169
1170impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1171    #[inline]
1172    fn derivative(&self, labels: &[usize]) -> f64 {
1173        self.deriv(labels)
1174    }
1175
1176    fn map_derivatives<F>(&self, mut f: F) -> Self
1177    where
1178        F: FnMut(&[usize]) -> f64,
1179    {
1180        let mut out = Self::zero();
1181        out.v = f(&[]);
1182        for i in 0..K {
1183            let labels = [i];
1184            out.g[i] = f(&labels);
1185        }
1186        for i in 0..K {
1187            for j in 0..K {
1188                let labels = [i, j];
1189                out.h[i][j] = f(&labels);
1190            }
1191        }
1192        for i in 0..K {
1193            for j in 0..K {
1194                for k in 0..K {
1195                    let labels = [i, j, k];
1196                    out.t3[i][j][k] = f(&labels);
1197                }
1198            }
1199        }
1200        out
1201    }
1202}
1203
1204impl<const K: usize> std::ops::Add for Tower3<K> {
1205    type Output = Self;
1206    fn add(self, o: Self) -> Self {
1207        let mut out = self;
1208        out.v += o.v;
1209        for i in 0..K {
1210            out.g[i] += o.g[i];
1211            for j in 0..K {
1212                out.h[i][j] += o.h[i][j];
1213                for k in 0..K {
1214                    out.t3[i][j][k] += o.t3[i][j][k];
1215                }
1216            }
1217        }
1218        out
1219    }
1220}
1221
1222pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1223    [
1224        statrs::function::gamma::ln_gamma(x),
1225        digamma_positive(x),
1226        polygamma_positive(1, x),
1227        polygamma_positive(2, x),
1228        polygamma_positive(3, x),
1229    ]
1230}
1231
1232pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1233    [
1234        statrs::function::gamma::ln_gamma(x),
1235        digamma_positive(x),
1236        polygamma_positive(1, x),
1237    ]
1238}
1239
1240pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1241    [
1242        digamma_positive(x),
1243        polygamma_positive(1, x),
1244        polygamma_positive(2, x),
1245        polygamma_positive(3, x),
1246        polygamma_positive(4, x),
1247    ]
1248}
1249
1250pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1251    [
1252        polygamma_positive(1, x),
1253        polygamma_positive(2, x),
1254        polygamma_positive(3, x),
1255        polygamma_positive(4, x),
1256        polygamma_positive(5, x),
1257    ]
1258}
1259
1260/// Scalar digamma ψ(x) for x>0. Bit-identical to `digamma_derivative_stack(x)[0]`
1261/// and to `ln_gamma_derivative_stack(x)[1]`, but evaluates ONLY ψ — the four
1262/// higher polygammas those `[f64; 5]` stacks build are pure discarded work at a
1263/// scalar consumer that reads a single element. Hot-path row kernels that need
1264/// only the digamma value (e.g. the GAMLSS Beta observed cross weight) call this
1265/// instead of indexing `[0]` off a full derivative stack.
1266#[inline]
1267pub fn digamma(x: f64) -> f64 {
1268    digamma_positive(x)
1269}
1270
1271/// Scalar trigamma ψ′(x) for x>0. Bit-identical to
1272/// `trigamma_derivative_stack(x)[0]` (both bottom out in `polygamma_positive(1,
1273/// x)`), but evaluates ONLY ψ′ — the four higher polygammas (orders 2–5) the
1274/// `[f64; 5]` stack builds are discarded at a `[0]` consumer. Used by the
1275/// dispersion-channel Fisher-information row kernels (NB2 `ψ′(θ)−ψ′(θ+μ)`, Beta
1276/// `μψ′(μφ)−(1−μ)ψ′((1−μ)φ)`) which read the trigamma value alone.
1277#[inline]
1278pub fn trigamma(x: f64) -> f64 {
1279    polygamma_positive(1, x)
1280}
1281
1282fn digamma_positive(mut x: f64) -> f64 {
1283    if !(x.is_finite() && x > 0.0) {
1284        return f64::NAN;
1285    }
1286    let mut acc = 0.0;
1287    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1288        acc -= 1.0 / x;
1289        x += 1.0;
1290    }
1291    acc + digamma_asymptotic(x)
1292}
1293
1294fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1295    if !(x.is_finite() && x > 0.0) {
1296        return f64::NAN;
1297    }
1298    let mut acc = 0.0;
1299    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1300        acc += polygamma_recurrence_term(order, x);
1301        x += 1.0;
1302    }
1303    acc + polygamma_asymptotic(order, x)
1304}
1305
1306const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1307const BERNOULLI_EVEN: [(usize, f64); 10] = [
1308    (2, 1.0 / 6.0),
1309    (4, -1.0 / 30.0),
1310    (6, 1.0 / 42.0),
1311    (8, -1.0 / 30.0),
1312    (10, 5.0 / 66.0),
1313    (12, -691.0 / 2730.0),
1314    (14, 7.0 / 6.0),
1315    (16, -3617.0 / 510.0),
1316    (18, 43867.0 / 798.0),
1317    (20, -174611.0 / 330.0),
1318];
1319
1320fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1321    let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1322    sign * factorial(order) / x.powi((order + 1) as i32)
1323}
1324
1325fn digamma_asymptotic(x: f64) -> f64 {
1326    let mut out = x.ln() - 0.5 / x;
1327    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1328        out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1329    }
1330    out
1331}
1332
1333fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1334    if !(1..=5).contains(&order) {
1335        return f64::NAN;
1336    }
1337
1338    let order_factorial = factorial(order);
1339    let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1340    let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1341        + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1342
1343    let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1344    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1345        let rising = rising_factorial(bernoulli_order, order);
1346        out += bernoulli_sign * bernoulli * rising
1347            / bernoulli_order as f64
1348            / x.powi((bernoulli_order + order) as i32);
1349    }
1350    out
1351}
1352
1353fn factorial(n: usize) -> f64 {
1354    (1..=n).fold(1.0, |acc, k| acc * k as f64)
1355}
1356
1357fn rising_factorial(start: usize, len: usize) -> f64 {
1358    (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1359}
1360
1361impl<const K: usize> std::ops::Add for Tower4<K> {
1362    type Output = Self;
1363    fn add(self, o: Self) -> Self {
1364        let mut out = self;
1365        out.v += o.v;
1366        for i in 0..K {
1367            out.g[i] += o.g[i];
1368            for j in 0..K {
1369                out.h[i][j] += o.h[i][j];
1370                for k in 0..K {
1371                    out.t3[i][j][k] += o.t3[i][j][k];
1372                    for l in 0..K {
1373                        out.t4[i][j][k][l] += o.t4[i][j][k][l];
1374                    }
1375                }
1376            }
1377        }
1378        out
1379    }
1380}
1381
1382impl<const K: usize> std::ops::Sub for Tower4<K> {
1383    type Output = Self;
1384    fn sub(self, o: Self) -> Self {
1385        self + o.scale(-1.0)
1386    }
1387}
1388
1389impl<const K: usize> std::ops::Neg for Tower4<K> {
1390    type Output = Self;
1391    fn neg(self) -> Self {
1392        self.scale(-1.0)
1393    }
1394}
1395
1396impl<const K: usize> std::ops::Mul for Tower4<K> {
1397    type Output = Self;
1398    fn mul(self, o: Self) -> Self {
1399        Tower4::mul(&self, &o)
1400    }
1401}
1402
1403impl<const K: usize> std::ops::Div for Tower4<K> {
1404    type Output = Self;
1405    fn div(self, o: Self) -> Self {
1406        Tower4::mul(&self, &o.recip())
1407    }
1408}
1409
1410impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1411    type Output = Self;
1412    fn add(self, c: f64) -> Self {
1413        let mut out = self;
1414        out.v += c;
1415        out
1416    }
1417}
1418
1419impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1420    type Output = Self;
1421    fn sub(self, c: f64) -> Self {
1422        self + (-c)
1423    }
1424}
1425
1426impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1427    type Output = Self;
1428    fn mul(self, c: f64) -> Self {
1429        self.scale(c)
1430    }
1431}
1432
1433// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1434//
1435// The flexible survival marginal-slope row loss is NOT a free composition
1436// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1437// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1438// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1439// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1440// last hand-written one in the codebase, and the genus of #736-class
1441// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1442// this reason). These two combinators close that gap: once the constraint
1443// `F` and the integrand/boundaries are themselves towers, the intercept's
1444// derivative tower and the integral's derivative tower come out EXACTLY at
1445// every order — there is no order left to hand-code and forget.
1446
1447/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1448/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1449/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1450/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1451/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1452/// and `a0` is that solved intercept value.
1453///
1454/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1455/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1456/// derivative. This is the mechanical replacement for the hand-coded
1457/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1458/// recursion (first_full.rs) and its third/fourth-order continuations.
1459///
1460/// Method: order-by-order substitution. We build `a` incrementally; at each
1461/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1462/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1463/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1464/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1465/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1466/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1467/// auditable and exact, not a hand-expanded formula per order.
1468///
1469/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1470/// solve's strict monotonicity guard.
1471///
1472/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1473/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1474/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1475/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1476/// is guarded explicitly and re-verified by a composed-residual self-check.
1477pub fn implicit_solve<const K1: usize, const K: usize>(
1478    f: &Tower4<K1>,
1479    a0: f64,
1480) -> Result<Tower4<K>, String> {
1481    assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1482    let f_a = f.g[0];
1483    if f_a == 0.0 || !f_a.is_finite() {
1484        return Err(format!(
1485            "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1486        ));
1487    }
1488    // The expansion point must be a genuine root of F. The single Newton
1489    // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1490    // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1491    // Newton step (rather than f.v directly) makes the criterion invariant to
1492    // the magnitude of f_a / the units of F.
1493    let root_tol = 1e-9;
1494    if !f.v.is_finite() {
1495        return Err(format!(
1496            "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1497            f.v
1498        ));
1499    }
1500    let newton_step = f.v.abs() / f_a.abs();
1501    if newton_step > root_tol * (1.0 + a0.abs()) {
1502        return Err(format!(
1503            "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1504             F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1505             root_tol {root_tol:.1e} · (1 + |a0|)",
1506            f.v
1507        ));
1508    }
1509    // Start with a = constant a0 (correct through order 0). Then lift each
1510    // order in turn. Because substitute_intercept reads `a`'s order-≤m
1511    // tensors when forming G's order-m coefficient, and the order-m
1512    // coefficient of G depends on a's order-m tensor ONLY through the linear
1513    // F_a·a_m term, a single corrective pass per order is exact.
1514    let mut a = Tower4::<K>::constant(a0);
1515    for order in 1..=4 {
1516        let g = substitute_intercept(f, &a);
1517        // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1518        // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1519        // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1520        // as a plain variable in the substitution's first-order part).
1521        match order {
1522            1 => {
1523                for i in 0..K {
1524                    a.g[i] -= g.g[i] / f_a;
1525                }
1526            }
1527            2 => {
1528                for i in 0..K {
1529                    for j in 0..K {
1530                        a.h[i][j] -= g.h[i][j] / f_a;
1531                    }
1532                }
1533            }
1534            3 => {
1535                for i in 0..K {
1536                    for j in 0..K {
1537                        for k in 0..K {
1538                            a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1539                        }
1540                    }
1541                }
1542            }
1543            _ => {
1544                for i in 0..K {
1545                    for j in 0..K {
1546                        for k in 0..K {
1547                            for l in 0..K {
1548                                a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1549                            }
1550                        }
1551                    }
1552                }
1553            }
1554        }
1555    }
1556    // Self-check: the composed residual G = F∘a must vanish through order 4.
1557    // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1558    // is exactly the root requirement guarded above. Re-verify all channels
1559    // against a scale-aware floor so any arithmetic regression in the
1560    // substitution recursion is loud rather than silently shipping a
1561    // level-set expansion.
1562    let g = substitute_intercept(f, &a);
1563    let resid_tol = 1e-7 * (1.0 + f_a.abs());
1564    let mut worst = g.v.abs();
1565    for i in 0..K {
1566        worst = worst.max(g.g[i].abs());
1567        for j in 0..K {
1568            worst = worst.max(g.h[i][j].abs());
1569            for k in 0..K {
1570                worst = worst.max(g.t3[i][j][k].abs());
1571                for l in 0..K {
1572                    worst = worst.max(g.t4[i][j][k][l].abs());
1573                }
1574            }
1575        }
1576    }
1577    if !worst.is_finite() || worst > resid_tol {
1578        return Err(format!(
1579            "implicit_solve: composed residual G = F∘a does not vanish: \
1580             worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1581        ));
1582    }
1583    Ok(a)
1584}
1585
1586/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1587/// written over `K + 1` variables, returning the composite tower over the
1588/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1589///
1590/// This is the exact multivariate chain rule specialised to "slot 0 is a
1591/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1592/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1593/// point, with the slot-0 increment being the non-constant part of `a` and
1594/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1595/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1596/// so it carries derivatives exactly through order four.
1597pub fn substitute_intercept<const K1: usize, const K: usize>(
1598    f: &Tower4<K1>,
1599    a: &Tower4<K>,
1600) -> Tower4<K> {
1601    assert_eq!(K1, K + 1);
1602    // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1603    // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1604    // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1605    // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1606    // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1607    // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1608    // using tower products for the increment monomials so all θ-derivatives
1609    // propagate exactly.
1610    let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1611        if slot == 0 {
1612            // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1613            let mut d = *a;
1614            d.v = 0.0;
1615            d
1616        } else {
1617            // slot i: the increment δθ_{i-1} = seeded variable minus value.
1618            // θ centred at its expansion value has zero constant term and unit
1619            // first derivative in its own slot.
1620            let mut d = Tower4::<K>::zero();
1621            d.g[slot - 1] = 1.0;
1622            d
1623        }
1624    });
1625    // Accumulate the Taylor sum. order-0 term:
1626    let mut out = Tower4::<K>::constant(f.v);
1627    // order 1: Σ_a f.g[a] · inp[a]
1628    for a_idx in 0..K1 {
1629        out = out + inp[a_idx].scale(f.g[a_idx]);
1630    }
1631    // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1632    for a_idx in 0..K1 {
1633        for b_idx in 0..K1 {
1634            let prod = inp[a_idx].mul(&inp[b_idx]);
1635            out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1636        }
1637    }
1638    // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1639    for a_idx in 0..K1 {
1640        for b_idx in 0..K1 {
1641            for c_idx in 0..K1 {
1642                let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1643                out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1644            }
1645        }
1646    }
1647    // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1648    for a_idx in 0..K1 {
1649        for b_idx in 0..K1 {
1650            for c_idx in 0..K1 {
1651                for d_idx in 0..K1 {
1652                    let prod = inp[a_idx]
1653                        .mul(&inp[b_idx])
1654                        .mul(&inp[c_idx])
1655                        .mul(&inp[d_idx]);
1656                    out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1657                }
1658            }
1659        }
1660    }
1661    out
1662}
1663
1664/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1665/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1666/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1667/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1668/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1669///
1670/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1671/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1672/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1673/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1674/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1675/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1676/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1677/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1678/// to a constant); callers read only the derivative channels and pair this
1679/// with the interior moment-integral value separately.
1680///
1681/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1682/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1683/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1684pub fn moving_limit_boundary_tower<const K: usize>(
1685    z_edge: &Tower4<K>,
1686    b_stack: [f64; 4],
1687) -> Tower4<K> {
1688    z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1689}
1690
1691/// The boundary-flux derivative tower of a single moving cell integral
1692/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1693/// two edge towers and the integrand stacks at each edge. The returned
1694/// tower's derivative channels are the EXACT moving-boundary contribution to
1695/// every θ-derivative of the cell integral, to fourth order, with no term
1696/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1697/// derivative channels are all zero, contributing nothing — matching the
1698/// production `edge_vel = 0` short-circuit.
1699pub fn cell_moving_boundary_flux_tower<const K: usize>(
1700    z_right: &Tower4<K>,
1701    b_stack_right: [f64; 4],
1702    z_left: &Tower4<K>,
1703    b_stack_left: [f64; 4],
1704) -> Tower4<K> {
1705    moving_limit_boundary_tower(z_right, b_stack_right)
1706        - moving_limit_boundary_tower(z_left, b_stack_left)
1707}
1708
1709/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1710///
1711/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1712/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1713/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1714/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1715/// integrand coefficients move with η, hence with the primaries), so the
1716/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1717/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1718/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1719/// edge-motion-carrying terms the hand path assembles one by one, including the
1720/// `G·z_uv` term the directional path drops).
1721///
1722/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1723/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1724/// part — everything carrying edge motion — is exactly
1725///   `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1726/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1727/// interior moment integral already supplies. Both are one
1728/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1729/// θ in slots 1..K): substituting the edge tower gives the full composite,
1730/// substituting a frozen constant edge isolates the pure-θ part, and their
1731/// difference is the exact boundary flux — every Leibniz term derived by the
1732/// substitution algebra, none hand-omitted.
1733///
1734/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1735/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1736/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1737/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1738/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1739/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1740/// the derivative channels are meaningful, matching the value-less convention of
1741/// [`moving_limit_boundary_tower`].
1742pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1743    phi_jet: &Tower4<K1>,
1744    z_edge: &Tower4<K>,
1745) -> Tower4<K> {
1746    assert_eq!(
1747        K1,
1748        K + 1,
1749        "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1750    );
1751    let frozen_edge = Tower4::<K>::constant(z_edge.v);
1752    let full = substitute_intercept(phi_jet, z_edge);
1753    let interior = substitute_intercept(phi_jet, &frozen_edge);
1754    full - interior
1755}
1756
1757/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1758/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1759/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1760/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1761/// so its `full` and `interior` substitutions coincide and it contributes
1762/// nothing — matching the production `edge_vel = 0` short-circuit.
1763pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1764    phi_jet_right: &Tower4<K1>,
1765    z_right: &Tower4<K>,
1766    phi_jet_left: &Tower4<K1>,
1767    z_left: &Tower4<K>,
1768) -> Tower4<K> {
1769    moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1770        - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1771}
1772
1773// ── The program seam ─────────────────────────────────────────────────
1774
1775/// A family's row negative log-likelihood written ONCE over tower scalars.
1776///
1777/// This is the single source of truth #932 asks for: the value channel of
1778/// the returned tower must BE the production row NLL (same branches, same
1779/// guards, same numerics), and every derivative channel is then exact by
1780/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1781/// NOT part of this trait — it is family data, not calculus, and stays on
1782/// the `RowKernel` implementor.
1783pub trait RowNllProgram<const K: usize>: Send + Sync {
1784    /// Number of observations the program covers.
1785    fn n_rows(&self) -> usize;
1786
1787    /// Current primary-scalar values for `row` (where to seed the tower).
1788    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1789
1790    /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1791    /// variable `a` at the current primary value; implementations combine
1792    /// them with `Tower4` arithmetic and per-row data (response, censoring
1793    /// indicators, offsets) entering as constants.
1794    fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1795}
1796
1797/// Evaluate a program's full tower at the current primaries for one row.
1798///
1799/// One call yields every `RowKernel` calculus channel; callers that need
1800/// several contractions of the same row should hold the returned tower and
1801/// contract repeatedly rather than re-evaluating.
1802pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1803    prog: &P,
1804    row: usize,
1805) -> Result<Tower4<K>, String> {
1806    let p = prog.primaries(row)?;
1807    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1808    prog.row_nll(row, &vars)
1809}
1810
1811/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1812pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1813    prog: &P,
1814    row: usize,
1815) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1816    let t = evaluate_program(prog, row)?;
1817    Ok((t.v, t.g, t.h))
1818}
1819
1820/// Mechanically derived `row_third_contracted` channel.
1821pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1822    prog: &P,
1823    row: usize,
1824    dir: &[f64; K],
1825) -> Result<[[f64; K]; K], String> {
1826    Ok(evaluate_program(prog, row)?.third_contracted(dir))
1827}
1828
1829/// Mechanically derived `row_fourth_contracted` channel.
1830pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1831    prog: &P,
1832    row: usize,
1833    dir_u: &[f64; K],
1834    dir_v: &[f64; K],
1835) -> Result<[[f64; K]; K], String> {
1836    Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1837}
1838
1839// ── The generic program seam (#932 scalar cutover) ───────────────────
1840
1841/// A family's row negative log-likelihood written ONCE over the generic
1842/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1843/// re-instantiated at whatever order / representation a consumer needs
1844/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1845/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1846/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1847/// [`Tower4`] for every channel at once).
1848///
1849/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1850/// program implementing this generic trait gets the small contracted scalars for
1851/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1852/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1853/// work unchanged; new families should prefer this generic trait.
1854///
1855/// Because a `Tower4`-specialised `row_nll` body uses only
1856/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1857/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1858/// A program written that way needs no `Tower4`-specialised method and routes
1859/// the directional and joint-Hessian gates through the contracted scalars from
1860/// a single definition.
1861pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1862    /// Number of observations the program covers.
1863    fn n_rows(&self) -> usize;
1864
1865    /// Current primary-scalar values for `row` (where to seed the scalar).
1866    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1867
1868    /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1869    /// (base value + per-scalar nilpotent directions) by the caller; the body
1870    /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1871    /// (response, censoring, offsets) entering as constants.
1872    fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1873        &self,
1874        row: usize,
1875        p: &[S; K],
1876    ) -> Result<S, String>;
1877}
1878
1879/// Evaluate a generic program at the value/gradient/Hessian scalar
1880/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1881/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1882///
1883/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1884/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1885/// cheap order-2 scalar. The single source of truth means the gradient and
1886/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1887pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1888    prog: &P,
1889    row: usize,
1890) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1891    let base = prog.primaries(row)?;
1892    let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1893        <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1894    });
1895    let s = prog.row_nll_generic(row, &vars)?;
1896    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1897}
1898
1899/// Evaluate a generic program at the one-seed scalar
1900/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1901/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1902/// materialising the dense `t3` tensor. The contraction direction is folded
1903/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1904pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1905    prog: &P,
1906    row: usize,
1907    dir: &[f64; K],
1908) -> Result<[[f64; K]; K], String> {
1909    let base = prog.primaries(row)?;
1910    let vars: [crate::jet_scalar::OneSeed<K>; K] =
1911        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1912    let s = prog.row_nll_generic(row, &vars)?;
1913    Ok(s.contracted_third())
1914}
1915
1916/// Evaluate a generic program at the two-seed scalar
1917/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1918/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1919/// WITHOUT materialising the dense `t4` tensor.
1920pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1921    prog: &P,
1922    row: usize,
1923    dir_u: &[f64; K],
1924    dir_v: &[f64; K],
1925) -> Result<[[f64; K]; K], String> {
1926    let base = prog.primaries(row)?;
1927    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1928        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1929    let s = prog.row_nll_generic(row, &vars)?;
1930    Ok(s.contracted_fourth())
1931}
1932
1933/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1934/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1935/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1936/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1937/// the order-2 / contracted scalars consume, so there is a single source of
1938/// truth across every channel.
1939pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1940    prog: &P,
1941    row: usize,
1942) -> Result<Tower4<K>, String> {
1943    let base = prog.primaries(row)?;
1944    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1945    prog.row_nll_generic(row, &vars)
1946}
1947
1948// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1949//
1950// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1951// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1952// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1953// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1954// as separate inherent methods, not a shared trait bound — so a row-NLL body
1955// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1956// 4-rows-per-pass SIMD batch path could not reuse the single source.
1957//
1958// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1959// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1960// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1961// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1962// (so the scalar call sites compile unchanged and bit-identically), and explicit
1963// impls route the `f64x4` lane towers through their existing per-lane methods. A
1964// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1965// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1966
1967/// The verdict of a per-lane [`RowJet::guard`] domain check.
1968///
1969/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1970/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1971/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1972/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1973/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1974/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1975#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1976pub struct GuardVerdict {
1977    lanes: u8,
1978    failed_mask: u8,
1979}
1980
1981impl GuardVerdict {
1982    /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1983    #[inline]
1984    pub fn scalar(pass: bool) -> Self {
1985        Self {
1986            lanes: 1,
1987            failed_mask: if pass { 0 } else { 1 },
1988        }
1989    }
1990    /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1991    #[inline]
1992    pub fn lanes4(failed_mask: u8) -> Self {
1993        Self {
1994            lanes: 4,
1995            failed_mask: failed_mask & 0x0f,
1996        }
1997    }
1998    /// Number of active lanes inspected (1 scalar, 4 batch).
1999    #[inline]
2000    pub fn lanes(self) -> usize {
2001        self.lanes as usize
2002    }
2003    /// True iff every inspected lane satisfied the predicate.
2004    #[inline]
2005    pub fn all_pass(self) -> bool {
2006        self.failed_mask == 0
2007    }
2008    /// True iff at least one inspected lane failed the predicate.
2009    #[inline]
2010    pub fn any_failed(self) -> bool {
2011        self.failed_mask != 0
2012    }
2013    /// True iff lane `i` failed the predicate.
2014    #[inline]
2015    pub fn lane_failed(self, i: usize) -> bool {
2016        (self.failed_mask >> i) & 1 == 1
2017    }
2018    /// The raw failure mask (bit `i` ⇒ lane `i` failed).
2019    #[inline]
2020    pub fn failed_mask(self) -> u8 {
2021        self.failed_mask
2022    }
2023}
2024
2025/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
2026/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
2027/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
2028/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
2029/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
2030/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
2031/// common `N == 5` case is bit-identical to passing the stack straight through).
2032#[inline]
2033fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
2034    let mut out = [0.0_f64; M];
2035    let m = N.min(M);
2036    out[..m].copy_from_slice(&s[..m]);
2037    out
2038}
2039
2040/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
2041/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
2042/// rows/pass without a dual-source copy (module §"The RowJet bridge").
2043///
2044/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
2045/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
2046/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
2047/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
2048/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
2049pub trait RowJet<const K: usize>: Copy {
2050    /// The value channel(s) seen by [`guard`](Self::guard) and
2051    /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
2052    /// `f64x4` lane tower.
2053    type Value: Copy;
2054
2055    /// A constant (value `c`, all derivatives zero), broadcast to every lane.
2056    fn constant(c: f64) -> Self;
2057    /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
2058    /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
2059    /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
2060    /// [`generic_batched_third_tower`]), which build the tower variables directly
2061    /// from each row's primaries; this method is for any row-invariant auxiliary
2062    /// variable a body introduces.
2063    fn variable(x: f64, slot: usize) -> Self;
2064    /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2065    fn values(&self) -> Self::Value;
2066
2067    /// Truncated Leibniz `self + o`.
2068    fn add(&self, o: &Self) -> Self;
2069    /// Truncated Leibniz `self − o`.
2070    fn sub(&self, o: &Self) -> Self;
2071    /// Truncated Leibniz `self · o`.
2072    fn mul(&self, o: &Self) -> Self;
2073    /// Multiply every channel by the plain scalar `s`.
2074    fn scale(&self, s: f64) -> Self;
2075    /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2076    /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2077    fn neg(&self) -> Self {
2078        self.scale(-1.0)
2079    }
2080
2081    /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2082    /// derivative stack is built from the running base value PER LANE through
2083    /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2084    /// inherent method that already exists on both the scalar towers and the lane
2085    /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2086    /// lane tower it is re-run per lane (the four rows carry four distinct base
2087    /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2088    /// Making it a trait method is precisely what lets a body written once over
2089    /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2090    /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2091    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2092
2093    /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2094    /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2095    /// its one value; a lane tower checks all four. Lets a batched program detect
2096    /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2097    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2098
2099    /// Per-lane scale: multiply every channel by the per-lane factor `s`
2100    /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2101    /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2102    /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2103    /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2104    /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2105    /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2106    /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2107    /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2108    /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2109    fn scale_rows(&self, s: Self::Value) -> Self;
2110
2111    /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2112    /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2113    /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2114    /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2115    /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2116    /// knowing the concrete representation: the program holds the per-row data and
2117    /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2118    /// [`RowNllProgramRowJet::row_nll`], so the body writes
2119    /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2120    /// buried in a `compose_unary_with` stack is pulled out the same way:
2121    /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2122    /// (Binary per-row branches such as the event indicator `di` are kept
2123    /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2124    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2125
2126    // ── value-derived transcendental conveniences ───────────────────────
2127    // Each routes through `compose_unary_with` with the SAME derivative stack the
2128    // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2129    // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2130    // `i` is bit-identical to the scalar result on row `i`.
2131
2132    /// `e^self`.
2133    fn exp(&self) -> Self {
2134        self.compose_unary_with(|u| {
2135            let e = u.exp();
2136            [e, e, e, e, e]
2137        })
2138    }
2139    /// `ln(self)`. Caller guarantees positivity.
2140    fn ln(&self) -> Self {
2141        self.compose_unary_with(|u| {
2142            let r = 1.0 / u;
2143            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2144        })
2145    }
2146    /// `√self`. Caller guarantees positivity.
2147    fn sqrt(&self) -> Self {
2148        self.compose_unary_with(|u| {
2149            let s = u.sqrt();
2150            [
2151                s,
2152                0.5 / s,
2153                -0.25 / (u * s),
2154                0.375 / (u * u * s),
2155                -0.9375 / (u * u * u * s),
2156            ]
2157        })
2158    }
2159    /// `1/self`.
2160    fn recip(&self) -> Self {
2161        self.compose_unary_with(|u| {
2162            let r = 1.0 / u;
2163            let r2 = r * r;
2164            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2165        })
2166    }
2167    /// `self^a` for real `a`. Caller guarantees a positive base.
2168    fn powf(&self, a: f64) -> Self {
2169        self.compose_unary_with(move |u| {
2170            [
2171                u.powf(a),
2172                a * u.powf(a - 1.0),
2173                a * (a - 1.0) * u.powf(a - 2.0),
2174                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2175                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2176            ]
2177        })
2178    }
2179    /// `ln Γ(self)`. Caller guarantees a positive argument.
2180    fn ln_gamma(&self) -> Self {
2181        self.compose_unary_with(ln_gamma_derivative_stack)
2182    }
2183    /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2184    fn digamma(&self) -> Self {
2185        self.compose_unary_with(digamma_derivative_stack)
2186    }
2187}
2188
2189/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2190/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2191/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2192/// bridge adds the lane representation without churning the scalar path. (The
2193/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2194/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2195/// any downstream impl, so the coherence checker proves the impls disjoint.)
2196impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2197    type Value = f64;
2198    #[inline]
2199    fn constant(c: f64) -> Self {
2200        <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2201    }
2202    #[inline]
2203    fn variable(x: f64, slot: usize) -> Self {
2204        <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2205    }
2206    #[inline]
2207    fn values(&self) -> f64 {
2208        crate::jet_scalar::JetScalar::value(self)
2209    }
2210    #[inline]
2211    fn add(&self, o: &Self) -> Self {
2212        crate::jet_scalar::JetScalar::add(self, o)
2213    }
2214    #[inline]
2215    fn sub(&self, o: &Self) -> Self {
2216        crate::jet_scalar::JetScalar::sub(self, o)
2217    }
2218    #[inline]
2219    fn mul(&self, o: &Self) -> Self {
2220        crate::jet_scalar::JetScalar::mul(self, o)
2221    }
2222    #[inline]
2223    fn scale(&self, s: f64) -> Self {
2224        crate::jet_scalar::JetScalar::scale(self, s)
2225    }
2226    #[inline]
2227    fn neg(&self) -> Self {
2228        crate::jet_scalar::JetScalar::neg(self)
2229    }
2230    #[inline]
2231    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2232        crate::jet_scalar::JetScalar::compose_unary_with(self, |u| {
2233            resize_stack::<N, 5>(stack_fn(u))
2234        })
2235    }
2236    #[inline]
2237    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2238        GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2239    }
2240    #[inline]
2241    fn scale_rows(&self, s: f64) -> Self {
2242        // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2243        // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2244        crate::jet_scalar::JetScalar::scale(self, s)
2245    }
2246    #[inline]
2247    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2248        value_of(rows[0])
2249    }
2250}
2251
2252/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2253/// routing each op through its existing per-lane method. Lane `i` of a batched
2254/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2255/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2256impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2257    type Value = [f64; 4];
2258    #[inline]
2259    fn constant(c: f64) -> Self {
2260        Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2261    }
2262    #[inline]
2263    fn variable(x: f64, slot: usize) -> Self {
2264        Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2265    }
2266    #[inline]
2267    fn values(&self) -> [f64; 4] {
2268        self.v.to_array()
2269    }
2270    #[inline]
2271    fn add(&self, o: &Self) -> Self {
2272        Tower4Lane::add(self, o)
2273    }
2274    #[inline]
2275    fn sub(&self, o: &Self) -> Self {
2276        Tower4Lane::sub(self, o)
2277    }
2278    #[inline]
2279    fn mul(&self, o: &Self) -> Self {
2280        Tower4Lane::mul(self, o)
2281    }
2282    #[inline]
2283    fn scale(&self, s: f64) -> Self {
2284        Tower4Lane::scale(self, s)
2285    }
2286    #[inline]
2287    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2288        Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2289    }
2290    #[inline]
2291    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2292        let vals = self.v.to_array();
2293        let mut mask = 0u8;
2294        for (i, &v) in vals.iter().enumerate() {
2295            if !pred(v) {
2296                mask |= 1 << i;
2297            }
2298        }
2299        GuardVerdict::lanes4(mask)
2300    }
2301    #[inline]
2302    fn scale_rows(&self, s: [f64; 4]) -> Self {
2303        // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2304        // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2305        let sl = wide::f64x4::new(s);
2306        let mut out = *self;
2307        out.v = self.v * sl;
2308        for i in 0..K {
2309            out.g[i] = self.g[i] * sl;
2310            for j in 0..K {
2311                out.h[i][j] = self.h[i][j] * sl;
2312                for k in 0..K {
2313                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2314                    for l in 0..K {
2315                        out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2316                    }
2317                }
2318            }
2319        }
2320        out
2321    }
2322    #[inline]
2323    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2324        [
2325            value_of(rows[0]),
2326            value_of(rows[1]),
2327            value_of(rows[2]),
2328            value_of(rows[3]),
2329        ]
2330    }
2331}
2332
2333/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2334/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2335/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2336/// [`Tower3`] which also carries only up to the third tensor.
2337impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2338    type Value = [f64; 4];
2339    #[inline]
2340    fn constant(c: f64) -> Self {
2341        Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2342    }
2343    #[inline]
2344    fn variable(x: f64, slot: usize) -> Self {
2345        Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2346    }
2347    #[inline]
2348    fn values(&self) -> [f64; 4] {
2349        self.v.to_array()
2350    }
2351    #[inline]
2352    fn add(&self, o: &Self) -> Self {
2353        Tower3Lane::add(self, o)
2354    }
2355    #[inline]
2356    fn sub(&self, o: &Self) -> Self {
2357        Tower3Lane::sub(self, o)
2358    }
2359    #[inline]
2360    fn mul(&self, o: &Self) -> Self {
2361        Tower3Lane::mul(self, o)
2362    }
2363    #[inline]
2364    fn scale(&self, s: f64) -> Self {
2365        Tower3Lane::scale(self, s)
2366    }
2367    #[inline]
2368    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2369        Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2370    }
2371    #[inline]
2372    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2373        let vals = self.v.to_array();
2374        let mut mask = 0u8;
2375        for (i, &v) in vals.iter().enumerate() {
2376            if !pred(v) {
2377                mask |= 1 << i;
2378            }
2379        }
2380        GuardVerdict::lanes4(mask)
2381    }
2382    #[inline]
2383    fn scale_rows(&self, s: [f64; 4]) -> Self {
2384        let sl = wide::f64x4::new(s);
2385        let mut out = *self;
2386        out.v = self.v * sl;
2387        for i in 0..K {
2388            out.g[i] = self.g[i] * sl;
2389            for j in 0..K {
2390                out.h[i][j] = self.h[i][j] * sl;
2391                for k in 0..K {
2392                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2393                }
2394            }
2395        }
2396        out
2397    }
2398    #[inline]
2399    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2400        [
2401            value_of(rows[0]),
2402            value_of(rows[1]),
2403            value_of(rows[2]),
2404            value_of(rows[3]),
2405        ]
2406    }
2407}
2408
2409/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2410/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2411/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2412/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2413/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2414/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2415/// and the batched channels through [`generic_batched_fourth_tower`] /
2416/// [`generic_batched_third_tower`], all from a single source.
2417pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2418    /// Number of observations the program covers.
2419    fn n_rows(&self) -> usize;
2420
2421    /// Current primary-scalar values for `row` (where to seed each lane).
2422    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2423
2424    /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2425    /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2426    /// pre-seeded by the caller (base value plus, for the directional scalars, the
2427    /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2428    /// per-row data entering through `rows`/`self` as constants.
2429    fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2430}
2431
2432/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2433/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2434/// `RowJet` twin of [`generic_row_kernel`].
2435pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2436    prog: &P,
2437    row: usize,
2438) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2439    let base = prog.primaries(row)?;
2440    let vars: [crate::jet_scalar::Order2<K>; K] =
2441        std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2442    let s = prog.row_nll(&[row], &vars)?;
2443    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2444}
2445
2446/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2447/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2448/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2449pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2450    prog: &P,
2451    row: usize,
2452    dir: &[f64; K],
2453) -> Result<[[f64; K]; K], String> {
2454    let base = prog.primaries(row)?;
2455    let vars: [crate::jet_scalar::OneSeed<K>; K] =
2456        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2457    let s = prog.row_nll(&[row], &vars)?;
2458    Ok(s.contracted_third())
2459}
2460
2461/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2462/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2463/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2464pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2465    prog: &P,
2466    row: usize,
2467    dir_u: &[f64; K],
2468    dir_v: &[f64; K],
2469) -> Result<[[f64; K]; K], String> {
2470    let base = prog.primaries(row)?;
2471    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2472        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2473    let s = prog.row_nll(&[row], &vars)?;
2474    Ok(s.contracted_fourth())
2475}
2476
2477/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2478/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2479/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2480/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2481/// the scalar [`generic_full_tower`] on `rows[i]`.
2482pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2483    prog: &P,
2484    rows: [usize; 4],
2485) -> Result<Tower4Batch<K>, String> {
2486    let bases: [[f64; K]; 4] = [
2487        prog.primaries(rows[0])?,
2488        prog.primaries(rows[1])?,
2489        prog.primaries(rows[2])?,
2490        prog.primaries(rows[3])?,
2491    ];
2492    let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2493        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2494        Tower4Batch::variable(lane_vals, a)
2495    });
2496    prog.row_nll(&rows, &vars)
2497}
2498
2499/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2500/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2501/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2502/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2503pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2504    prog: &P,
2505    rows: [usize; 4],
2506) -> Result<Tower3Batch<K>, String> {
2507    let bases: [[f64; K]; 4] = [
2508        prog.primaries(rows[0])?,
2509        prog.primaries(rows[1])?,
2510        prog.primaries(rows[2])?,
2511        prog.primaries(rows[3])?,
2512    ];
2513    let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2514        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2515        Tower3Batch::variable(lane_vals, a)
2516    });
2517    prog.row_nll(&rows, &vars)
2518}
2519
2520// ── The oracle ───────────────────────────────────────────────────────
2521
2522/// One row's worth of hand-written kernel outputs, as claimed by a
2523/// `RowKernel` implementation, packaged for verification against the
2524/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2525/// its visibility — can be audited from its own test module.
2526pub struct KernelChannels<const K: usize> {
2527    /// Claimed `(nll, ∇, H)` from `row_kernel`.
2528    pub value: f64,
2529    /// Claimed gradient.
2530    pub gradient: [f64; K],
2531    /// Claimed Hessian.
2532    pub hessian: [[f64; K]; K],
2533    /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2534    pub third: Vec<([f64; K], [[f64; K]; K])>,
2535    /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2536    pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2537}
2538
2539/// Channel-by-channel audit of a hand-written kernel against the
2540/// single-expression tower truth. Returns `Err` naming the first channel,
2541/// index, claimed and true values on disagreement — designed as the body
2542/// of the per-family CI oracle tests (#932 deployment step 2).
2543///
2544/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2545/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2546/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2547/// towers pass without demanding bit-equality, while a tiny cross-block
2548/// entry dropped next to a huge one is still caught (it is NOT measured
2549/// against the largest entry of the whole channel — there is no per-channel
2550/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2551///
2552/// Non-finite handling is strict: a NaN on either side always fails; an
2553/// infinity passes only when both sides are the SAME signed infinity.
2554pub fn verify_kernel_channels<const K: usize>(
2555    tower: &Tower4<K>,
2556    claims: &KernelChannels<K>,
2557    rel_tol: f64,
2558) -> Result<(), String> {
2559    // Absolute floor: reuse rel_tol so a single knob controls both the
2560    // relative band and the absolute floor for entries near zero.
2561    let atol = rel_tol;
2562    let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2563        // Non-finite values never silently pass the algebraic comparison
2564        // below (any comparison with NaN is false). Handle them explicitly:
2565        // NaN on either side always errs; an infinity passes only if both
2566        // sides are the identical signed infinity.
2567        if !claim.is_finite() || !truth.is_finite() {
2568            let agree = claim.is_infinite()
2569                && truth.is_infinite()
2570                && claim.is_sign_positive() == truth.is_sign_positive();
2571            if agree {
2572                return Ok(());
2573            }
2574            return Err(format!(
2575                "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2576            ));
2577        }
2578        let band = atol + rel_tol * claim.abs().max(truth.abs());
2579        if (claim - truth).abs() > band {
2580            return Err(format!(
2581                "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2582            ));
2583        }
2584        Ok(())
2585    };
2586
2587    check("value", claims.value, tower.v)?;
2588
2589    for a in 0..K {
2590        check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2591    }
2592
2593    for a in 0..K {
2594        for b in 0..K {
2595            check(
2596                &format!("hessian[{a}][{b}]"),
2597                claims.hessian[a][b],
2598                tower.h[a][b],
2599            )?;
2600        }
2601    }
2602
2603    for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2604        let truth = tower.third_contracted(dir);
2605        for a in 0..K {
2606            for b in 0..K {
2607                check(
2608                    &format!("third[{t_idx}][{a}][{b}]"),
2609                    claim[a][b],
2610                    truth[a][b],
2611                )?;
2612            }
2613        }
2614    }
2615
2616    for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2617        let truth = tower.fourth_contracted(u, w);
2618        for a in 0..K {
2619            for b in 0..K {
2620                check(
2621                    &format!("fourth[{f_idx}][{a}][{b}]"),
2622                    claim[a][b],
2623                    truth[a][b],
2624                )?;
2625            }
2626        }
2627    }
2628
2629    Ok(())
2630}
2631
2632// ===========================================================================
2633// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2634// ===========================================================================
2635//
2636// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2637// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2638// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2639// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2640// vector pass instead of one per scalar pass.
2641//
2642// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2643// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2644// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2645// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2646// performs no fp-contraction), so lane `i` of any channel of a
2647// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2648// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2649// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2650// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2651// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2652//
2653// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2654// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2655// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2656// exactly as the scalar path already does).
2657
2658use crate::jet_scalar::Lane;
2659
2660/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2661/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2662/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2663#[derive(Clone, Copy)]
2664pub struct Tower4Lane<L: Lane, const K: usize> {
2665    /// Value channel (one entry per lane/row).
2666    pub v: L,
2667    /// Gradient `∂/∂p_a`.
2668    pub g: [L; K],
2669    /// Hessian `∂²/∂p_a∂p_b`.
2670    pub h: [[L; K]; K],
2671    /// Third tensor `∂³`.
2672    pub t3: [[[L; K]; K]; K],
2673    /// Fourth tensor `∂⁴`.
2674    pub t4: [[[[L; K]; K]; K]; K],
2675}
2676
2677/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2678pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2679
2680impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2681    /// All-zero tower (every channel `+0.0` in every lane).
2682    #[inline]
2683    pub fn zero() -> Self {
2684        let z = L::splat(0.0);
2685        Self {
2686            v: z,
2687            g: [z; K],
2688            h: [[z; K]; K],
2689            t3: [[[z; K]; K]; K],
2690            t4: [[[[z; K]; K]; K]; K],
2691        }
2692    }
2693    /// Constant `c` (per lane): value channel only.
2694    #[inline]
2695    pub fn constant(c: L) -> Self {
2696        let mut o = Self::zero();
2697        o.v = c;
2698        o
2699    }
2700    /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2701    /// slot `idx` (mirrors [`Tower4::variable`]).
2702    #[inline]
2703    pub fn variable(value: L, idx: usize) -> Self {
2704        let mut o = Self::constant(value);
2705        o.g[idx] = L::splat(1.0);
2706        o
2707    }
2708    /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2709    #[inline]
2710    pub fn lane(&self, i: usize) -> Tower4<K> {
2711        let mut out = Tower4::<K>::zero();
2712        out.v = self.v.lane(i);
2713        for a in 0..K {
2714            out.g[a] = self.g[a].lane(i);
2715            for b in 0..K {
2716                out.h[a][b] = self.h[a][b].lane(i);
2717                for c in 0..K {
2718                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2719                    for d in 0..K {
2720                        out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2721                    }
2722                }
2723            }
2724        }
2725        out
2726    }
2727    /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2728    #[inline]
2729    pub fn add(&self, o: &Self) -> Self {
2730        let mut out = *self;
2731        out.v = self.v.add(o.v);
2732        for i in 0..K {
2733            out.g[i] = self.g[i].add(o.g[i]);
2734            for j in 0..K {
2735                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2736                for k in 0..K {
2737                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2738                    for l in 0..K {
2739                        out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2740                    }
2741                }
2742            }
2743        }
2744        out
2745    }
2746    /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2747    #[inline]
2748    pub fn sub(&self, o: &Self) -> Self {
2749        let mut out = *self;
2750        out.v = self.v.sub(o.v);
2751        for i in 0..K {
2752            out.g[i] = self.g[i].sub(o.g[i]);
2753            for j in 0..K {
2754                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2755                for k in 0..K {
2756                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2757                    for l in 0..K {
2758                        out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2759                    }
2760                }
2761            }
2762        }
2763        out
2764    }
2765    /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2766    #[inline]
2767    pub fn scale(&self, s: f64) -> Self {
2768        let sl = L::splat(s);
2769        let mut out = *self;
2770        out.v = self.v.mul(sl);
2771        for i in 0..K {
2772            out.g[i] = self.g[i].mul(sl);
2773            for j in 0..K {
2774                out.h[i][j] = self.h[i][j].mul(sl);
2775                for k in 0..K {
2776                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2777                    for l in 0..K {
2778                        out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2779                    }
2780                }
2781            }
2782        }
2783        out
2784    }
2785    /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2786    #[inline]
2787    pub fn mul(&self, o: &Self) -> Self {
2788        let a = self;
2789        let b = o;
2790        let mut out = Self::zero();
2791        out.v = a.v.mul(b.v);
2792        for i in 0..K {
2793            let mut acc = L::splat(0.0);
2794            acc = acc.add(a.v.mul(b.g[i]));
2795            acc = acc.add(a.g[i].mul(b.v));
2796            out.g[i] = acc;
2797        }
2798        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
2799        for i in 0..K {
2800            for j in i..K {
2801                let mut acc = L::splat(0.0);
2802                acc = acc.add(a.v.mul(b.h[i][j]));
2803                acc = acc.add(a.g[i].mul(b.g[j]));
2804                acc = acc.add(a.g[j].mul(b.g[i]));
2805                acc = acc.add(a.h[i][j].mul(b.v));
2806                out.h[i][j] = acc;
2807                out.h[j][i] = acc;
2808            }
2809        }
2810        for i in 0..K {
2811            for j in 0..K {
2812                for k in 0..K {
2813                    let mut acc = L::splat(0.0);
2814                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
2815                    acc = acc.add(a.g[i].mul(b.h[j][k]));
2816                    acc = acc.add(a.g[j].mul(b.h[i][k]));
2817                    acc = acc.add(a.h[i][j].mul(b.g[k]));
2818                    acc = acc.add(a.g[k].mul(b.h[i][j]));
2819                    acc = acc.add(a.h[i][k].mul(b.g[j]));
2820                    acc = acc.add(a.h[j][k].mul(b.g[i]));
2821                    acc = acc.add(a.t3[i][j][k].mul(b.v));
2822                    out.t3[i][j][k] = acc;
2823                }
2824            }
2825        }
2826        for i in 0..K {
2827            for j in 0..K {
2828                for k in 0..K {
2829                    for l in 0..K {
2830                        let mut acc = L::splat(0.0);
2831                        acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2832                        acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2833                        acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2834                        acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2835                        acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2836                        acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2837                        acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2838                        acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2839                        acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2840                        acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2841                        acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2842                        acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2843                        acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2844                        acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2845                        acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2846                        acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2847                        out.t4[i][j][k][l] = acc;
2848                    }
2849                }
2850            }
2851        }
2852        out
2853    }
2854    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2855    /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2856    /// (build via [`Lane::unary5`] from the scalar special-function stack).
2857    #[inline]
2858    pub fn compose_unary(&self, d: [L; 5]) -> Self {
2859        let mut out = Self::zero();
2860        out.v = d[0];
2861        for i in 0..K {
2862            let mut acc = L::splat(0.0);
2863            acc = acc.add(d[1].mul(self.g[i]));
2864            out.g[i] = acc;
2865        }
2866        for i in 0..K {
2867            for j in 0..K {
2868                let mut acc = L::splat(0.0);
2869                acc = acc.add(d[1].mul(self.h[i][j]));
2870                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2871                out.h[i][j] = acc;
2872            }
2873        }
2874        for i in 0..K {
2875            for j in 0..K {
2876                for k in 0..K {
2877                    let mut acc = L::splat(0.0);
2878                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
2879                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2880                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2881                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2882                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2883                    out.t3[i][j][k] = acc;
2884                }
2885            }
2886        }
2887        for i in 0..K {
2888            for j in 0..K {
2889                for k in 0..K {
2890                    for l in 0..K {
2891                        let mut acc = L::splat(0.0);
2892                        acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2893                        acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2894                        acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2895                        acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2896                        acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2897                        acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2898                        acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2899                        acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2900                        acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2901                        acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2902                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2903                        acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2904                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2905                        acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2906                        acc = acc.add(
2907                            d[4].mul(self.g[i])
2908                                .mul(self.g[j])
2909                                .mul(self.g[k])
2910                                .mul(self.g[l]),
2911                        );
2912                        out.t4[i][j][k][l] = acc;
2913                    }
2914                }
2915            }
2916        }
2917        out
2918    }
2919    /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2920    /// built from the base value through `stack_fn`, evaluated PER LANE — the
2921    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2922    /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2923    ///
2924    /// Each of the four lanes carries a DISTINCT base value, so the scalar
2925    /// `stack_fn` is run once per lane at that lane's own value (via
2926    /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2927    /// the composition is then the existing per-lane [`Self::compose_unary`].
2928    /// Because `unary_with` runs the identical scalar closure per lane and
2929    /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2930    /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2931    /// — which is exactly what lets a row program written against the scalar
2932    /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2933    #[inline]
2934    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2935        self.compose_unary(self.v.unary_with(stack_fn))
2936    }
2937
2938    /// Single-active-slot fast path, term-for-term lift of
2939    /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2940    #[inline]
2941    pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2942        let mut out = Self::zero();
2943        let s = slot;
2944        let g = self.g[s];
2945        let h = self.h[s][s];
2946        let t3 = self.t3[s][s][s];
2947        let t4 = self.t4[s][s][s][s];
2948        out.v = d[0];
2949        out.g[s] = {
2950            let mut acc = L::splat(0.0);
2951            acc = acc.add(d[1].mul(g));
2952            acc
2953        };
2954        out.h[s][s] = {
2955            let mut acc = L::splat(0.0);
2956            acc = acc.add(d[1].mul(h));
2957            acc = acc.add(d[2].mul(g).mul(g));
2958            acc
2959        };
2960        out.t3[s][s][s] = {
2961            let mut acc = L::splat(0.0);
2962            acc = acc.add(d[1].mul(t3));
2963            acc = acc.add(d[2].mul(h).mul(g));
2964            acc = acc.add(d[2].mul(h).mul(g));
2965            acc = acc.add(d[2].mul(g).mul(h));
2966            acc = acc.add(d[3].mul(g).mul(g).mul(g));
2967            acc
2968        };
2969        out.t4[s][s][s][s] = {
2970            let mut acc = L::splat(0.0);
2971            acc = acc.add(d[1].mul(t4));
2972            acc = acc.add(d[2].mul(t3).mul(g));
2973            acc = acc.add(d[2].mul(t3).mul(g));
2974            acc = acc.add(d[2].mul(h).mul(h));
2975            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2976            acc = acc.add(d[2].mul(t3).mul(g));
2977            acc = acc.add(d[2].mul(h).mul(h));
2978            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2979            acc = acc.add(d[2].mul(h).mul(h));
2980            acc = acc.add(d[2].mul(g).mul(t3));
2981            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2982            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2983            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2984            acc = acc.add(d[3].mul(g).mul(g).mul(h));
2985            acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2986            acc
2987        };
2988        out
2989    }
2990    /// Contract `t3` with a primary-space direction (lift of
2991    /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2992    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2993    #[inline]
2994    pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2995        let mut out = [[L::splat(0.0); K]; K];
2996        for a in 0..K {
2997            for b in a..K {
2998                let mut acc = L::splat(0.0);
2999                for c in 0..K {
3000                    acc = acc.add(self.t3[a][b][c].mul(dir[c]));
3001                }
3002                out[a][b] = acc;
3003                out[b][a] = acc;
3004            }
3005        }
3006        out
3007    }
3008    /// Contract `t4` with two primary-space directions (lift of
3009    /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
3010    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
3011    #[inline]
3012    pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
3013        let mut out = [[L::splat(0.0); K]; K];
3014        for i in 0..K {
3015            for j in i..K {
3016                let mut acc = L::splat(0.0);
3017                for k in 0..K {
3018                    for l in 0..K {
3019                        acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
3020                    }
3021                }
3022                out[i][j] = acc;
3023                out[j][i] = acc;
3024            }
3025        }
3026        out
3027    }
3028}
3029
3030/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
3031#[derive(Clone, Copy)]
3032pub struct Tower3Lane<L: Lane, const K: usize> {
3033    /// Value channel.
3034    pub v: L,
3035    /// Gradient.
3036    pub g: [L; K],
3037    /// Hessian.
3038    pub h: [[L; K]; K],
3039    /// Third tensor.
3040    pub t3: [[[L; K]; K]; K],
3041}
3042
3043/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
3044pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
3045
3046impl<L: Lane, const K: usize> Tower3Lane<L, K> {
3047    /// All-zero tower.
3048    #[inline]
3049    pub fn zero() -> Self {
3050        let z = L::splat(0.0);
3051        Self {
3052            v: z,
3053            g: [z; K],
3054            h: [[z; K]; K],
3055            t3: [[[z; K]; K]; K],
3056        }
3057    }
3058    /// Constant `c` (per lane).
3059    #[inline]
3060    pub fn constant(c: L) -> Self {
3061        let mut o = Self::zero();
3062        o.v = c;
3063        o
3064    }
3065    /// Seeded variable `p_idx` at per-lane `value`.
3066    #[inline]
3067    pub fn variable(value: L, idx: usize) -> Self {
3068        let mut o = Self::constant(value);
3069        o.g[idx] = L::splat(1.0);
3070        o
3071    }
3072    /// Extract lane `i` as a scalar [`Tower3<K>`].
3073    #[inline]
3074    pub fn lane(&self, i: usize) -> Tower3<K> {
3075        let mut out = Tower3::<K>::zero();
3076        out.v = self.v.lane(i);
3077        for a in 0..K {
3078            out.g[a] = self.g[a].lane(i);
3079            for b in 0..K {
3080                out.h[a][b] = self.h[a][b].lane(i);
3081                for c in 0..K {
3082                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
3083                }
3084            }
3085        }
3086        out
3087    }
3088    /// Per-channel lane-wise `self + o`.
3089    #[inline]
3090    pub fn add(&self, o: &Self) -> Self {
3091        let mut out = *self;
3092        out.v = self.v.add(o.v);
3093        for i in 0..K {
3094            out.g[i] = self.g[i].add(o.g[i]);
3095            for j in 0..K {
3096                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3097                for k in 0..K {
3098                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3099                }
3100            }
3101        }
3102        out
3103    }
3104    /// Per-channel lane-wise `self - o`.
3105    #[inline]
3106    pub fn sub(&self, o: &Self) -> Self {
3107        let mut out = *self;
3108        out.v = self.v.sub(o.v);
3109        for i in 0..K {
3110            out.g[i] = self.g[i].sub(o.g[i]);
3111            for j in 0..K {
3112                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3113                for k in 0..K {
3114                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3115                }
3116            }
3117        }
3118        out
3119    }
3120    /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3121    #[inline]
3122    pub fn scale(&self, s: f64) -> Self {
3123        let sl = L::splat(s);
3124        let mut out = *self;
3125        out.v = self.v.mul(sl);
3126        for i in 0..K {
3127            out.g[i] = self.g[i].mul(sl);
3128            for j in 0..K {
3129                out.h[i][j] = self.h[i][j].mul(sl);
3130                for k in 0..K {
3131                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3132                }
3133            }
3134        }
3135        out
3136    }
3137    /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3138    #[inline]
3139    pub fn mul(&self, o: &Self) -> Self {
3140        let a = self;
3141        let b = o;
3142        let mut out = Self::zero();
3143        out.v = a.v.mul(b.v);
3144        for i in 0..K {
3145            let mut acc = L::splat(0.0);
3146            acc = acc.add(a.v.mul(b.g[i]));
3147            acc = acc.add(a.g[i].mul(b.v));
3148            out.g[i] = acc;
3149        }
3150        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
3151        for i in 0..K {
3152            for j in i..K {
3153                let mut acc = L::splat(0.0);
3154                acc = acc.add(a.v.mul(b.h[i][j]));
3155                acc = acc.add(a.g[i].mul(b.g[j]));
3156                acc = acc.add(a.g[j].mul(b.g[i]));
3157                acc = acc.add(a.h[i][j].mul(b.v));
3158                out.h[i][j] = acc;
3159                out.h[j][i] = acc;
3160            }
3161        }
3162        for i in 0..K {
3163            for j in 0..K {
3164                for k in 0..K {
3165                    let mut acc = L::splat(0.0);
3166                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
3167                    acc = acc.add(a.g[i].mul(b.h[j][k]));
3168                    acc = acc.add(a.g[j].mul(b.h[i][k]));
3169                    acc = acc.add(a.h[i][j].mul(b.g[k]));
3170                    acc = acc.add(a.g[k].mul(b.h[i][j]));
3171                    acc = acc.add(a.h[i][k].mul(b.g[j]));
3172                    acc = acc.add(a.h[j][k].mul(b.g[i]));
3173                    acc = acc.add(a.t3[i][j][k].mul(b.v));
3174                    out.t3[i][j][k] = acc;
3175                }
3176            }
3177        }
3178        out
3179    }
3180    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3181    /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3182    #[inline]
3183    pub fn compose_unary(&self, d: [L; 4]) -> Self {
3184        let mut out = Self::zero();
3185        out.v = d[0];
3186        for i in 0..K {
3187            let mut acc = L::splat(0.0);
3188            acc = acc.add(d[1].mul(self.g[i]));
3189            out.g[i] = acc;
3190        }
3191        for i in 0..K {
3192            for j in 0..K {
3193                let mut acc = L::splat(0.0);
3194                acc = acc.add(d[1].mul(self.h[i][j]));
3195                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3196                out.h[i][j] = acc;
3197            }
3198        }
3199        for i in 0..K {
3200            for j in 0..K {
3201                for k in 0..K {
3202                    let mut acc = L::splat(0.0);
3203                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
3204                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3205                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3206                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3207                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3208                    out.t3[i][j][k] = acc;
3209                }
3210            }
3211        }
3212        out
3213    }
3214    /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3215    /// built from the base value through `stack_fn`, evaluated PER LANE — the
3216    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3217    /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3218    /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3219    /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3220    /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3221    /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3222    #[inline]
3223    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3224        self.compose_unary(self.v.unary_with(stack_fn))
3225    }
3226
3227    /// Single-active-slot fast path, term-for-term lift of
3228    /// [`Tower3::compose_unary_single_slot`].
3229    #[inline]
3230    pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3231        let mut out = Self::zero();
3232        let s = slot;
3233        let g = self.g[s];
3234        let h = self.h[s][s];
3235        let t3 = self.t3[s][s][s];
3236        out.v = d[0];
3237        out.g[s] = {
3238            let mut acc = L::splat(0.0);
3239            acc = acc.add(d[1].mul(g));
3240            acc
3241        };
3242        out.h[s][s] = {
3243            let mut acc = L::splat(0.0);
3244            acc = acc.add(d[1].mul(h));
3245            acc = acc.add(d[2].mul(g).mul(g));
3246            acc
3247        };
3248        out.t3[s][s][s] = {
3249            let mut acc = L::splat(0.0);
3250            acc = acc.add(d[1].mul(t3));
3251            acc = acc.add(d[2].mul(h).mul(g));
3252            acc = acc.add(d[2].mul(h).mul(g));
3253            acc = acc.add(d[2].mul(g).mul(h));
3254            acc = acc.add(d[3].mul(g).mul(g).mul(g));
3255            acc
3256        };
3257        out
3258    }
3259}
3260
3261#[cfg(test)]
3262mod batch_tests {
3263    use super::*;
3264
3265    struct Rng(u64);
3266    impl Rng {
3267        fn f(&mut self) -> f64 {
3268            self.0 = self
3269                .0
3270                .wrapping_mul(6364136223846793005)
3271                .wrapping_add(1442695040888963407);
3272            ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3273        }
3274    }
3275
3276    // Fill every channel of a scalar Tower4<K> with random data.
3277    fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3278        let mut t = Tower4::<K>::zero();
3279        t.v = r.f();
3280        for i in 0..K {
3281            t.g[i] = r.f();
3282            for j in 0..K {
3283                t.h[i][j] = r.f();
3284                for k in 0..K {
3285                    t.t3[i][j][k] = r.f();
3286                    for l in 0..K {
3287                        t.t4[i][j][k][l] = r.f();
3288                    }
3289                }
3290            }
3291        }
3292        t
3293    }
3294    fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3295        let mut t = Tower3::<K>::zero();
3296        t.v = r.f();
3297        for i in 0..K {
3298            t.g[i] = r.f();
3299            for j in 0..K {
3300                t.h[i][j] = r.f();
3301                for k in 0..K {
3302                    t.t3[i][j][k] = r.f();
3303                }
3304            }
3305        }
3306        t
3307    }
3308    fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3309        let mut b = Tower4Batch::<K>::zero();
3310        let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3311            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3312        };
3313        b.v = lane(&|t| t.v);
3314        for i in 0..K {
3315            b.g[i] = lane(&|t| t.g[i]);
3316            for j in 0..K {
3317                b.h[i][j] = lane(&|t| t.h[i][j]);
3318                for k in 0..K {
3319                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3320                    for l in 0..K {
3321                        b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3322                    }
3323                }
3324            }
3325        }
3326        b
3327    }
3328    fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3329        let mut b = Tower3Batch::<K>::zero();
3330        let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3331            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3332        };
3333        b.v = lane(&|t| t.v);
3334        for i in 0..K {
3335            b.g[i] = lane(&|t| t.g[i]);
3336            for j in 0..K {
3337                b.h[i][j] = lane(&|t| t.h[i][j]);
3338                for k in 0..K {
3339                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3340                }
3341            }
3342        }
3343        b
3344    }
3345    fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3346        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3347        for i in 0..K {
3348            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3349            for j in 0..K {
3350                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3351                for k in 0..K {
3352                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3353                    for l in 0..K {
3354                        assert_eq!(
3355                            b.t4[i][j][k][l].to_bits(),
3356                            s.t4[i][j][k][l].to_bits(),
3357                            "t4 {ctx}"
3358                        );
3359                    }
3360                }
3361            }
3362        }
3363    }
3364    fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3365        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3366        for i in 0..K {
3367            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3368            for j in 0..K {
3369                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3370                for k in 0..K {
3371                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3372                }
3373            }
3374        }
3375    }
3376
3377    // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3378    // then assert every channel of every lane is to_bits-identical.
3379    fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3380        let mut r = Rng(seed);
3381        let mut rows_checked = 0;
3382        for _ in 0..batches {
3383            let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3384            let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3385            let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3386            let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3387            let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3388            let s = r.f();
3389
3390            // scalar per-row reference
3391            let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3392                let prod = a[rw].mul(&b[rw]);
3393                let comp = prod.compose_unary(d[rw]);
3394                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3395                summed.compose_unary_single_slot(d[rw], 0)
3396            });
3397            let third: [[[f64; K]; K]; 4] =
3398                std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3399            let fourth: [[[f64; K]; K]; 4] =
3400                std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3401
3402            // batched f64x4
3403            let ab = pack4_t4(&a);
3404            let bb = pack4_t4(&b);
3405            let db: [wide::f64x4; 5] =
3406                std::array::from_fn(|c| wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]]));
3407            let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3408                wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3409            });
3410            let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3411                wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3412            });
3413            let prodb = ab.mul(&bb);
3414            let compb = prodb.compose_unary(db);
3415            let summedb = compb.add(&ab).sub(&bb).scale(s);
3416            let finalb = summedb.compose_unary_single_slot(db, 0);
3417            let thirdb = ab.third_contracted(&dirb);
3418            let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3419
3420            for rw in 0..4 {
3421                assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3422                for i in 0..K {
3423                    for j in 0..K {
3424                        assert_eq!(
3425                            thirdb[i][j].lane(rw).to_bits(),
3426                            third[rw][i][j].to_bits(),
3427                            "third"
3428                        );
3429                        assert_eq!(
3430                            fourthb[i][j].lane(rw).to_bits(),
3431                            fourth[rw][i][j].to_bits(),
3432                            "fourth"
3433                        );
3434                    }
3435                }
3436                rows_checked += 1;
3437            }
3438        }
3439        rows_checked
3440    }
3441    fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3442        let mut r = Rng(seed);
3443        let mut rows_checked = 0;
3444        for _ in 0..batches {
3445            let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3446            let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3447            let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3448            let s = r.f();
3449            let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3450                let prod = a[rw].mul(&b[rw]);
3451                let comp = prod.compose_unary(d[rw]);
3452                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3453                summed.compose_unary_single_slot(d[rw], 0)
3454            });
3455            let ab = pack4_t3(&a);
3456            let bb = pack4_t3(&b);
3457            let db: [wide::f64x4; 4] =
3458                std::array::from_fn(|c| wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]]));
3459            let prodb = ab.mul(&bb);
3460            let compb = prodb.compose_unary(db);
3461            let summedb = compb.add(&ab).sub(&bb).scale(s);
3462            let finalb = summedb.compose_unary_single_slot(db, 0);
3463            for rw in 0..4 {
3464                assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3465                rows_checked += 1;
3466            }
3467        }
3468        rows_checked
3469    }
3470
3471    // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3472    // lanes (≈210 KiB by value); the op chain keeps several live, which can
3473    // exceed a test thread's default stack. Run each width on a large-stack
3474    // thread so K=9 is exercised without a stack overflow.
3475    fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3476        std::thread::Builder::new()
3477            .stack_size(512 << 20)
3478            .spawn(f)
3479            .unwrap()
3480            .join()
3481            .unwrap()
3482    }
3483
3484    #[test]
3485    fn tower4_batch_lane_bit_identical() {
3486        let batches = 2000;
3487        let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3488            + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3489            + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3490            + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3491        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3492        // worker threads against silently running zero comparisons.
3493        assert_eq!(rows_checked, 4 * batches * 4);
3494    }
3495
3496    #[test]
3497    fn tower3_batch_lane_bit_identical() {
3498        let batches = 2000;
3499        let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3500            + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3501            + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3502            + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3503        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3504        // worker threads against silently running zero comparisons.
3505        assert_eq!(rows_checked, 4 * batches * 4);
3506    }
3507
3508    // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3509    //
3510    // The seam lets a single-sourced row program build its special-function
3511    // STACK from the base value through a closure, so the SAME expression
3512    // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3513    // per-lane bases). These oracles pin both arms `to_bits`.
3514
3515    /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3516    /// standing in for a family's hand-certified special-function stack. `stack4`
3517    /// is its order-≤3 truncation.
3518    fn seam_stack5(u: f64) -> [f64; 5] {
3519        [
3520            u.sin(),
3521            u.cos(),
3522            (2.0 * u).sin(),
3523            (0.5 * u).cos(),
3524            u * u - 0.3,
3525        ]
3526    }
3527    fn seam_stack4(u: f64) -> [f64; 4] {
3528        let s = seam_stack5(u);
3529        [s[0], s[1], s[2], s[3]]
3530    }
3531
3532    /// Force a distinct / edge per-lane base value (signed zeros included).
3533    fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3534        match which {
3535            0 => -0.0,
3536            1 => 0.0,
3537            2 => r.f(),
3538            _ => r.f() + 3.0,
3539        }
3540    }
3541
3542    /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3543    /// the explicit `compose_unary(f(value))` on every channel.
3544    fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3545        let mut r = Rng(seed);
3546        for _ in 0..n {
3547            let mut t = rand_t4::<K>(&mut r);
3548            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3549            assert_t4_eq(
3550                &t.compose_unary_with(seam_stack5),
3551                &t.compose_unary(seam_stack5(t.v)),
3552                "scalar t4 seam",
3553            );
3554        }
3555        n
3556    }
3557    fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3558        let mut r = Rng(seed);
3559        for _ in 0..n {
3560            let mut t = rand_t3::<K>(&mut r);
3561            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3562            assert_t3_eq(
3563                &t.compose_unary_with(seam_stack4),
3564                &t.compose_unary(seam_stack4(t.v)),
3565                "scalar t3 seam",
3566            );
3567        }
3568        n
3569    }
3570
3571    /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3572    /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3573    /// with the four lanes carrying DISTINCT base values (signed zeros included),
3574    /// so a buggy impl reusing one lane's base would fail.
3575    fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3576        let mut r = Rng(seed);
3577        let mut verified = 0usize;
3578        for _ in 0..batches {
3579            let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3580            for (rw, row) in rows.iter_mut().enumerate() {
3581                row.v = seam_edge_base(&mut r, rw);
3582            }
3583            let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3584            for (rw, row) in rows.iter().enumerate() {
3585                assert_t4_eq(
3586                    &batch_out.lane(rw),
3587                    &row.compose_unary_with(seam_stack5),
3588                    "lane t4 seam",
3589                );
3590                verified += 1;
3591            }
3592        }
3593        verified
3594    }
3595    fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3596        let mut r = Rng(seed);
3597        let mut verified = 0usize;
3598        for _ in 0..batches {
3599            let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3600            for (rw, row) in rows.iter_mut().enumerate() {
3601                row.v = seam_edge_base(&mut r, rw);
3602            }
3603            let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3604            for (rw, row) in rows.iter().enumerate() {
3605                assert_t3_eq(
3606                    &batch_out.lane(rw),
3607                    &row.compose_unary_with(seam_stack4),
3608                    "lane t3 seam",
3609                );
3610                verified += 1;
3611            }
3612        }
3613        verified
3614    }
3615
3616    #[test]
3617    fn compose_unary_with_scalar_bit_identical() {
3618        let n = 1100;
3619        let total = scalar_seam_t4::<2>(0x2200_0001, n)
3620            + scalar_seam_t4::<3>(0x2200_0002, n)
3621            + scalar_seam_t4::<4>(0x2200_0003, n)
3622            + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3623            + scalar_seam_t3::<2>(0x3300_0001, n)
3624            + scalar_seam_t3::<3>(0x3300_0002, n)
3625            + scalar_seam_t3::<4>(0x3300_0003, n)
3626            + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3627        // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3628        assert_eq!(total, 8 * n);
3629    }
3630
3631    #[test]
3632    fn compose_unary_with_lane_matches_scalar() {
3633        let b = 600;
3634        let total = lane_seam_t4::<2>(0x4400_0001, b)
3635            + lane_seam_t4::<3>(0x4400_0002, b)
3636            + lane_seam_t4::<4>(0x4400_0003, b)
3637            + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3638            + lane_seam_t3::<2>(0x5500_0001, b)
3639            + lane_seam_t3::<3>(0x5500_0002, b)
3640            + lane_seam_t3::<4>(0x5500_0003, b)
3641            + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3642        // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3643        assert_eq!(total, 8 * b * 4);
3644    }
3645}
3646
3647#[cfg(test)]
3648mod tests {
3649    use super::*;
3650
3651    /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3652    /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3653    /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3654    /// dropping the fourth tensor cannot perturb them. Exercises products
3655    /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3656    /// same operations the survival location-scale `nll_index_tower` composes —
3657    /// across all mixed partials, not just the diagonal entries that kernel reads.
3658    #[test]
3659    fn tower3_matches_tower4_through_third_order() {
3660        let s_a: [f64; 5] = [
3661            0.3_f64.sin(),
3662            0.3_f64.cos(),
3663            -0.3_f64.sin(),
3664            -0.3_f64.cos(),
3665            0.3_f64.sin(),
3666        ];
3667        let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3668        let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3669
3670        let a4 = Tower4::<3>::variable(0.4, 0);
3671        let b4 = Tower4::<3>::variable(-0.7, 1);
3672        let c4 = Tower4::<3>::variable(0.9, 2);
3673        let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3674            + a4.mul(&c4).scale(-0.7)
3675            + b4.compose_unary(s_b).scale(0.25);
3676
3677        let a3 = Tower3::<3>::variable(0.4, 0);
3678        let b3 = Tower3::<3>::variable(-0.7, 1);
3679        let c3 = Tower3::<3>::variable(0.9, 2);
3680        let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3681            + a3.mul(&c3).scale(-0.7)
3682            + b3.compose_unary(s4(s_b)).scale(0.25);
3683
3684        assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3685        for i in 0..3 {
3686            assert_eq!(
3687                prog3.g[i].to_bits(),
3688                prog4.g[i].to_bits(),
3689                "g[{i}] mismatch"
3690            );
3691            for j in 0..3 {
3692                assert_eq!(
3693                    prog3.h[i][j].to_bits(),
3694                    prog4.h[i][j].to_bits(),
3695                    "h[{i}][{j}] mismatch"
3696                );
3697                for k in 0..3 {
3698                    assert_eq!(
3699                        prog3.t3[i][j][k].to_bits(),
3700                        prog4.t3[i][j][k].to_bits(),
3701                        "t3[{i}][{j}][{k}] mismatch"
3702                    );
3703                }
3704            }
3705        }
3706    }
3707
3708    /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3709    /// The entire tower has textbook closed forms in μ = σ(η); this test
3710    /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3711    /// analytic truth at near-machine precision.
3712    struct LogitProgram {
3713        eta: Vec<f64>,
3714        y: Vec<f64>,
3715    }
3716
3717    impl RowNllProgram<1> for LogitProgram {
3718        fn n_rows(&self) -> usize {
3719            self.eta.len()
3720        }
3721        fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3722            Ok([self.eta[row]])
3723        }
3724        fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3725            let eta = p[0];
3726            Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3727        }
3728    }
3729
3730    #[test]
3731    fn logit_tower_matches_closed_forms() {
3732        let prog = LogitProgram {
3733            eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3734            y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3735        };
3736        for row in 0..prog.n_rows() {
3737            let t = evaluate_program(&prog, row).expect("logit program");
3738            let eta = prog.eta[row];
3739            let y = prog.y[row];
3740            let mu = 1.0 / (1.0 + (-eta).exp());
3741            let w = mu * (1.0 - mu);
3742            let expect = [
3743                (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3744                (t.g[0], mu - y, "grad"),
3745                (t.h[0][0], w, "hess"),
3746                (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3747                (
3748                    t.t4[0][0][0][0],
3749                    w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3750                    "fourth",
3751                ),
3752            ];
3753            for (got, want, label) in expect {
3754                assert!(
3755                    (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3756                    "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3757                );
3758            }
3759        }
3760    }
3761
3762    fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3763        let diff = (got - want).abs();
3764        assert!(
3765            diff <= rel_tol * want.abs().max(1.0),
3766            "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3767        );
3768    }
3769
3770    #[test]
3771    fn gamma_special_function_stacks_match_reference_values() {
3772        const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3773        let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3774        let cases = [
3775            (
3776                "x=0.1",
3777                0.1,
3778                -10.423_754_940_411_076,
3779                101.433_299_150_792_75,
3780            ),
3781            (
3782                "x=0.5",
3783                0.5,
3784                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3785                pi_sq / 2.0,
3786            ),
3787            ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3788            (
3789                "x=2.5",
3790                2.5,
3791                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3792                pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3793            ),
3794            (
3795                "x=50",
3796                50.0,
3797                3.901_989_673_427_892,
3798                0.020_201_333_226_697_128,
3799            ),
3800        ];
3801
3802        for (label, x, digamma_ref, trigamma_ref) in cases {
3803            let ln_gamma_stack = ln_gamma_derivative_stack(x);
3804            let digamma_stack = digamma_derivative_stack(x);
3805            let trigamma_stack = trigamma_derivative_stack(x);
3806            assert_close(
3807                &format!("{label} ln_gamma_stack digamma"),
3808                ln_gamma_stack[1],
3809                digamma_ref,
3810                1e-13,
3811            );
3812            assert_close(
3813                &format!("{label} digamma value"),
3814                digamma_stack[0],
3815                digamma_ref,
3816                1e-13,
3817            );
3818            assert_close(
3819                &format!("{label} ln_gamma_stack trigamma"),
3820                ln_gamma_stack[2],
3821                trigamma_ref,
3822                1e-13,
3823            );
3824            assert_close(
3825                &format!("{label} digamma_stack trigamma"),
3826                digamma_stack[1],
3827                trigamma_ref,
3828                1e-13,
3829            );
3830            assert_close(
3831                &format!("{label} trigamma value"),
3832                trigamma_stack[0],
3833                trigamma_ref,
3834                1e-13,
3835            );
3836        }
3837    }
3838
3839    #[test]
3840    fn gamma_special_function_stacks_obey_recurrences() {
3841        for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3842            let digamma_x = digamma_derivative_stack(x)[0];
3843            let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3844            let trigamma_x = trigamma_derivative_stack(x)[0];
3845            let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3846            assert_close(
3847                &format!("digamma recurrence x={x}"),
3848                digamma_next,
3849                digamma_x + 1.0 / x,
3850                1e-13,
3851            );
3852            assert_close(
3853                &format!("trigamma recurrence x={x}"),
3854                trigamma_next,
3855                trigamma_x - 1.0 / (x * x),
3856                1e-13,
3857            );
3858        }
3859    }
3860
3861    /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3862    /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3863    /// shape — all have one-line closed forms here.
3864    struct LocScaleProgram {
3865        eta: Vec<f64>,
3866        s: Vec<f64>,
3867        y: Vec<f64>,
3868    }
3869
3870    impl RowNllProgram<2> for LocScaleProgram {
3871        fn n_rows(&self) -> usize {
3872            self.eta.len()
3873        }
3874        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3875            Ok([self.eta[row], self.s[row]])
3876        }
3877        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3878            let r = -(p[0] - self.y[row]);
3879            Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3880        }
3881    }
3882
3883    #[test]
3884    fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3885        let prog = LocScaleProgram {
3886            eta: vec![0.3, -1.1, 2.0],
3887            s: vec![-0.5, 0.2, 0.8],
3888            y: vec![1.0, -2.0, 2.5],
3889        };
3890        let tol = 1e-12;
3891        for row in 0..prog.n_rows() {
3892            let t = evaluate_program(&prog, row).expect("locscale program");
3893            let r = prog.y[row] - prog.eta[row];
3894            let w = (-2.0 * prog.s[row]).exp();
3895            // (η, s) = indices (0, 1).
3896            let truth_g = [-w * r, 1.0 - w * r * r];
3897            let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3898            // Third tensor: distinct-entry closed forms.
3899            // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3900            let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3901                match a + b + c {
3902                    0 => 0.0,
3903                    1 => -2.0 * w,
3904                    2 => -4.0 * w * r,
3905                    _ => -4.0 * w * r * r,
3906                }
3907            };
3908            // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3909            // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3910            let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3911                match a + b + c + d {
3912                    0 | 1 => 0.0,
3913                    2 => 4.0 * w,
3914                    3 => 8.0 * w * r,
3915                    _ => 8.0 * w * r * r,
3916                }
3917            };
3918            for a in 0..2 {
3919                assert!(
3920                    (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3921                    "row {row} grad[{a}]"
3922                );
3923                for b in 0..2 {
3924                    assert!(
3925                        (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3926                        "row {row} hess[{a}][{b}]: got {} want {}",
3927                        t.h[a][b],
3928                        truth_h[a][b]
3929                    );
3930                    for c in 0..2 {
3931                        assert!(
3932                            (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3933                                <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3934                            "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3935                            t.t3[a][b][c],
3936                            t3_truth(a, b, c)
3937                        );
3938                        for d in 0..2 {
3939                            assert!(
3940                                (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3941                                    <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3942                                "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3943                                t.t4[a][b][c][d],
3944                                t4_truth(a, b, c, d)
3945                            );
3946                        }
3947                    }
3948                }
3949            }
3950            // The derived trait-surface helpers agree with direct contraction.
3951            let dir = [0.7, -1.3];
3952            let third = derived_third_contracted(&prog, row, &dir).expect("third");
3953            for a in 0..2 {
3954                for b in 0..2 {
3955                    let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3956                    assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3957                }
3958            }
3959        }
3960    }
3961
3962    /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3963    /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3964    /// every tower channel is checked against central finite differences
3965    /// of the channel one order below — value→grad, grad→hess, hess→t3,
3966    /// t3→t4 — so each order is independently anchored.
3967    ///
3968    /// The program carries a per-row primary fixture plus a per-row offset
3969    /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3970    /// drives both the seed point and the evaluated expression.
3971    struct GnarlyProgram {
3972        primaries: Vec<[f64; 3]>,
3973        tau: Vec<f64>,
3974    }
3975
3976    impl GnarlyProgram {
3977        fn fixture() -> Self {
3978            Self {
3979                primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3980                tau: vec![0.15, -0.35, 0.5],
3981            }
3982        }
3983    }
3984
3985    impl RowNllProgram<3> for GnarlyProgram {
3986        fn n_rows(&self) -> usize {
3987            self.primaries.len()
3988        }
3989        fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3990            self.primaries
3991                .get(row)
3992                .copied()
3993                .ok_or_else(|| format!("gnarly: row {row} out of range"))
3994        }
3995        fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3996            let tau = *self
3997                .tau
3998                .get(row)
3999                .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
4000            let a = (p[0] * p[1]).exp();
4001            let b = (p[2] * p[2] + 1.0).sqrt();
4002            let c = (a + b + tau).ln();
4003            let d = (p[1] * 0.5 + 2.0).powf(1.7);
4004            Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
4005        }
4006    }
4007
4008    /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
4009    /// `row` (used to drive central differences off the fixture grid),
4010    /// while keeping `row`'s per-row data (`tau`) in the loss.
4011    fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
4012        struct At<'a> {
4013            base: &'a GnarlyProgram,
4014            row: usize,
4015            p: [f64; 3],
4016        }
4017        impl RowNllProgram<3> for At<'_> {
4018            fn n_rows(&self) -> usize {
4019                1
4020            }
4021            fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
4022                if row != 0 {
4023                    return Err(format!("gnarly-at: row {row} out of range"));
4024                }
4025                Ok(self.p)
4026            }
4027            fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
4028                if eval_row != 0 {
4029                    return Err(format!("gnarly-at: eval row {eval_row} out of range"));
4030                }
4031                self.base.row_nll(self.row, vars)
4032            }
4033        }
4034        evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
4035    }
4036
4037    #[test]
4038    fn gnarly_tower_is_fd_consistent_order_by_order() {
4039        let prog = GnarlyProgram::fixture();
4040        for row in 0..prog.n_rows() {
4041            let base = prog.primaries(row).expect("primaries");
4042            let t = gnarly_tower_at(&prog, row, base);
4043            let h_step = 1e-5;
4044            let tol = 1e-6;
4045            for c in 0..3 {
4046                let mut up = base;
4047                let mut dn = base;
4048                up[c] += h_step;
4049                dn[c] -= h_step;
4050                let t_up = gnarly_tower_at(&prog, row, up);
4051                let t_dn = gnarly_tower_at(&prog, row, dn);
4052                // value → gradient.
4053                let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
4054                assert!(
4055                    (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4056                    "grad[{c}]: analytic {} fd {}",
4057                    t.g[c],
4058                    fd_g
4059                );
4060                for a in 0..3 {
4061                    // gradient → Hessian.
4062                    let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
4063                    assert!(
4064                        (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
4065                        "hess[{a}][{c}]: analytic {} fd {}",
4066                        t.h[a][c],
4067                        fd_h
4068                    );
4069                    for b in 0..3 {
4070                        // Hessian → third.
4071                        let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
4072                        assert!(
4073                            (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
4074                            "t3[{a}][{b}][{c}]: analytic {} fd {}",
4075                            t.t3[a][b][c],
4076                            fd_t3
4077                        );
4078                        for d in 0..3 {
4079                            // third → fourth.
4080                            let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
4081                            assert!(
4082                                (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
4083                                "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
4084                                t.t4[a][b][d][c],
4085                                fd_t4
4086                            );
4087                        }
4088                    }
4089                }
4090            }
4091        }
4092    }
4093
4094    /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
4095    /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
4096    /// smooth, strictly-`a`-monotone
4097    ///   F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
4098    /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
4099    /// independent finite-difference oracle. Mirrors the survival flex
4100    /// calibration solve (one implicit intercept over the primaries) without
4101    /// any survival machinery, so a failure localises to the combinator.
4102    #[test]
4103    fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
4104        const C: f64 = 1.7;
4105        // The scalar constraint as a plain f64 closure (the production root
4106        // finder analogue) and its tower form in (a, θ₀, θ₁).
4107        let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
4108        let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
4109        let solve = |th: [f64; 2]| -> f64 {
4110            let mut a = 0.0_f64;
4111            for _ in 0..100 {
4112                let r = f_scalar(a, th);
4113                if r.abs() < 1e-14 {
4114                    break;
4115                }
4116                a -= r / f_da(a, th);
4117            }
4118            a
4119        };
4120        // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
4121        let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
4122            let a = Tower4::<3>::variable(a0, 0);
4123            let t0 = Tower4::<3>::variable(th[0], 1);
4124            let t1 = Tower4::<3>::variable(th[1], 2);
4125            a + t0 * a.mul(&a) + t1 * a.exp() - C
4126        };
4127
4128        let th0 = [0.35, 0.2];
4129        let a0 = solve(th0);
4130        let f = f_tower(a0, th0);
4131        // Residual at the solved point is ~0 (the combinator tolerates the
4132        // production Newton residual; here it is machine-zero).
4133        assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4134        let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4135
4136        // FD oracle: central differences of the scalar re-solve. Each order is
4137        // built from the previous via one more central difference, exactly the
4138        // gnarly order-by-order ladder.
4139        let h = 1e-4;
4140        let tol = 1e-5;
4141        let re = |th: [f64; 2]| solve(th);
4142        for i in 0..2 {
4143            let mut up = th0;
4144            let mut dn = th0;
4145            up[i] += h;
4146            dn[i] -= h;
4147            let fd_g = (re(up) - re(dn)) / (2.0 * h);
4148            assert!(
4149                (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4150                "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4151                a_tower.g[i],
4152                fd_g
4153            );
4154            // second order: FD of the analytic gradient component would re-use
4155            // the combinator; instead difference a SCALAR gradient computed by
4156            // a nested re-solve so the oracle stays production-independent.
4157            let grad_at = |th: [f64; 2], j: usize| -> f64 {
4158                let mut up = th;
4159                let mut dn = th;
4160                up[j] += h;
4161                dn[j] -= h;
4162                (re(up) - re(dn)) / (2.0 * h)
4163            };
4164            for j in 0..2 {
4165                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4166                assert!(
4167                    (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4168                    "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4169                    a_tower.h[i][j],
4170                    fd_h
4171                );
4172            }
4173        }
4174    }
4175
4176    /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4177    /// linear-in-a constraint, and the second-order tensor matches the
4178    /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4179    /// This pins the recursion against the hand-coded first_full.rs formula it
4180    /// replaces, independent of any FD step.
4181    #[test]
4182    fn implicit_solve_matches_textbook_ift_recursion() {
4183        // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4184        let a0 = 0.4_f64;
4185        let th = [0.25_f64, -0.15_f64];
4186        let f = {
4187            let a = Tower4::<3>::variable(a0, 0);
4188            let t0 = Tower4::<3>::variable(th[0], 1);
4189            let t1 = Tower4::<3>::variable(th[1], 2);
4190            // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4191            // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4192            //   0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4193            // implicit_solve requires a genuine root; at the root the level-set
4194            // and root-curve derivatives coincide, so the textbook-IFT
4195            // assertions below are unaffected.
4196            a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4197        };
4198        let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4199        let f_a = f.g[0];
4200        // First order: a_u = −F_u / F_a.
4201        for u in 0..2 {
4202            let want = -f.g[u + 1] / f_a;
4203            assert!(
4204                (a_t.g[u] - want).abs() < 1e-12,
4205                "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4206                a_t.g[u],
4207                want
4208            );
4209        }
4210        // Second order textbook IFT (indices shifted by 1 for the a-slot).
4211        for u in 0..2 {
4212            for v in 0..2 {
4213                let f_uv = f.h[u + 1][v + 1];
4214                let f_au = f.h[0][u + 1];
4215                let f_av = f.h[0][v + 1];
4216                let f_aa = f.h[0][0];
4217                let want =
4218                    -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4219                assert!(
4220                    (a_t.h[u][v] - want).abs() < 1e-12,
4221                    "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4222                    a_t.h[u][v],
4223                    want
4224                );
4225            }
4226        }
4227    }
4228
4229    /// The moving-boundary flux tower reproduces every θ-derivative of a
4230    /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4231    /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4232    /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4233    /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4234    /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4235    #[test]
4236    fn moving_boundary_flux_carries_b_zuv_term() {
4237        use std::f64::consts::PI;
4238        let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4239        // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4240        let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4241        let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4242        let th0 = [0.7_f64, 0.5_f64];
4243
4244        // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4245        let mut z_edge = Tower4::<2>::constant(z_r(th0));
4246        z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4247        z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4248        z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2  (the z_uv the old flux dropped)
4249
4250        // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4251        // B‴=(3z−z³)·B.
4252        let z0 = z_edge.v;
4253        let b0 = b(z0);
4254        let stack = [
4255            b0,
4256            -z0 * b0,
4257            (z0 * z0 - 1.0) * b0,
4258            (3.0 * z0 - z0 * z0 * z0) * b0,
4259        ];
4260        let flux = moving_limit_boundary_tower(&z_edge, stack);
4261
4262        // FD truth of the integral's derivatives.
4263        let h = 1e-4;
4264        let tol = 1e-6;
4265        for i in 0..2 {
4266            let mut up = th0;
4267            let mut dn = th0;
4268            up[i] += h;
4269            dn[i] -= h;
4270            let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4271            assert!(
4272                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4273                "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4274                flux.g[i],
4275                fd_g
4276            );
4277        }
4278        // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4279        // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4280        // leave the [1][1] entry short by exactly 2·B(z₀).
4281        let grad1_at = |th: [f64; 2]| -> f64 {
4282            let mut up = th;
4283            let mut dn = th;
4284            up[1] += h;
4285            dn[1] -= h;
4286            (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4287        };
4288        let mut up = th0;
4289        let mut dn = th0;
4290        up[1] += h;
4291        dn[1] -= h;
4292        let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4293        assert!(
4294            (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4295            "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4296            flux.h[1][1],
4297            fd_h11
4298        );
4299        // Explicit witness that the B·z_uv term is present and material:
4300        // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4301        let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4302        let b_zuv = flux.h[1][1] - pure_zu2;
4303        assert!(
4304            (b_zuv - b0 * 2.0).abs() < 1e-10,
4305            "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4306            b_zuv,
4307            b0 * 2.0
4308        );
4309    }
4310
4311    /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4312    /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4313    /// plain `moving_limit_boundary_tower` cannot express, and the case the
4314    /// survival directional/bidirectional paths hand-assemble term-by-term
4315    /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4316    /// dropping `G·z_uv`). Two independent oracles:
4317    ///   (1) closed-form: the boundary flux of `∫ G dz` is exactly
4318    ///       `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4319    ///       derivatives we take by central FD of the closed form — no jet code.
4320    ///   (2) the explicit second-order hand closure, including the `G·z_uv` term,
4321    ///       built from the integrand's own (z,θ) partials.
4322    /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4323    /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4324    /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4325    /// miss a Hessian entry.
4326    #[test]
4327    fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4328        // G(z;θ) = exp(z·θ₀);  Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4329        let g = |z: f64, t0: f64| (z * t0).exp();
4330        let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4331        let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4332        let th0 = [0.4_f64, 0.5_f64];
4333        let z0 = z_r(th0);
4334
4335        // Edge tower z_edge(θ) over K=2 primaries.
4336        let mut z_edge = Tower4::<2>::constant(z0);
4337        z_edge.g[0] = 1.0; // ∂z/∂θ₀
4338        z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4339        z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4340
4341        // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4342        // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4343        let z_var = Tower4::<3>::variable(z0, 0);
4344        let t0_var = Tower4::<3>::variable(th0[0], 1);
4345        // θ₁ does not enter G/Φ here (its Φ-derivatives are zero; the z_edge
4346        // chain supplies all θ₁ motion through slot 0), so the K1 frame's θ₁
4347        // slot is intentionally left unseeded.
4348        let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4349        // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4350        assert!(
4351            (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4352            "Φ_z {:+.8e} != G {:+.8e}",
4353            phi_jet.g[0],
4354            g(z0, th0[0])
4355        );
4356
4357        let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4358
4359        // Value channel is 0 by construction (boundary, not the integral itself).
4360        assert!(
4361            flux.v.abs() < 1e-12,
4362            "boundary value channel {:+.3e}",
4363            flux.v
4364        );
4365
4366        // Oracle (1): central FD of the closed-form boundary flux
4367        //   Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ)   (z₀ FROZEN at the base edge).
4368        let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4369        let h = 1e-4;
4370        let tol = 1e-6;
4371        for i in 0..2 {
4372            let mut up = th0;
4373            let mut dn = th0;
4374            up[i] += h;
4375            dn[i] -= h;
4376            let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4377            assert!(
4378                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4379                "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4380                flux.g[i],
4381                fd_g
4382            );
4383        }
4384        let grad_at = |th: [f64; 2], j: usize| -> f64 {
4385            let mut up = th;
4386            let mut dn = th;
4387            up[j] += h;
4388            dn[j] -= h;
4389            (bnd(up) - bnd(dn)) / (2.0 * h)
4390        };
4391        for i in 0..2 {
4392            for j in 0..2 {
4393                let mut up = th0;
4394                let mut dn = th0;
4395                up[i] += h;
4396                dn[i] -= h;
4397                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4398                assert!(
4399                    (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4400                    "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4401                    flux.h[i][j],
4402                    fd_h
4403                );
4404            }
4405        }
4406
4407        // Oracle (2): the explicit second-order hand closure, term by term —
4408        // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4409        // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4410        // G_θ₁ = 0.
4411        let gg = g(z0, th0[0]);
4412        let g_z = th0[0] * gg;
4413        let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4414        for i in 0..2 {
4415            for j in 0..2 {
4416                let z_u = z_edge.g[i];
4417                let z_v = z_edge.g[j];
4418                let z_uv = z_edge.h[i][j];
4419                let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4420                assert!(
4421                    (flux.h[i][j] - hand).abs() < 1e-9,
4422                    "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4423                    flux.h[i][j],
4424                    hand
4425                );
4426            }
4427        }
4428
4429        // Decisive: the `G·z_uv` term the directional path DROPS is present and
4430        // material in the [1][1] entry (z_uv = 2 there).
4431        let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4432        let g_zuv = flux.h[1][1] - pure_no_zuv;
4433        assert!(
4434            (g_zuv - gg * 2.0).abs() < 1e-9,
4435            "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4436            g_zuv,
4437            gg * 2.0
4438        );
4439    }
4440
4441    /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4442    /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4443    /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4444    /// boundary-velocity formulas:
4445    ///   z_u   = −(a_u + [u==g]·z) / b
4446    ///   z_uv  = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4447    /// This pins the bridge between `implicit_solve` and
4448    /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4449    /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4450    /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4451    /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4452    #[test]
4453    fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4454        const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4455        let g_idx = 1usize;
4456        let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4457        // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4458        // two live axes so a_u and a_uv are both exercised. (In production this
4459        // comes from implicit_solve; here we plant known derivatives.)
4460        let mut a = Tower4::<3>::constant(0.45);
4461        a.g[0] = 0.7;
4462        a.g[1] = -0.3;
4463        a.h[0][0] = 0.25;
4464        a.h[0][1] = 0.11;
4465        a.h[1][0] = 0.11;
4466        a.h[1][1] = -0.08;
4467
4468        // In the survival flex frame the slope `b` IS the g-primary directly
4469        // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4470        // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4471        let b = Tower4::<3>::variable(g0, g_idx);
4472        let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4473
4474        let bv = g0;
4475        let z0 = z_edge.v;
4476        assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4477
4478        // z_u = −(a_u + [u==g]·z) / b.
4479        for u in 0..2 {
4480            let direct = if u == g_idx { z0 } else { 0.0 };
4481            let want = -(a.g[u] + direct) / bv;
4482            assert!(
4483                (z_edge.g[u] - want).abs() < 1e-10,
4484                "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4485                z_edge.g[u],
4486                want
4487            );
4488        }
4489        // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4490        // first-order z_v/z_u (already verified above).
4491        for u in 0..2 {
4492            for v in 0..2 {
4493                let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4494                    + if v == g_idx { z_edge.g[u] } else { 0.0 };
4495                let want = -(a.h[u][v] + cross) / bv;
4496                assert!(
4497                    (z_edge.h[u][v] - want).abs() < 1e-10,
4498                    "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4499                    z_edge.h[u][v],
4500                    want
4501                );
4502            }
4503        }
4504    }
4505
4506    /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4507    /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4508    /// FD-certified bare boundary-velocity constants exactly:
4509    ///   z_a  = ∂z/∂a   = −1/b
4510    ///   z_ab = ∂²z/∂a∂b = +1/b²
4511    ///   z_aa = ∂²z/∂a²  = 0
4512    ///   z_bb = ∂²z/∂b²  = +2(τ−a)/b³
4513    /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4514    /// production base path drops (and only adds in the dir twins, causing the
4515    /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4516    /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4517    /// that later. Pins the constant before the constraint-tower wiring.
4518    #[test]
4519    fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4520        const TAU: f64 = 1.3;
4521        let a0 = 0.45_f64;
4522        let b0 = 0.85_f64;
4523        // Slot 0 = a, slot 1 = b, both seeded independent.
4524        let a = Tower4::<2>::variable(a0, 0);
4525        let b = Tower4::<2>::variable(b0, 1);
4526        let z = (Tower4::<2>::constant(TAU) - a) / b;
4527
4528        assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4529        assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4530        assert!(
4531            (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4532            "z_ab {:+.10e} vs +1/b² {:+.10e}",
4533            z.h[0][1],
4534            1.0 / (b0 * b0)
4535        );
4536        assert!(
4537            z.h[0][0].abs() < 1e-12,
4538            "z_aa must vanish, got {:+.10e}",
4539            z.h[0][0]
4540        );
4541        let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4542        assert!(
4543            (z.h[1][1] - want_zbb).abs() < 1e-12,
4544            "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4545            z.h[1][1],
4546            want_zbb
4547        );
4548    }
4549
4550    /// The oracle harness catches a planted #736-style sign flip in a
4551    /// cross block and reports the channel by name.
4552    #[test]
4553    fn oracle_catches_planted_cross_block_sign_flip() {
4554        let prog = LocScaleProgram {
4555            eta: vec![0.3],
4556            s: vec![-0.5],
4557            y: vec![1.0],
4558        };
4559        let t = evaluate_program(&prog, 0).expect("tower");
4560        let dir = [0.6, -0.2];
4561        let mut third = t.third_contracted(&dir);
4562        let honest = KernelChannels {
4563            value: t.v,
4564            gradient: t.g,
4565            hessian: t.h,
4566            third: vec![(dir, third)],
4567            fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4568        };
4569        verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4570
4571        // Plant the #736 flip: negate one mixed cross entry.
4572        third[0][1] = -third[0][1];
4573        let flipped = KernelChannels {
4574            value: t.v,
4575            gradient: t.g,
4576            hessian: t.h,
4577            third: vec![(dir, third)],
4578            fourth: vec![],
4579        };
4580        let err = verify_kernel_channels(&t, &flipped, 1e-10)
4581            .expect_err("planted sign flip must be caught");
4582        assert!(
4583            err.contains("third[0][0][1]"),
4584            "oracle must name the flipped channel, got: {err}"
4585        );
4586    }
4587
4588    /// The third- and fourth-order tensors must be FULLY symmetric under
4589    /// index permutation (mixed partials commute). The tower stores them
4590    /// unsymmetrized, so equal-by-construction is a real invariant of the
4591    /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4592    /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4593    /// every composition path contributes. Lives in a test (not the hot
4594    /// per-op path) on purpose.
4595    #[test]
4596    fn t3_t4_are_fully_index_symmetric() {
4597        let prog = GnarlyProgram::fixture();
4598        // 3! = 6 permutations of three indices.
4599        let perms3: [[usize; 3]; 6] = [
4600            [0, 1, 2],
4601            [0, 2, 1],
4602            [1, 0, 2],
4603            [1, 2, 0],
4604            [2, 0, 1],
4605            [2, 1, 0],
4606        ];
4607        // 4! = 24 permutations of four indices.
4608        let perms4: [[usize; 4]; 24] = [
4609            [0, 1, 2, 3],
4610            [0, 1, 3, 2],
4611            [0, 2, 1, 3],
4612            [0, 2, 3, 1],
4613            [0, 3, 1, 2],
4614            [0, 3, 2, 1],
4615            [1, 0, 2, 3],
4616            [1, 0, 3, 2],
4617            [1, 2, 0, 3],
4618            [1, 2, 3, 0],
4619            [1, 3, 0, 2],
4620            [1, 3, 2, 0],
4621            [2, 0, 1, 3],
4622            [2, 0, 3, 1],
4623            [2, 1, 0, 3],
4624            [2, 1, 3, 0],
4625            [2, 3, 0, 1],
4626            [2, 3, 1, 0],
4627            [3, 0, 1, 2],
4628            [3, 0, 2, 1],
4629            [3, 1, 0, 2],
4630            [3, 1, 2, 0],
4631            [3, 2, 0, 1],
4632            [3, 2, 1, 0],
4633        ];
4634        for row in 0..prog.n_rows() {
4635            let t = evaluate_program(&prog, row).expect("gnarly tower");
4636            let scale_t3 =
4637                t.t3.iter()
4638                    .flatten()
4639                    .flatten()
4640                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4641                    .max(1.0);
4642            let scale_t4 =
4643                t.t4.iter()
4644                    .flatten()
4645                    .flatten()
4646                    .flatten()
4647                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4648                    .max(1.0);
4649            for i in 0..3 {
4650                for j in 0..3 {
4651                    for k in 0..3 {
4652                        let base = t.t3[i][j][k];
4653                        let idx = [i, j, k];
4654                        for p in &perms3 {
4655                            let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4656                            assert!(
4657                                (base - permed).abs() <= 1e-12 * scale_t3,
4658                                "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4659                                 permuted {permed:+.15e} under {p:?}"
4660                            );
4661                        }
4662                        for l in 0..3 {
4663                            let base4 = t.t4[i][j][k][l];
4664                            let idx4 = [i, j, k, l];
4665                            for p in &perms4 {
4666                                let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4667                                assert!(
4668                                    (base4 - permed).abs() <= 1e-12 * scale_t4,
4669                                    "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4670                                     permuted {permed:+.15e} under {p:?}"
4671                                );
4672                            }
4673                        }
4674                    }
4675                }
4676            }
4677        }
4678    }
4679}
4680
4681#[inline]
4682fn erfcx_nonnegative(x: f64) -> f64 {
4683    if !x.is_finite() {
4684        return if x.is_sign_positive() {
4685            0.0
4686        } else {
4687            f64::INFINITY
4688        };
4689    }
4690    if x <= 0.0 {
4691        return 1.0;
4692    }
4693    if x < 26.0 {
4694        ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4695    } else {
4696        let inv = 1.0 / x;
4697        let inv2 = inv * inv;
4698        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4699            + 6.5625 * inv2 * inv2 * inv2 * inv2;
4700        inv * poly / std::f64::consts::PI.sqrt()
4701    }
4702}
4703
4704#[inline]
4705fn log1mexp_positive(a: f64) -> f64 {
4706    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4707    if a > core::f64::consts::LN_2 {
4708        (-(-a).exp()).ln_1p()
4709    } else if a > 0.0 {
4710        (-(-a).exp_m1()).ln()
4711    } else {
4712        f64::NEG_INFINITY
4713    }
4714}
4715
4716#[inline]
4717fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4718    if x == f64::INFINITY {
4719        return (0.0, 0.0);
4720    }
4721    if x == f64::NEG_INFINITY {
4722        return (f64::NEG_INFINITY, f64::INFINITY);
4723    }
4724    if x.is_nan() {
4725        return (f64::NAN, f64::NAN);
4726    }
4727    if x < 0.0 {
4728        let u = -x / std::f64::consts::SQRT_2;
4729        let ex = erfcx_nonnegative(u).max(1e-300);
4730        let log_cdf = -u * u + (0.5 * ex).ln();
4731        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4732        (log_cdf, lambda)
4733    } else {
4734        let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4735        let lambda = crate::probability::normal_pdf(x) / cdf;
4736        (cdf.ln(), lambda)
4737    }
4738}
4739
4740/// Stable derivative stack for `log Phi(x)` through fourth order.
4741#[inline]
4742pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4743    let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4744    let lambda2 = lambda * lambda;
4745    let lambda3 = lambda2 * lambda;
4746    let x2 = x * x;
4747    [
4748        log_cdf,
4749        lambda,
4750        -lambda * (x + lambda),
4751        lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4752        -lambda
4753            * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4754    ]
4755}
4756
4757/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4758#[inline]
4759pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4760    let r = 1.0 / x.exp_m1();
4761    [
4762        log1mexp_positive(x),
4763        r,
4764        -r * (1.0 + r),
4765        r * (1.0 + r) * (1.0 + 2.0 * r),
4766        -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4767    ]
4768}
4769// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4770#[cfg(test)]
4771mod rowjet_bridge_tests {
4772    use super::*;
4773    use crate::jet_scalar::{JetScalar, Order2};
4774
4775    /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4776    /// subtraction, a scale/neg, a constant, and two value-distinct
4777    /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4778    /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4779    /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4780    struct ToyProgram {
4781        primaries: Vec<[f64; 2]>,
4782        /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4783        /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4784        /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4785        aux: Vec<[f64; 3]>,
4786    }
4787
4788    impl ToyProgram {
4789        /// The body uses `pack_rows` to gather the per-lane continuous data from
4790        /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4791        /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4792        fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4793            let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4794            let z = R::pack_rows(rows, |r| self.aux[r][1]);
4795            let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4796
4797            let a = p[0].mul(&p[1]).scale_rows(cov);
4798            let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4799            let c = b
4800                .compose_unary_with(|u| {
4801                    let e = u.exp();
4802                    [e, e, e, e, e]
4803                })
4804                .scale_rows(z);
4805            let d = c.neg().add(&p[0]);
4806            let e = d
4807                .compose_unary_with(|u| {
4808                    let s = (1.0 + u * u).sqrt();
4809                    let s3 = s * s * s;
4810                    let s5 = s3 * s * s;
4811                    let s7 = s5 * s * s;
4812                    [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4813                })
4814                .scale_rows(wi);
4815            e.mul(&p[1]).add(&e)
4816        }
4817    }
4818
4819    impl RowNllProgramRowJet<2> for ToyProgram {
4820        fn n_rows(&self) -> usize {
4821            self.primaries.len()
4822        }
4823        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4824            Ok(self.primaries[row])
4825        }
4826        fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4827            assert!(
4828                rows.len() == 1 || rows.len() == 4,
4829                "lane→row map is 1 or 4 wide"
4830            );
4831            Ok(self.body(rows, p))
4832        }
4833    }
4834
4835    fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4836        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4837        for i in 0..2 {
4838            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4839            for j in 0..2 {
4840                assert_eq!(
4841                    a.h[i][j].to_bits(),
4842                    b.h[i][j].to_bits(),
4843                    "{ctx}: h[{i}][{j}]"
4844                );
4845                for k in 0..2 {
4846                    assert_eq!(
4847                        a.t3[i][j][k].to_bits(),
4848                        b.t3[i][j][k].to_bits(),
4849                        "{ctx}: t3[{i}][{j}][{k}]"
4850                    );
4851                    for l in 0..2 {
4852                        assert_eq!(
4853                            a.t4[i][j][k][l].to_bits(),
4854                            b.t4[i][j][k][l].to_bits(),
4855                            "{ctx}: t4[{i}][{j}][{k}][{l}]"
4856                        );
4857                    }
4858                }
4859            }
4860        }
4861    }
4862
4863    fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4864        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4865        for i in 0..2 {
4866            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4867            for j in 0..2 {
4868                assert_eq!(
4869                    a.h[i][j].to_bits(),
4870                    b.h[i][j].to_bits(),
4871                    "{ctx}: h[{i}][{j}]"
4872                );
4873                for k in 0..2 {
4874                    assert_eq!(
4875                        a.t3[i][j][k].to_bits(),
4876                        b.t3[i][j][k].to_bits(),
4877                        "{ctx}: t3[{i}][{j}][{k}]"
4878                    );
4879                }
4880            }
4881        }
4882    }
4883
4884    // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4885    struct Lcg(u64);
4886    impl Lcg {
4887        fn next(&mut self) -> f64 {
4888            self.0 = self
4889                .0
4890                .wrapping_mul(6364136223846793005)
4891                .wrapping_add(1442695040888963407);
4892            (self.0 >> 11) as f64 / (1u64 << 53) as f64
4893        }
4894        fn val(&mut self) -> f64 {
4895            let u = self.next();
4896            if u < 0.04 {
4897                return 0.0;
4898            }
4899            if u < 0.08 {
4900                return -0.0;
4901            }
4902            (self.next() - 0.5) * 5.0
4903        }
4904    }
4905
4906    /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4907    /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4908    /// signed-zero and per-lane-distinct primaries.
4909    #[test]
4910    fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4911        let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4912        let mut batches = 0usize;
4913        for _ in 0..2500 {
4914            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4915            // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4916            let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4917            let prog = ToyProgram {
4918                primaries: bases.to_vec(),
4919                aux: aux.to_vec(),
4920            };
4921            let rows = [0usize, 1, 2, 3];
4922
4923            // order-4 batch vs scalar Tower4 (instantiated through the same body).
4924            let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4925            for (row, base) in bases.iter().enumerate() {
4926                let vars: [Tower4<2>; 2] =
4927                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4928                let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4929                assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4930            }
4931
4932            // order-3 batch vs scalar Tower3.
4933            let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4934            for (row, base) in bases.iter().enumerate() {
4935                let vars: [Tower3<2>; 2] =
4936                    std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4937                let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4938                assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4939            }
4940            batches += 1;
4941        }
4942        assert_eq!(batches, 2500);
4943    }
4944
4945    /// The blanket impl does not churn the scalar path: the body driven through
4946    /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4947    /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4948    /// `Tower4` lower channels.
4949    #[test]
4950    fn blanket_scalar_path_is_unchanged_and_consistent() {
4951        let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4952        for _ in 0..3000 {
4953            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4954            let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4955            let prog = ToyProgram {
4956                primaries: vec![base],
4957                aux: vec![aux0],
4958            };
4959
4960            // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4961            // reference body uses `scale(f64)` where the RowJet body uses
4962            // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4963            // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4964            let via_rowjet: Tower4<2> = {
4965                let vars: [Tower4<2>; 2] =
4966                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4967                prog.row_nll(&[0], &vars).expect("rowjet")
4968            };
4969            let via_jetscalar: Tower4<2> = {
4970                let vars: [Tower4<2>; 2] =
4971                    std::array::from_fn(|a| <Tower4<2> as JetScalar<2>>::variable(base[a], a));
4972                let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4973                // The body using JetScalar's own ops + scale(f64) directly.
4974                let a = vars[0].mul(&vars[1]).scale(cov);
4975                let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4976                let c = b
4977                    .compose_unary_with(|u| {
4978                        let e = u.exp();
4979                        [e, e, e, e, e]
4980                    })
4981                    .scale(z);
4982                let d = JetScalar::neg(&c).add(&vars[0]);
4983                let e = d
4984                    .compose_unary_with(|u| {
4985                        let s = (1.0 + u * u).sqrt();
4986                        let s3 = s * s * s;
4987                        let s5 = s3 * s * s;
4988                        let s7 = s5 * s * s;
4989                        [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4990                    })
4991                    .scale(wi);
4992                e.mul(&vars[1]).add(&e)
4993            };
4994            assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4995
4996            // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4997            // Order2 and Tower4 use different internal representations so
4998            // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4999            // Hessian channels that evaluate to exactly zero; IEEE equality
5000            // treats these as equal, so `==` is the right comparison here.
5001            let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
5002            assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
5003            for i in 0..2 {
5004                assert!(
5005                    g[i] == via_rowjet.g[i],
5006                    "kernel g[{i}]: {} vs {}",
5007                    g[i],
5008                    via_rowjet.g[i]
5009                );
5010                for j in 0..2 {
5011                    assert!(
5012                        h[i][j] == via_rowjet.h[i][j],
5013                        "kernel h[{i}][{j}]: {} vs {}",
5014                        h[i][j],
5015                        via_rowjet.h[i][j]
5016                    );
5017                }
5018            }
5019
5020            // (c) the Order2 scalar IS a RowJet via the blanket.
5021            let o2: [Order2<2>; 2] =
5022                std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
5023            let via_order2 = prog.body(&[0], &o2);
5024            assert_eq!(
5025                via_order2.0.v.to_bits(),
5026                via_rowjet.v.to_bits(),
5027                "Order2 blanket value channel must match the dense Tower4 program body"
5028            );
5029        }
5030    }
5031
5032    /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
5033    /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
5034    /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
5035    #[test]
5036    fn scale_rows_scalar_is_bit_identical_to_scale() {
5037        let mut rng = Lcg(0xFEED_FACE_0042_1001);
5038        for _ in 0..3000 {
5039            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
5040            let s = rng.val();
5041            // Build a dense tower with populated channels (exp of a product).
5042            let vars: [Tower4<2>; 2] =
5043                std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
5044            let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
5045                let e = u.exp();
5046                [e, e, e, e, e]
5047            });
5048            let via_scale = RowJet::scale(&jet, s);
5049            let via_scale_rows = RowJet::scale_rows(&jet, s);
5050            assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
5051        }
5052    }
5053
5054    /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
5055    /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
5056    /// continuous per-row data path the single-`f64` `scale` could not carry.
5057    #[test]
5058    fn batched_scale_rows_matches_per_row_scalar_scale() {
5059        let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
5060        for _ in 0..2500 {
5061            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
5062            let s: [f64; 4] = std::array::from_fn(|_| rng.val());
5063            let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5064                Tower4Batch::variable(
5065                    wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
5066                    a,
5067                )
5068            });
5069            let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
5070                let e = u.exp();
5071                [e, e, e, e, e]
5072            });
5073            let scaled = prod.scale_rows(s);
5074            for (row, base) in bases.iter().enumerate() {
5075                let v: [Tower4<2>; 2] =
5076                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
5077                let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
5078                    let e = u.exp();
5079                    [e, e, e, e, e]
5080                });
5081                let ref_s = RowJet::scale(&prod_s, s[row]);
5082                assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
5083            }
5084        }
5085    }
5086
5087    /// The per-lane guard reports exactly the failing lanes on a batch and the
5088    /// single lane on a scalar jet.
5089    #[test]
5090    fn guard_reports_per_lane_failures() {
5091        let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
5092        let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5093            Tower4Batch::variable(
5094                wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
5095                a,
5096            )
5097        });
5098        let verdict = vars[0].guard(|v| v > 0.0);
5099        assert_eq!(verdict.lanes(), 4);
5100        assert!(verdict.any_failed());
5101        assert!(!verdict.all_pass());
5102        assert!(!verdict.lane_failed(0));
5103        assert!(verdict.lane_failed(1));
5104        assert!(!verdict.lane_failed(2));
5105        assert!(verdict.lane_failed(3));
5106        assert_eq!(verdict.failed_mask(), 0b1010);
5107
5108        let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
5109        let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
5110        assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
5111        assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
5112        assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
5113    }
5114
5115    // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
5116
5117    #[test]
5118    fn ln_gamma_derivative_stack_known_values_at_1() {
5119        let s = ln_gamma_derivative_stack(1.0);
5120        // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
5121        assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
5122        // ψ₀(1) = -γ  (Euler–Mascheroni)
5123        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5124        assert!(
5125            (s[1] + euler_mascheroni).abs() < 1e-10,
5126            "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
5127            s[1]
5128        );
5129        // ψ₁(1) = π²/6
5130        let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
5131        assert!(
5132            (s[2] - pi2_6).abs() < 1e-10,
5133            "trigamma(1) ≈ {pi2_6:.6}, got {}",
5134            s[2]
5135        );
5136    }
5137
5138    #[test]
5139    fn ln_gamma_derivative_stack_known_values_at_2() {
5140        let s = ln_gamma_derivative_stack(2.0);
5141        // ln Γ(2) = ln(1) = 0 exactly
5142        assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
5143        // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
5144        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5145        let digamma_2 = 1.0 - euler_mascheroni;
5146        assert!(
5147            (s[1] - digamma_2).abs() < 1e-10,
5148            "digamma(2) ≈ {digamma_2:.6}, got {}",
5149            s[1]
5150        );
5151    }
5152
5153    #[test]
5154    fn ln_gamma_derivative_stack_order2_is_prefix() {
5155        for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5156            let full = ln_gamma_derivative_stack(x);
5157            let ord2 = ln_gamma_derivative_stack_order2(x);
5158            assert_eq!(ord2[0], full[0], "order2[0] != full[0] at x={x}");
5159            assert_eq!(ord2[1], full[1], "order2[1] != full[1] at x={x}");
5160            assert_eq!(ord2[2], full[2], "order2[2] != full[2] at x={x}");
5161        }
5162    }
5163
5164    #[test]
5165    fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5166        // The two stacks share a run of four polygamma values:
5167        // ln_gamma_stack[1..5] == digamma_stack[0..4]
5168        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5169            let lg = ln_gamma_derivative_stack(x);
5170            let dg = digamma_derivative_stack(x);
5171            for i in 0..4 {
5172                assert_eq!(
5173                    lg[i + 1],
5174                    dg[i],
5175                    "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5176                    i + 1,
5177                    i
5178                );
5179            }
5180        }
5181    }
5182
5183    #[test]
5184    fn trigamma_derivative_stack_overlaps_digamma_stack() {
5185        // digamma_stack[1..5] == trigamma_stack[0..4]
5186        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5187            let dg = digamma_derivative_stack(x);
5188            let tg = trigamma_derivative_stack(x);
5189            for i in 0..4 {
5190                assert_eq!(
5191                    dg[i + 1],
5192                    tg[i],
5193                    "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5194                    i + 1,
5195                    i
5196                );
5197            }
5198        }
5199    }
5200
5201    #[test]
5202    fn derivative_stacks_all_finite_at_positive_inputs() {
5203        for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5204            for v in ln_gamma_derivative_stack(x) {
5205                assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5206            }
5207            for v in digamma_derivative_stack(x) {
5208                assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5209            }
5210            for v in trigamma_derivative_stack(x) {
5211                assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5212            }
5213        }
5214    }
5215}
5216
5217// ── Contraction-symmetry optimization gate ────────────────────────────────────
5218//
5219// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5220// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5221// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5222// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5223// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5224// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5225// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5226// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5227// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5228// the full reference and (2) a measured wall-clock that is not slower.
5229#[cfg(test)]
5230mod contraction_symmetry_tests {
5231    use super::*;
5232
5233    struct Rng(u64);
5234    impl Rng {
5235        fn u(&mut self) -> f64 {
5236            self.0 = self
5237                .0
5238                .wrapping_mul(6364136223846793005)
5239                .wrapping_add(1442695040888963407);
5240            (self.0 >> 11) as f64 / (1u64 << 53) as f64
5241        }
5242        fn s(&mut self) -> f64 {
5243            (self.u() - 0.5) * 4.0
5244        }
5245    }
5246
5247    /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5248    fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5249        let mut t = Tower4::<K>::zero();
5250        t.v = r.s();
5251        for i in 0..K {
5252            t.g[i] = r.s();
5253        }
5254        for a in 0..K {
5255            for b in a..K {
5256                let v2 = r.s();
5257                t.h[a][b] = v2;
5258                t.h[b][a] = v2;
5259                for c in b..K {
5260                    let v3 = r.s();
5261                    for p in perms3([a, b, c]) {
5262                        t.t3[p[0]][p[1]][p[2]] = v3;
5263                    }
5264                    for d in c..K {
5265                        let v4 = r.s();
5266                        for p in perms4([a, b, c, d]) {
5267                            t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5268                        }
5269                    }
5270                }
5271            }
5272        }
5273        t
5274    }
5275
5276    fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5277        let [a, b, c] = idx;
5278        [
5279            [a, b, c],
5280            [a, c, b],
5281            [b, a, c],
5282            [b, c, a],
5283            [c, a, b],
5284            [c, b, a],
5285        ]
5286    }
5287    fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5288        let [a, b, c, d] = idx;
5289        [
5290            [a, b, c, d],
5291            [a, b, d, c],
5292            [a, c, b, d],
5293            [a, c, d, b],
5294            [a, d, b, c],
5295            [a, d, c, b],
5296            [b, a, c, d],
5297            [b, a, d, c],
5298            [b, c, a, d],
5299            [b, c, d, a],
5300            [b, d, a, c],
5301            [b, d, c, a],
5302            [c, a, b, d],
5303            [c, a, d, b],
5304            [c, b, a, d],
5305            [c, b, d, a],
5306            [c, d, a, b],
5307            [c, d, b, a],
5308            [d, a, b, c],
5309            [d, a, c, b],
5310            [d, b, a, c],
5311            [d, b, c, a],
5312            [d, c, a, b],
5313            [d, c, b, a],
5314        ]
5315    }
5316
5317    /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5318    fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5319        let mut out = [[0.0; K]; K];
5320        for a in 0..K {
5321            for b in 0..K {
5322                let mut acc = 0.0;
5323                for c in 0..K {
5324                    acc += t.t3[a][b][c] * dir[c];
5325                }
5326                out[a][b] = acc;
5327            }
5328        }
5329        out
5330    }
5331    fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5332        let mut out = [[0.0; K]; K];
5333        for i in 0..K {
5334            for j in 0..K {
5335                let mut acc = 0.0;
5336                for k in 0..K {
5337                    for l in 0..K {
5338                        acc += t.t4[i][j][k][l] * u[k] * w[l];
5339                    }
5340                }
5341                out[i][j] = acc;
5342            }
5343        }
5344        out
5345    }
5346
5347    /// Returns the number of bit-equality comparisons performed (`n·K·K·2`), so
5348    /// the caller can assert the intended workload actually ran: a generic
5349    /// (turbofish) helper call hides its internal assertions, so the count is
5350    /// surfaced and checked at the call site.
5351    fn check_bit_identical<const K: usize>(seed: u64, n: usize) -> usize {
5352        let mut r = Rng(seed);
5353        let mut checks = 0usize;
5354        for _ in 0..n {
5355            let t = rand_sym4::<K>(&mut r);
5356            let dir: [f64; K] = std::array::from_fn(|_| r.s());
5357            let u: [f64; K] = std::array::from_fn(|_| r.s());
5358            let w: [f64; K] = std::array::from_fn(|_| r.s());
5359            let t3_sym = t.third_contracted(&dir);
5360            let t3_full = third_full(&t, &dir);
5361            let t4_sym = t.fourth_contracted(&u, &w);
5362            let t4_full = fourth_full(&t, &u, &w);
5363            for a in 0..K {
5364                for b in 0..K {
5365                    assert_eq!(
5366                        t3_sym[a][b].to_bits(),
5367                        t3_full[a][b].to_bits(),
5368                        "third K={K} [{a}][{b}]"
5369                    );
5370                    assert_eq!(
5371                        t4_sym[a][b].to_bits(),
5372                        t4_full[a][b].to_bits(),
5373                        "fourth K={K} [{a}][{b}]"
5374                    );
5375                    checks += 2;
5376                }
5377            }
5378        }
5379        checks
5380    }
5381
5382    /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5383    /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5384    /// are unchanged; this is a pure speed-only optimization).
5385    #[test]
5386    fn contraction_symmetry_is_bit_identical_to_full_nest() {
5387        let checks = check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000)
5388            + check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800)
5389            + check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600)
5390            + check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5391        // Guards against the loops silently not running (e.g. a zeroed count):
5392        // 1000·2²·2 + 800·3²·2 + 600·4²·2 + 300·9²·2.
5393        assert_eq!(checks, 8000 + 14400 + 19200 + 48600);
5394    }
5395
5396    /// Measure the wall-clock of the output-symmetric contraction vs the full
5397    /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5398    /// test is the correctness gate). Informational — wall-clock is noisy — with
5399    /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5400    /// fewer inner contractions, so it must not be materially slower).
5401    #[test]
5402    fn contraction_symmetry_speedup_is_reported() {
5403        const K: usize = 9;
5404        let mut r = Rng(0xC0FF_EE99_1234_5678);
5405        let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5406        let dir: [f64; K] = std::array::from_fn(|_| r.s());
5407        let u: [f64; K] = std::array::from_fn(|_| r.s());
5408        let w: [f64; K] = std::array::from_fn(|_| r.s());
5409
5410        let reps = 400usize;
5411        let t_sym = {
5412            let start = std::time::Instant::now();
5413            let mut sink = 0.0f64;
5414            for _ in 0..reps {
5415                for t in &towers {
5416                    let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5417                    let o4 = std::hint::black_box(t)
5418                        .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5419                    sink += o3[0][K - 1] + o4[0][K - 1];
5420                }
5421            }
5422            std::hint::black_box(sink);
5423            start.elapsed().as_secs_f64()
5424        };
5425        let t_full = {
5426            let start = std::time::Instant::now();
5427            let mut sink = 0.0f64;
5428            for _ in 0..reps {
5429                for t in &towers {
5430                    let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5431                    let o4 = fourth_full(
5432                        std::hint::black_box(t),
5433                        std::hint::black_box(&u),
5434                        std::hint::black_box(&w),
5435                    );
5436                    sink += o3[0][K - 1] + o4[0][K - 1];
5437                }
5438            }
5439            std::hint::black_box(sink);
5440            start.elapsed().as_secs_f64()
5441        };
5442        let calls = (reps * towers.len()) as f64;
5443        eprintln!(
5444            "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5445             wall_speedup={:.2}x",
5446            t_sym / calls * 1e9,
5447            t_full / calls * 1e9,
5448            t_full / t_sym
5449        );
5450        assert!(
5451            t_sym <= t_full * 1.5,
5452            "output-symmetric contraction pathologically slower: \
5453             sym={t_sym:.4}s full={t_full:.4}s"
5454        );
5455    }
5456}