gam_math/jet_tower.rs
1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//! v ℓ
12//! g[a] ∂ℓ/∂p_a
13//! h[a][b] ∂²ℓ/∂p_a∂p_b
14//! t3[abc] ∂³ℓ/∂p_a∂p_b∂p_c
15//! t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//! asserting channel-by-channel agreement with a `RowNllProgram` written
94//! once — see [`verify_kernel_channels`]. This alone would have caught
95//! #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//! keep hand-tuned hot paths, now verified against the single-expression
98//! truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//! scale port) implement ONLY `RowNllProgram` and get an exact
101//! fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113 /// Value ℓ.
114 pub v: f64,
115 /// Gradient ∂ℓ/∂p_a.
116 pub g: [f64; K],
117 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118 pub h: [[f64; K]; K],
119 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120 pub t3: [[[f64; K]; K]; K],
121 /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122 pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126 /// The additive identity.
127 pub fn zero() -> Self {
128 Self {
129 v: 0.0,
130 g: [0.0; K],
131 h: [[0.0; K]; K],
132 t3: [[[0.0; K]; K]; K],
133 t4: [[[[0.0; K]; K]; K]; K],
134 }
135 }
136
137 /// A constant: value `c`, all derivatives zero.
138 pub fn constant(c: f64) -> Self {
139 let mut out = Self::zero();
140 out.v = c;
141 out
142 }
143
144 /// The seeded variable `p_idx` with current value `value`:
145 /// unit first derivative in slot `idx`, zero elsewhere and above.
146 pub fn variable(value: f64, idx: usize) -> Self {
147 let mut out = Self::constant(value);
148 out.g[idx] = 1.0;
149 out
150 }
151
152 /// Read the (fully symmetric) derivative tensor entry whose differentiation
153 /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154 #[inline]
155 fn deriv(&self, labels: &[usize]) -> f64 {
156 assert!(
157 labels.len() <= 4,
158 "Tower4 carries at most fourth-order derivatives"
159 );
160 match labels.len() {
161 0 => self.v,
162 1 => self.g[labels[0]],
163 2 => self.h[labels[0]][labels[1]],
164 3 => self.t3[labels[0]][labels[1]][labels[2]],
165 _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166 }
167 }
168
169 /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170 ///
171 /// # Codegen
172 ///
173 /// Each output entry's `2^m` subset sum is written as a compact straight-line
174 /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175 /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176 /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177 /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178 /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179 /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180 /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181 ///
182 /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183 /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184 /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185 /// start so a signed-zero leading product collapses to `+0.0` identically —
186 /// which matters because real jets carry exact-`0.0` channels
187 /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188 /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189 /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190 /// stress — the accumulator start is load-bearing).
191 pub fn mul(&self, o: &Self) -> Self {
192 let a = self;
193 let b = o;
194 let mut out = Self::zero();
195 out.v = a.v * b.v;
196 for i in 0..K {
197 // subsets of {i}: {} {i}
198 let mut acc = 0.0;
199 acc += a.v * b.g[i];
200 acc += a.g[i] * b.v;
201 out.g[i] = acc;
202 }
203 for i in 0..K {
204 for j in 0..K {
205 // subsets of {i,j}: {} {i} {j} {ij}
206 let mut acc = 0.0;
207 acc += a.v * b.h[i][j];
208 acc += a.g[i] * b.g[j];
209 acc += a.g[j] * b.g[i];
210 acc += a.h[i][j] * b.v;
211 out.h[i][j] = acc;
212 }
213 }
214 for i in 0..K {
215 for j in 0..K {
216 for k in 0..K {
217 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
218 let mut acc = 0.0;
219 acc += a.v * b.t3[i][j][k];
220 acc += a.g[i] * b.h[j][k];
221 acc += a.g[j] * b.h[i][k];
222 acc += a.h[i][j] * b.g[k];
223 acc += a.g[k] * b.h[i][j];
224 acc += a.h[i][k] * b.g[j];
225 acc += a.h[j][k] * b.g[i];
226 acc += a.t3[i][j][k] * b.v;
227 out.t3[i][j][k] = acc;
228 }
229 }
230 }
231 for i in 0..K {
232 for j in 0..K {
233 for k in 0..K {
234 for l in 0..K {
235 // subsets of {i,j,k,l} in bit order sub = 0..16
236 let mut acc = 0.0;
237 acc += a.v * b.t4[i][j][k][l];
238 acc += a.g[i] * b.t3[j][k][l];
239 acc += a.g[j] * b.t3[i][k][l];
240 acc += a.h[i][j] * b.h[k][l];
241 acc += a.g[k] * b.t3[i][j][l];
242 acc += a.h[i][k] * b.h[j][l];
243 acc += a.h[j][k] * b.h[i][l];
244 acc += a.t3[i][j][k] * b.g[l];
245 acc += a.g[l] * b.t3[i][j][k];
246 acc += a.h[i][l] * b.h[j][k];
247 acc += a.h[j][l] * b.h[i][k];
248 acc += a.t3[i][j][l] * b.g[k];
249 acc += a.h[k][l] * b.h[i][j];
250 acc += a.t3[i][k][l] * b.g[j];
251 acc += a.t3[j][k][l] * b.g[i];
252 acc += a.t4[i][j][k][l] * b.v;
253 out.t4[i][j][k][l] = acc;
254 }
255 }
256 }
257 }
258 out
259 }
260
261 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
262 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
263 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
264 /// out of the borrowed operands.
265 pub fn add(&self, o: &Self) -> Self {
266 *self + *o
267 }
268
269 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
270 pub fn sub(&self, o: &Self) -> Self {
271 *self + o.scale(-1.0)
272 }
273
274 /// Exact multivariate Faà di Bruno composition `f ∘ self`.
275 ///
276 /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
277 /// the SAME `[f64; 5]` stack shape the families' existing
278 /// `unary_derivatives_*` helpers produce, so those special-function
279 /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
280 ///
281 /// The order-m output sums over the set partitions of the m indices
282 /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
283 /// block count: each partition into r blocks contributes
284 /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
285 ///
286 /// # Codegen
287 ///
288 /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
289 /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
290 /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
291 /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
292 /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
293 /// sum is straight-line, so this does NOT unroll over `K` and does NOT
294 /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
295 /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
296 /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
297 ///
298 /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
299 /// walker's exact partition-enumeration order, each term's block products
300 /// are left-associated exactly as the walker's `prod *= block`, and the
301 /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
302 /// (so signed-zero products collapse to `+0.0` identically). The order-4
303 /// term sequence was generated from the walker's own enumeration. Proven
304 /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
305 /// 5000 random inputs each (zeroed / sign-varied stacks included).
306 pub fn compose_unary(&self, d: [f64; 5]) -> Self {
307 let mut out = Self::zero();
308 out.v = d[0];
309 for i in 0..K {
310 let mut acc = 0.0;
311 acc += d[1] * self.g[i];
312 out.g[i] = acc;
313 }
314 for i in 0..K {
315 for j in 0..K {
316 let mut acc = 0.0;
317 acc += d[1] * self.h[i][j];
318 acc += d[2] * self.g[i] * self.g[j];
319 out.h[i][j] = acc;
320 }
321 }
322 for i in 0..K {
323 for j in 0..K {
324 for k in 0..K {
325 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
326 let mut acc = 0.0;
327 acc += d[1] * self.t3[i][j][k];
328 acc += d[2] * self.h[i][j] * self.g[k];
329 acc += d[2] * self.h[i][k] * self.g[j];
330 acc += d[2] * self.g[i] * self.h[j][k];
331 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
332 out.t3[i][j][k] = acc;
333 }
334 }
335 }
336 for i in 0..K {
337 for j in 0..K {
338 for k in 0..K {
339 for l in 0..K {
340 // Bell(4)=15 partitions, walker enumeration order.
341 let mut acc = 0.0;
342 acc += d[1] * self.t4[i][j][k][l];
343 acc += d[2] * self.t3[i][j][k] * self.g[l];
344 acc += d[2] * self.t3[i][j][l] * self.g[k];
345 acc += d[2] * self.h[i][j] * self.h[k][l];
346 acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
347 acc += d[2] * self.t3[i][k][l] * self.g[j];
348 acc += d[2] * self.h[i][k] * self.h[j][l];
349 acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
350 acc += d[2] * self.h[i][l] * self.h[j][k];
351 acc += d[2] * self.g[i] * self.t3[j][k][l];
352 acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
353 acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
354 acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
355 acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
356 acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
357 out.t4[i][j][k][l] = acc;
358 }
359 }
360 }
361 }
362 out
363 }
364
365 /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
366 /// built from the base value through `stack_fn` — the scalar arm of the
367 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
368 /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
369 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
370 /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
371 /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
372 /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
373 #[inline]
374 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
375 self.compose_unary(stack_fn(self.v))
376 }
377
378 /// Single-active-slot fast path for [`Self::compose_unary`].
379 ///
380 /// When the inner jet `self` has derivative support ONLY on the all-`slot`
381 /// diagonal channels — i.e. it is a univariate jet in primary `slot`
382 /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
383 /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
384 /// Bruno walk collapses. Every output channel whose axis tuple contains an
385 /// axis `≠ slot` is structurally `0`: each set-partition has a block
386 /// covering that axis, that block reads an off-`slot` derivative of `self`
387 /// (which is `0`), so the block product and the whole partition vanish, and
388 /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
389 /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
390 /// `t3[slot]³`, `t4[slot]⁴`) survive.
391 ///
392 /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
393 /// the EXACT term order of [`Self::compose_unary`]'s diagonal
394 /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
395 /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
396 /// at the zero-init `+0.0`, which the full walk also produces (the
397 /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
398 /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
399 /// `K ≥ 2` this is far fewer floating-point operations than materialising
400 /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
401 /// are all zero, and far cheaper than the recursive set-partition walker the
402 /// diagonal channels previously routed through (a measured ~9.5× speedup vs
403 /// the full `compose_unary`, recovering a 5.9× walker regression at the
404 /// `K ∈ {2,3}` BMS tower widths).
405 ///
406 /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
407 /// five-channel build does not amortise the call/spill overhead).
408 ///
409 /// # Precondition
410 ///
411 /// The caller guarantees the single-active-slot structure. If it does not
412 /// hold, the off-`slot` channels would be wrongly zeroed; use the full
413 /// [`Self::compose_unary`] in that case.
414 #[inline]
415 pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
416 let mut out = Self::zero();
417 let s = slot;
418 let g = self.g[s];
419 let h = self.h[s][s];
420 let t3 = self.t3[s][s][s];
421 let t4 = self.t4[s][s][s][s];
422 out.v = d[0];
423 // g (i=s): d1*g
424 out.g[s] = {
425 let mut acc = 0.0;
426 acc += d[1] * g;
427 acc
428 };
429 // h (i=j=s): d1*h + d2*g*g
430 out.h[s][s] = {
431 let mut acc = 0.0;
432 acc += d[1] * h;
433 acc += d[2] * g * g;
434 acc
435 };
436 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
437 out.t3[s][s][s] = {
438 let mut acc = 0.0;
439 acc += d[1] * t3;
440 acc += d[2] * h * g;
441 acc += d[2] * h * g;
442 acc += d[2] * g * h;
443 acc += d[3] * g * g * g;
444 acc
445 };
446 // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
447 out.t4[s][s][s][s] = {
448 let mut acc = 0.0;
449 acc += d[1] * t4;
450 acc += d[2] * t3 * g;
451 acc += d[2] * t3 * g;
452 acc += d[2] * h * h;
453 acc += d[3] * h * g * g;
454 acc += d[2] * t3 * g;
455 acc += d[2] * h * h;
456 acc += d[3] * h * g * g;
457 acc += d[2] * h * h;
458 acc += d[2] * g * t3;
459 acc += d[3] * g * h * g;
460 acc += d[3] * h * g * g;
461 acc += d[3] * g * h * g;
462 acc += d[3] * g * g * h;
463 acc += d[4] * g * g * g * g;
464 acc
465 };
466 out
467 }
468
469 /// Multiply every channel by a plain scalar.
470 pub fn scale(&self, s: f64) -> Self {
471 let mut out = *self;
472 out.v *= s;
473 for i in 0..K {
474 out.g[i] *= s;
475 for j in 0..K {
476 out.h[i][j] *= s;
477 for k in 0..K {
478 out.t3[i][j][k] *= s;
479 for l in 0..K {
480 out.t4[i][j][k][l] *= s;
481 }
482 }
483 }
484 }
485 out
486 }
487
488 /// e^self.
489 pub fn exp(&self) -> Self {
490 let e = self.v.exp();
491 self.compose_unary([e, e, e, e, e])
492 }
493
494 /// ln(self). Caller guarantees positivity (likelihood programs do).
495 pub fn ln(&self) -> Self {
496 let u = self.v;
497 let r = 1.0 / u;
498 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
499 }
500
501 /// 1/self.
502 pub fn recip(&self) -> Self {
503 let r = 1.0 / self.v;
504 let r2 = r * r;
505 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
506 }
507
508 /// √self. Caller guarantees positivity.
509 pub fn sqrt(&self) -> Self {
510 let u = self.v;
511 let s = u.sqrt();
512 self.compose_unary([
513 s,
514 0.5 / s,
515 -0.25 / (u * s),
516 0.375 / (u * u * s),
517 -0.9375 / (u * u * u * s),
518 ])
519 }
520
521 /// self^a for real exponent `a`. Caller guarantees a positive base.
522 pub fn powf(&self, a: f64) -> Self {
523 let u = self.v;
524 let f0 = u.powf(a);
525 let f1 = a * u.powf(a - 1.0);
526 let f2 = a * (a - 1.0) * u.powf(a - 2.0);
527 let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
528 let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
529 self.compose_unary([f0, f1, f2, f3, f4])
530 }
531
532 /// ln Γ(self). Caller guarantees positivity.
533 pub fn ln_gamma(&self) -> Self {
534 self.compose_unary(ln_gamma_derivative_stack(self.v))
535 }
536
537 /// ψ(self), the digamma function. Caller guarantees positivity.
538 pub fn digamma(&self) -> Self {
539 self.compose_unary(digamma_derivative_stack(self.v))
540 }
541
542 /// ψ′(self), the trigamma function. Caller guarantees positivity.
543 pub fn trigamma(&self) -> Self {
544 self.compose_unary(trigamma_derivative_stack(self.v))
545 }
546
547 /// Contract `t3` with one primary-space direction:
548 /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
549 /// `row_third_contracted` shape.
550 ///
551 /// The output is symmetric in `(a, b)`: `t3` is fully index-symmetric, so
552 /// `t3[a][b][c] == t3[b][a][c]` and the `Σ_c` contraction gives
553 /// `out[a][b] == out[b][a]` term-for-term, in the same `c` order. We compute
554 /// only the upper triangle `a ≤ b` (the inner contraction is unchanged and
555 /// stays contiguous/vectorisable) and mirror into the lower triangle — this
556 /// is BIT-IDENTICAL to the full `a, b ∈ 0..K` nest while doing ~2× fewer
557 /// inner contractions, with no dense scatter (the mirror is a `K × K` copy).
558 pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
559 let mut out = [[0.0; K]; K];
560 for a in 0..K {
561 for b in a..K {
562 let mut acc = 0.0;
563 for c in 0..K {
564 acc += self.t3[a][b][c] * dir[c];
565 }
566 out[a][b] = acc;
567 out[b][a] = acc;
568 }
569 }
570 out
571 }
572
573 /// Contract `t4` with two primary-space directions:
574 /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
575 /// `row_fourth_contracted` shape.
576 ///
577 /// As in [`Self::third_contracted`], the output is symmetric in `(i, j)`
578 /// (`t4[j][i][k][l] == t4[i][j][k][l]`, contracted in the same `(k, l)`
579 /// order), so the upper triangle `i ≤ j` is computed and mirrored —
580 /// BIT-IDENTICAL to the full nest, ~2× fewer inner `Σ_{k,l}` contractions,
581 /// and the inner double loop stays the original contiguous/vectorisable form.
582 pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
583 let mut out = [[0.0; K]; K];
584 for i in 0..K {
585 for j in i..K {
586 let mut acc = 0.0;
587 for k in 0..K {
588 for l in 0..K {
589 acc += self.t4[i][j][k][l] * u[k] * w[l];
590 }
591 }
592 out[i][j] = acc;
593 out[j][i] = acc;
594 }
595 }
596 out
597 }
598}
599
600impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
601 #[inline]
602 fn derivative(&self, labels: &[usize]) -> f64 {
603 self.deriv(labels)
604 }
605
606 fn map_derivatives<F>(&self, mut f: F) -> Self
607 where
608 F: FnMut(&[usize]) -> f64,
609 {
610 let mut out = Self::zero();
611 out.v = f(&[]);
612 for i in 0..K {
613 let labels = [i];
614 out.g[i] = f(&labels);
615 }
616 for i in 0..K {
617 for j in 0..K {
618 let labels = [i, j];
619 out.h[i][j] = f(&labels);
620 }
621 }
622 for i in 0..K {
623 for j in 0..K {
624 for k in 0..K {
625 let labels = [i, j, k];
626 out.t3[i][j][k] = f(&labels);
627 }
628 }
629 }
630 for i in 0..K {
631 for j in 0..K {
632 for k in 0..K {
633 for l in 0..K {
634 let labels = [i, j, k, l];
635 out.t4[i][j][k][l] = f(&labels);
636 }
637 }
638 }
639 }
640 out
641 }
642}
643
644/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
645///
646/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
647/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
648/// [`Tower4`] uses for those orders, so for any program written over both
649/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
650/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
651/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
652/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
653/// gradient, or Hessian.
654///
655/// It exists purely for performance: an inner Newton step (and the
656/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
657/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
658/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
659/// dominates the cold marginal-slope fit, while returning the exact same
660/// `(v, g, h)`.
661#[derive(Clone, Copy, Debug)]
662pub struct Tower2<const K: usize> {
663 /// Value ℓ.
664 pub v: f64,
665 /// Gradient ∂ℓ/∂p_a.
666 pub g: [f64; K],
667 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
668 pub h: [[f64; K]; K],
669}
670
671impl<const K: usize> Tower2<K> {
672 /// The additive identity.
673 pub fn zero() -> Self {
674 Self {
675 v: 0.0,
676 g: [0.0; K],
677 h: [[0.0; K]; K],
678 }
679 }
680
681 /// A constant: value `c`, all derivatives zero.
682 pub fn constant(c: f64) -> Self {
683 let mut out = Self::zero();
684 out.v = c;
685 out
686 }
687
688 /// The seeded variable `p_idx` with current value `value`:
689 /// unit first derivative in slot `idx`, zero elsewhere and above.
690 pub fn variable(value: f64, idx: usize) -> Self {
691 let mut out = Self::constant(value);
692 out.g[idx] = 1.0;
693 out
694 }
695
696 /// Read the derivative tensor entry whose differentiation axes are
697 /// `labels` (length 0..=2): value, `g`, `h`.
698 #[inline]
699 fn deriv(&self, labels: &[usize]) -> f64 {
700 assert!(
701 labels.len() <= 2,
702 "Tower2 carries at most second-order derivatives"
703 );
704 match labels.len() {
705 0 => self.v,
706 1 => self.g[labels[0]],
707 _ => self.h[labels[0]][labels[1]],
708 }
709 }
710
711 /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` channels
712 /// match [`Tower4::mul`] term-for-term.
713 pub fn mul(&self, o: &Self) -> Self {
714 let a = self;
715 let b = o;
716 let mut out = Self::zero();
717 out.v = a.v * b.v;
718 for i in 0..K {
719 out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
720 }
721 for i in 0..K {
722 for j in 0..K {
723 out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
724 }
725 }
726 out
727 }
728
729 /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
730 ///
731 /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
732 /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
733 /// `d[0..=2]` for those orders), so this is a strict truncation, not an
734 /// approximation. The full-order `[f64; 5]` derivative stacks the families
735 /// already produce can be passed by slicing their first three entries.
736 ///
737 /// # Codegen
738 ///
739 /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
740 /// directly instead of routing through the generic
741 /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
742 /// closure dispatch). That matters because this is the kernel under EVERY
743 /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
744 /// composition all bottom out here — so the straight-line form (whose inner
745 /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
746 /// walker calls) lifts all of them at once.
747 ///
748 /// The term and accumulation order is BIT-IDENTICAL to the walker it
749 /// replaces: each output channel mirrors the walker's `total = 0.0` start
750 /// (the explicit `acc` accumulator), so a signed-zero product collapses to
751 /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
752 /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
753 /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
754 /// `g[i]` = `f′·u_i` (single block `{i}`)
755 /// `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j` (blocks `{ij}` then `{i}{j}`),
756 /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
757 pub fn compose_unary(&self, d: [f64; 3]) -> Self {
758 let mut out = Self::zero();
759 out.v = d[0];
760 for i in 0..K {
761 let mut acc = 0.0;
762 acc += d[1] * self.g[i];
763 out.g[i] = acc;
764 }
765 for i in 0..K {
766 for j in 0..K {
767 let mut acc = 0.0;
768 acc += d[1] * self.h[i][j];
769 acc += d[2] * self.g[i] * self.g[j];
770 out.h[i][j] = acc;
771 }
772 }
773 out
774 }
775
776 /// Multiply every channel by a plain scalar.
777 pub fn scale(&self, s: f64) -> Self {
778 let mut out = *self;
779 out.v *= s;
780 for i in 0..K {
781 out.g[i] *= s;
782 for j in 0..K {
783 out.h[i][j] *= s;
784 }
785 }
786 out
787 }
788
789 /// e^self.
790 pub fn exp(&self) -> Self {
791 let e = self.v.exp();
792 self.compose_unary([e, e, e])
793 }
794
795 /// √self. Caller guarantees positivity.
796 pub fn sqrt(&self) -> Self {
797 let u = self.v;
798 let s = u.sqrt();
799 self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
800 }
801}
802
803impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
804 #[inline]
805 fn derivative(&self, labels: &[usize]) -> f64 {
806 self.deriv(labels)
807 }
808
809 fn map_derivatives<F>(&self, mut f: F) -> Self
810 where
811 F: FnMut(&[usize]) -> f64,
812 {
813 let mut out = Self::zero();
814 out.v = f(&[]);
815 for i in 0..K {
816 let labels = [i];
817 out.g[i] = f(&labels);
818 }
819 for i in 0..K {
820 for j in 0..K {
821 let labels = [i, j];
822 out.h[i][j] = f(&labels);
823 }
824 }
825 out
826 }
827}
828
829impl<const K: usize> std::ops::Add for Tower2<K> {
830 type Output = Self;
831 fn add(self, o: Self) -> Self {
832 let mut out = self;
833 out.v += o.v;
834 for i in 0..K {
835 out.g[i] += o.g[i];
836 for j in 0..K {
837 out.h[i][j] += o.h[i][j];
838 }
839 }
840 out
841 }
842}
843
844impl<const K: usize> std::ops::Mul for Tower2<K> {
845 type Output = Self;
846 fn mul(self, o: Self) -> Self {
847 Tower2::mul(&self, &o)
848 }
849}
850
851impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
852 type Output = Self;
853 fn add(self, c: f64) -> Self {
854 let mut out = self;
855 out.v += c;
856 out
857 }
858}
859
860impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
861 type Output = Self;
862 fn mul(self, c: f64) -> Self {
863 self.scale(c)
864 }
865}
866
867/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
868///
869/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
870/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
871/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
872/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
873/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
874/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
875/// [`Tower4::compose_unary`]). So for any program written over both towers the
876/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
877/// perturb the value, gradient, Hessian, or third derivatives.
878///
879/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
880/// needs up to third derivatives (the survival location-scale row kernel reads
881/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
882/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
883/// product/composition that otherwise dominates the per-row cost.
884#[derive(Clone, Copy, Debug)]
885pub struct Tower3<const K: usize> {
886 /// Value ℓ.
887 pub v: f64,
888 /// Gradient ∂ℓ/∂p_a.
889 pub g: [f64; K],
890 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
891 pub h: [[f64; K]; K],
892 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
893 pub t3: [[[f64; K]; K]; K],
894}
895
896impl<const K: usize> Tower3<K> {
897 /// The additive identity.
898 pub fn zero() -> Self {
899 Self {
900 v: 0.0,
901 g: [0.0; K],
902 h: [[0.0; K]; K],
903 t3: [[[0.0; K]; K]; K],
904 }
905 }
906
907 /// A constant: value `c`, all derivatives zero.
908 pub fn constant(c: f64) -> Self {
909 let mut out = Self::zero();
910 out.v = c;
911 out
912 }
913
914 /// The seeded variable `p_idx` with current value `value`:
915 /// unit first derivative in slot `idx`, zero elsewhere and above.
916 pub fn variable(value: f64, idx: usize) -> Self {
917 let mut out = Self::constant(value);
918 out.g[idx] = 1.0;
919 out
920 }
921
922 /// Read the (fully symmetric) derivative tensor entry whose differentiation
923 /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
924 #[inline]
925 fn deriv(&self, labels: &[usize]) -> f64 {
926 assert!(
927 labels.len() <= 3,
928 "Tower3 carries at most third-order derivatives"
929 );
930 match labels.len() {
931 0 => self.v,
932 1 => self.g[labels[0]],
933 2 => self.h[labels[0]][labels[1]],
934 _ => self.t3[labels[0]][labels[1]][labels[2]],
935 }
936 }
937
938 /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
939 /// channels match [`Tower4::mul`] term-for-term.
940 ///
941 /// # Codegen
942 ///
943 /// Straight-line per-entry subset sums instead of the
944 /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
945 /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
946 /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
947 /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
948 /// signed-zero leading product on exact-`0.0` jet channels). Proven
949 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
950 /// zero/sign-stressed inputs each (these channel formulas are exactly the
951 /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
952 pub fn mul(&self, o: &Self) -> Self {
953 let a = self;
954 let b = o;
955 let mut out = Self::zero();
956 out.v = a.v * b.v;
957 for i in 0..K {
958 let mut acc = 0.0;
959 acc += a.v * b.g[i];
960 acc += a.g[i] * b.v;
961 out.g[i] = acc;
962 }
963 for i in 0..K {
964 for j in 0..K {
965 let mut acc = 0.0;
966 acc += a.v * b.h[i][j];
967 acc += a.g[i] * b.g[j];
968 acc += a.g[j] * b.g[i];
969 acc += a.h[i][j] * b.v;
970 out.h[i][j] = acc;
971 }
972 }
973 for i in 0..K {
974 for j in 0..K {
975 for k in 0..K {
976 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
977 let mut acc = 0.0;
978 acc += a.v * b.t3[i][j][k];
979 acc += a.g[i] * b.h[j][k];
980 acc += a.g[j] * b.h[i][k];
981 acc += a.h[i][j] * b.g[k];
982 acc += a.g[k] * b.h[i][j];
983 acc += a.h[i][k] * b.g[j];
984 acc += a.h[j][k] * b.g[i];
985 acc += a.t3[i][j][k] * b.v;
986 out.t3[i][j][k] = acc;
987 }
988 }
989 }
990 out
991 }
992
993 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
994 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
995 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
996 /// out of the borrowed operands.
997 pub fn add(&self, o: &Self) -> Self {
998 *self + *o
999 }
1000
1001 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
1002 pub fn sub(&self, o: &Self) -> Self {
1003 *self + o.scale(-1.0)
1004 }
1005
1006 /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
1007 ///
1008 /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
1009 /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
1010 /// (which uses only `d[0..=3]` for those orders), so this is a strict
1011 /// truncation, not an approximation. The full-order `[f64; 5]` derivative
1012 /// stacks the families already produce can be passed by slicing their first
1013 /// four entries.
1014 ///
1015 /// # Codegen
1016 ///
1017 /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1018 /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1019 /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1020 /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1021 /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1022 /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1023 /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1024 /// associated block products, `acc = 0.0` accumulator start. Proven
1025 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1026 /// random inputs each.
1027 pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1028 let mut out = Self::zero();
1029 out.v = d[0];
1030 for i in 0..K {
1031 let mut acc = 0.0;
1032 acc += d[1] * self.g[i];
1033 out.g[i] = acc;
1034 }
1035 for i in 0..K {
1036 for j in 0..K {
1037 let mut acc = 0.0;
1038 acc += d[1] * self.h[i][j];
1039 acc += d[2] * self.g[i] * self.g[j];
1040 out.h[i][j] = acc;
1041 }
1042 }
1043 for i in 0..K {
1044 for j in 0..K {
1045 for k in 0..K {
1046 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1047 let mut acc = 0.0;
1048 acc += d[1] * self.t3[i][j][k];
1049 acc += d[2] * self.h[i][j] * self.g[k];
1050 acc += d[2] * self.h[i][k] * self.g[j];
1051 acc += d[2] * self.g[i] * self.h[j][k];
1052 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1053 out.t3[i][j][k] = acc;
1054 }
1055 }
1056 }
1057 out
1058 }
1059
1060 /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1061 /// built from the base value through `stack_fn` — the scalar arm of the
1062 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1063 /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1064 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1065 /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1066 /// [`Tower4::compose_unary_with`].
1067 #[inline]
1068 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1069 self.compose_unary(stack_fn(self.v))
1070 }
1071
1072 /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1073 /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1074 /// derivative support only on the all-`slot` diagonal, every output channel
1075 /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1076 /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1077 /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1078 /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1079 /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1080 /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1081 /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1082 /// set-partition walker the diagonal channels previously routed through,
1083 /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1084 /// widths. Caller guarantees the single-slot precondition; otherwise use
1085 /// [`Self::compose_unary`].
1086 #[inline]
1087 pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1088 let mut out = Self::zero();
1089 let s = slot;
1090 let g = self.g[s];
1091 let h = self.h[s][s];
1092 let t3 = self.t3[s][s][s];
1093 out.v = d[0];
1094 // g (i=s): d1*g
1095 out.g[s] = {
1096 let mut acc = 0.0;
1097 acc += d[1] * g;
1098 acc
1099 };
1100 // h (i=j=s): d1*h + d2*g*g
1101 out.h[s][s] = {
1102 let mut acc = 0.0;
1103 acc += d[1] * h;
1104 acc += d[2] * g * g;
1105 acc
1106 };
1107 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1108 out.t3[s][s][s] = {
1109 let mut acc = 0.0;
1110 acc += d[1] * t3;
1111 acc += d[2] * h * g;
1112 acc += d[2] * h * g;
1113 acc += d[2] * g * h;
1114 acc += d[3] * g * g * g;
1115 acc
1116 };
1117 out
1118 }
1119
1120 /// Multiply every channel by a plain scalar.
1121 pub fn scale(&self, s: f64) -> Self {
1122 let mut out = *self;
1123 out.v *= s;
1124 for i in 0..K {
1125 out.g[i] *= s;
1126 for j in 0..K {
1127 out.h[i][j] *= s;
1128 for k in 0..K {
1129 out.t3[i][j][k] *= s;
1130 }
1131 }
1132 }
1133 out
1134 }
1135}
1136
1137impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1138 #[inline]
1139 fn derivative(&self, labels: &[usize]) -> f64 {
1140 self.deriv(labels)
1141 }
1142
1143 fn map_derivatives<F>(&self, mut f: F) -> Self
1144 where
1145 F: FnMut(&[usize]) -> f64,
1146 {
1147 let mut out = Self::zero();
1148 out.v = f(&[]);
1149 for i in 0..K {
1150 let labels = [i];
1151 out.g[i] = f(&labels);
1152 }
1153 for i in 0..K {
1154 for j in 0..K {
1155 let labels = [i, j];
1156 out.h[i][j] = f(&labels);
1157 }
1158 }
1159 for i in 0..K {
1160 for j in 0..K {
1161 for k in 0..K {
1162 let labels = [i, j, k];
1163 out.t3[i][j][k] = f(&labels);
1164 }
1165 }
1166 }
1167 out
1168 }
1169}
1170
1171impl<const K: usize> std::ops::Add for Tower3<K> {
1172 type Output = Self;
1173 fn add(self, o: Self) -> Self {
1174 let mut out = self;
1175 out.v += o.v;
1176 for i in 0..K {
1177 out.g[i] += o.g[i];
1178 for j in 0..K {
1179 out.h[i][j] += o.h[i][j];
1180 for k in 0..K {
1181 out.t3[i][j][k] += o.t3[i][j][k];
1182 }
1183 }
1184 }
1185 out
1186 }
1187}
1188
1189pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1190 [
1191 statrs::function::gamma::ln_gamma(x),
1192 digamma_positive(x),
1193 polygamma_positive(1, x),
1194 polygamma_positive(2, x),
1195 polygamma_positive(3, x),
1196 ]
1197}
1198
1199pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1200 [
1201 statrs::function::gamma::ln_gamma(x),
1202 digamma_positive(x),
1203 polygamma_positive(1, x),
1204 ]
1205}
1206
1207pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1208 [
1209 digamma_positive(x),
1210 polygamma_positive(1, x),
1211 polygamma_positive(2, x),
1212 polygamma_positive(3, x),
1213 polygamma_positive(4, x),
1214 ]
1215}
1216
1217pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1218 [
1219 polygamma_positive(1, x),
1220 polygamma_positive(2, x),
1221 polygamma_positive(3, x),
1222 polygamma_positive(4, x),
1223 polygamma_positive(5, x),
1224 ]
1225}
1226
1227fn digamma_positive(mut x: f64) -> f64 {
1228 if !(x.is_finite() && x > 0.0) {
1229 return f64::NAN;
1230 }
1231 let mut acc = 0.0;
1232 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1233 acc -= 1.0 / x;
1234 x += 1.0;
1235 }
1236 acc + digamma_asymptotic(x)
1237}
1238
1239fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1240 if !(x.is_finite() && x > 0.0) {
1241 return f64::NAN;
1242 }
1243 let mut acc = 0.0;
1244 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1245 acc += polygamma_recurrence_term(order, x);
1246 x += 1.0;
1247 }
1248 acc + polygamma_asymptotic(order, x)
1249}
1250
1251const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1252const BERNOULLI_EVEN: [(usize, f64); 10] = [
1253 (2, 1.0 / 6.0),
1254 (4, -1.0 / 30.0),
1255 (6, 1.0 / 42.0),
1256 (8, -1.0 / 30.0),
1257 (10, 5.0 / 66.0),
1258 (12, -691.0 / 2730.0),
1259 (14, 7.0 / 6.0),
1260 (16, -3617.0 / 510.0),
1261 (18, 43867.0 / 798.0),
1262 (20, -174611.0 / 330.0),
1263];
1264
1265fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1266 let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1267 sign * factorial(order) / x.powi((order + 1) as i32)
1268}
1269
1270fn digamma_asymptotic(x: f64) -> f64 {
1271 let mut out = x.ln() - 0.5 / x;
1272 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1273 out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1274 }
1275 out
1276}
1277
1278fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1279 if !(1..=5).contains(&order) {
1280 return f64::NAN;
1281 }
1282
1283 let order_factorial = factorial(order);
1284 let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1285 let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1286 + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1287
1288 let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1289 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1290 let rising = rising_factorial(bernoulli_order, order);
1291 out += bernoulli_sign * bernoulli * rising
1292 / bernoulli_order as f64
1293 / x.powi((bernoulli_order + order) as i32);
1294 }
1295 out
1296}
1297
1298fn factorial(n: usize) -> f64 {
1299 (1..=n).fold(1.0, |acc, k| acc * k as f64)
1300}
1301
1302fn rising_factorial(start: usize, len: usize) -> f64 {
1303 (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1304}
1305
1306impl<const K: usize> std::ops::Add for Tower4<K> {
1307 type Output = Self;
1308 fn add(self, o: Self) -> Self {
1309 let mut out = self;
1310 out.v += o.v;
1311 for i in 0..K {
1312 out.g[i] += o.g[i];
1313 for j in 0..K {
1314 out.h[i][j] += o.h[i][j];
1315 for k in 0..K {
1316 out.t3[i][j][k] += o.t3[i][j][k];
1317 for l in 0..K {
1318 out.t4[i][j][k][l] += o.t4[i][j][k][l];
1319 }
1320 }
1321 }
1322 }
1323 out
1324 }
1325}
1326
1327impl<const K: usize> std::ops::Sub for Tower4<K> {
1328 type Output = Self;
1329 fn sub(self, o: Self) -> Self {
1330 self + o.scale(-1.0)
1331 }
1332}
1333
1334impl<const K: usize> std::ops::Neg for Tower4<K> {
1335 type Output = Self;
1336 fn neg(self) -> Self {
1337 self.scale(-1.0)
1338 }
1339}
1340
1341impl<const K: usize> std::ops::Mul for Tower4<K> {
1342 type Output = Self;
1343 fn mul(self, o: Self) -> Self {
1344 Tower4::mul(&self, &o)
1345 }
1346}
1347
1348impl<const K: usize> std::ops::Div for Tower4<K> {
1349 type Output = Self;
1350 fn div(self, o: Self) -> Self {
1351 Tower4::mul(&self, &o.recip())
1352 }
1353}
1354
1355impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1356 type Output = Self;
1357 fn add(self, c: f64) -> Self {
1358 let mut out = self;
1359 out.v += c;
1360 out
1361 }
1362}
1363
1364impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1365 type Output = Self;
1366 fn sub(self, c: f64) -> Self {
1367 self + (-c)
1368 }
1369}
1370
1371impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1372 type Output = Self;
1373 fn mul(self, c: f64) -> Self {
1374 self.scale(c)
1375 }
1376}
1377
1378// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1379//
1380// The flexible survival marginal-slope row loss is NOT a free composition
1381// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1382// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1383// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1384// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1385// last hand-written one in the codebase, and the genus of #736-class
1386// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1387// this reason). These two combinators close that gap: once the constraint
1388// `F` and the integrand/boundaries are themselves towers, the intercept's
1389// derivative tower and the integral's derivative tower come out EXACTLY at
1390// every order — there is no order left to hand-code and forget.
1391
1392/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1393/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1394/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1395/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1396/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1397/// and `a0` is that solved intercept value.
1398///
1399/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1400/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1401/// derivative. This is the mechanical replacement for the hand-coded
1402/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1403/// recursion (first_full.rs) and its third/fourth-order continuations.
1404///
1405/// Method: order-by-order substitution. We build `a` incrementally; at each
1406/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1407/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1408/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1409/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1410/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1411/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1412/// auditable and exact, not a hand-expanded formula per order.
1413///
1414/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1415/// solve's strict monotonicity guard.
1416///
1417/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1418/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1419/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1420/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1421/// is guarded explicitly and re-verified by a composed-residual self-check.
1422pub fn implicit_solve<const K1: usize, const K: usize>(
1423 f: &Tower4<K1>,
1424 a0: f64,
1425) -> Result<Tower4<K>, String> {
1426 assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1427 let f_a = f.g[0];
1428 if f_a == 0.0 || !f_a.is_finite() {
1429 return Err(format!(
1430 "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1431 ));
1432 }
1433 // The expansion point must be a genuine root of F. The single Newton
1434 // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1435 // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1436 // Newton step (rather than f.v directly) makes the criterion invariant to
1437 // the magnitude of f_a / the units of F.
1438 let root_tol = 1e-9;
1439 if !f.v.is_finite() {
1440 return Err(format!(
1441 "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1442 f.v
1443 ));
1444 }
1445 let newton_step = f.v.abs() / f_a.abs();
1446 if newton_step > root_tol * (1.0 + a0.abs()) {
1447 return Err(format!(
1448 "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1449 F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1450 root_tol {root_tol:.1e} · (1 + |a0|)",
1451 f.v
1452 ));
1453 }
1454 // Start with a = constant a0 (correct through order 0). Then lift each
1455 // order in turn. Because substitute_intercept reads `a`'s order-≤m
1456 // tensors when forming G's order-m coefficient, and the order-m
1457 // coefficient of G depends on a's order-m tensor ONLY through the linear
1458 // F_a·a_m term, a single corrective pass per order is exact.
1459 let mut a = Tower4::<K>::constant(a0);
1460 for order in 1..=4 {
1461 let g = substitute_intercept(f, &a);
1462 // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1463 // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1464 // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1465 // as a plain variable in the substitution's first-order part).
1466 match order {
1467 1 => {
1468 for i in 0..K {
1469 a.g[i] -= g.g[i] / f_a;
1470 }
1471 }
1472 2 => {
1473 for i in 0..K {
1474 for j in 0..K {
1475 a.h[i][j] -= g.h[i][j] / f_a;
1476 }
1477 }
1478 }
1479 3 => {
1480 for i in 0..K {
1481 for j in 0..K {
1482 for k in 0..K {
1483 a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1484 }
1485 }
1486 }
1487 }
1488 _ => {
1489 for i in 0..K {
1490 for j in 0..K {
1491 for k in 0..K {
1492 for l in 0..K {
1493 a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1494 }
1495 }
1496 }
1497 }
1498 }
1499 }
1500 }
1501 // Self-check: the composed residual G = F∘a must vanish through order 4.
1502 // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1503 // is exactly the root requirement guarded above. Re-verify all channels
1504 // against a scale-aware floor so any arithmetic regression in the
1505 // substitution recursion is loud rather than silently shipping a
1506 // level-set expansion.
1507 let g = substitute_intercept(f, &a);
1508 let resid_tol = 1e-7 * (1.0 + f_a.abs());
1509 let mut worst = g.v.abs();
1510 for i in 0..K {
1511 worst = worst.max(g.g[i].abs());
1512 for j in 0..K {
1513 worst = worst.max(g.h[i][j].abs());
1514 for k in 0..K {
1515 worst = worst.max(g.t3[i][j][k].abs());
1516 for l in 0..K {
1517 worst = worst.max(g.t4[i][j][k][l].abs());
1518 }
1519 }
1520 }
1521 }
1522 if !worst.is_finite() || worst > resid_tol {
1523 return Err(format!(
1524 "implicit_solve: composed residual G = F∘a does not vanish: \
1525 worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1526 ));
1527 }
1528 Ok(a)
1529}
1530
1531/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1532/// written over `K + 1` variables, returning the composite tower over the
1533/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1534///
1535/// This is the exact multivariate chain rule specialised to "slot 0 is a
1536/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1537/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1538/// point, with the slot-0 increment being the non-constant part of `a` and
1539/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1540/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1541/// so it carries derivatives exactly through order four.
1542pub fn substitute_intercept<const K1: usize, const K: usize>(
1543 f: &Tower4<K1>,
1544 a: &Tower4<K>,
1545) -> Tower4<K> {
1546 assert_eq!(K1, K + 1);
1547 // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1548 // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1549 // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1550 // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1551 // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1552 // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1553 // using tower products for the increment monomials so all θ-derivatives
1554 // propagate exactly.
1555 let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1556 if slot == 0 {
1557 // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1558 let mut d = *a;
1559 d.v = 0.0;
1560 d
1561 } else {
1562 // slot i: the increment δθ_{i-1} = seeded variable minus value.
1563 // θ centred at its expansion value has zero constant term and unit
1564 // first derivative in its own slot.
1565 let mut d = Tower4::<K>::zero();
1566 d.g[slot - 1] = 1.0;
1567 d
1568 }
1569 });
1570 // Accumulate the Taylor sum. order-0 term:
1571 let mut out = Tower4::<K>::constant(f.v);
1572 // order 1: Σ_a f.g[a] · inp[a]
1573 for a_idx in 0..K1 {
1574 out = out + inp[a_idx].scale(f.g[a_idx]);
1575 }
1576 // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1577 for a_idx in 0..K1 {
1578 for b_idx in 0..K1 {
1579 let prod = inp[a_idx].mul(&inp[b_idx]);
1580 out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1581 }
1582 }
1583 // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1584 for a_idx in 0..K1 {
1585 for b_idx in 0..K1 {
1586 for c_idx in 0..K1 {
1587 let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1588 out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1589 }
1590 }
1591 }
1592 // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1593 for a_idx in 0..K1 {
1594 for b_idx in 0..K1 {
1595 for c_idx in 0..K1 {
1596 for d_idx in 0..K1 {
1597 let prod = inp[a_idx]
1598 .mul(&inp[b_idx])
1599 .mul(&inp[c_idx])
1600 .mul(&inp[d_idx]);
1601 out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1602 }
1603 }
1604 }
1605 }
1606 out
1607}
1608
1609/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1610/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1611/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1612/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1613/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1614///
1615/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1616/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1617/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1618/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1619/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1620/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1621/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1622/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1623/// to a constant); callers read only the derivative channels and pair this
1624/// with the interior moment-integral value separately.
1625///
1626/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1627/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1628/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1629pub fn moving_limit_boundary_tower<const K: usize>(
1630 z_edge: &Tower4<K>,
1631 b_stack: [f64; 4],
1632) -> Tower4<K> {
1633 z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1634}
1635
1636/// The boundary-flux derivative tower of a single moving cell integral
1637/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1638/// two edge towers and the integrand stacks at each edge. The returned
1639/// tower's derivative channels are the EXACT moving-boundary contribution to
1640/// every θ-derivative of the cell integral, to fourth order, with no term
1641/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1642/// derivative channels are all zero, contributing nothing — matching the
1643/// production `edge_vel = 0` short-circuit.
1644pub fn cell_moving_boundary_flux_tower<const K: usize>(
1645 z_right: &Tower4<K>,
1646 b_stack_right: [f64; 4],
1647 z_left: &Tower4<K>,
1648 b_stack_left: [f64; 4],
1649) -> Tower4<K> {
1650 moving_limit_boundary_tower(z_right, b_stack_right)
1651 - moving_limit_boundary_tower(z_left, b_stack_left)
1652}
1653
1654/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1655///
1656/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1657/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1658/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1659/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1660/// integrand coefficients move with η, hence with the primaries), so the
1661/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1662/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1663/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1664/// edge-motion-carrying terms the hand path assembles one by one, including the
1665/// `G·z_uv` term the directional path drops).
1666///
1667/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1668/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1669/// part — everything carrying edge motion — is exactly
1670/// `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1671/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1672/// interior moment integral already supplies. Both are one
1673/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1674/// θ in slots 1..K): substituting the edge tower gives the full composite,
1675/// substituting a frozen constant edge isolates the pure-θ part, and their
1676/// difference is the exact boundary flux — every Leibniz term derived by the
1677/// substitution algebra, none hand-omitted.
1678///
1679/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1680/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1681/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1682/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1683/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1684/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1685/// the derivative channels are meaningful, matching the value-less convention of
1686/// [`moving_limit_boundary_tower`].
1687pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1688 phi_jet: &Tower4<K1>,
1689 z_edge: &Tower4<K>,
1690) -> Tower4<K> {
1691 assert_eq!(
1692 K1,
1693 K + 1,
1694 "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1695 );
1696 let frozen_edge = Tower4::<K>::constant(z_edge.v);
1697 let full = substitute_intercept(phi_jet, z_edge);
1698 let interior = substitute_intercept(phi_jet, &frozen_edge);
1699 full - interior
1700}
1701
1702/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1703/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1704/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1705/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1706/// so its `full` and `interior` substitutions coincide and it contributes
1707/// nothing — matching the production `edge_vel = 0` short-circuit.
1708pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1709 phi_jet_right: &Tower4<K1>,
1710 z_right: &Tower4<K>,
1711 phi_jet_left: &Tower4<K1>,
1712 z_left: &Tower4<K>,
1713) -> Tower4<K> {
1714 moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1715 - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1716}
1717
1718// ── The program seam ─────────────────────────────────────────────────
1719
1720/// A family's row negative log-likelihood written ONCE over tower scalars.
1721///
1722/// This is the single source of truth #932 asks for: the value channel of
1723/// the returned tower must BE the production row NLL (same branches, same
1724/// guards, same numerics), and every derivative channel is then exact by
1725/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1726/// NOT part of this trait — it is family data, not calculus, and stays on
1727/// the `RowKernel` implementor.
1728pub trait RowNllProgram<const K: usize>: Send + Sync {
1729 /// Number of observations the program covers.
1730 fn n_rows(&self) -> usize;
1731
1732 /// Current primary-scalar values for `row` (where to seed the tower).
1733 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1734
1735 /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1736 /// variable `a` at the current primary value; implementations combine
1737 /// them with `Tower4` arithmetic and per-row data (response, censoring
1738 /// indicators, offsets) entering as constants.
1739 fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1740}
1741
1742/// Evaluate a program's full tower at the current primaries for one row.
1743///
1744/// One call yields every `RowKernel` calculus channel; callers that need
1745/// several contractions of the same row should hold the returned tower and
1746/// contract repeatedly rather than re-evaluating.
1747pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1748 prog: &P,
1749 row: usize,
1750) -> Result<Tower4<K>, String> {
1751 let p = prog.primaries(row)?;
1752 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1753 prog.row_nll(row, &vars)
1754}
1755
1756/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1757pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1758 prog: &P,
1759 row: usize,
1760) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1761 let t = evaluate_program(prog, row)?;
1762 Ok((t.v, t.g, t.h))
1763}
1764
1765/// Mechanically derived `row_third_contracted` channel.
1766pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1767 prog: &P,
1768 row: usize,
1769 dir: &[f64; K],
1770) -> Result<[[f64; K]; K], String> {
1771 Ok(evaluate_program(prog, row)?.third_contracted(dir))
1772}
1773
1774/// Mechanically derived `row_fourth_contracted` channel.
1775pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1776 prog: &P,
1777 row: usize,
1778 dir_u: &[f64; K],
1779 dir_v: &[f64; K],
1780) -> Result<[[f64; K]; K], String> {
1781 Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1782}
1783
1784// ── The generic program seam (#932 scalar cutover) ───────────────────
1785
1786/// A family's row negative log-likelihood written ONCE over the generic
1787/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1788/// re-instantiated at whatever order / representation a consumer needs
1789/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1790/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1791/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1792/// [`Tower4`] for every channel at once).
1793///
1794/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1795/// program implementing this generic trait gets the small contracted scalars for
1796/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1797/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1798/// work unchanged; new families should prefer this generic trait.
1799///
1800/// Because a `Tower4`-specialised `row_nll` body uses only
1801/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1802/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1803/// A program written that way needs no `Tower4`-specialised method and routes
1804/// the directional and joint-Hessian gates through the contracted scalars from
1805/// a single definition.
1806pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1807 /// Number of observations the program covers.
1808 fn n_rows(&self) -> usize;
1809
1810 /// Current primary-scalar values for `row` (where to seed the scalar).
1811 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1812
1813 /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1814 /// (base value + per-scalar nilpotent directions) by the caller; the body
1815 /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1816 /// (response, censoring, offsets) entering as constants.
1817 fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1818 &self,
1819 row: usize,
1820 p: &[S; K],
1821 ) -> Result<S, String>;
1822}
1823
1824/// Evaluate a generic program at the value/gradient/Hessian scalar
1825/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1826/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1827///
1828/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1829/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1830/// cheap order-2 scalar. The single source of truth means the gradient and
1831/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1832pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1833 prog: &P,
1834 row: usize,
1835) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1836 let base = prog.primaries(row)?;
1837 let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1838 <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1839 });
1840 let s = prog.row_nll_generic(row, &vars)?;
1841 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1842}
1843
1844/// Evaluate a generic program at the one-seed scalar
1845/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1846/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1847/// materialising the dense `t3` tensor. The contraction direction is folded
1848/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1849pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1850 prog: &P,
1851 row: usize,
1852 dir: &[f64; K],
1853) -> Result<[[f64; K]; K], String> {
1854 let base = prog.primaries(row)?;
1855 let vars: [crate::jet_scalar::OneSeed<K>; K] =
1856 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1857 let s = prog.row_nll_generic(row, &vars)?;
1858 Ok(s.contracted_third())
1859}
1860
1861/// Evaluate a generic program at the two-seed scalar
1862/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1863/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1864/// WITHOUT materialising the dense `t4` tensor.
1865pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1866 prog: &P,
1867 row: usize,
1868 dir_u: &[f64; K],
1869 dir_v: &[f64; K],
1870) -> Result<[[f64; K]; K], String> {
1871 let base = prog.primaries(row)?;
1872 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1873 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1874 let s = prog.row_nll_generic(row, &vars)?;
1875 Ok(s.contracted_fourth())
1876}
1877
1878/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1879/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1880/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1881/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1882/// the order-2 / contracted scalars consume, so there is a single source of
1883/// truth across every channel.
1884pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1885 prog: &P,
1886 row: usize,
1887) -> Result<Tower4<K>, String> {
1888 let base = prog.primaries(row)?;
1889 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1890 prog.row_nll_generic(row, &vars)
1891}
1892
1893// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1894//
1895// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1896// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1897// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1898// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1899// as separate inherent methods, not a shared trait bound — so a row-NLL body
1900// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1901// 4-rows-per-pass SIMD batch path could not reuse the single source.
1902//
1903// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1904// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1905// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1906// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1907// (so the scalar call sites compile unchanged and bit-identically), and explicit
1908// impls route the `f64x4` lane towers through their existing per-lane methods. A
1909// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1910// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1911
1912/// The verdict of a per-lane [`RowJet::guard`] domain check.
1913///
1914/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1915/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1916/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1917/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1918/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1919/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1920#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1921pub struct GuardVerdict {
1922 lanes: u8,
1923 failed_mask: u8,
1924}
1925
1926impl GuardVerdict {
1927 /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1928 #[inline]
1929 pub fn scalar(pass: bool) -> Self {
1930 Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1931 }
1932 /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1933 #[inline]
1934 pub fn lanes4(failed_mask: u8) -> Self {
1935 Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1936 }
1937 /// Number of active lanes inspected (1 scalar, 4 batch).
1938 #[inline]
1939 pub fn lanes(self) -> usize {
1940 self.lanes as usize
1941 }
1942 /// True iff every inspected lane satisfied the predicate.
1943 #[inline]
1944 pub fn all_pass(self) -> bool {
1945 self.failed_mask == 0
1946 }
1947 /// True iff at least one inspected lane failed the predicate.
1948 #[inline]
1949 pub fn any_failed(self) -> bool {
1950 self.failed_mask != 0
1951 }
1952 /// True iff lane `i` failed the predicate.
1953 #[inline]
1954 pub fn lane_failed(self, i: usize) -> bool {
1955 (self.failed_mask >> i) & 1 == 1
1956 }
1957 /// The raw failure mask (bit `i` ⇒ lane `i` failed).
1958 #[inline]
1959 pub fn failed_mask(self) -> u8 {
1960 self.failed_mask
1961 }
1962}
1963
1964/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
1965/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
1966/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
1967/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
1968/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
1969/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
1970/// common `N == 5` case is bit-identical to passing the stack straight through).
1971#[inline]
1972fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
1973 let mut out = [0.0_f64; M];
1974 let m = N.min(M);
1975 out[..m].copy_from_slice(&s[..m]);
1976 out
1977}
1978
1979/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
1980/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
1981/// rows/pass without a dual-source copy (module §"The RowJet bridge").
1982///
1983/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
1984/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
1985/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
1986/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
1987/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
1988pub trait RowJet<const K: usize>: Copy {
1989 /// The value channel(s) seen by [`guard`](Self::guard) and
1990 /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
1991 /// `f64x4` lane tower.
1992 type Value: Copy;
1993
1994 /// A constant (value `c`, all derivatives zero), broadcast to every lane.
1995 fn constant(c: f64) -> Self;
1996 /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
1997 /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
1998 /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
1999 /// [`generic_batched_third_tower`]), which build the tower variables directly
2000 /// from each row's primaries; this method is for any row-invariant auxiliary
2001 /// variable a body introduces.
2002 fn variable(x: f64, slot: usize) -> Self;
2003 /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
2004 fn values(&self) -> Self::Value;
2005
2006 /// Truncated Leibniz `self + o`.
2007 fn add(&self, o: &Self) -> Self;
2008 /// Truncated Leibniz `self − o`.
2009 fn sub(&self, o: &Self) -> Self;
2010 /// Truncated Leibniz `self · o`.
2011 fn mul(&self, o: &Self) -> Self;
2012 /// Multiply every channel by the plain scalar `s`.
2013 fn scale(&self, s: f64) -> Self;
2014 /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
2015 /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2016 fn neg(&self) -> Self {
2017 self.scale(-1.0)
2018 }
2019
2020 /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2021 /// derivative stack is built from the running base value PER LANE through
2022 /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2023 /// inherent method that already exists on both the scalar towers and the lane
2024 /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2025 /// lane tower it is re-run per lane (the four rows carry four distinct base
2026 /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2027 /// Making it a trait method is precisely what lets a body written once over
2028 /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2029 /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2030 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2031
2032 /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2033 /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2034 /// its one value; a lane tower checks all four. Lets a batched program detect
2035 /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2036 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2037
2038 /// Per-lane scale: multiply every channel by the per-lane factor `s`
2039 /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2040 /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2041 /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2042 /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2043 /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2044 /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2045 /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2046 /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2047 /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2048 fn scale_rows(&self, s: Self::Value) -> Self;
2049
2050 /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2051 /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2052 /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2053 /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2054 /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2055 /// knowing the concrete representation: the program holds the per-row data and
2056 /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2057 /// [`RowNllProgramRowJet::row_nll`], so the body writes
2058 /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2059 /// buried in a `compose_unary_with` stack is pulled out the same way:
2060 /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2061 /// (Binary per-row branches such as the event indicator `di` are kept
2062 /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2063 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2064
2065 // ── value-derived transcendental conveniences ───────────────────────
2066 // Each routes through `compose_unary_with` with the SAME derivative stack the
2067 // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2068 // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2069 // `i` is bit-identical to the scalar result on row `i`.
2070
2071 /// `e^self`.
2072 fn exp(&self) -> Self {
2073 self.compose_unary_with(|u| {
2074 let e = u.exp();
2075 [e, e, e, e, e]
2076 })
2077 }
2078 /// `ln(self)`. Caller guarantees positivity.
2079 fn ln(&self) -> Self {
2080 self.compose_unary_with(|u| {
2081 let r = 1.0 / u;
2082 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2083 })
2084 }
2085 /// `√self`. Caller guarantees positivity.
2086 fn sqrt(&self) -> Self {
2087 self.compose_unary_with(|u| {
2088 let s = u.sqrt();
2089 [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2090 })
2091 }
2092 /// `1/self`.
2093 fn recip(&self) -> Self {
2094 self.compose_unary_with(|u| {
2095 let r = 1.0 / u;
2096 let r2 = r * r;
2097 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2098 })
2099 }
2100 /// `self^a` for real `a`. Caller guarantees a positive base.
2101 fn powf(&self, a: f64) -> Self {
2102 self.compose_unary_with(move |u| {
2103 [
2104 u.powf(a),
2105 a * u.powf(a - 1.0),
2106 a * (a - 1.0) * u.powf(a - 2.0),
2107 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2108 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2109 ]
2110 })
2111 }
2112 /// `ln Γ(self)`. Caller guarantees a positive argument.
2113 fn ln_gamma(&self) -> Self {
2114 self.compose_unary_with(ln_gamma_derivative_stack)
2115 }
2116 /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2117 fn digamma(&self) -> Self {
2118 self.compose_unary_with(digamma_derivative_stack)
2119 }
2120}
2121
2122/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2123/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2124/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2125/// bridge adds the lane representation without churning the scalar path. (The
2126/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2127/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2128/// any downstream impl, so the coherence checker proves the impls disjoint.)
2129impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2130 type Value = f64;
2131 #[inline]
2132 fn constant(c: f64) -> Self {
2133 <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2134 }
2135 #[inline]
2136 fn variable(x: f64, slot: usize) -> Self {
2137 <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2138 }
2139 #[inline]
2140 fn values(&self) -> f64 {
2141 crate::jet_scalar::JetScalar::value(self)
2142 }
2143 #[inline]
2144 fn add(&self, o: &Self) -> Self {
2145 crate::jet_scalar::JetScalar::add(self, o)
2146 }
2147 #[inline]
2148 fn sub(&self, o: &Self) -> Self {
2149 crate::jet_scalar::JetScalar::sub(self, o)
2150 }
2151 #[inline]
2152 fn mul(&self, o: &Self) -> Self {
2153 crate::jet_scalar::JetScalar::mul(self, o)
2154 }
2155 #[inline]
2156 fn scale(&self, s: f64) -> Self {
2157 crate::jet_scalar::JetScalar::scale(self, s)
2158 }
2159 #[inline]
2160 fn neg(&self) -> Self {
2161 crate::jet_scalar::JetScalar::neg(self)
2162 }
2163 #[inline]
2164 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2165 crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2166 }
2167 #[inline]
2168 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2169 GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2170 }
2171 #[inline]
2172 fn scale_rows(&self, s: f64) -> Self {
2173 // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2174 // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2175 crate::jet_scalar::JetScalar::scale(self, s)
2176 }
2177 #[inline]
2178 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2179 value_of(rows[0])
2180 }
2181}
2182
2183/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2184/// routing each op through its existing per-lane method. Lane `i` of a batched
2185/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2186/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2187impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2188 type Value = [f64; 4];
2189 #[inline]
2190 fn constant(c: f64) -> Self {
2191 Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2192 }
2193 #[inline]
2194 fn variable(x: f64, slot: usize) -> Self {
2195 Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2196 }
2197 #[inline]
2198 fn values(&self) -> [f64; 4] {
2199 self.v.to_array()
2200 }
2201 #[inline]
2202 fn add(&self, o: &Self) -> Self {
2203 Tower4Lane::add(self, o)
2204 }
2205 #[inline]
2206 fn sub(&self, o: &Self) -> Self {
2207 Tower4Lane::sub(self, o)
2208 }
2209 #[inline]
2210 fn mul(&self, o: &Self) -> Self {
2211 Tower4Lane::mul(self, o)
2212 }
2213 #[inline]
2214 fn scale(&self, s: f64) -> Self {
2215 Tower4Lane::scale(self, s)
2216 }
2217 #[inline]
2218 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2219 Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2220 }
2221 #[inline]
2222 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2223 let vals = self.v.to_array();
2224 let mut mask = 0u8;
2225 for (i, &v) in vals.iter().enumerate() {
2226 if !pred(v) {
2227 mask |= 1 << i;
2228 }
2229 }
2230 GuardVerdict::lanes4(mask)
2231 }
2232 #[inline]
2233 fn scale_rows(&self, s: [f64; 4]) -> Self {
2234 // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2235 // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2236 let sl = wide::f64x4::new(s);
2237 let mut out = *self;
2238 out.v = self.v * sl;
2239 for i in 0..K {
2240 out.g[i] = self.g[i] * sl;
2241 for j in 0..K {
2242 out.h[i][j] = self.h[i][j] * sl;
2243 for k in 0..K {
2244 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2245 for l in 0..K {
2246 out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2247 }
2248 }
2249 }
2250 }
2251 out
2252 }
2253 #[inline]
2254 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2255 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2256 }
2257}
2258
2259/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2260/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2261/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2262/// [`Tower3`] which also carries only up to the third tensor.
2263impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2264 type Value = [f64; 4];
2265 #[inline]
2266 fn constant(c: f64) -> Self {
2267 Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2268 }
2269 #[inline]
2270 fn variable(x: f64, slot: usize) -> Self {
2271 Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2272 }
2273 #[inline]
2274 fn values(&self) -> [f64; 4] {
2275 self.v.to_array()
2276 }
2277 #[inline]
2278 fn add(&self, o: &Self) -> Self {
2279 Tower3Lane::add(self, o)
2280 }
2281 #[inline]
2282 fn sub(&self, o: &Self) -> Self {
2283 Tower3Lane::sub(self, o)
2284 }
2285 #[inline]
2286 fn mul(&self, o: &Self) -> Self {
2287 Tower3Lane::mul(self, o)
2288 }
2289 #[inline]
2290 fn scale(&self, s: f64) -> Self {
2291 Tower3Lane::scale(self, s)
2292 }
2293 #[inline]
2294 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2295 Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2296 }
2297 #[inline]
2298 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2299 let vals = self.v.to_array();
2300 let mut mask = 0u8;
2301 for (i, &v) in vals.iter().enumerate() {
2302 if !pred(v) {
2303 mask |= 1 << i;
2304 }
2305 }
2306 GuardVerdict::lanes4(mask)
2307 }
2308 #[inline]
2309 fn scale_rows(&self, s: [f64; 4]) -> Self {
2310 let sl = wide::f64x4::new(s);
2311 let mut out = *self;
2312 out.v = self.v * sl;
2313 for i in 0..K {
2314 out.g[i] = self.g[i] * sl;
2315 for j in 0..K {
2316 out.h[i][j] = self.h[i][j] * sl;
2317 for k in 0..K {
2318 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2319 }
2320 }
2321 }
2322 out
2323 }
2324 #[inline]
2325 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2326 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2327 }
2328}
2329
2330/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2331/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2332/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2333/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2334/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2335/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2336/// and the batched channels through [`generic_batched_fourth_tower`] /
2337/// [`generic_batched_third_tower`], all from a single source.
2338pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2339 /// Number of observations the program covers.
2340 fn n_rows(&self) -> usize;
2341
2342 /// Current primary-scalar values for `row` (where to seed each lane).
2343 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2344
2345 /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2346 /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2347 /// pre-seeded by the caller (base value plus, for the directional scalars, the
2348 /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2349 /// per-row data entering through `rows`/`self` as constants.
2350 fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2351}
2352
2353/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2354/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2355/// `RowJet` twin of [`generic_row_kernel`].
2356pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2357 prog: &P,
2358 row: usize,
2359) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2360 let base = prog.primaries(row)?;
2361 let vars: [crate::jet_scalar::Order2<K>; K] =
2362 std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2363 let s = prog.row_nll(&[row], &vars)?;
2364 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2365}
2366
2367/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2368/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2369/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2370pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2371 prog: &P,
2372 row: usize,
2373 dir: &[f64; K],
2374) -> Result<[[f64; K]; K], String> {
2375 let base = prog.primaries(row)?;
2376 let vars: [crate::jet_scalar::OneSeed<K>; K] =
2377 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2378 let s = prog.row_nll(&[row], &vars)?;
2379 Ok(s.contracted_third())
2380}
2381
2382/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2383/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2384/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2385pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2386 prog: &P,
2387 row: usize,
2388 dir_u: &[f64; K],
2389 dir_v: &[f64; K],
2390) -> Result<[[f64; K]; K], String> {
2391 let base = prog.primaries(row)?;
2392 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2393 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2394 let s = prog.row_nll(&[row], &vars)?;
2395 Ok(s.contracted_fourth())
2396}
2397
2398/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2399/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2400/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2401/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2402/// the scalar [`generic_full_tower`] on `rows[i]`.
2403pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2404 prog: &P,
2405 rows: [usize; 4],
2406) -> Result<Tower4Batch<K>, String> {
2407 let bases: [[f64; K]; 4] = [
2408 prog.primaries(rows[0])?,
2409 prog.primaries(rows[1])?,
2410 prog.primaries(rows[2])?,
2411 prog.primaries(rows[3])?,
2412 ];
2413 let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2414 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2415 Tower4Batch::variable(lane_vals, a)
2416 });
2417 prog.row_nll(&rows, &vars)
2418}
2419
2420/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2421/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2422/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2423/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2424pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2425 prog: &P,
2426 rows: [usize; 4],
2427) -> Result<Tower3Batch<K>, String> {
2428 let bases: [[f64; K]; 4] = [
2429 prog.primaries(rows[0])?,
2430 prog.primaries(rows[1])?,
2431 prog.primaries(rows[2])?,
2432 prog.primaries(rows[3])?,
2433 ];
2434 let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2435 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2436 Tower3Batch::variable(lane_vals, a)
2437 });
2438 prog.row_nll(&rows, &vars)
2439}
2440
2441// ── The oracle ───────────────────────────────────────────────────────
2442
2443/// One row's worth of hand-written kernel outputs, as claimed by a
2444/// `RowKernel` implementation, packaged for verification against the
2445/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2446/// its visibility — can be audited from its own test module.
2447pub struct KernelChannels<const K: usize> {
2448 /// Claimed `(nll, ∇, H)` from `row_kernel`.
2449 pub value: f64,
2450 /// Claimed gradient.
2451 pub gradient: [f64; K],
2452 /// Claimed Hessian.
2453 pub hessian: [[f64; K]; K],
2454 /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2455 pub third: Vec<([f64; K], [[f64; K]; K])>,
2456 /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2457 pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2458}
2459
2460/// Channel-by-channel audit of a hand-written kernel against the
2461/// single-expression tower truth. Returns `Err` naming the first channel,
2462/// index, claimed and true values on disagreement — designed as the body
2463/// of the per-family CI oracle tests (#932 deployment step 2).
2464///
2465/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2466/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2467/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2468/// towers pass without demanding bit-equality, while a tiny cross-block
2469/// entry dropped next to a huge one is still caught (it is NOT measured
2470/// against the largest entry of the whole channel — there is no per-channel
2471/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2472///
2473/// Non-finite handling is strict: a NaN on either side always fails; an
2474/// infinity passes only when both sides are the SAME signed infinity.
2475pub fn verify_kernel_channels<const K: usize>(
2476 tower: &Tower4<K>,
2477 claims: &KernelChannels<K>,
2478 rel_tol: f64,
2479) -> Result<(), String> {
2480 // Absolute floor: reuse rel_tol so a single knob controls both the
2481 // relative band and the absolute floor for entries near zero.
2482 let atol = rel_tol;
2483 let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2484 // Non-finite values never silently pass the algebraic comparison
2485 // below (any comparison with NaN is false). Handle them explicitly:
2486 // NaN on either side always errs; an infinity passes only if both
2487 // sides are the identical signed infinity.
2488 if !claim.is_finite() || !truth.is_finite() {
2489 let agree = claim.is_infinite()
2490 && truth.is_infinite()
2491 && claim.is_sign_positive() == truth.is_sign_positive();
2492 if agree {
2493 return Ok(());
2494 }
2495 return Err(format!(
2496 "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2497 ));
2498 }
2499 let band = atol + rel_tol * claim.abs().max(truth.abs());
2500 if (claim - truth).abs() > band {
2501 return Err(format!(
2502 "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2503 ));
2504 }
2505 Ok(())
2506 };
2507
2508 check("value", claims.value, tower.v)?;
2509
2510 for a in 0..K {
2511 check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2512 }
2513
2514 for a in 0..K {
2515 for b in 0..K {
2516 check(
2517 &format!("hessian[{a}][{b}]"),
2518 claims.hessian[a][b],
2519 tower.h[a][b],
2520 )?;
2521 }
2522 }
2523
2524 for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2525 let truth = tower.third_contracted(dir);
2526 for a in 0..K {
2527 for b in 0..K {
2528 check(
2529 &format!("third[{t_idx}][{a}][{b}]"),
2530 claim[a][b],
2531 truth[a][b],
2532 )?;
2533 }
2534 }
2535 }
2536
2537 for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2538 let truth = tower.fourth_contracted(u, w);
2539 for a in 0..K {
2540 for b in 0..K {
2541 check(
2542 &format!("fourth[{f_idx}][{a}][{b}]"),
2543 claim[a][b],
2544 truth[a][b],
2545 )?;
2546 }
2547 }
2548 }
2549
2550 Ok(())
2551}
2552
2553// ===========================================================================
2554// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2555// ===========================================================================
2556//
2557// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2558// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2559// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2560// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2561// vector pass instead of one per scalar pass.
2562//
2563// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2564// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2565// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2566// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2567// performs no fp-contraction), so lane `i` of any channel of a
2568// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2569// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2570// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2571// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2572// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2573//
2574// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2575// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2576// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2577// exactly as the scalar path already does).
2578
2579use crate::jet_scalar::Lane;
2580
2581/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2582/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2583/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2584#[derive(Clone, Copy)]
2585pub struct Tower4Lane<L: Lane, const K: usize> {
2586 /// Value channel (one entry per lane/row).
2587 pub v: L,
2588 /// Gradient `∂/∂p_a`.
2589 pub g: [L; K],
2590 /// Hessian `∂²/∂p_a∂p_b`.
2591 pub h: [[L; K]; K],
2592 /// Third tensor `∂³`.
2593 pub t3: [[[L; K]; K]; K],
2594 /// Fourth tensor `∂⁴`.
2595 pub t4: [[[[L; K]; K]; K]; K],
2596}
2597
2598/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2599pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2600
2601impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2602 /// All-zero tower (every channel `+0.0` in every lane).
2603 #[inline]
2604 pub fn zero() -> Self {
2605 let z = L::splat(0.0);
2606 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2607 }
2608 /// Constant `c` (per lane): value channel only.
2609 #[inline]
2610 pub fn constant(c: L) -> Self {
2611 let mut o = Self::zero();
2612 o.v = c;
2613 o
2614 }
2615 /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2616 /// slot `idx` (mirrors [`Tower4::variable`]).
2617 #[inline]
2618 pub fn variable(value: L, idx: usize) -> Self {
2619 let mut o = Self::constant(value);
2620 o.g[idx] = L::splat(1.0);
2621 o
2622 }
2623 /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2624 #[inline]
2625 pub fn lane(&self, i: usize) -> Tower4<K> {
2626 let mut out = Tower4::<K>::zero();
2627 out.v = self.v.lane(i);
2628 for a in 0..K {
2629 out.g[a] = self.g[a].lane(i);
2630 for b in 0..K {
2631 out.h[a][b] = self.h[a][b].lane(i);
2632 for c in 0..K {
2633 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2634 for d in 0..K {
2635 out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2636 }
2637 }
2638 }
2639 }
2640 out
2641 }
2642 /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2643 #[inline]
2644 pub fn add(&self, o: &Self) -> Self {
2645 let mut out = *self;
2646 out.v = self.v.add(o.v);
2647 for i in 0..K {
2648 out.g[i] = self.g[i].add(o.g[i]);
2649 for j in 0..K {
2650 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2651 for k in 0..K {
2652 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2653 for l in 0..K {
2654 out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2655 }
2656 }
2657 }
2658 }
2659 out
2660 }
2661 /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2662 #[inline]
2663 pub fn sub(&self, o: &Self) -> Self {
2664 let mut out = *self;
2665 out.v = self.v.sub(o.v);
2666 for i in 0..K {
2667 out.g[i] = self.g[i].sub(o.g[i]);
2668 for j in 0..K {
2669 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2670 for k in 0..K {
2671 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2672 for l in 0..K {
2673 out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2674 }
2675 }
2676 }
2677 }
2678 out
2679 }
2680 /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2681 #[inline]
2682 pub fn scale(&self, s: f64) -> Self {
2683 let sl = L::splat(s);
2684 let mut out = *self;
2685 out.v = self.v.mul(sl);
2686 for i in 0..K {
2687 out.g[i] = self.g[i].mul(sl);
2688 for j in 0..K {
2689 out.h[i][j] = self.h[i][j].mul(sl);
2690 for k in 0..K {
2691 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2692 for l in 0..K {
2693 out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2694 }
2695 }
2696 }
2697 }
2698 out
2699 }
2700 /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2701 #[inline]
2702 pub fn mul(&self, o: &Self) -> Self {
2703 let a = self;
2704 let b = o;
2705 let mut out = Self::zero();
2706 out.v = a.v.mul(b.v);
2707 for i in 0..K {
2708 let mut acc = L::splat(0.0);
2709 acc = acc.add(a.v.mul(b.g[i]));
2710 acc = acc.add(a.g[i].mul(b.v));
2711 out.g[i] = acc;
2712 }
2713 for i in 0..K {
2714 for j in 0..K {
2715 let mut acc = L::splat(0.0);
2716 acc = acc.add(a.v.mul(b.h[i][j]));
2717 acc = acc.add(a.g[i].mul(b.g[j]));
2718 acc = acc.add(a.g[j].mul(b.g[i]));
2719 acc = acc.add(a.h[i][j].mul(b.v));
2720 out.h[i][j] = acc;
2721 }
2722 }
2723 for i in 0..K {
2724 for j in 0..K {
2725 for k in 0..K {
2726 let mut acc = L::splat(0.0);
2727 acc = acc.add(a.v.mul(b.t3[i][j][k]));
2728 acc = acc.add(a.g[i].mul(b.h[j][k]));
2729 acc = acc.add(a.g[j].mul(b.h[i][k]));
2730 acc = acc.add(a.h[i][j].mul(b.g[k]));
2731 acc = acc.add(a.g[k].mul(b.h[i][j]));
2732 acc = acc.add(a.h[i][k].mul(b.g[j]));
2733 acc = acc.add(a.h[j][k].mul(b.g[i]));
2734 acc = acc.add(a.t3[i][j][k].mul(b.v));
2735 out.t3[i][j][k] = acc;
2736 }
2737 }
2738 }
2739 for i in 0..K {
2740 for j in 0..K {
2741 for k in 0..K {
2742 for l in 0..K {
2743 let mut acc = L::splat(0.0);
2744 acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2745 acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2746 acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2747 acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2748 acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2749 acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2750 acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2751 acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2752 acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2753 acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2754 acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2755 acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2756 acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2757 acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2758 acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2759 acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2760 out.t4[i][j][k][l] = acc;
2761 }
2762 }
2763 }
2764 }
2765 out
2766 }
2767 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2768 /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2769 /// (build via [`Lane::unary5`] from the scalar special-function stack).
2770 #[inline]
2771 pub fn compose_unary(&self, d: [L; 5]) -> Self {
2772 let mut out = Self::zero();
2773 out.v = d[0];
2774 for i in 0..K {
2775 let mut acc = L::splat(0.0);
2776 acc = acc.add(d[1].mul(self.g[i]));
2777 out.g[i] = acc;
2778 }
2779 for i in 0..K {
2780 for j in 0..K {
2781 let mut acc = L::splat(0.0);
2782 acc = acc.add(d[1].mul(self.h[i][j]));
2783 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2784 out.h[i][j] = acc;
2785 }
2786 }
2787 for i in 0..K {
2788 for j in 0..K {
2789 for k in 0..K {
2790 let mut acc = L::splat(0.0);
2791 acc = acc.add(d[1].mul(self.t3[i][j][k]));
2792 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2793 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2794 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2795 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2796 out.t3[i][j][k] = acc;
2797 }
2798 }
2799 }
2800 for i in 0..K {
2801 for j in 0..K {
2802 for k in 0..K {
2803 for l in 0..K {
2804 let mut acc = L::splat(0.0);
2805 acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2806 acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2807 acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2808 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2809 acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2810 acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2811 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2812 acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2813 acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2814 acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2815 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2816 acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2817 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2818 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2819 acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2820 out.t4[i][j][k][l] = acc;
2821 }
2822 }
2823 }
2824 }
2825 out
2826 }
2827 /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2828 /// built from the base value through `stack_fn`, evaluated PER LANE — the
2829 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2830 /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2831 ///
2832 /// Each of the four lanes carries a DISTINCT base value, so the scalar
2833 /// `stack_fn` is run once per lane at that lane's own value (via
2834 /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2835 /// the composition is then the existing per-lane [`Self::compose_unary`].
2836 /// Because `unary_with` runs the identical scalar closure per lane and
2837 /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2838 /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2839 /// — which is exactly what lets a row program written against the scalar
2840 /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2841 #[inline]
2842 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2843 self.compose_unary(self.v.unary_with(stack_fn))
2844 }
2845
2846 /// Single-active-slot fast path, term-for-term lift of
2847 /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2848 #[inline]
2849 pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2850 let mut out = Self::zero();
2851 let s = slot;
2852 let g = self.g[s];
2853 let h = self.h[s][s];
2854 let t3 = self.t3[s][s][s];
2855 let t4 = self.t4[s][s][s][s];
2856 out.v = d[0];
2857 out.g[s] = {
2858 let mut acc = L::splat(0.0);
2859 acc = acc.add(d[1].mul(g));
2860 acc
2861 };
2862 out.h[s][s] = {
2863 let mut acc = L::splat(0.0);
2864 acc = acc.add(d[1].mul(h));
2865 acc = acc.add(d[2].mul(g).mul(g));
2866 acc
2867 };
2868 out.t3[s][s][s] = {
2869 let mut acc = L::splat(0.0);
2870 acc = acc.add(d[1].mul(t3));
2871 acc = acc.add(d[2].mul(h).mul(g));
2872 acc = acc.add(d[2].mul(h).mul(g));
2873 acc = acc.add(d[2].mul(g).mul(h));
2874 acc = acc.add(d[3].mul(g).mul(g).mul(g));
2875 acc
2876 };
2877 out.t4[s][s][s][s] = {
2878 let mut acc = L::splat(0.0);
2879 acc = acc.add(d[1].mul(t4));
2880 acc = acc.add(d[2].mul(t3).mul(g));
2881 acc = acc.add(d[2].mul(t3).mul(g));
2882 acc = acc.add(d[2].mul(h).mul(h));
2883 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2884 acc = acc.add(d[2].mul(t3).mul(g));
2885 acc = acc.add(d[2].mul(h).mul(h));
2886 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2887 acc = acc.add(d[2].mul(h).mul(h));
2888 acc = acc.add(d[2].mul(g).mul(t3));
2889 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2890 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2891 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2892 acc = acc.add(d[3].mul(g).mul(g).mul(h));
2893 acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2894 acc
2895 };
2896 out
2897 }
2898 /// Contract `t3` with a primary-space direction (lift of
2899 /// [`Tower4::third_contracted`]). Output-symmetric in `(a, b)`: compute the
2900 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2901 #[inline]
2902 pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2903 let mut out = [[L::splat(0.0); K]; K];
2904 for a in 0..K {
2905 for b in a..K {
2906 let mut acc = L::splat(0.0);
2907 for c in 0..K {
2908 acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2909 }
2910 out[a][b] = acc;
2911 out[b][a] = acc;
2912 }
2913 }
2914 out
2915 }
2916 /// Contract `t4` with two primary-space directions (lift of
2917 /// [`Tower4::fourth_contracted`]). Output-symmetric in `(i, j)`: compute the
2918 /// upper triangle and mirror — bit-identical to the full nest, lane-for-lane.
2919 #[inline]
2920 pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2921 let mut out = [[L::splat(0.0); K]; K];
2922 for i in 0..K {
2923 for j in i..K {
2924 let mut acc = L::splat(0.0);
2925 for k in 0..K {
2926 for l in 0..K {
2927 acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2928 }
2929 }
2930 out[i][j] = acc;
2931 out[j][i] = acc;
2932 }
2933 }
2934 out
2935 }
2936}
2937
2938/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2939#[derive(Clone, Copy)]
2940pub struct Tower3Lane<L: Lane, const K: usize> {
2941 /// Value channel.
2942 pub v: L,
2943 /// Gradient.
2944 pub g: [L; K],
2945 /// Hessian.
2946 pub h: [[L; K]; K],
2947 /// Third tensor.
2948 pub t3: [[[L; K]; K]; K],
2949}
2950
2951/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
2952pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
2953
2954impl<L: Lane, const K: usize> Tower3Lane<L, K> {
2955 /// All-zero tower.
2956 #[inline]
2957 pub fn zero() -> Self {
2958 let z = L::splat(0.0);
2959 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
2960 }
2961 /// Constant `c` (per lane).
2962 #[inline]
2963 pub fn constant(c: L) -> Self {
2964 let mut o = Self::zero();
2965 o.v = c;
2966 o
2967 }
2968 /// Seeded variable `p_idx` at per-lane `value`.
2969 #[inline]
2970 pub fn variable(value: L, idx: usize) -> Self {
2971 let mut o = Self::constant(value);
2972 o.g[idx] = L::splat(1.0);
2973 o
2974 }
2975 /// Extract lane `i` as a scalar [`Tower3<K>`].
2976 #[inline]
2977 pub fn lane(&self, i: usize) -> Tower3<K> {
2978 let mut out = Tower3::<K>::zero();
2979 out.v = self.v.lane(i);
2980 for a in 0..K {
2981 out.g[a] = self.g[a].lane(i);
2982 for b in 0..K {
2983 out.h[a][b] = self.h[a][b].lane(i);
2984 for c in 0..K {
2985 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2986 }
2987 }
2988 }
2989 out
2990 }
2991 /// Per-channel lane-wise `self + o`.
2992 #[inline]
2993 pub fn add(&self, o: &Self) -> Self {
2994 let mut out = *self;
2995 out.v = self.v.add(o.v);
2996 for i in 0..K {
2997 out.g[i] = self.g[i].add(o.g[i]);
2998 for j in 0..K {
2999 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
3000 for k in 0..K {
3001 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
3002 }
3003 }
3004 }
3005 out
3006 }
3007 /// Per-channel lane-wise `self - o`.
3008 #[inline]
3009 pub fn sub(&self, o: &Self) -> Self {
3010 let mut out = *self;
3011 out.v = self.v.sub(o.v);
3012 for i in 0..K {
3013 out.g[i] = self.g[i].sub(o.g[i]);
3014 for j in 0..K {
3015 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
3016 for k in 0..K {
3017 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
3018 }
3019 }
3020 }
3021 out
3022 }
3023 /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3024 #[inline]
3025 pub fn scale(&self, s: f64) -> Self {
3026 let sl = L::splat(s);
3027 let mut out = *self;
3028 out.v = self.v.mul(sl);
3029 for i in 0..K {
3030 out.g[i] = self.g[i].mul(sl);
3031 for j in 0..K {
3032 out.h[i][j] = self.h[i][j].mul(sl);
3033 for k in 0..K {
3034 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3035 }
3036 }
3037 }
3038 out
3039 }
3040 /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3041 #[inline]
3042 pub fn mul(&self, o: &Self) -> Self {
3043 let a = self;
3044 let b = o;
3045 let mut out = Self::zero();
3046 out.v = a.v.mul(b.v);
3047 for i in 0..K {
3048 let mut acc = L::splat(0.0);
3049 acc = acc.add(a.v.mul(b.g[i]));
3050 acc = acc.add(a.g[i].mul(b.v));
3051 out.g[i] = acc;
3052 }
3053 for i in 0..K {
3054 for j in 0..K {
3055 let mut acc = L::splat(0.0);
3056 acc = acc.add(a.v.mul(b.h[i][j]));
3057 acc = acc.add(a.g[i].mul(b.g[j]));
3058 acc = acc.add(a.g[j].mul(b.g[i]));
3059 acc = acc.add(a.h[i][j].mul(b.v));
3060 out.h[i][j] = acc;
3061 }
3062 }
3063 for i in 0..K {
3064 for j in 0..K {
3065 for k in 0..K {
3066 let mut acc = L::splat(0.0);
3067 acc = acc.add(a.v.mul(b.t3[i][j][k]));
3068 acc = acc.add(a.g[i].mul(b.h[j][k]));
3069 acc = acc.add(a.g[j].mul(b.h[i][k]));
3070 acc = acc.add(a.h[i][j].mul(b.g[k]));
3071 acc = acc.add(a.g[k].mul(b.h[i][j]));
3072 acc = acc.add(a.h[i][k].mul(b.g[j]));
3073 acc = acc.add(a.h[j][k].mul(b.g[i]));
3074 acc = acc.add(a.t3[i][j][k].mul(b.v));
3075 out.t3[i][j][k] = acc;
3076 }
3077 }
3078 }
3079 out
3080 }
3081 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3082 /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3083 #[inline]
3084 pub fn compose_unary(&self, d: [L; 4]) -> Self {
3085 let mut out = Self::zero();
3086 out.v = d[0];
3087 for i in 0..K {
3088 let mut acc = L::splat(0.0);
3089 acc = acc.add(d[1].mul(self.g[i]));
3090 out.g[i] = acc;
3091 }
3092 for i in 0..K {
3093 for j in 0..K {
3094 let mut acc = L::splat(0.0);
3095 acc = acc.add(d[1].mul(self.h[i][j]));
3096 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3097 out.h[i][j] = acc;
3098 }
3099 }
3100 for i in 0..K {
3101 for j in 0..K {
3102 for k in 0..K {
3103 let mut acc = L::splat(0.0);
3104 acc = acc.add(d[1].mul(self.t3[i][j][k]));
3105 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3106 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3107 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3108 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3109 out.t3[i][j][k] = acc;
3110 }
3111 }
3112 }
3113 out
3114 }
3115 /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3116 /// built from the base value through `stack_fn`, evaluated PER LANE — the
3117 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3118 /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3119 /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3120 /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3121 /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3122 /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3123 #[inline]
3124 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3125 self.compose_unary(self.v.unary_with(stack_fn))
3126 }
3127
3128 /// Single-active-slot fast path, term-for-term lift of
3129 /// [`Tower3::compose_unary_single_slot`].
3130 #[inline]
3131 pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3132 let mut out = Self::zero();
3133 let s = slot;
3134 let g = self.g[s];
3135 let h = self.h[s][s];
3136 let t3 = self.t3[s][s][s];
3137 out.v = d[0];
3138 out.g[s] = {
3139 let mut acc = L::splat(0.0);
3140 acc = acc.add(d[1].mul(g));
3141 acc
3142 };
3143 out.h[s][s] = {
3144 let mut acc = L::splat(0.0);
3145 acc = acc.add(d[1].mul(h));
3146 acc = acc.add(d[2].mul(g).mul(g));
3147 acc
3148 };
3149 out.t3[s][s][s] = {
3150 let mut acc = L::splat(0.0);
3151 acc = acc.add(d[1].mul(t3));
3152 acc = acc.add(d[2].mul(h).mul(g));
3153 acc = acc.add(d[2].mul(h).mul(g));
3154 acc = acc.add(d[2].mul(g).mul(h));
3155 acc = acc.add(d[3].mul(g).mul(g).mul(g));
3156 acc
3157 };
3158 out
3159 }
3160}
3161
3162#[cfg(test)]
3163mod batch_tests {
3164 use super::*;
3165
3166 struct Rng(u64);
3167 impl Rng {
3168 fn f(&mut self) -> f64 {
3169 self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3170 ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3171 }
3172 }
3173
3174 // Fill every channel of a scalar Tower4<K> with random data.
3175 fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3176 let mut t = Tower4::<K>::zero();
3177 t.v = r.f();
3178 for i in 0..K {
3179 t.g[i] = r.f();
3180 for j in 0..K {
3181 t.h[i][j] = r.f();
3182 for k in 0..K {
3183 t.t3[i][j][k] = r.f();
3184 for l in 0..K {
3185 t.t4[i][j][k][l] = r.f();
3186 }
3187 }
3188 }
3189 }
3190 t
3191 }
3192 fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3193 let mut t = Tower3::<K>::zero();
3194 t.v = r.f();
3195 for i in 0..K {
3196 t.g[i] = r.f();
3197 for j in 0..K {
3198 t.h[i][j] = r.f();
3199 for k in 0..K {
3200 t.t3[i][j][k] = r.f();
3201 }
3202 }
3203 }
3204 t
3205 }
3206 fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3207 let mut b = Tower4Batch::<K>::zero();
3208 let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3209 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3210 };
3211 b.v = lane(&|t| t.v);
3212 for i in 0..K {
3213 b.g[i] = lane(&|t| t.g[i]);
3214 for j in 0..K {
3215 b.h[i][j] = lane(&|t| t.h[i][j]);
3216 for k in 0..K {
3217 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3218 for l in 0..K {
3219 b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3220 }
3221 }
3222 }
3223 }
3224 b
3225 }
3226 fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3227 let mut b = Tower3Batch::<K>::zero();
3228 let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3229 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3230 };
3231 b.v = lane(&|t| t.v);
3232 for i in 0..K {
3233 b.g[i] = lane(&|t| t.g[i]);
3234 for j in 0..K {
3235 b.h[i][j] = lane(&|t| t.h[i][j]);
3236 for k in 0..K {
3237 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3238 }
3239 }
3240 }
3241 b
3242 }
3243 fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3244 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3245 for i in 0..K {
3246 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3247 for j in 0..K {
3248 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3249 for k in 0..K {
3250 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3251 for l in 0..K {
3252 assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3253 }
3254 }
3255 }
3256 }
3257 }
3258 fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3259 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3260 for i in 0..K {
3261 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3262 for j in 0..K {
3263 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3264 for k in 0..K {
3265 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3266 }
3267 }
3268 }
3269 }
3270
3271 // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3272 // then assert every channel of every lane is to_bits-identical.
3273 fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3274 let mut r = Rng(seed);
3275 let mut rows_checked = 0;
3276 for _ in 0..batches {
3277 let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3278 let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3279 let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3280 let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3281 let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3282 let s = r.f();
3283
3284 // scalar per-row reference
3285 let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3286 let prod = a[rw].mul(&b[rw]);
3287 let comp = prod.compose_unary(d[rw]);
3288 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3289 summed.compose_unary_single_slot(d[rw], 0)
3290 });
3291 let third: [[[f64; K]; K]; 4] =
3292 std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3293 let fourth: [[[f64; K]; K]; 4] =
3294 std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3295
3296 // batched f64x4
3297 let ab = pack4_t4(&a);
3298 let bb = pack4_t4(&b);
3299 let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3300 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3301 });
3302 let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3303 wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3304 });
3305 let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3306 wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3307 });
3308 let prodb = ab.mul(&bb);
3309 let compb = prodb.compose_unary(db);
3310 let summedb = compb.add(&ab).sub(&bb).scale(s);
3311 let finalb = summedb.compose_unary_single_slot(db, 0);
3312 let thirdb = ab.third_contracted(&dirb);
3313 let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3314
3315 for rw in 0..4 {
3316 assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3317 for i in 0..K {
3318 for j in 0..K {
3319 assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3320 assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3321 }
3322 }
3323 rows_checked += 1;
3324 }
3325 }
3326 rows_checked
3327 }
3328 fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3329 let mut r = Rng(seed);
3330 let mut rows_checked = 0;
3331 for _ in 0..batches {
3332 let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3333 let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3334 let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3335 let s = r.f();
3336 let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3337 let prod = a[rw].mul(&b[rw]);
3338 let comp = prod.compose_unary(d[rw]);
3339 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3340 summed.compose_unary_single_slot(d[rw], 0)
3341 });
3342 let ab = pack4_t3(&a);
3343 let bb = pack4_t3(&b);
3344 let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3345 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3346 });
3347 let prodb = ab.mul(&bb);
3348 let compb = prodb.compose_unary(db);
3349 let summedb = compb.add(&ab).sub(&bb).scale(s);
3350 let finalb = summedb.compose_unary_single_slot(db, 0);
3351 for rw in 0..4 {
3352 assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3353 rows_checked += 1;
3354 }
3355 }
3356 rows_checked
3357 }
3358
3359 // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3360 // lanes (≈210 KiB by value); the op chain keeps several live, which can
3361 // exceed a test thread's default stack. Run each width on a large-stack
3362 // thread so K=9 is exercised without a stack overflow.
3363 fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3364 std::thread::Builder::new()
3365 .stack_size(512 << 20)
3366 .spawn(f)
3367 .unwrap()
3368 .join()
3369 .unwrap()
3370 }
3371
3372 #[test]
3373 fn tower4_batch_lane_bit_identical() {
3374 let batches = 2000;
3375 let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3376 + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3377 + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3378 + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3379 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3380 // worker threads against silently running zero comparisons.
3381 assert_eq!(rows_checked, 4 * batches * 4);
3382 }
3383
3384 #[test]
3385 fn tower3_batch_lane_bit_identical() {
3386 let batches = 2000;
3387 let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3388 + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3389 + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3390 + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3391 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3392 // worker threads against silently running zero comparisons.
3393 assert_eq!(rows_checked, 4 * batches * 4);
3394 }
3395
3396 // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3397 //
3398 // The seam lets a single-sourced row program build its special-function
3399 // STACK from the base value through a closure, so the SAME expression
3400 // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3401 // per-lane bases). These oracles pin both arms `to_bits`.
3402
3403 /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3404 /// standing in for a family's hand-certified special-function stack. `stack4`
3405 /// is its order-≤3 truncation.
3406 fn seam_stack5(u: f64) -> [f64; 5] {
3407 [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3408 }
3409 fn seam_stack4(u: f64) -> [f64; 4] {
3410 let s = seam_stack5(u);
3411 [s[0], s[1], s[2], s[3]]
3412 }
3413
3414 /// Force a distinct / edge per-lane base value (signed zeros included).
3415 fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3416 match which {
3417 0 => -0.0,
3418 1 => 0.0,
3419 2 => r.f(),
3420 _ => r.f() + 3.0,
3421 }
3422 }
3423
3424 /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3425 /// the explicit `compose_unary(f(value))` on every channel.
3426 fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3427 let mut r = Rng(seed);
3428 for _ in 0..n {
3429 let mut t = rand_t4::<K>(&mut r);
3430 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3431 assert_t4_eq(
3432 &t.compose_unary_with(seam_stack5),
3433 &t.compose_unary(seam_stack5(t.v)),
3434 "scalar t4 seam",
3435 );
3436 }
3437 n
3438 }
3439 fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3440 let mut r = Rng(seed);
3441 for _ in 0..n {
3442 let mut t = rand_t3::<K>(&mut r);
3443 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3444 assert_t3_eq(
3445 &t.compose_unary_with(seam_stack4),
3446 &t.compose_unary(seam_stack4(t.v)),
3447 "scalar t3 seam",
3448 );
3449 }
3450 n
3451 }
3452
3453 /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3454 /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3455 /// with the four lanes carrying DISTINCT base values (signed zeros included),
3456 /// so a buggy impl reusing one lane's base would fail.
3457 fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3458 let mut r = Rng(seed);
3459 let mut verified = 0usize;
3460 for _ in 0..batches {
3461 let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3462 for (rw, row) in rows.iter_mut().enumerate() {
3463 row.v = seam_edge_base(&mut r, rw);
3464 }
3465 let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3466 for (rw, row) in rows.iter().enumerate() {
3467 assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3468 verified += 1;
3469 }
3470 }
3471 verified
3472 }
3473 fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3474 let mut r = Rng(seed);
3475 let mut verified = 0usize;
3476 for _ in 0..batches {
3477 let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3478 for (rw, row) in rows.iter_mut().enumerate() {
3479 row.v = seam_edge_base(&mut r, rw);
3480 }
3481 let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3482 for (rw, row) in rows.iter().enumerate() {
3483 assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3484 verified += 1;
3485 }
3486 }
3487 verified
3488 }
3489
3490 #[test]
3491 fn compose_unary_with_scalar_bit_identical() {
3492 let n = 1100;
3493 let total = scalar_seam_t4::<2>(0x2200_0001, n)
3494 + scalar_seam_t4::<3>(0x2200_0002, n)
3495 + scalar_seam_t4::<4>(0x2200_0003, n)
3496 + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3497 + scalar_seam_t3::<2>(0x3300_0001, n)
3498 + scalar_seam_t3::<3>(0x3300_0002, n)
3499 + scalar_seam_t3::<4>(0x3300_0003, n)
3500 + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3501 // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3502 assert_eq!(total, 8 * n);
3503 }
3504
3505 #[test]
3506 fn compose_unary_with_lane_matches_scalar() {
3507 let b = 600;
3508 let total = lane_seam_t4::<2>(0x4400_0001, b)
3509 + lane_seam_t4::<3>(0x4400_0002, b)
3510 + lane_seam_t4::<4>(0x4400_0003, b)
3511 + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3512 + lane_seam_t3::<2>(0x5500_0001, b)
3513 + lane_seam_t3::<3>(0x5500_0002, b)
3514 + lane_seam_t3::<4>(0x5500_0003, b)
3515 + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3516 // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3517 assert_eq!(total, 8 * b * 4);
3518 }
3519}
3520
3521#[cfg(test)]
3522mod tests {
3523 use super::*;
3524
3525 /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3526 /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3527 /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3528 /// dropping the fourth tensor cannot perturb them. Exercises products
3529 /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3530 /// same operations the survival location-scale `nll_index_tower` composes —
3531 /// across all mixed partials, not just the diagonal entries that kernel reads.
3532 #[test]
3533 fn tower3_matches_tower4_through_third_order() {
3534 let s_a: [f64; 5] = [
3535 0.3_f64.sin(),
3536 0.3_f64.cos(),
3537 -0.3_f64.sin(),
3538 -0.3_f64.cos(),
3539 0.3_f64.sin(),
3540 ];
3541 let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3542 let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3543
3544 let a4 = Tower4::<3>::variable(0.4, 0);
3545 let b4 = Tower4::<3>::variable(-0.7, 1);
3546 let c4 = Tower4::<3>::variable(0.9, 2);
3547 let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3548 + a4.mul(&c4).scale(-0.7)
3549 + b4.compose_unary(s_b).scale(0.25);
3550
3551 let a3 = Tower3::<3>::variable(0.4, 0);
3552 let b3 = Tower3::<3>::variable(-0.7, 1);
3553 let c3 = Tower3::<3>::variable(0.9, 2);
3554 let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3555 + a3.mul(&c3).scale(-0.7)
3556 + b3.compose_unary(s4(s_b)).scale(0.25);
3557
3558 assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3559 for i in 0..3 {
3560 assert_eq!(
3561 prog3.g[i].to_bits(),
3562 prog4.g[i].to_bits(),
3563 "g[{i}] mismatch"
3564 );
3565 for j in 0..3 {
3566 assert_eq!(
3567 prog3.h[i][j].to_bits(),
3568 prog4.h[i][j].to_bits(),
3569 "h[{i}][{j}] mismatch"
3570 );
3571 for k in 0..3 {
3572 assert_eq!(
3573 prog3.t3[i][j][k].to_bits(),
3574 prog4.t3[i][j][k].to_bits(),
3575 "t3[{i}][{j}][{k}] mismatch"
3576 );
3577 }
3578 }
3579 }
3580 }
3581
3582 /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3583 /// The entire tower has textbook closed forms in μ = σ(η); this test
3584 /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3585 /// analytic truth at near-machine precision.
3586 struct LogitProgram {
3587 eta: Vec<f64>,
3588 y: Vec<f64>,
3589 }
3590
3591 impl RowNllProgram<1> for LogitProgram {
3592 fn n_rows(&self) -> usize {
3593 self.eta.len()
3594 }
3595 fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3596 Ok([self.eta[row]])
3597 }
3598 fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3599 let eta = p[0];
3600 Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3601 }
3602 }
3603
3604 #[test]
3605 fn logit_tower_matches_closed_forms() {
3606 let prog = LogitProgram {
3607 eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3608 y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3609 };
3610 for row in 0..prog.n_rows() {
3611 let t = evaluate_program(&prog, row).expect("logit program");
3612 let eta = prog.eta[row];
3613 let y = prog.y[row];
3614 let mu = 1.0 / (1.0 + (-eta).exp());
3615 let w = mu * (1.0 - mu);
3616 let expect = [
3617 (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3618 (t.g[0], mu - y, "grad"),
3619 (t.h[0][0], w, "hess"),
3620 (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3621 (
3622 t.t4[0][0][0][0],
3623 w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3624 "fourth",
3625 ),
3626 ];
3627 for (got, want, label) in expect {
3628 assert!(
3629 (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3630 "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3631 );
3632 }
3633 }
3634 }
3635
3636 fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3637 let diff = (got - want).abs();
3638 assert!(
3639 diff <= rel_tol * want.abs().max(1.0),
3640 "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3641 );
3642 }
3643
3644 #[test]
3645 fn gamma_special_function_stacks_match_reference_values() {
3646 const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3647 let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3648 let cases = [
3649 (
3650 "x=0.1",
3651 0.1,
3652 -10.423_754_940_411_076,
3653 101.433_299_150_792_75,
3654 ),
3655 (
3656 "x=0.5",
3657 0.5,
3658 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3659 pi_sq / 2.0,
3660 ),
3661 ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3662 (
3663 "x=2.5",
3664 2.5,
3665 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3666 pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3667 ),
3668 (
3669 "x=50",
3670 50.0,
3671 3.901_989_673_427_892,
3672 0.020_201_333_226_697_128,
3673 ),
3674 ];
3675
3676 for (label, x, digamma_ref, trigamma_ref) in cases {
3677 let ln_gamma_stack = ln_gamma_derivative_stack(x);
3678 let digamma_stack = digamma_derivative_stack(x);
3679 let trigamma_stack = trigamma_derivative_stack(x);
3680 assert_close(
3681 &format!("{label} ln_gamma_stack digamma"),
3682 ln_gamma_stack[1],
3683 digamma_ref,
3684 1e-13,
3685 );
3686 assert_close(
3687 &format!("{label} digamma value"),
3688 digamma_stack[0],
3689 digamma_ref,
3690 1e-13,
3691 );
3692 assert_close(
3693 &format!("{label} ln_gamma_stack trigamma"),
3694 ln_gamma_stack[2],
3695 trigamma_ref,
3696 1e-13,
3697 );
3698 assert_close(
3699 &format!("{label} digamma_stack trigamma"),
3700 digamma_stack[1],
3701 trigamma_ref,
3702 1e-13,
3703 );
3704 assert_close(
3705 &format!("{label} trigamma value"),
3706 trigamma_stack[0],
3707 trigamma_ref,
3708 1e-13,
3709 );
3710 }
3711 }
3712
3713 #[test]
3714 fn gamma_special_function_stacks_obey_recurrences() {
3715 for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3716 let digamma_x = digamma_derivative_stack(x)[0];
3717 let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3718 let trigamma_x = trigamma_derivative_stack(x)[0];
3719 let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3720 assert_close(
3721 &format!("digamma recurrence x={x}"),
3722 digamma_next,
3723 digamma_x + 1.0 / x,
3724 1e-13,
3725 );
3726 assert_close(
3727 &format!("trigamma recurrence x={x}"),
3728 trigamma_next,
3729 trigamma_x - 1.0 / (x * x),
3730 1e-13,
3731 );
3732 }
3733 }
3734
3735 /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3736 /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3737 /// shape — all have one-line closed forms here.
3738 struct LocScaleProgram {
3739 eta: Vec<f64>,
3740 s: Vec<f64>,
3741 y: Vec<f64>,
3742 }
3743
3744 impl RowNllProgram<2> for LocScaleProgram {
3745 fn n_rows(&self) -> usize {
3746 self.eta.len()
3747 }
3748 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3749 Ok([self.eta[row], self.s[row]])
3750 }
3751 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3752 let r = -(p[0] - self.y[row]);
3753 Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3754 }
3755 }
3756
3757 #[test]
3758 fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3759 let prog = LocScaleProgram {
3760 eta: vec![0.3, -1.1, 2.0],
3761 s: vec![-0.5, 0.2, 0.8],
3762 y: vec![1.0, -2.0, 2.5],
3763 };
3764 let tol = 1e-12;
3765 for row in 0..prog.n_rows() {
3766 let t = evaluate_program(&prog, row).expect("locscale program");
3767 let r = prog.y[row] - prog.eta[row];
3768 let w = (-2.0 * prog.s[row]).exp();
3769 // (η, s) = indices (0, 1).
3770 let truth_g = [-w * r, 1.0 - w * r * r];
3771 let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3772 // Third tensor: distinct-entry closed forms.
3773 // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3774 let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3775 match a + b + c {
3776 0 => 0.0,
3777 1 => -2.0 * w,
3778 2 => -4.0 * w * r,
3779 _ => -4.0 * w * r * r,
3780 }
3781 };
3782 // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3783 // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3784 let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3785 match a + b + c + d {
3786 0 | 1 => 0.0,
3787 2 => 4.0 * w,
3788 3 => 8.0 * w * r,
3789 _ => 8.0 * w * r * r,
3790 }
3791 };
3792 for a in 0..2 {
3793 assert!(
3794 (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3795 "row {row} grad[{a}]"
3796 );
3797 for b in 0..2 {
3798 assert!(
3799 (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3800 "row {row} hess[{a}][{b}]: got {} want {}",
3801 t.h[a][b],
3802 truth_h[a][b]
3803 );
3804 for c in 0..2 {
3805 assert!(
3806 (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3807 <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3808 "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3809 t.t3[a][b][c],
3810 t3_truth(a, b, c)
3811 );
3812 for d in 0..2 {
3813 assert!(
3814 (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3815 <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3816 "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3817 t.t4[a][b][c][d],
3818 t4_truth(a, b, c, d)
3819 );
3820 }
3821 }
3822 }
3823 }
3824 // The derived trait-surface helpers agree with direct contraction.
3825 let dir = [0.7, -1.3];
3826 let third = derived_third_contracted(&prog, row, &dir).expect("third");
3827 for a in 0..2 {
3828 for b in 0..2 {
3829 let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3830 assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3831 }
3832 }
3833 }
3834 }
3835
3836 /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3837 /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3838 /// every tower channel is checked against central finite differences
3839 /// of the channel one order below — value→grad, grad→hess, hess→t3,
3840 /// t3→t4 — so each order is independently anchored.
3841 ///
3842 /// The program carries a per-row primary fixture plus a per-row offset
3843 /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3844 /// drives both the seed point and the evaluated expression.
3845 struct GnarlyProgram {
3846 primaries: Vec<[f64; 3]>,
3847 tau: Vec<f64>,
3848 }
3849
3850 impl GnarlyProgram {
3851 fn fixture() -> Self {
3852 Self {
3853 primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3854 tau: vec![0.15, -0.35, 0.5],
3855 }
3856 }
3857 }
3858
3859 impl RowNllProgram<3> for GnarlyProgram {
3860 fn n_rows(&self) -> usize {
3861 self.primaries.len()
3862 }
3863 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3864 self.primaries
3865 .get(row)
3866 .copied()
3867 .ok_or_else(|| format!("gnarly: row {row} out of range"))
3868 }
3869 fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3870 let tau = *self
3871 .tau
3872 .get(row)
3873 .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3874 let a = (p[0] * p[1]).exp();
3875 let b = (p[2] * p[2] + 1.0).sqrt();
3876 let c = (a + b + tau).ln();
3877 let d = (p[1] * 0.5 + 2.0).powf(1.7);
3878 Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3879 }
3880 }
3881
3882 /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3883 /// `row` (used to drive central differences off the fixture grid),
3884 /// while keeping `row`'s per-row data (`tau`) in the loss.
3885 fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3886 struct At<'a> {
3887 base: &'a GnarlyProgram,
3888 row: usize,
3889 p: [f64; 3],
3890 }
3891 impl RowNllProgram<3> for At<'_> {
3892 fn n_rows(&self) -> usize {
3893 1
3894 }
3895 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3896 if row != 0 {
3897 return Err(format!("gnarly-at: row {row} out of range"));
3898 }
3899 Ok(self.p)
3900 }
3901 fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3902 if eval_row != 0 {
3903 return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3904 }
3905 self.base.row_nll(self.row, vars)
3906 }
3907 }
3908 evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3909 }
3910
3911 #[test]
3912 fn gnarly_tower_is_fd_consistent_order_by_order() {
3913 let prog = GnarlyProgram::fixture();
3914 for row in 0..prog.n_rows() {
3915 let base = prog.primaries(row).expect("primaries");
3916 let t = gnarly_tower_at(&prog, row, base);
3917 let h_step = 1e-5;
3918 let tol = 1e-6;
3919 for c in 0..3 {
3920 let mut up = base;
3921 let mut dn = base;
3922 up[c] += h_step;
3923 dn[c] -= h_step;
3924 let t_up = gnarly_tower_at(&prog, row, up);
3925 let t_dn = gnarly_tower_at(&prog, row, dn);
3926 // value → gradient.
3927 let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3928 assert!(
3929 (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3930 "grad[{c}]: analytic {} fd {}",
3931 t.g[c],
3932 fd_g
3933 );
3934 for a in 0..3 {
3935 // gradient → Hessian.
3936 let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3937 assert!(
3938 (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3939 "hess[{a}][{c}]: analytic {} fd {}",
3940 t.h[a][c],
3941 fd_h
3942 );
3943 for b in 0..3 {
3944 // Hessian → third.
3945 let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
3946 assert!(
3947 (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
3948 "t3[{a}][{b}][{c}]: analytic {} fd {}",
3949 t.t3[a][b][c],
3950 fd_t3
3951 );
3952 for d in 0..3 {
3953 // third → fourth.
3954 let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
3955 assert!(
3956 (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
3957 "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
3958 t.t4[a][b][d][c],
3959 fd_t4
3960 );
3961 }
3962 }
3963 }
3964 }
3965 }
3966 }
3967
3968 /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
3969 /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
3970 /// smooth, strictly-`a`-monotone
3971 /// F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
3972 /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
3973 /// independent finite-difference oracle. Mirrors the survival flex
3974 /// calibration solve (one implicit intercept over the primaries) without
3975 /// any survival machinery, so a failure localises to the combinator.
3976 #[test]
3977 fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
3978 const C: f64 = 1.7;
3979 // The scalar constraint as a plain f64 closure (the production root
3980 // finder analogue) and its tower form in (a, θ₀, θ₁).
3981 let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
3982 let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
3983 let solve = |th: [f64; 2]| -> f64 {
3984 let mut a = 0.0_f64;
3985 for _ in 0..100 {
3986 let r = f_scalar(a, th);
3987 if r.abs() < 1e-14 {
3988 break;
3989 }
3990 a -= r / f_da(a, th);
3991 }
3992 a
3993 };
3994 // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
3995 let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
3996 let a = Tower4::<3>::variable(a0, 0);
3997 let t0 = Tower4::<3>::variable(th[0], 1);
3998 let t1 = Tower4::<3>::variable(th[1], 2);
3999 a + t0 * a.mul(&a) + t1 * a.exp() - C
4000 };
4001
4002 let th0 = [0.35, 0.2];
4003 let a0 = solve(th0);
4004 let f = f_tower(a0, th0);
4005 // Residual at the solved point is ~0 (the combinator tolerates the
4006 // production Newton residual; here it is machine-zero).
4007 assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
4008 let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
4009
4010 // FD oracle: central differences of the scalar re-solve. Each order is
4011 // built from the previous via one more central difference, exactly the
4012 // gnarly order-by-order ladder.
4013 let h = 1e-4;
4014 let tol = 1e-5;
4015 let re = |th: [f64; 2]| solve(th);
4016 for i in 0..2 {
4017 let mut up = th0;
4018 let mut dn = th0;
4019 up[i] += h;
4020 dn[i] -= h;
4021 let fd_g = (re(up) - re(dn)) / (2.0 * h);
4022 assert!(
4023 (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4024 "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4025 a_tower.g[i],
4026 fd_g
4027 );
4028 // second order: FD of the analytic gradient component would re-use
4029 // the combinator; instead difference a SCALAR gradient computed by
4030 // a nested re-solve so the oracle stays production-independent.
4031 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4032 let mut up = th;
4033 let mut dn = th;
4034 up[j] += h;
4035 dn[j] -= h;
4036 (re(up) - re(dn)) / (2.0 * h)
4037 };
4038 for j in 0..2 {
4039 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4040 assert!(
4041 (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4042 "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4043 a_tower.h[i][j],
4044 fd_h
4045 );
4046 }
4047 }
4048 }
4049
4050 /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4051 /// linear-in-a constraint, and the second-order tensor matches the
4052 /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4053 /// This pins the recursion against the hand-coded first_full.rs formula it
4054 /// replaces, independent of any FD step.
4055 #[test]
4056 fn implicit_solve_matches_textbook_ift_recursion() {
4057 // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4058 let a0 = 0.4_f64;
4059 let th = [0.25_f64, -0.15_f64];
4060 let f = {
4061 let a = Tower4::<3>::variable(a0, 0);
4062 let t0 = Tower4::<3>::variable(th[0], 1);
4063 let t1 = Tower4::<3>::variable(th[1], 2);
4064 // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4065 // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4066 // 0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4067 // implicit_solve requires a genuine root; at the root the level-set
4068 // and root-curve derivatives coincide, so the textbook-IFT
4069 // assertions below are unaffected.
4070 a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4071 };
4072 let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4073 let f_a = f.g[0];
4074 // First order: a_u = −F_u / F_a.
4075 for u in 0..2 {
4076 let want = -f.g[u + 1] / f_a;
4077 assert!(
4078 (a_t.g[u] - want).abs() < 1e-12,
4079 "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4080 a_t.g[u],
4081 want
4082 );
4083 }
4084 // Second order textbook IFT (indices shifted by 1 for the a-slot).
4085 for u in 0..2 {
4086 for v in 0..2 {
4087 let f_uv = f.h[u + 1][v + 1];
4088 let f_au = f.h[0][u + 1];
4089 let f_av = f.h[0][v + 1];
4090 let f_aa = f.h[0][0];
4091 let want =
4092 -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4093 assert!(
4094 (a_t.h[u][v] - want).abs() < 1e-12,
4095 "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4096 a_t.h[u][v],
4097 want
4098 );
4099 }
4100 }
4101 }
4102
4103 /// The moving-boundary flux tower reproduces every θ-derivative of a
4104 /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4105 /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4106 /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4107 /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4108 /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4109 #[test]
4110 fn moving_boundary_flux_carries_b_zuv_term() {
4111 use std::f64::consts::PI;
4112 let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4113 // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4114 let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4115 let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4116 let th0 = [0.7_f64, 0.5_f64];
4117
4118 // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4119 let mut z_edge = Tower4::<2>::constant(z_r(th0));
4120 z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4121 z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4122 z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2 (the z_uv the old flux dropped)
4123
4124 // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4125 // B‴=(3z−z³)·B.
4126 let z0 = z_edge.v;
4127 let b0 = b(z0);
4128 let stack = [
4129 b0,
4130 -z0 * b0,
4131 (z0 * z0 - 1.0) * b0,
4132 (3.0 * z0 - z0 * z0 * z0) * b0,
4133 ];
4134 let flux = moving_limit_boundary_tower(&z_edge, stack);
4135
4136 // FD truth of the integral's derivatives.
4137 let h = 1e-4;
4138 let tol = 1e-6;
4139 for i in 0..2 {
4140 let mut up = th0;
4141 let mut dn = th0;
4142 up[i] += h;
4143 dn[i] -= h;
4144 let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4145 assert!(
4146 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4147 "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4148 flux.g[i],
4149 fd_g
4150 );
4151 }
4152 // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4153 // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4154 // leave the [1][1] entry short by exactly 2·B(z₀).
4155 let grad1_at = |th: [f64; 2]| -> f64 {
4156 let mut up = th;
4157 let mut dn = th;
4158 up[1] += h;
4159 dn[1] -= h;
4160 (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4161 };
4162 let mut up = th0;
4163 let mut dn = th0;
4164 up[1] += h;
4165 dn[1] -= h;
4166 let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4167 assert!(
4168 (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4169 "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4170 flux.h[1][1],
4171 fd_h11
4172 );
4173 // Explicit witness that the B·z_uv term is present and material:
4174 // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4175 let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4176 let b_zuv = flux.h[1][1] - pure_zu2;
4177 assert!(
4178 (b_zuv - b0 * 2.0).abs() < 1e-10,
4179 "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4180 b_zuv,
4181 b0 * 2.0
4182 );
4183 }
4184
4185 /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4186 /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4187 /// plain `moving_limit_boundary_tower` cannot express, and the case the
4188 /// survival directional/bidirectional paths hand-assemble term-by-term
4189 /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4190 /// dropping `G·z_uv`). Two independent oracles:
4191 /// (1) closed-form: the boundary flux of `∫ G dz` is exactly
4192 /// `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4193 /// derivatives we take by central FD of the closed form — no jet code.
4194 /// (2) the explicit second-order hand closure, including the `G·z_uv` term,
4195 /// built from the integrand's own (z,θ) partials.
4196 /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4197 /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4198 /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4199 /// miss a Hessian entry.
4200 #[test]
4201 fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4202 // G(z;θ) = exp(z·θ₀); Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4203 let g = |z: f64, t0: f64| (z * t0).exp();
4204 let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4205 let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4206 let th0 = [0.4_f64, 0.5_f64];
4207 let z0 = z_r(th0);
4208
4209 // Edge tower z_edge(θ) over K=2 primaries.
4210 let mut z_edge = Tower4::<2>::constant(z0);
4211 z_edge.g[0] = 1.0; // ∂z/∂θ₀
4212 z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4213 z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4214
4215 // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4216 // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4217 let z_var = Tower4::<3>::variable(z0, 0);
4218 let t0_var = Tower4::<3>::variable(th0[0], 1);
4219 // θ₁ does not enter G/Φ here, but seed it so the jet carries the full
4220 // K1 frame (its Φ-derivatives are zero; the z_edge chain supplies all θ₁
4221 // motion through slot 0).
4222 let _t1_var = Tower4::<3>::variable(th0[1], 2);
4223 let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4224 // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4225 assert!(
4226 (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4227 "Φ_z {:+.8e} != G {:+.8e}",
4228 phi_jet.g[0],
4229 g(z0, th0[0])
4230 );
4231
4232 let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4233
4234 // Value channel is 0 by construction (boundary, not the integral itself).
4235 assert!(
4236 flux.v.abs() < 1e-12,
4237 "boundary value channel {:+.3e}",
4238 flux.v
4239 );
4240
4241 // Oracle (1): central FD of the closed-form boundary flux
4242 // Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ) (z₀ FROZEN at the base edge).
4243 let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4244 let h = 1e-4;
4245 let tol = 1e-6;
4246 for i in 0..2 {
4247 let mut up = th0;
4248 let mut dn = th0;
4249 up[i] += h;
4250 dn[i] -= h;
4251 let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4252 assert!(
4253 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4254 "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4255 flux.g[i],
4256 fd_g
4257 );
4258 }
4259 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4260 let mut up = th;
4261 let mut dn = th;
4262 up[j] += h;
4263 dn[j] -= h;
4264 (bnd(up) - bnd(dn)) / (2.0 * h)
4265 };
4266 for i in 0..2 {
4267 for j in 0..2 {
4268 let mut up = th0;
4269 let mut dn = th0;
4270 up[i] += h;
4271 dn[i] -= h;
4272 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4273 assert!(
4274 (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4275 "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4276 flux.h[i][j],
4277 fd_h
4278 );
4279 }
4280 }
4281
4282 // Oracle (2): the explicit second-order hand closure, term by term —
4283 // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4284 // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4285 // G_θ₁ = 0.
4286 let gg = g(z0, th0[0]);
4287 let g_z = th0[0] * gg;
4288 let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4289 for i in 0..2 {
4290 for j in 0..2 {
4291 let z_u = z_edge.g[i];
4292 let z_v = z_edge.g[j];
4293 let z_uv = z_edge.h[i][j];
4294 let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4295 assert!(
4296 (flux.h[i][j] - hand).abs() < 1e-9,
4297 "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4298 flux.h[i][j],
4299 hand
4300 );
4301 }
4302 }
4303
4304 // Decisive: the `G·z_uv` term the directional path DROPS is present and
4305 // material in the [1][1] entry (z_uv = 2 there).
4306 let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4307 let g_zuv = flux.h[1][1] - pure_no_zuv;
4308 assert!(
4309 (g_zuv - gg * 2.0).abs() < 1e-9,
4310 "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4311 g_zuv,
4312 gg * 2.0
4313 );
4314 }
4315
4316 /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4317 /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4318 /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4319 /// boundary-velocity formulas:
4320 /// z_u = −(a_u + [u==g]·z) / b
4321 /// z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4322 /// This pins the bridge between `implicit_solve` and
4323 /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4324 /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4325 /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4326 /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4327 #[test]
4328 fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4329 const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4330 let g_idx = 1usize;
4331 let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4332 // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4333 // two live axes so a_u and a_uv are both exercised. (In production this
4334 // comes from implicit_solve; here we plant known derivatives.)
4335 let mut a = Tower4::<3>::constant(0.45);
4336 a.g[0] = 0.7;
4337 a.g[1] = -0.3;
4338 a.h[0][0] = 0.25;
4339 a.h[0][1] = 0.11;
4340 a.h[1][0] = 0.11;
4341 a.h[1][1] = -0.08;
4342
4343 // In the survival flex frame the slope `b` IS the g-primary directly
4344 // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4345 // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4346 let b = Tower4::<3>::variable(g0, g_idx);
4347 let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4348
4349 let bv = g0;
4350 let z0 = z_edge.v;
4351 assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4352
4353 // z_u = −(a_u + [u==g]·z) / b.
4354 for u in 0..2 {
4355 let direct = if u == g_idx { z0 } else { 0.0 };
4356 let want = -(a.g[u] + direct) / bv;
4357 assert!(
4358 (z_edge.g[u] - want).abs() < 1e-10,
4359 "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4360 z_edge.g[u],
4361 want
4362 );
4363 }
4364 // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4365 // first-order z_v/z_u (already verified above).
4366 for u in 0..2 {
4367 for v in 0..2 {
4368 let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4369 + if v == g_idx { z_edge.g[u] } else { 0.0 };
4370 let want = -(a.h[u][v] + cross) / bv;
4371 assert!(
4372 (z_edge.h[u][v] - want).abs() < 1e-10,
4373 "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4374 z_edge.h[u][v],
4375 want
4376 );
4377 }
4378 }
4379 }
4380
4381 /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4382 /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4383 /// FD-certified bare boundary-velocity constants exactly:
4384 /// z_a = ∂z/∂a = −1/b
4385 /// z_ab = ∂²z/∂a∂b = +1/b²
4386 /// z_aa = ∂²z/∂a² = 0
4387 /// z_bb = ∂²z/∂b² = +2(τ−a)/b³
4388 /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4389 /// production base path drops (and only adds in the dir twins, causing the
4390 /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4391 /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4392 /// that later. Pins the constant before the constraint-tower wiring.
4393 #[test]
4394 fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4395 const TAU: f64 = 1.3;
4396 let a0 = 0.45_f64;
4397 let b0 = 0.85_f64;
4398 // Slot 0 = a, slot 1 = b, both seeded independent.
4399 let a = Tower4::<2>::variable(a0, 0);
4400 let b = Tower4::<2>::variable(b0, 1);
4401 let z = (Tower4::<2>::constant(TAU) - a) / b;
4402
4403 assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4404 assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4405 assert!(
4406 (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4407 "z_ab {:+.10e} vs +1/b² {:+.10e}",
4408 z.h[0][1],
4409 1.0 / (b0 * b0)
4410 );
4411 assert!(
4412 z.h[0][0].abs() < 1e-12,
4413 "z_aa must vanish, got {:+.10e}",
4414 z.h[0][0]
4415 );
4416 let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4417 assert!(
4418 (z.h[1][1] - want_zbb).abs() < 1e-12,
4419 "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4420 z.h[1][1],
4421 want_zbb
4422 );
4423 }
4424
4425 /// The oracle harness catches a planted #736-style sign flip in a
4426 /// cross block and reports the channel by name.
4427 #[test]
4428 fn oracle_catches_planted_cross_block_sign_flip() {
4429 let prog = LocScaleProgram {
4430 eta: vec![0.3],
4431 s: vec![-0.5],
4432 y: vec![1.0],
4433 };
4434 let t = evaluate_program(&prog, 0).expect("tower");
4435 let dir = [0.6, -0.2];
4436 let mut third = t.third_contracted(&dir);
4437 let honest = KernelChannels {
4438 value: t.v,
4439 gradient: t.g,
4440 hessian: t.h,
4441 third: vec![(dir, third)],
4442 fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4443 };
4444 verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4445
4446 // Plant the #736 flip: negate one mixed cross entry.
4447 third[0][1] = -third[0][1];
4448 let flipped = KernelChannels {
4449 value: t.v,
4450 gradient: t.g,
4451 hessian: t.h,
4452 third: vec![(dir, third)],
4453 fourth: vec![],
4454 };
4455 let err = verify_kernel_channels(&t, &flipped, 1e-10)
4456 .expect_err("planted sign flip must be caught");
4457 assert!(
4458 err.contains("third[0][0][1]"),
4459 "oracle must name the flipped channel, got: {err}"
4460 );
4461 }
4462
4463 /// The third- and fourth-order tensors must be FULLY symmetric under
4464 /// index permutation (mixed partials commute). The tower stores them
4465 /// unsymmetrized, so equal-by-construction is a real invariant of the
4466 /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4467 /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4468 /// every composition path contributes. Lives in a test (not the hot
4469 /// per-op path) on purpose.
4470 #[test]
4471 fn t3_t4_are_fully_index_symmetric() {
4472 let prog = GnarlyProgram::fixture();
4473 // 3! = 6 permutations of three indices.
4474 let perms3: [[usize; 3]; 6] = [
4475 [0, 1, 2],
4476 [0, 2, 1],
4477 [1, 0, 2],
4478 [1, 2, 0],
4479 [2, 0, 1],
4480 [2, 1, 0],
4481 ];
4482 // 4! = 24 permutations of four indices.
4483 let perms4: [[usize; 4]; 24] = [
4484 [0, 1, 2, 3],
4485 [0, 1, 3, 2],
4486 [0, 2, 1, 3],
4487 [0, 2, 3, 1],
4488 [0, 3, 1, 2],
4489 [0, 3, 2, 1],
4490 [1, 0, 2, 3],
4491 [1, 0, 3, 2],
4492 [1, 2, 0, 3],
4493 [1, 2, 3, 0],
4494 [1, 3, 0, 2],
4495 [1, 3, 2, 0],
4496 [2, 0, 1, 3],
4497 [2, 0, 3, 1],
4498 [2, 1, 0, 3],
4499 [2, 1, 3, 0],
4500 [2, 3, 0, 1],
4501 [2, 3, 1, 0],
4502 [3, 0, 1, 2],
4503 [3, 0, 2, 1],
4504 [3, 1, 0, 2],
4505 [3, 1, 2, 0],
4506 [3, 2, 0, 1],
4507 [3, 2, 1, 0],
4508 ];
4509 for row in 0..prog.n_rows() {
4510 let t = evaluate_program(&prog, row).expect("gnarly tower");
4511 let scale_t3 =
4512 t.t3.iter()
4513 .flatten()
4514 .flatten()
4515 .fold(0.0_f64, |m, x| m.max(x.abs()))
4516 .max(1.0);
4517 let scale_t4 =
4518 t.t4.iter()
4519 .flatten()
4520 .flatten()
4521 .flatten()
4522 .fold(0.0_f64, |m, x| m.max(x.abs()))
4523 .max(1.0);
4524 for i in 0..3 {
4525 for j in 0..3 {
4526 for k in 0..3 {
4527 let base = t.t3[i][j][k];
4528 let idx = [i, j, k];
4529 for p in &perms3 {
4530 let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4531 assert!(
4532 (base - permed).abs() <= 1e-12 * scale_t3,
4533 "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4534 permuted {permed:+.15e} under {p:?}"
4535 );
4536 }
4537 for l in 0..3 {
4538 let base4 = t.t4[i][j][k][l];
4539 let idx4 = [i, j, k, l];
4540 for p in &perms4 {
4541 let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4542 assert!(
4543 (base4 - permed).abs() <= 1e-12 * scale_t4,
4544 "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4545 permuted {permed:+.15e} under {p:?}"
4546 );
4547 }
4548 }
4549 }
4550 }
4551 }
4552 }
4553 }
4554}
4555
4556#[inline]
4557fn erfcx_nonnegative(x: f64) -> f64 {
4558 if !x.is_finite() {
4559 return if x.is_sign_positive() {
4560 0.0
4561 } else {
4562 f64::INFINITY
4563 };
4564 }
4565 if x <= 0.0 {
4566 return 1.0;
4567 }
4568 if x < 26.0 {
4569 ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4570 } else {
4571 let inv = 1.0 / x;
4572 let inv2 = inv * inv;
4573 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4574 + 6.5625 * inv2 * inv2 * inv2 * inv2;
4575 inv * poly / std::f64::consts::PI.sqrt()
4576 }
4577}
4578
4579#[inline]
4580fn log1mexp_positive(a: f64) -> f64 {
4581 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4582 if a > core::f64::consts::LN_2 {
4583 (-(-a).exp()).ln_1p()
4584 } else if a > 0.0 {
4585 (-(-a).exp_m1()).ln()
4586 } else {
4587 f64::NEG_INFINITY
4588 }
4589}
4590
4591#[inline]
4592fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4593 if x == f64::INFINITY {
4594 return (0.0, 0.0);
4595 }
4596 if x == f64::NEG_INFINITY {
4597 return (f64::NEG_INFINITY, f64::INFINITY);
4598 }
4599 if x.is_nan() {
4600 return (f64::NAN, f64::NAN);
4601 }
4602 if x < 0.0 {
4603 let u = -x / std::f64::consts::SQRT_2;
4604 let ex = erfcx_nonnegative(u).max(1e-300);
4605 let log_cdf = -u * u + (0.5 * ex).ln();
4606 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4607 (log_cdf, lambda)
4608 } else {
4609 let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4610 let lambda = crate::probability::normal_pdf(x) / cdf;
4611 (cdf.ln(), lambda)
4612 }
4613}
4614
4615/// Stable derivative stack for `log Phi(x)` through fourth order.
4616#[inline]
4617pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4618 let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4619 let lambda2 = lambda * lambda;
4620 let lambda3 = lambda2 * lambda;
4621 let x2 = x * x;
4622 [
4623 log_cdf,
4624 lambda,
4625 -lambda * (x + lambda),
4626 lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4627 -lambda
4628 * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4629 ]
4630}
4631
4632/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4633#[inline]
4634pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4635 let r = 1.0 / x.exp_m1();
4636 [
4637 log1mexp_positive(x),
4638 r,
4639 -r * (1.0 + r),
4640 r * (1.0 + r) * (1.0 + 2.0 * r),
4641 -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4642 ]
4643}
4644// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4645#[cfg(test)]
4646mod rowjet_bridge_tests {
4647 use super::*;
4648 use crate::jet_scalar::{JetScalar, Order2};
4649
4650 /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4651 /// subtraction, a scale/neg, a constant, and two value-distinct
4652 /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4653 /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4654 /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4655 struct ToyProgram {
4656 primaries: Vec<[f64; 2]>,
4657 /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4658 /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4659 /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4660 aux: Vec<[f64; 3]>,
4661 }
4662
4663 impl ToyProgram {
4664 /// The body uses `pack_rows` to gather the per-lane continuous data from
4665 /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4666 /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4667 fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4668 let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4669 let z = R::pack_rows(rows, |r| self.aux[r][1]);
4670 let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4671
4672 let a = p[0].mul(&p[1]).scale_rows(cov);
4673 let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4674 let c = b
4675 .compose_unary_with(|u| {
4676 let e = u.exp();
4677 [e, e, e, e, e]
4678 })
4679 .scale_rows(z);
4680 let d = c.neg().add(&p[0]);
4681 let e = d
4682 .compose_unary_with(|u| {
4683 let s = (1.0 + u * u).sqrt();
4684 let s3 = s * s * s;
4685 let s5 = s3 * s * s;
4686 let s7 = s5 * s * s;
4687 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4688 })
4689 .scale_rows(wi);
4690 e.mul(&p[1]).add(&e)
4691 }
4692 }
4693
4694 impl RowNllProgramRowJet<2> for ToyProgram {
4695 fn n_rows(&self) -> usize {
4696 self.primaries.len()
4697 }
4698 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4699 Ok(self.primaries[row])
4700 }
4701 fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4702 assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4703 Ok(self.body(rows, p))
4704 }
4705 }
4706
4707 fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4708 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4709 for i in 0..2 {
4710 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4711 for j in 0..2 {
4712 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4713 for k in 0..2 {
4714 assert_eq!(
4715 a.t3[i][j][k].to_bits(),
4716 b.t3[i][j][k].to_bits(),
4717 "{ctx}: t3[{i}][{j}][{k}]"
4718 );
4719 for l in 0..2 {
4720 assert_eq!(
4721 a.t4[i][j][k][l].to_bits(),
4722 b.t4[i][j][k][l].to_bits(),
4723 "{ctx}: t4[{i}][{j}][{k}][{l}]"
4724 );
4725 }
4726 }
4727 }
4728 }
4729 }
4730
4731 fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4732 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4733 for i in 0..2 {
4734 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4735 for j in 0..2 {
4736 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4737 for k in 0..2 {
4738 assert_eq!(
4739 a.t3[i][j][k].to_bits(),
4740 b.t3[i][j][k].to_bits(),
4741 "{ctx}: t3[{i}][{j}][{k}]"
4742 );
4743 }
4744 }
4745 }
4746 }
4747
4748 // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4749 struct Lcg(u64);
4750 impl Lcg {
4751 fn next(&mut self) -> f64 {
4752 self.0 = self
4753 .0
4754 .wrapping_mul(6364136223846793005)
4755 .wrapping_add(1442695040888963407);
4756 (self.0 >> 11) as f64 / (1u64 << 53) as f64
4757 }
4758 fn val(&mut self) -> f64 {
4759 let u = self.next();
4760 if u < 0.04 {
4761 return 0.0;
4762 }
4763 if u < 0.08 {
4764 return -0.0;
4765 }
4766 (self.next() - 0.5) * 5.0
4767 }
4768 }
4769
4770 /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4771 /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4772 /// signed-zero and per-lane-distinct primaries.
4773 #[test]
4774 fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4775 let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4776 let mut batches = 0usize;
4777 for _ in 0..2500 {
4778 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4779 // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4780 let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4781 let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4782 let rows = [0usize, 1, 2, 3];
4783
4784 // order-4 batch vs scalar Tower4 (instantiated through the same body).
4785 let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4786 for (row, base) in bases.iter().enumerate() {
4787 let vars: [Tower4<2>; 2] =
4788 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4789 let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4790 assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4791 }
4792
4793 // order-3 batch vs scalar Tower3.
4794 let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4795 for (row, base) in bases.iter().enumerate() {
4796 let vars: [Tower3<2>; 2] =
4797 std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4798 let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4799 assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4800 }
4801 batches += 1;
4802 }
4803 assert_eq!(batches, 2500);
4804 }
4805
4806 /// The blanket impl does not churn the scalar path: the body driven through
4807 /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4808 /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4809 /// `Tower4` lower channels.
4810 #[test]
4811 fn blanket_scalar_path_is_unchanged_and_consistent() {
4812 let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4813 for _ in 0..3000 {
4814 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4815 let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4816 let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4817
4818 // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4819 // reference body uses `scale(f64)` where the RowJet body uses
4820 // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4821 // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4822 let via_rowjet: Tower4<2> = {
4823 let vars: [Tower4<2>; 2] =
4824 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4825 prog.row_nll(&[0], &vars).expect("rowjet")
4826 };
4827 let via_jetscalar: Tower4<2> = {
4828 let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4829 <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4830 });
4831 let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4832 // The body using JetScalar's own ops + scale(f64) directly.
4833 let a = vars[0].mul(&vars[1]).scale(cov);
4834 let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4835 let c = b
4836 .compose_unary_with(|u| {
4837 let e = u.exp();
4838 [e, e, e, e, e]
4839 })
4840 .scale(z);
4841 let d = JetScalar::neg(&c).add(&vars[0]);
4842 let e = d
4843 .compose_unary_with(|u| {
4844 let s = (1.0 + u * u).sqrt();
4845 let s3 = s * s * s;
4846 let s5 = s3 * s * s;
4847 let s7 = s5 * s * s;
4848 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4849 })
4850 .scale(wi);
4851 e.mul(&vars[1]).add(&e)
4852 };
4853 assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4854
4855 // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4856 // Order2 and Tower4 use different internal representations so
4857 // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4858 // Hessian channels that evaluate to exactly zero; IEEE equality
4859 // treats these as equal, so `==` is the right comparison here.
4860 let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4861 assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4862 for i in 0..2 {
4863 assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4864 for j in 0..2 {
4865 assert!(
4866 h[i][j] == via_rowjet.h[i][j],
4867 "kernel h[{i}][{j}]: {} vs {}",
4868 h[i][j],
4869 via_rowjet.h[i][j]
4870 );
4871 }
4872 }
4873
4874 // (c) the Order2 scalar IS a RowJet via the blanket.
4875 let o2: [Order2<2>; 2] =
4876 std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4877 let _ = prog.body(&[0], &o2);
4878 }
4879 }
4880
4881 /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4882 /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4883 /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4884 #[test]
4885 fn scale_rows_scalar_is_bit_identical_to_scale() {
4886 let mut rng = Lcg(0xFEED_FACE_0042_1001);
4887 for _ in 0..3000 {
4888 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4889 let s = rng.val();
4890 // Build a dense tower with populated channels (exp of a product).
4891 let vars: [Tower4<2>; 2] =
4892 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4893 let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4894 let e = u.exp();
4895 [e, e, e, e, e]
4896 });
4897 let via_scale = RowJet::scale(&jet, s);
4898 let via_scale_rows = RowJet::scale_rows(&jet, s);
4899 assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4900 }
4901 }
4902
4903 /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4904 /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4905 /// continuous per-row data path the single-`f64` `scale` could not carry.
4906 #[test]
4907 fn batched_scale_rows_matches_per_row_scalar_scale() {
4908 let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4909 for _ in 0..2500 {
4910 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4911 let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4912 let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4913 Tower4Batch::variable(
4914 wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4915 a,
4916 )
4917 });
4918 let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4919 let e = u.exp();
4920 [e, e, e, e, e]
4921 });
4922 let scaled = prod.scale_rows(s);
4923 for (row, base) in bases.iter().enumerate() {
4924 let v: [Tower4<2>; 2] =
4925 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4926 let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4927 let e = u.exp();
4928 [e, e, e, e, e]
4929 });
4930 let ref_s = RowJet::scale(&prod_s, s[row]);
4931 assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4932 }
4933 }
4934 }
4935
4936 /// The per-lane guard reports exactly the failing lanes on a batch and the
4937 /// single lane on a scalar jet.
4938 #[test]
4939 fn guard_reports_per_lane_failures() {
4940 let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
4941 let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4942 Tower4Batch::variable(
4943 wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
4944 a,
4945 )
4946 });
4947 let verdict = vars[0].guard(|v| v > 0.0);
4948 assert_eq!(verdict.lanes(), 4);
4949 assert!(verdict.any_failed());
4950 assert!(!verdict.all_pass());
4951 assert!(!verdict.lane_failed(0));
4952 assert!(verdict.lane_failed(1));
4953 assert!(!verdict.lane_failed(2));
4954 assert!(verdict.lane_failed(3));
4955 assert_eq!(verdict.failed_mask(), 0b1010);
4956
4957 let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
4958 let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
4959 assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
4960 assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
4961 assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
4962 }
4963
4964 // ── ln_gamma_derivative_stack / digamma_derivative_stack / trigamma_derivative_stack ──
4965
4966 #[test]
4967 fn ln_gamma_derivative_stack_known_values_at_1() {
4968 let s = ln_gamma_derivative_stack(1.0);
4969 // ln Γ(1) = 0; statrs uses Lanczos so the result is within ULP noise
4970 assert!(s[0].abs() < 1e-14, "ln_gamma(1) must be ~0, got {}", s[0]);
4971 // ψ₀(1) = -γ (Euler–Mascheroni)
4972 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
4973 assert!(
4974 (s[1] + euler_mascheroni).abs() < 1e-10,
4975 "digamma(1) ≈ -{euler_mascheroni:.6}, got {}",
4976 s[1]
4977 );
4978 // ψ₁(1) = π²/6
4979 let pi2_6 = std::f64::consts::PI * std::f64::consts::PI / 6.0;
4980 assert!(
4981 (s[2] - pi2_6).abs() < 1e-10,
4982 "trigamma(1) ≈ {pi2_6:.6}, got {}",
4983 s[2]
4984 );
4985 }
4986
4987 #[test]
4988 fn ln_gamma_derivative_stack_known_values_at_2() {
4989 let s = ln_gamma_derivative_stack(2.0);
4990 // ln Γ(2) = ln(1) = 0 exactly
4991 assert!(s[0].abs() < 1e-14, "ln_gamma(2) must be 0, got {}", s[0]);
4992 // ψ₀(2) = 1 − γ (recurrence: ψ₀(x+1) = ψ₀(x) + 1/x)
4993 let euler_mascheroni = 0.577_215_664_901_532_9_f64;
4994 let digamma_2 = 1.0 - euler_mascheroni;
4995 assert!(
4996 (s[1] - digamma_2).abs() < 1e-10,
4997 "digamma(2) ≈ {digamma_2:.6}, got {}",
4998 s[1]
4999 );
5000 }
5001
5002 #[test]
5003 fn ln_gamma_derivative_stack_order2_is_prefix() {
5004 for &x in &[0.5_f64, 1.0, 2.0, 5.0] {
5005 let full = ln_gamma_derivative_stack(x);
5006 let ord2 = ln_gamma_derivative_stack_order2(x);
5007 assert_eq!(
5008 ord2[0], full[0],
5009 "order2[0] != full[0] at x={x}"
5010 );
5011 assert_eq!(
5012 ord2[1], full[1],
5013 "order2[1] != full[1] at x={x}"
5014 );
5015 assert_eq!(
5016 ord2[2], full[2],
5017 "order2[2] != full[2] at x={x}"
5018 );
5019 }
5020 }
5021
5022 #[test]
5023 fn digamma_derivative_stack_overlaps_ln_gamma_stack() {
5024 // The two stacks share a run of four polygamma values:
5025 // ln_gamma_stack[1..5] == digamma_stack[0..4]
5026 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5027 let lg = ln_gamma_derivative_stack(x);
5028 let dg = digamma_derivative_stack(x);
5029 for i in 0..4 {
5030 assert_eq!(
5031 lg[i + 1], dg[i],
5032 "ln_gamma_stack[{}] != digamma_stack[{}] at x={x}",
5033 i + 1,
5034 i
5035 );
5036 }
5037 }
5038 }
5039
5040 #[test]
5041 fn trigamma_derivative_stack_overlaps_digamma_stack() {
5042 // digamma_stack[1..5] == trigamma_stack[0..4]
5043 for &x in &[0.5_f64, 1.0, 2.0, 7.0] {
5044 let dg = digamma_derivative_stack(x);
5045 let tg = trigamma_derivative_stack(x);
5046 for i in 0..4 {
5047 assert_eq!(
5048 dg[i + 1], tg[i],
5049 "digamma_stack[{}] != trigamma_stack[{}] at x={x}",
5050 i + 1,
5051 i
5052 );
5053 }
5054 }
5055 }
5056
5057 #[test]
5058 fn derivative_stacks_all_finite_at_positive_inputs() {
5059 for &x in &[0.01_f64, 0.5, 1.0, 2.0, 10.0, 100.0] {
5060 for v in ln_gamma_derivative_stack(x) {
5061 assert!(v.is_finite(), "ln_gamma_stack non-finite at x={x}: {v}");
5062 }
5063 for v in digamma_derivative_stack(x) {
5064 assert!(v.is_finite(), "digamma_stack non-finite at x={x}: {v}");
5065 }
5066 for v in trigamma_derivative_stack(x) {
5067 assert!(v.is_finite(), "trigamma_stack non-finite at x={x}: {v}");
5068 }
5069 }
5070 }
5071}
5072
5073// ── Contraction-symmetry optimization gate ────────────────────────────────────
5074//
5075// `Tower4::third_contracted` / `fourth_contracted` contract the (fully
5076// index-symmetric) `t3`/`t4` tensors against directions, leaving the output
5077// indices `(a, b)` / `(i, j)` free. Those free indices inherit the tensor's
5078// symmetry — `out[a][b] == out[b][a]` term-for-term — so only the upper triangle
5079// need be summed and the lower triangle mirrored. Unlike the dense symmetric
5080// FILL (which needs a K⁴ scatter and loses inner-loop vectorisation, and was
5081// measured SLOWER), the mirror here is a tiny K×K copy and the inner contraction
5082// is untouched (contiguous, vectorisable). This is BIT-IDENTICAL to the full
5083// nest, so it needs no fingerprint re-baseline; the gate is (1) bit-identity vs
5084// the full reference and (2) a measured wall-clock that is not slower.
5085#[cfg(test)]
5086mod contraction_symmetry_tests {
5087 use super::*;
5088
5089 struct Rng(u64);
5090 impl Rng {
5091 fn u(&mut self) -> f64 {
5092 self.0 = self
5093 .0
5094 .wrapping_mul(6364136223846793005)
5095 .wrapping_add(1442695040888963407);
5096 (self.0 >> 11) as f64 / (1u64 << 53) as f64
5097 }
5098 fn s(&mut self) -> f64 {
5099 (self.u() - 0.5) * 4.0
5100 }
5101 }
5102
5103 /// Random VALID fully-symmetric `Tower4<K>` (symmetric `h`/`t3`/`t4`).
5104 fn rand_sym4<const K: usize>(r: &mut Rng) -> Tower4<K> {
5105 let mut t = Tower4::<K>::zero();
5106 t.v = r.s();
5107 for i in 0..K {
5108 t.g[i] = r.s();
5109 }
5110 for a in 0..K {
5111 for b in a..K {
5112 let v2 = r.s();
5113 t.h[a][b] = v2;
5114 t.h[b][a] = v2;
5115 for c in b..K {
5116 let v3 = r.s();
5117 for p in perms3([a, b, c]) {
5118 t.t3[p[0]][p[1]][p[2]] = v3;
5119 }
5120 for d in c..K {
5121 let v4 = r.s();
5122 for p in perms4([a, b, c, d]) {
5123 t.t4[p[0]][p[1]][p[2]][p[3]] = v4;
5124 }
5125 }
5126 }
5127 }
5128 }
5129 t
5130 }
5131
5132 fn perms3(idx: [usize; 3]) -> [[usize; 3]; 6] {
5133 let [a, b, c] = idx;
5134 [[a, b, c], [a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]]
5135 }
5136 fn perms4(idx: [usize; 4]) -> [[usize; 4]; 24] {
5137 let [a, b, c, d] = idx;
5138 [
5139 [a, b, c, d], [a, b, d, c], [a, c, b, d], [a, c, d, b], [a, d, b, c], [a, d, c, b],
5140 [b, a, c, d], [b, a, d, c], [b, c, a, d], [b, c, d, a], [b, d, a, c], [b, d, c, a],
5141 [c, a, b, d], [c, a, d, b], [c, b, a, d], [c, b, d, a], [c, d, a, b], [c, d, b, a],
5142 [d, a, b, c], [d, a, c, b], [d, b, a, c], [d, b, c, a], [d, c, a, b], [d, c, b, a],
5143 ]
5144 }
5145
5146 /// Full-nest reference (the pre-opt `a, b ∈ 0..K` form).
5147 fn third_full<const K: usize>(t: &Tower4<K>, dir: &[f64; K]) -> [[f64; K]; K] {
5148 let mut out = [[0.0; K]; K];
5149 for a in 0..K {
5150 for b in 0..K {
5151 let mut acc = 0.0;
5152 for c in 0..K {
5153 acc += t.t3[a][b][c] * dir[c];
5154 }
5155 out[a][b] = acc;
5156 }
5157 }
5158 out
5159 }
5160 fn fourth_full<const K: usize>(t: &Tower4<K>, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
5161 let mut out = [[0.0; K]; K];
5162 for i in 0..K {
5163 for j in 0..K {
5164 let mut acc = 0.0;
5165 for k in 0..K {
5166 for l in 0..K {
5167 acc += t.t4[i][j][k][l] * u[k] * w[l];
5168 }
5169 }
5170 out[i][j] = acc;
5171 }
5172 }
5173 out
5174 }
5175
5176 fn check_bit_identical<const K: usize>(seed: u64, n: usize) {
5177 let mut r = Rng(seed);
5178 for _ in 0..n {
5179 let t = rand_sym4::<K>(&mut r);
5180 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5181 let u: [f64; K] = std::array::from_fn(|_| r.s());
5182 let w: [f64; K] = std::array::from_fn(|_| r.s());
5183 let t3_sym = t.third_contracted(&dir);
5184 let t3_full = third_full(&t, &dir);
5185 let t4_sym = t.fourth_contracted(&u, &w);
5186 let t4_full = fourth_full(&t, &u, &w);
5187 for a in 0..K {
5188 for b in 0..K {
5189 assert_eq!(
5190 t3_sym[a][b].to_bits(),
5191 t3_full[a][b].to_bits(),
5192 "third K={K} [{a}][{b}]"
5193 );
5194 assert_eq!(
5195 t4_sym[a][b].to_bits(),
5196 t4_full[a][b].to_bits(),
5197 "fourth K={K} [{a}][{b}]"
5198 );
5199 }
5200 }
5201 }
5202 }
5203
5204 /// The output-symmetric contraction is BIT-IDENTICAL to the full nest across
5205 /// `K ∈ {2,3,4,9}` (so no fingerprint re-baseline is owed — accuracy and bits
5206 /// are unchanged; this is a pure speed-only optimization).
5207 #[test]
5208 fn contraction_symmetry_is_bit_identical_to_full_nest() {
5209 check_bit_identical::<2>(0x0000_0002_C0FF_EE01, 1000);
5210 check_bit_identical::<3>(0x0000_0003_C0FF_EE01, 800);
5211 check_bit_identical::<4>(0x0000_0004_C0FF_EE01, 600);
5212 check_bit_identical::<9>(0x0000_0009_C0FF_EE01, 300);
5213 }
5214
5215 /// Measure the wall-clock of the output-symmetric contraction vs the full
5216 /// nest at `K = 9` (it does ~2× fewer inner contractions; the bit-identity
5217 /// test is the correctness gate). Informational — wall-clock is noisy — with
5218 /// only a PATHOLOGICAL-regression guard (the symmetric form does strictly
5219 /// fewer inner contractions, so it must not be materially slower).
5220 #[test]
5221 fn contraction_symmetry_speedup_is_reported() {
5222 const K: usize = 9;
5223 let mut r = Rng(0xC0FF_EE99_1234_5678);
5224 let towers: Vec<Tower4<K>> = (0..512).map(|_| rand_sym4::<K>(&mut r)).collect();
5225 let dir: [f64; K] = std::array::from_fn(|_| r.s());
5226 let u: [f64; K] = std::array::from_fn(|_| r.s());
5227 let w: [f64; K] = std::array::from_fn(|_| r.s());
5228
5229 let reps = 400usize;
5230 let t_sym = {
5231 let start = std::time::Instant::now();
5232 let mut sink = 0.0f64;
5233 for _ in 0..reps {
5234 for t in &towers {
5235 let o3 = std::hint::black_box(t).third_contracted(std::hint::black_box(&dir));
5236 let o4 = std::hint::black_box(t)
5237 .fourth_contracted(std::hint::black_box(&u), std::hint::black_box(&w));
5238 sink += o3[0][K - 1] + o4[0][K - 1];
5239 }
5240 }
5241 std::hint::black_box(sink);
5242 start.elapsed().as_secs_f64()
5243 };
5244 let t_full = {
5245 let start = std::time::Instant::now();
5246 let mut sink = 0.0f64;
5247 for _ in 0..reps {
5248 for t in &towers {
5249 let o3 = third_full(std::hint::black_box(t), std::hint::black_box(&dir));
5250 let o4 = fourth_full(
5251 std::hint::black_box(t),
5252 std::hint::black_box(&u),
5253 std::hint::black_box(&w),
5254 );
5255 sink += o3[0][K - 1] + o4[0][K - 1];
5256 }
5257 }
5258 std::hint::black_box(sink);
5259 start.elapsed().as_secs_f64()
5260 };
5261 let calls = (reps * towers.len()) as f64;
5262 eprintln!(
5263 "[contraction-symmetry speedup K=9] sym={:.1}ns/call full={:.1}ns/call \
5264 wall_speedup={:.2}x",
5265 t_sym / calls * 1e9,
5266 t_full / calls * 1e9,
5267 t_full / t_sym
5268 );
5269 assert!(
5270 t_sym <= t_full * 1.5,
5271 "output-symmetric contraction pathologically slower: \
5272 sym={t_sym:.4}s full={t_full:.4}s"
5273 );
5274 }
5275}