gam_math/jet_tower.rs
1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//! v ℓ
12//! g[a] ∂ℓ/∂p_a
13//! h[a][b] ∂²ℓ/∂p_a∂p_b
14//! t3[abc] ∂³ℓ/∂p_a∂p_b∂p_c
15//! t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//! asserting channel-by-channel agreement with a `RowNllProgram` written
94//! once — see [`verify_kernel_channels`]. This alone would have caught
95//! #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//! keep hand-tuned hot paths, now verified against the single-expression
98//! truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//! scale port) implement ONLY `RowNllProgram` and get an exact
101//! fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113 /// Value ℓ.
114 pub v: f64,
115 /// Gradient ∂ℓ/∂p_a.
116 pub g: [f64; K],
117 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118 pub h: [[f64; K]; K],
119 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120 pub t3: [[[f64; K]; K]; K],
121 /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122 pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126 /// The additive identity.
127 pub fn zero() -> Self {
128 Self {
129 v: 0.0,
130 g: [0.0; K],
131 h: [[0.0; K]; K],
132 t3: [[[0.0; K]; K]; K],
133 t4: [[[[0.0; K]; K]; K]; K],
134 }
135 }
136
137 /// A constant: value `c`, all derivatives zero.
138 pub fn constant(c: f64) -> Self {
139 let mut out = Self::zero();
140 out.v = c;
141 out
142 }
143
144 /// The seeded variable `p_idx` with current value `value`:
145 /// unit first derivative in slot `idx`, zero elsewhere and above.
146 pub fn variable(value: f64, idx: usize) -> Self {
147 let mut out = Self::constant(value);
148 out.g[idx] = 1.0;
149 out
150 }
151
152 /// Read the (fully symmetric) derivative tensor entry whose differentiation
153 /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154 #[inline]
155 fn deriv(&self, labels: &[usize]) -> f64 {
156 assert!(
157 labels.len() <= 4,
158 "Tower4 carries at most fourth-order derivatives"
159 );
160 match labels.len() {
161 0 => self.v,
162 1 => self.g[labels[0]],
163 2 => self.h[labels[0]][labels[1]],
164 3 => self.t3[labels[0]][labels[1]][labels[2]],
165 _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166 }
167 }
168
169 /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170 ///
171 /// # Codegen
172 ///
173 /// Each output entry's `2^m` subset sum is written as a compact straight-line
174 /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175 /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176 /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177 /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178 /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179 /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180 /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181 ///
182 /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183 /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184 /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185 /// start so a signed-zero leading product collapses to `+0.0` identically —
186 /// which matters because real jets carry exact-`0.0` channels
187 /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188 /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189 /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190 /// stress — the accumulator start is load-bearing).
191 pub fn mul(&self, o: &Self) -> Self {
192 let a = self;
193 let b = o;
194 let mut out = Self::zero();
195 out.v = a.v * b.v;
196 for i in 0..K {
197 // subsets of {i}: {} {i}
198 let mut acc = 0.0;
199 acc += a.v * b.g[i];
200 acc += a.g[i] * b.v;
201 out.g[i] = acc;
202 }
203 // Hessian is symmetric under i↔j; compute the upper triangle and mirror
204 // (see [`Tower2::mul`] — same term order, enforces exact symmetry).
205 for i in 0..K {
206 for j in i..K {
207 // subsets of {i,j}: {} {i} {j} {ij}
208 let mut acc = 0.0;
209 acc += a.v * b.h[i][j];
210 acc += a.g[i] * b.g[j];
211 acc += a.g[j] * b.g[i];
212 acc += a.h[i][j] * b.v;
213 out.h[i][j] = acc;
214 out.h[j][i] = acc;
215 }
216 }
217 for i in 0..K {
218 for j in 0..K {
219 for k in 0..K {
220 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
221 let mut acc = 0.0;
222 acc += a.v * b.t3[i][j][k];
223 acc += a.g[i] * b.h[j][k];
224 acc += a.g[j] * b.h[i][k];
225 acc += a.h[i][j] * b.g[k];
226 acc += a.g[k] * b.h[i][j];
227 acc += a.h[i][k] * b.g[j];
228 acc += a.h[j][k] * b.g[i];
229 acc += a.t3[i][j][k] * b.v;
230 out.t3[i][j][k] = acc;
231 }
232 }
233 }
234 for i in 0..K {
235 for j in 0..K {
236 for k in 0..K {
237 for l in 0..K {
238 // subsets of {i,j,k,l} in bit order sub = 0..16
239 let mut acc = 0.0;
240 acc += a.v * b.t4[i][j][k][l];
241 acc += a.g[i] * b.t3[j][k][l];
242 acc += a.g[j] * b.t3[i][k][l];
243 acc += a.h[i][j] * b.h[k][l];
244 acc += a.g[k] * b.t3[i][j][l];
245 acc += a.h[i][k] * b.h[j][l];
246 acc += a.h[j][k] * b.h[i][l];
247 acc += a.t3[i][j][k] * b.g[l];
248 acc += a.g[l] * b.t3[i][j][k];
249 acc += a.h[i][l] * b.h[j][k];
250 acc += a.h[j][l] * b.h[i][k];
251 acc += a.t3[i][j][l] * b.g[k];
252 acc += a.h[k][l] * b.h[i][j];
253 acc += a.t3[i][k][l] * b.g[j];
254 acc += a.t3[j][k][l] * b.g[i];
255 acc += a.t4[i][j][k][l] * b.v;
256 out.t4[i][j][k][l] = acc;
257 }
258 }
259 }
260 }
261 out
262 }
263
264 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
265 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
266 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
267 /// out of the borrowed operands.
268 pub fn add(&self, o: &Self) -> Self {
269 *self + *o
270 }
271
272 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
273 pub fn sub(&self, o: &Self) -> Self {
274 *self + o.scale(-1.0)
275 }
276
277 /// Exact multivariate Faà di Bruno composition `f ∘ self`.
278 ///
279 /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
280 /// the SAME `[f64; 5]` stack shape the families' existing
281 /// `unary_derivatives_*` helpers produce, so those special-function
282 /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
283 ///
284 /// The order-m output sums over the set partitions of the m indices
285 /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
286 /// block count: each partition into r blocks contributes
287 /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
288 ///
289 /// # Codegen
290 ///
291 /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
292 /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
293 /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
294 /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
295 /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
296 /// sum is straight-line, so this does NOT unroll over `K` and does NOT
297 /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
298 /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
299 /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
300 ///
301 /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
302 /// walker's exact partition-enumeration order, each term's block products
303 /// are left-associated exactly as the walker's `prod *= block`, and the
304 /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
305 /// (so signed-zero products collapse to `+0.0` identically). The order-4
306 /// term sequence was generated from the walker's own enumeration. Proven
307 /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
308 /// 5000 random inputs each (zeroed / sign-varied stacks included).
309 pub fn compose_unary(&self, d: [f64; 5]) -> Self {
310 let mut out = Self::zero();
311 out.v = d[0];
312 for i in 0..K {
313 let mut acc = 0.0;
314 acc += d[1] * self.g[i];
315 out.g[i] = acc;
316 }
317 for i in 0..K {
318 for j in 0..K {
319 let mut acc = 0.0;
320 acc += d[1] * self.h[i][j];
321 acc += d[2] * self.g[i] * self.g[j];
322 out.h[i][j] = acc;
323 }
324 }
325 for i in 0..K {
326 for j in 0..K {
327 for k in 0..K {
328 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
329 let mut acc = 0.0;
330 acc += d[1] * self.t3[i][j][k];
331 acc += d[2] * self.h[i][j] * self.g[k];
332 acc += d[2] * self.h[i][k] * self.g[j];
333 acc += d[2] * self.g[i] * self.h[j][k];
334 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
335 out.t3[i][j][k] = acc;
336 }
337 }
338 }
339 for i in 0..K {
340 for j in 0..K {
341 for k in 0..K {
342 for l in 0..K {
343 // Bell(4)=15 partitions, walker enumeration order.
344 let mut acc = 0.0;
345 acc += d[1] * self.t4[i][j][k][l];
346 acc += d[2] * self.t3[i][j][k] * self.g[l];
347 acc += d[2] * self.t3[i][j][l] * self.g[k];
348 acc += d[2] * self.h[i][j] * self.h[k][l];
349 acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
350 acc += d[2] * self.t3[i][k][l] * self.g[j];
351 acc += d[2] * self.h[i][k] * self.h[j][l];
352 acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
353 acc += d[2] * self.h[i][l] * self.h[j][k];
354 acc += d[2] * self.g[i] * self.t3[j][k][l];
355 acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
356 acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
357 acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
358 acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
359 acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
360 out.t4[i][j][k][l] = acc;
361 }
362 }
363 }
364 }
365 out
366 }
367
368 /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
369 /// built from the base value through `stack_fn` — the scalar arm of the
370 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
371 /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
372 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
373 /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
374 /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
375 /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
376 #[inline]
377 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
378 self.compose_unary(stack_fn(self.v))
379 }
380
381 /// Single-active-slot fast path for [`Self::compose_unary`].
382 ///
383 /// When the inner jet `self` has derivative support ONLY on the all-`slot`
384 /// diagonal channels — i.e. it is a univariate jet in primary `slot`
385 /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
386 /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
387 /// Bruno walk collapses. Every output channel whose axis tuple contains an
388 /// axis `≠ slot` is structurally `0`: each set-partition has a block
389 /// covering that axis, that block reads an off-`slot` derivative of `self`
390 /// (which is `0`), so the block product and the whole partition vanish, and
391 /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
392 /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
393 /// `t3[slot]³`, `t4[slot]⁴`) survive.
394 ///
395 /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
396 /// the EXACT term order of [`Self::compose_unary`]'s diagonal
397 /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
398 /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
399 /// at the zero-init `+0.0`, which the full walk also produces (the
400 /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
401 /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
402 /// `K ≥ 2` this is far fewer floating-point operations than materialising
403 /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
404 /// are all zero, and far cheaper than the recursive set-partition walker the
405 /// diagonal channels previously routed through (a measured ~9.5× speedup vs
406 /// the full `compose_unary`, recovering a 5.9× walker regression at the
407 /// `K ∈ {2,3}` BMS tower widths).
408 ///
409 /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
410 /// five-channel build does not amortise the call/spill overhead).
411 ///
412 /// # Precondition
413 ///
414 /// The caller guarantees the single-active-slot structure. If it does not
415 /// hold, the off-`slot` channels would be wrongly zeroed; use the full
416 /// [`Self::compose_unary`] in that case.
417 #[inline]
418 pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
419 let mut out = Self::zero();
420 let s = slot;
421 let g = self.g[s];
422 let h = self.h[s][s];
423 let t3 = self.t3[s][s][s];
424 let t4 = self.t4[s][s][s][s];
425 out.v = d[0];
426 // g (i=s): d1*g
427 out.g[s] = {
428 let mut acc = 0.0;
429 acc += d[1] * g;
430 acc
431 };
432 // h (i=j=s): d1*h + d2*g*g
433 out.h[s][s] = {
434 let mut acc = 0.0;
435 acc += d[1] * h;
436 acc += d[2] * g * g;
437 acc
438 };
439 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
440 out.t3[s][s][s] = {
441 let mut acc = 0.0;
442 acc += d[1] * t3;
443 acc += d[2] * h * g;
444 acc += d[2] * h * g;
445 acc += d[2] * g * h;
446 acc += d[3] * g * g * g;
447 acc
448 };
449 // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
450 out.t4[s][s][s][s] = {
451 let mut acc = 0.0;
452 acc += d[1] * t4;
453 acc += d[2] * t3 * g;
454 acc += d[2] * t3 * g;
455 acc += d[2] * h * h;
456 acc += d[3] * h * g * g;
457 acc += d[2] * t3 * g;
458 acc += d[2] * h * h;
459 acc += d[3] * h * g * g;
460 acc += d[2] * h * h;
461 acc += d[2] * g * t3;
462 acc += d[3] * g * h * g;
463 acc += d[3] * h * g * g;
464 acc += d[3] * g * h * g;
465 acc += d[3] * g * g * h;
466 acc += d[4] * g * g * g * g;
467 acc
468 };
469 out
470 }
471
472 /// Multiply every channel by a plain scalar.
473 pub fn scale(&self, s: f64) -> Self {
474 let mut out = *self;
475 out.v *= s;
476 for i in 0..K {
477 out.g[i] *= s;
478 for j in 0..K {
479 out.h[i][j] *= s;
480 for k in 0..K {
481 out.t3[i][j][k] *= s;
482 for l in 0..K {
483 out.t4[i][j][k][l] *= s;
484 }
485 }
486 }
487 }
488 out
489 }
490
491 /// e^self.
492 pub fn exp(&self) -> Self {
493 let e = self.v.exp();
494 self.compose_unary([e, e, e, e, e])
495 }
496
497 /// ln(self). Caller guarantees positivity (likelihood programs do).
498 pub fn ln(&self) -> Self {
499 let u = self.v;
500 let r = 1.0 / u;
501 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
502 }
503
504 /// 1/self.
505 pub fn recip(&self) -> Self {
506 let r = 1.0 / self.v;
507 let r2 = r * r;
508 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
509 }
510
511 /// √self. Caller guarantees positivity.
512 pub fn sqrt(&self) -> Self {
513 let u = self.v;
514 let s = u.sqrt();
515 self.compose_unary([
516 s,
517 0.5 / s,
518 -0.25 / (u * s),
519 0.375 / (u * u * s),
520 -0.9375 / (u * u * u * s),
521 ])
522 }
523
524 /// self^a for real exponent `a`. Caller guarantees a positive base.
525 pub fn powf(&self, a: f64) -> Self {
526 let u = self.v;
527 let f0 = u.powf(a);
528 let f1 = a * u.powf(a - 1.0);
529 let f2 = a * (a - 1.0) * u.powf(a - 2.0);
530 let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
531 let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
532 self.compose_unary([f0, f1, f2, f3, f4])
533 }
534
535 /// ln Γ(self). Caller guarantees positivity.
536 pub fn ln_gamma(&self) -> Self {
537 self.compose_unary(ln_gamma_derivative_stack(self.v))
538 }
539
540 /// ψ(self), the digamma function. Caller guarantees positivity.
541 pub fn digamma(&self) -> Self {
542 self.compose_unary(digamma_derivative_stack(self.v))
543 }
544
545 /// ψ′(self), the trigamma function. Caller guarantees positivity.
546 pub fn trigamma(&self) -> Self {
547 self.compose_unary(trigamma_derivative_stack(self.v))
548 }
549
550 /// Contract `t3` with one primary-space direction:
551 /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
552 /// `row_third_contracted` shape.
553 ///
554 /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
555 /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
556 /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
557 /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
558 /// stays contiguous/vectorisable) and mirror into the lower triangle — this
559 /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
560 /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
561 pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
562 let mut out = [[0.0; K]; K];
563 for a in 0..K {
564 for b in a..K {
565 let mut acc = 0.0;
566 for c in 0..K {
567 acc += self.t3[a][b][c] * dir[c];
568 }
569 out[a][b] = acc;
570 out[b][a] = acc;
571 }
572 }
573 out
574 }
575
576 /// Contract `t4` with two primary-space directions:
577 /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
578 /// `row_fourth_contracted` shape.
579 ///
580 /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
581 /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
582 /// order), so the upper triangle `i ≤ j` is computed and mirrored —
583 /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
584 /// and the inner double loop stays the original contiguous/vectorisable form.
585 pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
586 let mut out = [[0.0; K]; K];
587 for i in 0..K {
588 for j in i..K {
589 let mut acc = 0.0;
590 for k in 0..K {
591 for l in 0..K {
592 acc += self.t4[i][j][k][l] * u[k] * w[l];
593 }
594 }
595 out[i][j] = acc;
596 out[j][i] = acc;
597 }
598 }
599 out
600 }
601}
602
603impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
604 #[inline]
605 fn derivative(&self, labels: &[usize]) -> f64 {
606 self.deriv(labels)
607 }
608
609 fn map_derivatives<F>(&self, mut f: F) -> Self
610 where
611 F: FnMut(&[usize]) -> f64,
612 {
613 let mut out = Self::zero();
614 out.v = f(&[]);
615 for i in 0..K {
616 let labels = [i];
617 out.g[i] = f(&labels);
618 }
619 for i in 0..K {
620 for j in 0..K {
621 let labels = [i, j];
622 out.h[i][j] = f(&labels);
623 }
624 }
625 for i in 0..K {
626 for j in 0..K {
627 for k in 0..K {
628 let labels = [i, j, k];
629 out.t3[i][j][k] = f(&labels);
630 }
631 }
632 }
633 for i in 0..K {
634 for j in 0..K {
635 for k in 0..K {
636 for l in 0..K {
637 let labels = [i, j, k, l];
638 out.t4[i][j][k][l] = f(&labels);
639 }
640 }
641 }
642 }
643 out
644 }
645}
646
647/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
648///
649/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
650/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
651/// [`Tower4`] uses for those orders, so for any program written over both
652/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
653/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
654/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
655/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
656/// gradient, or Hessian.
657///
658/// It exists purely for performance: an inner Newton step (and the
659/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
660/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
661/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
662/// dominates the cold marginal-slope fit, while returning the exact same
663/// `(v, g, h)`.
664#[derive(Clone, Copy, Debug)]
665pub struct Tower2<const K: usize> {
666 /// Value ℓ.
667 pub v: f64,
668 /// Gradient ∂ℓ/∂p_a.
669 pub g: [f64; K],
670 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
671 pub h: [[f64; K]; K],
672}
673
674impl<const K: usize> Tower2<K> {
675 /// The additive identity.
676 pub fn zero() -> Self {
677 Self {
678 v: 0.0,
679 g: [0.0; K],
680 h: [[0.0; K]; K],
681 }
682 }
683
684 /// A constant: value `c`, all derivatives zero.
685 pub fn constant(c: f64) -> Self {
686 let mut out = Self::zero();
687 out.v = c;
688 out
689 }
690
691 /// The seeded variable `p_idx` with current value `value`:
692 /// unit first derivative in slot `idx`, zero elsewhere and above.
693 pub fn variable(value: f64, idx: usize) -> Self {
694 let mut out = Self::constant(value);
695 out.g[idx] = 1.0;
696 out
697 }
698
699 /// Read the derivative tensor entry whose differentiation axes are
700 /// `labels` (length 0..=2): value, `g`, `h`.
701 #[inline]
702 fn deriv(&self, labels: &[usize]) -> f64 {
703 assert!(
704 labels.len() <= 2,
705 "Tower2 carries at most second-order derivatives"
706 );
707 match labels.len() {
708 0 => self.v,
709 1 => self.g[labels[0]],
710 _ => self.h[labels[0]][labels[1]],
711 }
712 }
713
714 /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` upper
715 /// triangle matches [`Tower4::mul`] term-for-term.
716 ///
717 /// # Symmetry fast path
718 ///
719 /// The order-≤2 Leibniz Hessian
720 /// `h[i][j] = a.v·b.h[i][j] + a.g[i]·b.g[j] + a.g[j]·b.g[i] + a.h[i][j]·b.v`
721 /// is symmetric under `i ↔ j` whenever the operand Hessians are — which they
722 /// always are: `constant`/`variable` seed a symmetric (zero) `h`, and
723 /// `mul`/`compose_unary`/`add`/`scale` each preserve symmetry, so the
724 /// invariant holds for every tower a row program can build. We therefore
725 /// compute only the upper triangle `j ≥ i` and mirror it into the lower
726 /// triangle. At the `K = 9` survival width that is `K(K+1)/2 = 45` four-product
727 /// entry evaluations instead of `K² = 81`, and the win is larger in wall-clock
728 /// because the `648`-entry `h` spills at `K = 9` — halving the expensive
729 /// stores/reloads roughly halves the kernel (measured ≈2× on a `Tower2<9>`
730 /// mul-and-read throughput microbench; the dominant `mul` under every packed
731 /// scalar bottoms out here).
732 ///
733 /// The upper-triangle entries are BIT-IDENTICAL to the old rectangular form
734 /// (same term/accumulation order). The lower triangle now equals its mirror
735 /// exactly, where the rectangular form rounded `h[i][j]` and `h[j][i]`
736 /// independently (the two cross products accumulate in opposite order) and
737 /// left a ≤1-ulp asymmetry; mirroring removes it, so the result is exactly
738 /// symmetric — strictly closer to the true symmetric Hessian, not merely a
739 /// reordering. Dense-`h` consumers are all tolerance-gated (rel-tol ≥ 1e-11 ≫
740 /// 1e-16); the `f64`/`f64x4` lane oracle stays exact because
741 /// [`crate::jet_scalar::Order2Lane::mul`] mirrors term-for-term.
742 pub fn mul(&self, o: &Self) -> Self {
743 let a = self;
744 let b = o;
745 let mut out = Self::zero();
746 out.v = a.v * b.v;
747 for i in 0..K {
748 out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
749 }
750 for i in 0..K {
751 for j in i..K {
752 let hij =
753 a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
754 out.h[i][j] = hij;
755 out.h[j][i] = hij;
756 }
757 }
758 out
759 }
760
761 /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
762 ///
763 /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
764 /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
765 /// `d[0..=2]` for those orders), so this is a strict truncation, not an
766 /// approximation. The full-order `[f64; 5]` derivative stacks the families
767 /// already produce can be passed by slicing their first three entries.
768 ///
769 /// # Codegen
770 ///
771 /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
772 /// directly instead of routing through the generic
773 /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
774 /// closure dispatch). That matters because this is the kernel under EVERY
775 /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
776 /// composition all bottom out here — so the straight-line form (whose inner
777 /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
778 /// walker calls) lifts all of them at once.
779 ///
780 /// The term and accumulation order is BIT-IDENTICAL to the walker it
781 /// replaces: each output channel mirrors the walker's `total = 0.0` start
782 /// (the explicit `acc` accumulator), so a signed-zero product collapses to
783 /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
784 /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
785 /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
786 /// `g[i]` = `f′·u_i` (single block `{i}`)
787 /// `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j` (blocks `{ij}` then `{i}{j}`),
788 /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
789 pub fn compose_unary(&self, d: [f64; 3]) -> Self {
790 let mut out = Self::zero();
791 out.v = d[0];
792 for i in 0..K {
793 let mut acc = 0.0;
794 acc += d[1] * self.g[i];
795 out.g[i] = acc;
796 }
797 for i in 0..K {
798 for j in 0..K {
799 let mut acc = 0.0;
800 acc += d[1] * self.h[i][j];
801 acc += d[2] * self.g[i] * self.g[j];
802 out.h[i][j] = acc;
803 }
804 }
805 out
806 }
807
808 /// Multiply every channel by a plain scalar.
809 pub fn scale(&self, s: f64) -> Self {
810 let mut out = *self;
811 out.v *= s;
812 for i in 0..K {
813 out.g[i] *= s;
814 for j in 0..K {
815 out.h[i][j] *= s;
816 }
817 }
818 out
819 }
820
821 /// e^self.
822 pub fn exp(&self) -> Self {
823 let e = self.v.exp();
824 self.compose_unary([e, e, e])
825 }
826
827 /// √self. Caller guarantees positivity.
828 pub fn sqrt(&self) -> Self {
829 let u = self.v;
830 let s = u.sqrt();
831 self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
832 }
833}
834
835impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
836 #[inline]
837 fn derivative(&self, labels: &[usize]) -> f64 {
838 self.deriv(labels)
839 }
840
841 fn map_derivatives<F>(&self, mut f: F) -> Self
842 where
843 F: FnMut(&[usize]) -> f64,
844 {
845 let mut out = Self::zero();
846 out.v = f(&[]);
847 for i in 0..K {
848 let labels = [i];
849 out.g[i] = f(&labels);
850 }
851 for i in 0..K {
852 for j in 0..K {
853 let labels = [i, j];
854 out.h[i][j] = f(&labels);
855 }
856 }
857 out
858 }
859}
860
861impl<const K: usize> std::ops::Add for Tower2<K> {
862 type Output = Self;
863 fn add(self, o: Self) -> Self {
864 let mut out = self;
865 out.v += o.v;
866 for i in 0..K {
867 out.g[i] += o.g[i];
868 for j in 0..K {
869 out.h[i][j] += o.h[i][j];
870 }
871 }
872 out
873 }
874}
875
876impl<const K: usize> std::ops::Mul for Tower2<K> {
877 type Output = Self;
878 fn mul(self, o: Self) -> Self {
879 Tower2::mul(&self, &o)
880 }
881}
882
883impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
884 type Output = Self;
885 fn add(self, c: f64) -> Self {
886 let mut out = self;
887 out.v += c;
888 out
889 }
890}
891
892impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
893 type Output = Self;
894 fn mul(self, c: f64) -> Self {
895 self.scale(c)
896 }
897}
898
899/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
900///
901/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
902/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
903/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
904/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
905/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
906/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
907/// [`Tower4::compose_unary`]). So for any program written over both towers the
908/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
909/// perturb the value, gradient, Hessian, or third derivatives.
910///
911/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
912/// needs up to third derivatives (the survival location-scale row kernel reads
913/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
914/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
915/// product/composition that otherwise dominates the per-row cost.
916#[derive(Clone, Copy, Debug)]
917pub struct Tower3<const K: usize> {
918 /// Value ℓ.
919 pub v: f64,
920 /// Gradient ∂ℓ/∂p_a.
921 pub g: [f64; K],
922 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
923 pub h: [[f64; K]; K],
924 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
925 pub t3: [[[f64; K]; K]; K],
926}
927
928impl<const K: usize> Tower3<K> {
929 /// The additive identity.
930 pub fn zero() -> Self {
931 Self {
932 v: 0.0,
933 g: [0.0; K],
934 h: [[0.0; K]; K],
935 t3: [[[0.0; K]; K]; K],
936 }
937 }
938
939 /// A constant: value `c`, all derivatives zero.
940 pub fn constant(c: f64) -> Self {
941 let mut out = Self::zero();
942 out.v = c;
943 out
944 }
945
946 /// The seeded variable `p_idx` with current value `value`:
947 /// unit first derivative in slot `idx`, zero elsewhere and above.
948 pub fn variable(value: f64, idx: usize) -> Self {
949 let mut out = Self::constant(value);
950 out.g[idx] = 1.0;
951 out
952 }
953
954 /// Read the (fully symmetric) derivative tensor entry whose differentiation
955 /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
956 #[inline]
957 fn deriv(&self, labels: &[usize]) -> f64 {
958 assert!(
959 labels.len() <= 3,
960 "Tower3 carries at most third-order derivatives"
961 );
962 match labels.len() {
963 0 => self.v,
964 1 => self.g[labels[0]],
965 2 => self.h[labels[0]][labels[1]],
966 _ => self.t3[labels[0]][labels[1]][labels[2]],
967 }
968 }
969
970 /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
971 /// channels match [`Tower4::mul`] term-for-term.
972 ///
973 /// # Codegen
974 ///
975 /// Straight-line per-entry subset sums instead of the
976 /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
977 /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
978 /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
979 /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
980 /// signed-zero leading product on exact-`0.0` jet channels). Proven
981 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
982 /// zero/sign-stressed inputs each (these channel formulas are exactly the
983 /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
984 pub fn mul(&self, o: &Self) -> Self {
985 let a = self;
986 let b = o;
987 let mut out = Self::zero();
988 out.v = a.v * b.v;
989 for i in 0..K {
990 let mut acc = 0.0;
991 acc += a.v * b.g[i];
992 acc += a.g[i] * b.v;
993 out.g[i] = acc;
994 }
995 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
996 for i in 0..K {
997 for j in i..K {
998 let mut acc = 0.0;
999 acc += a.v * b.h[i][j];
1000 acc += a.g[i] * b.g[j];
1001 acc += a.g[j] * b.g[i];
1002 acc += a.h[i][j] * b.v;
1003 out.h[i][j] = acc;
1004 out.h[j][i] = acc;
1005 }
1006 }
1007 for i in 0..K {
1008 for j in 0..K {
1009 for k in 0..K {
1010 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
1011 let mut acc = 0.0;
1012 acc += a.v * b.t3[i][j][k];
1013 acc += a.g[i] * b.h[j][k];
1014 acc += a.g[j] * b.h[i][k];
1015 acc += a.h[i][j] * b.g[k];
1016 acc += a.g[k] * b.h[i][j];
1017 acc += a.h[i][k] * b.g[j];
1018 acc += a.h[j][k] * b.g[i];
1019 acc += a.t3[i][j][k] * b.v;
1020 out.t3[i][j][k] = acc;
1021 }
1022 }
1023 }
1024 out
1025 }
1026
1027 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
1028 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
1029 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
1030 /// out of the borrowed operands.
1031 pub fn add(&self, o: &Self) -> Self {
1032 *self + *o
1033 }
1034
1035 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1036 pub fn sub(&self, o: &Self) -> Self {
1037 *self + o.scale(-1.0)
1038 }
1039
1040 /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1041 ///
1042 /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1043 /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1044 /// (which uses only `d[0..=3]` for those orders), so this is a strict
1045 /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1046 /// stacks the families already produce can be passed by slicing their first
1047 /// four entries.
1048 ///
1049 /// # Codegen
1050 ///
1051 /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1052 /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1053 /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1054 /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1055 /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1056 /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1057 /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1058 /// associated block products, `acc = 0.0` accumulator start. Proven
1059 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1060 /// random inputs each.
1061 pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1062 let mut out = Self::zero();
1063 out.v = d[0];
1064 for i in 0..K {
1065 let mut acc = 0.0;
1066 acc += d[1] * self.g[i];
1067 out.g[i] = acc;
1068 }
1069 for i in 0..K {
1070 for j in 0..K {
1071 let mut acc = 0.0;
1072 acc += d[1] * self.h[i][j];
1073 acc += d[2] * self.g[i] * self.g[j];
1074 out.h[i][j] = acc;
1075 }
1076 }
1077 for i in 0..K {
1078 for j in 0..K {
1079 for k in 0..K {
1080 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1081 let mut acc = 0.0;
1082 acc += d[1] * self.t3[i][j][k];
1083 acc += d[2] * self.h[i][j] * self.g[k];
1084 acc += d[2] * self.h[i][k] * self.g[j];
1085 acc += d[2] * self.g[i] * self.h[j][k];
1086 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1087 out.t3[i][j][k] = acc;
1088 }
1089 }
1090 }
1091 out
1092 }
1093
1094 /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1095 /// built from the base value through `stack_fn` — the scalar arm of the
1096 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1097 /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1098 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1099 /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1100 /// [`Tower4::compose_unary_with`].
1101 #[inline]
1102 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1103 self.compose_unary(stack_fn(self.v))
1104 }
1105
1106 /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1107 /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1108 /// derivative support only on the all-`slot` diagonal, every output channel
1109 /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1110 /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1111 /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1112 /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1113 /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1114 /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1115 /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1116 /// set-partition walker the diagonal channels previously routed through,
1117 /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1118 /// widths. Caller guarantees the single-slot precondition; otherwise use
1119 /// [`Self::compose_unary`].
1120 #[inline]
1121 pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1122 let mut out = Self::zero();
1123 let s = slot;
1124 let g = self.g[s];
1125 let h = self.h[s][s];
1126 let t3 = self.t3[s][s][s];
1127 out.v = d[0];
1128 // g (i=s): d1*g
1129 out.g[s] = {
1130 let mut acc = 0.0;
1131 acc += d[1] * g;
1132 acc
1133 };
1134 // h (i=j=s): d1*h + d2*g*g
1135 out.h[s][s] = {
1136 let mut acc = 0.0;
1137 acc += d[1] * h;
1138 acc += d[2] * g * g;
1139 acc
1140 };
1141 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1142 out.t3[s][s][s] = {
1143 let mut acc = 0.0;
1144 acc += d[1] * t3;
1145 acc += d[2] * h * g;
1146 acc += d[2] * h * g;
1147 acc += d[2] * g * h;
1148 acc += d[3] * g * g * g;
1149 acc
1150 };
1151 out
1152 }
1153
1154 /// Multiply every channel by a plain scalar.
1155 pub fn scale(&self, s: f64) -> Self {
1156 let mut out = *self;
1157 out.v *= s;
1158 for i in 0..K {
1159 out.g[i] *= s;
1160 for j in 0..K {
1161 out.h[i][j] *= s;
1162 for k in 0..K {
1163 out.t3[i][j][k] *= s;
1164 }
1165 }
1166 }
1167 out
1168 }
1169}
1170
1171impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1172 #[inline]
1173 fn derivative(&self, labels: &[usize]) -> f64 {
1174 self.deriv(labels)
1175 }
1176
1177 fn map_derivatives<F>(&self, mut f: F) -> Self
1178 where
1179 F: FnMut(&[usize]) -> f64,
1180 {
1181 let mut out = Self::zero();
1182 out.v = f(&[]);
1183 for i in 0..K {
1184 let labels = [i];
1185 out.g[i] = f(&labels);
1186 }
1187 for i in 0..K {
1188 for j in 0..K {
1189 let labels = [i, j];
1190 out.h[i][j] = f(&labels);
1191 }
1192 }
1193 for i in 0..K {
1194 for j in 0..K {
1195 for k in 0..K {
1196 let labels = [i, j, k];
1197 out.t3[i][j][k] = f(&labels);
1198 }
1199 }
1200 }
1201 out
1202 }
1203}
1204
1205impl<const K: usize> std::ops::Add for Tower3<K> {
1206 type Output = Self;
1207 fn add(self, o: Self) -> Self {
1208 let mut out = self;
1209 out.v += o.v;
1210 for i in 0..K {
1211 out.g[i] += o.g[i];
1212 for j in 0..K {
1213 out.h[i][j] += o.h[i][j];
1214 for k in 0..K {
1215 out.t3[i][j][k] += o.t3[i][j][k];
1216 }
1217 }
1218 }
1219 out
1220 }
1221}
1222
1223pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1224 [
1225 statrs::function::gamma::ln_gamma(x),
1226 digamma_positive(x),
1227 polygamma_positive(1, x),
1228 polygamma_positive(2, x),
1229 polygamma_positive(3, x),
1230 ]
1231}
1232
1233pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1234 [
1235 statrs::function::gamma::ln_gamma(x),
1236 digamma_positive(x),
1237 polygamma_positive(1, x),
1238 ]
1239}
1240
1241pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1242 [
1243 digamma_positive(x),
1244 polygamma_positive(1, x),
1245 polygamma_positive(2, x),
1246 polygamma_positive(3, x),
1247 polygamma_positive(4, x),
1248 ]
1249}
1250
1251pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1252 [
1253 polygamma_positive(1, x),
1254 polygamma_positive(2, x),
1255 polygamma_positive(3, x),
1256 polygamma_positive(4, x),
1257 polygamma_positive(5, x),
1258 ]
1259}
1260
1261/// Scalar digamma ψ(x) for x>0. Bit-identical to `digamma_derivative_stack(x)[0]`
1262/// and to `ln_gamma_derivative_stack(x)[1]`, but evaluates ONLY ψ — the four
1263/// higher polygammas those `[f64; 5]` stacks build are pure discarded work at a
1264/// scalar consumer that reads a single element. Hot-path row kernels that need
1265/// only the digamma value (e.g. the GAMLSS Beta observed cross weight) call this
1266/// instead of indexing `[0]` off a full derivative stack.
1267#[inline]
1268pub fn digamma(x: f64) -> f64 {
1269 digamma_positive(x)
1270}
1271
1272/// Scalar trigamma ψ′(x) for x>0. Bit-identical to
1273/// `trigamma_derivative_stack(x)[0]` (both bottom out in `polygamma_positive(1,
1274/// x)`), but evaluates ONLY ψ′ — the four higher polygammas (orders 2–5) the
1275/// `[f64; 5]` stack builds are discarded at a `[0]` consumer. Used by the
1276/// dispersion-channel Fisher-information row kernels (NB2 `ψ′(θ)−ψ′(θ+μ)`, Beta
1277/// `μψ′(μφ)−(1−μ)ψ′((1−μ)φ)`) which read the trigamma value alone.
1278#[inline]
1279pub fn trigamma(x: f64) -> f64 {
1280 polygamma_positive(1, x)
1281}
1282
1283fn digamma_positive(mut x: f64) -> f64 {
1284 if !(x.is_finite() && x > 0.0) {
1285 return f64::NAN;
1286 }
1287 let mut acc = 0.0;
1288 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1289 acc -= 1.0 / x;
1290 x += 1.0;
1291 }
1292 acc + digamma_asymptotic(x)
1293}
1294
1295fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1296 if !(x.is_finite() && x > 0.0) {
1297 return f64::NAN;
1298 }
1299 let mut acc = 0.0;
1300 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1301 acc += polygamma_recurrence_term(order, x);
1302 x += 1.0;
1303 }
1304 acc + polygamma_asymptotic(order, x)
1305}
1306
1307const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1308const BERNOULLI_EVEN: [(usize, f64); 10] = [
1309 (2, 1.0 / 6.0),
1310 (4, -1.0 / 30.0),
1311 (6, 1.0 / 42.0),
1312 (8, -1.0 / 30.0),
1313 (10, 5.0 / 66.0),
1314 (12, -691.0 / 2730.0),
1315 (14, 7.0 / 6.0),
1316 (16, -3617.0 / 510.0),
1317 (18, 43867.0 / 798.0),
1318 (20, -174611.0 / 330.0),
1319];
1320
1321fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1322 let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1323 sign * factorial(order) / x.powi((order + 1) as i32)
1324}
1325
1326fn digamma_asymptotic(x: f64) -> f64 {
1327 let mut out = x.ln() - 0.5 / x;
1328 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1329 out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1330 }
1331 out
1332}
1333
1334fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1335 if !(1..=5).contains(&order) {
1336 return f64::NAN;
1337 }
1338
1339 let order_factorial = factorial(order);
1340 let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1341 let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1342 + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1343
1344 let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1345 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1346 let rising = rising_factorial(bernoulli_order, order);
1347 out += bernoulli_sign * bernoulli * rising
1348 / bernoulli_order as f64
1349 / x.powi((bernoulli_order + order) as i32);
1350 }
1351 out
1352}
1353
1354fn factorial(n: usize) -> f64 {
1355 (1..=n).fold(1.0, |acc, k| acc * k as f64)
1356}
1357
1358fn rising_factorial(start: usize, len: usize) -> f64 {
1359 (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1360}
1361
1362impl<const K: usize> std::ops::Add for Tower4<K> {
1363 type Output = Self;
1364 fn add(self, o: Self) -> Self {
1365 let mut out = self;
1366 out.v += o.v;
1367 for i in 0..K {
1368 out.g[i] += o.g[i];
1369 for j in 0..K {
1370 out.h[i][j] += o.h[i][j];
1371 for k in 0..K {
1372 out.t3[i][j][k] += o.t3[i][j][k];
1373 for l in 0..K {
1374 out.t4[i][j][k][l] += o.t4[i][j][k][l];
1375 }
1376 }
1377 }
1378 }
1379 out
1380 }
1381}
1382
1383impl<const K: usize> std::ops::Sub for Tower4<K> {
1384 type Output = Self;
1385 fn sub(self, o: Self) -> Self {
1386 self + o.scale(-1.0)
1387 }
1388}
1389
1390impl<const K: usize> std::ops::Neg for Tower4<K> {
1391 type Output = Self;
1392 fn neg(self) -> Self {
1393 self.scale(-1.0)
1394 }
1395}
1396
1397impl<const K: usize> std::ops::Mul for Tower4<K> {
1398 type Output = Self;
1399 fn mul(self, o: Self) -> Self {
1400 Tower4::mul(&self, &o)
1401 }
1402}
1403
1404impl<const K: usize> std::ops::Div for Tower4<K> {
1405 type Output = Self;
1406 fn div(self, o: Self) -> Self {
1407 Tower4::mul(&self, &o.recip())
1408 }
1409}
1410
1411impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1412 type Output = Self;
1413 fn add(self, c: f64) -> Self {
1414 let mut out = self;
1415 out.v += c;
1416 out
1417 }
1418}
1419
1420impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1421 type Output = Self;
1422 fn sub(self, c: f64) -> Self {
1423 self + (-c)
1424 }
1425}
1426
1427impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1428 type Output = Self;
1429 fn mul(self, c: f64) -> Self {
1430 self.scale(c)
1431 }
1432}
1433
1434// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1435//
1436// The flexible survival marginal-slope row loss is NOT a free composition
1437// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1438// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1439// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1440// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1441// last hand-written one in the codebase, and the genus of #736-class
1442// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1443// this reason). These two combinators close that gap: once the constraint
1444// `F` and the integrand/boundaries are themselves towers, the intercept's
1445// derivative tower and the integral's derivative tower come out EXACTLY at
1446// every order — there is no order left to hand-code and forget.
1447
1448/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1449/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1450/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1451/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1452/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1453/// and `a0` is that solved intercept value.
1454///
1455/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1456/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1457/// derivative. This is the mechanical replacement for the hand-coded
1458/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1459/// recursion (first_full.rs) and its third/fourth-order continuations.
1460///
1461/// Method: order-by-order substitution. We build `a` incrementally; at each
1462/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1463/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1464/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1465/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1466/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1467/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1468/// auditable and exact, not a hand-expanded formula per order.
1469///
1470/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1471/// solve's strict monotonicity guard.
1472///
1473/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1474/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1475/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1476/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1477/// is guarded explicitly and re-verified by a composed-residual self-check.
1478pub fn implicit_solve<const K1: usize, const K: usize>(
1479 f: &Tower4<K1>,
1480 a0: f64,
1481) -> Result<Tower4<K>, String> {
1482 assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1483 let f_a = f.g[0];
1484 if f_a == 0.0 || !f_a.is_finite() {
1485 return Err(format!(
1486 "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1487 ));
1488 }
1489 // The expansion point must be a genuine root of F. The single Newton
1490 // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1491 // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1492 // Newton step (rather than f.v directly) makes the criterion invariant to
1493 // the magnitude of f_a / the units of F.
1494 let root_tol = 1e-9;
1495 if !f.v.is_finite() {
1496 return Err(format!(
1497 "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1498 f.v
1499 ));
1500 }
1501 let newton_step = f.v.abs() / f_a.abs();
1502 if newton_step > root_tol * (1.0 + a0.abs()) {
1503 return Err(format!(
1504 "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1505 F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1506 root_tol {root_tol:.1e} · (1 + |a0|)",
1507 f.v
1508 ));
1509 }
1510 // Start with a = constant a0 (correct through order 0). Then lift each
1511 // order in turn. Because substitute_intercept reads `a`'s order-≤m
1512 // tensors when forming G's order-m coefficient, and the order-m
1513 // coefficient of G depends on a's order-m tensor ONLY through the linear
1514 // F_a·a_m term, a single corrective pass per order is exact.
1515 let mut a = Tower4::<K>::constant(a0);
1516 for order in 1..=4 {
1517 let g = substitute_intercept(f, &a);
1518 // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1519 // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1520 // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1521 // as a plain variable in the substitution's first-order part).
1522 match order {
1523 1 => {
1524 for i in 0..K {
1525 a.g[i] -= g.g[i] / f_a;
1526 }
1527 }
1528 2 => {
1529 for i in 0..K {
1530 for j in 0..K {
1531 a.h[i][j] -= g.h[i][j] / f_a;
1532 }
1533 }
1534 }
1535 3 => {
1536 for i in 0..K {
1537 for j in 0..K {
1538 for k in 0..K {
1539 a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1540 }
1541 }
1542 }
1543 }
1544 _ => {
1545 for i in 0..K {
1546 for j in 0..K {
1547 for k in 0..K {
1548 for l in 0..K {
1549 a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1550 }
1551 }
1552 }
1553 }
1554 }
1555 }
1556 }
1557 // Self-check: the composed residual G = F∘a must vanish through order 4.
1558 // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1559 // is exactly the root requirement guarded above. Re-verify all channels
1560 // against a scale-aware floor so any arithmetic regression in the
1561 // substitution recursion is loud rather than silently shipping a
1562 // level-set expansion.
1563 let g = substitute_intercept(f, &a);
1564 let resid_tol = 1e-7 * (1.0 + f_a.abs());
1565 let mut worst = g.v.abs();
1566 for i in 0..K {
1567 worst = worst.max(g.g[i].abs());
1568 for j in 0..K {
1569 worst = worst.max(g.h[i][j].abs());
1570 for k in 0..K {
1571 worst = worst.max(g.t3[i][j][k].abs());
1572 for l in 0..K {
1573 worst = worst.max(g.t4[i][j][k][l].abs());
1574 }
1575 }
1576 }
1577 }
1578 if !worst.is_finite() || worst > resid_tol {
1579 return Err(format!(
1580 "implicit_solve: composed residual G = F∘a does not vanish: \
1581 worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1582 ));
1583 }
1584 Ok(a)
1585}
1586
1587/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1588/// written over `K + 1` variables, returning the composite tower over the
1589/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1590///
1591/// This is the exact multivariate chain rule specialised to "slot 0 is a
1592/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1593/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1594/// point, with the slot-0 increment being the non-constant part of `a` and
1595/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1596/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1597/// so it carries derivatives exactly through order four.
1598pub fn substitute_intercept<const K1: usize, const K: usize>(
1599 f: &Tower4<K1>,
1600 a: &Tower4<K>,
1601) -> Tower4<K> {
1602 assert_eq!(K1, K + 1);
1603 // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1604 // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1605 // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1606 // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1607 // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1608 // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1609 // using tower products for the increment monomials so all θ-derivatives
1610 // propagate exactly.
1611 let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1612 if slot == 0 {
1613 // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1614 let mut d = *a;
1615 d.v = 0.0;
1616 d
1617 } else {
1618 // slot i: the increment δθ_{i-1} = seeded variable minus value.
1619 // θ centred at its expansion value has zero constant term and unit
1620 // first derivative in its own slot.
1621 let mut d = Tower4::<K>::zero();
1622 d.g[slot - 1] = 1.0;
1623 d
1624 }
1625 });
1626 // Accumulate the Taylor sum. order-0 term:
1627 let mut out = Tower4::<K>::constant(f.v);
1628 // order 1: Σ_a f.g[a] · inp[a]
1629 for a_idx in 0..K1 {
1630 out = out + inp[a_idx].scale(f.g[a_idx]);
1631 }
1632 // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1633 for a_idx in 0..K1 {
1634 for b_idx in 0..K1 {
1635 let prod = inp[a_idx].mul(&inp[b_idx]);
1636 out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1637 }
1638 }
1639 // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1640 for a_idx in 0..K1 {
1641 for b_idx in 0..K1 {
1642 for c_idx in 0..K1 {
1643 let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1644 out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1645 }
1646 }
1647 }
1648 // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1649 for a_idx in 0..K1 {
1650 for b_idx in 0..K1 {
1651 for c_idx in 0..K1 {
1652 for d_idx in 0..K1 {
1653 let prod = inp[a_idx]
1654 .mul(&inp[b_idx])
1655 .mul(&inp[c_idx])
1656 .mul(&inp[d_idx]);
1657 out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1658 }
1659 }
1660 }
1661 }
1662 out
1663}
1664
1665/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1666/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1667/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1668/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1669/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1670///
1671/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1672/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1673/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1674/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1675/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1676/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1677/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1678/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1679/// to a constant); callers read only the derivative channels and pair this
1680/// with the interior moment-integral value separately.
1681///
1682/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1683/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1684/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1685pub fn moving_limit_boundary_tower<const K: usize>(
1686 z_edge: &Tower4<K>,
1687 b_stack: [f64; 4],
1688) -> Tower4<K> {
1689 z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1690}
1691
1692/// The boundary-flux derivative tower of a single moving cell integral
1693/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1694/// two edge towers and the integrand stacks at each edge. The returned
1695/// tower's derivative channels are the EXACT moving-boundary contribution to
1696/// every θ-derivative of the cell integral, to fourth order, with no term
1697/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1698/// derivative channels are all zero, contributing nothing — matching the
1699/// production `edge_vel = 0` short-circuit.
1700pub fn cell_moving_boundary_flux_tower<const K: usize>(
1701 z_right: &Tower4<K>,
1702 b_stack_right: [f64; 4],
1703 z_left: &Tower4<K>,
1704 b_stack_left: [f64; 4],
1705) -> Tower4<K> {
1706 moving_limit_boundary_tower(z_right, b_stack_right)
1707 - moving_limit_boundary_tower(z_left, b_stack_left)
1708}
1709
1710/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1711///
1712/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1713/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1714/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1715/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1716/// integrand coefficients move with η, hence with the primaries), so the
1717/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1718/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1719/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1720/// edge-motion-carrying terms the hand path assembles one by one, including the
1721/// `G·z_uv` term the directional path drops).
1722///
1723/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1724/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1725/// part — everything carrying edge motion — is exactly
1726/// `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1727/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1728/// interior moment integral already supplies. Both are one
1729/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1730/// θ in slots 1..K): substituting the edge tower gives the full composite,
1731/// substituting a frozen constant edge isolates the pure-θ part, and their
1732/// difference is the exact boundary flux — every Leibniz term derived by the
1733/// substitution algebra, none hand-omitted.
1734///
1735/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1736/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1737/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1738/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1739/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1740/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1741/// the derivative channels are meaningful, matching the value-less convention of
1742/// [`moving_limit_boundary_tower`].
1743pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1744 phi_jet: &Tower4<K1>,
1745 z_edge: &Tower4<K>,
1746) -> Tower4<K> {
1747 assert_eq!(
1748 K1,
1749 K + 1,
1750 "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1751 );
1752 let frozen_edge = Tower4::<K>::constant(z_edge.v);
1753 let full = substitute_intercept(phi_jet, z_edge);
1754 let interior = substitute_intercept(phi_jet, &frozen_edge);
1755 full - interior
1756}
1757
1758/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1759/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1760/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1761/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1762/// so its `full` and `interior` substitutions coincide and it contributes
1763/// nothing — matching the production `edge_vel = 0` short-circuit.
1764pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1765 phi_jet_right: &Tower4<K1>,
1766 z_right: &Tower4<K>,
1767 phi_jet_left: &Tower4<K1>,
1768 z_left: &Tower4<K>,
1769) -> Tower4<K> {
1770 moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1771 - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1772}
1773
1774// ── The program seam ─────────────────────────────────────────────────
1775
1776/// A family's row negative log-likelihood written ONCE over tower scalars.
1777///
1778/// This is the single source of truth #932 asks for: the value channel of
1779/// the returned tower must BE the production row NLL (same branches, same
1780/// guards, same numerics), and every derivative channel is then exact by
1781/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1782/// NOT part of this trait — it is family data, not calculus, and stays on
1783/// the `RowKernel` implementor.
1784pub trait RowNllProgram<const K: usize>: Send + Sync {
1785 /// Number of observations the program covers.
1786 fn n_rows(&self) -> usize;
1787
1788 /// Current primary-scalar values for `row` (where to seed the tower).
1789 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1790
1791 /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1792 /// variable `a` at the current primary value; implementations combine
1793 /// them with `Tower4` arithmetic and per-row data (response, censoring
1794 /// indicators, offsets) entering as constants.
1795 fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1796}
1797
1798/// Evaluate a program's full tower at the current primaries for one row.
1799///
1800/// One call yields every `RowKernel` calculus channel; callers that need
1801/// several contractions of the same row should hold the returned tower and
1802/// contract repeatedly rather than re-evaluating.
1803pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1804 prog: &P,
1805 row: usize,
1806) -> Result<Tower4<K>, String> {
1807 let p = prog.primaries(row)?;
1808 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1809 prog.row_nll(row, &vars)
1810}
1811
1812/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1813pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1814 prog: &P,
1815 row: usize,
1816) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1817 let t = evaluate_program(prog, row)?;
1818 Ok((t.v, t.g, t.h))
1819}
1820
1821/// Mechanically derived `row_third_contracted` channel.
1822pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1823 prog: &P,
1824 row: usize,
1825 dir: &[f64; K],
1826) -> Result<[[f64; K]; K], String> {
1827 Ok(evaluate_program(prog, row)?.third_contracted(dir))
1828}
1829
1830/// Mechanically derived `row_fourth_contracted` channel.
1831pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1832 prog: &P,
1833 row: usize,
1834 dir_u: &[f64; K],
1835 dir_v: &[f64; K],
1836) -> Result<[[f64; K]; K], String> {
1837 Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1838}
1839
1840// ── The generic program seam (#932 scalar cutover) ───────────────────
1841
1842/// A family's row negative log-likelihood written ONCE over the generic
1843/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1844/// re-instantiated at whatever order / representation a consumer needs
1845/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1846/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1847/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1848/// [`Tower4`] for every channel at once).
1849///
1850/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1851/// program implementing this generic trait gets the small contracted scalars for
1852/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1853/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1854/// work unchanged; new families should prefer this generic trait.
1855///
1856/// Because a `Tower4`-specialised `row_nll` body uses only
1857/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1858/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1859/// A program written that way needs no `Tower4`-specialised method and routes
1860/// the directional and joint-Hessian gates through the contracted scalars from
1861/// a single definition.
1862pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1863 /// Number of observations the program covers.
1864 fn n_rows(&self) -> usize;
1865
1866 /// Current primary-scalar values for `row` (where to seed the scalar).
1867 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1868
1869 /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1870 /// (base value + per-scalar nilpotent directions) by the caller; the body
1871 /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1872 /// (response, censoring, offsets) entering as constants.
1873 fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1874 &self,
1875 row: usize,
1876 p: &[S; K],
1877 ) -> Result<S, String>;
1878}
1879
1880/// Evaluate a generic program at the value/gradient/Hessian scalar
1881/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1882/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1883///
1884/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1885/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1886/// cheap order-2 scalar. The single source of truth means the gradient and
1887/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1888pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1889 prog: &P,
1890 row: usize,
1891) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1892 let base = prog.primaries(row)?;
1893 let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1894 <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1895 });
1896 let s = prog.row_nll_generic(row, &vars)?;
1897 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1898}
1899
1900/// Evaluate a generic program at the one-seed scalar
1901/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1902/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1903/// materialising the dense `t3` tensor. The contraction direction is folded
1904/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1905pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1906 prog: &P,
1907 row: usize,
1908 dir: &[f64; K],
1909) -> Result<[[f64; K]; K], String> {
1910 let base = prog.primaries(row)?;
1911 let vars: [crate::jet_scalar::OneSeed<K>; K] =
1912 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1913 let s = prog.row_nll_generic(row, &vars)?;
1914 Ok(s.contracted_third())
1915}
1916
1917/// Evaluate a generic program at the two-seed scalar
1918/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1919/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1920/// WITHOUT materialising the dense `t4` tensor.
1921pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1922 prog: &P,
1923 row: usize,
1924 dir_u: &[f64; K],
1925 dir_v: &[f64; K],
1926) -> Result<[[f64; K]; K], String> {
1927 let base = prog.primaries(row)?;
1928 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1929 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1930 let s = prog.row_nll_generic(row, &vars)?;
1931 Ok(s.contracted_fourth())
1932}
1933
1934/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1935/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1936/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1937/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1938/// the order-2 / contracted scalars consume, so there is a single source of
1939/// truth across every channel.
1940pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1941 prog: &P,
1942 row: usize,
1943) -> Result<Tower4<K>, String> {
1944 let base = prog.primaries(row)?;
1945 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1946 prog.row_nll_generic(row, &vars)
1947}
1948
1949// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1950//
1951// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1952// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1953// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1954// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1955// as separate inherent methods, not a shared trait bound — so a row-NLL body
1956// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1957// 4-rows-per-pass SIMD batch path could not reuse the single source.
1958//
1959// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1960// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1961// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1962// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1963// (so the scalar call sites compile unchanged and bit-identically), and explicit
1964// impls route the `f64x4` lane towers through their existing per-lane methods. A
1965// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1966// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1967
1968/// The verdict of a per-lane [`RowJet::guard`] domain check.
1969///
1970/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1971/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1972/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1973/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1974/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1975/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1976#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1977pub struct GuardVerdict {
1978 lanes: u8,
1979 failed_mask: u8,
1980}
1981
1982impl GuardVerdict {
1983 /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1984 #[inline]
1985 pub fn scalar(pass: bool) -> Self {
1986 Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1987 }
1988 /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1989 #[inline]
1990 pub fn lanes4(failed_mask: u8) -> Self {
1991 Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1992 }
1993 /// Number of active lanes inspected (1 scalar, 4 batch).
1994 #[inline]
1995 pub fn lanes(self) -> usize {
1996 self.lanes as usize
1997 }
1998 /// True iff every inspected lane satisfied the predicate.
1999 #[inline]
2000 pub fn all_pass(self) -> bool {
2001 self.failed_mask == 0
2002 }
2003 /// True iff at least one inspected lane failed the predicate.
2004 #[inline]
2005 pub fn any_failed(self) -> bool {
2006 self.failed_mask != 0
2007 }
2008 /// True iff lane `i` failed the predicate.
2009 #[inline]
2010 pub fn lane_failed(self, i: usize) -> bool {
2011 (self.failed_mask >> i) & 1 == 1
2012 }
2013 /// The raw failure mask (bit `i` ⇒ lane `i` failed).
2014 #[inline]
2015 pub fn failed_mask(self) -> u8 {
2016 self.failed_mask
2017 }
2018}
2019
2020/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
2021/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
2022/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
2023/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
2024/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
2025/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
2026/// common `N == 5` case is bit-identical to passing the stack straight through).
2027#[inline]
2028fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
2029 let mut out = [0.0_f64; M];
2030 let m = N.min(M);
2031 out[..m].copy_from_slice(&s[..m]);
2032 out
2033}
2034
2035/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
2036/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
2037/// rows/pass without a dual-source copy (module §"The RowJet bridge").
2038///
2039/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
2040/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
2041/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
2042/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
2043/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
2044pub trait RowJet<const K: usize>: Copy {
2045 /// The value channel(s) seen by [`guard`](Self::guard) and
2046 /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
2047 /// `f64x4` lane tower.
2048 type Value: Copy;
2049
2050 /// A constant (value `c`, all derivatives zero), broadcast to every lane.
2051 fn constant(c: f64) -> Self;
2052 /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
2053 /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
2054 /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
2055 /// [`generic_batched_third_tower`]), which build the tower variables directly
2056 /// from each row's primaries; this method is for any row-invariant auxiliary
2057 /// variable a body introduces.
2058 fn variable(x: f64, slot: usize) -> Self;
2059 /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2060 fn values(&self) -> Self::Value;
2061
2062 /// Truncated Leibniz `self + o`.
2063 fn add(&self, o: &Self) -> Self;
2064 /// Truncated Leibniz `self − o`.
2065 fn sub(&self, o: &Self) -> Self;
2066 /// Truncated Leibniz `self · o`.
2067 fn mul(&self, o: &Self) -> Self;
2068 /// Multiply every channel by the plain scalar `s`.
2069 fn scale(&self, s: f64) -> Self;
2070 /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2071 /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2072 fn neg(&self) -> Self {
2073 self.scale(-1.0)
2074 }
2075
2076 /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2077 /// derivative stack is built from the running base value PER LANE through
2078 /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2079 /// inherent method that already exists on both the scalar towers and the lane
2080 /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2081 /// lane tower it is re-run per lane (the four rows carry four distinct base
2082 /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2083 /// Making it a trait method is precisely what lets a body written once over
2084 /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2085 /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2086 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2087
2088 /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2089 /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2090 /// its one value; a lane tower checks all four. Lets a batched program detect
2091 /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2092 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2093
2094 /// Per-lane scale: multiply every channel by the per-lane factor `s`
2095 /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2096 /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2097 /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2098 /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2099 /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2100 /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2101 /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2102 /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2103 /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2104 fn scale_rows(&self, s: Self::Value) -> Self;
2105
2106 /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2107 /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2108 /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2109 /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2110 /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2111 /// knowing the concrete representation: the program holds the per-row data and
2112 /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2113 /// [`RowNllProgramRowJet::row_nll`], so the body writes
2114 /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2115 /// buried in a `compose_unary_with` stack is pulled out the same way:
2116 /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2117 /// (Binary per-row branches such as the event indicator `di` are kept
2118 /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2119 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2120
2121 // ── value-derived transcendental conveniences ───────────────────────
2122 // Each routes through `compose_unary_with` with the SAME derivative stack the
2123 // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2124 // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2125 // `i` is bit-identical to the scalar result on row `i`.
2126
2127 /// `e^self`.
2128 fn exp(&self) -> Self {
2129 self.compose_unary_with(|u| {
2130 let e = u.exp();
2131 [e, e, e, e, e]
2132 })
2133 }
2134 /// `ln(self)`. Caller guarantees positivity.
2135 fn ln(&self) -> Self {
2136 self.compose_unary_with(|u| {
2137 let r = 1.0 / u;
2138 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2139 })
2140 }
2141 /// `√self`. Caller guarantees positivity.
2142 fn sqrt(&self) -> Self {
2143 self.compose_unary_with(|u| {
2144 let s = u.sqrt();
2145 [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2146 })
2147 }
2148 /// `1/self`.
2149 fn recip(&self) -> Self {
2150 self.compose_unary_with(|u| {
2151 let r = 1.0 / u;
2152 let r2 = r * r;
2153 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2154 })
2155 }
2156 /// `self^a` for real `a`. Caller guarantees a positive base.
2157 fn powf(&self, a: f64) -> Self {
2158 self.compose_unary_with(move |u| {
2159 [
2160 u.powf(a),
2161 a * u.powf(a - 1.0),
2162 a * (a - 1.0) * u.powf(a - 2.0),
2163 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2164 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2165 ]
2166 })
2167 }
2168 /// `ln Γ(self)`. Caller guarantees a positive argument.
2169 fn ln_gamma(&self) -> Self {
2170 self.compose_unary_with(ln_gamma_derivative_stack)
2171 }
2172 /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2173 fn digamma(&self) -> Self {
2174 self.compose_unary_with(digamma_derivative_stack)
2175 }
2176}
2177
2178/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2179/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2180/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2181/// bridge adds the lane representation without churning the scalar path. (The
2182/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2183/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2184/// any downstream impl, so the coherence checker proves the impls disjoint.)
2185impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2186 type Value = f64;
2187 #[inline]
2188 fn constant(c: f64) -> Self {
2189 <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2190 }
2191 #[inline]
2192 fn variable(x: f64, slot: usize) -> Self {
2193 <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2194 }
2195 #[inline]
2196 fn values(&self) -> f64 {
2197 crate::jet_scalar::JetScalar::value(self)
2198 }
2199 #[inline]
2200 fn add(&self, o: &Self) -> Self {
2201 crate::jet_scalar::JetScalar::add(self, o)
2202 }
2203 #[inline]
2204 fn sub(&self, o: &Self) -> Self {
2205 crate::jet_scalar::JetScalar::sub(self, o)
2206 }
2207 #[inline]
2208 fn mul(&self, o: &Self) -> Self {
2209 crate::jet_scalar::JetScalar::mul(self, o)
2210 }
2211 #[inline]
2212 fn scale(&self, s: f64) -> Self {
2213 crate::jet_scalar::JetScalar::scale(self, s)
2214 }
2215 #[inline]
2216 fn neg(&self) -> Self {
2217 crate::jet_scalar::JetScalar::neg(self)
2218 }
2219 #[inline]
2220 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2221 crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2222 }
2223 #[inline]
2224 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2225 GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2226 }
2227 #[inline]
2228 fn scale_rows(&self, s: f64) -> Self {
2229 // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2230 // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2231 crate::jet_scalar::JetScalar::scale(self, s)
2232 }
2233 #[inline]
2234 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2235 value_of(rows[0])
2236 }
2237}
2238
2239/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2240/// routing each op through its existing per-lane method. Lane `i` of a batched
2241/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2242/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2243impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2244 type Value = [f64; 4];
2245 #[inline]
2246 fn constant(c: f64) -> Self {
2247 Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2248 }
2249 #[inline]
2250 fn variable(x: f64, slot: usize) -> Self {
2251 Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2252 }
2253 #[inline]
2254 fn values(&self) -> [f64; 4] {
2255 self.v.to_array()
2256 }
2257 #[inline]
2258 fn add(&self, o: &Self) -> Self {
2259 Tower4Lane::add(self, o)
2260 }
2261 #[inline]
2262 fn sub(&self, o: &Self) -> Self {
2263 Tower4Lane::sub(self, o)
2264 }
2265 #[inline]
2266 fn mul(&self, o: &Self) -> Self {
2267 Tower4Lane::mul(self, o)
2268 }
2269 #[inline]
2270 fn scale(&self, s: f64) -> Self {
2271 Tower4Lane::scale(self, s)
2272 }
2273 #[inline]
2274 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2275 Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2276 }
2277 #[inline]
2278 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2279 let vals = self.v.to_array();
2280 let mut mask = 0u8;
2281 for (i, &v) in vals.iter().enumerate() {
2282 if !pred(v) {
2283 mask |= 1 << i;
2284 }
2285 }
2286 GuardVerdict::lanes4(mask)
2287 }
2288 #[inline]
2289 fn scale_rows(&self, s: [f64; 4]) -> Self {
2290 // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2291 // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2292 let sl = wide::f64x4::new(s);
2293 let mut out = *self;
2294 out.v = self.v * sl;
2295 for i in 0..K {
2296 out.g[i] = self.g[i] * sl;
2297 for j in 0..K {
2298 out.h[i][j] = self.h[i][j] * sl;
2299 for k in 0..K {
2300 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2301 for l in 0..K {
2302 out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2303 }
2304 }
2305 }
2306 }
2307 out
2308 }
2309 #[inline]
2310 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2311 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2312 }
2313}
2314
2315/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2316/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2317/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2318/// [`Tower3`] which also carries only up to the third tensor.
2319impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2320 type Value = [f64; 4];
2321 #[inline]
2322 fn constant(c: f64) -> Self {
2323 Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2324 }
2325 #[inline]
2326 fn variable(x: f64, slot: usize) -> Self {
2327 Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2328 }
2329 #[inline]
2330 fn values(&self) -> [f64; 4] {
2331 self.v.to_array()
2332 }
2333 #[inline]
2334 fn add(&self, o: &Self) -> Self {
2335 Tower3Lane::add(self, o)
2336 }
2337 #[inline]
2338 fn sub(&self, o: &Self) -> Self {
2339 Tower3Lane::sub(self, o)
2340 }
2341 #[inline]
2342 fn mul(&self, o: &Self) -> Self {
2343 Tower3Lane::mul(self, o)
2344 }
2345 #[inline]
2346 fn scale(&self, s: f64) -> Self {
2347 Tower3Lane::scale(self, s)
2348 }
2349 #[inline]
2350 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2351 Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2352 }
2353 #[inline]
2354 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2355 let vals = self.v.to_array();
2356 let mut mask = 0u8;
2357 for (i, &v) in vals.iter().enumerate() {
2358 if !pred(v) {
2359 mask |= 1 << i;
2360 }
2361 }
2362 GuardVerdict::lanes4(mask)
2363 }
2364 #[inline]
2365 fn scale_rows(&self, s: [f64; 4]) -> Self {
2366 let sl = wide::f64x4::new(s);
2367 let mut out = *self;
2368 out.v = self.v * sl;
2369 for i in 0..K {
2370 out.g[i] = self.g[i] * sl;
2371 for j in 0..K {
2372 out.h[i][j] = self.h[i][j] * sl;
2373 for k in 0..K {
2374 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2375 }
2376 }
2377 }
2378 out
2379 }
2380 #[inline]
2381 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2382 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2383 }
2384}
2385
2386/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2387/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2388/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2389/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2390/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2391/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2392/// and the batched channels through [`generic_batched_fourth_tower`] /
2393/// [`generic_batched_third_tower`], all from a single source.
2394pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2395 /// Number of observations the program covers.
2396 fn n_rows(&self) -> usize;
2397
2398 /// Current primary-scalar values for `row` (where to seed each lane).
2399 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2400
2401 /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2402 /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2403 /// pre-seeded by the caller (base value plus, for the directional scalars, the
2404 /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2405 /// per-row data entering through `rows`/`self` as constants.
2406 fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2407}
2408
2409/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2410/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2411/// `RowJet` twin of [`generic_row_kernel`].
2412pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2413 prog: &P,
2414 row: usize,
2415) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2416 let base = prog.primaries(row)?;
2417 let vars: [crate::jet_scalar::Order2<K>; K] =
2418 std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2419 let s = prog.row_nll(&[row], &vars)?;
2420 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2421}
2422
2423/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2424/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2425/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2426pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2427 prog: &P,
2428 row: usize,
2429 dir: &[f64; K],
2430) -> Result<[[f64; K]; K], String> {
2431 let base = prog.primaries(row)?;
2432 let vars: [crate::jet_scalar::OneSeed<K>; K] =
2433 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2434 let s = prog.row_nll(&[row], &vars)?;
2435 Ok(s.contracted_third())
2436}
2437
2438/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2439/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2440/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2441pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2442 prog: &P,
2443 row: usize,
2444 dir_u: &[f64; K],
2445 dir_v: &[f64; K],
2446) -> Result<[[f64; K]; K], String> {
2447 let base = prog.primaries(row)?;
2448 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2449 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2450 let s = prog.row_nll(&[row], &vars)?;
2451 Ok(s.contracted_fourth())
2452}
2453
2454/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2455/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2456/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2457/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2458/// the scalar [`generic_full_tower`] on `rows[i]`.
2459pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2460 prog: &P,
2461 rows: [usize; 4],
2462) -> Result<Tower4Batch<K>, String> {
2463 let bases: [[f64; K]; 4] = [
2464 prog.primaries(rows[0])?,
2465 prog.primaries(rows[1])?,
2466 prog.primaries(rows[2])?,
2467 prog.primaries(rows[3])?,
2468 ];
2469 let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2470 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2471 Tower4Batch::variable(lane_vals, a)
2472 });
2473 prog.row_nll(&rows, &vars)
2474}
2475
2476/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2477/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2478/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2479/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2480pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2481 prog: &P,
2482 rows: [usize; 4],
2483) -> Result<Tower3Batch<K>, String> {
2484 let bases: [[f64; K]; 4] = [
2485 prog.primaries(rows[0])?,
2486 prog.primaries(rows[1])?,
2487 prog.primaries(rows[2])?,
2488 prog.primaries(rows[3])?,
2489 ];
2490 let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2491 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2492 Tower3Batch::variable(lane_vals, a)
2493 });
2494 prog.row_nll(&rows, &vars)
2495}
2496
2497// ── The oracle ───────────────────────────────────────────────────────
2498
2499/// One row's worth of hand-written kernel outputs, as claimed by a
2500/// `RowKernel` implementation, packaged for verification against the
2501/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2502/// its visibility — can be audited from its own test module.
2503pub struct KernelChannels<const K: usize> {
2504 /// Claimed `(nll, ∇, H)` from `row_kernel`.
2505 pub value: f64,
2506 /// Claimed gradient.
2507 pub gradient: [f64; K],
2508 /// Claimed Hessian.
2509 pub hessian: [[f64; K]; K],
2510 /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2511 pub third: Vec<([f64; K], [[f64; K]; K])>,
2512 /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2513 pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2514}
2515
2516/// Channel-by-channel audit of a hand-written kernel against the
2517/// single-expression tower truth. Returns `Err` naming the first channel,
2518/// index, claimed and true values on disagreement — designed as the body
2519/// of the per-family CI oracle tests (#932 deployment step 2).
2520///
2521/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2522/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2523/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2524/// towers pass without demanding bit-equality, while a tiny cross-block
2525/// entry dropped next to a huge one is still caught (it is NOT measured
2526/// against the largest entry of the whole channel — there is no per-channel
2527/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2528///
2529/// Non-finite handling is strict: a NaN on either side always fails; an
2530/// infinity passes only when both sides are the SAME signed infinity.
2531pub fn verify_kernel_channels<const K: usize>(
2532 tower: &Tower4<K>,
2533 claims: &KernelChannels<K>,
2534 rel_tol: f64,
2535) -> Result<(), String> {
2536 // Absolute floor: reuse rel_tol so a single knob controls both the
2537 // relative band and the absolute floor for entries near zero.
2538 let atol = rel_tol;
2539 let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2540 // Non-finite values never silently pass the algebraic comparison
2541 // below (any comparison with NaN is false). Handle them explicitly:
2542 // NaN on either side always errs; an infinity passes only if both
2543 // sides are the identical signed infinity.
2544 if !claim.is_finite() || !truth.is_finite() {
2545 let agree = claim.is_infinite()
2546 && truth.is_infinite()
2547 && claim.is_sign_positive() == truth.is_sign_positive();
2548 if agree {
2549 return Ok(());
2550 }
2551 return Err(format!(
2552 "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2553 ));
2554 }
2555 let band = atol + rel_tol * claim.abs().max(truth.abs());
2556 if (claim - truth).abs() > band {
2557 return Err(format!(
2558 "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2559 ));
2560 }
2561 Ok(())
2562 };
2563
2564 check("value", claims.value, tower.v)?;
2565
2566 for a in 0..K {
2567 check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2568 }
2569
2570 for a in 0..K {
2571 for b in 0..K {
2572 check(
2573 &format!("hessian[{a}][{b}]"),
2574 claims.hessian[a][b],
2575 tower.h[a][b],
2576 )?;
2577 }
2578 }
2579
2580 for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2581 let truth = tower.third_contracted(dir);
2582 for a in 0..K {
2583 for b in 0..K {
2584 check(
2585 &format!("third[{t_idx}][{a}][{b}]"),
2586 claim[a][b],
2587 truth[a][b],
2588 )?;
2589 }
2590 }
2591 }
2592
2593 for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2594 let truth = tower.fourth_contracted(u, w);
2595 for a in 0..K {
2596 for b in 0..K {
2597 check(
2598 &format!("fourth[{f_idx}][{a}][{b}]"),
2599 claim[a][b],
2600 truth[a][b],
2601 )?;
2602 }
2603 }
2604 }
2605
2606 Ok(())
2607}
2608
2609// ===========================================================================
2610// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2611// ===========================================================================
2612//
2613// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2614// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2615// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2616// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2617// vector pass instead of one per scalar pass.
2618//
2619// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2620// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2621// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2622// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2623// performs no fp-contraction), so lane `i` of any channel of a
2624// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2625// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2626// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2627// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2628// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2629//
2630// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2631// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2632// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2633// exactly as the scalar path already does).
2634
2635use crate::jet_scalar::Lane;
2636
2637/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2638/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2639/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2640#[derive(Clone, Copy)]
2641pub struct Tower4Lane<L: Lane, const K: usize> {
2642 /// Value channel (one entry per lane/row).
2643 pub v: L,
2644 /// Gradient `∂/∂p_a`.
2645 pub g: [L; K],
2646 /// Hessian `∂²/∂p_a∂p_b`.
2647 pub h: [[L; K]; K],
2648 /// Third tensor `∂³`.
2649 pub t3: [[[L; K]; K]; K],
2650 /// Fourth tensor `∂⁴`.
2651 pub t4: [[[[L; K]; K]; K]; K],
2652}
2653
2654/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2655pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2656
2657impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2658 /// All-zero tower (every channel `+0.0` in every lane).
2659 #[inline]
2660 pub fn zero() -> Self {
2661 let z = L::splat(0.0);
2662 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2663 }
2664 /// Constant `c` (per lane): value channel only.
2665 #[inline]
2666 pub fn constant(c: L) -> Self {
2667 let mut o = Self::zero();
2668 o.v = c;
2669 o
2670 }
2671 /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2672 /// slot `idx` (mirrors [`Tower4::variable`]).
2673 #[inline]
2674 pub fn variable(value: L, idx: usize) -> Self {
2675 let mut o = Self::constant(value);
2676 o.g[idx] = L::splat(1.0);
2677 o
2678 }
2679 /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2680 #[inline]
2681 pub fn lane(&self, i: usize) -> Tower4<K> {
2682 let mut out = Tower4::<K>::zero();
2683 out.v = self.v.lane(i);
2684 for a in 0..K {
2685 out.g[a] = self.g[a].lane(i);
2686 for b in 0..K {
2687 out.h[a][b] = self.h[a][b].lane(i);
2688 for c in 0..K {
2689 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2690 for d in 0..K {
2691 out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2692 }
2693 }
2694 }
2695 }
2696 out
2697 }
2698 /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2699 #[inline]
2700 pub fn add(&self, o: &Self) -> Self {
2701 let mut out = *self;
2702 out.v = self.v.add(o.v);
2703 for i in 0..K {
2704 out.g[i] = self.g[i].add(o.g[i]);
2705 for j in 0..K {
2706 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2707 for k in 0..K {
2708 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2709 for l in 0..K {
2710 out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2711 }
2712 }
2713 }
2714 }
2715 out
2716 }
2717 /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2718 #[inline]
2719 pub fn sub(&self, o: &Self) -> Self {
2720 let mut out = *self;
2721 out.v = self.v.sub(o.v);
2722 for i in 0..K {
2723 out.g[i] = self.g[i].sub(o.g[i]);
2724 for j in 0..K {
2725 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2726 for k in 0..K {
2727 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2728 for l in 0..K {
2729 out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2730 }
2731 }
2732 }
2733 }
2734 out
2735 }
2736 /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2737 #[inline]
2738 pub fn scale(&self, s: f64) -> Self {
2739 let sl = L::splat(s);
2740 let mut out = *self;
2741 out.v = self.v.mul(sl);
2742 for i in 0..K {
2743 out.g[i] = self.g[i].mul(sl);
2744 for j in 0..K {
2745 out.h[i][j] = self.h[i][j].mul(sl);
2746 for k in 0..K {
2747 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2748 for l in 0..K {
2749 out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2750 }
2751 }
2752 }
2753 }
2754 out
2755 }
2756 /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2757 #[inline]
2758 pub fn mul(&self, o: &Self) -> Self {
2759 let a = self;
2760 let b = o;
2761 let mut out = Self::zero();
2762 out.v = a.v.mul(b.v);
2763 for i in 0..K {
2764 let mut acc = L::splat(0.0);
2765 acc = acc.add(a.v.mul(b.g[i]));
2766 acc = acc.add(a.g[i].mul(b.v));
2767 out.g[i] = acc;
2768 }
2769 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
2770 for i in 0..K {
2771 for j in i..K {
2772 let mut acc = L::splat(0.0);
2773 acc = acc.add(a.v.mul(b.h[i][j]));
2774 acc = acc.add(a.g[i].mul(b.g[j]));
2775 acc = acc.add(a.g[j].mul(b.g[i]));
2776 acc = acc.add(a.h[i][j].mul(b.v));
2777 out.h[i][j] = acc;
2778 out.h[j][i] = acc;
2779 }
2780 }
2781 for i in 0..K {
2782 for j in 0..K {
2783 for k in 0..K {
2784 let mut acc = L::splat(0.0);
2785 acc = acc.add(a.v.mul(b.t3[i][j][k]));
2786 acc = acc.add(a.g[i].mul(b.h[j][k]));
2787 acc = acc.add(a.g[j].mul(b.h[i][k]));
2788 acc = acc.add(a.h[i][j].mul(b.g[k]));
2789 acc = acc.add(a.g[k].mul(b.h[i][j]));
2790 acc = acc.add(a.h[i][k].mul(b.g[j]));
2791 acc = acc.add(a.h[j][k].mul(b.g[i]));
2792 acc = acc.add(a.t3[i][j][k].mul(b.v));
2793 out.t3[i][j][k] = acc;
2794 }
2795 }
2796 }
2797 for i in 0..K {
2798 for j in 0..K {
2799 for k in 0..K {
2800 for l in 0..K {
2801 let mut acc = L::splat(0.0);
2802 acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2803 acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2804 acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2805 acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2806 acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2807 acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2808 acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2809 acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2810 acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2811 acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2812 acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2813 acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2814 acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2815 acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2816 acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2817 acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2818 out.t4[i][j][k][l] = acc;
2819 }
2820 }
2821 }
2822 }
2823 out
2824 }
2825 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2826 /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2827 /// (build via [`Lane::unary5`] from the scalar special-function stack).
2828 #[inline]
2829 pub fn compose_unary(&self, d: [L; 5]) -> Self {
2830 let mut out = Self::zero();
2831 out.v = d[0];
2832 for i in 0..K {
2833 let mut acc = L::splat(0.0);
2834 acc = acc.add(d[1].mul(self.g[i]));
2835 out.g[i] = acc;
2836 }
2837 for i in 0..K {
2838 for j in 0..K {
2839 let mut acc = L::splat(0.0);
2840 acc = acc.add(d[1].mul(self.h[i][j]));
2841 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2842 out.h[i][j] = acc;
2843 }
2844 }
2845 for i in 0..K {
2846 for j in 0..K {
2847 for k in 0..K {
2848 let mut acc = L::splat(0.0);
2849 acc = acc.add(d[1].mul(self.t3[i][j][k]));
2850 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2851 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2852 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2853 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2854 out.t3[i][j][k] = acc;
2855 }
2856 }
2857 }
2858 for i in 0..K {
2859 for j in 0..K {
2860 for k in 0..K {
2861 for l in 0..K {
2862 let mut acc = L::splat(0.0);
2863 acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2864 acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2865 acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2866 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2867 acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2868 acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2869 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2870 acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2871 acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2872 acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2873 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2874 acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2875 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2876 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2877 acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2878 out.t4[i][j][k][l] = acc;
2879 }
2880 }
2881 }
2882 }
2883 out
2884 }
2885 /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2886 /// built from the base value through `stack_fn`, evaluated PER LANE — the
2887 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2888 /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2889 ///
2890 /// Each of the four lanes carries a DISTINCT base value, so the scalar
2891 /// `stack_fn` is run once per lane at that lane's own value (via
2892 /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2893 /// the composition is then the existing per-lane [`Self::compose_unary`].
2894 /// Because `unary_with` runs the identical scalar closure per lane and
2895 /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2896 /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2897 /// — which is exactly what lets a row program written against the scalar
2898 /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2899 #[inline]
2900 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2901 self.compose_unary(self.v.unary_with(stack_fn))
2902 }
2903
2904 /// Single-active-slot fast path, term-for-term lift of
2905 /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2906 #[inline]
2907 pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2908 let mut out = Self::zero();
2909 let s = slot;
2910 let g = self.g[s];
2911 let h = self.h[s][s];
2912 let t3 = self.t3[s][s][s];
2913 let t4 = self.t4[s][s][s][s];
2914 out.v = d[0];
2915 out.g[s] = {
2916 let mut acc = L::splat(0.0);
2917 acc = acc.add(d[1].mul(g));
2918 acc
2919 };
2920 out.h[s][s] = {
2921 let mut acc = L::splat(0.0);
2922 acc = acc.add(d[1].mul(h));
2923 acc = acc.add(d[2].mul(g).mul(g));
2924 acc
2925 };
2926 out.t3[s][s][s] = {
2927 let mut acc = L::splat(0.0);
2928 acc = acc.add(d[1].mul(t3));
2929 acc = acc.add(d[2].mul(h).mul(g));
2930 acc = acc.add(d[2].mul(h).mul(g));
2931 acc = acc.add(d[2].mul(g).mul(h));
2932 acc = acc.add(d[3].mul(g).mul(g).mul(g));
2933 acc
2934 };
2935 out.t4[s][s][s][s] = {
2936 let mut acc = L::splat(0.0);
2937 acc = acc.add(d[1].mul(t4));
2938 acc = acc.add(d[2].mul(t3).mul(g));
2939 acc = acc.add(d[2].mul(t3).mul(g));
2940 acc = acc.add(d[2].mul(h).mul(h));
2941 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2942 acc = acc.add(d[2].mul(t3).mul(g));
2943 acc = acc.add(d[2].mul(h).mul(h));
2944 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2945 acc = acc.add(d[2].mul(h).mul(h));
2946 acc = acc.add(d[2].mul(g).mul(t3));
2947 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2948 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2949 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2950 acc = acc.add(d[3].mul(g).mul(g).mul(h));
2951 acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2952 acc
2953 };
2954 out
2955 }
2956 /// Contract `t3` with a primary-space direction (lift of
2957 /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2958 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2959 #[inline]
2960 pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2961 let mut out = [[L::splat(0.0); K]; K];
2962 for a in 0..K {
2963 for b in a..K {
2964 let mut acc = L::splat(0.0);
2965 for c in 0..K {
2966 acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2967 }
2968 out[a][b] = acc;
2969 out[b][a] = acc;
2970 }
2971 }
2972 out
2973 }
2974 /// Contract `t4` with two primary-space directions (lift of
2975 /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
2976 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2977 #[inline]
2978 pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2979 let mut out = [[L::splat(0.0); K]; K];
2980 for i in 0..K {
2981 for j in i..K {
2982 let mut acc = L::splat(0.0);
2983 for k in 0..K {
2984 for l in 0..K {
2985 acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2986 }
2987 }
2988 out[i][j] = acc;
2989 out[j][i] = acc;
2990 }
2991 }
2992 out
2993 }
2994}
2995
2996/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2997#[derive(Clone, Copy)]
2998pub struct Tower3Lane<L: Lane, const K: usize> {
2999 /// Value channel.
3000 pub v: L,
3001 /// Gradient.
3002 pub g: [L; K],
3003 /// Hessian.
3004 pub h: [[L; K]; K],
3005 /// Third tensor.
3006 pub t3: [[[L; K]; K]; K],
3007}
3008
3009/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
3010pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
3011
3012impl<L: Lane, const K: usize> Tower3Lane<L, K> {
3013 /// All-zero tower.
3014 #[inline]
3015 pub fn zero() -> Self {
3016 let z = L::splat(0.0);
3017 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
3018 }
3019 /// Constant `c` (per lane).
3020 #[inline]
3021 pub fn constant(c: L) -> Self {
3022 let mut o = Self::zero();
3023 o.v = c;
3024 o
3025 }
3026 /// Seeded variable `p_idx` at per-lane `value`.
3027 #[inline]
3028 pub fn variable(value: L, idx: usize) -> Self {
3029 let mut o = Self::constant(value);
3030 o.g[idx] = L::splat(1.0);
3031 o
3032 }
3033 /// Extract lane `i` as a scalar [`Tower3<K>`].
3034 #[inline]
3035 pub fn lane(&self, i: usize) -> Tower3<K> {
3036 let mut out = Tower3::<K>::zero();
3037 out.v = self.v.lane(i);
3038 for a in 0..K {
3039 out.g[a] = self.g[a].lane(i);
3040 for b in 0..K {
3041 out.h[a][b] = self.h[a][b].lane(i);
3042 for c in 0..K {
3043 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
3044 }
3045 }
3046 }
3047 out
3048 }
3049 /// Per-channel lane-wise `self + o`.
3050 #[inline]
3051 pub fn add(&self, o: &Self) -> Self {
3052 let mut out = *self;
3053 out.v = self.v.add(o.v);
3054 for i in 0..K {
3055 out.g[i] = self.g[i].add(o.g[i]);
3056 for j in 0..K {
3057 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3058 for k in 0..K {
3059 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3060 }
3061 }
3062 }
3063 out
3064 }
3065 /// Per-channel lane-wise `self - o`.
3066 #[inline]
3067 pub fn sub(&self, o: &Self) -> Self {
3068 let mut out = *self;
3069 out.v = self.v.sub(o.v);
3070 for i in 0..K {
3071 out.g[i] = self.g[i].sub(o.g[i]);
3072 for j in 0..K {
3073 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3074 for k in 0..K {
3075 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3076 }
3077 }
3078 }
3079 out
3080 }
3081 /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3082 #[inline]
3083 pub fn scale(&self, s: f64) -> Self {
3084 let sl = L::splat(s);
3085 let mut out = *self;
3086 out.v = self.v.mul(sl);
3087 for i in 0..K {
3088 out.g[i] = self.g[i].mul(sl);
3089 for j in 0..K {
3090 out.h[i][j] = self.h[i][j].mul(sl);
3091 for k in 0..K {
3092 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3093 }
3094 }
3095 }
3096 out
3097 }
3098 /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3099 #[inline]
3100 pub fn mul(&self, o: &Self) -> Self {
3101 let a = self;
3102 let b = o;
3103 let mut out = Self::zero();
3104 out.v = a.v.mul(b.v);
3105 for i in 0..K {
3106 let mut acc = L::splat(0.0);
3107 acc = acc.add(a.v.mul(b.g[i]));
3108 acc = acc.add(a.g[i].mul(b.v));
3109 out.g[i] = acc;
3110 }
3111 // Hessian is symmetric under i↔j; upper triangle + mirror (see Tower2::mul).
3112 for i in 0..K {
3113 for j in i..K {
3114 let mut acc = L::splat(0.0);
3115 acc = acc.add(a.v.mul(b.h[i][j]));
3116 acc = acc.add(a.g[i].mul(b.g[j]));
3117 acc = acc.add(a.g[j].mul(b.g[i]));
3118 acc = acc.add(a.h[i][j].mul(b.v));
3119 out.h[i][j] = acc;
3120 out.h[j][i] = acc;
3121 }
3122 }
3123 for i in 0..K {
3124 for j in 0..K {
3125 for k in 0..K {
3126 let mut acc = L::splat(0.0);
3127 acc = acc.add(a.v.mul(b.t3[i][j][k]));
3128 acc = acc.add(a.g[i].mul(b.h[j][k]));
3129 acc = acc.add(a.g[j].mul(b.h[i][k]));
3130 acc = acc.add(a.h[i][j].mul(b.g[k]));
3131 acc = acc.add(a.g[k].mul(b.h[i][j]));
3132 acc = acc.add(a.h[i][k].mul(b.g[j]));
3133 acc = acc.add(a.h[j][k].mul(b.g[i]));
3134 acc = acc.add(a.t3[i][j][k].mul(b.v));
3135 out.t3[i][j][k] = acc;
3136 }
3137 }
3138 }
3139 out
3140 }
3141 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3142 /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3143 #[inline]
3144 pub fn compose_unary(&self, d: [L; 4]) -> Self {
3145 let mut out = Self::zero();
3146 out.v = d[0];
3147 for i in 0..K {
3148 let mut acc = L::splat(0.0);
3149 acc = acc.add(d[1].mul(self.g[i]));
3150 out.g[i] = acc;
3151 }
3152 for i in 0..K {
3153 for j in 0..K {
3154 let mut acc = L::splat(0.0);
3155 acc = acc.add(d[1].mul(self.h[i][j]));
3156 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3157 out.h[i][j] = acc;
3158 }
3159 }
3160 for i in 0..K {
3161 for j in 0..K {
3162 for k in 0..K {
3163 let mut acc = L::splat(0.0);
3164 acc = acc.add(d[1].mul(self.t3[i][j][k]));
3165 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3166 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3167 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3168 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3169 out.t3[i][j][k] = acc;
3170 }
3171 }
3172 }
3173 out
3174 }
3175 /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3176 /// built from the base value through `stack_fn`, evaluated PER LANE — the
3177 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3178 /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3179 /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3180 /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3181 /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3182 /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3183 #[inline]
3184 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3185 self.compose_unary(self.v.unary_with(stack_fn))
3186 }
3187
3188 /// Single-active-slot fast path, term-for-term lift of
3189 /// [`Tower3::compose_unary_single_slot`].
3190 #[inline]
3191 pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3192 let mut out = Self::zero();
3193 let s = slot;
3194 let g = self.g[s];
3195 let h = self.h[s][s];
3196 let t3 = self.t3[s][s][s];
3197 out.v = d[0];
3198 out.g[s] = {
3199 let mut acc = L::splat(0.0);
3200 acc = acc.add(d[1].mul(g));
3201 acc
3202 };
3203 out.h[s][s] = {
3204 let mut acc = L::splat(0.0);
3205 acc = acc.add(d[1].mul(h));
3206 acc = acc.add(d[2].mul(g).mul(g));
3207 acc
3208 };
3209 out.t3[s][s][s] = {
3210 let mut acc = L::splat(0.0);
3211 acc = acc.add(d[1].mul(t3));
3212 acc = acc.add(d[2].mul(h).mul(g));
3213 acc = acc.add(d[2].mul(h).mul(g));
3214 acc = acc.add(d[2].mul(g).mul(h));
3215 acc = acc.add(d[3].mul(g).mul(g).mul(g));
3216 acc
3217 };
3218 out
3219 }
3220}
3221
3222#[cfg(test)]
3223mod batch_tests {
3224 use super::*;
3225
3226 struct Rng(u64);
3227 impl Rng {
3228 fn f(&mut self) -> f64 {
3229 self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3230 ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3231 }
3232 }
3233
3234 // Fill every channel of a scalar Tower4<K> with random data.
3235 fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3236 let mut t = Tower4::<K>::zero();
3237 t.v = r.f();
3238 for i in 0..K {
3239 t.g[i] = r.f();
3240 for j in 0..K {
3241 t.h[i][j] = r.f();
3242 for k in 0..K {
3243 t.t3[i][j][k] = r.f();
3244 for l in 0..K {
3245 t.t4[i][j][k][l] = r.f();
3246 }
3247 }
3248 }
3249 }
3250 t
3251 }
3252 fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3253 let mut t = Tower3::<K>::zero();
3254 t.v = r.f();
3255 for i in 0..K {
3256 t.g[i] = r.f();
3257 for j in 0..K {
3258 t.h[i][j] = r.f();
3259 for k in 0..K {
3260 t.t3[i][j][k] = r.f();
3261 }
3262 }
3263 }
3264 t
3265 }
3266 fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3267 let mut b = Tower4Batch::<K>::zero();
3268 let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3269 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3270 };
3271 b.v = lane(&|t| t.v);
3272 for i in 0..K {
3273 b.g[i] = lane(&|t| t.g[i]);
3274 for j in 0..K {
3275 b.h[i][j] = lane(&|t| t.h[i][j]);
3276 for k in 0..K {
3277 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3278 for l in 0..K {
3279 b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3280 }
3281 }
3282 }
3283 }
3284 b
3285 }
3286 fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3287 let mut b = Tower3Batch::<K>::zero();
3288 let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3289 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3290 };
3291 b.v = lane(&|t| t.v);
3292 for i in 0..K {
3293 b.g[i] = lane(&|t| t.g[i]);
3294 for j in 0..K {
3295 b.h[i][j] = lane(&|t| t.h[i][j]);
3296 for k in 0..K {
3297 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3298 }
3299 }
3300 }
3301 b
3302 }
3303 fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3304 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3305 for i in 0..K {
3306 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3307 for j in 0..K {
3308 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3309 for k in 0..K {
3310 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3311 for l in 0..K {
3312 assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3313 }
3314 }
3315 }
3316 }
3317 }
3318 fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3319 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3320 for i in 0..K {
3321 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3322 for j in 0..K {
3323 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3324 for k in 0..K {
3325 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3326 }
3327 }
3328 }
3329 }
3330
3331 // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3332 // then assert every channel of every lane is to_bits-identical.
3333 fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3334 let mut r = Rng(seed);
3335 let mut rows_checked = 0;
3336 for _ in 0..batches {
3337 let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3338 let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3339 let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3340 let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3341 let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3342 let s = r.f();
3343
3344 // scalar per-row reference
3345 let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3346 let prod = a[rw].mul(&b[rw]);
3347 let comp = prod.compose_unary(d[rw]);
3348 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3349 summed.compose_unary_single_slot(d[rw], 0)
3350 });
3351 let third: [[[f64; K]; K]; 4] =
3352 std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3353 let fourth: [[[f64; K]; K]; 4] =
3354 std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3355
3356 // batched f64x4
3357 let ab = pack4_t4(&a);
3358 let bb = pack4_t4(&b);
3359 let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3360 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3361 });
3362 let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3363 wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3364 });
3365 let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3366 wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3367 });
3368 let prodb = ab.mul(&bb);
3369 let compb = prodb.compose_unary(db);
3370 let summedb = compb.add(&ab).sub(&bb).scale(s);
3371 let finalb = summedb.compose_unary_single_slot(db, 0);
3372 let thirdb = ab.third_contracted(&dirb);
3373 let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3374
3375 for rw in 0..4 {
3376 assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3377 for i in 0..K {
3378 for j in 0..K {
3379 assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3380 assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3381 }
3382 }
3383 rows_checked += 1;
3384 }
3385 }
3386 rows_checked
3387 }
3388 fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3389 let mut r = Rng(seed);
3390 let mut rows_checked = 0;
3391 for _ in 0..batches {
3392 let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3393 let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3394 let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3395 let s = r.f();
3396 let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3397 let prod = a[rw].mul(&b[rw]);
3398 let comp = prod.compose_unary(d[rw]);
3399 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3400 summed.compose_unary_single_slot(d[rw], 0)
3401 });
3402 let ab = pack4_t3(&a);
3403 let bb = pack4_t3(&b);
3404 let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3405 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3406 });
3407 let prodb = ab.mul(&bb);
3408 let compb = prodb.compose_unary(db);
3409 let summedb = compb.add(&ab).sub(&bb).scale(s);
3410 let finalb = summedb.compose_unary_single_slot(db, 0);
3411 for rw in 0..4 {
3412 assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3413 rows_checked += 1;
3414 }
3415 }
3416 rows_checked
3417 }
3418
3419 // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3420 // lanes (≈210 KiB by value); the op chain keeps several live, which can
3421 // exceed a test thread's default stack. Run each width on a large-stack
3422 // thread so K=9 is exercised without a stack overflow.
3423 fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3424 std::thread::Builder::new()
3425 .stack_size(512 << 20)
3426 .spawn(f)
3427 .unwrap()
3428 .join()
3429 .unwrap()
3430 }
3431
3432 #[test]
3433 fn tower4_batch_lane_bit_identical() {
3434 let batches = 2000;
3435 let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3436 + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3437 + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3438 + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3439 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3440 // worker threads against silently running zero comparisons.
3441 assert_eq!(rows_checked, 4 * batches * 4);
3442 }
3443
3444 #[test]
3445 fn tower3_batch_lane_bit_identical() {
3446 let batches = 2000;
3447 let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3448 + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3449 + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3450 + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3451 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3452 // worker threads against silently running zero comparisons.
3453 assert_eq!(rows_checked, 4 * batches * 4);
3454 }
3455
3456 // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3457 //
3458 // The seam lets a single-sourced row program build its special-function
3459 // STACK from the base value through a closure, so the SAME expression
3460 // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3461 // per-lane bases). These oracles pin both arms `to_bits`.
3462
3463 /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3464 /// standing in for a family's hand-certified special-function stack. `stack4`
3465 /// is its order-≤3 truncation.
3466 fn seam_stack5(u: f64) -> [f64; 5] {
3467 [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3468 }
3469 fn seam_stack4(u: f64) -> [f64; 4] {
3470 let s = seam_stack5(u);
3471 [s[0], s[1], s[2], s[3]]
3472 }
3473
3474 /// Force a distinct / edge per-lane base value (signed zeros included).
3475 fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3476 match which {
3477 0 => -0.0,
3478 1 => 0.0,
3479 2 => r.f(),
3480 _ => r.f() + 3.0,
3481 }
3482 }
3483
3484 /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3485 /// the explicit `compose_unary(f(value))` on every channel.
3486 fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3487 let mut r = Rng(seed);
3488 for _ in 0..n {
3489 let mut t = rand_t4::<K>(&mut r);
3490 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3491 assert_t4_eq(
3492 &t.compose_unary_with(seam_stack5),
3493 &t.compose_unary(seam_stack5(t.v)),
3494 "scalar t4 seam",
3495 );
3496 }
3497 n
3498 }
3499 fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3500 let mut r = Rng(seed);
3501 for _ in 0..n {
3502 let mut t = rand_t3::<K>(&mut r);
3503 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3504 assert_t3_eq(
3505 &t.compose_unary_with(seam_stack4),
3506 &t.compose_unary(seam_stack4(t.v)),
3507 "scalar t3 seam",
3508 );
3509 }
3510 n
3511 }
3512
3513 /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3514 /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3515 /// with the four lanes carrying DISTINCT base values (signed zeros included),
3516 /// so a buggy impl reusing one lane's base would fail.
3517 fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3518 let mut r = Rng(seed);
3519 let mut verified = 0usize;
3520 for _ in 0..batches {
3521 let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3522 for (rw, row) in rows.iter_mut().enumerate() {
3523 row.v = seam_edge_base(&mut r, rw);
3524 }
3525 let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3526 for (rw, row) in rows.iter().enumerate() {
3527 assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3528 verified += 1;
3529 }
3530 }
3531 verified
3532 }
3533 fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3534 let mut r = Rng(seed);
3535 let mut verified = 0usize;
3536 for _ in 0..batches {
3537 let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3538 for (rw, row) in rows.iter_mut().enumerate() {
3539 row.v = seam_edge_base(&mut r, rw);
3540 }
3541 let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3542 for (rw, row) in rows.iter().enumerate() {
3543 assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3544 verified += 1;
3545 }
3546 }
3547 verified
3548 }
3549
3550 #[test]
3551 fn compose_unary_with_scalar_bit_identical() {
3552 let n = 1100;
3553 let total = scalar_seam_t4::<2>(0x2200_0001, n)
3554 + scalar_seam_t4::<3>(0x2200_0002, n)
3555 + scalar_seam_t4::<4>(0x2200_0003, n)
3556 + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3557 + scalar_seam_t3::<2>(0x3300_0001, n)
3558 + scalar_seam_t3::<3>(0x3300_0002, n)
3559 + scalar_seam_t3::<4>(0x3300_0003, n)
3560 + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3561 // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3562 assert_eq!(total, 8 * n);
3563 }
3564
3565 #[test]
3566 fn compose_unary_with_lane_matches_scalar() {
3567 let b = 600;
3568 let total = lane_seam_t4::<2>(0x4400_0001, b)
3569 + lane_seam_t4::<3>(0x4400_0002, b)
3570 + lane_seam_t4::<4>(0x4400_0003, b)
3571 + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3572 + lane_seam_t3::<2>(0x5500_0001, b)
3573 + lane_seam_t3::<3>(0x5500_0002, b)
3574 + lane_seam_t3::<4>(0x5500_0003, b)
3575 + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3576 // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3577 assert_eq!(total, 8 * b * 4);
3578 }
3579}
3580
3581#[cfg(test)]
3582mod tests {
3583 use super::*;
3584
3585 /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3586 /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3587 /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3588 /// dropping the fourth tensor cannot perturb them. Exercises products
3589 /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3590 /// same operations the survival location-scale `nll_index_tower` composes —
3591 /// across all mixed partials, not just the diagonal entries that kernel reads.
3592 #[test]
3593 fn tower3_matches_tower4_through_third_order() {
3594 let s_a: [f64; 5] = [
3595 0.3_f64.sin(),
3596 0.3_f64.cos(),
3597 -0.3_f64.sin(),
3598 -0.3_f64.cos(),
3599 0.3_f64.sin(),
3600 ];
3601 let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3602 let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3603
3604 let a4 = Tower4::<3>::variable(0.4, 0);
3605 let b4 = Tower4::<3>::variable(-0.7, 1);
3606 let c4 = Tower4::<3>::variable(0.9, 2);
3607 let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3608 + a4.mul(&c4).scale(-0.7)
3609 + b4.compose_unary(s_b).scale(0.25);
3610
3611 let a3 = Tower3::<3>::variable(0.4, 0);
3612 let b3 = Tower3::<3>::variable(-0.7, 1);
3613 let c3 = Tower3::<3>::variable(0.9, 2);
3614 let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3615 + a3.mul(&c3).scale(-0.7)
3616 + b3.compose_unary(s4(s_b)).scale(0.25);
3617
3618 assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3619 for i in 0..3 {
3620 assert_eq!(
3621 prog3.g[i].to_bits(),
3622 prog4.g[i].to_bits(),
3623 "g[{i}] mismatch"
3624 );
3625 for j in 0..3 {
3626 assert_eq!(
3627 prog3.h[i][j].to_bits(),
3628 prog4.h[i][j].to_bits(),
3629 "h[{i}][{j}] mismatch"
3630 );
3631 for k in 0..3 {
3632 assert_eq!(
3633 prog3.t3[i][j][k].to_bits(),
3634 prog4.t3[i][j][k].to_bits(),
3635 "t3[{i}][{j}][{k}] mismatch"
3636 );
3637 }
3638 }
3639 }
3640 }
3641
3642 /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3643 /// The entire tower has textbook closed forms in μ = σ(η); this test
3644 /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3645 /// analytic truth at near-machine precision.
3646 struct LogitProgram {
3647 eta: Vec<f64>,
3648 y: Vec<f64>,
3649 }
3650
3651 impl RowNllProgram<1> for LogitProgram {
3652 fn n_rows(&self) -> usize {
3653 self.eta.len()
3654 }
3655 fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3656 Ok([self.eta[row]])
3657 }
3658 fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3659 let eta = p[0];
3660 Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3661 }
3662 }
3663
3664 #[test]
3665 fn logit_tower_matches_closed_forms() {
3666 let prog = LogitProgram {
3667 eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3668 y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3669 };
3670 for row in 0..prog.n_rows() {
3671 let t = evaluate_program(&prog, row).expect("logit program");
3672 let eta = prog.eta[row];
3673 let y = prog.y[row];
3674 let mu = 1.0 / (1.0 + (-eta).exp());
3675 let w = mu * (1.0 - mu);
3676 let expect = [
3677 (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3678 (t.g[0], mu - y, "grad"),
3679 (t.h[0][0], w, "hess"),
3680 (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3681 (
3682 t.t4[0][0][0][0],
3683 w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3684 "fourth",
3685 ),
3686 ];
3687 for (got, want, label) in expect {
3688 assert!(
3689 (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3690 "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3691 );
3692 }
3693 }
3694 }
3695
3696 fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3697 let diff = (got - want).abs();
3698 assert!(
3699 diff <= rel_tol * want.abs().max(1.0),
3700 "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3701 );
3702 }
3703
3704 #[test]
3705 fn gamma_special_function_stacks_match_reference_values() {
3706 const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3707 let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3708 let cases = [
3709 (
3710 "x=0.1",
3711 0.1,
3712 -10.423_754_940_411_076,
3713 101.433_299_150_792_75,
3714 ),
3715 (
3716 "x=0.5",
3717 0.5,
3718 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3719 pi_sq / 2.0,
3720 ),
3721 ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3722 (
3723 "x=2.5",
3724 2.5,
3725 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3726 pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3727 ),
3728 (
3729 "x=50",
3730 50.0,
3731 3.901_989_673_427_892,
3732 0.020_201_333_226_697_128,
3733 ),
3734 ];
3735
3736 for (label, x, digamma_ref, trigamma_ref) in cases {
3737 let ln_gamma_stack = ln_gamma_derivative_stack(x);
3738 let digamma_stack = digamma_derivative_stack(x);
3739 let trigamma_stack = trigamma_derivative_stack(x);
3740 assert_close(
3741 &format!("{label} ln_gamma_stack digamma"),
3742 ln_gamma_stack[1],
3743 digamma_ref,
3744 1e-13,
3745 );
3746 assert_close(
3747 &format!("{label} digamma value"),
3748 digamma_stack[0],
3749 digamma_ref,
3750 1e-13,
3751 );
3752 assert_close(
3753 &format!("{label} ln_gamma_stack trigamma"),
3754 ln_gamma_stack[2],
3755 trigamma_ref,
3756 1e-13,
3757 );
3758 assert_close(
3759 &format!("{label} digamma_stack trigamma"),
3760 digamma_stack[1],
3761 trigamma_ref,
3762 1e-13,
3763 );
3764 assert_close(
3765 &format!("{label} trigamma value"),
3766 trigamma_stack[0],
3767 trigamma_ref,
3768 1e-13,
3769 );
3770 }
3771 }
3772
3773 #[test]
3774 fn gamma_special_function_stacks_obey_recurrences() {
3775 for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3776 let digamma_x = digamma_derivative_stack(x)[0];
3777 let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3778 let trigamma_x = trigamma_derivative_stack(x)[0];
3779 let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3780 assert_close(
3781 &format!("digamma recurrence x={x}"),
3782 digamma_next,
3783 digamma_x + 1.0 / x,
3784 1e-13,
3785 );
3786 assert_close(
3787 &format!("trigamma recurrence x={x}"),
3788 trigamma_next,
3789 trigamma_x - 1.0 / (x * x),
3790 1e-13,
3791 );
3792 }
3793 }
3794
3795 /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3796 /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3797 /// shape — all have one-line closed forms here.
3798 struct LocScaleProgram {
3799 eta: Vec<f64>,
3800 s: Vec<f64>,
3801 y: Vec<f64>,
3802 }
3803
3804 impl RowNllProgram<2> for LocScaleProgram {
3805 fn n_rows(&self) -> usize {
3806 self.eta.len()
3807 }
3808 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3809 Ok([self.eta[row], self.s[row]])
3810 }
3811 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3812 let r = -(p[0] - self.y[row]);
3813 Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3814 }
3815 }
3816
3817 #[test]
3818 fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3819 let prog = LocScaleProgram {
3820 eta: vec![0.3, -1.1, 2.0],
3821 s: vec![-0.5, 0.2, 0.8],
3822 y: vec![1.0, -2.0, 2.5],
3823 };
3824 let tol = 1e-12;
3825 for row in 0..prog.n_rows() {
3826 let t = evaluate_program(&prog, row).expect("locscale program");
3827 let r = prog.y[row] - prog.eta[row];
3828 let w = (-2.0 * prog.s[row]).exp();
3829 // (η, s) = indices (0, 1).
3830 let truth_g = [-w * r, 1.0 - w * r * r];
3831 let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3832 // Third tensor: distinct-entry closed forms.
3833 // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3834 let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3835 match a + b + c {
3836 0 => 0.0,
3837 1 => -2.0 * w,
3838 2 => -4.0 * w * r,
3839 _ => -4.0 * w * r * r,
3840 }
3841 };
3842 // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3843 // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3844 let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3845 match a + b + c + d {
3846 0 | 1 => 0.0,
3847 2 => 4.0 * w,
3848 3 => 8.0 * w * r,
3849 _ => 8.0 * w * r * r,
3850 }
3851 };
3852 for a in 0..2 {
3853 assert!(
3854 (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3855 "row {row} grad[{a}]"
3856 );
3857 for b in 0..2 {
3858 assert!(
3859 (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3860 "row {row} hess[{a}][{b}]: got {} want {}",
3861 t.h[a][b],
3862 truth_h[a][b]
3863 );
3864 for c in 0..2 {
3865 assert!(
3866 (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3867 <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3868 "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3869 t.t3[a][b][c],
3870 t3_truth(a, b, c)
3871 );
3872 for d in 0..2 {
3873 assert!(
3874 (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3875 <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3876 "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3877 t.t4[a][b][c][d],
3878 t4_truth(a, b, c, d)
3879 );
3880 }
3881 }
3882 }
3883 }
3884 // The derived trait-surface helpers agree with direct contraction.
3885 let dir = [0.7, -1.3];
3886 let third = derived_third_contracted(&prog, row, &dir).expect("third");
3887 for a in 0..2 {
3888 for b in 0..2 {
3889 let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3890 assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3891 }
3892 }
3893 }
3894 }
3895
3896 /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3897 /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3898 /// every tower channel is checked against central finite differences
3899 /// of the channel one order below — value→grad, grad→hess, hess→t3,
3900 /// t3→t4 — so each order is independently anchored.
3901 ///
3902 /// The program carries a per-row primary fixture plus a per-row offset
3903 /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3904 /// drives both the seed point and the evaluated expression.
3905 struct GnarlyProgram {
3906 primaries: Vec<[f64; 3]>,
3907 tau: Vec<f64>,
3908 }
3909
3910 impl GnarlyProgram {
3911 fn fixture() -> Self {
3912 Self {
3913 primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3914 tau: vec![0.15, -0.35, 0.5],
3915 }
3916 }
3917 }
3918
3919 impl RowNllProgram<3> for GnarlyProgram {
3920 fn n_rows(&self) -> usize {
3921 self.primaries.len()
3922 }
3923 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3924 self.primaries
3925 .get(row)
3926 .copied()
3927 .ok_or_else(|| format!("gnarly: row {row} out of range"))
3928 }
3929 fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3930 let tau = *self
3931 .tau
3932 .get(row)
3933 .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3934 let a = (p[0] * p[1]).exp();
3935 let b = (p[2] * p[2] + 1.0).sqrt();
3936 let c = (a + b + tau).ln();
3937 let d = (p[1] * 0.5 + 2.0).powf(1.7);
3938 Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3939 }
3940 }
3941
3942 /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3943 /// `row` (used to drive central differences off the fixture grid),
3944 /// while keeping `row`'s per-row data (`tau`) in the loss.
3945 fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3946 struct At<'a> {
3947 base: &'a GnarlyProgram,
3948 row: usize,
3949 p: [f64; 3],
3950 }
3951 impl RowNllProgram<3> for At<'_> {
3952 fn n_rows(&self) -> usize {
3953 1
3954 }
3955 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3956 if row != 0 {
3957 return Err(format!("gnarly-at: row {row} out of range"));
3958 }
3959 Ok(self.p)
3960 }
3961 fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3962 if eval_row != 0 {
3963 return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3964 }
3965 self.base.row_nll(self.row, vars)
3966 }
3967 }
3968 evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3969 }
3970
3971 #[test]
3972 fn gnarly_tower_is_fd_consistent_order_by_order() {
3973 let prog = GnarlyProgram::fixture();
3974 for row in 0..prog.n_rows() {
3975 let base = prog.primaries(row).expect("primaries");
3976 let t = gnarly_tower_at(&prog, row, base);
3977 let h_step = 1e-5;
3978 let tol = 1e-6;
3979 for c in 0..3 {
3980 let mut up = base;
3981 let mut dn = base;
3982 up[c] += h_step;
3983 dn[c] -= h_step;
3984 let t_up = gnarly_tower_at(&prog, row, up);
3985 let t_dn = gnarly_tower_at(&prog, row, dn);
3986 // value → gradient.
3987 let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3988 assert!(
3989 (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3990 "grad[{c}]: analytic {} fd {}",
3991 t.g[c],
3992 fd_g
3993 );
3994 for a in 0..3 {
3995 // gradient → Hessian.
3996 let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3997 assert!(
3998 (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3999 "hess[{a}][{c}]: analytic {} fd {}",
4000 t.h[a][c],
4001 fd_h
4002 );
4003 for b in 0..3 {
4004 // Hessian → third.
4005 let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
4006 assert!(
4007 (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
4008 "t3[{a}][{b}][{c}]: analytic {} fd {}",
4009 t.t3[a][b][c],
4010 fd_t3
4011 );
4012 for d in 0..3 {
4013 // third → fourth.
4014 let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
4015 assert!(
4016 (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
4017 "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
4018 t.t4[a][b][d][c],
4019 fd_t4
4020 );
4021 }
4022 }
4023 }
4024 }
4025 }
4026 }
4027
4028 /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
4029 /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
4030 /// smooth, strictly-`a`-monotone
4031 /// F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
4032 /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
4033 /// independent finite-difference oracle. Mirrors the survival flex
4034 /// calibration solve (one implicit intercept over the primaries) without
4035 /// any survival machinery, so a failure localises to the combinator.
4036 #[test]
4037 fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
4038 const C: f64 = 1.7;
4039 // The scalar constraint as a plain f64 closure (the production root
4040 // finder analogue) and its tower form in (a, θ₀, θ₁).
4041 let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
4042 let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
4043 let solve = |th: [f64; 2]| -> f64 {
4044 let mut a = 0.0_f64;
4045 for _ in 0..100 {
4046 let r = f_scalar(a, th);
4047 if r.abs() < 1e-14 {
4048 break;
4049 }
4050 a -= r / f_da(a, th);
4051 }
4052 a
4053 };
4054 // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
4055 let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
4056 let a = Tower4::<3>::variable(a0, 0);
4057 let t0 = Tower4::<3>::variable(th[0], 1);
4058 let t1 = Tower4::<3>::variable(th[1], 2);
4059 a + t0 * a.mul(&a) + t1 * a.exp() - C
4060 };
4061
4062 let th0 = [0.35, 0.2];
4063 let a0 = solve(th0);
4064 let f = f_tower(a0, th0);
4065 // Residual at the solved point is ~0 (the combinator tolerates the
4066 // production Newton residual; here it is machine-zero).
4067 assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4068 let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4069
4070 // FD oracle: central differences of the scalar re-solve. Each order is
4071 // built from the previous via one more central difference, exactly the
4072 // gnarly order-by-order ladder.
4073 let h = 1e-4;
4074 let tol = 1e-5;
4075 let re = |th: [f64; 2]| solve(th);
4076 for i in 0..2 {
4077 let mut up = th0;
4078 let mut dn = th0;
4079 up[i] += h;
4080 dn[i] -= h;
4081 let fd_g = (re(up) - re(dn)) / (2.0 * h);
4082 assert!(
4083 (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4084 "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4085 a_tower.g[i],
4086 fd_g
4087 );
4088 // second order: FD of the analytic gradient component would re-use
4089 // the combinator; instead difference a SCALAR gradient computed by
4090 // a nested re-solve so the oracle stays production-independent.
4091 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4092 let mut up = th;
4093 let mut dn = th;
4094 up[j] += h;
4095 dn[j] -= h;
4096 (re(up) - re(dn)) / (2.0 * h)
4097 };
4098 for j in 0..2 {
4099 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4100 assert!(
4101 (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4102 "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4103 a_tower.h[i][j],
4104 fd_h
4105 );
4106 }
4107 }
4108 }
4109
4110 /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4111 /// linear-in-a constraint, and the second-order tensor matches the
4112 /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4113 /// This pins the recursion against the hand-coded first_full.rs formula it
4114 /// replaces, independent of any FD step.
4115 #[test]
4116 fn implicit_solve_matches_textbook_ift_recursion() {
4117 // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4118 let a0 = 0.4_f64;
4119 let th = [0.25_f64, -0.15_f64];
4120 let f = {
4121 let a = Tower4::<3>::variable(a0, 0);
4122 let t0 = Tower4::<3>::variable(th[0], 1);
4123 let t1 = Tower4::<3>::variable(th[1], 2);
4124 // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4125 // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4126 // 0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4127 // implicit_solve requires a genuine root; at the root the level-set
4128 // and root-curve derivatives coincide, so the textbook-IFT
4129 // assertions below are unaffected.
4130 a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4131 };
4132 let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4133 let f_a = f.g[0];
4134 // First order: a_u = −F_u / F_a.
4135 for u in 0..2 {
4136 let want = -f.g[u + 1] / f_a;
4137 assert!(
4138 (a_t.g[u] - want).abs() < 1e-12,
4139 "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4140 a_t.g[u],
4141 want
4142 );
4143 }
4144 // Second order textbook IFT (indices shifted by 1 for the a-slot).
4145 for u in 0..2 {
4146 for v in 0..2 {
4147 let f_uv = f.h[u + 1][v + 1];
4148 let f_au = f.h[0][u + 1];
4149 let f_av = f.h[0][v + 1];
4150 let f_aa = f.h[0][0];
4151 let want =
4152 -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4153 assert!(
4154 (a_t.h[u][v] - want).abs() < 1e-12,
4155 "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4156 a_t.h[u][v],
4157 want
4158 );
4159 }
4160 }
4161 }
4162
4163 /// The moving-boundary flux tower reproduces every θ-derivative of a
4164 /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4165 /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4166 /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4167 /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4168 /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4169 #[test]
4170 fn moving_boundary_flux_carries_b_zuv_term() {
4171 use std::f64::consts::PI;
4172 let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4173 // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4174 let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4175 let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4176 let th0 = [0.7_f64, 0.5_f64];
4177
4178 // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4179 let mut z_edge = Tower4::<2>::constant(z_r(th0));
4180 z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4181 z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4182 z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2 (the z_uv the old flux dropped)
4183
4184 // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4185 // B‴=(3z−z³)·B.
4186 let z0 = z_edge.v;
4187 let b0 = b(z0);
4188 let stack = [
4189 b0,
4190 -z0 * b0,
4191 (z0 * z0 - 1.0) * b0,
4192 (3.0 * z0 - z0 * z0 * z0) * b0,
4193 ];
4194 let flux = moving_limit_boundary_tower(&z_edge, stack);
4195
4196 // FD truth of the integral's derivatives.
4197 let h = 1e-4;
4198 let tol = 1e-6;
4199 for i in 0..2 {
4200 let mut up = th0;
4201 let mut dn = th0;
4202 up[i] += h;
4203 dn[i] -= h;
4204 let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4205 assert!(
4206 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4207 "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4208 flux.g[i],
4209 fd_g
4210 );
4211 }
4212 // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4213 // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4214 // leave the [1][1] entry short by exactly 2·B(z₀).
4215 let grad1_at = |th: [f64; 2]| -> f64 {
4216 let mut up = th;
4217 let mut dn = th;
4218 up[1] += h;
4219 dn[1] -= h;
4220 (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4221 };
4222 let mut up = th0;
4223 let mut dn = th0;
4224 up[1] += h;
4225 dn[1] -= h;
4226 let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4227 assert!(
4228 (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4229 "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4230 flux.h[1][1],
4231 fd_h11
4232 );
4233 // Explicit witness that the B·z_uv term is present and material:
4234 // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4235 let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4236 let b_zuv = flux.h[1][1] - pure_zu2;
4237 assert!(
4238 (b_zuv - b0 * 2.0).abs() < 1e-10,
4239 "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4240 b_zuv,
4241 b0 * 2.0
4242 );
4243 }
4244
4245 /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4246 /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4247 /// plain `moving_limit_boundary_tower` cannot express, and the case the
4248 /// survival directional/bidirectional paths hand-assemble term-by-term
4249 /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4250 /// dropping `G·z_uv`). Two independent oracles:
4251 /// (1) closed-form: the boundary flux of `∫ G dz` is exactly
4252 /// `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4253 /// derivatives we take by central FD of the closed form — no jet code.
4254 /// (2) the explicit second-order hand closure, including the `G·z_uv` term,
4255 /// built from the integrand's own (z,θ) partials.
4256 /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4257 /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4258 /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4259 /// miss a Hessian entry.
4260 #[test]
4261 fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4262 // G(z;θ) = exp(z·θ₀); Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4263 let g = |z: f64, t0: f64| (z * t0).exp();
4264 let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4265 let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4266 let th0 = [0.4_f64, 0.5_f64];
4267 let z0 = z_r(th0);
4268
4269 // Edge tower z_edge(θ) over K=2 primaries.
4270 let mut z_edge = Tower4::<2>::constant(z0);
4271 z_edge.g[0] = 1.0; // ∂z/∂θ₀
4272 z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4273 z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4274
4275 // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4276 // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4277 let z_var = Tower4::<3>::variable(z0, 0);
4278 let t0_var = Tower4::<3>::variable(th0[0], 1);
4279 // θ₁ does not enter G/Φ here (its Φ-derivatives are zero; the z_edge
4280 // chain supplies all θ₁ motion through slot 0), so the K1 frame's θ₁
4281 // slot is intentionally left unseeded.
4282 let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4283 // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4284 assert!(
4285 (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4286 "Φ_z {:+.8e} != G {:+.8e}",
4287 phi_jet.g[0],
4288 g(z0, th0[0])
4289 );
4290
4291 let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4292
4293 // Value channel is 0 by construction (boundary, not the integral itself).
4294 assert!(
4295 flux.v.abs() < 1e-12,
4296 "boundary value channel {:+.3e}",
4297 flux.v
4298 );
4299
4300 // Oracle (1): central FD of the closed-form boundary flux
4301 // Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ) (z₀ FROZEN at the base edge).
4302 let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4303 let h = 1e-4;
4304 let tol = 1e-6;
4305 for i in 0..2 {
4306 let mut up = th0;
4307 let mut dn = th0;
4308 up[i] += h;
4309 dn[i] -= h;
4310 let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4311 assert!(
4312 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4313 "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4314 flux.g[i],
4315 fd_g
4316 );
4317 }
4318 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4319 let mut up = th;
4320 let mut dn = th;
4321 up[j] += h;
4322 dn[j] -= h;
4323 (bnd(up) - bnd(dn)) / (2.0 * h)
4324 };
4325 for i in 0..2 {
4326 for j in 0..2 {
4327 let mut up = th0;
4328 let mut dn = th0;
4329 up[i] += h;
4330 dn[i] -= h;
4331 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4332 assert!(
4333 (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4334 "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4335 flux.h[i][j],
4336 fd_h
4337 );
4338 }
4339 }
4340
4341 // Oracle (2): the explicit second-order hand closure, term by term —
4342 // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4343 // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4344 // G_θ₁ = 0.
4345 let gg = g(z0, th0[0]);
4346 let g_z = th0[0] * gg;
4347 let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4348 for i in 0..2 {
4349 for j in 0..2 {
4350 let z_u = z_edge.g[i];
4351 let z_v = z_edge.g[j];
4352 let z_uv = z_edge.h[i][j];
4353 let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4354 assert!(
4355 (flux.h[i][j] - hand).abs() < 1e-9,
4356 "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4357 flux.h[i][j],
4358 hand
4359 );
4360 }
4361 }
4362
4363 // Decisive: the `G·z_uv` term the directional path DROPS is present and
4364 // material in the [1][1] entry (z_uv = 2 there).
4365 let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4366 let g_zuv = flux.h[1][1] - pure_no_zuv;
4367 assert!(
4368 (g_zuv - gg * 2.0).abs() < 1e-9,
4369 "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4370 g_zuv,
4371 gg * 2.0
4372 );
4373 }
4374
4375 /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4376 /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4377 /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4378 /// boundary-velocity formulas:
4379 /// z_u = −(a_u + [u==g]·z) / b
4380 /// z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4381 /// This pins the bridge between `implicit_solve` and
4382 /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4383 /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4384 /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4385 /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4386 #[test]
4387 fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4388 const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4389 let g_idx = 1usize;
4390 let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4391 // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4392 // two live axes so a_u and a_uv are both exercised. (In production this
4393 // comes from implicit_solve; here we plant known derivatives.)
4394 let mut a = Tower4::<3>::constant(0.45);
4395 a.g[0] = 0.7;
4396 a.g[1] = -0.3;
4397 a.h[0][0] = 0.25;
4398 a.h[0][1] = 0.11;
4399 a.h[1][0] = 0.11;
4400 a.h[1][1] = -0.08;
4401
4402 // In the survival flex frame the slope `b` IS the g-primary directly
4403 // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4404 // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4405 let b = Tower4::<3>::variable(g0, g_idx);
4406 let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4407
4408 let bv = g0;
4409 let z0 = z_edge.v;
4410 assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4411
4412 // z_u = −(a_u + [u==g]·z) / b.
4413 for u in 0..2 {
4414 let direct = if u == g_idx { z0 } else { 0.0 };
4415 let want = -(a.g[u] + direct) / bv;
4416 assert!(
4417 (z_edge.g[u] - want).abs() < 1e-10,
4418 "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4419 z_edge.g[u],
4420 want
4421 );
4422 }
4423 // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4424 // first-order z_v/z_u (already verified above).
4425 for u in 0..2 {
4426 for v in 0..2 {
4427 let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4428 + if v == g_idx { z_edge.g[u] } else { 0.0 };
4429 let want = -(a.h[u][v] + cross) / bv;
4430 assert!(
4431 (z_edge.h[u][v] - want).abs() < 1e-10,
4432 "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4433 z_edge.h[u][v],
4434 want
4435 );
4436 }
4437 }
4438 }
4439
4440 /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4441 /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4442 /// FD-certified bare boundary-velocity constants exactly:
4443 /// z_a = ∂z/∂a = −1/b
4444 /// z_ab = ∂²z/∂a∂b = +1/b²
4445 /// z_aa = ∂²z/∂a² = 0
4446 /// z_bb = ∂²z/∂b² = +2(τ−a)/b³
4447 /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4448 /// production base path drops (and only adds in the dir twins, causing the
4449 /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4450 /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4451 /// that later. Pins the constant before the constraint-tower wiring.
4452 #[test]
4453 fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4454 const TAU: f64 = 1.3;
4455 let a0 = 0.45_f64;
4456 let b0 = 0.85_f64;
4457 // Slot 0 = a, slot 1 = b, both seeded independent.
4458 let a = Tower4::<2>::variable(a0, 0);
4459 let b = Tower4::<2>::variable(b0, 1);
4460 let z = (Tower4::<2>::constant(TAU) - a) / b;
4461
4462 assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4463 assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4464 assert!(
4465 (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4466 "z_ab {:+.10e} vs +1/b² {:+.10e}",
4467 z.h[0][1],
4468 1.0 / (b0 * b0)
4469 );
4470 assert!(
4471 z.h[0][0].abs() < 1e-12,
4472 "z_aa must vanish, got {:+.10e}",
4473 z.h[0][0]
4474 );
4475 let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4476 assert!(
4477 (z.h[1][1] - want_zbb).abs() < 1e-12,
4478 "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4479 z.h[1][1],
4480 want_zbb
4481 );
4482 }
4483
4484 /// The oracle harness catches a planted #736-style sign flip in a
4485 /// cross block and reports the channel by name.
4486 #[test]
4487 fn oracle_catches_planted_cross_block_sign_flip() {
4488 let prog = LocScaleProgram {
4489 eta: vec![0.3],
4490 s: vec![-0.5],
4491 y: vec![1.0],
4492 };
4493 let t = evaluate_program(&prog, 0).expect("tower");
4494 let dir = [0.6, -0.2];
4495 let mut third = t.third_contracted(&dir);
4496 let honest = KernelChannels {
4497 value: t.v,
4498 gradient: t.g,
4499 hessian: t.h,
4500 third: vec![(dir, third)],
4501 fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4502 };
4503 verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4504
4505 // Plant the #736 flip: negate one mixed cross entry.
4506 third[0][1] = -third[0][1];
4507 let flipped = KernelChannels {
4508 value: t.v,
4509 gradient: t.g,
4510 hessian: t.h,
4511 third: vec![(dir, third)],
4512 fourth: vec![],
4513 };
4514 let err = verify_kernel_channels(&t, &flipped, 1e-10)
4515 .expect_err("planted sign flip must be caught");
4516 assert!(
4517 err.contains("third[0][0][1]"),
4518 "oracle must name the flipped channel, got: {err}"
4519 );
4520 }
4521
4522 /// The third- and fourth-order tensors must be FULLY symmetric under
4523 /// index permutation (mixed partials commute). The tower stores them
4524 /// unsymmetrized, so equal-by-construction is a real invariant of the
4525 /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4526 /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4527 /// every composition path contributes. Lives in a test (not the hot
4528 /// per-op path) on purpose.
4529 #[test]
4530 fn t3_t4_are_fully_index_symmetric() {
4531 let prog = GnarlyProgram::fixture();
4532 // 3! = 6 permutations of three indices.
4533 let perms3: [[usize; 3]; 6] = [
4534 [0, 1, 2],
4535 [0, 2, 1],
4536 [1, 0, 2],
4537 [1, 2, 0],
4538 [2, 0, 1],
4539 [2, 1, 0],
4540 ];
4541 // 4! = 24 permutations of four indices.
4542 let perms4: [[usize; 4]; 24] = [
4543 [0, 1, 2, 3],
4544 [0, 1, 3, 2],
4545 [0, 2, 1, 3],
4546 [0, 2, 3, 1],
4547 [0, 3, 1, 2],
4548 [0, 3, 2, 1],
4549 [1, 0, 2, 3],
4550 [1, 0, 3, 2],
4551 [1, 2, 0, 3],
4552 [1, 2, 3, 0],
4553 [1, 3, 0, 2],
4554 [1, 3, 2, 0],
4555 [2, 0, 1, 3],
4556 [2, 0, 3, 1],
4557 [2, 1, 0, 3],
4558 [2, 1, 3, 0],
4559 [2, 3, 0, 1],
4560 [2, 3, 1, 0],
4561 [3, 0, 1, 2],
4562 [3, 0, 2, 1],
4563 [3, 1, 0, 2],
4564 [3, 1, 2, 0],
4565 [3, 2, 0, 1],
4566 [3, 2, 1, 0],
4567 ];
4568 for row in 0..prog.n_rows() {
4569 let t = evaluate_program(&prog, row).expect("gnarly tower");
4570 let scale_t3 =
4571 t.t3.iter()
4572 .flatten()
4573 .flatten()
4574 .fold(0.0_f64, |m, x| m.max(x.abs()))
4575 .max(1.0);
4576 let scale_t4 =
4577 t.t4.iter()
4578 .flatten()
4579 .flatten()
4580 .flatten()
4581 .fold(0.0_f64, |m, x| m.max(x.abs()))
4582 .max(1.0);
4583 for i in 0..3 {
4584 for j in 0..3 {
4585 for k in 0..3 {
4586 let base = t.t3[i][j][k];
4587 let idx = [i, j, k];
4588 for p in &perms3 {
4589 let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4590 assert!(
4591 (base - permed).abs() <= 1e-12 * scale_t3,
4592 "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4593 permuted {permed:+.15e} under {p:?}"
4594 );
4595 }
4596 for l in 0..3 {
4597 let base4 = t.t4[i][j][k][l];
4598 let idx4 = [i, j, k, l];
4599 for p in &perms4 {
4600 let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4601 assert!(
4602 (base4 - permed).abs() <= 1e-12 * scale_t4,
4603 "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4604 permuted {permed:+.15e} under {p:?}"
4605 );
4606 }
4607 }
4608 }
4609 }
4610 }
4611 }
4612 }
4613}
4614
4615#[inline]
4616fn erfcx_nonnegative(x: f64) -> f64 {
4617 if !x.is_finite() {
4618 return if x.is_sign_positive() {
4619 0.0
4620 } else {
4621 f64::INFINITY
4622 };
4623 }
4624 if x <= 0.0 {
4625 return 1.0;
4626 }
4627 if x < 26.0 {
4628 ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4629 } else {
4630 let inv = 1.0 / x;
4631 let inv2 = inv * inv;
4632 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4633 + 6.5625 * inv2 * inv2 * inv2 * inv2;
4634 inv * poly / std::f64::consts::PI.sqrt()
4635 }
4636}
4637
4638#[inline]
4639fn log1mexp_positive(a: f64) -> f64 {
4640 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4641 if a > core::f64::consts::LN_2 {
4642 (-(-a).exp()).ln_1p()
4643 } else if a > 0.0 {
4644 (-(-a).exp_m1()).ln()
4645 } else {
4646 f64::NEG_INFINITY
4647 }
4648}
4649
4650#[inline]
4651fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4652 if x == f64::INFINITY {
4653 return (0.0, 0.0);
4654 }
4655 if x == f64::NEG_INFINITY {
4656 return (f64::NEG_INFINITY, f64::INFINITY);
4657 }
4658 if x.is_nan() {
4659 return (f64::NAN, f64::NAN);
4660 }
4661 if x < 0.0 {
4662 let u = -x / std::f64::consts::SQRT_2;
4663 let ex = erfcx_nonnegative(u).max(1e-300);
4664 let log_cdf = -u * u + (0.5 * ex).ln();
4665 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4666 (log_cdf, lambda)
4667 } else {
4668 let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4669 let lambda = crate::probability::normal_pdf(x) / cdf;
4670 (cdf.ln(), lambda)
4671 }
4672}
4673
4674/// Stable derivative stack for `log Phi(x)` through fourth order.
4675#[inline]
4676pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4677 let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4678 let lambda2 = lambda * lambda;
4679 let lambda3 = lambda2 * lambda;
4680 let x2 = x * x;
4681 [
4682 log_cdf,
4683 lambda,
4684 -lambda * (x + lambda),
4685 lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4686 -lambda
4687 * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4688 ]
4689}
4690
4691/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4692#[inline]
4693pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4694 let r = 1.0 / x.exp_m1();
4695 [
4696 log1mexp_positive(x),
4697 r,
4698 -r * (1.0 + r),
4699 r * (1.0 + r) * (1.0 + 2.0 * r),
4700 -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4701 ]
4702}
4703// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4704#[cfg(test)]
4705mod rowjet_bridge_tests {
4706 use super::*;
4707 use crate::jet_scalar::{JetScalar, Order2};
4708
4709 /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4710 /// subtraction, a scale/neg, a constant, and two value-distinct
4711 /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4712 /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4713 /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4714 struct ToyProgram {
4715 primaries: Vec<[f64; 2]>,
4716 /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4717 /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4718 /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4719 aux: Vec<[f64; 3]>,
4720 }
4721
4722 impl ToyProgram {
4723 /// The body uses `pack_rows` to gather the per-lane continuous data from
4724 /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4725 /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4726 fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4727 let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4728 let z = R::pack_rows(rows, |r| self.aux[r][1]);
4729 let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4730
4731 let a = p[0].mul(&p[1]).scale_rows(cov);
4732 let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4733 let c = b
4734 .compose_unary_with(|u| {
4735 let e = u.exp();
4736 [e, e, e, e, e]
4737 })
4738 .scale_rows(z);
4739 let d = c.neg().add(&p[0]);
4740 let e = d
4741 .compose_unary_with(|u| {
4742 let s = (1.0 + u * u).sqrt();
4743 let s3 = s * s * s;
4744 let s5 = s3 * s * s;
4745 let s7 = s5 * s * s;
4746 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4747 })
4748 .scale_rows(wi);
4749 e.mul(&p[1]).add(&e)
4750 }
4751 }
4752
4753 impl RowNllProgramRowJet<2> for ToyProgram {
4754 fn n_rows(&self) -> usize {
4755 self.primaries.len()
4756 }
4757 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4758 Ok(self.primaries[row])
4759 }
4760 fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4761 assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4762 Ok(self.body(rows, p))
4763 }
4764 }
4765
4766 fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4767 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4768 for i in 0..2 {
4769 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4770 for j in 0..2 {
4771 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4772 for k in 0..2 {
4773 assert_eq!(
4774 a.t3[i][j][k].to_bits(),
4775 b.t3[i][j][k].to_bits(),
4776 "{ctx}: t3[{i}][{j}][{k}]"
4777 );
4778 for l in 0..2 {
4779 assert_eq!(
4780 a.t4[i][j][k][l].to_bits(),
4781 b.t4[i][j][k][l].to_bits(),
4782 "{ctx}: t4[{i}][{j}][{k}][{l}]"
4783 );
4784 }
4785 }
4786 }
4787 }
4788 }
4789
4790 fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4791 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4792 for i in 0..2 {
4793 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4794 for j in 0..2 {
4795 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4796 for k in 0..2 {
4797 assert_eq!(
4798 a.t3[i][j][k].to_bits(),
4799 b.t3[i][j][k].to_bits(),
4800 "{ctx}: t3[{i}][{j}][{k}]"
4801 );
4802 }
4803 }
4804 }
4805 }
4806
4807 // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4808 struct Lcg(u64);
4809 impl Lcg {
4810 fn next(&mut self) -> f64 {
4811 self.0 = self
4812 .0
4813 .wrapping_mul(6364136223846793005)
4814 .wrapping_add(1442695040888963407);
4815 (self.0 >> 11) as f64 / (1u64 << 53) as f64
4816 }
4817 fn val(&mut self) -> f64 {
4818 let u = self.next();
4819 if u < 0.04 {
4820 return 0.0;
4821 }
4822 if u < 0.08 {
4823 return -0.0;
4824 }
4825 (self.next() - 0.5) * 5.0
4826 }
4827 }
4828
4829 /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4830 /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4831 /// signed-zero and per-lane-distinct primaries.
4832 #[test]
4833 fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4834 let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4835 let mut batches = 0usize;
4836 for _ in 0..2500 {
4837 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4838 // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4839 let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4840 let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4841 let rows = [0usize, 1, 2, 3];
4842
4843 // order-4 batch vs scalar Tower4 (instantiated through the same body).
4844 let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4845 for (row, base) in bases.iter().enumerate() {
4846 let vars: [Tower4<2>; 2] =
4847 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4848 let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4849 assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4850 }
4851
4852 // order-3 batch vs scalar Tower3.
4853 let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4854 for (row, base) in bases.iter().enumerate() {
4855 let vars: [Tower3<2>; 2] =
4856 std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4857 let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4858 assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4859 }
4860 batches += 1;
4861 }
4862 assert_eq!(batches, 2500);
4863 }
4864
4865 /// The blanket impl does not churn the scalar path: the body driven through
4866 /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4867 /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4868 /// `Tower4` lower channels.
4869 #[test]
4870 fn blanket_scalar_path_is_unchanged_and_consistent() {
4871 let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4872 for _ in 0..3000 {
4873 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4874 let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4875 let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4876
4877 // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4878 // reference body uses `scale(f64)` where the RowJet body uses
4879 // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4880 // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4881 let via_rowjet: Tower4<2> = {
4882 let vars: [Tower4<2>; 2] =
4883 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4884 prog.row_nll(&[0], &vars).expect("rowjet")
4885 };
4886 let via_jetscalar: Tower4<2> = {
4887 let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4888 <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4889 });
4890 let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4891 // The body using JetScalar's own ops + scale(f64) directly.
4892 let a = vars[0].mul(&vars[1]).scale(cov);
4893 let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4894 let c = b
4895 .compose_unary_with(|u| {
4896 let e = u.exp();
4897 [e, e, e, e, e]
4898 })
4899 .scale(z);
4900 let d = JetScalar::neg(&c).add(&vars[0]);
4901 let e = d
4902 .compose_unary_with(|u| {
4903 let s = (1.0 + u * u).sqrt();
4904 let s3 = s * s * s;
4905 let s5 = s3 * s * s;
4906 let s7 = s5 * s * s;
4907 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4908 })
4909 .scale(wi);
4910 e.mul(&vars[1]).add(&e)
4911 };
4912 assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4913
4914 // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4915 // Order2 and Tower4 use different internal representations so
4916 // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4917 // Hessian channels that evaluate to exactly zero; IEEE equality
4918 // treats these as equal, so `==` is the right comparison here.
4919 let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4920 assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4921 for i in 0..2 {
4922 assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4923 for j in 0..2 {
4924 assert!(
4925 h[i][j] == via_rowjet.h[i][j],
4926 "kernel h[{i}][{j}]: {} vs {}",
4927 h[i][j],
4928 via_rowjet.h[i][j]
4929 );
4930 }
4931 }
4932
4933 // (c) the Order2 scalar IS a RowJet via the blanket.
4934 let o2: [Order2<2>; 2] =
4935 std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4936 let via_order2 = prog.body(&[0], &o2);
4937 assert_eq!(
4938 via_order2.0.v.to_bits(),
4939 via_rowjet.v.to_bits(),
4940 "Order2 blanket value channel must match the dense Tower4 program body"
4941 );
4942 }
4943 }
4944
4945 /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4946 /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4947 /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4948 #[test]
4949 fn scale_rows_scalar_is_bit_identical_to_scale() {
4950 let mut rng = Lcg(0xFEED_FACE_0042_1001);
4951 for _ in 0..3000 {
4952 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4953 let s = rng.val();
4954 // Build a dense tower with populated channels (exp of a product).
4955 let vars: [Tower4<2>; 2] =
4956 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4957 let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4958 let e = u.exp();
4959 [e, e, e, e, e]
4960 });
4961 let via_scale = RowJet::scale(&jet, s);
4962 let via_scale_rows = RowJet::scale_rows(&jet, s);
4963 assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4964 }
4965 }
4966
4967 /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4968 /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4969 /// continuous per-row data path the single-`f64` `scale` could not carry.
4970 #[test]
4971 fn batched_scale_rows_matches_per_row_scalar_scale() {
4972 let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4973 for _ in 0..2500 {
4974 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4975 let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4976 let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4977 Tower4Batch::variable(
4978 wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4979 a,
4980 )
4981 });
4982 let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4983 let e = u.exp();
4984 [e, e, e, e, e]
4985 });
4986 let scaled = prod.scale_rows(s);
4987 for (row, base) in bases.iter().enumerate() {
4988 let v: [Tower4<2>; 2] =
4989 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4990 let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4991 let e = u.exp();
4992 [e, e, e, e, e]
4993 });
4994 let ref_s = RowJet::scale(&prod_s, s[row]);
4995 assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4996 }
4997 }
4998 }
4999
5000 /// The per-lane guard reports exactly the failing lanes on a batch and the
5001 /// single lane on a scalar jet.
5002 #[test]
5003 fn guard_reports_per_lane_failures() {
5004 let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
5005 let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
5006 Tower4Batch::variable(
5007 wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
5008 a,
5009 )
5010 });
5011 let verdict = vars[0].guard(|v| v > 0.0);
5012 assert_eq!(verdict.lanes(), 4);
5013 assert!(verdict.any_failed());
5014 assert!(!verdict.all_pass());
5015 assert!(!verdict.lane_failed(0));
5016 assert!(verdict.lane_failed(1));
5017 assert!(!verdict.lane_failed(2));
5018 assert!(verdict.lane_failed(3));
5019 assert_eq!(verdict.failed_mask(), 0b1010);
5020
5021 let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
5022 let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
5023 assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
5024 assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
5025 assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
5026 }
5027
5028 // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
5029
5030 #[test]
5031 fn ln_gamma_derivative_stack_known_values_at_1() {
5032 let s = ln_gamma_derivative_stack(1.0);
5033 // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
5034 assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
5035 // ψ₀(1) = -γ (Euler–Mascheroni)
5036 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5037 assert!(
5038 (s[1] + euler_mascheroni).abs() < 1e-10,
5039 "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
5040 s[1]
5041 );
5042 // ψ₁(1) = π²/6
5043 let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
5044 assert!(
5045 (s[2] - pi2_6).abs() < 1e-10,
5046 "trigamma(1) ≈ {pi2_6:.6}, got {}",
5047 s[2]
5048 );
5049 }
5050
5051 #[test]
5052 fn ln_gamma_derivative_stack_known_values_at_2() {
5053 let s = ln_gamma_derivative_stack(2.0);
5054 // ln Γ(2) = ln(1) = 0 exactly
5055 assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
5056 // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
5057 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
5058 let digamma_2 = 1.0 - euler_mascheroni;
5059 assert!(
5060 (s[1] - digamma_2).abs() < 1e-10,
5061 "digamma(2) ≈ {digamma_2:.6}, got {}",
5062 s[1]
5063 );
5064 }
5065
5066 #[test]
5067 fn ln_gamma_derivative_stack_order2_is_prefix() {
5068 for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5069 let full = ln_gamma_derivative_stack(x);
5070 let ord2 = ln_gamma_derivative_stack_order2(x);
5071 assert_eq!(
5072 ord2[0], full[0],
5073 "order2[0] != full[0] at x={x}"
5074 );
5075 assert_eq!(
5076 ord2[1], full[1],
5077 "order2[1] != full[1] at x={x}"
5078 );
5079 assert_eq!(
5080 ord2[2], full[2],
5081 "order2[2] != full[2] at x={x}"
5082 );
5083 }
5084 }
5085
5086 #[test]
5087 fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5088 // The two stacks share a run of four polygamma values:
5089 // ln_gamma_stack[1..5] == digamma_stack[0..4]
5090 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5091 let lg = ln_gamma_derivative_stack(x);
5092 let dg = digamma_derivative_stack(x);
5093 for i in 0..4 {
5094 assert_eq!(
5095 lg[i + 1], dg[i],
5096 "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5097 i + 1,
5098 i
5099 );
5100 }
5101 }
5102 }
5103
5104 #[test]
5105 fn trigamma_derivative_stack_overlaps_digamma_stack() {
5106 // digamma_stack[1..5] == trigamma_stack[0..4]
5107 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5108 let dg = digamma_derivative_stack(x);
5109 let tg = trigamma_derivative_stack(x);
5110 for i in 0..4 {
5111 assert_eq!(
5112 dg[i + 1], tg[i],
5113 "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5114 i + 1,
5115 i
5116 );
5117 }
5118 }
5119 }
5120
5121 #[test]
5122 fn derivative_stacks_all_finite_at_positive_inputs() {
5123 for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5124 for v in ln_gamma_derivative_stack(x) {
5125 assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5126 }
5127 for v in digamma_derivative_stack(x) {
5128 assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5129 }
5130 for v in trigamma_derivative_stack(x) {
5131 assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5132 }
5133 }
5134 }
5135}
5136
5137// ── Contraction-symmetry optimization gate ────────────────────────────────────
5138//
5139// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5140// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5141// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5142// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5143// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5144// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5145// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5146// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5147// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5148// the full reference and (2) a measured wall-clock that is not slower.
5149#[cfg(test)]
5150mod contraction_symmetry_tests {
5151 use super::*;
5152
5153 struct Rng(u64);
5154 impl Rng {
5155 fn u(&mut self) -> f64 {
5156 self.0 = self
5157 .0
5158 .wrapping_mul(6364136223846793005)
5159 .wrapping_add(1442695040888963407);
5160 (self.0 >> 11) as f64 / (1u64 << 53) as f64
5161 }
5162 fn s(&mut self) -> f64 {
5163 (self.u() - 0.5) * 4.0
5164 }
5165 }
5166
5167 /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5168 fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5169 let mut t = Tower4::<K>::zero();
5170 t.v = r.s();
5171 for i in 0..K {
5172 t.g[i] = r.s();
5173 }
5174 for a in 0..K {
5175 for b in a..K {
5176 let v2 = r.s();
5177 t.h[a][b] = v2;
5178 t.h[b][a] = v2;
5179 for c in b..K {
5180 let v3 = r.s();
5181 for p in perms3([a, b, c]) {
5182 t.t3[p[0]][p[1]][p[2]] = v3;
5183 }
5184 for d in c..K {
5185 let v4 = r.s();
5186 for p in perms4([a, b, c, d]) {
5187 t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5188 }
5189 }
5190 }
5191 }
5192 }
5193 t
5194 }
5195
5196 fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5197 let [a, b, c] = idx;
5198 [[a, b, c], [a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]]
5199 }
5200 fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5201 let [a, b, c, d] = idx;
5202 [
5203 [a, b, c, d], [a, b, d, c], [a, c, b, d], [a, c, d, b], [a, d, b, c], [a, d, c, b],
5204 [b, a, c, d], [b, a, d, c], [b, c, a, d], [b, c, d, a], [b, d, a, c], [b, d, c, a],
5205 [c, a, b, d], [c, a, d, b], [c, b, a, d], [c, b, d, a], [c, d, a, b], [c, d, b, a],
5206 [d, a, b, c], [d, a, c, b], [d, b, a, c], [d, b, c, a], [d, c, a, b], [d, c, b, a],
5207 ]
5208 }
5209
5210 /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5211 fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5212 let mut out = [[0.0; K]; K];
5213 for a in 0..K {
5214 for b in 0..K {
5215 let mut acc = 0.0;
5216 for c in 0..K {
5217 acc += t.t3[a][b][c] * dir[c];
5218 }
5219 out[a][b] = acc;
5220 }
5221 }
5222 out
5223 }
5224 fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5225 let mut out = [[0.0; K]; K];
5226 for i in 0..K {
5227 for j in 0..K {
5228 let mut acc = 0.0;
5229 for k in 0..K {
5230 for l in 0..K {
5231 acc += t.t4[i][j][k][l] * u[k] * w[l];
5232 }
5233 }
5234 out[i][j] = acc;
5235 }
5236 }
5237 out
5238 }
5239
5240 /// Returns the number of bit-equality comparisons performed (`n·K·K·2`), so
5241 /// the caller can assert the intended workload actually ran: a generic
5242 /// (turbofish) helper call hides its internal assertions, so the count is
5243 /// surfaced and checked at the call site.
5244 fn check_bit_identical<const K: usize>(seed: u64, n: usize) -> usize {
5245 let mut r = Rng(seed);
5246 let mut checks = 0usize;
5247 for _ in 0..n {
5248 let t = rand_sym4::<K>(&mut r);
5249 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5250 let u: [f64; K] = std::array::from_fn(|_| r.s());
5251 let w: [f64; K] = std::array::from_fn(|_| r.s());
5252 let t3_sym = t.third_contracted(&dir);
5253 let t3_full = third_full(&t, &dir);
5254 let t4_sym = t.fourth_contracted(&u, &w);
5255 let t4_full = fourth_full(&t, &u, &w);
5256 for a in 0..K {
5257 for b in 0..K {
5258 assert_eq!(
5259 t3_sym[a][b].to_bits(),
5260 t3_full[a][b].to_bits(),
5261 "third K={K} [{a}][{b}]"
5262 );
5263 assert_eq!(
5264 t4_sym[a][b].to_bits(),
5265 t4_full[a][b].to_bits(),
5266 "fourth K={K} [{a}][{b}]"
5267 );
5268 checks += 2;
5269 }
5270 }
5271 }
5272 checks
5273 }
5274
5275 /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5276 /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5277 /// are unchanged; this is a pure speed-only optimization).
5278 #[test]
5279 fn contraction_symmetry_is_bit_identical_to_full_nest() {
5280 let checks = check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000)
5281 + check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800)
5282 + check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600)
5283 + check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5284 // Guards against the loops silently not running (e.g. a zeroed count):
5285 // 1000·2²·2 + 800·3²·2 + 600·4²·2 + 300·9²·2.
5286 assert_eq!(checks, 8000 + 14400 + 19200 + 48600);
5287 }
5288
5289 /// Measure the wall-clock of the output-symmetric contraction vs the full
5290 /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5291 /// test is the correctness gate). Informational — wall-clock is noisy — with
5292 /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5293 /// fewer inner contractions, so it must not be materially slower).
5294 #[test]
5295 fn contraction_symmetry_speedup_is_reported() {
5296 const K: usize = 9;
5297 let mut r = Rng(0xC0FF_EE99_1234_5678);
5298 let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5299 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5300 let u: [f64; K] = std::array::from_fn(|_| r.s());
5301 let w: [f64; K] = std::array::from_fn(|_| r.s());
5302
5303 let reps = 400usize;
5304 let t_sym = {
5305 let start = std::time::Instant::now();
5306 let mut sink = 0.0f64;
5307 for _ in 0..reps {
5308 for t in &towers {
5309 let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5310 let o4 = std::hint::black_box(t)
5311 .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5312 sink += o3[0][K - 1] + o4[0][K - 1];
5313 }
5314 }
5315 std::hint::black_box(sink);
5316 start.elapsed().as_secs_f64()
5317 };
5318 let t_full = {
5319 let start = std::time::Instant::now();
5320 let mut sink = 0.0f64;
5321 for _ in 0..reps {
5322 for t in &towers {
5323 let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5324 let o4 = fourth_full(
5325 std::hint::black_box(t),
5326 std::hint::black_box(&u),
5327 std::hint::black_box(&w),
5328 );
5329 sink += o3[0][K - 1] + o4[0][K - 1];
5330 }
5331 }
5332 std::hint::black_box(sink);
5333 start.elapsed().as_secs_f64()
5334 };
5335 let calls = (reps * towers.len()) as f64;
5336 eprintln!(
5337 "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5338 wall_speedup={:.2}x",
5339 t_sym / calls * 1e9,
5340 t_full / calls * 1e9,
5341 t_full / t_sym
5342 );
5343 assert!(
5344 t_sym <= t_full * 1.5,
5345 "output-symmetric contraction pathologically slower: \
5346 sym={t_sym:.4}s full={t_full:.4}s"
5347 );
5348 }
5349}