Skip to main content

gam_math/
jet_tower.rs

1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//!   v        ℓ
12//!   g[a]     ∂ℓ/∂p_a
13//!   h[a][b]  ∂²ℓ/∂p_a∂p_b
14//!   t3[abc]  ∂³ℓ/∂p_a∂p_b∂p_c
15//!   t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//!    asserting channel-by-channel agreement with a `RowNllProgram` written
94//!    once — see [`verify_kernel_channels`]. This alone would have caught
95//!    #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//!    keep hand-tuned hot paths, now verified against the single-expression
98//!    truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//!    scale port) implement ONLY `RowNllProgram` and get an exact
101//!    fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113    /// Value ℓ.
114    pub v: f64,
115    /// Gradient ∂ℓ/∂p_a.
116    pub g: [f64; K],
117    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118    pub h: [[f64; K]; K],
119    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120    pub t3: [[[f64; K]; K]; K],
121    /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122    pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126    /// The additive identity.
127    pub fn zero() -> Self {
128        Self {
129            v: 0.0,
130            g: [0.0; K],
131            h: [[0.0; K]; K],
132            t3: [[[0.0; K]; K]; K],
133            t4: [[[[0.0; K]; K]; K]; K],
134        }
135    }
136
137    /// A constant: value `c`, all derivatives zero.
138    pub fn constant(c: f64) -> Self {
139        let mut out = Self::zero();
140        out.v = c;
141        out
142    }
143
144    /// The seeded variable `p_idx` with current value `value`:
145    /// unit first derivative in slot `idx`, zero elsewhere and above.
146    pub fn variable(value: f64, idx: usize) -> Self {
147        let mut out = Self::constant(value);
148        out.g[idx] = 1.0;
149        out
150    }
151
152    /// Read the (fully symmetric) derivative tensor entry whose differentiation
153    /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154    #[inline]
155    fn deriv(&self, labels: &[usize]) -> f64 {
156        assert!(
157            labels.len() <= 4,
158            "Tower4 carries at most fourth-order derivatives"
159        );
160        match labels.len() {
161            0 => self.v,
162            1 => self.g[labels[0]],
163            2 => self.h[labels[0]][labels[1]],
164            3 => self.t3[labels[0]][labels[1]][labels[2]],
165            _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166        }
167    }
168
169    /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170    ///
171    /// # Codegen
172    ///
173    /// Each output entry's `2^m` subset sum is written as a compact straight-line
174    /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175    /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176    /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177    /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178    /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179    /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180    /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181    ///
182    /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183    /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184    /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185    /// start so a signed-zero leading product collapses to `+0.0` identically —
186    /// which matters because real jets carry exact-`0.0` channels
187    /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188    /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189    /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190    /// stress — the accumulator start is load-bearing).
191    pub fn mul(&self, o: &Self) -> Self {
192        let a = self;
193        let b = o;
194        let mut out = Self::zero();
195        out.v = a.v * b.v;
196        for i in 0..K {
197            // subsets of {i}: {} {i}
198            let mut acc = 0.0;
199            acc += a.v * b.g[i];
200            acc += a.g[i] * b.v;
201            out.g[i] = acc;
202        }
203        for i in 0..K {
204            for j in 0..K {
205                // subsets of {i,j}: {} {i} {j} {ij}
206                let mut acc = 0.0;
207                acc += a.v * b.h[i][j];
208                acc += a.g[i] * b.g[j];
209                acc += a.g[j] * b.g[i];
210                acc += a.h[i][j] * b.v;
211                out.h[i][j] = acc;
212            }
213        }
214        for i in 0..K {
215            for j in 0..K {
216                for k in 0..K {
217                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
218                    let mut acc = 0.0;
219                    acc += a.v * b.t3[i][j][k];
220                    acc += a.g[i] * b.h[j][k];
221                    acc += a.g[j] * b.h[i][k];
222                    acc += a.h[i][j] * b.g[k];
223                    acc += a.g[k] * b.h[i][j];
224                    acc += a.h[i][k] * b.g[j];
225                    acc += a.h[j][k] * b.g[i];
226                    acc += a.t3[i][j][k] * b.v;
227                    out.t3[i][j][k] = acc;
228                }
229            }
230        }
231        for i in 0..K {
232            for j in 0..K {
233                for k in 0..K {
234                    for l in 0..K {
235                        // subsets of {i,j,k,l} in bit order sub = 0..16
236                        let mut acc = 0.0;
237                        acc += a.v * b.t4[i][j][k][l];
238                        acc += a.g[i] * b.t3[j][k][l];
239                        acc += a.g[j] * b.t3[i][k][l];
240                        acc += a.h[i][j] * b.h[k][l];
241                        acc += a.g[k] * b.t3[i][j][l];
242                        acc += a.h[i][k] * b.h[j][l];
243                        acc += a.h[j][k] * b.h[i][l];
244                        acc += a.t3[i][j][k] * b.g[l];
245                        acc += a.g[l] * b.t3[i][j][k];
246                        acc += a.h[i][l] * b.h[j][k];
247                        acc += a.h[j][l] * b.h[i][k];
248                        acc += a.t3[i][j][l] * b.g[k];
249                        acc += a.h[k][l] * b.h[i][j];
250                        acc += a.t3[i][k][l] * b.g[j];
251                        acc += a.t3[j][k][l] * b.g[i];
252                        acc += a.t4[i][j][k][l] * b.v;
253                        out.t4[i][j][k][l] = acc;
254                    }
255                }
256            }
257        }
258        out
259    }
260
261    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
262    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
263    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
264    /// out of the borrowed operands.
265    pub fn add(&self, o: &Self) -> Self {
266        *self + *o
267    }
268
269    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
270    pub fn sub(&self, o: &Self) -> Self {
271        *self + o.scale(-1.0)
272    }
273
274    /// Exact multivariate Faà di Bruno composition `f ∘ self`.
275    ///
276    /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
277    /// the SAME `[f64; 5]` stack shape the families' existing
278    /// `unary_derivatives_*` helpers produce, so those special-function
279    /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
280    ///
281    /// The order-m output sums over the set partitions of the m indices
282    /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
283    /// block count: each partition into r blocks contributes
284    /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
285    ///
286    /// # Codegen
287    ///
288    /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
289    /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
290    /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
291    /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
292    /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
293    /// sum is straight-line, so this does NOT unroll over `K` and does NOT
294    /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
295    /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
296    /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
297    ///
298    /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
299    /// walker's exact partition-enumeration order, each term's block products
300    /// are left-associated exactly as the walker's `prod *= block`, and the
301    /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
302    /// (so signed-zero products collapse to `+0.0` identically). The order-4
303    /// term sequence was generated from the walker's own enumeration. Proven
304    /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
305    /// 5000 random inputs each (zeroed / sign-varied stacks included).
306    pub fn compose_unary(&self, d: [f64; 5]) -> Self {
307        let mut out = Self::zero();
308        out.v = d[0];
309        for i in 0..K {
310            let mut acc = 0.0;
311            acc += d[1] * self.g[i];
312            out.g[i] = acc;
313        }
314        for i in 0..K {
315            for j in 0..K {
316                let mut acc = 0.0;
317                acc += d[1] * self.h[i][j];
318                acc += d[2] * self.g[i] * self.g[j];
319                out.h[i][j] = acc;
320            }
321        }
322        for i in 0..K {
323            for j in 0..K {
324                for k in 0..K {
325                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
326                    let mut acc = 0.0;
327                    acc += d[1] * self.t3[i][j][k];
328                    acc += d[2] * self.h[i][j] * self.g[k];
329                    acc += d[2] * self.h[i][k] * self.g[j];
330                    acc += d[2] * self.g[i] * self.h[j][k];
331                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
332                    out.t3[i][j][k] = acc;
333                }
334            }
335        }
336        for i in 0..K {
337            for j in 0..K {
338                for k in 0..K {
339                    for l in 0..K {
340                        // Bell(4)=15 partitions, walker enumeration order.
341                        let mut acc = 0.0;
342                        acc += d[1] * self.t4[i][j][k][l];
343                        acc += d[2] * self.t3[i][j][k] * self.g[l];
344                        acc += d[2] * self.t3[i][j][l] * self.g[k];
345                        acc += d[2] * self.h[i][j] * self.h[k][l];
346                        acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
347                        acc += d[2] * self.t3[i][k][l] * self.g[j];
348                        acc += d[2] * self.h[i][k] * self.h[j][l];
349                        acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
350                        acc += d[2] * self.h[i][l] * self.h[j][k];
351                        acc += d[2] * self.g[i] * self.t3[j][k][l];
352                        acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
353                        acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
354                        acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
355                        acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
356                        acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
357                        out.t4[i][j][k][l] = acc;
358                    }
359                }
360            }
361        }
362        out
363    }
364
365    /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
366    /// built from the base value through `stack_fn` — the scalar arm of the
367    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
368    /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
369    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
370    /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
371    /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
372    /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
373    #[inline]
374    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
375        self.compose_unary(stack_fn(self.v))
376    }
377
378    /// Single-active-slot fast path for [`Self::compose_unary`].
379    ///
380    /// When the inner jet `self` has derivative support ONLY on the all-`slot`
381    /// diagonal channels — i.e. it is a univariate jet in primary `slot`
382    /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
383    /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
384    /// Bruno walk collapses. Every output channel whose axis tuple contains an
385    /// axis `≠ slot` is structurally `0`: each set-partition has a block
386    /// covering that axis, that block reads an off-`slot` derivative of `self`
387    /// (which is `0`), so the block product and the whole partition vanish, and
388    /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
389    /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
390    /// `t3[slot]³`, `t4[slot]⁴`) survive.
391    ///
392    /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
393    /// the EXACT term order of [`Self::compose_unary`]'s diagonal
394    /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
395    /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
396    /// at the zero-init `+0.0`, which the full walk also produces (the
397    /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
398    /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
399    /// `K ≥ 2` this is far fewer floating-point operations than materialising
400    /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
401    /// are all zero, and far cheaper than the recursive set-partition walker the
402    /// diagonal channels previously routed through (a measured ~9.5× speedup vs
403    /// the full `compose_unary`, recovering a 5.9× walker regression at the
404    /// `K ∈ {2,3}` BMS tower widths).
405    ///
406    /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
407    /// five-channel build does not amortise the call/spill overhead).
408    ///
409    /// # Precondition
410    ///
411    /// The caller guarantees the single-active-slot structure. If it does not
412    /// hold, the off-`slot` channels would be wrongly zeroed; use the full
413    /// [`Self::compose_unary`] in that case.
414    #[inline]
415    pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
416        let mut out = Self::zero();
417        let s = slot;
418        let g = self.g[s];
419        let h = self.h[s][s];
420        let t3 = self.t3[s][s][s];
421        let t4 = self.t4[s][s][s][s];
422        out.v = d[0];
423        // g (i=s): d1*g
424        out.g[s] = {
425            let mut acc = 0.0;
426            acc += d[1] * g;
427            acc
428        };
429        // h (i=j=s): d1*h + d2*g*g
430        out.h[s][s] = {
431            let mut acc = 0.0;
432            acc += d[1] * h;
433            acc += d[2] * g * g;
434            acc
435        };
436        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
437        out.t3[s][s][s] = {
438            let mut acc = 0.0;
439            acc += d[1] * t3;
440            acc += d[2] * h * g;
441            acc += d[2] * h * g;
442            acc += d[2] * g * h;
443            acc += d[3] * g * g * g;
444            acc
445        };
446        // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
447        out.t4[s][s][s][s] = {
448            let mut acc = 0.0;
449            acc += d[1] * t4;
450            acc += d[2] * t3 * g;
451            acc += d[2] * t3 * g;
452            acc += d[2] * h * h;
453            acc += d[3] * h * g * g;
454            acc += d[2] * t3 * g;
455            acc += d[2] * h * h;
456            acc += d[3] * h * g * g;
457            acc += d[2] * h * h;
458            acc += d[2] * g * t3;
459            acc += d[3] * g * h * g;
460            acc += d[3] * h * g * g;
461            acc += d[3] * g * h * g;
462            acc += d[3] * g * g * h;
463            acc += d[4] * g * g * g * g;
464            acc
465        };
466        out
467    }
468
469    /// Multiply every channel by a plain scalar.
470    pub fn scale(&self, s: f64) -> Self {
471        let mut out = *self;
472        out.v *= s;
473        for i in 0..K {
474            out.g[i] *= s;
475            for j in 0..K {
476                out.h[i][j] *= s;
477                for k in 0..K {
478                    out.t3[i][j][k] *= s;
479                    for l in 0..K {
480                        out.t4[i][j][k][l] *= s;
481                    }
482                }
483            }
484        }
485        out
486    }
487
488    /// e^self.
489    pub fn exp(&self) -> Self {
490        let e = self.v.exp();
491        self.compose_unary([e, e, e, e, e])
492    }
493
494    /// ln(self). Caller guarantees positivity (likelihood programs do).
495    pub fn ln(&self) -> Self {
496        let u = self.v;
497        let r = 1.0 / u;
498        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
499    }
500
501    /// 1/self.
502    pub fn recip(&self) -> Self {
503        let r = 1.0 / self.v;
504        let r2 = r * r;
505        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
506    }
507
508    /// √self. Caller guarantees positivity.
509    pub fn sqrt(&self) -> Self {
510        let u = self.v;
511        let s = u.sqrt();
512        self.compose_unary([
513            s,
514            0.5 / s,
515            -0.25 / (u * s),
516            0.375 / (u * u * s),
517            -0.9375 / (u * u * u * s),
518        ])
519    }
520
521    /// self^a for real exponent `a`. Caller guarantees a positive base.
522    pub fn powf(&self, a: f64) -> Self {
523        let u = self.v;
524        let f0 = u.powf(a);
525        let f1 = a * u.powf(a - 1.0);
526        let f2 = a * (a - 1.0) * u.powf(a - 2.0);
527        let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
528        let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
529        self.compose_unary([f0, f1, f2, f3, f4])
530    }
531
532    /// ln Γ(self). Caller guarantees positivity.
533    pub fn ln_gamma(&self) -> Self {
534        self.compose_unary(ln_gamma_derivative_stack(self.v))
535    }
536
537    /// ψ(self), the digamma function. Caller guarantees positivity.
538    pub fn digamma(&self) -> Self {
539        self.compose_unary(digamma_derivative_stack(self.v))
540    }
541
542    /// ψ′(self), the trigamma function. Caller guarantees positivity.
543    pub fn trigamma(&self) -> Self {
544        self.compose_unary(trigamma_derivative_stack(self.v))
545    }
546
547    /// Contract `t3` with one primary-space direction:
548    /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
549    /// `row_third_contracted` shape.
550    pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
551        let mut out = [[0.0; K]; K];
552        for a in 0..K {
553            for b in 0..K {
554                let mut acc = 0.0;
555                for c in 0..K {
556                    acc += self.t3[a][b][c] * dir[c];
557                }
558                out[a][b] = acc;
559            }
560        }
561        out
562    }
563
564    /// Contract `t4` with two primary-space directions:
565    /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
566    /// `row_fourth_contracted` shape.
567    pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
568        let mut out = [[0.0; K]; K];
569        for i in 0..K {
570            for j in 0..K {
571                let mut acc = 0.0;
572                for k in 0..K {
573                    for l in 0..K {
574                        acc += self.t4[i][j][k][l] * u[k] * w[l];
575                    }
576                }
577                out[i][j] = acc;
578            }
579        }
580        out
581    }
582}
583
584impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
585    #[inline]
586    fn derivative(&self, labels: &[usize]) -> f64 {
587        self.deriv(labels)
588    }
589
590    fn map_derivatives<F>(&self, mut f: F) -> Self
591    where
592        F: FnMut(&[usize]) -> f64,
593    {
594        let mut out = Self::zero();
595        out.v = f(&[]);
596        for i in 0..K {
597            let labels = [i];
598            out.g[i] = f(&labels);
599        }
600        for i in 0..K {
601            for j in 0..K {
602                let labels = [i, j];
603                out.h[i][j] = f(&labels);
604            }
605        }
606        for i in 0..K {
607            for j in 0..K {
608                for k in 0..K {
609                    let labels = [i, j, k];
610                    out.t3[i][j][k] = f(&labels);
611                }
612            }
613        }
614        for i in 0..K {
615            for j in 0..K {
616                for k in 0..K {
617                    for l in 0..K {
618                        let labels = [i, j, k, l];
619                        out.t4[i][j][k][l] = f(&labels);
620                    }
621                }
622            }
623        }
624        out
625    }
626}
627
628/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
629///
630/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
631/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
632/// [`Tower4`] uses for those orders, so for any program written over both
633/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
634/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
635/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
636/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
637/// gradient, or Hessian.
638///
639/// It exists purely for performance: an inner Newton step (and the
640/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
641/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
642/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
643/// dominates the cold marginal-slope fit, while returning the exact same
644/// `(v, g, h)`.
645#[derive(Clone, Copy, Debug)]
646pub struct Tower2<const K: usize> {
647    /// Value ℓ.
648    pub v: f64,
649    /// Gradient ∂ℓ/∂p_a.
650    pub g: [f64; K],
651    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
652    pub h: [[f64; K]; K],
653}
654
655impl<const K: usize> Tower2<K> {
656    /// The additive identity.
657    pub fn zero() -> Self {
658        Self {
659            v: 0.0,
660            g: [0.0; K],
661            h: [[0.0; K]; K],
662        }
663    }
664
665    /// A constant: value `c`, all derivatives zero.
666    pub fn constant(c: f64) -> Self {
667        let mut out = Self::zero();
668        out.v = c;
669        out
670    }
671
672    /// The seeded variable `p_idx` with current value `value`:
673    /// unit first derivative in slot `idx`, zero elsewhere and above.
674    pub fn variable(value: f64, idx: usize) -> Self {
675        let mut out = Self::constant(value);
676        out.g[idx] = 1.0;
677        out
678    }
679
680    /// Read the derivative tensor entry whose differentiation axes are
681    /// `labels` (length 0..=2): value, `g`, `h`.
682    #[inline]
683    fn deriv(&self, labels: &[usize]) -> f64 {
684        assert!(
685            labels.len() <= 2,
686            "Tower2 carries at most second-order derivatives"
687        );
688        match labels.len() {
689            0 => self.v,
690            1 => self.g[labels[0]],
691            _ => self.h[labels[0]][labels[1]],
692        }
693    }
694
695    /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` channels
696    /// match [`Tower4::mul`] term-for-term.
697    pub fn mul(&self, o: &Self) -> Self {
698        let a = self;
699        let b = o;
700        let mut out = Self::zero();
701        out.v = a.v * b.v;
702        for i in 0..K {
703            out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
704        }
705        for i in 0..K {
706            for j in 0..K {
707                out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
708            }
709        }
710        out
711    }
712
713    /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
714    ///
715    /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
716    /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
717    /// `d[0..=2]` for those orders), so this is a strict truncation, not an
718    /// approximation. The full-order `[f64; 5]` derivative stacks the families
719    /// already produce can be passed by slicing their first three entries.
720    ///
721    /// # Codegen
722    ///
723    /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
724    /// directly instead of routing through the generic
725    /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
726    /// closure dispatch). That matters because this is the kernel under EVERY
727    /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
728    /// composition all bottom out here — so the straight-line form (whose inner
729    /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
730    /// walker calls) lifts all of them at once.
731    ///
732    /// The term and accumulation order is BIT-IDENTICAL to the walker it
733    /// replaces: each output channel mirrors the walker's `total = 0.0` start
734    /// (the explicit `acc` accumulator), so a signed-zero product collapses to
735    /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
736    /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
737    /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
738    ///   `g[i]`   = `f′·u_i`                   (single block `{i}`)
739    ///   `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j`  (blocks `{ij}` then `{i}{j}`),
740    /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
741    pub fn compose_unary(&self, d: [f64; 3]) -> Self {
742        let mut out = Self::zero();
743        out.v = d[0];
744        for i in 0..K {
745            let mut acc = 0.0;
746            acc += d[1] * self.g[i];
747            out.g[i] = acc;
748        }
749        for i in 0..K {
750            for j in 0..K {
751                let mut acc = 0.0;
752                acc += d[1] * self.h[i][j];
753                acc += d[2] * self.g[i] * self.g[j];
754                out.h[i][j] = acc;
755            }
756        }
757        out
758    }
759
760    /// Multiply every channel by a plain scalar.
761    pub fn scale(&self, s: f64) -> Self {
762        let mut out = *self;
763        out.v *= s;
764        for i in 0..K {
765            out.g[i] *= s;
766            for j in 0..K {
767                out.h[i][j] *= s;
768            }
769        }
770        out
771    }
772
773    /// e^self.
774    pub fn exp(&self) -> Self {
775        let e = self.v.exp();
776        self.compose_unary([e, e, e])
777    }
778
779    /// √self. Caller guarantees positivity.
780    pub fn sqrt(&self) -> Self {
781        let u = self.v;
782        let s = u.sqrt();
783        self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
784    }
785}
786
787impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
788    #[inline]
789    fn derivative(&self, labels: &[usize]) -> f64 {
790        self.deriv(labels)
791    }
792
793    fn map_derivatives<F>(&self, mut f: F) -> Self
794    where
795        F: FnMut(&[usize]) -> f64,
796    {
797        let mut out = Self::zero();
798        out.v = f(&[]);
799        for i in 0..K {
800            let labels = [i];
801            out.g[i] = f(&labels);
802        }
803        for i in 0..K {
804            for j in 0..K {
805                let labels = [i, j];
806                out.h[i][j] = f(&labels);
807            }
808        }
809        out
810    }
811}
812
813impl<const K: usize> std::ops::Add for Tower2<K> {
814    type Output = Self;
815    fn add(self, o: Self) -> Self {
816        let mut out = self;
817        out.v += o.v;
818        for i in 0..K {
819            out.g[i] += o.g[i];
820            for j in 0..K {
821                out.h[i][j] += o.h[i][j];
822            }
823        }
824        out
825    }
826}
827
828impl<const K: usize> std::ops::Mul for Tower2<K> {
829    type Output = Self;
830    fn mul(self, o: Self) -> Self {
831        Tower2::mul(&self, &o)
832    }
833}
834
835impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
836    type Output = Self;
837    fn add(self, c: f64) -> Self {
838        let mut out = self;
839        out.v += c;
840        out
841    }
842}
843
844impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
845    type Output = Self;
846    fn mul(self, c: f64) -> Self {
847        self.scale(c)
848    }
849}
850
851/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
852///
853/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
854/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
855/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
856/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
857/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
858/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
859/// [`Tower4::compose_unary`]). So for any program written over both towers the
860/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
861/// perturb the value, gradient, Hessian, or third derivatives.
862///
863/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
864/// needs up to third derivatives (the survival location-scale row kernel reads
865/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
866/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
867/// product/composition that otherwise dominates the per-row cost.
868#[derive(Clone, Copy, Debug)]
869pub struct Tower3<const K: usize> {
870    /// Value ℓ.
871    pub v: f64,
872    /// Gradient ∂ℓ/∂p_a.
873    pub g: [f64; K],
874    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
875    pub h: [[f64; K]; K],
876    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
877    pub t3: [[[f64; K]; K]; K],
878}
879
880impl<const K: usize> Tower3<K> {
881    /// The additive identity.
882    pub fn zero() -> Self {
883        Self {
884            v: 0.0,
885            g: [0.0; K],
886            h: [[0.0; K]; K],
887            t3: [[[0.0; K]; K]; K],
888        }
889    }
890
891    /// A constant: value `c`, all derivatives zero.
892    pub fn constant(c: f64) -> Self {
893        let mut out = Self::zero();
894        out.v = c;
895        out
896    }
897
898    /// The seeded variable `p_idx` with current value `value`:
899    /// unit first derivative in slot `idx`, zero elsewhere and above.
900    pub fn variable(value: f64, idx: usize) -> Self {
901        let mut out = Self::constant(value);
902        out.g[idx] = 1.0;
903        out
904    }
905
906    /// Read the (fully symmetric) derivative tensor entry whose differentiation
907    /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
908    #[inline]
909    fn deriv(&self, labels: &[usize]) -> f64 {
910        assert!(
911            labels.len() <= 3,
912            "Tower3 carries at most third-order derivatives"
913        );
914        match labels.len() {
915            0 => self.v,
916            1 => self.g[labels[0]],
917            2 => self.h[labels[0]][labels[1]],
918            _ => self.t3[labels[0]][labels[1]][labels[2]],
919        }
920    }
921
922    /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
923    /// channels match [`Tower4::mul`] term-for-term.
924    ///
925    /// # Codegen
926    ///
927    /// Straight-line per-entry subset sums instead of the
928    /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
929    /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
930    /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
931    /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
932    /// signed-zero leading product on exact-`0.0` jet channels). Proven
933    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
934    /// zero/sign-stressed inputs each (these channel formulas are exactly the
935    /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
936    pub fn mul(&self, o: &Self) -> Self {
937        let a = self;
938        let b = o;
939        let mut out = Self::zero();
940        out.v = a.v * b.v;
941        for i in 0..K {
942            let mut acc = 0.0;
943            acc += a.v * b.g[i];
944            acc += a.g[i] * b.v;
945            out.g[i] = acc;
946        }
947        for i in 0..K {
948            for j in 0..K {
949                let mut acc = 0.0;
950                acc += a.v * b.h[i][j];
951                acc += a.g[i] * b.g[j];
952                acc += a.g[j] * b.g[i];
953                acc += a.h[i][j] * b.v;
954                out.h[i][j] = acc;
955            }
956        }
957        for i in 0..K {
958            for j in 0..K {
959                for k in 0..K {
960                    // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
961                    let mut acc = 0.0;
962                    acc += a.v * b.t3[i][j][k];
963                    acc += a.g[i] * b.h[j][k];
964                    acc += a.g[j] * b.h[i][k];
965                    acc += a.h[i][j] * b.g[k];
966                    acc += a.g[k] * b.h[i][j];
967                    acc += a.h[i][k] * b.g[j];
968                    acc += a.h[j][k] * b.g[i];
969                    acc += a.t3[i][j][k] * b.v;
970                    out.t3[i][j][k] = acc;
971                }
972            }
973        }
974        out
975    }
976
977    /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
978    /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
979    /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
980    /// out of the borrowed operands.
981    pub fn add(&self, o: &Self) -> Self {
982        *self + *o
983    }
984
985    /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
986    pub fn sub(&self, o: &Self) -> Self {
987        *self + o.scale(-1.0)
988    }
989
990    /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
991    ///
992    /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
993    /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
994    /// (which uses only `d[0..=3]` for those orders), so this is a strict
995    /// truncation, not an approximation. The full-order `[f64; 5]` derivative
996    /// stacks the families already produce can be passed by slicing their first
997    /// four entries.
998    ///
999    /// # Codegen
1000    ///
1001    /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1002    /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1003    /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1004    /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1005    /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1006    /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1007    /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1008    /// associated block products, `acc = 0.0` accumulator start. Proven
1009    /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1010    /// random inputs each.
1011    pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1012        let mut out = Self::zero();
1013        out.v = d[0];
1014        for i in 0..K {
1015            let mut acc = 0.0;
1016            acc += d[1] * self.g[i];
1017            out.g[i] = acc;
1018        }
1019        for i in 0..K {
1020            for j in 0..K {
1021                let mut acc = 0.0;
1022                acc += d[1] * self.h[i][j];
1023                acc += d[2] * self.g[i] * self.g[j];
1024                out.h[i][j] = acc;
1025            }
1026        }
1027        for i in 0..K {
1028            for j in 0..K {
1029                for k in 0..K {
1030                    // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1031                    let mut acc = 0.0;
1032                    acc += d[1] * self.t3[i][j][k];
1033                    acc += d[2] * self.h[i][j] * self.g[k];
1034                    acc += d[2] * self.h[i][k] * self.g[j];
1035                    acc += d[2] * self.g[i] * self.h[j][k];
1036                    acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1037                    out.t3[i][j][k] = acc;
1038                }
1039            }
1040        }
1041        out
1042    }
1043
1044    /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1045    /// built from the base value through `stack_fn` — the scalar arm of the
1046    /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1047    /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1048    /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1049    /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1050    /// [`Tower4::compose_unary_with`].
1051    #[inline]
1052    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1053        self.compose_unary(stack_fn(self.v))
1054    }
1055
1056    /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1057    /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1058    /// derivative support only on the all-`slot` diagonal, every output channel
1059    /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1060    /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1061    /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1062    /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1063    /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1064    /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1065    /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1066    /// set-partition walker the diagonal channels previously routed through,
1067    /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1068    /// widths. Caller guarantees the single-slot precondition; otherwise use
1069    /// [`Self::compose_unary`].
1070    #[inline]
1071    pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1072        let mut out = Self::zero();
1073        let s = slot;
1074        let g = self.g[s];
1075        let h = self.h[s][s];
1076        let t3 = self.t3[s][s][s];
1077        out.v = d[0];
1078        // g (i=s): d1*g
1079        out.g[s] = {
1080            let mut acc = 0.0;
1081            acc += d[1] * g;
1082            acc
1083        };
1084        // h (i=j=s): d1*h + d2*g*g
1085        out.h[s][s] = {
1086            let mut acc = 0.0;
1087            acc += d[1] * h;
1088            acc += d[2] * g * g;
1089            acc
1090        };
1091        // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1092        out.t3[s][s][s] = {
1093            let mut acc = 0.0;
1094            acc += d[1] * t3;
1095            acc += d[2] * h * g;
1096            acc += d[2] * h * g;
1097            acc += d[2] * g * h;
1098            acc += d[3] * g * g * g;
1099            acc
1100        };
1101        out
1102    }
1103
1104    /// Multiply every channel by a plain scalar.
1105    pub fn scale(&self, s: f64) -> Self {
1106        let mut out = *self;
1107        out.v *= s;
1108        for i in 0..K {
1109            out.g[i] *= s;
1110            for j in 0..K {
1111                out.h[i][j] *= s;
1112                for k in 0..K {
1113                    out.t3[i][j][k] *= s;
1114                }
1115            }
1116        }
1117        out
1118    }
1119}
1120
1121impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1122    #[inline]
1123    fn derivative(&self, labels: &[usize]) -> f64 {
1124        self.deriv(labels)
1125    }
1126
1127    fn map_derivatives<F>(&self, mut f: F) -> Self
1128    where
1129        F: FnMut(&[usize]) -> f64,
1130    {
1131        let mut out = Self::zero();
1132        out.v = f(&[]);
1133        for i in 0..K {
1134            let labels = [i];
1135            out.g[i] = f(&labels);
1136        }
1137        for i in 0..K {
1138            for j in 0..K {
1139                let labels = [i, j];
1140                out.h[i][j] = f(&labels);
1141            }
1142        }
1143        for i in 0..K {
1144            for j in 0..K {
1145                for k in 0..K {
1146                    let labels = [i, j, k];
1147                    out.t3[i][j][k] = f(&labels);
1148                }
1149            }
1150        }
1151        out
1152    }
1153}
1154
1155impl<const K: usize> std::ops::Add for Tower3<K> {
1156    type Output = Self;
1157    fn add(self, o: Self) -> Self {
1158        let mut out = self;
1159        out.v += o.v;
1160        for i in 0..K {
1161            out.g[i] += o.g[i];
1162            for j in 0..K {
1163                out.h[i][j] += o.h[i][j];
1164                for k in 0..K {
1165                    out.t3[i][j][k] += o.t3[i][j][k];
1166                }
1167            }
1168        }
1169        out
1170    }
1171}
1172
1173pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1174    [
1175        statrs::function::gamma::ln_gamma(x),
1176        digamma_positive(x),
1177        polygamma_positive(1, x),
1178        polygamma_positive(2, x),
1179        polygamma_positive(3, x),
1180    ]
1181}
1182
1183pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1184    [
1185        statrs::function::gamma::ln_gamma(x),
1186        digamma_positive(x),
1187        polygamma_positive(1, x),
1188    ]
1189}
1190
1191pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1192    [
1193        digamma_positive(x),
1194        polygamma_positive(1, x),
1195        polygamma_positive(2, x),
1196        polygamma_positive(3, x),
1197        polygamma_positive(4, x),
1198    ]
1199}
1200
1201pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1202    [
1203        polygamma_positive(1, x),
1204        polygamma_positive(2, x),
1205        polygamma_positive(3, x),
1206        polygamma_positive(4, x),
1207        polygamma_positive(5, x),
1208    ]
1209}
1210
1211fn digamma_positive(mut x: f64) -> f64 {
1212    if !(x.is_finite() && x > 0.0) {
1213        return f64::NAN;
1214    }
1215    let mut acc = 0.0;
1216    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1217        acc -= 1.0 / x;
1218        x += 1.0;
1219    }
1220    acc + digamma_asymptotic(x)
1221}
1222
1223fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1224    if !(x.is_finite() && x > 0.0) {
1225        return f64::NAN;
1226    }
1227    let mut acc = 0.0;
1228    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1229        acc += polygamma_recurrence_term(order, x);
1230        x += 1.0;
1231    }
1232    acc + polygamma_asymptotic(order, x)
1233}
1234
1235const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1236const BERNOULLI_EVEN: [(usize, f64); 10] = [
1237    (2, 1.0 / 6.0),
1238    (4, -1.0 / 30.0),
1239    (6, 1.0 / 42.0),
1240    (8, -1.0 / 30.0),
1241    (10, 5.0 / 66.0),
1242    (12, -691.0 / 2730.0),
1243    (14, 7.0 / 6.0),
1244    (16, -3617.0 / 510.0),
1245    (18, 43867.0 / 798.0),
1246    (20, -174611.0 / 330.0),
1247];
1248
1249fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1250    let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1251    sign * factorial(order) / x.powi((order + 1) as i32)
1252}
1253
1254fn digamma_asymptotic(x: f64) -> f64 {
1255    let mut out = x.ln() - 0.5 / x;
1256    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1257        out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1258    }
1259    out
1260}
1261
1262fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1263    if !(1..=5).contains(&order) {
1264        return f64::NAN;
1265    }
1266
1267    let order_factorial = factorial(order);
1268    let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1269    let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1270        + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1271
1272    let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1273    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1274        let rising = rising_factorial(bernoulli_order, order);
1275        out += bernoulli_sign * bernoulli * rising
1276            / bernoulli_order as f64
1277            / x.powi((bernoulli_order + order) as i32);
1278    }
1279    out
1280}
1281
1282fn factorial(n: usize) -> f64 {
1283    (1..=n).fold(1.0, |acc, k| acc * k as f64)
1284}
1285
1286fn rising_factorial(start: usize, len: usize) -> f64 {
1287    (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1288}
1289
1290impl<const K: usize> std::ops::Add for Tower4<K> {
1291    type Output = Self;
1292    fn add(self, o: Self) -> Self {
1293        let mut out = self;
1294        out.v += o.v;
1295        for i in 0..K {
1296            out.g[i] += o.g[i];
1297            for j in 0..K {
1298                out.h[i][j] += o.h[i][j];
1299                for k in 0..K {
1300                    out.t3[i][j][k] += o.t3[i][j][k];
1301                    for l in 0..K {
1302                        out.t4[i][j][k][l] += o.t4[i][j][k][l];
1303                    }
1304                }
1305            }
1306        }
1307        out
1308    }
1309}
1310
1311impl<const K: usize> std::ops::Sub for Tower4<K> {
1312    type Output = Self;
1313    fn sub(self, o: Self) -> Self {
1314        self + o.scale(-1.0)
1315    }
1316}
1317
1318impl<const K: usize> std::ops::Neg for Tower4<K> {
1319    type Output = Self;
1320    fn neg(self) -> Self {
1321        self.scale(-1.0)
1322    }
1323}
1324
1325impl<const K: usize> std::ops::Mul for Tower4<K> {
1326    type Output = Self;
1327    fn mul(self, o: Self) -> Self {
1328        Tower4::mul(&self, &o)
1329    }
1330}
1331
1332impl<const K: usize> std::ops::Div for Tower4<K> {
1333    type Output = Self;
1334    fn div(self, o: Self) -> Self {
1335        Tower4::mul(&self, &o.recip())
1336    }
1337}
1338
1339impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1340    type Output = Self;
1341    fn add(self, c: f64) -> Self {
1342        let mut out = self;
1343        out.v += c;
1344        out
1345    }
1346}
1347
1348impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1349    type Output = Self;
1350    fn sub(self, c: f64) -> Self {
1351        self + (-c)
1352    }
1353}
1354
1355impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1356    type Output = Self;
1357    fn mul(self, c: f64) -> Self {
1358        self.scale(c)
1359    }
1360}
1361
1362// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1363//
1364// The flexible survival marginal-slope row loss is NOT a free composition
1365// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1366// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1367// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1368// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1369// last hand-written one in the codebase, and the genus of #736-class
1370// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1371// this reason). These two combinators close that gap: once the constraint
1372// `F` and the integrand/boundaries are themselves towers, the intercept's
1373// derivative tower and the integral's derivative tower come out EXACTLY at
1374// every order — there is no order left to hand-code and forget.
1375
1376/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1377/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1378/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1379/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1380/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1381/// and `a0` is that solved intercept value.
1382///
1383/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1384/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1385/// derivative. This is the mechanical replacement for the hand-coded
1386/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1387/// recursion (first_full.rs) and its third/fourth-order continuations.
1388///
1389/// Method: order-by-order substitution. We build `a` incrementally; at each
1390/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1391/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1392/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1393/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1394/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1395/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1396/// auditable and exact, not a hand-expanded formula per order.
1397///
1398/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1399/// solve's strict monotonicity guard.
1400///
1401/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1402/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1403/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1404/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1405/// is guarded explicitly and re-verified by a composed-residual self-check.
1406pub fn implicit_solve<const K1: usize, const K: usize>(
1407    f: &Tower4<K1>,
1408    a0: f64,
1409) -> Result<Tower4<K>, String> {
1410    assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1411    let f_a = f.g[0];
1412    if f_a == 0.0 || !f_a.is_finite() {
1413        return Err(format!(
1414            "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1415        ));
1416    }
1417    // The expansion point must be a genuine root of F. The single Newton
1418    // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1419    // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1420    // Newton step (rather than f.v directly) makes the criterion invariant to
1421    // the magnitude of f_a / the units of F.
1422    let root_tol = 1e-9;
1423    if !f.v.is_finite() {
1424        return Err(format!(
1425            "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1426            f.v
1427        ));
1428    }
1429    let newton_step = f.v.abs() / f_a.abs();
1430    if newton_step > root_tol * (1.0 + a0.abs()) {
1431        return Err(format!(
1432            "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1433             F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1434             root_tol {root_tol:.1e} · (1 + |a0|)",
1435            f.v
1436        ));
1437    }
1438    // Start with a = constant a0 (correct through order 0). Then lift each
1439    // order in turn. Because substitute_intercept reads `a`'s order-≤m
1440    // tensors when forming G's order-m coefficient, and the order-m
1441    // coefficient of G depends on a's order-m tensor ONLY through the linear
1442    // F_a·a_m term, a single corrective pass per order is exact.
1443    let mut a = Tower4::<K>::constant(a0);
1444    for order in 1..=4 {
1445        let g = substitute_intercept(f, &a);
1446        // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1447        // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1448        // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1449        // as a plain variable in the substitution's first-order part).
1450        match order {
1451            1 => {
1452                for i in 0..K {
1453                    a.g[i] -= g.g[i] / f_a;
1454                }
1455            }
1456            2 => {
1457                for i in 0..K {
1458                    for j in 0..K {
1459                        a.h[i][j] -= g.h[i][j] / f_a;
1460                    }
1461                }
1462            }
1463            3 => {
1464                for i in 0..K {
1465                    for j in 0..K {
1466                        for k in 0..K {
1467                            a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1468                        }
1469                    }
1470                }
1471            }
1472            _ => {
1473                for i in 0..K {
1474                    for j in 0..K {
1475                        for k in 0..K {
1476                            for l in 0..K {
1477                                a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1478                            }
1479                        }
1480                    }
1481                }
1482            }
1483        }
1484    }
1485    // Self-check: the composed residual G = F∘a must vanish through order 4.
1486    // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1487    // is exactly the root requirement guarded above. Re-verify all channels
1488    // against a scale-aware floor so any arithmetic regression in the
1489    // substitution recursion is loud rather than silently shipping a
1490    // level-set expansion.
1491    let g = substitute_intercept(f, &a);
1492    let resid_tol = 1e-7 * (1.0 + f_a.abs());
1493    let mut worst = g.v.abs();
1494    for i in 0..K {
1495        worst = worst.max(g.g[i].abs());
1496        for j in 0..K {
1497            worst = worst.max(g.h[i][j].abs());
1498            for k in 0..K {
1499                worst = worst.max(g.t3[i][j][k].abs());
1500                for l in 0..K {
1501                    worst = worst.max(g.t4[i][j][k][l].abs());
1502                }
1503            }
1504        }
1505    }
1506    if !worst.is_finite() || worst > resid_tol {
1507        return Err(format!(
1508            "implicit_solve: composed residual G = F∘a does not vanish: \
1509             worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1510        ));
1511    }
1512    Ok(a)
1513}
1514
1515/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1516/// written over `K + 1` variables, returning the composite tower over the
1517/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1518///
1519/// This is the exact multivariate chain rule specialised to "slot 0 is a
1520/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1521/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1522/// point, with the slot-0 increment being the non-constant part of `a` and
1523/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1524/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1525/// so it carries derivatives exactly through order four.
1526pub fn substitute_intercept<const K1: usize, const K: usize>(
1527    f: &Tower4<K1>,
1528    a: &Tower4<K>,
1529) -> Tower4<K> {
1530    assert_eq!(K1, K + 1);
1531    // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1532    // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1533    // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1534    // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1535    // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1536    // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1537    // using tower products for the increment monomials so all θ-derivatives
1538    // propagate exactly.
1539    let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1540        if slot == 0 {
1541            // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1542            let mut d = *a;
1543            d.v = 0.0;
1544            d
1545        } else {
1546            // slot i: the increment δθ_{i-1} = seeded variable minus value.
1547            // θ centred at its expansion value has zero constant term and unit
1548            // first derivative in its own slot.
1549            let mut d = Tower4::<K>::zero();
1550            d.g[slot - 1] = 1.0;
1551            d
1552        }
1553    });
1554    // Accumulate the Taylor sum. order-0 term:
1555    let mut out = Tower4::<K>::constant(f.v);
1556    // order 1: Σ_a f.g[a] · inp[a]
1557    for a_idx in 0..K1 {
1558        out = out + inp[a_idx].scale(f.g[a_idx]);
1559    }
1560    // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1561    for a_idx in 0..K1 {
1562        for b_idx in 0..K1 {
1563            let prod = inp[a_idx].mul(&inp[b_idx]);
1564            out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1565        }
1566    }
1567    // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1568    for a_idx in 0..K1 {
1569        for b_idx in 0..K1 {
1570            for c_idx in 0..K1 {
1571                let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1572                out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1573            }
1574        }
1575    }
1576    // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1577    for a_idx in 0..K1 {
1578        for b_idx in 0..K1 {
1579            for c_idx in 0..K1 {
1580                for d_idx in 0..K1 {
1581                    let prod = inp[a_idx]
1582                        .mul(&inp[b_idx])
1583                        .mul(&inp[c_idx])
1584                        .mul(&inp[d_idx]);
1585                    out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1586                }
1587            }
1588        }
1589    }
1590    out
1591}
1592
1593/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1594/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1595/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1596/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1597/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1598///
1599/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1600/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1601/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1602/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1603/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1604/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1605/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1606/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1607/// to a constant); callers read only the derivative channels and pair this
1608/// with the interior moment-integral value separately.
1609///
1610/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1611/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1612/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1613pub fn moving_limit_boundary_tower<const K: usize>(
1614    z_edge: &Tower4<K>,
1615    b_stack: [f64; 4],
1616) -> Tower4<K> {
1617    z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1618}
1619
1620/// The boundary-flux derivative tower of a single moving cell integral
1621/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1622/// two edge towers and the integrand stacks at each edge. The returned
1623/// tower's derivative channels are the EXACT moving-boundary contribution to
1624/// every θ-derivative of the cell integral, to fourth order, with no term
1625/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1626/// derivative channels are all zero, contributing nothing — matching the
1627/// production `edge_vel = 0` short-circuit.
1628pub fn cell_moving_boundary_flux_tower<const K: usize>(
1629    z_right: &Tower4<K>,
1630    b_stack_right: [f64; 4],
1631    z_left: &Tower4<K>,
1632    b_stack_left: [f64; 4],
1633) -> Tower4<K> {
1634    moving_limit_boundary_tower(z_right, b_stack_right)
1635        - moving_limit_boundary_tower(z_left, b_stack_left)
1636}
1637
1638/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1639///
1640/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1641/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1642/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1643/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1644/// integrand coefficients move with η, hence with the primaries), so the
1645/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1646/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1647/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1648/// edge-motion-carrying terms the hand path assembles one by one, including the
1649/// `G·z_uv` term the directional path drops).
1650///
1651/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1652/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1653/// part — everything carrying edge motion — is exactly
1654///   `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1655/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1656/// interior moment integral already supplies. Both are one
1657/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1658/// θ in slots 1..K): substituting the edge tower gives the full composite,
1659/// substituting a frozen constant edge isolates the pure-θ part, and their
1660/// difference is the exact boundary flux — every Leibniz term derived by the
1661/// substitution algebra, none hand-omitted.
1662///
1663/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1664/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1665/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1666/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1667/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1668/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1669/// the derivative channels are meaningful, matching the value-less convention of
1670/// [`moving_limit_boundary_tower`].
1671pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1672    phi_jet: &Tower4<K1>,
1673    z_edge: &Tower4<K>,
1674) -> Tower4<K> {
1675    assert_eq!(
1676        K1,
1677        K + 1,
1678        "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1679    );
1680    let frozen_edge = Tower4::<K>::constant(z_edge.v);
1681    let full = substitute_intercept(phi_jet, z_edge);
1682    let interior = substitute_intercept(phi_jet, &frozen_edge);
1683    full - interior
1684}
1685
1686/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1687/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1688/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1689/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1690/// so its `full` and `interior` substitutions coincide and it contributes
1691/// nothing — matching the production `edge_vel = 0` short-circuit.
1692pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1693    phi_jet_right: &Tower4<K1>,
1694    z_right: &Tower4<K>,
1695    phi_jet_left: &Tower4<K1>,
1696    z_left: &Tower4<K>,
1697) -> Tower4<K> {
1698    moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1699        - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1700}
1701
1702// ── The program seam ─────────────────────────────────────────────────
1703
1704/// A family's row negative log-likelihood written ONCE over tower scalars.
1705///
1706/// This is the single source of truth #932 asks for: the value channel of
1707/// the returned tower must BE the production row NLL (same branches, same
1708/// guards, same numerics), and every derivative channel is then exact by
1709/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1710/// NOT part of this trait — it is family data, not calculus, and stays on
1711/// the `RowKernel` implementor.
1712pub trait RowNllProgram<const K: usize>: Send + Sync {
1713    /// Number of observations the program covers.
1714    fn n_rows(&self) -> usize;
1715
1716    /// Current primary-scalar values for `row` (where to seed the tower).
1717    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1718
1719    /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1720    /// variable `a` at the current primary value; implementations combine
1721    /// them with `Tower4` arithmetic and per-row data (response, censoring
1722    /// indicators, offsets) entering as constants.
1723    fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1724}
1725
1726/// Evaluate a program's full tower at the current primaries for one row.
1727///
1728/// One call yields every `RowKernel` calculus channel; callers that need
1729/// several contractions of the same row should hold the returned tower and
1730/// contract repeatedly rather than re-evaluating.
1731pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1732    prog: &P,
1733    row: usize,
1734) -> Result<Tower4<K>, String> {
1735    let p = prog.primaries(row)?;
1736    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1737    prog.row_nll(row, &vars)
1738}
1739
1740/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1741pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1742    prog: &P,
1743    row: usize,
1744) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1745    let t = evaluate_program(prog, row)?;
1746    Ok((t.v, t.g, t.h))
1747}
1748
1749/// Mechanically derived `row_third_contracted` channel.
1750pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1751    prog: &P,
1752    row: usize,
1753    dir: &[f64; K],
1754) -> Result<[[f64; K]; K], String> {
1755    Ok(evaluate_program(prog, row)?.third_contracted(dir))
1756}
1757
1758/// Mechanically derived `row_fourth_contracted` channel.
1759pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1760    prog: &P,
1761    row: usize,
1762    dir_u: &[f64; K],
1763    dir_v: &[f64; K],
1764) -> Result<[[f64; K]; K], String> {
1765    Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1766}
1767
1768// ── The generic program seam (#932 scalar cutover) ───────────────────
1769
1770/// A family's row negative log-likelihood written ONCE over the generic
1771/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1772/// re-instantiated at whatever order / representation a consumer needs
1773/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1774/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1775/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1776/// [`Tower4`] for every channel at once).
1777///
1778/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1779/// program implementing this generic trait gets the small contracted scalars for
1780/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1781/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1782/// work unchanged; new families should prefer this generic trait.
1783///
1784/// Because a `Tower4`-specialised `row_nll` body uses only
1785/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1786/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1787/// A program written that way needs no `Tower4`-specialised method and routes
1788/// the directional and joint-Hessian gates through the contracted scalars from
1789/// a single definition.
1790pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1791    /// Number of observations the program covers.
1792    fn n_rows(&self) -> usize;
1793
1794    /// Current primary-scalar values for `row` (where to seed the scalar).
1795    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1796
1797    /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1798    /// (base value + per-scalar nilpotent directions) by the caller; the body
1799    /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1800    /// (response, censoring, offsets) entering as constants.
1801    fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1802        &self,
1803        row: usize,
1804        p: &[S; K],
1805    ) -> Result<S, String>;
1806}
1807
1808/// Evaluate a generic program at the value/gradient/Hessian scalar
1809/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1810/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1811///
1812/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1813/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1814/// cheap order-2 scalar. The single source of truth means the gradient and
1815/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1816pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1817    prog: &P,
1818    row: usize,
1819) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1820    let base = prog.primaries(row)?;
1821    let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1822        <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1823    });
1824    let s = prog.row_nll_generic(row, &vars)?;
1825    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1826}
1827
1828/// Evaluate a generic program at the one-seed scalar
1829/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1830/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1831/// materialising the dense `t3` tensor. The contraction direction is folded
1832/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1833pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1834    prog: &P,
1835    row: usize,
1836    dir: &[f64; K],
1837) -> Result<[[f64; K]; K], String> {
1838    let base = prog.primaries(row)?;
1839    let vars: [crate::jet_scalar::OneSeed<K>; K] =
1840        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1841    let s = prog.row_nll_generic(row, &vars)?;
1842    Ok(s.contracted_third())
1843}
1844
1845/// Evaluate a generic program at the two-seed scalar
1846/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1847/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1848/// WITHOUT materialising the dense `t4` tensor.
1849pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1850    prog: &P,
1851    row: usize,
1852    dir_u: &[f64; K],
1853    dir_v: &[f64; K],
1854) -> Result<[[f64; K]; K], String> {
1855    let base = prog.primaries(row)?;
1856    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1857        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1858    let s = prog.row_nll_generic(row, &vars)?;
1859    Ok(s.contracted_fourth())
1860}
1861
1862/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1863/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1864/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1865/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1866/// the order-2 / contracted scalars consume, so there is a single source of
1867/// truth across every channel.
1868pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1869    prog: &P,
1870    row: usize,
1871) -> Result<Tower4<K>, String> {
1872    let base = prog.primaries(row)?;
1873    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1874    prog.row_nll_generic(row, &vars)
1875}
1876
1877// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1878//
1879// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1880// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1881// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1882// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1883// as separate inherent methods, not a shared trait bound — so a row-NLL body
1884// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1885// 4-rows-per-pass SIMD batch path could not reuse the single source.
1886//
1887// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1888// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1889// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1890// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1891// (so the scalar call sites compile unchanged and bit-identically), and explicit
1892// impls route the `f64x4` lane towers through their existing per-lane methods. A
1893// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1894// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1895
1896/// The verdict of a per-lane [`RowJet::guard`] domain check.
1897///
1898/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1899/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1900/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1901/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1902/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1903/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1904#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1905pub struct GuardVerdict {
1906    lanes: u8,
1907    failed_mask: u8,
1908}
1909
1910impl GuardVerdict {
1911    /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1912    #[inline]
1913    pub fn scalar(pass: bool) -> Self {
1914        Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1915    }
1916    /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1917    #[inline]
1918    pub fn lanes4(failed_mask: u8) -> Self {
1919        Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1920    }
1921    /// Number of active lanes inspected (1 scalar, 4 batch).
1922    #[inline]
1923    pub fn lanes(self) -> usize {
1924        self.lanes as usize
1925    }
1926    /// True iff every inspected lane satisfied the predicate.
1927    #[inline]
1928    pub fn all_pass(self) -> bool {
1929        self.failed_mask == 0
1930    }
1931    /// True iff at least one inspected lane failed the predicate.
1932    #[inline]
1933    pub fn any_failed(self) -> bool {
1934        self.failed_mask != 0
1935    }
1936    /// True iff lane `i` failed the predicate.
1937    #[inline]
1938    pub fn lane_failed(self, i: usize) -> bool {
1939        (self.failed_mask >> i) & 1 == 1
1940    }
1941    /// The raw failure mask (bit `i` ⇒ lane `i` failed).
1942    #[inline]
1943    pub fn failed_mask(self) -> u8 {
1944        self.failed_mask
1945    }
1946}
1947
1948/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
1949/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
1950/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
1951/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
1952/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
1953/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
1954/// common `N == 5` case is bit-identical to passing the stack straight through).
1955#[inline]
1956fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
1957    let mut out = [0.0_f64; M];
1958    let m = N.min(M);
1959    out[..m].copy_from_slice(&s[..m]);
1960    out
1961}
1962
1963/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
1964/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
1965/// rows/pass without a dual-source copy (module §"The RowJet bridge").
1966///
1967/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
1968/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
1969/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
1970/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
1971/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
1972pub trait RowJet<const K: usize>: Copy {
1973    /// The value channel(s) seen by [`guard`](Self::guard) and
1974    /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
1975    /// `f64x4` lane tower.
1976    type Value: Copy;
1977
1978    /// A constant (value `c`, all derivatives zero), broadcast to every lane.
1979    fn constant(c: f64) -> Self;
1980    /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
1981    /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
1982    /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
1983    /// [`generic_batched_third_tower`]), which build the tower variables directly
1984    /// from each row's primaries; this method is for any row-invariant auxiliary
1985    /// variable a body introduces.
1986    fn variable(x: f64, slot: usize) -> Self;
1987    /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
1988    fn values(&self) -> Self::Value;
1989
1990    /// Truncated Leibniz `self + o`.
1991    fn add(&self, o: &Self) -> Self;
1992    /// Truncated Leibniz `self − o`.
1993    fn sub(&self, o: &Self) -> Self;
1994    /// Truncated Leibniz `self · o`.
1995    fn mul(&self, o: &Self) -> Self;
1996    /// Multiply every channel by the plain scalar `s`.
1997    fn scale(&self, s: f64) -> Self;
1998    /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
1999    /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2000    fn neg(&self) -> Self {
2001        self.scale(-1.0)
2002    }
2003
2004    /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2005    /// derivative stack is built from the running base value PER LANE through
2006    /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2007    /// inherent method that already exists on both the scalar towers and the lane
2008    /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2009    /// lane tower it is re-run per lane (the four rows carry four distinct base
2010    /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2011    /// Making it a trait method is precisely what lets a body written once over
2012    /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2013    /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2014    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2015
2016    /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2017    /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2018    /// its one value; a lane tower checks all four. Lets a batched program detect
2019    /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2020    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2021
2022    /// Per-lane scale: multiply every channel by the per-lane factor `s`
2023    /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2024    /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2025    /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2026    /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2027    /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2028    /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2029    /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2030    /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2031    /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2032    fn scale_rows(&self, s: Self::Value) -> Self;
2033
2034    /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2035    /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2036    /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2037    /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2038    /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2039    /// knowing the concrete representation: the program holds the per-row data and
2040    /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2041    /// [`RowNllProgramRowJet::row_nll`], so the body writes
2042    /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2043    /// buried in a `compose_unary_with` stack is pulled out the same way:
2044    /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2045    /// (Binary per-row branches such as the event indicator `di` are kept
2046    /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2047    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2048
2049    // ── value-derived transcendental conveniences ───────────────────────
2050    // Each routes through `compose_unary_with` with the SAME derivative stack the
2051    // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2052    // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2053    // `i` is bit-identical to the scalar result on row `i`.
2054
2055    /// `e^self`.
2056    fn exp(&self) -> Self {
2057        self.compose_unary_with(|u| {
2058            let e = u.exp();
2059            [e, e, e, e, e]
2060        })
2061    }
2062    /// `ln(self)`. Caller guarantees positivity.
2063    fn ln(&self) -> Self {
2064        self.compose_unary_with(|u| {
2065            let r = 1.0 / u;
2066            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2067        })
2068    }
2069    /// `√self`. Caller guarantees positivity.
2070    fn sqrt(&self) -> Self {
2071        self.compose_unary_with(|u| {
2072            let s = u.sqrt();
2073            [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2074        })
2075    }
2076    /// `1/self`.
2077    fn recip(&self) -> Self {
2078        self.compose_unary_with(|u| {
2079            let r = 1.0 / u;
2080            let r2 = r * r;
2081            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2082        })
2083    }
2084    /// `self^a` for real `a`. Caller guarantees a positive base.
2085    fn powf(&self, a: f64) -> Self {
2086        self.compose_unary_with(move |u| {
2087            [
2088                u.powf(a),
2089                a * u.powf(a - 1.0),
2090                a * (a - 1.0) * u.powf(a - 2.0),
2091                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2092                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2093            ]
2094        })
2095    }
2096    /// `ln Γ(self)`. Caller guarantees a positive argument.
2097    fn ln_gamma(&self) -> Self {
2098        self.compose_unary_with(ln_gamma_derivative_stack)
2099    }
2100    /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2101    fn digamma(&self) -> Self {
2102        self.compose_unary_with(digamma_derivative_stack)
2103    }
2104}
2105
2106/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2107/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2108/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2109/// bridge adds the lane representation without churning the scalar path. (The
2110/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2111/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2112/// any downstream impl, so the coherence checker proves the impls disjoint.)
2113impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2114    type Value = f64;
2115    #[inline]
2116    fn constant(c: f64) -> Self {
2117        <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2118    }
2119    #[inline]
2120    fn variable(x: f64, slot: usize) -> Self {
2121        <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2122    }
2123    #[inline]
2124    fn values(&self) -> f64 {
2125        crate::jet_scalar::JetScalar::value(self)
2126    }
2127    #[inline]
2128    fn add(&self, o: &Self) -> Self {
2129        crate::jet_scalar::JetScalar::add(self, o)
2130    }
2131    #[inline]
2132    fn sub(&self, o: &Self) -> Self {
2133        crate::jet_scalar::JetScalar::sub(self, o)
2134    }
2135    #[inline]
2136    fn mul(&self, o: &Self) -> Self {
2137        crate::jet_scalar::JetScalar::mul(self, o)
2138    }
2139    #[inline]
2140    fn scale(&self, s: f64) -> Self {
2141        crate::jet_scalar::JetScalar::scale(self, s)
2142    }
2143    #[inline]
2144    fn neg(&self) -> Self {
2145        crate::jet_scalar::JetScalar::neg(self)
2146    }
2147    #[inline]
2148    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2149        crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2150    }
2151    #[inline]
2152    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2153        GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2154    }
2155    #[inline]
2156    fn scale_rows(&self, s: f64) -> Self {
2157        // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2158        // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2159        crate::jet_scalar::JetScalar::scale(self, s)
2160    }
2161    #[inline]
2162    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2163        value_of(rows[0])
2164    }
2165}
2166
2167/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2168/// routing each op through its existing per-lane method. Lane `i` of a batched
2169/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2170/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2171impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2172    type Value = [f64; 4];
2173    #[inline]
2174    fn constant(c: f64) -> Self {
2175        Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2176    }
2177    #[inline]
2178    fn variable(x: f64, slot: usize) -> Self {
2179        Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2180    }
2181    #[inline]
2182    fn values(&self) -> [f64; 4] {
2183        self.v.to_array()
2184    }
2185    #[inline]
2186    fn add(&self, o: &Self) -> Self {
2187        Tower4Lane::add(self, o)
2188    }
2189    #[inline]
2190    fn sub(&self, o: &Self) -> Self {
2191        Tower4Lane::sub(self, o)
2192    }
2193    #[inline]
2194    fn mul(&self, o: &Self) -> Self {
2195        Tower4Lane::mul(self, o)
2196    }
2197    #[inline]
2198    fn scale(&self, s: f64) -> Self {
2199        Tower4Lane::scale(self, s)
2200    }
2201    #[inline]
2202    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2203        Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2204    }
2205    #[inline]
2206    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2207        let vals = self.v.to_array();
2208        let mut mask = 0u8;
2209        for (i, &v) in vals.iter().enumerate() {
2210            if !pred(v) {
2211                mask |= 1 << i;
2212            }
2213        }
2214        GuardVerdict::lanes4(mask)
2215    }
2216    #[inline]
2217    fn scale_rows(&self, s: [f64; 4]) -> Self {
2218        // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2219        // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2220        let sl = wide::f64x4::new(s);
2221        let mut out = *self;
2222        out.v = self.v * sl;
2223        for i in 0..K {
2224            out.g[i] = self.g[i] * sl;
2225            for j in 0..K {
2226                out.h[i][j] = self.h[i][j] * sl;
2227                for k in 0..K {
2228                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2229                    for l in 0..K {
2230                        out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2231                    }
2232                }
2233            }
2234        }
2235        out
2236    }
2237    #[inline]
2238    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2239        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2240    }
2241}
2242
2243/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2244/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2245/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2246/// [`Tower3`] which also carries only up to the third tensor.
2247impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2248    type Value = [f64; 4];
2249    #[inline]
2250    fn constant(c: f64) -> Self {
2251        Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2252    }
2253    #[inline]
2254    fn variable(x: f64, slot: usize) -> Self {
2255        Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2256    }
2257    #[inline]
2258    fn values(&self) -> [f64; 4] {
2259        self.v.to_array()
2260    }
2261    #[inline]
2262    fn add(&self, o: &Self) -> Self {
2263        Tower3Lane::add(self, o)
2264    }
2265    #[inline]
2266    fn sub(&self, o: &Self) -> Self {
2267        Tower3Lane::sub(self, o)
2268    }
2269    #[inline]
2270    fn mul(&self, o: &Self) -> Self {
2271        Tower3Lane::mul(self, o)
2272    }
2273    #[inline]
2274    fn scale(&self, s: f64) -> Self {
2275        Tower3Lane::scale(self, s)
2276    }
2277    #[inline]
2278    fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2279        Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2280    }
2281    #[inline]
2282    fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2283        let vals = self.v.to_array();
2284        let mut mask = 0u8;
2285        for (i, &v) in vals.iter().enumerate() {
2286            if !pred(v) {
2287                mask |= 1 << i;
2288            }
2289        }
2290        GuardVerdict::lanes4(mask)
2291    }
2292    #[inline]
2293    fn scale_rows(&self, s: [f64; 4]) -> Self {
2294        let sl = wide::f64x4::new(s);
2295        let mut out = *self;
2296        out.v = self.v * sl;
2297        for i in 0..K {
2298            out.g[i] = self.g[i] * sl;
2299            for j in 0..K {
2300                out.h[i][j] = self.h[i][j] * sl;
2301                for k in 0..K {
2302                    out.t3[i][j][k] = self.t3[i][j][k] * sl;
2303                }
2304            }
2305        }
2306        out
2307    }
2308    #[inline]
2309    fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2310        [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2311    }
2312}
2313
2314/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2315/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2316/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2317/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2318/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2319/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2320/// and the batched channels through [`generic_batched_fourth_tower`] /
2321/// [`generic_batched_third_tower`], all from a single source.
2322pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2323    /// Number of observations the program covers.
2324    fn n_rows(&self) -> usize;
2325
2326    /// Current primary-scalar values for `row` (where to seed each lane).
2327    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2328
2329    /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2330    /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2331    /// pre-seeded by the caller (base value plus, for the directional scalars, the
2332    /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2333    /// per-row data entering through `rows`/`self` as constants.
2334    fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2335}
2336
2337/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2338/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2339/// `RowJet` twin of [`generic_row_kernel`].
2340pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2341    prog: &P,
2342    row: usize,
2343) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2344    let base = prog.primaries(row)?;
2345    let vars: [crate::jet_scalar::Order2<K>; K] =
2346        std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2347    let s = prog.row_nll(&[row], &vars)?;
2348    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2349}
2350
2351/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2352/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2353/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2354pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2355    prog: &P,
2356    row: usize,
2357    dir: &[f64; K],
2358) -> Result<[[f64; K]; K], String> {
2359    let base = prog.primaries(row)?;
2360    let vars: [crate::jet_scalar::OneSeed<K>; K] =
2361        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2362    let s = prog.row_nll(&[row], &vars)?;
2363    Ok(s.contracted_third())
2364}
2365
2366/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2367/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2368/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2369pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2370    prog: &P,
2371    row: usize,
2372    dir_u: &[f64; K],
2373    dir_v: &[f64; K],
2374) -> Result<[[f64; K]; K], String> {
2375    let base = prog.primaries(row)?;
2376    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2377        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2378    let s = prog.row_nll(&[row], &vars)?;
2379    Ok(s.contracted_fourth())
2380}
2381
2382/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2383/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2384/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2385/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2386/// the scalar [`generic_full_tower`] on `rows[i]`.
2387pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2388    prog: &P,
2389    rows: [usize; 4],
2390) -> Result<Tower4Batch<K>, String> {
2391    let bases: [[f64; K]; 4] = [
2392        prog.primaries(rows[0])?,
2393        prog.primaries(rows[1])?,
2394        prog.primaries(rows[2])?,
2395        prog.primaries(rows[3])?,
2396    ];
2397    let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2398        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2399        Tower4Batch::variable(lane_vals, a)
2400    });
2401    prog.row_nll(&rows, &vars)
2402}
2403
2404/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2405/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2406/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2407/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2408pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2409    prog: &P,
2410    rows: [usize; 4],
2411) -> Result<Tower3Batch<K>, String> {
2412    let bases: [[f64; K]; 4] = [
2413        prog.primaries(rows[0])?,
2414        prog.primaries(rows[1])?,
2415        prog.primaries(rows[2])?,
2416        prog.primaries(rows[3])?,
2417    ];
2418    let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2419        let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2420        Tower3Batch::variable(lane_vals, a)
2421    });
2422    prog.row_nll(&rows, &vars)
2423}
2424
2425// ── The oracle ───────────────────────────────────────────────────────
2426
2427/// One row's worth of hand-written kernel outputs, as claimed by a
2428/// `RowKernel` implementation, packaged for verification against the
2429/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2430/// its visibility — can be audited from its own test module.
2431pub struct KernelChannels<const K: usize> {
2432    /// Claimed `(nll, ∇, H)` from `row_kernel`.
2433    pub value: f64,
2434    /// Claimed gradient.
2435    pub gradient: [f64; K],
2436    /// Claimed Hessian.
2437    pub hessian: [[f64; K]; K],
2438    /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2439    pub third: Vec<([f64; K], [[f64; K]; K])>,
2440    /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2441    pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2442}
2443
2444/// Channel-by-channel audit of a hand-written kernel against the
2445/// single-expression tower truth. Returns `Err` naming the first channel,
2446/// index, claimed and true values on disagreement — designed as the body
2447/// of the per-family CI oracle tests (#932 deployment step 2).
2448///
2449/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2450/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2451/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2452/// towers pass without demanding bit-equality, while a tiny cross-block
2453/// entry dropped next to a huge one is still caught (it is NOT measured
2454/// against the largest entry of the whole channel — there is no per-channel
2455/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2456///
2457/// Non-finite handling is strict: a NaN on either side always fails; an
2458/// infinity passes only when both sides are the SAME signed infinity.
2459pub fn verify_kernel_channels<const K: usize>(
2460    tower: &Tower4<K>,
2461    claims: &KernelChannels<K>,
2462    rel_tol: f64,
2463) -> Result<(), String> {
2464    // Absolute floor: reuse rel_tol so a single knob controls both the
2465    // relative band and the absolute floor for entries near zero.
2466    let atol = rel_tol;
2467    let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2468        // Non-finite values never silently pass the algebraic comparison
2469        // below (any comparison with NaN is false). Handle them explicitly:
2470        // NaN on either side always errs; an infinity passes only if both
2471        // sides are the identical signed infinity.
2472        if !claim.is_finite() || !truth.is_finite() {
2473            let agree = claim.is_infinite()
2474                && truth.is_infinite()
2475                && claim.is_sign_positive() == truth.is_sign_positive();
2476            if agree {
2477                return Ok(());
2478            }
2479            return Err(format!(
2480                "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2481            ));
2482        }
2483        let band = atol + rel_tol * claim.abs().max(truth.abs());
2484        if (claim - truth).abs() > band {
2485            return Err(format!(
2486                "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2487            ));
2488        }
2489        Ok(())
2490    };
2491
2492    check("value", claims.value, tower.v)?;
2493
2494    for a in 0..K {
2495        check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2496    }
2497
2498    for a in 0..K {
2499        for b in 0..K {
2500            check(
2501                &format!("hessian[{a}][{b}]"),
2502                claims.hessian[a][b],
2503                tower.h[a][b],
2504            )?;
2505        }
2506    }
2507
2508    for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2509        let truth = tower.third_contracted(dir);
2510        for a in 0..K {
2511            for b in 0..K {
2512                check(
2513                    &format!("third[{t_idx}][{a}][{b}]"),
2514                    claim[a][b],
2515                    truth[a][b],
2516                )?;
2517            }
2518        }
2519    }
2520
2521    for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2522        let truth = tower.fourth_contracted(u, w);
2523        for a in 0..K {
2524            for b in 0..K {
2525                check(
2526                    &format!("fourth[{f_idx}][{a}][{b}]"),
2527                    claim[a][b],
2528                    truth[a][b],
2529                )?;
2530            }
2531        }
2532    }
2533
2534    Ok(())
2535}
2536
2537// ===========================================================================
2538// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2539// ===========================================================================
2540//
2541// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2542// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2543// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2544// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2545// vector pass instead of one per scalar pass.
2546//
2547// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2548// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2549// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2550// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2551// performs no fp-contraction), so lane `i` of any channel of a
2552// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2553// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2554// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2555// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2556// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2557//
2558// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2559// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2560// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2561// exactly as the scalar path already does).
2562
2563use crate::jet_scalar::Lane;
2564
2565/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2566/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2567/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2568#[derive(Clone, Copy)]
2569pub struct Tower4Lane<L: Lane, const K: usize> {
2570    /// Value channel (one entry per lane/row).
2571    pub v: L,
2572    /// Gradient `∂/∂p_a`.
2573    pub g: [L; K],
2574    /// Hessian `∂²/∂p_a∂p_b`.
2575    pub h: [[L; K]; K],
2576    /// Third tensor `∂³`.
2577    pub t3: [[[L; K]; K]; K],
2578    /// Fourth tensor `∂⁴`.
2579    pub t4: [[[[L; K]; K]; K]; K],
2580}
2581
2582/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2583pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2584
2585impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2586    /// All-zero tower (every channel `+0.0` in every lane).
2587    #[inline]
2588    pub fn zero() -> Self {
2589        let z = L::splat(0.0);
2590        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2591    }
2592    /// Constant `c` (per lane): value channel only.
2593    #[inline]
2594    pub fn constant(c: L) -> Self {
2595        let mut o = Self::zero();
2596        o.v = c;
2597        o
2598    }
2599    /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2600    /// slot `idx` (mirrors [`Tower4::variable`]).
2601    #[inline]
2602    pub fn variable(value: L, idx: usize) -> Self {
2603        let mut o = Self::constant(value);
2604        o.g[idx] = L::splat(1.0);
2605        o
2606    }
2607    /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2608    #[inline]
2609    pub fn lane(&self, i: usize) -> Tower4<K> {
2610        let mut out = Tower4::<K>::zero();
2611        out.v = self.v.lane(i);
2612        for a in 0..K {
2613            out.g[a] = self.g[a].lane(i);
2614            for b in 0..K {
2615                out.h[a][b] = self.h[a][b].lane(i);
2616                for c in 0..K {
2617                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2618                    for d in 0..K {
2619                        out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2620                    }
2621                }
2622            }
2623        }
2624        out
2625    }
2626    /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2627    #[inline]
2628    pub fn add(&self, o: &Self) -> Self {
2629        let mut out = *self;
2630        out.v = self.v.add(o.v);
2631        for i in 0..K {
2632            out.g[i] = self.g[i].add(o.g[i]);
2633            for j in 0..K {
2634                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2635                for k in 0..K {
2636                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2637                    for l in 0..K {
2638                        out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2639                    }
2640                }
2641            }
2642        }
2643        out
2644    }
2645    /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2646    #[inline]
2647    pub fn sub(&self, o: &Self) -> Self {
2648        let mut out = *self;
2649        out.v = self.v.sub(o.v);
2650        for i in 0..K {
2651            out.g[i] = self.g[i].sub(o.g[i]);
2652            for j in 0..K {
2653                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2654                for k in 0..K {
2655                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2656                    for l in 0..K {
2657                        out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2658                    }
2659                }
2660            }
2661        }
2662        out
2663    }
2664    /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2665    #[inline]
2666    pub fn scale(&self, s: f64) -> Self {
2667        let sl = L::splat(s);
2668        let mut out = *self;
2669        out.v = self.v.mul(sl);
2670        for i in 0..K {
2671            out.g[i] = self.g[i].mul(sl);
2672            for j in 0..K {
2673                out.h[i][j] = self.h[i][j].mul(sl);
2674                for k in 0..K {
2675                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2676                    for l in 0..K {
2677                        out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2678                    }
2679                }
2680            }
2681        }
2682        out
2683    }
2684    /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2685    #[inline]
2686    pub fn mul(&self, o: &Self) -> Self {
2687        let a = self;
2688        let b = o;
2689        let mut out = Self::zero();
2690        out.v = a.v.mul(b.v);
2691        for i in 0..K {
2692            let mut acc = L::splat(0.0);
2693            acc = acc.add(a.v.mul(b.g[i]));
2694            acc = acc.add(a.g[i].mul(b.v));
2695            out.g[i] = acc;
2696        }
2697        for i in 0..K {
2698            for j in 0..K {
2699                let mut acc = L::splat(0.0);
2700                acc = acc.add(a.v.mul(b.h[i][j]));
2701                acc = acc.add(a.g[i].mul(b.g[j]));
2702                acc = acc.add(a.g[j].mul(b.g[i]));
2703                acc = acc.add(a.h[i][j].mul(b.v));
2704                out.h[i][j] = acc;
2705            }
2706        }
2707        for i in 0..K {
2708            for j in 0..K {
2709                for k in 0..K {
2710                    let mut acc = L::splat(0.0);
2711                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
2712                    acc = acc.add(a.g[i].mul(b.h[j][k]));
2713                    acc = acc.add(a.g[j].mul(b.h[i][k]));
2714                    acc = acc.add(a.h[i][j].mul(b.g[k]));
2715                    acc = acc.add(a.g[k].mul(b.h[i][j]));
2716                    acc = acc.add(a.h[i][k].mul(b.g[j]));
2717                    acc = acc.add(a.h[j][k].mul(b.g[i]));
2718                    acc = acc.add(a.t3[i][j][k].mul(b.v));
2719                    out.t3[i][j][k] = acc;
2720                }
2721            }
2722        }
2723        for i in 0..K {
2724            for j in 0..K {
2725                for k in 0..K {
2726                    for l in 0..K {
2727                        let mut acc = L::splat(0.0);
2728                        acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2729                        acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2730                        acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2731                        acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2732                        acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2733                        acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2734                        acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2735                        acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2736                        acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2737                        acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2738                        acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2739                        acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2740                        acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2741                        acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2742                        acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2743                        acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2744                        out.t4[i][j][k][l] = acc;
2745                    }
2746                }
2747            }
2748        }
2749        out
2750    }
2751    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2752    /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2753    /// (build via [`Lane::unary5`] from the scalar special-function stack).
2754    #[inline]
2755    pub fn compose_unary(&self, d: [L; 5]) -> Self {
2756        let mut out = Self::zero();
2757        out.v = d[0];
2758        for i in 0..K {
2759            let mut acc = L::splat(0.0);
2760            acc = acc.add(d[1].mul(self.g[i]));
2761            out.g[i] = acc;
2762        }
2763        for i in 0..K {
2764            for j in 0..K {
2765                let mut acc = L::splat(0.0);
2766                acc = acc.add(d[1].mul(self.h[i][j]));
2767                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2768                out.h[i][j] = acc;
2769            }
2770        }
2771        for i in 0..K {
2772            for j in 0..K {
2773                for k in 0..K {
2774                    let mut acc = L::splat(0.0);
2775                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
2776                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2777                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2778                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2779                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2780                    out.t3[i][j][k] = acc;
2781                }
2782            }
2783        }
2784        for i in 0..K {
2785            for j in 0..K {
2786                for k in 0..K {
2787                    for l in 0..K {
2788                        let mut acc = L::splat(0.0);
2789                        acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2790                        acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2791                        acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2792                        acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2793                        acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2794                        acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2795                        acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2796                        acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2797                        acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2798                        acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2799                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2800                        acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2801                        acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2802                        acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2803                        acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2804                        out.t4[i][j][k][l] = acc;
2805                    }
2806                }
2807            }
2808        }
2809        out
2810    }
2811    /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2812    /// built from the base value through `stack_fn`, evaluated PER LANE — the
2813    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2814    /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2815    ///
2816    /// Each of the four lanes carries a DISTINCT base value, so the scalar
2817    /// `stack_fn` is run once per lane at that lane's own value (via
2818    /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2819    /// the composition is then the existing per-lane [`Self::compose_unary`].
2820    /// Because `unary_with` runs the identical scalar closure per lane and
2821    /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2822    /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2823    /// — which is exactly what lets a row program written against the scalar
2824    /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2825    #[inline]
2826    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2827        self.compose_unary(self.v.unary_with(stack_fn))
2828    }
2829
2830    /// Single-active-slot fast path, term-for-term lift of
2831    /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2832    #[inline]
2833    pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2834        let mut out = Self::zero();
2835        let s = slot;
2836        let g = self.g[s];
2837        let h = self.h[s][s];
2838        let t3 = self.t3[s][s][s];
2839        let t4 = self.t4[s][s][s][s];
2840        out.v = d[0];
2841        out.g[s] = {
2842            let mut acc = L::splat(0.0);
2843            acc = acc.add(d[1].mul(g));
2844            acc
2845        };
2846        out.h[s][s] = {
2847            let mut acc = L::splat(0.0);
2848            acc = acc.add(d[1].mul(h));
2849            acc = acc.add(d[2].mul(g).mul(g));
2850            acc
2851        };
2852        out.t3[s][s][s] = {
2853            let mut acc = L::splat(0.0);
2854            acc = acc.add(d[1].mul(t3));
2855            acc = acc.add(d[2].mul(h).mul(g));
2856            acc = acc.add(d[2].mul(h).mul(g));
2857            acc = acc.add(d[2].mul(g).mul(h));
2858            acc = acc.add(d[3].mul(g).mul(g).mul(g));
2859            acc
2860        };
2861        out.t4[s][s][s][s] = {
2862            let mut acc = L::splat(0.0);
2863            acc = acc.add(d[1].mul(t4));
2864            acc = acc.add(d[2].mul(t3).mul(g));
2865            acc = acc.add(d[2].mul(t3).mul(g));
2866            acc = acc.add(d[2].mul(h).mul(h));
2867            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2868            acc = acc.add(d[2].mul(t3).mul(g));
2869            acc = acc.add(d[2].mul(h).mul(h));
2870            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2871            acc = acc.add(d[2].mul(h).mul(h));
2872            acc = acc.add(d[2].mul(g).mul(t3));
2873            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2874            acc = acc.add(d[3].mul(h).mul(g).mul(g));
2875            acc = acc.add(d[3].mul(g).mul(h).mul(g));
2876            acc = acc.add(d[3].mul(g).mul(g).mul(h));
2877            acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2878            acc
2879        };
2880        out
2881    }
2882    /// Contract `t3` with a primary-space direction (lift of
2883    /// [`Tower4::third_contracted`]).
2884    #[inline]
2885    pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2886        let mut out = [[L::splat(0.0); K]; K];
2887        for a in 0..K {
2888            for b in 0..K {
2889                let mut acc = L::splat(0.0);
2890                for c in 0..K {
2891                    acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2892                }
2893                out[a][b] = acc;
2894            }
2895        }
2896        out
2897    }
2898    /// Contract `t4` with two primary-space directions (lift of
2899    /// [`Tower4::fourth_contracted`]).
2900    #[inline]
2901    pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2902        let mut out = [[L::splat(0.0); K]; K];
2903        for i in 0..K {
2904            for j in 0..K {
2905                let mut acc = L::splat(0.0);
2906                for k in 0..K {
2907                    for l in 0..K {
2908                        acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2909                    }
2910                }
2911                out[i][j] = acc;
2912            }
2913        }
2914        out
2915    }
2916}
2917
2918/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2919#[derive(Clone, Copy)]
2920pub struct Tower3Lane<L: Lane, const K: usize> {
2921    /// Value channel.
2922    pub v: L,
2923    /// Gradient.
2924    pub g: [L; K],
2925    /// Hessian.
2926    pub h: [[L; K]; K],
2927    /// Third tensor.
2928    pub t3: [[[L; K]; K]; K],
2929}
2930
2931/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
2932pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
2933
2934impl<L: Lane, const K: usize> Tower3Lane<L, K> {
2935    /// All-zero tower.
2936    #[inline]
2937    pub fn zero() -> Self {
2938        let z = L::splat(0.0);
2939        Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
2940    }
2941    /// Constant `c` (per lane).
2942    #[inline]
2943    pub fn constant(c: L) -> Self {
2944        let mut o = Self::zero();
2945        o.v = c;
2946        o
2947    }
2948    /// Seeded variable `p_idx` at per-lane `value`.
2949    #[inline]
2950    pub fn variable(value: L, idx: usize) -> Self {
2951        let mut o = Self::constant(value);
2952        o.g[idx] = L::splat(1.0);
2953        o
2954    }
2955    /// Extract lane `i` as a scalar [`Tower3<K>`].
2956    #[inline]
2957    pub fn lane(&self, i: usize) -> Tower3<K> {
2958        let mut out = Tower3::<K>::zero();
2959        out.v = self.v.lane(i);
2960        for a in 0..K {
2961            out.g[a] = self.g[a].lane(i);
2962            for b in 0..K {
2963                out.h[a][b] = self.h[a][b].lane(i);
2964                for c in 0..K {
2965                    out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2966                }
2967            }
2968        }
2969        out
2970    }
2971    /// Per-channel lane-wise `self + o`.
2972    #[inline]
2973    pub fn add(&self, o: &Self) -> Self {
2974        let mut out = *self;
2975        out.v = self.v.add(o.v);
2976        for i in 0..K {
2977            out.g[i] = self.g[i].add(o.g[i]);
2978            for j in 0..K {
2979                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2980                for k in 0..K {
2981                    out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2982                }
2983            }
2984        }
2985        out
2986    }
2987    /// Per-channel lane-wise `self - o`.
2988    #[inline]
2989    pub fn sub(&self, o: &Self) -> Self {
2990        let mut out = *self;
2991        out.v = self.v.sub(o.v);
2992        for i in 0..K {
2993            out.g[i] = self.g[i].sub(o.g[i]);
2994            for j in 0..K {
2995                out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2996                for k in 0..K {
2997                    out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2998                }
2999            }
3000        }
3001        out
3002    }
3003    /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3004    #[inline]
3005    pub fn scale(&self, s: f64) -> Self {
3006        let sl = L::splat(s);
3007        let mut out = *self;
3008        out.v = self.v.mul(sl);
3009        for i in 0..K {
3010            out.g[i] = self.g[i].mul(sl);
3011            for j in 0..K {
3012                out.h[i][j] = self.h[i][j].mul(sl);
3013                for k in 0..K {
3014                    out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3015                }
3016            }
3017        }
3018        out
3019    }
3020    /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3021    #[inline]
3022    pub fn mul(&self, o: &Self) -> Self {
3023        let a = self;
3024        let b = o;
3025        let mut out = Self::zero();
3026        out.v = a.v.mul(b.v);
3027        for i in 0..K {
3028            let mut acc = L::splat(0.0);
3029            acc = acc.add(a.v.mul(b.g[i]));
3030            acc = acc.add(a.g[i].mul(b.v));
3031            out.g[i] = acc;
3032        }
3033        for i in 0..K {
3034            for j in 0..K {
3035                let mut acc = L::splat(0.0);
3036                acc = acc.add(a.v.mul(b.h[i][j]));
3037                acc = acc.add(a.g[i].mul(b.g[j]));
3038                acc = acc.add(a.g[j].mul(b.g[i]));
3039                acc = acc.add(a.h[i][j].mul(b.v));
3040                out.h[i][j] = acc;
3041            }
3042        }
3043        for i in 0..K {
3044            for j in 0..K {
3045                for k in 0..K {
3046                    let mut acc = L::splat(0.0);
3047                    acc = acc.add(a.v.mul(b.t3[i][j][k]));
3048                    acc = acc.add(a.g[i].mul(b.h[j][k]));
3049                    acc = acc.add(a.g[j].mul(b.h[i][k]));
3050                    acc = acc.add(a.h[i][j].mul(b.g[k]));
3051                    acc = acc.add(a.g[k].mul(b.h[i][j]));
3052                    acc = acc.add(a.h[i][k].mul(b.g[j]));
3053                    acc = acc.add(a.h[j][k].mul(b.g[i]));
3054                    acc = acc.add(a.t3[i][j][k].mul(b.v));
3055                    out.t3[i][j][k] = acc;
3056                }
3057            }
3058        }
3059        out
3060    }
3061    /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3062    /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3063    #[inline]
3064    pub fn compose_unary(&self, d: [L; 4]) -> Self {
3065        let mut out = Self::zero();
3066        out.v = d[0];
3067        for i in 0..K {
3068            let mut acc = L::splat(0.0);
3069            acc = acc.add(d[1].mul(self.g[i]));
3070            out.g[i] = acc;
3071        }
3072        for i in 0..K {
3073            for j in 0..K {
3074                let mut acc = L::splat(0.0);
3075                acc = acc.add(d[1].mul(self.h[i][j]));
3076                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3077                out.h[i][j] = acc;
3078            }
3079        }
3080        for i in 0..K {
3081            for j in 0..K {
3082                for k in 0..K {
3083                    let mut acc = L::splat(0.0);
3084                    acc = acc.add(d[1].mul(self.t3[i][j][k]));
3085                    acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3086                    acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3087                    acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3088                    acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3089                    out.t3[i][j][k] = acc;
3090                }
3091            }
3092        }
3093        out
3094    }
3095    /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3096    /// built from the base value through `stack_fn`, evaluated PER LANE — the
3097    /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3098    /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3099    /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3100    /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3101    /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3102    /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3103    #[inline]
3104    pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3105        self.compose_unary(self.v.unary_with(stack_fn))
3106    }
3107
3108    /// Single-active-slot fast path, term-for-term lift of
3109    /// [`Tower3::compose_unary_single_slot`].
3110    #[inline]
3111    pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3112        let mut out = Self::zero();
3113        let s = slot;
3114        let g = self.g[s];
3115        let h = self.h[s][s];
3116        let t3 = self.t3[s][s][s];
3117        out.v = d[0];
3118        out.g[s] = {
3119            let mut acc = L::splat(0.0);
3120            acc = acc.add(d[1].mul(g));
3121            acc
3122        };
3123        out.h[s][s] = {
3124            let mut acc = L::splat(0.0);
3125            acc = acc.add(d[1].mul(h));
3126            acc = acc.add(d[2].mul(g).mul(g));
3127            acc
3128        };
3129        out.t3[s][s][s] = {
3130            let mut acc = L::splat(0.0);
3131            acc = acc.add(d[1].mul(t3));
3132            acc = acc.add(d[2].mul(h).mul(g));
3133            acc = acc.add(d[2].mul(h).mul(g));
3134            acc = acc.add(d[2].mul(g).mul(h));
3135            acc = acc.add(d[3].mul(g).mul(g).mul(g));
3136            acc
3137        };
3138        out
3139    }
3140}
3141
3142#[cfg(test)]
3143mod batch_tests {
3144    use super::*;
3145
3146    struct Rng(u64);
3147    impl Rng {
3148        fn f(&mut self) -> f64 {
3149            self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3150            ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3151        }
3152    }
3153
3154    // Fill every channel of a scalar Tower4<K> with random data.
3155    fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3156        let mut t = Tower4::<K>::zero();
3157        t.v = r.f();
3158        for i in 0..K {
3159            t.g[i] = r.f();
3160            for j in 0..K {
3161                t.h[i][j] = r.f();
3162                for k in 0..K {
3163                    t.t3[i][j][k] = r.f();
3164                    for l in 0..K {
3165                        t.t4[i][j][k][l] = r.f();
3166                    }
3167                }
3168            }
3169        }
3170        t
3171    }
3172    fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3173        let mut t = Tower3::<K>::zero();
3174        t.v = r.f();
3175        for i in 0..K {
3176            t.g[i] = r.f();
3177            for j in 0..K {
3178                t.h[i][j] = r.f();
3179                for k in 0..K {
3180                    t.t3[i][j][k] = r.f();
3181                }
3182            }
3183        }
3184        t
3185    }
3186    fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3187        let mut b = Tower4Batch::<K>::zero();
3188        let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3189            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3190        };
3191        b.v = lane(&|t| t.v);
3192        for i in 0..K {
3193            b.g[i] = lane(&|t| t.g[i]);
3194            for j in 0..K {
3195                b.h[i][j] = lane(&|t| t.h[i][j]);
3196                for k in 0..K {
3197                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3198                    for l in 0..K {
3199                        b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3200                    }
3201                }
3202            }
3203        }
3204        b
3205    }
3206    fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3207        let mut b = Tower3Batch::<K>::zero();
3208        let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3209            wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3210        };
3211        b.v = lane(&|t| t.v);
3212        for i in 0..K {
3213            b.g[i] = lane(&|t| t.g[i]);
3214            for j in 0..K {
3215                b.h[i][j] = lane(&|t| t.h[i][j]);
3216                for k in 0..K {
3217                    b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3218                }
3219            }
3220        }
3221        b
3222    }
3223    fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3224        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3225        for i in 0..K {
3226            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3227            for j in 0..K {
3228                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3229                for k in 0..K {
3230                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3231                    for l in 0..K {
3232                        assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3233                    }
3234                }
3235            }
3236        }
3237    }
3238    fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3239        assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3240        for i in 0..K {
3241            assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3242            for j in 0..K {
3243                assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3244                for k in 0..K {
3245                    assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3246                }
3247            }
3248        }
3249    }
3250
3251    // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3252    // then assert every channel of every lane is to_bits-identical.
3253    fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3254        let mut r = Rng(seed);
3255        let mut rows_checked = 0;
3256        for _ in 0..batches {
3257            let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3258            let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3259            let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3260            let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3261            let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3262            let s = r.f();
3263
3264            // scalar per-row reference
3265            let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3266                let prod = a[rw].mul(&b[rw]);
3267                let comp = prod.compose_unary(d[rw]);
3268                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3269                summed.compose_unary_single_slot(d[rw], 0)
3270            });
3271            let third: [[[f64; K]; K]; 4] =
3272                std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3273            let fourth: [[[f64; K]; K]; 4] =
3274                std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3275
3276            // batched f64x4
3277            let ab = pack4_t4(&a);
3278            let bb = pack4_t4(&b);
3279            let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3280                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3281            });
3282            let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3283                wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3284            });
3285            let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3286                wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3287            });
3288            let prodb = ab.mul(&bb);
3289            let compb = prodb.compose_unary(db);
3290            let summedb = compb.add(&ab).sub(&bb).scale(s);
3291            let finalb = summedb.compose_unary_single_slot(db, 0);
3292            let thirdb = ab.third_contracted(&dirb);
3293            let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3294
3295            for rw in 0..4 {
3296                assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3297                for i in 0..K {
3298                    for j in 0..K {
3299                        assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3300                        assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3301                    }
3302                }
3303                rows_checked += 1;
3304            }
3305        }
3306        rows_checked
3307    }
3308    fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3309        let mut r = Rng(seed);
3310        let mut rows_checked = 0;
3311        for _ in 0..batches {
3312            let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3313            let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3314            let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3315            let s = r.f();
3316            let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3317                let prod = a[rw].mul(&b[rw]);
3318                let comp = prod.compose_unary(d[rw]);
3319                let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3320                summed.compose_unary_single_slot(d[rw], 0)
3321            });
3322            let ab = pack4_t3(&a);
3323            let bb = pack4_t3(&b);
3324            let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3325                wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3326            });
3327            let prodb = ab.mul(&bb);
3328            let compb = prodb.compose_unary(db);
3329            let summedb = compb.add(&ab).sub(&bb).scale(s);
3330            let finalb = summedb.compose_unary_single_slot(db, 0);
3331            for rw in 0..4 {
3332                assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3333                rows_checked += 1;
3334            }
3335        }
3336        rows_checked
3337    }
3338
3339    // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3340    // lanes (≈210 KiB by value); the op chain keeps several live, which can
3341    // exceed a test thread's default stack. Run each width on a large-stack
3342    // thread so K=9 is exercised without a stack overflow.
3343    fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3344        std::thread::Builder::new()
3345            .stack_size(512 << 20)
3346            .spawn(f)
3347            .unwrap()
3348            .join()
3349            .unwrap()
3350    }
3351
3352    #[test]
3353    fn tower4_batch_lane_bit_identical() {
3354        let batches = 2000;
3355        let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3356            + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3357            + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3358            + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3359        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3360        // worker threads against silently running zero comparisons.
3361        assert_eq!(rows_checked, 4 * batches * 4);
3362    }
3363
3364    #[test]
3365    fn tower3_batch_lane_bit_identical() {
3366        let batches = 2000;
3367        let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3368            + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3369            + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3370            + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3371        // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3372        // worker threads against silently running zero comparisons.
3373        assert_eq!(rows_checked, 4 * batches * 4);
3374    }
3375
3376    // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3377    //
3378    // The seam lets a single-sourced row program build its special-function
3379    // STACK from the base value through a closure, so the SAME expression
3380    // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3381    // per-lane bases). These oracles pin both arms `to_bits`.
3382
3383    /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3384    /// standing in for a family's hand-certified special-function stack. `stack4`
3385    /// is its order-≤3 truncation.
3386    fn seam_stack5(u: f64) -> [f64; 5] {
3387        [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3388    }
3389    fn seam_stack4(u: f64) -> [f64; 4] {
3390        let s = seam_stack5(u);
3391        [s[0], s[1], s[2], s[3]]
3392    }
3393
3394    /// Force a distinct / edge per-lane base value (signed zeros included).
3395    fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3396        match which {
3397            0 => -0.0,
3398            1 => 0.0,
3399            2 => r.f(),
3400            _ => r.f() + 3.0,
3401        }
3402    }
3403
3404    /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3405    /// the explicit `compose_unary(f(value))` on every channel.
3406    fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3407        let mut r = Rng(seed);
3408        for _ in 0..n {
3409            let mut t = rand_t4::<K>(&mut r);
3410            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3411            assert_t4_eq(
3412                &t.compose_unary_with(seam_stack5),
3413                &t.compose_unary(seam_stack5(t.v)),
3414                "scalar t4 seam",
3415            );
3416        }
3417        n
3418    }
3419    fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3420        let mut r = Rng(seed);
3421        for _ in 0..n {
3422            let mut t = rand_t3::<K>(&mut r);
3423            t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3424            assert_t3_eq(
3425                &t.compose_unary_with(seam_stack4),
3426                &t.compose_unary(seam_stack4(t.v)),
3427                "scalar t3 seam",
3428            );
3429        }
3430        n
3431    }
3432
3433    /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3434    /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3435    /// with the four lanes carrying DISTINCT base values (signed zeros included),
3436    /// so a buggy impl reusing one lane's base would fail.
3437    fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3438        let mut r = Rng(seed);
3439        let mut verified = 0usize;
3440        for _ in 0..batches {
3441            let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3442            for (rw, row) in rows.iter_mut().enumerate() {
3443                row.v = seam_edge_base(&mut r, rw);
3444            }
3445            let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3446            for (rw, row) in rows.iter().enumerate() {
3447                assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3448                verified += 1;
3449            }
3450        }
3451        verified
3452    }
3453    fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3454        let mut r = Rng(seed);
3455        let mut verified = 0usize;
3456        for _ in 0..batches {
3457            let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3458            for (rw, row) in rows.iter_mut().enumerate() {
3459                row.v = seam_edge_base(&mut r, rw);
3460            }
3461            let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3462            for (rw, row) in rows.iter().enumerate() {
3463                assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3464                verified += 1;
3465            }
3466        }
3467        verified
3468    }
3469
3470    #[test]
3471    fn compose_unary_with_scalar_bit_identical() {
3472        let n = 1100;
3473        let total = scalar_seam_t4::<2>(0x2200_0001, n)
3474            + scalar_seam_t4::<3>(0x2200_0002, n)
3475            + scalar_seam_t4::<4>(0x2200_0003, n)
3476            + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3477            + scalar_seam_t3::<2>(0x3300_0001, n)
3478            + scalar_seam_t3::<3>(0x3300_0002, n)
3479            + scalar_seam_t3::<4>(0x3300_0003, n)
3480            + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3481        // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3482        assert_eq!(total, 8 * n);
3483    }
3484
3485    #[test]
3486    fn compose_unary_with_lane_matches_scalar() {
3487        let b = 600;
3488        let total = lane_seam_t4::<2>(0x4400_0001, b)
3489            + lane_seam_t4::<3>(0x4400_0002, b)
3490            + lane_seam_t4::<4>(0x4400_0003, b)
3491            + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3492            + lane_seam_t3::<2>(0x5500_0001, b)
3493            + lane_seam_t3::<3>(0x5500_0002, b)
3494            + lane_seam_t3::<4>(0x5500_0003, b)
3495            + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3496        // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3497        assert_eq!(total, 8 * b * 4);
3498    }
3499}
3500
3501#[cfg(test)]
3502mod tests {
3503    use super::*;
3504
3505    /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3506    /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3507    /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3508    /// dropping the fourth tensor cannot perturb them. Exercises products
3509    /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3510    /// same operations the survival location-scale `nll_index_tower` composes —
3511    /// across all mixed partials, not just the diagonal entries that kernel reads.
3512    #[test]
3513    fn tower3_matches_tower4_through_third_order() {
3514        let s_a: [f64; 5] = [
3515            0.3_f64.sin(),
3516            0.3_f64.cos(),
3517            -0.3_f64.sin(),
3518            -0.3_f64.cos(),
3519            0.3_f64.sin(),
3520        ];
3521        let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3522        let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3523
3524        let a4 = Tower4::<3>::variable(0.4, 0);
3525        let b4 = Tower4::<3>::variable(-0.7, 1);
3526        let c4 = Tower4::<3>::variable(0.9, 2);
3527        let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3528            + a4.mul(&c4).scale(-0.7)
3529            + b4.compose_unary(s_b).scale(0.25);
3530
3531        let a3 = Tower3::<3>::variable(0.4, 0);
3532        let b3 = Tower3::<3>::variable(-0.7, 1);
3533        let c3 = Tower3::<3>::variable(0.9, 2);
3534        let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3535            + a3.mul(&c3).scale(-0.7)
3536            + b3.compose_unary(s4(s_b)).scale(0.25);
3537
3538        assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3539        for i in 0..3 {
3540            assert_eq!(
3541                prog3.g[i].to_bits(),
3542                prog4.g[i].to_bits(),
3543                "g[{i}] mismatch"
3544            );
3545            for j in 0..3 {
3546                assert_eq!(
3547                    prog3.h[i][j].to_bits(),
3548                    prog4.h[i][j].to_bits(),
3549                    "h[{i}][{j}] mismatch"
3550                );
3551                for k in 0..3 {
3552                    assert_eq!(
3553                        prog3.t3[i][j][k].to_bits(),
3554                        prog4.t3[i][j][k].to_bits(),
3555                        "t3[{i}][{j}][{k}] mismatch"
3556                    );
3557                }
3558            }
3559        }
3560    }
3561
3562    /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3563    /// The entire tower has textbook closed forms in μ = σ(η); this test
3564    /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3565    /// analytic truth at near-machine precision.
3566    struct LogitProgram {
3567        eta: Vec<f64>,
3568        y: Vec<f64>,
3569    }
3570
3571    impl RowNllProgram<1> for LogitProgram {
3572        fn n_rows(&self) -> usize {
3573            self.eta.len()
3574        }
3575        fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3576            Ok([self.eta[row]])
3577        }
3578        fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3579            let eta = p[0];
3580            Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3581        }
3582    }
3583
3584    #[test]
3585    fn logit_tower_matches_closed_forms() {
3586        let prog = LogitProgram {
3587            eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3588            y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3589        };
3590        for row in 0..prog.n_rows() {
3591            let t = evaluate_program(&prog, row).expect("logit program");
3592            let eta = prog.eta[row];
3593            let y = prog.y[row];
3594            let mu = 1.0 / (1.0 + (-eta).exp());
3595            let w = mu * (1.0 - mu);
3596            let expect = [
3597                (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3598                (t.g[0], mu - y, "grad"),
3599                (t.h[0][0], w, "hess"),
3600                (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3601                (
3602                    t.t4[0][0][0][0],
3603                    w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3604                    "fourth",
3605                ),
3606            ];
3607            for (got, want, label) in expect {
3608                assert!(
3609                    (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3610                    "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3611                );
3612            }
3613        }
3614    }
3615
3616    fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3617        let diff = (got - want).abs();
3618        assert!(
3619            diff <= rel_tol * want.abs().max(1.0),
3620            "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3621        );
3622    }
3623
3624    #[test]
3625    fn gamma_special_function_stacks_match_reference_values() {
3626        const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3627        let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3628        let cases = [
3629            (
3630                "x=0.1",
3631                0.1,
3632                -10.423_754_940_411_076,
3633                101.433_299_150_792_75,
3634            ),
3635            (
3636                "x=0.5",
3637                0.5,
3638                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3639                pi_sq / 2.0,
3640            ),
3641            ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3642            (
3643                "x=2.5",
3644                2.5,
3645                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3646                pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3647            ),
3648            (
3649                "x=50",
3650                50.0,
3651                3.901_989_673_427_892,
3652                0.020_201_333_226_697_128,
3653            ),
3654        ];
3655
3656        for (label, x, digamma_ref, trigamma_ref) in cases {
3657            let ln_gamma_stack = ln_gamma_derivative_stack(x);
3658            let digamma_stack = digamma_derivative_stack(x);
3659            let trigamma_stack = trigamma_derivative_stack(x);
3660            assert_close(
3661                &format!("{label} ln_gamma_stack digamma"),
3662                ln_gamma_stack[1],
3663                digamma_ref,
3664                1e-13,
3665            );
3666            assert_close(
3667                &format!("{label} digamma value"),
3668                digamma_stack[0],
3669                digamma_ref,
3670                1e-13,
3671            );
3672            assert_close(
3673                &format!("{label} ln_gamma_stack trigamma"),
3674                ln_gamma_stack[2],
3675                trigamma_ref,
3676                1e-13,
3677            );
3678            assert_close(
3679                &format!("{label} digamma_stack trigamma"),
3680                digamma_stack[1],
3681                trigamma_ref,
3682                1e-13,
3683            );
3684            assert_close(
3685                &format!("{label} trigamma value"),
3686                trigamma_stack[0],
3687                trigamma_ref,
3688                1e-13,
3689            );
3690        }
3691    }
3692
3693    #[test]
3694    fn gamma_special_function_stacks_obey_recurrences() {
3695        for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3696            let digamma_x = digamma_derivative_stack(x)[0];
3697            let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3698            let trigamma_x = trigamma_derivative_stack(x)[0];
3699            let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3700            assert_close(
3701                &format!("digamma recurrence x={x}"),
3702                digamma_next,
3703                digamma_x + 1.0 / x,
3704                1e-13,
3705            );
3706            assert_close(
3707                &format!("trigamma recurrence x={x}"),
3708                trigamma_next,
3709                trigamma_x - 1.0 / (x * x),
3710                1e-13,
3711            );
3712        }
3713    }
3714
3715    /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3716    /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3717    /// shape — all have one-line closed forms here.
3718    struct LocScaleProgram {
3719        eta: Vec<f64>,
3720        s: Vec<f64>,
3721        y: Vec<f64>,
3722    }
3723
3724    impl RowNllProgram<2> for LocScaleProgram {
3725        fn n_rows(&self) -> usize {
3726            self.eta.len()
3727        }
3728        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3729            Ok([self.eta[row], self.s[row]])
3730        }
3731        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3732            let r = -(p[0] - self.y[row]);
3733            Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3734        }
3735    }
3736
3737    #[test]
3738    fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3739        let prog = LocScaleProgram {
3740            eta: vec![0.3, -1.1, 2.0],
3741            s: vec![-0.5, 0.2, 0.8],
3742            y: vec![1.0, -2.0, 2.5],
3743        };
3744        let tol = 1e-12;
3745        for row in 0..prog.n_rows() {
3746            let t = evaluate_program(&prog, row).expect("locscale program");
3747            let r = prog.y[row] - prog.eta[row];
3748            let w = (-2.0 * prog.s[row]).exp();
3749            // (η, s) = indices (0, 1).
3750            let truth_g = [-w * r, 1.0 - w * r * r];
3751            let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3752            // Third tensor: distinct-entry closed forms.
3753            // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3754            let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3755                match a + b + c {
3756                    0 => 0.0,
3757                    1 => -2.0 * w,
3758                    2 => -4.0 * w * r,
3759                    _ => -4.0 * w * r * r,
3760                }
3761            };
3762            // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3763            // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3764            let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3765                match a + b + c + d {
3766                    0 | 1 => 0.0,
3767                    2 => 4.0 * w,
3768                    3 => 8.0 * w * r,
3769                    _ => 8.0 * w * r * r,
3770                }
3771            };
3772            for a in 0..2 {
3773                assert!(
3774                    (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3775                    "row {row} grad[{a}]"
3776                );
3777                for b in 0..2 {
3778                    assert!(
3779                        (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3780                        "row {row} hess[{a}][{b}]: got {} want {}",
3781                        t.h[a][b],
3782                        truth_h[a][b]
3783                    );
3784                    for c in 0..2 {
3785                        assert!(
3786                            (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3787                                <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3788                            "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3789                            t.t3[a][b][c],
3790                            t3_truth(a, b, c)
3791                        );
3792                        for d in 0..2 {
3793                            assert!(
3794                                (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3795                                    <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3796                                "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3797                                t.t4[a][b][c][d],
3798                                t4_truth(a, b, c, d)
3799                            );
3800                        }
3801                    }
3802                }
3803            }
3804            // The derived trait-surface helpers agree with direct contraction.
3805            let dir = [0.7, -1.3];
3806            let third = derived_third_contracted(&prog, row, &dir).expect("third");
3807            for a in 0..2 {
3808                for b in 0..2 {
3809                    let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3810                    assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3811                }
3812            }
3813        }
3814    }
3815
3816    /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3817    /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3818    /// every tower channel is checked against central finite differences
3819    /// of the channel one order below — value→grad, grad→hess, hess→t3,
3820    /// t3→t4 — so each order is independently anchored.
3821    ///
3822    /// The program carries a per-row primary fixture plus a per-row offset
3823    /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3824    /// drives both the seed point and the evaluated expression.
3825    struct GnarlyProgram {
3826        primaries: Vec<[f64; 3]>,
3827        tau: Vec<f64>,
3828    }
3829
3830    impl GnarlyProgram {
3831        fn fixture() -> Self {
3832            Self {
3833                primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3834                tau: vec![0.15, -0.35, 0.5],
3835            }
3836        }
3837    }
3838
3839    impl RowNllProgram<3> for GnarlyProgram {
3840        fn n_rows(&self) -> usize {
3841            self.primaries.len()
3842        }
3843        fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3844            self.primaries
3845                .get(row)
3846                .copied()
3847                .ok_or_else(|| format!("gnarly: row {row} out of range"))
3848        }
3849        fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3850            let tau = *self
3851                .tau
3852                .get(row)
3853                .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3854            let a = (p[0] * p[1]).exp();
3855            let b = (p[2] * p[2] + 1.0).sqrt();
3856            let c = (a + b + tau).ln();
3857            let d = (p[1] * 0.5 + 2.0).powf(1.7);
3858            Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3859        }
3860    }
3861
3862    /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3863    /// `row` (used to drive central differences off the fixture grid),
3864    /// while keeping `row`'s per-row data (`tau`) in the loss.
3865    fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3866        struct At<'a> {
3867            base: &'a GnarlyProgram,
3868            row: usize,
3869            p: [f64; 3],
3870        }
3871        impl RowNllProgram<3> for At<'_> {
3872            fn n_rows(&self) -> usize {
3873                1
3874            }
3875            fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3876                if row != 0 {
3877                    return Err(format!("gnarly-at: row {row} out of range"));
3878                }
3879                Ok(self.p)
3880            }
3881            fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3882                if eval_row != 0 {
3883                    return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3884                }
3885                self.base.row_nll(self.row, vars)
3886            }
3887        }
3888        evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3889    }
3890
3891    #[test]
3892    fn gnarly_tower_is_fd_consistent_order_by_order() {
3893        let prog = GnarlyProgram::fixture();
3894        for row in 0..prog.n_rows() {
3895            let base = prog.primaries(row).expect("primaries");
3896            let t = gnarly_tower_at(&prog, row, base);
3897            let h_step = 1e-5;
3898            let tol = 1e-6;
3899            for c in 0..3 {
3900                let mut up = base;
3901                let mut dn = base;
3902                up[c] += h_step;
3903                dn[c] -= h_step;
3904                let t_up = gnarly_tower_at(&prog, row, up);
3905                let t_dn = gnarly_tower_at(&prog, row, dn);
3906                // value → gradient.
3907                let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3908                assert!(
3909                    (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3910                    "grad[{c}]: analytic {} fd {}",
3911                    t.g[c],
3912                    fd_g
3913                );
3914                for a in 0..3 {
3915                    // gradient → Hessian.
3916                    let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3917                    assert!(
3918                        (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3919                        "hess[{a}][{c}]: analytic {} fd {}",
3920                        t.h[a][c],
3921                        fd_h
3922                    );
3923                    for b in 0..3 {
3924                        // Hessian → third.
3925                        let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
3926                        assert!(
3927                            (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
3928                            "t3[{a}][{b}][{c}]: analytic {} fd {}",
3929                            t.t3[a][b][c],
3930                            fd_t3
3931                        );
3932                        for d in 0..3 {
3933                            // third → fourth.
3934                            let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
3935                            assert!(
3936                                (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
3937                                "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
3938                                t.t4[a][b][d][c],
3939                                fd_t4
3940                            );
3941                        }
3942                    }
3943                }
3944            }
3945        }
3946    }
3947
3948    /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
3949    /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
3950    /// smooth, strictly-`a`-monotone
3951    ///   F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
3952    /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
3953    /// independent finite-difference oracle. Mirrors the survival flex
3954    /// calibration solve (one implicit intercept over the primaries) without
3955    /// any survival machinery, so a failure localises to the combinator.
3956    #[test]
3957    fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
3958        const C: f64 = 1.7;
3959        // The scalar constraint as a plain f64 closure (the production root
3960        // finder analogue) and its tower form in (a, θ₀, θ₁).
3961        let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
3962        let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
3963        let solve = |th: [f64; 2]| -> f64 {
3964            let mut a = 0.0_f64;
3965            for _ in 0..100 {
3966                let r = f_scalar(a, th);
3967                if r.abs() < 1e-14 {
3968                    break;
3969                }
3970                a -= r / f_da(a, th);
3971            }
3972            a
3973        };
3974        // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
3975        let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
3976            let a = Tower4::<3>::variable(a0, 0);
3977            let t0 = Tower4::<3>::variable(th[0], 1);
3978            let t1 = Tower4::<3>::variable(th[1], 2);
3979            a + t0 * a.mul(&a) + t1 * a.exp() - C
3980        };
3981
3982        let th0 = [0.35, 0.2];
3983        let a0 = solve(th0);
3984        let f = f_tower(a0, th0);
3985        // Residual at the solved point is ~0 (the combinator tolerates the
3986        // production Newton residual; here it is machine-zero).
3987        assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
3988        let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
3989
3990        // FD oracle: central differences of the scalar re-solve. Each order is
3991        // built from the previous via one more central difference, exactly the
3992        // gnarly order-by-order ladder.
3993        let h = 1e-4;
3994        let tol = 1e-5;
3995        let re = |th: [f64; 2]| solve(th);
3996        for i in 0..2 {
3997            let mut up = th0;
3998            let mut dn = th0;
3999            up[i] += h;
4000            dn[i] -= h;
4001            let fd_g = (re(up) - re(dn)) / (2.0 * h);
4002            assert!(
4003                (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4004                "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4005                a_tower.g[i],
4006                fd_g
4007            );
4008            // second order: FD of the analytic gradient component would re-use
4009            // the combinator; instead difference a SCALAR gradient computed by
4010            // a nested re-solve so the oracle stays production-independent.
4011            let grad_at = |th: [f64; 2], j: usize| -> f64 {
4012                let mut up = th;
4013                let mut dn = th;
4014                up[j] += h;
4015                dn[j] -= h;
4016                (re(up) - re(dn)) / (2.0 * h)
4017            };
4018            for j in 0..2 {
4019                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4020                assert!(
4021                    (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4022                    "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4023                    a_tower.h[i][j],
4024                    fd_h
4025                );
4026            }
4027        }
4028    }
4029
4030    /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4031    /// linear-in-a constraint, and the second-order tensor matches the
4032    /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4033    /// This pins the recursion against the hand-coded first_full.rs formula it
4034    /// replaces, independent of any FD step.
4035    #[test]
4036    fn implicit_solve_matches_textbook_ift_recursion() {
4037        // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4038        let a0 = 0.4_f64;
4039        let th = [0.25_f64, -0.15_f64];
4040        let f = {
4041            let a = Tower4::<3>::variable(a0, 0);
4042            let t0 = Tower4::<3>::variable(th[0], 1);
4043            let t1 = Tower4::<3>::variable(th[1], 2);
4044            // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4045            // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4046            //   0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4047            // implicit_solve requires a genuine root; at the root the level-set
4048            // and root-curve derivatives coincide, so the textbook-IFT
4049            // assertions below are unaffected.
4050            a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4051        };
4052        let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4053        let f_a = f.g[0];
4054        // First order: a_u = −F_u / F_a.
4055        for u in 0..2 {
4056            let want = -f.g[u + 1] / f_a;
4057            assert!(
4058                (a_t.g[u] - want).abs() < 1e-12,
4059                "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4060                a_t.g[u],
4061                want
4062            );
4063        }
4064        // Second order textbook IFT (indices shifted by 1 for the a-slot).
4065        for u in 0..2 {
4066            for v in 0..2 {
4067                let f_uv = f.h[u + 1][v + 1];
4068                let f_au = f.h[0][u + 1];
4069                let f_av = f.h[0][v + 1];
4070                let f_aa = f.h[0][0];
4071                let want =
4072                    -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4073                assert!(
4074                    (a_t.h[u][v] - want).abs() < 1e-12,
4075                    "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4076                    a_t.h[u][v],
4077                    want
4078                );
4079            }
4080        }
4081    }
4082
4083    /// The moving-boundary flux tower reproduces every θ-derivative of a
4084    /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4085    /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4086    /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4087    /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4088    /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4089    #[test]
4090    fn moving_boundary_flux_carries_b_zuv_term() {
4091        use std::f64::consts::PI;
4092        let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4093        // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4094        let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4095        let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4096        let th0 = [0.7_f64, 0.5_f64];
4097
4098        // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4099        let mut z_edge = Tower4::<2>::constant(z_r(th0));
4100        z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4101        z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4102        z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2  (the z_uv the old flux dropped)
4103
4104        // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4105        // B‴=(3z−z³)·B.
4106        let z0 = z_edge.v;
4107        let b0 = b(z0);
4108        let stack = [
4109            b0,
4110            -z0 * b0,
4111            (z0 * z0 - 1.0) * b0,
4112            (3.0 * z0 - z0 * z0 * z0) * b0,
4113        ];
4114        let flux = moving_limit_boundary_tower(&z_edge, stack);
4115
4116        // FD truth of the integral's derivatives.
4117        let h = 1e-4;
4118        let tol = 1e-6;
4119        for i in 0..2 {
4120            let mut up = th0;
4121            let mut dn = th0;
4122            up[i] += h;
4123            dn[i] -= h;
4124            let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4125            assert!(
4126                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4127                "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4128                flux.g[i],
4129                fd_g
4130            );
4131        }
4132        // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4133        // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4134        // leave the [1][1] entry short by exactly 2·B(z₀).
4135        let grad1_at = |th: [f64; 2]| -> f64 {
4136            let mut up = th;
4137            let mut dn = th;
4138            up[1] += h;
4139            dn[1] -= h;
4140            (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4141        };
4142        let mut up = th0;
4143        let mut dn = th0;
4144        up[1] += h;
4145        dn[1] -= h;
4146        let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4147        assert!(
4148            (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4149            "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4150            flux.h[1][1],
4151            fd_h11
4152        );
4153        // Explicit witness that the B·z_uv term is present and material:
4154        // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4155        let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4156        let b_zuv = flux.h[1][1] - pure_zu2;
4157        assert!(
4158            (b_zuv - b0 * 2.0).abs() < 1e-10,
4159            "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4160            b_zuv,
4161            b0 * 2.0
4162        );
4163    }
4164
4165    /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4166    /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4167    /// plain `moving_limit_boundary_tower` cannot express, and the case the
4168    /// survival directional/bidirectional paths hand-assemble term-by-term
4169    /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4170    /// dropping `G·z_uv`). Two independent oracles:
4171    ///   (1) closed-form: the boundary flux of `∫ G dz` is exactly
4172    ///       `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4173    ///       derivatives we take by central FD of the closed form — no jet code.
4174    ///   (2) the explicit second-order hand closure, including the `G·z_uv` term,
4175    ///       built from the integrand's own (z,θ) partials.
4176    /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4177    /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4178    /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4179    /// miss a Hessian entry.
4180    #[test]
4181    fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4182        // G(z;θ) = exp(z·θ₀);  Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4183        let g = |z: f64, t0: f64| (z * t0).exp();
4184        let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4185        let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4186        let th0 = [0.4_f64, 0.5_f64];
4187        let z0 = z_r(th0);
4188
4189        // Edge tower z_edge(θ) over K=2 primaries.
4190        let mut z_edge = Tower4::<2>::constant(z0);
4191        z_edge.g[0] = 1.0; // ∂z/∂θ₀
4192        z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4193        z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4194
4195        // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4196        // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4197        let z_var = Tower4::<3>::variable(z0, 0);
4198        let t0_var = Tower4::<3>::variable(th0[0], 1);
4199        // θ₁ does not enter G/Φ here, but seed it so the jet carries the full
4200        // K1 frame (its Φ-derivatives are zero; the z_edge chain supplies all θ₁
4201        // motion through slot 0).
4202        let _t1_var = Tower4::<3>::variable(th0[1], 2);
4203        let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4204        // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4205        assert!(
4206            (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4207            "Φ_z {:+.8e} != G {:+.8e}",
4208            phi_jet.g[0],
4209            g(z0, th0[0])
4210        );
4211
4212        let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4213
4214        // Value channel is 0 by construction (boundary, not the integral itself).
4215        assert!(
4216            flux.v.abs() < 1e-12,
4217            "boundary value channel {:+.3e}",
4218            flux.v
4219        );
4220
4221        // Oracle (1): central FD of the closed-form boundary flux
4222        //   Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ)   (z₀ FROZEN at the base edge).
4223        let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4224        let h = 1e-4;
4225        let tol = 1e-6;
4226        for i in 0..2 {
4227            let mut up = th0;
4228            let mut dn = th0;
4229            up[i] += h;
4230            dn[i] -= h;
4231            let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4232            assert!(
4233                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4234                "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4235                flux.g[i],
4236                fd_g
4237            );
4238        }
4239        let grad_at = |th: [f64; 2], j: usize| -> f64 {
4240            let mut up = th;
4241            let mut dn = th;
4242            up[j] += h;
4243            dn[j] -= h;
4244            (bnd(up) - bnd(dn)) / (2.0 * h)
4245        };
4246        for i in 0..2 {
4247            for j in 0..2 {
4248                let mut up = th0;
4249                let mut dn = th0;
4250                up[i] += h;
4251                dn[i] -= h;
4252                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4253                assert!(
4254                    (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4255                    "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4256                    flux.h[i][j],
4257                    fd_h
4258                );
4259            }
4260        }
4261
4262        // Oracle (2): the explicit second-order hand closure, term by term —
4263        // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4264        // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4265        // G_θ₁ = 0.
4266        let gg = g(z0, th0[0]);
4267        let g_z = th0[0] * gg;
4268        let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4269        for i in 0..2 {
4270            for j in 0..2 {
4271                let z_u = z_edge.g[i];
4272                let z_v = z_edge.g[j];
4273                let z_uv = z_edge.h[i][j];
4274                let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4275                assert!(
4276                    (flux.h[i][j] - hand).abs() < 1e-9,
4277                    "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4278                    flux.h[i][j],
4279                    hand
4280                );
4281            }
4282        }
4283
4284        // Decisive: the `G·z_uv` term the directional path DROPS is present and
4285        // material in the [1][1] entry (z_uv = 2 there).
4286        let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4287        let g_zuv = flux.h[1][1] - pure_no_zuv;
4288        assert!(
4289            (g_zuv - gg * 2.0).abs() < 1e-9,
4290            "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4291            g_zuv,
4292            gg * 2.0
4293        );
4294    }
4295
4296    /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4297    /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4298    /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4299    /// boundary-velocity formulas:
4300    ///   z_u   = −(a_u + [u==g]·z) / b
4301    ///   z_uv  = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4302    /// This pins the bridge between `implicit_solve` and
4303    /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4304    /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4305    /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4306    /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4307    #[test]
4308    fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4309        const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4310        let g_idx = 1usize;
4311        let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4312        // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4313        // two live axes so a_u and a_uv are both exercised. (In production this
4314        // comes from implicit_solve; here we plant known derivatives.)
4315        let mut a = Tower4::<3>::constant(0.45);
4316        a.g[0] = 0.7;
4317        a.g[1] = -0.3;
4318        a.h[0][0] = 0.25;
4319        a.h[0][1] = 0.11;
4320        a.h[1][0] = 0.11;
4321        a.h[1][1] = -0.08;
4322
4323        // In the survival flex frame the slope `b` IS the g-primary directly
4324        // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4325        // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4326        let b = Tower4::<3>::variable(g0, g_idx);
4327        let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4328
4329        let bv = g0;
4330        let z0 = z_edge.v;
4331        assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4332
4333        // z_u = −(a_u + [u==g]·z) / b.
4334        for u in 0..2 {
4335            let direct = if u == g_idx { z0 } else { 0.0 };
4336            let want = -(a.g[u] + direct) / bv;
4337            assert!(
4338                (z_edge.g[u] - want).abs() < 1e-10,
4339                "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4340                z_edge.g[u],
4341                want
4342            );
4343        }
4344        // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4345        // first-order z_v/z_u (already verified above).
4346        for u in 0..2 {
4347            for v in 0..2 {
4348                let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4349                    + if v == g_idx { z_edge.g[u] } else { 0.0 };
4350                let want = -(a.h[u][v] + cross) / bv;
4351                assert!(
4352                    (z_edge.h[u][v] - want).abs() < 1e-10,
4353                    "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4354                    z_edge.h[u][v],
4355                    want
4356                );
4357            }
4358        }
4359    }
4360
4361    /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4362    /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4363    /// FD-certified bare boundary-velocity constants exactly:
4364    ///   z_a  = ∂z/∂a   = −1/b
4365    ///   z_ab = ∂²z/∂a∂b = +1/b²
4366    ///   z_aa = ∂²z/∂a²  = 0
4367    ///   z_bb = ∂²z/∂b²  = +2(τ−a)/b³
4368    /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4369    /// production base path drops (and only adds in the dir twins, causing the
4370    /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4371    /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4372    /// that later. Pins the constant before the constraint-tower wiring.
4373    #[test]
4374    fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4375        const TAU: f64 = 1.3;
4376        let a0 = 0.45_f64;
4377        let b0 = 0.85_f64;
4378        // Slot 0 = a, slot 1 = b, both seeded independent.
4379        let a = Tower4::<2>::variable(a0, 0);
4380        let b = Tower4::<2>::variable(b0, 1);
4381        let z = (Tower4::<2>::constant(TAU) - a) / b;
4382
4383        assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4384        assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4385        assert!(
4386            (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4387            "z_ab {:+.10e} vs +1/b² {:+.10e}",
4388            z.h[0][1],
4389            1.0 / (b0 * b0)
4390        );
4391        assert!(
4392            z.h[0][0].abs() < 1e-12,
4393            "z_aa must vanish, got {:+.10e}",
4394            z.h[0][0]
4395        );
4396        let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4397        assert!(
4398            (z.h[1][1] - want_zbb).abs() < 1e-12,
4399            "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4400            z.h[1][1],
4401            want_zbb
4402        );
4403    }
4404
4405    /// The oracle harness catches a planted #736-style sign flip in a
4406    /// cross block and reports the channel by name.
4407    #[test]
4408    fn oracle_catches_planted_cross_block_sign_flip() {
4409        let prog = LocScaleProgram {
4410            eta: vec![0.3],
4411            s: vec![-0.5],
4412            y: vec![1.0],
4413        };
4414        let t = evaluate_program(&prog, 0).expect("tower");
4415        let dir = [0.6, -0.2];
4416        let mut third = t.third_contracted(&dir);
4417        let honest = KernelChannels {
4418            value: t.v,
4419            gradient: t.g,
4420            hessian: t.h,
4421            third: vec![(dir, third)],
4422            fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4423        };
4424        verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4425
4426        // Plant the #736 flip: negate one mixed cross entry.
4427        third[0][1] = -third[0][1];
4428        let flipped = KernelChannels {
4429            value: t.v,
4430            gradient: t.g,
4431            hessian: t.h,
4432            third: vec![(dir, third)],
4433            fourth: vec![],
4434        };
4435        let err = verify_kernel_channels(&t, &flipped, 1e-10)
4436            .expect_err("planted sign flip must be caught");
4437        assert!(
4438            err.contains("third[0][0][1]"),
4439            "oracle must name the flipped channel, got: {err}"
4440        );
4441    }
4442
4443    /// The third- and fourth-order tensors must be FULLY symmetric under
4444    /// index permutation (mixed partials commute). The tower stores them
4445    /// unsymmetrized, so equal-by-construction is a real invariant of the
4446    /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4447    /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4448    /// every composition path contributes. Lives in a test (not the hot
4449    /// per-op path) on purpose.
4450    #[test]
4451    fn t3_t4_are_fully_index_symmetric() {
4452        let prog = GnarlyProgram::fixture();
4453        // 3! = 6 permutations of three indices.
4454        let perms3: [[usize; 3]; 6] = [
4455            [0, 1, 2],
4456            [0, 2, 1],
4457            [1, 0, 2],
4458            [1, 2, 0],
4459            [2, 0, 1],
4460            [2, 1, 0],
4461        ];
4462        // 4! = 24 permutations of four indices.
4463        let perms4: [[usize; 4]; 24] = [
4464            [0, 1, 2, 3],
4465            [0, 1, 3, 2],
4466            [0, 2, 1, 3],
4467            [0, 2, 3, 1],
4468            [0, 3, 1, 2],
4469            [0, 3, 2, 1],
4470            [1, 0, 2, 3],
4471            [1, 0, 3, 2],
4472            [1, 2, 0, 3],
4473            [1, 2, 3, 0],
4474            [1, 3, 0, 2],
4475            [1, 3, 2, 0],
4476            [2, 0, 1, 3],
4477            [2, 0, 3, 1],
4478            [2, 1, 0, 3],
4479            [2, 1, 3, 0],
4480            [2, 3, 0, 1],
4481            [2, 3, 1, 0],
4482            [3, 0, 1, 2],
4483            [3, 0, 2, 1],
4484            [3, 1, 0, 2],
4485            [3, 1, 2, 0],
4486            [3, 2, 0, 1],
4487            [3, 2, 1, 0],
4488        ];
4489        for row in 0..prog.n_rows() {
4490            let t = evaluate_program(&prog, row).expect("gnarly tower");
4491            let scale_t3 =
4492                t.t3.iter()
4493                    .flatten()
4494                    .flatten()
4495                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4496                    .max(1.0);
4497            let scale_t4 =
4498                t.t4.iter()
4499                    .flatten()
4500                    .flatten()
4501                    .flatten()
4502                    .fold(0.0_f64, |m, x| m.max(x.abs()))
4503                    .max(1.0);
4504            for i in 0..3 {
4505                for j in 0..3 {
4506                    for k in 0..3 {
4507                        let base = t.t3[i][j][k];
4508                        let idx = [i, j, k];
4509                        for p in &perms3 {
4510                            let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4511                            assert!(
4512                                (base - permed).abs() <= 1e-12 * scale_t3,
4513                                "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4514                                 permuted {permed:+.15e} under {p:?}"
4515                            );
4516                        }
4517                        for l in 0..3 {
4518                            let base4 = t.t4[i][j][k][l];
4519                            let idx4 = [i, j, k, l];
4520                            for p in &perms4 {
4521                                let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4522                                assert!(
4523                                    (base4 - permed).abs() <= 1e-12 * scale_t4,
4524                                    "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4525                                     permuted {permed:+.15e} under {p:?}"
4526                                );
4527                            }
4528                        }
4529                    }
4530                }
4531            }
4532        }
4533    }
4534}
4535
4536#[inline]
4537fn erfcx_nonnegative(x: f64) -> f64 {
4538    if !x.is_finite() {
4539        return if x.is_sign_positive() {
4540            0.0
4541        } else {
4542            f64::INFINITY
4543        };
4544    }
4545    if x <= 0.0 {
4546        return 1.0;
4547    }
4548    if x < 26.0 {
4549        ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4550    } else {
4551        let inv = 1.0 / x;
4552        let inv2 = inv * inv;
4553        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4554            + 6.5625 * inv2 * inv2 * inv2 * inv2;
4555        inv * poly / std::f64::consts::PI.sqrt()
4556    }
4557}
4558
4559#[inline]
4560fn log1mexp_positive(a: f64) -> f64 {
4561    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4562    if a > core::f64::consts::LN_2 {
4563        (-(-a).exp()).ln_1p()
4564    } else if a > 0.0 {
4565        (-(-a).exp_m1()).ln()
4566    } else {
4567        f64::NEG_INFINITY
4568    }
4569}
4570
4571#[inline]
4572fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4573    if x == f64::INFINITY {
4574        return (0.0, 0.0);
4575    }
4576    if x == f64::NEG_INFINITY {
4577        return (f64::NEG_INFINITY, f64::INFINITY);
4578    }
4579    if x.is_nan() {
4580        return (f64::NAN, f64::NAN);
4581    }
4582    if x < 0.0 {
4583        let u = -x / std::f64::consts::SQRT_2;
4584        let ex = erfcx_nonnegative(u).max(1e-300);
4585        let log_cdf = -u * u + (0.5 * ex).ln();
4586        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4587        (log_cdf, lambda)
4588    } else {
4589        let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4590        let lambda = crate::probability::normal_pdf(x) / cdf;
4591        (cdf.ln(), lambda)
4592    }
4593}
4594
4595/// Stable derivative stack for `log Phi(x)` through fourth order.
4596#[inline]
4597pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4598    let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4599    let lambda2 = lambda * lambda;
4600    let lambda3 = lambda2 * lambda;
4601    let x2 = x * x;
4602    [
4603        log_cdf,
4604        lambda,
4605        -lambda * (x + lambda),
4606        lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4607        -lambda
4608            * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4609    ]
4610}
4611
4612/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4613#[inline]
4614pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4615    let r = 1.0 / x.exp_m1();
4616    [
4617        log1mexp_positive(x),
4618        r,
4619        -r * (1.0 + r),
4620        r * (1.0 + r) * (1.0 + 2.0 * r),
4621        -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4622    ]
4623}
4624// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4625#[cfg(test)]
4626mod rowjet_bridge_tests {
4627    use super::*;
4628    use crate::jet_scalar::{JetScalar, Order2};
4629
4630    /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4631    /// subtraction, a scale/neg, a constant, and two value-distinct
4632    /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4633    /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4634    /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4635    struct ToyProgram {
4636        primaries: Vec<[f64; 2]>,
4637        /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4638        /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4639        /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4640        aux: Vec<[f64; 3]>,
4641    }
4642
4643    impl ToyProgram {
4644        /// The body uses `pack_rows` to gather the per-lane continuous data from
4645        /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4646        /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4647        fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4648            let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4649            let z = R::pack_rows(rows, |r| self.aux[r][1]);
4650            let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4651
4652            let a = p[0].mul(&p[1]).scale_rows(cov);
4653            let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4654            let c = b
4655                .compose_unary_with(|u| {
4656                    let e = u.exp();
4657                    [e, e, e, e, e]
4658                })
4659                .scale_rows(z);
4660            let d = c.neg().add(&p[0]);
4661            let e = d
4662                .compose_unary_with(|u| {
4663                    let s = (1.0 + u * u).sqrt();
4664                    let s3 = s * s * s;
4665                    let s5 = s3 * s * s;
4666                    let s7 = s5 * s * s;
4667                    [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4668                })
4669                .scale_rows(wi);
4670            e.mul(&p[1]).add(&e)
4671        }
4672    }
4673
4674    impl RowNllProgramRowJet<2> for ToyProgram {
4675        fn n_rows(&self) -> usize {
4676            self.primaries.len()
4677        }
4678        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4679            Ok(self.primaries[row])
4680        }
4681        fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4682            assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4683            Ok(self.body(rows, p))
4684        }
4685    }
4686
4687    fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4688        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4689        for i in 0..2 {
4690            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4691            for j in 0..2 {
4692                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4693                for k in 0..2 {
4694                    assert_eq!(
4695                        a.t3[i][j][k].to_bits(),
4696                        b.t3[i][j][k].to_bits(),
4697                        "{ctx}: t3[{i}][{j}][{k}]"
4698                    );
4699                    for l in 0..2 {
4700                        assert_eq!(
4701                            a.t4[i][j][k][l].to_bits(),
4702                            b.t4[i][j][k][l].to_bits(),
4703                            "{ctx}: t4[{i}][{j}][{k}][{l}]"
4704                        );
4705                    }
4706                }
4707            }
4708        }
4709    }
4710
4711    fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4712        assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4713        for i in 0..2 {
4714            assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4715            for j in 0..2 {
4716                assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4717                for k in 0..2 {
4718                    assert_eq!(
4719                        a.t3[i][j][k].to_bits(),
4720                        b.t3[i][j][k].to_bits(),
4721                        "{ctx}: t3[{i}][{j}][{k}]"
4722                    );
4723                }
4724            }
4725        }
4726    }
4727
4728    // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4729    struct Lcg(u64);
4730    impl Lcg {
4731        fn next(&mut self) -> f64 {
4732            self.0 = self
4733                .0
4734                .wrapping_mul(6364136223846793005)
4735                .wrapping_add(1442695040888963407);
4736            (self.0 >> 11) as f64 / (1u64 << 53) as f64
4737        }
4738        fn val(&mut self) -> f64 {
4739            let u = self.next();
4740            if u < 0.04 {
4741                return 0.0;
4742            }
4743            if u < 0.08 {
4744                return -0.0;
4745            }
4746            (self.next() - 0.5) * 5.0
4747        }
4748    }
4749
4750    /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4751    /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4752    /// signed-zero and per-lane-distinct primaries.
4753    #[test]
4754    fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4755        let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4756        let mut batches = 0usize;
4757        for _ in 0..2500 {
4758            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4759            // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4760            let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4761            let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4762            let rows = [0usize, 1, 2, 3];
4763
4764            // order-4 batch vs scalar Tower4 (instantiated through the same body).
4765            let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4766            for (row, base) in bases.iter().enumerate() {
4767                let vars: [Tower4<2>; 2] =
4768                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4769                let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4770                assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4771            }
4772
4773            // order-3 batch vs scalar Tower3.
4774            let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4775            for (row, base) in bases.iter().enumerate() {
4776                let vars: [Tower3<2>; 2] =
4777                    std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4778                let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4779                assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4780            }
4781            batches += 1;
4782        }
4783        assert_eq!(batches, 2500);
4784    }
4785
4786    /// The blanket impl does not churn the scalar path: the body driven through
4787    /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4788    /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4789    /// `Tower4` lower channels.
4790    #[test]
4791    fn blanket_scalar_path_is_unchanged_and_consistent() {
4792        let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4793        for _ in 0..3000 {
4794            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4795            let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4796            let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4797
4798            // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4799            // reference body uses `scale(f64)` where the RowJet body uses
4800            // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4801            // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4802            let via_rowjet: Tower4<2> = {
4803                let vars: [Tower4<2>; 2] =
4804                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4805                prog.row_nll(&[0], &vars).expect("rowjet")
4806            };
4807            let via_jetscalar: Tower4<2> = {
4808                let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4809                    <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4810                });
4811                let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4812                // The body using JetScalar's own ops + scale(f64) directly.
4813                let a = vars[0].mul(&vars[1]).scale(cov);
4814                let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4815                let c = b
4816                    .compose_unary_with(|u| {
4817                        let e = u.exp();
4818                        [e, e, e, e, e]
4819                    })
4820                    .scale(z);
4821                let d = JetScalar::neg(&c).add(&vars[0]);
4822                let e = d
4823                    .compose_unary_with(|u| {
4824                        let s = (1.0 + u * u).sqrt();
4825                        let s3 = s * s * s;
4826                        let s5 = s3 * s * s;
4827                        let s7 = s5 * s * s;
4828                        [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4829                    })
4830                    .scale(wi);
4831                e.mul(&vars[1]).add(&e)
4832            };
4833            assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4834
4835            // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4836            // Order2 and Tower4 use different internal representations so
4837            // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4838            // Hessian channels that evaluate to exactly zero; IEEE equality
4839            // treats these as equal, so `==` is the right comparison here.
4840            let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4841            assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4842            for i in 0..2 {
4843                assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4844                for j in 0..2 {
4845                    assert!(
4846                        h[i][j] == via_rowjet.h[i][j],
4847                        "kernel h[{i}][{j}]: {} vs {}",
4848                        h[i][j],
4849                        via_rowjet.h[i][j]
4850                    );
4851                }
4852            }
4853
4854            // (c) the Order2 scalar IS a RowJet via the blanket.
4855            let o2: [Order2<2>; 2] =
4856                std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4857            let _ = prog.body(&[0], &o2);
4858        }
4859    }
4860
4861    /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4862    /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4863    /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4864    #[test]
4865    fn scale_rows_scalar_is_bit_identical_to_scale() {
4866        let mut rng = Lcg(0xFEED_FACE_0042_1001);
4867        for _ in 0..3000 {
4868            let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4869            let s = rng.val();
4870            // Build a dense tower with populated channels (exp of a product).
4871            let vars: [Tower4<2>; 2] =
4872                std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4873            let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4874                let e = u.exp();
4875                [e, e, e, e, e]
4876            });
4877            let via_scale = RowJet::scale(&jet, s);
4878            let via_scale_rows = RowJet::scale_rows(&jet, s);
4879            assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4880        }
4881    }
4882
4883    /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4884    /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4885    /// continuous per-row data path the single-`f64` `scale` could not carry.
4886    #[test]
4887    fn batched_scale_rows_matches_per_row_scalar_scale() {
4888        let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4889        for _ in 0..2500 {
4890            let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4891            let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4892            let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4893                Tower4Batch::variable(
4894                    wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4895                    a,
4896                )
4897            });
4898            let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4899                let e = u.exp();
4900                [e, e, e, e, e]
4901            });
4902            let scaled = prod.scale_rows(s);
4903            for (row, base) in bases.iter().enumerate() {
4904                let v: [Tower4<2>; 2] =
4905                    std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4906                let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4907                    let e = u.exp();
4908                    [e, e, e, e, e]
4909                });
4910                let ref_s = RowJet::scale(&prod_s, s[row]);
4911                assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4912            }
4913        }
4914    }
4915
4916    /// The per-lane guard reports exactly the failing lanes on a batch and the
4917    /// single lane on a scalar jet.
4918    #[test]
4919    fn guard_reports_per_lane_failures() {
4920        let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
4921        let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4922            Tower4Batch::variable(
4923                wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
4924                a,
4925            )
4926        });
4927        let verdict = vars[0].guard(|v| v > 0.0);
4928        assert_eq!(verdict.lanes(), 4);
4929        assert!(verdict.any_failed());
4930        assert!(!verdict.all_pass());
4931        assert!(!verdict.lane_failed(0));
4932        assert!(verdict.lane_failed(1));
4933        assert!(!verdict.lane_failed(2));
4934        assert!(verdict.lane_failed(3));
4935        assert_eq!(verdict.failed_mask(), 0b1010);
4936
4937        let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
4938        let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
4939        assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
4940        assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
4941        assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
4942    }
4943}