gam_math/jet_tower.rs
1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//! v ℓ
12//! g[a] ∂ℓ/∂p_a
13//! h[a][b] ∂²ℓ/∂p_a∂p_b
14//! t3[abc] ∂³ℓ/∂p_a∂p_b∂p_c
15//! t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//! asserting channel-by-channel agreement with a `RowNllProgram` written
94//! once — see [`verify_kernel_channels`]. This alone would have caught
95//! #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//! keep hand-tuned hot paths, now verified against the single-expression
98//! truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//! scale port) implement ONLY `RowNllProgram` and get an exact
101//! fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113 /// Value ℓ.
114 pub v: f64,
115 /// Gradient ∂ℓ/∂p_a.
116 pub g: [f64; K],
117 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118 pub h: [[f64; K]; K],
119 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120 pub t3: [[[f64; K]; K]; K],
121 /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122 pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126 /// The additive identity.
127 pub fn zero() -> Self {
128 Self {
129 v: 0.0,
130 g: [0.0; K],
131 h: [[0.0; K]; K],
132 t3: [[[0.0; K]; K]; K],
133 t4: [[[[0.0; K]; K]; K]; K],
134 }
135 }
136
137 /// A constant: value `c`, all derivatives zero.
138 pub fn constant(c: f64) -> Self {
139 let mut out = Self::zero();
140 out.v = c;
141 out
142 }
143
144 /// The seeded variable `p_idx` with current value `value`:
145 /// unit first derivative in slot `idx`, zero elsewhere and above.
146 pub fn variable(value: f64, idx: usize) -> Self {
147 let mut out = Self::constant(value);
148 out.g[idx] = 1.0;
149 out
150 }
151
152 /// Read the (fully symmetric) derivative tensor entry whose differentiation
153 /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154 #[inline]
155 fn deriv(&self, labels: &[usize]) -> f64 {
156 assert!(
157 labels.len() <= 4,
158 "Tower4 carries at most fourth-order derivatives"
159 );
160 match labels.len() {
161 0 => self.v,
162 1 => self.g[labels[0]],
163 2 => self.h[labels[0]][labels[1]],
164 3 => self.t3[labels[0]][labels[1]][labels[2]],
165 _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166 }
167 }
168
169 /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170 ///
171 /// # Codegen
172 ///
173 /// Each output entry's `2^m` subset sum is written as a compact straight-line
174 /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175 /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176 /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177 /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178 /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179 /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180 /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181 ///
182 /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183 /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184 /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185 /// start so a signed-zero leading product collapses to `+0.0` identically —
186 /// which matters because real jets carry exact-`0.0` channels
187 /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188 /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189 /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190 /// stress — the accumulator start is load-bearing).
191 pub fn mul(&self, o: &Self) -> Self {
192 let a = self;
193 let b = o;
194 let mut out = Self::zero();
195 out.v = a.v * b.v;
196 for i in 0..K {
197 // subsets of {i}: {} {i}
198 let mut acc = 0.0;
199 acc += a.v * b.g[i];
200 acc += a.g[i] * b.v;
201 out.g[i] = acc;
202 }
203 // Hessian is symmetric under i↔j; compute the upper triangle and mirror
204 // (see [`Tower2::mul`] — same term order, enforces exact symmetry).
205 for i in 0..K {
206 for j in i..K {
207 // subsets of {i,j}: {} {i} {j} {ij}
208 let mut acc = 0.0;
209 acc += a.v * b.h[i][j];
210 acc += a.g[i] * b.g[j];
211 acc += a.g[j] * b.g[i];
212 acc += a.h[i][j] * b.v;
213 out.h[i][j] = acc;
214 out.h[j][i] = acc;
215 }
216 }
217 for i in 0..K {
218 for j in 0..K {
219 for k in 0..K {
220 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
221 let mut acc = 0.0;
222 acc += a.v * b.t3[i][j][k];
223 acc += a.g[i] * b.h[j][k];
224 acc += a.g[j] * b.h[i][k];
225 acc += a.h[i][j] * b.g[k];
226 acc += a.g[k] * b.h[i][j];
227 acc += a.h[i][k] * b.g[j];
228 acc += a.h[j][k] * b.g[i];
229 acc += a.t3[i][j][k] * b.v;
230 out.t3[i][j][k] = acc;
231 }
232 }
233 }
234 for i in 0..K {
235 for j in 0..K {
236 for k in 0..K {
237 for l in 0..K {
238 // subsets of {i,j,k,l} in bit order sub = 0..16
239 let mut acc = 0.0;
240 acc += a.v * b.t4[i][j][k][l];
241 acc += a.g[i] * b.t3[j][k][l];
242 acc += a.g[j] * b.t3[i][k][l];
243 acc += a.h[i][j] * b.h[k][l];
244 acc += a.g[k] * b.t3[i][j][l];
245 acc += a.h[i][k] * b.h[j][l];
246 acc += a.h[j][k] * b.h[i][l];
247 acc += a.t3[i][j][k] * b.g[l];
248 acc += a.g[l] * b.t3[i][j][k];
249 acc += a.h[i][l] * b.h[j][k];
250 acc += a.h[j][l] * b.h[i][k];
251 acc += a.t3[i][j][l] * b.g[k];
252 acc += a.h[k][l] * b.h[i][j];
253 acc += a.t3[i][k][l] * b.g[j];
254 acc += a.t3[j][k][l] * b.g[i];
255 acc += a.t4[i][j][k][l] * b.v;
256 out.t4[i][j][k][l] = acc;
257 }
258 }
259 }
260 }
261 out
262 }
263
264 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
265 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
266 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
267 /// out of the borrowed operands.
268 pub fn add(&self, o: &Self) -> Self {
269 *self + *o
270 }
271
272 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
273 pub fn sub(&self, o: &Self) -> Self {
274 *self + o.scale(-1.0)
275 }
276
277 /// Exact multivariate Faà di Bruno composition `f ∘ self`.
278 ///
279 /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
280 /// the SAME `[f64; 5]` stack shape the families' existing
281 /// `unary_derivatives_*` helpers produce, so those special-function
282 /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
283 ///
284 /// The order-m output sums over the set partitions of the m indices
285 /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
286 /// block count: each partition into r blocks contributes
287 /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
288 ///
289 /// # Codegen
290 ///
291 /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
292 /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
293 /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
294 /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
295 /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
296 /// sum is straight-line, so this does NOT unroll over `K` and does NOT
297 /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
298 /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
299 /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
300 ///
301 /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
302 /// walker's exact partition-enumeration order, each term's block products
303 /// are left-associated exactly as the walker's `prod *= block`, and the
304 /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
305 /// (so signed-zero products collapse to `+0.0` identically). The order-4
306 /// term sequence was generated from the walker's own enumeration. Proven
307 /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
308 /// 5000 random inputs each (zeroed / sign-varied stacks included).
309 pub fn compose_unary(&self, d: [f64; 5]) -> Self {
310 let mut out = Self::zero();
311 out.v = d[0];
312 for i in 0..K {
313 let mut acc = 0.0;
314 acc += d[1] * self.g[i];
315 out.g[i] = acc;
316 }
317 for i in 0..K {
318 for j in 0..K {
319 let mut acc = 0.0;
320 acc += d[1] * self.h[i][j];
321 acc += d[2] * self.g[i] * self.g[j];
322 out.h[i][j] = acc;
323 }
324 }
325 for i in 0..K {
326 for j in 0..K {
327 for k in 0..K {
328 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
329 let mut acc = 0.0;
330 acc += d[1] * self.t3[i][j][k];
331 acc += d[2] * self.h[i][j] * self.g[k];
332 acc += d[2] * self.h[i][k] * self.g[j];
333 acc += d[2] * self.g[i] * self.h[j][k];
334 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
335 out.t3[i][j][k] = acc;
336 }
337 }
338 }
339 for i in 0..K {
340 for j in 0..K {
341 for k in 0..K {
342 for l in 0..K {
343 // Bell(4)=15 partitions, walker enumeration order.
344 let mut acc = 0.0;
345 acc += d[1] * self.t4[i][j][k][l];
346 acc += d[2] * self.t3[i][j][k] * self.g[l];
347 acc += d[2] * self.t3[i][j][l] * self.g[k];
348 acc += d[2] * self.h[i][j] * self.h[k][l];
349 acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
350 acc += d[2] * self.t3[i][k][l] * self.g[j];
351 acc += d[2] * self.h[i][k] * self.h[j][l];
352 acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
353 acc += d[2] * self.h[i][l] * self.h[j][k];
354 acc += d[2] * self.g[i] * self.t3[j][k][l];
355 acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
356 acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
357 acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
358 acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
359 acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
360 out.t4[i][j][k][l] = acc;
361 }
362 }
363 }
364 }
365 out
366 }
367
368 /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
369 /// built from the base value through `stack_fn` — the scalar arm of the
370 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
371 /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
372 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
373 /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
374 /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
375 /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
376 #[inline]
377 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
378 self.compose_unary(stack_fn(self.v))
379 }
380
381 /// Single-active-slot fast path for [`Self::compose_unary`].
382 ///
383 /// When the inner jet `self` has derivative support ONLY on the all-`slot`
384 /// diagonal channels — i.e. it is a univariate jet in primary `slot`
385 /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
386 /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
387 /// Bruno walk collapses. Every output channel whose axis tuple contains an
388 /// axis `≠ slot` is structurally `0`: each set-partition has a block
389 /// covering that axis, that block reads an off-`slot` derivative of `self`
390 /// (which is `0`), so the block product and the whole partition vanish, and
391 /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
392 /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
393 /// `t3[slot]³`, `t4[slot]⁴`) survive.
394 ///
395 /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
396 /// the EXACT term order of [`Self::compose_unary`]'s diagonal
397 /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
398 /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
399 /// at the zero-init `+0.0`, which the full walk also produces (the
400 /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
401 /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
402 /// `K ≥ 2` this is far fewer floating-point operations than materialising
403 /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
404 /// are all zero, and far cheaper than the recursive set-partition walker the
405 /// diagonal channels previously routed through (a measured ~9.5× speedup vs
406 /// the full `compose_unary`, recovering a 5.9× walker regression at the
407 /// `K ∈ {2,3}` BMS tower widths).
408 ///
409 /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
410 /// five-channel build does not amortise the call/spill overhead).
411 ///
412 /// # Precondition
413 ///
414 /// The caller guarantees the single-active-slot structure. If it does not
415 /// hold, the off-`slot` channels would be wrongly zeroed; use the full
416 /// [`Self::compose_unary`] in that case.
417 #[inline]
418 pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
419 let mut out = Self::zero();
420 let s = slot;
421 let g = self.g[s];
422 let h = self.h[s][s];
423 let t3 = self.t3[s][s][s];
424 let t4 = self.t4[s][s][s][s];
425 out.v = d[0];
426 // g (i=s): d1*g
427 out.g[s] = {
428 let mut acc = 0.0;
429 acc += d[1] * g;
430 acc
431 };
432 // h (i=j=s): d1*h + d2*g*g
433 out.h[s][s] = {
434 let mut acc = 0.0;
435 acc += d[1] * h;
436 acc += d[2] * g * g;
437 acc
438 };
439 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
440 out.t3[s][s][s] = {
441 let mut acc = 0.0;
442 acc += d[1] * t3;
443 acc += d[2] * h * g;
444 acc += d[2] * h * g;
445 acc += d[2] * g * h;
446 acc += d[3] * g * g * g;
447 acc
448 };
449 // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
450 out.t4[s][s][s][s] = {
451 let mut acc = 0.0;
452 acc += d[1] * t4;
453 acc += d[2] * t3 * g;
454 acc += d[2] * t3 * g;
455 acc += d[2] * h * h;
456 acc += d[3] * h * g * g;
457 acc += d[2] * t3 * g;
458 acc += d[2] * h * h;
459 acc += d[3] * h * g * g;
460 acc += d[2] * h * h;
461 acc += d[2] * g * t3;
462 acc += d[3] * g * h * g;
463 acc += d[3] * h * g * g;
464 acc += d[3] * g * h * g;
465 acc += d[3] * g * g * h;
466 acc += d[4] * g * g * g * g;
467 acc
468 };
469 out
470 }
471
472 /// Multiply every channel by a plain scalar.
473 pub fn scale(&self, s: f64) -> Self {
474 let mut out = *self;
475 out.v *= s;
476 for i in 0..K {
477 out.g[i] *= s;
478 for j in 0..K {
479 out.h[i][j] *= s;
480 for k in 0..K {
481 out.t3[i][j][k] *= s;
482 for l in 0..K {
483 out.t4[i][j][k][l] *= s;
484 }
485 }
486 }
487 }
488 out
489 }
490
491 /// e^self.
492 pub fn exp(&self) -> Self {
493 let e = self.v.exp();
494 self.compose_unary([e, e, e, e, e])
495 }
496
497 /// ln(self). Caller guarantees positivity (likelihood programs do).
498 pub fn ln(&self) -> Self {
499 let u = self.v;
500 let r = 1.0 / u;
501 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
502 }
503
504 /// 1/self.
505 pub fn recip(&self) -> Self {
506 let r = 1.0 / self.v;
507 let r2 = r * r;
508 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
509 }
510
511 /// √self. Caller guarantees positivity.
512 pub fn sqrt(&self) -> Self {
513 let u = self.v;
514 let s = u.sqrt();
515 self.compose_unary([
516 s,
517 0.5 / s,
518 -0.25 / (u * s),
519 0.375 / (u * u * s),
520 -0.9375 / (u * u * u * s),
521 ])
522 }
523
524 /// self^a for real exponent `a`. Caller guarantees a positive base.
525 pub fn powf(&self, a: f64) -> Self {
526 let u = self.v;
527 let f0 = u.powf(a);
528 let f1 = a * u.powf(a - 1.0);
529 let f2 = a * (a - 1.0) * u.powf(a - 2.0);
530 let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
531 let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
532 self.compose_unary([f0, f1, f2, f3, f4])
533 }
534
535 /// ln Γ(self). Caller guarantees positivity.
536 pub fn ln_gamma(&self) -> Self {
537 self.compose_unary(ln_gamma_derivative_stack(self.v))
538 }
539
540 /// ψ(self), the digamma function. Caller guarantees positivity.
541 pub fn digamma(&self) -> Self {
542 self.compose_unary(digamma_derivative_stack(self.v))
543 }
544
545 /// ψ′(self), the trigamma function. Caller guarantees positivity.
546 pub fn trigamma(&self) -> Self {
547 self.compose_unary(trigamma_derivative_stack(self.v))
548 }
549
550 /// Contract `t3` with one primary-space direction:
551 /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
552 /// `row_third_contracted` shape.
553 ///
554 /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
555 /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
556 /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
557 /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
558 /// stays contiguous/vectorisable) and mirror into the lower triangle — this
559 /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
560 /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
561 pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
562 let mut out = [[0.0; K]; K];
563 for a in 0..K {
564 for b in a..K {
565 let mut acc = 0.0;
566 for c in 0..K {
567 acc += self.t3[a][b][c] * dir[c];
568 }
569 out[a][b] = acc;
570 out[b][a] = acc;
571 }
572 }
573 out
574 }
575
576 /// Contract `t4` with two primary-space directions:
577 /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
578 /// `row_fourth_contracted` shape.
579 ///
580 /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
581 /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
582 /// order), so the upper triangle `i ≤ j` is computed and mirrored —
583 /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
584 /// and the inner double loop stays the original contiguous/vectorisable form.
585 pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
586 let mut out = [[0.0; K]; K];
587 for i in 0..K {
588 for j in i..K {
589 let mut acc = 0.0;
590 for k in 0..K {
591 for l in 0..K {
592 acc += self.t4[i][j][k][l] * u[k] * w[l];
593 }
594 }
595 out[i][j] = acc;
596 out[j][i] = acc;
597 }
598 }
599 out
600 }
601}
602
603impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
604 #[inline]
605 fn derivative(&self, labels: &[usize]) -> f64 {
606 self.deriv(labels)
607 }
608
609 fn map_derivatives<F>(&self, mut f: F) -> Self
610 where
611 F: FnMut(&[usize]) -> f64,
612 {
613 let mut out = Self::zero();
614 out.v = f(&[]);
615 for i in 0..K {
616 let labels = [i];
617 out.g[i] = f(&labels);
618 }
619 for i in 0..K {
620 for j in 0..K {
621 let labels = [i, j];
622 out.h[i][j] = f(&labels);
623 }
624 }
625 for i in 0..K {
626 for j in 0..K {
627 for k in 0..K {
628 let labels = [i, j, k];
629 out.t3[i][j][k] = f(&labels);
630 }
631 }
632 }
633 for i in 0..K {
634 for j in 0..K {
635 for k in 0..K {
636 for l in 0..K {
637 let labels = [i, j, k, l];
638 out.t4[i][j][k][l] = f(&labels);
639 }
640 }
641 }
642 }
643 out
644 }
645}
646
647/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
648///
649/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
650/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
651/// [`Tower4`] uses for those orders, so for any program written over both
652/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
653/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
654/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
655/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
656/// gradient, or Hessian.
657///
658/// It exists purely for performance: an inner Newton step (and the
659/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
660/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
661/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
662/// dominates the cold marginal-slope fit, while returning the exact same
663/// `(v, g, h)`.
664#[derive(Clone, Copy, Debug)]
665pub struct Tower2<const K: usize> {
666 /// Value ℓ.
667 pub v: f64,
668 /// Gradient ∂ℓ/∂p_a.
669 pub g: [f64; K],
670 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
671 pub h: [[f64; K]; K],
672}
673
674impl<const K: usize> Tower2<K> {
675 /// The additive identity.
676 pub fn zero() -> Self {
677 Self {
678 v: 0.0,
679 g: [0.0; K],
680 h: [[0.0; K]; K],
681 }
682 }
683
684 /// A constant: value `c`, all derivatives zero.
685 pub fn constant(c: f64) -> Self {
686 let mut out = Self::zero();
687 out.v = c;
688 out
689 }
690
691 /// The seeded variable `p_idx` with current value `value`:
692 /// unit first derivative in slot `idx`, zero elsewhere and above.
693 pub fn variable(value: f64, idx: usize) -> Self {
694 let mut out = Self::constant(value);
695 out.g[idx] = 1.0;
696 out
697 }
698
699 /// Read the derivative tensor entry whose differentiation axes are
700 /// `labels` (length 0..=2): value, `g`, `h`.
701 #[inline]
702 fn deriv(&self, labels: &[usize]) -> f64 {
703 assert!(
704 labels.len() <= 2,
705 "Tower2 carries at most second-order derivatives"
706 );
707 match labels.len() {
708 0 => self.v,
709 1 => self.g[labels[0]],
710 _ => self.h[labels[0]][labels[1]],
711 }
712 }
713
714 /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` upper
715 /// triangle matches [`Tower4::mul`] term-for-term.
716 ///
717 /// # Symmetry fast path
718 ///
719 /// The order-≤2 Leibniz Hessian
720 /// `h[i][j] = a.v·b.h[i][j] + a.g[i]·b.g[j] + a.g[j]·b.g[i] + a.h[i][j]·b.v`
721 /// is symmetric under `i ↔ j` whenever the operand Hessians are — which they
722 /// always are: `constant`/`variable` seed a symmetric (zero) `h`, and
723 /// `mul`/`compose_unary`/`add`/`scale` each preserve symmetry, so the
724 /// invariant holds for every tower a row program can build. We therefore
725 /// compute only the upper triangle `j ≥ i` and mirror it into the lower
726 /// triangle. At the `K = 9` survival width that is `K(K+1)/2 = 45` four-product
727 /// entry evaluations instead of `K² = 81`, and the win is larger in wall-clock
728 /// because the `648`-entry `h` spills at `K = 9` — halving the expensive
729 /// stores/reloads roughly halves the kernel (measured ≈2× on a `Tower2<9>`
730 /// mul-and-read throughput microbench; the dominant `mul` under every packed
731 /// scalar bottoms out here).
732 ///
733 /// The upper-triangle entries are BIT-IDENTICAL to the old rectangular form
734 /// (same term/accumulation order). The lower triangle now equals its mirror
735 /// exactly, where the rectangular form rounded `h[i][j]` and `h[j][i]`
736 /// independently (the two cross products accumulate in opposite order) and
737 /// left a ≤1-ulp asymmetry; mirroring removes it, so the result is exactly
738 /// symmetric — strictly closer to the true symmetric Hessian, not merely a
739 /// reordering. Dense-`h` consumers are all tolerance-gated (rel-tol ≥ 1e-11 ≫
740 /// 1e-16); the `f64`/`f64x4` lane oracle stays exact because
741 /// [`crate::jet_scalar::Order2Lane::mul`] mirrors term-for-term.
742 pub fn mul(&self, o: &Self) -> Self {
743 let a = self;
744 let b = o;
745 let mut out = Self::zero();
746 out.v = a.v * b.v;
747 for i in 0..K {
748 out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
749 }
750 for i in 0..K {
751 for j in i..K {
752 let hij = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
753 out.h[i][j] = hij;
754 out.h[j][i] = hij;
755 }
756 }
757 out
758 }
759
760 /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
761 ///
762 /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
763 /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
764 /// `d[0..=2]` for those orders), so this is a strict truncation, not an
765 /// approximation. The full-order `[f64; 5]` derivative stacks the families
766 /// already produce can be passed by slicing their first three entries.
767 ///
768 /// # Codegen
769 ///
770 /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
771 /// directly instead of routing through the generic
772 /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
773 /// closure dispatch). That matters because this is the kernel under EVERY
774 /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
775 /// composition all bottom out here — so the straight-line form (whose inner
776 /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
777 /// walker calls) lifts all of them at once.
778 ///
779 /// The term and accumulation order is BIT-IDENTICAL to the walker it
780 /// replaces: each output channel mirrors the walker's `total = 0.0` start
781 /// (the explicit `acc` accumulator), so a signed-zero product collapses to
782 /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
783 /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
784 /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
785 /// `g[i]` = `f′·u_i` (single block `{i}`)
786 /// `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j` (blocks `{ij}` then `{i}{j}`),
787 /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
788 pub fn compose_unary(&self, d: [f64; 3]) -> Self {
789 let mut out = Self::zero();
790 out.v = d[0];
791 for i in 0..K {
792 let mut acc = 0.0;
793 acc += d[1] * self.g[i];
794 out.g[i] = acc;
795 }
796 for i in 0..K {
797 for j in 0..K {
798 let mut acc = 0.0;
799 acc += d[1] * self.h[i][j];
800 acc += d[2] * self.g[i] * self.g[j];
801 out.h[i][j] = acc;
802 }
803 }
804 out
805 }
806
807 /// Multiply every channel by a plain scalar.
808 pub fn scale(&self, s: f64) -> Self {
809 let mut out = *self;
810 out.v *= s;
811 for i in 0..K {
812 out.g[i] *= s;
813 for j in 0..K {
814 out.h[i][j] *= s;
815 }
816 }
817 out
818 }
819
820 /// e^self.
821 pub fn exp(&self) -> Self {
822 let e = self.v.exp();
823 self.compose_unary([e, e, e])
824 }
825
826 /// √self. Caller guarantees positivity.
827 pub fn sqrt(&self) -> Self {
828 let u = self.v;
829 let s = u.sqrt();
830 self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
831 }
832}
833
834impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
835 #[inline]
836 fn derivative(&self, labels: &[usize]) -> f64 {
837 self.deriv(labels)
838 }
839
840 fn map_derivatives<F>(&self, mut f: F) -> Self
841 where
842 F: FnMut(&[usize]) -> f64,
843 {
844 let mut out = Self::zero();
845 out.v = f(&[]);
846 for i in 0..K {
847 let labels = [i];
848 out.g[i] = f(&labels);
849 }
850 for i in 0..K {
851 for j in 0..K {
852 let labels = [i, j];
853 out.h[i][j] = f(&labels);
854 }
855 }
856 out
857 }
858}
859
860impl<const K: usize> std::ops::Add for Tower2<K> {
861 type Output = Self;
862 fn add(self, o: Self) -> Self {
863 let mut out = self;
864 out.v += o.v;
865 for i in 0..K {
866 out.g[i] += o.g[i];
867 for j in 0..K {
868 out.h[i][j] += o.h[i][j];
869 }
870 }
871 out
872 }
873}
874
875impl<const K: usize> std::ops::Mul for Tower2<K> {
876 type Output = Self;
877 fn mul(self, o: Self) -> Self {
878 Tower2::mul(&self, &o)
879 }
880}
881
882impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
883 type Output = Self;
884 fn add(self, c: f64) -> Self {
885 let mut out = self;
886 out.v += c;
887 out
888 }
889}
890
891impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
892 type Output = Self;
893 fn mul(self, c: f64) -> Self {
894 self.scale(c)
895 }
896}
897
898/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
899///
900/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
901/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
902/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
903/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
904/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
905/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
906/// [`Tower4::compose_unary`]). So for any program written over both towers the
907/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
908/// perturb the value, gradient, Hessian, or third derivatives.
909///
910/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
911/// needs up to third derivatives (the survival location-scale row kernel reads
912/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
913/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
914/// product/composition that otherwise dominates the per-row cost.
915#[derive(Clone, Copy, Debug)]
916pub struct Tower3<const K: usize> {
917 /// Value ℓ.
918 pub v: f64,
919 /// Gradient ∂ℓ/∂p_a.
920 pub g: [f64; K],
921 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
922 pub h: [[f64; K]; K],
923 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
924 pub t3: [[[f64; K]; K]; K],
925}
926
927impl<const K: usize> Tower3<K> {
928 /// The additive identity.
929 pub fn zero() -> Self {
930 Self {
931 v: 0.0,
932 g: [0.0; K],
933 h: [[0.0; K]; K],
934 t3: [[[0.0; K]; K]; K],
935 }
936 }
937
938 /// A constant: value `c`, all derivatives zero.
939 pub fn constant(c: f64) -> Self {
940 let mut out = Self::zero();
941 out.v = c;
942 out
943 }
944
945 /// The seeded variable `p_idx` with current value `value`:
946 /// unit first derivative in slot `idx`, zero elsewhere and above.
947 pub fn variable(value: f64, idx: usize) -> Self {
948 let mut out = Self::constant(value);
949 out.g[idx] = 1.0;
950 out
951 }
952
953 /// Read the (fully symmetric) derivative tensor entry whose differentiation
954 /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
955 #[inline]
956 fn deriv(&self, labels: &[usize]) -> f64 {
957 assert!(
958 labels.len() <= 3,
959 "Tower3 carries at most third-order derivatives"
960 );
961 match labels.len() {
962 0 => self.v,
963 1 => self.g[labels[0]],
964 2 => self.h[labels[0]][labels[1]],
965 _ => self.t3[labels[0]][labels[1]][labels[2]],
966 }
967 }
968
969 /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
970 /// channels match [`Tower4::mul`] term-for-term.
971 ///
972 /// # Codegen
973 ///
974 /// Straight-line per-entry subset sums instead of the
975 /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
976 /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
977 /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
978 /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
979 /// signed-zero leading product on exact-`0.0` jet channels). Proven
980 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
981 /// zero/sign-stressed inputs each (these channel formulas are exactly the
982 /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
983 pub fn mul(&self, o: &Self) -> Self {
984 let a = self;
985 let b = o;
986 let mut out = Self::zero();
987 out.v = a.v * b.v;
988 for i in 0..K {
989 let mut acc = 0.0;
990 acc += a.v * b.g[i];
991 acc += a.g[i] * b.v;
992 out.g[i] = acc;
993 }
994 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
995 for i in 0..K {
996 for j in i..K {
997 let mut acc = 0.0;
998 acc += a.v * b.h[i][j];
999 acc += a.g[i] * b.g[j];
1000 acc += a.g[j] * b.g[i];
1001 acc += a.h[i][j] * b.v;
1002 out.h[i][j] = acc;
1003 out.h[j][i] = acc;
1004 }
1005 }
1006 for i in 0..K {
1007 for j in 0..K {
1008 for k in 0..K {
1009 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
1010 let mut acc = 0.0;
1011 acc += a.v * b.t3[i][j][k];
1012 acc += a.g[i] * b.h[j][k];
1013 acc += a.g[j] * b.h[i][k];
1014 acc += a.h[i][j] * b.g[k];
1015 acc += a.g[k] * b.h[i][j];
1016 acc += a.h[i][k] * b.g[j];
1017 acc += a.h[j][k] * b.g[i];
1018 acc += a.t3[i][j][k] * b.v;
1019 out.t3[i][j][k] = acc;
1020 }
1021 }
1022 }
1023 out
1024 }
1025
1026 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
1027 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
1028 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
1029 /// out of the borrowed operands.
1030 pub fn add(&self, o: &Self) -> Self {
1031 *self + *o
1032 }
1033
1034 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1035 pub fn sub(&self, o: &Self) -> Self {
1036 *self + o.scale(-1.0)
1037 }
1038
1039 /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1040 ///
1041 /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1042 /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1043 /// (which uses only `d[0..=3]` for those orders), so this is a strict
1044 /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1045 /// stacks the families already produce can be passed by slicing their first
1046 /// four entries.
1047 ///
1048 /// # Codegen
1049 ///
1050 /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1051 /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1052 /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1053 /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1054 /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1055 /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1056 /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1057 /// associated block products, `acc = 0.0` accumulator start. Proven
1058 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1059 /// random inputs each.
1060 pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1061 let mut out = Self::zero();
1062 out.v = d[0];
1063 for i in 0..K {
1064 let mut acc = 0.0;
1065 acc += d[1] * self.g[i];
1066 out.g[i] = acc;
1067 }
1068 for i in 0..K {
1069 for j in 0..K {
1070 let mut acc = 0.0;
1071 acc += d[1] * self.h[i][j];
1072 acc += d[2] * self.g[i] * self.g[j];
1073 out.h[i][j] = acc;
1074 }
1075 }
1076 for i in 0..K {
1077 for j in 0..K {
1078 for k in 0..K {
1079 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1080 let mut acc = 0.0;
1081 acc += d[1] * self.t3[i][j][k];
1082 acc += d[2] * self.h[i][j] * self.g[k];
1083 acc += d[2] * self.h[i][k] * self.g[j];
1084 acc += d[2] * self.g[i] * self.h[j][k];
1085 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1086 out.t3[i][j][k] = acc;
1087 }
1088 }
1089 }
1090 out
1091 }
1092
1093 /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1094 /// built from the base value through `stack_fn` — the scalar arm of the
1095 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1096 /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1097 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1098 /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1099 /// [`Tower4::compose_unary_with`].
1100 #[inline]
1101 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1102 self.compose_unary(stack_fn(self.v))
1103 }
1104
1105 /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1106 /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1107 /// derivative support only on the all-`slot` diagonal, every output channel
1108 /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1109 /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1110 /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1111 /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1112 /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1113 /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1114 /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1115 /// set-partition walker the diagonal channels previously routed through,
1116 /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1117 /// widths. Caller guarantees the single-slot precondition; otherwise use
1118 /// [`Self::compose_unary`].
1119 #[inline]
1120 pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1121 let mut out = Self::zero();
1122 let s = slot;
1123 let g = self.g[s];
1124 let h = self.h[s][s];
1125 let t3 = self.t3[s][s][s];
1126 out.v = d[0];
1127 // g (i=s): d1*g
1128 out.g[s] = {
1129 let mut acc = 0.0;
1130 acc += d[1] * g;
1131 acc
1132 };
1133 // h (i=j=s): d1*h + d2*g*g
1134 out.h[s][s] = {
1135 let mut acc = 0.0;
1136 acc += d[1] * h;
1137 acc += d[2] * g * g;
1138 acc
1139 };
1140 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1141 out.t3[s][s][s] = {
1142 let mut acc = 0.0;
1143 acc += d[1] * t3;
1144 acc += d[2] * h * g;
1145 acc += d[2] * h * g;
1146 acc += d[2] * g * h;
1147 acc += d[3] * g * g * g;
1148 acc
1149 };
1150 out
1151 }
1152
1153 /// Multiply every channel by a plain scalar.
1154 pub fn scale(&self, s: f64) -> Self {
1155 let mut out = *self;
1156 out.v *= s;
1157 for i in 0..K {
1158 out.g[i] *= s;
1159 for j in 0..K {
1160 out.h[i][j] *= s;
1161 for k in 0..K {
1162 out.t3[i][j][k] *= s;
1163 }
1164 }
1165 }
1166 out
1167 }
1168}
1169
1170impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1171 #[inline]
1172 fn derivative(&self, labels: &[usize]) -> f64 {
1173 self.deriv(labels)
1174 }
1175
1176 fn map_derivatives<F>(&self, mut f: F) -> Self
1177 where
1178 F: FnMut(&[usize]) -> f64,
1179 {
1180 let mut out = Self::zero();
1181 out.v = f(&[]);
1182 for i in 0..K {
1183 let labels = [i];
1184 out.g[i] = f(&labels);
1185 }
1186 for i in 0..K {
1187 for j in 0..K {
1188 let labels = [i, j];
1189 out.h[i][j] = f(&labels);
1190 }
1191 }
1192 for i in 0..K {
1193 for j in 0..K {
1194 for k in 0..K {
1195 let labels = [i, j, k];
1196 out.t3[i][j][k] = f(&labels);
1197 }
1198 }
1199 }
1200 out
1201 }
1202}
1203
1204impl<const K: usize> std::ops::Add for Tower3<K> {
1205 type Output = Self;
1206 fn add(self, o: Self) -> Self {
1207 let mut out = self;
1208 out.v += o.v;
1209 for i in 0..K {
1210 out.g[i] += o.g[i];
1211 for j in 0..K {
1212 out.h[i][j] += o.h[i][j];
1213 for k in 0..K {
1214 out.t3[i][j][k] += o.t3[i][j][k];
1215 }
1216 }
1217 }
1218 out
1219 }
1220}
1221
1222pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1223 [
1224 statrs::function::gamma::ln_gamma(x),
1225 digamma_positive(x),
1226 polygamma_positive(1, x),
1227 polygamma_positive(2, x),
1228 polygamma_positive(3, x),
1229 ]
1230}
1231
1232pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1233 [
1234 statrs::function::gamma::ln_gamma(x),
1235 digamma_positive(x),
1236 polygamma_positive(1, x),
1237 ]
1238}
1239
1240pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1241 [
1242 digamma_positive(x),
1243 polygamma_positive(1, x),
1244 polygamma_positive(2, x),
1245 polygamma_positive(3, x),
1246 polygamma_positive(4, x),
1247 ]
1248}
1249
1250pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1251 [
1252 polygamma_positive(1, x),
1253 polygamma_positive(2, x),
1254 polygamma_positive(3, x),
1255 polygamma_positive(4, x),
1256 polygamma_positive(5, x),
1257 ]
1258}
1259
1260/// Scalar digamma ψ(x) for x>0. Bit-identical to `digamma_derivative_stack(x)[0]`
1261/// and to `ln_gamma_derivative_stack(x)[1]`, but evaluates ONLY ψ — the four
1262/// higher polygammas those `[f64; 5]` stacks build are pure discarded work at a
1263/// scalar consumer that reads a single element. Hot-path row kernels that need
1264/// only the digamma value (e.g. the GAMLSS Beta observed cross weight) call this
1265/// instead of indexing `[0]` off a full derivative stack.
1266#[inline]
1267pub fn digamma(x: f64) -> f64 {
1268 digamma_positive(x)
1269}
1270
1271/// Scalar trigamma ψ′(x) for x>0. Bit-identical to
1272/// `trigamma_derivative_stack(x)[0]` (both bottom out in `polygamma_positive(1,
1273/// x)`), but evaluates ONLY ψ′ — the four higher polygammas (orders 2–5) the
1274/// `[f64; 5]` stack builds are discarded at a `[0]` consumer. Used by the
1275/// dispersion-channel Fisher-information row kernels (NB2 `ψ′(θ)−ψ′(θ+μ)`, Beta
1276/// `μψ′(μφ)−(1−μ)ψ′((1−μ)φ)`) which read the trigamma value alone.
1277#[inline]
1278pub fn trigamma(x: f64) -> f64 {
1279 polygamma_positive(1, x)
1280}
1281
1282fn digamma_positive(mut x: f64) -> f64 {
1283 if !(x.is_finite() && x > 0.0) {
1284 return f64::NAN;
1285 }
1286 let mut acc = 0.0;
1287 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1288 acc -= 1.0 / x;
1289 x += 1.0;
1290 }
1291 acc + digamma_asymptotic(x)
1292}
1293
1294fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1295 if !(x.is_finite() && x > 0.0) {
1296 return f64::NAN;
1297 }
1298 let mut acc = 0.0;
1299 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1300 acc += polygamma_recurrence_term(order, x);
1301 x += 1.0;
1302 }
1303 acc + polygamma_asymptotic(order, x)
1304}
1305
1306const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1307const BERNOULLI_EVEN: [(usize, f64); 10] = [
1308 (2, 1.0 / 6.0),
1309 (4, -1.0 / 30.0),
1310 (6, 1.0 / 42.0),
1311 (8, -1.0 / 30.0),
1312 (10, 5.0 / 66.0),
1313 (12, -691.0 / 2730.0),
1314 (14, 7.0 / 6.0),
1315 (16, -3617.0 / 510.0),
1316 (18, 43867.0 / 798.0),
1317 (20, -174611.0 / 330.0),
1318];
1319
1320fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1321 let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1322 sign * factorial(order) / x.powi((order + 1) as i32)
1323}
1324
1325fn digamma_asymptotic(x: f64) -> f64 {
1326 let mut out = x.ln() - 0.5 / x;
1327 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1328 out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1329 }
1330 out
1331}
1332
1333fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1334 if !(1..=5).contains(&order) {
1335 return f64::NAN;
1336 }
1337
1338 let order_factorial = factorial(order);
1339 let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1340 let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1341 + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1342
1343 let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1344 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1345 let rising = rising_factorial(bernoulli_order, order);
1346 out += bernoulli_sign * bernoulli * rising
1347 / bernoulli_order as f64
1348 / x.powi((bernoulli_order + order) as i32);
1349 }
1350 out
1351}
1352
1353fn factorial(n: usize) -> f64 {
1354 (1..=n).fold(1.0, |acc, k| acc * k as f64)
1355}
1356
1357fn rising_factorial(start: usize, len: usize) -> f64 {
1358 (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1359}
1360
1361impl<const K: usize> std::ops::Add for Tower4<K> {
1362 type Output = Self;
1363 fn add(self, o: Self) -> Self {
1364 let mut out = self;
1365 out.v += o.v;
1366 for i in 0..K {
1367 out.g[i] += o.g[i];
1368 for j in 0..K {
1369 out.h[i][j] += o.h[i][j];
1370 for k in 0..K {
1371 out.t3[i][j][k] += o.t3[i][j][k];
1372 for l in 0..K {
1373 out.t4[i][j][k][l] += o.t4[i][j][k][l];
1374 }
1375 }
1376 }
1377 }
1378 out
1379 }
1380}
1381
1382impl<const K: usize> std::ops::Sub for Tower4<K> {
1383 type Output = Self;
1384 fn sub(self, o: Self) -> Self {
1385 self + o.scale(-1.0)
1386 }
1387}
1388
1389impl<const K: usize> std::ops::Neg for Tower4<K> {
1390 type Output = Self;
1391 fn neg(self) -> Self {
1392 self.scale(-1.0)
1393 }
1394}
1395
1396impl<const K: usize> std::ops::Mul for Tower4<K> {
1397 type Output = Self;
1398 fn mul(self, o: Self) -> Self {
1399 Tower4::mul(&self, &o)
1400 }
1401}
1402
1403impl<const K: usize> std::ops::Div for Tower4<K> {
1404 type Output = Self;
1405 fn div(self, o: Self) -> Self {
1406 Tower4::mul(&self, &o.recip())
1407 }
1408}
1409
1410impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1411 type Output = Self;
1412 fn add(self, c: f64) -> Self {
1413 let mut out = self;
1414 out.v += c;
1415 out
1416 }
1417}
1418
1419impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1420 type Output = Self;
1421 fn sub(self, c: f64) -> Self {
1422 self + (-c)
1423 }
1424}
1425
1426impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1427 type Output = Self;
1428 fn mul(self, c: f64) -> Self {
1429 self.scale(c)
1430 }
1431}
1432
1433// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1434//
1435// The flexible survival marginal-slope row loss is NOT a free composition
1436// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1437// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1438// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1439// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1440// last hand-written one in the codebase, and the genus of #736-class
1441// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1442// this reason). These two combinators close that gap: once the constraint
1443// `F` and the integrand/boundaries are themselves towers, the intercept's
1444// derivative tower and the integral's derivative tower come out EXACTLY at
1445// every order — there is no order left to hand-code and forget.
1446
1447/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1448/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1449/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1450/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1451/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1452/// and `a0` is that solved intercept value.
1453///
1454/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1455/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1456/// derivative. This is the mechanical replacement for the hand-coded
1457/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1458/// recursion (first_full.rs) and its third/fourth-order continuations.
1459///
1460/// Method: order-by-order substitution. We build `a` incrementally; at each
1461/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1462/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1463/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1464/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1465/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1466/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1467/// auditable and exact, not a hand-expanded formula per order.
1468///
1469/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1470/// solve's strict monotonicity guard.
1471///
1472/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1473/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1474/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1475/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1476/// is guarded explicitly and re-verified by a composed-residual self-check.
1477pub fn implicit_solve<const K1: usize, const K: usize>(
1478 f: &Tower4<K1>,
1479 a0: f64,
1480) -> Result<Tower4<K>, String> {
1481 assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1482 let f_a = f.g[0];
1483 if f_a == 0.0 || !f_a.is_finite() {
1484 return Err(format!(
1485 "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1486 ));
1487 }
1488 // The expansion point must be a genuine root of F. The single Newton
1489 // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1490 // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1491 // Newton step (rather than f.v directly) makes the criterion invariant to
1492 // the magnitude of f_a / the units of F.
1493 let root_tol = 1e-9;
1494 if !f.v.is_finite() {
1495 return Err(format!(
1496 "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1497 f.v
1498 ));
1499 }
1500 let newton_step = f.v.abs() / f_a.abs();
1501 if newton_step > root_tol * (1.0 + a0.abs()) {
1502 return Err(format!(
1503 "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1504 F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1505 root_tol {root_tol:.1e} · (1 + |a0|)",
1506 f.v
1507 ));
1508 }
1509 // Start with a = constant a0 (correct through order 0). Then lift each
1510 // order in turn. Because substitute_intercept reads `a`'s order-≤m
1511 // tensors when forming G's order-m coefficient, and the order-m
1512 // coefficient of G depends on a's order-m tensor ONLY through the linear
1513 // F_a·a_m term, a single corrective pass per order is exact.
1514 let mut a = Tower4::<K>::constant(a0);
1515 for order in 1..=4 {
1516 let g = substitute_intercept(f, &a);
1517 // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1518 // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1519 // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1520 // as a plain variable in the substitution's first-order part).
1521 match order {
1522 1 => {
1523 for i in 0..K {
1524 a.g[i] -= g.g[i] / f_a;
1525 }
1526 }
1527 2 => {
1528 for i in 0..K {
1529 for j in 0..K {
1530 a.h[i][j] -= g.h[i][j] / f_a;
1531 }
1532 }
1533 }
1534 3 => {
1535 for i in 0..K {
1536 for j in 0..K {
1537 for k in 0..K {
1538 a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1539 }
1540 }
1541 }
1542 }
1543 _ => {
1544 for i in 0..K {
1545 for j in 0..K {
1546 for k in 0..K {
1547 for l in 0..K {
1548 a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1549 }
1550 }
1551 }
1552 }
1553 }
1554 }
1555 }
1556 // Self-check: the composed residual G = F∘a must vanish through order 4.
1557 // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1558 // is exactly the root requirement guarded above. Re-verify all channels
1559 // against a scale-aware floor so any arithmetic regression in the
1560 // substitution recursion is loud rather than silently shipping a
1561 // level-set expansion.
1562 let g = substitute_intercept(f, &a);
1563 let resid_tol = 1e-7 * (1.0 + f_a.abs());
1564 let mut worst = g.v.abs();
1565 for i in 0..K {
1566 worst = worst.max(g.g[i].abs());
1567 for j in 0..K {
1568 worst = worst.max(g.h[i][j].abs());
1569 for k in 0..K {
1570 worst = worst.max(g.t3[i][j][k].abs());
1571 for l in 0..K {
1572 worst = worst.max(g.t4[i][j][k][l].abs());
1573 }
1574 }
1575 }
1576 }
1577 if !worst.is_finite() || worst > resid_tol {
1578 return Err(format!(
1579 "implicit_solve: composed residual G = F∘a does not vanish: \
1580 worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1581 ));
1582 }
1583 Ok(a)
1584}
1585
1586/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1587/// written over `K + 1` variables, returning the composite tower over the
1588/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1589///
1590/// This is the exact multivariate chain rule specialised to "slot 0 is a
1591/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1592/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1593/// point, with the slot-0 increment being the non-constant part of `a` and
1594/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1595/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1596/// so it carries derivatives exactly through order four.
1597pub fn substitute_intercept<const K1: usize, const K: usize>(
1598 f: &Tower4<K1>,
1599 a: &Tower4<K>,
1600) -> Tower4<K> {
1601 assert_eq!(K1, K + 1);
1602 // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1603 // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1604 // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1605 // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1606 // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1607 // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1608 // using tower products for the increment monomials so all θ-derivatives
1609 // propagate exactly.
1610 let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1611 if slot == 0 {
1612 // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1613 let mut d = *a;
1614 d.v = 0.0;
1615 d
1616 } else {
1617 // slot i: the increment δθ_{i-1} = seeded variable minus value.
1618 // θ centred at its expansion value has zero constant term and unit
1619 // first derivative in its own slot.
1620 let mut d = Tower4::<K>::zero();
1621 d.g[slot - 1] = 1.0;
1622 d
1623 }
1624 });
1625 // Accumulate the Taylor sum. order-0 term:
1626 let mut out = Tower4::<K>::constant(f.v);
1627 // order 1: Σ_a f.g[a] · inp[a]
1628 for a_idx in 0..K1 {
1629 out = out + inp[a_idx].scale(f.g[a_idx]);
1630 }
1631 // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1632 for a_idx in 0..K1 {
1633 for b_idx in 0..K1 {
1634 let prod = inp[a_idx].mul(&inp[b_idx]);
1635 out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1636 }
1637 }
1638 // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1639 for a_idx in 0..K1 {
1640 for b_idx in 0..K1 {
1641 for c_idx in 0..K1 {
1642 let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1643 out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1644 }
1645 }
1646 }
1647 // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1648 for a_idx in 0..K1 {
1649 for b_idx in 0..K1 {
1650 for c_idx in 0..K1 {
1651 for d_idx in 0..K1 {
1652 let prod = inp[a_idx]
1653 .mul(&inp[b_idx])
1654 .mul(&inp[c_idx])
1655 .mul(&inp[d_idx]);
1656 out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1657 }
1658 }
1659 }
1660 }
1661 out
1662}
1663
1664/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1665/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1666/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1667/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1668/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1669///
1670/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1671/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1672/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1673/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1674/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1675/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1676/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1677/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1678/// to a constant); callers read only the derivative channels and pair this
1679/// with the interior moment-integral value separately.
1680///
1681/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1682/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1683/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1684pub fn moving_limit_boundary_tower<const K: usize>(
1685 z_edge: &Tower4<K>,
1686 b_stack: [f64; 4],
1687) -> Tower4<K> {
1688 z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1689}
1690
1691/// The boundary-flux derivative tower of a single moving cell integral
1692/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1693/// two edge towers and the integrand stacks at each edge. The returned
1694/// tower's derivative channels are the EXACT moving-boundary contribution to
1695/// every θ-derivative of the cell integral, to fourth order, with no term
1696/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1697/// derivative channels are all zero, contributing nothing — matching the
1698/// production `edge_vel = 0` short-circuit.
1699pub fn cell_moving_boundary_flux_tower<const K: usize>(
1700 z_right: &Tower4<K>,
1701 b_stack_right: [f64; 4],
1702 z_left: &Tower4<K>,
1703 b_stack_left: [f64; 4],
1704) -> Tower4<K> {
1705 moving_limit_boundary_tower(z_right, b_stack_right)
1706 - moving_limit_boundary_tower(z_left, b_stack_left)
1707}
1708
1709/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1710///
1711/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1712/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1713/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1714/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1715/// integrand coefficients move with η, hence with the primaries), so the
1716/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1717/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1718/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1719/// edge-motion-carrying terms the hand path assembles one by one, including the
1720/// `G·z_uv` term the directional path drops).
1721///
1722/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1723/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1724/// part — everything carrying edge motion — is exactly
1725/// `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1726/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1727/// interior moment integral already supplies. Both are one
1728/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1729/// θ in slots 1..K): substituting the edge tower gives the full composite,
1730/// substituting a frozen constant edge isolates the pure-θ part, and their
1731/// difference is the exact boundary flux — every Leibniz term derived by the
1732/// substitution algebra, none hand-omitted.
1733///
1734/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1735/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1736/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1737/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1738/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1739/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1740/// the derivative channels are meaningful, matching the value-less convention of
1741/// [`moving_limit_boundary_tower`].
1742pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1743 phi_jet: &Tower4<K1>,
1744 z_edge: &Tower4<K>,
1745) -> Tower4<K> {
1746 assert_eq!(
1747 K1,
1748 K + 1,
1749 "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1750 );
1751 let frozen_edge = Tower4::<K>::constant(z_edge.v);
1752 let full = substitute_intercept(phi_jet, z_edge);
1753 let interior = substitute_intercept(phi_jet, &frozen_edge);
1754 full - interior
1755}
1756
1757/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1758/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1759/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1760/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1761/// so its `full` and `interior` substitutions coincide and it contributes
1762/// nothing — matching the production `edge_vel = 0` short-circuit.
1763pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1764 phi_jet_right: &Tower4<K1>,
1765 z_right: &Tower4<K>,
1766 phi_jet_left: &Tower4<K1>,
1767 z_left: &Tower4<K>,
1768) -> Tower4<K> {
1769 moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1770 - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1771}
1772
1773// ── The program seam ─────────────────────────────────────────────────
1774
1775/// A family's row negative log-likelihood written ONCE over tower scalars.
1776///
1777/// This is the single source of truth #932 asks for: the value channel of
1778/// the returned tower must BE the production row NLL (same branches, same
1779/// guards, same numerics), and every derivative channel is then exact by
1780/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1781/// NOT part of this trait — it is family data, not calculus, and stays on
1782/// the `RowKernel` implementor.
1783pub trait RowNllProgram<const K: usize>: Send + Sync {
1784 /// Number of observations the program covers.
1785 fn n_rows(&self) -> usize;
1786
1787 /// Current primary-scalar values for `row` (where to seed the tower).
1788 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1789
1790 /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1791 /// variable `a` at the current primary value; implementations combine
1792 /// them with `Tower4` arithmetic and per-row data (response, censoring
1793 /// indicators, offsets) entering as constants.
1794 fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1795}
1796
1797/// Evaluate a program's full tower at the current primaries for one row.
1798///
1799/// One call yields every `RowKernel` calculus channel; callers that need
1800/// several contractions of the same row should hold the returned tower and
1801/// contract repeatedly rather than re-evaluating.
1802pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1803 prog: &P,
1804 row: usize,
1805) -> Result<Tower4<K>, String> {
1806 let p = prog.primaries(row)?;
1807 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1808 prog.row_nll(row, &vars)
1809}
1810
1811/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1812pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1813 prog: &P,
1814 row: usize,
1815) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1816 let t = evaluate_program(prog, row)?;
1817 Ok((t.v, t.g, t.h))
1818}
1819
1820/// Mechanically derived `row_third_contracted` channel.
1821pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1822 prog: &P,
1823 row: usize,
1824 dir: &[f64; K],
1825) -> Result<[[f64; K]; K], String> {
1826 Ok(evaluate_program(prog, row)?.third_contracted(dir))
1827}
1828
1829/// Mechanically derived `row_fourth_contracted` channel.
1830pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1831 prog: &P,
1832 row: usize,
1833 dir_u: &[f64; K],
1834 dir_v: &[f64; K],
1835) -> Result<[[f64; K]; K], String> {
1836 Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1837}
1838
1839// ── The generic program seam (#932 scalar cutover) ───────────────────
1840
1841/// A family's row negative log-likelihood written ONCE over the generic
1842/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1843/// re-instantiated at whatever order / representation a consumer needs
1844/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1845/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1846/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1847/// [`Tower4`] for every channel at once).
1848///
1849/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1850/// program implementing this generic trait gets the small contracted scalars for
1851/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1852/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1853/// work unchanged; new families should prefer this generic trait.
1854///
1855/// Because a `Tower4`-specialised `row_nll` body uses only
1856/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1857/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1858/// A program written that way needs no `Tower4`-specialised method and routes
1859/// the directional and joint-Hessian gates through the contracted scalars from
1860/// a single definition.
1861pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1862 /// Number of observations the program covers.
1863 fn n_rows(&self) -> usize;
1864
1865 /// Current primary-scalar values for `row` (where to seed the scalar).
1866 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1867
1868 /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1869 /// (base value + per-scalar nilpotent directions) by the caller; the body
1870 /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1871 /// (response, censoring, offsets) entering as constants.
1872 fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1873 &self,
1874 row: usize,
1875 p: &[S; K],
1876 ) -> Result<S, String>;
1877}
1878
1879/// Evaluate a generic program at the value/gradient/Hessian scalar
1880/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1881/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1882///
1883/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1884/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1885/// cheap order-2 scalar. The single source of truth means the gradient and
1886/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1887pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1888 prog: &P,
1889 row: usize,
1890) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1891 let base = prog.primaries(row)?;
1892 let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1893 <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1894 });
1895 let s = prog.row_nll_generic(row, &vars)?;
1896 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1897}
1898
1899/// Evaluate a generic program at the one-seed scalar
1900/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1901/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1902/// materialising the dense `t3` tensor. The contraction direction is folded
1903/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1904pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1905 prog: &P,
1906 row: usize,
1907 dir: &[f64; K],
1908) -> Result<[[f64; K]; K], String> {
1909 let base = prog.primaries(row)?;
1910 let vars: [crate::jet_scalar::OneSeed<K>; K] =
1911 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1912 let s = prog.row_nll_generic(row, &vars)?;
1913 Ok(s.contracted_third())
1914}
1915
1916/// Evaluate a generic program at the two-seed scalar
1917/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1918/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1919/// WITHOUT materialising the dense `t4` tensor.
1920pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1921 prog: &P,
1922 row: usize,
1923 dir_u: &[f64; K],
1924 dir_v: &[f64; K],
1925) -> Result<[[f64; K]; K], String> {
1926 let base = prog.primaries(row)?;
1927 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1928 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1929 let s = prog.row_nll_generic(row, &vars)?;
1930 Ok(s.contracted_fourth())
1931}
1932
1933/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1934/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1935/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1936/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1937/// the order-2 / contracted scalars consume, so there is a single source of
1938/// truth across every channel.
1939pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1940 prog: &P,
1941 row: usize,
1942) -> Result<Tower4<K>, String> {
1943 let base = prog.primaries(row)?;
1944 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1945 prog.row_nll_generic(row, &vars)
1946}
1947
1948// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1949//
1950// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1951// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1952// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1953// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1954// as separate inherent methods, not a shared trait bound — so a row-NLL body
1955// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1956// 4-rows-per-pass SIMD batch path could not reuse the single source.
1957//
1958// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1959// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1960// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1961// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1962// (so the scalar call sites compile unchanged and bit-identically), and explicit
1963// impls route the `f64x4` lane towers through their existing per-lane methods. A
1964// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1965// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1966
1967/// The verdict of a per-lane [`RowJet::guard`] domain check.
1968///
1969/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1970/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1971/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1972/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1973/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1974/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1975#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1976pub struct GuardVerdict {
1977 lanes: u8,
1978 failed_mask: u8,
1979}
1980
1981impl GuardVerdict {
1982 /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1983 #[inline]
1984 pub fn scalar(pass: bool) -> Self {
1985 Self {
1986 lanes: 1,
1987 failed_mask: if pass { 0 } else { 1 },
1988 }
1989 }
1990 /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1991 #[inline]
1992 pub fn lanes4(failed_mask: u8) -> Self {
1993 Self {
1994 lanes: 4,
1995 failed_mask: failed_mask & 0x0f,
1996 }
1997 }
1998 /// Number of active lanes inspected (1 scalar, 4 batch).
1999 #[inline]
2000 pub fn lanes(self) -> usize {
2001 self.lanes as usize
2002 }
2003 /// True iff every inspected lane satisfied the predicate.
2004 #[inline]
2005 pub fn all_pass(self) -> bool {
2006 self.failed_mask == 0
2007 }
2008 /// True iff at least one inspected lane failed the predicate.
2009 #[inline]
2010 pub fn any_failed(self) -> bool {
2011 self.failed_mask != 0
2012 }
2013 /// True iff lane `i` failed the predicate.
2014 #[inline]
2015 pub fn lane_failed(self, i: usize) -> bool {
2016 (self.failed_mask >> i) & 1 == 1
2017 }
2018 /// The raw failure mask (bit `i` ⇒ lane `i` failed).
2019 #[inline]
2020 pub fn failed_mask(self) -> u8 {
2021 self.failed_mask
2022 }
2023}
2024
2025/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
2026/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
2027/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
2028/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
2029/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
2030/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
2031/// common `N == 5` case is bit-identical to passing the stack straight through).
2032#[inline]
2033fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
2034 let mut out = [0.0_f64; M];
2035 let m = N.min(M);
2036 out[..m].copy_from_slice(&s[..m]);
2037 out
2038}
2039
2040/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
2041/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
2042/// rows/pass without a dual-source copy (module §"The RowJet bridge").
2043///
2044/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
2045/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
2046/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
2047/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
2048/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
2049pub trait RowJet<const K: usize>: Copy {
2050 /// The value channel(s) seen by [`guard`](Self::guard) and
2051 /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
2052 /// `f64x4` lane tower.
2053 type Value: Copy;
2054
2055 /// A constant (value `c`, all derivatives zero), broadcast to every lane.
2056 fn constant(c: f64) -> Self;
2057 /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
2058 /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
2059 /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
2060 /// [`generic_batched_third_tower`]), which build the tower variables directly
2061 /// from each row's primaries; this method is for any row-invariant auxiliary
2062 /// variable a body introduces.
2063 fn variable(x: f64, slot: usize) -> Self;
2064 /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2065 fn values(&self) -> Self::Value;
2066
2067 /// Truncated Leibniz `self + o`.
2068 fn add(&self, o: &Self) -> Self;
2069 /// Truncated Leibniz `self − o`.
2070 fn sub(&self, o: &Self) -> Self;
2071 /// Truncated Leibniz `self · o`.
2072 fn mul(&self, o: &Self) -> Self;
2073 /// Multiply every channel by the plain scalar `s`.
2074 fn scale(&self, s: f64) -> Self;
2075 /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2076 /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2077 fn neg(&self) -> Self {
2078 self.scale(-1.0)
2079 }
2080
2081 /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2082 /// derivative stack is built from the running base value PER LANE through
2083 /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2084 /// inherent method that already exists on both the scalar towers and the lane
2085 /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2086 /// lane tower it is re-run per lane (the four rows carry four distinct base
2087 /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2088 /// Making it a trait method is precisely what lets a body written once over
2089 /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2090 /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2091 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2092
2093 /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2094 /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2095 /// its one value; a lane tower checks all four. Lets a batched program detect
2096 /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2097 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2098
2099 /// Per-lane scale: multiply every channel by the per-lane factor `s`
2100 /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2101 /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2102 /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2103 /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2104 /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2105 /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2106 /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2107 /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2108 /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2109 fn scale_rows(&self, s: Self::Value) -> Self;
2110
2111 /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2112 /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2113 /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2114 /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2115 /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2116 /// knowing the concrete representation: the program holds the per-row data and
2117 /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2118 /// [`RowNllProgramRowJet::row_nll`], so the body writes
2119 /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2120 /// buried in a `compose_unary_with` stack is pulled out the same way:
2121 /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2122 /// (Binary per-row branches such as the event indicator `di` are kept
2123 /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2124 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2125
2126 // ── value-derived transcendental conveniences ───────────────────────
2127 // Each routes through `compose_unary_with` with the SAME derivative stack the
2128 // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2129 // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2130 // `i` is bit-identical to the scalar result on row `i`.
2131
2132 /// `e^self`.
2133 fn exp(&self) -> Self {
2134 self.compose_unary_with(|u| {
2135 let e = u.exp();
2136 [e, e, e, e, e]
2137 })
2138 }
2139 /// `ln(self)`. Caller guarantees positivity.
2140 fn ln(&self) -> Self {
2141 self.compose_unary_with(|u| {
2142 let r = 1.0 / u;
2143 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2144 })
2145 }
2146 /// `√self`. Caller guarantees positivity.
2147 fn sqrt(&self) -> Self {
2148 self.compose_unary_with(|u| {
2149 let s = u.sqrt();
2150 [
2151 s,
2152 0.5 / s,
2153 -0.25 / (u * s),
2154 0.375 / (u * u * s),
2155 -0.9375 / (u * u * u * s),
2156 ]
2157 })
2158 }
2159 /// `1/self`.
2160 fn recip(&self) -> Self {
2161 self.compose_unary_with(|u| {
2162 let r = 1.0 / u;
2163 let r2 = r * r;
2164 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2165 })
2166 }
2167 /// `self^a` for real `a`. Caller guarantees a positive base.
2168 fn powf(&self, a: f64) -> Self {
2169 self.compose_unary_with(move |u| {
2170 [
2171 u.powf(a),
2172 a * u.powf(a - 1.0),
2173 a * (a - 1.0) * u.powf(a - 2.0),
2174 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2175 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2176 ]
2177 })
2178 }
2179 /// `ln Γ(self)`. Caller guarantees a positive argument.
2180 fn ln_gamma(&self) -> Self {
2181 self.compose_unary_with(ln_gamma_derivative_stack)
2182 }
2183 /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2184 fn digamma(&self) -> Self {
2185 self.compose_unary_with(digamma_derivative_stack)
2186 }
2187}
2188
2189/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2190/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2191/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2192/// bridge adds the lane representation without churning the scalar path. (The
2193/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2194/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2195/// any downstream impl, so the coherence checker proves the impls disjoint.)
2196impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2197 type Value = f64;
2198 #[inline]
2199 fn constant(c: f64) -> Self {
2200 <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2201 }
2202 #[inline]
2203 fn variable(x: f64, slot: usize) -> Self {
2204 <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2205 }
2206 #[inline]
2207 fn values(&self) -> f64 {
2208 crate::jet_scalar::JetScalar::value(self)
2209 }
2210 #[inline]
2211 fn add(&self, o: &Self) -> Self {
2212 crate::jet_scalar::JetScalar::add(self, o)
2213 }
2214 #[inline]
2215 fn sub(&self, o: &Self) -> Self {
2216 crate::jet_scalar::JetScalar::sub(self, o)
2217 }
2218 #[inline]
2219 fn mul(&self, o: &Self) -> Self {
2220 crate::jet_scalar::JetScalar::mul(self, o)
2221 }
2222 #[inline]
2223 fn scale(&self, s: f64) -> Self {
2224 crate::jet_scalar::JetScalar::scale(self, s)
2225 }
2226 #[inline]
2227 fn neg(&self) -> Self {
2228 crate::jet_scalar::JetScalar::neg(self)
2229 }
2230 #[inline]
2231 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2232 crate::jet_scalar::JetScalar::compose_unary_with(self, |u| {
2233 resize_stack::<N, 5>(stack_fn(u))
2234 })
2235 }
2236 #[inline]
2237 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2238 GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2239 }
2240 #[inline]
2241 fn scale_rows(&self, s: f64) -> Self {
2242 // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2243 // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2244 crate::jet_scalar::JetScalar::scale(self, s)
2245 }
2246 #[inline]
2247 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2248 value_of(rows[0])
2249 }
2250}
2251
2252/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2253/// routing each op through its existing per-lane method. Lane `i` of a batched
2254/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2255/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2256impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2257 type Value = [f64; 4];
2258 #[inline]
2259 fn constant(c: f64) -> Self {
2260 Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2261 }
2262 #[inline]
2263 fn variable(x: f64, slot: usize) -> Self {
2264 Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2265 }
2266 #[inline]
2267 fn values(&self) -> [f64; 4] {
2268 self.v.to_array()
2269 }
2270 #[inline]
2271 fn add(&self, o: &Self) -> Self {
2272 Tower4Lane::add(self, o)
2273 }
2274 #[inline]
2275 fn sub(&self, o: &Self) -> Self {
2276 Tower4Lane::sub(self, o)
2277 }
2278 #[inline]
2279 fn mul(&self, o: &Self) -> Self {
2280 Tower4Lane::mul(self, o)
2281 }
2282 #[inline]
2283 fn scale(&self, s: f64) -> Self {
2284 Tower4Lane::scale(self, s)
2285 }
2286 #[inline]
2287 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2288 Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2289 }
2290 #[inline]
2291 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2292 let vals = self.v.to_array();
2293 let mut mask = 0u8;
2294 for (i, &v) in vals.iter().enumerate() {
2295 if !pred(v) {
2296 mask |= 1 << i;
2297 }
2298 }
2299 GuardVerdict::lanes4(mask)
2300 }
2301 #[inline]
2302 fn scale_rows(&self, s: [f64; 4]) -> Self {
2303 // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2304 // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2305 let sl = wide::f64x4::new(s);
2306 let mut out = *self;
2307 out.v = self.v * sl;
2308 for i in 0..K {
2309 out.g[i] = self.g[i] * sl;
2310 for j in 0..K {
2311 out.h[i][j] = self.h[i][j] * sl;
2312 for k in 0..K {
2313 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2314 for l in 0..K {
2315 out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2316 }
2317 }
2318 }
2319 }
2320 out
2321 }
2322 #[inline]
2323 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2324 [
2325 value_of(rows[0]),
2326 value_of(rows[1]),
2327 value_of(rows[2]),
2328 value_of(rows[3]),
2329 ]
2330 }
2331}
2332
2333/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2334/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2335/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2336/// [`Tower3`] which also carries only up to the third tensor.
2337impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2338 type Value = [f64; 4];
2339 #[inline]
2340 fn constant(c: f64) -> Self {
2341 Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2342 }
2343 #[inline]
2344 fn variable(x: f64, slot: usize) -> Self {
2345 Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2346 }
2347 #[inline]
2348 fn values(&self) -> [f64; 4] {
2349 self.v.to_array()
2350 }
2351 #[inline]
2352 fn add(&self, o: &Self) -> Self {
2353 Tower3Lane::add(self, o)
2354 }
2355 #[inline]
2356 fn sub(&self, o: &Self) -> Self {
2357 Tower3Lane::sub(self, o)
2358 }
2359 #[inline]
2360 fn mul(&self, o: &Self) -> Self {
2361 Tower3Lane::mul(self, o)
2362 }
2363 #[inline]
2364 fn scale(&self, s: f64) -> Self {
2365 Tower3Lane::scale(self, s)
2366 }
2367 #[inline]
2368 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2369 Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2370 }
2371 #[inline]
2372 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2373 let vals = self.v.to_array();
2374 let mut mask = 0u8;
2375 for (i, &v) in vals.iter().enumerate() {
2376 if !pred(v) {
2377 mask |= 1 << i;
2378 }
2379 }
2380 GuardVerdict::lanes4(mask)
2381 }
2382 #[inline]
2383 fn scale_rows(&self, s: [f64; 4]) -> Self {
2384 let sl = wide::f64x4::new(s);
2385 let mut out = *self;
2386 out.v = self.v * sl;
2387 for i in 0..K {
2388 out.g[i] = self.g[i] * sl;
2389 for j in 0..K {
2390 out.h[i][j] = self.h[i][j] * sl;
2391 for k in 0..K {
2392 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2393 }
2394 }
2395 }
2396 out
2397 }
2398 #[inline]
2399 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2400 [
2401 value_of(rows[0]),
2402 value_of(rows[1]),
2403 value_of(rows[2]),
2404 value_of(rows[3]),
2405 ]
2406 }
2407}
2408
2409/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2410/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2411/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2412/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2413/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2414/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2415/// and the batched channels through [`generic_batched_fourth_tower`] /
2416/// [`generic_batched_third_tower`], all from a single source.
2417pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2418 /// Number of observations the program covers.
2419 fn n_rows(&self) -> usize;
2420
2421 /// Current primary-scalar values for `row` (where to seed each lane).
2422 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2423
2424 /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2425 /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2426 /// pre-seeded by the caller (base value plus, for the directional scalars, the
2427 /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2428 /// per-row data entering through `rows`/`self` as constants.
2429 fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2430}
2431
2432/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2433/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2434/// `RowJet` twin of [`generic_row_kernel`].
2435pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2436 prog: &P,
2437 row: usize,
2438) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2439 let base = prog.primaries(row)?;
2440 let vars: [crate::jet_scalar::Order2<K>; K] =
2441 std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2442 let s = prog.row_nll(&[row], &vars)?;
2443 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2444}
2445
2446/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2447/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2448/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2449pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2450 prog: &P,
2451 row: usize,
2452 dir: &[f64; K],
2453) -> Result<[[f64; K]; K], String> {
2454 let base = prog.primaries(row)?;
2455 let vars: [crate::jet_scalar::OneSeed<K>; K] =
2456 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2457 let s = prog.row_nll(&[row], &vars)?;
2458 Ok(s.contracted_third())
2459}
2460
2461/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2462/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2463/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2464pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2465 prog: &P,
2466 row: usize,
2467 dir_u: &[f64; K],
2468 dir_v: &[f64; K],
2469) -> Result<[[f64; K]; K], String> {
2470 let base = prog.primaries(row)?;
2471 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2472 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2473 let s = prog.row_nll(&[row], &vars)?;
2474 Ok(s.contracted_fourth())
2475}
2476
2477/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2478/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2479/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2480/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2481/// the scalar [`generic_full_tower`] on `rows[i]`.
2482pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2483 prog: &P,
2484 rows: [usize; 4],
2485) -> Result<Tower4Batch<K>, String> {
2486 let bases: [[f64; K]; 4] = [
2487 prog.primaries(rows[0])?,
2488 prog.primaries(rows[1])?,
2489 prog.primaries(rows[2])?,
2490 prog.primaries(rows[3])?,
2491 ];
2492 let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2493 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2494 Tower4Batch::variable(lane_vals, a)
2495 });
2496 prog.row_nll(&rows, &vars)
2497}
2498
2499/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2500/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2501/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2502/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2503pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2504 prog: &P,
2505 rows: [usize; 4],
2506) -> Result<Tower3Batch<K>, String> {
2507 let bases: [[f64; K]; 4] = [
2508 prog.primaries(rows[0])?,
2509 prog.primaries(rows[1])?,
2510 prog.primaries(rows[2])?,
2511 prog.primaries(rows[3])?,
2512 ];
2513 let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2514 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2515 Tower3Batch::variable(lane_vals, a)
2516 });
2517 prog.row_nll(&rows, &vars)
2518}
2519
2520// ── The oracle ───────────────────────────────────────────────────────
2521
2522/// One row's worth of hand-written kernel outputs, as claimed by a
2523/// `RowKernel` implementation, packaged for verification against the
2524/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2525/// its visibility — can be audited from its own test module.
2526pub struct KernelChannels<const K: usize> {
2527 /// Claimed `(nll, ∇, H)` from `row_kernel`.
2528 pub value: f64,
2529 /// Claimed gradient.
2530 pub gradient: [f64; K],
2531 /// Claimed Hessian.
2532 pub hessian: [[f64; K]; K],
2533 /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2534 pub third: Vec<([f64; K], [[f64; K]; K])>,
2535 /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2536 pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2537}
2538
2539/// Channel-by-channel audit of a hand-written kernel against the
2540/// single-expression tower truth. Returns `Err` naming the first channel,
2541/// index, claimed and true values on disagreement — designed as the body
2542/// of the per-family CI oracle tests (#932 deployment step 2).
2543///
2544/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2545/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2546/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2547/// towers pass without demanding bit-equality, while a tiny cross-block
2548/// entry dropped next to a huge one is still caught (it is NOT measured
2549/// against the largest entry of the whole channel — there is no per-channel
2550/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2551///
2552/// Non-finite handling is strict: a NaN on either side always fails; an
2553/// infinity passes only when both sides are the SAME signed infinity.
2554pub fn verify_kernel_channels<const K: usize>(
2555 tower: &Tower4<K>,
2556 claims: &KernelChannels<K>,
2557 rel_tol: f64,
2558) -> Result<(), String> {
2559 // Absolute floor: reuse rel_tol so a single knob controls both the
2560 // relative band and the absolute floor for entries near zero.
2561 let atol = rel_tol;
2562 let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2563 // Non-finite values never silently pass the algebraic comparison
2564 // below (any comparison with NaN is false). Handle them explicitly:
2565 // NaN on either side always errs; an infinity passes only if both
2566 // sides are the identical signed infinity.
2567 if !claim.is_finite() || !truth.is_finite() {
2568 let agree = claim.is_infinite()
2569 && truth.is_infinite()
2570 && claim.is_sign_positive() == truth.is_sign_positive();
2571 if agree {
2572 return Ok(());
2573 }
2574 return Err(format!(
2575 "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2576 ));
2577 }
2578 let band = atol + rel_tol * claim.abs().max(truth.abs());
2579 if (claim - truth).abs() > band {
2580 return Err(format!(
2581 "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2582 ));
2583 }
2584 Ok(())
2585 };
2586
2587 check("value", claims.value, tower.v)?;
2588
2589 for a in 0..K {
2590 check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2591 }
2592
2593 for a in 0..K {
2594 for b in 0..K {
2595 check(
2596 &format!("hessian[{a}][{b}]"),
2597 claims.hessian[a][b],
2598 tower.h[a][b],
2599 )?;
2600 }
2601 }
2602
2603 for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2604 let truth = tower.third_contracted(dir);
2605 for a in 0..K {
2606 for b in 0..K {
2607 check(
2608 &format!("third[{t_idx}][{a}][{b}]"),
2609 claim[a][b],
2610 truth[a][b],
2611 )?;
2612 }
2613 }
2614 }
2615
2616 for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2617 let truth = tower.fourth_contracted(u, w);
2618 for a in 0..K {
2619 for b in 0..K {
2620 check(
2621 &format!("fourth[{f_idx}][{a}][{b}]"),
2622 claim[a][b],
2623 truth[a][b],
2624 )?;
2625 }
2626 }
2627 }
2628
2629 Ok(())
2630}
2631
2632// ===========================================================================
2633// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2634// ===========================================================================
2635//
2636// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2637// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2638// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2639// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2640// vector pass instead of one per scalar pass.
2641//
2642// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2643// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2644// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2645// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2646// performs no fp-contraction), so lane `i` of any channel of a
2647// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2648// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2649// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2650// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2651// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2652//
2653// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2654// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2655// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2656// exactly as the scalar path already does).
2657
2658use crate::jet_scalar::Lane;
2659
2660/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2661/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2662/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2663#[derive(Clone, Copy)]
2664pub struct Tower4Lane<L: Lane, const K: usize> {
2665 /// Value channel (one entry per lane/row).
2666 pub v: L,
2667 /// Gradient `∂/∂p_a`.
2668 pub g: [L; K],
2669 /// Hessian `∂²/∂p_a∂p_b`.
2670 pub h: [[L; K]; K],
2671 /// Third tensor `∂³`.
2672 pub t3: [[[L; K]; K]; K],
2673 /// Fourth tensor `∂⁴`.
2674 pub t4: [[[[L; K]; K]; K]; K],
2675}
2676
2677/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2678pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2679
2680impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2681 /// All-zero tower (every channel `+0.0` in every lane).
2682 #[inline]
2683 pub fn zero() -> Self {
2684 let z = L::splat(0.0);
2685 Self {
2686 v: z,
2687 g: [z; K],
2688 h: [[z; K]; K],
2689 t3: [[[z; K]; K]; K],
2690 t4: [[[[z; K]; K]; K]; K],
2691 }
2692 }
2693 /// Constant `c` (per lane): value channel only.
2694 #[inline]
2695 pub fn constant(c: L) -> Self {
2696 let mut o = Self::zero();
2697 o.v = c;
2698 o
2699 }
2700 /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2701 /// slot `idx` (mirrors [`Tower4::variable`]).
2702 #[inline]
2703 pub fn variable(value: L, idx: usize) -> Self {
2704 let mut o = Self::constant(value);
2705 o.g[idx] = L::splat(1.0);
2706 o
2707 }
2708 /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2709 #[inline]
2710 pub fn lane(&self, i: usize) -> Tower4<K> {
2711 let mut out = Tower4::<K>::zero();
2712 out.v = self.v.lane(i);
2713 for a in 0..K {
2714 out.g[a] = self.g[a].lane(i);
2715 for b in 0..K {
2716 out.h[a][b] = self.h[a][b].lane(i);
2717 for c in 0..K {
2718 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2719 for d in 0..K {
2720 out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2721 }
2722 }
2723 }
2724 }
2725 out
2726 }
2727 /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2728 #[inline]
2729 pub fn add(&self, o: &Self) -> Self {
2730 let mut out = *self;
2731 out.v = self.v.add(o.v);
2732 for i in 0..K {
2733 out.g[i] = self.g[i].add(o.g[i]);
2734 for j in 0..K {
2735 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2736 for k in 0..K {
2737 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2738 for l in 0..K {
2739 out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2740 }
2741 }
2742 }
2743 }
2744 out
2745 }
2746 /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2747 #[inline]
2748 pub fn sub(&self, o: &Self) -> Self {
2749 let mut out = *self;
2750 out.v = self.v.sub(o.v);
2751 for i in 0..K {
2752 out.g[i] = self.g[i].sub(o.g[i]);
2753 for j in 0..K {
2754 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2755 for k in 0..K {
2756 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2757 for l in 0..K {
2758 out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2759 }
2760 }
2761 }
2762 }
2763 out
2764 }
2765 /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2766 #[inline]
2767 pub fn scale(&self, s: f64) -> Self {
2768 let sl = L::splat(s);
2769 let mut out = *self;
2770 out.v = self.v.mul(sl);
2771 for i in 0..K {
2772 out.g[i] = self.g[i].mul(sl);
2773 for j in 0..K {
2774 out.h[i][j] = self.h[i][j].mul(sl);
2775 for k in 0..K {
2776 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2777 for l in 0..K {
2778 out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2779 }
2780 }
2781 }
2782 }
2783 out
2784 }
2785 /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2786 #[inline]
2787 pub fn mul(&self, o: &Self) -> Self {
2788 let a = self;
2789 let b = o;
2790 let mut out = Self::zero();
2791 out.v = a.v.mul(b.v);
2792 for i in 0..K {
2793 let mut acc = L::splat(0.0);
2794 acc = acc.add(a.v.mul(b.g[i]));
2795 acc = acc.add(a.g[i].mul(b.v));
2796 out.g[i] = acc;
2797 }
2798 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
2799 for i in 0..K {
2800 for j in i..K {
2801 let mut acc = L::splat(0.0);
2802 acc = acc.add(a.v.mul(b.h[i][j]));
2803 acc = acc.add(a.g[i].mul(b.g[j]));
2804 acc = acc.add(a.g[j].mul(b.g[i]));
2805 acc = acc.add(a.h[i][j].mul(b.v));
2806 out.h[i][j] = acc;
2807 out.h[j][i] = acc;
2808 }
2809 }
2810 for i in 0..K {
2811 for j in 0..K {
2812 for k in 0..K {
2813 let mut acc = L::splat(0.0);
2814 acc = acc.add(a.v.mul(b.t3[i][j][k]));
2815 acc = acc.add(a.g[i].mul(b.h[j][k]));
2816 acc = acc.add(a.g[j].mul(b.h[i][k]));
2817 acc = acc.add(a.h[i][j].mul(b.g[k]));
2818 acc = acc.add(a.g[k].mul(b.h[i][j]));
2819 acc = acc.add(a.h[i][k].mul(b.g[j]));
2820 acc = acc.add(a.h[j][k].mul(b.g[i]));
2821 acc = acc.add(a.t3[i][j][k].mul(b.v));
2822 out.t3[i][j][k] = acc;
2823 }
2824 }
2825 }
2826 for i in 0..K {
2827 for j in 0..K {
2828 for k in 0..K {
2829 for l in 0..K {
2830 let mut acc = L::splat(0.0);
2831 acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2832 acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2833 acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2834 acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2835 acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2836 acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2837 acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2838 acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2839 acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2840 acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2841 acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2842 acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2843 acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2844 acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2845 acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2846 acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2847 out.t4[i][j][k][l] = acc;
2848 }
2849 }
2850 }
2851 }
2852 out
2853 }
2854 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2855 /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2856 /// (build via [`Lane::unary5`] from the scalar special-function stack).
2857 #[inline]
2858 pub fn compose_unary(&self, d: [L; 5]) -> Self {
2859 let mut out = Self::zero();
2860 out.v = d[0];
2861 for i in 0..K {
2862 let mut acc = L::splat(0.0);
2863 acc = acc.add(d[1].mul(self.g[i]));
2864 out.g[i] = acc;
2865 }
2866 for i in 0..K {
2867 for j in 0..K {
2868 let mut acc = L::splat(0.0);
2869 acc = acc.add(d[1].mul(self.h[i][j]));
2870 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2871 out.h[i][j] = acc;
2872 }
2873 }
2874 for i in 0..K {
2875 for j in 0..K {
2876 for k in 0..K {
2877 let mut acc = L::splat(0.0);
2878 acc = acc.add(d[1].mul(self.t3[i][j][k]));
2879 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2880 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2881 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2882 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2883 out.t3[i][j][k] = acc;
2884 }
2885 }
2886 }
2887 for i in 0..K {
2888 for j in 0..K {
2889 for k in 0..K {
2890 for l in 0..K {
2891 let mut acc = L::splat(0.0);
2892 acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2893 acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2894 acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2895 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2896 acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2897 acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2898 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2899 acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2900 acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2901 acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2902 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2903 acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2904 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2905 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2906 acc = acc.add(
2907 d[4].mul(self.g[i])
2908 .mul(self.g[j])
2909 .mul(self.g[k])
2910 .mul(self.g[l]),
2911 );
2912 out.t4[i][j][k][l] = acc;
2913 }
2914 }
2915 }
2916 }
2917 out
2918 }
2919 /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2920 /// built from the base value through `stack_fn`, evaluated PER LANE — the
2921 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2922 /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2923 ///
2924 /// Each of the four lanes carries a DISTINCT base value, so the scalar
2925 /// `stack_fn` is run once per lane at that lane's own value (via
2926 /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2927 /// the composition is then the existing per-lane [`Self::compose_unary`].
2928 /// Because `unary_with` runs the identical scalar closure per lane and
2929 /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2930 /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2931 /// — which is exactly what lets a row program written against the scalar
2932 /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2933 #[inline]
2934 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2935 self.compose_unary(self.v.unary_with(stack_fn))
2936 }
2937
2938 /// Single-active-slot fast path, term-for-term lift of
2939 /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2940 #[inline]
2941 pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2942 let mut out = Self::zero();
2943 let s = slot;
2944 let g = self.g[s];
2945 let h = self.h[s][s];
2946 let t3 = self.t3[s][s][s];
2947 let t4 = self.t4[s][s][s][s];
2948 out.v = d[0];
2949 out.g[s] = {
2950 let mut acc = L::splat(0.0);
2951 acc = acc.add(d[1].mul(g));
2952 acc
2953 };
2954 out.h[s][s] = {
2955 let mut acc = L::splat(0.0);
2956 acc = acc.add(d[1].mul(h));
2957 acc = acc.add(d[2].mul(g).mul(g));
2958 acc
2959 };
2960 out.t3[s][s][s] = {
2961 let mut acc = L::splat(0.0);
2962 acc = acc.add(d[1].mul(t3));
2963 acc = acc.add(d[2].mul(h).mul(g));
2964 acc = acc.add(d[2].mul(h).mul(g));
2965 acc = acc.add(d[2].mul(g).mul(h));
2966 acc = acc.add(d[3].mul(g).mul(g).mul(g));
2967 acc
2968 };
2969 out.t4[s][s][s][s] = {
2970 let mut acc = L::splat(0.0);
2971 acc = acc.add(d[1].mul(t4));
2972 acc = acc.add(d[2].mul(t3).mul(g));
2973 acc = acc.add(d[2].mul(t3).mul(g));
2974 acc = acc.add(d[2].mul(h).mul(h));
2975 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2976 acc = acc.add(d[2].mul(t3).mul(g));
2977 acc = acc.add(d[2].mul(h).mul(h));
2978 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2979 acc = acc.add(d[2].mul(h).mul(h));
2980 acc = acc.add(d[2].mul(g).mul(t3));
2981 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2982 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2983 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2984 acc = acc.add(d[3].mul(g).mul(g).mul(h));
2985 acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2986 acc
2987 };
2988 out
2989 }
2990 /// Contract `t3` with a primary-space direction (lift of
2991 /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2992 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2993 #[inline]
2994 pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2995 let mut out = [[L::splat(0.0); K]; K];
2996 for a in 0..K {
2997 for b in a..K {
2998 let mut acc = L::splat(0.0);
2999 for c in 0..K {
3000 acc = acc.add(self.t3[a][b][c].mul(dir[c]));
3001 }
3002 out[a][b] = acc;
3003 out[b][a] = acc;
3004 }
3005 }
3006 out
3007 }
3008 /// Contract `t4` with two primary-space directions (lift of
3009 /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
3010 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
3011 #[inline]
3012 pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
3013 let mut out = [[L::splat(0.0); K]; K];
3014 for i in 0..K {
3015 for j in i..K {
3016 let mut acc = L::splat(0.0);
3017 for k in 0..K {
3018 for l in 0..K {
3019 acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
3020 }
3021 }
3022 out[i][j] = acc;
3023 out[j][i] = acc;
3024 }
3025 }
3026 out
3027 }
3028}
3029
3030/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
3031#[derive(Clone, Copy)]
3032pub struct Tower3Lane<L: Lane, const K: usize> {
3033 /// Value channel.
3034 pub v: L,
3035 /// Gradient.
3036 pub g: [L; K],
3037 /// Hessian.
3038 pub h: [[L; K]; K],
3039 /// Third tensor.
3040 pub t3: [[[L; K]; K]; K],
3041}
3042
3043/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
3044pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
3045
3046impl<L: Lane, const K: usize> Tower3Lane<L, K> {
3047 /// All-zero tower.
3048 #[inline]
3049 pub fn zero() -> Self {
3050 let z = L::splat(0.0);
3051 Self {
3052 v: z,
3053 g: [z; K],
3054 h: [[z; K]; K],
3055 t3: [[[z; K]; K]; K],
3056 }
3057 }
3058 /// Constant `c` (per lane).
3059 #[inline]
3060 pub fn constant(c: L) -> Self {
3061 let mut o = Self::zero();
3062 o.v = c;
3063 o
3064 }
3065 /// Seeded variable `p_idx` at per-lane `value`.
3066 #[inline]
3067 pub fn variable(value: L, idx: usize) -> Self {
3068 let mut o = Self::constant(value);
3069 o.g[idx] = L::splat(1.0);
3070 o
3071 }
3072 /// Extract lane `i` as a scalar [`Tower3<K>`].
3073 #[inline]
3074 pub fn lane(&self, i: usize) -> Tower3<K> {
3075 let mut out = Tower3::<K>::zero();
3076 out.v = self.v.lane(i);
3077 for a in 0..K {
3078 out.g[a] = self.g[a].lane(i);
3079 for b in 0..K {
3080 out.h[a][b] = self.h[a][b].lane(i);
3081 for c in 0..K {
3082 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
3083 }
3084 }
3085 }
3086 out
3087 }
3088 /// Per-channel lane-wise `self + o`.
3089 #[inline]
3090 pub fn add(&self, o: &Self) -> Self {
3091 let mut out = *self;
3092 out.v = self.v.add(o.v);
3093 for i in 0..K {
3094 out.g[i] = self.g[i].add(o.g[i]);
3095 for j in 0..K {
3096 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3097 for k in 0..K {
3098 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3099 }
3100 }
3101 }
3102 out
3103 }
3104 /// Per-channel lane-wise `self - o`.
3105 #[inline]
3106 pub fn sub(&self, o: &Self) -> Self {
3107 let mut out = *self;
3108 out.v = self.v.sub(o.v);
3109 for i in 0..K {
3110 out.g[i] = self.g[i].sub(o.g[i]);
3111 for j in 0..K {
3112 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3113 for k in 0..K {
3114 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3115 }
3116 }
3117 }
3118 out
3119 }
3120 /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3121 #[inline]
3122 pub fn scale(&self, s: f64) -> Self {
3123 let sl = L::splat(s);
3124 let mut out = *self;
3125 out.v = self.v.mul(sl);
3126 for i in 0..K {
3127 out.g[i] = self.g[i].mul(sl);
3128 for j in 0..K {
3129 out.h[i][j] = self.h[i][j].mul(sl);
3130 for k in 0..K {
3131 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3132 }
3133 }
3134 }
3135 out
3136 }
3137 /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3138 #[inline]
3139 pub fn mul(&self, o: &Self) -> Self {
3140 let a = self;
3141 let b = o;
3142 let mut out = Self::zero();
3143 out.v = a.v.mul(b.v);
3144 for i in 0..K {
3145 let mut acc = L::splat(0.0);
3146 acc = acc.add(a.v.mul(b.g[i]));
3147 acc = acc.add(a.g[i].mul(b.v));
3148 out.g[i] = acc;
3149 }
3150 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
3151 for i in 0..K {
3152 for j in i..K {
3153 let mut acc = L::splat(0.0);
3154 acc = acc.add(a.v.mul(b.h[i][j]));
3155 acc = acc.add(a.g[i].mul(b.g[j]));
3156 acc = acc.add(a.g[j].mul(b.g[i]));
3157 acc = acc.add(a.h[i][j].mul(b.v));
3158 out.h[i][j] = acc;
3159 out.h[j][i] = acc;
3160 }
3161 }
3162 for i in 0..K {
3163 for j in 0..K {
3164 for k in 0..K {
3165 let mut acc = L::splat(0.0);
3166 acc = acc.add(a.v.mul(b.t3[i][j][k]));
3167 acc = acc.add(a.g[i].mul(b.h[j][k]));
3168 acc = acc.add(a.g[j].mul(b.h[i][k]));
3169 acc = acc.add(a.h[i][j].mul(b.g[k]));
3170 acc = acc.add(a.g[k].mul(b.h[i][j]));
3171 acc = acc.add(a.h[i][k].mul(b.g[j]));
3172 acc = acc.add(a.h[j][k].mul(b.g[i]));
3173 acc = acc.add(a.t3[i][j][k].mul(b.v));
3174 out.t3[i][j][k] = acc;
3175 }
3176 }
3177 }
3178 out
3179 }
3180 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3181 /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3182 #[inline]
3183 pub fn compose_unary(&self, d: [L; 4]) -> Self {
3184 let mut out = Self::zero();
3185 out.v = d[0];
3186 for i in 0..K {
3187 let mut acc = L::splat(0.0);
3188 acc = acc.add(d[1].mul(self.g[i]));
3189 out.g[i] = acc;
3190 }
3191 for i in 0..K {
3192 for j in 0..K {
3193 let mut acc = L::splat(0.0);
3194 acc = acc.add(d[1].mul(self.h[i][j]));
3195 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3196 out.h[i][j] = acc;
3197 }
3198 }
3199 for i in 0..K {
3200 for j in 0..K {
3201 for k in 0..K {
3202 let mut acc = L::splat(0.0);
3203 acc = acc.add(d[1].mul(self.t3[i][j][k]));
3204 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3205 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3206 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3207 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3208 out.t3[i][j][k] = acc;
3209 }
3210 }
3211 }
3212 out
3213 }
3214 /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3215 /// built from the base value through `stack_fn`, evaluated PER LANE — the
3216 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3217 /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3218 /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3219 /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3220 /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3221 /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3222 #[inline]
3223 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3224 self.compose_unary(self.v.unary_with(stack_fn))
3225 }
3226
3227 /// Single-active-slot fast path, term-for-term lift of
3228 /// [`Tower3::compose_unary_single_slot`].
3229 #[inline]
3230 pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3231 let mut out = Self::zero();
3232 let s = slot;
3233 let g = self.g[s];
3234 let h = self.h[s][s];
3235 let t3 = self.t3[s][s][s];
3236 out.v = d[0];
3237 out.g[s] = {
3238 let mut acc = L::splat(0.0);
3239 acc = acc.add(d[1].mul(g));
3240 acc
3241 };
3242 out.h[s][s] = {
3243 let mut acc = L::splat(0.0);
3244 acc = acc.add(d[1].mul(h));
3245 acc = acc.add(d[2].mul(g).mul(g));
3246 acc
3247 };
3248 out.t3[s][s][s] = {
3249 let mut acc = L::splat(0.0);
3250 acc = acc.add(d[1].mul(t3));
3251 acc = acc.add(d[2].mul(h).mul(g));
3252 acc = acc.add(d[2].mul(h).mul(g));
3253 acc = acc.add(d[2].mul(g).mul(h));
3254 acc = acc.add(d[3].mul(g).mul(g).mul(g));
3255 acc
3256 };
3257 out
3258 }
3259}
3260
3261#[cfg(test)]
3262mod batch_tests {
3263 use super::*;
3264
3265 struct Rng(u64);
3266 impl Rng {
3267 fn f(&mut self) -> f64 {
3268 self.0 = self
3269 .0
3270 .wrapping_mul(6364136223846793005)
3271 .wrapping_add(1442695040888963407);
3272 ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3273 }
3274 }
3275
3276 // Fill every channel of a scalar Tower4<K> with random data.
3277 fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3278 let mut t = Tower4::<K>::zero();
3279 t.v = r.f();
3280 for i in 0..K {
3281 t.g[i] = r.f();
3282 for j in 0..K {
3283 t.h[i][j] = r.f();
3284 for k in 0..K {
3285 t.t3[i][j][k] = r.f();
3286 for l in 0..K {
3287 t.t4[i][j][k][l] = r.f();
3288 }
3289 }
3290 }
3291 }
3292 t
3293 }
3294 fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3295 let mut t = Tower3::<K>::zero();
3296 t.v = r.f();
3297 for i in 0..K {
3298 t.g[i] = r.f();
3299 for j in 0..K {
3300 t.h[i][j] = r.f();
3301 for k in 0..K {
3302 t.t3[i][j][k] = r.f();
3303 }
3304 }
3305 }
3306 t
3307 }
3308 fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3309 let mut b = Tower4Batch::<K>::zero();
3310 let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3311 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3312 };
3313 b.v = lane(&|t| t.v);
3314 for i in 0..K {
3315 b.g[i] = lane(&|t| t.g[i]);
3316 for j in 0..K {
3317 b.h[i][j] = lane(&|t| t.h[i][j]);
3318 for k in 0..K {
3319 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3320 for l in 0..K {
3321 b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3322 }
3323 }
3324 }
3325 }
3326 b
3327 }
3328 fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3329 let mut b = Tower3Batch::<K>::zero();
3330 let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3331 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3332 };
3333 b.v = lane(&|t| t.v);
3334 for i in 0..K {
3335 b.g[i] = lane(&|t| t.g[i]);
3336 for j in 0..K {
3337 b.h[i][j] = lane(&|t| t.h[i][j]);
3338 for k in 0..K {
3339 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3340 }
3341 }
3342 }
3343 b
3344 }
3345 fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3346 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3347 for i in 0..K {
3348 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3349 for j in 0..K {
3350 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3351 for k in 0..K {
3352 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3353 for l in 0..K {
3354 assert_eq!(
3355 b.t4[i][j][k][l].to_bits(),
3356 s.t4[i][j][k][l].to_bits(),
3357 "t4 {ctx}"
3358 );
3359 }
3360 }
3361 }
3362 }
3363 }
3364 fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3365 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3366 for i in 0..K {
3367 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3368 for j in 0..K {
3369 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3370 for k in 0..K {
3371 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3372 }
3373 }
3374 }
3375 }
3376
3377 // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3378 // then assert every channel of every lane is to_bits-identical.
3379 fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3380 let mut r = Rng(seed);
3381 let mut rows_checked = 0;
3382 for _ in 0..batches {
3383 let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3384 let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3385 let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3386 let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3387 let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3388 let s = r.f();
3389
3390 // scalar per-row reference
3391 let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3392 let prod = a[rw].mul(&b[rw]);
3393 let comp = prod.compose_unary(d[rw]);
3394 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3395 summed.compose_unary_single_slot(d[rw], 0)
3396 });
3397 let third: [[[f64; K]; K]; 4] =
3398 std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3399 let fourth: [[[f64; K]; K]; 4] =
3400 std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3401
3402 // batched f64x4
3403 let ab = pack4_t4(&a);
3404 let bb = pack4_t4(&b);
3405 let db: [wide::f64x4; 5] =
3406 std::array::from_fn(|c| wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]]));
3407 let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3408 wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3409 });
3410 let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3411 wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3412 });
3413 let prodb = ab.mul(&bb);
3414 let compb = prodb.compose_unary(db);
3415 let summedb = compb.add(&ab).sub(&bb).scale(s);
3416 let finalb = summedb.compose_unary_single_slot(db, 0);
3417 let thirdb = ab.third_contracted(&dirb);
3418 let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3419
3420 for rw in 0..4 {
3421 assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3422 for i in 0..K {
3423 for j in 0..K {
3424 assert_eq!(
3425 thirdb[i][j].lane(rw).to_bits(),
3426 third[rw][i][j].to_bits(),
3427 "third"
3428 );
3429 assert_eq!(
3430 fourthb[i][j].lane(rw).to_bits(),
3431 fourth[rw][i][j].to_bits(),
3432 "fourth"
3433 );
3434 }
3435 }
3436 rows_checked += 1;
3437 }
3438 }
3439 rows_checked
3440 }
3441 fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3442 let mut r = Rng(seed);
3443 let mut rows_checked = 0;
3444 for _ in 0..batches {
3445 let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3446 let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3447 let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3448 let s = r.f();
3449 let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3450 let prod = a[rw].mul(&b[rw]);
3451 let comp = prod.compose_unary(d[rw]);
3452 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3453 summed.compose_unary_single_slot(d[rw], 0)
3454 });
3455 let ab = pack4_t3(&a);
3456 let bb = pack4_t3(&b);
3457 let db: [wide::f64x4; 4] =
3458 std::array::from_fn(|c| wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]]));
3459 let prodb = ab.mul(&bb);
3460 let compb = prodb.compose_unary(db);
3461 let summedb = compb.add(&ab).sub(&bb).scale(s);
3462 let finalb = summedb.compose_unary_single_slot(db, 0);
3463 for rw in 0..4 {
3464 assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3465 rows_checked += 1;
3466 }
3467 }
3468 rows_checked
3469 }
3470
3471 // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3472 // lanes (≈210 KiB by value); the op chain keeps several live, which can
3473 // exceed a test thread's default stack. Run each width on a large-stack
3474 // thread so K=9 is exercised without a stack overflow.
3475 fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3476 std::thread::Builder::new()
3477 .stack_size(512 << 20)
3478 .spawn(f)
3479 .unwrap()
3480 .join()
3481 .unwrap()
3482 }
3483
3484 #[test]
3485 fn tower4_batch_lane_bit_identical() {
3486 let batches = 2000;
3487 let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3488 + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3489 + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3490 + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3491 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3492 // worker threads against silently running zero comparisons.
3493 assert_eq!(rows_checked, 4 * batches * 4);
3494 }
3495
3496 #[test]
3497 fn tower3_batch_lane_bit_identical() {
3498 let batches = 2000;
3499 let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3500 + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3501 + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3502 + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3503 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3504 // worker threads against silently running zero comparisons.
3505 assert_eq!(rows_checked, 4 * batches * 4);
3506 }
3507
3508 // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3509 //
3510 // The seam lets a single-sourced row program build its special-function
3511 // STACK from the base value through a closure, so the SAME expression
3512 // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3513 // per-lane bases). These oracles pin both arms `to_bits`.
3514
3515 /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3516 /// standing in for a family's hand-certified special-function stack. `stack4`
3517 /// is its order-≤3 truncation.
3518 fn seam_stack5(u: f64) -> [f64; 5] {
3519 [
3520 u.sin(),
3521 u.cos(),
3522 (2.0 * u).sin(),
3523 (0.5 * u).cos(),
3524 u * u - 0.3,
3525 ]
3526 }
3527 fn seam_stack4(u: f64) -> [f64; 4] {
3528 let s = seam_stack5(u);
3529 [s[0], s[1], s[2], s[3]]
3530 }
3531
3532 /// Force a distinct / edge per-lane base value (signed zeros included).
3533 fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3534 match which {
3535 0 => -0.0,
3536 1 => 0.0,
3537 2 => r.f(),
3538 _ => r.f() + 3.0,
3539 }
3540 }
3541
3542 /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3543 /// the explicit `compose_unary(f(value))` on every channel.
3544 fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3545 let mut r = Rng(seed);
3546 for _ in 0..n {
3547 let mut t = rand_t4::<K>(&mut r);
3548 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3549 assert_t4_eq(
3550 &t.compose_unary_with(seam_stack5),
3551 &t.compose_unary(seam_stack5(t.v)),
3552 "scalar t4 seam",
3553 );
3554 }
3555 n
3556 }
3557 fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3558 let mut r = Rng(seed);
3559 for _ in 0..n {
3560 let mut t = rand_t3::<K>(&mut r);
3561 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3562 assert_t3_eq(
3563 &t.compose_unary_with(seam_stack4),
3564 &t.compose_unary(seam_stack4(t.v)),
3565 "scalar t3 seam",
3566 );
3567 }
3568 n
3569 }
3570
3571 /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3572 /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3573 /// with the four lanes carrying DISTINCT base values (signed zeros included),
3574 /// so a buggy impl reusing one lane's base would fail.
3575 fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3576 let mut r = Rng(seed);
3577 let mut verified = 0usize;
3578 for _ in 0..batches {
3579 let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3580 for (rw, row) in rows.iter_mut().enumerate() {
3581 row.v = seam_edge_base(&mut r, rw);
3582 }
3583 let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3584 for (rw, row) in rows.iter().enumerate() {
3585 assert_t4_eq(
3586 &batch_out.lane(rw),
3587 &row.compose_unary_with(seam_stack5),
3588 "lane t4 seam",
3589 );
3590 verified += 1;
3591 }
3592 }
3593 verified
3594 }
3595 fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3596 let mut r = Rng(seed);
3597 let mut verified = 0usize;
3598 for _ in 0..batches {
3599 let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3600 for (rw, row) in rows.iter_mut().enumerate() {
3601 row.v = seam_edge_base(&mut r, rw);
3602 }
3603 let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3604 for (rw, row) in rows.iter().enumerate() {
3605 assert_t3_eq(
3606 &batch_out.lane(rw),
3607 &row.compose_unary_with(seam_stack4),
3608 "lane t3 seam",
3609 );
3610 verified += 1;
3611 }
3612 }
3613 verified
3614 }
3615
3616 #[test]
3617 fn compose_unary_with_scalar_bit_identical() {
3618 let n = 1100;
3619 let total = scalar_seam_t4::<2>(0x2200_0001, n)
3620 + scalar_seam_t4::<3>(0x2200_0002, n)
3621 + scalar_seam_t4::<4>(0x2200_0003, n)
3622 + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3623 + scalar_seam_t3::<2>(0x3300_0001, n)
3624 + scalar_seam_t3::<3>(0x3300_0002, n)
3625 + scalar_seam_t3::<4>(0x3300_0003, n)
3626 + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3627 // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3628 assert_eq!(total, 8 * n);
3629 }
3630
3631 #[test]
3632 fn compose_unary_with_lane_matches_scalar() {
3633 let b = 600;
3634 let total = lane_seam_t4::<2>(0x4400_0001, b)
3635 + lane_seam_t4::<3>(0x4400_0002, b)
3636 + lane_seam_t4::<4>(0x4400_0003, b)
3637 + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3638 + lane_seam_t3::<2>(0x5500_0001, b)
3639 + lane_seam_t3::<3>(0x5500_0002, b)
3640 + lane_seam_t3::<4>(0x5500_0003, b)
3641 + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3642 // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3643 assert_eq!(total, 8 * b * 4);
3644 }
3645}
3646
3647#[cfg(test)]
3648mod tests {
3649 use super::*;
3650
3651 /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3652 /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3653 /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3654 /// dropping the fourth tensor cannot perturb them. Exercises products
3655 /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3656 /// same operations the survival location-scale `nll_index_tower` composes —
3657 /// across all mixed partials, not just the diagonal entries that kernel reads.
3658 #[test]
3659 fn tower3_matches_tower4_through_third_order() {
3660 let s_a: [f64; 5] = [
3661 0.3_f64.sin(),
3662 0.3_f64.cos(),
3663 -0.3_f64.sin(),
3664 -0.3_f64.cos(),
3665 0.3_f64.sin(),
3666 ];
3667 let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3668 let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3669
3670 let a4 = Tower4::<3>::variable(0.4, 0);
3671 let b4 = Tower4::<3>::variable(-0.7, 1);
3672 let c4 = Tower4::<3>::variable(0.9, 2);
3673 let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3674 + a4.mul(&c4).scale(-0.7)
3675 + b4.compose_unary(s_b).scale(0.25);
3676
3677 let a3 = Tower3::<3>::variable(0.4, 0);
3678 let b3 = Tower3::<3>::variable(-0.7, 1);
3679 let c3 = Tower3::<3>::variable(0.9, 2);
3680 let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3681 + a3.mul(&c3).scale(-0.7)
3682 + b3.compose_unary(s4(s_b)).scale(0.25);
3683
3684 assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3685 for i in 0..3 {
3686 assert_eq!(
3687 prog3.g[i].to_bits(),
3688 prog4.g[i].to_bits(),
3689 "g[{i}] mismatch"
3690 );
3691 for j in 0..3 {
3692 assert_eq!(
3693 prog3.h[i][j].to_bits(),
3694 prog4.h[i][j].to_bits(),
3695 "h[{i}][{j}] mismatch"
3696 );
3697 for k in 0..3 {
3698 assert_eq!(
3699 prog3.t3[i][j][k].to_bits(),
3700 prog4.t3[i][j][k].to_bits(),
3701 "t3[{i}][{j}][{k}] mismatch"
3702 );
3703 }
3704 }
3705 }
3706 }
3707
3708 /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3709 /// The entire tower has textbook closed forms in μ = σ(η); this test
3710 /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3711 /// analytic truth at near-machine precision.
3712 struct LogitProgram {
3713 eta: Vec<f64>,
3714 y: Vec<f64>,
3715 }
3716
3717 impl RowNllProgram<1> for LogitProgram {
3718 fn n_rows(&self) -> usize {
3719 self.eta.len()
3720 }
3721 fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3722 Ok([self.eta[row]])
3723 }
3724 fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3725 let eta = p[0];
3726 Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3727 }
3728 }
3729
3730 #[test]
3731 fn logit_tower_matches_closed_forms() {
3732 let prog = LogitProgram {
3733 eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3734 y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3735 };
3736 for row in 0..prog.n_rows() {
3737 let t = evaluate_program(&prog, row).expect("logit program");
3738 let eta = prog.eta[row];
3739 let y = prog.y[row];
3740 let mu = 1.0 / (1.0 + (-eta).exp());
3741 let w = mu * (1.0 - mu);
3742 let expect = [
3743 (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3744 (t.g[0], mu - y, "grad"),
3745 (t.h[0][0], w, "hess"),
3746 (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3747 (
3748 t.t4[0][0][0][0],
3749 w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3750 "fourth",
3751 ),
3752 ];
3753 for (got, want, label) in expect {
3754 assert!(
3755 (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3756 "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3757 );
3758 }
3759 }
3760 }
3761
3762 fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3763 let diff = (got - want).abs();
3764 assert!(
3765 diff <= rel_tol * want.abs().max(1.0),
3766 "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3767 );
3768 }
3769
3770 #[test]
3771 fn gamma_special_function_stacks_match_reference_values() {
3772 const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3773 let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3774 let cases = [
3775 (
3776 "x=0.1",
3777 0.1,
3778 -10.423_754_940_411_076,
3779 101.433_299_150_792_75,
3780 ),
3781 (
3782 "x=0.5",
3783 0.5,
3784 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3785 pi_sq / 2.0,
3786 ),
3787 ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3788 (
3789 "x=2.5",
3790 2.5,
3791 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3792 pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3793 ),
3794 (
3795 "x=50",
3796 50.0,
3797 3.901_989_673_427_892,
3798 0.020_201_333_226_697_128,
3799 ),
3800 ];
3801
3802 for (label, x, digamma_ref, trigamma_ref) in cases {
3803 let ln_gamma_stack = ln_gamma_derivative_stack(x);
3804 let digamma_stack = digamma_derivative_stack(x);
3805 let trigamma_stack = trigamma_derivative_stack(x);
3806 assert_close(
3807 &format!("{label} ln_gamma_stack digamma"),
3808 ln_gamma_stack[1],
3809 digamma_ref,
3810 1e-13,
3811 );
3812 assert_close(
3813 &format!("{label} digamma value"),
3814 digamma_stack[0],
3815 digamma_ref,
3816 1e-13,
3817 );
3818 assert_close(
3819 &format!("{label} ln_gamma_stack trigamma"),
3820 ln_gamma_stack[2],
3821 trigamma_ref,
3822 1e-13,
3823 );
3824 assert_close(
3825 &format!("{label} digamma_stack trigamma"),
3826 digamma_stack[1],
3827 trigamma_ref,
3828 1e-13,
3829 );
3830 assert_close(
3831 &format!("{label} trigamma value"),
3832 trigamma_stack[0],
3833 trigamma_ref,
3834 1e-13,
3835 );
3836 }
3837 }
3838
3839 #[test]
3840 fn gamma_special_function_stacks_obey_recurrences() {
3841 for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3842 let digamma_x = digamma_derivative_stack(x)[0];
3843 let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3844 let trigamma_x = trigamma_derivative_stack(x)[0];
3845 let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3846 assert_close(
3847 &format!("digamma recurrence x={x}"),
3848 digamma_next,
3849 digamma_x + 1.0 / x,
3850 1e-13,
3851 );
3852 assert_close(
3853 &format!("trigamma recurrence x={x}"),
3854 trigamma_next,
3855 trigamma_x - 1.0 / (x * x),
3856 1e-13,
3857 );
3858 }
3859 }
3860
3861 /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3862 /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3863 /// shape — all have one-line closed forms here.
3864 struct LocScaleProgram {
3865 eta: Vec<f64>,
3866 s: Vec<f64>,
3867 y: Vec<f64>,
3868 }
3869
3870 impl RowNllProgram<2> for LocScaleProgram {
3871 fn n_rows(&self) -> usize {
3872 self.eta.len()
3873 }
3874 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3875 Ok([self.eta[row], self.s[row]])
3876 }
3877 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3878 let r = -(p[0] - self.y[row]);
3879 Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3880 }
3881 }
3882
3883 #[test]
3884 fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3885 let prog = LocScaleProgram {
3886 eta: vec![0.3, -1.1, 2.0],
3887 s: vec![-0.5, 0.2, 0.8],
3888 y: vec![1.0, -2.0, 2.5],
3889 };
3890 let tol = 1e-12;
3891 for row in 0..prog.n_rows() {
3892 let t = evaluate_program(&prog, row).expect("locscale program");
3893 let r = prog.y[row] - prog.eta[row];
3894 let w = (-2.0 * prog.s[row]).exp();
3895 // (η, s) = indices (0, 1).
3896 let truth_g = [-w * r, 1.0 - w * r * r];
3897 let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3898 // Third tensor: distinct-entry closed forms.
3899 // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3900 let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3901 match a + b + c {
3902 0 => 0.0,
3903 1 => -2.0 * w,
3904 2 => -4.0 * w * r,
3905 _ => -4.0 * w * r * r,
3906 }
3907 };
3908 // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3909 // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3910 let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3911 match a + b + c + d {
3912 0 | 1 => 0.0,
3913 2 => 4.0 * w,
3914 3 => 8.0 * w * r,
3915 _ => 8.0 * w * r * r,
3916 }
3917 };
3918 for a in 0..2 {
3919 assert!(
3920 (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3921 "row {row} grad[{a}]"
3922 );
3923 for b in 0..2 {
3924 assert!(
3925 (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3926 "row {row} hess[{a}][{b}]: got {} want {}",
3927 t.h[a][b],
3928 truth_h[a][b]
3929 );
3930 for c in 0..2 {
3931 assert!(
3932 (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3933 <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3934 "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3935 t.t3[a][b][c],
3936 t3_truth(a, b, c)
3937 );
3938 for d in 0..2 {
3939 assert!(
3940 (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3941 <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3942 "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3943 t.t4[a][b][c][d],
3944 t4_truth(a, b, c, d)
3945 );
3946 }
3947 }
3948 }
3949 }
3950 // The derived trait-surface helpers agree with direct contraction.
3951 let dir = [0.7, -1.3];
3952 let third = derived_third_contracted(&prog, row, &dir).expect("third");
3953 for a in 0..2 {
3954 for b in 0..2 {
3955 let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3956 assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3957 }
3958 }
3959 }
3960 }
3961
3962 /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3963 /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3964 /// every tower channel is checked against central finite differences
3965 /// of the channel one order below — value→grad, grad→hess, hess→t3,
3966 /// t3→t4 — so each order is independently anchored.
3967 ///
3968 /// The program carries a per-row primary fixture plus a per-row offset
3969 /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3970 /// drives both the seed point and the evaluated expression.
3971 struct GnarlyProgram {
3972 primaries: Vec<[f64; 3]>,
3973 tau: Vec<f64>,
3974 }
3975
3976 impl GnarlyProgram {
3977 fn fixture() -> Self {
3978 Self {
3979 primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3980 tau: vec![0.15, -0.35, 0.5],
3981 }
3982 }
3983 }
3984
3985 impl RowNllProgram<3> for GnarlyProgram {
3986 fn n_rows(&self) -> usize {
3987 self.primaries.len()
3988 }
3989 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3990 self.primaries
3991 .get(row)
3992 .copied()
3993 .ok_or_else(|| format!("gnarly: row {row} out of range"))
3994 }
3995 fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3996 let tau = *self
3997 .tau
3998 .get(row)
3999 .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
4000 let a = (p[0] * p[1]).exp();
4001 let b = (p[2] * p[2] + 1.0).sqrt();
4002 let c = (a + b + tau).ln();
4003 let d = (p[1] * 0.5 + 2.0).powf(1.7);
4004 Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
4005 }
4006 }
4007
4008 /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
4009 /// `row` (used to drive central differences off the fixture grid),
4010 /// while keeping `row`'s per-row data (`tau`) in the loss.
4011 fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
4012 struct At<'a> {
4013 base: &'a GnarlyProgram,
4014 row: usize,
4015 p: [f64; 3],
4016 }
4017 impl RowNllProgram<3> for At<'_> {
4018 fn n_rows(&self) -> usize {
4019 1
4020 }
4021 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
4022 if row != 0 {
4023 return Err(format!("gnarly-at: row {row} out of range"));
4024 }
4025 Ok(self.p)
4026 }
4027 fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
4028 if eval_row != 0 {
4029 return Err(format!("gnarly-at: eval row {eval_row} out of range"));
4030 }
4031 self.base.row_nll(self.row, vars)
4032 }
4033 }
4034 evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
4035 }
4036
4037 #[test]
4038 fn gnarly_tower_is_fd_consistent_order_by_order() {
4039 let prog = GnarlyProgram::fixture();
4040 for row in 0..prog.n_rows() {
4041 let base = prog.primaries(row).expect("primaries");
4042 let t = gnarly_tower_at(&prog, row, base);
4043 let h_step = 1e-5;
4044 let tol = 1e-6;
4045 for c in 0..3 {
4046 let mut up = base;
4047 let mut dn = base;
4048 up[c] += h_step;
4049 dn[c] -= h_step;
4050 let t_up = gnarly_tower_at(&prog, row, up);
4051 let t_dn = gnarly_tower_at(&prog, row, dn);
4052 // value → gradient.
4053 let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
4054 assert!(
4055 (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4056 "grad[{c}]: analytic {} fd {}",
4057 t.g[c],
4058 fd_g
4059 );
4060 for a in 0..3 {
4061 // gradient → Hessian.
4062 let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
4063 assert!(
4064 (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
4065 "hess[{a}][{c}]: analytic {} fd {}",
4066 t.h[a][c],
4067 fd_h
4068 );
4069 for b in 0..3 {
4070 // Hessian → third.
4071 let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
4072 assert!(
4073 (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
4074 "t3[{a}][{b}][{c}]: analytic {} fd {}",
4075 t.t3[a][b][c],
4076 fd_t3
4077 );
4078 for d in 0..3 {
4079 // third → fourth.
4080 let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
4081 assert!(
4082 (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
4083 "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
4084 t.t4[a][b][d][c],
4085 fd_t4
4086 );
4087 }
4088 }
4089 }
4090 }
4091 }
4092 }
4093
4094 /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
4095 /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
4096 /// smooth, strictly-`a`-monotone
4097 /// F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
4098 /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
4099 /// independent finite-difference oracle. Mirrors the survival flex
4100 /// calibration solve (one implicit intercept over the primaries) without
4101 /// any survival machinery, so a failure localises to the combinator.
4102 #[test]
4103 fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
4104 const C: f64 = 1.7;
4105 // The scalar constraint as a plain f64 closure (the production root
4106 // finder analogue) and its tower form in (a, θ₀, θ₁).
4107 let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
4108 let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
4109 let solve = |th: [f64; 2]| -> f64 {
4110 let mut a = 0.0_f64;
4111 for _ in 0..100 {
4112 let r = f_scalar(a, th);
4113 if r.abs() < 1e-14 {
4114 break;
4115 }
4116 a -= r / f_da(a, th);
4117 }
4118 a
4119 };
4120 // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
4121 let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
4122 let a = Tower4::<3>::variable(a0, 0);
4123 let t0 = Tower4::<3>::variable(th[0], 1);
4124 let t1 = Tower4::<3>::variable(th[1], 2);
4125 a + t0 * a.mul(&a) + t1 * a.exp() - C
4126 };
4127
4128 let th0 = [0.35, 0.2];
4129 let a0 = solve(th0);
4130 let f = f_tower(a0, th0);
4131 // Residual at the solved point is ~0 (the combinator tolerates the
4132 // production Newton residual; here it is machine-zero).
4133 assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4134 let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4135
4136 // FD oracle: central differences of the scalar re-solve. Each order is
4137 // built from the previous via one more central difference, exactly the
4138 // gnarly order-by-order ladder.
4139 let h = 1e-4;
4140 let tol = 1e-5;
4141 let re = |th: [f64; 2]| solve(th);
4142 for i in 0..2 {
4143 let mut up = th0;
4144 let mut dn = th0;
4145 up[i] += h;
4146 dn[i] -= h;
4147 let fd_g = (re(up) - re(dn)) / (2.0 * h);
4148 assert!(
4149 (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4150 "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4151 a_tower.g[i],
4152 fd_g
4153 );
4154 // second order: FD of the analytic gradient component would re-use
4155 // the combinator; instead difference a SCALAR gradient computed by
4156 // a nested re-solve so the oracle stays production-independent.
4157 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4158 let mut up = th;
4159 let mut dn = th;
4160 up[j] += h;
4161 dn[j] -= h;
4162 (re(up) - re(dn)) / (2.0 * h)
4163 };
4164 for j in 0..2 {
4165 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4166 assert!(
4167 (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4168 "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4169 a_tower.h[i][j],
4170 fd_h
4171 );
4172 }
4173 }
4174 }
4175
4176 /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4177 /// linear-in-a constraint, and the second-order tensor matches the
4178 /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4179 /// This pins the recursion against the hand-coded first_full.rs formula it
4180 /// replaces, independent of any FD step.
4181 #[test]
4182 fn implicit_solve_matches_textbook_ift_recursion() {
4183 // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4184 let a0 = 0.4_f64;
4185 let th = [0.25_f64, -0.15_f64];
4186 let f = {
4187 let a = Tower4::<3>::variable(a0, 0);
4188 let t0 = Tower4::<3>::variable(th[0], 1);
4189 let t1 = Tower4::<3>::variable(th[1], 2);
4190 // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4191 // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4192 // 0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4193 // implicit_solve requires a genuine root; at the root the level-set
4194 // and root-curve derivatives coincide, so the textbook-IFT
4195 // assertions below are unaffected.
4196 a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4197 };
4198 let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4199 let f_a = f.g[0];
4200 // First order: a_u = −F_u / F_a.
4201 for u in 0..2 {
4202 let want = -f.g[u + 1] / f_a;
4203 assert!(
4204 (a_t.g[u] - want).abs() < 1e-12,
4205 "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4206 a_t.g[u],
4207 want
4208 );
4209 }
4210 // Second order textbook IFT (indices shifted by 1 for the a-slot).
4211 for u in 0..2 {
4212 for v in 0..2 {
4213 let f_uv = f.h[u + 1][v + 1];
4214 let f_au = f.h[0][u + 1];
4215 let f_av = f.h[0][v + 1];
4216 let f_aa = f.h[0][0];
4217 let want =
4218 -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4219 assert!(
4220 (a_t.h[u][v] - want).abs() < 1e-12,
4221 "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4222 a_t.h[u][v],
4223 want
4224 );
4225 }
4226 }
4227 }
4228
4229 /// The moving-boundary flux tower reproduces every θ-derivative of a
4230 /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4231 /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4232 /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4233 /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4234 /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4235 #[test]
4236 fn moving_boundary_flux_carries_b_zuv_term() {
4237 use std::f64::consts::PI;
4238 let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4239 // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4240 let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4241 let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4242 let th0 = [0.7_f64, 0.5_f64];
4243
4244 // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4245 let mut z_edge = Tower4::<2>::constant(z_r(th0));
4246 z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4247 z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4248 z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2 (the z_uv the old flux dropped)
4249
4250 // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4251 // B‴=(3z−z³)·B.
4252 let z0 = z_edge.v;
4253 let b0 = b(z0);
4254 let stack = [
4255 b0,
4256 -z0 * b0,
4257 (z0 * z0 - 1.0) * b0,
4258 (3.0 * z0 - z0 * z0 * z0) * b0,
4259 ];
4260 let flux = moving_limit_boundary_tower(&z_edge, stack);
4261
4262 // FD truth of the integral's derivatives.
4263 let h = 1e-4;
4264 let tol = 1e-6;
4265 for i in 0..2 {
4266 let mut up = th0;
4267 let mut dn = th0;
4268 up[i] += h;
4269 dn[i] -= h;
4270 let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4271 assert!(
4272 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4273 "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4274 flux.g[i],
4275 fd_g
4276 );
4277 }
4278 // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4279 // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4280 // leave the [1][1] entry short by exactly 2·B(z₀).
4281 let grad1_at = |th: [f64; 2]| -> f64 {
4282 let mut up = th;
4283 let mut dn = th;
4284 up[1] += h;
4285 dn[1] -= h;
4286 (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4287 };
4288 let mut up = th0;
4289 let mut dn = th0;
4290 up[1] += h;
4291 dn[1] -= h;
4292 let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4293 assert!(
4294 (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4295 "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4296 flux.h[1][1],
4297 fd_h11
4298 );
4299 // Explicit witness that the B·z_uv term is present and material:
4300 // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4301 let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4302 let b_zuv = flux.h[1][1] - pure_zu2;
4303 assert!(
4304 (b_zuv - b0 * 2.0).abs() < 1e-10,
4305 "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4306 b_zuv,
4307 b0 * 2.0
4308 );
4309 }
4310
4311 /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4312 /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4313 /// plain `moving_limit_boundary_tower` cannot express, and the case the
4314 /// survival directional/bidirectional paths hand-assemble term-by-term
4315 /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4316 /// dropping `G·z_uv`). Two independent oracles:
4317 /// (1) closed-form: the boundary flux of `∫ G dz` is exactly
4318 /// `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4319 /// derivatives we take by central FD of the closed form — no jet code.
4320 /// (2) the explicit second-order hand closure, including the `G·z_uv` term,
4321 /// built from the integrand's own (z,θ) partials.
4322 /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4323 /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4324 /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4325 /// miss a Hessian entry.
4326 #[test]
4327 fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4328 // G(z;θ) = exp(z·θ₀); Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4329 let g = |z: f64, t0: f64| (z * t0).exp();
4330 let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4331 let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4332 let th0 = [0.4_f64, 0.5_f64];
4333 let z0 = z_r(th0);
4334
4335 // Edge tower z_edge(θ) over K=2 primaries.
4336 let mut z_edge = Tower4::<2>::constant(z0);
4337 z_edge.g[0] = 1.0; // ∂z/∂θ₀
4338 z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4339 z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4340
4341 // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4342 // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4343 let z_var = Tower4::<3>::variable(z0, 0);
4344 let t0_var = Tower4::<3>::variable(th0[0], 1);
4345 // θ₁ does not enter G/Φ here (its Φ-derivatives are zero; the z_edge
4346 // chain supplies all θ₁ motion through slot 0), so the K1 frame's θ₁
4347 // slot is intentionally left unseeded.
4348 let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4349 // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4350 assert!(
4351 (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4352 "Φ_z {:+.8e} != G {:+.8e}",
4353 phi_jet.g[0],
4354 g(z0, th0[0])
4355 );
4356
4357 let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4358
4359 // Value channel is 0 by construction (boundary, not the integral itself).
4360 assert!(
4361 flux.v.abs() < 1e-12,
4362 "boundary value channel {:+.3e}",
4363 flux.v
4364 );
4365
4366 // Oracle (1): central FD of the closed-form boundary flux
4367 // Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ) (z₀ FROZEN at the base edge).
4368 let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4369 let h = 1e-4;
4370 let tol = 1e-6;
4371 for i in 0..2 {
4372 let mut up = th0;
4373 let mut dn = th0;
4374 up[i] += h;
4375 dn[i] -= h;
4376 let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4377 assert!(
4378 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4379 "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4380 flux.g[i],
4381 fd_g
4382 );
4383 }
4384 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4385 let mut up = th;
4386 let mut dn = th;
4387 up[j] += h;
4388 dn[j] -= h;
4389 (bnd(up) - bnd(dn)) / (2.0 * h)
4390 };
4391 for i in 0..2 {
4392 for j in 0..2 {
4393 let mut up = th0;
4394 let mut dn = th0;
4395 up[i] += h;
4396 dn[i] -= h;
4397 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4398 assert!(
4399 (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4400 "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4401 flux.h[i][j],
4402 fd_h
4403 );
4404 }
4405 }
4406
4407 // Oracle (2): the explicit second-order hand closure, term by term —
4408 // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4409 // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4410 // G_θ₁ = 0.
4411 let gg = g(z0, th0[0]);
4412 let g_z = th0[0] * gg;
4413 let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4414 for i in 0..2 {
4415 for j in 0..2 {
4416 let z_u = z_edge.g[i];
4417 let z_v = z_edge.g[j];
4418 let z_uv = z_edge.h[i][j];
4419 let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4420 assert!(
4421 (flux.h[i][j] - hand).abs() < 1e-9,
4422 "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4423 flux.h[i][j],
4424 hand
4425 );
4426 }
4427 }
4428
4429 // Decisive: the `G·z_uv` term the directional path DROPS is present and
4430 // material in the [1][1] entry (z_uv = 2 there).
4431 let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4432 let g_zuv = flux.h[1][1] - pure_no_zuv;
4433 assert!(
4434 (g_zuv - gg * 2.0).abs() < 1e-9,
4435 "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4436 g_zuv,
4437 gg * 2.0
4438 );
4439 }
4440
4441 /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4442 /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4443 /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4444 /// boundary-velocity formulas:
4445 /// z_u = −(a_u + [u==g]·z) / b
4446 /// z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4447 /// This pins the bridge between `implicit_solve` and
4448 /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4449 /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4450 /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4451 /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4452 #[test]
4453 fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4454 const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4455 let g_idx = 1usize;
4456 let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4457 // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4458 // two live axes so a_u and a_uv are both exercised. (In production this
4459 // comes from implicit_solve; here we plant known derivatives.)
4460 let mut a = Tower4::<3>::constant(0.45);
4461 a.g[0] = 0.7;
4462 a.g[1] = -0.3;
4463 a.h[0][0] = 0.25;
4464 a.h[0][1] = 0.11;
4465 a.h[1][0] = 0.11;
4466 a.h[1][1] = -0.08;
4467
4468 // In the survival flex frame the slope `b` IS the g-primary directly
4469 // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4470 // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4471 let b = Tower4::<3>::variable(g0, g_idx);
4472 let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4473
4474 let bv = g0;
4475 let z0 = z_edge.v;
4476 assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4477
4478 // z_u = −(a_u + [u==g]·z) / b.
4479 for u in 0..2 {
4480 let direct = if u == g_idx { z0 } else { 0.0 };
4481 let want = -(a.g[u] + direct) / bv;
4482 assert!(
4483 (z_edge.g[u] - want).abs() < 1e-10,
4484 "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4485 z_edge.g[u],
4486 want
4487 );
4488 }
4489 // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4490 // first-order z_v/z_u (already verified above).
4491 for u in 0..2 {
4492 for v in 0..2 {
4493 let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4494 + if v == g_idx { z_edge.g[u] } else { 0.0 };
4495 let want = -(a.h[u][v] + cross) / bv;
4496 assert!(
4497 (z_edge.h[u][v] - want).abs() < 1e-10,
4498 "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4499 z_edge.h[u][v],
4500 want
4501 );
4502 }
4503 }
4504 }
4505
4506 /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4507 /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4508 /// FD-certified bare boundary-velocity constants exactly:
4509 /// z_a = ∂z/∂a = −1/b
4510 /// z_ab = ∂²z/∂a∂b = +1/b²
4511 /// z_aa = ∂²z/∂a² = 0
4512 /// z_bb = ∂²z/∂b² = +2(τ−a)/b³
4513 /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4514 /// production base path drops (and only adds in the dir twins, causing the
4515 /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4516 /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4517 /// that later. Pins the constant before the constraint-tower wiring.
4518 #[test]
4519 fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4520 const TAU: f64 = 1.3;
4521 let a0 = 0.45_f64;
4522 let b0 = 0.85_f64;
4523 // Slot 0 = a, slot 1 = b, both seeded independent.
4524 let a = Tower4::<2>::variable(a0, 0);
4525 let b = Tower4::<2>::variable(b0, 1);
4526 let z = (Tower4::<2>::constant(TAU) - a) / b;
4527
4528 assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4529 assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4530 assert!(
4531 (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4532 "z_ab {:+.10e} vs +1/b² {:+.10e}",
4533 z.h[0][1],
4534 1.0 / (b0 * b0)
4535 );
4536 assert!(
4537 z.h[0][0].abs() < 1e-12,
4538 "z_aa must vanish, got {:+.10e}",
4539 z.h[0][0]
4540 );
4541 let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4542 assert!(
4543 (z.h[1][1] - want_zbb).abs() < 1e-12,
4544 "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4545 z.h[1][1],
4546 want_zbb
4547 );
4548 }
4549
4550 /// The oracle harness catches a planted #736-style sign flip in a
4551 /// cross block and reports the channel by name.
4552 #[test]
4553 fn oracle_catches_planted_cross_block_sign_flip() {
4554 let prog = LocScaleProgram {
4555 eta: vec![0.3],
4556 s: vec![-0.5],
4557 y: vec![1.0],
4558 };
4559 let t = evaluate_program(&prog, 0).expect("tower");
4560 let dir = [0.6, -0.2];
4561 let mut third = t.third_contracted(&dir);
4562 let honest = KernelChannels {
4563 value: t.v,
4564 gradient: t.g,
4565 hessian: t.h,
4566 third: vec![(dir, third)],
4567 fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4568 };
4569 verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4570
4571 // Plant the #736 flip: negate one mixed cross entry.
4572 third[0][1] = -third[0][1];
4573 let flipped = KernelChannels {
4574 value: t.v,
4575 gradient: t.g,
4576 hessian: t.h,
4577 third: vec![(dir, third)],
4578 fourth: vec![],
4579 };
4580 let err = verify_kernel_channels(&t, &flipped, 1e-10)
4581 .expect_err("planted sign flip must be caught");
4582 assert!(
4583 err.contains("third[0][0][1]"),
4584 "oracle must name the flipped channel, got: {err}"
4585 );
4586 }
4587
4588 /// The third- and fourth-order tensors must be FULLY symmetric under
4589 /// index permutation (mixed partials commute). The tower stores them
4590 /// unsymmetrized, so equal-by-construction is a real invariant of the
4591 /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4592 /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4593 /// every composition path contributes. Lives in a test (not the hot
4594 /// per-op path) on purpose.
4595 #[test]
4596 fn t3_t4_are_fully_index_symmetric() {
4597 let prog = GnarlyProgram::fixture();
4598 // 3! = 6 permutations of three indices.
4599 let perms3: [[usize; 3]; 6] = [
4600 [0, 1, 2],
4601 [0, 2, 1],
4602 [1, 0, 2],
4603 [1, 2, 0],
4604 [2, 0, 1],
4605 [2, 1, 0],
4606 ];
4607 // 4! = 24 permutations of four indices.
4608 let perms4: [[usize; 4]; 24] = [
4609 [0, 1, 2, 3],
4610 [0, 1, 3, 2],
4611 [0, 2, 1, 3],
4612 [0, 2, 3, 1],
4613 [0, 3, 1, 2],
4614 [0, 3, 2, 1],
4615 [1, 0, 2, 3],
4616 [1, 0, 3, 2],
4617 [1, 2, 0, 3],
4618 [1, 2, 3, 0],
4619 [1, 3, 0, 2],
4620 [1, 3, 2, 0],
4621 [2, 0, 1, 3],
4622 [2, 0, 3, 1],
4623 [2, 1, 0, 3],
4624 [2, 1, 3, 0],
4625 [2, 3, 0, 1],
4626 [2, 3, 1, 0],
4627 [3, 0, 1, 2],
4628 [3, 0, 2, 1],
4629 [3, 1, 0, 2],
4630 [3, 1, 2, 0],
4631 [3, 2, 0, 1],
4632 [3, 2, 1, 0],
4633 ];
4634 for row in 0..prog.n_rows() {
4635 let t = evaluate_program(&prog, row).expect("gnarly tower");
4636 let scale_t3 =
4637 t.t3.iter()
4638 .flatten()
4639 .flatten()
4640 .fold(0.0_f64, |m, x| m.max(x.abs()))
4641 .max(1.0);
4642 let scale_t4 =
4643 t.t4.iter()
4644 .flatten()
4645 .flatten()
4646 .flatten()
4647 .fold(0.0_f64, |m, x| m.max(x.abs()))
4648 .max(1.0);
4649 for i in 0..3 {
4650 for j in 0..3 {
4651 for k in 0..3 {
4652 let base = t.t3[i][j][k];
4653 let idx = [i, j, k];
4654 for p in &perms3 {
4655 let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4656 assert!(
4657 (base - permed).abs() <= 1e-12 * scale_t3,
4658 "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4659 permuted {permed:+.15e} under {p:?}"
4660 );
4661 }
4662 for l in 0..3 {
4663 let base4 = t.t4[i][j][k][l];
4664 let idx4 = [i, j, k, l];
4665 for p in &perms4 {
4666 let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4667 assert!(
4668 (base4 - permed).abs() <= 1e-12 * scale_t4,
4669 "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4670 permuted {permed:+.15e} under {p:?}"
4671 );
4672 }
4673 }
4674 }
4675 }
4676 }
4677 }
4678 }
4679}
4680
4681#[inline]
4682fn erfcx_nonnegative(x: f64) -> f64 {
4683 if !x.is_finite() {
4684 return if x.is_sign_positive() {
4685 0.0
4686 } else {
4687 f64::INFINITY
4688 };
4689 }
4690 if x <= 0.0 {
4691 return 1.0;
4692 }
4693 if x < 26.0 {
4694 ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4695 } else {
4696 let inv = 1.0 / x;
4697 let inv2 = inv * inv;
4698 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4699 + 6.5625 * inv2 * inv2 * inv2 * inv2;
4700 inv * poly / std::f64::consts::PI.sqrt()
4701 }
4702}
4703
4704#[inline]
4705fn log1mexp_positive(a: f64) -> f64 {
4706 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4707 if a > core::f64::consts::LN_2 {
4708 (-(-a).exp()).ln_1p()
4709 } else if a > 0.0 {
4710 (-(-a).exp_m1()).ln()
4711 } else {
4712 f64::NEG_INFINITY
4713 }
4714}
4715
4716#[inline]
4717fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4718 if x == f64::INFINITY {
4719 return (0.0, 0.0);
4720 }
4721 if x == f64::NEG_INFINITY {
4722 return (f64::NEG_INFINITY, f64::INFINITY);
4723 }
4724 if x.is_nan() {
4725 return (f64::NAN, f64::NAN);
4726 }
4727 if x < 0.0 {
4728 let u = -x / std::f64::consts::SQRT_2;
4729 let ex = erfcx_nonnegative(u).max(1e-300);
4730 let log_cdf = -u * u + (0.5 * ex).ln();
4731 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4732 (log_cdf, lambda)
4733 } else {
4734 let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4735 let lambda = crate::probability::normal_pdf(x) / cdf;
4736 (cdf.ln(), lambda)
4737 }
4738}
4739
4740/// Stable derivative stack for `log Phi(x)` through fourth order.
4741#[inline]
4742pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4743 let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4744 let lambda2 = lambda * lambda;
4745 let lambda3 = lambda2 * lambda;
4746 let x2 = x * x;
4747 [
4748 log_cdf,
4749 lambda,
4750 -lambda * (x + lambda),
4751 lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4752 -lambda
4753 * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4754 ]
4755}
4756
4757/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4758#[inline]
4759pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4760 let r = 1.0 / x.exp_m1();
4761 [
4762 log1mexp_positive(x),
4763 r,
4764 -r * (1.0 + r),
4765 r * (1.0 + r) * (1.0 + 2.0 * r),
4766 -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4767 ]
4768}
4769// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4770#[cfg(test)]
4771mod rowjet_bridge_tests {
4772 use super::*;
4773 use crate::jet_scalar::{JetScalar, Order2};
4774
4775 /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4776 /// subtraction, a scale/neg, a constant, and two value-distinct
4777 /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4778 /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4779 /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4780 struct ToyProgram {
4781 primaries: Vec<[f64; 2]>,
4782 /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4783 /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4784 /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4785 aux: Vec<[f64; 3]>,
4786 }
4787
4788 impl ToyProgram {
4789 /// The body uses `pack_rows` to gather the per-lane continuous data from
4790 /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4791 /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4792 fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4793 let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4794 let z = R::pack_rows(rows, |r| self.aux[r][1]);
4795 let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4796
4797 let a = p[0].mul(&p[1]).scale_rows(cov);
4798 let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4799 let c = b
4800 .compose_unary_with(|u| {
4801 let e = u.exp();
4802 [e, e, e, e, e]
4803 })
4804 .scale_rows(z);
4805 let d = c.neg().add(&p[0]);
4806 let e = d
4807 .compose_unary_with(|u| {
4808 let s = (1.0 + u * u).sqrt();
4809 let s3 = s * s * s;
4810 let s5 = s3 * s * s;
4811 let s7 = s5 * s * s;
4812 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4813 })
4814 .scale_rows(wi);
4815 e.mul(&p[1]).add(&e)
4816 }
4817 }
4818
4819 impl RowNllProgramRowJet<2> for ToyProgram {
4820 fn n_rows(&self) -> usize {
4821 self.primaries.len()
4822 }
4823 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4824 Ok(self.primaries[row])
4825 }
4826 fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4827 assert!(
4828 rows.len() == 1 || rows.len() == 4,
4829 "lane→row map is 1 or 4 wide"
4830 );
4831 Ok(self.body(rows, p))
4832 }
4833 }
4834
4835 fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4836 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4837 for i in 0..2 {
4838 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4839 for j in 0..2 {
4840 assert_eq!(
4841 a.h[i][j].to_bits(),
4842 b.h[i][j].to_bits(),
4843 "{ctx}: h[{i}][{j}]"
4844 );
4845 for k in 0..2 {
4846 assert_eq!(
4847 a.t3[i][j][k].to_bits(),
4848 b.t3[i][j][k].to_bits(),
4849 "{ctx}: t3[{i}][{j}][{k}]"
4850 );
4851 for l in 0..2 {
4852 assert_eq!(
4853 a.t4[i][j][k][l].to_bits(),
4854 b.t4[i][j][k][l].to_bits(),
4855 "{ctx}: t4[{i}][{j}][{k}][{l}]"
4856 );
4857 }
4858 }
4859 }
4860 }
4861 }
4862
4863 fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4864 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4865 for i in 0..2 {
4866 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4867 for j in 0..2 {
4868 assert_eq!(
4869 a.h[i][j].to_bits(),
4870 b.h[i][j].to_bits(),
4871 "{ctx}: h[{i}][{j}]"
4872 );
4873 for k in 0..2 {
4874 assert_eq!(
4875 a.t3[i][j][k].to_bits(),
4876 b.t3[i][j][k].to_bits(),
4877 "{ctx}: t3[{i}][{j}][{k}]"
4878 );
4879 }
4880 }
4881 }
4882 }
4883
4884 // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4885 struct Lcg(u64);
4886 impl Lcg {
4887 fn next(&mut self) -> f64 {
4888 self.0 = self
4889 .0
4890 .wrapping_mul(6364136223846793005)
4891 .wrapping_add(1442695040888963407);
4892 (self.0 >> 11) as f64 / (1u64 << 53) as f64
4893 }
4894 fn val(&mut self) -> f64 {
4895 let u = self.next();
4896 if u < 0.04 {
4897 return 0.0;
4898 }
4899 if u < 0.08 {
4900 return -0.0;
4901 }
4902 (self.next() - 0.5) * 5.0
4903 }
4904 }
4905
4906 /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4907 /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4908 /// signed-zero and per-lane-distinct primaries.
4909 #[test]
4910 fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4911 let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4912 let mut batches = 0usize;
4913 for _ in 0..2500 {
4914 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4915 // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4916 let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4917 let prog = ToyProgram {
4918 primaries: bases.to_vec(),
4919 aux: aux.to_vec(),
4920 };
4921 let rows = [0usize, 1, 2, 3];
4922
4923 // order-4 batch vs scalar Tower4 (instantiated through the same body).
4924 let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4925 for (row, base) in bases.iter().enumerate() {
4926 let vars: [Tower4<2>; 2] =
4927 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4928 let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4929 assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4930 }
4931
4932 // order-3 batch vs scalar Tower3.
4933 let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4934 for (row, base) in bases.iter().enumerate() {
4935 let vars: [Tower3<2>; 2] =
4936 std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4937 let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4938 assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4939 }
4940 batches += 1;
4941 }
4942 assert_eq!(batches, 2500);
4943 }
4944
4945 /// The blanket impl does not churn the scalar path: the body driven through
4946 /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4947 /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4948 /// `Tower4` lower channels.
4949 #[test]
4950 fn blanket_scalar_path_is_unchanged_and_consistent() {
4951 let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4952 for _ in 0..3000 {
4953 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4954 let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4955 let prog = ToyProgram {
4956 primaries: vec![base],
4957 aux: vec![aux0],
4958 };
4959
4960 // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4961 // reference body uses `scale(f64)` where the RowJet body uses
4962 // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4963 // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4964 let via_rowjet: Tower4<2> = {
4965 let vars: [Tower4<2>; 2] =
4966 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4967 prog.row_nll(&[0], &vars).expect("rowjet")
4968 };
4969 let via_jetscalar: Tower4<2> = {
4970 let vars: [Tower4<2>; 2] =
4971 std::array::from_fn(|a| <Tower4<2> as JetScalar<2>>::variable(base[a], a));
4972 let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4973 // The body using JetScalar's own ops + scale(f64) directly.
4974 let a = vars[0].mul(&vars[1]).scale(cov);
4975 let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4976 let c = b
4977 .compose_unary_with(|u| {
4978 let e = u.exp();
4979 [e, e, e, e, e]
4980 })
4981 .scale(z);
4982 let d = JetScalar::neg(&c).add(&vars[0]);
4983 let e = d
4984 .compose_unary_with(|u| {
4985 let s = (1.0 + u * u).sqrt();
4986 let s3 = s * s * s;
4987 let s5 = s3 * s * s;
4988 let s7 = s5 * s * s;
4989 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4990 })
4991 .scale(wi);
4992 e.mul(&vars[1]).add(&e)
4993 };
4994 assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4995
4996 // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4997 // Order2 and Tower4 use different internal representations so
4998 // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4999 // Hessian channels that evaluate to exactly zero; IEEE equality
5000 // treats these as equal, so `==` is the right comparison here.
5001 let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
5002 assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
5003 for i in 0..2 {
5004 assert!(
5005 g[i] == via_rowjet.g[i],
5006 "kernel g[{i}]: {} vs {}",
5007 g[i],
5008 via_rowjet.g[i]
5009 );
5010 for j in 0..2 {
5011 assert!(
5012 h[i][j] == via_rowjet.h[i][j],
5013 "kernel h[{i}][{j}]: {} vs {}",
5014 h[i][j],
5015 via_rowjet.h[i][j]
5016 );
5017 }
5018 }
5019
5020 // (c) the Order2 scalar IS a RowJet via the blanket.
5021 let o2: [Order2<2>; 2] =
5022 std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
5023 let via_order2 = prog.body(&[0], &o2);
5024 assert_eq!(
5025 via_order2.0.v.to_bits(),
5026 via_rowjet.v.to_bits(),
5027 "Order2 blanket value channel must match the dense Tower4 program body"
5028 );
5029 }
5030 }
5031
5032 /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
5033 /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
5034 /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
5035 #[test]
5036 fn scale_rows_scalar_is_bit_identical_to_scale() {
5037 let mut rng = Lcg(0xFEED_FACE_0042_1001);
5038 for _ in 0..3000 {
5039 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
5040 let s = rng.val();
5041 // Build a dense tower with populated channels (exp of a product).
5042 let vars: [Tower4<2>; 2] =
5043 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
5044 let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
5045 let e = u.exp();
5046 [e, e, e, e, e]
5047 });
5048 let via_scale = RowJet::scale(&jet, s);
5049 let via_scale_rows = RowJet::scale_rows(&jet, s);
5050 assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
5051 }
5052 }
5053
5054 /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
5055 /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
5056 /// continuous per-row data path the single-`f64` `scale` could not carry.
5057 #[test]
5058 fn batched_scale_rows_matches_per_row_scalar_scale() {
5059 let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
5060 for _ in 0..2500 {
5061 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
5062 let s: [f64; 4] = std::array::from_fn(|_| rng.val());
5063 let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5064 Tower4Batch::variable(
5065 wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
5066 a,
5067 )
5068 });
5069 let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
5070 let e = u.exp();
5071 [e, e, e, e, e]
5072 });
5073 let scaled = prod.scale_rows(s);
5074 for (row, base) in bases.iter().enumerate() {
5075 let v: [Tower4<2>; 2] =
5076 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
5077 let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
5078 let e = u.exp();
5079 [e, e, e, e, e]
5080 });
5081 let ref_s = RowJet::scale(&prod_s, s[row]);
5082 assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
5083 }
5084 }
5085 }
5086
5087 /// The per-lane guard reports exactly the failing lanes on a batch and the
5088 /// single lane on a scalar jet.
5089 #[test]
5090 fn guard_reports_per_lane_failures() {
5091 let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
5092 let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5093 Tower4Batch::variable(
5094 wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
5095 a,
5096 )
5097 });
5098 let verdict = vars[0].guard(|v| v > 0.0);
5099 assert_eq!(verdict.lanes(), 4);
5100 assert!(verdict.any_failed());
5101 assert!(!verdict.all_pass());
5102 assert!(!verdict.lane_failed(0));
5103 assert!(verdict.lane_failed(1));
5104 assert!(!verdict.lane_failed(2));
5105 assert!(verdict.lane_failed(3));
5106 assert_eq!(verdict.failed_mask(), 0b1010);
5107
5108 let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
5109 let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
5110 assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
5111 assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
5112 assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
5113 }
5114
5115 // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
5116
5117 #[test]
5118 fn ln_gamma_derivative_stack_known_values_at_1() {
5119 let s = ln_gamma_derivative_stack(1.0);
5120 // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
5121 assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
5122 // ψ₀(1) = -γ (Euler–Mascheroni)
5123 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5124 assert!(
5125 (s[1] + euler_mascheroni).abs() < 1e-10,
5126 "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
5127 s[1]
5128 );
5129 // ψ₁(1) = π²/6
5130 let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
5131 assert!(
5132 (s[2] - pi2_6).abs() < 1e-10,
5133 "trigamma(1) ≈ {pi2_6:.6}, got {}",
5134 s[2]
5135 );
5136 }
5137
5138 #[test]
5139 fn ln_gamma_derivative_stack_known_values_at_2() {
5140 let s = ln_gamma_derivative_stack(2.0);
5141 // ln Γ(2) = ln(1) = 0 exactly
5142 assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
5143 // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
5144 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5145 let digamma_2 = 1.0 - euler_mascheroni;
5146 assert!(
5147 (s[1] - digamma_2).abs() < 1e-10,
5148 "digamma(2) ≈ {digamma_2:.6}, got {}",
5149 s[1]
5150 );
5151 }
5152
5153 #[test]
5154 fn ln_gamma_derivative_stack_order2_is_prefix() {
5155 for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5156 let full = ln_gamma_derivative_stack(x);
5157 let ord2 = ln_gamma_derivative_stack_order2(x);
5158 assert_eq!(ord2[0], full[0], "order2[0] != full[0] at x={x}");
5159 assert_eq!(ord2[1], full[1], "order2[1] != full[1] at x={x}");
5160 assert_eq!(ord2[2], full[2], "order2[2] != full[2] at x={x}");
5161 }
5162 }
5163
5164 #[test]
5165 fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5166 // The two stacks share a run of four polygamma values:
5167 // ln_gamma_stack[1..5] == digamma_stack[0..4]
5168 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5169 let lg = ln_gamma_derivative_stack(x);
5170 let dg = digamma_derivative_stack(x);
5171 for i in 0..4 {
5172 assert_eq!(
5173 lg[i + 1],
5174 dg[i],
5175 "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5176 i + 1,
5177 i
5178 );
5179 }
5180 }
5181 }
5182
5183 #[test]
5184 fn trigamma_derivative_stack_overlaps_digamma_stack() {
5185 // digamma_stack[1..5] == trigamma_stack[0..4]
5186 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5187 let dg = digamma_derivative_stack(x);
5188 let tg = trigamma_derivative_stack(x);
5189 for i in 0..4 {
5190 assert_eq!(
5191 dg[i + 1],
5192 tg[i],
5193 "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5194 i + 1,
5195 i
5196 );
5197 }
5198 }
5199 }
5200
5201 #[test]
5202 fn derivative_stacks_all_finite_at_positive_inputs() {
5203 for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5204 for v in ln_gamma_derivative_stack(x) {
5205 assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5206 }
5207 for v in digamma_derivative_stack(x) {
5208 assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5209 }
5210 for v in trigamma_derivative_stack(x) {
5211 assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5212 }
5213 }
5214 }
5215}
5216
5217// ── Contraction-symmetry optimization gate ────────────────────────────────────
5218//
5219// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5220// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5221// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5222// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5223// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5224// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5225// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5226// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5227// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5228// the full reference and (2) a measured wall-clock that is not slower.
5229#[cfg(test)]
5230mod contraction_symmetry_tests {
5231 use super::*;
5232
5233 struct Rng(u64);
5234 impl Rng {
5235 fn u(&mut self) -> f64 {
5236 self.0 = self
5237 .0
5238 .wrapping_mul(6364136223846793005)
5239 .wrapping_add(1442695040888963407);
5240 (self.0 >> 11) as f64 / (1u64 << 53) as f64
5241 }
5242 fn s(&mut self) -> f64 {
5243 (self.u() - 0.5) * 4.0
5244 }
5245 }
5246
5247 /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5248 fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5249 let mut t = Tower4::<K>::zero();
5250 t.v = r.s();
5251 for i in 0..K {
5252 t.g[i] = r.s();
5253 }
5254 for a in 0..K {
5255 for b in a..K {
5256 let v2 = r.s();
5257 t.h[a][b] = v2;
5258 t.h[b][a] = v2;
5259 for c in b..K {
5260 let v3 = r.s();
5261 for p in perms3([a, b, c]) {
5262 t.t3[p[0]][p[1]][p[2]] = v3;
5263 }
5264 for d in c..K {
5265 let v4 = r.s();
5266 for p in perms4([a, b, c, d]) {
5267 t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5268 }
5269 }
5270 }
5271 }
5272 }
5273 t
5274 }
5275
5276 fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5277 let [a, b, c] = idx;
5278 [
5279 [a, b, c],
5280 [a, c, b],
5281 [b, a, c],
5282 [b, c, a],
5283 [c, a, b],
5284 [c, b, a],
5285 ]
5286 }
5287 fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5288 let [a, b, c, d] = idx;
5289 [
5290 [a, b, c, d],
5291 [a, b, d, c],
5292 [a, c, b, d],
5293 [a, c, d, b],
5294 [a, d, b, c],
5295 [a, d, c, b],
5296 [b, a, c, d],
5297 [b, a, d, c],
5298 [b, c, a, d],
5299 [b, c, d, a],
5300 [b, d, a, c],
5301 [b, d, c, a],
5302 [c, a, b, d],
5303 [c, a, d, b],
5304 [c, b, a, d],
5305 [c, b, d, a],
5306 [c, d, a, b],
5307 [c, d, b, a],
5308 [d, a, b, c],
5309 [d, a, c, b],
5310 [d, b, a, c],
5311 [d, b, c, a],
5312 [d, c, a, b],
5313 [d, c, b, a],
5314 ]
5315 }
5316
5317 /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5318 fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5319 let mut out = [[0.0; K]; K];
5320 for a in 0..K {
5321 for b in 0..K {
5322 let mut acc = 0.0;
5323 for c in 0..K {
5324 acc += t.t3[a][b][c] * dir[c];
5325 }
5326 out[a][b] = acc;
5327 }
5328 }
5329 out
5330 }
5331 fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5332 let mut out = [[0.0; K]; K];
5333 for i in 0..K {
5334 for j in 0..K {
5335 let mut acc = 0.0;
5336 for k in 0..K {
5337 for l in 0..K {
5338 acc += t.t4[i][j][k][l] * u[k] * w[l];
5339 }
5340 }
5341 out[i][j] = acc;
5342 }
5343 }
5344 out
5345 }
5346
5347 /// Returns the number of bit-equality comparisons performed (`n·K·K·2`), so
5348 /// the caller can assert the intended workload actually ran: a generic
5349 /// (turbofish) helper call hides its internal assertions, so the count is
5350 /// surfaced and checked at the call site.
5351 fn check_bit_identical<const K: usize>(seed: u64, n: usize) -> usize {
5352 let mut r = Rng(seed);
5353 let mut checks = 0usize;
5354 for _ in 0..n {
5355 let t = rand_sym4::<K>(&mut r);
5356 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5357 let u: [f64; K] = std::array::from_fn(|_| r.s());
5358 let w: [f64; K] = std::array::from_fn(|_| r.s());
5359 let t3_sym = t.third_contracted(&dir);
5360 let t3_full = third_full(&t, &dir);
5361 let t4_sym = t.fourth_contracted(&u, &w);
5362 let t4_full = fourth_full(&t, &u, &w);
5363 for a in 0..K {
5364 for b in 0..K {
5365 assert_eq!(
5366 t3_sym[a][b].to_bits(),
5367 t3_full[a][b].to_bits(),
5368 "third K={K} [{a}][{b}]"
5369 );
5370 assert_eq!(
5371 t4_sym[a][b].to_bits(),
5372 t4_full[a][b].to_bits(),
5373 "fourth K={K} [{a}][{b}]"
5374 );
5375 checks += 2;
5376 }
5377 }
5378 }
5379 checks
5380 }
5381
5382 /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5383 /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5384 /// are unchanged; this is a pure speed-only optimization).
5385 #[test]
5386 fn contraction_symmetry_is_bit_identical_to_full_nest() {
5387 let checks = check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000)
5388 + check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800)
5389 + check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600)
5390 + check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5391 // Guards against the loops silently not running (e.g. a zeroed count):
5392 // 1000·2²·2 + 800·3²·2 + 600·4²·2 + 300·9²·2.
5393 assert_eq!(checks, 8000 + 14400 + 19200 + 48600);
5394 }
5395
5396 /// Measure the wall-clock of the output-symmetric contraction vs the full
5397 /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5398 /// test is the correctness gate). Informational — wall-clock is noisy — with
5399 /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5400 /// fewer inner contractions, so it must not be materially slower).
5401 #[test]
5402 fn contraction_symmetry_speedup_is_reported() {
5403 const K: usize = 9;
5404 let mut r = Rng(0xC0FF_EE99_1234_5678);
5405 let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5406 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5407 let u: [f64; K] = std::array::from_fn(|_| r.s());
5408 let w: [f64; K] = std::array::from_fn(|_| r.s());
5409
5410 let reps = 400usize;
5411 let t_sym = {
5412 let start = std::time::Instant::now();
5413 let mut sink = 0.0f64;
5414 for _ in 0..reps {
5415 for t in &towers {
5416 let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5417 let o4 = std::hint::black_box(t)
5418 .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5419 sink += o3[0][K - 1] + o4[0][K - 1];
5420 }
5421 }
5422 std::hint::black_box(sink);
5423 start.elapsed().as_secs_f64()
5424 };
5425 let t_full = {
5426 let start = std::time::Instant::now();
5427 let mut sink = 0.0f64;
5428 for _ in 0..reps {
5429 for t in &towers {
5430 let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5431 let o4 = fourth_full(
5432 std::hint::black_box(t),
5433 std::hint::black_box(&u),
5434 std::hint::black_box(&w),
5435 );
5436 sink += o3[0][K - 1] + o4[0][K - 1];
5437 }
5438 }
5439 std::hint::black_box(sink);
5440 start.elapsed().as_secs_f64()
5441 };
5442 let calls = (reps * towers.len()) as f64;
5443 eprintln!(
5444 "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5445 wall_speedup={:.2}x",
5446 t_sym / calls * 1e9,
5447 t_full / calls * 1e9,
5448 t_full / t_sym
5449 );
5450 assert!(
5451 t_sym <= t_full * 1.5,
5452 "output-symmetric contraction pathologically slower: \
5453 sym={t_sym:.4}s full={t_full:.4}s"
5454 );
5455 }
5456}