Skip to main content

gam_math/
jet_scalar.rs

1//! Order-specific Taylor-jet SCALAR algebras (#932 cutover, doc §A).
2//!
3//! [`crate::jet_tower::Tower4`] carries the full value/gradient/Hessian/`t3`/`t4`
4//! tensor stack: it answers EVERY channel a [`super::row_kernel::RowKernel`]
5//! consumer can ask for, but at `K = 9` that is a ~50 KiB per-row object whose
6//! by-value copies overflowed the stack and timed out the location-scale fit —
7//! which is exactly why `row_kernel_directional_supported()` /
8//! `row_kernel_joint_hessian_supported()` still `return false`. The cutover does
9//! NOT need the dense `Tower4<9>` per row; it needs, per consumer, only the one
10//! channel that consumer serves:
11//!
12//! | consumer | channel | scalar here | K=9 size |
13//! |---|---|---|---|
14//! | inner Newton / `row_kernel` | `(v, g, H)` | [`Order2`] | 728 B |
15//! | `row_third_contracted(dir)` | `Σ_c ℓ_{abc} dir_c` | [`OneSeed`] | 1.46 KiB |
16//! | `row_fourth_contracted(u, v)` | `Σ_{cd} ℓ_{abcd} u_c v_d` | [`TwoSeed`] | 2.8 KiB |
17//!
18//! Each is built on [`Order2`] (value/grad/Hessian), which is the production
19//! [`crate::jet_tower::Tower2`] re-expressed behind a generic interface: a row
20//! loss written ONCE against [`JetScalar`] re-instantiates at whatever order /
21//! representation a consumer needs, with the contraction folded INTO the
22//! differentiation (the nilpotent ε / δ directions), so `t3` / `t4` are never
23//! materialised. The single source of truth is the same one expression — the
24//! genus of #736 cross-block drift cannot reappear because there is no separate
25//! channel to forget.
26//!
27//! # Why each scalar is exact (doc §A.1–A.3)
28//!
29//! * [`Order2`] is the order-≤2 truncation of the Leibniz / Faà di Bruno rules.
30//!   Those order-2 terms read ONLY the order-≤2 channels of their inputs (see
31//!   [`crate::jet_tower::Tower4::mul`]: `out.h[i][j]` never touches `t3`/`t4`),
32//!   so its `(v, g, H)` is BIT-IDENTICAL to a full `Tower4<K>` — and identical
33//!   to [`crate::jet_tower::Tower2`], over which it is a thin newtype.
34//! * [`OneSeed`] carries an [`Order2`] base plus one nilpotent ε (`ε² = 0`)
35//!   holding another [`Order2`]. Seeding ε with the fixed direction `u` makes the
36//!   ε-component of the Hessian channel the contracted third `Σ_c ℓ_{abc} u_c`
37//!   (the nilpotent implements `d/dτ|₀` of `ℓ_{ab}(p + τu)` exactly).
38//! * [`TwoSeed`] carries an [`Order2`] base plus ε, δ (`ε² = δ² = 0`, `εδ`
39//!   retained) — four [`Order2`] parts. Seeding ε, δ with `u, v` makes the
40//!   εδ-component of the Hessian channel the contracted fourth
41//!   `Σ_{cd} ℓ_{abcd} u_c v_d` (the single mixed `∂_σ∂_ρ|₀` term, no `σ²`/`ρ²`
42//!   contamination).
43//!
44//! # Stability discipline
45//!
46//! As in [`crate::jet_tower`], humans own primitive stability and the algebra
47//! owns combinatorics: tail-critical special functions enter ONLY as
48//! hand-certified `[f64; 5]` derivative stacks through [`JetScalar::compose_unary`]
49//! (each scalar consumes the leading entries its order needs), never by
50//! differentiating an unstable primal.
51//!
52//! # Production scalars and the test-only all-channels oracle
53//!
54//! The `JetScalar` trait below is production: it is the bound on
55//! [`crate::jet_tower::RowNllProgramGeneric::row_nll_generic`], the seam a family
56//! row loss is written against. The order-specific scalars that *consume* it —
57//! [`Order2`] (value/grad/Hessian), [`OneSeed`] (contracted third) and
58//! [`TwoSeed`] (contracted fourth) — are production: the survival location-scale
59//! `RowKernel<9>` builds its joint Hessian / directional derivatives through them
60//! (`survival::location_scale::row_kernel`), paying only the small packed scalar
61//! per row instead of the ~50 KiB dense [`crate::jet_tower::Tower4`].
62//!
63//! The [`crate::jet_tower::Tower4`] all-channels `JetScalar` impl is test-only: it
64//! is the oracle that pins the contracted scalars against the dense
65//! value/grad/Hessian/`t3`/`t4` truth, so it lives in the `#[cfg(test)]` module.
66
67/// A truncated-Taylor scalar carrying derivatives in `K` primaries.
68///
69/// All concrete scalars here ([`Order2`], [`OneSeed`], [`TwoSeed`]) and the full
70/// [`crate::jet_tower::Tower4`] implement the SAME algebra; only the carried
71/// channel set differs. A row loss written once against this interface yields a
72/// different channel set per instantiation, all exact for the channel they serve
73/// (doc §A.0).
74pub trait JetScalar<const K: usize>: Copy {
75    /// A constant: value `c`, every derivative channel zero.
76    fn constant(c: f64) -> Self;
77
78    /// The seeded variable `p_axis` at value `x`: unit first derivative in slot
79    /// `axis`, all higher channels zero. (The nilpotent / cross channels of the
80    /// directional scalars are seeded zero — callers set ε/δ directions through
81    /// the scalar-specific [`OneSeed::seed_direction`] / [`TwoSeed::seed`].)
82    fn variable(x: f64, axis: usize) -> Self;
83
84    /// The value channel `ℓ(p)`.
85    fn value(&self) -> f64;
86
87    /// Exact truncated Leibniz sum `self + o`.
88    fn add(&self, o: &Self) -> Self;
89    /// Exact truncated Leibniz difference `self − o`.
90    fn sub(&self, o: &Self) -> Self;
91    /// Exact truncated Leibniz product `self · o`.
92    fn mul(&self, o: &Self) -> Self;
93    /// Negate every channel.
94    fn neg(&self) -> Self;
95    /// Multiply every channel by a plain scalar `s`.
96    fn scale(&self, s: f64) -> Self;
97
98    /// Exact multivariate Faà di Bruno composition `f ∘ self`, given the outer
99    /// derivative stack `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` at
100    /// `u = self.value()`.
101    ///
102    /// This is the SAME `[f64; 5]` stack shape [`crate::jet_tower::Tower4`] and
103    /// the families' `unary_derivatives_*` helpers (built on erfcx / log_ndtr)
104    /// already produce, so those stacks plug in directly. Each scalar consumes
105    /// only the leading entries its order needs (order-2 reads `d[0..=2]`; the
106    /// directional scalars read one / two beyond their base) — the fixed-length
107    /// array makes that windowing total, no length guard required.
108    fn compose_unary(&self, d: [f64; 5]) -> Self;
109
110    /// Compose with a unary special-function whose derivative STACK is built
111    /// from the scalar base value through `stack_fn` — the generic-over-`Lane`
112    /// seam that lets a single-sourced row program instantiate at BOTH the scalar
113    /// `f64` jets and the SIMD `f64x4` batch towers from ONE expression.
114    ///
115    /// On a scalar jet this evaluates `stack_fn(self.value())` ONCE and forwards
116    /// to [`compose_unary`](Self::compose_unary), so it is BIT-IDENTICAL to the
117    /// hand-written `self.compose_unary(stack_fn(self.value()))` (default body
118    /// below). The lever is that the SAME call shape exists on
119    /// [`crate::jet_tower::Tower3Lane`] / [`crate::jet_tower::Tower4Lane`], where
120    /// the four lanes carry FOUR DISTINCT base values, so the batch
121    /// implementation re-runs `stack_fn` per lane — a thing the old
122    /// `compose_unary(stack_from(self.value()))` shape could not express on a
123    /// batch type (it has no single scalar `.value()`). Writing a row program
124    /// against this method instead of the explicit two-step is what makes it
125    /// instantiate, unchanged, at `f64x4` for the 4-rows-per-pass batch path.
126    fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
127        self.compose_unary(stack_fn(self.value()))
128    }
129
130    /// `e^self`. Convenience for tame arguments (see module stability note).
131    fn exp(&self) -> Self {
132        let e = self.value().exp();
133        self.compose_unary([e, e, e, e, e])
134    }
135
136    /// `√self`. Caller guarantees positivity.
137    fn sqrt(&self) -> Self {
138        let u = self.value();
139        let s = u.sqrt();
140        self.compose_unary([
141            s,
142            0.5 / s,
143            -0.25 / (u * s),
144            0.375 / (u * u * s),
145            -0.9375 / (u * u * u * s),
146        ])
147    }
148
149    /// `ln(self)`. Caller guarantees positivity. Same derivative stack
150    /// [`crate::jet_tower::Tower4::ln`] uses, so any program written over both
151    /// matches term-for-term.
152    fn ln(&self) -> Self {
153        let u = self.value();
154        let r = 1.0 / u;
155        self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
156    }
157
158    /// `1/self`.
159    fn recip(&self) -> Self {
160        let r = 1.0 / self.value();
161        let r2 = r * r;
162        self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
163    }
164
165    /// `self^a` for real exponent `a`. Caller guarantees a positive base.
166    /// Mirrors [`crate::jet_tower::Tower4::powf`] (falling-factorial stack).
167    fn powf(&self, a: f64) -> Self {
168        let u = self.value();
169        self.compose_unary([
170            u.powf(a),
171            a * u.powf(a - 1.0),
172            a * (a - 1.0) * u.powf(a - 2.0),
173            a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
174            a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
175        ])
176    }
177
178    /// `ln Γ(self)`. Caller guarantees a positive argument. Uses the SAME
179    /// hand-certified derivative stack [`crate::jet_tower::Tower4::ln_gamma`]
180    /// consumes ([`crate::jet_tower::ln_gamma_derivative_stack`]), so any
181    /// program written over both matches term-for-term.
182    fn ln_gamma(&self) -> Self {
183        self.compose_unary(crate::jet_tower::ln_gamma_derivative_stack(self.value()))
184    }
185
186    /// `ψ(self) = d/dx ln Γ(x)` (digamma). Caller guarantees a positive
187    /// argument. Same hand-certified stack
188    /// [`crate::jet_tower::digamma_derivative_stack`].
189    fn digamma(&self) -> Self {
190        self.compose_unary(crate::jet_tower::digamma_derivative_stack(self.value()))
191    }
192}
193
194// ── Order2<K> ergonomic operator overloads (doc §A.1) ───────────────────
195//
196// The dispersion-family row NLLs are written with `+`/`-`/`*` operators over
197// the primaries (mirroring how they read as `Tower4` expressions). These
198// delegate channel-for-channel to the inner `Tower2` arithmetic (which has
199// `Add`/`Mul`; `Sub`/`Neg` are expressed as `+ (-1)·rhs` exactly as the
200// `JetScalar::sub` / `JetScalar::neg` impls do), so an `Order2` expression is
201// bit-identical to the same `Tower4` expression's order-≤2 channels.
202
203impl<const K: usize> std::ops::Add for Order2<K> {
204    type Output = Self;
205    #[inline]
206    fn add(self, o: Self) -> Self {
207        Order2(self.0 + o.0)
208    }
209}
210
211impl<const K: usize> std::ops::Add<f64> for Order2<K> {
212    type Output = Self;
213    #[inline]
214    fn add(self, c: f64) -> Self {
215        Order2(self.0 + c)
216    }
217}
218
219impl<const K: usize> std::ops::Sub for Order2<K> {
220    type Output = Self;
221    #[inline]
222    fn sub(self, o: Self) -> Self {
223        Order2(self.0 + o.0.scale(-1.0))
224    }
225}
226
227impl<const K: usize> std::ops::Sub<f64> for Order2<K> {
228    type Output = Self;
229    #[inline]
230    fn sub(self, c: f64) -> Self {
231        Order2(self.0 + (-c))
232    }
233}
234
235impl<const K: usize> std::ops::Mul for Order2<K> {
236    type Output = Self;
237    #[inline]
238    fn mul(self, o: Self) -> Self {
239        Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
240    }
241}
242
243impl<const K: usize> std::ops::Mul<f64> for Order2<K> {
244    type Output = Self;
245    #[inline]
246    fn mul(self, c: f64) -> Self {
247        Order2(self.0.scale(c))
248    }
249}
250
251impl<const K: usize> std::ops::Neg for Order2<K> {
252    type Output = Self;
253    #[inline]
254    fn neg(self) -> Self {
255        Order2(self.0.scale(-1.0))
256    }
257}
258
259/// Filtered Hensel lift of a SCALAR implicit state `a(θ)` defined by the
260/// constraint `F(a, θ) = 0`, evaluated in ANY [`JetScalar`] algebra `S` (doc
261/// §11, "A generic implicit-lift operator for every production scalar").
262///
263/// This is the perf-respecting alternative to lifting through a dense
264/// `Tower4<K+1>` (which carries the implicit variable as an extra dense axis):
265/// the state `a` lives directly in the consumer's own `K`-primary algebra
266/// `S` — `Order2<K>` for value/gradient/Hessian, `Tower4<K>` for the full
267/// `t3`/`t4` — never paying for an extra variable.
268///
269/// **Method.** Fixed-Jacobian Newton in the nilpotent algebra. By the
270/// filtered-lift theorem (doc §11.1), if `F_a := ∂F/∂a(a₀, θ₀)` is the primal
271/// Jacobian at the base point and `inv_fa = 1/F_a`, then the iteration
272/// `A ← A − inv_fa · F(A, θ)` raises the filtration degree of the residual by
273/// at least one per step: each step kills exactly one graded layer. Starting
274/// from `A = const(a₀)` (whose residual lies in `F¹` because `θ − θ₀ ∈ 𝔫`),
275/// `iters` equal to the algebra's nilpotency order returns the *exact* lifted
276/// jet (`Order2`: 2, `OneSeed`: 3, `Tower4`/`TwoSeed`: 4). The value channel of
277/// `A` never moves — `F(A, θ).value() = F(a₀, θ₀) = 0` at the certified root —
278/// so a caller may precompute every primitive's derivative stack at the fixed
279/// base index once and let the cheap polynomial composition repeat per step.
280///
281/// `f` evaluates the constraint `F(a, θ)` in `S` (capturing the seeded
282/// parameter jets `θ`); `a0` is the certified scalar root `F(a₀, θ₀) ≈ 0`.
283pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
284    a0: f64,
285    inv_fa: f64,
286    iters: usize,
287    f: impl Fn(&S) -> S,
288) -> S {
289    let mut a = S::constant(a0);
290    for _ in 0..iters {
291        let residual = f(&a);
292        a = a.sub(&residual.scale(inv_fa));
293    }
294    a
295}
296
297// ── Order2<K>: value / gradient / Hessian (doc §A.1) ────────────────────
298
299/// Truncated SECOND-order scalar: value `v`, gradient `g_a`, Hessian `H_{ab}`.
300///
301/// This is a thin newtype over the production [`crate::jet_tower::Tower2`], so
302/// its `(v, g, H)` channels are obtained by the SAME formulas — and are
303/// therefore bit-identical to both [`crate::jet_tower::Tower2`] and the order-≤2
304/// channels of a full [`crate::jet_tower::Tower4`] (doc §A.1, "Bit-identity with
305/// the full tower"). The wrapper exists only to satisfy the generic
306/// [`JetScalar`] interface (the `compose_unary` / `add` / `sub` / `neg` /
307/// `recip` the trait demands, which `Tower2` does not expose by that shape) —
308/// every channel is delegated to `Tower2` arithmetic unchanged.
309#[derive(Clone, Copy, Debug)]
310pub struct Order2<const K: usize>(pub crate::jet_tower::Tower2<K>);
311
312impl<const K: usize> Order2<K> {
313    /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
314    #[inline]
315    pub fn g(&self) -> [f64; K] {
316        self.0.g
317    }
318
319    /// Read the Hessian channel.
320    #[inline]
321    pub fn h(&self) -> [[f64; K]; K] {
322        self.0.h
323    }
324}
325
326impl<const K: usize> JetScalar<K> for Order2<K> {
327    fn constant(c: f64) -> Self {
328        Order2(crate::jet_tower::Tower2::constant(c))
329    }
330    fn variable(x: f64, axis: usize) -> Self {
331        Order2(crate::jet_tower::Tower2::variable(x, axis))
332    }
333    fn value(&self) -> f64 {
334        self.0.v
335    }
336    fn add(&self, o: &Self) -> Self {
337        Order2(self.0 + o.0)
338    }
339    fn sub(&self, o: &Self) -> Self {
340        // Tower2 has no Sub op; subtract by adding the negation, matching
341        // Tower4::sub (self + o.scale(-1.0)).
342        Order2(self.0 + o.0.scale(-1.0))
343    }
344    fn mul(&self, o: &Self) -> Self {
345        Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
346    }
347    fn neg(&self) -> Self {
348        Order2(self.0.scale(-1.0))
349    }
350    fn scale(&self, s: f64) -> Self {
351        Order2(self.0.scale(s))
352    }
353    fn compose_unary(&self, d: [f64; 5]) -> Self {
354        // Order-≤2 reads only [f, f', f''] of the stack.
355        Order2(self.0.compose_unary([d[0], d[1], d[2]]))
356    }
357}
358
359// ── Lane-batched Order-2 scalar: 4 rows per pass in SIMD lanes (perf) ────
360//
361// The hot per-row jet kernels evaluate ONE row's `(v, g, H)` tower at a time in
362// scalar `f64`. A hand-written scalar derivative does the same. The throughput
363// lever a jet has that scalar hand-code cannot is **row batching in SIMD
364// lanes**: the order-≤2 Leibniz product `Order2::mul` is `O(K²)` independent
365// per-channel float ops, and EVERY row runs the identical op graph on different
366// data — the textbook SPMD shape. Packing `LANES = 4` rows into a `wide::f64x4`
367// and running the algebra once per 4 rows replaces 4 scalar passes with one
368// vector pass: the `K²` Hessian channel updates become `K²` NEON `.2d` / SSE2
369// `pd` instructions covering 4 rows each, ~4× fewer FP instructions per row.
370//
371// The carried scalar field is abstracted by [`Lane`] so the SAME algebra body
372// instantiates at `f64` (1 row, used as the bit-identity oracle) or
373// [`wide::f64x4`] (4 rows). Bit-identity is structural, not approximate:
374//
375//   * Every arithmetic op is a plain lane-wise `+` / `-` / `*` (NEVER a fused
376//     `mul_add`), and IEEE-754 double `+`/`-`/`*`/`/` are correctly rounded and
377//     deterministic, so lane `i` of an `f64x4` op equals the scalar `f64` op on
378//     that lane's inputs bit-for-bit.
379//   * The transcendental derivative STACKS (`exp`/`ln`/`sqrt`/…) are produced
380//     **per lane by the identical scalar code** ([`Lane::unary3`] unpacks, runs
381//     the same `[f64; 3]` stack closure the scalar path runs, repacks), so the
382//     only thing vectorised is the cheap rational tensor composition — the
383//     library transcendental itself is the exact same `f64::exp` call per lane.
384//   * The op order mirrors [`crate::jet_tower::Tower2`] term-for-term, so
385//     [`Order2Lane<f64, K>`] is `to_bits`-identical to the production
386//     [`Order2<K>`] (= `Tower2<K>`), and [`Order2Lane<f64x4, K>`] lane `i` is
387//     `to_bits`-identical to that — proven by the `batch_tests` oracle below
388//     (≥2000 random 4-row batches across `K ∈ {2,3,4,9}`).
389
390/// The scalar field a [`Order2Lane`] carries: either a single `f64` (one row,
391/// the oracle) or a [`wide::f64x4`] (four rows evaluated in SIMD lanes). All ops
392/// are plain lane-wise IEEE arithmetic, so a vector op equals the scalar op on
393/// each lane bit-for-bit.
394pub trait Lane: Copy {
395    /// Broadcast a scalar to every lane.
396    fn splat(x: f64) -> Self;
397    /// Lane-wise `self + o`.
398    fn add(self, o: Self) -> Self;
399    /// Lane-wise `self - o`.
400    fn sub(self, o: Self) -> Self;
401    /// Lane-wise `self * o`.
402    fn mul(self, o: Self) -> Self;
403    /// The `f64` in lane `i` (`i < LANES`; `f64` ignores `i`).
404    fn lane(self, i: usize) -> f64;
405    /// Build the order-≤2 derivative stack `[f(u), f′(u), f″(u)]` **per lane**
406    /// from the lane value `u`, via the SAME scalar `stack` closure the
407    /// per-row path runs (so the transcendental/rational stack is bit-identical
408    /// to the scalar evaluation — only the subsequent tensor composition is
409    /// vectorised).
410    fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3];
411    /// Build the order-≤4 derivative stack `[f, f′, f″, f‴, f⁗]` **per lane**
412    /// from the lane value `u`, via the SAME scalar `stack` closure the per-row
413    /// path runs. The one-/two-seed scalars ([`OneSeedLane`] / [`TwoSeedLane`])
414    /// need outer derivatives one / two orders beyond their order-2 base, so
415    /// they build their composition stack through this five-entry variant. As
416    /// with [`unary3`](Lane::unary3), only the transcendental/rational stack is
417    /// evaluated per lane (bit-identically to the scalar path); the subsequent
418    /// tensor composition is vectorised.
419    fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5];
420    /// The general-`N` sibling of [`unary3`](Lane::unary3) / [`unary5`](Lane::unary5):
421    /// build an `N`-wide derivative stack **per lane** from the lane value, via
422    /// the SAME scalar `stack` closure the per-row path runs, then pack the `N`
423    /// columns lane-wise. This is the lane primitive the compose-with-stack seam
424    /// ([`crate::jet_tower::Tower4Lane::compose_unary_with`] and its `Tower3`
425    /// sibling) routes through: it evaluates `stack` once per lane at that lane's
426    /// OWN base value (each of the four rows in an `f64x4` carries a distinct
427    /// base), so lane `i` of the packed result equals the scalar `stack(value_i)`
428    /// bit-for-bit (only the cheap pack is vectorised; the closure body is the
429    /// identical scalar code). With `N = 3` / `N = 5` it is `to_bits`-identical to
430    /// [`unary3`](Lane::unary3) / [`unary5`](Lane::unary5).
431    fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N];
432}
433
434impl Lane for f64 {
435    #[inline]
436    fn splat(x: f64) -> Self {
437        x
438    }
439    #[inline]
440    fn add(self, o: Self) -> Self {
441        self + o
442    }
443    #[inline]
444    fn sub(self, o: Self) -> Self {
445        self - o
446    }
447    #[inline]
448    fn mul(self, o: Self) -> Self {
449        self * o
450    }
451    #[inline]
452    fn lane(self, _: usize) -> f64 {
453        self
454    }
455    #[inline]
456    fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
457        stack(self)
458    }
459    #[inline]
460    fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
461        stack(self)
462    }
463    #[inline]
464    fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
465        // One row: the packed result IS the scalar stack ([Self; N] = [f64; N]).
466        stack(self)
467    }
468}
469
470impl Lane for wide::f64x4 {
471    #[inline]
472    fn splat(x: f64) -> Self {
473        wide::f64x4::splat(x)
474    }
475    #[inline]
476    fn add(self, o: Self) -> Self {
477        self + o
478    }
479    #[inline]
480    fn sub(self, o: Self) -> Self {
481        self - o
482    }
483    #[inline]
484    fn mul(self, o: Self) -> Self {
485        self * o
486    }
487    #[inline]
488    fn lane(self, i: usize) -> f64 {
489        self.to_array()[i]
490    }
491    #[inline]
492    fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
493        let a = self.to_array();
494        let mut d0 = [0.0_f64; 4];
495        let mut d1 = [0.0_f64; 4];
496        let mut d2 = [0.0_f64; 4];
497        for i in 0..4 {
498            let s = stack(a[i]);
499            d0[i] = s[0];
500            d1[i] = s[1];
501            d2[i] = s[2];
502        }
503        [
504            wide::f64x4::new(d0),
505            wide::f64x4::new(d1),
506            wide::f64x4::new(d2),
507        ]
508    }
509    #[inline]
510    fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
511        let a = self.to_array();
512        let mut d = [[0.0_f64; 4]; 5];
513        for i in 0..4 {
514            let s = stack(a[i]);
515            for (k, dk) in d.iter_mut().enumerate() {
516                dk[i] = s[k];
517            }
518        }
519        [
520            wide::f64x4::new(d[0]),
521            wide::f64x4::new(d[1]),
522            wide::f64x4::new(d[2]),
523            wide::f64x4::new(d[3]),
524            wide::f64x4::new(d[4]),
525        ]
526    }
527    #[inline]
528    fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
529        // Evaluate the scalar stack PER LANE at that lane's own base value, then
530        // pack the N derivative columns lane-wise (the same shape `unary5` uses,
531        // generalised to N). Lane `i` of column `k` is `stack(base_i)[k]`.
532        let a = self.to_array();
533        let mut cols = [[0.0_f64; 4]; N];
534        for (i, &base) in a.iter().enumerate() {
535            let s = stack(base);
536            for (k, sk) in s.iter().enumerate() {
537                cols[k][i] = *sk;
538            }
539        }
540        std::array::from_fn(|k| wide::f64x4::new(cols[k]))
541    }
542}
543
544/// A lane-batched order-≤2 Taylor scalar: value / gradient / Hessian carried in
545/// a SIMD field [`L: Lane`](Lane). With `L = f64x4` one instance carries FOUR
546/// rows at once, so the row loop processes 4 rows per vector pass instead of one
547/// per scalar pass.
548///
549/// The channel layout and every float op mirror [`crate::jet_tower::Tower2`]
550/// term-for-term, so `Order2Lane<f64, K>` is `to_bits`-identical to the
551/// production [`Order2<K>`] and `Order2Lane<f64x4, K>` lane `i` is
552/// `to_bits`-identical to that (see the module note and `batch_tests`).
553#[derive(Clone, Copy, Debug)]
554pub struct Order2Lane<L: Lane, const K: usize> {
555    /// Value channel `ℓ` (one entry per lane/row).
556    pub v: L,
557    /// Gradient channel `∂ℓ/∂p_a`.
558    pub g: [L; K],
559    /// Hessian channel `∂²ℓ/∂p_a∂p_b` (symmetric).
560    pub h: [[L; K]; K],
561}
562
563/// The 4-rows-per-pass batched order-≤2 scalar (`wide::f64x4` lanes).
564pub type Order2Batch<const K: usize> = Order2Lane<wide::f64x4, K>;
565
566impl<L: Lane, const K: usize> Order2Lane<L, K> {
567    /// A constant: value `c` in every channel-zero slot.
568    #[inline]
569    pub fn constant(c: L) -> Self {
570        Order2Lane {
571            v: c,
572            g: [L::splat(0.0); K],
573            h: [[L::splat(0.0); K]; K],
574        }
575    }
576
577    /// The seeded variable `p_axis` at (per-lane) value `value`: unit first
578    /// derivative in slot `axis`. With `L = f64x4`, `value` packs the four
579    /// rows' values of primary `axis`.
580    #[inline]
581    pub fn variable(value: L, axis: usize) -> Self {
582        let mut out = Self::constant(value);
583        out.g[axis] = L::splat(1.0);
584        out
585    }
586
587    /// Lane-wise `self + o` (mirrors `Tower2` Add: per-channel add).
588    #[inline]
589    pub fn add(&self, o: &Self) -> Self {
590        let mut out = *self;
591        out.v = self.v.add(o.v);
592        for i in 0..K {
593            out.g[i] = self.g[i].add(o.g[i]);
594            for j in 0..K {
595                out.h[i][j] = self.h[i][j].add(o.h[i][j]);
596            }
597        }
598        out
599    }
600
601    /// Multiply every channel by the plain scalar `s` (mirrors `Tower2::scale`).
602    #[inline]
603    pub fn scale(&self, s: f64) -> Self {
604        let sl = L::splat(s);
605        let mut out = *self;
606        out.v = self.v.mul(sl);
607        for i in 0..K {
608            out.g[i] = self.g[i].mul(sl);
609            for j in 0..K {
610                out.h[i][j] = self.h[i][j].mul(sl);
611            }
612        }
613        out
614    }
615
616    /// Lane-wise `self - o`, expressed as `self + o·(-1)` exactly as
617    /// [`Order2::sub`] / `Tower4::sub` do, so signed-zero handling matches.
618    #[inline]
619    pub fn sub(&self, o: &Self) -> Self {
620        self.add(&o.scale(-1.0))
621    }
622
623    /// Negate every channel (= `scale(-1.0)`, matching [`Order2::neg`]).
624    #[inline]
625    pub fn neg(&self) -> Self {
626        self.scale(-1.0)
627    }
628
629    /// Exact order-≤2 Leibniz product, term-for-term identical to
630    /// [`crate::jet_tower::Tower2::mul`] (same factor order, no `mul_add`).
631    ///
632    /// The Hessian channel is symmetric under `i ↔ j` (see
633    /// [`crate::jet_tower::Tower2::mul`] for why the invariant always holds), so
634    /// we compute the upper triangle `j ≥ i` and mirror it — `K(K+1)/2` lane
635    /// entry-chains instead of `K²`. Because each lane entry is already a full
636    /// SIMD op (no cross-`j` lane packing to lose), halving the entry count is a
637    /// direct throughput win (~18 % on `Order2Lane<f64x4, 9>`, the survival batch
638    /// kernel, and ~2× on the `f64` oracle). The upper triangle uses the EXACT
639    /// term order of `Tower2::mul`, so `Order2Lane<f64>` stays `to_bits`-identical
640    /// to `Order2` (= `Tower2`) and `Order2Lane<f64x4>` lane `i` stays
641    /// `to_bits`-identical to that; the mirror makes the batch Hessian exactly
642    /// symmetric, matching the scalar `Tower2::mul` (which mirrors identically).
643    #[inline]
644    pub fn mul(&self, o: &Self) -> Self {
645        let a = self;
646        let b = o;
647        let mut out = Self::constant(a.v.mul(b.v));
648        for i in 0..K {
649            // a.v*b.g[i] + a.g[i]*b.v
650            out.g[i] = a.v.mul(b.g[i]).add(a.g[i].mul(b.v));
651        }
652        for i in 0..K {
653            for j in i..K {
654                // a.v*b.h + a.g[i]*b.g[j] + a.g[j]*b.g[i] + a.h*b.v
655                let hij =
656                    a.v.mul(b.h[i][j])
657                        .add(a.g[i].mul(b.g[j]))
658                        .add(a.g[j].mul(b.g[i]))
659                        .add(a.h[i][j].mul(b.v));
660                out.h[i][j] = hij;
661                out.h[j][i] = hij;
662            }
663        }
664        out
665    }
666
667    /// Exact order-≤2 Faà di Bruno composition `f ∘ self`, given the per-lane
668    /// derivative stack `d = [f(u), f′(u), f″(u)]`. Mirrors
669    /// [`crate::jet_tower::Tower2::compose_unary`] term-for-term (`acc` starts at
670    /// `0` then accumulates, so signed-zero collapses identically).
671    #[inline]
672    pub fn compose_unary(&self, d: [L; 3]) -> Self {
673        let mut out = Self::constant(d[0]);
674        for i in 0..K {
675            let mut acc = L::splat(0.0);
676            acc = acc.add(d[1].mul(self.g[i]));
677            out.g[i] = acc;
678        }
679        for i in 0..K {
680            for j in 0..K {
681                let mut acc = L::splat(0.0);
682                acc = acc.add(d[1].mul(self.h[i][j]));
683                acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
684                out.h[i][j] = acc;
685            }
686        }
687        out
688    }
689
690    /// `e^self`, per-lane stack `[e, e, e]` (matches the [`JetScalar::exp`]
691    /// default forwarded through `Order2`).
692    #[inline]
693    pub fn exp(&self) -> Self {
694        let d = self.v.unary3(|u| {
695            let e = u.exp();
696            [e, e, e]
697        });
698        self.compose_unary(d)
699    }
700
701    /// `ln(self)`; caller guarantees positivity. Per-lane stack
702    /// `[ln u, 1/u, -1/u²]` (matches [`JetScalar::ln`] truncated to order 2).
703    #[inline]
704    pub fn ln(&self) -> Self {
705        let d = self.v.unary3(|u| {
706            let r = 1.0 / u;
707            [u.ln(), r, -r * r]
708        });
709        self.compose_unary(d)
710    }
711
712    /// `√self`; caller guarantees positivity. Per-lane stack
713    /// `[s, 0.5/s, -0.25/(u·s)]` (matches [`JetScalar::sqrt`]).
714    #[inline]
715    pub fn sqrt(&self) -> Self {
716        let d = self.v.unary3(|u| {
717            let s = u.sqrt();
718            [s, 0.5 / s, -0.25 / (u * s)]
719        });
720        self.compose_unary(d)
721    }
722
723    /// `1/self`. Per-lane stack `[r, -r², 2r³]` (matches [`JetScalar::recip`]).
724    #[inline]
725    pub fn recip(&self) -> Self {
726        let d = self.v.unary3(|u| {
727            let r = 1.0 / u;
728            let r2 = r * r;
729            [r, -r2, 2.0 * r2 * r]
730        });
731        self.compose_unary(d)
732    }
733
734    /// `self^a` for real `a`; caller guarantees a positive base. Per-lane
735    /// falling-factorial stack (matches [`JetScalar::powf`]).
736    #[inline]
737    pub fn powf(&self, a: f64) -> Self {
738        let d = self.v.unary3(|u| {
739            [
740                u.powf(a),
741                a * u.powf(a - 1.0),
742                a * (a - 1.0) * u.powf(a - 2.0),
743            ]
744        });
745        self.compose_unary(d)
746    }
747}
748
749impl<const K: usize> Order2Batch<K> {
750    /// Extract lane `i`'s `(v, g, H)` as a production [`Order2<K>`] scalar.
751    /// Lane `i` is `to_bits`-identical to evaluating the same program at
752    /// [`Order2<K>`] on row `i` (see `batch_tests`).
753    #[inline]
754    #[must_use]
755    pub fn lane(&self, i: usize) -> Order2<K> {
756        let mut t = crate::jet_tower::Tower2::<K>::constant(self.v.lane(i));
757        for a in 0..K {
758            t.g[a] = self.g[a].lane(i);
759            for b in 0..K {
760                t.h[a][b] = self.h[a][b].lane(i);
761            }
762        }
763        Order2(t)
764    }
765}
766
767// ── Order1<K>: value / gradient only (doc §A.1, first-order prune) ──────
768
769/// Truncated FIRST-order scalar: value `v` and gradient `g_a` only — NO Hessian.
770///
771/// This is [`Order2`] with the K×K Hessian channel deleted. Its value and
772/// gradient are computed by the SAME order-≤1 truncation of the Leibniz / Faà
773/// di Bruno rules that [`Order2`] uses for those two channels, with the float
774/// operations applied in the identical order — so its `(v, g)` is BIT-IDENTICAL
775/// to both [`Order2`]'s and a full [`crate::jet_tower::Tower4`]'s order-≤1
776/// channels. Use it at a consumer that reads ONLY value + gradient (the SAE
777/// β-border channel: the reconstruction is linear in β, so the Hessian-in-β
778/// vanishes and the dense K×K Hessian product `Tower2::mul` would build is pure
779/// discarded work). Order-≤1 value/gradient never read any input's Hessian, so
780/// dropping that channel changes neither result nor float-op order — it only
781/// removes the `K²` arithmetic that produced an unread tensor.
782#[derive(Clone, Copy, Debug)]
783pub struct Order1<const K: usize> {
784    /// Value ℓ.
785    pub v: f64,
786    /// Gradient ∂ℓ/∂p_a.
787    pub g: [f64; K],
788}
789
790impl<const K: usize> Order1<K> {
791    /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
792    #[inline]
793    pub fn g(&self) -> [f64; K] {
794        self.g
795    }
796}
797
798impl<const K: usize> JetScalar<K> for Order1<K> {
799    fn constant(c: f64) -> Self {
800        // Order2::constant -> Tower2::constant: value c, all derivatives zero.
801        Order1 { v: c, g: [0.0; K] }
802    }
803    fn variable(x: f64, axis: usize) -> Self {
804        // Order2::variable -> Tower2::variable: unit first derivative in `axis`.
805        let mut g = [0.0; K];
806        g[axis] = 1.0;
807        Order1 { v: x, g }
808    }
809    fn value(&self) -> f64 {
810        self.v
811    }
812    fn add(&self, o: &Self) -> Self {
813        // Tower2 Add: out.v += o.v; out.g[i] += o.g[i] (same float order).
814        let mut g = self.g;
815        for i in 0..K {
816            g[i] += o.g[i];
817        }
818        Order1 { v: self.v + o.v, g }
819    }
820    fn sub(&self, o: &Self) -> Self {
821        // Mirror Order2::sub == self + o.scale(-1.0) exactly: scale then add.
822        self.add(&o.scale(-1.0))
823    }
824    fn mul(&self, o: &Self) -> Self {
825        // Tower2::mul value/grad terms, identical float order:
826        //   v = a.v*b.v;  g[i] = a.v*b.g[i] + a.g[i]*b.v.
827        // (The Hessian loop `a.v*b.h + a.g*b.g + ... + a.h*b.v` is the discarded
828        //  work this type exists to skip; it never feeds v or g.)
829        let a = self;
830        let b = o;
831        let mut g = [0.0; K];
832        for i in 0..K {
833            g[i] = a.v * b.g[i] + a.g[i] * b.v;
834        }
835        Order1 { v: a.v * b.v, g }
836    }
837    fn neg(&self) -> Self {
838        // Order2::neg == self.0.scale(-1.0).
839        self.scale(-1.0)
840    }
841    fn scale(&self, s: f64) -> Self {
842        // Tower2::scale: out.v *= s; out.g[i] *= s (same float order).
843        let mut g = self.g;
844        for i in 0..K {
845            g[i] *= s;
846        }
847        Order1 { v: self.v * s, g }
848    }
849    fn compose_unary(&self, d: [f64; 5]) -> Self {
850        // Faà di Bruno truncated to order ≤ 1 (matches `faa_di_bruno` /
851        // `Tower2::compose_unary` for the value and gradient channels):
852        //   value channel (m=0): d[0].
853        //   grad channel (positions=[i], single partition {{0}}): d[1]·g[i].
854        // Order-≤1 reads only d[0], d[1]; trailing stack entries are unused.
855        let mut g = [0.0; K];
856        for i in 0..K {
857            g[i] = d[1] * self.g[i];
858        }
859        Order1 { v: d[0], g }
860    }
861}
862
863// ── OneSeed<K>: one-seed directional, contracted third (doc §A.2) ───────
864
865/// One-seed directional scalar: an [`Order2`] base plus ONE nilpotent ε
866/// (`ε² = 0`) whose coefficient is itself an [`Order2`].
867///
868/// A scalar is `s = base + ε·eps`. Arithmetic is the `ε² = 0` truncation of the
869/// product (doc §A.2): the base parts multiply as ordinary [`Order2`] products,
870/// and the ε-coefficient picks up `a.base·b.eps + a.eps·b.base`. Composition
871/// pushes ε through one extra outer derivative.
872///
873/// Seed each primary with [`seed_direction`](Self::seed_direction): the base is
874/// the usual seeded variable (carrying `e_a` for the Hessian channel) and the
875/// ε-coefficient is the FIXED contraction direction `u_a` (a constant). Then the
876/// ε-component of the evaluated Hessian channel is the contracted third
877/// `[eps.h][a][b] = Σ_c ℓ_{abc} u_c` — exactly `row_third_contracted(dir = u)`,
878/// without materialising `t3`.
879#[derive(Clone, Copy, Debug)]
880pub struct OneSeed<const K: usize> {
881    /// The `ε⁰` part: value / gradient / Hessian of `ℓ`.
882    pub base: Order2<K>,
883    /// The `ε¹` part: value / gradient / Hessian of the ε-coefficient. After a
884    /// `seed_direction(u)` evaluation, `eps.h[a][b] = Σ_c ℓ_{abc} u_c`.
885    pub eps: Order2<K>,
886}
887
888impl<const K: usize> OneSeed<K> {
889    /// Seed primary `axis` at value `x` with ε-direction component `u_axis`:
890    /// `p_axis = p_axis⁰ + x-seed + ε·u_axis`, i.e. base = `variable(x, axis)`
891    /// and eps = `constant(u_axis)` (doc §A.2 "Seeding").
892    pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
893        OneSeed {
894            base: Order2::variable(x, axis),
895            eps: Order2::constant(u_axis),
896        }
897    }
898
899    /// The contracted-third channel after a `seed_direction(u)` evaluation:
900    /// `out[a][b] = Σ_c ℓ_{abc} u_c`, i.e. the ε-coefficient's Hessian (doc §A.2).
901    pub fn contracted_third(&self) -> [[f64; K]; K] {
902        self.eps.h()
903    }
904}
905
906impl<const K: usize> JetScalar<K> for OneSeed<K> {
907    fn constant(c: f64) -> Self {
908        OneSeed {
909            base: Order2::constant(c),
910            eps: Order2::constant(0.0),
911        }
912    }
913    fn variable(x: f64, axis: usize) -> Self {
914        // No ε-direction unless seeded via `seed_direction`.
915        OneSeed {
916            base: Order2::variable(x, axis),
917            eps: Order2::constant(0.0),
918        }
919    }
920    fn value(&self) -> f64 {
921        self.base.value()
922    }
923    fn add(&self, o: &Self) -> Self {
924        OneSeed {
925            base: self.base.add(&o.base),
926            eps: self.eps.add(&o.eps),
927        }
928    }
929    fn sub(&self, o: &Self) -> Self {
930        OneSeed {
931            base: self.base.sub(&o.base),
932            eps: self.eps.sub(&o.eps),
933        }
934    }
935    fn mul(&self, o: &Self) -> Self {
936        // (a.base + ε a.eps)(b.base + ε b.eps), dropping ε².
937        OneSeed {
938            base: self.base.mul(&o.base),
939            eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
940        }
941    }
942    fn neg(&self) -> Self {
943        OneSeed {
944            base: self.base.neg(),
945            eps: self.eps.neg(),
946        }
947    }
948    fn scale(&self, s: f64) -> Self {
949        OneSeed {
950            base: self.base.scale(s),
951            eps: self.eps.scale(s),
952        }
953    }
954    fn compose_unary(&self, d: [f64; 5]) -> Self {
955        // f(base + ε eps) = f(base) + ε · f'(base)·eps  (ε² = 0). Each factor is
956        // an Order2 composition: the base composes with the f-stack, and the
957        // ε-coefficient is the Order2 of the SHIFTED stack (the chain rule
958        // `f'(base)` as an Order2) times eps. Order2 reads only the leading
959        // three entries of whatever stack it is handed, so the trailing slots
960        // are unused padding (the fixed-length array makes the windowing total).
961        let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
962        // f'(base) as an Order2 (consumes [f', f'', f''']).
963        let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
964        let eps = fprime.mul(&self.eps);
965        OneSeed { base, eps }
966    }
967}
968
969// ── OneSeedLane<L, K>: lane-batched one-seed directional (doc §A.2) ──────
970
971/// Lane-batched [`OneSeed`]: the same one-seed directional scalar with its two
972/// [`Order2`] parts re-typed to [`Order2Lane<L, K>`], so one `L = f64x4`
973/// instance carries FOUR rows' contracted-third evaluations per vector pass.
974///
975/// Every operation (`add`/`sub`/`mul`/`neg`/`scale`/`compose_unary` and the
976/// transcendentals) is a term-for-term structural re-type of the scalar
977/// [`OneSeed`] ops onto the lane-implemented [`Order2Lane`] algebra. With
978/// `L = f64`, `OneSeedLane<f64, K>` is `to_bits`-identical to [`OneSeed<K>`];
979/// with `L = f64x4`, lane `i` is `to_bits`-identical to that (see `batch_tests`).
980#[derive(Clone, Copy, Debug)]
981pub struct OneSeedLane<L: Lane, const K: usize> {
982    /// The `ε⁰` part (lane-batched value / gradient / Hessian of `ℓ`).
983    pub base: Order2Lane<L, K>,
984    /// The `ε¹` part. After a `seed_direction(u)` evaluation,
985    /// `eps.h[a][b]` lane `i` is row `i`'s `Σ_c ℓ_{abc} u_c`.
986    pub eps: Order2Lane<L, K>,
987}
988
989/// The 4-rows-per-pass batched one-seed scalar (`wide::f64x4` lanes).
990pub type OneSeedBatch<const K: usize> = OneSeedLane<wide::f64x4, K>;
991
992impl<L: Lane, const K: usize> OneSeedLane<L, K> {
993    /// A constant: base = `constant(c)`, ε-part zero (mirrors [`OneSeed::constant`]).
994    #[inline]
995    pub fn constant(c: L) -> Self {
996        OneSeedLane {
997            base: Order2Lane::constant(c),
998            eps: Order2Lane::constant(L::splat(0.0)),
999        }
1000    }
1001
1002    /// The seeded variable `p_axis` at (per-lane) value `value`, no ε-direction
1003    /// (mirrors [`OneSeed::variable`]).
1004    #[inline]
1005    pub fn variable(value: L, axis: usize) -> Self {
1006        OneSeedLane {
1007            base: Order2Lane::variable(value, axis),
1008            eps: Order2Lane::constant(L::splat(0.0)),
1009        }
1010    }
1011
1012    /// Seed primary `axis` at (per-lane) value `value` with ε-direction
1013    /// component `u_axis`: base = `variable(value, axis)`, eps = `constant(u_axis)`
1014    /// (mirrors [`OneSeed::seed_direction`]). With `L = f64x4`, `value` / `u_axis`
1015    /// pack the four rows' values / directions of primary `axis`.
1016    #[inline]
1017    pub fn seed_direction(value: L, axis: usize, u_axis: L) -> Self {
1018        OneSeedLane {
1019            base: Order2Lane::variable(value, axis),
1020            eps: Order2Lane::constant(u_axis),
1021        }
1022    }
1023
1024    /// The contracted-third channel after a `seed_direction(u)` evaluation:
1025    /// `out[a][b]` lane `i` is row `i`'s `Σ_c ℓ_{abc} u_c` (the ε-part Hessian).
1026    #[inline]
1027    #[must_use]
1028    pub fn contracted_third(&self) -> [[L; K]; K] {
1029        self.eps.h
1030    }
1031
1032    /// Lane-wise `self + o` (mirrors [`OneSeed::add`]).
1033    #[inline]
1034    pub fn add(&self, o: &Self) -> Self {
1035        OneSeedLane {
1036            base: self.base.add(&o.base),
1037            eps: self.eps.add(&o.eps),
1038        }
1039    }
1040
1041    /// Lane-wise `self - o` (mirrors [`OneSeed::sub`]).
1042    #[inline]
1043    pub fn sub(&self, o: &Self) -> Self {
1044        OneSeedLane {
1045            base: self.base.sub(&o.base),
1046            eps: self.eps.sub(&o.eps),
1047        }
1048    }
1049
1050    /// Lane-wise `self · o`, ε² = 0 truncation (mirrors [`OneSeed::mul`]).
1051    #[inline]
1052    pub fn mul(&self, o: &Self) -> Self {
1053        OneSeedLane {
1054            base: self.base.mul(&o.base),
1055            eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
1056        }
1057    }
1058
1059    /// Negate every part (mirrors [`OneSeed::neg`]).
1060    #[inline]
1061    pub fn neg(&self) -> Self {
1062        OneSeedLane {
1063            base: self.base.neg(),
1064            eps: self.eps.neg(),
1065        }
1066    }
1067
1068    /// Multiply every part by the plain scalar `s` (mirrors [`OneSeed::scale`]).
1069    #[inline]
1070    pub fn scale(&self, s: f64) -> Self {
1071        OneSeedLane {
1072            base: self.base.scale(s),
1073            eps: self.eps.scale(s),
1074        }
1075    }
1076
1077    /// Exact order-≤2-per-part Faà di Bruno composition `f ∘ self`, given the
1078    /// per-lane outer-derivative stack `d = [f, f′, f″, f‴, f⁗]`. Term-for-term
1079    /// identical to [`OneSeed::compose_unary`]: the base reads `d[0..=2]` and the
1080    /// ε-coefficient is `f′(base)` (reads `d[1..=3]`) times `eps`.
1081    #[inline]
1082    pub fn compose_unary(&self, d: [L; 5]) -> Self {
1083        let base = self.base.compose_unary([d[0], d[1], d[2]]);
1084        let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
1085        let eps = fprime.mul(&self.eps);
1086        OneSeedLane { base, eps }
1087    }
1088
1089    /// `e^self`, per-lane stack `[e, e, e, e, e]` (matches [`JetScalar::exp`]).
1090    #[inline]
1091    pub fn exp(&self) -> Self {
1092        let d = self.base.v.unary5(|u| {
1093            let e = u.exp();
1094            [e, e, e, e, e]
1095        });
1096        self.compose_unary(d)
1097    }
1098
1099    /// `ln(self)`; caller guarantees positivity (matches [`JetScalar::ln`]).
1100    #[inline]
1101    pub fn ln(&self) -> Self {
1102        let d = self.base.v.unary5(|u| {
1103            let r = 1.0 / u;
1104            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
1105        });
1106        self.compose_unary(d)
1107    }
1108
1109    /// `√self`; caller guarantees positivity (matches [`JetScalar::sqrt`]).
1110    #[inline]
1111    pub fn sqrt(&self) -> Self {
1112        let d = self.base.v.unary5(|u| {
1113            let s = u.sqrt();
1114            [
1115                s,
1116                0.5 / s,
1117                -0.25 / (u * s),
1118                0.375 / (u * u * s),
1119                -0.9375 / (u * u * u * s),
1120            ]
1121        });
1122        self.compose_unary(d)
1123    }
1124
1125    /// `1/self` (matches [`JetScalar::recip`]).
1126    #[inline]
1127    pub fn recip(&self) -> Self {
1128        let d = self.base.v.unary5(|u| {
1129            let r = 1.0 / u;
1130            let r2 = r * r;
1131            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
1132        });
1133        self.compose_unary(d)
1134    }
1135
1136    /// `self^a` for real `a`; caller guarantees a positive base (matches
1137    /// [`JetScalar::powf`]).
1138    #[inline]
1139    pub fn powf(&self, a: f64) -> Self {
1140        let d = self.base.v.unary5(|u| {
1141            [
1142                u.powf(a),
1143                a * u.powf(a - 1.0),
1144                a * (a - 1.0) * u.powf(a - 2.0),
1145                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
1146                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
1147            ]
1148        });
1149        self.compose_unary(d)
1150    }
1151
1152    /// `ln Γ(self)`; caller guarantees positivity (matches [`JetScalar::ln_gamma`],
1153    /// same hand-certified stack).
1154    #[inline]
1155    pub fn ln_gamma(&self) -> Self {
1156        let d = self
1157            .base
1158            .v
1159            .unary5(crate::jet_tower::ln_gamma_derivative_stack);
1160        self.compose_unary(d)
1161    }
1162
1163    /// `ψ(self)` digamma; caller guarantees positivity (matches
1164    /// [`JetScalar::digamma`], same hand-certified stack).
1165    #[inline]
1166    pub fn digamma(&self) -> Self {
1167        let d = self
1168            .base
1169            .v
1170            .unary5(crate::jet_tower::digamma_derivative_stack);
1171        self.compose_unary(d)
1172    }
1173}
1174
1175impl<const K: usize> OneSeedBatch<K> {
1176    /// Extract lane `i`'s parts as a production [`OneSeed<K>`]. Lane `i` is
1177    /// `to_bits`-identical to evaluating the same program at [`OneSeed<K>`] on
1178    /// row `i` (see `batch_tests`).
1179    #[inline]
1180    #[must_use]
1181    pub fn lane(&self, i: usize) -> OneSeed<K> {
1182        OneSeed {
1183            base: self.base.lane(i),
1184            eps: self.eps.lane(i),
1185        }
1186    }
1187}
1188
1189// ── TwoSeed<K>: two-seed, contracted fourth (doc §A.3) ──────────────────
1190
1191/// Two-seed scalar: an [`Order2`] base plus TWO nilpotents ε, δ
1192/// (`ε² = δ² = 0`, `εδ` retained) — four [`Order2`] parts
1193/// `s = base + ε·eps + δ·del + εδ·eps_del`.
1194///
1195/// Product truncates `ε² = δ² = 0` (doc §A.3): each part is built from
1196/// [`Order2`] products of the four input parts. Composition picks up
1197/// successively higher outer derivatives, the cross part carrying the second
1198/// Faà di Bruno term `f''·eps·del + f'·eps_del`.
1199///
1200/// Seed each primary with [`seed`](Self::seed): base = `variable(x, axis)`,
1201/// eps = `constant(u_axis)`, del = `constant(v_axis)`, eps_del = `constant(0)`.
1202/// Then the εδ-component of the evaluated Hessian channel is the contracted
1203/// fourth `[eps_del.h][a][b] = Σ_{cd} ℓ_{abcd} u_c v_d` — exactly
1204/// `row_fourth_contracted(u, v)`, without materialising `t4`.
1205#[derive(Clone, Copy, Debug)]
1206pub struct TwoSeed<const K: usize> {
1207    /// The `ε⁰δ⁰` part: value / grad / Hessian of `ℓ`.
1208    pub base: Order2<K>,
1209    /// The `ε¹δ⁰` part.
1210    pub eps: Order2<K>,
1211    /// The `ε⁰δ¹` part.
1212    pub del: Order2<K>,
1213    /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation,
1214    /// `eps_del.h[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`.
1215    pub eps_del: Order2<K>,
1216}
1217
1218impl<const K: usize> TwoSeed<K> {
1219    /// Seed primary `axis` at value `x` with ε-direction `u_axis` and
1220    /// δ-direction `v_axis`:
1221    /// `p_axis = p_axis⁰ + x-seed + ε·u_axis + δ·v_axis` (doc §A.3 "Seeding").
1222    pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
1223        TwoSeed {
1224            base: Order2::variable(x, axis),
1225            eps: Order2::constant(u_axis),
1226            del: Order2::constant(v_axis),
1227            eps_del: Order2::constant(0.0),
1228        }
1229    }
1230
1231    /// The contracted-fourth channel after a `seed(u, v)` evaluation:
1232    /// `out[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`, i.e. the εδ-coefficient's Hessian.
1233    pub fn contracted_fourth(&self) -> [[f64; K]; K] {
1234        self.eps_del.h()
1235    }
1236}
1237
1238impl<const K: usize> JetScalar<K> for TwoSeed<K> {
1239    fn constant(c: f64) -> Self {
1240        TwoSeed {
1241            base: Order2::constant(c),
1242            eps: Order2::constant(0.0),
1243            del: Order2::constant(0.0),
1244            eps_del: Order2::constant(0.0),
1245        }
1246    }
1247    fn variable(x: f64, axis: usize) -> Self {
1248        TwoSeed {
1249            base: Order2::variable(x, axis),
1250            eps: Order2::constant(0.0),
1251            del: Order2::constant(0.0),
1252            eps_del: Order2::constant(0.0),
1253        }
1254    }
1255    fn value(&self) -> f64 {
1256        self.base.value()
1257    }
1258    fn add(&self, o: &Self) -> Self {
1259        TwoSeed {
1260            base: self.base.add(&o.base),
1261            eps: self.eps.add(&o.eps),
1262            del: self.del.add(&o.del),
1263            eps_del: self.eps_del.add(&o.eps_del),
1264        }
1265    }
1266    fn sub(&self, o: &Self) -> Self {
1267        TwoSeed {
1268            base: self.base.sub(&o.base),
1269            eps: self.eps.sub(&o.eps),
1270            del: self.del.sub(&o.del),
1271            eps_del: self.eps_del.sub(&o.eps_del),
1272        }
1273    }
1274    fn mul(&self, o: &Self) -> Self {
1275        let a = self;
1276        let b = o;
1277        // Truncate ε² = δ² = 0 (doc §A.3 product table).
1278        let base = a.base.mul(&b.base);
1279        let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
1280        let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
1281        let eps_del = a
1282            .base
1283            .mul(&b.eps_del)
1284            .add(&a.eps.mul(&b.del))
1285            .add(&a.del.mul(&b.eps))
1286            .add(&a.eps_del.mul(&b.base));
1287        TwoSeed {
1288            base,
1289            eps,
1290            del,
1291            eps_del,
1292        }
1293    }
1294    fn neg(&self) -> Self {
1295        TwoSeed {
1296            base: self.base.neg(),
1297            eps: self.eps.neg(),
1298            del: self.del.neg(),
1299            eps_del: self.eps_del.neg(),
1300        }
1301    }
1302    fn scale(&self, s: f64) -> Self {
1303        TwoSeed {
1304            base: self.base.scale(s),
1305            eps: self.eps.scale(s),
1306            del: self.del.scale(s),
1307            eps_del: self.eps_del.scale(s),
1308        }
1309    }
1310    fn compose_unary(&self, d: [f64; 5]) -> Self {
1311        // f(s) with s = base + ε eps + δ del + εδ eps_del, ε²=δ²=0:
1312        //   f(s) = f(base)
1313        //        + ε · f'(base)·eps
1314        //        + δ · f'(base)·del
1315        //        + εδ · ( f''(base)·eps·del + f'(base)·eps_del ).
1316        // Each f^{(r)}(base) is the Order2 composition of base with the stack
1317        // shifted r entries (doc §A.3 composition). Order2 reads only the
1318        // leading three entries of whatever stack it is handed, so the trailing
1319        // padding slots are unused (the fixed-length array makes this total).
1320        let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
1321        let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); // f'(base) as Order2
1322        let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); // f''(base) as Order2
1323        let eps = fprime.mul(&self.eps);
1324        let del = fprime.mul(&self.del);
1325        let eps_del = fsecond
1326            .mul(&self.eps)
1327            .mul(&self.del)
1328            .add(&fprime.mul(&self.eps_del));
1329        TwoSeed {
1330            base,
1331            eps,
1332            del,
1333            eps_del,
1334        }
1335    }
1336}
1337
1338// ── TwoSeedLane<L, K>: lane-batched two-seed, contracted fourth (doc §A.3) ─
1339
1340/// Lane-batched [`TwoSeed`]: the same two-seed scalar with its four [`Order2`]
1341/// parts re-typed to [`Order2Lane<L, K>`], so one `L = f64x4` instance carries
1342/// FOUR rows' contracted-fourth evaluations per vector pass.
1343///
1344/// Every operation is a term-for-term structural re-type of the scalar
1345/// [`TwoSeed`] ops onto the lane-implemented [`Order2Lane`] algebra. With
1346/// `L = f64`, `TwoSeedLane<f64, K>` is `to_bits`-identical to [`TwoSeed<K>`];
1347/// with `L = f64x4`, lane `i` is `to_bits`-identical to that (see `batch_tests`).
1348#[derive(Clone, Copy, Debug)]
1349pub struct TwoSeedLane<L: Lane, const K: usize> {
1350    /// The `ε⁰δ⁰` part.
1351    pub base: Order2Lane<L, K>,
1352    /// The `ε¹δ⁰` part.
1353    pub eps: Order2Lane<L, K>,
1354    /// The `ε⁰δ¹` part.
1355    pub del: Order2Lane<L, K>,
1356    /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation, `eps_del.h[a][b]`
1357    /// lane `i` is row `i`'s `Σ_{cd} ℓ_{abcd} u_c v_d`.
1358    pub eps_del: Order2Lane<L, K>,
1359}
1360
1361/// The 4-rows-per-pass batched two-seed scalar (`wide::f64x4` lanes).
1362pub type TwoSeedBatch<const K: usize> = TwoSeedLane<wide::f64x4, K>;
1363
1364impl<L: Lane, const K: usize> TwoSeedLane<L, K> {
1365    /// A constant: base = `constant(c)`, all seed parts zero (mirrors
1366    /// [`TwoSeed::constant`]).
1367    #[inline]
1368    pub fn constant(c: L) -> Self {
1369        let z = Order2Lane::constant(L::splat(0.0));
1370        TwoSeedLane {
1371            base: Order2Lane::constant(c),
1372            eps: z,
1373            del: z,
1374            eps_del: z,
1375        }
1376    }
1377
1378    /// The seeded variable `p_axis` at (per-lane) value `value`, no ε/δ direction
1379    /// (mirrors [`TwoSeed::variable`]).
1380    #[inline]
1381    pub fn variable(value: L, axis: usize) -> Self {
1382        let z = Order2Lane::constant(L::splat(0.0));
1383        TwoSeedLane {
1384            base: Order2Lane::variable(value, axis),
1385            eps: z,
1386            del: z,
1387            eps_del: z,
1388        }
1389    }
1390
1391    /// Seed primary `axis` at (per-lane) value `value` with ε-direction `u_axis`
1392    /// and δ-direction `v_axis` (mirrors [`TwoSeed::seed`]). With `L = f64x4`,
1393    /// each argument packs the four rows' values for primary `axis`.
1394    #[inline]
1395    pub fn seed(value: L, axis: usize, u_axis: L, v_axis: L) -> Self {
1396        TwoSeedLane {
1397            base: Order2Lane::variable(value, axis),
1398            eps: Order2Lane::constant(u_axis),
1399            del: Order2Lane::constant(v_axis),
1400            eps_del: Order2Lane::constant(L::splat(0.0)),
1401        }
1402    }
1403
1404    /// The contracted-fourth channel after a `seed(u, v)` evaluation:
1405    /// `out[a][b]` lane `i` is row `i`'s `Σ_{cd} ℓ_{abcd} u_c v_d`
1406    /// (the εδ-part Hessian).
1407    #[inline]
1408    #[must_use]
1409    pub fn contracted_fourth(&self) -> [[L; K]; K] {
1410        self.eps_del.h
1411    }
1412
1413    /// Lane-wise `self + o` (mirrors [`TwoSeed::add`]).
1414    #[inline]
1415    pub fn add(&self, o: &Self) -> Self {
1416        TwoSeedLane {
1417            base: self.base.add(&o.base),
1418            eps: self.eps.add(&o.eps),
1419            del: self.del.add(&o.del),
1420            eps_del: self.eps_del.add(&o.eps_del),
1421        }
1422    }
1423
1424    /// Lane-wise `self - o` (mirrors [`TwoSeed::sub`]).
1425    #[inline]
1426    pub fn sub(&self, o: &Self) -> Self {
1427        TwoSeedLane {
1428            base: self.base.sub(&o.base),
1429            eps: self.eps.sub(&o.eps),
1430            del: self.del.sub(&o.del),
1431            eps_del: self.eps_del.sub(&o.eps_del),
1432        }
1433    }
1434
1435    /// Lane-wise `self · o`, ε² = δ² = 0 truncation (mirrors [`TwoSeed::mul`]).
1436    #[inline]
1437    pub fn mul(&self, o: &Self) -> Self {
1438        let a = self;
1439        let b = o;
1440        let base = a.base.mul(&b.base);
1441        let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
1442        let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
1443        let eps_del = a
1444            .base
1445            .mul(&b.eps_del)
1446            .add(&a.eps.mul(&b.del))
1447            .add(&a.del.mul(&b.eps))
1448            .add(&a.eps_del.mul(&b.base));
1449        TwoSeedLane {
1450            base,
1451            eps,
1452            del,
1453            eps_del,
1454        }
1455    }
1456
1457    /// Negate every part (mirrors [`TwoSeed::neg`]).
1458    #[inline]
1459    pub fn neg(&self) -> Self {
1460        TwoSeedLane {
1461            base: self.base.neg(),
1462            eps: self.eps.neg(),
1463            del: self.del.neg(),
1464            eps_del: self.eps_del.neg(),
1465        }
1466    }
1467
1468    /// Multiply every part by the plain scalar `s` (mirrors [`TwoSeed::scale`]).
1469    #[inline]
1470    pub fn scale(&self, s: f64) -> Self {
1471        TwoSeedLane {
1472            base: self.base.scale(s),
1473            eps: self.eps.scale(s),
1474            del: self.del.scale(s),
1475            eps_del: self.eps_del.scale(s),
1476        }
1477    }
1478
1479    /// Exact composition `f ∘ self`, given the per-lane outer-derivative stack
1480    /// `d = [f, f′, f″, f‴, f⁗]`. Term-for-term identical to
1481    /// [`TwoSeed::compose_unary`]: base reads `d[0..=2]`, `f′(base)` reads
1482    /// `d[1..=3]`, `f″(base)` reads `d[2..=4]`, and the cross part carries
1483    /// `f″·eps·del + f′·eps_del`.
1484    #[inline]
1485    pub fn compose_unary(&self, d: [L; 5]) -> Self {
1486        let base = self.base.compose_unary([d[0], d[1], d[2]]);
1487        let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
1488        let fsecond = self.base.compose_unary([d[2], d[3], d[4]]);
1489        let eps = fprime.mul(&self.eps);
1490        let del = fprime.mul(&self.del);
1491        let eps_del = fsecond
1492            .mul(&self.eps)
1493            .mul(&self.del)
1494            .add(&fprime.mul(&self.eps_del));
1495        TwoSeedLane {
1496            base,
1497            eps,
1498            del,
1499            eps_del,
1500        }
1501    }
1502
1503    /// `e^self`, per-lane stack `[e; 5]` (matches [`JetScalar::exp`]).
1504    #[inline]
1505    pub fn exp(&self) -> Self {
1506        let d = self.base.v.unary5(|u| {
1507            let e = u.exp();
1508            [e, e, e, e, e]
1509        });
1510        self.compose_unary(d)
1511    }
1512
1513    /// `ln(self)`; caller guarantees positivity (matches [`JetScalar::ln`]).
1514    #[inline]
1515    pub fn ln(&self) -> Self {
1516        let d = self.base.v.unary5(|u| {
1517            let r = 1.0 / u;
1518            [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
1519        });
1520        self.compose_unary(d)
1521    }
1522
1523    /// `√self`; caller guarantees positivity (matches [`JetScalar::sqrt`]).
1524    #[inline]
1525    pub fn sqrt(&self) -> Self {
1526        let d = self.base.v.unary5(|u| {
1527            let s = u.sqrt();
1528            [
1529                s,
1530                0.5 / s,
1531                -0.25 / (u * s),
1532                0.375 / (u * u * s),
1533                -0.9375 / (u * u * u * s),
1534            ]
1535        });
1536        self.compose_unary(d)
1537    }
1538
1539    /// `1/self` (matches [`JetScalar::recip`]).
1540    #[inline]
1541    pub fn recip(&self) -> Self {
1542        let d = self.base.v.unary5(|u| {
1543            let r = 1.0 / u;
1544            let r2 = r * r;
1545            [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
1546        });
1547        self.compose_unary(d)
1548    }
1549
1550    /// `self^a` for real `a`; caller guarantees a positive base (matches
1551    /// [`JetScalar::powf`]).
1552    #[inline]
1553    pub fn powf(&self, a: f64) -> Self {
1554        let d = self.base.v.unary5(|u| {
1555            [
1556                u.powf(a),
1557                a * u.powf(a - 1.0),
1558                a * (a - 1.0) * u.powf(a - 2.0),
1559                a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
1560                a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
1561            ]
1562        });
1563        self.compose_unary(d)
1564    }
1565
1566    /// `ln Γ(self)`; caller guarantees positivity (matches [`JetScalar::ln_gamma`]).
1567    #[inline]
1568    pub fn ln_gamma(&self) -> Self {
1569        let d = self
1570            .base
1571            .v
1572            .unary5(crate::jet_tower::ln_gamma_derivative_stack);
1573        self.compose_unary(d)
1574    }
1575
1576    /// `ψ(self)` digamma; caller guarantees positivity (matches
1577    /// [`JetScalar::digamma`]).
1578    #[inline]
1579    pub fn digamma(&self) -> Self {
1580        let d = self
1581            .base
1582            .v
1583            .unary5(crate::jet_tower::digamma_derivative_stack);
1584        self.compose_unary(d)
1585    }
1586}
1587
1588impl<const K: usize> TwoSeedBatch<K> {
1589    /// Extract lane `i`'s parts as a production [`TwoSeed<K>`]. Lane `i` is
1590    /// `to_bits`-identical to evaluating the same program at [`TwoSeed<K>`] on
1591    /// row `i` (see `batch_tests`).
1592    #[inline]
1593    #[must_use]
1594    pub fn lane(&self, i: usize) -> TwoSeed<K> {
1595        TwoSeed {
1596            base: self.base.lane(i),
1597            eps: self.eps.lane(i),
1598            del: self.del.lane(i),
1599            eps_del: self.eps_del.lane(i),
1600        }
1601    }
1602}
1603
1604// ── Tower3<K>: value / gradient / Hessian / third tensor ────────────────
1605
1606/// The order-≤3 [`crate::jet_tower::Tower3`] is also a [`JetScalar`]. It serves
1607/// consumers that read `.t3` but never `.t4`, avoiding the fourth-tensor
1608/// product/composition work while preserving the lower channels
1609/// bit-for-bit against [`crate::jet_tower::Tower4`].
1610impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower3<K> {
1611    fn constant(c: f64) -> Self {
1612        crate::jet_tower::Tower3::constant(c)
1613    }
1614    fn variable(x: f64, axis: usize) -> Self {
1615        crate::jet_tower::Tower3::variable(x, axis)
1616    }
1617    fn value(&self) -> f64 {
1618        self.v
1619    }
1620    fn add(&self, o: &Self) -> Self {
1621        *self + *o
1622    }
1623    fn sub(&self, o: &Self) -> Self {
1624        *self + o.scale(-1.0)
1625    }
1626    fn mul(&self, o: &Self) -> Self {
1627        crate::jet_tower::Tower3::mul(self, o)
1628    }
1629    fn neg(&self) -> Self {
1630        self.scale(-1.0)
1631    }
1632    fn scale(&self, s: f64) -> Self {
1633        crate::jet_tower::Tower3::scale(self, s)
1634    }
1635    fn compose_unary(&self, d: [f64; 5]) -> Self {
1636        crate::jet_tower::Tower3::compose_unary(self, [d[0], d[1], d[2], d[3]])
1637    }
1638}
1639
1640// ── Tower4<K>: full dense tower as a JetScalar (the all-channels scalar) ─
1641
1642/// The full dense [`crate::jet_tower::Tower4`] is itself a [`JetScalar`]: it
1643/// carries EVERY channel, so a row expression written ONCE against [`JetScalar`]
1644/// can be evaluated at `Tower4` to obtain the full `(v, g, H, t3, t4)` in one
1645/// pass. This is BOTH the #932 oracle ground truth the packed [`Order2`] /
1646/// [`OneSeed`] / [`TwoSeed`] scalars are pinned against, AND a production scalar:
1647/// a family whose uncontracted third / fourth derivative tensors are needed
1648/// (the BMS rigid `third_full` / `fourth_full` caches) evaluates the SAME
1649/// generic row-NLL expression at `Tower4` and reads `.t3` / `.t4` off the
1650/// result — so the dense tensors come from the single source of truth, not a
1651/// separately hand-written jet. The packed scalars serve the consumers that
1652/// need only `(v, g, H)` (`Order2`) or one / two contractions
1653/// (`OneSeed` / `TwoSeed`) without paying for the dense tensors.
1654impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower4<K> {
1655    fn constant(c: f64) -> Self {
1656        crate::jet_tower::Tower4::constant(c)
1657    }
1658    fn variable(x: f64, axis: usize) -> Self {
1659        crate::jet_tower::Tower4::variable(x, axis)
1660    }
1661    fn value(&self) -> f64 {
1662        self.v
1663    }
1664    fn add(&self, o: &Self) -> Self {
1665        *self + *o
1666    }
1667    fn sub(&self, o: &Self) -> Self {
1668        *self - *o
1669    }
1670    fn mul(&self, o: &Self) -> Self {
1671        crate::jet_tower::Tower4::mul(self, o)
1672    }
1673    fn neg(&self) -> Self {
1674        self.scale(-1.0)
1675    }
1676    fn scale(&self, s: f64) -> Self {
1677        crate::jet_tower::Tower4::scale(self, s)
1678    }
1679    fn compose_unary(&self, d: [f64; 5]) -> Self {
1680        crate::jet_tower::Tower4::compose_unary(self, d)
1681    }
1682}
1683
1684#[cfg(test)]
1685mod tests {
1686    use super::*;
1687    use crate::jet_tower::{RowNllProgram, Tower4, evaluate_program};
1688
1689    /// A small polynomial-plus-unary row expression written ONCE, generically
1690    /// over `S: JetScalar<2>`, so it can be evaluated against every scalar:
1691    /// `ℓ = (e^{p0·p1} + 2) · √(p0·p0 + 1) − p1·p1·0.5`.
1692    /// Exercises mul, add/sub, scale, exp, sqrt — every algebra op.
1693    fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
1694        let g = p[0].mul(&p[1]).exp();
1695        let inner = g.add(&S::constant(2.0));
1696        let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
1697        inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
1698    }
1699
1700    /// The same expression as a Tower4 `RowNllProgram`, the ground-truth tower.
1701    struct ExprProgram {
1702        p: [f64; 2],
1703    }
1704    impl RowNllProgram<2> for ExprProgram {
1705        fn n_rows(&self) -> usize {
1706            1
1707        }
1708        fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
1709            if row >= self.n_rows() {
1710                return Err(format!("ExprProgram: row {row} out of range"));
1711            }
1712            Ok(self.p)
1713        }
1714        fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
1715            if row >= self.n_rows() {
1716                return Err(format!("ExprProgram: row {row} out of range"));
1717            }
1718            Ok(row_expr(p))
1719        }
1720    }
1721
1722    const SEED: [f64; 2] = [0.37, -0.81];
1723    const U: [f64; 2] = [0.6, -0.2];
1724    const V: [f64; 2] = [-0.4, 1.1];
1725    const TOL: f64 = 1e-10;
1726
1727    fn close(a: f64, b: f64, label: &str) {
1728        let band = TOL + TOL * a.abs().max(b.abs());
1729        assert!(
1730            (a - b).abs() <= band,
1731            "{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
1732        );
1733    }
1734
1735    fn tower() -> Tower4<2> {
1736        evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
1737    }
1738
1739    /// Order2 reproduces Tower4's value/grad/Hessian channels exactly.
1740    #[test]
1741    fn order2_matches_tower_value_grad_hessian() {
1742        let t = tower();
1743        let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
1744        let s = row_expr(&vars);
1745        close(s.value(), t.v, "value");
1746        for a in 0..2 {
1747            close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
1748            for b in 0..2 {
1749                close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
1750            }
1751        }
1752    }
1753
1754    /// The `compose_unary_with` seam on a scalar jet is `to_bits`-identical to
1755    /// the explicit `compose_unary(stack_fn(value))` — the contract the batch
1756    /// arm (`Tower{3,4}Lane::compose_unary_with`) lane-matches. Exercised on
1757    /// [`Order2`] across `K ∈ {2,3,4,9}`, ≥ 4000 random seeded inputs.
1758    #[test]
1759    fn compose_unary_with_scalar_seam_bit_identical() {
1760        fn rand_unit(state: &mut u64) -> f64 {
1761            let mut x = *state;
1762            x ^= x << 13;
1763            x ^= x >> 7;
1764            x ^= x << 17;
1765            *state = x;
1766            2.0 * ((x >> 11) as f64 / ((1u64 << 53) as f64)) - 1.0
1767        }
1768        // A base-value-dependent finite stack standing in for a family stack.
1769        fn stack(u: f64) -> [f64; 5] {
1770            [
1771                u.sin(),
1772                u.cos(),
1773                (2.0 * u).sin(),
1774                (0.5 * u).cos(),
1775                u * u - 0.3,
1776            ]
1777        }
1778        fn run<const K: usize>(state: &mut u64, n: usize) -> usize {
1779            for _ in 0..n {
1780                // A non-trivial Order2<K> jet: a seeded variable pushed through a
1781                // couple of algebra ops so g/h are dense, then exercise the seam.
1782                let base = rand_unit(state);
1783                let mut s = Order2::<K>::variable(base, 0);
1784                for a in 1..K {
1785                    s = JetScalar::mul(&s, &Order2::<K>::variable(rand_unit(state), a));
1786                }
1787                let with = s.compose_unary_with(stack);
1788                let explicit = s.compose_unary(stack(s.value()));
1789                assert_eq!(with.value().to_bits(), explicit.value().to_bits(), "value");
1790                for a in 0..K {
1791                    assert_eq!(with.g()[a].to_bits(), explicit.g()[a].to_bits(), "g[{a}]");
1792                    for b in 0..K {
1793                        assert_eq!(
1794                            with.h()[a][b].to_bits(),
1795                            explicit.h()[a][b].to_bits(),
1796                            "h[{a}][{b}]"
1797                        );
1798                    }
1799                }
1800            }
1801            n
1802        }
1803        let mut st = 0x9e37_79b9_7f4a_7c15u64;
1804        let total = run::<2>(&mut st, 1100)
1805            + run::<3>(&mut st, 1100)
1806            + run::<4>(&mut st, 1100)
1807            + run::<9>(&mut st, 1100);
1808        assert_eq!(total, 4400);
1809    }
1810
1811    /// OneSeed's ε-Hessian is the contracted third Σ_c ℓ_{abc} u_c, matching
1812    /// `Tower4::third_contracted(u)`. Base channels also match the tower.
1813    #[test]
1814    fn one_seed_matches_tower_third_contracted() {
1815        let t = tower();
1816        let truth = t.third_contracted(&U);
1817        let vars: [OneSeed<2>; 2] =
1818            std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
1819        let s = row_expr(&vars);
1820        // Base channels are the plain (v, g, H).
1821        close(s.value(), t.v, "value");
1822        for a in 0..2 {
1823            for b in 0..2 {
1824                close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
1825            }
1826        }
1827        let third = s.contracted_third();
1828        for a in 0..2 {
1829            for b in 0..2 {
1830                close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
1831            }
1832        }
1833    }
1834
1835    /// TwoSeed's εδ-Hessian is the contracted fourth Σ_{cd} ℓ_{abcd} u_c v_d,
1836    /// matching `Tower4::fourth_contracted(u, v)`. The ε / δ single-seed parts
1837    /// reproduce the two third contractions Σ_c ℓ_{abc} u_c and …v_d.
1838    #[test]
1839    fn two_seed_matches_tower_fourth_contracted() {
1840        let t = tower();
1841        let truth4 = t.fourth_contracted(&U, &V);
1842        let truth3_u = t.third_contracted(&U);
1843        let truth3_v = t.third_contracted(&V);
1844        let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
1845        let s = row_expr(&vars);
1846        close(s.value(), t.v, "value");
1847        for a in 0..2 {
1848            close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
1849            for b in 0..2 {
1850                close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
1851                close(
1852                    s.eps.h()[a][b],
1853                    truth3_u[a][b],
1854                    &format!("eps third_u[{a}][{b}]"),
1855                );
1856                close(
1857                    s.del.h()[a][b],
1858                    truth3_v[a][b],
1859                    &format!("del third_v[{a}][{b}]"),
1860                );
1861            }
1862        }
1863        let fourth = s.contracted_fourth();
1864        for a in 0..2 {
1865            for b in 0..2 {
1866                close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
1867            }
1868        }
1869    }
1870
1871    /// The generic `row_nll_generic` seam (added to Tower4's program trait
1872    /// surface) evaluates the SAME expression on each scalar and extracts the
1873    /// channel a consumer asks for, agreeing with the direct Tower4 contraction.
1874    #[test]
1875    fn generic_program_seam_matches_tower_for_every_channel() {
1876        let t = tower();
1877        // Order2 via generic seam.
1878        let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
1879        let so2 = row_expr(&o2);
1880        close(so2.value(), t.v, "seam order2 value");
1881        // OneSeed third.
1882        let os: [OneSeed<2>; 2] =
1883            std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
1884        let third = row_expr(&os).contracted_third();
1885        let truth3 = t.third_contracted(&U);
1886        for a in 0..2 {
1887            for b in 0..2 {
1888                close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
1889            }
1890        }
1891        // TwoSeed fourth.
1892        let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
1893        let fourth = row_expr(&ts).contracted_fourth();
1894        let truth4 = t.fourth_contracted(&U, &V);
1895        for a in 0..2 {
1896            for b in 0..2 {
1897                close(
1898                    fourth[a][b],
1899                    truth4[a][b],
1900                    &format!("seam fourth[{a}][{b}]"),
1901                );
1902            }
1903        }
1904    }
1905
1906    /// The (test-only) `Tower4: JetScalar` impl is the all-channels oracle scalar:
1907    /// evaluating the SAME generic `row_expr` at `S = Tower4` (through the
1908    /// `JetScalar` trait ops) must reproduce, channel-for-channel, the `Tower4`
1909    /// obtained from the `RowNllProgram` / inherent-operator path
1910    /// (`evaluate_program`). This pins that the trait impl delegates faithfully to
1911    /// the inherent `Tower4` arithmetic (so the contracted-scalar oracles above,
1912    /// which compare against `evaluate_program`'s tower, are comparing against the
1913    /// same algebra the `JetScalar` interface exposes).
1914    #[test]
1915    fn tower4_as_jetscalar_matches_program_tower_all_channels() {
1916        let t = tower();
1917        let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
1918        let s = row_expr(&vars);
1919        close(s.v, t.v, "tower-jetscalar value");
1920        for a in 0..2 {
1921            close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
1922            for b in 0..2 {
1923                close(
1924                    s.h[a][b],
1925                    t.h[a][b],
1926                    &format!("tower-jetscalar hess[{a}][{b}]"),
1927                );
1928                for c in 0..2 {
1929                    close(
1930                        s.t3[a][b][c],
1931                        t.t3[a][b][c],
1932                        &format!("tower-jetscalar t3[{a}][{b}][{c}]"),
1933                    );
1934                    for d in 0..2 {
1935                        close(
1936                            s.t4[a][b][c][d],
1937                            t.t4[a][b][c][d],
1938                            &format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
1939                        );
1940                    }
1941                }
1942            }
1943        }
1944    }
1945}
1946
1947#[cfg(test)]
1948mod batch_tests {
1949    //! SIMD row-batching oracle: prove [`Order2Batch<K>`] (4 rows in
1950    //! `wide::f64x4` lanes) is `to_bits`-identical, on every value/gradient/
1951    //! Hessian channel, to the production [`Order2<K>`] evaluated per row — and
1952    //! that the new scalar field [`Order2Lane<f64, K>`] is too. Composing the two
1953    //! claims, batch lane `i` reproduces the production scalar for row `i` bit
1954    //! for bit, so the 4× throughput is a free lunch (no result change).
1955
1956    use super::{
1957        JetScalar, Lane, OneSeed, OneSeedBatch, OneSeedLane, Order2, Order2Batch, Order2Lane,
1958        TwoSeed, TwoSeedBatch, TwoSeedLane,
1959    };
1960
1961    /// The ops the witness row expression needs, so ONE generic body evaluates
1962    /// at the production [`Order2<K>`], the new scalar [`Order2Lane<f64, K>`],
1963    /// and the batched [`Order2Batch<K>`].
1964    trait RowAlg<const K: usize>: Copy {
1965        fn constant(c: f64) -> Self;
1966        fn add(&self, o: &Self) -> Self;
1967        fn sub(&self, o: &Self) -> Self;
1968        fn mul(&self, o: &Self) -> Self;
1969        fn scale(&self, s: f64) -> Self;
1970        fn exp(&self) -> Self;
1971        fn sqrt(&self) -> Self;
1972        fn recip(&self) -> Self;
1973    }
1974
1975    impl<const K: usize> RowAlg<K> for Order2<K> {
1976        fn constant(c: f64) -> Self {
1977            <Self as JetScalar<K>>::constant(c)
1978        }
1979        fn add(&self, o: &Self) -> Self {
1980            JetScalar::add(self, o)
1981        }
1982        fn sub(&self, o: &Self) -> Self {
1983            JetScalar::sub(self, o)
1984        }
1985        fn mul(&self, o: &Self) -> Self {
1986            JetScalar::mul(self, o)
1987        }
1988        fn scale(&self, s: f64) -> Self {
1989            JetScalar::scale(self, s)
1990        }
1991        fn exp(&self) -> Self {
1992            JetScalar::exp(self)
1993        }
1994        fn sqrt(&self) -> Self {
1995            JetScalar::sqrt(self)
1996        }
1997        fn recip(&self) -> Self {
1998            JetScalar::recip(self)
1999        }
2000    }
2001
2002    impl<L: Lane, const K: usize> RowAlg<K> for Order2Lane<L, K> {
2003        fn constant(c: f64) -> Self {
2004            Order2Lane::constant(L::splat(c))
2005        }
2006        fn add(&self, o: &Self) -> Self {
2007            Order2Lane::add(self, o)
2008        }
2009        fn sub(&self, o: &Self) -> Self {
2010            Order2Lane::sub(self, o)
2011        }
2012        fn mul(&self, o: &Self) -> Self {
2013            Order2Lane::mul(self, o)
2014        }
2015        fn scale(&self, s: f64) -> Self {
2016            Order2Lane::scale(self, s)
2017        }
2018        fn exp(&self) -> Self {
2019            Order2Lane::exp(self)
2020        }
2021        fn sqrt(&self) -> Self {
2022            Order2Lane::sqrt(self)
2023        }
2024        fn recip(&self) -> Self {
2025            Order2Lane::recip(self)
2026        }
2027    }
2028
2029    /// A dense witness row expression touching every algebra op (mul, add, sub,
2030    /// scale, exp, sqrt, recip) over ALL `K` primaries, so the gradient and the
2031    /// full `K×K` Hessian are dense (no trivially-zero channel). All transcend.
2032    /// arguments are kept finite/positive: `sqrt(s²+1) > 0`, `recip(exp+2) > 0`.
2033    fn row_expr<const K: usize, A: RowAlg<K>>(p: &[A; K]) -> A {
2034        let mut s = A::constant(0.3);
2035        for a in 0..K {
2036            let b = (a + 1) % K;
2037            s = s.add(&p[a].mul(&p[b]).scale(0.1 + 0.05 * a as f64));
2038        }
2039        let e = s.exp();
2040        let r = s.mul(&s).add(&A::constant(1.0)).sqrt();
2041        let denom = e.add(&A::constant(2.0));
2042        e.mul(&r).sub(&s.scale(0.5)).mul(&denom.recip())
2043    }
2044
2045    /// xorshift64 → `f64` in `[-1, 1)`.
2046    fn rand_unit(state: &mut u64) -> f64 {
2047        let mut x = *state;
2048        x ^= x << 13;
2049        x ^= x >> 7;
2050        x ^= x << 17;
2051        *state = x;
2052        let u = (x >> 11) as f64 / ((1u64 << 53) as f64); // [0, 1)
2053        2.0 * u - 1.0
2054    }
2055
2056    /// Returns the number of (batch, row) pairs whose every channel was
2057    /// verified bit-identical, so the caller can assert the expected total ran.
2058    fn check_k<const K: usize>(state: &mut u64, batches: usize) -> usize {
2059        let mut verified_rows = 0usize;
2060        for _ in 0..batches {
2061            // Four independent rows of K primary values.
2062            let rows: [[f64; K]; 4] =
2063                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2064
2065            // Production ground truth, evaluated per row at Order2<K>.
2066            let prod: [Order2<K>; 4] = std::array::from_fn(|r| {
2067                let p: [Order2<K>; K] = std::array::from_fn(|a| Order2::variable(rows[r][a], a));
2068                row_expr(&p)
2069            });
2070
2071            // New scalar field (Order2Lane<f64>), per row.
2072            let scal: [Order2Lane<f64, K>; 4] = std::array::from_fn(|r| {
2073                let p: [Order2Lane<f64, K>; K] =
2074                    std::array::from_fn(|a| Order2Lane::variable(rows[r][a], a));
2075                row_expr(&p)
2076            });
2077
2078            // Batched: 4 rows packed into f64x4 lanes, ONE vector pass.
2079            let pbatch: [Order2Batch<K>; K] = std::array::from_fn(|a| {
2080                let packed = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2081                Order2Batch::variable(packed, a)
2082            });
2083            let batch = row_expr(&pbatch);
2084
2085            for r in 0..4 {
2086                let g = prod[r].0;
2087                // Order2Lane<f64> == Order2<K> (bit-identical scalar field).
2088                assert_eq!(scal[r].v.to_bits(), g.v.to_bits(), "K={K} scalar v");
2089                // Batch lane r == Order2<K> for row r.
2090                let lr = batch.lane(r).0;
2091                assert_eq!(lr.v.to_bits(), g.v.to_bits(), "K={K} batch lane {r} v");
2092                for a in 0..K {
2093                    assert_eq!(
2094                        scal[r].g[a].to_bits(),
2095                        g.g[a].to_bits(),
2096                        "K={K} scalar g[{a}]"
2097                    );
2098                    assert_eq!(
2099                        lr.g[a].to_bits(),
2100                        g.g[a].to_bits(),
2101                        "K={K} batch lane {r} g[{a}]"
2102                    );
2103                    for b in 0..K {
2104                        assert_eq!(
2105                            scal[r].h[a][b].to_bits(),
2106                            g.h[a][b].to_bits(),
2107                            "K={K} scalar h[{a}][{b}]"
2108                        );
2109                        assert_eq!(
2110                            lr.h[a][b].to_bits(),
2111                            g.h[a][b].to_bits(),
2112                            "K={K} batch lane {r} h[{a}][{b}]"
2113                        );
2114                    }
2115                }
2116                verified_rows += 1;
2117            }
2118        }
2119        verified_rows
2120    }
2121
2122    /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: every channel of
2123    /// every lane is `to_bits`-identical to the production scalar per row.
2124    #[test]
2125    fn batch_lanes_bit_identical_to_scalar_per_row() {
2126        let mut state = 0x9E37_79B9_7F4A_7C15_u64;
2127        let mut verified = 0usize;
2128        verified += check_k::<2>(&mut state, 2000);
2129        verified += check_k::<3>(&mut state, 2000);
2130        verified += check_k::<4>(&mut state, 2000);
2131        verified += check_k::<9>(&mut state, 2000);
2132        // 4 K-values × 2000 batches × 4 packed rows each, all bit-identical.
2133        assert_eq!(verified, 4 * 2000 * 4, "every batch row must be verified");
2134    }
2135
2136    // ── One-/two-seed lane oracles ──────────────────────────────────────────
2137    //
2138    // The same dense `row_expr` witness program runs over the SEEDED directional
2139    // scalars: the scalar `OneSeed`/`TwoSeed` per row, the `f64`-lane re-type
2140    // (`*SeedLane<f64>`), and the 4-rows-per-pass batch (`*SeedBatch`). The
2141    // headline claim is that the contracted-third / contracted-fourth channel of
2142    // every lane is `to_bits`-identical to the production scalar's per row.
2143
2144    impl<const K: usize> RowAlg<K> for OneSeed<K> {
2145        fn constant(c: f64) -> Self {
2146            <Self as JetScalar<K>>::constant(c)
2147        }
2148        fn add(&self, o: &Self) -> Self {
2149            JetScalar::add(self, o)
2150        }
2151        fn sub(&self, o: &Self) -> Self {
2152            JetScalar::sub(self, o)
2153        }
2154        fn mul(&self, o: &Self) -> Self {
2155            JetScalar::mul(self, o)
2156        }
2157        fn scale(&self, s: f64) -> Self {
2158            JetScalar::scale(self, s)
2159        }
2160        fn exp(&self) -> Self {
2161            JetScalar::exp(self)
2162        }
2163        fn sqrt(&self) -> Self {
2164            JetScalar::sqrt(self)
2165        }
2166        fn recip(&self) -> Self {
2167            JetScalar::recip(self)
2168        }
2169    }
2170
2171    impl<L: Lane, const K: usize> RowAlg<K> for OneSeedLane<L, K> {
2172        fn constant(c: f64) -> Self {
2173            OneSeedLane::constant(L::splat(c))
2174        }
2175        fn add(&self, o: &Self) -> Self {
2176            OneSeedLane::add(self, o)
2177        }
2178        fn sub(&self, o: &Self) -> Self {
2179            OneSeedLane::sub(self, o)
2180        }
2181        fn mul(&self, o: &Self) -> Self {
2182            OneSeedLane::mul(self, o)
2183        }
2184        fn scale(&self, s: f64) -> Self {
2185            OneSeedLane::scale(self, s)
2186        }
2187        fn exp(&self) -> Self {
2188            OneSeedLane::exp(self)
2189        }
2190        fn sqrt(&self) -> Self {
2191            OneSeedLane::sqrt(self)
2192        }
2193        fn recip(&self) -> Self {
2194            OneSeedLane::recip(self)
2195        }
2196    }
2197
2198    impl<const K: usize> RowAlg<K> for TwoSeed<K> {
2199        fn constant(c: f64) -> Self {
2200            <Self as JetScalar<K>>::constant(c)
2201        }
2202        fn add(&self, o: &Self) -> Self {
2203            JetScalar::add(self, o)
2204        }
2205        fn sub(&self, o: &Self) -> Self {
2206            JetScalar::sub(self, o)
2207        }
2208        fn mul(&self, o: &Self) -> Self {
2209            JetScalar::mul(self, o)
2210        }
2211        fn scale(&self, s: f64) -> Self {
2212            JetScalar::scale(self, s)
2213        }
2214        fn exp(&self) -> Self {
2215            JetScalar::exp(self)
2216        }
2217        fn sqrt(&self) -> Self {
2218            JetScalar::sqrt(self)
2219        }
2220        fn recip(&self) -> Self {
2221            JetScalar::recip(self)
2222        }
2223    }
2224
2225    impl<L: Lane, const K: usize> RowAlg<K> for TwoSeedLane<L, K> {
2226        fn constant(c: f64) -> Self {
2227            TwoSeedLane::constant(L::splat(c))
2228        }
2229        fn add(&self, o: &Self) -> Self {
2230            TwoSeedLane::add(self, o)
2231        }
2232        fn sub(&self, o: &Self) -> Self {
2233            TwoSeedLane::sub(self, o)
2234        }
2235        fn mul(&self, o: &Self) -> Self {
2236            TwoSeedLane::mul(self, o)
2237        }
2238        fn scale(&self, s: f64) -> Self {
2239            TwoSeedLane::scale(self, s)
2240        }
2241        fn exp(&self) -> Self {
2242            TwoSeedLane::exp(self)
2243        }
2244        fn sqrt(&self) -> Self {
2245            TwoSeedLane::sqrt(self)
2246        }
2247        fn recip(&self) -> Self {
2248            TwoSeedLane::recip(self)
2249        }
2250    }
2251
2252    fn check_oneseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
2253        let mut rows_checked = 0;
2254        for _ in 0..batches {
2255            let rows: [[f64; K]; 4] =
2256                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2257            // Per-row ε-direction.
2258            let u: [[f64; K]; 4] =
2259                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2260
2261            // Production ground truth (scalar OneSeed per row).
2262            let prod: [OneSeed<K>; 4] = std::array::from_fn(|r| {
2263                let p: [OneSeed<K>; K] =
2264                    std::array::from_fn(|a| OneSeed::seed_direction(rows[r][a], a, u[r][a]));
2265                row_expr(&p)
2266            });
2267
2268            // f64-lane re-type per row.
2269            let scal: [OneSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
2270                let p: [OneSeedLane<f64, K>; K] =
2271                    std::array::from_fn(|a| OneSeedLane::seed_direction(rows[r][a], a, u[r][a]));
2272                row_expr(&p)
2273            });
2274
2275            // 4-rows-per-pass batch.
2276            let pbatch: [OneSeedBatch<K>; K] = std::array::from_fn(|a| {
2277                let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2278                let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
2279                OneSeedBatch::seed_direction(val, a, uu)
2280            });
2281            let batch = row_expr(&pbatch);
2282
2283            for r in 0..4 {
2284                let want = prod[r].contracted_third();
2285                let got_scal = scal[r].contracted_third();
2286                let got_batch = batch.lane(r).contracted_third();
2287                // Value channel too (sanity that the base program agrees).
2288                assert_eq!(
2289                    scal[r].base.v.to_bits(),
2290                    prod[r].base.value().to_bits(),
2291                    "OneSeed K={K} scalar value"
2292                );
2293                assert_eq!(
2294                    batch.lane(r).base.value().to_bits(),
2295                    prod[r].base.value().to_bits(),
2296                    "OneSeed K={K} batch lane {r} value"
2297                );
2298                for a in 0..K {
2299                    for b in 0..K {
2300                        assert_eq!(
2301                            got_scal[a][b].to_bits(),
2302                            want[a][b].to_bits(),
2303                            "OneSeed K={K} scalar third[{a}][{b}]"
2304                        );
2305                        assert_eq!(
2306                            got_batch[a][b].to_bits(),
2307                            want[a][b].to_bits(),
2308                            "OneSeed K={K} batch lane {r} third[{a}][{b}]"
2309                        );
2310                    }
2311                }
2312                rows_checked += 1;
2313            }
2314        }
2315        rows_checked
2316    }
2317
2318    fn check_twoseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
2319        let mut rows_checked = 0;
2320        for _ in 0..batches {
2321            let rows: [[f64; K]; 4] =
2322                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2323            let u: [[f64; K]; 4] =
2324                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2325            let v: [[f64; K]; 4] =
2326                std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2327
2328            let prod: [TwoSeed<K>; 4] = std::array::from_fn(|r| {
2329                let p: [TwoSeed<K>; K] =
2330                    std::array::from_fn(|a| TwoSeed::seed(rows[r][a], a, u[r][a], v[r][a]));
2331                row_expr(&p)
2332            });
2333
2334            let scal: [TwoSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
2335                let p: [TwoSeedLane<f64, K>; K] =
2336                    std::array::from_fn(|a| TwoSeedLane::seed(rows[r][a], a, u[r][a], v[r][a]));
2337                row_expr(&p)
2338            });
2339
2340            let pbatch: [TwoSeedBatch<K>; K] = std::array::from_fn(|a| {
2341                let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2342                let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
2343                let vv = wide::f64x4::new([v[0][a], v[1][a], v[2][a], v[3][a]]);
2344                TwoSeedBatch::seed(val, a, uu, vv)
2345            });
2346            let batch = row_expr(&pbatch);
2347
2348            for r in 0..4 {
2349                let want = prod[r].contracted_fourth();
2350                let got_scal = scal[r].contracted_fourth();
2351                let got_batch = batch.lane(r).contracted_fourth();
2352                assert_eq!(
2353                    scal[r].base.v.to_bits(),
2354                    prod[r].base.value().to_bits(),
2355                    "TwoSeed K={K} scalar value"
2356                );
2357                assert_eq!(
2358                    batch.lane(r).base.value().to_bits(),
2359                    prod[r].base.value().to_bits(),
2360                    "TwoSeed K={K} batch lane {r} value"
2361                );
2362                for a in 0..K {
2363                    for b in 0..K {
2364                        assert_eq!(
2365                            got_scal[a][b].to_bits(),
2366                            want[a][b].to_bits(),
2367                            "TwoSeed K={K} scalar fourth[{a}][{b}]"
2368                        );
2369                        assert_eq!(
2370                            got_batch[a][b].to_bits(),
2371                            want[a][b].to_bits(),
2372                            "TwoSeed K={K} batch lane {r} fourth[{a}][{b}]"
2373                        );
2374                    }
2375                }
2376                rows_checked += 1;
2377            }
2378        }
2379        rows_checked
2380    }
2381
2382    /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: the
2383    /// contracted-third channel of every `OneSeedLane` lane is `to_bits`-identical
2384    /// to the production [`OneSeed`] per row.
2385    #[test]
2386    fn oneseed_lanes_contracted_third_bit_identical() {
2387        let mut state = 0x1234_5678_9ABC_DEF0_u64;
2388        let batches = 2000;
2389        let rows_checked = check_oneseed::<2>(&mut state, batches)
2390            + check_oneseed::<3>(&mut state, batches)
2391            + check_oneseed::<4>(&mut state, batches)
2392            + check_oneseed::<9>(&mut state, batches);
2393        // 4 widths × `batches` batches × 4 rows each: a silently empty inner
2394        // loop would leave this at zero instead of passing as a no-op.
2395        assert_eq!(rows_checked, 4 * batches * 4);
2396    }
2397
2398    /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: the
2399    /// contracted-fourth channel of every `TwoSeedLane` lane is `to_bits`-identical
2400    /// to the production [`TwoSeed`] per row.
2401    #[test]
2402    fn twoseed_lanes_contracted_fourth_bit_identical() {
2403        let mut state = 0x0FED_CBA9_8765_4321_u64;
2404        let batches = 2000;
2405        let rows_checked = check_twoseed::<2>(&mut state, batches)
2406            + check_twoseed::<3>(&mut state, batches)
2407            + check_twoseed::<4>(&mut state, batches)
2408            + check_twoseed::<9>(&mut state, batches);
2409        // 4 widths × `batches` batches × 4 rows each: a silently empty inner
2410        // loop would leave this at zero instead of passing as a no-op.
2411        assert_eq!(rows_checked, 4 * batches * 4);
2412    }
2413}
2414
2415#[cfg(test)]
2416mod unit_tests {
2417    use super::{JetScalar, Order1, Order2, filtered_implicit_solve_scalar};
2418
2419    // ── Order2 direct property tests ─────────────────────────────────────────
2420
2421    /// `Order2::constant(c)` carries value `c` and zero everywhere else.
2422    #[test]
2423    fn order2_constant_has_zero_derivatives() {
2424        let s = Order2::<3>::constant(7.5);
2425        assert_eq!(s.value(), 7.5);
2426        for a in 0..3 {
2427            assert_eq!(s.g()[a], 0.0, "grad[{a}] should be zero");
2428            for b in 0..3 {
2429                assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
2430            }
2431        }
2432    }
2433
2434    /// `Order2::variable(x, axis)` has unit gradient in slot `axis` and zero Hessian.
2435    #[test]
2436    fn order2_variable_has_unit_gradient_in_seeded_slot() {
2437        let x = -2.5_f64;
2438        let s = Order2::<4>::variable(x, 2);
2439        assert_eq!(s.value(), x);
2440        for a in 0..4 {
2441            let expected_g = if a == 2 { 1.0 } else { 0.0 };
2442            assert_eq!(s.g()[a], expected_g, "grad[{a}]");
2443            for b in 0..4 {
2444                assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
2445            }
2446        }
2447    }
2448
2449    /// `Order2::add` sums gradient channels; `sub` is the inverse on gradients.
2450    /// Uses integer-valued primaries so the value roundtrip is also exact.
2451    #[test]
2452    fn order2_add_sub_roundtrip() {
2453        let p = Order2::<2>::variable(3.0, 0);
2454        let q = Order2::<2>::variable(2.0, 1);
2455        let pq = JetScalar::add(&p, &q);
2456        // value = 3 + 2 = 5
2457        assert_eq!(pq.value(), 5.0, "add value");
2458        let back = JetScalar::sub(&pq, &q);
2459        // (p + q) - q gradient should equal p's gradient exactly
2460        for a in 0..2 {
2461            assert_eq!(back.g()[a], p.g()[a], "grad[{a}] roundtrip");
2462        }
2463    }
2464
2465    /// `Order2::mul` of two variables satisfies the Leibniz product rule:
2466    ///   ∂(p·q)/∂p = q,  ∂(p·q)/∂q = p,  ∂²(p·q)/∂p∂q = 1.
2467    #[test]
2468    fn order2_mul_satisfies_leibniz_rule() {
2469        let pv = 3.0_f64;
2470        let qv = -2.0_f64;
2471        let p = Order2::<2>::variable(pv, 0);
2472        let q = Order2::<2>::variable(qv, 1);
2473        let pq = JetScalar::mul(&p, &q);
2474        assert_eq!(pq.value(), pv * qv, "value = p·q");
2475        assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
2476        assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
2477        assert_eq!(pq.h()[0][1], 1.0, "∂²(p·q)/∂p∂q = 1");
2478        assert_eq!(pq.h()[1][0], 1.0, "∂²(p·q)/∂q∂p = 1 (symmetric)");
2479        assert_eq!(pq.h()[0][0], 0.0, "∂²(p·q)/∂p² = 0");
2480        assert_eq!(pq.h()[1][1], 0.0, "∂²(p·q)/∂q² = 0");
2481    }
2482
2483    /// `Order2::scale(s)` multiplies every channel by `s`.
2484    #[test]
2485    fn order2_scale_multiplies_all_channels() {
2486        let p = Order2::<2>::variable(4.0, 0);
2487        let s = 2.5_f64;
2488        let ps = JetScalar::scale(&p, s);
2489        assert_eq!(ps.value(), 4.0 * s);
2490        assert_eq!(ps.g()[0], 1.0 * s);
2491        assert_eq!(ps.g()[1], 0.0);
2492    }
2493
2494    /// `Order2::exp` at a constant has value `e^c`, gradient `e^c * g`, Hessian `e^c * (g⊗g + H)`.
2495    /// At a seeded variable `p₀`, the first derivative is `e^{p₀}` and second is `e^{p₀}`.
2496    #[test]
2497    fn order2_exp_derivative_stack_correct() {
2498        let p0 = 1.0_f64;
2499        let p = Order2::<1>::variable(p0, 0);
2500        let ep = JetScalar::exp(&p);
2501        let e = p0.exp();
2502        assert!((ep.value() - e).abs() < 1e-15, "exp value");
2503        assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p) = exp(p)");
2504        assert!((ep.h()[0][0] - e).abs() < 1e-15, "d²/dp² exp(p) = exp(p)");
2505    }
2506
2507    /// `Order2::ln` at a seeded variable: d/dp ln(p) = 1/p, d²/dp² ln(p) = -1/p².
2508    #[test]
2509    fn order2_ln_derivative_stack_correct() {
2510        let p0 = 2.0_f64;
2511        let p = Order2::<1>::variable(p0, 0);
2512        let lnp = JetScalar::ln(&p);
2513        assert!((lnp.value() - p0.ln()).abs() < 1e-15, "ln value");
2514        assert!((lnp.g()[0] - 1.0 / p0).abs() < 1e-15, "d/dp ln(p) = 1/p");
2515        assert!(
2516            (lnp.h()[0][0] - (-1.0 / (p0 * p0))).abs() < 1e-15,
2517            "d²/dp² ln(p) = -1/p²"
2518        );
2519    }
2520
2521    /// `exp` and `ln` are mutual inverses: `ln(exp(p)).value() == p` at the scalar.
2522    #[test]
2523    fn order2_exp_ln_roundtrip_at_value() {
2524        let p0 = 0.8_f64;
2525        let p = Order2::<1>::variable(p0, 0);
2526        let roundtrip = JetScalar::ln(&JetScalar::exp(&p));
2527        assert!((roundtrip.value() - p0).abs() < 1e-14, "ln(exp(p)) ≈ p");
2528    }
2529
2530    // ── Order1 tests ─────────────────────────────────────────────────────────
2531
2532    /// `Order1::constant` carries the correct value with all-zero gradient.
2533    #[test]
2534    fn order1_constant_has_zero_gradient() {
2535        let s = Order1::<3>::constant(-5.0);
2536        assert_eq!(s.value(), -5.0);
2537        for a in 0..3 {
2538            assert_eq!(s.g()[a], 0.0, "g[{a}] should be zero");
2539        }
2540    }
2541
2542    /// `Order1::variable(x, axis)` has unit gradient only in `axis`.
2543    #[test]
2544    fn order1_variable_has_unit_gradient_in_seeded_slot() {
2545        let s = Order1::<3>::variable(2.0, 1);
2546        assert_eq!(s.value(), 2.0);
2547        assert_eq!(s.g()[0], 0.0);
2548        assert_eq!(s.g()[1], 1.0);
2549        assert_eq!(s.g()[2], 0.0);
2550    }
2551
2552    /// `Order1::mul` satisfies the product rule (value and gradient, no Hessian).
2553    #[test]
2554    fn order1_mul_satisfies_product_rule() {
2555        let pv = 3.0_f64;
2556        let qv = -2.0_f64;
2557        let p = Order1::<2>::variable(pv, 0);
2558        let q = Order1::<2>::variable(qv, 1);
2559        let pq = JetScalar::mul(&p, &q);
2560        assert_eq!(pq.value(), pv * qv);
2561        assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
2562        assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
2563    }
2564
2565    /// `Order1::exp` carries the correct value and gradient `e^{p₀}`.
2566    #[test]
2567    fn order1_exp_has_correct_value_and_gradient() {
2568        let p0 = 0.5_f64;
2569        let p = Order1::<2>::variable(p0, 0);
2570        let ep = JetScalar::exp(&p);
2571        let e = p0.exp();
2572        assert!((ep.value() - e).abs() < 1e-15, "exp value");
2573        assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p)");
2574        assert_eq!(ep.g()[1], 0.0, "irrelevant gradient slot is zero");
2575    }
2576
2577    /// `Order1` and `Order2` agree on value and gradient for the same expression.
2578    #[test]
2579    fn order1_and_order2_agree_on_value_and_gradient() {
2580        let p0 = 1.3_f64;
2581        let q0 = -0.7_f64;
2582        // evaluate (p * q + p).exp() at (p0, q0)
2583        let p1 = Order1::<2>::variable(p0, 0);
2584        let q1 = Order1::<2>::variable(q0, 1);
2585        let expr1 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p1, &q1), &p1));
2586
2587        let p2 = Order2::<2>::variable(p0, 0);
2588        let q2 = Order2::<2>::variable(q0, 1);
2589        let expr2 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p2, &q2), &p2));
2590
2591        assert!(
2592            (expr1.value() - expr2.value()).abs() < 1e-14,
2593            "value mismatch"
2594        );
2595        for a in 0..2 {
2596            assert!(
2597                (expr1.g()[a] - expr2.g()[a]).abs() < 1e-14,
2598                "gradient[{a}] mismatch"
2599            );
2600        }
2601    }
2602
2603    // ── filtered_implicit_solve_scalar ────────────────────────────────────────
2604
2605    /// Lift the trivial linear constraint F(a, θ) = a - θ = 0 through `Order2<1>`.
2606    /// The exact lifted jet is a(θ) = θ, so value=θ₀, gradient=1.
2607    #[test]
2608    fn filtered_implicit_solve_linear_constraint_gives_exact_jet() {
2609        let theta0 = 3.0_f64;
2610        let theta = Order2::<1>::variable(theta0, 0);
2611        // a0 = theta0, F_a = 1, inv_fa = 1; 2 iters suffice for Order2.
2612        let a = filtered_implicit_solve_scalar::<1, Order2<1>>(theta0, 1.0, 2, |a_jet| {
2613            JetScalar::sub(a_jet, &theta)
2614        });
2615        assert!((a.value() - theta0).abs() < 1e-14, "value = theta0");
2616        // da/dtheta = 1 (identity)
2617        assert!((a.g()[0] - 1.0).abs() < 1e-14, "gradient = 1");
2618        // d²a/dtheta² = 0 (linear)
2619        assert!(a.h()[0][0].abs() < 1e-14, "hessian = 0");
2620    }
2621
2622    /// `filtered_implicit_solve_scalar` on a quadratic constraint F(a,θ)=a²-θ=0
2623    /// with primal root a₀=√θ₀, giving da/dθ = 1/(2√θ₀), d²a/dθ² = -1/(4θ₀^{3/2}).
2624    #[test]
2625    fn filtered_implicit_solve_quadratic_constraint_matches_analytic_derivatives() {
2626        let theta0 = 4.0_f64;
2627        let a0 = theta0.sqrt();
2628        let inv_fa = 1.0 / (2.0 * a0);
2629        let theta = Order2::<1>::variable(theta0, 0);
2630        // F(a,theta) = a*a - theta
2631        let a = filtered_implicit_solve_scalar::<1, Order2<1>>(a0, inv_fa, 2, |a_jet| {
2632            let aa = JetScalar::mul(a_jet, a_jet);
2633            JetScalar::sub(&aa, &theta)
2634        });
2635        let tol = 1e-12;
2636        assert!((a.value() - a0).abs() < tol, "value = sqrt(theta0)");
2637        let expected_g = 0.5 / a0;
2638        assert!(
2639            (a.g()[0] - expected_g).abs() < tol,
2640            "da/dtheta = 1/(2*sqrt)"
2641        );
2642        let expected_h = -0.25 / (theta0 * a0);
2643        assert!(
2644            (a.h()[0][0] - expected_h).abs() < tol,
2645            "d2a/dtheta2 = -1/(4*theta^1.5)"
2646        );
2647    }
2648}