gam_math/jet_scalar.rs
1//! Order-specific Taylor-jet SCALAR algebras (#932 cutover, doc §A).
2//!
3//! [`crate::jet_tower::Tower4`] carries the full value/gradient/Hessian/`t3`/`t4`
4//! tensor stack: it answers EVERY channel a [`super::row_kernel::RowKernel`]
5//! consumer can ask for, but at `K = 9` that is a ~50 KiB per-row object whose
6//! by-value copies overflowed the stack and timed out the location-scale fit —
7//! which is exactly why `row_kernel_directional_supported()` /
8//! `row_kernel_joint_hessian_supported()` still `return false`. The cutover does
9//! NOT need the dense `Tower4<9>` per row; it needs, per consumer, only the one
10//! channel that consumer serves:
11//!
12//! | consumer | channel | scalar here | K=9 size |
13//! |---|---|---|---|
14//! | inner Newton / `row_kernel` | `(v, g, H)` | [`Order2`] | 728 B |
15//! | `row_third_contracted(dir)` | `Σ_c ℓ_{abc} dir_c` | [`OneSeed`] | 1.46 KiB |
16//! | `row_fourth_contracted(u, v)` | `Σ_{cd} ℓ_{abcd} u_c v_d` | [`TwoSeed`] | 2.8 KiB |
17//!
18//! Each is built on [`Order2`] (value/grad/Hessian), which is the production
19//! [`crate::jet_tower::Tower2`] re-expressed behind a generic interface: a row
20//! loss written ONCE against [`JetScalar`] re-instantiates at whatever order /
21//! representation a consumer needs, with the contraction folded INTO the
22//! differentiation (the nilpotent ε / δ directions), so `t3` / `t4` are never
23//! materialised. The single source of truth is the same one expression — the
24//! genus of #736 cross-block drift cannot reappear because there is no separate
25//! channel to forget.
26//!
27//! # Why each scalar is exact (doc §A.1–A.3)
28//!
29//! * [`Order2`] is the order-≤2 truncation of the Leibniz / Faà di Bruno rules.
30//! Those order-2 terms read ONLY the order-≤2 channels of their inputs (see
31//! [`crate::jet_tower::Tower4::mul`]: `out.h[i][j]` never touches `t3`/`t4`),
32//! so its `(v, g, H)` is BIT-IDENTICAL to a full `Tower4<K>` — and identical
33//! to [`crate::jet_tower::Tower2`], over which it is a thin newtype.
34//! * [`OneSeed`] carries an [`Order2`] base plus one nilpotent ε (`ε² = 0`)
35//! holding another [`Order2`]. Seeding ε with the fixed direction `u` makes the
36//! ε-component of the Hessian channel the contracted third `Σ_c ℓ_{abc} u_c`
37//! (the nilpotent implements `d/dτ|₀` of `ℓ_{ab}(p + τu)` exactly).
38//! * [`TwoSeed`] carries an [`Order2`] base plus ε, δ (`ε² = δ² = 0`, `εδ`
39//! retained) — four [`Order2`] parts. Seeding ε, δ with `u, v` makes the
40//! εδ-component of the Hessian channel the contracted fourth
41//! `Σ_{cd} ℓ_{abcd} u_c v_d` (the single mixed `∂_σ∂_ρ|₀` term, no `σ²`/`ρ²`
42//! contamination).
43//!
44//! # Stability discipline
45//!
46//! As in [`crate::jet_tower`], humans own primitive stability and the algebra
47//! owns combinatorics: tail-critical special functions enter ONLY as
48//! hand-certified `[f64; 5]` derivative stacks through [`JetScalar::compose_unary`]
49//! (each scalar consumes the leading entries its order needs), never by
50//! differentiating an unstable primal.
51//!
52//! # Production scalars and the test-only all-channels oracle
53//!
54//! The `JetScalar` trait below is production: it is the bound on
55//! [`crate::jet_tower::RowNllProgramGeneric::row_nll_generic`], the seam a family
56//! row loss is written against. The order-specific scalars that *consume* it —
57//! [`Order2`] (value/grad/Hessian), [`OneSeed`] (contracted third) and
58//! [`TwoSeed`] (contracted fourth) — are production: the survival location-scale
59//! `RowKernel<9>` builds its joint Hessian / directional derivatives through them
60//! (`survival::location_scale::row_kernel`), paying only the small packed scalar
61//! per row instead of the ~50 KiB dense [`crate::jet_tower::Tower4`].
62//!
63//! The [`crate::jet_tower::Tower4`] all-channels `JetScalar` impl is test-only: it
64//! is the oracle that pins the contracted scalars against the dense
65//! value/grad/Hessian/`t3`/`t4` truth, so it lives in the `#[cfg(test)]` module.
66
67/// A truncated-Taylor scalar carrying derivatives in `K` primaries.
68///
69/// All concrete scalars here ([`Order2`], [`OneSeed`], [`TwoSeed`]) and the full
70/// [`crate::jet_tower::Tower4`] implement the SAME algebra; only the carried
71/// channel set differs. A row loss written once against this interface yields a
72/// different channel set per instantiation, all exact for the channel they serve
73/// (doc §A.0).
74pub trait JetScalar<const K: usize>: Copy {
75 /// A constant: value `c`, every derivative channel zero.
76 fn constant(c: f64) -> Self;
77
78 /// The seeded variable `p_axis` at value `x`: unit first derivative in slot
79 /// `axis`, all higher channels zero. (The nilpotent / cross channels of the
80 /// directional scalars are seeded zero — callers set ε/δ directions through
81 /// the scalar-specific [`OneSeed::seed_direction`] / [`TwoSeed::seed`].)
82 fn variable(x: f64, axis: usize) -> Self;
83
84 /// The value channel `ℓ(p)`.
85 fn value(&self) -> f64;
86
87 /// Exact truncated Leibniz sum `self + o`.
88 fn add(&self, o: &Self) -> Self;
89 /// Exact truncated Leibniz difference `self − o`.
90 fn sub(&self, o: &Self) -> Self;
91 /// Exact truncated Leibniz product `self · o`.
92 fn mul(&self, o: &Self) -> Self;
93 /// Negate every channel.
94 fn neg(&self) -> Self;
95 /// Multiply every channel by a plain scalar `s`.
96 fn scale(&self, s: f64) -> Self;
97
98 /// Exact multivariate Faà di Bruno composition `f ∘ self`, given the outer
99 /// derivative stack `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` at
100 /// `u = self.value()`.
101 ///
102 /// This is the SAME `[f64; 5]` stack shape [`crate::jet_tower::Tower4`] and
103 /// the families' `unary_derivatives_*` helpers (built on erfcx / log_ndtr)
104 /// already produce, so those stacks plug in directly. Each scalar consumes
105 /// only the leading entries its order needs (order-2 reads `d[0..=2]`; the
106 /// directional scalars read one / two beyond their base) — the fixed-length
107 /// array makes that windowing total, no length guard required.
108 fn compose_unary(&self, d: [f64; 5]) -> Self;
109
110 /// `e^self`. Convenience for tame arguments (see module stability note).
111 fn exp(&self) -> Self {
112 let e = self.value().exp();
113 self.compose_unary([e, e, e, e, e])
114 }
115
116 /// `√self`. Caller guarantees positivity.
117 fn sqrt(&self) -> Self {
118 let u = self.value();
119 let s = u.sqrt();
120 self.compose_unary([
121 s,
122 0.5 / s,
123 -0.25 / (u * s),
124 0.375 / (u * u * s),
125 -0.9375 / (u * u * u * s),
126 ])
127 }
128
129 /// `ln(self)`. Caller guarantees positivity. Same derivative stack
130 /// [`crate::jet_tower::Tower4::ln`] uses, so any program written over both
131 /// matches term-for-term.
132 fn ln(&self) -> Self {
133 let u = self.value();
134 let r = 1.0 / u;
135 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
136 }
137
138 /// `1/self`.
139 fn recip(&self) -> Self {
140 let r = 1.0 / self.value();
141 let r2 = r * r;
142 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
143 }
144
145 /// `self^a` for real exponent `a`. Caller guarantees a positive base.
146 /// Mirrors [`crate::jet_tower::Tower4::powf`] (falling-factorial stack).
147 fn powf(&self, a: f64) -> Self {
148 let u = self.value();
149 self.compose_unary([
150 u.powf(a),
151 a * u.powf(a - 1.0),
152 a * (a - 1.0) * u.powf(a - 2.0),
153 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
154 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
155 ])
156 }
157
158 /// `ln Γ(self)`. Caller guarantees a positive argument. Uses the SAME
159 /// hand-certified derivative stack [`crate::jet_tower::Tower4::ln_gamma`]
160 /// consumes ([`crate::jet_tower::ln_gamma_derivative_stack`]), so any
161 /// program written over both matches term-for-term.
162 fn ln_gamma(&self) -> Self {
163 self.compose_unary(crate::jet_tower::ln_gamma_derivative_stack(self.value()))
164 }
165
166 /// `ψ(self) = d/dx ln Γ(x)` (digamma). Caller guarantees a positive
167 /// argument. Same hand-certified stack
168 /// [`crate::jet_tower::digamma_derivative_stack`].
169 fn digamma(&self) -> Self {
170 self.compose_unary(crate::jet_tower::digamma_derivative_stack(self.value()))
171 }
172}
173
174// ── Order2<K> ergonomic operator overloads (doc §A.1) ───────────────────
175//
176// The dispersion-family row NLLs are written with `+`/`-`/`*` operators over
177// the primaries (mirroring how they read as `Tower4` expressions). These
178// delegate channel-for-channel to the inner `Tower2` arithmetic (which has
179// `Add`/`Mul`; `Sub`/`Neg` are expressed as `+ (-1)·rhs` exactly as the
180// `JetScalar::sub` / `JetScalar::neg` impls do), so an `Order2` expression is
181// bit-identical to the same `Tower4` expression's order-≤2 channels.
182
183impl<const K: usize> std::ops::Add for Order2<K> {
184 type Output = Self;
185 #[inline]
186 fn add(self, o: Self) -> Self {
187 Order2(self.0 + o.0)
188 }
189}
190
191impl<const K: usize> std::ops::Add<f64> for Order2<K> {
192 type Output = Self;
193 #[inline]
194 fn add(self, c: f64) -> Self {
195 Order2(self.0 + c)
196 }
197}
198
199impl<const K: usize> std::ops::Sub for Order2<K> {
200 type Output = Self;
201 #[inline]
202 fn sub(self, o: Self) -> Self {
203 Order2(self.0 + o.0.scale(-1.0))
204 }
205}
206
207impl<const K: usize> std::ops::Sub<f64> for Order2<K> {
208 type Output = Self;
209 #[inline]
210 fn sub(self, c: f64) -> Self {
211 Order2(self.0 + (-c))
212 }
213}
214
215impl<const K: usize> std::ops::Mul for Order2<K> {
216 type Output = Self;
217 #[inline]
218 fn mul(self, o: Self) -> Self {
219 Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
220 }
221}
222
223impl<const K: usize> std::ops::Mul<f64> for Order2<K> {
224 type Output = Self;
225 #[inline]
226 fn mul(self, c: f64) -> Self {
227 Order2(self.0.scale(c))
228 }
229}
230
231impl<const K: usize> std::ops::Neg for Order2<K> {
232 type Output = Self;
233 #[inline]
234 fn neg(self) -> Self {
235 Order2(self.0.scale(-1.0))
236 }
237}
238
239/// Filtered Hensel lift of a SCALAR implicit state `a(θ)` defined by the
240/// constraint `F(a, θ) = 0`, evaluated in ANY [`JetScalar`] algebra `S` (doc
241/// §11, "A generic implicit-lift operator for every production scalar").
242///
243/// This is the perf-respecting alternative to lifting through a dense
244/// `Tower4<K+1>` (which carries the implicit variable as an extra dense axis):
245/// the state `a` lives directly in the consumer's own `K`-primary algebra
246/// `S` — `Order2<K>` for value/gradient/Hessian, `Tower4<K>` for the full
247/// `t3`/`t4` — never paying for an extra variable.
248///
249/// **Method.** Fixed-Jacobian Newton in the nilpotent algebra. By the
250/// filtered-lift theorem (doc §11.1), if `F_a := ∂F/∂a(a₀, θ₀)` is the primal
251/// Jacobian at the base point and `inv_fa = 1/F_a`, then the iteration
252/// `A ← A − inv_fa · F(A, θ)` raises the filtration degree of the residual by
253/// at least one per step: each step kills exactly one graded layer. Starting
254/// from `A = const(a₀)` (whose residual lies in `F¹` because `θ − θ₀ ∈ 𝔫`),
255/// `iters` equal to the algebra's nilpotency order returns the *exact* lifted
256/// jet (`Order2`: 2, `OneSeed`: 3, `Tower4`/`TwoSeed`: 4). The value channel of
257/// `A` never moves — `F(A, θ).value() = F(a₀, θ₀) = 0` at the certified root —
258/// so a caller may precompute every primitive's derivative stack at the fixed
259/// base index once and let the cheap polynomial composition repeat per step.
260///
261/// `f` evaluates the constraint `F(a, θ)` in `S` (capturing the seeded
262/// parameter jets `θ`); `a0` is the certified scalar root `F(a₀, θ₀) ≈ 0`.
263pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
264 a0: f64,
265 inv_fa: f64,
266 iters: usize,
267 f: impl Fn(&S) -> S,
268) -> S {
269 let mut a = S::constant(a0);
270 for _ in 0..iters {
271 let residual = f(&a);
272 a = a.sub(&residual.scale(inv_fa));
273 }
274 a
275}
276
277// ── Order2<K>: value / gradient / Hessian (doc §A.1) ────────────────────
278
279/// Truncated SECOND-order scalar: value `v`, gradient `g_a`, Hessian `H_{ab}`.
280///
281/// This is a thin newtype over the production [`crate::jet_tower::Tower2`], so
282/// its `(v, g, H)` channels are obtained by the SAME formulas — and are
283/// therefore bit-identical to both [`crate::jet_tower::Tower2`] and the order-≤2
284/// channels of a full [`crate::jet_tower::Tower4`] (doc §A.1, "Bit-identity with
285/// the full tower"). The wrapper exists only to satisfy the generic
286/// [`JetScalar`] interface (the `compose_unary` / `add` / `sub` / `neg` /
287/// `recip` the trait demands, which `Tower2` does not expose by that shape) —
288/// every channel is delegated to `Tower2` arithmetic unchanged.
289#[derive(Clone, Copy, Debug)]
290pub struct Order2<const K: usize>(pub crate::jet_tower::Tower2<K>);
291
292impl<const K: usize> Order2<K> {
293 /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
294 #[inline]
295 pub fn g(&self) -> [f64; K] {
296 self.0.g
297 }
298
299 /// Read the Hessian channel.
300 #[inline]
301 pub fn h(&self) -> [[f64; K]; K] {
302 self.0.h
303 }
304}
305
306impl<const K: usize> JetScalar<K> for Order2<K> {
307 fn constant(c: f64) -> Self {
308 Order2(crate::jet_tower::Tower2::constant(c))
309 }
310 fn variable(x: f64, axis: usize) -> Self {
311 Order2(crate::jet_tower::Tower2::variable(x, axis))
312 }
313 fn value(&self) -> f64 {
314 self.0.v
315 }
316 fn add(&self, o: &Self) -> Self {
317 Order2(self.0 + o.0)
318 }
319 fn sub(&self, o: &Self) -> Self {
320 // Tower2 has no Sub op; subtract by adding the negation, matching
321 // Tower4::sub (self + o.scale(-1.0)).
322 Order2(self.0 + o.0.scale(-1.0))
323 }
324 fn mul(&self, o: &Self) -> Self {
325 Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
326 }
327 fn neg(&self) -> Self {
328 Order2(self.0.scale(-1.0))
329 }
330 fn scale(&self, s: f64) -> Self {
331 Order2(self.0.scale(s))
332 }
333 fn compose_unary(&self, d: [f64; 5]) -> Self {
334 // Order-≤2 reads only [f, f', f''] of the stack.
335 Order2(self.0.compose_unary([d[0], d[1], d[2]]))
336 }
337}
338
339// ── OneSeed<K>: one-seed directional, contracted third (doc §A.2) ───────
340
341/// One-seed directional scalar: an [`Order2`] base plus ONE nilpotent ε
342/// (`ε² = 0`) whose coefficient is itself an [`Order2`].
343///
344/// A scalar is `s = base + ε·eps`. Arithmetic is the `ε² = 0` truncation of the
345/// product (doc §A.2): the base parts multiply as ordinary [`Order2`] products,
346/// and the ε-coefficient picks up `a.base·b.eps + a.eps·b.base`. Composition
347/// pushes ε through one extra outer derivative.
348///
349/// Seed each primary with [`seed_direction`](Self::seed_direction): the base is
350/// the usual seeded variable (carrying `e_a` for the Hessian channel) and the
351/// ε-coefficient is the FIXED contraction direction `u_a` (a constant). Then the
352/// ε-component of the evaluated Hessian channel is the contracted third
353/// `[eps.h][a][b] = Σ_c ℓ_{abc} u_c` — exactly `row_third_contracted(dir = u)`,
354/// without materialising `t3`.
355#[derive(Clone, Copy, Debug)]
356pub struct OneSeed<const K: usize> {
357 /// The `ε⁰` part: value / gradient / Hessian of `ℓ`.
358 pub base: Order2<K>,
359 /// The `ε¹` part: value / gradient / Hessian of the ε-coefficient. After a
360 /// `seed_direction(u)` evaluation, `eps.h[a][b] = Σ_c ℓ_{abc} u_c`.
361 pub eps: Order2<K>,
362}
363
364impl<const K: usize> OneSeed<K> {
365 /// Seed primary `axis` at value `x` with ε-direction component `u_axis`:
366 /// `p_axis = p_axis⁰ + x-seed + ε·u_axis`, i.e. base = `variable(x, axis)`
367 /// and eps = `constant(u_axis)` (doc §A.2 "Seeding").
368 pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
369 OneSeed {
370 base: Order2::variable(x, axis),
371 eps: Order2::constant(u_axis),
372 }
373 }
374
375 /// The contracted-third channel after a `seed_direction(u)` evaluation:
376 /// `out[a][b] = Σ_c ℓ_{abc} u_c`, i.e. the ε-coefficient's Hessian (doc §A.2).
377 pub fn contracted_third(&self) -> [[f64; K]; K] {
378 self.eps.h()
379 }
380}
381
382impl<const K: usize> JetScalar<K> for OneSeed<K> {
383 fn constant(c: f64) -> Self {
384 OneSeed {
385 base: Order2::constant(c),
386 eps: Order2::constant(0.0),
387 }
388 }
389 fn variable(x: f64, axis: usize) -> Self {
390 // No ε-direction unless seeded via `seed_direction`.
391 OneSeed {
392 base: Order2::variable(x, axis),
393 eps: Order2::constant(0.0),
394 }
395 }
396 fn value(&self) -> f64 {
397 self.base.value()
398 }
399 fn add(&self, o: &Self) -> Self {
400 OneSeed {
401 base: self.base.add(&o.base),
402 eps: self.eps.add(&o.eps),
403 }
404 }
405 fn sub(&self, o: &Self) -> Self {
406 OneSeed {
407 base: self.base.sub(&o.base),
408 eps: self.eps.sub(&o.eps),
409 }
410 }
411 fn mul(&self, o: &Self) -> Self {
412 // (a.base + ε a.eps)(b.base + ε b.eps), dropping ε².
413 OneSeed {
414 base: self.base.mul(&o.base),
415 eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
416 }
417 }
418 fn neg(&self) -> Self {
419 OneSeed {
420 base: self.base.neg(),
421 eps: self.eps.neg(),
422 }
423 }
424 fn scale(&self, s: f64) -> Self {
425 OneSeed {
426 base: self.base.scale(s),
427 eps: self.eps.scale(s),
428 }
429 }
430 fn compose_unary(&self, d: [f64; 5]) -> Self {
431 // f(base + ε eps) = f(base) + ε · f'(base)·eps (ε² = 0). Each factor is
432 // an Order2 composition: the base composes with the f-stack, and the
433 // ε-coefficient is the Order2 of the SHIFTED stack (the chain rule
434 // `f'(base)` as an Order2) times eps. Order2 reads only the leading
435 // three entries of whatever stack it is handed, so the trailing slots
436 // are unused padding (the fixed-length array makes the windowing total).
437 let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
438 // f'(base) as an Order2 (consumes [f', f'', f''']).
439 let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
440 let eps = fprime.mul(&self.eps);
441 OneSeed { base, eps }
442 }
443}
444
445// ── TwoSeed<K>: two-seed, contracted fourth (doc §A.3) ──────────────────
446
447/// Two-seed scalar: an [`Order2`] base plus TWO nilpotents ε, δ
448/// (`ε² = δ² = 0`, `εδ` retained) — four [`Order2`] parts
449/// `s = base + ε·eps + δ·del + εδ·eps_del`.
450///
451/// Product truncates `ε² = δ² = 0` (doc §A.3): each part is built from
452/// [`Order2`] products of the four input parts. Composition picks up
453/// successively higher outer derivatives, the cross part carrying the second
454/// Faà di Bruno term `f''·eps·del + f'·eps_del`.
455///
456/// Seed each primary with [`seed`](Self::seed): base = `variable(x, axis)`,
457/// eps = `constant(u_axis)`, del = `constant(v_axis)`, eps_del = `constant(0)`.
458/// Then the εδ-component of the evaluated Hessian channel is the contracted
459/// fourth `[eps_del.h][a][b] = Σ_{cd} ℓ_{abcd} u_c v_d` — exactly
460/// `row_fourth_contracted(u, v)`, without materialising `t4`.
461#[derive(Clone, Copy, Debug)]
462pub struct TwoSeed<const K: usize> {
463 /// The `ε⁰δ⁰` part: value / grad / Hessian of `ℓ`.
464 pub base: Order2<K>,
465 /// The `ε¹δ⁰` part.
466 pub eps: Order2<K>,
467 /// The `ε⁰δ¹` part.
468 pub del: Order2<K>,
469 /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation,
470 /// `eps_del.h[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`.
471 pub eps_del: Order2<K>,
472}
473
474impl<const K: usize> TwoSeed<K> {
475 /// Seed primary `axis` at value `x` with ε-direction `u_axis` and
476 /// δ-direction `v_axis`:
477 /// `p_axis = p_axis⁰ + x-seed + ε·u_axis + δ·v_axis` (doc §A.3 "Seeding").
478 pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
479 TwoSeed {
480 base: Order2::variable(x, axis),
481 eps: Order2::constant(u_axis),
482 del: Order2::constant(v_axis),
483 eps_del: Order2::constant(0.0),
484 }
485 }
486
487 /// The contracted-fourth channel after a `seed(u, v)` evaluation:
488 /// `out[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`, i.e. the εδ-coefficient's Hessian.
489 pub fn contracted_fourth(&self) -> [[f64; K]; K] {
490 self.eps_del.h()
491 }
492}
493
494impl<const K: usize> JetScalar<K> for TwoSeed<K> {
495 fn constant(c: f64) -> Self {
496 TwoSeed {
497 base: Order2::constant(c),
498 eps: Order2::constant(0.0),
499 del: Order2::constant(0.0),
500 eps_del: Order2::constant(0.0),
501 }
502 }
503 fn variable(x: f64, axis: usize) -> Self {
504 TwoSeed {
505 base: Order2::variable(x, axis),
506 eps: Order2::constant(0.0),
507 del: Order2::constant(0.0),
508 eps_del: Order2::constant(0.0),
509 }
510 }
511 fn value(&self) -> f64 {
512 self.base.value()
513 }
514 fn add(&self, o: &Self) -> Self {
515 TwoSeed {
516 base: self.base.add(&o.base),
517 eps: self.eps.add(&o.eps),
518 del: self.del.add(&o.del),
519 eps_del: self.eps_del.add(&o.eps_del),
520 }
521 }
522 fn sub(&self, o: &Self) -> Self {
523 TwoSeed {
524 base: self.base.sub(&o.base),
525 eps: self.eps.sub(&o.eps),
526 del: self.del.sub(&o.del),
527 eps_del: self.eps_del.sub(&o.eps_del),
528 }
529 }
530 fn mul(&self, o: &Self) -> Self {
531 let a = self;
532 let b = o;
533 // Truncate ε² = δ² = 0 (doc §A.3 product table).
534 let base = a.base.mul(&b.base);
535 let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
536 let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
537 let eps_del = a
538 .base
539 .mul(&b.eps_del)
540 .add(&a.eps.mul(&b.del))
541 .add(&a.del.mul(&b.eps))
542 .add(&a.eps_del.mul(&b.base));
543 TwoSeed {
544 base,
545 eps,
546 del,
547 eps_del,
548 }
549 }
550 fn neg(&self) -> Self {
551 TwoSeed {
552 base: self.base.neg(),
553 eps: self.eps.neg(),
554 del: self.del.neg(),
555 eps_del: self.eps_del.neg(),
556 }
557 }
558 fn scale(&self, s: f64) -> Self {
559 TwoSeed {
560 base: self.base.scale(s),
561 eps: self.eps.scale(s),
562 del: self.del.scale(s),
563 eps_del: self.eps_del.scale(s),
564 }
565 }
566 fn compose_unary(&self, d: [f64; 5]) -> Self {
567 // f(s) with s = base + ε eps + δ del + εδ eps_del, ε²=δ²=0:
568 // f(s) = f(base)
569 // + ε · f'(base)·eps
570 // + δ · f'(base)·del
571 // + εδ · ( f''(base)·eps·del + f'(base)·eps_del ).
572 // Each f^{(r)}(base) is the Order2 composition of base with the stack
573 // shifted r entries (doc §A.3 composition). Order2 reads only the
574 // leading three entries of whatever stack it is handed, so the trailing
575 // padding slots are unused (the fixed-length array makes this total).
576 let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
577 let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); // f'(base) as Order2
578 let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); // f''(base) as Order2
579 let eps = fprime.mul(&self.eps);
580 let del = fprime.mul(&self.del);
581 let eps_del = fsecond
582 .mul(&self.eps)
583 .mul(&self.del)
584 .add(&fprime.mul(&self.eps_del));
585 TwoSeed {
586 base,
587 eps,
588 del,
589 eps_del,
590 }
591 }
592}
593
594// ── Tower4<K>: full dense tower as a JetScalar (the all-channels scalar) ─
595
596/// The full dense [`crate::jet_tower::Tower4`] is itself a [`JetScalar`]: it
597/// carries EVERY channel, so a row expression written ONCE against [`JetScalar`]
598/// can be evaluated at `Tower4` to obtain the full `(v, g, H, t3, t4)` in one
599/// pass. This is BOTH the #932 oracle ground truth the packed [`Order2`] /
600/// [`OneSeed`] / [`TwoSeed`] scalars are pinned against, AND a production scalar:
601/// a family whose uncontracted third / fourth derivative tensors are needed
602/// (the BMS rigid `third_full` / `fourth_full` caches) evaluates the SAME
603/// generic row-NLL expression at `Tower4` and reads `.t3` / `.t4` off the
604/// result — so the dense tensors come from the single source of truth, not a
605/// separately hand-written jet. The packed scalars serve the consumers that
606/// need only `(v, g, H)` (`Order2`) or one / two contractions
607/// (`OneSeed` / `TwoSeed`) without paying for the dense tensors.
608impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower4<K> {
609 fn constant(c: f64) -> Self {
610 crate::jet_tower::Tower4::constant(c)
611 }
612 fn variable(x: f64, axis: usize) -> Self {
613 crate::jet_tower::Tower4::variable(x, axis)
614 }
615 fn value(&self) -> f64 {
616 self.v
617 }
618 fn add(&self, o: &Self) -> Self {
619 *self + *o
620 }
621 fn sub(&self, o: &Self) -> Self {
622 *self - *o
623 }
624 fn mul(&self, o: &Self) -> Self {
625 crate::jet_tower::Tower4::mul(self, o)
626 }
627 fn neg(&self) -> Self {
628 self.scale(-1.0)
629 }
630 fn scale(&self, s: f64) -> Self {
631 crate::jet_tower::Tower4::scale(self, s)
632 }
633 fn compose_unary(&self, d: [f64; 5]) -> Self {
634 crate::jet_tower::Tower4::compose_unary(self, d)
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::jet_tower::{RowNllProgram, Tower4, evaluate_program};
642
643 /// A small polynomial-plus-unary row expression written ONCE, generically
644 /// over `S: JetScalar<2>`, so it can be evaluated against every scalar:
645 /// `ℓ = (e^{p0·p1} + 2) · √(p0·p0 + 1) − p1·p1·0.5`.
646 /// Exercises mul, add/sub, scale, exp, sqrt — every algebra op.
647 fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
648 let g = p[0].mul(&p[1]).exp();
649 let inner = g.add(&S::constant(2.0));
650 let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
651 inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
652 }
653
654 /// The same expression as a Tower4 `RowNllProgram`, the ground-truth tower.
655 struct ExprProgram {
656 p: [f64; 2],
657 }
658 impl RowNllProgram<2> for ExprProgram {
659 fn n_rows(&self) -> usize {
660 1
661 }
662 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
663 if row >= self.n_rows() {
664 return Err(format!("ExprProgram: row {row} out of range"));
665 }
666 Ok(self.p)
667 }
668 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
669 if row >= self.n_rows() {
670 return Err(format!("ExprProgram: row {row} out of range"));
671 }
672 Ok(row_expr(p))
673 }
674 }
675
676 const SEED: [f64; 2] = [0.37, -0.81];
677 const U: [f64; 2] = [0.6, -0.2];
678 const V: [f64; 2] = [-0.4, 1.1];
679 const TOL: f64 = 1e-10;
680
681 fn close(a: f64, b: f64, label: &str) {
682 let band = TOL + TOL * a.abs().max(b.abs());
683 assert!(
684 (a - b).abs() <= band,
685 "{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
686 );
687 }
688
689 fn tower() -> Tower4<2> {
690 evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
691 }
692
693 /// Order2 reproduces Tower4's value/grad/Hessian channels exactly.
694 #[test]
695 fn order2_matches_tower_value_grad_hessian() {
696 let t = tower();
697 let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
698 let s = row_expr(&vars);
699 close(s.value(), t.v, "value");
700 for a in 0..2 {
701 close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
702 for b in 0..2 {
703 close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
704 }
705 }
706 }
707
708 /// OneSeed's ε-Hessian is the contracted third Σ_c ℓ_{abc} u_c, matching
709 /// `Tower4::third_contracted(u)`. Base channels also match the tower.
710 #[test]
711 fn one_seed_matches_tower_third_contracted() {
712 let t = tower();
713 let truth = t.third_contracted(&U);
714 let vars: [OneSeed<2>; 2] =
715 std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
716 let s = row_expr(&vars);
717 // Base channels are the plain (v, g, H).
718 close(s.value(), t.v, "value");
719 for a in 0..2 {
720 for b in 0..2 {
721 close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
722 }
723 }
724 let third = s.contracted_third();
725 for a in 0..2 {
726 for b in 0..2 {
727 close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
728 }
729 }
730 }
731
732 /// TwoSeed's εδ-Hessian is the contracted fourth Σ_{cd} ℓ_{abcd} u_c v_d,
733 /// matching `Tower4::fourth_contracted(u, v)`. The ε / δ single-seed parts
734 /// reproduce the two third contractions Σ_c ℓ_{abc} u_c and …v_d.
735 #[test]
736 fn two_seed_matches_tower_fourth_contracted() {
737 let t = tower();
738 let truth4 = t.fourth_contracted(&U, &V);
739 let truth3_u = t.third_contracted(&U);
740 let truth3_v = t.third_contracted(&V);
741 let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
742 let s = row_expr(&vars);
743 close(s.value(), t.v, "value");
744 for a in 0..2 {
745 close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
746 for b in 0..2 {
747 close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
748 close(
749 s.eps.h()[a][b],
750 truth3_u[a][b],
751 &format!("eps third_u[{a}][{b}]"),
752 );
753 close(
754 s.del.h()[a][b],
755 truth3_v[a][b],
756 &format!("del third_v[{a}][{b}]"),
757 );
758 }
759 }
760 let fourth = s.contracted_fourth();
761 for a in 0..2 {
762 for b in 0..2 {
763 close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
764 }
765 }
766 }
767
768 /// The generic `row_nll_generic` seam (added to Tower4's program trait
769 /// surface) evaluates the SAME expression on each scalar and extracts the
770 /// channel a consumer asks for, agreeing with the direct Tower4 contraction.
771 #[test]
772 fn generic_program_seam_matches_tower_for_every_channel() {
773 let t = tower();
774 // Order2 via generic seam.
775 let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
776 let so2 = row_expr(&o2);
777 close(so2.value(), t.v, "seam order2 value");
778 // OneSeed third.
779 let os: [OneSeed<2>; 2] =
780 std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
781 let third = row_expr(&os).contracted_third();
782 let truth3 = t.third_contracted(&U);
783 for a in 0..2 {
784 for b in 0..2 {
785 close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
786 }
787 }
788 // TwoSeed fourth.
789 let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
790 let fourth = row_expr(&ts).contracted_fourth();
791 let truth4 = t.fourth_contracted(&U, &V);
792 for a in 0..2 {
793 for b in 0..2 {
794 close(
795 fourth[a][b],
796 truth4[a][b],
797 &format!("seam fourth[{a}][{b}]"),
798 );
799 }
800 }
801 }
802
803 /// The (test-only) `Tower4: JetScalar` impl is the all-channels oracle scalar:
804 /// evaluating the SAME generic `row_expr` at `S = Tower4` (through the
805 /// `JetScalar` trait ops) must reproduce, channel-for-channel, the `Tower4`
806 /// obtained from the `RowNllProgram` / inherent-operator path
807 /// (`evaluate_program`). This pins that the trait impl delegates faithfully to
808 /// the inherent `Tower4` arithmetic (so the contracted-scalar oracles above,
809 /// which compare against `evaluate_program`'s tower, are comparing against the
810 /// same algebra the `JetScalar` interface exposes).
811 #[test]
812 fn tower4_as_jetscalar_matches_program_tower_all_channels() {
813 let t = tower();
814 let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
815 let s = row_expr(&vars);
816 close(s.v, t.v, "tower-jetscalar value");
817 for a in 0..2 {
818 close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
819 for b in 0..2 {
820 close(
821 s.h[a][b],
822 t.h[a][b],
823 &format!("tower-jetscalar hess[{a}][{b}]"),
824 );
825 for c in 0..2 {
826 close(
827 s.t3[a][b][c],
828 t.t3[a][b][c],
829 &format!("tower-jetscalar t3[{a}][{b}][{c}]"),
830 );
831 for d in 0..2 {
832 close(
833 s.t4[a][b][c][d],
834 t.t4[a][b][c][d],
835 &format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
836 );
837 }
838 }
839 }
840 }
841 }
842}