Skip to main content

gam_math/
jet_scalar.rs

1//! Order-specific Taylor-jet SCALAR algebras (#932 cutover, doc §A).
2//!
3//! [`crate::jet_tower::Tower4`] carries the full value/gradient/Hessian/`t3`/`t4`
4//! tensor stack: it answers EVERY channel a [`super::row_kernel::RowKernel`]
5//! consumer can ask for, but at `K = 9` that is a ~50 KiB per-row object whose
6//! by-value copies overflowed the stack and timed out the location-scale fit —
7//! which is exactly why `row_kernel_directional_supported()` /
8//! `row_kernel_joint_hessian_supported()` still `return false`. The cutover does
9//! NOT need the dense `Tower4<9>` per row; it needs, per consumer, only the one
10//! channel that consumer serves:
11//!
12//! | consumer | channel | scalar here | K=9 size |
13//! |---|---|---|---|
14//! | inner Newton / `row_kernel` | `(v, g, H)` | [`Order2`] | 728 B |
15//! | `row_third_contracted(dir)` | `Σ_c ℓ_{abc} dir_c` | [`OneSeed`] | 1.46 KiB |
16//! | `row_fourth_contracted(u, v)` | `Σ_{cd} ℓ_{abcd} u_c v_d` | [`TwoSeed`] | 2.8 KiB |
17//!
18//! Each is built on [`Order2`] (value/grad/Hessian), which is the production
19//! [`crate::jet_tower::Tower2`] re-expressed behind a generic interface: a row
20//! loss written ONCE against [`JetScalar`] re-instantiates at whatever order /
21//! representation a consumer needs, with the contraction folded INTO the
22//! differentiation (the nilpotent ε / δ directions), so `t3` / `t4` are never
23//! materialised. The single source of truth is the same one expression — the
24//! genus of #736 cross-block drift cannot reappear because there is no separate
25//! channel to forget.
26//!
27//! # Why each scalar is exact (doc §A.1–A.3)
28//!
29//! * [`Order2`] is the order-≤2 truncation of the Leibniz / Faà di Bruno rules.
30//!   Those order-2 terms read ONLY the order-≤2 channels of their inputs (see
31//!   [`crate::jet_tower::Tower4::mul`]: `out.h[i][j]` never touches `t3`/`t4`),
32//!   so its `(v, g, H)` is BIT-IDENTICAL to a full `Tower4<K>` — and identical
33//!   to [`crate::jet_tower::Tower2`], over which it is a thin newtype.
34//! * [`OneSeed`] carries an [`Order2`] base plus one nilpotent ε (`ε² = 0`)
35//!   holding another [`Order2`]. Seeding ε with the fixed direction `u` makes the
36//!   ε-component of the Hessian channel the contracted third `Σ_c ℓ_{abc} u_c`
37//!   (the nilpotent implements `d/dτ|₀` of `ℓ_{ab}(p + τu)` exactly).
38//! * [`TwoSeed`] carries an [`Order2`] base plus ε, δ (`ε² = δ² = 0`, `εδ`
39//!   retained) — four [`Order2`] parts. Seeding ε, δ with `u, v` makes the
40//!   εδ-component of the Hessian channel the contracted fourth
41//!   `Σ_{cd} ℓ_{abcd} u_c v_d` (the single mixed `∂_σ∂_ρ|₀` term, no `σ²`/`ρ²`
42//!   contamination).
43//!
44//! # Stability discipline
45//!
46//! As in [`crate::jet_tower`], humans own primitive stability and the algebra
47//! owns combinatorics: tail-critical special functions enter ONLY as
48//! hand-certified `[f64; 5]` derivative stacks through [`JetScalar::compose_unary`]
49//! (each scalar consumes the leading entries its order needs), never by
50//! differentiating an unstable primal.
51//!
52//! # Production scalars and the test-only all-channels oracle
53//!
54//! The `JetScalar` trait below is production: it is the bound on
55//! [`crate::jet_tower::RowNllProgramGeneric::row_nll_generic`], the seam a family
56//! row loss is written against. The order-specific scalars that *consume* it —
57//! [`Order2`] (value/grad/Hessian), [`OneSeed`] (contracted third) and
58//! [`TwoSeed`] (contracted fourth) — are production: the survival location-scale
59//! `RowKernel<9>` builds its joint Hessian / directional derivatives through them
60//! (`survival::location_scale::row_kernel`), paying only the small packed scalar
61//! per row instead of the ~50 KiB dense [`crate::jet_tower::Tower4`].
62//!
63//! The [`crate::jet_tower::Tower4`] all-channels `JetScalar` impl is test-only: it
64//! is the oracle that pins the contracted scalars against the dense
65//! value/grad/Hessian/`t3`/`t4` truth, so it lives in the `#[cfg(test)]` module.
66
67/// A truncated-Taylor scalar carrying derivatives in `K` primaries.
68///
69/// All concrete scalars here ([`Order2`], [`OneSeed`], [`TwoSeed`]) and the full
70/// [`crate::jet_tower::Tower4`] implement the SAME algebra; only the carried
71/// channel set differs. A row loss written once against this interface yields a
72/// different channel set per instantiation, all exact for the channel they serve
73/// (doc §A.0).
74pub trait JetScalar<const K: usize>: Copy {
75    /// A constant: value `c`, every derivative channel zero.
76    fn constant(c: f64) -> Self;
77
78    /// The seeded variable `p_axis` at value `x`: unit first derivative in slot
79    /// `axis`, all higher channels zero. (The nilpotent / cross channels of the
80    /// directional scalars are seeded zero — callers set ε/δ directions through
81    /// the scalar-specific [`OneSeed::seed_direction`] / [`TwoSeed::seed`].)
82    fn variable(x: f64, axis: usize) -> Self;
83
84    /// The value channel `ℓ(p)`.
85    fn value(&self) -> f64;
86
87    /// Exact truncated Leibniz sum `self + o`.
88    fn add(&self, o: &Self) -> Self;
89    /// Exact truncated Leibniz difference `self − o`.
90    fn sub(&self, o: &Self) -> Self;
91    /// Exact truncated Leibniz product `self · o`.
92    fn mul(&self, o: &Self) -> Self;
93    /// Negate every channel.
94    fn neg(&self) -> Self;
95    /// Multiply every channel by a plain scalar `s`.
96    fn scale(&self, s: f64) -> Self;
97
98    /// Exact multivariate Faà di Bruno composition `f ∘ self`, given the outer
99    /// derivative stack `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` at
100    /// `u = self.value()`.
101    ///
102    /// This is the SAME `[f64; 5]` stack shape [`crate::jet_tower::Tower4`] and
103    /// the families' `unary_derivatives_*` helpers (built on erfcx / log_ndtr)
104    /// already produce, so those stacks plug in directly. Each scalar consumes
105    /// only the leading entries its order needs (order-2 reads `d[0..=2]`; the
106    /// directional scalars read one / two beyond their base) — the fixed-length
107    /// array makes that windowing total, no length guard required.
108    fn compose_unary(&self, d: [f64; 5]) -> Self;
109
110    /// `e^self`. Convenience for tame arguments (see module stability note).
111    fn exp(&self) -> Self {
112        let e = self.value().exp();
113        self.compose_unary([e, e, e, e, e])
114    }
115
116    /// `√self`. Caller guarantees positivity.
117    fn sqrt(&self) -> Self {
118        let u = self.value();
119        let s = u.sqrt();
120        self.compose_unary([
121            s,
122            0.5 / s,
123            -0.25 / (u * s),
124            0.375 / (u * u * s),
125            -0.9375 / (u * u * u * s),
126        ])
127    }
128
129    /// `ln(self)`. Caller guarantees positivity. Same derivative stack
130    /// [`crate::jet_tower::Tower4::ln`] uses, so any program written over both
131    /// matches term-for-term.
132    fn ln(&self) -> Self {
133        let u = self.value();
134        let r = 1.0 / u;
135        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
136    }
137
138    /// `1/self`.
139    fn recip(&self) -> Self {
140        let r = 1.0 / self.value();
141        let r2 = r * r;
142        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
143    }
144
145    /// `self^a` for real exponent `a`. Caller guarantees a positive base.
146    /// Mirrors [`crate::jet_tower::Tower4::powf`] (falling-factorial stack).
147    fn powf(&self, a: f64) -> Self {
148        let u = self.value();
149        self.compose_unary([
150            u.powf(a),
151            a * u.powf(a - 1.0),
152            a * (a - 1.0) * u.powf(a - 2.0),
153            a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
154            a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
155        ])
156    }
157
158    /// `ln Γ(self)`. Caller guarantees a positive argument. Uses the SAME
159    /// hand-certified derivative stack [`crate::jet_tower::Tower4::ln_gamma`]
160    /// consumes ([`crate::jet_tower::ln_gamma_derivative_stack`]), so any
161    /// program written over both matches term-for-term.
162    fn ln_gamma(&self) -> Self {
163        self.compose_unary(crate::jet_tower::ln_gamma_derivative_stack(self.value()))
164    }
165
166    /// `ψ(self) = d/dx ln Γ(x)` (digamma). Caller guarantees a positive
167    /// argument. Same hand-certified stack
168    /// [`crate::jet_tower::digamma_derivative_stack`].
169    fn digamma(&self) -> Self {
170        self.compose_unary(crate::jet_tower::digamma_derivative_stack(self.value()))
171    }
172}
173
174// ── Order2<K> ergonomic operator overloads (doc §A.1) ───────────────────
175//
176// The dispersion-family row NLLs are written with `+`/`-`/`*` operators over
177// the primaries (mirroring how they read as `Tower4` expressions). These
178// delegate channel-for-channel to the inner `Tower2` arithmetic (which has
179// `Add`/`Mul`; `Sub`/`Neg` are expressed as `+ (-1)·rhs` exactly as the
180// `JetScalar::sub` / `JetScalar::neg` impls do), so an `Order2` expression is
181// bit-identical to the same `Tower4` expression's order-≤2 channels.
182
183impl<const K: usize> std::ops::Add for Order2<K> {
184    type Output = Self;
185    #[inline]
186    fn add(self, o: Self) -> Self {
187        Order2(self.0 + o.0)
188    }
189}
190
191impl<const K: usize> std::ops::Add<f64> for Order2<K> {
192    type Output = Self;
193    #[inline]
194    fn add(self, c: f64) -> Self {
195        Order2(self.0 + c)
196    }
197}
198
199impl<const K: usize> std::ops::Sub for Order2<K> {
200    type Output = Self;
201    #[inline]
202    fn sub(self, o: Self) -> Self {
203        Order2(self.0 + o.0.scale(-1.0))
204    }
205}
206
207impl<const K: usize> std::ops::Sub<f64> for Order2<K> {
208    type Output = Self;
209    #[inline]
210    fn sub(self, c: f64) -> Self {
211        Order2(self.0 + (-c))
212    }
213}
214
215impl<const K: usize> std::ops::Mul for Order2<K> {
216    type Output = Self;
217    #[inline]
218    fn mul(self, o: Self) -> Self {
219        Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
220    }
221}
222
223impl<const K: usize> std::ops::Mul<f64> for Order2<K> {
224    type Output = Self;
225    #[inline]
226    fn mul(self, c: f64) -> Self {
227        Order2(self.0.scale(c))
228    }
229}
230
231impl<const K: usize> std::ops::Neg for Order2<K> {
232    type Output = Self;
233    #[inline]
234    fn neg(self) -> Self {
235        Order2(self.0.scale(-1.0))
236    }
237}
238
239/// Filtered Hensel lift of a SCALAR implicit state `a(θ)` defined by the
240/// constraint `F(a, θ) = 0`, evaluated in ANY [`JetScalar`] algebra `S` (doc
241/// §11, "A generic implicit-lift operator for every production scalar").
242///
243/// This is the perf-respecting alternative to lifting through a dense
244/// `Tower4<K+1>` (which carries the implicit variable as an extra dense axis):
245/// the state `a` lives directly in the consumer's own `K`-primary algebra
246/// `S` — `Order2<K>` for value/gradient/Hessian, `Tower4<K>` for the full
247/// `t3`/`t4` — never paying for an extra variable.
248///
249/// **Method.** Fixed-Jacobian Newton in the nilpotent algebra. By the
250/// filtered-lift theorem (doc §11.1), if `F_a := ∂F/∂a(a₀, θ₀)` is the primal
251/// Jacobian at the base point and `inv_fa = 1/F_a`, then the iteration
252/// `A ← A − inv_fa · F(A, θ)` raises the filtration degree of the residual by
253/// at least one per step: each step kills exactly one graded layer. Starting
254/// from `A = const(a₀)` (whose residual lies in `F¹` because `θ − θ₀ ∈ 𝔫`),
255/// `iters` equal to the algebra's nilpotency order returns the *exact* lifted
256/// jet (`Order2`: 2, `OneSeed`: 3, `Tower4`/`TwoSeed`: 4). The value channel of
257/// `A` never moves — `F(A, θ).value() = F(a₀, θ₀) = 0` at the certified root —
258/// so a caller may precompute every primitive's derivative stack at the fixed
259/// base index once and let the cheap polynomial composition repeat per step.
260///
261/// `f` evaluates the constraint `F(a, θ)` in `S` (capturing the seeded
262/// parameter jets `θ`); `a0` is the certified scalar root `F(a₀, θ₀) ≈ 0`.
263pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
264    a0: f64,
265    inv_fa: f64,
266    iters: usize,
267    f: impl Fn(&S) -> S,
268) -> S {
269    let mut a = S::constant(a0);
270    for _ in 0..iters {
271        let residual = f(&a);
272        a = a.sub(&residual.scale(inv_fa));
273    }
274    a
275}
276
277// ── Order2<K>: value / gradient / Hessian (doc §A.1) ────────────────────
278
279/// Truncated SECOND-order scalar: value `v`, gradient `g_a`, Hessian `H_{ab}`.
280///
281/// This is a thin newtype over the production [`crate::jet_tower::Tower2`], so
282/// its `(v, g, H)` channels are obtained by the SAME formulas — and are
283/// therefore bit-identical to both [`crate::jet_tower::Tower2`] and the order-≤2
284/// channels of a full [`crate::jet_tower::Tower4`] (doc §A.1, "Bit-identity with
285/// the full tower"). The wrapper exists only to satisfy the generic
286/// [`JetScalar`] interface (the `compose_unary` / `add` / `sub` / `neg` /
287/// `recip` the trait demands, which `Tower2` does not expose by that shape) —
288/// every channel is delegated to `Tower2` arithmetic unchanged.
289#[derive(Clone, Copy, Debug)]
290pub struct Order2<const K: usize>(pub crate::jet_tower::Tower2<K>);
291
292impl<const K: usize> Order2<K> {
293    /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
294    #[inline]
295    pub fn g(&self) -> [f64; K] {
296        self.0.g
297    }
298
299    /// Read the Hessian channel.
300    #[inline]
301    pub fn h(&self) -> [[f64; K]; K] {
302        self.0.h
303    }
304}
305
306impl<const K: usize> JetScalar<K> for Order2<K> {
307    fn constant(c: f64) -> Self {
308        Order2(crate::jet_tower::Tower2::constant(c))
309    }
310    fn variable(x: f64, axis: usize) -> Self {
311        Order2(crate::jet_tower::Tower2::variable(x, axis))
312    }
313    fn value(&self) -> f64 {
314        self.0.v
315    }
316    fn add(&self, o: &Self) -> Self {
317        Order2(self.0 + o.0)
318    }
319    fn sub(&self, o: &Self) -> Self {
320        // Tower2 has no Sub op; subtract by adding the negation, matching
321        // Tower4::sub (self + o.scale(-1.0)).
322        Order2(self.0 + o.0.scale(-1.0))
323    }
324    fn mul(&self, o: &Self) -> Self {
325        Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
326    }
327    fn neg(&self) -> Self {
328        Order2(self.0.scale(-1.0))
329    }
330    fn scale(&self, s: f64) -> Self {
331        Order2(self.0.scale(s))
332    }
333    fn compose_unary(&self, d: [f64; 5]) -> Self {
334        // Order-≤2 reads only [f, f', f''] of the stack.
335        Order2(self.0.compose_unary([d[0], d[1], d[2]]))
336    }
337}
338
339// ── OneSeed<K>: one-seed directional, contracted third (doc §A.2) ───────
340
341/// One-seed directional scalar: an [`Order2`] base plus ONE nilpotent ε
342/// (`ε² = 0`) whose coefficient is itself an [`Order2`].
343///
344/// A scalar is `s = base + ε·eps`. Arithmetic is the `ε² = 0` truncation of the
345/// product (doc §A.2): the base parts multiply as ordinary [`Order2`] products,
346/// and the ε-coefficient picks up `a.base·b.eps + a.eps·b.base`. Composition
347/// pushes ε through one extra outer derivative.
348///
349/// Seed each primary with [`seed_direction`](Self::seed_direction): the base is
350/// the usual seeded variable (carrying `e_a` for the Hessian channel) and the
351/// ε-coefficient is the FIXED contraction direction `u_a` (a constant). Then the
352/// ε-component of the evaluated Hessian channel is the contracted third
353/// `[eps.h][a][b] = Σ_c ℓ_{abc} u_c` — exactly `row_third_contracted(dir = u)`,
354/// without materialising `t3`.
355#[derive(Clone, Copy, Debug)]
356pub struct OneSeed<const K: usize> {
357    /// The `ε⁰` part: value / gradient / Hessian of `ℓ`.
358    pub base: Order2<K>,
359    /// The `ε¹` part: value / gradient / Hessian of the ε-coefficient. After a
360    /// `seed_direction(u)` evaluation, `eps.h[a][b] = Σ_c ℓ_{abc} u_c`.
361    pub eps: Order2<K>,
362}
363
364impl<const K: usize> OneSeed<K> {
365    /// Seed primary `axis` at value `x` with ε-direction component `u_axis`:
366    /// `p_axis = p_axis⁰ + x-seed + ε·u_axis`, i.e. base = `variable(x, axis)`
367    /// and eps = `constant(u_axis)` (doc §A.2 "Seeding").
368    pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
369        OneSeed {
370            base: Order2::variable(x, axis),
371            eps: Order2::constant(u_axis),
372        }
373    }
374
375    /// The contracted-third channel after a `seed_direction(u)` evaluation:
376    /// `out[a][b] = Σ_c ℓ_{abc} u_c`, i.e. the ε-coefficient's Hessian (doc §A.2).
377    pub fn contracted_third(&self) -> [[f64; K]; K] {
378        self.eps.h()
379    }
380}
381
382impl<const K: usize> JetScalar<K> for OneSeed<K> {
383    fn constant(c: f64) -> Self {
384        OneSeed {
385            base: Order2::constant(c),
386            eps: Order2::constant(0.0),
387        }
388    }
389    fn variable(x: f64, axis: usize) -> Self {
390        // No ε-direction unless seeded via `seed_direction`.
391        OneSeed {
392            base: Order2::variable(x, axis),
393            eps: Order2::constant(0.0),
394        }
395    }
396    fn value(&self) -> f64 {
397        self.base.value()
398    }
399    fn add(&self, o: &Self) -> Self {
400        OneSeed {
401            base: self.base.add(&o.base),
402            eps: self.eps.add(&o.eps),
403        }
404    }
405    fn sub(&self, o: &Self) -> Self {
406        OneSeed {
407            base: self.base.sub(&o.base),
408            eps: self.eps.sub(&o.eps),
409        }
410    }
411    fn mul(&self, o: &Self) -> Self {
412        // (a.base + ε a.eps)(b.base + ε b.eps), dropping ε².
413        OneSeed {
414            base: self.base.mul(&o.base),
415            eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
416        }
417    }
418    fn neg(&self) -> Self {
419        OneSeed {
420            base: self.base.neg(),
421            eps: self.eps.neg(),
422        }
423    }
424    fn scale(&self, s: f64) -> Self {
425        OneSeed {
426            base: self.base.scale(s),
427            eps: self.eps.scale(s),
428        }
429    }
430    fn compose_unary(&self, d: [f64; 5]) -> Self {
431        // f(base + ε eps) = f(base) + ε · f'(base)·eps  (ε² = 0). Each factor is
432        // an Order2 composition: the base composes with the f-stack, and the
433        // ε-coefficient is the Order2 of the SHIFTED stack (the chain rule
434        // `f'(base)` as an Order2) times eps. Order2 reads only the leading
435        // three entries of whatever stack it is handed, so the trailing slots
436        // are unused padding (the fixed-length array makes the windowing total).
437        let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
438        // f'(base) as an Order2 (consumes [f', f'', f''']).
439        let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
440        let eps = fprime.mul(&self.eps);
441        OneSeed { base, eps }
442    }
443}
444
445// ── TwoSeed<K>: two-seed, contracted fourth (doc §A.3) ──────────────────
446
447/// Two-seed scalar: an [`Order2`] base plus TWO nilpotents ε, δ
448/// (`ε² = δ² = 0`, `εδ` retained) — four [`Order2`] parts
449/// `s = base + ε·eps + δ·del + εδ·eps_del`.
450///
451/// Product truncates `ε² = δ² = 0` (doc §A.3): each part is built from
452/// [`Order2`] products of the four input parts. Composition picks up
453/// successively higher outer derivatives, the cross part carrying the second
454/// Faà di Bruno term `f''·eps·del + f'·eps_del`.
455///
456/// Seed each primary with [`seed`](Self::seed): base = `variable(x, axis)`,
457/// eps = `constant(u_axis)`, del = `constant(v_axis)`, eps_del = `constant(0)`.
458/// Then the εδ-component of the evaluated Hessian channel is the contracted
459/// fourth `[eps_del.h][a][b] = Σ_{cd} ℓ_{abcd} u_c v_d` — exactly
460/// `row_fourth_contracted(u, v)`, without materialising `t4`.
461#[derive(Clone, Copy, Debug)]
462pub struct TwoSeed<const K: usize> {
463    /// The `ε⁰δ⁰` part: value / grad / Hessian of `ℓ`.
464    pub base: Order2<K>,
465    /// The `ε¹δ⁰` part.
466    pub eps: Order2<K>,
467    /// The `ε⁰δ¹` part.
468    pub del: Order2<K>,
469    /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation,
470    /// `eps_del.h[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`.
471    pub eps_del: Order2<K>,
472}
473
474impl<const K: usize> TwoSeed<K> {
475    /// Seed primary `axis` at value `x` with ε-direction `u_axis` and
476    /// δ-direction `v_axis`:
477    /// `p_axis = p_axis⁰ + x-seed + ε·u_axis + δ·v_axis` (doc §A.3 "Seeding").
478    pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
479        TwoSeed {
480            base: Order2::variable(x, axis),
481            eps: Order2::constant(u_axis),
482            del: Order2::constant(v_axis),
483            eps_del: Order2::constant(0.0),
484        }
485    }
486
487    /// The contracted-fourth channel after a `seed(u, v)` evaluation:
488    /// `out[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`, i.e. the εδ-coefficient's Hessian.
489    pub fn contracted_fourth(&self) -> [[f64; K]; K] {
490        self.eps_del.h()
491    }
492}
493
494impl<const K: usize> JetScalar<K> for TwoSeed<K> {
495    fn constant(c: f64) -> Self {
496        TwoSeed {
497            base: Order2::constant(c),
498            eps: Order2::constant(0.0),
499            del: Order2::constant(0.0),
500            eps_del: Order2::constant(0.0),
501        }
502    }
503    fn variable(x: f64, axis: usize) -> Self {
504        TwoSeed {
505            base: Order2::variable(x, axis),
506            eps: Order2::constant(0.0),
507            del: Order2::constant(0.0),
508            eps_del: Order2::constant(0.0),
509        }
510    }
511    fn value(&self) -> f64 {
512        self.base.value()
513    }
514    fn add(&self, o: &Self) -> Self {
515        TwoSeed {
516            base: self.base.add(&o.base),
517            eps: self.eps.add(&o.eps),
518            del: self.del.add(&o.del),
519            eps_del: self.eps_del.add(&o.eps_del),
520        }
521    }
522    fn sub(&self, o: &Self) -> Self {
523        TwoSeed {
524            base: self.base.sub(&o.base),
525            eps: self.eps.sub(&o.eps),
526            del: self.del.sub(&o.del),
527            eps_del: self.eps_del.sub(&o.eps_del),
528        }
529    }
530    fn mul(&self, o: &Self) -> Self {
531        let a = self;
532        let b = o;
533        // Truncate ε² = δ² = 0 (doc §A.3 product table).
534        let base = a.base.mul(&b.base);
535        let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
536        let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
537        let eps_del = a
538            .base
539            .mul(&b.eps_del)
540            .add(&a.eps.mul(&b.del))
541            .add(&a.del.mul(&b.eps))
542            .add(&a.eps_del.mul(&b.base));
543        TwoSeed {
544            base,
545            eps,
546            del,
547            eps_del,
548        }
549    }
550    fn neg(&self) -> Self {
551        TwoSeed {
552            base: self.base.neg(),
553            eps: self.eps.neg(),
554            del: self.del.neg(),
555            eps_del: self.eps_del.neg(),
556        }
557    }
558    fn scale(&self, s: f64) -> Self {
559        TwoSeed {
560            base: self.base.scale(s),
561            eps: self.eps.scale(s),
562            del: self.del.scale(s),
563            eps_del: self.eps_del.scale(s),
564        }
565    }
566    fn compose_unary(&self, d: [f64; 5]) -> Self {
567        // f(s) with s = base + ε eps + δ del + εδ eps_del, ε²=δ²=0:
568        //   f(s) = f(base)
569        //        + ε · f'(base)·eps
570        //        + δ · f'(base)·del
571        //        + εδ · ( f''(base)·eps·del + f'(base)·eps_del ).
572        // Each f^{(r)}(base) is the Order2 composition of base with the stack
573        // shifted r entries (doc §A.3 composition). Order2 reads only the
574        // leading three entries of whatever stack it is handed, so the trailing
575        // padding slots are unused (the fixed-length array makes this total).
576        let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
577        let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); // f'(base) as Order2
578        let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); // f''(base) as Order2
579        let eps = fprime.mul(&self.eps);
580        let del = fprime.mul(&self.del);
581        let eps_del = fsecond
582            .mul(&self.eps)
583            .mul(&self.del)
584            .add(&fprime.mul(&self.eps_del));
585        TwoSeed {
586            base,
587            eps,
588            del,
589            eps_del,
590        }
591    }
592}
593
594// ── Tower4<K>: full dense tower as a JetScalar (the all-channels scalar) ─
595
596/// The full dense [`crate::jet_tower::Tower4`] is itself a [`JetScalar`]: it
597/// carries EVERY channel, so a row expression written ONCE against [`JetScalar`]
598/// can be evaluated at `Tower4` to obtain the full `(v, g, H, t3, t4)` in one
599/// pass. This is BOTH the #932 oracle ground truth the packed [`Order2`] /
600/// [`OneSeed`] / [`TwoSeed`] scalars are pinned against, AND a production scalar:
601/// a family whose uncontracted third / fourth derivative tensors are needed
602/// (the BMS rigid `third_full` / `fourth_full` caches) evaluates the SAME
603/// generic row-NLL expression at `Tower4` and reads `.t3` / `.t4` off the
604/// result — so the dense tensors come from the single source of truth, not a
605/// separately hand-written jet. The packed scalars serve the consumers that
606/// need only `(v, g, H)` (`Order2`) or one / two contractions
607/// (`OneSeed` / `TwoSeed`) without paying for the dense tensors.
608impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower4<K> {
609    fn constant(c: f64) -> Self {
610        crate::jet_tower::Tower4::constant(c)
611    }
612    fn variable(x: f64, axis: usize) -> Self {
613        crate::jet_tower::Tower4::variable(x, axis)
614    }
615    fn value(&self) -> f64 {
616        self.v
617    }
618    fn add(&self, o: &Self) -> Self {
619        *self + *o
620    }
621    fn sub(&self, o: &Self) -> Self {
622        *self - *o
623    }
624    fn mul(&self, o: &Self) -> Self {
625        crate::jet_tower::Tower4::mul(self, o)
626    }
627    fn neg(&self) -> Self {
628        self.scale(-1.0)
629    }
630    fn scale(&self, s: f64) -> Self {
631        crate::jet_tower::Tower4::scale(self, s)
632    }
633    fn compose_unary(&self, d: [f64; 5]) -> Self {
634        crate::jet_tower::Tower4::compose_unary(self, d)
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use crate::jet_tower::{RowNllProgram, Tower4, evaluate_program};
642
643    /// A small polynomial-plus-unary row expression written ONCE, generically
644    /// over `S: JetScalar<2>`, so it can be evaluated against every scalar:
645    /// `ℓ = (e^{p0·p1} + 2) · √(p0·p0 + 1) − p1·p1·0.5`.
646    /// Exercises mul, add/sub, scale, exp, sqrt — every algebra op.
647    fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
648        let g = p[0].mul(&p[1]).exp();
649        let inner = g.add(&S::constant(2.0));
650        let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
651        inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
652    }
653
654    /// The same expression as a Tower4 `RowNllProgram`, the ground-truth tower.
655    struct ExprProgram {
656        p: [f64; 2],
657    }
658    impl RowNllProgram<2> for ExprProgram {
659        fn n_rows(&self) -> usize {
660            1
661        }
662        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
663            if row >= self.n_rows() {
664                return Err(format!("ExprProgram: row {row} out of range"));
665            }
666            Ok(self.p)
667        }
668        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
669            if row >= self.n_rows() {
670                return Err(format!("ExprProgram: row {row} out of range"));
671            }
672            Ok(row_expr(p))
673        }
674    }
675
676    const SEED: [f64; 2] = [0.37, -0.81];
677    const U: [f64; 2] = [0.6, -0.2];
678    const V: [f64; 2] = [-0.4, 1.1];
679    const TOL: f64 = 1e-10;
680
681    fn close(a: f64, b: f64, label: &str) {
682        let band = TOL + TOL * a.abs().max(b.abs());
683        assert!(
684            (a - b).abs() <= band,
685            "{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
686        );
687    }
688
689    fn tower() -> Tower4<2> {
690        evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
691    }
692
693    /// Order2 reproduces Tower4's value/grad/Hessian channels exactly.
694    #[test]
695    fn order2_matches_tower_value_grad_hessian() {
696        let t = tower();
697        let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
698        let s = row_expr(&vars);
699        close(s.value(), t.v, "value");
700        for a in 0..2 {
701            close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
702            for b in 0..2 {
703                close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
704            }
705        }
706    }
707
708    /// OneSeed's ε-Hessian is the contracted third Σ_c ℓ_{abc} u_c, matching
709    /// `Tower4::third_contracted(u)`. Base channels also match the tower.
710    #[test]
711    fn one_seed_matches_tower_third_contracted() {
712        let t = tower();
713        let truth = t.third_contracted(&U);
714        let vars: [OneSeed<2>; 2] =
715            std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
716        let s = row_expr(&vars);
717        // Base channels are the plain (v, g, H).
718        close(s.value(), t.v, "value");
719        for a in 0..2 {
720            for b in 0..2 {
721                close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
722            }
723        }
724        let third = s.contracted_third();
725        for a in 0..2 {
726            for b in 0..2 {
727                close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
728            }
729        }
730    }
731
732    /// TwoSeed's εδ-Hessian is the contracted fourth Σ_{cd} ℓ_{abcd} u_c v_d,
733    /// matching `Tower4::fourth_contracted(u, v)`. The ε / δ single-seed parts
734    /// reproduce the two third contractions Σ_c ℓ_{abc} u_c and …v_d.
735    #[test]
736    fn two_seed_matches_tower_fourth_contracted() {
737        let t = tower();
738        let truth4 = t.fourth_contracted(&U, &V);
739        let truth3_u = t.third_contracted(&U);
740        let truth3_v = t.third_contracted(&V);
741        let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
742        let s = row_expr(&vars);
743        close(s.value(), t.v, "value");
744        for a in 0..2 {
745            close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
746            for b in 0..2 {
747                close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
748                close(
749                    s.eps.h()[a][b],
750                    truth3_u[a][b],
751                    &format!("eps third_u[{a}][{b}]"),
752                );
753                close(
754                    s.del.h()[a][b],
755                    truth3_v[a][b],
756                    &format!("del third_v[{a}][{b}]"),
757                );
758            }
759        }
760        let fourth = s.contracted_fourth();
761        for a in 0..2 {
762            for b in 0..2 {
763                close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
764            }
765        }
766    }
767
768    /// The generic `row_nll_generic` seam (added to Tower4's program trait
769    /// surface) evaluates the SAME expression on each scalar and extracts the
770    /// channel a consumer asks for, agreeing with the direct Tower4 contraction.
771    #[test]
772    fn generic_program_seam_matches_tower_for_every_channel() {
773        let t = tower();
774        // Order2 via generic seam.
775        let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
776        let so2 = row_expr(&o2);
777        close(so2.value(), t.v, "seam order2 value");
778        // OneSeed third.
779        let os: [OneSeed<2>; 2] =
780            std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
781        let third = row_expr(&os).contracted_third();
782        let truth3 = t.third_contracted(&U);
783        for a in 0..2 {
784            for b in 0..2 {
785                close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
786            }
787        }
788        // TwoSeed fourth.
789        let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
790        let fourth = row_expr(&ts).contracted_fourth();
791        let truth4 = t.fourth_contracted(&U, &V);
792        for a in 0..2 {
793            for b in 0..2 {
794                close(
795                    fourth[a][b],
796                    truth4[a][b],
797                    &format!("seam fourth[{a}][{b}]"),
798                );
799            }
800        }
801    }
802
803    /// The (test-only) `Tower4: JetScalar` impl is the all-channels oracle scalar:
804    /// evaluating the SAME generic `row_expr` at `S = Tower4` (through the
805    /// `JetScalar` trait ops) must reproduce, channel-for-channel, the `Tower4`
806    /// obtained from the `RowNllProgram` / inherent-operator path
807    /// (`evaluate_program`). This pins that the trait impl delegates faithfully to
808    /// the inherent `Tower4` arithmetic (so the contracted-scalar oracles above,
809    /// which compare against `evaluate_program`'s tower, are comparing against the
810    /// same algebra the `JetScalar` interface exposes).
811    #[test]
812    fn tower4_as_jetscalar_matches_program_tower_all_channels() {
813        let t = tower();
814        let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
815        let s = row_expr(&vars);
816        close(s.v, t.v, "tower-jetscalar value");
817        for a in 0..2 {
818            close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
819            for b in 0..2 {
820                close(
821                    s.h[a][b],
822                    t.h[a][b],
823                    &format!("tower-jetscalar hess[{a}][{b}]"),
824                );
825                for c in 0..2 {
826                    close(
827                        s.t3[a][b][c],
828                        t.t3[a][b][c],
829                        &format!("tower-jetscalar t3[{a}][{b}][{c}]"),
830                    );
831                    for d in 0..2 {
832                        close(
833                            s.t4[a][b][c][d],
834                            t.t4[a][b][c][d],
835                            &format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
836                        );
837                    }
838                }
839            }
840        }
841    }
842}