Skip to main content

gam_math/
jet_tower.rs

1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//!   v        ℓ
12//!   g[a]     ∂ℓ/∂p_a
13//!   h[a][b]  ∂²ℓ/∂p_a∂p_b
14//!   t3[abc]  ∂³ℓ/∂p_a∂p_b∂p_c
15//!   t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//!    asserting channel-by-channel agreement with a `RowNllProgram` written
94//!    once — see [`verify_kernel_channels`]. This alone would have caught
95//!    #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//!    keep hand-tuned hot paths, now verified against the single-expression
98//!    truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//!    scale port) implement ONLY `RowNllProgram` and get an exact
101//!    fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113    /// Value ℓ.
114    pub v: f64,
115    /// Gradient ∂ℓ/∂p_a.
116    pub g: [f64; K],
117    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118    pub h: [[f64; K]; K],
119    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120    pub t3: [[[f64; K]; K]; K],
121    /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122    pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126    /// The additive identity.
127    pub fn zero() -> Self {
128        Self {
129            v: 0.0,
130            g: [0.0; K],
131            h: [[0.0; K]; K],
132            t3: [[[0.0; K]; K]; K],
133            t4: [[[[0.0; K]; K]; K]; K],
134        }
135    }
136
137    /// A constant: value `c`, all derivatives zero.
138    pub fn constant(c: f64) -> Self {
139        let mut out = Self::zero();
140        out.v = c;
141        out
142    }
143
144    /// The seeded variable `p_idx` with current value `value`:
145    /// unit first derivative in slot `idx`, zero elsewhere and above.
146    pub fn variable(value: f64, idx: usize) -> Self {
147        let mut out = Self::constant(value);
148        out.g[idx] = 1.0;
149        out
150    }
151
152    /// Read the (fully symmetric) derivative tensor entry whose differentiation
153    /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154    #[inline]
155    fn deriv(&self, labels: &[usize]) -> f64 {
156        assert!(
157            labels.len() <= 4,
158            "Tower4 carries at most fourth-order derivatives"
159        );
160        match labels.len() {
161            0 => self.v,
162            1 => self.g[labels[0]],
163            2 => self.h[labels[0]][labels[1]],
164            3 => self.t3[labels[0]][labels[1]][labels[2]],
165            _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166        }
167    }
168
169    /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170    ///
171    /// # Codegen
172    ///
173    /// Each output entry's `2^m` subset sum is written as a compact straight-line
174    /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175    /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176    /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177    /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178    /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179    /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180    /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181    ///
182    /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183    /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184    /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185    /// start so a signed-zero leading product collapses to `+0.0` identically —
186    /// which matters because real jets carry exact-`0.0` channels
187    /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188    /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189    /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190    /// stress — the accumulator start is load-bearing).
191    pub fn mul(&self, o: &Self) -> Self {
192        let a = self;
193        let b = o;
194        let mut out = Self::zero();
195        out.v = a.v * b.v;
196        for i in 0..K {
197            // subsets of {i}: {} {i}
198            let mut acc = 0.0;
199            acc += a.v * b.g[i];
200            acc += a.g[i] * b.v;
201            out.g[i] = acc;
202        }
203        // Hessian is symmetric under i↔j; compute the upper triangle and mirror
204        // (see [`Tower2::mul`] — same term order, enforces exact symmetry).
205        for i in 0..K {
206            for j in i..K {
207                // subsets of {i,j}: {} {i} {j} {ij}
208                let mut acc = 0.0;
209                acc += a.v * b.h[i][j];
210                acc += a.g[i] * b.g[j];
211                acc += a.g[j] * b.g[i];
212                acc += a.h[i][j] * b.v;
213                out.h[i][j] = acc;
214                out.h[j][i] = acc;
215            }
216        }
217        for i in 0..K {
218            for j in 0..K {
219                for k in 0..K {
220                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
221                    let mut acc = 0.0;
222                    acc += a.v * b.t3[i][j][k];
223                    acc += a.g[i] * b.h[j][k];
224                    acc += a.g[j] * b.h[i][k];
225                    acc += a.h[i][j] * b.g[k];
226                    acc += a.g[k] * b.h[i][j];
227                    acc += a.h[i][k] * b.g[j];
228                    acc += a.h[j][k] * b.g[i];
229                    acc += a.t3[i][j][k] * b.v;
230                    out.t3[i][j][k] = acc;
231                }
232            }
233        }
234        for i in 0..K {
235            for j in 0..K {
236                for k in 0..K {
237                    for l in 0..K {
238                        // subsets of {i,j,k,l} in bit order sub = 0..16
239                        let mut acc = 0.0;
240                        acc += a.v * b.t4[i][j][k][l];
241                        acc += a.g[i] * b.t3[j][k][l];
242                        acc += a.g[j] * b.t3[i][k][l];
243                        acc += a.h[i][j] * b.h[k][l];
244                        acc += a.g[k] * b.t3[i][j][l];
245                        acc += a.h[i][k] * b.h[j][l];
246                        acc += a.h[j][k] * b.h[i][l];
247                        acc += a.t3[i][j][k] * b.g[l];
248                        acc += a.g[l] * b.t3[i][j][k];
249                        acc += a.h[i][l] * b.h[j][k];
250                        acc += a.h[j][l] * b.h[i][k];
251                        acc += a.t3[i][j][l] * b.g[k];
252                        acc += a.h[k][l] * b.h[i][j];
253                        acc += a.t3[i][k][l] * b.g[j];
254                        acc += a.t3[j][k][l] * b.g[i];
255                        acc += a.t4[i][j][k][l] * b.v;
256                        out.t4[i][j][k][l] = acc;
257                    }
258                }
259            }
260        }
261        out
262    }
263
264    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
265    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
266    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
267    /// out of the borrowed operands.
268    pub fn add(&self, o: &Self) -> Self {
269        *self + *o
270    }
271
272    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
273    pub fn sub(&self, o: &Self) -> Self {
274        *self + o.scale(-1.0)
275    }
276
277    /// Exact multivariate Faà di Bruno composition `f ∘ self`.
278    ///
279    /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
280    /// the SAME `[f64; 5]` stack shape the families' existing
281    /// `unary_derivatives_*` helpers produce, so those special-function
282    /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
283    ///
284    /// The order-m output sums over the set partitions of the m indices
285    /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
286    /// block count: each partition into r blocks contributes
287    /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
288    ///
289    /// # Codegen
290    ///
291    /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
292    /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
293    /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
294    /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
295    /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
296    /// sum is straight-line, so this does NOT unroll over `K` and does NOT
297    /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
298    /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
299    /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
300    ///
301    /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
302    /// walker's exact partition-enumeration order, each term's block products
303    /// are left-associated exactly as the walker's `prod *= block`, and the
304    /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
305    /// (so signed-zero products collapse to `+0.0` identically). The order-4
306    /// term sequence was generated from the walker's own enumeration. Proven
307    /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
308    /// 5000 random inputs each (zeroed / sign-varied stacks included).
309    pub fn compose_unary(&self, d: [f64; 5]) -> Self {
310        let mut out = Self::zero();
311        out.v = d[0];
312        for i in 0..K {
313            let mut acc = 0.0;
314            acc += d[1] * self.g[i];
315            out.g[i] = acc;
316        }
317        for i in 0..K {
318            for j in 0..K {
319                let mut acc = 0.0;
320                acc += d[1] * self.h[i][j];
321                acc += d[2] * self.g[i] * self.g[j];
322                out.h[i][j] = acc;
323            }
324        }
325        for i in 0..K {
326            for j in 0..K {
327                for k in 0..K {
328                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
329                    let mut acc = 0.0;
330                    acc += d[1] * self.t3[i][j][k];
331                    acc += d[2] * self.h[i][j] * self.g[k];
332                    acc += d[2] * self.h[i][k] * self.g[j];
333                    acc += d[2] * self.g[i] * self.h[j][k];
334                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
335                    out.t3[i][j][k] = acc;
336                }
337            }
338        }
339        for i in 0..K {
340            for j in 0..K {
341                for k in 0..K {
342                    for l in 0..K {
343                        // Bell(4)=15 partitions, walker enumeration order.
344                        let mut acc = 0.0;
345                        acc += d[1] * self.t4[i][j][k][l];
346                        acc += d[2] * self.t3[i][j][k] * self.g[l];
347                        acc += d[2] * self.t3[i][j][l] * self.g[k];
348                        acc += d[2] * self.h[i][j] * self.h[k][l];
349                        acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
350                        acc += d[2] * self.t3[i][k][l] * self.g[j];
351                        acc += d[2] * self.h[i][k] * self.h[j][l];
352                        acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
353                        acc += d[2] * self.h[i][l] * self.h[j][k];
354                        acc += d[2] * self.g[i] * self.t3[j][k][l];
355                        acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
356                        acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
357                        acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
358                        acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
359                        acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
360                        out.t4[i][j][k][l] = acc;
361                    }
362                }
363            }
364        }
365        out
366    }
367
368    /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
369    /// built from the base value through `stack_fn` — the scalar arm of the
370    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
371    /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
372    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
373    /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
374    /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
375    /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
376    #[inline]
377    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
378        self.compose_unary(stack_fn(self.v))
379    }
380
381    /// Single-active-slot fast path for [`Self::compose_unary`].
382    ///
383    /// When the inner jet `self` has derivative support ONLY on the all-`slot`
384    /// diagonal channels — i.e. it is a univariate jet in primary `slot`
385    /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
386    /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
387    /// Bruno walk collapses. Every output channel whose axis tuple contains an
388    /// axis `≠ slot` is structurally `0`: each set-partition has a block
389    /// covering that axis, that block reads an off-`slot` derivative of `self`
390    /// (which is `0`), so the block product and the whole partition vanish, and
391    /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
392    /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
393    /// `t3[slot]³`, `t4[slot]⁴`) survive.
394    ///
395    /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
396    /// the EXACT term order of [`Self::compose_unary`]'s diagonal
397    /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
398    /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
399    /// at the zero-init `+0.0`, which the full walk also produces (the
400    /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
401    /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
402    /// `K ≥ 2` this is far fewer floating-point operations than materialising
403    /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
404    /// are all zero, and far cheaper than the recursive set-partition walker the
405    /// diagonal channels previously routed through (a measured ~9.5× speedup vs
406    /// the full `compose_unary`, recovering a 5.9× walker regression at the
407    /// `K ∈ {2,3}` BMS tower widths).
408    ///
409    /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
410    /// five-channel build does not amortise the call/spill overhead).
411    ///
412    /// # Precondition
413    ///
414    /// The caller guarantees the single-active-slot structure. If it does not
415    /// hold, the off-`slot` channels would be wrongly zeroed; use the full
416    /// [`Self::compose_unary`] in that case.
417    #[inline]
418    pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
419        let mut out = Self::zero();
420        let s = slot;
421        let g = self.g[s];
422        let h = self.h[s][s];
423        let t3 = self.t3[s][s][s];
424        let t4 = self.t4[s][s][s][s];
425        out.v = d[0];
426        // g (i=s): d1*g
427        out.g[s] = {
428            let mut acc = 0.0;
429            acc += d[1] * g;
430            acc
431        };
432        // h (i=j=s): d1*h + d2*g*g
433        out.h[s][s] = {
434            let mut acc = 0.0;
435            acc += d[1] * h;
436            acc += d[2] * g * g;
437            acc
438        };
439        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
440        out.t3[s][s][s] = {
441            let mut acc = 0.0;
442            acc += d[1] * t3;
443            acc += d[2] * h * g;
444            acc += d[2] * h * g;
445            acc += d[2] * g * h;
446            acc += d[3] * g * g * g;
447            acc
448        };
449        // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
450        out.t4[s][s][s][s] = {
451            let mut acc = 0.0;
452            acc += d[1] * t4;
453            acc += d[2] * t3 * g;
454            acc += d[2] * t3 * g;
455            acc += d[2] * h * h;
456            acc += d[3] * h * g * g;
457            acc += d[2] * t3 * g;
458            acc += d[2] * h * h;
459            acc += d[3] * h * g * g;
460            acc += d[2] * h * h;
461            acc += d[2] * g * t3;
462            acc += d[3] * g * h * g;
463            acc += d[3] * h * g * g;
464            acc += d[3] * g * h * g;
465            acc += d[3] * g * g * h;
466            acc += d[4] * g * g * g * g;
467            acc
468        };
469        out
470    }
471
472    /// Multiply every channel by a plain scalar.
473    pub fn scale(&self, s: f64) -> Self {
474        let mut out = *self;
475        out.v *= s;
476        for i in 0..K {
477            out.g[i] *= s;
478            for j in 0..K {
479                out.h[i][j] *= s;
480                for k in 0..K {
481                    out.t3[i][j][k] *= s;
482                    for l in 0..K {
483                        out.t4[i][j][k][l] *= s;
484                    }
485                }
486            }
487        }
488        out
489    }
490
491    /// e^self.
492    pub fn exp(&self) -> Self {
493        let e = self.v.exp();
494        self.compose_unary([e, e, e, e, e])
495    }
496
497    /// ln(self). Caller guarantees positivity (likelihood programs do).
498    pub fn ln(&self) -> Self {
499        let u = self.v;
500        let r = 1.0 / u;
501        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
502    }
503
504    /// 1/self.
505    pub fn recip(&self) -> Self {
506        let r = 1.0 / self.v;
507        let r2 = r * r;
508        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
509    }
510
511    /// √self. Caller guarantees positivity.
512    pub fn sqrt(&self) -> Self {
513        let u = self.v;
514        let s = u.sqrt();
515        self.compose_unary([
516            s,
517            0.5 / s,
518            -0.25 / (u * s),
519            0.375 / (u * u * s),
520            -0.9375 / (u * u * u * s),
521        ])
522    }
523
524    /// self^a for real exponent `a`. Caller guarantees a positive base.
525    pub fn powf(&self, a: f64) -> Self {
526        let u = self.v;
527        let f0 = u.powf(a);
528        let f1 = a * u.powf(a - 1.0);
529        let f2 = a * (a - 1.0) * u.powf(a - 2.0);
530        let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
531        let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
532        self.compose_unary([f0, f1, f2, f3, f4])
533    }
534
535    /// ln Γ(self). Caller guarantees positivity.
536    pub fn ln_gamma(&self) -> Self {
537        self.compose_unary(ln_gamma_derivative_stack(self.v))
538    }
539
540    /// ψ(self), the digamma function. Caller guarantees positivity.
541    pub fn digamma(&self) -> Self {
542        self.compose_unary(digamma_derivative_stack(self.v))
543    }
544
545    /// ψ′(self), the trigamma function. Caller guarantees positivity.
546    pub fn trigamma(&self) -> Self {
547        self.compose_unary(trigamma_derivative_stack(self.v))
548    }
549
550    /// Contract `t3` with one primary-space direction:
551    /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
552    /// `row_third_contracted` shape.
553    ///
554    /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
555    /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
556    /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
557    /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
558    /// stays contiguous/vectorisable) and mirror into the lower triangle — this
559    /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
560    /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
561    pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
562        let mut out = [[0.0; K]; K];
563        for a in 0..K {
564            for b in a..K {
565                let mut acc = 0.0;
566                for c in 0..K {
567                    acc += self.t3[a][b][c] * dir[c];
568                }
569                out[a][b] = acc;
570                out[b][a] = acc;
571            }
572        }
573        out
574    }
575
576    /// Contract `t4` with two primary-space directions:
577    /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
578    /// `row_fourth_contracted` shape.
579    ///
580    /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
581    /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
582    /// order), so the upper triangle `i ≤ j` is computed and mirrored —
583    /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
584    /// and the inner double loop stays the original contiguous/vectorisable form.
585    pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
586        let mut out = [[0.0; K]; K];
587        for i in 0..K {
588            for j in i..K {
589                let mut acc = 0.0;
590                for k in 0..K {
591                    for l in 0..K {
592                        acc += self.t4[i][j][k][l] * u[k] * w[l];
593                    }
594                }
595                out[i][j] = acc;
596                out[j][i] = acc;
597            }
598        }
599        out
600    }
601}
602
603impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
604    #[inline]
605    fn derivative(&self, labels: &[usize]) -> f64 {
606        self.deriv(labels)
607    }
608
609    fn map_derivatives<F>(&self, mut f: F) -> Self
610    where
611        F: FnMut(&[usize]) -> f64,
612    {
613        let mut out = Self::zero();
614        out.v = f(&[]);
615        for i in 0..K {
616            let labels = [i];
617            out.g[i] = f(&labels);
618        }
619        for i in 0..K {
620            for j in 0..K {
621                let labels = [i, j];
622                out.h[i][j] = f(&labels);
623            }
624        }
625        for i in 0..K {
626            for j in 0..K {
627                for k in 0..K {
628                    let labels = [i, j, k];
629                    out.t3[i][j][k] = f(&labels);
630                }
631            }
632        }
633        for i in 0..K {
634            for j in 0..K {
635                for k in 0..K {
636                    for l in 0..K {
637                        let labels = [i, j, k, l];
638                        out.t4[i][j][k][l] = f(&labels);
639                    }
640                }
641            }
642        }
643        out
644    }
645}
646
647/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
648///
649/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
650/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
651/// [`Tower4`] uses for those orders, so for any program written over both
652/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
653/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
654/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
655/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
656/// gradient, or Hessian.
657///
658/// It exists purely for performance: an inner Newton step (and the
659/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
660/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
661/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
662/// dominates the cold marginal-slope fit, while returning the exact same
663/// `(v, g, h)`.
664#[derive(Clone, Copy, Debug)]
665pub struct Tower2<const K: usize> {
666    /// Value ℓ.
667    pub v: f64,
668    /// Gradient ∂ℓ/∂p_a.
669    pub g: [f64; K],
670    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
671    pub h: [[f64; K]; K],
672}
673
674impl<const K: usize> Tower2<K> {
675    /// The additive identity.
676    pub fn zero() -> Self {
677        Self {
678            v: 0.0,
679            g: [0.0; K],
680            h: [[0.0; K]; K],
681        }
682    }
683
684    /// A constant: value `c`, all derivatives zero.
685    pub fn constant(c: f64) -> Self {
686        let mut out = Self::zero();
687        out.v = c;
688        out
689    }
690
691    /// The seeded variable `p_idx` with current value `value`:
692    /// unit first derivative in slot `idx`, zero elsewhere and above.
693    pub fn variable(value: f64, idx: usize) -> Self {
694        let mut out = Self::constant(value);
695        out.g[idx] = 1.0;
696        out
697    }
698
699    /// Read the derivative tensor entry whose differentiation axes are
700    /// `labels` (length 0..=2): value, `g`, `h`.
701    #[inline]
702    fn deriv(&self, labels: &[usize]) -> f64 {
703        assert!(
704            labels.len() <= 2,
705            "Tower2 carries at most second-order derivatives"
706        );
707        match labels.len() {
708            0 => self.v,
709            1 => self.g[labels[0]],
710            _ => self.h[labels[0]][labels[1]],
711        }
712    }
713
714    /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` upper
715    /// triangle matches [`Tower4::mul`] term-for-term.
716    ///
717    /// # Symmetry fast path
718    ///
719    /// The order-≤2 Leibniz Hessian
720    /// `h[i][j] = a.v·b.h[i][j] + a.g[i]·b.g[j] + a.g[j]·b.g[i] + a.h[i][j]·b.v`
721    /// is symmetric under `i ↔ j` whenever the operand Hessians are — which they
722    /// always are: `constant`/`variable` seed a symmetric (zero) `h`, and
723    /// `mul`/`compose_unary`/`add`/`scale` each preserve symmetry, so the
724    /// invariant holds for every tower a row program can build. We therefore
725    /// compute only the upper triangle `j ≥ i` and mirror it into the lower
726    /// triangle. At the `K = 9` survival width that is `K(K+1)/2 = 45` four-product
727    /// entry evaluations instead of `K² = 81`, and the win is larger in wall-clock
728    /// because the `648`-entry `h` spills at `K = 9` — halving the expensive
729    /// stores/reloads roughly halves the kernel (measured ≈2× on a `Tower2<9>`
730    /// mul-and-read throughput microbench; the dominant `mul` under every packed
731    /// scalar bottoms out here).
732    ///
733    /// The upper-triangle entries are BIT-IDENTICAL to the old rectangular form
734    /// (same term/accumulation order). The lower triangle now equals its mirror
735    /// exactly, where the rectangular form rounded `h[i][j]` and `h[j][i]`
736    /// independently (the two cross products accumulate in opposite order) and
737    /// left a ≤1-ulp asymmetry; mirroring removes it, so the result is exactly
738    /// symmetric — strictly closer to the true symmetric Hessian, not merely a
739    /// reordering. Dense-`h` consumers are all tolerance-gated (rel-tol ≥ 1e-11 ≫
740    /// 1e-16); the `f64`/`f64x4` lane oracle stays exact because
741    /// [`crate::jet_scalar::Order2Lane::mul`] mirrors term-for-term.
742    pub fn mul(&self, o: &Self) -> Self {
743        let a = self;
744        let b = o;
745        let mut out = Self::zero();
746        out.v = a.v * b.v;
747        for i in 0..K {
748            out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
749        }
750        for i in 0..K {
751            for j in i..K {
752                let hij =
753                    a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
754                out.h[i][j] = hij;
755                out.h[j][i] = hij;
756            }
757        }
758        out
759    }
760
761    /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
762    ///
763    /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
764    /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
765    /// `d[0..=2]` for those orders), so this is a strict truncation, not an
766    /// approximation. The full-order `[f64; 5]` derivative stacks the families
767    /// already produce can be passed by slicing their first three entries.
768    ///
769    /// # Codegen
770    ///
771    /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
772    /// directly instead of routing through the generic
773    /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
774    /// closure dispatch). That matters because this is the kernel under EVERY
775    /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
776    /// composition all bottom out here — so the straight-line form (whose inner
777    /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
778    /// walker calls) lifts all of them at once.
779    ///
780    /// The term and accumulation order is BIT-IDENTICAL to the walker it
781    /// replaces: each output channel mirrors the walker's `total = 0.0` start
782    /// (the explicit `acc` accumulator), so a signed-zero product collapses to
783    /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
784    /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
785    /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
786    ///   `g[i]`   = `f′·u_i`                   (single block `{i}`)
787    ///   `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j`  (blocks `{ij}` then `{i}{j}`),
788    /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
789    pub fn compose_unary(&self, d: [f64; 3]) -> Self {
790        let mut out = Self::zero();
791        out.v = d[0];
792        for i in 0..K {
793            let mut acc = 0.0;
794            acc += d[1] * self.g[i];
795            out.g[i] = acc;
796        }
797        for i in 0..K {
798            for j in 0..K {
799                let mut acc = 0.0;
800                acc += d[1] * self.h[i][j];
801                acc += d[2] * self.g[i] * self.g[j];
802                out.h[i][j] = acc;
803            }
804        }
805        out
806    }
807
808    /// Multiply every channel by a plain scalar.
809    pub fn scale(&self, s: f64) -> Self {
810        let mut out = *self;
811        out.v *= s;
812        for i in 0..K {
813            out.g[i] *= s;
814            for j in 0..K {
815                out.h[i][j] *= s;
816            }
817        }
818        out
819    }
820
821    /// e^self.
822    pub fn exp(&self) -> Self {
823        let e = self.v.exp();
824        self.compose_unary([e, e, e])
825    }
826
827    /// √self. Caller guarantees positivity.
828    pub fn sqrt(&self) -> Self {
829        let u = self.v;
830        let s = u.sqrt();
831        self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
832    }
833}
834
835impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
836    #[inline]
837    fn derivative(&self, labels: &[usize]) -> f64 {
838        self.deriv(labels)
839    }
840
841    fn map_derivatives<F>(&self, mut f: F) -> Self
842    where
843        F: FnMut(&[usize]) -> f64,
844    {
845        let mut out = Self::zero();
846        out.v = f(&[]);
847        for i in 0..K {
848            let labels = [i];
849            out.g[i] = f(&labels);
850        }
851        for i in 0..K {
852            for j in 0..K {
853                let labels = [i, j];
854                out.h[i][j] = f(&labels);
855            }
856        }
857        out
858    }
859}
860
861impl<const K: usize> std::ops::Add for Tower2<K> {
862    type Output = Self;
863    fn add(self, o: Self) -> Self {
864        let mut out = self;
865        out.v += o.v;
866        for i in 0..K {
867            out.g[i] += o.g[i];
868            for j in 0..K {
869                out.h[i][j] += o.h[i][j];
870            }
871        }
872        out
873    }
874}
875
876impl<const K: usize> std::ops::Mul for Tower2<K> {
877    type Output = Self;
878    fn mul(self, o: Self) -> Self {
879        Tower2::mul(&self, &o)
880    }
881}
882
883impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
884    type Output = Self;
885    fn add(self, c: f64) -> Self {
886        let mut out = self;
887        out.v += c;
888        out
889    }
890}
891
892impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
893    type Output = Self;
894    fn mul(self, c: f64) -> Self {
895        self.scale(c)
896    }
897}
898
899/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
900///
901/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
902/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
903/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
904/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
905/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
906/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
907/// [`Tower4::compose_unary`]). So for any program written over both towers the
908/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
909/// perturb the value, gradient, Hessian, or third derivatives.
910///
911/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
912/// needs up to third derivatives (the survival location-scale row kernel reads
913/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
914/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
915/// product/composition that otherwise dominates the per-row cost.
916#[derive(Clone, Copy, Debug)]
917pub struct Tower3<const K: usize> {
918    /// Value ℓ.
919    pub v: f64,
920    /// Gradient ∂ℓ/∂p_a.
921    pub g: [f64; K],
922    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
923    pub h: [[f64; K]; K],
924    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
925    pub t3: [[[f64; K]; K]; K],
926}
927
928impl<const K: usize> Tower3<K> {
929    /// The additive identity.
930    pub fn zero() -> Self {
931        Self {
932            v: 0.0,
933            g: [0.0; K],
934            h: [[0.0; K]; K],
935            t3: [[[0.0; K]; K]; K],
936        }
937    }
938
939    /// A constant: value `c`, all derivatives zero.
940    pub fn constant(c: f64) -> Self {
941        let mut out = Self::zero();
942        out.v = c;
943        out
944    }
945
946    /// The seeded variable `p_idx` with current value `value`:
947    /// unit first derivative in slot `idx`, zero elsewhere and above.
948    pub fn variable(value: f64, idx: usize) -> Self {
949        let mut out = Self::constant(value);
950        out.g[idx] = 1.0;
951        out
952    }
953
954    /// Read the (fully symmetric) derivative tensor entry whose differentiation
955    /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
956    #[inline]
957    fn deriv(&self, labels: &[usize]) -> f64 {
958        assert!(
959            labels.len() <= 3,
960            "Tower3 carries at most third-order derivatives"
961        );
962        match labels.len() {
963            0 => self.v,
964            1 => self.g[labels[0]],
965            2 => self.h[labels[0]][labels[1]],
966            _ => self.t3[labels[0]][labels[1]][labels[2]],
967        }
968    }
969
970    /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
971    /// channels match [`Tower4::mul`] term-for-term.
972    ///
973    /// # Codegen
974    ///
975    /// Straight-line per-entry subset sums instead of the
976    /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
977    /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
978    /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
979    /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
980    /// signed-zero leading product on exact-`0.0` jet channels). Proven
981    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
982    /// zero/sign-stressed inputs each (these channel formulas are exactly the
983    /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
984    pub fn mul(&self, o: &Self) -> Self {
985        let a = self;
986        let b = o;
987        let mut out = Self::zero();
988        out.v = a.v * b.v;
989        for i in 0..K {
990            let mut acc = 0.0;
991            acc += a.v * b.g[i];
992            acc += a.g[i] * b.v;
993            out.g[i] = acc;
994        }
995        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
996        for i in 0..K {
997            for j in i..K {
998                let mut acc = 0.0;
999                acc += a.v * b.h[i][j];
1000                acc += a.g[i] * b.g[j];
1001                acc += a.g[j] * b.g[i];
1002                acc += a.h[i][j] * b.v;
1003                out.h[i][j] = acc;
1004                out.h[j][i] = acc;
1005            }
1006        }
1007        for i in 0..K {
1008            for j in 0..K {
1009                for k in 0..K {
1010                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
1011                    let mut acc = 0.0;
1012                    acc += a.v * b.t3[i][j][k];
1013                    acc += a.g[i] * b.h[j][k];
1014                    acc += a.g[j] * b.h[i][k];
1015                    acc += a.h[i][j] * b.g[k];
1016                    acc += a.g[k] * b.h[i][j];
1017                    acc += a.h[i][k] * b.g[j];
1018                    acc += a.h[j][k] * b.g[i];
1019                    acc += a.t3[i][j][k] * b.v;
1020                    out.t3[i][j][k] = acc;
1021                }
1022            }
1023        }
1024        out
1025    }
1026
1027    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
1028    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
1029    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
1030    /// out of the borrowed operands.
1031    pub fn add(&self, o: &Self) -> Self {
1032        *self + *o
1033    }
1034
1035    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1036    pub fn sub(&self, o: &Self) -> Self {
1037        *self + o.scale(-1.0)
1038    }
1039
1040    /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1041    ///
1042    /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1043    /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1044    /// (which uses only `d[0..=3]` for those orders), so this is a strict
1045    /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1046    /// stacks the families already produce can be passed by slicing their first
1047    /// four entries.
1048    ///
1049    /// # Codegen
1050    ///
1051    /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1052    /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1053    /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1054    /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1055    /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1056    /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1057    /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1058    /// associated block products, `acc = 0.0` accumulator start. Proven
1059    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1060    /// random inputs each.
1061    pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1062        let mut out = Self::zero();
1063        out.v = d[0];
1064        for i in 0..K {
1065            let mut acc = 0.0;
1066            acc += d[1] * self.g[i];
1067            out.g[i] = acc;
1068        }
1069        for i in 0..K {
1070            for j in 0..K {
1071                let mut acc = 0.0;
1072                acc += d[1] * self.h[i][j];
1073                acc += d[2] * self.g[i] * self.g[j];
1074                out.h[i][j] = acc;
1075            }
1076        }
1077        for i in 0..K {
1078            for j in 0..K {
1079                for k in 0..K {
1080                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1081                    let mut acc = 0.0;
1082                    acc += d[1] * self.t3[i][j][k];
1083                    acc += d[2] * self.h[i][j] * self.g[k];
1084                    acc += d[2] * self.h[i][k] * self.g[j];
1085                    acc += d[2] * self.g[i] * self.h[j][k];
1086                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1087                    out.t3[i][j][k] = acc;
1088                }
1089            }
1090        }
1091        out
1092    }
1093
1094    /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1095    /// built from the base value through `stack_fn` — the scalar arm of the
1096    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1097    /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1098    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1099    /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1100    /// [`Tower4::compose_unary_with`].
1101    #[inline]
1102    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1103        self.compose_unary(stack_fn(self.v))
1104    }
1105
1106    /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1107    /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1108    /// derivative support only on the all-`slot` diagonal, every output channel
1109    /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1110    /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1111    /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1112    /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1113    /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1114    /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1115    /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1116    /// set-partition walker the diagonal channels previously routed through,
1117    /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1118    /// widths. Caller guarantees the single-slot precondition; otherwise use
1119    /// [`Self::compose_unary`].
1120    #[inline]
1121    pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1122        let mut out = Self::zero();
1123        let s = slot;
1124        let g = self.g[s];
1125        let h = self.h[s][s];
1126        let t3 = self.t3[s][s][s];
1127        out.v = d[0];
1128        // g (i=s): d1*g
1129        out.g[s] = {
1130            let mut acc = 0.0;
1131            acc += d[1] * g;
1132            acc
1133        };
1134        // h (i=j=s): d1*h + d2*g*g
1135        out.h[s][s] = {
1136            let mut acc = 0.0;
1137            acc += d[1] * h;
1138            acc += d[2] * g * g;
1139            acc
1140        };
1141        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1142        out.t3[s][s][s] = {
1143            let mut acc = 0.0;
1144            acc += d[1] * t3;
1145            acc += d[2] * h * g;
1146            acc += d[2] * h * g;
1147            acc += d[2] * g * h;
1148            acc += d[3] * g * g * g;
1149            acc
1150        };
1151        out
1152    }
1153
1154    /// Multiply every channel by a plain scalar.
1155    pub fn scale(&self, s: f64) -> Self {
1156        let mut out = *self;
1157        out.v *= s;
1158        for i in 0..K {
1159            out.g[i] *= s;
1160            for j in 0..K {
1161                out.h[i][j] *= s;
1162                for k in 0..K {
1163                    out.t3[i][j][k] *= s;
1164                }
1165            }
1166        }
1167        out
1168    }
1169}
1170
1171impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1172    #[inline]
1173    fn derivative(&self, labels: &[usize]) -> f64 {
1174        self.deriv(labels)
1175    }
1176
1177    fn map_derivatives<F>(&self, mut f: F) -> Self
1178    where
1179        F: FnMut(&[usize]) -> f64,
1180    {
1181        let mut out = Self::zero();
1182        out.v = f(&[]);
1183        for i in 0..K {
1184            let labels = [i];
1185            out.g[i] = f(&labels);
1186        }
1187        for i in 0..K {
1188            for j in 0..K {
1189                let labels = [i, j];
1190                out.h[i][j] = f(&labels);
1191            }
1192        }
1193        for i in 0..K {
1194            for j in 0..K {
1195                for k in 0..K {
1196                    let labels = [i, j, k];
1197                    out.t3[i][j][k] = f(&labels);
1198                }
1199            }
1200        }
1201        out
1202    }
1203}
1204
1205impl<const K: usize> std::ops::Add for Tower3<K> {
1206    type Output = Self;
1207    fn add(self, o: Self) -> Self {
1208        let mut out = self;
1209        out.v += o.v;
1210        for i in 0..K {
1211            out.g[i] += o.g[i];
1212            for j in 0..K {
1213                out.h[i][j] += o.h[i][j];
1214                for k in 0..K {
1215                    out.t3[i][j][k] += o.t3[i][j][k];
1216                }
1217            }
1218        }
1219        out
1220    }
1221}
1222
1223pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1224    [
1225        statrs::function::gamma::ln_gamma(x),
1226        digamma_positive(x),
1227        polygamma_positive(1, x),
1228        polygamma_positive(2, x),
1229        polygamma_positive(3, x),
1230    ]
1231}
1232
1233pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1234    [
1235        statrs::function::gamma::ln_gamma(x),
1236        digamma_positive(x),
1237        polygamma_positive(1, x),
1238    ]
1239}
1240
1241pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1242    [
1243        digamma_positive(x),
1244        polygamma_positive(1, x),
1245        polygamma_positive(2, x),
1246        polygamma_positive(3, x),
1247        polygamma_positive(4, x),
1248    ]
1249}
1250
1251pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1252    [
1253        polygamma_positive(1, x),
1254        polygamma_positive(2, x),
1255        polygamma_positive(3, x),
1256        polygamma_positive(4, x),
1257        polygamma_positive(5, x),
1258    ]
1259}
1260
1261/// Scalar digamma ψ(x) for x>0. Bit-identical to `digamma_derivative_stack(x)[0]`
1262/// and to `ln_gamma_derivative_stack(x)[1]`, but evaluates ONLY ψ — the four
1263/// higher polygammas those `[f64; 5]` stacks build are pure discarded work at a
1264/// scalar consumer that reads a single element. Hot-path row kernels that need
1265/// only the digamma value (e.g. the GAMLSS Beta observed cross weight) call this
1266/// instead of indexing `[0]` off a full derivative stack.
1267#[inline]
1268pub fn digamma(x: f64) -> f64 {
1269    digamma_positive(x)
1270}
1271
1272/// Scalar trigamma ψ′(x) for x>0. Bit-identical to
1273/// `trigamma_derivative_stack(x)[0]` (both bottom out in `polygamma_positive(1,
1274/// x)`), but evaluates ONLY ψ′ — the four higher polygammas (orders 2–5) the
1275/// `[f64; 5]` stack builds are discarded at a `[0]` consumer. Used by the
1276/// dispersion-channel Fisher-information row kernels (NB2 `ψ′(θ)−ψ′(θ+μ)`, Beta
1277/// `μψ′(μφ)−(1−μ)ψ′((1−μ)φ)`) which read the trigamma value alone.
1278#[inline]
1279pub fn trigamma(x: f64) -> f64 {
1280    polygamma_positive(1, x)
1281}
1282
1283fn digamma_positive(mut x: f64) -> f64 {
1284    if !(x.is_finite() && x > 0.0) {
1285        return f64::NAN;
1286    }
1287    let mut acc = 0.0;
1288    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1289        acc -= 1.0 / x;
1290        x += 1.0;
1291    }
1292    acc + digamma_asymptotic(x)
1293}
1294
1295fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1296    if !(x.is_finite() && x > 0.0) {
1297        return f64::NAN;
1298    }
1299    let mut acc = 0.0;
1300    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1301        acc += polygamma_recurrence_term(order, x);
1302        x += 1.0;
1303    }
1304    acc + polygamma_asymptotic(order, x)
1305}
1306
1307const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1308const BERNOULLI_EVEN: [(usize, f64); 10] = [
1309    (2, 1.0 / 6.0),
1310    (4, -1.0 / 30.0),
1311    (6, 1.0 / 42.0),
1312    (8, -1.0 / 30.0),
1313    (10, 5.0 / 66.0),
1314    (12, -691.0 / 2730.0),
1315    (14, 7.0 / 6.0),
1316    (16, -3617.0 / 510.0),
1317    (18, 43867.0 / 798.0),
1318    (20, -174611.0 / 330.0),
1319];
1320
1321fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1322    let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1323    sign * factorial(order) / x.powi((order + 1) as i32)
1324}
1325
1326fn digamma_asymptotic(x: f64) -> f64 {
1327    let mut out = x.ln() - 0.5 / x;
1328    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1329        out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1330    }
1331    out
1332}
1333
1334fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1335    if !(1..=5).contains(&order) {
1336        return f64::NAN;
1337    }
1338
1339    let order_factorial = factorial(order);
1340    let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1341    let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1342        + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1343
1344    let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1345    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1346        let rising = rising_factorial(bernoulli_order, order);
1347        out += bernoulli_sign * bernoulli * rising
1348            / bernoulli_order as f64
1349            / x.powi((bernoulli_order + order) as i32);
1350    }
1351    out
1352}
1353
1354fn factorial(n: usize) -> f64 {
1355    (1..=n).fold(1.0, |acc, k| acc * k as f64)
1356}
1357
1358fn rising_factorial(start: usize, len: usize) -> f64 {
1359    (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1360}
1361
1362impl<const K: usize> std::ops::Add for Tower4<K> {
1363    type Output = Self;
1364    fn add(self, o: Self) -> Self {
1365        let mut out = self;
1366        out.v += o.v;
1367        for i in 0..K {
1368            out.g[i] += o.g[i];
1369            for j in 0..K {
1370                out.h[i][j] += o.h[i][j];
1371                for k in 0..K {
1372                    out.t3[i][j][k] += o.t3[i][j][k];
1373                    for l in 0..K {
1374                        out.t4[i][j][k][l] += o.t4[i][j][k][l];
1375                    }
1376                }
1377            }
1378        }
1379        out
1380    }
1381}
1382
1383impl<const K: usize> std::ops::Sub for Tower4<K> {
1384    type Output = Self;
1385    fn sub(self, o: Self) -> Self {
1386        self + o.scale(-1.0)
1387    }
1388}
1389
1390impl<const K: usize> std::ops::Neg for Tower4<K> {
1391    type Output = Self;
1392    fn neg(self) -> Self {
1393        self.scale(-1.0)
1394    }
1395}
1396
1397impl<const K: usize> std::ops::Mul for Tower4<K> {
1398    type Output = Self;
1399    fn mul(self, o: Self) -> Self {
1400        Tower4::mul(&self, &o)
1401    }
1402}
1403
1404impl<const K: usize> std::ops::Div for Tower4<K> {
1405    type Output = Self;
1406    fn div(self, o: Self) -> Self {
1407        Tower4::mul(&self, &o.recip())
1408    }
1409}
1410
1411impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1412    type Output = Self;
1413    fn add(self, c: f64) -> Self {
1414        let mut out = self;
1415        out.v += c;
1416        out
1417    }
1418}
1419
1420impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1421    type Output = Self;
1422    fn sub(self, c: f64) -> Self {
1423        self + (-c)
1424    }
1425}
1426
1427impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1428    type Output = Self;
1429    fn mul(self, c: f64) -> Self {
1430        self.scale(c)
1431    }
1432}
1433
1434// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1435//
1436// The flexible survival marginal-slope row loss is NOT a free composition
1437// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1438// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1439// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1440// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1441// last hand-written one in the codebase, and the genus of #736-class
1442// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1443// this reason). These two combinators close that gap: once the constraint
1444// `F` and the integrand/boundaries are themselves towers, the intercept's
1445// derivative tower and the integral's derivative tower come out EXACTLY at
1446// every order — there is no order left to hand-code and forget.
1447
1448/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1449/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1450/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1451/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1452/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1453/// and `a0` is that solved intercept value.
1454///
1455/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1456/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1457/// derivative. This is the mechanical replacement for the hand-coded
1458/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1459/// recursion (first_full.rs) and its third/fourth-order continuations.
1460///
1461/// Method: order-by-order substitution. We build `a` incrementally; at each
1462/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1463/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1464/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1465/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1466/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1467/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1468/// auditable and exact, not a hand-expanded formula per order.
1469///
1470/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1471/// solve's strict monotonicity guard.
1472///
1473/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1474/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1475/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1476/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1477/// is guarded explicitly and re-verified by a composed-residual self-check.
1478pub fn implicit_solve<const K1: usize, const K: usize>(
1479    f: &Tower4<K1>,
1480    a0: f64,
1481) -> Result<Tower4<K>, String> {
1482    assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1483    let f_a = f.g[0];
1484    if f_a == 0.0 || !f_a.is_finite() {
1485        return Err(format!(
1486            "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1487        ));
1488    }
1489    // The expansion point must be a genuine root of F. The single Newton
1490    // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1491    // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1492    // Newton step (rather than f.v directly) makes the criterion invariant to
1493    // the magnitude of f_a / the units of F.
1494    let root_tol = 1e-9;
1495    if !f.v.is_finite() {
1496        return Err(format!(
1497            "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1498            f.v
1499        ));
1500    }
1501    let newton_step = f.v.abs() / f_a.abs();
1502    if newton_step > root_tol * (1.0 + a0.abs()) {
1503        return Err(format!(
1504            "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1505             F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1506             root_tol {root_tol:.1e} · (1 + |a0|)",
1507            f.v
1508        ));
1509    }
1510    // Start with a = constant a0 (correct through order 0). Then lift each
1511    // order in turn. Because substitute_intercept reads `a`'s order-≤m
1512    // tensors when forming G's order-m coefficient, and the order-m
1513    // coefficient of G depends on a's order-m tensor ONLY through the linear
1514    // F_a·a_m term, a single corrective pass per order is exact.
1515    let mut a = Tower4::<K>::constant(a0);
1516    for order in 1..=4 {
1517        let g = substitute_intercept(f, &a);
1518        // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1519        // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1520        // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1521        // as a plain variable in the substitution's first-order part).
1522        match order {
1523            1 => {
1524                for i in 0..K {
1525                    a.g[i] -= g.g[i] / f_a;
1526                }
1527            }
1528            2 => {
1529                for i in 0..K {
1530                    for j in 0..K {
1531                        a.h[i][j] -= g.h[i][j] / f_a;
1532                    }
1533                }
1534            }
1535            3 => {
1536                for i in 0..K {
1537                    for j in 0..K {
1538                        for k in 0..K {
1539                            a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1540                        }
1541                    }
1542                }
1543            }
1544            _ => {
1545                for i in 0..K {
1546                    for j in 0..K {
1547                        for k in 0..K {
1548                            for l in 0..K {
1549                                a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1550                            }
1551                        }
1552                    }
1553                }
1554            }
1555        }
1556    }
1557    // Self-check: the composed residual G = F∘a must vanish through order 4.
1558    // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1559    // is exactly the root requirement guarded above. Re-verify all channels
1560    // against a scale-aware floor so any arithmetic regression in the
1561    // substitution recursion is loud rather than silently shipping a
1562    // level-set expansion.
1563    let g = substitute_intercept(f, &a);
1564    let resid_tol = 1e-7 * (1.0 + f_a.abs());
1565    let mut worst = g.v.abs();
1566    for i in 0..K {
1567        worst = worst.max(g.g[i].abs());
1568        for j in 0..K {
1569            worst = worst.max(g.h[i][j].abs());
1570            for k in 0..K {
1571                worst = worst.max(g.t3[i][j][k].abs());
1572                for l in 0..K {
1573                    worst = worst.max(g.t4[i][j][k][l].abs());
1574                }
1575            }
1576        }
1577    }
1578    if !worst.is_finite() || worst > resid_tol {
1579        return Err(format!(
1580            "implicit_solve: composed residual G = F∘a does not vanish: \
1581             worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1582        ));
1583    }
1584    Ok(a)
1585}
1586
1587/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1588/// written over `K + 1` variables, returning the composite tower over the
1589/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1590///
1591/// This is the exact multivariate chain rule specialised to "slot 0 is a
1592/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1593/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1594/// point, with the slot-0 increment being the non-constant part of `a` and
1595/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1596/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1597/// so it carries derivatives exactly through order four.
1598pub fn substitute_intercept<const K1: usize, const K: usize>(
1599    f: &Tower4<K1>,
1600    a: &Tower4<K>,
1601) -> Tower4<K> {
1602    assert_eq!(K1, K + 1);
1603    // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1604    // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1605    // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1606    // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1607    // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1608    // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1609    // using tower products for the increment monomials so all θ-derivatives
1610    // propagate exactly.
1611    let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1612        if slot == 0 {
1613            // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1614            let mut d = *a;
1615            d.v = 0.0;
1616            d
1617        } else {
1618            // slot i: the increment δθ_{i-1} = seeded variable minus value.
1619            // θ centred at its expansion value has zero constant term and unit
1620            // first derivative in its own slot.
1621            let mut d = Tower4::<K>::zero();
1622            d.g[slot - 1] = 1.0;
1623            d
1624        }
1625    });
1626    // Accumulate the Taylor sum. order-0 term:
1627    let mut out = Tower4::<K>::constant(f.v);
1628    // order 1: Σ_a f.g[a] · inp[a]
1629    for a_idx in 0..K1 {
1630        out = out + inp[a_idx].scale(f.g[a_idx]);
1631    }
1632    // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1633    for a_idx in 0..K1 {
1634        for b_idx in 0..K1 {
1635            let prod = inp[a_idx].mul(&inp[b_idx]);
1636            out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1637        }
1638    }
1639    // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1640    for a_idx in 0..K1 {
1641        for b_idx in 0..K1 {
1642            for c_idx in 0..K1 {
1643                let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1644                out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1645            }
1646        }
1647    }
1648    // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1649    for a_idx in 0..K1 {
1650        for b_idx in 0..K1 {
1651            for c_idx in 0..K1 {
1652                for d_idx in 0..K1 {
1653                    let prod = inp[a_idx]
1654                        .mul(&inp[b_idx])
1655                        .mul(&inp[c_idx])
1656                        .mul(&inp[d_idx]);
1657                    out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1658                }
1659            }
1660        }
1661    }
1662    out
1663}
1664
1665/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1666/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1667/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1668/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1669/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1670///
1671/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1672/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1673/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1674/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1675/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1676/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1677/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1678/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1679/// to a constant); callers read only the derivative channels and pair this
1680/// with the interior moment-integral value separately.
1681///
1682/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1683/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1684/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1685pub fn moving_limit_boundary_tower<const K: usize>(
1686    z_edge: &Tower4<K>,
1687    b_stack: [f64; 4],
1688) -> Tower4<K> {
1689    z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1690}
1691
1692/// The boundary-flux derivative tower of a single moving cell integral
1693/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1694/// two edge towers and the integrand stacks at each edge. The returned
1695/// tower's derivative channels are the EXACT moving-boundary contribution to
1696/// every θ-derivative of the cell integral, to fourth order, with no term
1697/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1698/// derivative channels are all zero, contributing nothing — matching the
1699/// production `edge_vel = 0` short-circuit.
1700pub fn cell_moving_boundary_flux_tower<const K: usize>(
1701    z_right: &Tower4<K>,
1702    b_stack_right: [f64; 4],
1703    z_left: &Tower4<K>,
1704    b_stack_left: [f64; 4],
1705) -> Tower4<K> {
1706    moving_limit_boundary_tower(z_right, b_stack_right)
1707        - moving_limit_boundary_tower(z_left, b_stack_left)
1708}
1709
1710/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1711///
1712/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1713/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1714/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1715/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1716/// integrand coefficients move with η, hence with the primaries), so the
1717/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1718/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1719/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1720/// edge-motion-carrying terms the hand path assembles one by one, including the
1721/// `G·z_uv` term the directional path drops).
1722///
1723/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1724/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1725/// part — everything carrying edge motion — is exactly
1726///   `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1727/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1728/// interior moment integral already supplies. Both are one
1729/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1730/// θ in slots 1..K): substituting the edge tower gives the full composite,
1731/// substituting a frozen constant edge isolates the pure-θ part, and their
1732/// difference is the exact boundary flux — every Leibniz term derived by the
1733/// substitution algebra, none hand-omitted.
1734///
1735/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1736/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1737/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1738/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1739/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1740/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1741/// the derivative channels are meaningful, matching the value-less convention of
1742/// [`moving_limit_boundary_tower`].
1743pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1744    phi_jet: &Tower4<K1>,
1745    z_edge: &Tower4<K>,
1746) -> Tower4<K> {
1747    assert_eq!(
1748        K1,
1749        K + 1,
1750        "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1751    );
1752    let frozen_edge = Tower4::<K>::constant(z_edge.v);
1753    let full = substitute_intercept(phi_jet, z_edge);
1754    let interior = substitute_intercept(phi_jet, &frozen_edge);
1755    full - interior
1756}
1757
1758/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1759/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1760/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1761/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1762/// so its `full` and `interior` substitutions coincide and it contributes
1763/// nothing — matching the production `edge_vel = 0` short-circuit.
1764pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1765    phi_jet_right: &Tower4<K1>,
1766    z_right: &Tower4<K>,
1767    phi_jet_left: &Tower4<K1>,
1768    z_left: &Tower4<K>,
1769) -> Tower4<K> {
1770    moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1771        - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1772}
1773
1774// ── The program seam ─────────────────────────────────────────────────
1775
1776/// A family's row negative log-likelihood written ONCE over tower scalars.
1777///
1778/// This is the single source of truth #932 asks for: the value channel of
1779/// the returned tower must BE the production row NLL (same branches, same
1780/// guards, same numerics), and every derivative channel is then exact by
1781/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1782/// NOT part of this trait — it is family data, not calculus, and stays on
1783/// the `RowKernel` implementor.
1784pub trait RowNllProgram<const K: usize>: Send + Sync {
1785    /// Number of observations the program covers.
1786    fn n_rows(&self) -> usize;
1787
1788    /// Current primary-scalar values for `row` (where to seed the tower).
1789    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1790
1791    /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1792    /// variable `a` at the current primary value; implementations combine
1793    /// them with `Tower4` arithmetic and per-row data (response, censoring
1794    /// indicators, offsets) entering as constants.
1795    fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1796}
1797
1798/// Evaluate a program's full tower at the current primaries for one row.
1799///
1800/// One call yields every `RowKernel` calculus channel; callers that need
1801/// several contractions of the same row should hold the returned tower and
1802/// contract repeatedly rather than re-evaluating.
1803pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1804    prog: &P,
1805    row: usize,
1806) -> Result<Tower4<K>, String> {
1807    let p = prog.primaries(row)?;
1808    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1809    prog.row_nll(row, &vars)
1810}
1811
1812/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1813pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1814    prog: &P,
1815    row: usize,
1816) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1817    let t = evaluate_program(prog, row)?;
1818    Ok((t.v, t.g, t.h))
1819}
1820
1821/// Mechanically derived `row_third_contracted` channel.
1822pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1823    prog: &P,
1824    row: usize,
1825    dir: &[f64; K],
1826) -> Result<[[f64; K]; K], String> {
1827    Ok(evaluate_program(prog, row)?.third_contracted(dir))
1828}
1829
1830/// Mechanically derived `row_fourth_contracted` channel.
1831pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1832    prog: &P,
1833    row: usize,
1834    dir_u: &[f64; K],
1835    dir_v: &[f64; K],
1836) -> Result<[[f64; K]; K], String> {
1837    Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1838}
1839
1840// ── The generic program seam (#932 scalar cutover) ───────────────────
1841
1842/// A family's row negative log-likelihood written ONCE over the generic
1843/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1844/// re-instantiated at whatever order / representation a consumer needs
1845/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1846/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1847/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1848/// [`Tower4`] for every channel at once).
1849///
1850/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1851/// program implementing this generic trait gets the small contracted scalars for
1852/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1853/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1854/// work unchanged; new families should prefer this generic trait.
1855///
1856/// Because a `Tower4`-specialised `row_nll` body uses only
1857/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1858/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1859/// A program written that way needs no `Tower4`-specialised method and routes
1860/// the directional and joint-Hessian gates through the contracted scalars from
1861/// a single definition.
1862pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1863    /// Number of observations the program covers.
1864    fn n_rows(&self) -> usize;
1865
1866    /// Current primary-scalar values for `row` (where to seed the scalar).
1867    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1868
1869    /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1870    /// (base value + per-scalar nilpotent directions) by the caller; the body
1871    /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1872    /// (response, censoring, offsets) entering as constants.
1873    fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1874        &self,
1875        row: usize,
1876        p: &[S; K],
1877    ) -> Result<S, String>;
1878}
1879
1880/// Evaluate a generic program at the value/gradient/Hessian scalar
1881/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1882/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1883///
1884/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1885/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1886/// cheap order-2 scalar. The single source of truth means the gradient and
1887/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1888pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1889    prog: &P,
1890    row: usize,
1891) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1892    let base = prog.primaries(row)?;
1893    let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1894        <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1895    });
1896    let s = prog.row_nll_generic(row, &vars)?;
1897    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1898}
1899
1900/// Evaluate a generic program at the one-seed scalar
1901/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1902/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1903/// materialising the dense `t3` tensor. The contraction direction is folded
1904/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1905pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1906    prog: &P,
1907    row: usize,
1908    dir: &[f64; K],
1909) -> Result<[[f64; K]; K], String> {
1910    let base = prog.primaries(row)?;
1911    let vars: [crate::jet_scalar::OneSeed<K>; K] =
1912        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1913    let s = prog.row_nll_generic(row, &vars)?;
1914    Ok(s.contracted_third())
1915}
1916
1917/// Evaluate a generic program at the two-seed scalar
1918/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1919/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1920/// WITHOUT materialising the dense `t4` tensor.
1921pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1922    prog: &P,
1923    row: usize,
1924    dir_u: &[f64; K],
1925    dir_v: &[f64; K],
1926) -> Result<[[f64; K]; K], String> {
1927    let base = prog.primaries(row)?;
1928    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1929        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1930    let s = prog.row_nll_generic(row, &vars)?;
1931    Ok(s.contracted_fourth())
1932}
1933
1934/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1935/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1936/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1937/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1938/// the order-2 / contracted scalars consume, so there is a single source of
1939/// truth across every channel.
1940pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1941    prog: &P,
1942    row: usize,
1943) -> Result<Tower4<K>, String> {
1944    let base = prog.primaries(row)?;
1945    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1946    prog.row_nll_generic(row, &vars)
1947}
1948
1949// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1950//
1951// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1952// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1953// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1954// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1955// as separate inherent methods, not a shared trait bound — so a row-NLL body
1956// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1957// 4-rows-per-pass SIMD batch path could not reuse the single source.
1958//
1959// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1960// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1961// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1962// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1963// (so the scalar call sites compile unchanged and bit-identically), and explicit
1964// impls route the `f64x4` lane towers through their existing per-lane methods. A
1965// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1966// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1967
1968/// The verdict of a per-lane [`RowJet::guard`] domain check.
1969///
1970/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1971/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1972/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1973/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1974/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1975/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1976#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1977pub struct GuardVerdict {
1978    lanes: u8,
1979    failed_mask: u8,
1980}
1981
1982impl GuardVerdict {
1983    /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1984    #[inline]
1985    pub fn scalar(pass: bool) -> Self {
1986        Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1987    }
1988    /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1989    #[inline]
1990    pub fn lanes4(failed_mask: u8) -> Self {
1991        Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1992    }
1993    /// Number of active lanes inspected (1 scalar, 4 batch).
1994    #[inline]
1995    pub fn lanes(self) -> usize {
1996        self.lanes as usize
1997    }
1998    /// True iff every inspected lane satisfied the predicate.
1999    #[inline]
2000    pub fn all_pass(self) -> bool {
2001        self.failed_mask == 0
2002    }
2003    /// True iff at least one inspected lane failed the predicate.
2004    #[inline]
2005    pub fn any_failed(self) -> bool {
2006        self.failed_mask != 0
2007    }
2008    /// True iff lane `i` failed the predicate.
2009    #[inline]
2010    pub fn lane_failed(self, i: usize) -> bool {
2011        (self.failed_mask >> i) & 1 == 1
2012    }
2013    /// The raw failure mask (bit `i` ⇒ lane `i` failed).
2014    #[inline]
2015    pub fn failed_mask(self) -> u8 {
2016        self.failed_mask
2017    }
2018}
2019
2020/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
2021/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
2022/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
2023/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
2024/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
2025/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
2026/// common `N == 5` case is bit-identical to passing the stack straight through).
2027#[inline]
2028fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
2029    let mut out = [0.0_f64; M];
2030    let m = N.min(M);
2031    out[..m].copy_from_slice(&s[..m]);
2032    out
2033}
2034
2035/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
2036/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
2037/// rows/pass without a dual-source copy (module §"The RowJet bridge").
2038///
2039/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
2040/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
2041/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
2042/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
2043/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
2044pub trait RowJet<const K: usize>: Copy {
2045    /// The value channel(s) seen by [`guard`](Self::guard) and
2046    /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
2047    /// `f64x4` lane tower.
2048    type Value: Copy;
2049
2050    /// A constant (value `c`, all derivatives zero), broadcast to every lane.
2051    fn constant(c: f64) -> Self;
2052    /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
2053    /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
2054    /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
2055    /// [`generic_batched_third_tower`]), which build the tower variables directly
2056    /// from each row's primaries; this method is for any row-invariant auxiliary
2057    /// variable a body introduces.
2058    fn variable(x: f64, slot: usize) -> Self;
2059    /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2060    fn values(&self) -> Self::Value;
2061
2062    /// Truncated Leibniz `self + o`.
2063    fn add(&self, o: &Self) -> Self;
2064    /// Truncated Leibniz `self − o`.
2065    fn sub(&self, o: &Self) -> Self;
2066    /// Truncated Leibniz `self · o`.
2067    fn mul(&self, o: &Self) -> Self;
2068    /// Multiply every channel by the plain scalar `s`.
2069    fn scale(&self, s: f64) -> Self;
2070    /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2071    /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2072    fn neg(&self) -> Self {
2073        self.scale(-1.0)
2074    }
2075
2076    /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2077    /// derivative stack is built from the running base value PER LANE through
2078    /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2079    /// inherent method that already exists on both the scalar towers and the lane
2080    /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2081    /// lane tower it is re-run per lane (the four rows carry four distinct base
2082    /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2083    /// Making it a trait method is precisely what lets a body written once over
2084    /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2085    /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2086    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2087
2088    /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2089    /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2090    /// its one value; a lane tower checks all four. Lets a batched program detect
2091    /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2092    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2093
2094    /// Per-lane scale: multiply every channel by the per-lane factor `s`
2095    /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2096    /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2097    /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2098    /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2099    /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2100    /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2101    /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2102    /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2103    /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2104    fn scale_rows(&self, s: Self::Value) -> Self;
2105
2106    /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2107    /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2108    /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2109    /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2110    /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2111    /// knowing the concrete representation: the program holds the per-row data and
2112    /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2113    /// [`RowNllProgramRowJet::row_nll`], so the body writes
2114    /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2115    /// buried in a `compose_unary_with` stack is pulled out the same way:
2116    /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2117    /// (Binary per-row branches such as the event indicator `di` are kept
2118    /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2119    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2120
2121    // ── value-derived transcendental conveniences ───────────────────────
2122    // Each routes through `compose_unary_with` with the SAME derivative stack the
2123    // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2124    // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2125    // `i` is bit-identical to the scalar result on row `i`.
2126
2127    /// `e^self`.
2128    fn exp(&self) -> Self {
2129        self.compose_unary_with(|u| {
2130            let e = u.exp();
2131            [e, e, e, e, e]
2132        })
2133    }
2134    /// `ln(self)`. Caller guarantees positivity.
2135    fn ln(&self) -> Self {
2136        self.compose_unary_with(|u| {
2137            let r = 1.0 / u;
2138            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2139        })
2140    }
2141    /// `√self`. Caller guarantees positivity.
2142    fn sqrt(&self) -> Self {
2143        self.compose_unary_with(|u| {
2144            let s = u.sqrt();
2145            [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2146        })
2147    }
2148    /// `1/self`.
2149    fn recip(&self) -> Self {
2150        self.compose_unary_with(|u| {
2151            let r = 1.0 / u;
2152            let r2 = r * r;
2153            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2154        })
2155    }
2156    /// `self^a` for real `a`. Caller guarantees a positive base.
2157    fn powf(&self, a: f64) -> Self {
2158        self.compose_unary_with(move |u| {
2159            [
2160                u.powf(a),
2161                a * u.powf(a - 1.0),
2162                a * (a - 1.0) * u.powf(a - 2.0),
2163                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2164                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2165            ]
2166        })
2167    }
2168    /// `ln Γ(self)`. Caller guarantees a positive argument.
2169    fn ln_gamma(&self) -> Self {
2170        self.compose_unary_with(ln_gamma_derivative_stack)
2171    }
2172    /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2173    fn digamma(&self) -> Self {
2174        self.compose_unary_with(digamma_derivative_stack)
2175    }
2176}
2177
2178/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2179/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2180/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2181/// bridge adds the lane representation without churning the scalar path. (The
2182/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2183/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2184/// any downstream impl, so the coherence checker proves the impls disjoint.)
2185impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2186    type Value = f64;
2187    #[inline]
2188    fn constant(c: f64) -> Self {
2189        <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2190    }
2191    #[inline]
2192    fn variable(x: f64, slot: usize) -> Self {
2193        <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2194    }
2195    #[inline]
2196    fn values(&self) -> f64 {
2197        crate::jet_scalar::JetScalar::value(self)
2198    }
2199    #[inline]
2200    fn add(&self, o: &Self) -> Self {
2201        crate::jet_scalar::JetScalar::add(self, o)
2202    }
2203    #[inline]
2204    fn sub(&self, o: &Self) -> Self {
2205        crate::jet_scalar::JetScalar::sub(self, o)
2206    }
2207    #[inline]
2208    fn mul(&self, o: &Self) -> Self {
2209        crate::jet_scalar::JetScalar::mul(self, o)
2210    }
2211    #[inline]
2212    fn scale(&self, s: f64) -> Self {
2213        crate::jet_scalar::JetScalar::scale(self, s)
2214    }
2215    #[inline]
2216    fn neg(&self) -> Self {
2217        crate::jet_scalar::JetScalar::neg(self)
2218    }
2219    #[inline]
2220    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2221        crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2222    }
2223    #[inline]
2224    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2225        GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2226    }
2227    #[inline]
2228    fn scale_rows(&self, s: f64) -> Self {
2229        // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2230        // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2231        crate::jet_scalar::JetScalar::scale(self, s)
2232    }
2233    #[inline]
2234    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2235        value_of(rows[0])
2236    }
2237}
2238
2239/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2240/// routing each op through its existing per-lane method. Lane `i` of a batched
2241/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2242/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2243impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2244    type Value = [f64; 4];
2245    #[inline]
2246    fn constant(c: f64) -> Self {
2247        Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2248    }
2249    #[inline]
2250    fn variable(x: f64, slot: usize) -> Self {
2251        Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2252    }
2253    #[inline]
2254    fn values(&self) -> [f64; 4] {
2255        self.v.to_array()
2256    }
2257    #[inline]
2258    fn add(&self, o: &Self) -> Self {
2259        Tower4Lane::add(self, o)
2260    }
2261    #[inline]
2262    fn sub(&self, o: &Self) -> Self {
2263        Tower4Lane::sub(self, o)
2264    }
2265    #[inline]
2266    fn mul(&self, o: &Self) -> Self {
2267        Tower4Lane::mul(self, o)
2268    }
2269    #[inline]
2270    fn scale(&self, s: f64) -> Self {
2271        Tower4Lane::scale(self, s)
2272    }
2273    #[inline]
2274    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2275        Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2276    }
2277    #[inline]
2278    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2279        let vals = self.v.to_array();
2280        let mut mask = 0u8;
2281        for (i, &v) in vals.iter().enumerate() {
2282            if !pred(v) {
2283                mask |= 1 << i;
2284            }
2285        }
2286        GuardVerdict::lanes4(mask)
2287    }
2288    #[inline]
2289    fn scale_rows(&self, s: [f64; 4]) -> Self {
2290        // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2291        // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2292        let sl = wide::f64x4::new(s);
2293        let mut out = *self;
2294        out.v = self.v * sl;
2295        for i in 0..K {
2296            out.g[i] = self.g[i] * sl;
2297            for j in 0..K {
2298                out.h[i][j] = self.h[i][j] * sl;
2299                for k in 0..K {
2300                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2301                    for l in 0..K {
2302                        out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2303                    }
2304                }
2305            }
2306        }
2307        out
2308    }
2309    #[inline]
2310    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2311        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2312    }
2313}
2314
2315/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2316/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2317/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2318/// [`Tower3`] which also carries only up to the third tensor.
2319impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2320    type Value = [f64; 4];
2321    #[inline]
2322    fn constant(c: f64) -> Self {
2323        Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2324    }
2325    #[inline]
2326    fn variable(x: f64, slot: usize) -> Self {
2327        Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2328    }
2329    #[inline]
2330    fn values(&self) -> [f64; 4] {
2331        self.v.to_array()
2332    }
2333    #[inline]
2334    fn add(&self, o: &Self) -> Self {
2335        Tower3Lane::add(self, o)
2336    }
2337    #[inline]
2338    fn sub(&self, o: &Self) -> Self {
2339        Tower3Lane::sub(self, o)
2340    }
2341    #[inline]
2342    fn mul(&self, o: &Self) -> Self {
2343        Tower3Lane::mul(self, o)
2344    }
2345    #[inline]
2346    fn scale(&self, s: f64) -> Self {
2347        Tower3Lane::scale(self, s)
2348    }
2349    #[inline]
2350    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2351        Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2352    }
2353    #[inline]
2354    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2355        let vals = self.v.to_array();
2356        let mut mask = 0u8;
2357        for (i, &v) in vals.iter().enumerate() {
2358            if !pred(v) {
2359                mask |= 1 << i;
2360            }
2361        }
2362        GuardVerdict::lanes4(mask)
2363    }
2364    #[inline]
2365    fn scale_rows(&self, s: [f64; 4]) -> Self {
2366        let sl = wide::f64x4::new(s);
2367        let mut out = *self;
2368        out.v = self.v * sl;
2369        for i in 0..K {
2370            out.g[i] = self.g[i] * sl;
2371            for j in 0..K {
2372                out.h[i][j] = self.h[i][j] * sl;
2373                for k in 0..K {
2374                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2375                }
2376            }
2377        }
2378        out
2379    }
2380    #[inline]
2381    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2382        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2383    }
2384}
2385
2386/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2387/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2388/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2389/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2390/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2391/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2392/// and the batched channels through [`generic_batched_fourth_tower`] /
2393/// [`generic_batched_third_tower`], all from a single source.
2394pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2395    /// Number of observations the program covers.
2396    fn n_rows(&self) -> usize;
2397
2398    /// Current primary-scalar values for `row` (where to seed each lane).
2399    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2400
2401    /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2402    /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2403    /// pre-seeded by the caller (base value plus, for the directional scalars, the
2404    /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2405    /// per-row data entering through `rows`/`self` as constants.
2406    fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2407}
2408
2409/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2410/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2411/// `RowJet` twin of [`generic_row_kernel`].
2412pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2413    prog: &P,
2414    row: usize,
2415) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2416    let base = prog.primaries(row)?;
2417    let vars: [crate::jet_scalar::Order2<K>; K] =
2418        std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2419    let s = prog.row_nll(&[row], &vars)?;
2420    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2421}
2422
2423/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2424/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2425/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2426pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2427    prog: &P,
2428    row: usize,
2429    dir: &[f64; K],
2430) -> Result<[[f64; K]; K], String> {
2431    let base = prog.primaries(row)?;
2432    let vars: [crate::jet_scalar::OneSeed<K>; K] =
2433        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2434    let s = prog.row_nll(&[row], &vars)?;
2435    Ok(s.contracted_third())
2436}
2437
2438/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2439/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2440/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2441pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2442    prog: &P,
2443    row: usize,
2444    dir_u: &[f64; K],
2445    dir_v: &[f64; K],
2446) -> Result<[[f64; K]; K], String> {
2447    let base = prog.primaries(row)?;
2448    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2449        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2450    let s = prog.row_nll(&[row], &vars)?;
2451    Ok(s.contracted_fourth())
2452}
2453
2454/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2455/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2456/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2457/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2458/// the scalar [`generic_full_tower`] on `rows[i]`.
2459pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2460    prog: &P,
2461    rows: [usize; 4],
2462) -> Result<Tower4Batch<K>, String> {
2463    let bases: [[f64; K]; 4] = [
2464        prog.primaries(rows[0])?,
2465        prog.primaries(rows[1])?,
2466        prog.primaries(rows[2])?,
2467        prog.primaries(rows[3])?,
2468    ];
2469    let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2470        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2471        Tower4Batch::variable(lane_vals, a)
2472    });
2473    prog.row_nll(&rows, &vars)
2474}
2475
2476/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2477/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2478/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2479/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2480pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2481    prog: &P,
2482    rows: [usize; 4],
2483) -> Result<Tower3Batch<K>, String> {
2484    let bases: [[f64; K]; 4] = [
2485        prog.primaries(rows[0])?,
2486        prog.primaries(rows[1])?,
2487        prog.primaries(rows[2])?,
2488        prog.primaries(rows[3])?,
2489    ];
2490    let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2491        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2492        Tower3Batch::variable(lane_vals, a)
2493    });
2494    prog.row_nll(&rows, &vars)
2495}
2496
2497// ── The oracle ───────────────────────────────────────────────────────
2498
2499/// One row's worth of hand-written kernel outputs, as claimed by a
2500/// `RowKernel` implementation, packaged for verification against the
2501/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2502/// its visibility — can be audited from its own test module.
2503pub struct KernelChannels<const K: usize> {
2504    /// Claimed `(nll, ∇, H)` from `row_kernel`.
2505    pub value: f64,
2506    /// Claimed gradient.
2507    pub gradient: [f64; K],
2508    /// Claimed Hessian.
2509    pub hessian: [[f64; K]; K],
2510    /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2511    pub third: Vec<([f64; K], [[f64; K]; K])>,
2512    /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2513    pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2514}
2515
2516/// Channel-by-channel audit of a hand-written kernel against the
2517/// single-expression tower truth. Returns `Err` naming the first channel,
2518/// index, claimed and true values on disagreement — designed as the body
2519/// of the per-family CI oracle tests (#932 deployment step 2).
2520///
2521/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2522/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2523/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2524/// towers pass without demanding bit-equality, while a tiny cross-block
2525/// entry dropped next to a huge one is still caught (it is NOT measured
2526/// against the largest entry of the whole channel — there is no per-channel
2527/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2528///
2529/// Non-finite handling is strict: a NaN on either side always fails; an
2530/// infinity passes only when both sides are the SAME signed infinity.
2531pub fn verify_kernel_channels<const K: usize>(
2532    tower: &Tower4<K>,
2533    claims: &KernelChannels<K>,
2534    rel_tol: f64,
2535) -> Result<(), String> {
2536    // Absolute floor: reuse rel_tol so a single knob controls both the
2537    // relative band and the absolute floor for entries near zero.
2538    let atol = rel_tol;
2539    let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2540        // Non-finite values never silently pass the algebraic comparison
2541        // below (any comparison with NaN is false). Handle them explicitly:
2542        // NaN on either side always errs; an infinity passes only if both
2543        // sides are the identical signed infinity.
2544        if !claim.is_finite() || !truth.is_finite() {
2545            let agree = claim.is_infinite()
2546                && truth.is_infinite()
2547                && claim.is_sign_positive() == truth.is_sign_positive();
2548            if agree {
2549                return Ok(());
2550            }
2551            return Err(format!(
2552                "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2553            ));
2554        }
2555        let band = atol + rel_tol * claim.abs().max(truth.abs());
2556        if (claim - truth).abs() > band {
2557            return Err(format!(
2558                "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2559            ));
2560        }
2561        Ok(())
2562    };
2563
2564    check("value", claims.value, tower.v)?;
2565
2566    for a in 0..K {
2567        check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2568    }
2569
2570    for a in 0..K {
2571        for b in 0..K {
2572            check(
2573                &format!("hessian[{a}][{b}]"),
2574                claims.hessian[a][b],
2575                tower.h[a][b],
2576            )?;
2577        }
2578    }
2579
2580    for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2581        let truth = tower.third_contracted(dir);
2582        for a in 0..K {
2583            for b in 0..K {
2584                check(
2585                    &format!("third[{t_idx}][{a}][{b}]"),
2586                    claim[a][b],
2587                    truth[a][b],
2588                )?;
2589            }
2590        }
2591    }
2592
2593    for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2594        let truth = tower.fourth_contracted(u, w);
2595        for a in 0..K {
2596            for b in 0..K {
2597                check(
2598                    &format!("fourth[{f_idx}][{a}][{b}]"),
2599                    claim[a][b],
2600                    truth[a][b],
2601                )?;
2602            }
2603        }
2604    }
2605
2606    Ok(())
2607}
2608
2609// ===========================================================================
2610// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2611// ===========================================================================
2612//
2613// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2614// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2615// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2616// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2617// vector pass instead of one per scalar pass.
2618//
2619// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2620// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2621// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2622// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2623// performs no fp-contraction), so lane `i` of any channel of a
2624// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2625// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2626// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2627// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2628// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2629//
2630// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2631// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2632// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2633// exactly as the scalar path already does).
2634
2635use crate::jet_scalar::Lane;
2636
2637/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2638/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2639/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2640#[derive(Clone, Copy)]
2641pub struct Tower4Lane<L: Lane, const K: usize> {
2642    /// Value channel (one entry per lane/row).
2643    pub v: L,
2644    /// Gradient `∂/∂p_a`.
2645    pub g: [L; K],
2646    /// Hessian `∂²/∂p_a∂p_b`.
2647    pub h: [[L; K]; K],
2648    /// Third tensor `∂³`.
2649    pub t3: [[[L; K]; K]; K],
2650    /// Fourth tensor `∂⁴`.
2651    pub t4: [[[[L; K]; K]; K]; K],
2652}
2653
2654/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2655pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2656
2657impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2658    /// All-zero tower (every channel `+0.0` in every lane).
2659    #[inline]
2660    pub fn zero() -> Self {
2661        let z = L::splat(0.0);
2662        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2663    }
2664    /// Constant `c` (per lane): value channel only.
2665    #[inline]
2666    pub fn constant(c: L) -> Self {
2667        let mut o = Self::zero();
2668        o.v = c;
2669        o
2670    }
2671    /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2672    /// slot `idx` (mirrors [`Tower4::variable`]).
2673    #[inline]
2674    pub fn variable(value: L, idx: usize) -> Self {
2675        let mut o = Self::constant(value);
2676        o.g[idx] = L::splat(1.0);
2677        o
2678    }
2679    /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2680    #[inline]
2681    pub fn lane(&self, i: usize) -> Tower4<K> {
2682        let mut out = Tower4::<K>::zero();
2683        out.v = self.v.lane(i);
2684        for a in 0..K {
2685            out.g[a] = self.g[a].lane(i);
2686            for b in 0..K {
2687                out.h[a][b] = self.h[a][b].lane(i);
2688                for c in 0..K {
2689                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2690                    for d in 0..K {
2691                        out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2692                    }
2693                }
2694            }
2695        }
2696        out
2697    }
2698    /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2699    #[inline]
2700    pub fn add(&self, o: &Self) -> Self {
2701        let mut out = *self;
2702        out.v = self.v.add(o.v);
2703        for i in 0..K {
2704            out.g[i] = self.g[i].add(o.g[i]);
2705            for j in 0..K {
2706                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2707                for k in 0..K {
2708                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2709                    for l in 0..K {
2710                        out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2711                    }
2712                }
2713            }
2714        }
2715        out
2716    }
2717    /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2718    #[inline]
2719    pub fn sub(&self, o: &Self) -> Self {
2720        let mut out = *self;
2721        out.v = self.v.sub(o.v);
2722        for i in 0..K {
2723            out.g[i] = self.g[i].sub(o.g[i]);
2724            for j in 0..K {
2725                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2726                for k in 0..K {
2727                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2728                    for l in 0..K {
2729                        out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2730                    }
2731                }
2732            }
2733        }
2734        out
2735    }
2736    /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2737    #[inline]
2738    pub fn scale(&self, s: f64) -> Self {
2739        let sl = L::splat(s);
2740        let mut out = *self;
2741        out.v = self.v.mul(sl);
2742        for i in 0..K {
2743            out.g[i] = self.g[i].mul(sl);
2744            for j in 0..K {
2745                out.h[i][j] = self.h[i][j].mul(sl);
2746                for k in 0..K {
2747                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2748                    for l in 0..K {
2749                        out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2750                    }
2751                }
2752            }
2753        }
2754        out
2755    }
2756    /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2757    #[inline]
2758    pub fn mul(&self, o: &Self) -> Self {
2759        let a = self;
2760        let b = o;
2761        let mut out = Self::zero();
2762        out.v = a.v.mul(b.v);
2763        for i in 0..K {
2764            let mut acc = L::splat(0.0);
2765            acc = acc.add(a.v.mul(b.g[i]));
2766            acc = acc.add(a.g[i].mul(b.v));
2767            out.g[i] = acc;
2768        }
2769        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
2770        for i in 0..K {
2771            for j in i..K {
2772                let mut acc = L::splat(0.0);
2773                acc = acc.add(a.v.mul(b.h[i][j]));
2774                acc = acc.add(a.g[i].mul(b.g[j]));
2775                acc = acc.add(a.g[j].mul(b.g[i]));
2776                acc = acc.add(a.h[i][j].mul(b.v));
2777                out.h[i][j] = acc;
2778                out.h[j][i] = acc;
2779            }
2780        }
2781        for i in 0..K {
2782            for j in 0..K {
2783                for k in 0..K {
2784                    let mut acc = L::splat(0.0);
2785                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
2786                    acc = acc.add(a.g[i].mul(b.h[j][k]));
2787                    acc = acc.add(a.g[j].mul(b.h[i][k]));
2788                    acc = acc.add(a.h[i][j].mul(b.g[k]));
2789                    acc = acc.add(a.g[k].mul(b.h[i][j]));
2790                    acc = acc.add(a.h[i][k].mul(b.g[j]));
2791                    acc = acc.add(a.h[j][k].mul(b.g[i]));
2792                    acc = acc.add(a.t3[i][j][k].mul(b.v));
2793                    out.t3[i][j][k] = acc;
2794                }
2795            }
2796        }
2797        for i in 0..K {
2798            for j in 0..K {
2799                for k in 0..K {
2800                    for l in 0..K {
2801                        let mut acc = L::splat(0.0);
2802                        acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2803                        acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2804                        acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2805                        acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2806                        acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2807                        acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2808                        acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2809                        acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2810                        acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2811                        acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2812                        acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2813                        acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2814                        acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2815                        acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2816                        acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2817                        acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2818                        out.t4[i][j][k][l] = acc;
2819                    }
2820                }
2821            }
2822        }
2823        out
2824    }
2825    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2826    /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2827    /// (build via [`Lane::unary5`] from the scalar special-function stack).
2828    #[inline]
2829    pub fn compose_unary(&self, d: [L; 5]) -> Self {
2830        let mut out = Self::zero();
2831        out.v = d[0];
2832        for i in 0..K {
2833            let mut acc = L::splat(0.0);
2834            acc = acc.add(d[1].mul(self.g[i]));
2835            out.g[i] = acc;
2836        }
2837        for i in 0..K {
2838            for j in 0..K {
2839                let mut acc = L::splat(0.0);
2840                acc = acc.add(d[1].mul(self.h[i][j]));
2841                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2842                out.h[i][j] = acc;
2843            }
2844        }
2845        for i in 0..K {
2846            for j in 0..K {
2847                for k in 0..K {
2848                    let mut acc = L::splat(0.0);
2849                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
2850                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2851                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2852                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2853                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2854                    out.t3[i][j][k] = acc;
2855                }
2856            }
2857        }
2858        for i in 0..K {
2859            for j in 0..K {
2860                for k in 0..K {
2861                    for l in 0..K {
2862                        let mut acc = L::splat(0.0);
2863                        acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2864                        acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2865                        acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2866                        acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2867                        acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2868                        acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2869                        acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2870                        acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2871                        acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2872                        acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2873                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2874                        acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2875                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2876                        acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2877                        acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2878                        out.t4[i][j][k][l] = acc;
2879                    }
2880                }
2881            }
2882        }
2883        out
2884    }
2885    /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2886    /// built from the base value through `stack_fn`, evaluated PER LANE — the
2887    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2888    /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2889    ///
2890    /// Each of the four lanes carries a DISTINCT base value, so the scalar
2891    /// `stack_fn` is run once per lane at that lane's own value (via
2892    /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2893    /// the composition is then the existing per-lane [`Self::compose_unary`].
2894    /// Because `unary_with` runs the identical scalar closure per lane and
2895    /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2896    /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2897    /// — which is exactly what lets a row program written against the scalar
2898    /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2899    #[inline]
2900    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2901        self.compose_unary(self.v.unary_with(stack_fn))
2902    }
2903
2904    /// Single-active-slot fast path, term-for-term lift of
2905    /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2906    #[inline]
2907    pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2908        let mut out = Self::zero();
2909        let s = slot;
2910        let g = self.g[s];
2911        let h = self.h[s][s];
2912        let t3 = self.t3[s][s][s];
2913        let t4 = self.t4[s][s][s][s];
2914        out.v = d[0];
2915        out.g[s] = {
2916            let mut acc = L::splat(0.0);
2917            acc = acc.add(d[1].mul(g));
2918            acc
2919        };
2920        out.h[s][s] = {
2921            let mut acc = L::splat(0.0);
2922            acc = acc.add(d[1].mul(h));
2923            acc = acc.add(d[2].mul(g).mul(g));
2924            acc
2925        };
2926        out.t3[s][s][s] = {
2927            let mut acc = L::splat(0.0);
2928            acc = acc.add(d[1].mul(t3));
2929            acc = acc.add(d[2].mul(h).mul(g));
2930            acc = acc.add(d[2].mul(h).mul(g));
2931            acc = acc.add(d[2].mul(g).mul(h));
2932            acc = acc.add(d[3].mul(g).mul(g).mul(g));
2933            acc
2934        };
2935        out.t4[s][s][s][s] = {
2936            let mut acc = L::splat(0.0);
2937            acc = acc.add(d[1].mul(t4));
2938            acc = acc.add(d[2].mul(t3).mul(g));
2939            acc = acc.add(d[2].mul(t3).mul(g));
2940            acc = acc.add(d[2].mul(h).mul(h));
2941            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2942            acc = acc.add(d[2].mul(t3).mul(g));
2943            acc = acc.add(d[2].mul(h).mul(h));
2944            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2945            acc = acc.add(d[2].mul(h).mul(h));
2946            acc = acc.add(d[2].mul(g).mul(t3));
2947            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2948            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2949            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2950            acc = acc.add(d[3].mul(g).mul(g).mul(h));
2951            acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2952            acc
2953        };
2954        out
2955    }
2956    /// Contract `t3` with a primary-space direction (lift of
2957    /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2958    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2959    #[inline]
2960    pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2961        let mut out = [[L::splat(0.0); K]; K];
2962        for a in 0..K {
2963            for b in a..K {
2964                let mut acc = L::splat(0.0);
2965                for c in 0..K {
2966                    acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2967                }
2968                out[a][b] = acc;
2969                out[b][a] = acc;
2970            }
2971        }
2972        out
2973    }
2974    /// Contract `t4` with two primary-space directions (lift of
2975    /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
2976    /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2977    #[inline]
2978    pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2979        let mut out = [[L::splat(0.0); K]; K];
2980        for i in 0..K {
2981            for j in i..K {
2982                let mut acc = L::splat(0.0);
2983                for k in 0..K {
2984                    for l in 0..K {
2985                        acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2986                    }
2987                }
2988                out[i][j] = acc;
2989                out[j][i] = acc;
2990            }
2991        }
2992        out
2993    }
2994}
2995
2996/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2997#[derive(Clone, Copy)]
2998pub struct Tower3Lane<L: Lane, const K: usize> {
2999    /// Value channel.
3000    pub v: L,
3001    /// Gradient.
3002    pub g: [L; K],
3003    /// Hessian.
3004    pub h: [[L; K]; K],
3005    /// Third tensor.
3006    pub t3: [[[L; K]; K]; K],
3007}
3008
3009/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
3010pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
3011
3012impl<L: Lane, const K: usize> Tower3Lane<L, K> {
3013    /// All-zero tower.
3014    #[inline]
3015    pub fn zero() -> Self {
3016        let z = L::splat(0.0);
3017        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
3018    }
3019    /// Constant `c` (per lane).
3020    #[inline]
3021    pub fn constant(c: L) -> Self {
3022        let mut o = Self::zero();
3023        o.v = c;
3024        o
3025    }
3026    /// Seeded variable `p_idx` at per-lane `value`.
3027    #[inline]
3028    pub fn variable(value: L, idx: usize) -> Self {
3029        let mut o = Self::constant(value);
3030        o.g[idx] = L::splat(1.0);
3031        o
3032    }
3033    /// Extract lane `i` as a scalar [`Tower3<K>`].
3034    #[inline]
3035    pub fn lane(&self, i: usize) -> Tower3<K> {
3036        let mut out = Tower3::<K>::zero();
3037        out.v = self.v.lane(i);
3038        for a in 0..K {
3039            out.g[a] = self.g[a].lane(i);
3040            for b in 0..K {
3041                out.h[a][b] = self.h[a][b].lane(i);
3042                for c in 0..K {
3043                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
3044                }
3045            }
3046        }
3047        out
3048    }
3049    /// Per-channel lane-wise `self + o`.
3050    #[inline]
3051    pub fn add(&self, o: &Self) -> Self {
3052        let mut out = *self;
3053        out.v = self.v.add(o.v);
3054        for i in 0..K {
3055            out.g[i] = self.g[i].add(o.g[i]);
3056            for j in 0..K {
3057                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3058                for k in 0..K {
3059                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3060                }
3061            }
3062        }
3063        out
3064    }
3065    /// Per-channel lane-wise `self - o`.
3066    #[inline]
3067    pub fn sub(&self, o: &Self) -> Self {
3068        let mut out = *self;
3069        out.v = self.v.sub(o.v);
3070        for i in 0..K {
3071            out.g[i] = self.g[i].sub(o.g[i]);
3072            for j in 0..K {
3073                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3074                for k in 0..K {
3075                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3076                }
3077            }
3078        }
3079        out
3080    }
3081    /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3082    #[inline]
3083    pub fn scale(&self, s: f64) -> Self {
3084        let sl = L::splat(s);
3085        let mut out = *self;
3086        out.v = self.v.mul(sl);
3087        for i in 0..K {
3088            out.g[i] = self.g[i].mul(sl);
3089            for j in 0..K {
3090                out.h[i][j] = self.h[i][j].mul(sl);
3091                for k in 0..K {
3092                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3093                }
3094            }
3095        }
3096        out
3097    }
3098    /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3099    #[inline]
3100    pub fn mul(&self, o: &Self) -> Self {
3101        let a = self;
3102        let b = o;
3103        let mut out = Self::zero();
3104        out.v = a.v.mul(b.v);
3105        for i in 0..K {
3106            let mut acc = L::splat(0.0);
3107            acc = acc.add(a.v.mul(b.g[i]));
3108            acc = acc.add(a.g[i].mul(b.v));
3109            out.g[i] = acc;
3110        }
3111        // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
3112        for i in 0..K {
3113            for j in i..K {
3114                let mut acc = L::splat(0.0);
3115                acc = acc.add(a.v.mul(b.h[i][j]));
3116                acc = acc.add(a.g[i].mul(b.g[j]));
3117                acc = acc.add(a.g[j].mul(b.g[i]));
3118                acc = acc.add(a.h[i][j].mul(b.v));
3119                out.h[i][j] = acc;
3120                out.h[j][i] = acc;
3121            }
3122        }
3123        for i in 0..K {
3124            for j in 0..K {
3125                for k in 0..K {
3126                    let mut acc = L::splat(0.0);
3127                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
3128                    acc = acc.add(a.g[i].mul(b.h[j][k]));
3129                    acc = acc.add(a.g[j].mul(b.h[i][k]));
3130                    acc = acc.add(a.h[i][j].mul(b.g[k]));
3131                    acc = acc.add(a.g[k].mul(b.h[i][j]));
3132                    acc = acc.add(a.h[i][k].mul(b.g[j]));
3133                    acc = acc.add(a.h[j][k].mul(b.g[i]));
3134                    acc = acc.add(a.t3[i][j][k].mul(b.v));
3135                    out.t3[i][j][k] = acc;
3136                }
3137            }
3138        }
3139        out
3140    }
3141    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3142    /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3143    #[inline]
3144    pub fn compose_unary(&self, d: [L; 4]) -> Self {
3145        let mut out = Self::zero();
3146        out.v = d[0];
3147        for i in 0..K {
3148            let mut acc = L::splat(0.0);
3149            acc = acc.add(d[1].mul(self.g[i]));
3150            out.g[i] = acc;
3151        }
3152        for i in 0..K {
3153            for j in 0..K {
3154                let mut acc = L::splat(0.0);
3155                acc = acc.add(d[1].mul(self.h[i][j]));
3156                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3157                out.h[i][j] = acc;
3158            }
3159        }
3160        for i in 0..K {
3161            for j in 0..K {
3162                for k in 0..K {
3163                    let mut acc = L::splat(0.0);
3164                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
3165                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3166                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3167                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3168                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3169                    out.t3[i][j][k] = acc;
3170                }
3171            }
3172        }
3173        out
3174    }
3175    /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3176    /// built from the base value through `stack_fn`, evaluated PER LANE — the
3177    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3178    /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3179    /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3180    /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3181    /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3182    /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3183    #[inline]
3184    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3185        self.compose_unary(self.v.unary_with(stack_fn))
3186    }
3187
3188    /// Single-active-slot fast path, term-for-term lift of
3189    /// [`Tower3::compose_unary_single_slot`].
3190    #[inline]
3191    pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3192        let mut out = Self::zero();
3193        let s = slot;
3194        let g = self.g[s];
3195        let h = self.h[s][s];
3196        let t3 = self.t3[s][s][s];
3197        out.v = d[0];
3198        out.g[s] = {
3199            let mut acc = L::splat(0.0);
3200            acc = acc.add(d[1].mul(g));
3201            acc
3202        };
3203        out.h[s][s] = {
3204            let mut acc = L::splat(0.0);
3205            acc = acc.add(d[1].mul(h));
3206            acc = acc.add(d[2].mul(g).mul(g));
3207            acc
3208        };
3209        out.t3[s][s][s] = {
3210            let mut acc = L::splat(0.0);
3211            acc = acc.add(d[1].mul(t3));
3212            acc = acc.add(d[2].mul(h).mul(g));
3213            acc = acc.add(d[2].mul(h).mul(g));
3214            acc = acc.add(d[2].mul(g).mul(h));
3215            acc = acc.add(d[3].mul(g).mul(g).mul(g));
3216            acc
3217        };
3218        out
3219    }
3220}
3221
3222#[cfg(test)]
3223mod batch_tests {
3224    use super::*;
3225
3226    struct Rng(u64);
3227    impl Rng {
3228        fn f(&mut self) -> f64 {
3229            self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3230            ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3231        }
3232    }
3233
3234    // Fill every channel of a scalar Tower4<K> with random data.
3235    fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3236        let mut t = Tower4::<K>::zero();
3237        t.v = r.f();
3238        for i in 0..K {
3239            t.g[i] = r.f();
3240            for j in 0..K {
3241                t.h[i][j] = r.f();
3242                for k in 0..K {
3243                    t.t3[i][j][k] = r.f();
3244                    for l in 0..K {
3245                        t.t4[i][j][k][l] = r.f();
3246                    }
3247                }
3248            }
3249        }
3250        t
3251    }
3252    fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3253        let mut t = Tower3::<K>::zero();
3254        t.v = r.f();
3255        for i in 0..K {
3256            t.g[i] = r.f();
3257            for j in 0..K {
3258                t.h[i][j] = r.f();
3259                for k in 0..K {
3260                    t.t3[i][j][k] = r.f();
3261                }
3262            }
3263        }
3264        t
3265    }
3266    fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3267        let mut b = Tower4Batch::<K>::zero();
3268        let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3269            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3270        };
3271        b.v = lane(&|t| t.v);
3272        for i in 0..K {
3273            b.g[i] = lane(&|t| t.g[i]);
3274            for j in 0..K {
3275                b.h[i][j] = lane(&|t| t.h[i][j]);
3276                for k in 0..K {
3277                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3278                    for l in 0..K {
3279                        b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3280                    }
3281                }
3282            }
3283        }
3284        b
3285    }
3286    fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3287        let mut b = Tower3Batch::<K>::zero();
3288        let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3289            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3290        };
3291        b.v = lane(&|t| t.v);
3292        for i in 0..K {
3293            b.g[i] = lane(&|t| t.g[i]);
3294            for j in 0..K {
3295                b.h[i][j] = lane(&|t| t.h[i][j]);
3296                for k in 0..K {
3297                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3298                }
3299            }
3300        }
3301        b
3302    }
3303    fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3304        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3305        for i in 0..K {
3306            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3307            for j in 0..K {
3308                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3309                for k in 0..K {
3310                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3311                    for l in 0..K {
3312                        assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3313                    }
3314                }
3315            }
3316        }
3317    }
3318    fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3319        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3320        for i in 0..K {
3321            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3322            for j in 0..K {
3323                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3324                for k in 0..K {
3325                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3326                }
3327            }
3328        }
3329    }
3330
3331    // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3332    // then assert every channel of every lane is to_bits-identical.
3333    fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3334        let mut r = Rng(seed);
3335        let mut rows_checked = 0;
3336        for _ in 0..batches {
3337            let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3338            let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3339            let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3340            let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3341            let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3342            let s = r.f();
3343
3344            // scalar per-row reference
3345            let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3346                let prod = a[rw].mul(&b[rw]);
3347                let comp = prod.compose_unary(d[rw]);
3348                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3349                summed.compose_unary_single_slot(d[rw], 0)
3350            });
3351            let third: [[[f64; K]; K]; 4] =
3352                std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3353            let fourth: [[[f64; K]; K]; 4] =
3354                std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3355
3356            // batched f64x4
3357            let ab = pack4_t4(&a);
3358            let bb = pack4_t4(&b);
3359            let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3360                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3361            });
3362            let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3363                wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3364            });
3365            let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3366                wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3367            });
3368            let prodb = ab.mul(&bb);
3369            let compb = prodb.compose_unary(db);
3370            let summedb = compb.add(&ab).sub(&bb).scale(s);
3371            let finalb = summedb.compose_unary_single_slot(db, 0);
3372            let thirdb = ab.third_contracted(&dirb);
3373            let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3374
3375            for rw in 0..4 {
3376                assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3377                for i in 0..K {
3378                    for j in 0..K {
3379                        assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3380                        assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3381                    }
3382                }
3383                rows_checked += 1;
3384            }
3385        }
3386        rows_checked
3387    }
3388    fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3389        let mut r = Rng(seed);
3390        let mut rows_checked = 0;
3391        for _ in 0..batches {
3392            let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3393            let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3394            let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3395            let s = r.f();
3396            let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3397                let prod = a[rw].mul(&b[rw]);
3398                let comp = prod.compose_unary(d[rw]);
3399                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3400                summed.compose_unary_single_slot(d[rw], 0)
3401            });
3402            let ab = pack4_t3(&a);
3403            let bb = pack4_t3(&b);
3404            let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3405                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3406            });
3407            let prodb = ab.mul(&bb);
3408            let compb = prodb.compose_unary(db);
3409            let summedb = compb.add(&ab).sub(&bb).scale(s);
3410            let finalb = summedb.compose_unary_single_slot(db, 0);
3411            for rw in 0..4 {
3412                assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3413                rows_checked += 1;
3414            }
3415        }
3416        rows_checked
3417    }
3418
3419    // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3420    // lanes (≈210 KiB by value); the op chain keeps several live, which can
3421    // exceed a test thread's default stack. Run each width on a large-stack
3422    // thread so K=9 is exercised without a stack overflow.
3423    fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3424        std::thread::Builder::new()
3425            .stack_size(512 << 20)
3426            .spawn(f)
3427            .unwrap()
3428            .join()
3429            .unwrap()
3430    }
3431
3432    #[test]
3433    fn tower4_batch_lane_bit_identical() {
3434        let batches = 2000;
3435        let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3436            + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3437            + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3438            + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3439        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3440        // worker threads against silently running zero comparisons.
3441        assert_eq!(rows_checked, 4 * batches * 4);
3442    }
3443
3444    #[test]
3445    fn tower3_batch_lane_bit_identical() {
3446        let batches = 2000;
3447        let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3448            + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3449            + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3450            + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3451        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3452        // worker threads against silently running zero comparisons.
3453        assert_eq!(rows_checked, 4 * batches * 4);
3454    }
3455
3456    // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3457    //
3458    // The seam lets a single-sourced row program build its special-function
3459    // STACK from the base value through a closure, so the SAME expression
3460    // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3461    // per-lane bases). These oracles pin both arms `to_bits`.
3462
3463    /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3464    /// standing in for a family's hand-certified special-function stack. `stack4`
3465    /// is its order-≤3 truncation.
3466    fn seam_stack5(u: f64) -> [f64; 5] {
3467        [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3468    }
3469    fn seam_stack4(u: f64) -> [f64; 4] {
3470        let s = seam_stack5(u);
3471        [s[0], s[1], s[2], s[3]]
3472    }
3473
3474    /// Force a distinct / edge per-lane base value (signed zeros included).
3475    fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3476        match which {
3477            0 => -0.0,
3478            1 => 0.0,
3479            2 => r.f(),
3480            _ => r.f() + 3.0,
3481        }
3482    }
3483
3484    /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3485    /// the explicit `compose_unary(f(value))` on every channel.
3486    fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3487        let mut r = Rng(seed);
3488        for _ in 0..n {
3489            let mut t = rand_t4::<K>(&mut r);
3490            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3491            assert_t4_eq(
3492                &t.compose_unary_with(seam_stack5),
3493                &t.compose_unary(seam_stack5(t.v)),
3494                "scalar t4 seam",
3495            );
3496        }
3497        n
3498    }
3499    fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3500        let mut r = Rng(seed);
3501        for _ in 0..n {
3502            let mut t = rand_t3::<K>(&mut r);
3503            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3504            assert_t3_eq(
3505                &t.compose_unary_with(seam_stack4),
3506                &t.compose_unary(seam_stack4(t.v)),
3507                "scalar t3 seam",
3508            );
3509        }
3510        n
3511    }
3512
3513    /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3514    /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3515    /// with the four lanes carrying DISTINCT base values (signed zeros included),
3516    /// so a buggy impl reusing one lane's base would fail.
3517    fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3518        let mut r = Rng(seed);
3519        let mut verified = 0usize;
3520        for _ in 0..batches {
3521            let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3522            for (rw, row) in rows.iter_mut().enumerate() {
3523                row.v = seam_edge_base(&mut r, rw);
3524            }
3525            let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3526            for (rw, row) in rows.iter().enumerate() {
3527                assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3528                verified += 1;
3529            }
3530        }
3531        verified
3532    }
3533    fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3534        let mut r = Rng(seed);
3535        let mut verified = 0usize;
3536        for _ in 0..batches {
3537            let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3538            for (rw, row) in rows.iter_mut().enumerate() {
3539                row.v = seam_edge_base(&mut r, rw);
3540            }
3541            let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3542            for (rw, row) in rows.iter().enumerate() {
3543                assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3544                verified += 1;
3545            }
3546        }
3547        verified
3548    }
3549
3550    #[test]
3551    fn compose_unary_with_scalar_bit_identical() {
3552        let n = 1100;
3553        let total = scalar_seam_t4::<2>(0x2200_0001, n)
3554            + scalar_seam_t4::<3>(0x2200_0002, n)
3555            + scalar_seam_t4::<4>(0x2200_0003, n)
3556            + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3557            + scalar_seam_t3::<2>(0x3300_0001, n)
3558            + scalar_seam_t3::<3>(0x3300_0002, n)
3559            + scalar_seam_t3::<4>(0x3300_0003, n)
3560            + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3561        // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3562        assert_eq!(total, 8 * n);
3563    }
3564
3565    #[test]
3566    fn compose_unary_with_lane_matches_scalar() {
3567        let b = 600;
3568        let total = lane_seam_t4::<2>(0x4400_0001, b)
3569            + lane_seam_t4::<3>(0x4400_0002, b)
3570            + lane_seam_t4::<4>(0x4400_0003, b)
3571            + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3572            + lane_seam_t3::<2>(0x5500_0001, b)
3573            + lane_seam_t3::<3>(0x5500_0002, b)
3574            + lane_seam_t3::<4>(0x5500_0003, b)
3575            + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3576        // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3577        assert_eq!(total, 8 * b * 4);
3578    }
3579}
3580
3581#[cfg(test)]
3582mod tests {
3583    use super::*;
3584
3585    /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3586    /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3587    /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3588    /// dropping the fourth tensor cannot perturb them. Exercises products
3589    /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3590    /// same operations the survival location-scale `nll_index_tower` composes —
3591    /// across all mixed partials, not just the diagonal entries that kernel reads.
3592    #[test]
3593    fn tower3_matches_tower4_through_third_order() {
3594        let s_a: [f64; 5] = [
3595            0.3_f64.sin(),
3596            0.3_f64.cos(),
3597            -0.3_f64.sin(),
3598            -0.3_f64.cos(),
3599            0.3_f64.sin(),
3600        ];
3601        let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3602        let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3603
3604        let a4 = Tower4::<3>::variable(0.4, 0);
3605        let b4 = Tower4::<3>::variable(-0.7, 1);
3606        let c4 = Tower4::<3>::variable(0.9, 2);
3607        let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3608            + a4.mul(&c4).scale(-0.7)
3609            + b4.compose_unary(s_b).scale(0.25);
3610
3611        let a3 = Tower3::<3>::variable(0.4, 0);
3612        let b3 = Tower3::<3>::variable(-0.7, 1);
3613        let c3 = Tower3::<3>::variable(0.9, 2);
3614        let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3615            + a3.mul(&c3).scale(-0.7)
3616            + b3.compose_unary(s4(s_b)).scale(0.25);
3617
3618        assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3619        for i in 0..3 {
3620            assert_eq!(
3621                prog3.g[i].to_bits(),
3622                prog4.g[i].to_bits(),
3623                "g[{i}] mismatch"
3624            );
3625            for j in 0..3 {
3626                assert_eq!(
3627                    prog3.h[i][j].to_bits(),
3628                    prog4.h[i][j].to_bits(),
3629                    "h[{i}][{j}] mismatch"
3630                );
3631                for k in 0..3 {
3632                    assert_eq!(
3633                        prog3.t3[i][j][k].to_bits(),
3634                        prog4.t3[i][j][k].to_bits(),
3635                        "t3[{i}][{j}][{k}] mismatch"
3636                    );
3637                }
3638            }
3639        }
3640    }
3641
3642    /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3643    /// The entire tower has textbook closed forms in μ = σ(η); this test
3644    /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3645    /// analytic truth at near-machine precision.
3646    struct LogitProgram {
3647        eta: Vec<f64>,
3648        y: Vec<f64>,
3649    }
3650
3651    impl RowNllProgram<1> for LogitProgram {
3652        fn n_rows(&self) -> usize {
3653            self.eta.len()
3654        }
3655        fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3656            Ok([self.eta[row]])
3657        }
3658        fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3659            let eta = p[0];
3660            Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3661        }
3662    }
3663
3664    #[test]
3665    fn logit_tower_matches_closed_forms() {
3666        let prog = LogitProgram {
3667            eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3668            y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3669        };
3670        for row in 0..prog.n_rows() {
3671            let t = evaluate_program(&prog, row).expect("logit program");
3672            let eta = prog.eta[row];
3673            let y = prog.y[row];
3674            let mu = 1.0 / (1.0 + (-eta).exp());
3675            let w = mu * (1.0 - mu);
3676            let expect = [
3677                (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3678                (t.g[0], mu - y, "grad"),
3679                (t.h[0][0], w, "hess"),
3680                (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3681                (
3682                    t.t4[0][0][0][0],
3683                    w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3684                    "fourth",
3685                ),
3686            ];
3687            for (got, want, label) in expect {
3688                assert!(
3689                    (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3690                    "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3691                );
3692            }
3693        }
3694    }
3695
3696    fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3697        let diff = (got - want).abs();
3698        assert!(
3699            diff <= rel_tol * want.abs().max(1.0),
3700            "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3701        );
3702    }
3703
3704    #[test]
3705    fn gamma_special_function_stacks_match_reference_values() {
3706        const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3707        let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3708        let cases = [
3709            (
3710                "x=0.1",
3711                0.1,
3712                -10.423_754_940_411_076,
3713                101.433_299_150_792_75,
3714            ),
3715            (
3716                "x=0.5",
3717                0.5,
3718                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3719                pi_sq / 2.0,
3720            ),
3721            ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3722            (
3723                "x=2.5",
3724                2.5,
3725                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3726                pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3727            ),
3728            (
3729                "x=50",
3730                50.0,
3731                3.901_989_673_427_892,
3732                0.020_201_333_226_697_128,
3733            ),
3734        ];
3735
3736        for (label, x, digamma_ref, trigamma_ref) in cases {
3737            let ln_gamma_stack = ln_gamma_derivative_stack(x);
3738            let digamma_stack = digamma_derivative_stack(x);
3739            let trigamma_stack = trigamma_derivative_stack(x);
3740            assert_close(
3741                &format!("{label} ln_gamma_stack digamma"),
3742                ln_gamma_stack[1],
3743                digamma_ref,
3744                1e-13,
3745            );
3746            assert_close(
3747                &format!("{label} digamma value"),
3748                digamma_stack[0],
3749                digamma_ref,
3750                1e-13,
3751            );
3752            assert_close(
3753                &format!("{label} ln_gamma_stack trigamma"),
3754                ln_gamma_stack[2],
3755                trigamma_ref,
3756                1e-13,
3757            );
3758            assert_close(
3759                &format!("{label} digamma_stack trigamma"),
3760                digamma_stack[1],
3761                trigamma_ref,
3762                1e-13,
3763            );
3764            assert_close(
3765                &format!("{label} trigamma value"),
3766                trigamma_stack[0],
3767                trigamma_ref,
3768                1e-13,
3769            );
3770        }
3771    }
3772
3773    #[test]
3774    fn gamma_special_function_stacks_obey_recurrences() {
3775        for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3776            let digamma_x = digamma_derivative_stack(x)[0];
3777            let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3778            let trigamma_x = trigamma_derivative_stack(x)[0];
3779            let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3780            assert_close(
3781                &format!("digamma recurrence x={x}"),
3782                digamma_next,
3783                digamma_x + 1.0 / x,
3784                1e-13,
3785            );
3786            assert_close(
3787                &format!("trigamma recurrence x={x}"),
3788                trigamma_next,
3789                trigamma_x - 1.0 / (x * x),
3790                1e-13,
3791            );
3792        }
3793    }
3794
3795    /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3796    /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3797    /// shape — all have one-line closed forms here.
3798    struct LocScaleProgram {
3799        eta: Vec<f64>,
3800        s: Vec<f64>,
3801        y: Vec<f64>,
3802    }
3803
3804    impl RowNllProgram<2> for LocScaleProgram {
3805        fn n_rows(&self) -> usize {
3806            self.eta.len()
3807        }
3808        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3809            Ok([self.eta[row], self.s[row]])
3810        }
3811        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3812            let r = -(p[0] - self.y[row]);
3813            Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3814        }
3815    }
3816
3817    #[test]
3818    fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3819        let prog = LocScaleProgram {
3820            eta: vec![0.3, -1.1, 2.0],
3821            s: vec![-0.5, 0.2, 0.8],
3822            y: vec![1.0, -2.0, 2.5],
3823        };
3824        let tol = 1e-12;
3825        for row in 0..prog.n_rows() {
3826            let t = evaluate_program(&prog, row).expect("locscale program");
3827            let r = prog.y[row] - prog.eta[row];
3828            let w = (-2.0 * prog.s[row]).exp();
3829            // (η, s) = indices (0, 1).
3830            let truth_g = [-w * r, 1.0 - w * r * r];
3831            let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3832            // Third tensor: distinct-entry closed forms.
3833            // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3834            let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3835                match a + b + c {
3836                    0 => 0.0,
3837                    1 => -2.0 * w,
3838                    2 => -4.0 * w * r,
3839                    _ => -4.0 * w * r * r,
3840                }
3841            };
3842            // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3843            // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3844            let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3845                match a + b + c + d {
3846                    0 | 1 => 0.0,
3847                    2 => 4.0 * w,
3848                    3 => 8.0 * w * r,
3849                    _ => 8.0 * w * r * r,
3850                }
3851            };
3852            for a in 0..2 {
3853                assert!(
3854                    (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3855                    "row {row} grad[{a}]"
3856                );
3857                for b in 0..2 {
3858                    assert!(
3859                        (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3860                        "row {row} hess[{a}][{b}]: got {} want {}",
3861                        t.h[a][b],
3862                        truth_h[a][b]
3863                    );
3864                    for c in 0..2 {
3865                        assert!(
3866                            (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3867                                <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3868                            "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3869                            t.t3[a][b][c],
3870                            t3_truth(a, b, c)
3871                        );
3872                        for d in 0..2 {
3873                            assert!(
3874                                (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3875                                    <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3876                                "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3877                                t.t4[a][b][c][d],
3878                                t4_truth(a, b, c, d)
3879                            );
3880                        }
3881                    }
3882                }
3883            }
3884            // The derived trait-surface helpers agree with direct contraction.
3885            let dir = [0.7, -1.3];
3886            let third = derived_third_contracted(&prog, row, &dir).expect("third");
3887            for a in 0..2 {
3888                for b in 0..2 {
3889                    let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3890                    assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3891                }
3892            }
3893        }
3894    }
3895
3896    /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3897    /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3898    /// every tower channel is checked against central finite differences
3899    /// of the channel one order below — value→grad, grad→hess, hess→t3,
3900    /// t3→t4 — so each order is independently anchored.
3901    ///
3902    /// The program carries a per-row primary fixture plus a per-row offset
3903    /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3904    /// drives both the seed point and the evaluated expression.
3905    struct GnarlyProgram {
3906        primaries: Vec<[f64; 3]>,
3907        tau: Vec<f64>,
3908    }
3909
3910    impl GnarlyProgram {
3911        fn fixture() -> Self {
3912            Self {
3913                primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3914                tau: vec![0.15, -0.35, 0.5],
3915            }
3916        }
3917    }
3918
3919    impl RowNllProgram<3> for GnarlyProgram {
3920        fn n_rows(&self) -> usize {
3921            self.primaries.len()
3922        }
3923        fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3924            self.primaries
3925                .get(row)
3926                .copied()
3927                .ok_or_else(|| format!("gnarly: row {row} out of range"))
3928        }
3929        fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3930            let tau = *self
3931                .tau
3932                .get(row)
3933                .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3934            let a = (p[0] * p[1]).exp();
3935            let b = (p[2] * p[2] + 1.0).sqrt();
3936            let c = (a + b + tau).ln();
3937            let d = (p[1] * 0.5 + 2.0).powf(1.7);
3938            Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3939        }
3940    }
3941
3942    /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3943    /// `row` (used to drive central differences off the fixture grid),
3944    /// while keeping `row`'s per-row data (`tau`) in the loss.
3945    fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3946        struct At<'a> {
3947            base: &'a GnarlyProgram,
3948            row: usize,
3949            p: [f64; 3],
3950        }
3951        impl RowNllProgram<3> for At<'_> {
3952            fn n_rows(&self) -> usize {
3953                1
3954            }
3955            fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3956                if row != 0 {
3957                    return Err(format!("gnarly-at: row {row} out of range"));
3958                }
3959                Ok(self.p)
3960            }
3961            fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3962                if eval_row != 0 {
3963                    return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3964                }
3965                self.base.row_nll(self.row, vars)
3966            }
3967        }
3968        evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3969    }
3970
3971    #[test]
3972    fn gnarly_tower_is_fd_consistent_order_by_order() {
3973        let prog = GnarlyProgram::fixture();
3974        for row in 0..prog.n_rows() {
3975            let base = prog.primaries(row).expect("primaries");
3976            let t = gnarly_tower_at(&prog, row, base);
3977            let h_step = 1e-5;
3978            let tol = 1e-6;
3979            for c in 0..3 {
3980                let mut up = base;
3981                let mut dn = base;
3982                up[c] += h_step;
3983                dn[c] -= h_step;
3984                let t_up = gnarly_tower_at(&prog, row, up);
3985                let t_dn = gnarly_tower_at(&prog, row, dn);
3986                // value → gradient.
3987                let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3988                assert!(
3989                    (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3990                    "grad[{c}]: analytic {} fd {}",
3991                    t.g[c],
3992                    fd_g
3993                );
3994                for a in 0..3 {
3995                    // gradient → Hessian.
3996                    let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3997                    assert!(
3998                        (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3999                        "hess[{a}][{c}]: analytic {} fd {}",
4000                        t.h[a][c],
4001                        fd_h
4002                    );
4003                    for b in 0..3 {
4004                        // Hessian → third.
4005                        let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
4006                        assert!(
4007                            (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
4008                            "t3[{a}][{b}][{c}]: analytic {} fd {}",
4009                            t.t3[a][b][c],
4010                            fd_t3
4011                        );
4012                        for d in 0..3 {
4013                            // third → fourth.
4014                            let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
4015                            assert!(
4016                                (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
4017                                "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
4018                                t.t4[a][b][d][c],
4019                                fd_t4
4020                            );
4021                        }
4022                    }
4023                }
4024            }
4025        }
4026    }
4027
4028    /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
4029    /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
4030    /// smooth, strictly-`a`-monotone
4031    ///   F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
4032    /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
4033    /// independent finite-difference oracle. Mirrors the survival flex
4034    /// calibration solve (one implicit intercept over the primaries) without
4035    /// any survival machinery, so a failure localises to the combinator.
4036    #[test]
4037    fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
4038        const C: f64 = 1.7;
4039        // The scalar constraint as a plain f64 closure (the production root
4040        // finder analogue) and its tower form in (a, θ₀, θ₁).
4041        let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
4042        let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
4043        let solve = |th: [f64; 2]| -> f64 {
4044            let mut a = 0.0_f64;
4045            for _ in 0..100 {
4046                let r = f_scalar(a, th);
4047                if r.abs() < 1e-14 {
4048                    break;
4049                }
4050                a -= r / f_da(a, th);
4051            }
4052            a
4053        };
4054        // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
4055        let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
4056            let a = Tower4::<3>::variable(a0, 0);
4057            let t0 = Tower4::<3>::variable(th[0], 1);
4058            let t1 = Tower4::<3>::variable(th[1], 2);
4059            a + t0 * a.mul(&a) + t1 * a.exp() - C
4060        };
4061
4062        let th0 = [0.35, 0.2];
4063        let a0 = solve(th0);
4064        let f = f_tower(a0, th0);
4065        // Residual at the solved point is ~0 (the combinator tolerates the
4066        // production Newton residual; here it is machine-zero).
4067        assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4068        let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4069
4070        // FD oracle: central differences of the scalar re-solve. Each order is
4071        // built from the previous via one more central difference, exactly the
4072        // gnarly order-by-order ladder.
4073        let h = 1e-4;
4074        let tol = 1e-5;
4075        let re = |th: [f64; 2]| solve(th);
4076        for i in 0..2 {
4077            let mut up = th0;
4078            let mut dn = th0;
4079            up[i] += h;
4080            dn[i] -= h;
4081            let fd_g = (re(up) - re(dn)) / (2.0 * h);
4082            assert!(
4083                (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4084                "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4085                a_tower.g[i],
4086                fd_g
4087            );
4088            // second order: FD of the analytic gradient component would re-use
4089            // the combinator; instead difference a SCALAR gradient computed by
4090            // a nested re-solve so the oracle stays production-independent.
4091            let grad_at = |th: [f64; 2], j: usize| -> f64 {
4092                let mut up = th;
4093                let mut dn = th;
4094                up[j] += h;
4095                dn[j] -= h;
4096                (re(up) - re(dn)) / (2.0 * h)
4097            };
4098            for j in 0..2 {
4099                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4100                assert!(
4101                    (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4102                    "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4103                    a_tower.h[i][j],
4104                    fd_h
4105                );
4106            }
4107        }
4108    }
4109
4110    /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4111    /// linear-in-a constraint, and the second-order tensor matches the
4112    /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4113    /// This pins the recursion against the hand-coded first_full.rs formula it
4114    /// replaces, independent of any FD step.
4115    #[test]
4116    fn implicit_solve_matches_textbook_ift_recursion() {
4117        // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4118        let a0 = 0.4_f64;
4119        let th = [0.25_f64, -0.15_f64];
4120        let f = {
4121            let a = Tower4::<3>::variable(a0, 0);
4122            let t0 = Tower4::<3>::variable(th[0], 1);
4123            let t1 = Tower4::<3>::variable(th[1], 2);
4124            // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4125            // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4126            //   0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4127            // implicit_solve requires a genuine root; at the root the level-set
4128            // and root-curve derivatives coincide, so the textbook-IFT
4129            // assertions below are unaffected.
4130            a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4131        };
4132        let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4133        let f_a = f.g[0];
4134        // First order: a_u = −F_u / F_a.
4135        for u in 0..2 {
4136            let want = -f.g[u + 1] / f_a;
4137            assert!(
4138                (a_t.g[u] - want).abs() < 1e-12,
4139                "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4140                a_t.g[u],
4141                want
4142            );
4143        }
4144        // Second order textbook IFT (indices shifted by 1 for the a-slot).
4145        for u in 0..2 {
4146            for v in 0..2 {
4147                let f_uv = f.h[u + 1][v + 1];
4148                let f_au = f.h[0][u + 1];
4149                let f_av = f.h[0][v + 1];
4150                let f_aa = f.h[0][0];
4151                let want =
4152                    -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4153                assert!(
4154                    (a_t.h[u][v] - want).abs() < 1e-12,
4155                    "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4156                    a_t.h[u][v],
4157                    want
4158                );
4159            }
4160        }
4161    }
4162
4163    /// The moving-boundary flux tower reproduces every θ-derivative of a
4164    /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4165    /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4166    /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4167    /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4168    /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4169    #[test]
4170    fn moving_boundary_flux_carries_b_zuv_term() {
4171        use std::f64::consts::PI;
4172        let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4173        // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4174        let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4175        let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4176        let th0 = [0.7_f64, 0.5_f64];
4177
4178        // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4179        let mut z_edge = Tower4::<2>::constant(z_r(th0));
4180        z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4181        z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4182        z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2  (the z_uv the old flux dropped)
4183
4184        // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4185        // B‴=(3z−z³)·B.
4186        let z0 = z_edge.v;
4187        let b0 = b(z0);
4188        let stack = [
4189            b0,
4190            -z0 * b0,
4191            (z0 * z0 - 1.0) * b0,
4192            (3.0 * z0 - z0 * z0 * z0) * b0,
4193        ];
4194        let flux = moving_limit_boundary_tower(&z_edge, stack);
4195
4196        // FD truth of the integral's derivatives.
4197        let h = 1e-4;
4198        let tol = 1e-6;
4199        for i in 0..2 {
4200            let mut up = th0;
4201            let mut dn = th0;
4202            up[i] += h;
4203            dn[i] -= h;
4204            let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4205            assert!(
4206                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4207                "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4208                flux.g[i],
4209                fd_g
4210            );
4211        }
4212        // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4213        // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4214        // leave the [1][1] entry short by exactly 2·B(z₀).
4215        let grad1_at = |th: [f64; 2]| -> f64 {
4216            let mut up = th;
4217            let mut dn = th;
4218            up[1] += h;
4219            dn[1] -= h;
4220            (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4221        };
4222        let mut up = th0;
4223        let mut dn = th0;
4224        up[1] += h;
4225        dn[1] -= h;
4226        let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4227        assert!(
4228            (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4229            "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4230            flux.h[1][1],
4231            fd_h11
4232        );
4233        // Explicit witness that the B·z_uv term is present and material:
4234        // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4235        let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4236        let b_zuv = flux.h[1][1] - pure_zu2;
4237        assert!(
4238            (b_zuv - b0 * 2.0).abs() < 1e-10,
4239            "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4240            b_zuv,
4241            b0 * 2.0
4242        );
4243    }
4244
4245    /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4246    /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4247    /// plain `moving_limit_boundary_tower` cannot express, and the case the
4248    /// survival directional/bidirectional paths hand-assemble term-by-term
4249    /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4250    /// dropping `G·z_uv`). Two independent oracles:
4251    ///   (1) closed-form: the boundary flux of `∫ G dz` is exactly
4252    ///       `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4253    ///       derivatives we take by central FD of the closed form — no jet code.
4254    ///   (2) the explicit second-order hand closure, including the `G·z_uv` term,
4255    ///       built from the integrand's own (z,θ) partials.
4256    /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4257    /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4258    /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4259    /// miss a Hessian entry.
4260    #[test]
4261    fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4262        // G(z;θ) = exp(z·θ₀);  Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4263        let g = |z: f64, t0: f64| (z * t0).exp();
4264        let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4265        let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4266        let th0 = [0.4_f64, 0.5_f64];
4267        let z0 = z_r(th0);
4268
4269        // Edge tower z_edge(θ) over K=2 primaries.
4270        let mut z_edge = Tower4::<2>::constant(z0);
4271        z_edge.g[0] = 1.0; // ∂z/∂θ₀
4272        z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4273        z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4274
4275        // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4276        // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4277        let z_var = Tower4::<3>::variable(z0, 0);
4278        let t0_var = Tower4::<3>::variable(th0[0], 1);
4279        // θ₁ does not enter G/Φ here (its Φ-derivatives are zero; the z_edge
4280        // chain supplies all θ₁ motion through slot 0), so the K1 frame's θ₁
4281        // slot is intentionally left unseeded.
4282        let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4283        // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4284        assert!(
4285            (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4286            "Φ_z {:+.8e} != G {:+.8e}",
4287            phi_jet.g[0],
4288            g(z0, th0[0])
4289        );
4290
4291        let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4292
4293        // Value channel is 0 by construction (boundary, not the integral itself).
4294        assert!(
4295            flux.v.abs() < 1e-12,
4296            "boundary value channel {:+.3e}",
4297            flux.v
4298        );
4299
4300        // Oracle (1): central FD of the closed-form boundary flux
4301        //   Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ)   (z₀ FROZEN at the base edge).
4302        let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4303        let h = 1e-4;
4304        let tol = 1e-6;
4305        for i in 0..2 {
4306            let mut up = th0;
4307            let mut dn = th0;
4308            up[i] += h;
4309            dn[i] -= h;
4310            let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4311            assert!(
4312                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4313                "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4314                flux.g[i],
4315                fd_g
4316            );
4317        }
4318        let grad_at = |th: [f64; 2], j: usize| -> f64 {
4319            let mut up = th;
4320            let mut dn = th;
4321            up[j] += h;
4322            dn[j] -= h;
4323            (bnd(up) - bnd(dn)) / (2.0 * h)
4324        };
4325        for i in 0..2 {
4326            for j in 0..2 {
4327                let mut up = th0;
4328                let mut dn = th0;
4329                up[i] += h;
4330                dn[i] -= h;
4331                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4332                assert!(
4333                    (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4334                    "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4335                    flux.h[i][j],
4336                    fd_h
4337                );
4338            }
4339        }
4340
4341        // Oracle (2): the explicit second-order hand closure, term by term —
4342        // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4343        // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4344        // G_θ₁ = 0.
4345        let gg = g(z0, th0[0]);
4346        let g_z = th0[0] * gg;
4347        let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4348        for i in 0..2 {
4349            for j in 0..2 {
4350                let z_u = z_edge.g[i];
4351                let z_v = z_edge.g[j];
4352                let z_uv = z_edge.h[i][j];
4353                let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4354                assert!(
4355                    (flux.h[i][j] - hand).abs() < 1e-9,
4356                    "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4357                    flux.h[i][j],
4358                    hand
4359                );
4360            }
4361        }
4362
4363        // Decisive: the `G·z_uv` term the directional path DROPS is present and
4364        // material in the [1][1] entry (z_uv = 2 there).
4365        let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4366        let g_zuv = flux.h[1][1] - pure_no_zuv;
4367        assert!(
4368            (g_zuv - gg * 2.0).abs() < 1e-9,
4369            "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4370            g_zuv,
4371            gg * 2.0
4372        );
4373    }
4374
4375    /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4376    /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4377    /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4378    /// boundary-velocity formulas:
4379    ///   z_u   = −(a_u + [u==g]·z) / b
4380    ///   z_uv  = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4381    /// This pins the bridge between `implicit_solve` and
4382    /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4383    /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4384    /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4385    /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4386    #[test]
4387    fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4388        const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4389        let g_idx = 1usize;
4390        let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4391        // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4392        // two live axes so a_u and a_uv are both exercised. (In production this
4393        // comes from implicit_solve; here we plant known derivatives.)
4394        let mut a = Tower4::<3>::constant(0.45);
4395        a.g[0] = 0.7;
4396        a.g[1] = -0.3;
4397        a.h[0][0] = 0.25;
4398        a.h[0][1] = 0.11;
4399        a.h[1][0] = 0.11;
4400        a.h[1][1] = -0.08;
4401
4402        // In the survival flex frame the slope `b` IS the g-primary directly
4403        // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4404        // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4405        let b = Tower4::<3>::variable(g0, g_idx);
4406        let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4407
4408        let bv = g0;
4409        let z0 = z_edge.v;
4410        assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4411
4412        // z_u = −(a_u + [u==g]·z) / b.
4413        for u in 0..2 {
4414            let direct = if u == g_idx { z0 } else { 0.0 };
4415            let want = -(a.g[u] + direct) / bv;
4416            assert!(
4417                (z_edge.g[u] - want).abs() < 1e-10,
4418                "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4419                z_edge.g[u],
4420                want
4421            );
4422        }
4423        // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4424        // first-order z_v/z_u (already verified above).
4425        for u in 0..2 {
4426            for v in 0..2 {
4427                let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4428                    + if v == g_idx { z_edge.g[u] } else { 0.0 };
4429                let want = -(a.h[u][v] + cross) / bv;
4430                assert!(
4431                    (z_edge.h[u][v] - want).abs() < 1e-10,
4432                    "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4433                    z_edge.h[u][v],
4434                    want
4435                );
4436            }
4437        }
4438    }
4439
4440    /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4441    /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4442    /// FD-certified bare boundary-velocity constants exactly:
4443    ///   z_a  = ∂z/∂a   = −1/b
4444    ///   z_ab = ∂²z/∂a∂b = +1/b²
4445    ///   z_aa = ∂²z/∂a²  = 0
4446    ///   z_bb = ∂²z/∂b²  = +2(τ−a)/b³
4447    /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4448    /// production base path drops (and only adds in the dir twins, causing the
4449    /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4450    /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4451    /// that later. Pins the constant before the constraint-tower wiring.
4452    #[test]
4453    fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4454        const TAU: f64 = 1.3;
4455        let a0 = 0.45_f64;
4456        let b0 = 0.85_f64;
4457        // Slot 0 = a, slot 1 = b, both seeded independent.
4458        let a = Tower4::<2>::variable(a0, 0);
4459        let b = Tower4::<2>::variable(b0, 1);
4460        let z = (Tower4::<2>::constant(TAU) - a) / b;
4461
4462        assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4463        assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4464        assert!(
4465            (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4466            "z_ab {:+.10e} vs +1/b² {:+.10e}",
4467            z.h[0][1],
4468            1.0 / (b0 * b0)
4469        );
4470        assert!(
4471            z.h[0][0].abs() < 1e-12,
4472            "z_aa must vanish, got {:+.10e}",
4473            z.h[0][0]
4474        );
4475        let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4476        assert!(
4477            (z.h[1][1] - want_zbb).abs() < 1e-12,
4478            "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4479            z.h[1][1],
4480            want_zbb
4481        );
4482    }
4483
4484    /// The oracle harness catches a planted #736-style sign flip in a
4485    /// cross block and reports the channel by name.
4486    #[test]
4487    fn oracle_catches_planted_cross_block_sign_flip() {
4488        let prog = LocScaleProgram {
4489            eta: vec![0.3],
4490            s: vec![-0.5],
4491            y: vec![1.0],
4492        };
4493        let t = evaluate_program(&prog, 0).expect("tower");
4494        let dir = [0.6, -0.2];
4495        let mut third = t.third_contracted(&dir);
4496        let honest = KernelChannels {
4497            value: t.v,
4498            gradient: t.g,
4499            hessian: t.h,
4500            third: vec![(dir, third)],
4501            fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4502        };
4503        verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4504
4505        // Plant the #736 flip: negate one mixed cross entry.
4506        third[0][1] = -third[0][1];
4507        let flipped = KernelChannels {
4508            value: t.v,
4509            gradient: t.g,
4510            hessian: t.h,
4511            third: vec![(dir, third)],
4512            fourth: vec![],
4513        };
4514        let err = verify_kernel_channels(&t, &flipped, 1e-10)
4515            .expect_err("planted sign flip must be caught");
4516        assert!(
4517            err.contains("third[0][0][1]"),
4518            "oracle must name the flipped channel, got: {err}"
4519        );
4520    }
4521
4522    /// The third- and fourth-order tensors must be FULLY symmetric under
4523    /// index permutation (mixed partials commute). The tower stores them
4524    /// unsymmetrized, so equal-by-construction is a real invariant of the
4525    /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4526    /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4527    /// every composition path contributes. Lives in a test (not the hot
4528    /// per-op path) on purpose.
4529    #[test]
4530    fn t3_t4_are_fully_index_symmetric() {
4531        let prog = GnarlyProgram::fixture();
4532        // 3! = 6 permutations of three indices.
4533        let perms3: [[usize; 3]; 6] = [
4534            [0, 1, 2],
4535            [0, 2, 1],
4536            [1, 0, 2],
4537            [1, 2, 0],
4538            [2, 0, 1],
4539            [2, 1, 0],
4540        ];
4541        // 4! = 24 permutations of four indices.
4542        let perms4: [[usize; 4]; 24] = [
4543            [0, 1, 2, 3],
4544            [0, 1, 3, 2],
4545            [0, 2, 1, 3],
4546            [0, 2, 3, 1],
4547            [0, 3, 1, 2],
4548            [0, 3, 2, 1],
4549            [1, 0, 2, 3],
4550            [1, 0, 3, 2],
4551            [1, 2, 0, 3],
4552            [1, 2, 3, 0],
4553            [1, 3, 0, 2],
4554            [1, 3, 2, 0],
4555            [2, 0, 1, 3],
4556            [2, 0, 3, 1],
4557            [2, 1, 0, 3],
4558            [2, 1, 3, 0],
4559            [2, 3, 0, 1],
4560            [2, 3, 1, 0],
4561            [3, 0, 1, 2],
4562            [3, 0, 2, 1],
4563            [3, 1, 0, 2],
4564            [3, 1, 2, 0],
4565            [3, 2, 0, 1],
4566            [3, 2, 1, 0],
4567        ];
4568        for row in 0..prog.n_rows() {
4569            let t = evaluate_program(&prog, row).expect("gnarly tower");
4570            let scale_t3 =
4571                t.t3.iter()
4572                    .flatten()
4573                    .flatten()
4574                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4575                    .max(1.0);
4576            let scale_t4 =
4577                t.t4.iter()
4578                    .flatten()
4579                    .flatten()
4580                    .flatten()
4581                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4582                    .max(1.0);
4583            for i in 0..3 {
4584                for j in 0..3 {
4585                    for k in 0..3 {
4586                        let base = t.t3[i][j][k];
4587                        let idx = [i, j, k];
4588                        for p in &perms3 {
4589                            let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4590                            assert!(
4591                                (base - permed).abs() <= 1e-12 * scale_t3,
4592                                "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4593                                 permuted {permed:+.15e} under {p:?}"
4594                            );
4595                        }
4596                        for l in 0..3 {
4597                            let base4 = t.t4[i][j][k][l];
4598                            let idx4 = [i, j, k, l];
4599                            for p in &perms4 {
4600                                let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4601                                assert!(
4602                                    (base4 - permed).abs() <= 1e-12 * scale_t4,
4603                                    "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4604                                     permuted {permed:+.15e} under {p:?}"
4605                                );
4606                            }
4607                        }
4608                    }
4609                }
4610            }
4611        }
4612    }
4613}
4614
4615#[inline]
4616fn erfcx_nonnegative(x: f64) -> f64 {
4617    if !x.is_finite() {
4618        return if x.is_sign_positive() {
4619            0.0
4620        } else {
4621            f64::INFINITY
4622        };
4623    }
4624    if x <= 0.0 {
4625        return 1.0;
4626    }
4627    if x < 26.0 {
4628        ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4629    } else {
4630        let inv = 1.0 / x;
4631        let inv2 = inv * inv;
4632        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4633            + 6.5625 * inv2 * inv2 * inv2 * inv2;
4634        inv * poly / std::f64::consts::PI.sqrt()
4635    }
4636}
4637
4638#[inline]
4639fn log1mexp_positive(a: f64) -> f64 {
4640    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4641    if a > core::f64::consts::LN_2 {
4642        (-(-a).exp()).ln_1p()
4643    } else if a > 0.0 {
4644        (-(-a).exp_m1()).ln()
4645    } else {
4646        f64::NEG_INFINITY
4647    }
4648}
4649
4650#[inline]
4651fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4652    if x == f64::INFINITY {
4653        return (0.0, 0.0);
4654    }
4655    if x == f64::NEG_INFINITY {
4656        return (f64::NEG_INFINITY, f64::INFINITY);
4657    }
4658    if x.is_nan() {
4659        return (f64::NAN, f64::NAN);
4660    }
4661    if x < 0.0 {
4662        let u = -x / std::f64::consts::SQRT_2;
4663        let ex = erfcx_nonnegative(u).max(1e-300);
4664        let log_cdf = -u * u + (0.5 * ex).ln();
4665        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4666        (log_cdf, lambda)
4667    } else {
4668        let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4669        let lambda = crate::probability::normal_pdf(x) / cdf;
4670        (cdf.ln(), lambda)
4671    }
4672}
4673
4674/// Stable derivative stack for `log Phi(x)` through fourth order.
4675#[inline]
4676pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4677    let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4678    let lambda2 = lambda * lambda;
4679    let lambda3 = lambda2 * lambda;
4680    let x2 = x * x;
4681    [
4682        log_cdf,
4683        lambda,
4684        -lambda * (x + lambda),
4685        lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4686        -lambda
4687            * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4688    ]
4689}
4690
4691/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4692#[inline]
4693pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4694    let r = 1.0 / x.exp_m1();
4695    [
4696        log1mexp_positive(x),
4697        r,
4698        -r * (1.0 + r),
4699        r * (1.0 + r) * (1.0 + 2.0 * r),
4700        -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4701    ]
4702}
4703// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4704#[cfg(test)]
4705mod rowjet_bridge_tests {
4706    use super::*;
4707    use crate::jet_scalar::{JetScalar, Order2};
4708
4709    /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4710    /// subtraction, a scale/neg, a constant, and two value-distinct
4711    /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4712    /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4713    /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4714    struct ToyProgram {
4715        primaries: Vec<[f64; 2]>,
4716        /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4717        /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4718        /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4719        aux: Vec<[f64; 3]>,
4720    }
4721
4722    impl ToyProgram {
4723        /// The body uses `pack_rows` to gather the per-lane continuous data from
4724        /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4725        /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4726        fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4727            let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4728            let z = R::pack_rows(rows, |r| self.aux[r][1]);
4729            let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4730
4731            let a = p[0].mul(&p[1]).scale_rows(cov);
4732            let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4733            let c = b
4734                .compose_unary_with(|u| {
4735                    let e = u.exp();
4736                    [e, e, e, e, e]
4737                })
4738                .scale_rows(z);
4739            let d = c.neg().add(&p[0]);
4740            let e = d
4741                .compose_unary_with(|u| {
4742                    let s = (1.0 + u * u).sqrt();
4743                    let s3 = s * s * s;
4744                    let s5 = s3 * s * s;
4745                    let s7 = s5 * s * s;
4746                    [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4747                })
4748                .scale_rows(wi);
4749            e.mul(&p[1]).add(&e)
4750        }
4751    }
4752
4753    impl RowNllProgramRowJet<2> for ToyProgram {
4754        fn n_rows(&self) -> usize {
4755            self.primaries.len()
4756        }
4757        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4758            Ok(self.primaries[row])
4759        }
4760        fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4761            assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4762            Ok(self.body(rows, p))
4763        }
4764    }
4765
4766    fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4767        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4768        for i in 0..2 {
4769            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4770            for j in 0..2 {
4771                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4772                for k in 0..2 {
4773                    assert_eq!(
4774                        a.t3[i][j][k].to_bits(),
4775                        b.t3[i][j][k].to_bits(),
4776                        "{ctx}: t3[{i}][{j}][{k}]"
4777                    );
4778                    for l in 0..2 {
4779                        assert_eq!(
4780                            a.t4[i][j][k][l].to_bits(),
4781                            b.t4[i][j][k][l].to_bits(),
4782                            "{ctx}: t4[{i}][{j}][{k}][{l}]"
4783                        );
4784                    }
4785                }
4786            }
4787        }
4788    }
4789
4790    fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4791        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4792        for i in 0..2 {
4793            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4794            for j in 0..2 {
4795                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4796                for k in 0..2 {
4797                    assert_eq!(
4798                        a.t3[i][j][k].to_bits(),
4799                        b.t3[i][j][k].to_bits(),
4800                        "{ctx}: t3[{i}][{j}][{k}]"
4801                    );
4802                }
4803            }
4804        }
4805    }
4806
4807    // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4808    struct Lcg(u64);
4809    impl Lcg {
4810        fn next(&mut self) -> f64 {
4811            self.0 = self
4812                .0
4813                .wrapping_mul(6364136223846793005)
4814                .wrapping_add(1442695040888963407);
4815            (self.0 >> 11) as f64 / (1u64 << 53) as f64
4816        }
4817        fn val(&mut self) -> f64 {
4818            let u = self.next();
4819            if u < 0.04 {
4820                return 0.0;
4821            }
4822            if u < 0.08 {
4823                return -0.0;
4824            }
4825            (self.next() - 0.5) * 5.0
4826        }
4827    }
4828
4829    /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4830    /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4831    /// signed-zero and per-lane-distinct primaries.
4832    #[test]
4833    fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4834        let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4835        let mut batches = 0usize;
4836        for _ in 0..2500 {
4837            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4838            // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4839            let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4840            let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4841            let rows = [0usize, 1, 2, 3];
4842
4843            // order-4 batch vs scalar Tower4 (instantiated through the same body).
4844            let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4845            for (row, base) in bases.iter().enumerate() {
4846                let vars: [Tower4<2>; 2] =
4847                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4848                let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4849                assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4850            }
4851
4852            // order-3 batch vs scalar Tower3.
4853            let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4854            for (row, base) in bases.iter().enumerate() {
4855                let vars: [Tower3<2>; 2] =
4856                    std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4857                let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4858                assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4859            }
4860            batches += 1;
4861        }
4862        assert_eq!(batches, 2500);
4863    }
4864
4865    /// The blanket impl does not churn the scalar path: the body driven through
4866    /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4867    /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4868    /// `Tower4` lower channels.
4869    #[test]
4870    fn blanket_scalar_path_is_unchanged_and_consistent() {
4871        let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4872        for _ in 0..3000 {
4873            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4874            let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4875            let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4876
4877            // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4878            // reference body uses `scale(f64)` where the RowJet body uses
4879            // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4880            // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4881            let via_rowjet: Tower4<2> = {
4882                let vars: [Tower4<2>; 2] =
4883                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4884                prog.row_nll(&[0], &vars).expect("rowjet")
4885            };
4886            let via_jetscalar: Tower4<2> = {
4887                let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4888                    <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4889                });
4890                let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4891                // The body using JetScalar's own ops + scale(f64) directly.
4892                let a = vars[0].mul(&vars[1]).scale(cov);
4893                let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4894                let c = b
4895                    .compose_unary_with(|u| {
4896                        let e = u.exp();
4897                        [e, e, e, e, e]
4898                    })
4899                    .scale(z);
4900                let d = JetScalar::neg(&c).add(&vars[0]);
4901                let e = d
4902                    .compose_unary_with(|u| {
4903                        let s = (1.0 + u * u).sqrt();
4904                        let s3 = s * s * s;
4905                        let s5 = s3 * s * s;
4906                        let s7 = s5 * s * s;
4907                        [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4908                    })
4909                    .scale(wi);
4910                e.mul(&vars[1]).add(&e)
4911            };
4912            assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4913
4914            // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4915            // Order2 and Tower4 use different internal representations so
4916            // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4917            // Hessian channels that evaluate to exactly zero; IEEE equality
4918            // treats these as equal, so `==` is the right comparison here.
4919            let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4920            assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4921            for i in 0..2 {
4922                assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4923                for j in 0..2 {
4924                    assert!(
4925                        h[i][j] == via_rowjet.h[i][j],
4926                        "kernel h[{i}][{j}]: {} vs {}",
4927                        h[i][j],
4928                        via_rowjet.h[i][j]
4929                    );
4930                }
4931            }
4932
4933            // (c) the Order2 scalar IS a RowJet via the blanket.
4934            let o2: [Order2<2>; 2] =
4935                std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4936            let via_order2 = prog.body(&[0], &o2);
4937            assert_eq!(
4938                via_order2.0.v.to_bits(),
4939                via_rowjet.v.to_bits(),
4940                "Order2 blanket value channel must match the dense Tower4 program body"
4941            );
4942        }
4943    }
4944
4945    /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4946    /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4947    /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4948    #[test]
4949    fn scale_rows_scalar_is_bit_identical_to_scale() {
4950        let mut rng = Lcg(0xFEED_FACE_0042_1001);
4951        for _ in 0..3000 {
4952            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4953            let s = rng.val();
4954            // Build a dense tower with populated channels (exp of a product).
4955            let vars: [Tower4<2>; 2] =
4956                std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4957            let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4958                let e = u.exp();
4959                [e, e, e, e, e]
4960            });
4961            let via_scale = RowJet::scale(&jet, s);
4962            let via_scale_rows = RowJet::scale_rows(&jet, s);
4963            assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4964        }
4965    }
4966
4967    /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4968    /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4969    /// continuous per-row data path the single-`f64` `scale` could not carry.
4970    #[test]
4971    fn batched_scale_rows_matches_per_row_scalar_scale() {
4972        let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4973        for _ in 0..2500 {
4974            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4975            let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4976            let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4977                Tower4Batch::variable(
4978                    wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4979                    a,
4980                )
4981            });
4982            let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4983                let e = u.exp();
4984                [e, e, e, e, e]
4985            });
4986            let scaled = prod.scale_rows(s);
4987            for (row, base) in bases.iter().enumerate() {
4988                let v: [Tower4<2>; 2] =
4989                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4990                let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4991                    let e = u.exp();
4992                    [e, e, e, e, e]
4993                });
4994                let ref_s = RowJet::scale(&prod_s, s[row]);
4995                assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4996            }
4997        }
4998    }
4999
5000    /// The per-lane guard reports exactly the failing lanes on a batch and the
5001    /// single lane on a scalar jet.
5002    #[test]
5003    fn guard_reports_per_lane_failures() {
5004        let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
5005        let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5006            Tower4Batch::variable(
5007                wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
5008                a,
5009            )
5010        });
5011        let verdict = vars[0].guard(|v| v > 0.0);
5012        assert_eq!(verdict.lanes(), 4);
5013        assert!(verdict.any_failed());
5014        assert!(!verdict.all_pass());
5015        assert!(!verdict.lane_failed(0));
5016        assert!(verdict.lane_failed(1));
5017        assert!(!verdict.lane_failed(2));
5018        assert!(verdict.lane_failed(3));
5019        assert_eq!(verdict.failed_mask(), 0b1010);
5020
5021        let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
5022        let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
5023        assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
5024        assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
5025        assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
5026    }
5027
5028    // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
5029
5030    #[test]
5031    fn ln_gamma_derivative_stack_known_values_at_1() {
5032        let s = ln_gamma_derivative_stack(1.0);
5033        // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
5034        assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
5035        // ψ₀(1) = -γ  (Euler–Mascheroni)
5036        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5037        assert!(
5038            (s[1] + euler_mascheroni).abs() < 1e-10,
5039            "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
5040            s[1]
5041        );
5042        // ψ₁(1) = π²/6
5043        let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
5044        assert!(
5045            (s[2] - pi2_6).abs() < 1e-10,
5046            "trigamma(1) ≈ {pi2_6:.6}, got {}",
5047            s[2]
5048        );
5049    }
5050
5051    #[test]
5052    fn ln_gamma_derivative_stack_known_values_at_2() {
5053        let s = ln_gamma_derivative_stack(2.0);
5054        // ln Γ(2) = ln(1) = 0 exactly
5055        assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
5056        // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
5057        let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5058        let digamma_2 = 1.0 - euler_mascheroni;
5059        assert!(
5060            (s[1] - digamma_2).abs() < 1e-10,
5061            "digamma(2) ≈ {digamma_2:.6}, got {}",
5062            s[1]
5063        );
5064    }
5065
5066    #[test]
5067    fn ln_gamma_derivative_stack_order2_is_prefix() {
5068        for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5069            let full = ln_gamma_derivative_stack(x);
5070            let ord2 = ln_gamma_derivative_stack_order2(x);
5071            assert_eq!(
5072                ord2[0], full[0],
5073                "order2[0] != full[0] at x={x}"
5074            );
5075            assert_eq!(
5076                ord2[1], full[1],
5077                "order2[1] != full[1] at x={x}"
5078            );
5079            assert_eq!(
5080                ord2[2], full[2],
5081                "order2[2] != full[2] at x={x}"
5082            );
5083        }
5084    }
5085
5086    #[test]
5087    fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5088        // The two stacks share a run of four polygamma values:
5089        // ln_gamma_stack[1..5] == digamma_stack[0..4]
5090        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5091            let lg = ln_gamma_derivative_stack(x);
5092            let dg = digamma_derivative_stack(x);
5093            for i in 0..4 {
5094                assert_eq!(
5095                    lg[i + 1], dg[i],
5096                    "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5097                    i + 1,
5098                    i
5099                );
5100            }
5101        }
5102    }
5103
5104    #[test]
5105    fn trigamma_derivative_stack_overlaps_digamma_stack() {
5106        // digamma_stack[1..5] == trigamma_stack[0..4]
5107        for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5108            let dg = digamma_derivative_stack(x);
5109            let tg = trigamma_derivative_stack(x);
5110            for i in 0..4 {
5111                assert_eq!(
5112                    dg[i + 1], tg[i],
5113                    "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5114                    i + 1,
5115                    i
5116                );
5117            }
5118        }
5119    }
5120
5121    #[test]
5122    fn derivative_stacks_all_finite_at_positive_inputs() {
5123        for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5124            for v in ln_gamma_derivative_stack(x) {
5125                assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5126            }
5127            for v in digamma_derivative_stack(x) {
5128                assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5129            }
5130            for v in trigamma_derivative_stack(x) {
5131                assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5132            }
5133        }
5134    }
5135}
5136
5137// ── Contraction-symmetry optimization gate ────────────────────────────────────
5138//
5139// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5140// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5141// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5142// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5143// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5144// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5145// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5146// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5147// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5148// the full reference and (2) a measured wall-clock that is not slower.
5149#[cfg(test)]
5150mod contraction_symmetry_tests {
5151    use super::*;
5152
5153    struct Rng(u64);
5154    impl Rng {
5155        fn u(&mut self) -> f64 {
5156            self.0 = self
5157                .0
5158                .wrapping_mul(6364136223846793005)
5159                .wrapping_add(1442695040888963407);
5160            (self.0 >> 11) as f64 / (1u64 << 53) as f64
5161        }
5162        fn s(&mut self) -> f64 {
5163            (self.u() - 0.5) * 4.0
5164        }
5165    }
5166
5167    /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5168    fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5169        let mut t = Tower4::<K>::zero();
5170        t.v = r.s();
5171        for i in 0..K {
5172            t.g[i] = r.s();
5173        }
5174        for a in 0..K {
5175            for b in a..K {
5176                let v2 = r.s();
5177                t.h[a][b] = v2;
5178                t.h[b][a] = v2;
5179                for c in b..K {
5180                    let v3 = r.s();
5181                    for p in perms3([a, b, c]) {
5182                        t.t3[p[0]][p[1]][p[2]] = v3;
5183                    }
5184                    for d in c..K {
5185                        let v4 = r.s();
5186                        for p in perms4([a, b, c, d]) {
5187                            t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5188                        }
5189                    }
5190                }
5191            }
5192        }
5193        t
5194    }
5195
5196    fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5197        let [a, b, c] = idx;
5198        [[a, b, c], [a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]]
5199    }
5200    fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5201        let [a, b, c, d] = idx;
5202        [
5203            [a, b, c, d], [a, b, d, c], [a, c, b, d], [a, c, d, b], [a, d, b, c], [a, d, c, b],
5204            [b, a, c, d], [b, a, d, c], [b, c, a, d], [b, c, d, a], [b, d, a, c], [b, d, c, a],
5205            [c, a, b, d], [c, a, d, b], [c, b, a, d], [c, b, d, a], [c, d, a, b], [c, d, b, a],
5206            [d, a, b, c], [d, a, c, b], [d, b, a, c], [d, b, c, a], [d, c, a, b], [d, c, b, a],
5207        ]
5208    }
5209
5210    /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5211    fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5212        let mut out = [[0.0; K]; K];
5213        for a in 0..K {
5214            for b in 0..K {
5215                let mut acc = 0.0;
5216                for c in 0..K {
5217                    acc += t.t3[a][b][c] * dir[c];
5218                }
5219                out[a][b] = acc;
5220            }
5221        }
5222        out
5223    }
5224    fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5225        let mut out = [[0.0; K]; K];
5226        for i in 0..K {
5227            for j in 0..K {
5228                let mut acc = 0.0;
5229                for k in 0..K {
5230                    for l in 0..K {
5231                        acc += t.t4[i][j][k][l] * u[k] * w[l];
5232                    }
5233                }
5234                out[i][j] = acc;
5235            }
5236        }
5237        out
5238    }
5239
5240    /// Returns the number of bit-equality comparisons performed (`n·K·K·2`), so
5241    /// the caller can assert the intended workload actually ran: a generic
5242    /// (turbofish) helper call hides its internal assertions, so the count is
5243    /// surfaced and checked at the call site.
5244    fn check_bit_identical<const K: usize>(seed: u64, n: usize) -> usize {
5245        let mut r = Rng(seed);
5246        let mut checks = 0usize;
5247        for _ in 0..n {
5248            let t = rand_sym4::<K>(&mut r);
5249            let dir: [f64; K] = std::array::from_fn(|_| r.s());
5250            let u: [f64; K] = std::array::from_fn(|_| r.s());
5251            let w: [f64; K] = std::array::from_fn(|_| r.s());
5252            let t3_sym = t.third_contracted(&dir);
5253            let t3_full = third_full(&t, &dir);
5254            let t4_sym = t.fourth_contracted(&u, &w);
5255            let t4_full = fourth_full(&t, &u, &w);
5256            for a in 0..K {
5257                for b in 0..K {
5258                    assert_eq!(
5259                        t3_sym[a][b].to_bits(),
5260                        t3_full[a][b].to_bits(),
5261                        "third K={K} [{a}][{b}]"
5262                    );
5263                    assert_eq!(
5264                        t4_sym[a][b].to_bits(),
5265                        t4_full[a][b].to_bits(),
5266                        "fourth K={K} [{a}][{b}]"
5267                    );
5268                    checks += 2;
5269                }
5270            }
5271        }
5272        checks
5273    }
5274
5275    /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5276    /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5277    /// are unchanged; this is a pure speed-only optimization).
5278    #[test]
5279    fn contraction_symmetry_is_bit_identical_to_full_nest() {
5280        let checks = check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000)
5281            + check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800)
5282            + check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600)
5283            + check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5284        // Guards against the loops silently not running (e.g. a zeroed count):
5285        // 1000·2²·2 + 800·3²·2 + 600·4²·2 + 300·9²·2.
5286        assert_eq!(checks, 8000 + 14400 + 19200 + 48600);
5287    }
5288
5289    /// Measure the wall-clock of the output-symmetric contraction vs the full
5290    /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5291    /// test is the correctness gate). Informational — wall-clock is noisy — with
5292    /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5293    /// fewer inner contractions, so it must not be materially slower).
5294    #[test]
5295    fn contraction_symmetry_speedup_is_reported() {
5296        const K: usize = 9;
5297        let mut r = Rng(0xC0FF_EE99_1234_5678);
5298        let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5299        let dir: [f64; K] = std::array::from_fn(|_| r.s());
5300        let u: [f64; K] = std::array::from_fn(|_| r.s());
5301        let w: [f64; K] = std::array::from_fn(|_| r.s());
5302
5303        let reps = 400usize;
5304        let t_sym = {
5305            let start = std::time::Instant::now();
5306            let mut sink = 0.0f64;
5307            for _ in 0..reps {
5308                for t in &towers {
5309                    let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5310                    let o4 = std::hint::black_box(t)
5311                        .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5312                    sink += o3[0][K - 1] + o4[0][K - 1];
5313                }
5314            }
5315            std::hint::black_box(sink);
5316            start.elapsed().as_secs_f64()
5317        };
5318        let t_full = {
5319            let start = std::time::Instant::now();
5320            let mut sink = 0.0f64;
5321            for _ in 0..reps {
5322                for t in &towers {
5323                    let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5324                    let o4 = fourth_full(
5325                        std::hint::black_box(t),
5326                        std::hint::black_box(&u),
5327                        std::hint::black_box(&w),
5328                    );
5329                    sink += o3[0][K - 1] + o4[0][K - 1];
5330                }
5331            }
5332            std::hint::black_box(sink);
5333            start.elapsed().as_secs_f64()
5334        };
5335        let calls = (reps * towers.len()) as f64;
5336        eprintln!(
5337            "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5338             wall_speedup={:.2}x",
5339            t_sym / calls * 1e9,
5340            t_full / calls * 1e9,
5341            t_full / t_sym
5342        );
5343        assert!(
5344            t_sym <= t_full * 1.5,
5345            "output-symmetric contraction pathologically slower: \
5346             sym={t_sym:.4}s full={t_full:.4}s"
5347        );
5348    }
5349}