Skip to main content

gam_math/
jet_tower.rs

1//! Taylor-jet tower algebra: write each family's row log-likelihood ONCE,
2//! derive the entire `RowKernel<K>` derivative tower mechanically (#932).
3//!
4//! # The object
5//!
6//! [`Tower4<K>`] is a truncated multivariate Taylor scalar in `K` primary
7//! variables, carrying the value and ALL partial derivatives through fourth
8//! order as full (unsymmetrized) tensors:
9//!
10//! ```text
11//!   v        ℓ
12//!   g[a]     ∂ℓ/∂p_a
13//!   h[a][b]  ∂²ℓ/∂p_a∂p_b
14//!   t3[abc]  ∂³ℓ/∂p_a∂p_b∂p_c
15//!   t4[abcd] ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d
16//! ```
17//!
18//! Arithmetic (`+ − × ÷`, scalar mixes) propagates the tower by the exact
19//! Leibniz rule; unary transcendentals propagate by the exact multivariate
20//! Faà di Bruno formula given a `[f, f′, f″, f‴, f⁗]` stack evaluated at the
21//! inner value. This is truncated Taylor ALGEBRA — exact derivatives of the
22//! evaluated expression, not finite differences, not an approximation —
23//! fully compatible with the exact-REML-only policy.
24//!
25//! One evaluation of a row NLL program at seeded variables yields, in a
26//! single pass, every channel the [`super::row_kernel::RowKernel`] trait
27//! demands: `row_kernel` (value/∇/H), `row_third_contracted(dir)` (contract
28//! `t3` with `dir`), and `row_fourth_contracted(u, v)` (contract `t4` with
29//! `u` and `v`). The directional cross-channels that hand-written towers
30//! drop (#736's residual gap) cannot be dropped here: there is no separate
31//! "channel" to forget — every derivative of the one expression is carried.
32//!
33//! # Why this exists (the bug genus)
34//!
35//! Every family today hand-writes its tower: value in one function,
36//! gradient in another, `pdfthird_derivative`/`pdffourth_derivative`,
37//! entry/exit-specific cross blocks — thousands of lines of calculus that
38//! drift. #736 was a sign flip in a hand-written cross-Hessian block,
39//! invisible until a new consumer touched it; #948 is a derivative path
40//! that is not the derivative of the evaluated row loss (clamped-μ
41//! surrogate); the objective↔gradient desync class is the same disease at
42//! the criterion level. A tower-derived kernel is exact-by-construction:
43//! the value channel IS the production loss expression, so its derivative
44//! channels cannot desync from it.
45//!
46//! # Relation to `jet_partitions::MultiDirJet`
47//!
48//! The tree already carries a *directional* jet (bitmask coefficients over
49//! distinct seeded directions, heap-allocated, Bell-partition compose) used
50//! inside the marginal-slope and latent-survival families. It answers "the
51//! derivative along THESE specific directions" and must be re-seeded and
52//! re-evaluated per direction tuple (e.g. 10 symmetric `(a,b)` pairs for a
53//! K=4 fourth contraction). `Tower4` answers ALL of them from one
54//! evaluation: contraction happens AFTER differentiation, as plain linear
55//! algebra on the stored tensors. Use `MultiDirJet` when you need a handful
56//! of directions of a huge-K expression; use `Tower4` when you need the
57//! complete small-K tower — which is exactly the `RowKernel<K≤4>` shape.
58//! The `[f64; 5]` unary-derivative stacks
59//! (`unary_derivatives_neglog_phi`, …) are signature-compatible with
60//! [`Tower4::compose_unary`], so the families' existing special-function
61//! stacks are directly reusable.
62//!
63//! # Stability discipline (why this is NOT autodiff)
64//!
65//! Differentiating the primal code path inherits its instabilities: a jet
66//! pushed through a naive `ln(1 + e^η)` is garbage in the saturated tail
67//! even though the true derivative σ(η) is benign there. This module
68//! therefore splits responsibility: **humans own primitive stability,
69//! the algebra owns combinatorics**. Tail-critical special functions enter
70//! a program ONLY as hand-certified `[f64; 5]` derivative stacks through
71//! [`Tower4::compose_unary`] — the same stacks the families already write
72//! (`unary_derivatives_neglog_phi` and friends, built on erfcx/log_ndtr) —
73//! and the tower mechanizes only the Leibniz/Faà di Bruno composition,
74//! which is where hand-written towers actually fail (#736 was a
75//! composition sign flip, not a primitive error). Program authors must use
76//! a stable primitive stack wherever the f64 production loss does; the
77//! convenience methods (`exp`, `ln`, `sqrt`, …) are for expressions whose
78//! arguments are tame by construction.
79//!
80//! # Storage convention
81//!
82//! Tensors are stored FULL, not symmetric-packed: `t4` for K=4 is 256
83//! doubles where 35 would do. This is deliberate clarity-over-speed for the
84//! oracle role — indexing is trivially auditable, contraction loops are
85//! obvious, and the redundancy is itself a checked invariant (the algebra
86//! only ever writes symmetric values). Symmetric packing is a later,
87//! profile-justified optimization behind the same API.
88//!
89//! # Deployment ladder (#932)
90//!
91//! 1. This module: the algebra + the program seam + the oracle.
92//! 2. Universal oracle: every hand-written `RowKernel` gains a CI test
93//!    asserting channel-by-channel agreement with a `RowNllProgram` written
94//!    once — see [`verify_kernel_channels`]. This alone would have caught
95//!    #736 at introduction.
96//! 3. Migrate error-dense / cold towers to [`derived_row_kernel`] et al.;
97//!    keep hand-tuned hot paths, now verified against the single-expression
98//!    truth instead of being the only definition.
99//! 4. New families (#914/#916/#917 ZI/ordinal/expectile, #921's location-
100//!    scale port) implement ONLY `RowNllProgram` and get an exact
101//!    fourth-order tower for the price of writing the likelihood.
102
103use crate::jet_algebra;
104
105/// Truncated fourth-order multivariate Taylor scalar in `K` variables.
106///
107/// See the module documentation for semantics and conventions. `Copy` is
108/// intentional despite the size (2 KiB at K=4): towers are per-row
109/// temporaries that live entirely in registers/stack during a row program,
110/// and value semantics keep program code readable (`a * b + c`).
111#[derive(Clone, Copy, Debug)]
112pub struct Tower4<const K: usize> {
113    /// Value ℓ.
114    pub v: f64,
115    /// Gradient ∂ℓ/∂p_a.
116    pub g: [f64; K],
117    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
118    pub h: [[f64; K]; K],
119    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
120    pub t3: [[[f64; K]; K]; K],
121    /// Fourth derivatives ∂⁴ℓ/∂p_a∂p_b∂p_c∂p_d (fully symmetric).
122    pub t4: [[[[f64; K]; K]; K]; K],
123}
124
125impl<const K: usize> Tower4<K> {
126    /// The additive identity.
127    pub fn zero() -> Self {
128        Self {
129            v: 0.0,
130            g: [0.0; K],
131            h: [[0.0; K]; K],
132            t3: [[[0.0; K]; K]; K],
133            t4: [[[[0.0; K]; K]; K]; K],
134        }
135    }
136
137    /// A constant: value `c`, all derivatives zero.
138    pub fn constant(c: f64) -> Self {
139        let mut out = Self::zero();
140        out.v = c;
141        out
142    }
143
144    /// The seeded variable `p_idx` with current value `value`:
145    /// unit first derivative in slot `idx`, zero elsewhere and above.
146    pub fn variable(value: f64, idx: usize) -> Self {
147        let mut out = Self::constant(value);
148        out.g[idx] = 1.0;
149        out
150    }
151
152    /// Read the (fully symmetric) derivative tensor entry whose differentiation
153    /// axes are `labels` (length 0..=4): value, `g`, `h`, `t3`, `t4`.
154    #[inline]
155    fn deriv(&self, labels: &[usize]) -> f64 {
156        assert!(
157            labels.len() <= 4,
158            "Tower4 carries at most fourth-order derivatives"
159        );
160        match labels.len() {
161            0 => self.v,
162            1 => self.g[labels[0]],
163            2 => self.h[labels[0]][labels[1]],
164            3 => self.t3[labels[0]][labels[1]][labels[2]],
165            _ => self.t4[labels[0]][labels[1]][labels[2]][labels[3]],
166        }
167    }
168
169    /// Exact truncated Leibniz product.
170    ///
171    /// Every output entry `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)` is summed
172    /// by the shared [`jet_algebra::leibniz_product`] subset walker (#1151),
173    /// the same kernel `MultiDirJet::mul` uses; the two layouts differ only in
174    /// how a slot-group selects a derivative.
175    pub fn mul(&self, o: &Self) -> Self {
176        let a = self;
177        let b = o;
178        let mut out = Self::zero();
179        out.v = a.v * b.v;
180        for i in 0..K {
181            let labels = [i];
182            out.g[i] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
183        }
184        for i in 0..K {
185            for j in 0..K {
186                let labels = [i, j];
187                out.h[i][j] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
188            }
189        }
190        for i in 0..K {
191            for j in 0..K {
192                for k in 0..K {
193                    let labels = [i, j, k];
194                    out.t3[i][j][k] =
195                        jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
196                }
197            }
198        }
199        for i in 0..K {
200            for j in 0..K {
201                for k in 0..K {
202                    for l in 0..K {
203                        let labels = [i, j, k, l];
204                        out.t4[i][j][k][l] =
205                            jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
206                    }
207                }
208            }
209        }
210        out
211    }
212
213    /// Exact multivariate Faà di Bruno composition `f ∘ self`.
214    ///
215    /// `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` evaluated at `u = self.v` —
216    /// the SAME `[f64; 5]` stack shape the families' existing
217    /// `unary_derivatives_*` helpers produce, so those special-function
218    /// stacks (Φ, log-Φ, normal pdf, …) plug in directly.
219    ///
220    /// The order-m output sums over the set partitions of the m indices
221    /// (Bell(3) = 5 terms at order 3, Bell(4) = 15 at order 4), grouped by
222    /// block count: each partition into r blocks contributes
223    /// `f⁽ʳ⁾ · Π_blocks D_block(u)`.
224    pub fn compose_unary(&self, d: [f64; 5]) -> Self {
225        <Self as jet_algebra::JetAlgebra<5>>::compose_unary(self, d)
226    }
227
228    /// Multiply every channel by a plain scalar.
229    pub fn scale(&self, s: f64) -> Self {
230        let mut out = *self;
231        out.v *= s;
232        for i in 0..K {
233            out.g[i] *= s;
234            for j in 0..K {
235                out.h[i][j] *= s;
236                for k in 0..K {
237                    out.t3[i][j][k] *= s;
238                    for l in 0..K {
239                        out.t4[i][j][k][l] *= s;
240                    }
241                }
242            }
243        }
244        out
245    }
246
247    /// e^self.
248    pub fn exp(&self) -> Self {
249        let e = self.v.exp();
250        self.compose_unary([e, e, e, e, e])
251    }
252
253    /// ln(self). Caller guarantees positivity (likelihood programs do).
254    pub fn ln(&self) -> Self {
255        let u = self.v;
256        let r = 1.0 / u;
257        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
258    }
259
260    /// 1/self.
261    pub fn recip(&self) -> Self {
262        let r = 1.0 / self.v;
263        let r2 = r * r;
264        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
265    }
266
267    /// √self. Caller guarantees positivity.
268    pub fn sqrt(&self) -> Self {
269        let u = self.v;
270        let s = u.sqrt();
271        self.compose_unary([
272            s,
273            0.5 / s,
274            -0.25 / (u * s),
275            0.375 / (u * u * s),
276            -0.9375 / (u * u * u * s),
277        ])
278    }
279
280    /// self^a for real exponent `a`. Caller guarantees a positive base.
281    pub fn powf(&self, a: f64) -> Self {
282        let u = self.v;
283        let f0 = u.powf(a);
284        let f1 = a * u.powf(a - 1.0);
285        let f2 = a * (a - 1.0) * u.powf(a - 2.0);
286        let f3 = a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0);
287        let f4 = a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0);
288        self.compose_unary([f0, f1, f2, f3, f4])
289    }
290
291    /// ln Γ(self). Caller guarantees positivity.
292    pub fn ln_gamma(&self) -> Self {
293        self.compose_unary(ln_gamma_derivative_stack(self.v))
294    }
295
296    /// ψ(self), the digamma function. Caller guarantees positivity.
297    pub fn digamma(&self) -> Self {
298        self.compose_unary(digamma_derivative_stack(self.v))
299    }
300
301    /// ψ′(self), the trigamma function. Caller guarantees positivity.
302    pub fn trigamma(&self) -> Self {
303        self.compose_unary(trigamma_derivative_stack(self.v))
304    }
305
306    /// Contract `t3` with one primary-space direction:
307    /// `out[a][b] = Σ_c t3[a][b][c] · dir[c]` — exactly the
308    /// `row_third_contracted` shape.
309    pub fn third_contracted(&self, dir: &[f64; K]) -> [[f64; K]; K] {
310        let mut out = [[0.0; K]; K];
311        for a in 0..K {
312            for b in 0..K {
313                let mut acc = 0.0;
314                for c in 0..K {
315                    acc += self.t3[a][b][c] * dir[c];
316                }
317                out[a][b] = acc;
318            }
319        }
320        out
321    }
322
323    /// Contract `t4` with two primary-space directions:
324    /// `out[a][b] = Σ_{c,d} t4[a][b][c][d] · u[c] · v[d]` — exactly the
325    /// `row_fourth_contracted` shape.
326    pub fn fourth_contracted(&self, u: &[f64; K], w: &[f64; K]) -> [[f64; K]; K] {
327        let mut out = [[0.0; K]; K];
328        for i in 0..K {
329            for j in 0..K {
330                let mut acc = 0.0;
331                for k in 0..K {
332                    for l in 0..K {
333                        acc += self.t4[i][j][k][l] * u[k] * w[l];
334                    }
335                }
336                out[i][j] = acc;
337            }
338        }
339        out
340    }
341}
342
343impl<const K: usize> jet_algebra::JetAlgebra<5> for Tower4<K> {
344    #[inline]
345    fn derivative(&self, labels: &[usize]) -> f64 {
346        self.deriv(labels)
347    }
348
349    fn map_derivatives<F>(&self, mut f: F) -> Self
350    where
351        F: FnMut(&[usize]) -> f64,
352    {
353        let mut out = Self::zero();
354        out.v = f(&[]);
355        for i in 0..K {
356            let labels = [i];
357            out.g[i] = f(&labels);
358        }
359        for i in 0..K {
360            for j in 0..K {
361                let labels = [i, j];
362                out.h[i][j] = f(&labels);
363            }
364        }
365        for i in 0..K {
366            for j in 0..K {
367                for k in 0..K {
368                    let labels = [i, j, k];
369                    out.t3[i][j][k] = f(&labels);
370                }
371            }
372        }
373        for i in 0..K {
374            for j in 0..K {
375                for k in 0..K {
376                    for l in 0..K {
377                        let labels = [i, j, k, l];
378                        out.t4[i][j][k][l] = f(&labels);
379                    }
380                }
381            }
382        }
383        out
384    }
385}
386
387/// Truncated SECOND-order multivariate Taylor scalar in `K` variables.
388///
389/// This is the value/gradient/Hessian-only sibling of [`Tower4`]. Every
390/// channel it carries (`v`, `g`, `h`) is computed by the SAME formulas
391/// [`Tower4`] uses for those orders, so for any program written over both
392/// towers the order-≤2 outputs are *bit-identical*: the order-2 Leibniz and
393/// Faà-di-Bruno terms read only the order-≤2 channels of their inputs (see
394/// [`Tower4::mul`] / [`Tower4::compose_unary`] — `out.h` never touches `t3`
395/// or `t4`), so dropping the third/fourth tensors cannot perturb the value,
396/// gradient, or Hessian.
397///
398/// It exists purely for performance: an inner Newton step (and the
399/// value-only ρ-homotopy pre-warm) needs at most curvature, never the
400/// outer-κ/ψ third/fourth derivatives. Evaluating a row likelihood over
401/// `Tower2` skips the `K⁴` fourth-tensor product/composition arithmetic that
402/// dominates the cold marginal-slope fit, while returning the exact same
403/// `(v, g, h)`.
404#[derive(Clone, Copy, Debug)]
405pub struct Tower2<const K: usize> {
406    /// Value ℓ.
407    pub v: f64,
408    /// Gradient ∂ℓ/∂p_a.
409    pub g: [f64; K],
410    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
411    pub h: [[f64; K]; K],
412}
413
414impl<const K: usize> Tower2<K> {
415    /// The additive identity.
416    pub fn zero() -> Self {
417        Self {
418            v: 0.0,
419            g: [0.0; K],
420            h: [[0.0; K]; K],
421        }
422    }
423
424    /// A constant: value `c`, all derivatives zero.
425    pub fn constant(c: f64) -> Self {
426        let mut out = Self::zero();
427        out.v = c;
428        out
429    }
430
431    /// The seeded variable `p_idx` with current value `value`:
432    /// unit first derivative in slot `idx`, zero elsewhere and above.
433    pub fn variable(value: f64, idx: usize) -> Self {
434        let mut out = Self::constant(value);
435        out.g[idx] = 1.0;
436        out
437    }
438
439    /// Read the derivative tensor entry whose differentiation axes are
440    /// `labels` (length 0..=2): value, `g`, `h`.
441    #[inline]
442    fn deriv(&self, labels: &[usize]) -> f64 {
443        assert!(
444            labels.len() <= 2,
445            "Tower2 carries at most second-order derivatives"
446        );
447        match labels.len() {
448            0 => self.v,
449            1 => self.g[labels[0]],
450            _ => self.h[labels[0]][labels[1]],
451        }
452    }
453
454    /// Exact truncated (order ≤ 2) Leibniz product. The `v`/`g`/`h` channels
455    /// match [`Tower4::mul`] term-for-term.
456    pub fn mul(&self, o: &Self) -> Self {
457        let a = self;
458        let b = o;
459        let mut out = Self::zero();
460        out.v = a.v * b.v;
461        for i in 0..K {
462            out.g[i] = a.v * b.g[i] + a.g[i] * b.v;
463        }
464        for i in 0..K {
465            for j in 0..K {
466                out.h[i][j] = a.v * b.h[i][j] + a.g[i] * b.g[j] + a.g[j] * b.g[i] + a.h[i][j] * b.v;
467            }
468        }
469        out
470    }
471
472    /// Exact (order ≤ 2) multivariate Faà di Bruno composition `f ∘ self`.
473    ///
474    /// `d = [f(u), f′(u), f″(u)]` evaluated at `u = self.v`. The `v`/`g`/`h`
475    /// channels match [`Tower4::compose_unary`] term-for-term (which uses only
476    /// `d[0..=2]` for those orders), so this is a strict truncation, not an
477    /// approximation. The full-order `[f64; 5]` derivative stacks the families
478    /// already produce can be passed by slicing their first three entries.
479    pub fn compose_unary(&self, d: [f64; 3]) -> Self {
480        <Self as jet_algebra::JetAlgebra<3>>::compose_unary(self, d)
481    }
482
483    /// Multiply every channel by a plain scalar.
484    pub fn scale(&self, s: f64) -> Self {
485        let mut out = *self;
486        out.v *= s;
487        for i in 0..K {
488            out.g[i] *= s;
489            for j in 0..K {
490                out.h[i][j] *= s;
491            }
492        }
493        out
494    }
495
496    /// e^self.
497    pub fn exp(&self) -> Self {
498        let e = self.v.exp();
499        self.compose_unary([e, e, e])
500    }
501
502    /// √self. Caller guarantees positivity.
503    pub fn sqrt(&self) -> Self {
504        let u = self.v;
505        let s = u.sqrt();
506        self.compose_unary([s, 0.5 / s, -0.25 / (u * s)])
507    }
508}
509
510impl<const K: usize> jet_algebra::JetAlgebra<3> for Tower2<K> {
511    #[inline]
512    fn derivative(&self, labels: &[usize]) -> f64 {
513        self.deriv(labels)
514    }
515
516    fn map_derivatives<F>(&self, mut f: F) -> Self
517    where
518        F: FnMut(&[usize]) -> f64,
519    {
520        let mut out = Self::zero();
521        out.v = f(&[]);
522        for i in 0..K {
523            let labels = [i];
524            out.g[i] = f(&labels);
525        }
526        for i in 0..K {
527            for j in 0..K {
528                let labels = [i, j];
529                out.h[i][j] = f(&labels);
530            }
531        }
532        out
533    }
534}
535
536impl<const K: usize> std::ops::Add for Tower2<K> {
537    type Output = Self;
538    fn add(self, o: Self) -> Self {
539        let mut out = self;
540        out.v += o.v;
541        for i in 0..K {
542            out.g[i] += o.g[i];
543            for j in 0..K {
544                out.h[i][j] += o.h[i][j];
545            }
546        }
547        out
548    }
549}
550
551impl<const K: usize> std::ops::Mul for Tower2<K> {
552    type Output = Self;
553    fn mul(self, o: Self) -> Self {
554        Tower2::mul(&self, &o)
555    }
556}
557
558impl<const K: usize> std::ops::Add<f64> for Tower2<K> {
559    type Output = Self;
560    fn add(self, c: f64) -> Self {
561        let mut out = self;
562        out.v += c;
563        out
564    }
565}
566
567impl<const K: usize> std::ops::Mul<f64> for Tower2<K> {
568    type Output = Self;
569    fn mul(self, c: f64) -> Self {
570        self.scale(c)
571    }
572}
573
574/// Truncated THIRD-order multivariate Taylor scalar in `K` variables.
575///
576/// The value/gradient/Hessian/third-derivative sibling of [`Tower4`], standing
577/// between [`Tower2`] and [`Tower4`]. Every channel it carries (`v`, `g`, `h`,
578/// `t3`) is computed by the SAME shared Leibniz / Faà-di-Bruno kernels
579/// [`Tower4`] uses for those orders, and the order-≤3 terms of those kernels
580/// read only the order-≤3 channels of their inputs (the order-3 Faà-di-Bruno
581/// partitions never reach the f⁗ stack slot or the inner `t4` tensor — see
582/// [`Tower4::compose_unary`]). So for any program written over both towers the
583/// order-≤3 outputs are *bit-identical*: dropping the fourth tensor cannot
584/// perturb the value, gradient, Hessian, or third derivatives.
585///
586/// It exists purely for performance, exactly like [`Tower2`]: a consumer that
587/// needs up to third derivatives (the survival location-scale row kernel reads
588/// `g`, the diagonal `h`, and the diagonal `t3`, but never `t4`) pays the
589/// `K³` third-tensor arithmetic but skips the `K⁴` fourth-tensor
590/// product/composition that otherwise dominates the per-row cost.
591#[derive(Clone, Copy, Debug)]
592pub struct Tower3<const K: usize> {
593    /// Value ℓ.
594    pub v: f64,
595    /// Gradient ∂ℓ/∂p_a.
596    pub g: [f64; K],
597    /// Hessian ∂²ℓ/∂p_a∂p_b (symmetric).
598    pub h: [[f64; K]; K],
599    /// Third derivatives ∂³ℓ/∂p_a∂p_b∂p_c (fully symmetric).
600    pub t3: [[[f64; K]; K]; K],
601}
602
603impl<const K: usize> Tower3<K> {
604    /// The additive identity.
605    pub fn zero() -> Self {
606        Self {
607            v: 0.0,
608            g: [0.0; K],
609            h: [[0.0; K]; K],
610            t3: [[[0.0; K]; K]; K],
611        }
612    }
613
614    /// A constant: value `c`, all derivatives zero.
615    pub fn constant(c: f64) -> Self {
616        let mut out = Self::zero();
617        out.v = c;
618        out
619    }
620
621    /// The seeded variable `p_idx` with current value `value`:
622    /// unit first derivative in slot `idx`, zero elsewhere and above.
623    pub fn variable(value: f64, idx: usize) -> Self {
624        let mut out = Self::constant(value);
625        out.g[idx] = 1.0;
626        out
627    }
628
629    /// Read the (fully symmetric) derivative tensor entry whose differentiation
630    /// axes are `labels` (length 0..=3): value, `g`, `h`, `t3`.
631    #[inline]
632    fn deriv(&self, labels: &[usize]) -> f64 {
633        assert!(
634            labels.len() <= 3,
635            "Tower3 carries at most third-order derivatives"
636        );
637        match labels.len() {
638            0 => self.v,
639            1 => self.g[labels[0]],
640            2 => self.h[labels[0]][labels[1]],
641            _ => self.t3[labels[0]][labels[1]][labels[2]],
642        }
643    }
644
645    /// Exact truncated (order ≤ 3) Leibniz product. The `v`/`g`/`h`/`t3`
646    /// channels match [`Tower4::mul`] term-for-term.
647    pub fn mul(&self, o: &Self) -> Self {
648        let a = self;
649        let b = o;
650        let mut out = Self::zero();
651        out.v = a.v * b.v;
652        for i in 0..K {
653            let labels = [i];
654            out.g[i] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
655        }
656        for i in 0..K {
657            for j in 0..K {
658                let labels = [i, j];
659                out.h[i][j] = jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
660            }
661        }
662        for i in 0..K {
663            for j in 0..K {
664                for k in 0..K {
665                    let labels = [i, j, k];
666                    out.t3[i][j][k] =
667                        jet_algebra::leibniz_product(&labels, |t| a.deriv(t), |c| b.deriv(c));
668                }
669            }
670        }
671        out
672    }
673
674    /// Exact (order ≤ 3) multivariate Faà di Bruno composition `f ∘ self`.
675    ///
676    /// `d = [f(u), f′(u), f″(u), f‴(u)]` evaluated at `u = self.v`. The
677    /// `v`/`g`/`h`/`t3` channels match [`Tower4::compose_unary`] term-for-term
678    /// (which uses only `d[0..=3]` for those orders), so this is a strict
679    /// truncation, not an approximation. The full-order `[f64; 5]` derivative
680    /// stacks the families already produce can be passed by slicing their first
681    /// four entries.
682    pub fn compose_unary(&self, d: [f64; 4]) -> Self {
683        <Self as jet_algebra::JetAlgebra<4>>::compose_unary(self, d)
684    }
685
686    /// Multiply every channel by a plain scalar.
687    pub fn scale(&self, s: f64) -> Self {
688        let mut out = *self;
689        out.v *= s;
690        for i in 0..K {
691            out.g[i] *= s;
692            for j in 0..K {
693                out.h[i][j] *= s;
694                for k in 0..K {
695                    out.t3[i][j][k] *= s;
696                }
697            }
698        }
699        out
700    }
701}
702
703impl<const K: usize> jet_algebra::JetAlgebra<4> for Tower3<K> {
704    #[inline]
705    fn derivative(&self, labels: &[usize]) -> f64 {
706        self.deriv(labels)
707    }
708
709    fn map_derivatives<F>(&self, mut f: F) -> Self
710    where
711        F: FnMut(&[usize]) -> f64,
712    {
713        let mut out = Self::zero();
714        out.v = f(&[]);
715        for i in 0..K {
716            let labels = [i];
717            out.g[i] = f(&labels);
718        }
719        for i in 0..K {
720            for j in 0..K {
721                let labels = [i, j];
722                out.h[i][j] = f(&labels);
723            }
724        }
725        for i in 0..K {
726            for j in 0..K {
727                for k in 0..K {
728                    let labels = [i, j, k];
729                    out.t3[i][j][k] = f(&labels);
730                }
731            }
732        }
733        out
734    }
735}
736
737impl<const K: usize> std::ops::Add for Tower3<K> {
738    type Output = Self;
739    fn add(self, o: Self) -> Self {
740        let mut out = self;
741        out.v += o.v;
742        for i in 0..K {
743            out.g[i] += o.g[i];
744            for j in 0..K {
745                out.h[i][j] += o.h[i][j];
746                for k in 0..K {
747                    out.t3[i][j][k] += o.t3[i][j][k];
748                }
749            }
750        }
751        out
752    }
753}
754
755pub fn ln_gamma_derivative_stack(x: f64) -> [f64; 5] {
756    [
757        statrs::function::gamma::ln_gamma(x),
758        digamma_positive(x),
759        polygamma_positive(1, x),
760        polygamma_positive(2, x),
761        polygamma_positive(3, x),
762    ]
763}
764
765pub fn digamma_derivative_stack(x: f64) -> [f64; 5] {
766    [
767        digamma_positive(x),
768        polygamma_positive(1, x),
769        polygamma_positive(2, x),
770        polygamma_positive(3, x),
771        polygamma_positive(4, x),
772    ]
773}
774
775pub fn trigamma_derivative_stack(x: f64) -> [f64; 5] {
776    [
777        polygamma_positive(1, x),
778        polygamma_positive(2, x),
779        polygamma_positive(3, x),
780        polygamma_positive(4, x),
781        polygamma_positive(5, x),
782    ]
783}
784
785fn digamma_positive(mut x: f64) -> f64 {
786    if !(x.is_finite() && x > 0.0) {
787        return f64::NAN;
788    }
789    let mut acc = 0.0;
790    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
791        acc -= 1.0 / x;
792        x += 1.0;
793    }
794    acc + digamma_asymptotic(x)
795}
796
797fn polygamma_positive(order: usize, mut x: f64) -> f64 {
798    if !(x.is_finite() && x > 0.0) {
799        return f64::NAN;
800    }
801    let mut acc = 0.0;
802    while x < POLYGAMMA_ASYMPTOTIC_MIN_X {
803        acc += polygamma_recurrence_term(order, x);
804        x += 1.0;
805    }
806    acc + polygamma_asymptotic(order, x)
807}
808
809const POLYGAMMA_ASYMPTOTIC_MIN_X: f64 = 20.0;
810const BERNOULLI_EVEN: [(usize, f64); 10] = [
811    (2, 1.0 / 6.0),
812    (4, -1.0 / 30.0),
813    (6, 1.0 / 42.0),
814    (8, -1.0 / 30.0),
815    (10, 5.0 / 66.0),
816    (12, -691.0 / 2730.0),
817    (14, 7.0 / 6.0),
818    (16, -3617.0 / 510.0),
819    (18, 43867.0 / 798.0),
820    (20, -174611.0 / 330.0),
821];
822
823fn polygamma_recurrence_term(order: usize, x: f64) -> f64 {
824    let sign = if order % 2 == 1 { 1.0 } else { -1.0 };
825    sign * factorial(order) / x.powi((order + 1) as i32)
826}
827
828fn digamma_asymptotic(x: f64) -> f64 {
829    let mut out = x.ln() - 0.5 / x;
830    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
831        out -= bernoulli / (bernoulli_order as f64 * x.powi(bernoulli_order as i32));
832    }
833    out
834}
835
836fn polygamma_asymptotic(order: usize, x: f64) -> f64 {
837    if !(1..=5).contains(&order) {
838        return f64::NAN;
839    }
840
841    let order_factorial = factorial(order);
842    let leading_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
843    let mut out = leading_sign * factorial(order - 1) / x.powi(order as i32)
844        + leading_sign * order_factorial / (2.0 * x.powi((order + 1) as i32));
845
846    let bernoulli_sign = if order % 2 == 1 { 1.0 } else { -1.0 };
847    for (bernoulli_order, bernoulli) in BERNOULLI_EVEN {
848        let rising = rising_factorial(bernoulli_order, order);
849        out += bernoulli_sign * bernoulli * rising
850            / bernoulli_order as f64
851            / x.powi((bernoulli_order + order) as i32);
852    }
853    out
854}
855
856fn factorial(n: usize) -> f64 {
857    (1..=n).fold(1.0, |acc, k| acc * k as f64)
858}
859
860fn rising_factorial(start: usize, len: usize) -> f64 {
861    (start..start + len).fold(1.0, |acc, k| acc * k as f64)
862}
863
864impl<const K: usize> std::ops::Add for Tower4<K> {
865    type Output = Self;
866    fn add(self, o: Self) -> Self {
867        let mut out = self;
868        out.v += o.v;
869        for i in 0..K {
870            out.g[i] += o.g[i];
871            for j in 0..K {
872                out.h[i][j] += o.h[i][j];
873                for k in 0..K {
874                    out.t3[i][j][k] += o.t3[i][j][k];
875                    for l in 0..K {
876                        out.t4[i][j][k][l] += o.t4[i][j][k][l];
877                    }
878                }
879            }
880        }
881        out
882    }
883}
884
885impl<const K: usize> std::ops::Sub for Tower4<K> {
886    type Output = Self;
887    fn sub(self, o: Self) -> Self {
888        self + o.scale(-1.0)
889    }
890}
891
892impl<const K: usize> std::ops::Neg for Tower4<K> {
893    type Output = Self;
894    fn neg(self) -> Self {
895        self.scale(-1.0)
896    }
897}
898
899impl<const K: usize> std::ops::Mul for Tower4<K> {
900    type Output = Self;
901    fn mul(self, o: Self) -> Self {
902        Tower4::mul(&self, &o)
903    }
904}
905
906impl<const K: usize> std::ops::Div for Tower4<K> {
907    type Output = Self;
908    fn div(self, o: Self) -> Self {
909        Tower4::mul(&self, &o.recip())
910    }
911}
912
913impl<const K: usize> std::ops::Add<f64> for Tower4<K> {
914    type Output = Self;
915    fn add(self, c: f64) -> Self {
916        let mut out = self;
917        out.v += c;
918        out
919    }
920}
921
922impl<const K: usize> std::ops::Sub<f64> for Tower4<K> {
923    type Output = Self;
924    fn sub(self, c: f64) -> Self {
925        self + (-c)
926    }
927}
928
929impl<const K: usize> std::ops::Mul<f64> for Tower4<K> {
930    type Output = Self;
931    fn mul(self, c: f64) -> Self {
932        self.scale(c)
933    }
934}
935
936// ── Implicit-function and moving-boundary seams (#932 flex) ──────────
937//
938// The flexible survival marginal-slope row loss is NOT a free composition
939// of the primaries: it threads an IMPLICIT calibration intercept `a(θ)`
940// solving a constraint `F(a, θ) = 0`, and integrates a density over cells
941// whose edges `z_L(θ), z_R(θ)` MOVE with θ through that intercept. Plain
942// `Tower4` Faà di Bruno cannot express either — so the flex tower was the
943// last hand-written one in the codebase, and the genus of #736-class
944// drift bugs (the (g,w0) deviation-cross third was 3× short for exactly
945// this reason). These two combinators close that gap: once the constraint
946// `F` and the integrand/boundaries are themselves towers, the intercept's
947// derivative tower and the integral's derivative tower come out EXACTLY at
948// every order — there is no order left to hand-code and forget.
949
950/// Solve the implicit relation `F(a(θ), θ) ≡ 0` for the intercept tower
951/// `a(θ)` over the `K` primaries θ, given the constraint tower `f` written
952/// over `K + 1` variables (slot `0` is the intercept `a`, slots `1..=K`
953/// are the primaries θ) evaluated at the SOLVED point — i.e. `f.v` is the
954/// constraint residual at `(a₀, θ₀)` (≈ 0 from the production Newton solve)
955/// and `a0` is that solved intercept value.
956///
957/// Returns the `Tower4<K>` whose value is `a0` and whose every derivative
958/// tensor (∂a/∂θ, ∂²a/∂θ², …, ∂⁴a/∂θ⁴) is the exact implicit-function
959/// derivative. This is the mechanical replacement for the hand-coded
960/// `a_u = -f_u/f_a`, `a_uv = -(f_uv + f_au·a_v + f_av·a_u + f_aa·a_u·a_v)/f_a`
961/// recursion (first_full.rs) and its third/fourth-order continuations.
962///
963/// Method: order-by-order substitution. We build `a` incrementally; at each
964/// order `m` the composite `G(θ) = f(a(θ), θ)` has a top-order coefficient
965/// that is linear in `a`'s order-`m` tensor with leading factor `F_a`
966/// (= `f.g[0]`), plus terms in `a`'s lower orders already fixed. Setting the
967/// order-`m` tensor of `a` to cancel the rest of `G`'s order-`m` coefficient
968/// keeps `G ≡ 0` through that order. The substitution `G = f∘(a, θ)` reuses
969/// only the exact [`substitute_intercept`] chain rule, so the recursion is
970/// auditable and exact, not a hand-expanded formula per order.
971///
972/// `f.g[0]` (= ∂F/∂a) must be non-zero — guaranteed by the production
973/// solve's strict monotonicity guard.
974///
975/// The expansion point `a0` must be a genuine root `F(a0, θ0) = 0`: the
976/// substitution recursion below cancels orders 1..=4 of `G = F∘a` but never
977/// touches order 0, so a non-root `a0` would yield the Taylor expansion of
978/// the LEVEL SET `F = F(a0)` through `a0`, not the root curve `F = 0`. This
979/// is guarded explicitly and re-verified by a composed-residual self-check.
980pub fn implicit_solve<const K1: usize, const K: usize>(
981    f: &Tower4<K1>,
982    a0: f64,
983) -> Result<Tower4<K>, String> {
984    assert_eq!(K1, K + 1, "implicit_solve: constraint must carry K+1 vars");
985    let f_a = f.g[0];
986    if f_a == 0.0 || !f_a.is_finite() {
987        return Err(format!(
988            "implicit_solve: ∂F/∂a = {f_a:+.3e} is not invertible"
989        ));
990    }
991    // The expansion point must be a genuine root of F. The single Newton
992    // correction that would move a0 onto the root is |f.v|/|f_a|; require it
993    // to be negligible relative to the natural scale (1 + |a0|). Guarding the
994    // Newton step (rather than f.v directly) makes the criterion invariant to
995    // the magnitude of f_a / the units of F.
996    let root_tol = 1e-9;
997    if !f.v.is_finite() {
998        return Err(format!(
999            "implicit_solve: F(a0, θ0) = {:+.3e} is not finite",
1000            f.v
1001        ));
1002    }
1003    let newton_step = f.v.abs() / f_a.abs();
1004    if newton_step > root_tol * (1.0 + a0.abs()) {
1005        return Err(format!(
1006            "implicit_solve: expansion point a0 = {a0:+.6e} is not a root of F: \
1007             F(a0, θ0) = {:+.3e}, Newton correction {newton_step:+.3e} exceeds \
1008             root_tol {root_tol:.1e} · (1 + |a0|)",
1009            f.v
1010        ));
1011    }
1012    // Start with a = constant a0 (correct through order 0). Then lift each
1013    // order in turn. Because substitute_intercept reads `a`'s order-≤m
1014    // tensors when forming G's order-m coefficient, and the order-m
1015    // coefficient of G depends on a's order-m tensor ONLY through the linear
1016    // F_a·a_m term, a single corrective pass per order is exact.
1017    let mut a = Tower4::<K>::constant(a0);
1018    for order in 1..=4 {
1019        let g = substitute_intercept(f, &a);
1020        // Cancel G's order-`order` coefficient by adjusting a's order-`order`
1021        // tensor: a_m -= G_m / F_a (the F_a·a_m term is the only one carrying
1022        // a's order-m tensor, with unit chain coefficient since slot 0 seeds a
1023        // as a plain variable in the substitution's first-order part).
1024        match order {
1025            1 => {
1026                for i in 0..K {
1027                    a.g[i] -= g.g[i] / f_a;
1028                }
1029            }
1030            2 => {
1031                for i in 0..K {
1032                    for j in 0..K {
1033                        a.h[i][j] -= g.h[i][j] / f_a;
1034                    }
1035                }
1036            }
1037            3 => {
1038                for i in 0..K {
1039                    for j in 0..K {
1040                        for k in 0..K {
1041                            a.t3[i][j][k] -= g.t3[i][j][k] / f_a;
1042                        }
1043                    }
1044                }
1045            }
1046            _ => {
1047                for i in 0..K {
1048                    for j in 0..K {
1049                        for k in 0..K {
1050                            for l in 0..K {
1051                                a.t4[i][j][k][l] -= g.t4[i][j][k][l] / f_a;
1052                            }
1053                        }
1054                    }
1055                }
1056            }
1057        }
1058    }
1059    // Self-check: the composed residual G = F∘a must vanish through order 4.
1060    // By construction orders 1..=4 were cancelled; the value G.v == F(a0,θ0)
1061    // is exactly the root requirement guarded above. Re-verify all channels
1062    // against a scale-aware floor so any arithmetic regression in the
1063    // substitution recursion is loud rather than silently shipping a
1064    // level-set expansion.
1065    let g = substitute_intercept(f, &a);
1066    let resid_tol = 1e-7 * (1.0 + f_a.abs());
1067    let mut worst = g.v.abs();
1068    for i in 0..K {
1069        worst = worst.max(g.g[i].abs());
1070        for j in 0..K {
1071            worst = worst.max(g.h[i][j].abs());
1072            for k in 0..K {
1073                worst = worst.max(g.t3[i][j][k].abs());
1074                for l in 0..K {
1075                    worst = worst.max(g.t4[i][j][k][l].abs());
1076                }
1077            }
1078        }
1079    }
1080    if !worst.is_finite() || worst > resid_tol {
1081        return Err(format!(
1082            "implicit_solve: composed residual G = F∘a does not vanish: \
1083             worst channel magnitude {worst:+.3e} exceeds tol {resid_tol:.1e}"
1084        ));
1085    }
1086    Ok(a)
1087}
1088
1089/// Substitute the intercept tower `a(θ)` into slot `0` of a constraint
1090/// written over `K + 1` variables, returning the composite tower over the
1091/// `K` primaries θ: `G(θ) = f(a(θ), θ₁, …, θ_K)`.
1092///
1093/// This is the exact multivariate chain rule specialised to "slot 0 is a
1094/// dependent tower, slots 1..=K are the independent primaries". It evaluates
1095/// `f`'s fourth-order multivariate Taylor polynomial about the expansion
1096/// point, with the slot-0 increment being the non-constant part of `a` and
1097/// the slot-(i) increment being the unit-seeded primary `θ_i`. The sum is
1098/// assembled by the same subset/partition algebra `Tower4` arithmetic uses,
1099/// so it carries derivatives exactly through order four.
1100pub fn substitute_intercept<const K1: usize, const K: usize>(
1101    f: &Tower4<K1>,
1102    a: &Tower4<K>,
1103) -> Tower4<K> {
1104    assert_eq!(K1, K + 1);
1105    // Build the K+1 input towers in θ-space: slot 0 = a(θ), slot i+1 = θ_i.
1106    // The composite is Σ over ordered label tuples s (|s| ≤ 4) of input
1107    // indices: (1/|s|!) · f.deriv(s) · Π_{j in s} (inp[s_j] centred) — but
1108    // since f.deriv is the SYMMETRIC partial tensor and we enumerate ordered
1109    // tuples, the 1/|s|! exactly cancels the tuple multiplicity. We assemble
1110    // it directly as a Horner-free explicit sum over the (K+1)-ary tuples,
1111    // using tower products for the increment monomials so all θ-derivatives
1112    // propagate exactly.
1113    let inp: [Tower4<K>; K1] = std::array::from_fn(|slot| {
1114        if slot == 0 {
1115            // slot 0: a(θ) minus its constant value (the increment δa(θ)).
1116            let mut d = *a;
1117            d.v = 0.0;
1118            d
1119        } else {
1120            // slot i: the increment δθ_{i-1} = seeded variable minus value.
1121            // θ centred at its expansion value has zero constant term and unit
1122            // first derivative in its own slot.
1123            let mut d = Tower4::<K>::zero();
1124            d.g[slot - 1] = 1.0;
1125            d
1126        }
1127    });
1128    // Accumulate the Taylor sum. order-0 term:
1129    let mut out = Tower4::<K>::constant(f.v);
1130    // order 1: Σ_a f.g[a] · inp[a]
1131    for a_idx in 0..K1 {
1132        out = out + inp[a_idx].scale(f.g[a_idx]);
1133    }
1134    // order 2: (1/2) Σ_{a,b} f.h[a][b] · inp[a]·inp[b]
1135    for a_idx in 0..K1 {
1136        for b_idx in 0..K1 {
1137            let prod = inp[a_idx].mul(&inp[b_idx]);
1138            out = out + prod.scale(0.5 * f.h[a_idx][b_idx]);
1139        }
1140    }
1141    // order 3: (1/6) Σ f.t3[a][b][c] · inp[a]·inp[b]·inp[c]
1142    for a_idx in 0..K1 {
1143        for b_idx in 0..K1 {
1144            for c_idx in 0..K1 {
1145                let prod = inp[a_idx].mul(&inp[b_idx]).mul(&inp[c_idx]);
1146                out = out + prod.scale(f.t3[a_idx][b_idx][c_idx] / 6.0);
1147            }
1148        }
1149    }
1150    // order 4: (1/24) Σ f.t4[a][b][c][d] · inp[a]·inp[b]·inp[c]·inp[d]
1151    for a_idx in 0..K1 {
1152        for b_idx in 0..K1 {
1153            for c_idx in 0..K1 {
1154                for d_idx in 0..K1 {
1155                    let prod = inp[a_idx]
1156                        .mul(&inp[b_idx])
1157                        .mul(&inp[c_idx])
1158                        .mul(&inp[d_idx]);
1159                    out = out + prod.scale(f.t4[a_idx][b_idx][c_idx][d_idx] / 24.0);
1160                }
1161            }
1162        }
1163    }
1164    out
1165}
1166
1167/// The exact θ-derivative tower of a moving-LIMIT integral's BOUNDARY
1168/// contribution: given the edge-position tower `z_edge(θ)` over the `K`
1169/// primaries and the integrand `B` evaluated-and-differentiated at the edge
1170/// value as the stack `b_stack = [B(z₀), B′(z₀), B″(z₀), B‴(z₀)]`
1171/// (`z₀ = z_edge.v`), returns the tower of `Φ(z_edge(θ))` where `Φ′ = B`.
1172///
1173/// Rationale: `∂_θ ∫^{z_edge(θ)} B(z) dz = Φ(z_edge(θ))` with `Φ` an
1174/// antiderivative of `B`, so the boundary part of every θ-derivative of the
1175/// integral is just the composition `Φ ∘ z_edge` — whose Faà di Bruno
1176/// expansion carries, at one stroke, EVERY Leibniz boundary term the
1177/// hand-written flux dropped: the first-order `B·z_u`, the second-order
1178/// `B′·z_u·z_v + B·z_uv` (the `G_z·z_u·z_v` self-flux AND the previously
1179/// dropped `G·z_uv`), and the full third/fourth-order continuations. The
1180/// VALUE channel of the returned tower is meaningless (`Φ` is only defined up
1181/// to a constant); callers read only the derivative channels and pair this
1182/// with the interior moment-integral value separately.
1183///
1184/// `b_stack` holds `B` and its first three z-derivatives; the antiderivative
1185/// `Φ` contributes only as the order-≥1 channels, so `compose_unary` receives
1186/// `[0, B, B′, B″, B‴]` — the leading `0` is the discarded `Φ(z₀)` slot.
1187pub fn moving_limit_boundary_tower<const K: usize>(
1188    z_edge: &Tower4<K>,
1189    b_stack: [f64; 4],
1190) -> Tower4<K> {
1191    z_edge.compose_unary([0.0, b_stack[0], b_stack[1], b_stack[2], b_stack[3]])
1192}
1193
1194/// The boundary-flux derivative tower of a single moving cell integral
1195/// `∫_{z_L(θ)}^{z_R(θ)} B dz`: `Φ(z_R(θ)) − Φ(z_L(θ))`, assembled from the
1196/// two edge towers and the integrand stacks at each edge. The returned
1197/// tower's derivative channels are the EXACT moving-boundary contribution to
1198/// every θ-derivative of the cell integral, to fourth order, with no term
1199/// hand-omitted. A `Fixed` (non-moving) edge passes a `z_edge` whose
1200/// derivative channels are all zero, contributing nothing — matching the
1201/// production `edge_vel = 0` short-circuit.
1202pub fn cell_moving_boundary_flux_tower<const K: usize>(
1203    z_right: &Tower4<K>,
1204    b_stack_right: [f64; 4],
1205    z_left: &Tower4<K>,
1206    b_stack_left: [f64; 4],
1207) -> Tower4<K> {
1208    moving_limit_boundary_tower(z_right, b_stack_right)
1209        - moving_limit_boundary_tower(z_left, b_stack_left)
1210}
1211
1212/// Moving-limit boundary tower for a θ-DEPENDENT integrand `G(z; θ)`.
1213///
1214/// [`moving_limit_boundary_tower`] assumes the integrand depends on θ only
1215/// through the moving edge `z_edge(θ)` (a fixed z-derivative `b_stack`). The
1216/// marginal-slope flex boundary is richer: the integrand `G(z; θ)` ALSO carries
1217/// its own θ-dependence (the density weight `w = e^{−q}/2π` and the cell
1218/// integrand coefficients move with η, hence with the primaries), so the
1219/// Leibniz expansion of `∂ⁿ_θ ∫^{z_edge(θ)} G(z;θ) dz` mixes edge-motion
1220/// derivatives of the limit with θ-derivatives of `G` itself — e.g. at second
1221/// order `G·z_uv + G_z·z_u·z_v + G_{θu}·z_v + G_{θv}·z_u` (the four
1222/// edge-motion-carrying terms the hand path assembles one by one, including the
1223/// `G·z_uv` term the directional path drops).
1224///
1225/// Mechanization: let `Φ(z; θ)` be the z-antiderivative of `G` (so `Φ_z = G`).
1226/// The full upper-limit contribution is `Φ(z_edge(θ); θ)`, and the BOUNDARY
1227/// part — everything carrying edge motion — is exactly
1228///   `Φ(z_edge(θ); θ) − Φ(z₀; θ)`,
1229/// the second term being the pure-integrand-θ part (`∫^{z₀} ∂ⁿ_θ G`) the
1230/// interior moment integral already supplies. Both are one
1231/// [`substitute_intercept`] of the SAME mixed `(z, θ)` jet of `Φ` (z in slot 0,
1232/// θ in slots 1..K): substituting the edge tower gives the full composite,
1233/// substituting a frozen constant edge isolates the pure-θ part, and their
1234/// difference is the exact boundary flux — every Leibniz term derived by the
1235/// substitution algebra, none hand-omitted.
1236///
1237/// `phi_jet` is the `(K+1)`-variable Taylor jet of `Φ` about `(z₀, θ₀)` with
1238/// `z₀ = z_edge.v`: slot 0 is the z-direction (so `phi_jet.g[0] = G(z₀;θ₀)`,
1239/// `phi_jet.h[0][0] = G_z`, …) and slots `1..=K` are the primaries θ (carrying
1240/// `Φ`'s own θ- and mixed z·θ-derivatives — i.e. the integrand's θ-derivatives
1241/// integrated in z, and `G_{θ…}` in the mixed slots). The returned tower's
1242/// VALUE channel is 0 by construction (the `Φ(z₀;θ₀)` constants cancel); only
1243/// the derivative channels are meaningful, matching the value-less convention of
1244/// [`moving_limit_boundary_tower`].
1245pub fn moving_limit_boundary_tower_theta_integrand<const K1: usize, const K: usize>(
1246    phi_jet: &Tower4<K1>,
1247    z_edge: &Tower4<K>,
1248) -> Tower4<K> {
1249    assert_eq!(
1250        K1,
1251        K + 1,
1252        "moving_limit_boundary_tower_theta_integrand: Φ jet must carry z + K θ-vars"
1253    );
1254    let frozen_edge = Tower4::<K>::constant(z_edge.v);
1255    let full = substitute_intercept(phi_jet, z_edge);
1256    let interior = substitute_intercept(phi_jet, &frozen_edge);
1257    full - interior
1258}
1259
1260/// Two-edge cell version of [`moving_limit_boundary_tower_theta_integrand`]:
1261/// the exact boundary-flux tower of `∫_{z_L(θ)}^{z_R(θ)} G(z;θ) dz` with a
1262/// θ-dependent integrand, `Φ(z_R;θ) − Φ(z_L;θ)` minus the pure-θ parts at each
1263/// frozen edge. A `Fixed` edge passes a `z_edge` with zero derivative channels,
1264/// so its `full` and `interior` substitutions coincide and it contributes
1265/// nothing — matching the production `edge_vel = 0` short-circuit.
1266pub fn cell_moving_boundary_flux_tower_theta_integrand<const K1: usize, const K: usize>(
1267    phi_jet_right: &Tower4<K1>,
1268    z_right: &Tower4<K>,
1269    phi_jet_left: &Tower4<K1>,
1270    z_left: &Tower4<K>,
1271) -> Tower4<K> {
1272    moving_limit_boundary_tower_theta_integrand(phi_jet_right, z_right)
1273        - moving_limit_boundary_tower_theta_integrand(phi_jet_left, z_left)
1274}
1275
1276// ── The program seam ─────────────────────────────────────────────────
1277
1278/// A family's row negative log-likelihood written ONCE over tower scalars.
1279///
1280/// This is the single source of truth #932 asks for: the value channel of
1281/// the returned tower must BE the production row NLL (same branches, same
1282/// guards, same numerics), and every derivative channel is then exact by
1283/// construction. The linear Jacobian wiring (coefficients ↔ primaries) is
1284/// NOT part of this trait — it is family data, not calculus, and stays on
1285/// the `RowKernel` implementor.
1286pub trait RowNllProgram<const K: usize>: Send + Sync {
1287    /// Number of observations the program covers.
1288    fn n_rows(&self) -> usize;
1289
1290    /// Current primary-scalar values for `row` (where to seed the tower).
1291    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1292
1293    /// The row NLL evaluated on tower scalars. `p[a]` arrives pre-seeded as
1294    /// variable `a` at the current primary value; implementations combine
1295    /// them with `Tower4` arithmetic and per-row data (response, censoring
1296    /// indicators, offsets) entering as constants.
1297    fn row_nll(&self, row: usize, p: &[Tower4<K>; K]) -> Result<Tower4<K>, String>;
1298}
1299
1300/// Evaluate a program's full tower at the current primaries for one row.
1301///
1302/// One call yields every `RowKernel` calculus channel; callers that need
1303/// several contractions of the same row should hold the returned tower and
1304/// contract repeatedly rather than re-evaluating.
1305pub fn evaluate_program<const K: usize, P: RowNllProgram<K> + ?Sized>(
1306    prog: &P,
1307    row: usize,
1308) -> Result<Tower4<K>, String> {
1309    let p = prog.primaries(row)?;
1310    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(p[a], a));
1311    prog.row_nll(row, &vars)
1312}
1313
1314/// Mechanically derived `row_kernel` channel: `(nll, ∇, H)`.
1315pub fn derived_row_kernel<const K: usize, P: RowNllProgram<K> + ?Sized>(
1316    prog: &P,
1317    row: usize,
1318) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1319    let t = evaluate_program(prog, row)?;
1320    Ok((t.v, t.g, t.h))
1321}
1322
1323/// Mechanically derived `row_third_contracted` channel.
1324pub fn derived_third_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1325    prog: &P,
1326    row: usize,
1327    dir: &[f64; K],
1328) -> Result<[[f64; K]; K], String> {
1329    Ok(evaluate_program(prog, row)?.third_contracted(dir))
1330}
1331
1332/// Mechanically derived `row_fourth_contracted` channel.
1333pub fn derived_fourth_contracted<const K: usize, P: RowNllProgram<K> + ?Sized>(
1334    prog: &P,
1335    row: usize,
1336    dir_u: &[f64; K],
1337    dir_v: &[f64; K],
1338) -> Result<[[f64; K]; K], String> {
1339    Ok(evaluate_program(prog, row)?.fourth_contracted(dir_u, dir_v))
1340}
1341
1342// ── The generic program seam (#932 scalar cutover) ───────────────────
1343
1344/// A family's row negative log-likelihood written ONCE over the generic
1345/// [`crate::jet_scalar::JetScalar`] interface, so the SAME expression can be
1346/// re-instantiated at whatever order / representation a consumer needs
1347/// ([`crate::jet_scalar::Order2`] for `(v, g, H)`,
1348/// [`crate::jet_scalar::OneSeed`] for the contracted third,
1349/// [`crate::jet_scalar::TwoSeed`] for the contracted fourth, or the full
1350/// [`Tower4`] for every channel at once).
1351///
1352/// This is additive to [`RowNllProgram`] (which is `Tower4`-specialised): a
1353/// program implementing this generic trait gets the small contracted scalars for
1354/// free, dissolving the dense-`Tower4<9>` cost objection in the location-scale
1355/// gates (doc §A.4). An existing `Tower4`-only [`RowNllProgram`] continues to
1356/// work unchanged; new families should prefer this generic trait.
1357///
1358/// Because a `Tower4`-specialised `row_nll` body uses only
1359/// `add`/`sub`/`mul`/`scale`/`exp`/`ln`/… — all of which this trait also
1360/// provides — the same body is expressible directly over `S: JetScalar<K>`.
1361/// A program written that way needs no `Tower4`-specialised method and routes
1362/// the directional and joint-Hessian gates through the contracted scalars from
1363/// a single definition.
1364pub trait RowNllProgramGeneric<const K: usize>: Send + Sync {
1365    /// Number of observations the program covers.
1366    fn n_rows(&self) -> usize;
1367
1368    /// Current primary-scalar values for `row` (where to seed the scalar).
1369    fn primaries(&self, row: usize) -> Result<[f64; K], String>;
1370
1371    /// The row NLL evaluated on a generic jet scalar. `p[a]` arrives pre-seeded
1372    /// (base value + per-scalar nilpotent directions) by the caller; the body
1373    /// uses ONLY [`crate::jet_scalar::JetScalar`] ops and per-row data
1374    /// (response, censoring, offsets) entering as constants.
1375    fn row_nll_generic<S: crate::jet_scalar::JetScalar<K>>(
1376        &self,
1377        row: usize,
1378        p: &[S; K],
1379    ) -> Result<S, String>;
1380}
1381
1382/// Evaluate a generic program at the value/gradient/Hessian scalar
1383/// [`crate::jet_scalar::Order2`], returning `(nll, ∇, H)` — the
1384/// `row_kernel` channel — WITHOUT materialising any third / fourth tensor.
1385///
1386/// This is the production seam for the inner-Newton `(v, g, H)` path: the row
1387/// loss is written ONCE in `row_nll_generic`, and this routes it through the
1388/// cheap order-2 scalar. The single source of truth means the gradient and
1389/// Hessian cannot desync from the value (the #736 / #948 bug genus).
1390pub fn generic_row_kernel<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1391    prog: &P,
1392    row: usize,
1393) -> Result<(f64, [f64; K], [[f64; K]; K]), String> {
1394    let base = prog.primaries(row)?;
1395    let vars: [crate::jet_scalar::Order2<K>; K] = std::array::from_fn(|a| {
1396        <crate::jet_scalar::Order2<K> as crate::jet_scalar::JetScalar<K>>::variable(base[a], a)
1397    });
1398    let s = prog.row_nll_generic(row, &vars)?;
1399    Ok((crate::jet_scalar::JetScalar::value(&s), s.g(), s.h()))
1400}
1401
1402/// Evaluate a generic program at the one-seed scalar
1403/// [`crate::jet_scalar::OneSeed`], returning the contracted third
1404/// `Σ_c ℓ_{abc} dir_c` — the `row_third_contracted(dir)` channel — WITHOUT
1405/// materialising the dense `t3` tensor. The contraction direction is folded
1406/// INTO the differentiation by the nilpotent ε seeded with `dir`.
1407pub fn generic_third_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1408    prog: &P,
1409    row: usize,
1410    dir: &[f64; K],
1411) -> Result<[[f64; K]; K], String> {
1412    let base = prog.primaries(row)?;
1413    let vars: [crate::jet_scalar::OneSeed<K>; K] =
1414        std::array::from_fn(|a| crate::jet_scalar::OneSeed::seed_direction(base[a], a, dir[a]));
1415    let s = prog.row_nll_generic(row, &vars)?;
1416    Ok(s.contracted_third())
1417}
1418
1419/// Evaluate a generic program at the two-seed scalar
1420/// [`crate::jet_scalar::TwoSeed`], returning the contracted fourth
1421/// `Σ_{cd} ℓ_{abcd} u_c v_d` — the `row_fourth_contracted(u, v)` channel —
1422/// WITHOUT materialising the dense `t4` tensor.
1423pub fn generic_fourth_contracted<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1424    prog: &P,
1425    row: usize,
1426    dir_u: &[f64; K],
1427    dir_v: &[f64; K],
1428) -> Result<[[f64; K]; K], String> {
1429    let base = prog.primaries(row)?;
1430    let vars: [crate::jet_scalar::TwoSeed<K>; K] =
1431        std::array::from_fn(|a| crate::jet_scalar::TwoSeed::seed(base[a], a, dir_u[a], dir_v[a]));
1432    let s = prog.row_nll_generic(row, &vars)?;
1433    Ok(s.contracted_fourth())
1434}
1435
1436/// Evaluate a generic program at the full dense [`Tower4`] scalar, returning
1437/// every channel `(v, g, h, t3, t4)` in one pass. Used where the UNCONTRACTED
1438/// third / fourth tensors are needed (the BMS rigid `third_full` / `fourth_full`
1439/// caches): the dense tensors come from the SAME `row_nll_generic` expression
1440/// the order-2 / contracted scalars consume, so there is a single source of
1441/// truth across every channel.
1442pub fn generic_full_tower<const K: usize, P: RowNllProgramGeneric<K> + ?Sized>(
1443    prog: &P,
1444    row: usize,
1445) -> Result<Tower4<K>, String> {
1446    let base = prog.primaries(row)?;
1447    let vars: [Tower4<K>; K] = std::array::from_fn(|a| Tower4::variable(base[a], a));
1448    prog.row_nll_generic(row, &vars)
1449}
1450
1451// ── The oracle ───────────────────────────────────────────────────────
1452
1453/// One row's worth of hand-written kernel outputs, as claimed by a
1454/// `RowKernel` implementation, packaged for verification against the
1455/// tower truth. Plain data (no trait coupling) so any kernel — whatever
1456/// its visibility — can be audited from its own test module.
1457pub struct KernelChannels<const K: usize> {
1458    /// Claimed `(nll, ∇, H)` from `row_kernel`.
1459    pub value: f64,
1460    /// Claimed gradient.
1461    pub gradient: [f64; K],
1462    /// Claimed Hessian.
1463    pub hessian: [[f64; K]; K],
1464    /// Claimed `row_third_contracted(dir)` outputs as `(dir, claim)` pairs.
1465    pub third: Vec<([f64; K], [[f64; K]; K])>,
1466    /// Claimed `row_fourth_contracted(u, v)` outputs as `(u, v, claim)`.
1467    pub fourth: Vec<([f64; K], [f64; K], [[f64; K]; K])>,
1468}
1469
1470/// Channel-by-channel audit of a hand-written kernel against the
1471/// single-expression tower truth. Returns `Err` naming the first channel,
1472/// index, claimed and true values on disagreement — designed as the body
1473/// of the per-family CI oracle tests (#932 deployment step 2).
1474///
1475/// Tolerance is PER ENTRY, mixed absolute/relative: each comparison uses
1476/// `|claim − truth| ≤ atol + rel_tol · max(|claim|, |truth|)`. The absolute
1477/// floor `atol = rel_tol` lets exact-zero entries of structurally sparse
1478/// towers pass without demanding bit-equality, while a tiny cross-block
1479/// entry dropped next to a huge one is still caught (it is NOT measured
1480/// against the largest entry of the whole channel — there is no per-channel
1481/// magnitude floor). Genuine sign flips (#736) and dropped channels are loud.
1482///
1483/// Non-finite handling is strict: a NaN on either side always fails; an
1484/// infinity passes only when both sides are the SAME signed infinity.
1485pub fn verify_kernel_channels<const K: usize>(
1486    tower: &Tower4<K>,
1487    claims: &KernelChannels<K>,
1488    rel_tol: f64,
1489) -> Result<(), String> {
1490    // Absolute floor: reuse rel_tol so a single knob controls both the
1491    // relative band and the absolute floor for entries near zero.
1492    let atol = rel_tol;
1493    let check = |label: &str, claim: f64, truth: f64| -> Result<(), String> {
1494        // Non-finite values never silently pass the algebraic comparison
1495        // below (any comparison with NaN is false). Handle them explicitly:
1496        // NaN on either side always errs; an infinity passes only if both
1497        // sides are the identical signed infinity.
1498        if !claim.is_finite() || !truth.is_finite() {
1499            let agree = claim.is_infinite()
1500                && truth.is_infinite()
1501                && claim.is_sign_positive() == truth.is_sign_positive();
1502            if agree {
1503                return Ok(());
1504            }
1505            return Err(format!(
1506                "row-kernel oracle: {label} non-finite mismatch: claimed {claim:+.12e}, tower {truth:+.12e}"
1507            ));
1508        }
1509        let band = atol + rel_tol * claim.abs().max(truth.abs());
1510        if (claim - truth).abs() > band {
1511            return Err(format!(
1512                "row-kernel oracle: {label} disagrees: claimed {claim:+.12e}, tower {truth:+.12e} (rel_tol {rel_tol:.1e}, atol {atol:.1e}, band {band:.3e})"
1513            ));
1514        }
1515        Ok(())
1516    };
1517
1518    check("value", claims.value, tower.v)?;
1519
1520    for a in 0..K {
1521        check(&format!("gradient[{a}]"), claims.gradient[a], tower.g[a])?;
1522    }
1523
1524    for a in 0..K {
1525        for b in 0..K {
1526            check(
1527                &format!("hessian[{a}][{b}]"),
1528                claims.hessian[a][b],
1529                tower.h[a][b],
1530            )?;
1531        }
1532    }
1533
1534    for (t_idx, (dir, claim)) in claims.third.iter().enumerate() {
1535        let truth = tower.third_contracted(dir);
1536        for a in 0..K {
1537            for b in 0..K {
1538                check(
1539                    &format!("third[{t_idx}][{a}][{b}]"),
1540                    claim[a][b],
1541                    truth[a][b],
1542                )?;
1543            }
1544        }
1545    }
1546
1547    for (f_idx, (u, w, claim)) in claims.fourth.iter().enumerate() {
1548        let truth = tower.fourth_contracted(u, w);
1549        for a in 0..K {
1550            for b in 0..K {
1551                check(
1552                    &format!("fourth[{f_idx}][{a}][{b}]"),
1553                    claim[a][b],
1554                    truth[a][b],
1555                )?;
1556            }
1557        }
1558    }
1559
1560    Ok(())
1561}
1562
1563#[cfg(test)]
1564mod tests {
1565    use super::*;
1566
1567    /// `Tower3<K>` must be bit-identical to `Tower4<K>` on every channel it
1568    /// carries (value, gradient, Hessian, third derivatives). The order-≤3
1569    /// Leibniz / Faà-di-Bruno terms read only order-≤3 inner channels, so
1570    /// dropping the fourth tensor cannot perturb them. Exercises products
1571    /// (Leibniz cross-terms), unary composition, scaling, and addition — the
1572    /// same operations the survival location-scale `nll_index_tower` composes —
1573    /// across all mixed partials, not just the diagonal entries that kernel reads.
1574    #[test]
1575    fn tower3_matches_tower4_through_third_order() {
1576        let s_a: [f64; 5] = [
1577            0.3_f64.sin(),
1578            0.3_f64.cos(),
1579            -0.3_f64.sin(),
1580            -0.3_f64.cos(),
1581            0.3_f64.sin(),
1582        ];
1583        let s_b: [f64; 5] = [1.1, -0.4, 0.8, -0.2, 0.05];
1584        let s4 = |s: [f64; 5]| [s[0], s[1], s[2], s[3]];
1585
1586        let a4 = Tower4::<3>::variable(0.4, 0);
1587        let b4 = Tower4::<3>::variable(-0.7, 1);
1588        let c4 = Tower4::<3>::variable(0.9, 2);
1589        let prog4 = (a4.mul(&b4) + c4).compose_unary(s_a).scale(1.3)
1590            + a4.mul(&c4).scale(-0.7)
1591            + b4.compose_unary(s_b).scale(0.25);
1592
1593        let a3 = Tower3::<3>::variable(0.4, 0);
1594        let b3 = Tower3::<3>::variable(-0.7, 1);
1595        let c3 = Tower3::<3>::variable(0.9, 2);
1596        let prog3 = (a3.mul(&b3) + c3).compose_unary(s4(s_a)).scale(1.3)
1597            + a3.mul(&c3).scale(-0.7)
1598            + b3.compose_unary(s4(s_b)).scale(0.25);
1599
1600        assert_eq!(prog3.v.to_bits(), prog4.v.to_bits(), "value mismatch");
1601        for i in 0..3 {
1602            assert_eq!(
1603                prog3.g[i].to_bits(),
1604                prog4.g[i].to_bits(),
1605                "g[{i}] mismatch"
1606            );
1607            for j in 0..3 {
1608                assert_eq!(
1609                    prog3.h[i][j].to_bits(),
1610                    prog4.h[i][j].to_bits(),
1611                    "h[{i}][{j}] mismatch"
1612                );
1613                for k in 0..3 {
1614                    assert_eq!(
1615                        prog3.t3[i][j][k].to_bits(),
1616                        prog4.t3[i][j][k].to_bits(),
1617                        "t3[{i}][{j}][{k}] mismatch"
1618                    );
1619                }
1620            }
1621        }
1622    }
1623
1624    /// Binomial-logit row NLL, K=1: ℓ(η) = ln(1 + e^η) − y·η.
1625    /// The entire tower has textbook closed forms in μ = σ(η); this test
1626    /// pins the algebra (exp, ln, scalar mixes, Leibniz/Faà di Bruno) to
1627    /// analytic truth at near-machine precision.
1628    struct LogitProgram {
1629        eta: Vec<f64>,
1630        y: Vec<f64>,
1631    }
1632
1633    impl RowNllProgram<1> for LogitProgram {
1634        fn n_rows(&self) -> usize {
1635            self.eta.len()
1636        }
1637        fn primaries(&self, row: usize) -> Result<[f64; 1], String> {
1638            Ok([self.eta[row]])
1639        }
1640        fn row_nll(&self, row: usize, p: &[Tower4<1>; 1]) -> Result<Tower4<1>, String> {
1641            let eta = p[0];
1642            Ok((eta.exp() + 1.0).ln() - eta * self.y[row])
1643        }
1644    }
1645
1646    #[test]
1647    fn logit_tower_matches_closed_forms() {
1648        let prog = LogitProgram {
1649            eta: vec![-2.3, -0.4, 0.0, 0.9, 3.1],
1650            y: vec![1.0, 0.0, 1.0, 0.0, 1.0],
1651        };
1652        for row in 0..prog.n_rows() {
1653            let t = evaluate_program(&prog, row).expect("logit program");
1654            let eta = prog.eta[row];
1655            let y = prog.y[row];
1656            let mu = 1.0 / (1.0 + (-eta).exp());
1657            let w = mu * (1.0 - mu);
1658            let expect = [
1659                (t.v, (1.0 + eta.exp()).ln() - y * eta, "value"),
1660                (t.g[0], mu - y, "grad"),
1661                (t.h[0][0], w, "hess"),
1662                (t.t3[0][0][0], w * (1.0 - 2.0 * mu), "third"),
1663                (
1664                    t.t4[0][0][0][0],
1665                    w * (1.0 - 6.0 * mu + 6.0 * mu * mu),
1666                    "fourth",
1667                ),
1668            ];
1669            for (got, want, label) in expect {
1670                assert!(
1671                    (got - want).abs() <= 1e-12 * want.abs().max(1.0),
1672                    "row {row} {label}: got {got:+.15e} want {want:+.15e}"
1673                );
1674            }
1675        }
1676    }
1677
1678    fn assert_close(label: &str, got: f64, want: f64, rel_tol: f64) {
1679        let diff = (got - want).abs();
1680        assert!(
1681            diff <= rel_tol * want.abs().max(1.0),
1682            "{label}: got {got:+.17e} want {want:+.17e} diff {diff:.3e}"
1683        );
1684    }
1685
1686    #[test]
1687    fn gamma_special_function_stacks_match_reference_values() {
1688        const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
1689        let pi_sq = std::f64::consts::PI * std::f64::consts::PI;
1690        let cases = [
1691            (
1692                "x=0.1",
1693                0.1,
1694                -10.423_754_940_411_076,
1695                101.433_299_150_792_75,
1696            ),
1697            (
1698                "x=0.5",
1699                0.5,
1700                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2,
1701                pi_sq / 2.0,
1702            ),
1703            ("x=1", 1.0, -EULER_GAMMA, pi_sq / 6.0),
1704            (
1705                "x=2.5",
1706                2.5,
1707                -EULER_GAMMA - 2.0 * std::f64::consts::LN_2 + 2.0 + 2.0 / 3.0,
1708                pi_sq / 2.0 - 4.0 - 4.0 / 9.0,
1709            ),
1710            (
1711                "x=50",
1712                50.0,
1713                3.901_989_673_427_892,
1714                0.020_201_333_226_697_128,
1715            ),
1716        ];
1717
1718        for (label, x, digamma_ref, trigamma_ref) in cases {
1719            let ln_gamma_stack = ln_gamma_derivative_stack(x);
1720            let digamma_stack = digamma_derivative_stack(x);
1721            let trigamma_stack = trigamma_derivative_stack(x);
1722            assert_close(
1723                &format!("{label} ln_gamma_stack digamma"),
1724                ln_gamma_stack[1],
1725                digamma_ref,
1726                1e-13,
1727            );
1728            assert_close(
1729                &format!("{label} digamma value"),
1730                digamma_stack[0],
1731                digamma_ref,
1732                1e-13,
1733            );
1734            assert_close(
1735                &format!("{label} ln_gamma_stack trigamma"),
1736                ln_gamma_stack[2],
1737                trigamma_ref,
1738                1e-13,
1739            );
1740            assert_close(
1741                &format!("{label} digamma_stack trigamma"),
1742                digamma_stack[1],
1743                trigamma_ref,
1744                1e-13,
1745            );
1746            assert_close(
1747                &format!("{label} trigamma value"),
1748                trigamma_stack[0],
1749                trigamma_ref,
1750                1e-13,
1751            );
1752        }
1753    }
1754
1755    #[test]
1756    fn gamma_special_function_stacks_obey_recurrences() {
1757        for x in [0.1, 0.5, 1.0, 2.5, 50.0] {
1758            let digamma_x = digamma_derivative_stack(x)[0];
1759            let digamma_next = digamma_derivative_stack(x + 1.0)[0];
1760            let trigamma_x = trigamma_derivative_stack(x)[0];
1761            let trigamma_next = trigamma_derivative_stack(x + 1.0)[0];
1762            assert_close(
1763                &format!("digamma recurrence x={x}"),
1764                digamma_next,
1765                digamma_x + 1.0 / x,
1766                1e-13,
1767            );
1768            assert_close(
1769                &format!("trigamma recurrence x={x}"),
1770                trigamma_next,
1771                trigamma_x - 1.0 / (x * x),
1772                1e-13,
1773            );
1774        }
1775    }
1776
1777    /// Gaussian location-scale row NLL, K=2 primaries (η, s = log σ):
1778    /// ℓ = s + ½ e^{−2s} (y − η)². Mixed cross blocks — the #736 fragility
1779    /// shape — all have one-line closed forms here.
1780    struct LocScaleProgram {
1781        eta: Vec<f64>,
1782        s: Vec<f64>,
1783        y: Vec<f64>,
1784    }
1785
1786    impl RowNllProgram<2> for LocScaleProgram {
1787        fn n_rows(&self) -> usize {
1788            self.eta.len()
1789        }
1790        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
1791            Ok([self.eta[row], self.s[row]])
1792        }
1793        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
1794            let r = -(p[0] - self.y[row]);
1795            Ok(p[1] + (p[1] * (-2.0)).exp() * r * r * 0.5)
1796        }
1797    }
1798
1799    #[test]
1800    fn locscale_tower_matches_closed_forms_including_cross_blocks() {
1801        let prog = LocScaleProgram {
1802            eta: vec![0.3, -1.1, 2.0],
1803            s: vec![-0.5, 0.2, 0.8],
1804            y: vec![1.0, -2.0, 2.5],
1805        };
1806        let tol = 1e-12;
1807        for row in 0..prog.n_rows() {
1808            let t = evaluate_program(&prog, row).expect("locscale program");
1809            let r = prog.y[row] - prog.eta[row];
1810            let w = (-2.0 * prog.s[row]).exp();
1811            // (η, s) = indices (0, 1).
1812            let truth_g = [-w * r, 1.0 - w * r * r];
1813            let truth_h = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
1814            // Third tensor: distinct-entry closed forms.
1815            // ∂ηηη = 0, ∂ηηs = −2w, ∂ηss = −4wr, ∂sss = −4wr².
1816            let t3_truth = |a: usize, b: usize, c: usize| -> f64 {
1817                match a + b + c {
1818                    0 => 0.0,
1819                    1 => -2.0 * w,
1820                    2 => -4.0 * w * r,
1821                    _ => -4.0 * w * r * r,
1822                }
1823            };
1824            // Fourth tensor: ∂ηηηη = 0, ∂ηηηs = 0? No: d/ds(∂ηηη)=0 ✓;
1825            // ∂ηηss = 4w, ∂ηsss = 8wr, ∂ssss = 8wr².
1826            let t4_truth = |a: usize, b: usize, c: usize, d: usize| -> f64 {
1827                match a + b + c + d {
1828                    0 | 1 => 0.0,
1829                    2 => 4.0 * w,
1830                    3 => 8.0 * w * r,
1831                    _ => 8.0 * w * r * r,
1832                }
1833            };
1834            for a in 0..2 {
1835                assert!(
1836                    (t.g[a] - truth_g[a]).abs() <= tol * truth_g[a].abs().max(1.0),
1837                    "row {row} grad[{a}]"
1838                );
1839                for b in 0..2 {
1840                    assert!(
1841                        (t.h[a][b] - truth_h[a][b]).abs() <= tol * w.max(1.0) * (1.0 + r.abs()),
1842                        "row {row} hess[{a}][{b}]: got {} want {}",
1843                        t.h[a][b],
1844                        truth_h[a][b]
1845                    );
1846                    for c in 0..2 {
1847                        assert!(
1848                            (t.t3[a][b][c] - t3_truth(a, b, c)).abs()
1849                                <= tol * 8.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
1850                            "row {row} t3[{a}][{b}][{c}]: got {} want {}",
1851                            t.t3[a][b][c],
1852                            t3_truth(a, b, c)
1853                        );
1854                        for d in 0..2 {
1855                            assert!(
1856                                (t.t4[a][b][c][d] - t4_truth(a, b, c, d)).abs()
1857                                    <= tol * 16.0 * w.max(1.0) * (1.0 + r.abs() + r * r),
1858                                "row {row} t4[{a}][{b}][{c}][{d}]: got {} want {}",
1859                                t.t4[a][b][c][d],
1860                                t4_truth(a, b, c, d)
1861                            );
1862                        }
1863                    }
1864                }
1865            }
1866            // The derived trait-surface helpers agree with direct contraction.
1867            let dir = [0.7, -1.3];
1868            let third = derived_third_contracted(&prog, row, &dir).expect("third");
1869            for a in 0..2 {
1870                for b in 0..2 {
1871                    let want = t.t3[a][b][0] * dir[0] + t.t3[a][b][1] * dir[1];
1872                    assert!((third[a][b] - want).abs() <= 1e-13 * want.abs().max(1.0));
1873                }
1874            }
1875        }
1876    }
1877
1878    /// FD cross-check on a deliberately gnarly composition (div, sqrt,
1879    /// powf, nested exp/ln) in K=3, where no closed form is consulted:
1880    /// every tower channel is checked against central finite differences
1881    /// of the channel one order below — value→grad, grad→hess, hess→t3,
1882    /// t3→t4 — so each order is independently anchored.
1883    ///
1884    /// The program carries a per-row primary fixture plus a per-row offset
1885    /// `tau[row]` that enters the loss as a constant, so `row` genuinely
1886    /// drives both the seed point and the evaluated expression.
1887    struct GnarlyProgram {
1888        primaries: Vec<[f64; 3]>,
1889        tau: Vec<f64>,
1890    }
1891
1892    impl GnarlyProgram {
1893        fn fixture() -> Self {
1894            Self {
1895                primaries: vec![[0.4, -0.7, 1.2], [-0.9, 0.6, 0.3], [1.1, -0.2, -0.8]],
1896                tau: vec![0.15, -0.35, 0.5],
1897            }
1898        }
1899    }
1900
1901    impl RowNllProgram<3> for GnarlyProgram {
1902        fn n_rows(&self) -> usize {
1903            self.primaries.len()
1904        }
1905        fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
1906            self.primaries
1907                .get(row)
1908                .copied()
1909                .ok_or_else(|| format!("gnarly: row {row} out of range"))
1910        }
1911        fn row_nll(&self, row: usize, p: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
1912            let tau = *self
1913                .tau
1914                .get(row)
1915                .ok_or_else(|| format!("gnarly: tau row {row} out of range"))?;
1916            let a = (p[0] * p[1]).exp();
1917            let b = (p[2] * p[2] + 1.0).sqrt();
1918            let c = (a + b + tau).ln();
1919            let d = (p[1] * 0.5 + 2.0).powf(1.7);
1920            Ok(c / d + (p[0] - p[2]) * (p[0] - p[2]) * 0.25)
1921        }
1922    }
1923
1924    /// Evaluate the gnarly program's tower at an ARBITRARY seed point for
1925    /// `row` (used to drive central differences off the fixture grid),
1926    /// while keeping `row`'s per-row data (`tau`) in the loss.
1927    fn gnarly_tower_at(prog: &GnarlyProgram, row: usize, p: [f64; 3]) -> Tower4<3> {
1928        struct At<'a> {
1929            base: &'a GnarlyProgram,
1930            row: usize,
1931            p: [f64; 3],
1932        }
1933        impl RowNllProgram<3> for At<'_> {
1934            fn n_rows(&self) -> usize {
1935                1
1936            }
1937            fn primaries(&self, row: usize) -> Result<[f64; 3], String> {
1938                if row != 0 {
1939                    return Err(format!("gnarly-at: row {row} out of range"));
1940                }
1941                Ok(self.p)
1942            }
1943            fn row_nll(&self, eval_row: usize, vars: &[Tower4<3>; 3]) -> Result<Tower4<3>, String> {
1944                if eval_row != 0 {
1945                    return Err(format!("gnarly-at: eval row {eval_row} out of range"));
1946                }
1947                self.base.row_nll(self.row, vars)
1948            }
1949        }
1950        evaluate_program(&At { base: prog, row, p }, 0).expect("gnarly tower")
1951    }
1952
1953    #[test]
1954    fn gnarly_tower_is_fd_consistent_order_by_order() {
1955        let prog = GnarlyProgram::fixture();
1956        for row in 0..prog.n_rows() {
1957            let base = prog.primaries(row).expect("primaries");
1958            let t = gnarly_tower_at(&prog, row, base);
1959            let h_step = 1e-5;
1960            let tol = 1e-6;
1961            for c in 0..3 {
1962                let mut up = base;
1963                let mut dn = base;
1964                up[c] += h_step;
1965                dn[c] -= h_step;
1966                let t_up = gnarly_tower_at(&prog, row, up);
1967                let t_dn = gnarly_tower_at(&prog, row, dn);
1968                // value → gradient.
1969                let fd_g = (t_up.v - t_dn.v) / (2.0 * h_step);
1970                assert!(
1971                    (t.g[c] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
1972                    "grad[{c}]: analytic {} fd {}",
1973                    t.g[c],
1974                    fd_g
1975                );
1976                for a in 0..3 {
1977                    // gradient → Hessian.
1978                    let fd_h = (t_up.g[a] - t_dn.g[a]) / (2.0 * h_step);
1979                    assert!(
1980                        (t.h[a][c] - fd_h).abs() <= tol * fd_h.abs().max(1.0),
1981                        "hess[{a}][{c}]: analytic {} fd {}",
1982                        t.h[a][c],
1983                        fd_h
1984                    );
1985                    for b in 0..3 {
1986                        // Hessian → third.
1987                        let fd_t3 = (t_up.h[a][b] - t_dn.h[a][b]) / (2.0 * h_step);
1988                        assert!(
1989                            (t.t3[a][b][c] - fd_t3).abs() <= tol * fd_t3.abs().max(1.0),
1990                            "t3[{a}][{b}][{c}]: analytic {} fd {}",
1991                            t.t3[a][b][c],
1992                            fd_t3
1993                        );
1994                        for d in 0..3 {
1995                            // third → fourth.
1996                            let fd_t4 = (t_up.t3[a][b][d] - t_dn.t3[a][b][d]) / (2.0 * h_step);
1997                            assert!(
1998                                (t.t4[a][b][d][c] - fd_t4).abs() <= tol * fd_t4.abs().max(1.0),
1999                                "t4[{a}][{b}][{d}][{c}]: analytic {} fd {}",
2000                                t.t4[a][b][d][c],
2001                                fd_t4
2002                            );
2003                        }
2004                    }
2005                }
2006            }
2007        }
2008    }
2009
2010    /// `implicit_solve` reproduces the true implicit function `a(θ)` of a
2011    /// constraint `F(a, θ) = 0` to fourth order. The constraint here is the
2012    /// smooth, strictly-`a`-monotone
2013    ///   F(a, θ) = a + θ₀·a² + θ₁·exp(a) − c
2014    /// whose root `a(θ)` is re-solved by scalar Newton at perturbed θ as the
2015    /// independent finite-difference oracle. Mirrors the survival flex
2016    /// calibration solve (one implicit intercept over the primaries) without
2017    /// any survival machinery, so a failure localises to the combinator.
2018    #[test]
2019    fn implicit_solve_matches_scalar_resolve_to_fourth_order() {
2020        const C: f64 = 1.7;
2021        // The scalar constraint as a plain f64 closure (the production root
2022        // finder analogue) and its tower form in (a, θ₀, θ₁).
2023        let f_scalar = |a: f64, th: [f64; 2]| a + th[0] * a * a + th[1] * a.exp() - C;
2024        let f_da = |a: f64, th: [f64; 2]| 1.0 + 2.0 * th[0] * a + th[1] * a.exp();
2025        let solve = |th: [f64; 2]| -> f64 {
2026            let mut a = 0.0_f64;
2027            for _ in 0..100 {
2028                let r = f_scalar(a, th);
2029                if r.abs() < 1e-14 {
2030                    break;
2031                }
2032                a -= r / f_da(a, th);
2033            }
2034            a
2035        };
2036        // Tower constraint over K1 = 3 vars: slot 0 = a, slots 1,2 = θ₀, θ₁.
2037        let f_tower = |a0: f64, th: [f64; 2]| -> Tower4<3> {
2038            let a = Tower4::<3>::variable(a0, 0);
2039            let t0 = Tower4::<3>::variable(th[0], 1);
2040            let t1 = Tower4::<3>::variable(th[1], 2);
2041            a + t0 * a.mul(&a) + t1 * a.exp() - C
2042        };
2043
2044        let th0 = [0.35, 0.2];
2045        let a0 = solve(th0);
2046        let f = f_tower(a0, th0);
2047        // Residual at the solved point is ~0 (the combinator tolerates the
2048        // production Newton residual; here it is machine-zero).
2049        assert!(f.v.abs() < 1e-12, "constraint residual {:+.3e}", f.v);
2050        let a_tower: Tower4<2> = implicit_solve::<3, 2>(&f, a0).expect("implicit solve");
2051
2052        // FD oracle: central differences of the scalar re-solve. Each order is
2053        // built from the previous via one more central difference, exactly the
2054        // gnarly order-by-order ladder.
2055        let h = 1e-4;
2056        let tol = 1e-5;
2057        let re = |th: [f64; 2]| solve(th);
2058        for i in 0..2 {
2059            let mut up = th0;
2060            let mut dn = th0;
2061            up[i] += h;
2062            dn[i] -= h;
2063            let fd_g = (re(up) - re(dn)) / (2.0 * h);
2064            assert!(
2065                (a_tower.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
2066                "a_θ[{i}]: analytic {:+.6e} fd {:+.6e}",
2067                a_tower.g[i],
2068                fd_g
2069            );
2070            // second order: FD of the analytic gradient component would re-use
2071            // the combinator; instead difference a SCALAR gradient computed by
2072            // a nested re-solve so the oracle stays production-independent.
2073            let grad_at = |th: [f64; 2], j: usize| -> f64 {
2074                let mut up = th;
2075                let mut dn = th;
2076                up[j] += h;
2077                dn[j] -= h;
2078                (re(up) - re(dn)) / (2.0 * h)
2079            };
2080            for j in 0..2 {
2081                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
2082                assert!(
2083                    (a_tower.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
2084                    "a_θθ[{i}][{j}]: analytic {:+.6e} fd {:+.6e}",
2085                    a_tower.h[i][j],
2086                    fd_h
2087                );
2088            }
2089        }
2090    }
2091
2092    /// `implicit_solve` degenerates to `a_θ = −F_θ / F_a` at first order on a
2093    /// linear-in-a constraint, and the second-order tensor matches the
2094    /// textbook IFT formula `a_uv = −(F_uv + F_au a_v + F_av a_u + F_aa a_u a_v)/F_a`.
2095    /// This pins the recursion against the hand-coded first_full.rs formula it
2096    /// replaces, independent of any FD step.
2097    #[test]
2098    fn implicit_solve_matches_textbook_ift_recursion() {
2099        // A constraint with non-trivial F_a, F_aa, F_au, F_uv all present.
2100        let a0 = 0.4_f64;
2101        let th = [0.25_f64, -0.15_f64];
2102        let f = {
2103            let a = Tower4::<3>::variable(a0, 0);
2104            let t0 = Tower4::<3>::variable(th[0], 1);
2105            let t1 = Tower4::<3>::variable(th[1], 2);
2106            // F = a·(1 + θ₀) + θ₁·a² + θ₀·θ₁ − 0.4385. The constant is chosen so
2107            // F(a0, θ0) = 0 exactly at a0 = 0.4, θ = [0.25, −0.15]:
2108            //   0.4·1.25 + (−0.15)·0.16 + 0.25·(−0.15) = 0.4385.
2109            // implicit_solve requires a genuine root; at the root the level-set
2110            // and root-curve derivatives coincide, so the textbook-IFT
2111            // assertions below are unaffected.
2112            a * (t0 + 1.0) + t1 * a.mul(&a) + t0 * t1 - 0.4385
2113        };
2114        let a_t = implicit_solve::<3, 2>(&f, a0).expect("solve");
2115        let f_a = f.g[0];
2116        // First order: a_u = −F_u / F_a.
2117        for u in 0..2 {
2118            let want = -f.g[u + 1] / f_a;
2119            assert!(
2120                (a_t.g[u] - want).abs() < 1e-12,
2121                "a_u[{u}] {:+.6e} vs −F_u/F_a {:+.6e}",
2122                a_t.g[u],
2123                want
2124            );
2125        }
2126        // Second order textbook IFT (indices shifted by 1 for the a-slot).
2127        for u in 0..2 {
2128            for v in 0..2 {
2129                let f_uv = f.h[u + 1][v + 1];
2130                let f_au = f.h[0][u + 1];
2131                let f_av = f.h[0][v + 1];
2132                let f_aa = f.h[0][0];
2133                let want =
2134                    -(f_uv + f_au * a_t.g[v] + f_av * a_t.g[u] + f_aa * a_t.g[u] * a_t.g[v]) / f_a;
2135                assert!(
2136                    (a_t.h[u][v] - want).abs() < 1e-12,
2137                    "a_uv[{u}][{v}] {:+.6e} vs IFT {:+.6e}",
2138                    a_t.h[u][v],
2139                    want
2140                );
2141            }
2142        }
2143    }
2144
2145    /// The moving-boundary flux tower reproduces every θ-derivative of a
2146    /// moving-limit integral, INCLUDING the second-order `B·z_uv` term the
2147    /// hand-written flux dropped (#932). The edge `z_R(θ) = θ₀ + θ₁²` has a
2148    /// genuinely nonzero `∂²z_R/∂θ₁² = 2`, so a combinator that omitted
2149    /// `B·z_uv` would miss the [1][1] Hessian entry. Truth = central FD of the
2150    /// closed-form integral `∫₀^{z_R} e^{−z²/2} dz = √(π/2)·erf(z_R/√2)`.
2151    #[test]
2152    fn moving_boundary_flux_carries_b_zuv_term() {
2153        use std::f64::consts::PI;
2154        let b = |z: f64| (-0.5 * z * z).exp(); // integrand B(z)
2155        // Antiderivative-based closed-form integral I(z_R) = ∫₀^{z_R} B dz.
2156        let integral = |z_r: f64| (PI / 2.0).sqrt() * libm::erf(z_r / 2.0_f64.sqrt());
2157        let z_r = |th: [f64; 2]| th[0] + th[1] * th[1];
2158        let th0 = [0.7_f64, 0.5_f64];
2159
2160        // Edge tower z_R(θ) over K=2 primaries: value + exact derivatives.
2161        let mut z_edge = Tower4::<2>::constant(z_r(th0));
2162        z_edge.g[0] = 1.0; // ∂z_R/∂θ₀ = 1
2163        z_edge.g[1] = 2.0 * th0[1]; // ∂z_R/∂θ₁ = 2θ₁
2164        z_edge.h[1][1] = 2.0; // ∂²z_R/∂θ₁² = 2  (the z_uv the old flux dropped)
2165
2166        // Integrand stack [B, B′, B″, B‴] at z₀: B′=−z·B, B″=(z²−1)·B,
2167        // B‴=(3z−z³)·B.
2168        let z0 = z_edge.v;
2169        let b0 = b(z0);
2170        let stack = [
2171            b0,
2172            -z0 * b0,
2173            (z0 * z0 - 1.0) * b0,
2174            (3.0 * z0 - z0 * z0 * z0) * b0,
2175        ];
2176        let flux = moving_limit_boundary_tower(&z_edge, stack);
2177
2178        // FD truth of the integral's derivatives.
2179        let h = 1e-4;
2180        let tol = 1e-6;
2181        for i in 0..2 {
2182            let mut up = th0;
2183            let mut dn = th0;
2184            up[i] += h;
2185            dn[i] -= h;
2186            let fd_g = (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h);
2187            assert!(
2188                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
2189                "flux_g[{i}]: analytic {:+.8e} fd {:+.8e}",
2190                flux.g[i],
2191                fd_g
2192            );
2193        }
2194        // The decisive entry: ∂²I/∂θ₁² = B′·(z_θ₁)² + B·z_θ₁θ₁. With z_θ₁=2θ₁=1
2195        // and z_θ₁θ₁=2, the B·z_uv contribution is B(z₀)·2 — omitting it would
2196        // leave the [1][1] entry short by exactly 2·B(z₀).
2197        let grad1_at = |th: [f64; 2]| -> f64 {
2198            let mut up = th;
2199            let mut dn = th;
2200            up[1] += h;
2201            dn[1] -= h;
2202            (integral(z_r(up)) - integral(z_r(dn))) / (2.0 * h)
2203        };
2204        let mut up = th0;
2205        let mut dn = th0;
2206        up[1] += h;
2207        dn[1] -= h;
2208        let fd_h11 = (grad1_at(up) - grad1_at(dn)) / (2.0 * h);
2209        assert!(
2210            (flux.h[1][1] - fd_h11).abs() <= 1e-3 * fd_h11.abs().max(1.0),
2211            "flux_h[1][1] (carries B·z_uv): analytic {:+.8e} fd {:+.8e}",
2212            flux.h[1][1],
2213            fd_h11
2214        );
2215        // Explicit witness that the B·z_uv term is present and material:
2216        // analytic h[1][1] minus the pure (z_u)² part must equal B·z_uv = 2·B₀.
2217        let pure_zu2 = stack[1] * z_edge.g[1] * z_edge.g[1];
2218        let b_zuv = flux.h[1][1] - pure_zu2;
2219        assert!(
2220            (b_zuv - b0 * 2.0).abs() < 1e-10,
2221            "B·z_uv term {:+.8e} != B₀·z_uv {:+.8e}",
2222            b_zuv,
2223            b0 * 2.0
2224        );
2225    }
2226
2227    /// `moving_limit_boundary_tower_theta_integrand` reproduces the marginal-slope
2228    /// flex boundary closure for a θ-DEPENDENT integrand `G(z;θ)` — the case the
2229    /// plain `moving_limit_boundary_tower` cannot express, and the case the
2230    /// survival directional/bidirectional paths hand-assemble term-by-term
2231    /// (`G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`, with the directional path
2232    /// dropping `G·z_uv`). Two independent oracles:
2233    ///   (1) closed-form: the boundary flux of `∫ G dz` is exactly
2234    ///       `Φ(z_edge(θ);θ) − Φ(z₀;θ)` (Φ = z-antiderivative of G), whose θ
2235    ///       derivatives we take by central FD of the closed form — no jet code.
2236    ///   (2) the explicit second-order hand closure, including the `G·z_uv` term,
2237    ///       built from the integrand's own (z,θ) partials.
2238    /// G(z;θ) = exp(z·θ₀) is genuinely θ-dependent (G_θ₀ = z·e^{zθ₀} ≠ 0), and
2239    /// the edge z_edge = z₀ + θ₀ + θ₁² has a real z_uv = ∂²/∂θ₁² = 2, so a
2240    /// combinator that dropped either the integrand-θ terms or `G·z_uv` would
2241    /// miss a Hessian entry.
2242    #[test]
2243    fn moving_boundary_theta_integrand_matches_handpath_and_closed_form() {
2244        // G(z;θ) = exp(z·θ₀);  Φ(z;θ) = ∫₀^z G = (e^{zθ₀} − 1)/θ₀.
2245        let g = |z: f64, t0: f64| (z * t0).exp();
2246        let phi = |z: f64, t0: f64| ((z * t0).exp() - 1.0) / t0;
2247        let z_r = |th: [f64; 2]| 0.6 + th[0] + th[1] * th[1];
2248        let th0 = [0.4_f64, 0.5_f64];
2249        let z0 = z_r(th0);
2250
2251        // Edge tower z_edge(θ) over K=2 primaries.
2252        let mut z_edge = Tower4::<2>::constant(z0);
2253        z_edge.g[0] = 1.0; // ∂z/∂θ₀
2254        z_edge.g[1] = 2.0 * th0[1]; // ∂z/∂θ₁
2255        z_edge.h[1][1] = 2.0; // ∂²z/∂θ₁² (the z_uv the directional path drops)
2256
2257        // Φ's mixed (z, θ) jet over K1 = 3 vars: slot 0 = z, slots 1,2 = θ₀,θ₁.
2258        // Built ONCE in tower arithmetic so every (z^i θ^j) partial is exact.
2259        let z_var = Tower4::<3>::variable(z0, 0);
2260        let t0_var = Tower4::<3>::variable(th0[0], 1);
2261        // θ₁ does not enter G/Φ here, but seed it so the jet carries the full
2262        // K1 frame (its Φ-derivatives are zero; the z_edge chain supplies all θ₁
2263        // motion through slot 0).
2264        let _t1_var = Tower4::<3>::variable(th0[1], 2);
2265        let phi_jet = ((z_var * t0_var).exp() - 1.0) / t0_var;
2266        // Sanity: slot-0 first derivative of Φ IS G(z₀;θ₀).
2267        assert!(
2268            (phi_jet.g[0] - g(z0, th0[0])).abs() < 1e-12,
2269            "Φ_z {:+.8e} != G {:+.8e}",
2270            phi_jet.g[0],
2271            g(z0, th0[0])
2272        );
2273
2274        let flux = moving_limit_boundary_tower_theta_integrand::<3, 2>(&phi_jet, &z_edge);
2275
2276        // Value channel is 0 by construction (boundary, not the integral itself).
2277        assert!(
2278            flux.v.abs() < 1e-12,
2279            "boundary value channel {:+.3e}",
2280            flux.v
2281        );
2282
2283        // Oracle (1): central FD of the closed-form boundary flux
2284        //   Bnd(θ) = Φ(z_edge(θ); θ) − Φ(z₀; θ)   (z₀ FROZEN at the base edge).
2285        let bnd = |th: [f64; 2]| phi(z_r(th), th[0]) - phi(z0, th[0]);
2286        let h = 1e-4;
2287        let tol = 1e-6;
2288        for i in 0..2 {
2289            let mut up = th0;
2290            let mut dn = th0;
2291            up[i] += h;
2292            dn[i] -= h;
2293            let fd_g = (bnd(up) - bnd(dn)) / (2.0 * h);
2294            assert!(
2295                (flux.g[i] - fd_g).abs() <= tol * fd_g.abs().max(1.0),
2296                "boundary_g[{i}] analytic {:+.8e} fd {:+.8e}",
2297                flux.g[i],
2298                fd_g
2299            );
2300        }
2301        let grad_at = |th: [f64; 2], j: usize| -> f64 {
2302            let mut up = th;
2303            let mut dn = th;
2304            up[j] += h;
2305            dn[j] -= h;
2306            (bnd(up) - bnd(dn)) / (2.0 * h)
2307        };
2308        for i in 0..2 {
2309            for j in 0..2 {
2310                let mut up = th0;
2311                let mut dn = th0;
2312                up[i] += h;
2313                dn[i] -= h;
2314                let fd_h = (grad_at(up, j) - grad_at(dn, j)) / (2.0 * h);
2315                assert!(
2316                    (flux.h[i][j] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0),
2317                    "boundary_h[{i}][{j}] analytic {:+.8e} fd {:+.8e}",
2318                    flux.h[i][j],
2319                    fd_h
2320                );
2321            }
2322        }
2323
2324        // Oracle (2): the explicit second-order hand closure, term by term —
2325        // `G·z_uv + G_z·z_u·z_v + G_θu·z_v + G_θv·z_u`. Read G's partials at the
2326        // base point directly (no jet): G = e^{zθ₀}, G_z = θ₀·G, G_θ₀ = z·G,
2327        // G_θ₁ = 0.
2328        let gg = g(z0, th0[0]);
2329        let g_z = th0[0] * gg;
2330        let g_theta = [z0 * gg, 0.0]; // [G_θ₀, G_θ₁]
2331        for i in 0..2 {
2332            for j in 0..2 {
2333                let z_u = z_edge.g[i];
2334                let z_v = z_edge.g[j];
2335                let z_uv = z_edge.h[i][j];
2336                let hand = gg * z_uv + g_z * z_u * z_v + g_theta[i] * z_v + g_theta[j] * z_u;
2337                assert!(
2338                    (flux.h[i][j] - hand).abs() < 1e-9,
2339                    "boundary_h[{i}][{j}] {:+.8e} != hand closure {:+.8e}",
2340                    flux.h[i][j],
2341                    hand
2342                );
2343            }
2344        }
2345
2346        // Decisive: the `G·z_uv` term the directional path DROPS is present and
2347        // material in the [1][1] entry (z_uv = 2 there).
2348        let pure_no_zuv = g_z * z_edge.g[1] * z_edge.g[1] + 2.0 * g_theta[1] * z_edge.g[1];
2349        let g_zuv = flux.h[1][1] - pure_no_zuv;
2350        assert!(
2351            (g_zuv - gg * 2.0).abs() < 1e-9,
2352            "G·z_uv term {:+.8e} != G₀·z_uv {:+.8e}",
2353            g_zuv,
2354            gg * 2.0
2355        );
2356    }
2357
2358    /// The survival crossing-edge position tower `z_edge = (τ − a(θ)) / b`,
2359    /// `b = exp(g)`, built from the intercept tower `a(θ)` (here a stand-in)
2360    /// and the seeded slope `g`, reproduces taylor-jet's exact hand-path
2361    /// boundary-velocity formulas:
2362    ///   z_u   = −(a_u + [u==g]·z) / b
2363    ///   z_uv  = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b
2364    /// This pins the bridge between `implicit_solve` and
2365    /// `cell_moving_boundary_flux_tower`: the boundary jet that the production
2366    /// flex path hand-codes (and dropped `z_uv` from) is exactly `∂²` of this
2367    /// tower. K=3 reduced frame: slot 0 = a-axis carrier (an arbitrary smooth
2368    /// a(θ) with nonzero a_u/a_uv), slot 1 = g (the log-slope), slot 2 unused.
2369    #[test]
2370    fn crossing_edge_tower_matches_handpath_velocity_formulas() {
2371        const TAU: f64 = 1.3; // the link-knot crossing threshold τ
2372        let g_idx = 1usize;
2373        let g0 = 0.85_f64; // the slope value b (the g-primary IS the slope)
2374        // Stand-in intercept tower a(θ): nonzero value, gradient, Hessian in the
2375        // two live axes so a_u and a_uv are both exercised. (In production this
2376        // comes from implicit_solve; here we plant known derivatives.)
2377        let mut a = Tower4::<3>::constant(0.45);
2378        a.g[0] = 0.7;
2379        a.g[1] = -0.3;
2380        a.h[0][0] = 0.25;
2381        a.h[0][1] = 0.11;
2382        a.h[1][0] = 0.11;
2383        a.h[1][1] = -0.08;
2384
2385        // In the survival flex frame the slope `b` IS the g-primary directly
2386        // (the directional code passes `g` as `b`, and ∂z/∂g uses ∂b/∂g = 1):
2387        // z_edge = (τ − a) / b with b seeded as the g-axis variable.
2388        let b = Tower4::<3>::variable(g0, g_idx);
2389        let z_edge = (Tower4::<3>::constant(TAU) - a) / b;
2390
2391        let bv = g0;
2392        let z0 = z_edge.v;
2393        assert!((z0 - (TAU - 0.45) / bv).abs() < 1e-12);
2394
2395        // z_u = −(a_u + [u==g]·z) / b.
2396        for u in 0..2 {
2397            let direct = if u == g_idx { z0 } else { 0.0 };
2398            let want = -(a.g[u] + direct) / bv;
2399            assert!(
2400                (z_edge.g[u] - want).abs() < 1e-10,
2401                "z_u[{u}] {:+.8e} vs hand formula {:+.8e}",
2402                z_edge.g[u],
2403                want
2404            );
2405        }
2406        // z_uv = −(a_uv + [u==g]·z_v + [v==g]·z_u) / b, using the tower's own
2407        // first-order z_v/z_u (already verified above).
2408        for u in 0..2 {
2409            for v in 0..2 {
2410                let cross = if u == g_idx { z_edge.g[v] } else { 0.0 }
2411                    + if v == g_idx { z_edge.g[u] } else { 0.0 };
2412                let want = -(a.h[u][v] + cross) / bv;
2413                assert!(
2414                    (z_edge.h[u][v] - want).abs() < 1e-10,
2415                    "z_uv[{u}][{v}] {:+.8e} vs hand formula {:+.8e}",
2416                    z_edge.h[u][v],
2417                    want
2418                );
2419            }
2420        }
2421    }
2422
2423    /// The crossing-edge tower in the CONSTRAINT frame (intercept `a` and
2424    /// slope `b` BOTH independent — slots 0 and 1) reproduces taylor-jet's
2425    /// FD-certified bare boundary-velocity constants exactly:
2426    ///   z_a  = ∂z/∂a   = −1/b
2427    ///   z_ab = ∂²z/∂a∂b = +1/b²
2428    ///   z_aa = ∂²z/∂a²  = 0
2429    ///   z_bb = ∂²z/∂b²  = +2(τ−a)/b³
2430    /// These are the `f_a`/`f_au`/`f_aa` constraint-jet boundary motions the
2431    /// production base path drops (and only adds in the dir twins, causing the
2432    /// #932 desync). Here `a` is independent (NOT yet substituted with a(θ)),
2433    /// so `z_aa = 0` and there is no `a_uv` chain — `implicit_solve` introduces
2434    /// that later. Pins the constant before the constraint-tower wiring.
2435    #[test]
2436    fn crossing_edge_constraint_frame_matches_bare_velocity_constants() {
2437        const TAU: f64 = 1.3;
2438        let a0 = 0.45_f64;
2439        let b0 = 0.85_f64;
2440        // Slot 0 = a, slot 1 = b, both seeded independent.
2441        let a = Tower4::<2>::variable(a0, 0);
2442        let b = Tower4::<2>::variable(b0, 1);
2443        let z = (Tower4::<2>::constant(TAU) - a) / b;
2444
2445        assert!((z.v - (TAU - a0) / b0).abs() < 1e-12);
2446        assert!((z.g[0] - (-1.0 / b0)).abs() < 1e-12, "z_a {:+.10e}", z.g[0]);
2447        assert!(
2448            (z.h[0][1] - 1.0 / (b0 * b0)).abs() < 1e-12,
2449            "z_ab {:+.10e} vs +1/b² {:+.10e}",
2450            z.h[0][1],
2451            1.0 / (b0 * b0)
2452        );
2453        assert!(
2454            z.h[0][0].abs() < 1e-12,
2455            "z_aa must vanish, got {:+.10e}",
2456            z.h[0][0]
2457        );
2458        let want_zbb = 2.0 * (TAU - a0) / (b0 * b0 * b0);
2459        assert!(
2460            (z.h[1][1] - want_zbb).abs() < 1e-12,
2461            "z_bb {:+.10e} vs 2(τ−a)/b³ {:+.10e}",
2462            z.h[1][1],
2463            want_zbb
2464        );
2465    }
2466
2467    /// The oracle harness catches a planted #736-style sign flip in a
2468    /// cross block and reports the channel by name.
2469    #[test]
2470    fn oracle_catches_planted_cross_block_sign_flip() {
2471        let prog = LocScaleProgram {
2472            eta: vec![0.3],
2473            s: vec![-0.5],
2474            y: vec![1.0],
2475        };
2476        let t = evaluate_program(&prog, 0).expect("tower");
2477        let dir = [0.6, -0.2];
2478        let mut third = t.third_contracted(&dir);
2479        let honest = KernelChannels {
2480            value: t.v,
2481            gradient: t.g,
2482            hessian: t.h,
2483            third: vec![(dir, third)],
2484            fourth: vec![(dir, [1.0, 0.5], t.fourth_contracted(&dir, &[1.0, 0.5]))],
2485        };
2486        verify_kernel_channels(&t, &honest, 1e-10).expect("honest kernel must pass");
2487
2488        // Plant the #736 flip: negate one mixed cross entry.
2489        third[0][1] = -third[0][1];
2490        let flipped = KernelChannels {
2491            value: t.v,
2492            gradient: t.g,
2493            hessian: t.h,
2494            third: vec![(dir, third)],
2495            fourth: vec![],
2496        };
2497        let err = verify_kernel_channels(&t, &flipped, 1e-10)
2498            .expect_err("planted sign flip must be caught");
2499        assert!(
2500            err.contains("third[0][0][1]"),
2501            "oracle must name the flipped channel, got: {err}"
2502        );
2503    }
2504
2505    /// The third- and fourth-order tensors must be FULLY symmetric under
2506    /// index permutation (mixed partials commute). The tower stores them
2507    /// unsymmetrized, so equal-by-construction is a real invariant of the
2508    /// Leibniz/Faà di Bruno writes — a cheap typo tripwire. Asserted on a
2509    /// nontrivial K=3 tower with all of div/sqrt/powf/exp/ln exercised, so
2510    /// every composition path contributes. Lives in a test (not the hot
2511    /// per-op path) on purpose.
2512    #[test]
2513    fn t3_t4_are_fully_index_symmetric() {
2514        let prog = GnarlyProgram::fixture();
2515        // 3! = 6 permutations of three indices.
2516        let perms3: [[usize; 3]; 6] = [
2517            [0, 1, 2],
2518            [0, 2, 1],
2519            [1, 0, 2],
2520            [1, 2, 0],
2521            [2, 0, 1],
2522            [2, 1, 0],
2523        ];
2524        // 4! = 24 permutations of four indices.
2525        let perms4: [[usize; 4]; 24] = [
2526            [0, 1, 2, 3],
2527            [0, 1, 3, 2],
2528            [0, 2, 1, 3],
2529            [0, 2, 3, 1],
2530            [0, 3, 1, 2],
2531            [0, 3, 2, 1],
2532            [1, 0, 2, 3],
2533            [1, 0, 3, 2],
2534            [1, 2, 0, 3],
2535            [1, 2, 3, 0],
2536            [1, 3, 0, 2],
2537            [1, 3, 2, 0],
2538            [2, 0, 1, 3],
2539            [2, 0, 3, 1],
2540            [2, 1, 0, 3],
2541            [2, 1, 3, 0],
2542            [2, 3, 0, 1],
2543            [2, 3, 1, 0],
2544            [3, 0, 1, 2],
2545            [3, 0, 2, 1],
2546            [3, 1, 0, 2],
2547            [3, 1, 2, 0],
2548            [3, 2, 0, 1],
2549            [3, 2, 1, 0],
2550        ];
2551        for row in 0..prog.n_rows() {
2552            let t = evaluate_program(&prog, row).expect("gnarly tower");
2553            let scale_t3 =
2554                t.t3.iter()
2555                    .flatten()
2556                    .flatten()
2557                    .fold(0.0_f64, |m, x| m.max(x.abs()))
2558                    .max(1.0);
2559            let scale_t4 =
2560                t.t4.iter()
2561                    .flatten()
2562                    .flatten()
2563                    .flatten()
2564                    .fold(0.0_f64, |m, x| m.max(x.abs()))
2565                    .max(1.0);
2566            for i in 0..3 {
2567                for j in 0..3 {
2568                    for k in 0..3 {
2569                        let base = t.t3[i][j][k];
2570                        let idx = [i, j, k];
2571                        for p in &perms3 {
2572                            let permed = t.t3[idx[p[0]]][idx[p[1]]][idx[p[2]]];
2573                            assert!(
2574                                (base - permed).abs() <= 1e-12 * scale_t3,
2575                                "row {row}: t3[{i}][{j}][{k}]={base:+.15e} != \
2576                                 permuted {permed:+.15e} under {p:?}"
2577                            );
2578                        }
2579                        for l in 0..3 {
2580                            let base4 = t.t4[i][j][k][l];
2581                            let idx4 = [i, j, k, l];
2582                            for p in &perms4 {
2583                                let permed = t.t4[idx4[p[0]]][idx4[p[1]]][idx4[p[2]]][idx4[p[3]]];
2584                                assert!(
2585                                    (base4 - permed).abs() <= 1e-12 * scale_t4,
2586                                    "row {row}: t4[{i}][{j}][{k}][{l}]={base4:+.15e} != \
2587                                     permuted {permed:+.15e} under {p:?}"
2588                                );
2589                            }
2590                        }
2591                    }
2592                }
2593            }
2594        }
2595    }
2596}
2597
2598#[inline]
2599fn erfcx_nonnegative(x: f64) -> f64 {
2600    if !x.is_finite() {
2601        return if x.is_sign_positive() {
2602            0.0
2603        } else {
2604            f64::INFINITY
2605        };
2606    }
2607    if x <= 0.0 {
2608        return 1.0;
2609    }
2610    if x < 26.0 {
2611        ((x * x).min(700.0)).exp() * statrs::function::erf::erfc(x)
2612    } else {
2613        let inv = 1.0 / x;
2614        let inv2 = inv * inv;
2615        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
2616            + 6.5625 * inv2 * inv2 * inv2 * inv2;
2617        inv * poly / std::f64::consts::PI.sqrt()
2618    }
2619}
2620
2621#[inline]
2622fn log1mexp_positive(a: f64) -> f64 {
2623    assert!(a >= 0.0, "log1mexp_positive requires a >= 0: a={a}");
2624    if a > core::f64::consts::LN_2 {
2625        (-(-a).exp()).ln_1p()
2626    } else if a > 0.0 {
2627        (-(-a).exp_m1()).ln()
2628    } else {
2629        f64::NEG_INFINITY
2630    }
2631}
2632
2633#[inline]
2634fn signed_probit_logcdf_and_mills_ratio(x: f64) -> (f64, f64) {
2635    if x == f64::INFINITY {
2636        return (0.0, 0.0);
2637    }
2638    if x == f64::NEG_INFINITY {
2639        return (f64::NEG_INFINITY, f64::INFINITY);
2640    }
2641    if x.is_nan() {
2642        return (f64::NAN, f64::NAN);
2643    }
2644    if x < 0.0 {
2645        let u = -x / std::f64::consts::SQRT_2;
2646        let ex = erfcx_nonnegative(u).max(1e-300);
2647        let log_cdf = -u * u + (0.5 * ex).ln();
2648        let lambda = (2.0 / std::f64::consts::PI).sqrt() / ex;
2649        (log_cdf, lambda)
2650    } else {
2651        let cdf = crate::probability::normal_cdf(x).clamp(1e-300, 1.0);
2652        let lambda = crate::probability::normal_pdf(x) / cdf;
2653        (cdf.ln(), lambda)
2654    }
2655}
2656
2657/// Stable derivative stack for `log Phi(x)` through fourth order.
2658#[inline]
2659pub fn unary_derivatives_normal_logcdf(x: f64) -> [f64; 5] {
2660    let (log_cdf, lambda) = signed_probit_logcdf_and_mills_ratio(x);
2661    let lambda2 = lambda * lambda;
2662    let lambda3 = lambda2 * lambda;
2663    let x2 = x * x;
2664    [
2665        log_cdf,
2666        lambda,
2667        -lambda * (x + lambda),
2668        lambda * (x2 - 1.0 + 3.0 * x * lambda + 2.0 * lambda2),
2669        -lambda
2670            * ((x * x2 - 3.0 * x) + (7.0 * x2 - 4.0) * lambda + 12.0 * x * lambda2 + 6.0 * lambda3),
2671    ]
2672}
2673
2674/// Stable derivative stack for `log(1 - exp(-x))`, `x > 0`, through fourth order.
2675#[inline]
2676pub fn unary_derivatives_log1mexp_positive(x: f64) -> [f64; 5] {
2677    let r = 1.0 / x.exp_m1();
2678    [
2679        log1mexp_positive(x),
2680        r,
2681        -r * (1.0 + r),
2682        r * (1.0 + r) * (1.0 + 2.0 * r),
2683        -r * (1.0 + r) * (1.0 + 6.0 * r + 6.0 * r * r),
2684    ]
2685}