gam_math/jet_tower.rs
1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//! v ℓ
12//! g[a] ∂ℓ/∂p_a
13//! h[a][b] ∂²ℓ/∂p_a∂p_b
14//! t3[abc] ∂³ℓ/∂p_a∂p_b∂p_c
15//! t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//! asserting channel-by-channel agreement with a `RowNllProgram` written
94//! once — see [`verify_kernel_channels`]. This alone would have caught
95//! #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//! keep hand-tuned hot paths, now verified against the single-expression
98//! truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//! scale port) implement ONLY `RowNllProgram` and get an exact
101//! fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113 /// Value ℓ.
114 pub v: f64,
115 /// Gradient ∂ℓ/∂p_a.
116 pub g: [f64; K],
117 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118 pub h: [[f64; K]; K],
119 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120 pub t3: [[[f64; K]; K]; K],
121 /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122 pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126 /// The additive identity.
127 pub fn zero() -> Self {
128 Self {
129 v: 0.0,
130 g: [0.0; K],
131 h: [[0.0; K]; K],
132 t3: [[[0.0; K]; K]; K],
133 t4: [[[[0.0; K]; K]; K]; K],
134 }
135 }
136
137 /// A constant: value `c`, all derivatives zero.
138 pub fn constant(c: f64) -> Self {
139 let mut out = Self::zero();
140 out.v = c;
141 out
142 }
143
144 /// The seeded variable `p_idx` with current value `value`:
145 /// unit first derivative in slot `idx`, zero elsewhere and above.
146 pub fn variable(value: f64, idx: usize) -> Self {
147 let mut out = Self::constant(value);
148 out.g[idx] = 1.0;
149 out
150 }
151
152 /// Read the (fully symmetric) derivative tensor entry whose differentiation
153 /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154 #[inline]
155 fn deriv(&self, labels: &[usize]) -> f64 {
156 assert!(
157 labels.len() <= 4,
158 "Tower4 carries at most fourth-order derivatives"
159 );
160 match labels.len() {
161 0 => self.v,
162 1 => self.g[labels[0]],
163 2 => self.h[labels[0]][labels[1]],
164 3 => self.t3[labels[0]][labels[1]][labels[2]],
165 _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166 }
167 }
168
169 /// Exact truncated Leibniz product `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`.
170 ///
171 /// # Codegen
172 ///
173 /// Each output entry's `2^m` subset sum is written as a compact straight-line
174 /// expression instead of the shared [`jet_algebra::leibniz_product`] subset
175 /// walker (which, per entry, builds `SlotBuf`s and `match`-dispatches the
176 /// `deriv` closure across all `2^m` subsets). The loop nest over `(i,j,k,l)`
177 /// is unchanged — only the inner per-entry sum is unrolled — so this does NOT
178 /// unroll over `K` and does NOT bloat code: on a `Tower4<9>` mul-and-read
179 /// consumer the new form is faster AND smaller (asm: 34 outlined walker `bl`
180 /// calls → 0, 21.1 KiB → 14.3 KiB, +100 NEON `.2d` ops).
181 ///
182 /// BIT-IDENTICAL to the walker: each entry's terms are in the walker's exact
183 /// subset-enumeration order (subset bit `b` ↔ position `b`, `sub = 0..2^m`),
184 /// and the per-entry `acc` accumulator mirrors the walker's `total = 0.0`
185 /// start so a signed-zero leading product collapses to `+0.0` identically —
186 /// which matters because real jets carry exact-`0.0` channels
187 /// (`constant`/`variable` towers). Proven `to_bits`-identical on
188 /// `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`, 5000 inputs each with ~30 %
189 /// exact-`0.0` channels and signed values (a no-leading-`0.0` form fails this
190 /// stress — the accumulator start is load-bearing).
191 pub fn mul(&self, o: &Self) -> Self {
192 let a = self;
193 let b = o;
194 let mut out = Self::zero();
195 out.v = a.v * b.v;
196 for i in 0..K {
197 // subsets of {i}: {} {i}
198 let mut acc = 0.0;
199 acc += a.v * b.g[i];
200 acc += a.g[i] * b.v;
201 out.g[i] = acc;
202 }
203 for i in 0..K {
204 for j in 0..K {
205 // subsets of {i,j}: {} {i} {j} {ij}
206 let mut acc = 0.0;
207 acc += a.v * b.h[i][j];
208 acc += a.g[i] * b.g[j];
209 acc += a.g[j] * b.g[i];
210 acc += a.h[i][j] * b.v;
211 out.h[i][j] = acc;
212 }
213 }
214 for i in 0..K {
215 for j in 0..K {
216 for k in 0..K {
217 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
218 let mut acc = 0.0;
219 acc += a.v * b.t3[i][j][k];
220 acc += a.g[i] * b.h[j][k];
221 acc += a.g[j] * b.h[i][k];
222 acc += a.h[i][j] * b.g[k];
223 acc += a.g[k] * b.h[i][j];
224 acc += a.h[i][k] * b.g[j];
225 acc += a.h[j][k] * b.g[i];
226 acc += a.t3[i][j][k] * b.v;
227 out.t3[i][j][k] = acc;
228 }
229 }
230 }
231 for i in 0..K {
232 for j in 0..K {
233 for k in 0..K {
234 for l in 0..K {
235 // subsets of {i,j,k,l} in bit order sub = 0..16
236 let mut acc = 0.0;
237 acc += a.v * b.t4[i][j][k][l];
238 acc += a.g[i] * b.t3[j][k][l];
239 acc += a.g[j] * b.t3[i][k][l];
240 acc += a.h[i][j] * b.h[k][l];
241 acc += a.g[k] * b.t3[i][j][l];
242 acc += a.h[i][k] * b.h[j][l];
243 acc += a.h[j][k] * b.h[i][l];
244 acc += a.t3[i][j][k] * b.g[l];
245 acc += a.g[l] * b.t3[i][j][k];
246 acc += a.h[i][l] * b.h[j][k];
247 acc += a.h[j][l] * b.h[i][k];
248 acc += a.t3[i][j][l] * b.g[k];
249 acc += a.h[k][l] * b.h[i][j];
250 acc += a.t3[i][k][l] * b.g[j];
251 acc += a.t3[j][k][l] * b.g[i];
252 acc += a.t4[i][j][k][l] * b.v;
253 out.t4[i][j][k][l] = acc;
254 }
255 }
256 }
257 }
258 out
259 }
260
261 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
262 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
263 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
264 /// out of the borrowed operands.
265 pub fn add(&self, o: &Self) -> Self {
266 *self + *o
267 }
268
269 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
270 pub fn sub(&self, o: &Self) -> Self {
271 *self + o.scale(-1.0)
272 }
273
274 /// Exact multivariate Faà di Bruno composition `f ∘ self`.
275 ///
276 /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
277 /// the SAME `[f64; 5]` stack shape the families' existing
278 /// `unary_derivatives_*` helpers produce, so those special-function
279 /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
280 ///
281 /// The order-m output sums over the set partitions of the m indices
282 /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
283 /// block count: each partition into r blocks contributes
284 /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
285 ///
286 /// # Codegen
287 ///
288 /// Evaluated as a compact closed form (the Bell(4)=15 set-partitions of
289 /// `t4`, Bell(3)=5 of `t3`, …) instead of routing through the recursive
290 /// [`jet_algebra::faa_di_bruno`] walker (per-output `for_each_partition`
291 /// recursion + per-block `SlotBuf` + closure dispatch). The loop nest is
292 /// identical to the walker's (`for i,j,k,l`); only the per-entry partition
293 /// sum is straight-line, so this does NOT unroll over `K` and does NOT
294 /// bloat code — measured on a `Tower4<9>` compose-and-read consumer the new
295 /// form is both faster and SMALLER (asm: 94 outlined walker `bl` calls → 0,
296 /// 47.5 KiB → 16.7 KiB, +197 NEON `.2d` ops).
297 ///
298 /// BIT-IDENTICAL to the walker: each channel's terms are emitted in the
299 /// walker's exact partition-enumeration order, each term's block products
300 /// are left-associated exactly as the walker's `prod *= block`, and the
301 /// per-channel `acc` accumulator mirrors the walker's `total = 0.0` start
302 /// (so signed-zero products collapse to `+0.0` identically). The order-4
303 /// term sequence was generated from the walker's own enumeration. Proven
304 /// `to_bits`-identical on `v`/`g`/`h`/`t3`/`t4` across `K ∈ {2,3,4,9}`,
305 /// 5000 random inputs each (zeroed / sign-varied stacks included).
306 pub fn compose_unary(&self, d: [f64; 5]) -> Self {
307 let mut out = Self::zero();
308 out.v = d[0];
309 for i in 0..K {
310 let mut acc = 0.0;
311 acc += d[1] * self.g[i];
312 out.g[i] = acc;
313 }
314 for i in 0..K {
315 for j in 0..K {
316 let mut acc = 0.0;
317 acc += d[1] * self.h[i][j];
318 acc += d[2] * self.g[i] * self.g[j];
319 out.h[i][j] = acc;
320 }
321 }
322 for i in 0..K {
323 for j in 0..K {
324 for k in 0..K {
325 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
326 let mut acc = 0.0;
327 acc += d[1] * self.t3[i][j][k];
328 acc += d[2] * self.h[i][j] * self.g[k];
329 acc += d[2] * self.h[i][k] * self.g[j];
330 acc += d[2] * self.g[i] * self.h[j][k];
331 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
332 out.t3[i][j][k] = acc;
333 }
334 }
335 }
336 for i in 0..K {
337 for j in 0..K {
338 for k in 0..K {
339 for l in 0..K {
340 // Bell(4)=15 partitions, walker enumeration order.
341 let mut acc = 0.0;
342 acc += d[1] * self.t4[i][j][k][l];
343 acc += d[2] * self.t3[i][j][k] * self.g[l];
344 acc += d[2] * self.t3[i][j][l] * self.g[k];
345 acc += d[2] * self.h[i][j] * self.h[k][l];
346 acc += d[3] * self.h[i][j] * self.g[k] * self.g[l];
347 acc += d[2] * self.t3[i][k][l] * self.g[j];
348 acc += d[2] * self.h[i][k] * self.h[j][l];
349 acc += d[3] * self.h[i][k] * self.g[j] * self.g[l];
350 acc += d[2] * self.h[i][l] * self.h[j][k];
351 acc += d[2] * self.g[i] * self.t3[j][k][l];
352 acc += d[3] * self.g[i] * self.h[j][k] * self.g[l];
353 acc += d[3] * self.h[i][l] * self.g[j] * self.g[k];
354 acc += d[3] * self.g[i] * self.h[j][l] * self.g[k];
355 acc += d[3] * self.g[i] * self.g[j] * self.h[k][l];
356 acc += d[4] * self.g[i] * self.g[j] * self.g[k] * self.g[l];
357 out.t4[i][j][k][l] = acc;
358 }
359 }
360 }
361 }
362 out
363 }
364
365 /// Compose with a unary special-function whose `[f64; 5]` derivative STACK is
366 /// built from the base value through `stack_fn` — the scalar arm of the
367 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
368 /// [`Tower4Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
369 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
370 /// `self.compose_unary(stack_fn(self.v))`. Writing a program against this seam
371 /// lets it re-instantiate, unchanged, at [`Tower4Lane`] (where each of the four
372 /// lanes carries a DISTINCT base value and `stack_fn` is re-run per lane).
373 #[inline]
374 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
375 self.compose_unary(stack_fn(self.v))
376 }
377
378 /// Single-active-slot fast path for [`Self::compose_unary`].
379 ///
380 /// When the inner jet `self` has derivative support ONLY on the all-`slot`
381 /// diagonal channels — i.e. it is a univariate jet in primary `slot`
382 /// scattered into the `K`-wide layout (`g[a] = 0`, `h[a][b] = 0`,
383 /// `t3 = 0`, `t4 = 0` for any axis `≠ slot`) — the multivariate Faà di
384 /// Bruno walk collapses. Every output channel whose axis tuple contains an
385 /// axis `≠ slot` is structurally `0`: each set-partition has a block
386 /// covering that axis, that block reads an off-`slot` derivative of `self`
387 /// (which is `0`), so the block product and the whole partition vanish, and
388 /// the channel sums to the walker's `total = 0.0` start, i.e. `+0.0`. Only
389 /// the five diagonal channels (`v`, `g[slot]`, `h[slot][slot]`,
390 /// `t3[slot]³`, `t4[slot]⁴`) survive.
391 ///
392 /// This computes exactly those five as STRAIGHT-LINE accumulations, each in
393 /// the EXACT term order of [`Self::compose_unary`]'s diagonal
394 /// (`i = j = k = l = slot`) case — so they are BIT-IDENTICAL to
395 /// [`Self::compose_unary`] on the diagonal — and leaves every other channel
396 /// at the zero-init `+0.0`, which the full walk also produces (the
397 /// off-`slot` collapse is `to_bits`-`+0.0`, signed-zero products included;
398 /// proven across `K ∈ {2,3,4,9}`, 5000 single-slot inputs each). At any
399 /// `K ≥ 2` this is far fewer floating-point operations than materialising
400 /// the full `1 + K + K² + K³ + K⁴` channel set whose off-diagonal entries
401 /// are all zero, and far cheaper than the recursive set-partition walker the
402 /// diagonal channels previously routed through (a measured ~9.5× speedup vs
403 /// the full `compose_unary`, recovering a 5.9× walker regression at the
404 /// `K ∈ {2,3}` BMS tower widths).
405 ///
406 /// `#[inline]` so an adopting consumer pays no `bl` call (uninlined, the
407 /// five-channel build does not amortise the call/spill overhead).
408 ///
409 /// # Precondition
410 ///
411 /// The caller guarantees the single-active-slot structure. If it does not
412 /// hold, the off-`slot` channels would be wrongly zeroed; use the full
413 /// [`Self::compose_unary`] in that case.
414 #[inline]
415 pub fn compose_unary_single_slot(&self, d: [f64; 5], slot: usize) -> Self {
416 let mut out = Self::zero();
417 let s = slot;
418 let g = self.g[s];
419 let h = self.h[s][s];
420 let t3 = self.t3[s][s][s];
421 let t4 = self.t4[s][s][s][s];
422 out.v = d[0];
423 // g (i=s): d1*g
424 out.g[s] = {
425 let mut acc = 0.0;
426 acc += d[1] * g;
427 acc
428 };
429 // h (i=j=s): d1*h + d2*g*g
430 out.h[s][s] = {
431 let mut acc = 0.0;
432 acc += d[1] * h;
433 acc += d[2] * g * g;
434 acc
435 };
436 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
437 out.t3[s][s][s] = {
438 let mut acc = 0.0;
439 acc += d[1] * t3;
440 acc += d[2] * h * g;
441 acc += d[2] * h * g;
442 acc += d[2] * g * h;
443 acc += d[3] * g * g * g;
444 acc
445 };
446 // t4 (i=j=k=l=s): exact term order of compose_unary's inner loop.
447 out.t4[s][s][s][s] = {
448 let mut acc = 0.0;
449 acc += d[1] * t4;
450 acc += d[2] * t3 * g;
451 acc += d[2] * t3 * g;
452 acc += d[2] * h * h;
453 acc += d[3] * h * g * g;
454 acc += d[2] * t3 * g;
455 acc += d[2] * h * h;
456 acc += d[3] * h * g * g;
457 acc += d[2] * h * h;
458 acc += d[2] * g * t3;
459 acc += d[3] * g * h * g;
460 acc += d[3] * h * g * g;
461 acc += d[3] * g * h * g;
462 acc += d[3] * g * g * h;
463 acc += d[4] * g * g * g * g;
464 acc
465 };
466 out
467 }
468
469 /// Multiply every channel by a plain scalar.
470 pub fn scale(&self, s: f64) -> Self {
471 let mut out = *self;
472 out.v *= s;
473 for i in 0..K {
474 out.g[i] *= s;
475 for j in 0..K {
476 out.h[i][j] *= s;
477 for k in 0..K {
478 out.t3[i][j][k] *= s;
479 for l in 0..K {
480 out.t4[i][j][k][l] *= s;
481 }
482 }
483 }
484 }
485 out
486 }
487
488 /// e^self.
489 pub fn exp(&self) -> Self {
490 let e = self.v.exp();
491 self.compose_unary([e, e, e, e, e])
492 }
493
494 /// ln(self). Caller guarantees positivity (likelihood programs do).
495 pub fn ln(&self) -> Self {
496 let u = self.v;
497 let r = 1.0 / u;
498 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
499 }
500
501 /// 1/self.
502 pub fn recip(&self) -> Self {
503 let r = 1.0 / self.v;
504 let r2 = r * r;
505 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
506 }
507
508 /// √self. Caller guarantees positivity.
509 pub fn sqrt(&self) -> Self {
510 let u = self.v;
511 let s = u.sqrt();
512 self.compose_unary([
513 s,
514 0.5 / s,
515 -0.25 / (u * s),
516 0.375 / (u * u * s),
517 -0.9375 / (u * u * u * s),
518 ])
519 }
520
521 /// self^a for real exponent `a`. Caller guarantees a positive base.
522 pub fn powf(&self, a: f64) -> Self {
523 let u = self.v;
524 let f0 = u.powf(a);
525 let f1 = a * u.powf(a - 1.0);
526 let f2 = a * (a - 1.0) * u.powf(a - 2.0);
527 let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
528 let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
529 self.compose_unary([f0, f1, f2, f3, f4])
530 }
531
532 /// ln Γ(self). Caller guarantees positivity.
533 pub fn ln_gamma(&self) -> Self {
534 self.compose_unary(ln_gamma_derivative_stack(self.v))
535 }
536
537 /// ψ(self), the digamma function. Caller guarantees positivity.
538 pub fn digamma(&self) -> Self {
539 self.compose_unary(digamma_derivative_stack(self.v))
540 }
541
542 /// ψ′(self), the trigamma function. Caller guarantees positivity.
543 pub fn trigamma(&self) -> Self {
544 self.compose_unary(trigamma_derivative_stack(self.v))
545 }
546
547 /// Contract `t3` with one primary-space direction:
548 /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
549 /// `row_third_contracted` shape.
550 pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
551 let mut out = [[0.0; K]; K];
552 for a in 0..K {
553 for b in 0..K {
554 let mut acc = 0.0;
555 for c in 0..K {
556 acc += self.t3[a][b][c] * dir[c];
557 }
558 out[a][b] = acc;
559 }
560 }
561 out
562 }
563
564 /// Contract `t4` with two primary-space directions:
565 /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
566 /// `row_fourth_contracted` shape.
567 pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
568 let mut out = [[0.0; K]; K];
569 for i in 0..K {
570 for j in 0..K {
571 let mut acc = 0.0;
572 for k in 0..K {
573 for l in 0..K {
574 acc += self.t4[i][j][k][l] * u[k] * w[l];
575 }
576 }
577 out[i][j] = acc;
578 }
579 }
580 out
581 }
582}
583
584impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
585 #[inline]
586 fn derivative(&self, labels: &[usize]) -> f64 {
587 self.deriv(labels)
588 }
589
590 fn map_derivatives<F>(&self, mut f: F) -> Self
591 where
592 F: FnMut(&[usize]) -> f64,
593 {
594 let mut out = Self::zero();
595 out.v = f(&[]);
596 for i in 0..K {
597 let labels = [i];
598 out.g[i] = f(&labels);
599 }
600 for i in 0..K {
601 for j in 0..K {
602 let labels = [i, j];
603 out.h[i][j] = f(&labels);
604 }
605 }
606 for i in 0..K {
607 for j in 0..K {
608 for k in 0..K {
609 let labels = [i, j, k];
610 out.t3[i][j][k] = f(&labels);
611 }
612 }
613 }
614 for i in 0..K {
615 for j in 0..K {
616 for k in 0..K {
617 for l in 0..K {
618 let labels = [i, j, k, l];
619 out.t4[i][j][k][l] = f(&labels);
620 }
621 }
622 }
623 }
624 out
625 }
626}
627
628/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
629///
630/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
631/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
632/// [`Tower4`] uses for those orders, so for any program written over both
633/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
634/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
635/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
636/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
637/// gradient, or Hessian.
638///
639/// It exists purely for performance: an inner Newton step (and the
640/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
641/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
642/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
643/// dominates the cold marginal-slope fit, while returning the exact same
644/// `(v, g, h)`.
645#[derive(Clone, Copy, Debug)]
646pub struct Tower2<const K: usize> {
647 /// Value ℓ.
648 pub v: f64,
649 /// Gradient ∂ℓ/∂p_a.
650 pub g: [f64; K],
651 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
652 pub h: [[f64; K]; K],
653}
654
655impl<const K: usize> Tower2<K> {
656 /// The additive identity.
657 pub fn zero() -> Self {
658 Self {
659 v: 0.0,
660 g: [0.0; K],
661 h: [[0.0; K]; K],
662 }
663 }
664
665 /// A constant: value `c`, all derivatives zero.
666 pub fn constant(c: f64) -> Self {
667 let mut out = Self::zero();
668 out.v = c;
669 out
670 }
671
672 /// The seeded variable `p_idx` with current value `value`:
673 /// unit first derivative in slot `idx`, zero elsewhere and above.
674 pub fn variable(value: f64, idx: usize) -> Self {
675 let mut out = Self::constant(value);
676 out.g[idx] = 1.0;
677 out
678 }
679
680 /// Read the derivative tensor entry whose differentiation axes are
681 /// `labels` (length 0..=2): value, `g`, `h`.
682 #[inline]
683 fn deriv(&self, labels: &[usize]) -> f64 {
684 assert!(
685 labels.len() <= 2,
686 "Tower2 carries at most second-order derivatives"
687 );
688 match labels.len() {
689 0 => self.v,
690 1 => self.g[labels[0]],
691 _ => self.h[labels[0]][labels[1]],
692 }
693 }
694
695 /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` channels
696 /// match [`Tower4::mul`] term-for-term.
697 pub fn mul(&self, o: &Self) -> Self {
698 let a = self;
699 let b = o;
700 let mut out = Self::zero();
701 out.v = a.v * b.v;
702 for i in 0..K {
703 out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
704 }
705 for i in 0..K {
706 for j in 0..K {
707 out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
708 }
709 }
710 out
711 }
712
713 /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
714 ///
715 /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
716 /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
717 /// `d[0..=2]` for those orders), so this is a strict truncation, not an
718 /// approximation. The full-order `[f64; 5]` derivative stacks the families
719 /// already produce can be passed by slicing their first three entries.
720 ///
721 /// # Codegen
722 ///
723 /// Order-≤2 Faà di Bruno is a tiny closed form, so this evaluates it
724 /// directly instead of routing through the generic
725 /// [`jet_algebra::faa_di_bruno`] set-partition walker (recursion + per-block
726 /// closure dispatch). That matters because this is the kernel under EVERY
727 /// packed scalar — [`crate::jet_scalar::Order2`] / `OneSeed` / `TwoSeed`
728 /// composition all bottom out here — so the straight-line form (whose inner
729 /// loops auto-vectorise to NEON/SSE 2-wide and which emits zero outlined
730 /// walker calls) lifts all of them at once.
731 ///
732 /// The term and accumulation order is BIT-IDENTICAL to the walker it
733 /// replaces: each output channel mirrors the walker's `total = 0.0` start
734 /// (the explicit `acc` accumulator), so a signed-zero product collapses to
735 /// `+0.0` exactly as `total += prod` does. Proven `to_bits`-identical on
736 /// `v`/`g`/`h` across `K ∈ {2,3,4,9}`, 5000 random inputs each (incl.
737 /// zeroed / sign-varied stacks). The order-≤2 walker partitions are:
738 /// `g[i]` = `f′·u_i` (single block `{i}`)
739 /// `h[i][j]` = `f′·u_ij + (f″·u_i)·u_j` (blocks `{ij}` then `{i}{j}`),
740 /// with `f′ = d[1]`, `f″ = d[2]`, `u_* = self.{g,h}`.
741 pub fn compose_unary(&self, d: [f64; 3]) -> Self {
742 let mut out = Self::zero();
743 out.v = d[0];
744 for i in 0..K {
745 let mut acc = 0.0;
746 acc += d[1] * self.g[i];
747 out.g[i] = acc;
748 }
749 for i in 0..K {
750 for j in 0..K {
751 let mut acc = 0.0;
752 acc += d[1] * self.h[i][j];
753 acc += d[2] * self.g[i] * self.g[j];
754 out.h[i][j] = acc;
755 }
756 }
757 out
758 }
759
760 /// Multiply every channel by a plain scalar.
761 pub fn scale(&self, s: f64) -> Self {
762 let mut out = *self;
763 out.v *= s;
764 for i in 0..K {
765 out.g[i] *= s;
766 for j in 0..K {
767 out.h[i][j] *= s;
768 }
769 }
770 out
771 }
772
773 /// e^self.
774 pub fn exp(&self) -> Self {
775 let e = self.v.exp();
776 self.compose_unary([e, e, e])
777 }
778
779 /// √self. Caller guarantees positivity.
780 pub fn sqrt(&self) -> Self {
781 let u = self.v;
782 let s = u.sqrt();
783 self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
784 }
785}
786
787impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
788 #[inline]
789 fn derivative(&self, labels: &[usize]) -> f64 {
790 self.deriv(labels)
791 }
792
793 fn map_derivatives<F>(&self, mut f: F) -> Self
794 where
795 F: FnMut(&[usize]) -> f64,
796 {
797 let mut out = Self::zero();
798 out.v = f(&[]);
799 for i in 0..K {
800 let labels = [i];
801 out.g[i] = f(&labels);
802 }
803 for i in 0..K {
804 for j in 0..K {
805 let labels = [i, j];
806 out.h[i][j] = f(&labels);
807 }
808 }
809 out
810 }
811}
812
813impl<const K: usize> std::ops::Add for Tower2<K> {
814 type Output = Self;
815 fn add(self, o: Self) -> Self {
816 let mut out = self;
817 out.v += o.v;
818 for i in 0..K {
819 out.g[i] += o.g[i];
820 for j in 0..K {
821 out.h[i][j] += o.h[i][j];
822 }
823 }
824 out
825 }
826}
827
828impl<const K: usize> std::ops::Mul for Tower2<K> {
829 type Output = Self;
830 fn mul(self, o: Self) -> Self {
831 Tower2::mul(&self, &o)
832 }
833}
834
835impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
836 type Output = Self;
837 fn add(self, c: f64) -> Self {
838 let mut out = self;
839 out.v += c;
840 out
841 }
842}
843
844impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
845 type Output = Self;
846 fn mul(self, c: f64) -> Self {
847 self.scale(c)
848 }
849}
850
851/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
852///
853/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
854/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
855/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
856/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
857/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
858/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
859/// [`Tower4::compose_unary`]). So for any program written over both towers the
860/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
861/// perturb the value, gradient, Hessian, or third derivatives.
862///
863/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
864/// needs up to third derivatives (the survival location-scale row kernel reads
865/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
866/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
867/// product/composition that otherwise dominates the per-row cost.
868#[derive(Clone, Copy, Debug)]
869pub struct Tower3<const K: usize> {
870 /// Value ℓ.
871 pub v: f64,
872 /// Gradient ∂ℓ/∂p_a.
873 pub g: [f64; K],
874 /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
875 pub h: [[f64; K]; K],
876 /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
877 pub t3: [[[f64; K]; K]; K],
878}
879
880impl<const K: usize> Tower3<K> {
881 /// The additive identity.
882 pub fn zero() -> Self {
883 Self {
884 v: 0.0,
885 g: [0.0; K],
886 h: [[0.0; K]; K],
887 t3: [[[0.0; K]; K]; K],
888 }
889 }
890
891 /// A constant: value `c`, all derivatives zero.
892 pub fn constant(c: f64) -> Self {
893 let mut out = Self::zero();
894 out.v = c;
895 out
896 }
897
898 /// The seeded variable `p_idx` with current value `value`:
899 /// unit first derivative in slot `idx`, zero elsewhere and above.
900 pub fn variable(value: f64, idx: usize) -> Self {
901 let mut out = Self::constant(value);
902 out.g[idx] = 1.0;
903 out
904 }
905
906 /// Read the (fully symmetric) derivative tensor entry whose differentiation
907 /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
908 #[inline]
909 fn deriv(&self, labels: &[usize]) -> f64 {
910 assert!(
911 labels.len() <= 3,
912 "Tower3 carries at most third-order derivatives"
913 );
914 match labels.len() {
915 0 => self.v,
916 1 => self.g[labels[0]],
917 2 => self.h[labels[0]][labels[1]],
918 _ => self.t3[labels[0]][labels[1]][labels[2]],
919 }
920 }
921
922 /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
923 /// channels match [`Tower4::mul`] term-for-term.
924 ///
925 /// # Codegen
926 ///
927 /// Straight-line per-entry subset sums instead of the
928 /// [`jet_algebra::leibniz_product`] walker — the order-≤3 sibling of
929 /// [`Tower4::mul`] (no `t4`). Loop nest unchanged, no unroll over `K`, no
930 /// code bloat; auto-vectorises. BIT-IDENTICAL: terms in the walker's exact
931 /// subset order with an `acc = 0.0` accumulator start (load-bearing for the
932 /// signed-zero leading product on exact-`0.0` jet channels). Proven
933 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
934 /// zero/sign-stressed inputs each (these channel formulas are exactly the
935 /// `g`/`h`/`t3` of the [`Tower4::mul`] oracle, which passes that stress).
936 pub fn mul(&self, o: &Self) -> Self {
937 let a = self;
938 let b = o;
939 let mut out = Self::zero();
940 out.v = a.v * b.v;
941 for i in 0..K {
942 let mut acc = 0.0;
943 acc += a.v * b.g[i];
944 acc += a.g[i] * b.v;
945 out.g[i] = acc;
946 }
947 for i in 0..K {
948 for j in 0..K {
949 let mut acc = 0.0;
950 acc += a.v * b.h[i][j];
951 acc += a.g[i] * b.g[j];
952 acc += a.g[j] * b.g[i];
953 acc += a.h[i][j] * b.v;
954 out.h[i][j] = acc;
955 }
956 }
957 for i in 0..K {
958 for j in 0..K {
959 for k in 0..K {
960 // subsets of {i,j,k}: {} {i} {j} {ij} {k} {ik} {jk} {ijk}
961 let mut acc = 0.0;
962 acc += a.v * b.t3[i][j][k];
963 acc += a.g[i] * b.h[j][k];
964 acc += a.g[j] * b.h[i][k];
965 acc += a.h[i][j] * b.g[k];
966 acc += a.g[k] * b.h[i][j];
967 acc += a.h[i][k] * b.g[j];
968 acc += a.h[j][k] * b.g[i];
969 acc += a.t3[i][j][k] * b.v;
970 out.t3[i][j][k] = acc;
971 }
972 }
973 }
974 out
975 }
976
977 /// Ref-taking elementwise sum, the by-ref twin of the `std::ops::Add`
978 /// operator (which consumes by value). Mirrors the inherent `mul`/`scale`
979 /// API so a chain like `a.mul(&b).add(&c)` reads uniformly without moving
980 /// out of the borrowed operands.
981 pub fn add(&self, o: &Self) -> Self {
982 *self + *o
983 }
984
985 /// Ref-taking elementwise difference, the by-ref twin of `std::ops::Sub`.
986 pub fn sub(&self, o: &Self) -> Self {
987 *self + o.scale(-1.0)
988 }
989
990 /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
991 ///
992 /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
993 /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
994 /// (which uses only `d[0..=3]` for those orders), so this is a strict
995 /// truncation, not an approximation. The full-order `[f64; 5]` derivative
996 /// stacks the families already produce can be passed by slicing their first
997 /// four entries.
998 ///
999 /// # Codegen
1000 ///
1001 /// Order-≤3 Faà di Bruno written as a compact closed form instead of the
1002 /// recursive [`jet_algebra::faa_di_bruno`] walker — the order-≤2 sibling of
1003 /// [`Tower4::compose_unary`], one tensor order shallower. The loop nest is
1004 /// unchanged (no unroll over `K`, no code bloat: measured on a `Tower3<9>`
1005 /// compose-and-read consumer the new form is faster and SMALLER — asm: 71
1006 /// walker `bl` calls → 0, 39.5 KiB → 13.9 KiB, +197 NEON `.2d` ops).
1007 /// BIT-IDENTICAL: terms in the walker's exact partition order, left-
1008 /// associated block products, `acc = 0.0` accumulator start. Proven
1009 /// `to_bits`-identical on `v`/`g`/`h`/`t3` across `K ∈ {2,3,4,9}`, 5000
1010 /// random inputs each.
1011 pub fn compose_unary(&self, d: [f64; 4]) -> Self {
1012 let mut out = Self::zero();
1013 out.v = d[0];
1014 for i in 0..K {
1015 let mut acc = 0.0;
1016 acc += d[1] * self.g[i];
1017 out.g[i] = acc;
1018 }
1019 for i in 0..K {
1020 for j in 0..K {
1021 let mut acc = 0.0;
1022 acc += d[1] * self.h[i][j];
1023 acc += d[2] * self.g[i] * self.g[j];
1024 out.h[i][j] = acc;
1025 }
1026 }
1027 for i in 0..K {
1028 for j in 0..K {
1029 for k in 0..K {
1030 // walker partitions: {ijk} {ij}{k} {ik}{j} {i}{jk} {i}{j}{k}
1031 let mut acc = 0.0;
1032 acc += d[1] * self.t3[i][j][k];
1033 acc += d[2] * self.h[i][j] * self.g[k];
1034 acc += d[2] * self.h[i][k] * self.g[j];
1035 acc += d[2] * self.g[i] * self.h[j][k];
1036 acc += d[3] * self.g[i] * self.g[j] * self.g[k];
1037 out.t3[i][j][k] = acc;
1038 }
1039 }
1040 }
1041 out
1042 }
1043
1044 /// Compose with a unary special-function whose `[f64; 4]` derivative STACK is
1045 /// built from the base value through `stack_fn` — the scalar arm of the
1046 /// generic-over-[`Lane`](crate::jet_scalar::Lane) compose seam (see
1047 /// [`Tower3Lane::compose_unary_with`]). Evaluates `stack_fn(self.v)` ONCE and
1048 /// forwards to [`Self::compose_unary`], so it is BIT-IDENTICAL to the explicit
1049 /// `self.compose_unary(stack_fn(self.v))`. The order-≤3 sibling of
1050 /// [`Tower4::compose_unary_with`].
1051 #[inline]
1052 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
1053 self.compose_unary(stack_fn(self.v))
1054 }
1055
1056 /// Single-active-slot fast path for [`Self::compose_unary`] — the order-≤3
1057 /// sibling of [`Tower4::compose_unary_single_slot`]. When `self` carries
1058 /// derivative support only on the all-`slot` diagonal, every output channel
1059 /// touching an axis `≠ slot` collapses to the walker's `total = 0.0` start
1060 /// (`+0.0`), so only `v`, `g[slot]`, `h[slot][slot]`, `t3[slot]³` survive.
1061 /// These four are computed as STRAIGHT-LINE accumulations, each in the EXACT
1062 /// term order of [`Self::compose_unary`]'s diagonal (`i = j = k = slot`)
1063 /// case (BIT-IDENTICAL to the full path on the diagonal); off-`slot`
1064 /// channels stay at the zero-init `+0.0` the full walk also yields (proven
1065 /// `to_bits` across `K ∈ {2,3,4,9}`). This drops the recursive
1066 /// set-partition walker the diagonal channels previously routed through,
1067 /// recovering its measured ~5.9× regression at the `K ∈ {2,3}` BMS tower
1068 /// widths. Caller guarantees the single-slot precondition; otherwise use
1069 /// [`Self::compose_unary`].
1070 #[inline]
1071 pub fn compose_unary_single_slot(&self, d: [f64; 4], slot: usize) -> Self {
1072 let mut out = Self::zero();
1073 let s = slot;
1074 let g = self.g[s];
1075 let h = self.h[s][s];
1076 let t3 = self.t3[s][s][s];
1077 out.v = d[0];
1078 // g (i=s): d1*g
1079 out.g[s] = {
1080 let mut acc = 0.0;
1081 acc += d[1] * g;
1082 acc
1083 };
1084 // h (i=j=s): d1*h + d2*g*g
1085 out.h[s][s] = {
1086 let mut acc = 0.0;
1087 acc += d[1] * h;
1088 acc += d[2] * g * g;
1089 acc
1090 };
1091 // t3 (i=j=k=s): exact term order of compose_unary's inner loop.
1092 out.t3[s][s][s] = {
1093 let mut acc = 0.0;
1094 acc += d[1] * t3;
1095 acc += d[2] * h * g;
1096 acc += d[2] * h * g;
1097 acc += d[2] * g * h;
1098 acc += d[3] * g * g * g;
1099 acc
1100 };
1101 out
1102 }
1103
1104 /// Multiply every channel by a plain scalar.
1105 pub fn scale(&self, s: f64) -> Self {
1106 let mut out = *self;
1107 out.v *= s;
1108 for i in 0..K {
1109 out.g[i] *= s;
1110 for j in 0..K {
1111 out.h[i][j] *= s;
1112 for k in 0..K {
1113 out.t3[i][j][k] *= s;
1114 }
1115 }
1116 }
1117 out
1118 }
1119}
1120
1121impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
1122 #[inline]
1123 fn derivative(&self, labels: &[usize]) -> f64 {
1124 self.deriv(labels)
1125 }
1126
1127 fn map_derivatives<F>(&self, mut f: F) -> Self
1128 where
1129 F: FnMut(&[usize]) -> f64,
1130 {
1131 let mut out = Self::zero();
1132 out.v = f(&[]);
1133 for i in 0..K {
1134 let labels = [i];
1135 out.g[i] = f(&labels);
1136 }
1137 for i in 0..K {
1138 for j in 0..K {
1139 let labels = [i, j];
1140 out.h[i][j] = f(&labels);
1141 }
1142 }
1143 for i in 0..K {
1144 for j in 0..K {
1145 for k in 0..K {
1146 let labels = [i, j, k];
1147 out.t3[i][j][k] = f(&labels);
1148 }
1149 }
1150 }
1151 out
1152 }
1153}
1154
1155impl<const K: usize> std::ops::Add for Tower3<K> {
1156 type Output = Self;
1157 fn add(self, o: Self) -> Self {
1158 let mut out = self;
1159 out.v += o.v;
1160 for i in 0..K {
1161 out.g[i] += o.g[i];
1162 for j in 0..K {
1163 out.h[i][j] += o.h[i][j];
1164 for k in 0..K {
1165 out.t3[i][j][k] += o.t3[i][j][k];
1166 }
1167 }
1168 }
1169 out
1170 }
1171}
1172
1173pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
1174 [
1175 statrs::function::gamma::ln_gamma(x),
1176 digamma_positive(x),
1177 polygamma_positive(1, x),
1178 polygamma_positive(2, x),
1179 polygamma_positive(3, x),
1180 ]
1181}
1182
1183pub fn ln_gamma_derivative_stack_order2(x: f64) -> [f64; 3] {
1184 [
1185 statrs::function::gamma::ln_gamma(x),
1186 digamma_positive(x),
1187 polygamma_positive(1, x),
1188 ]
1189}
1190
1191pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
1192 [
1193 digamma_positive(x),
1194 polygamma_positive(1, x),
1195 polygamma_positive(2, x),
1196 polygamma_positive(3, x),
1197 polygamma_positive(4, x),
1198 ]
1199}
1200
1201pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
1202 [
1203 polygamma_positive(1, x),
1204 polygamma_positive(2, x),
1205 polygamma_positive(3, x),
1206 polygamma_positive(4, x),
1207 polygamma_positive(5, x),
1208 ]
1209}
1210
1211fn digamma_positive(mut x: f64) -> f64 {
1212 if !(x.is_finite() && x > 0.0) {
1213 return f64::NAN;
1214 }
1215 let mut acc = 0.0;
1216 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1217 acc -= 1.0 / x;
1218 x += 1.0;
1219 }
1220 acc + digamma_asymptotic(x)
1221}
1222
1223fn polygamma_positive(order: usize, mut x: f64) -> f64 {
1224 if !(x.is_finite() && x > 0.0) {
1225 return f64::NAN;
1226 }
1227 let mut acc = 0.0;
1228 while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
1229 acc += polygamma_recurrence_term(order, x);
1230 x += 1.0;
1231 }
1232 acc + polygamma_asymptotic(order, x)
1233}
1234
1235const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
1236const BERNOULLI_EVEN: [(usize, f64); 10] = [
1237 (2, 1.0 / 6.0),
1238 (4, -1.0 / 30.0),
1239 (6, 1.0 / 42.0),
1240 (8, -1.0 / 30.0),
1241 (10, 5.0 / 66.0),
1242 (12, -691.0 / 2730.0),
1243 (14, 7.0 / 6.0),
1244 (16, -3617.0 / 510.0),
1245 (18, 43867.0 / 798.0),
1246 (20, -174611.0 / 330.0),
1247];
1248
1249fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
1250 let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1251 sign * factorial(order) / x.powi((order + 1) as i32)
1252}
1253
1254fn digamma_asymptotic(x: f64) -> f64 {
1255 let mut out = x.ln() - 0.5 / x;
1256 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1257 out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
1258 }
1259 out
1260}
1261
1262fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
1263 if !(1..=5).contains(&order) {
1264 return f64::NAN;
1265 }
1266
1267 let order_factorial = factorial(order);
1268 let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1269 let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
1270 + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
1271
1272 let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
1273 for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
1274 let rising = rising_factorial(bernoulli_order, order);
1275 out += bernoulli_sign * bernoulli * rising
1276 / bernoulli_order as f64
1277 / x.powi((bernoulli_order + order) as i32);
1278 }
1279 out
1280}
1281
1282fn factorial(n: usize) -> f64 {
1283 (1..=n).fold(1.0, |acc, k| acc * k as f64)
1284}
1285
1286fn rising_factorial(start: usize, len: usize) -> f64 {
1287 (start..start + len).fold(1.0, |acc, k| acc * k as f64)
1288}
1289
1290impl<const K: usize> std::ops::Add for Tower4<K> {
1291 type Output = Self;
1292 fn add(self, o: Self) -> Self {
1293 let mut out = self;
1294 out.v += o.v;
1295 for i in 0..K {
1296 out.g[i] += o.g[i];
1297 for j in 0..K {
1298 out.h[i][j] += o.h[i][j];
1299 for k in 0..K {
1300 out.t3[i][j][k] += o.t3[i][j][k];
1301 for l in 0..K {
1302 out.t4[i][j][k][l] += o.t4[i][j][k][l];
1303 }
1304 }
1305 }
1306 }
1307 out
1308 }
1309}
1310
1311impl<const K: usize> std::ops::Sub for Tower4<K> {
1312 type Output = Self;
1313 fn sub(self, o: Self) -> Self {
1314 self + o.scale(-1.0)
1315 }
1316}
1317
1318impl<const K: usize> std::ops::Neg for Tower4<K> {
1319 type Output = Self;
1320 fn neg(self) -> Self {
1321 self.scale(-1.0)
1322 }
1323}
1324
1325impl<const K: usize> std::ops::Mul for Tower4<K> {
1326 type Output = Self;
1327 fn mul(self, o: Self) -> Self {
1328 Tower4::mul(&self, &o)
1329 }
1330}
1331
1332impl<const K: usize> std::ops::Div for Tower4<K> {
1333 type Output = Self;
1334 fn div(self, o: Self) -> Self {
1335 Tower4::mul(&self, &o.recip())
1336 }
1337}
1338
1339impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
1340 type Output = Self;
1341 fn add(self, c: f64) -> Self {
1342 let mut out = self;
1343 out.v += c;
1344 out
1345 }
1346}
1347
1348impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
1349 type Output = Self;
1350 fn sub(self, c: f64) -> Self {
1351 self + (-c)
1352 }
1353}
1354
1355impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
1356 type Output = Self;
1357 fn mul(self, c: f64) -> Self {
1358 self.scale(c)
1359 }
1360}
1361
1362// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
1363//
1364// The flexible survival marginal-slope row loss is NOT a free composition
1365// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
1366// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
1367// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
1368// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
1369// last hand-written one in the codebase, and the genus of #736-class
1370// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
1371// this reason). These two combinators close that gap: once the constraint
1372// `F` and the integrand/boundaries are themselves towers, the intercept's
1373// derivative tower and the integral's derivative tower come out EXACTLY at
1374// every order — there is no order left to hand-code and forget.
1375
1376/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
1377/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
1378/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
1379/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
1380/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
1381/// and `a0` is that solved intercept value.
1382///
1383/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
1384/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
1385/// derivative. This is the mechanical replacement for the hand-coded
1386/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
1387/// recursion (first_full.rs) and its third/fourth-order continuations.
1388///
1389/// Method: order-by-order substitution. We build `a` incrementally; at each
1390/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
1391/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
1392/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
1393/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
1394/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
1395/// only the exact [`substitute_intercept`] chain rule, so the recursion is
1396/// auditable and exact, not a hand-expanded formula per order.
1397///
1398/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
1399/// solve's strict monotonicity guard.
1400///
1401/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
1402/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
1403/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
1404/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
1405/// is guarded explicitly and re-verified by a composed-residual self-check.
1406pub fn implicit_solve<const K1: usize, const K: usize>(
1407 f: &Tower4<K1>,
1408 a0: f64,
1409) -> Result<Tower4<K>, String> {
1410 assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
1411 let f_a = f.g[0];
1412 if f_a == 0.0 || !f_a.is_finite() {
1413 return Err(format!(
1414 "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
1415 ));
1416 }
1417 // The expansion point must be a genuine root of F. The single Newton
1418 // correction that would move a0 onto the root is |f.v|/|f_a|; require it
1419 // to be negligible relative to the natural scale (1 + |a0|). Guarding the
1420 // Newton step (rather than f.v directly) makes the criterion invariant to
1421 // the magnitude of f_a / the units of F.
1422 let root_tol = 1e-9;
1423 if !f.v.is_finite() {
1424 return Err(format!(
1425 "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1426 f.v
1427 ));
1428 }
1429 let newton_step = f.v.abs() / f_a.abs();
1430 if newton_step > root_tol * (1.0 + a0.abs()) {
1431 return Err(format!(
1432 "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1433 F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1434 root_tol {root_tol:.1e} · (1 + |a0|)",
1435 f.v
1436 ));
1437 }
1438 // Start with a = constant a0 (correct through order 0). Then lift each
1439 // order in turn. Because substitute_intercept reads `a`'s order-≤m
1440 // tensors when forming G's order-m coefficient, and the order-m
1441 // coefficient of G depends on a's order-m tensor ONLY through the linear
1442 // F_a·a_m term, a single corrective pass per order is exact.
1443 let mut a = Tower4::<K>::constant(a0);
1444 for order in 1..=4 {
1445 let g = substitute_intercept(f, &a);
1446 // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1447 // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1448 // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1449 // as a plain variable in the substitution's first-order part).
1450 match order {
1451 1 => {
1452 for i in 0..K {
1453 a.g[i] -= g.g[i] / f_a;
1454 }
1455 }
1456 2 => {
1457 for i in 0..K {
1458 for j in 0..K {
1459 a.h[i][j] -= g.h[i][j] / f_a;
1460 }
1461 }
1462 }
1463 3 => {
1464 for i in 0..K {
1465 for j in 0..K {
1466 for k in 0..K {
1467 a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1468 }
1469 }
1470 }
1471 }
1472 _ => {
1473 for i in 0..K {
1474 for j in 0..K {
1475 for k in 0..K {
1476 for l in 0..K {
1477 a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1478 }
1479 }
1480 }
1481 }
1482 }
1483 }
1484 }
1485 // Self-check: the composed residual G = F∘a must vanish through order 4.
1486 // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1487 // is exactly the root requirement guarded above. Re-verify all channels
1488 // against a scale-aware floor so any arithmetic regression in the
1489 // substitution recursion is loud rather than silently shipping a
1490 // level-set expansion.
1491 let g = substitute_intercept(f, &a);
1492 let resid_tol = 1e-7 * (1.0 + f_a.abs());
1493 let mut worst = g.v.abs();
1494 for i in 0..K {
1495 worst = worst.max(g.g[i].abs());
1496 for j in 0..K {
1497 worst = worst.max(g.h[i][j].abs());
1498 for k in 0..K {
1499 worst = worst.max(g.t3[i][j][k].abs());
1500 for l in 0..K {
1501 worst = worst.max(g.t4[i][j][k][l].abs());
1502 }
1503 }
1504 }
1505 }
1506 if !worst.is_finite() || worst > resid_tol {
1507 return Err(format!(
1508 "implicit_solve: composed residual G = F∘a does not vanish: \
1509 worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1510 ));
1511 }
1512 Ok(a)
1513}
1514
1515/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1516/// written over `K + 1` variables, returning the composite tower over the
1517/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1518///
1519/// This is the exact multivariate chain rule specialised to "slot 0 is a
1520/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1521/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1522/// point, with the slot-0 increment being the non-constant part of `a` and
1523/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1524/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1525/// so it carries derivatives exactly through order four.
1526pub fn substitute_intercept<const K1: usize, const K: usize>(
1527 f: &Tower4<K1>,
1528 a: &Tower4<K>,
1529) -> Tower4<K> {
1530 assert_eq!(K1, K + 1);
1531 // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1532 // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1533 // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1534 // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1535 // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1536 // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1537 // using tower products for the increment monomials so all θ-derivatives
1538 // propagate exactly.
1539 let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1540 if slot == 0 {
1541 // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1542 let mut d = *a;
1543 d.v = 0.0;
1544 d
1545 } else {
1546 // slot i: the increment δθ_{i-1} = seeded variable minus value.
1547 // θ centred at its expansion value has zero constant term and unit
1548 // first derivative in its own slot.
1549 let mut d = Tower4::<K>::zero();
1550 d.g[slot - 1] = 1.0;
1551 d
1552 }
1553 });
1554 // Accumulate the Taylor sum. order-0 term:
1555 let mut out = Tower4::<K>::constant(f.v);
1556 // order 1: Σ_a f.g[a] · inp[a]
1557 for a_idx in 0..K1 {
1558 out = out + inp[a_idx].scale(f.g[a_idx]);
1559 }
1560 // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1561 for a_idx in 0..K1 {
1562 for b_idx in 0..K1 {
1563 let prod = inp[a_idx].mul(&inp[b_idx]);
1564 out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1565 }
1566 }
1567 // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1568 for a_idx in 0..K1 {
1569 for b_idx in 0..K1 {
1570 for c_idx in 0..K1 {
1571 let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1572 out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1573 }
1574 }
1575 }
1576 // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1577 for a_idx in 0..K1 {
1578 for b_idx in 0..K1 {
1579 for c_idx in 0..K1 {
1580 for d_idx in 0..K1 {
1581 let prod = inp[a_idx]
1582 .mul(&inp[b_idx])
1583 .mul(&inp[c_idx])
1584 .mul(&inp[d_idx]);
1585 out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1586 }
1587 }
1588 }
1589 }
1590 out
1591}
1592
1593/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1594/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1595/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1596/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1597/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1598///
1599/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1600/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1601/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1602/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1603/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1604/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1605/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1606/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1607/// to a constant); callers read only the derivative channels and pair this
1608/// with the interior moment-integral value separately.
1609///
1610/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1611/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1612/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1613pub fn moving_limit_boundary_tower<const K: usize>(
1614 z_edge: &Tower4<K>,
1615 b_stack: [f64; 4],
1616) -> Tower4<K> {
1617 z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1618}
1619
1620/// The boundary-flux derivative tower of a single moving cell integral
1621/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1622/// two edge towers and the integrand stacks at each edge. The returned
1623/// tower's derivative channels are the EXACT moving-boundary contribution to
1624/// every θ-derivative of the cell integral, to fourth order, with no term
1625/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1626/// derivative channels are all zero, contributing nothing — matching the
1627/// production `edge_vel = 0` short-circuit.
1628pub fn cell_moving_boundary_flux_tower<const K: usize>(
1629 z_right: &Tower4<K>,
1630 b_stack_right: [f64; 4],
1631 z_left: &Tower4<K>,
1632 b_stack_left: [f64; 4],
1633) -> Tower4<K> {
1634 moving_limit_boundary_tower(z_right, b_stack_right)
1635 - moving_limit_boundary_tower(z_left, b_stack_left)
1636}
1637
1638/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1639///
1640/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1641/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1642/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1643/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1644/// integrand coefficients move with η, hence with the primaries), so the
1645/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1646/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1647/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1648/// edge-motion-carrying terms the hand path assembles one by one, including the
1649/// `G·z_uv` term the directional path drops).
1650///
1651/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1652/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1653/// part — everything carrying edge motion — is exactly
1654/// `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1655/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1656/// interior moment integral already supplies. Both are one
1657/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1658/// θ in slots 1..K): substituting the edge tower gives the full composite,
1659/// substituting a frozen constant edge isolates the pure-θ part, and their
1660/// difference is the exact boundary flux — every Leibniz term derived by the
1661/// substitution algebra, none hand-omitted.
1662///
1663/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1664/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1665/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1666/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1667/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1668/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1669/// the derivative channels are meaningful, matching the value-less convention of
1670/// [`moving_limit_boundary_tower`].
1671pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1672 phi_jet: &Tower4<K1>,
1673 z_edge: &Tower4<K>,
1674) -> Tower4<K> {
1675 assert_eq!(
1676 K1,
1677 K + 1,
1678 "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1679 );
1680 let frozen_edge = Tower4::<K>::constant(z_edge.v);
1681 let full = substitute_intercept(phi_jet, z_edge);
1682 let interior = substitute_intercept(phi_jet, &frozen_edge);
1683 full - interior
1684}
1685
1686/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1687/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1688/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1689/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1690/// so its `full` and `interior` substitutions coincide and it contributes
1691/// nothing — matching the production `edge_vel = 0` short-circuit.
1692pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1693 phi_jet_right: &Tower4<K1>,
1694 z_right: &Tower4<K>,
1695 phi_jet_left: &Tower4<K1>,
1696 z_left: &Tower4<K>,
1697) -> Tower4<K> {
1698 moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1699 - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1700}
1701
1702// ── The program seam ─────────────────────────────────────────────────
1703
1704/// A family's row negative log-likelihood written ONCE over tower scalars.
1705///
1706/// This is the single source of truth #932 asks for: the value channel of
1707/// the returned tower must BE the production row NLL (same branches, same
1708/// guards, same numerics), and every derivative channel is then exact by
1709/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1710/// NOT part of this trait — it is family data, not calculus, and stays on
1711/// the `RowKernel` implementor.
1712pub trait RowNllProgram<const K: usize>: Send + Sync {
1713 /// Number of observations the program covers.
1714 fn n_rows(&self) -> usize;
1715
1716 /// Current primary-scalar values for `row` (where to seed the tower).
1717 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1718
1719 /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1720 /// variable `a` at the current primary value; implementations combine
1721 /// them with `Tower4` arithmetic and per-row data (response, censoring
1722 /// indicators, offsets) entering as constants.
1723 fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1724}
1725
1726/// Evaluate a program's full tower at the current primaries for one row.
1727///
1728/// One call yields every `RowKernel` calculus channel; callers that need
1729/// several contractions of the same row should hold the returned tower and
1730/// contract repeatedly rather than re-evaluating.
1731pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1732 prog: &P,
1733 row: usize,
1734) -> Result<Tower4<K>, String> {
1735 let p = prog.primaries(row)?;
1736 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1737 prog.row_nll(row, &vars)
1738}
1739
1740/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1741pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1742 prog: &P,
1743 row: usize,
1744) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1745 let t = evaluate_program(prog, row)?;
1746 Ok((t.v, t.g, t.h))
1747}
1748
1749/// Mechanically derived `row_third_contracted` channel.
1750pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1751 prog: &P,
1752 row: usize,
1753 dir: &[f64; K],
1754) -> Result<[[f64; K]; K], String> {
1755 Ok(evaluate_program(prog, row)?.third_contracted(dir))
1756}
1757
1758/// Mechanically derived `row_fourth_contracted` channel.
1759pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1760 prog: &P,
1761 row: usize,
1762 dir_u: &[f64; K],
1763 dir_v: &[f64; K],
1764) -> Result<[[f64; K]; K], String> {
1765 Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1766}
1767
1768// ── The generic program seam (#932 scalar cutover) ───────────────────
1769
1770/// A family's row negative log-likelihood written ONCE over the generic
1771/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1772/// re-instantiated at whatever order / representation a consumer needs
1773/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1774/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1775/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1776/// [`Tower4`] for every channel at once).
1777///
1778/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1779/// program implementing this generic trait gets the small contracted scalars for
1780/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1781/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1782/// work unchanged; new families should prefer this generic trait.
1783///
1784/// Because a `Tower4`-specialised `row_nll` body uses only
1785/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1786/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1787/// A program written that way needs no `Tower4`-specialised method and routes
1788/// the directional and joint-Hessian gates through the contracted scalars from
1789/// a single definition.
1790pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1791 /// Number of observations the program covers.
1792 fn n_rows(&self) -> usize;
1793
1794 /// Current primary-scalar values for `row` (where to seed the scalar).
1795 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1796
1797 /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1798 /// (base value + per-scalar nilpotent directions) by the caller; the body
1799 /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1800 /// (response, censoring, offsets) entering as constants.
1801 fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1802 &self,
1803 row: usize,
1804 p: &[S; K],
1805 ) -> Result<S, String>;
1806}
1807
1808/// Evaluate a generic program at the value/gradient/Hessian scalar
1809/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1810/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1811///
1812/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1813/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1814/// cheap order-2 scalar. The single source of truth means the gradient and
1815/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1816pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1817 prog: &P,
1818 row: usize,
1819) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1820 let base = prog.primaries(row)?;
1821 let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1822 <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1823 });
1824 let s = prog.row_nll_generic(row, &vars)?;
1825 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1826}
1827
1828/// Evaluate a generic program at the one-seed scalar
1829/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1830/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1831/// materialising the dense `t3` tensor. The contraction direction is folded
1832/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1833pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1834 prog: &P,
1835 row: usize,
1836 dir: &[f64; K],
1837) -> Result<[[f64; K]; K], String> {
1838 let base = prog.primaries(row)?;
1839 let vars: [crate::jet_scalar::OneSeed<K>; K] =
1840 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1841 let s = prog.row_nll_generic(row, &vars)?;
1842 Ok(s.contracted_third())
1843}
1844
1845/// Evaluate a generic program at the two-seed scalar
1846/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1847/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1848/// WITHOUT materialising the dense `t4` tensor.
1849pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1850 prog: &P,
1851 row: usize,
1852 dir_u: &[f64; K],
1853 dir_v: &[f64; K],
1854) -> Result<[[f64; K]; K], String> {
1855 let base = prog.primaries(row)?;
1856 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1857 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1858 let s = prog.row_nll_generic(row, &vars)?;
1859 Ok(s.contracted_fourth())
1860}
1861
1862/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1863/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1864/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1865/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1866/// the order-2 / contracted scalars consume, so there is a single source of
1867/// truth across every channel.
1868pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1869 prog: &P,
1870 row: usize,
1871) -> Result<Tower4<K>, String> {
1872 let base = prog.primaries(row)?;
1873 let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1874 prog.row_nll_generic(row, &vars)
1875}
1876
1877// ── The RowJet bridge: one row-NLL body over scalar jets AND lane towers ─
1878//
1879// `JetScalar<K>` (jet_scalar.rs) abstracts the SCALAR jets — its `value()`
1880// returns one `f64`, so the `f64x4` lane towers ([`Tower3Lane`] / [`Tower4Lane`])
1881// CANNOT implement it (their value channel is four rows). `compose_unary_with`
1882// exists as an inherent method on BOTH the scalar towers and the lane towers, but
1883// as separate inherent methods, not a shared trait bound — so a row-NLL body
1884// written `<S: JetScalar<K>>` could not be instantiated at `Tower4Lane`, and the
1885// 4-rows-per-pass SIMD batch path could not reuse the single source.
1886//
1887// [`RowJet<K>`] is that shared bound. It exposes exactly the ops a row-NLL body
1888// needs — `constant` / `variable` / `add` / `sub` / `mul` / `scale` / `neg`, the
1889// value-derived `compose_unary_with`, and a per-lane domain `guard` — over BOTH
1890// representations. A blanket impl makes every scalar `JetScalar<K>` a `RowJet<K>`
1891// (so the scalar call sites compile unchanged and bit-identically), and explicit
1892// impls route the `f64x4` lane towers through their existing per-lane methods. A
1893// body written once over `R: RowJet<K>` then instantiates at a scalar jet for the
1894// `(v, g, H)` / contracted-tensor channels AND at a lane tower for the batch.
1895
1896/// The verdict of a per-lane [`RowJet::guard`] domain check.
1897///
1898/// A scalar jet (a [`crate::jet_scalar::JetScalar`] via the blanket impl) carries
1899/// ONE value, so it reports `lanes == 1` and a one-bit mask. A lane tower
1900/// ([`Tower3Lane`] / [`Tower4Lane`] over `f64x4`) carries FOUR rows, so it reports
1901/// `lanes == 4` and one mask bit per lane. The mask lets a batched program bail
1902/// exactly the offending 4-group to the scalar tail ([`any_failed`](Self::any_failed)),
1903/// or inspect which lanes tripped ([`lane_failed`](Self::lane_failed)).
1904#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1905pub struct GuardVerdict {
1906 lanes: u8,
1907 failed_mask: u8,
1908}
1909
1910impl GuardVerdict {
1911 /// A scalar (1-lane) verdict: `pass == true` ⇒ no failure.
1912 #[inline]
1913 pub fn scalar(pass: bool) -> Self {
1914 Self { lanes: 1, failed_mask: if pass { 0 } else { 1 } }
1915 }
1916 /// A 4-lane verdict from a per-lane failure mask (bit `i` ⇒ lane `i` failed).
1917 #[inline]
1918 pub fn lanes4(failed_mask: u8) -> Self {
1919 Self { lanes: 4, failed_mask: failed_mask & 0x0f }
1920 }
1921 /// Number of active lanes inspected (1 scalar, 4 batch).
1922 #[inline]
1923 pub fn lanes(self) -> usize {
1924 self.lanes as usize
1925 }
1926 /// True iff every inspected lane satisfied the predicate.
1927 #[inline]
1928 pub fn all_pass(self) -> bool {
1929 self.failed_mask == 0
1930 }
1931 /// True iff at least one inspected lane failed the predicate.
1932 #[inline]
1933 pub fn any_failed(self) -> bool {
1934 self.failed_mask != 0
1935 }
1936 /// True iff lane `i` failed the predicate.
1937 #[inline]
1938 pub fn lane_failed(self, i: usize) -> bool {
1939 (self.failed_mask >> i) & 1 == 1
1940 }
1941 /// The raw failure mask (bit `i` ⇒ lane `i` failed).
1942 #[inline]
1943 pub fn failed_mask(self) -> u8 {
1944 self.failed_mask
1945 }
1946}
1947
1948/// Copy-or-zero-pad a derivative stack from length `N` to length `M`. Used by the
1949/// [`RowJet::compose_unary_with`] impls to bridge a program's chosen stack length
1950/// to each tower's native compose width ([`Tower4Lane`]: 5, [`Tower3Lane`]: 4).
1951/// `M ≥ N` zero-pads the unseeded high derivatives; `M < N` drops the unused tail
1952/// — both total, so the order-`(M−1)` tower reads exactly the channels it needs
1953/// and never an uninitialised entry. With `N == M` it is a verbatim copy (the
1954/// common `N == 5` case is bit-identical to passing the stack straight through).
1955#[inline]
1956fn resize_stack<const N: usize, const M: usize>(s: [f64; N]) -> [f64; M] {
1957 let mut out = [0.0_f64; M];
1958 let m = N.min(M);
1959 out[..m].copy_from_slice(&s[..m]);
1960 out
1961}
1962
1963/// The shared row-NLL algebra over BOTH the scalar jets and the `f64x4` lane
1964/// towers — the bound that lets ONE single-source row-NLL body SIMD-batch 4
1965/// rows/pass without a dual-source copy (module §"The RowJet bridge").
1966///
1967/// Every scalar [`crate::jet_scalar::JetScalar<K>`] is a `RowJet<K>` via the
1968/// blanket impl below (`Value = f64`), bit-identically to its `JetScalar`
1969/// methods; [`Tower3Lane`] / [`Tower4Lane`] over `f64x4` are `RowJet<K>` with
1970/// `Value = [f64; 4]`, routing through their per-lane methods so lane `i` of a
1971/// batched evaluation is `to_bits`-identical to the scalar evaluation on row `i`.
1972pub trait RowJet<const K: usize>: Copy {
1973 /// The value channel(s) seen by [`guard`](Self::guard) and
1974 /// [`values`](Self::values): a single `f64` on a scalar jet, `[f64; 4]` on an
1975 /// `f64x4` lane tower.
1976 type Value: Copy;
1977
1978 /// A constant (value `c`, all derivatives zero), broadcast to every lane.
1979 fn constant(c: f64) -> Self;
1980 /// The seeded primary `slot` at value `x` (unit first derivative in `slot`),
1981 /// broadcast to every lane. Per-lane-DISTINCT seeding for the batch path is
1982 /// done by the lane instantiators ([`generic_batched_fourth_tower`] /
1983 /// [`generic_batched_third_tower`]), which build the tower variables directly
1984 /// from each row's primaries; this method is for any row-invariant auxiliary
1985 /// variable a body introduces.
1986 fn variable(x: f64, slot: usize) -> Self;
1987 /// The value channel(s): `f64` (scalar) or `[f64; 4]` (lane).
1988 fn values(&self) -> Self::Value;
1989
1990 /// Truncated Leibniz `self + o`.
1991 fn add(&self, o: &Self) -> Self;
1992 /// Truncated Leibniz `self − o`.
1993 fn sub(&self, o: &Self) -> Self;
1994 /// Truncated Leibniz `self · o`.
1995 fn mul(&self, o: &Self) -> Self;
1996 /// Multiply every channel by the plain scalar `s`.
1997 fn scale(&self, s: f64) -> Self;
1998 /// Negate every channel. Defaults to `scale(-1.0)`; the blanket overrides it
1999 /// to delegate to [`crate::jet_scalar::JetScalar::neg`].
2000 fn neg(&self) -> Self {
2001 self.scale(-1.0)
2002 }
2003
2004 /// Faà di Bruno compose with a unary special function whose `[f64; N]`
2005 /// derivative stack is built from the running base value PER LANE through
2006 /// `stack_fn`. This is the SHARED-TRAIT version of the `compose_unary_with`
2007 /// inherent method that already exists on both the scalar towers and the lane
2008 /// towers: on a scalar jet `stack_fn` is run once at the value; on an `f64x4`
2009 /// lane tower it is re-run per lane (the four rows carry four distinct base
2010 /// values), so lane `i` is `to_bits`-identical to the scalar result on row `i`.
2011 /// Making it a trait method is precisely what lets a body written once over
2012 /// `R: RowJet<K>` instantiate at the batch towers. `N` is widened/narrowed to
2013 /// the tower's native width by [`resize_stack`] (`N == 5` is a verbatim copy).
2014 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self;
2015
2016 /// Per-lane domain guard: evaluate `pred` on each active lane's value channel
2017 /// and report which lanes failed (see [`GuardVerdict`]). A scalar jet checks
2018 /// its one value; a lane tower checks all four. Lets a batched program detect
2019 /// an out-of-domain row in a 4-group and bail that group to the scalar tail.
2020 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict;
2021
2022 /// Per-lane scale: multiply every channel by the per-lane factor `s`
2023 /// ([`Self::Value`]). On a scalar jet `Self::Value = f64`, so this is exactly
2024 /// [`scale`](Self::scale) and the scalar call sites stay BIT-IDENTICAL when
2025 /// `.scale(x)` is rewritten to `.scale_rows(x)`; on an `f64x4` lane tower
2026 /// `Self::Value = [f64; 4]` and lane `i` is multiplied by `s[i]`. This is the
2027 /// primitive that lets a batched body carry CONTINUOUS per-row data — the
2028 /// survival `covariance_ones` / `z_sum` / observation-weight `wi` factors that
2029 /// enter the jet algebra as `.scale(per_row_value)` and that the single-`f64`
2030 /// [`scale`](Self::scale) would broadcast wrongly across the four rows. Build
2031 /// `s` from the lane→row map with [`pack_rows`](Self::pack_rows).
2032 fn scale_rows(&self, s: Self::Value) -> Self;
2033
2034 /// Gather a per-lane auxiliary datum from the lane→row map `rows`: `value_of(r)`
2035 /// is evaluated for each active lane's row and packed into [`Self::Value`] (a
2036 /// single `f64` on a scalar jet, `[f64; 4]` on an `f64x4` lane tower). This is
2037 /// how a body written once over [`RowJet`] feeds per-row CONTINUOUS data (the
2038 /// arguments to [`scale_rows`](Self::scale_rows)) into the batch path without
2039 /// knowing the concrete representation: the program holds the per-row data and
2040 /// the caller threads `rows` (length 1 scalar, length 4 batch) into
2041 /// [`RowNllProgramRowJet::row_nll`], so the body writes
2042 /// `x.scale_rows(R::pack_rows(rows, |r| self.cov(r)))`. A multiplicative weight
2043 /// buried in a `compose_unary_with` stack is pulled out the same way:
2044 /// `x.compose_unary_with(|u| stack(u, 1.0)).scale_rows(R::pack_rows(rows, |r| self.wi(r)))`.
2045 /// (Binary per-row branches such as the event indicator `di` are kept
2046 /// lane-uniform by grouping and the [`guard`](Self::guard) bail, not packed.)
2047 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> Self::Value;
2048
2049 // ── value-derived transcendental conveniences ───────────────────────
2050 // Each routes through `compose_unary_with` with the SAME derivative stack the
2051 // corresponding `JetScalar` method uses, so on a scalar jet (blanket) the
2052 // result is bit-identical to the `JetScalar` method, and on a lane tower lane
2053 // `i` is bit-identical to the scalar result on row `i`.
2054
2055 /// `e^self`.
2056 fn exp(&self) -> Self {
2057 self.compose_unary_with(|u| {
2058 let e = u.exp();
2059 [e, e, e, e, e]
2060 })
2061 }
2062 /// `ln(self)`. Caller guarantees positivity.
2063 fn ln(&self) -> Self {
2064 self.compose_unary_with(|u| {
2065 let r = 1.0 / u;
2066 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
2067 })
2068 }
2069 /// `√self`. Caller guarantees positivity.
2070 fn sqrt(&self) -> Self {
2071 self.compose_unary_with(|u| {
2072 let s = u.sqrt();
2073 [s, 0.5 / s, -0.25 / (u * s), 0.375 / (u * u * s), -0.9375 / (u * u * u * s)]
2074 })
2075 }
2076 /// `1/self`.
2077 fn recip(&self) -> Self {
2078 self.compose_unary_with(|u| {
2079 let r = 1.0 / u;
2080 let r2 = r * r;
2081 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
2082 })
2083 }
2084 /// `self^a` for real `a`. Caller guarantees a positive base.
2085 fn powf(&self, a: f64) -> Self {
2086 self.compose_unary_with(move |u| {
2087 [
2088 u.powf(a),
2089 a * u.powf(a - 1.0),
2090 a * (a - 1.0) * u.powf(a - 2.0),
2091 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
2092 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
2093 ]
2094 })
2095 }
2096 /// `ln Γ(self)`. Caller guarantees a positive argument.
2097 fn ln_gamma(&self) -> Self {
2098 self.compose_unary_with(ln_gamma_derivative_stack)
2099 }
2100 /// `ψ(self)` (digamma). Caller guarantees a positive argument.
2101 fn digamma(&self) -> Self {
2102 self.compose_unary_with(digamma_derivative_stack)
2103 }
2104}
2105
2106/// Blanket: every scalar [`crate::jet_scalar::JetScalar<K>`] is a [`RowJet<K>`]
2107/// with `Value = f64`. Each op delegates to the identical `JetScalar` method, so
2108/// the existing scalar call sites compile UNCHANGED and bit-identically — the
2109/// bridge adds the lane representation without churning the scalar path. (The
2110/// concrete lane impls below cannot overlap this: [`Tower3Lane`] / [`Tower4Lane`]
2111/// are local types that do not implement `JetScalar`, and the orphan rule forbids
2112/// any downstream impl, so the coherence checker proves the impls disjoint.)
2113impl<const K: usize, S: crate::jet_scalar::JetScalar<K>> RowJet<K> for S {
2114 type Value = f64;
2115 #[inline]
2116 fn constant(c: f64) -> Self {
2117 <S as crate::jet_scalar::JetScalar<K>>::constant(c)
2118 }
2119 #[inline]
2120 fn variable(x: f64, slot: usize) -> Self {
2121 <S as crate::jet_scalar::JetScalar<K>>::variable(x, slot)
2122 }
2123 #[inline]
2124 fn values(&self) -> f64 {
2125 crate::jet_scalar::JetScalar::value(self)
2126 }
2127 #[inline]
2128 fn add(&self, o: &Self) -> Self {
2129 crate::jet_scalar::JetScalar::add(self, o)
2130 }
2131 #[inline]
2132 fn sub(&self, o: &Self) -> Self {
2133 crate::jet_scalar::JetScalar::sub(self, o)
2134 }
2135 #[inline]
2136 fn mul(&self, o: &Self) -> Self {
2137 crate::jet_scalar::JetScalar::mul(self, o)
2138 }
2139 #[inline]
2140 fn scale(&self, s: f64) -> Self {
2141 crate::jet_scalar::JetScalar::scale(self, s)
2142 }
2143 #[inline]
2144 fn neg(&self) -> Self {
2145 crate::jet_scalar::JetScalar::neg(self)
2146 }
2147 #[inline]
2148 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2149 crate::jet_scalar::JetScalar::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2150 }
2151 #[inline]
2152 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2153 GuardVerdict::scalar(pred(crate::jet_scalar::JetScalar::value(self)))
2154 }
2155 #[inline]
2156 fn scale_rows(&self, s: f64) -> Self {
2157 // `Value == f64`, so per-lane scale is exactly `scale` — the rewrite
2158 // `.scale(x)` → `.scale_rows(x)` is bit-identical on the scalar path.
2159 crate::jet_scalar::JetScalar::scale(self, s)
2160 }
2161 #[inline]
2162 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> f64 {
2163 value_of(rows[0])
2164 }
2165}
2166
2167/// The `f64x4` lane [`Tower4Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2168/// routing each op through its existing per-lane method. Lane `i` of a batched
2169/// evaluation is `to_bits`-identical to the scalar [`Tower4`] evaluation on row
2170/// `i` (the per-lane methods are term-for-term lifts of the scalar tower).
2171impl<const K: usize> RowJet<K> for Tower4Lane<wide::f64x4, K> {
2172 type Value = [f64; 4];
2173 #[inline]
2174 fn constant(c: f64) -> Self {
2175 Tower4Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2176 }
2177 #[inline]
2178 fn variable(x: f64, slot: usize) -> Self {
2179 Tower4Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2180 }
2181 #[inline]
2182 fn values(&self) -> [f64; 4] {
2183 self.v.to_array()
2184 }
2185 #[inline]
2186 fn add(&self, o: &Self) -> Self {
2187 Tower4Lane::add(self, o)
2188 }
2189 #[inline]
2190 fn sub(&self, o: &Self) -> Self {
2191 Tower4Lane::sub(self, o)
2192 }
2193 #[inline]
2194 fn mul(&self, o: &Self) -> Self {
2195 Tower4Lane::mul(self, o)
2196 }
2197 #[inline]
2198 fn scale(&self, s: f64) -> Self {
2199 Tower4Lane::scale(self, s)
2200 }
2201 #[inline]
2202 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2203 Tower4Lane::compose_unary_with(self, |u| resize_stack::<N, 5>(stack_fn(u)))
2204 }
2205 #[inline]
2206 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2207 let vals = self.v.to_array();
2208 let mut mask = 0u8;
2209 for (i, &v) in vals.iter().enumerate() {
2210 if !pred(v) {
2211 mask |= 1 << i;
2212 }
2213 }
2214 GuardVerdict::lanes4(mask)
2215 }
2216 #[inline]
2217 fn scale_rows(&self, s: [f64; 4]) -> Self {
2218 // True per-lane scale: lane `i` of every channel is multiplied by `s[i]`,
2219 // so lane `i` matches the scalar `Tower4::scale(s[i])` on row `i`.
2220 let sl = wide::f64x4::new(s);
2221 let mut out = *self;
2222 out.v = self.v * sl;
2223 for i in 0..K {
2224 out.g[i] = self.g[i] * sl;
2225 for j in 0..K {
2226 out.h[i][j] = self.h[i][j] * sl;
2227 for k in 0..K {
2228 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2229 for l in 0..K {
2230 out.t4[i][j][k][l] = self.t4[i][j][k][l] * sl;
2231 }
2232 }
2233 }
2234 }
2235 out
2236 }
2237 #[inline]
2238 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2239 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2240 }
2241}
2242
2243/// The `f64x4` lane [`Tower3Lane`] is a [`RowJet<K>`] with `Value = [f64; 4]`,
2244/// the order-≤3 sibling of the [`Tower4Lane`] impl. A body that uses `N == 5`
2245/// stacks drops the (unused) fourth-derivative entry here, matching the scalar
2246/// [`Tower3`] which also carries only up to the third tensor.
2247impl<const K: usize> RowJet<K> for Tower3Lane<wide::f64x4, K> {
2248 type Value = [f64; 4];
2249 #[inline]
2250 fn constant(c: f64) -> Self {
2251 Tower3Lane::constant(<wide::f64x4 as crate::jet_scalar::Lane>::splat(c))
2252 }
2253 #[inline]
2254 fn variable(x: f64, slot: usize) -> Self {
2255 Tower3Lane::variable(<wide::f64x4 as crate::jet_scalar::Lane>::splat(x), slot)
2256 }
2257 #[inline]
2258 fn values(&self) -> [f64; 4] {
2259 self.v.to_array()
2260 }
2261 #[inline]
2262 fn add(&self, o: &Self) -> Self {
2263 Tower3Lane::add(self, o)
2264 }
2265 #[inline]
2266 fn sub(&self, o: &Self) -> Self {
2267 Tower3Lane::sub(self, o)
2268 }
2269 #[inline]
2270 fn mul(&self, o: &Self) -> Self {
2271 Tower3Lane::mul(self, o)
2272 }
2273 #[inline]
2274 fn scale(&self, s: f64) -> Self {
2275 Tower3Lane::scale(self, s)
2276 }
2277 #[inline]
2278 fn compose_unary_with<const N: usize>(&self, stack_fn: impl Fn(f64) -> [f64; N]) -> Self {
2279 Tower3Lane::compose_unary_with(self, |u| resize_stack::<N, 4>(stack_fn(u)))
2280 }
2281 #[inline]
2282 fn guard(&self, pred: impl Fn(f64) -> bool) -> GuardVerdict {
2283 let vals = self.v.to_array();
2284 let mut mask = 0u8;
2285 for (i, &v) in vals.iter().enumerate() {
2286 if !pred(v) {
2287 mask |= 1 << i;
2288 }
2289 }
2290 GuardVerdict::lanes4(mask)
2291 }
2292 #[inline]
2293 fn scale_rows(&self, s: [f64; 4]) -> Self {
2294 let sl = wide::f64x4::new(s);
2295 let mut out = *self;
2296 out.v = self.v * sl;
2297 for i in 0..K {
2298 out.g[i] = self.g[i] * sl;
2299 for j in 0..K {
2300 out.h[i][j] = self.h[i][j] * sl;
2301 for k in 0..K {
2302 out.t3[i][j][k] = self.t3[i][j][k] * sl;
2303 }
2304 }
2305 }
2306 out
2307 }
2308 #[inline]
2309 fn pack_rows(rows: &[usize], value_of: impl Fn(usize) -> f64) -> [f64; 4] {
2310 [value_of(rows[0]), value_of(rows[1]), value_of(rows[2]), value_of(rows[3])]
2311 }
2312}
2313
2314/// A family's row negative log-likelihood written ONCE over the [`RowJet`]
2315/// bridge, so the SAME body instantiates at the scalar jets (for the `(v, g, H)`
2316/// and contracted-tensor channels) AND at the `f64x4` lane towers (for the
2317/// 4-rows-per-pass SIMD batch). This is the lane-capable successor to
2318/// [`RowNllProgramGeneric`]: a body written here gets the scalar channels through
2319/// [`rowjet_row_kernel`] / [`rowjet_third_contracted`] / [`rowjet_fourth_contracted`]
2320/// and the batched channels through [`generic_batched_fourth_tower`] /
2321/// [`generic_batched_third_tower`], all from a single source.
2322pub trait RowNllProgramRowJet<const K: usize>: Send + Sync {
2323 /// Number of observations the program covers.
2324 fn n_rows(&self) -> usize;
2325
2326 /// Current primary-scalar values for `row` (where to seed each lane).
2327 fn primaries(&self, row: usize) -> Result<[f64; K], String>;
2328
2329 /// The row NLL evaluated on the [`RowJet`] bridge. `rows` is the lane→row map
2330 /// (length 1 for a scalar instantiation, length 4 for a batch); `p[a]` arrives
2331 /// pre-seeded by the caller (base value plus, for the directional scalars, the
2332 /// nilpotent contraction directions). The body uses ONLY [`RowJet`] ops and
2333 /// per-row data entering through `rows`/`self` as constants.
2334 fn row_nll<R: RowJet<K>>(&self, rows: &[usize], p: &[R; K]) -> Result<R, String>;
2335}
2336
2337/// Evaluate a [`RowNllProgramRowJet`] at the value/gradient/Hessian scalar
2338/// [`crate::jet_scalar::Order2`] (the `(v, g, H)` inner-Newton channel) — the
2339/// `RowJet` twin of [`generic_row_kernel`].
2340pub fn rowjet_row_kernel<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2341 prog: &P,
2342 row: usize,
2343) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
2344 let base = prog.primaries(row)?;
2345 let vars: [crate::jet_scalar::Order2<K>; K] =
2346 std::array::from_fn(|a| <crate::jet_scalar::Order2<K> as RowJet<K>>::variable(base[a], a));
2347 let s = prog.row_nll(&[row], &vars)?;
2348 Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
2349}
2350
2351/// Evaluate a [`RowNllProgramRowJet`] at the one-seed scalar
2352/// [`crate::jet_scalar::OneSeed`], returning the contracted third
2353/// `Σ_c ℓ_{abc} dir_c` — the `RowJet` twin of [`generic_third_contracted`].
2354pub fn rowjet_third_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2355 prog: &P,
2356 row: usize,
2357 dir: &[f64; K],
2358) -> Result<[[f64; K]; K], String> {
2359 let base = prog.primaries(row)?;
2360 let vars: [crate::jet_scalar::OneSeed<K>; K] =
2361 std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
2362 let s = prog.row_nll(&[row], &vars)?;
2363 Ok(s.contracted_third())
2364}
2365
2366/// Evaluate a [`RowNllProgramRowJet`] at the two-seed scalar
2367/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
2368/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `RowJet` twin of [`generic_fourth_contracted`].
2369pub fn rowjet_fourth_contracted<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2370 prog: &P,
2371 row: usize,
2372 dir_u: &[f64; K],
2373 dir_v: &[f64; K],
2374) -> Result<[[f64; K]; K], String> {
2375 let base = prog.primaries(row)?;
2376 let vars: [crate::jet_scalar::TwoSeed<K>; K] =
2377 std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
2378 let s = prog.row_nll(&[row], &vars)?;
2379 Ok(s.contracted_fourth())
2380}
2381
2382/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower4Batch`],
2383/// computing the FULL `(v, g, H, t3, t4)` for FOUR rows in one SIMD pass — the
2384/// lane twin of [`generic_full_tower`]. Each of the four lanes is seeded with its
2385/// own row's primaries, so [`Tower4Batch::lane`]`(i)` is `to_bits`-identical to
2386/// the scalar [`generic_full_tower`] on `rows[i]`.
2387pub fn generic_batched_fourth_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2388 prog: &P,
2389 rows: [usize; 4],
2390) -> Result<Tower4Batch<K>, String> {
2391 let bases: [[f64; K]; 4] = [
2392 prog.primaries(rows[0])?,
2393 prog.primaries(rows[1])?,
2394 prog.primaries(rows[2])?,
2395 prog.primaries(rows[3])?,
2396 ];
2397 let vars: [Tower4Batch<K>; K] = std::array::from_fn(|a| {
2398 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2399 Tower4Batch::variable(lane_vals, a)
2400 });
2401 prog.row_nll(&rows, &vars)
2402}
2403
2404/// Evaluate a [`RowNllProgramRowJet`] at the `f64x4` lane [`Tower3Batch`],
2405/// computing `(v, g, H, t3)` for FOUR rows in one SIMD pass — the order-≤3 lane
2406/// twin of [`generic_full_tower`]. [`Tower3Batch::lane`]`(i)` is
2407/// `to_bits`-identical to the order-≤3 scalar evaluation on `rows[i]`.
2408pub fn generic_batched_third_tower<const K: usize, P: RowNllProgramRowJet<K> + ?Sized>(
2409 prog: &P,
2410 rows: [usize; 4],
2411) -> Result<Tower3Batch<K>, String> {
2412 let bases: [[f64; K]; 4] = [
2413 prog.primaries(rows[0])?,
2414 prog.primaries(rows[1])?,
2415 prog.primaries(rows[2])?,
2416 prog.primaries(rows[3])?,
2417 ];
2418 let vars: [Tower3Batch<K>; K] = std::array::from_fn(|a| {
2419 let lane_vals = wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]);
2420 Tower3Batch::variable(lane_vals, a)
2421 });
2422 prog.row_nll(&rows, &vars)
2423}
2424
2425// ── The oracle ───────────────────────────────────────────────────────
2426
2427/// One row's worth of hand-written kernel outputs, as claimed by a
2428/// `RowKernel` implementation, packaged for verification against the
2429/// tower truth. Plain data (no trait coupling) so any kernel — whatever
2430/// its visibility — can be audited from its own test module.
2431pub struct KernelChannels<const K: usize> {
2432 /// Claimed `(nll, ∇, H)` from `row_kernel`.
2433 pub value: f64,
2434 /// Claimed gradient.
2435 pub gradient: [f64; K],
2436 /// Claimed Hessian.
2437 pub hessian: [[f64; K]; K],
2438 /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
2439 pub third: Vec<([f64; K], [[f64; K]; K])>,
2440 /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
2441 pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
2442}
2443
2444/// Channel-by-channel audit of a hand-written kernel against the
2445/// single-expression tower truth. Returns `Err` naming the first channel,
2446/// index, claimed and true values on disagreement — designed as the body
2447/// of the per-family CI oracle tests (#932 deployment step 2).
2448///
2449/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
2450/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
2451/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
2452/// towers pass without demanding bit-equality, while a tiny cross-block
2453/// entry dropped next to a huge one is still caught (it is NOT measured
2454/// against the largest entry of the whole channel — there is no per-channel
2455/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
2456///
2457/// Non-finite handling is strict: a NaN on either side always fails; an
2458/// infinity passes only when both sides are the SAME signed infinity.
2459pub fn verify_kernel_channels<const K: usize>(
2460 tower: &Tower4<K>,
2461 claims: &KernelChannels<K>,
2462 rel_tol: f64,
2463) -> Result<(), String> {
2464 // Absolute floor: reuse rel_tol so a single knob controls both the
2465 // relative band and the absolute floor for entries near zero.
2466 let atol = rel_tol;
2467 let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
2468 // Non-finite values never silently pass the algebraic comparison
2469 // below (any comparison with NaN is false). Handle them explicitly:
2470 // NaN on either side always errs; an infinity passes only if both
2471 // sides are the identical signed infinity.
2472 if !claim.is_finite() || !truth.is_finite() {
2473 let agree = claim.is_infinite()
2474 && truth.is_infinite()
2475 && claim.is_sign_positive() == truth.is_sign_positive();
2476 if agree {
2477 return Ok(());
2478 }
2479 return Err(format!(
2480 "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
2481 ));
2482 }
2483 let band = atol + rel_tol * claim.abs().max(truth.abs());
2484 if (claim - truth).abs() > band {
2485 return Err(format!(
2486 "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
2487 ));
2488 }
2489 Ok(())
2490 };
2491
2492 check("value", claims.value, tower.v)?;
2493
2494 for a in 0..K {
2495 check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
2496 }
2497
2498 for a in 0..K {
2499 for b in 0..K {
2500 check(
2501 &format!("hessian[{a}][{b}]"),
2502 claims.hessian[a][b],
2503 tower.h[a][b],
2504 )?;
2505 }
2506 }
2507
2508 for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
2509 let truth = tower.third_contracted(dir);
2510 for a in 0..K {
2511 for b in 0..K {
2512 check(
2513 &format!("third[{t_idx}][{a}][{b}]"),
2514 claim[a][b],
2515 truth[a][b],
2516 )?;
2517 }
2518 }
2519 }
2520
2521 for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
2522 let truth = tower.fourth_contracted(u, w);
2523 for a in 0..K {
2524 for b in 0..K {
2525 check(
2526 &format!("fourth[{f_idx}][{a}][{b}]"),
2527 claim[a][b],
2528 truth[a][b],
2529 )?;
2530 }
2531 }
2532 }
2533
2534 Ok(())
2535}
2536
2537// ===========================================================================
2538// SIMD row-batched towers (#1151 follow-up): Tower3Lane / Tower4Lane
2539// ===========================================================================
2540//
2541// `Tower{3,4}Lane<L: Lane, K>` re-type every channel of `Tower{3,4}<K>` from a
2542// scalar `f64` to a SIMD lane field `L`. With `L = wide::f64x4` one instance
2543// carries FOUR rows at once, so a per-row kernel (BMS `row_nll`, survival
2544// `row_kernel`, `marginal_slope` `build_row_*_towers`) can evaluate 4 rows per
2545// vector pass instead of one per scalar pass.
2546//
2547// Every floating-point op is a DIRECT, term-for-term lift of the scalar
2548// `Tower{3,4}<K>` body — `a * b` -> `a.mul(b)`, `a + b` -> `a.add(b)`, a literal
2549// `c` -> `L::splat(c)` — in the SAME accumulation order. `wide::f64x4`
2550// add/sub/mul are lane-wise IEEE-754 ops with NO fused-multiply-add (Rust
2551// performs no fp-contraction), so lane `i` of any channel of a
2552// `Tower{3,4}Lane<wide::f64x4, K>` is `to_bits`-IDENTICAL to the scalar
2553// `Tower{3,4}<K>` channel computed on row `i` — exactly the structural
2554// bit-identity the existing [`crate::jet_scalar::Order2Lane`] relies on. Proven
2555// by the in-tree `batch_tests` (real `wide::f64x4`) and a standalone
2556// f64x4-model oracle, `K ∈ {2,3,4,9}`.
2557//
2558// Only the pure-arithmetic ops are lifted (the transcendental `exp`/`ln`/`sqrt`/
2559// `…` route through scalar libm, which has no `f64x4` form; consumers build the
2560// per-lane derivative stack scalar-side and feed it to `compose_unary([L; _])`,
2561// exactly as the scalar path already does).
2562
2563use crate::jet_scalar::Lane;
2564
2565/// Lane-batched [`Tower4`]: value / gradient / Hessian / 3rd / 4th tensors
2566/// carried in a SIMD field `L`. `Tower4Lane<f64x4, K>` lane `i` is
2567/// `to_bits`-identical to [`Tower4<K>`] on row `i`.
2568#[derive(Clone, Copy)]
2569pub struct Tower4Lane<L: Lane, const K: usize> {
2570 /// Value channel (one entry per lane/row).
2571 pub v: L,
2572 /// Gradient `∂/∂p_a`.
2573 pub g: [L; K],
2574 /// Hessian `∂²/∂p_a∂p_b`.
2575 pub h: [[L; K]; K],
2576 /// Third tensor `∂³`.
2577 pub t3: [[[L; K]; K]; K],
2578 /// Fourth tensor `∂⁴`.
2579 pub t4: [[[[L; K]; K]; K]; K],
2580}
2581
2582/// The 4-rows-per-pass batched [`Tower4`] (`wide::f64x4` lanes).
2583pub type Tower4Batch<const K: usize> = Tower4Lane<wide::f64x4, K>;
2584
2585impl<L: Lane, const K: usize> Tower4Lane<L, K> {
2586 /// All-zero tower (every channel `+0.0` in every lane).
2587 #[inline]
2588 pub fn zero() -> Self {
2589 let z = L::splat(0.0);
2590 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K], t4: [[[[z; K]; K]; K]; K] }
2591 }
2592 /// Constant `c` (per lane): value channel only.
2593 #[inline]
2594 pub fn constant(c: L) -> Self {
2595 let mut o = Self::zero();
2596 o.v = c;
2597 o
2598 }
2599 /// Seeded variable `p_idx` at per-lane `value`: unit first derivative in
2600 /// slot `idx` (mirrors [`Tower4::variable`]).
2601 #[inline]
2602 pub fn variable(value: L, idx: usize) -> Self {
2603 let mut o = Self::constant(value);
2604 o.g[idx] = L::splat(1.0);
2605 o
2606 }
2607 /// Extract lane `i` as a scalar [`Tower4<K>`] (channel-for-channel).
2608 #[inline]
2609 pub fn lane(&self, i: usize) -> Tower4<K> {
2610 let mut out = Tower4::<K>::zero();
2611 out.v = self.v.lane(i);
2612 for a in 0..K {
2613 out.g[a] = self.g[a].lane(i);
2614 for b in 0..K {
2615 out.h[a][b] = self.h[a][b].lane(i);
2616 for c in 0..K {
2617 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2618 for d in 0..K {
2619 out.t4[a][b][c][d] = self.t4[a][b][c][d].lane(i);
2620 }
2621 }
2622 }
2623 }
2624 out
2625 }
2626 /// Per-channel lane-wise `self + o` (mirrors `Tower4` `Add`).
2627 #[inline]
2628 pub fn add(&self, o: &Self) -> Self {
2629 let mut out = *self;
2630 out.v = self.v.add(o.v);
2631 for i in 0..K {
2632 out.g[i] = self.g[i].add(o.g[i]);
2633 for j in 0..K {
2634 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2635 for k in 0..K {
2636 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2637 for l in 0..K {
2638 out.t4[i][j][k][l] = self.t4[i][j][k][l].add(o.t4[i][j][k][l]);
2639 }
2640 }
2641 }
2642 }
2643 out
2644 }
2645 /// Per-channel lane-wise `self - o` (mirrors `Tower4` `Sub`).
2646 #[inline]
2647 pub fn sub(&self, o: &Self) -> Self {
2648 let mut out = *self;
2649 out.v = self.v.sub(o.v);
2650 for i in 0..K {
2651 out.g[i] = self.g[i].sub(o.g[i]);
2652 for j in 0..K {
2653 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2654 for k in 0..K {
2655 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2656 for l in 0..K {
2657 out.t4[i][j][k][l] = self.t4[i][j][k][l].sub(o.t4[i][j][k][l]);
2658 }
2659 }
2660 }
2661 }
2662 out
2663 }
2664 /// Multiply every channel by the plain scalar `s` (mirrors `Tower4::scale`).
2665 #[inline]
2666 pub fn scale(&self, s: f64) -> Self {
2667 let sl = L::splat(s);
2668 let mut out = *self;
2669 out.v = self.v.mul(sl);
2670 for i in 0..K {
2671 out.g[i] = self.g[i].mul(sl);
2672 for j in 0..K {
2673 out.h[i][j] = self.h[i][j].mul(sl);
2674 for k in 0..K {
2675 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
2676 for l in 0..K {
2677 out.t4[i][j][k][l] = self.t4[i][j][k][l].mul(sl);
2678 }
2679 }
2680 }
2681 }
2682 out
2683 }
2684 /// Leibniz product `self · o`, term-for-term lift of [`Tower4::mul`].
2685 #[inline]
2686 pub fn mul(&self, o: &Self) -> Self {
2687 let a = self;
2688 let b = o;
2689 let mut out = Self::zero();
2690 out.v = a.v.mul(b.v);
2691 for i in 0..K {
2692 let mut acc = L::splat(0.0);
2693 acc = acc.add(a.v.mul(b.g[i]));
2694 acc = acc.add(a.g[i].mul(b.v));
2695 out.g[i] = acc;
2696 }
2697 for i in 0..K {
2698 for j in 0..K {
2699 let mut acc = L::splat(0.0);
2700 acc = acc.add(a.v.mul(b.h[i][j]));
2701 acc = acc.add(a.g[i].mul(b.g[j]));
2702 acc = acc.add(a.g[j].mul(b.g[i]));
2703 acc = acc.add(a.h[i][j].mul(b.v));
2704 out.h[i][j] = acc;
2705 }
2706 }
2707 for i in 0..K {
2708 for j in 0..K {
2709 for k in 0..K {
2710 let mut acc = L::splat(0.0);
2711 acc = acc.add(a.v.mul(b.t3[i][j][k]));
2712 acc = acc.add(a.g[i].mul(b.h[j][k]));
2713 acc = acc.add(a.g[j].mul(b.h[i][k]));
2714 acc = acc.add(a.h[i][j].mul(b.g[k]));
2715 acc = acc.add(a.g[k].mul(b.h[i][j]));
2716 acc = acc.add(a.h[i][k].mul(b.g[j]));
2717 acc = acc.add(a.h[j][k].mul(b.g[i]));
2718 acc = acc.add(a.t3[i][j][k].mul(b.v));
2719 out.t3[i][j][k] = acc;
2720 }
2721 }
2722 }
2723 for i in 0..K {
2724 for j in 0..K {
2725 for k in 0..K {
2726 for l in 0..K {
2727 let mut acc = L::splat(0.0);
2728 acc = acc.add(a.v.mul(b.t4[i][j][k][l]));
2729 acc = acc.add(a.g[i].mul(b.t3[j][k][l]));
2730 acc = acc.add(a.g[j].mul(b.t3[i][k][l]));
2731 acc = acc.add(a.h[i][j].mul(b.h[k][l]));
2732 acc = acc.add(a.g[k].mul(b.t3[i][j][l]));
2733 acc = acc.add(a.h[i][k].mul(b.h[j][l]));
2734 acc = acc.add(a.h[j][k].mul(b.h[i][l]));
2735 acc = acc.add(a.t3[i][j][k].mul(b.g[l]));
2736 acc = acc.add(a.g[l].mul(b.t3[i][j][k]));
2737 acc = acc.add(a.h[i][l].mul(b.h[j][k]));
2738 acc = acc.add(a.h[j][l].mul(b.h[i][k]));
2739 acc = acc.add(a.t3[i][j][l].mul(b.g[k]));
2740 acc = acc.add(a.h[k][l].mul(b.h[i][j]));
2741 acc = acc.add(a.t3[i][k][l].mul(b.g[j]));
2742 acc = acc.add(a.t3[j][k][l].mul(b.g[i]));
2743 acc = acc.add(a.t4[i][j][k][l].mul(b.v));
2744 out.t4[i][j][k][l] = acc;
2745 }
2746 }
2747 }
2748 }
2749 out
2750 }
2751 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
2752 /// [`Tower4::compose_unary`]. `d = [f, f′, f″, f‴, f⁗]` packed per lane
2753 /// (build via [`Lane::unary5`] from the scalar special-function stack).
2754 #[inline]
2755 pub fn compose_unary(&self, d: [L; 5]) -> Self {
2756 let mut out = Self::zero();
2757 out.v = d[0];
2758 for i in 0..K {
2759 let mut acc = L::splat(0.0);
2760 acc = acc.add(d[1].mul(self.g[i]));
2761 out.g[i] = acc;
2762 }
2763 for i in 0..K {
2764 for j in 0..K {
2765 let mut acc = L::splat(0.0);
2766 acc = acc.add(d[1].mul(self.h[i][j]));
2767 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
2768 out.h[i][j] = acc;
2769 }
2770 }
2771 for i in 0..K {
2772 for j in 0..K {
2773 for k in 0..K {
2774 let mut acc = L::splat(0.0);
2775 acc = acc.add(d[1].mul(self.t3[i][j][k]));
2776 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
2777 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
2778 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
2779 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
2780 out.t3[i][j][k] = acc;
2781 }
2782 }
2783 }
2784 for i in 0..K {
2785 for j in 0..K {
2786 for k in 0..K {
2787 for l in 0..K {
2788 let mut acc = L::splat(0.0);
2789 acc = acc.add(d[1].mul(self.t4[i][j][k][l]));
2790 acc = acc.add(d[2].mul(self.t3[i][j][k]).mul(self.g[l]));
2791 acc = acc.add(d[2].mul(self.t3[i][j][l]).mul(self.g[k]));
2792 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.h[k][l]));
2793 acc = acc.add(d[3].mul(self.h[i][j]).mul(self.g[k]).mul(self.g[l]));
2794 acc = acc.add(d[2].mul(self.t3[i][k][l]).mul(self.g[j]));
2795 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.h[j][l]));
2796 acc = acc.add(d[3].mul(self.h[i][k]).mul(self.g[j]).mul(self.g[l]));
2797 acc = acc.add(d[2].mul(self.h[i][l]).mul(self.h[j][k]));
2798 acc = acc.add(d[2].mul(self.g[i]).mul(self.t3[j][k][l]));
2799 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][k]).mul(self.g[l]));
2800 acc = acc.add(d[3].mul(self.h[i][l]).mul(self.g[j]).mul(self.g[k]));
2801 acc = acc.add(d[3].mul(self.g[i]).mul(self.h[j][l]).mul(self.g[k]));
2802 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.h[k][l]));
2803 acc = acc.add(d[4].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]).mul(self.g[l]));
2804 out.t4[i][j][k][l] = acc;
2805 }
2806 }
2807 }
2808 }
2809 out
2810 }
2811 /// Compose with a unary special-function whose `[f64; 5]` derivative stack is
2812 /// built from the base value through `stack_fn`, evaluated PER LANE — the
2813 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
2814 /// seam (the SIMD twin of [`Tower4::compose_unary_with`]).
2815 ///
2816 /// Each of the four lanes carries a DISTINCT base value, so the scalar
2817 /// `stack_fn` is run once per lane at that lane's own value (via
2818 /// [`Lane::unary_with`]) and the `[f64; 5]` results are packed into `[L; 5]`;
2819 /// the composition is then the existing per-lane [`Self::compose_unary`].
2820 /// Because `unary_with` runs the identical scalar closure per lane and
2821 /// `compose_unary` is a term-for-term lift of the scalar tower, lane `i` of
2822 /// the result is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`
2823 /// — which is exactly what lets a row program written against the scalar
2824 /// [`Tower4::compose_unary_with`] seam re-instantiate, unchanged, at `f64x4`.
2825 #[inline]
2826 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
2827 self.compose_unary(self.v.unary_with(stack_fn))
2828 }
2829
2830 /// Single-active-slot fast path, term-for-term lift of
2831 /// [`Tower4::compose_unary_single_slot`] (only the 5 diagonal channels).
2832 #[inline]
2833 pub fn compose_unary_single_slot(&self, d: [L; 5], slot: usize) -> Self {
2834 let mut out = Self::zero();
2835 let s = slot;
2836 let g = self.g[s];
2837 let h = self.h[s][s];
2838 let t3 = self.t3[s][s][s];
2839 let t4 = self.t4[s][s][s][s];
2840 out.v = d[0];
2841 out.g[s] = {
2842 let mut acc = L::splat(0.0);
2843 acc = acc.add(d[1].mul(g));
2844 acc
2845 };
2846 out.h[s][s] = {
2847 let mut acc = L::splat(0.0);
2848 acc = acc.add(d[1].mul(h));
2849 acc = acc.add(d[2].mul(g).mul(g));
2850 acc
2851 };
2852 out.t3[s][s][s] = {
2853 let mut acc = L::splat(0.0);
2854 acc = acc.add(d[1].mul(t3));
2855 acc = acc.add(d[2].mul(h).mul(g));
2856 acc = acc.add(d[2].mul(h).mul(g));
2857 acc = acc.add(d[2].mul(g).mul(h));
2858 acc = acc.add(d[3].mul(g).mul(g).mul(g));
2859 acc
2860 };
2861 out.t4[s][s][s][s] = {
2862 let mut acc = L::splat(0.0);
2863 acc = acc.add(d[1].mul(t4));
2864 acc = acc.add(d[2].mul(t3).mul(g));
2865 acc = acc.add(d[2].mul(t3).mul(g));
2866 acc = acc.add(d[2].mul(h).mul(h));
2867 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2868 acc = acc.add(d[2].mul(t3).mul(g));
2869 acc = acc.add(d[2].mul(h).mul(h));
2870 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2871 acc = acc.add(d[2].mul(h).mul(h));
2872 acc = acc.add(d[2].mul(g).mul(t3));
2873 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2874 acc = acc.add(d[3].mul(h).mul(g).mul(g));
2875 acc = acc.add(d[3].mul(g).mul(h).mul(g));
2876 acc = acc.add(d[3].mul(g).mul(g).mul(h));
2877 acc = acc.add(d[4].mul(g).mul(g).mul(g).mul(g));
2878 acc
2879 };
2880 out
2881 }
2882 /// Contract `t3` with a primary-space direction (lift of
2883 /// [`Tower4::third_contracted`]).
2884 #[inline]
2885 pub fn third_contracted(&self, dir: &[L; K]) -> [[L; K]; K] {
2886 let mut out = [[L::splat(0.0); K]; K];
2887 for a in 0..K {
2888 for b in 0..K {
2889 let mut acc = L::splat(0.0);
2890 for c in 0..K {
2891 acc = acc.add(self.t3[a][b][c].mul(dir[c]));
2892 }
2893 out[a][b] = acc;
2894 }
2895 }
2896 out
2897 }
2898 /// Contract `t4` with two primary-space directions (lift of
2899 /// [`Tower4::fourth_contracted`]).
2900 #[inline]
2901 pub fn fourth_contracted(&self, u: &[L; K], w: &[L; K]) -> [[L; K]; K] {
2902 let mut out = [[L::splat(0.0); K]; K];
2903 for i in 0..K {
2904 for j in 0..K {
2905 let mut acc = L::splat(0.0);
2906 for k in 0..K {
2907 for l in 0..K {
2908 acc = acc.add(self.t4[i][j][k][l].mul(u[k]).mul(w[l]));
2909 }
2910 }
2911 out[i][j] = acc;
2912 }
2913 }
2914 out
2915 }
2916}
2917
2918/// Lane-batched [`Tower3`] (order-≤3 sibling of [`Tower4Lane`]).
2919#[derive(Clone, Copy)]
2920pub struct Tower3Lane<L: Lane, const K: usize> {
2921 /// Value channel.
2922 pub v: L,
2923 /// Gradient.
2924 pub g: [L; K],
2925 /// Hessian.
2926 pub h: [[L; K]; K],
2927 /// Third tensor.
2928 pub t3: [[[L; K]; K]; K],
2929}
2930
2931/// The 4-rows-per-pass batched [`Tower3`] (`wide::f64x4` lanes).
2932pub type Tower3Batch<const K: usize> = Tower3Lane<wide::f64x4, K>;
2933
2934impl<L: Lane, const K: usize> Tower3Lane<L, K> {
2935 /// All-zero tower.
2936 #[inline]
2937 pub fn zero() -> Self {
2938 let z = L::splat(0.0);
2939 Self { v: z, g: [z; K], h: [[z; K]; K], t3: [[[z; K]; K]; K] }
2940 }
2941 /// Constant `c` (per lane).
2942 #[inline]
2943 pub fn constant(c: L) -> Self {
2944 let mut o = Self::zero();
2945 o.v = c;
2946 o
2947 }
2948 /// Seeded variable `p_idx` at per-lane `value`.
2949 #[inline]
2950 pub fn variable(value: L, idx: usize) -> Self {
2951 let mut o = Self::constant(value);
2952 o.g[idx] = L::splat(1.0);
2953 o
2954 }
2955 /// Extract lane `i` as a scalar [`Tower3<K>`].
2956 #[inline]
2957 pub fn lane(&self, i: usize) -> Tower3<K> {
2958 let mut out = Tower3::<K>::zero();
2959 out.v = self.v.lane(i);
2960 for a in 0..K {
2961 out.g[a] = self.g[a].lane(i);
2962 for b in 0..K {
2963 out.h[a][b] = self.h[a][b].lane(i);
2964 for c in 0..K {
2965 out.t3[a][b][c] = self.t3[a][b][c].lane(i);
2966 }
2967 }
2968 }
2969 out
2970 }
2971 /// Per-channel lane-wise `self + o`.
2972 #[inline]
2973 pub fn add(&self, o: &Self) -> Self {
2974 let mut out = *self;
2975 out.v = self.v.add(o.v);
2976 for i in 0..K {
2977 out.g[i] = self.g[i].add(o.g[i]);
2978 for j in 0..K {
2979 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
2980 for k in 0..K {
2981 out.t3[i][j][k] = self.t3[i][j][k].add(o.t3[i][j][k]);
2982 }
2983 }
2984 }
2985 out
2986 }
2987 /// Per-channel lane-wise `self - o`.
2988 #[inline]
2989 pub fn sub(&self, o: &Self) -> Self {
2990 let mut out = *self;
2991 out.v = self.v.sub(o.v);
2992 for i in 0..K {
2993 out.g[i] = self.g[i].sub(o.g[i]);
2994 for j in 0..K {
2995 out.h[i][j] = self.h[i][j].sub(o.h[i][j]);
2996 for k in 0..K {
2997 out.t3[i][j][k] = self.t3[i][j][k].sub(o.t3[i][j][k]);
2998 }
2999 }
3000 }
3001 out
3002 }
3003 /// Multiply every channel by the plain scalar `s` (mirrors `Tower3::scale`).
3004 #[inline]
3005 pub fn scale(&self, s: f64) -> Self {
3006 let sl = L::splat(s);
3007 let mut out = *self;
3008 out.v = self.v.mul(sl);
3009 for i in 0..K {
3010 out.g[i] = self.g[i].mul(sl);
3011 for j in 0..K {
3012 out.h[i][j] = self.h[i][j].mul(sl);
3013 for k in 0..K {
3014 out.t3[i][j][k] = self.t3[i][j][k].mul(sl);
3015 }
3016 }
3017 }
3018 out
3019 }
3020 /// Leibniz product `self · o`, term-for-term lift of [`Tower3::mul`].
3021 #[inline]
3022 pub fn mul(&self, o: &Self) -> Self {
3023 let a = self;
3024 let b = o;
3025 let mut out = Self::zero();
3026 out.v = a.v.mul(b.v);
3027 for i in 0..K {
3028 let mut acc = L::splat(0.0);
3029 acc = acc.add(a.v.mul(b.g[i]));
3030 acc = acc.add(a.g[i].mul(b.v));
3031 out.g[i] = acc;
3032 }
3033 for i in 0..K {
3034 for j in 0..K {
3035 let mut acc = L::splat(0.0);
3036 acc = acc.add(a.v.mul(b.h[i][j]));
3037 acc = acc.add(a.g[i].mul(b.g[j]));
3038 acc = acc.add(a.g[j].mul(b.g[i]));
3039 acc = acc.add(a.h[i][j].mul(b.v));
3040 out.h[i][j] = acc;
3041 }
3042 }
3043 for i in 0..K {
3044 for j in 0..K {
3045 for k in 0..K {
3046 let mut acc = L::splat(0.0);
3047 acc = acc.add(a.v.mul(b.t3[i][j][k]));
3048 acc = acc.add(a.g[i].mul(b.h[j][k]));
3049 acc = acc.add(a.g[j].mul(b.h[i][k]));
3050 acc = acc.add(a.h[i][j].mul(b.g[k]));
3051 acc = acc.add(a.g[k].mul(b.h[i][j]));
3052 acc = acc.add(a.h[i][k].mul(b.g[j]));
3053 acc = acc.add(a.h[j][k].mul(b.g[i]));
3054 acc = acc.add(a.t3[i][j][k].mul(b.v));
3055 out.t3[i][j][k] = acc;
3056 }
3057 }
3058 }
3059 out
3060 }
3061 /// Faà di Bruno composition `f ∘ self`, term-for-term lift of
3062 /// [`Tower3::compose_unary`]. `d = [f, f′, f″, f‴]` packed per lane.
3063 #[inline]
3064 pub fn compose_unary(&self, d: [L; 4]) -> Self {
3065 let mut out = Self::zero();
3066 out.v = d[0];
3067 for i in 0..K {
3068 let mut acc = L::splat(0.0);
3069 acc = acc.add(d[1].mul(self.g[i]));
3070 out.g[i] = acc;
3071 }
3072 for i in 0..K {
3073 for j in 0..K {
3074 let mut acc = L::splat(0.0);
3075 acc = acc.add(d[1].mul(self.h[i][j]));
3076 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
3077 out.h[i][j] = acc;
3078 }
3079 }
3080 for i in 0..K {
3081 for j in 0..K {
3082 for k in 0..K {
3083 let mut acc = L::splat(0.0);
3084 acc = acc.add(d[1].mul(self.t3[i][j][k]));
3085 acc = acc.add(d[2].mul(self.h[i][j]).mul(self.g[k]));
3086 acc = acc.add(d[2].mul(self.h[i][k]).mul(self.g[j]));
3087 acc = acc.add(d[2].mul(self.g[i]).mul(self.h[j][k]));
3088 acc = acc.add(d[3].mul(self.g[i]).mul(self.g[j]).mul(self.g[k]));
3089 out.t3[i][j][k] = acc;
3090 }
3091 }
3092 }
3093 out
3094 }
3095 /// Compose with a unary special-function whose `[f64; 4]` derivative stack is
3096 /// built from the base value through `stack_fn`, evaluated PER LANE — the
3097 /// batch arm of the generic-over-[`Lane`](crate::jet_scalar::Lane) compose
3098 /// seam (the SIMD twin of [`Tower3::compose_unary_with`], order-≤3 sibling of
3099 /// [`Tower4Lane::compose_unary_with`]). The scalar `stack_fn` is run once per
3100 /// lane at that lane's own base value (via [`Lane::unary_with`]) and packed
3101 /// into `[L; 4]` for the existing per-lane [`Self::compose_unary`], so lane
3102 /// `i` is `to_bits`-identical to `self.lane(i).compose_unary_with(stack_fn)`.
3103 #[inline]
3104 pub fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 4]) -> Self {
3105 self.compose_unary(self.v.unary_with(stack_fn))
3106 }
3107
3108 /// Single-active-slot fast path, term-for-term lift of
3109 /// [`Tower3::compose_unary_single_slot`].
3110 #[inline]
3111 pub fn compose_unary_single_slot(&self, d: [L; 4], slot: usize) -> Self {
3112 let mut out = Self::zero();
3113 let s = slot;
3114 let g = self.g[s];
3115 let h = self.h[s][s];
3116 let t3 = self.t3[s][s][s];
3117 out.v = d[0];
3118 out.g[s] = {
3119 let mut acc = L::splat(0.0);
3120 acc = acc.add(d[1].mul(g));
3121 acc
3122 };
3123 out.h[s][s] = {
3124 let mut acc = L::splat(0.0);
3125 acc = acc.add(d[1].mul(h));
3126 acc = acc.add(d[2].mul(g).mul(g));
3127 acc
3128 };
3129 out.t3[s][s][s] = {
3130 let mut acc = L::splat(0.0);
3131 acc = acc.add(d[1].mul(t3));
3132 acc = acc.add(d[2].mul(h).mul(g));
3133 acc = acc.add(d[2].mul(h).mul(g));
3134 acc = acc.add(d[2].mul(g).mul(h));
3135 acc = acc.add(d[3].mul(g).mul(g).mul(g));
3136 acc
3137 };
3138 out
3139 }
3140}
3141
3142#[cfg(test)]
3143mod batch_tests {
3144 use super::*;
3145
3146 struct Rng(u64);
3147 impl Rng {
3148 fn f(&mut self) -> f64 {
3149 self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
3150 ((self.0 >> 11) as f64 / (1u64 << 53) as f64) * 4.0 - 2.0
3151 }
3152 }
3153
3154 // Fill every channel of a scalar Tower4<K> with random data.
3155 fn rand_t4<const K: usize>(r: &mut Rng) -> Tower4<K> {
3156 let mut t = Tower4::<K>::zero();
3157 t.v = r.f();
3158 for i in 0..K {
3159 t.g[i] = r.f();
3160 for j in 0..K {
3161 t.h[i][j] = r.f();
3162 for k in 0..K {
3163 t.t3[i][j][k] = r.f();
3164 for l in 0..K {
3165 t.t4[i][j][k][l] = r.f();
3166 }
3167 }
3168 }
3169 }
3170 t
3171 }
3172 fn rand_t3<const K: usize>(r: &mut Rng) -> Tower3<K> {
3173 let mut t = Tower3::<K>::zero();
3174 t.v = r.f();
3175 for i in 0..K {
3176 t.g[i] = r.f();
3177 for j in 0..K {
3178 t.h[i][j] = r.f();
3179 for k in 0..K {
3180 t.t3[i][j][k] = r.f();
3181 }
3182 }
3183 }
3184 t
3185 }
3186 fn pack4_t4<const K: usize>(rows: &[Tower4<K>; 4]) -> Tower4Batch<K> {
3187 let mut b = Tower4Batch::<K>::zero();
3188 let lane = |f: &dyn Fn(&Tower4<K>) -> f64| {
3189 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3190 };
3191 b.v = lane(&|t| t.v);
3192 for i in 0..K {
3193 b.g[i] = lane(&|t| t.g[i]);
3194 for j in 0..K {
3195 b.h[i][j] = lane(&|t| t.h[i][j]);
3196 for k in 0..K {
3197 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3198 for l in 0..K {
3199 b.t4[i][j][k][l] = lane(&|t| t.t4[i][j][k][l]);
3200 }
3201 }
3202 }
3203 }
3204 b
3205 }
3206 fn pack4_t3<const K: usize>(rows: &[Tower3<K>; 4]) -> Tower3Batch<K> {
3207 let mut b = Tower3Batch::<K>::zero();
3208 let lane = |f: &dyn Fn(&Tower3<K>) -> f64| {
3209 wide::f64x4::new([f(&rows[0]), f(&rows[1]), f(&rows[2]), f(&rows[3])])
3210 };
3211 b.v = lane(&|t| t.v);
3212 for i in 0..K {
3213 b.g[i] = lane(&|t| t.g[i]);
3214 for j in 0..K {
3215 b.h[i][j] = lane(&|t| t.h[i][j]);
3216 for k in 0..K {
3217 b.t3[i][j][k] = lane(&|t| t.t3[i][j][k]);
3218 }
3219 }
3220 }
3221 b
3222 }
3223 fn assert_t4_eq<const K: usize>(b: &Tower4<K>, s: &Tower4<K>, ctx: &str) {
3224 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3225 for i in 0..K {
3226 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3227 for j in 0..K {
3228 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3229 for k in 0..K {
3230 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3231 for l in 0..K {
3232 assert_eq!(b.t4[i][j][k][l].to_bits(), s.t4[i][j][k][l].to_bits(), "t4 {ctx}");
3233 }
3234 }
3235 }
3236 }
3237 }
3238 fn assert_t3_eq<const K: usize>(b: &Tower3<K>, s: &Tower3<K>, ctx: &str) {
3239 assert_eq!(b.v.to_bits(), s.v.to_bits(), "v {ctx}");
3240 for i in 0..K {
3241 assert_eq!(b.g[i].to_bits(), s.g[i].to_bits(), "g {ctx}");
3242 for j in 0..K {
3243 assert_eq!(b.h[i][j].to_bits(), s.h[i][j].to_bits(), "h {ctx}");
3244 for k in 0..K {
3245 assert_eq!(b.t3[i][j][k].to_bits(), s.t3[i][j][k].to_bits(), "t3 {ctx}");
3246 }
3247 }
3248 }
3249 }
3250
3251 // Run a representative op chain on 4 scalar rows and on the f64x4 batch,
3252 // then assert every channel of every lane is to_bits-identical.
3253 fn run4<const K: usize>(seed: u64, batches: usize) -> usize {
3254 let mut r = Rng(seed);
3255 let mut rows_checked = 0;
3256 for _ in 0..batches {
3257 let a: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3258 let b: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3259 let d: [[f64; 5]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3260 let dir: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3261 let dir2: [[f64; K]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3262 let s = r.f();
3263
3264 // scalar per-row reference
3265 let scal: [Tower4<K>; 4] = std::array::from_fn(|rw| {
3266 let prod = a[rw].mul(&b[rw]);
3267 let comp = prod.compose_unary(d[rw]);
3268 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3269 summed.compose_unary_single_slot(d[rw], 0)
3270 });
3271 let third: [[[f64; K]; K]; 4] =
3272 std::array::from_fn(|rw| a[rw].third_contracted(&dir[rw]));
3273 let fourth: [[[f64; K]; K]; 4] =
3274 std::array::from_fn(|rw| a[rw].fourth_contracted(&dir[rw], &dir2[rw]));
3275
3276 // batched f64x4
3277 let ab = pack4_t4(&a);
3278 let bb = pack4_t4(&b);
3279 let db: [wide::f64x4; 5] = std::array::from_fn(|c| {
3280 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3281 });
3282 let dirb: [wide::f64x4; K] = std::array::from_fn(|c| {
3283 wide::f64x4::new([dir[0][c], dir[1][c], dir[2][c], dir[3][c]])
3284 });
3285 let dir2b: [wide::f64x4; K] = std::array::from_fn(|c| {
3286 wide::f64x4::new([dir2[0][c], dir2[1][c], dir2[2][c], dir2[3][c]])
3287 });
3288 let prodb = ab.mul(&bb);
3289 let compb = prodb.compose_unary(db);
3290 let summedb = compb.add(&ab).sub(&bb).scale(s);
3291 let finalb = summedb.compose_unary_single_slot(db, 0);
3292 let thirdb = ab.third_contracted(&dirb);
3293 let fourthb = ab.fourth_contracted(&dirb, &dir2b);
3294
3295 for rw in 0..4 {
3296 assert_t4_eq(&finalb.lane(rw), &scal[rw], "t4-chain");
3297 for i in 0..K {
3298 for j in 0..K {
3299 assert_eq!(thirdb[i][j].lane(rw).to_bits(), third[rw][i][j].to_bits(), "third");
3300 assert_eq!(fourthb[i][j].lane(rw).to_bits(), fourth[rw][i][j].to_bits(), "fourth");
3301 }
3302 }
3303 rows_checked += 1;
3304 }
3305 }
3306 rows_checked
3307 }
3308 fn run3<const K: usize>(seed: u64, batches: usize) -> usize {
3309 let mut r = Rng(seed);
3310 let mut rows_checked = 0;
3311 for _ in 0..batches {
3312 let a: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3313 let b: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3314 let d: [[f64; 4]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| r.f()));
3315 let s = r.f();
3316 let scal: [Tower3<K>; 4] = std::array::from_fn(|rw| {
3317 let prod = a[rw].mul(&b[rw]);
3318 let comp = prod.compose_unary(d[rw]);
3319 let summed = comp.add(&a[rw]).sub(&b[rw]).scale(s);
3320 summed.compose_unary_single_slot(d[rw], 0)
3321 });
3322 let ab = pack4_t3(&a);
3323 let bb = pack4_t3(&b);
3324 let db: [wide::f64x4; 4] = std::array::from_fn(|c| {
3325 wide::f64x4::new([d[0][c], d[1][c], d[2][c], d[3][c]])
3326 });
3327 let prodb = ab.mul(&bb);
3328 let compb = prodb.compose_unary(db);
3329 let summedb = compb.add(&ab).sub(&bb).scale(s);
3330 let finalb = summedb.compose_unary_single_slot(db, 0);
3331 for rw in 0..4 {
3332 assert_t3_eq(&finalb.lane(rw), &scal[rw], "t3-chain");
3333 rows_checked += 1;
3334 }
3335 }
3336 rows_checked
3337 }
3338
3339 // A `Tower4Batch<9>` carries a `9⁴ = 6561`-entry `t4` tensor in 4-wide
3340 // lanes (≈210 KiB by value); the op chain keeps several live, which can
3341 // exceed a test thread's default stack. Run each width on a large-stack
3342 // thread so K=9 is exercised without a stack overflow.
3343 fn big_stack<R: Send + 'static, F: FnOnce() -> R + Send + 'static>(f: F) -> R {
3344 std::thread::Builder::new()
3345 .stack_size(512 << 20)
3346 .spawn(f)
3347 .unwrap()
3348 .join()
3349 .unwrap()
3350 }
3351
3352 #[test]
3353 fn tower4_batch_lane_bit_identical() {
3354 let batches = 2000;
3355 let rows_checked = big_stack(move || run4::<2>(0x1111_2222_3333_4444, batches))
3356 + big_stack(move || run4::<3>(0x5555_6666_7777_8888, batches))
3357 + big_stack(move || run4::<4>(0x9999_aaaa_bbbb_cccc, batches))
3358 + big_stack(move || run4::<9>(0xdddd_eeee_ffff_0000, batches));
3359 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3360 // worker threads against silently running zero comparisons.
3361 assert_eq!(rows_checked, 4 * batches * 4);
3362 }
3363
3364 #[test]
3365 fn tower3_batch_lane_bit_identical() {
3366 let batches = 2000;
3367 let rows_checked = big_stack(move || run3::<2>(0x0f0f_1e1e_2d2d_3c3c, batches))
3368 + big_stack(move || run3::<3>(0x4b4b_5a5a_6969_7878, batches))
3369 + big_stack(move || run3::<4>(0x8787_9696_a5a5_b4b4, batches))
3370 + big_stack(move || run3::<9>(0xc3c3_d2d2_e1e1_f0f0, batches));
3371 // 4 widths × `batches` batches × 4 rows each: guards the large-stack
3372 // worker threads against silently running zero comparisons.
3373 assert_eq!(rows_checked, 4 * batches * 4);
3374 }
3375
3376 // ── compose_unary_with seam (generic-over-Lane compose) ─────────────────
3377 //
3378 // The seam lets a single-sourced row program build its special-function
3379 // STACK from the base value through a closure, so the SAME expression
3380 // instantiates at a scalar tower (one base) AND a batch tower (four distinct
3381 // per-lane bases). These oracles pin both arms `to_bits`.
3382
3383 /// A base-value-dependent `[f64; 5]` derivative stack (finite for finite `u`),
3384 /// standing in for a family's hand-certified special-function stack. `stack4`
3385 /// is its order-≤3 truncation.
3386 fn seam_stack5(u: f64) -> [f64; 5] {
3387 [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
3388 }
3389 fn seam_stack4(u: f64) -> [f64; 4] {
3390 let s = seam_stack5(u);
3391 [s[0], s[1], s[2], s[3]]
3392 }
3393
3394 /// Force a distinct / edge per-lane base value (signed zeros included).
3395 fn seam_edge_base(r: &mut Rng, which: usize) -> f64 {
3396 match which {
3397 0 => -0.0,
3398 1 => 0.0,
3399 2 => r.f(),
3400 _ => r.f() + 3.0,
3401 }
3402 }
3403
3404 /// (a) scalar arm: `Tower4::compose_unary_with(f)` is `to_bits`-identical to
3405 /// the explicit `compose_unary(f(value))` on every channel.
3406 fn scalar_seam_t4<const K: usize>(seed: u64, n: usize) -> usize {
3407 let mut r = Rng(seed);
3408 for _ in 0..n {
3409 let mut t = rand_t4::<K>(&mut r);
3410 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3411 assert_t4_eq(
3412 &t.compose_unary_with(seam_stack5),
3413 &t.compose_unary(seam_stack5(t.v)),
3414 "scalar t4 seam",
3415 );
3416 }
3417 n
3418 }
3419 fn scalar_seam_t3<const K: usize>(seed: u64, n: usize) -> usize {
3420 let mut r = Rng(seed);
3421 for _ in 0..n {
3422 let mut t = rand_t3::<K>(&mut r);
3423 t.v = seam_edge_base(&mut r, (t.v.to_bits() % 4) as usize);
3424 assert_t3_eq(
3425 &t.compose_unary_with(seam_stack4),
3426 &t.compose_unary(seam_stack4(t.v)),
3427 "scalar t3 seam",
3428 );
3429 }
3430 n
3431 }
3432
3433 /// (b) lane arm: `Tower4Lane::compose_unary_with` lane `i` is
3434 /// `to_bits`-identical to the scalar `Tower4::compose_unary_with` on row `i`,
3435 /// with the four lanes carrying DISTINCT base values (signed zeros included),
3436 /// so a buggy impl reusing one lane's base would fail.
3437 fn lane_seam_t4<const K: usize>(seed: u64, batches: usize) -> usize {
3438 let mut r = Rng(seed);
3439 let mut verified = 0usize;
3440 for _ in 0..batches {
3441 let mut rows: [Tower4<K>; 4] = std::array::from_fn(|_| rand_t4::<K>(&mut r));
3442 for (rw, row) in rows.iter_mut().enumerate() {
3443 row.v = seam_edge_base(&mut r, rw);
3444 }
3445 let batch_out = pack4_t4(&rows).compose_unary_with(seam_stack5);
3446 for (rw, row) in rows.iter().enumerate() {
3447 assert_t4_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack5), "lane t4 seam");
3448 verified += 1;
3449 }
3450 }
3451 verified
3452 }
3453 fn lane_seam_t3<const K: usize>(seed: u64, batches: usize) -> usize {
3454 let mut r = Rng(seed);
3455 let mut verified = 0usize;
3456 for _ in 0..batches {
3457 let mut rows: [Tower3<K>; 4] = std::array::from_fn(|_| rand_t3::<K>(&mut r));
3458 for (rw, row) in rows.iter_mut().enumerate() {
3459 row.v = seam_edge_base(&mut r, rw);
3460 }
3461 let batch_out = pack4_t3(&rows).compose_unary_with(seam_stack4);
3462 for (rw, row) in rows.iter().enumerate() {
3463 assert_t3_eq(&batch_out.lane(rw), &row.compose_unary_with(seam_stack4), "lane t3 seam");
3464 verified += 1;
3465 }
3466 }
3467 verified
3468 }
3469
3470 #[test]
3471 fn compose_unary_with_scalar_bit_identical() {
3472 let n = 1100;
3473 let total = scalar_seam_t4::<2>(0x2200_0001, n)
3474 + scalar_seam_t4::<3>(0x2200_0002, n)
3475 + scalar_seam_t4::<4>(0x2200_0003, n)
3476 + big_stack(move || scalar_seam_t4::<9>(0x2200_0004, n))
3477 + scalar_seam_t3::<2>(0x3300_0001, n)
3478 + scalar_seam_t3::<3>(0x3300_0002, n)
3479 + scalar_seam_t3::<4>(0x3300_0003, n)
3480 + big_stack(move || scalar_seam_t3::<9>(0x3300_0004, n));
3481 // 8 arms × 1100 = 8800 ≥ 4000 inputs.
3482 assert_eq!(total, 8 * n);
3483 }
3484
3485 #[test]
3486 fn compose_unary_with_lane_matches_scalar() {
3487 let b = 600;
3488 let total = lane_seam_t4::<2>(0x4400_0001, b)
3489 + lane_seam_t4::<3>(0x4400_0002, b)
3490 + lane_seam_t4::<4>(0x4400_0003, b)
3491 + big_stack(move || lane_seam_t4::<9>(0x4400_0004, b))
3492 + lane_seam_t3::<2>(0x5500_0001, b)
3493 + lane_seam_t3::<3>(0x5500_0002, b)
3494 + lane_seam_t3::<4>(0x5500_0003, b)
3495 + big_stack(move || lane_seam_t3::<9>(0x5500_0004, b));
3496 // 8 arms × 600 = 4800 batches ≥ 2000; each verifies 4 lanes (19200 checks).
3497 assert_eq!(total, 8 * b * 4);
3498 }
3499}
3500
3501#[cfg(test)]
3502mod tests {
3503 use super::*;
3504
3505 /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
3506 /// carries (value, gradient, Hessian, third derivatives). The order-≤3
3507 /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
3508 /// dropping the fourth tensor cannot perturb them. Exercises products
3509 /// (Leibniz cross-terms), unary composition, scaling, and addition — the
3510 /// same operations the survival location-scale `nll_index_tower` composes —
3511 /// across all mixed partials, not just the diagonal entries that kernel reads.
3512 #[test]
3513 fn tower3_matches_tower4_through_third_order() {
3514 let s_a: [f64; 5] = [
3515 0.3_f64.sin(),
3516 0.3_f64.cos(),
3517 -0.3_f64.sin(),
3518 -0.3_f64.cos(),
3519 0.3_f64.sin(),
3520 ];
3521 let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
3522 let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
3523
3524 let a4 = Tower4::<3>::variable(0.4, 0);
3525 let b4 = Tower4::<3>::variable(-0.7, 1);
3526 let c4 = Tower4::<3>::variable(0.9, 2);
3527 let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
3528 + a4.mul(&c4).scale(-0.7)
3529 + b4.compose_unary(s_b).scale(0.25);
3530
3531 let a3 = Tower3::<3>::variable(0.4, 0);
3532 let b3 = Tower3::<3>::variable(-0.7, 1);
3533 let c3 = Tower3::<3>::variable(0.9, 2);
3534 let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
3535 + a3.mul(&c3).scale(-0.7)
3536 + b3.compose_unary(s4(s_b)).scale(0.25);
3537
3538 assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
3539 for i in 0..3 {
3540 assert_eq!(
3541 prog3.g[i].to_bits(),
3542 prog4.g[i].to_bits(),
3543 "g[{i}] mismatch"
3544 );
3545 for j in 0..3 {
3546 assert_eq!(
3547 prog3.h[i][j].to_bits(),
3548 prog4.h[i][j].to_bits(),
3549 "h[{i}][{j}] mismatch"
3550 );
3551 for k in 0..3 {
3552 assert_eq!(
3553 prog3.t3[i][j][k].to_bits(),
3554 prog4.t3[i][j][k].to_bits(),
3555 "t3[{i}][{j}][{k}] mismatch"
3556 );
3557 }
3558 }
3559 }
3560 }
3561
3562 /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
3563 /// The entire tower has textbook closed forms in μ = σ(η); this test
3564 /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
3565 /// analytic truth at near-machine precision.
3566 struct LogitProgram {
3567 eta: Vec<f64>,
3568 y: Vec<f64>,
3569 }
3570
3571 impl RowNllProgram<1> for LogitProgram {
3572 fn n_rows(&self) -> usize {
3573 self.eta.len()
3574 }
3575 fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
3576 Ok([self.eta[row]])
3577 }
3578 fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
3579 let eta = p[0];
3580 Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
3581 }
3582 }
3583
3584 #[test]
3585 fn logit_tower_matches_closed_forms() {
3586 let prog = LogitProgram {
3587 eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
3588 y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
3589 };
3590 for row in 0..prog.n_rows() {
3591 let t = evaluate_program(&prog, row).expect("logit program");
3592 let eta = prog.eta[row];
3593 let y = prog.y[row];
3594 let mu = 1.0 / (1.0 + (-eta).exp());
3595 let w = mu * (1.0 - mu);
3596 let expect = [
3597 (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
3598 (t.g[0], mu - y, "grad"),
3599 (t.h[0][0], w, "hess"),
3600 (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
3601 (
3602 t.t4[0][0][0][0],
3603 w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
3604 "fourth",
3605 ),
3606 ];
3607 for (got, want, label) in expect {
3608 assert!(
3609 (got - want).abs() <= 1e-12 * want.abs().max(1.0),
3610 "row {row} {label}: got {got:+.15e} want {want:+.15e}"
3611 );
3612 }
3613 }
3614 }
3615
3616 fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
3617 let diff = (got - want).abs();
3618 assert!(
3619 diff <= rel_tol * want.abs().max(1.0),
3620 "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
3621 );
3622 }
3623
3624 #[test]
3625 fn gamma_special_function_stacks_match_reference_values() {
3626 const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
3627 let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
3628 let cases = [
3629 (
3630 "x=0.1",
3631 0.1,
3632 -10.423_754_940_411_076,
3633 101.433_299_150_792_75,
3634 ),
3635 (
3636 "x=0.5",
3637 0.5,
3638 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
3639 pi_sq / 2.0,
3640 ),
3641 ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
3642 (
3643 "x=2.5",
3644 2.5,
3645 -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
3646 pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
3647 ),
3648 (
3649 "x=50",
3650 50.0,
3651 3.901_989_673_427_892,
3652 0.020_201_333_226_697_128,
3653 ),
3654 ];
3655
3656 for (label, x, digamma_ref, trigamma_ref) in cases {
3657 let ln_gamma_stack = ln_gamma_derivative_stack(x);
3658 let digamma_stack = digamma_derivative_stack(x);
3659 let trigamma_stack = trigamma_derivative_stack(x);
3660 assert_close(
3661 &format!("{label} ln_gamma_stack digamma"),
3662 ln_gamma_stack[1],
3663 digamma_ref,
3664 1e-13,
3665 );
3666 assert_close(
3667 &format!("{label} digamma value"),
3668 digamma_stack[0],
3669 digamma_ref,
3670 1e-13,
3671 );
3672 assert_close(
3673 &format!("{label} ln_gamma_stack trigamma"),
3674 ln_gamma_stack[2],
3675 trigamma_ref,
3676 1e-13,
3677 );
3678 assert_close(
3679 &format!("{label} digamma_stack trigamma"),
3680 digamma_stack[1],
3681 trigamma_ref,
3682 1e-13,
3683 );
3684 assert_close(
3685 &format!("{label} trigamma value"),
3686 trigamma_stack[0],
3687 trigamma_ref,
3688 1e-13,
3689 );
3690 }
3691 }
3692
3693 #[test]
3694 fn gamma_special_function_stacks_obey_recurrences() {
3695 for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
3696 let digamma_x = digamma_derivative_stack(x)[0];
3697 let digamma_next = digamma_derivative_stack(x + 1.0)[0];
3698 let trigamma_x = trigamma_derivative_stack(x)[0];
3699 let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
3700 assert_close(
3701 &format!("digamma recurrence x={x}"),
3702 digamma_next,
3703 digamma_x + 1.0 / x,
3704 1e-13,
3705 );
3706 assert_close(
3707 &format!("trigamma recurrence x={x}"),
3708 trigamma_next,
3709 trigamma_x - 1.0 / (x * x),
3710 1e-13,
3711 );
3712 }
3713 }
3714
3715 /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
3716 /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
3717 /// shape — all have one-line closed forms here.
3718 struct LocScaleProgram {
3719 eta: Vec<f64>,
3720 s: Vec<f64>,
3721 y: Vec<f64>,
3722 }
3723
3724 impl RowNllProgram<2> for LocScaleProgram {
3725 fn n_rows(&self) -> usize {
3726 self.eta.len()
3727 }
3728 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
3729 Ok([self.eta[row], self.s[row]])
3730 }
3731 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
3732 let r = -(p[0] - self.y[row]);
3733 Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
3734 }
3735 }
3736
3737 #[test]
3738 fn locscale_tower_matches_closed_forms_including_cross_blocks() {
3739 let prog = LocScaleProgram {
3740 eta: vec![0.3, -1.1, 2.0],
3741 s: vec![-0.5, 0.2, 0.8],
3742 y: vec![1.0, -2.0, 2.5],
3743 };
3744 let tol = 1e-12;
3745 for row in 0..prog.n_rows() {
3746 let t = evaluate_program(&prog, row).expect("locscale program");
3747 let r = prog.y[row] - prog.eta[row];
3748 let w = (-2.0 * prog.s[row]).exp();
3749 // (η, s) = indices (0, 1).
3750 let truth_g = [-w * r, 1.0 - w * r * r];
3751 let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
3752 // Third tensor: distinct-entry closed forms.
3753 // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
3754 let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
3755 match a + b + c {
3756 0 => 0.0,
3757 1 => -2.0 * w,
3758 2 => -4.0 * w * r,
3759 _ => -4.0 * w * r * r,
3760 }
3761 };
3762 // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
3763 // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
3764 let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
3765 match a + b + c + d {
3766 0 | 1 => 0.0,
3767 2 => 4.0 * w,
3768 3 => 8.0 * w * r,
3769 _ => 8.0 * w * r * r,
3770 }
3771 };
3772 for a in 0..2 {
3773 assert!(
3774 (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
3775 "row {row} grad[{a}]"
3776 );
3777 for b in 0..2 {
3778 assert!(
3779 (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
3780 "row {row} hess[{a}][{b}]: got {} want {}",
3781 t.h[a][b],
3782 truth_h[a][b]
3783 );
3784 for c in 0..2 {
3785 assert!(
3786 (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
3787 <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3788 "row {row} t3[{a}][{b}][{c}]: got {} want {}",
3789 t.t3[a][b][c],
3790 t3_truth(a, b, c)
3791 );
3792 for d in 0..2 {
3793 assert!(
3794 (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
3795 <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
3796 "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
3797 t.t4[a][b][c][d],
3798 t4_truth(a, b, c, d)
3799 );
3800 }
3801 }
3802 }
3803 }
3804 // The derived trait-surface helpers agree with direct contraction.
3805 let dir = [0.7, -1.3];
3806 let third = derived_third_contracted(&prog, row, &dir).expect("third");
3807 for a in 0..2 {
3808 for b in 0..2 {
3809 let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
3810 assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
3811 }
3812 }
3813 }
3814 }
3815
3816 /// FD cross-check on a deliberately gnarly composition (div, sqrt,
3817 /// powf, nested exp/ln) in K=3, where no closed form is consulted:
3818 /// every tower channel is checked against central finite differences
3819 /// of the channel one order below — value→grad, grad→hess, hess→t3,
3820 /// t3→t4 — so each order is independently anchored.
3821 ///
3822 /// The program carries a per-row primary fixture plus a per-row offset
3823 /// `tau[row]` that enters the loss as a constant, so `row` genuinely
3824 /// drives both the seed point and the evaluated expression.
3825 struct GnarlyProgram {
3826 primaries: Vec<[f64; 3]>,
3827 tau: Vec<f64>,
3828 }
3829
3830 impl GnarlyProgram {
3831 fn fixture() -> Self {
3832 Self {
3833 primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
3834 tau: vec![0.15, -0.35, 0.5],
3835 }
3836 }
3837 }
3838
3839 impl RowNllProgram<3> for GnarlyProgram {
3840 fn n_rows(&self) -> usize {
3841 self.primaries.len()
3842 }
3843 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3844 self.primaries
3845 .get(row)
3846 .copied()
3847 .ok_or_else(|| format!("gnarly: row {row} out of range"))
3848 }
3849 fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3850 let tau = *self
3851 .tau
3852 .get(row)
3853 .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
3854 let a = (p[0] * p[1]).exp();
3855 let b = (p[2] * p[2] + 1.0).sqrt();
3856 let c = (a + b + tau).ln();
3857 let d = (p[1] * 0.5 + 2.0).powf(1.7);
3858 Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
3859 }
3860 }
3861
3862 /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
3863 /// `row` (used to drive central differences off the fixture grid),
3864 /// while keeping `row`'s per-row data (`tau`) in the loss.
3865 fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
3866 struct At<'a> {
3867 base: &'a GnarlyProgram,
3868 row: usize,
3869 p: [f64; 3],
3870 }
3871 impl RowNllProgram<3> for At<'_> {
3872 fn n_rows(&self) -> usize {
3873 1
3874 }
3875 fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
3876 if row != 0 {
3877 return Err(format!("gnarly-at: row {row} out of range"));
3878 }
3879 Ok(self.p)
3880 }
3881 fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
3882 if eval_row != 0 {
3883 return Err(format!("gnarly-at: eval row {eval_row} out of range"));
3884 }
3885 self.base.row_nll(self.row, vars)
3886 }
3887 }
3888 evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
3889 }
3890
3891 #[test]
3892 fn gnarly_tower_is_fd_consistent_order_by_order() {
3893 let prog = GnarlyProgram::fixture();
3894 for row in 0..prog.n_rows() {
3895 let base = prog.primaries(row).expect("primaries");
3896 let t = gnarly_tower_at(&prog, row, base);
3897 let h_step = 1e-5;
3898 let tol = 1e-6;
3899 for c in 0..3 {
3900 let mut up = base;
3901 let mut dn = base;
3902 up[c] += h_step;
3903 dn[c] -= h_step;
3904 let t_up = gnarly_tower_at(&prog, row, up);
3905 let t_dn = gnarly_tower_at(&prog, row, dn);
3906 // value → gradient.
3907 let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
3908 assert!(
3909 (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
3910 "grad[{c}]: analytic {} fd {}",
3911 t.g[c],
3912 fd_g
3913 );
3914 for a in 0..3 {
3915 // gradient → Hessian.
3916 let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
3917 assert!(
3918 (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
3919 "hess[{a}][{c}]: analytic {} fd {}",
3920 t.h[a][c],
3921 fd_h
3922 );
3923 for b in 0..3 {
3924 // Hessian → third.
3925 let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
3926 assert!(
3927 (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
3928 "t3[{a}][{b}][{c}]: analytic {} fd {}",
3929 t.t3[a][b][c],
3930 fd_t3
3931 );
3932 for d in 0..3 {
3933 // third → fourth.
3934 let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
3935 assert!(
3936 (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
3937 "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
3938 t.t4[a][b][d][c],
3939 fd_t4
3940 );
3941 }
3942 }
3943 }
3944 }
3945 }
3946 }
3947
3948 /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
3949 /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
3950 /// smooth, strictly-`a`-monotone
3951 /// F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
3952 /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
3953 /// independent finite-difference oracle. Mirrors the survival flex
3954 /// calibration solve (one implicit intercept over the primaries) without
3955 /// any survival machinery, so a failure localises to the combinator.
3956 #[test]
3957 fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
3958 const C: f64 = 1.7;
3959 // The scalar constraint as a plain f64 closure (the production root
3960 // finder analogue) and its tower form in (a, θ₀, θ₁).
3961 let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
3962 let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
3963 let solve = |th: [f64; 2]| -> f64 {
3964 let mut a = 0.0_f64;
3965 for _ in 0..100 {
3966 let r = f_scalar(a, th);
3967 if r.abs() < 1e-14 {
3968 break;
3969 }
3970 a -= r / f_da(a, th);
3971 }
3972 a
3973 };
3974 // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
3975 let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
3976 let a = Tower4::<3>::variable(a0, 0);
3977 let t0 = Tower4::<3>::variable(th[0], 1);
3978 let t1 = Tower4::<3>::variable(th[1], 2);
3979 a + t0 * a.mul(&a) + t1 * a.exp() - C
3980 };
3981
3982 let th0 = [0.35, 0.2];
3983 let a0 = solve(th0);
3984 let f = f_tower(a0, th0);
3985 // Residual at the solved point is ~0 (the combinator tolerates the
3986 // production Newton residual; here it is machine-zero).
3987 assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
3988 let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
3989
3990 // FD oracle: central differences of the scalar re-solve. Each order is
3991 // built from the previous via one more central difference, exactly the
3992 // gnarly order-by-order ladder.
3993 let h = 1e-4;
3994 let tol = 1e-5;
3995 let re = |th: [f64; 2]| solve(th);
3996 for i in 0..2 {
3997 let mut up = th0;
3998 let mut dn = th0;
3999 up[i] += h;
4000 dn[i] -= h;
4001 let fd_g = (re(up) - re(dn)) / (2.0 * h);
4002 assert!(
4003 (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4004 "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
4005 a_tower.g[i],
4006 fd_g
4007 );
4008 // second order: FD of the analytic gradient component would re-use
4009 // the combinator; instead difference a SCALAR gradient computed by
4010 // a nested re-solve so the oracle stays production-independent.
4011 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4012 let mut up = th;
4013 let mut dn = th;
4014 up[j] += h;
4015 dn[j] -= h;
4016 (re(up) - re(dn)) / (2.0 * h)
4017 };
4018 for j in 0..2 {
4019 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4020 assert!(
4021 (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4022 "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
4023 a_tower.h[i][j],
4024 fd_h
4025 );
4026 }
4027 }
4028 }
4029
4030 /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
4031 /// linear-in-a constraint, and the second-order tensor matches the
4032 /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
4033 /// This pins the recursion against the hand-coded first_full.rs formula it
4034 /// replaces, independent of any FD step.
4035 #[test]
4036 fn implicit_solve_matches_textbook_ift_recursion() {
4037 // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
4038 let a0 = 0.4_f64;
4039 let th = [0.25_f64, -0.15_f64];
4040 let f = {
4041 let a = Tower4::<3>::variable(a0, 0);
4042 let t0 = Tower4::<3>::variable(th[0], 1);
4043 let t1 = Tower4::<3>::variable(th[1], 2);
4044 // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
4045 // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
4046 // 0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
4047 // implicit_solve requires a genuine root; at the root the level-set
4048 // and root-curve derivatives coincide, so the textbook-IFT
4049 // assertions below are unaffected.
4050 a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
4051 };
4052 let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
4053 let f_a = f.g[0];
4054 // First order: a_u = −F_u / F_a.
4055 for u in 0..2 {
4056 let want = -f.g[u + 1] / f_a;
4057 assert!(
4058 (a_t.g[u] - want).abs() < 1e-12,
4059 "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
4060 a_t.g[u],
4061 want
4062 );
4063 }
4064 // Second order textbook IFT (indices shifted by 1 for the a-slot).
4065 for u in 0..2 {
4066 for v in 0..2 {
4067 let f_uv = f.h[u + 1][v + 1];
4068 let f_au = f.h[0][u + 1];
4069 let f_av = f.h[0][v + 1];
4070 let f_aa = f.h[0][0];
4071 let want =
4072 -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
4073 assert!(
4074 (a_t.h[u][v] - want).abs() < 1e-12,
4075 "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
4076 a_t.h[u][v],
4077 want
4078 );
4079 }
4080 }
4081 }
4082
4083 /// The moving-boundary flux tower reproduces every θ-derivative of a
4084 /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
4085 /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
4086 /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
4087 /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
4088 /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
4089 #[test]
4090 fn moving_boundary_flux_carries_b_zuv_term() {
4091 use std::f64::consts::PI;
4092 let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
4093 // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
4094 let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
4095 let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
4096 let th0 = [0.7_f64, 0.5_f64];
4097
4098 // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
4099 let mut z_edge = Tower4::<2>::constant(z_r(th0));
4100 z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
4101 z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
4102 z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2 (the z_uv the old flux dropped)
4103
4104 // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
4105 // B‴=(3z−z³)·B.
4106 let z0 = z_edge.v;
4107 let b0 = b(z0);
4108 let stack = [
4109 b0,
4110 -z0 * b0,
4111 (z0 * z0 - 1.0) * b0,
4112 (3.0 * z0 - z0 * z0 * z0) * b0,
4113 ];
4114 let flux = moving_limit_boundary_tower(&z_edge, stack);
4115
4116 // FD truth of the integral's derivatives.
4117 let h = 1e-4;
4118 let tol = 1e-6;
4119 for i in 0..2 {
4120 let mut up = th0;
4121 let mut dn = th0;
4122 up[i] += h;
4123 dn[i] -= h;
4124 let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
4125 assert!(
4126 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4127 "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
4128 flux.g[i],
4129 fd_g
4130 );
4131 }
4132 // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
4133 // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
4134 // leave the [1][1] entry short by exactly 2·B(z₀).
4135 let grad1_at = |th: [f64; 2]| -> f64 {
4136 let mut up = th;
4137 let mut dn = th;
4138 up[1] += h;
4139 dn[1] -= h;
4140 (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
4141 };
4142 let mut up = th0;
4143 let mut dn = th0;
4144 up[1] += h;
4145 dn[1] -= h;
4146 let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
4147 assert!(
4148 (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
4149 "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
4150 flux.h[1][1],
4151 fd_h11
4152 );
4153 // Explicit witness that the B·z_uv term is present and material:
4154 // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
4155 let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
4156 let b_zuv = flux.h[1][1] - pure_zu2;
4157 assert!(
4158 (b_zuv - b0 * 2.0).abs() < 1e-10,
4159 "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
4160 b_zuv,
4161 b0 * 2.0
4162 );
4163 }
4164
4165 /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
4166 /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
4167 /// plain `moving_limit_boundary_tower` cannot express, and the case the
4168 /// survival directional/bidirectional paths hand-assemble term-by-term
4169 /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
4170 /// dropping `G·z_uv`). Two independent oracles:
4171 /// (1) closed-form: the boundary flux of `∫ G dz` is exactly
4172 /// `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
4173 /// derivatives we take by central FD of the closed form — no jet code.
4174 /// (2) the explicit second-order hand closure, including the `G·z_uv` term,
4175 /// built from the integrand's own (z,θ) partials.
4176 /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
4177 /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
4178 /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
4179 /// miss a Hessian entry.
4180 #[test]
4181 fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
4182 // G(z;θ) = exp(z·θ₀); Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
4183 let g = |z: f64, t0: f64| (z * t0).exp();
4184 let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
4185 let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
4186 let th0 = [0.4_f64, 0.5_f64];
4187 let z0 = z_r(th0);
4188
4189 // Edge tower z_edge(θ) over K=2 primaries.
4190 let mut z_edge = Tower4::<2>::constant(z0);
4191 z_edge.g[0] = 1.0; // ∂z/∂θ₀
4192 z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
4193 z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
4194
4195 // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
4196 // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
4197 let z_var = Tower4::<3>::variable(z0, 0);
4198 let t0_var = Tower4::<3>::variable(th0[0], 1);
4199 // θ₁ does not enter G/Φ here, but seed it so the jet carries the full
4200 // K1 frame (its Φ-derivatives are zero; the z_edge chain supplies all θ₁
4201 // motion through slot 0).
4202 let _t1_var = Tower4::<3>::variable(th0[1], 2);
4203 let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
4204 // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
4205 assert!(
4206 (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
4207 "Φ_z {:+.8e} != G {:+.8e}",
4208 phi_jet.g[0],
4209 g(z0, th0[0])
4210 );
4211
4212 let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
4213
4214 // Value channel is 0 by construction (boundary, not the integral itself).
4215 assert!(
4216 flux.v.abs() < 1e-12,
4217 "boundary value channel {:+.3e}",
4218 flux.v
4219 );
4220
4221 // Oracle (1): central FD of the closed-form boundary flux
4222 // Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ) (z₀ FROZEN at the base edge).
4223 let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
4224 let h = 1e-4;
4225 let tol = 1e-6;
4226 for i in 0..2 {
4227 let mut up = th0;
4228 let mut dn = th0;
4229 up[i] += h;
4230 dn[i] -= h;
4231 let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
4232 assert!(
4233 (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
4234 "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
4235 flux.g[i],
4236 fd_g
4237 );
4238 }
4239 let grad_at = |th: [f64; 2], j: usize| -> f64 {
4240 let mut up = th;
4241 let mut dn = th;
4242 up[j] += h;
4243 dn[j] -= h;
4244 (bnd(up) - bnd(dn)) / (2.0 * h)
4245 };
4246 for i in 0..2 {
4247 for j in 0..2 {
4248 let mut up = th0;
4249 let mut dn = th0;
4250 up[i] += h;
4251 dn[i] -= h;
4252 let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
4253 assert!(
4254 (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
4255 "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
4256 flux.h[i][j],
4257 fd_h
4258 );
4259 }
4260 }
4261
4262 // Oracle (2): the explicit second-order hand closure, term by term —
4263 // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
4264 // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
4265 // G_θ₁ = 0.
4266 let gg = g(z0, th0[0]);
4267 let g_z = th0[0] * gg;
4268 let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
4269 for i in 0..2 {
4270 for j in 0..2 {
4271 let z_u = z_edge.g[i];
4272 let z_v = z_edge.g[j];
4273 let z_uv = z_edge.h[i][j];
4274 let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
4275 assert!(
4276 (flux.h[i][j] - hand).abs() < 1e-9,
4277 "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
4278 flux.h[i][j],
4279 hand
4280 );
4281 }
4282 }
4283
4284 // Decisive: the `G·z_uv` term the directional path DROPS is present and
4285 // material in the [1][1] entry (z_uv = 2 there).
4286 let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
4287 let g_zuv = flux.h[1][1] - pure_no_zuv;
4288 assert!(
4289 (g_zuv - gg * 2.0).abs() < 1e-9,
4290 "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
4291 g_zuv,
4292 gg * 2.0
4293 );
4294 }
4295
4296 /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
4297 /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
4298 /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
4299 /// boundary-velocity formulas:
4300 /// z_u = −(a_u + [u==g]·z) / b
4301 /// z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
4302 /// This pins the bridge between `implicit_solve` and
4303 /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
4304 /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
4305 /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
4306 /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
4307 #[test]
4308 fn crossing_edge_tower_matches_handpath_velocity_formulas() {
4309 const TAU: f64 = 1.3; // the link-knot crossing threshold τ
4310 let g_idx = 1usize;
4311 let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
4312 // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
4313 // two live axes so a_u and a_uv are both exercised. (In production this
4314 // comes from implicit_solve; here we plant known derivatives.)
4315 let mut a = Tower4::<3>::constant(0.45);
4316 a.g[0] = 0.7;
4317 a.g[1] = -0.3;
4318 a.h[0][0] = 0.25;
4319 a.h[0][1] = 0.11;
4320 a.h[1][0] = 0.11;
4321 a.h[1][1] = -0.08;
4322
4323 // In the survival flex frame the slope `b` IS the g-primary directly
4324 // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
4325 // z_edge = (τ − a) / b with b seeded as the g-axis variable.
4326 let b = Tower4::<3>::variable(g0, g_idx);
4327 let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
4328
4329 let bv = g0;
4330 let z0 = z_edge.v;
4331 assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
4332
4333 // z_u = −(a_u + [u==g]·z) / b.
4334 for u in 0..2 {
4335 let direct = if u == g_idx { z0 } else { 0.0 };
4336 let want = -(a.g[u] + direct) / bv;
4337 assert!(
4338 (z_edge.g[u] - want).abs() < 1e-10,
4339 "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
4340 z_edge.g[u],
4341 want
4342 );
4343 }
4344 // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
4345 // first-order z_v/z_u (already verified above).
4346 for u in 0..2 {
4347 for v in 0..2 {
4348 let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
4349 + if v == g_idx { z_edge.g[u] } else { 0.0 };
4350 let want = -(a.h[u][v] + cross) / bv;
4351 assert!(
4352 (z_edge.h[u][v] - want).abs() < 1e-10,
4353 "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
4354 z_edge.h[u][v],
4355 want
4356 );
4357 }
4358 }
4359 }
4360
4361 /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
4362 /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
4363 /// FD-certified bare boundary-velocity constants exactly:
4364 /// z_a = ∂z/∂a = −1/b
4365 /// z_ab = ∂²z/∂a∂b = +1/b²
4366 /// z_aa = ∂²z/∂a² = 0
4367 /// z_bb = ∂²z/∂b² = +2(τ−a)/b³
4368 /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
4369 /// production base path drops (and only adds in the dir twins, causing the
4370 /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
4371 /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
4372 /// that later. Pins the constant before the constraint-tower wiring.
4373 #[test]
4374 fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
4375 const TAU: f64 = 1.3;
4376 let a0 = 0.45_f64;
4377 let b0 = 0.85_f64;
4378 // Slot 0 = a, slot 1 = b, both seeded independent.
4379 let a = Tower4::<2>::variable(a0, 0);
4380 let b = Tower4::<2>::variable(b0, 1);
4381 let z = (Tower4::<2>::constant(TAU) - a) / b;
4382
4383 assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
4384 assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
4385 assert!(
4386 (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
4387 "z_ab {:+.10e} vs +1/b² {:+.10e}",
4388 z.h[0][1],
4389 1.0 / (b0 * b0)
4390 );
4391 assert!(
4392 z.h[0][0].abs() < 1e-12,
4393 "z_aa must vanish, got {:+.10e}",
4394 z.h[0][0]
4395 );
4396 let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
4397 assert!(
4398 (z.h[1][1] - want_zbb).abs() < 1e-12,
4399 "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
4400 z.h[1][1],
4401 want_zbb
4402 );
4403 }
4404
4405 /// The oracle harness catches a planted #736-style sign flip in a
4406 /// cross block and reports the channel by name.
4407 #[test]
4408 fn oracle_catches_planted_cross_block_sign_flip() {
4409 let prog = LocScaleProgram {
4410 eta: vec![0.3],
4411 s: vec![-0.5],
4412 y: vec![1.0],
4413 };
4414 let t = evaluate_program(&prog, 0).expect("tower");
4415 let dir = [0.6, -0.2];
4416 let mut third = t.third_contracted(&dir);
4417 let honest = KernelChannels {
4418 value: t.v,
4419 gradient: t.g,
4420 hessian: t.h,
4421 third: vec![(dir, third)],
4422 fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
4423 };
4424 verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
4425
4426 // Plant the #736 flip: negate one mixed cross entry.
4427 third[0][1] = -third[0][1];
4428 let flipped = KernelChannels {
4429 value: t.v,
4430 gradient: t.g,
4431 hessian: t.h,
4432 third: vec![(dir, third)],
4433 fourth: vec![],
4434 };
4435 let err = verify_kernel_channels(&t, &flipped, 1e-10)
4436 .expect_err("planted sign flip must be caught");
4437 assert!(
4438 err.contains("third[0][0][1]"),
4439 "oracle must name the flipped channel, got: {err}"
4440 );
4441 }
4442
4443 /// The third- and fourth-order tensors must be FULLY symmetric under
4444 /// index permutation (mixed partials commute). The tower stores them
4445 /// unsymmetrized, so equal-by-construction is a real invariant of the
4446 /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
4447 /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
4448 /// every composition path contributes. Lives in a test (not the hot
4449 /// per-op path) on purpose.
4450 #[test]
4451 fn t3_t4_are_fully_index_symmetric() {
4452 let prog = GnarlyProgram::fixture();
4453 // 3! = 6 permutations of three indices.
4454 let perms3: [[usize; 3]; 6] = [
4455 [0, 1, 2],
4456 [0, 2, 1],
4457 [1, 0, 2],
4458 [1, 2, 0],
4459 [2, 0, 1],
4460 [2, 1, 0],
4461 ];
4462 // 4! = 24 permutations of four indices.
4463 let perms4: [[usize; 4]; 24] = [
4464 [0, 1, 2, 3],
4465 [0, 1, 3, 2],
4466 [0, 2, 1, 3],
4467 [0, 2, 3, 1],
4468 [0, 3, 1, 2],
4469 [0, 3, 2, 1],
4470 [1, 0, 2, 3],
4471 [1, 0, 3, 2],
4472 [1, 2, 0, 3],
4473 [1, 2, 3, 0],
4474 [1, 3, 0, 2],
4475 [1, 3, 2, 0],
4476 [2, 0, 1, 3],
4477 [2, 0, 3, 1],
4478 [2, 1, 0, 3],
4479 [2, 1, 3, 0],
4480 [2, 3, 0, 1],
4481 [2, 3, 1, 0],
4482 [3, 0, 1, 2],
4483 [3, 0, 2, 1],
4484 [3, 1, 0, 2],
4485 [3, 1, 2, 0],
4486 [3, 2, 0, 1],
4487 [3, 2, 1, 0],
4488 ];
4489 for row in 0..prog.n_rows() {
4490 let t = evaluate_program(&prog, row).expect("gnarly tower");
4491 let scale_t3 =
4492 t.t3.iter()
4493 .flatten()
4494 .flatten()
4495 .fold(0.0_f64, |m, x| m.max(x.abs()))
4496 .max(1.0);
4497 let scale_t4 =
4498 t.t4.iter()
4499 .flatten()
4500 .flatten()
4501 .flatten()
4502 .fold(0.0_f64, |m, x| m.max(x.abs()))
4503 .max(1.0);
4504 for i in 0..3 {
4505 for j in 0..3 {
4506 for k in 0..3 {
4507 let base = t.t3[i][j][k];
4508 let idx = [i, j, k];
4509 for p in &perms3 {
4510 let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
4511 assert!(
4512 (base - permed).abs() <= 1e-12 * scale_t3,
4513 "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
4514 permuted {permed:+.15e} under {p:?}"
4515 );
4516 }
4517 for l in 0..3 {
4518 let base4 = t.t4[i][j][k][l];
4519 let idx4 = [i, j, k, l];
4520 for p in &perms4 {
4521 let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
4522 assert!(
4523 (base4 - permed).abs() <= 1e-12 * scale_t4,
4524 "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
4525 permuted {permed:+.15e} under {p:?}"
4526 );
4527 }
4528 }
4529 }
4530 }
4531 }
4532 }
4533 }
4534}
4535
4536#[inline]
4537fn erfcx_nonnegative(x: f64) -> f64 {
4538 if !x.is_finite() {
4539 return if x.is_sign_positive() {
4540 0.0
4541 } else {
4542 f64::INFINITY
4543 };
4544 }
4545 if x <= 0.0 {
4546 return 1.0;
4547 }
4548 if x < 26.0 {
4549 ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
4550 } else {
4551 let inv = 1.0 / x;
4552 let inv2 = inv * inv;
4553 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
4554 + 6.5625 * inv2 * inv2 * inv2 * inv2;
4555 inv * poly / std::f64::consts::PI.sqrt()
4556 }
4557}
4558
4559#[inline]
4560fn log1mexp_positive(a: f64) -> f64 {
4561 assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
4562 if a > core::f64::consts::LN_2 {
4563 (-(-a).exp()).ln_1p()
4564 } else if a > 0.0 {
4565 (-(-a).exp_m1()).ln()
4566 } else {
4567 f64::NEG_INFINITY
4568 }
4569}
4570
4571#[inline]
4572fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
4573 if x == f64::INFINITY {
4574 return (0.0, 0.0);
4575 }
4576 if x == f64::NEG_INFINITY {
4577 return (f64::NEG_INFINITY, f64::INFINITY);
4578 }
4579 if x.is_nan() {
4580 return (f64::NAN, f64::NAN);
4581 }
4582 if x < 0.0 {
4583 let u = -x / std::f64::consts::SQRT_2;
4584 let ex = erfcx_nonnegative(u).max(1e-300);
4585 let log_cdf = -u * u + (0.5 * ex).ln();
4586 let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
4587 (log_cdf, lambda)
4588 } else {
4589 let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
4590 let lambda = crate::probability::normal_pdf(x) / cdf;
4591 (cdf.ln(), lambda)
4592 }
4593}
4594
4595/// Stable derivative stack for `log Phi(x)` through fourth order.
4596#[inline]
4597pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
4598 let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
4599 let lambda2 = lambda * lambda;
4600 let lambda3 = lambda2 * lambda;
4601 let x2 = x * x;
4602 [
4603 log_cdf,
4604 lambda,
4605 -lambda * (x + lambda),
4606 lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
4607 -lambda
4608 * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
4609 ]
4610}
4611
4612/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
4613#[inline]
4614pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
4615 let r = 1.0 / x.exp_m1();
4616 [
4617 log1mexp_positive(x),
4618 r,
4619 -r * (1.0 + r),
4620 r * (1.0 + r) * (1.0 + 2.0 * r),
4621 -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
4622 ]
4623}
4624// ── The RowJet bridge oracle (CI) ─────────────────────────────────────
4625#[cfg(test)]
4626mod rowjet_bridge_tests {
4627 use super::*;
4628 use crate::jet_scalar::{JetScalar, Order2};
4629
4630 /// A toy row-NLL written ONCE over the [`RowJet`] bridge: a product, a sum, a
4631 /// subtraction, a scale/neg, a constant, and two value-distinct
4632 /// `compose_unary_with` stacks (an exp stack and a smooth finite-everywhere
4633 /// stack), plus a domain `guard`. The body is generic over `R: RowJet<2>`, so
4634 /// the SAME source instantiates at the scalar jets and the `f64x4` lane towers.
4635 struct ToyProgram {
4636 primaries: Vec<[f64; 2]>,
4637 /// Per-row CONTINUOUS auxiliary data `[cov, z, wi]` — the survival
4638 /// `covariance_ones` / `z_sum` / observation-weight analogues that enter
4639 /// the jet algebra as `.scale_rows(per_row_value)`, distinct per lane.
4640 aux: Vec<[f64; 3]>,
4641 }
4642
4643 impl ToyProgram {
4644 /// The body uses `pack_rows` to gather the per-lane continuous data from
4645 /// the lane→row map and `scale_rows` to fold it in — so a 4-row batch
4646 /// carries four DISTINCT cov/z/wi, which the single-`f64` `scale` could not.
4647 fn body<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> R {
4648 let cov = R::pack_rows(rows, |r| self.aux[r][0]);
4649 let z = R::pack_rows(rows, |r| self.aux[r][1]);
4650 let wi = R::pack_rows(rows, |r| self.aux[r][2]);
4651
4652 let a = p[0].mul(&p[1]).scale_rows(cov);
4653 let b = a.add(&R::constant(0.5)).sub(&p[0].scale(0.25));
4654 let c = b
4655 .compose_unary_with(|u| {
4656 let e = u.exp();
4657 [e, e, e, e, e]
4658 })
4659 .scale_rows(z);
4660 let d = c.neg().add(&p[0]);
4661 let e = d
4662 .compose_unary_with(|u| {
4663 let s = (1.0 + u * u).sqrt();
4664 let s3 = s * s * s;
4665 let s5 = s3 * s * s;
4666 let s7 = s5 * s * s;
4667 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4668 })
4669 .scale_rows(wi);
4670 e.mul(&p[1]).add(&e)
4671 }
4672 }
4673
4674 impl RowNllProgramRowJet<2> for ToyProgram {
4675 fn n_rows(&self) -> usize {
4676 self.primaries.len()
4677 }
4678 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
4679 Ok(self.primaries[row])
4680 }
4681 fn row_nll<R: RowJet<2>>(&self, rows: &[usize], p: &[R; 2]) -> Result<R, String> {
4682 assert!(rows.len() == 1 || rows.len() == 4, "lane→row map is 1 or 4 wide");
4683 Ok(self.body(rows, p))
4684 }
4685 }
4686
4687 fn assert_t4_bits_eq(a: &Tower4<2>, b: &Tower4<2>, ctx: &str) {
4688 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4689 for i in 0..2 {
4690 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4691 for j in 0..2 {
4692 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4693 for k in 0..2 {
4694 assert_eq!(
4695 a.t3[i][j][k].to_bits(),
4696 b.t3[i][j][k].to_bits(),
4697 "{ctx}: t3[{i}][{j}][{k}]"
4698 );
4699 for l in 0..2 {
4700 assert_eq!(
4701 a.t4[i][j][k][l].to_bits(),
4702 b.t4[i][j][k][l].to_bits(),
4703 "{ctx}: t4[{i}][{j}][{k}][{l}]"
4704 );
4705 }
4706 }
4707 }
4708 }
4709 }
4710
4711 fn assert_t3_bits_eq(a: &Tower3<2>, b: &Tower3<2>, ctx: &str) {
4712 assert_eq!(a.v.to_bits(), b.v.to_bits(), "{ctx}: v");
4713 for i in 0..2 {
4714 assert_eq!(a.g[i].to_bits(), b.g[i].to_bits(), "{ctx}: g[{i}]");
4715 for j in 0..2 {
4716 assert_eq!(a.h[i][j].to_bits(), b.h[i][j].to_bits(), "{ctx}: h[{i}][{j}]");
4717 for k in 0..2 {
4718 assert_eq!(
4719 a.t3[i][j][k].to_bits(),
4720 b.t3[i][j][k].to_bits(),
4721 "{ctx}: t3[{i}][{j}][{k}]"
4722 );
4723 }
4724 }
4725 }
4726 }
4727
4728 // Deterministic LCG with signed-zero injection and per-lane-distinct values.
4729 struct Lcg(u64);
4730 impl Lcg {
4731 fn next(&mut self) -> f64 {
4732 self.0 = self
4733 .0
4734 .wrapping_mul(6364136223846793005)
4735 .wrapping_add(1442695040888963407);
4736 (self.0 >> 11) as f64 / (1u64 << 53) as f64
4737 }
4738 fn val(&mut self) -> f64 {
4739 let u = self.next();
4740 if u < 0.04 {
4741 return 0.0;
4742 }
4743 if u < 0.08 {
4744 return -0.0;
4745 }
4746 (self.next() - 0.5) * 5.0
4747 }
4748 }
4749
4750 /// Lane `i` of the batched order-4 / order-3 tower is `to_bits`-identical to
4751 /// the scalar tower on row `i`, for ≥2000 distinct 4-row batches with
4752 /// signed-zero and per-lane-distinct primaries.
4753 #[test]
4754 fn batched_lane_i_matches_scalar_row_i_bit_identical() {
4755 let mut rng = Lcg(0xA5A5_1234_DEAD_BEEF);
4756 let mut batches = 0usize;
4757 for _ in 0..2500 {
4758 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4759 // per-lane-DISTINCT continuous aux (cov/z/wi), signed-zero injected.
4760 let aux: [[f64; 3]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4761 let prog = ToyProgram { primaries: bases.to_vec(), aux: aux.to_vec() };
4762 let rows = [0usize, 1, 2, 3];
4763
4764 // order-4 batch vs scalar Tower4 (instantiated through the same body).
4765 let batch4 = generic_batched_fourth_tower(&prog, rows).expect("batch4");
4766 for (row, base) in bases.iter().enumerate() {
4767 let vars: [Tower4<2>; 2] =
4768 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4769 let scal = prog.row_nll(&[row], &vars).expect("scalar tower4");
4770 assert_t4_bits_eq(&batch4.lane(row), &scal, "batched_fourth");
4771 }
4772
4773 // order-3 batch vs scalar Tower3.
4774 let batch3 = generic_batched_third_tower(&prog, rows).expect("batch3");
4775 for (row, base) in bases.iter().enumerate() {
4776 let vars: [Tower3<2>; 2] =
4777 std::array::from_fn(|a| <Tower3<2> as RowJet<2>>::variable(base[a], a));
4778 let scal = prog.row_nll(&[row], &vars).expect("scalar tower3");
4779 assert_t3_bits_eq(&batch3.lane(row), &scal, "batched_third");
4780 }
4781 batches += 1;
4782 }
4783 assert_eq!(batches, 2500);
4784 }
4785
4786 /// The blanket impl does not churn the scalar path: the body driven through
4787 /// `RowJet` ops is `to_bits`-identical to the body driven directly through
4788 /// `JetScalar` ops, and `rowjet_row_kernel`'s `(v, g, H)` matches the dense
4789 /// `Tower4` lower channels.
4790 #[test]
4791 fn blanket_scalar_path_is_unchanged_and_consistent() {
4792 let mut rng = Lcg(0x0BAD_F00D_1357_2468);
4793 for _ in 0..3000 {
4794 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4795 let aux0: [f64; 3] = std::array::from_fn(|_| rng.val());
4796 let prog = ToyProgram { primaries: vec![base], aux: vec![aux0] };
4797
4798 // (a) RowJet-driven body == JetScalar-driven body, bit-for-bit. The
4799 // reference body uses `scale(f64)` where the RowJet body uses
4800 // `scale_rows(f64)` — proving the scalar `scale_rows` rewrite does not
4801 // churn the path (`scale_rows(s) == scale(s)` on `Value = f64`).
4802 let via_rowjet: Tower4<2> = {
4803 let vars: [Tower4<2>; 2] =
4804 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4805 prog.row_nll(&[0], &vars).expect("rowjet")
4806 };
4807 let via_jetscalar: Tower4<2> = {
4808 let vars: [Tower4<2>; 2] = std::array::from_fn(|a| {
4809 <Tower4<2> as JetScalar<2>>::variable(base[a], a)
4810 });
4811 let (cov, z, wi) = (aux0[0], aux0[1], aux0[2]);
4812 // The body using JetScalar's own ops + scale(f64) directly.
4813 let a = vars[0].mul(&vars[1]).scale(cov);
4814 let b = a.add(&Tower4::constant(0.5)).sub(&vars[0].scale(0.25));
4815 let c = b
4816 .compose_unary_with(|u| {
4817 let e = u.exp();
4818 [e, e, e, e, e]
4819 })
4820 .scale(z);
4821 let d = JetScalar::neg(&c).add(&vars[0]);
4822 let e = d
4823 .compose_unary_with(|u| {
4824 let s = (1.0 + u * u).sqrt();
4825 let s3 = s * s * s;
4826 let s5 = s3 * s * s;
4827 let s7 = s5 * s * s;
4828 [s, u / s, 1.0 / s3, -3.0 * u / s5, (12.0 * u * u - 3.0) / s7]
4829 })
4830 .scale(wi);
4831 e.mul(&vars[1]).add(&e)
4832 };
4833 assert_t4_bits_eq(&via_rowjet, &via_jetscalar, "blanket_vs_direct");
4834
4835 // (b) rowjet_row_kernel (v,g,H) == dense Tower4 lower channels.
4836 // Order2 and Tower4 use different internal representations so
4837 // signed-zero differences (−0.0 vs +0.0) may arise in gradient/
4838 // Hessian channels that evaluate to exactly zero; IEEE equality
4839 // treats these as equal, so `==` is the right comparison here.
4840 let (v, g, h) = rowjet_row_kernel(&prog, 0).expect("kernel");
4841 assert_eq!(v.to_bits(), via_rowjet.v.to_bits(), "kernel v");
4842 for i in 0..2 {
4843 assert!(g[i] == via_rowjet.g[i], "kernel g[{i}]: {} vs {}", g[i], via_rowjet.g[i]);
4844 for j in 0..2 {
4845 assert!(
4846 h[i][j] == via_rowjet.h[i][j],
4847 "kernel h[{i}][{j}]: {} vs {}",
4848 h[i][j],
4849 via_rowjet.h[i][j]
4850 );
4851 }
4852 }
4853
4854 // (c) the Order2 scalar IS a RowJet via the blanket.
4855 let o2: [Order2<2>; 2] =
4856 std::array::from_fn(|a| <Order2<2> as RowJet<2>>::variable(base[a], a));
4857 let _ = prog.body(&[0], &o2);
4858 }
4859 }
4860
4861 /// On the scalar path (`Value = f64`) `scale_rows(s)` is `to_bits`-identical
4862 /// to `scale(s)` for EVERY channel — so rewriting a survival `.scale(per_row)`
4863 /// to `.scale_rows(per_row)` cannot perturb the existing scalar fits.
4864 #[test]
4865 fn scale_rows_scalar_is_bit_identical_to_scale() {
4866 let mut rng = Lcg(0xFEED_FACE_0042_1001);
4867 for _ in 0..3000 {
4868 let base: [f64; 2] = std::array::from_fn(|_| rng.val());
4869 let s = rng.val();
4870 // Build a dense tower with populated channels (exp of a product).
4871 let vars: [Tower4<2>; 2] =
4872 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4873 let jet = vars[0].mul(&vars[1]).compose_unary_with(|u| {
4874 let e = u.exp();
4875 [e, e, e, e, e]
4876 });
4877 let via_scale = RowJet::scale(&jet, s);
4878 let via_scale_rows = RowJet::scale_rows(&jet, s);
4879 assert_t4_bits_eq(&via_scale_rows, &via_scale, "scale_rows==scale");
4880 }
4881 }
4882
4883 /// `scale_rows` on a batch multiplies lane `i` by `s[i]`, so lane `i` of a
4884 /// per-lane-scaled batch matches the scalar `scale(s[i])` on row `i` — the
4885 /// continuous per-row data path the single-`f64` `scale` could not carry.
4886 #[test]
4887 fn batched_scale_rows_matches_per_row_scalar_scale() {
4888 let mut rng = Lcg(0x1357_9BDF_2468_ACE0);
4889 for _ in 0..2500 {
4890 let bases: [[f64; 2]; 4] = std::array::from_fn(|_| std::array::from_fn(|_| rng.val()));
4891 let s: [f64; 4] = std::array::from_fn(|_| rng.val());
4892 let batch: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4893 Tower4Batch::variable(
4894 wide::f64x4::new([bases[0][a], bases[1][a], bases[2][a], bases[3][a]]),
4895 a,
4896 )
4897 });
4898 let prod = batch[0].mul(&batch[1]).compose_unary_with(|u| {
4899 let e = u.exp();
4900 [e, e, e, e, e]
4901 });
4902 let scaled = prod.scale_rows(s);
4903 for (row, base) in bases.iter().enumerate() {
4904 let v: [Tower4<2>; 2] =
4905 std::array::from_fn(|a| <Tower4<2> as RowJet<2>>::variable(base[a], a));
4906 let prod_s = v[0].mul(&v[1]).compose_unary_with(|u| {
4907 let e = u.exp();
4908 [e, e, e, e, e]
4909 });
4910 let ref_s = RowJet::scale(&prod_s, s[row]);
4911 assert_t4_bits_eq(&scaled.lane(row), &ref_s, "batched_scale_rows");
4912 }
4913 }
4914 }
4915
4916 /// The per-lane guard reports exactly the failing lanes on a batch and the
4917 /// single lane on a scalar jet.
4918 #[test]
4919 fn guard_reports_per_lane_failures() {
4920 let cols: [[f64; 2]; 4] = [[1.0, 0.5], [-2.0, 0.5], [3.0, 0.5], [-0.0, 0.5]];
4921 let vars: [Tower4Batch<2>; 2] = std::array::from_fn(|a| {
4922 Tower4Batch::variable(
4923 wide::f64x4::new([cols[0][a], cols[1][a], cols[2][a], cols[3][a]]),
4924 a,
4925 )
4926 });
4927 let verdict = vars[0].guard(|v| v > 0.0);
4928 assert_eq!(verdict.lanes(), 4);
4929 assert!(verdict.any_failed());
4930 assert!(!verdict.all_pass());
4931 assert!(!verdict.lane_failed(0));
4932 assert!(verdict.lane_failed(1));
4933 assert!(!verdict.lane_failed(2));
4934 assert!(verdict.lane_failed(3));
4935 assert_eq!(verdict.failed_mask(), 0b1010);
4936
4937 let s_ok = <Tower4<2> as RowJet<2>>::variable(1.0, 0);
4938 let s_bad = <Tower4<2> as RowJet<2>>::variable(-1.0, 0);
4939 assert!(RowJet::guard(&s_ok, |v| v > 0.0).all_pass());
4940 assert!(RowJet::guard(&s_bad, |v| v > 0.0).any_failed());
4941 assert_eq!(RowJet::guard(&s_ok, |v| v > 0.0).lanes(), 1);
4942 }
4943}