gam_math/jet_scalar.rs
1//! Order-specific Taylor-jet SCALAR algebras (#932 cutover, doc §A).
2//!
3//! [`crate::jet_tower::Tower4`] carries the full value/gradient/Hessian/`t3`/`t4`
4//! tensor stack: it answers EVERY channel a [`super::row_kernel::RowKernel`]
5//! consumer can ask for, but at `K = 9` that is a ~50 KiB per-row object whose
6//! by-value copies overflowed the stack and timed out the location-scale fit —
7//! which is exactly why `row_kernel_directional_supported()` /
8//! `row_kernel_joint_hessian_supported()` still `return false`. The cutover does
9//! NOT need the dense `Tower4<9>` per row; it needs, per consumer, only the one
10//! channel that consumer serves:
11//!
12//! | consumer | channel | scalar here | K=9 size |
13//! |---|---|---|---|
14//! | inner Newton / `row_kernel` | `(v, g, H)` | [`Order2`] | 728 B |
15//! | `row_third_contracted(dir)` | `Σ_c ℓ_{abc} dir_c` | [`OneSeed`] | 1.46 KiB |
16//! | `row_fourth_contracted(u, v)` | `Σ_{cd} ℓ_{abcd} u_c v_d` | [`TwoSeed`] | 2.8 KiB |
17//!
18//! Each is built on [`Order2`] (value/grad/Hessian), which is the production
19//! [`crate::jet_tower::Tower2`] re-expressed behind a generic interface: a row
20//! loss written ONCE against [`JetScalar`] re-instantiates at whatever order /
21//! representation a consumer needs, with the contraction folded INTO the
22//! differentiation (the nilpotent ε / δ directions), so `t3` / `t4` are never
23//! materialised. The single source of truth is the same one expression — the
24//! genus of #736 cross-block drift cannot reappear because there is no separate
25//! channel to forget.
26//!
27//! # Why each scalar is exact (doc §A.1–A.3)
28//!
29//! * [`Order2`] is the order-≤2 truncation of the Leibniz / Faà di Bruno rules.
30//! Those order-2 terms read ONLY the order-≤2 channels of their inputs (see
31//! [`crate::jet_tower::Tower4::mul`]: `out.h[i][j]` never touches `t3`/`t4`),
32//! so its `(v, g, H)` is BIT-IDENTICAL to a full `Tower4<K>` — and identical
33//! to [`crate::jet_tower::Tower2`], over which it is a thin newtype.
34//! * [`OneSeed`] carries an [`Order2`] base plus one nilpotent ε (`ε² = 0`)
35//! holding another [`Order2`]. Seeding ε with the fixed direction `u` makes the
36//! ε-component of the Hessian channel the contracted third `Σ_c ℓ_{abc} u_c`
37//! (the nilpotent implements `d/dτ|₀` of `ℓ_{ab}(p + τu)` exactly).
38//! * [`TwoSeed`] carries an [`Order2`] base plus ε, δ (`ε² = δ² = 0`, `εδ`
39//! retained) — four [`Order2`] parts. Seeding ε, δ with `u, v` makes the
40//! εδ-component of the Hessian channel the contracted fourth
41//! `Σ_{cd} ℓ_{abcd} u_c v_d` (the single mixed `∂_σ∂_ρ|₀` term, no `σ²`/`ρ²`
42//! contamination).
43//!
44//! # Stability discipline
45//!
46//! As in [`crate::jet_tower`], humans own primitive stability and the algebra
47//! owns combinatorics: tail-critical special functions enter ONLY as
48//! hand-certified `[f64; 5]` derivative stacks through [`JetScalar::compose_unary`]
49//! (each scalar consumes the leading entries its order needs), never by
50//! differentiating an unstable primal.
51//!
52//! # Production scalars and the test-only all-channels oracle
53//!
54//! The `JetScalar` trait below is production: it is the bound on
55//! [`crate::jet_tower::RowNllProgramGeneric::row_nll_generic`], the seam a family
56//! row loss is written against. The order-specific scalars that *consume* it —
57//! [`Order2`] (value/grad/Hessian), [`OneSeed`] (contracted third) and
58//! [`TwoSeed`] (contracted fourth) — are production: the survival location-scale
59//! `RowKernel<9>` builds its joint Hessian / directional derivatives through them
60//! (`survival::location_scale::row_kernel`), paying only the small packed scalar
61//! per row instead of the ~50 KiB dense [`crate::jet_tower::Tower4`].
62//!
63//! The [`crate::jet_tower::Tower4`] all-channels `JetScalar` impl is test-only: it
64//! is the oracle that pins the contracted scalars against the dense
65//! value/grad/Hessian/`t3`/`t4` truth, so it lives in the `#[cfg(test)]` module.
66
67/// A truncated-Taylor scalar carrying derivatives in `K` primaries.
68///
69/// All concrete scalars here ([`Order2`], [`OneSeed`], [`TwoSeed`]) and the full
70/// [`crate::jet_tower::Tower4`] implement the SAME algebra; only the carried
71/// channel set differs. A row loss written once against this interface yields a
72/// different channel set per instantiation, all exact for the channel they serve
73/// (doc §A.0).
74pub trait JetScalar<const K: usize>: Copy {
75 /// A constant: value `c`, every derivative channel zero.
76 fn constant(c: f64) -> Self;
77
78 /// The seeded variable `p_axis` at value `x`: unit first derivative in slot
79 /// `axis`, all higher channels zero. (The nilpotent / cross channels of the
80 /// directional scalars are seeded zero — callers set ε/δ directions through
81 /// the scalar-specific [`OneSeed::seed_direction`] / [`TwoSeed::seed`].)
82 fn variable(x: f64, axis: usize) -> Self;
83
84 /// The value channel `ℓ(p)`.
85 fn value(&self) -> f64;
86
87 /// Exact truncated Leibniz sum `self + o`.
88 fn add(&self, o: &Self) -> Self;
89 /// Exact truncated Leibniz difference `self − o`.
90 fn sub(&self, o: &Self) -> Self;
91 /// Exact truncated Leibniz product `self · o`.
92 fn mul(&self, o: &Self) -> Self;
93 /// Negate every channel.
94 fn neg(&self) -> Self;
95 /// Multiply every channel by a plain scalar `s`.
96 fn scale(&self, s: f64) -> Self;
97
98 /// Exact multivariate Faà di Bruno composition `f ∘ self`, given the outer
99 /// derivative stack `d = [f(u), f′(u), f″(u), f‴(u), f⁗(u)]` at
100 /// `u = self.value()`.
101 ///
102 /// This is the SAME `[f64; 5]` stack shape [`crate::jet_tower::Tower4`] and
103 /// the families' `unary_derivatives_*` helpers (built on erfcx / log_ndtr)
104 /// already produce, so those stacks plug in directly. Each scalar consumes
105 /// only the leading entries its order needs (order-2 reads `d[0..=2]`; the
106 /// directional scalars read one / two beyond their base) — the fixed-length
107 /// array makes that windowing total, no length guard required.
108 fn compose_unary(&self, d: [f64; 5]) -> Self;
109
110 /// Compose with a unary special-function whose derivative STACK is built
111 /// from the scalar base value through `stack_fn` — the generic-over-`Lane`
112 /// seam that lets a single-sourced row program instantiate at BOTH the scalar
113 /// `f64` jets and the SIMD `f64x4` batch towers from ONE expression.
114 ///
115 /// On a scalar jet this evaluates `stack_fn(self.value())` ONCE and forwards
116 /// to [`compose_unary`](Self::compose_unary), so it is BIT-IDENTICAL to the
117 /// hand-written `self.compose_unary(stack_fn(self.value()))` (default body
118 /// below). The lever is that the SAME call shape exists on
119 /// [`crate::jet_tower::Tower3Lane`] / [`crate::jet_tower::Tower4Lane`], where
120 /// the four lanes carry FOUR DISTINCT base values, so the batch
121 /// implementation re-runs `stack_fn` per lane — a thing the old
122 /// `compose_unary(stack_from(self.value()))` shape could not express on a
123 /// batch type (it has no single scalar `.value()`). Writing a row program
124 /// against this method instead of the explicit two-step is what makes it
125 /// instantiate, unchanged, at `f64x4` for the 4-rows-per-pass batch path.
126 fn compose_unary_with(&self, stack_fn: impl Fn(f64) -> [f64; 5]) -> Self {
127 self.compose_unary(stack_fn(self.value()))
128 }
129
130 /// `e^self`. Convenience for tame arguments (see module stability note).
131 fn exp(&self) -> Self {
132 let e = self.value().exp();
133 self.compose_unary([e, e, e, e, e])
134 }
135
136 /// `√self`. Caller guarantees positivity.
137 fn sqrt(&self) -> Self {
138 let u = self.value();
139 let s = u.sqrt();
140 self.compose_unary([
141 s,
142 0.5 / s,
143 -0.25 / (u * s),
144 0.375 / (u * u * s),
145 -0.9375 / (u * u * u * s),
146 ])
147 }
148
149 /// `ln(self)`. Caller guarantees positivity. Same derivative stack
150 /// [`crate::jet_tower::Tower4::ln`] uses, so any program written over both
151 /// matches term-for-term.
152 fn ln(&self) -> Self {
153 let u = self.value();
154 let r = 1.0 / u;
155 self.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
156 }
157
158 /// `1/self`.
159 fn recip(&self) -> Self {
160 let r = 1.0 / self.value();
161 let r2 = r * r;
162 self.compose_unary([r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r])
163 }
164
165 /// `self^a` for real exponent `a`. Caller guarantees a positive base.
166 /// Mirrors [`crate::jet_tower::Tower4::powf`] (falling-factorial stack).
167 fn powf(&self, a: f64) -> Self {
168 let u = self.value();
169 self.compose_unary([
170 u.powf(a),
171 a * u.powf(a - 1.0),
172 a * (a - 1.0) * u.powf(a - 2.0),
173 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
174 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
175 ])
176 }
177
178 /// `ln Γ(self)`. Caller guarantees a positive argument. Uses the SAME
179 /// hand-certified derivative stack [`crate::jet_tower::Tower4::ln_gamma`]
180 /// consumes ([`crate::jet_tower::ln_gamma_derivative_stack`]), so any
181 /// program written over both matches term-for-term.
182 fn ln_gamma(&self) -> Self {
183 self.compose_unary(crate::jet_tower::ln_gamma_derivative_stack(self.value()))
184 }
185
186 /// `ψ(self) = d/dx ln Γ(x)` (digamma). Caller guarantees a positive
187 /// argument. Same hand-certified stack
188 /// [`crate::jet_tower::digamma_derivative_stack`].
189 fn digamma(&self) -> Self {
190 self.compose_unary(crate::jet_tower::digamma_derivative_stack(self.value()))
191 }
192}
193
194// ── Order2<K> ergonomic operator overloads (doc §A.1) ───────────────────
195//
196// The dispersion-family row NLLs are written with `+`/`-`/`*` operators over
197// the primaries (mirroring how they read as `Tower4` expressions). These
198// delegate channel-for-channel to the inner `Tower2` arithmetic (which has
199// `Add`/`Mul`; `Sub`/`Neg` are expressed as `+ (-1)·rhs` exactly as the
200// `JetScalar::sub` / `JetScalar::neg` impls do), so an `Order2` expression is
201// bit-identical to the same `Tower4` expression's order-≤2 channels.
202
203impl<const K: usize> std::ops::Add for Order2<K> {
204 type Output = Self;
205 #[inline]
206 fn add(self, o: Self) -> Self {
207 Order2(self.0 + o.0)
208 }
209}
210
211impl<const K: usize> std::ops::Add<f64> for Order2<K> {
212 type Output = Self;
213 #[inline]
214 fn add(self, c: f64) -> Self {
215 Order2(self.0 + c)
216 }
217}
218
219impl<const K: usize> std::ops::Sub for Order2<K> {
220 type Output = Self;
221 #[inline]
222 fn sub(self, o: Self) -> Self {
223 Order2(self.0 + o.0.scale(-1.0))
224 }
225}
226
227impl<const K: usize> std::ops::Sub<f64> for Order2<K> {
228 type Output = Self;
229 #[inline]
230 fn sub(self, c: f64) -> Self {
231 Order2(self.0 + (-c))
232 }
233}
234
235impl<const K: usize> std::ops::Mul for Order2<K> {
236 type Output = Self;
237 #[inline]
238 fn mul(self, o: Self) -> Self {
239 Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
240 }
241}
242
243impl<const K: usize> std::ops::Mul<f64> for Order2<K> {
244 type Output = Self;
245 #[inline]
246 fn mul(self, c: f64) -> Self {
247 Order2(self.0.scale(c))
248 }
249}
250
251impl<const K: usize> std::ops::Neg for Order2<K> {
252 type Output = Self;
253 #[inline]
254 fn neg(self) -> Self {
255 Order2(self.0.scale(-1.0))
256 }
257}
258
259/// Filtered Hensel lift of a SCALAR implicit state `a(θ)` defined by the
260/// constraint `F(a, θ) = 0`, evaluated in ANY [`JetScalar`] algebra `S` (doc
261/// §11, "A generic implicit-lift operator for every production scalar").
262///
263/// This is the perf-respecting alternative to lifting through a dense
264/// `Tower4<K+1>` (which carries the implicit variable as an extra dense axis):
265/// the state `a` lives directly in the consumer's own `K`-primary algebra
266/// `S` — `Order2<K>` for value/gradient/Hessian, `Tower4<K>` for the full
267/// `t3`/`t4` — never paying for an extra variable.
268///
269/// **Method.** Fixed-Jacobian Newton in the nilpotent algebra. By the
270/// filtered-lift theorem (doc §11.1), if `F_a := ∂F/∂a(a₀, θ₀)` is the primal
271/// Jacobian at the base point and `inv_fa = 1/F_a`, then the iteration
272/// `A ← A − inv_fa · F(A, θ)` raises the filtration degree of the residual by
273/// at least one per step: each step kills exactly one graded layer. Starting
274/// from `A = const(a₀)` (whose residual lies in `F¹` because `θ − θ₀ ∈ 𝔫`),
275/// `iters` equal to the algebra's nilpotency order returns the *exact* lifted
276/// jet (`Order2`: 2, `OneSeed`: 3, `Tower4`/`TwoSeed`: 4). The value channel of
277/// `A` never moves — `F(A, θ).value() = F(a₀, θ₀) = 0` at the certified root —
278/// so a caller may precompute every primitive's derivative stack at the fixed
279/// base index once and let the cheap polynomial composition repeat per step.
280///
281/// `f` evaluates the constraint `F(a, θ)` in `S` (capturing the seeded
282/// parameter jets `θ`); `a0` is the certified scalar root `F(a₀, θ₀) ≈ 0`.
283pub fn filtered_implicit_solve_scalar<const K: usize, S: JetScalar<K>>(
284 a0: f64,
285 inv_fa: f64,
286 iters: usize,
287 f: impl Fn(&S) -> S,
288) -> S {
289 let mut a = S::constant(a0);
290 for _ in 0..iters {
291 let residual = f(&a);
292 a = a.sub(&residual.scale(inv_fa));
293 }
294 a
295}
296
297// ── Order2<K>: value / gradient / Hessian (doc §A.1) ────────────────────
298
299/// Truncated SECOND-order scalar: value `v`, gradient `g_a`, Hessian `H_{ab}`.
300///
301/// This is a thin newtype over the production [`crate::jet_tower::Tower2`], so
302/// its `(v, g, H)` channels are obtained by the SAME formulas — and are
303/// therefore bit-identical to both [`crate::jet_tower::Tower2`] and the order-≤2
304/// channels of a full [`crate::jet_tower::Tower4`] (doc §A.1, "Bit-identity with
305/// the full tower"). The wrapper exists only to satisfy the generic
306/// [`JetScalar`] interface (the `compose_unary` / `add` / `sub` / `neg` /
307/// `recip` the trait demands, which `Tower2` does not expose by that shape) —
308/// every channel is delegated to `Tower2` arithmetic unchanged.
309#[derive(Clone, Copy, Debug)]
310pub struct Order2<const K: usize>(pub crate::jet_tower::Tower2<K>);
311
312impl<const K: usize> Order2<K> {
313 /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
314 #[inline]
315 pub fn g(&self) -> [f64; K] {
316 self.0.g
317 }
318
319 /// Read the Hessian channel.
320 #[inline]
321 pub fn h(&self) -> [[f64; K]; K] {
322 self.0.h
323 }
324}
325
326impl<const K: usize> JetScalar<K> for Order2<K> {
327 fn constant(c: f64) -> Self {
328 Order2(crate::jet_tower::Tower2::constant(c))
329 }
330 fn variable(x: f64, axis: usize) -> Self {
331 Order2(crate::jet_tower::Tower2::variable(x, axis))
332 }
333 fn value(&self) -> f64 {
334 self.0.v
335 }
336 fn add(&self, o: &Self) -> Self {
337 Order2(self.0 + o.0)
338 }
339 fn sub(&self, o: &Self) -> Self {
340 // Tower2 has no Sub op; subtract by adding the negation, matching
341 // Tower4::sub (self + o.scale(-1.0)).
342 Order2(self.0 + o.0.scale(-1.0))
343 }
344 fn mul(&self, o: &Self) -> Self {
345 Order2(crate::jet_tower::Tower2::mul(&self.0, &o.0))
346 }
347 fn neg(&self) -> Self {
348 Order2(self.0.scale(-1.0))
349 }
350 fn scale(&self, s: f64) -> Self {
351 Order2(self.0.scale(s))
352 }
353 fn compose_unary(&self, d: [f64; 5]) -> Self {
354 // Order-≤2 reads only [f, f', f''] of the stack.
355 Order2(self.0.compose_unary([d[0], d[1], d[2]]))
356 }
357}
358
359// ── Lane-batched Order-2 scalar: 4 rows per pass in SIMD lanes (perf) ────
360//
361// The hot per-row jet kernels evaluate ONE row's `(v, g, H)` tower at a time in
362// scalar `f64`. A hand-written scalar derivative does the same. The throughput
363// lever a jet has that scalar hand-code cannot is **row batching in SIMD
364// lanes**: the order-≤2 Leibniz product `Order2::mul` is `O(K²)` independent
365// per-channel float ops, and EVERY row runs the identical op graph on different
366// data — the textbook SPMD shape. Packing `LANES = 4` rows into a `wide::f64x4`
367// and running the algebra once per 4 rows replaces 4 scalar passes with one
368// vector pass: the `K²` Hessian channel updates become `K²` NEON `.2d` / SSE2
369// `pd` instructions covering 4 rows each, ~4× fewer FP instructions per row.
370//
371// The carried scalar field is abstracted by [`Lane`] so the SAME algebra body
372// instantiates at `f64` (1 row, used as the bit-identity oracle) or
373// [`wide::f64x4`] (4 rows). Bit-identity is structural, not approximate:
374//
375// * Every arithmetic op is a plain lane-wise `+` / `-` / `*` (NEVER a fused
376// `mul_add`), and IEEE-754 double `+`/`-`/`*`/`/` are correctly rounded and
377// deterministic, so lane `i` of an `f64x4` op equals the scalar `f64` op on
378// that lane's inputs bit-for-bit.
379// * The transcendental derivative STACKS (`exp`/`ln`/`sqrt`/…) are produced
380// **per lane by the identical scalar code** ([`Lane::unary3`] unpacks, runs
381// the same `[f64; 3]` stack closure the scalar path runs, repacks), so the
382// only thing vectorised is the cheap rational tensor composition — the
383// library transcendental itself is the exact same `f64::exp` call per lane.
384// * The op order mirrors [`crate::jet_tower::Tower2`] term-for-term, so
385// [`Order2Lane<f64, K>`] is `to_bits`-identical to the production
386// [`Order2<K>`] (= `Tower2<K>`), and [`Order2Lane<f64x4, K>`] lane `i` is
387// `to_bits`-identical to that — proven by the `batch_tests` oracle below
388// (≥2000 random 4-row batches across `K ∈ {2,3,4,9}`).
389
390/// The scalar field a [`Order2Lane`] carries: either a single `f64` (one row,
391/// the oracle) or a [`wide::f64x4`] (four rows evaluated in SIMD lanes). All ops
392/// are plain lane-wise IEEE arithmetic, so a vector op equals the scalar op on
393/// each lane bit-for-bit.
394pub trait Lane: Copy {
395 /// Broadcast a scalar to every lane.
396 fn splat(x: f64) -> Self;
397 /// Lane-wise `self + o`.
398 fn add(self, o: Self) -> Self;
399 /// Lane-wise `self - o`.
400 fn sub(self, o: Self) -> Self;
401 /// Lane-wise `self * o`.
402 fn mul(self, o: Self) -> Self;
403 /// The `f64` in lane `i` (`i < LANES`; `f64` ignores `i`).
404 fn lane(self, i: usize) -> f64;
405 /// Build the order-≤2 derivative stack `[f(u), f′(u), f″(u)]` **per lane**
406 /// from the lane value `u`, via the SAME scalar `stack` closure the
407 /// per-row path runs (so the transcendental/rational stack is bit-identical
408 /// to the scalar evaluation — only the subsequent tensor composition is
409 /// vectorised).
410 fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3];
411 /// Build the order-≤4 derivative stack `[f, f′, f″, f‴, f⁗]` **per lane**
412 /// from the lane value `u`, via the SAME scalar `stack` closure the per-row
413 /// path runs. The one-/two-seed scalars ([`OneSeedLane`] / [`TwoSeedLane`])
414 /// need outer derivatives one / two orders beyond their order-2 base, so
415 /// they build their composition stack through this five-entry variant. As
416 /// with [`unary3`](Lane::unary3), only the transcendental/rational stack is
417 /// evaluated per lane (bit-identically to the scalar path); the subsequent
418 /// tensor composition is vectorised.
419 fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5];
420 /// The general-`N` sibling of [`unary3`](Lane::unary3) / [`unary5`](Lane::unary5):
421 /// build an `N`-wide derivative stack **per lane** from the lane value, via
422 /// the SAME scalar `stack` closure the per-row path runs, then pack the `N`
423 /// columns lane-wise. This is the lane primitive the compose-with-stack seam
424 /// ([`crate::jet_tower::Tower4Lane::compose_unary_with`] and its `Tower3`
425 /// sibling) routes through: it evaluates `stack` once per lane at that lane's
426 /// OWN base value (each of the four rows in an `f64x4` carries a distinct
427 /// base), so lane `i` of the packed result equals the scalar `stack(value_i)`
428 /// bit-for-bit (only the cheap pack is vectorised; the closure body is the
429 /// identical scalar code). With `N = 3` / `N = 5` it is `to_bits`-identical to
430 /// [`unary3`](Lane::unary3) / [`unary5`](Lane::unary5).
431 fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N];
432}
433
434impl Lane for f64 {
435 #[inline]
436 fn splat(x: f64) -> Self {
437 x
438 }
439 #[inline]
440 fn add(self, o: Self) -> Self {
441 self + o
442 }
443 #[inline]
444 fn sub(self, o: Self) -> Self {
445 self - o
446 }
447 #[inline]
448 fn mul(self, o: Self) -> Self {
449 self * o
450 }
451 #[inline]
452 fn lane(self, _: usize) -> f64 {
453 self
454 }
455 #[inline]
456 fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
457 stack(self)
458 }
459 #[inline]
460 fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
461 stack(self)
462 }
463 #[inline]
464 fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
465 // One row: the packed result IS the scalar stack ([Self; N] = [f64; N]).
466 stack(self)
467 }
468}
469
470impl Lane for wide::f64x4 {
471 #[inline]
472 fn splat(x: f64) -> Self {
473 wide::f64x4::splat(x)
474 }
475 #[inline]
476 fn add(self, o: Self) -> Self {
477 self + o
478 }
479 #[inline]
480 fn sub(self, o: Self) -> Self {
481 self - o
482 }
483 #[inline]
484 fn mul(self, o: Self) -> Self {
485 self * o
486 }
487 #[inline]
488 fn lane(self, i: usize) -> f64 {
489 self.to_array()[i]
490 }
491 #[inline]
492 fn unary3(self, stack: impl Fn(f64) -> [f64; 3]) -> [Self; 3] {
493 let a = self.to_array();
494 let mut d0 = [0.0_f64; 4];
495 let mut d1 = [0.0_f64; 4];
496 let mut d2 = [0.0_f64; 4];
497 for i in 0..4 {
498 let s = stack(a[i]);
499 d0[i] = s[0];
500 d1[i] = s[1];
501 d2[i] = s[2];
502 }
503 [
504 wide::f64x4::new(d0),
505 wide::f64x4::new(d1),
506 wide::f64x4::new(d2),
507 ]
508 }
509 #[inline]
510 fn unary5(self, stack: impl Fn(f64) -> [f64; 5]) -> [Self; 5] {
511 let a = self.to_array();
512 let mut d = [[0.0_f64; 4]; 5];
513 for i in 0..4 {
514 let s = stack(a[i]);
515 for (k, dk) in d.iter_mut().enumerate() {
516 dk[i] = s[k];
517 }
518 }
519 [
520 wide::f64x4::new(d[0]),
521 wide::f64x4::new(d[1]),
522 wide::f64x4::new(d[2]),
523 wide::f64x4::new(d[3]),
524 wide::f64x4::new(d[4]),
525 ]
526 }
527 #[inline]
528 fn unary_with<const N: usize>(self, stack: impl Fn(f64) -> [f64; N]) -> [Self; N] {
529 // Evaluate the scalar stack PER LANE at that lane's own base value, then
530 // pack the N derivative columns lane-wise (the same shape `unary5` uses,
531 // generalised to N). Lane `i` of column `k` is `stack(base_i)[k]`.
532 let a = self.to_array();
533 let mut cols = [[0.0_f64; 4]; N];
534 for (i, &base) in a.iter().enumerate() {
535 let s = stack(base);
536 for (k, sk) in s.iter().enumerate() {
537 cols[k][i] = *sk;
538 }
539 }
540 std::array::from_fn(|k| wide::f64x4::new(cols[k]))
541 }
542}
543
544/// A lane-batched order-≤2 Taylor scalar: value / gradient / Hessian carried in
545/// a SIMD field [`L: Lane`](Lane). With `L = f64x4` one instance carries FOUR
546/// rows at once, so the row loop processes 4 rows per vector pass instead of one
547/// per scalar pass.
548///
549/// The channel layout and every float op mirror [`crate::jet_tower::Tower2`]
550/// term-for-term, so `Order2Lane<f64, K>` is `to_bits`-identical to the
551/// production [`Order2<K>`] and `Order2Lane<f64x4, K>` lane `i` is
552/// `to_bits`-identical to that (see the module note and `batch_tests`).
553#[derive(Clone, Copy, Debug)]
554pub struct Order2Lane<L: Lane, const K: usize> {
555 /// Value channel `ℓ` (one entry per lane/row).
556 pub v: L,
557 /// Gradient channel `∂ℓ/∂p_a`.
558 pub g: [L; K],
559 /// Hessian channel `∂²ℓ/∂p_a∂p_b` (symmetric).
560 pub h: [[L; K]; K],
561}
562
563/// The 4-rows-per-pass batched order-≤2 scalar (`wide::f64x4` lanes).
564pub type Order2Batch<const K: usize> = Order2Lane<wide::f64x4, K>;
565
566impl<L: Lane, const K: usize> Order2Lane<L, K> {
567 /// A constant: value `c` in every channel-zero slot.
568 #[inline]
569 pub fn constant(c: L) -> Self {
570 Order2Lane {
571 v: c,
572 g: [L::splat(0.0); K],
573 h: [[L::splat(0.0); K]; K],
574 }
575 }
576
577 /// The seeded variable `p_axis` at (per-lane) value `value`: unit first
578 /// derivative in slot `axis`. With `L = f64x4`, `value` packs the four
579 /// rows' values of primary `axis`.
580 #[inline]
581 pub fn variable(value: L, axis: usize) -> Self {
582 let mut out = Self::constant(value);
583 out.g[axis] = L::splat(1.0);
584 out
585 }
586
587 /// Lane-wise `self + o` (mirrors `Tower2` Add: per-channel add).
588 #[inline]
589 pub fn add(&self, o: &Self) -> Self {
590 let mut out = *self;
591 out.v = self.v.add(o.v);
592 for i in 0..K {
593 out.g[i] = self.g[i].add(o.g[i]);
594 for j in 0..K {
595 out.h[i][j] = self.h[i][j].add(o.h[i][j]);
596 }
597 }
598 out
599 }
600
601 /// Multiply every channel by the plain scalar `s` (mirrors `Tower2::scale`).
602 #[inline]
603 pub fn scale(&self, s: f64) -> Self {
604 let sl = L::splat(s);
605 let mut out = *self;
606 out.v = self.v.mul(sl);
607 for i in 0..K {
608 out.g[i] = self.g[i].mul(sl);
609 for j in 0..K {
610 out.h[i][j] = self.h[i][j].mul(sl);
611 }
612 }
613 out
614 }
615
616 /// Lane-wise `self - o`, expressed as `self + o·(-1)` exactly as
617 /// [`Order2::sub`] / `Tower4::sub` do, so signed-zero handling matches.
618 #[inline]
619 pub fn sub(&self, o: &Self) -> Self {
620 self.add(&o.scale(-1.0))
621 }
622
623 /// Negate every channel (= `scale(-1.0)`, matching [`Order2::neg`]).
624 #[inline]
625 pub fn neg(&self) -> Self {
626 self.scale(-1.0)
627 }
628
629 /// Exact order-≤2 Leibniz product, term-for-term identical to
630 /// [`crate::jet_tower::Tower2::mul`] (same factor order, no `mul_add`).
631 ///
632 /// The Hessian channel is symmetric under `i ↔ j` (see
633 /// [`crate::jet_tower::Tower2::mul`] for why the invariant always holds), so
634 /// we compute the upper triangle `j ≥ i` and mirror it — `K(K+1)/2` lane
635 /// entry-chains instead of `K²`. Because each lane entry is already a full
636 /// SIMD op (no cross-`j` lane packing to lose), halving the entry count is a
637 /// direct throughput win (~18 % on `Order2Lane<f64x4, 9>`, the survival batch
638 /// kernel, and ~2× on the `f64` oracle). The upper triangle uses the EXACT
639 /// term order of `Tower2::mul`, so `Order2Lane<f64>` stays `to_bits`-identical
640 /// to `Order2` (= `Tower2`) and `Order2Lane<f64x4>` lane `i` stays
641 /// `to_bits`-identical to that; the mirror makes the batch Hessian exactly
642 /// symmetric, matching the scalar `Tower2::mul` (which mirrors identically).
643 #[inline]
644 pub fn mul(&self, o: &Self) -> Self {
645 let a = self;
646 let b = o;
647 let mut out = Self::constant(a.v.mul(b.v));
648 for i in 0..K {
649 // a.v*b.g[i] + a.g[i]*b.v
650 out.g[i] = a.v.mul(b.g[i]).add(a.g[i].mul(b.v));
651 }
652 for i in 0..K {
653 for j in i..K {
654 // a.v*b.h + a.g[i]*b.g[j] + a.g[j]*b.g[i] + a.h*b.v
655 let hij = a
656 .v
657 .mul(b.h[i][j])
658 .add(a.g[i].mul(b.g[j]))
659 .add(a.g[j].mul(b.g[i]))
660 .add(a.h[i][j].mul(b.v));
661 out.h[i][j] = hij;
662 out.h[j][i] = hij;
663 }
664 }
665 out
666 }
667
668 /// Exact order-≤2 Faà di Bruno composition `f ∘ self`, given the per-lane
669 /// derivative stack `d = [f(u), f′(u), f″(u)]`. Mirrors
670 /// [`crate::jet_tower::Tower2::compose_unary`] term-for-term (`acc` starts at
671 /// `0` then accumulates, so signed-zero collapses identically).
672 #[inline]
673 pub fn compose_unary(&self, d: [L; 3]) -> Self {
674 let mut out = Self::constant(d[0]);
675 for i in 0..K {
676 let mut acc = L::splat(0.0);
677 acc = acc.add(d[1].mul(self.g[i]));
678 out.g[i] = acc;
679 }
680 for i in 0..K {
681 for j in 0..K {
682 let mut acc = L::splat(0.0);
683 acc = acc.add(d[1].mul(self.h[i][j]));
684 acc = acc.add(d[2].mul(self.g[i]).mul(self.g[j]));
685 out.h[i][j] = acc;
686 }
687 }
688 out
689 }
690
691 /// `e^self`, per-lane stack `[e, e, e]` (matches the [`JetScalar::exp`]
692 /// default forwarded through `Order2`).
693 #[inline]
694 pub fn exp(&self) -> Self {
695 let d = self.v.unary3(|u| {
696 let e = u.exp();
697 [e, e, e]
698 });
699 self.compose_unary(d)
700 }
701
702 /// `ln(self)`; caller guarantees positivity. Per-lane stack
703 /// `[ln u, 1/u, -1/u²]` (matches [`JetScalar::ln`] truncated to order 2).
704 #[inline]
705 pub fn ln(&self) -> Self {
706 let d = self.v.unary3(|u| {
707 let r = 1.0 / u;
708 [u.ln(), r, -r * r]
709 });
710 self.compose_unary(d)
711 }
712
713 /// `√self`; caller guarantees positivity. Per-lane stack
714 /// `[s, 0.5/s, -0.25/(u·s)]` (matches [`JetScalar::sqrt`]).
715 #[inline]
716 pub fn sqrt(&self) -> Self {
717 let d = self.v.unary3(|u| {
718 let s = u.sqrt();
719 [s, 0.5 / s, -0.25 / (u * s)]
720 });
721 self.compose_unary(d)
722 }
723
724 /// `1/self`. Per-lane stack `[r, -r², 2r³]` (matches [`JetScalar::recip`]).
725 #[inline]
726 pub fn recip(&self) -> Self {
727 let d = self.v.unary3(|u| {
728 let r = 1.0 / u;
729 let r2 = r * r;
730 [r, -r2, 2.0 * r2 * r]
731 });
732 self.compose_unary(d)
733 }
734
735 /// `self^a` for real `a`; caller guarantees a positive base. Per-lane
736 /// falling-factorial stack (matches [`JetScalar::powf`]).
737 #[inline]
738 pub fn powf(&self, a: f64) -> Self {
739 let d = self.v.unary3(|u| {
740 [
741 u.powf(a),
742 a * u.powf(a - 1.0),
743 a * (a - 1.0) * u.powf(a - 2.0),
744 ]
745 });
746 self.compose_unary(d)
747 }
748}
749
750impl<const K: usize> Order2Batch<K> {
751 /// Extract lane `i`'s `(v, g, H)` as a production [`Order2<K>`] scalar.
752 /// Lane `i` is `to_bits`-identical to evaluating the same program at
753 /// [`Order2<K>`] on row `i` (see `batch_tests`).
754 #[inline]
755 #[must_use]
756 pub fn lane(&self, i: usize) -> Order2<K> {
757 let mut t = crate::jet_tower::Tower2::<K>::constant(self.v.lane(i));
758 for a in 0..K {
759 t.g[a] = self.g[a].lane(i);
760 for b in 0..K {
761 t.h[a][b] = self.h[a][b].lane(i);
762 }
763 }
764 Order2(t)
765 }
766}
767
768// ── Order1<K>: value / gradient only (doc §A.1, first-order prune) ──────
769
770/// Truncated FIRST-order scalar: value `v` and gradient `g_a` only — NO Hessian.
771///
772/// This is [`Order2`] with the K×K Hessian channel deleted. Its value and
773/// gradient are computed by the SAME order-≤1 truncation of the Leibniz / Faà
774/// di Bruno rules that [`Order2`] uses for those two channels, with the float
775/// operations applied in the identical order — so its `(v, g)` is BIT-IDENTICAL
776/// to both [`Order2`]'s and a full [`crate::jet_tower::Tower4`]'s order-≤1
777/// channels. Use it at a consumer that reads ONLY value + gradient (the SAE
778/// β-border channel: the reconstruction is linear in β, so the Hessian-in-β
779/// vanishes and the dense K×K Hessian product `Tower2::mul` would build is pure
780/// discarded work). Order-≤1 value/gradient never read any input's Hessian, so
781/// dropping that channel changes neither result nor float-op order — it only
782/// removes the `K²` arithmetic that produced an unread tensor.
783#[derive(Clone, Copy, Debug)]
784pub struct Order1<const K: usize> {
785 /// Value ℓ.
786 pub v: f64,
787 /// Gradient ∂ℓ/∂p_a.
788 pub g: [f64; K],
789}
790
791impl<const K: usize> Order1<K> {
792 /// Read the gradient channel `g_a = ∂ℓ/∂p_a`.
793 #[inline]
794 pub fn g(&self) -> [f64; K] {
795 self.g
796 }
797}
798
799impl<const K: usize> JetScalar<K> for Order1<K> {
800 fn constant(c: f64) -> Self {
801 // Order2::constant -> Tower2::constant: value c, all derivatives zero.
802 Order1 { v: c, g: [0.0; K] }
803 }
804 fn variable(x: f64, axis: usize) -> Self {
805 // Order2::variable -> Tower2::variable: unit first derivative in `axis`.
806 let mut g = [0.0; K];
807 g[axis] = 1.0;
808 Order1 { v: x, g }
809 }
810 fn value(&self) -> f64 {
811 self.v
812 }
813 fn add(&self, o: &Self) -> Self {
814 // Tower2 Add: out.v += o.v; out.g[i] += o.g[i] (same float order).
815 let mut g = self.g;
816 for i in 0..K {
817 g[i] += o.g[i];
818 }
819 Order1 { v: self.v + o.v, g }
820 }
821 fn sub(&self, o: &Self) -> Self {
822 // Mirror Order2::sub == self + o.scale(-1.0) exactly: scale then add.
823 self.add(&o.scale(-1.0))
824 }
825 fn mul(&self, o: &Self) -> Self {
826 // Tower2::mul value/grad terms, identical float order:
827 // v = a.v*b.v; g[i] = a.v*b.g[i] + a.g[i]*b.v.
828 // (The Hessian loop `a.v*b.h + a.g*b.g + ... + a.h*b.v` is the discarded
829 // work this type exists to skip; it never feeds v or g.)
830 let a = self;
831 let b = o;
832 let mut g = [0.0; K];
833 for i in 0..K {
834 g[i] = a.v * b.g[i] + a.g[i] * b.v;
835 }
836 Order1 { v: a.v * b.v, g }
837 }
838 fn neg(&self) -> Self {
839 // Order2::neg == self.0.scale(-1.0).
840 self.scale(-1.0)
841 }
842 fn scale(&self, s: f64) -> Self {
843 // Tower2::scale: out.v *= s; out.g[i] *= s (same float order).
844 let mut g = self.g;
845 for i in 0..K {
846 g[i] *= s;
847 }
848 Order1 { v: self.v * s, g }
849 }
850 fn compose_unary(&self, d: [f64; 5]) -> Self {
851 // Faà di Bruno truncated to order ≤ 1 (matches `faa_di_bruno` /
852 // `Tower2::compose_unary` for the value and gradient channels):
853 // value channel (m=0): d[0].
854 // grad channel (positions=[i], single partition {{0}}): d[1]·g[i].
855 // Order-≤1 reads only d[0], d[1]; trailing stack entries are unused.
856 let mut g = [0.0; K];
857 for i in 0..K {
858 g[i] = d[1] * self.g[i];
859 }
860 Order1 { v: d[0], g }
861 }
862}
863
864// ── OneSeed<K>: one-seed directional, contracted third (doc §A.2) ───────
865
866/// One-seed directional scalar: an [`Order2`] base plus ONE nilpotent ε
867/// (`ε² = 0`) whose coefficient is itself an [`Order2`].
868///
869/// A scalar is `s = base + ε·eps`. Arithmetic is the `ε² = 0` truncation of the
870/// product (doc §A.2): the base parts multiply as ordinary [`Order2`] products,
871/// and the ε-coefficient picks up `a.base·b.eps + a.eps·b.base`. Composition
872/// pushes ε through one extra outer derivative.
873///
874/// Seed each primary with [`seed_direction`](Self::seed_direction): the base is
875/// the usual seeded variable (carrying `e_a` for the Hessian channel) and the
876/// ε-coefficient is the FIXED contraction direction `u_a` (a constant). Then the
877/// ε-component of the evaluated Hessian channel is the contracted third
878/// `[eps.h][a][b] = Σ_c ℓ_{abc} u_c` — exactly `row_third_contracted(dir = u)`,
879/// without materialising `t3`.
880#[derive(Clone, Copy, Debug)]
881pub struct OneSeed<const K: usize> {
882 /// The `ε⁰` part: value / gradient / Hessian of `ℓ`.
883 pub base: Order2<K>,
884 /// The `ε¹` part: value / gradient / Hessian of the ε-coefficient. After a
885 /// `seed_direction(u)` evaluation, `eps.h[a][b] = Σ_c ℓ_{abc} u_c`.
886 pub eps: Order2<K>,
887}
888
889impl<const K: usize> OneSeed<K> {
890 /// Seed primary `axis` at value `x` with ε-direction component `u_axis`:
891 /// `p_axis = p_axis⁰ + x-seed + ε·u_axis`, i.e. base = `variable(x, axis)`
892 /// and eps = `constant(u_axis)` (doc §A.2 "Seeding").
893 pub fn seed_direction(x: f64, axis: usize, u_axis: f64) -> Self {
894 OneSeed {
895 base: Order2::variable(x, axis),
896 eps: Order2::constant(u_axis),
897 }
898 }
899
900 /// The contracted-third channel after a `seed_direction(u)` evaluation:
901 /// `out[a][b] = Σ_c ℓ_{abc} u_c`, i.e. the ε-coefficient's Hessian (doc §A.2).
902 pub fn contracted_third(&self) -> [[f64; K]; K] {
903 self.eps.h()
904 }
905}
906
907impl<const K: usize> JetScalar<K> for OneSeed<K> {
908 fn constant(c: f64) -> Self {
909 OneSeed {
910 base: Order2::constant(c),
911 eps: Order2::constant(0.0),
912 }
913 }
914 fn variable(x: f64, axis: usize) -> Self {
915 // No ε-direction unless seeded via `seed_direction`.
916 OneSeed {
917 base: Order2::variable(x, axis),
918 eps: Order2::constant(0.0),
919 }
920 }
921 fn value(&self) -> f64 {
922 self.base.value()
923 }
924 fn add(&self, o: &Self) -> Self {
925 OneSeed {
926 base: self.base.add(&o.base),
927 eps: self.eps.add(&o.eps),
928 }
929 }
930 fn sub(&self, o: &Self) -> Self {
931 OneSeed {
932 base: self.base.sub(&o.base),
933 eps: self.eps.sub(&o.eps),
934 }
935 }
936 fn mul(&self, o: &Self) -> Self {
937 // (a.base + ε a.eps)(b.base + ε b.eps), dropping ε².
938 OneSeed {
939 base: self.base.mul(&o.base),
940 eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
941 }
942 }
943 fn neg(&self) -> Self {
944 OneSeed {
945 base: self.base.neg(),
946 eps: self.eps.neg(),
947 }
948 }
949 fn scale(&self, s: f64) -> Self {
950 OneSeed {
951 base: self.base.scale(s),
952 eps: self.eps.scale(s),
953 }
954 }
955 fn compose_unary(&self, d: [f64; 5]) -> Self {
956 // f(base + ε eps) = f(base) + ε · f'(base)·eps (ε² = 0). Each factor is
957 // an Order2 composition: the base composes with the f-stack, and the
958 // ε-coefficient is the Order2 of the SHIFTED stack (the chain rule
959 // `f'(base)` as an Order2) times eps. Order2 reads only the leading
960 // three entries of whatever stack it is handed, so the trailing slots
961 // are unused padding (the fixed-length array makes the windowing total).
962 let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
963 // f'(base) as an Order2 (consumes [f', f'', f''']).
964 let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
965 let eps = fprime.mul(&self.eps);
966 OneSeed { base, eps }
967 }
968}
969
970// ── OneSeedLane<L, K>: lane-batched one-seed directional (doc §A.2) ──────
971
972/// Lane-batched [`OneSeed`]: the same one-seed directional scalar with its two
973/// [`Order2`] parts re-typed to [`Order2Lane<L, K>`], so one `L = f64x4`
974/// instance carries FOUR rows' contracted-third evaluations per vector pass.
975///
976/// Every operation (`add`/`sub`/`mul`/`neg`/`scale`/`compose_unary` and the
977/// transcendentals) is a term-for-term structural re-type of the scalar
978/// [`OneSeed`] ops onto the lane-implemented [`Order2Lane`] algebra. With
979/// `L = f64`, `OneSeedLane<f64, K>` is `to_bits`-identical to [`OneSeed<K>`];
980/// with `L = f64x4`, lane `i` is `to_bits`-identical to that (see `batch_tests`).
981#[derive(Clone, Copy, Debug)]
982pub struct OneSeedLane<L: Lane, const K: usize> {
983 /// The `ε⁰` part (lane-batched value / gradient / Hessian of `ℓ`).
984 pub base: Order2Lane<L, K>,
985 /// The `ε¹` part. After a `seed_direction(u)` evaluation,
986 /// `eps.h[a][b]` lane `i` is row `i`'s `Σ_c ℓ_{abc} u_c`.
987 pub eps: Order2Lane<L, K>,
988}
989
990/// The 4-rows-per-pass batched one-seed scalar (`wide::f64x4` lanes).
991pub type OneSeedBatch<const K: usize> = OneSeedLane<wide::f64x4, K>;
992
993impl<L: Lane, const K: usize> OneSeedLane<L, K> {
994 /// A constant: base = `constant(c)`, ε-part zero (mirrors [`OneSeed::constant`]).
995 #[inline]
996 pub fn constant(c: L) -> Self {
997 OneSeedLane {
998 base: Order2Lane::constant(c),
999 eps: Order2Lane::constant(L::splat(0.0)),
1000 }
1001 }
1002
1003 /// The seeded variable `p_axis` at (per-lane) value `value`, no ε-direction
1004 /// (mirrors [`OneSeed::variable`]).
1005 #[inline]
1006 pub fn variable(value: L, axis: usize) -> Self {
1007 OneSeedLane {
1008 base: Order2Lane::variable(value, axis),
1009 eps: Order2Lane::constant(L::splat(0.0)),
1010 }
1011 }
1012
1013 /// Seed primary `axis` at (per-lane) value `value` with ε-direction
1014 /// component `u_axis`: base = `variable(value, axis)`, eps = `constant(u_axis)`
1015 /// (mirrors [`OneSeed::seed_direction`]). With `L = f64x4`, `value` / `u_axis`
1016 /// pack the four rows' values / directions of primary `axis`.
1017 #[inline]
1018 pub fn seed_direction(value: L, axis: usize, u_axis: L) -> Self {
1019 OneSeedLane {
1020 base: Order2Lane::variable(value, axis),
1021 eps: Order2Lane::constant(u_axis),
1022 }
1023 }
1024
1025 /// The contracted-third channel after a `seed_direction(u)` evaluation:
1026 /// `out[a][b]` lane `i` is row `i`'s `Σ_c ℓ_{abc} u_c` (the ε-part Hessian).
1027 #[inline]
1028 #[must_use]
1029 pub fn contracted_third(&self) -> [[L; K]; K] {
1030 self.eps.h
1031 }
1032
1033 /// Lane-wise `self + o` (mirrors [`OneSeed::add`]).
1034 #[inline]
1035 pub fn add(&self, o: &Self) -> Self {
1036 OneSeedLane {
1037 base: self.base.add(&o.base),
1038 eps: self.eps.add(&o.eps),
1039 }
1040 }
1041
1042 /// Lane-wise `self - o` (mirrors [`OneSeed::sub`]).
1043 #[inline]
1044 pub fn sub(&self, o: &Self) -> Self {
1045 OneSeedLane {
1046 base: self.base.sub(&o.base),
1047 eps: self.eps.sub(&o.eps),
1048 }
1049 }
1050
1051 /// Lane-wise `self · o`, ε² = 0 truncation (mirrors [`OneSeed::mul`]).
1052 #[inline]
1053 pub fn mul(&self, o: &Self) -> Self {
1054 OneSeedLane {
1055 base: self.base.mul(&o.base),
1056 eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
1057 }
1058 }
1059
1060 /// Negate every part (mirrors [`OneSeed::neg`]).
1061 #[inline]
1062 pub fn neg(&self) -> Self {
1063 OneSeedLane {
1064 base: self.base.neg(),
1065 eps: self.eps.neg(),
1066 }
1067 }
1068
1069 /// Multiply every part by the plain scalar `s` (mirrors [`OneSeed::scale`]).
1070 #[inline]
1071 pub fn scale(&self, s: f64) -> Self {
1072 OneSeedLane {
1073 base: self.base.scale(s),
1074 eps: self.eps.scale(s),
1075 }
1076 }
1077
1078 /// Exact order-≤2-per-part Faà di Bruno composition `f ∘ self`, given the
1079 /// per-lane outer-derivative stack `d = [f, f′, f″, f‴, f⁗]`. Term-for-term
1080 /// identical to [`OneSeed::compose_unary`]: the base reads `d[0..=2]` and the
1081 /// ε-coefficient is `f′(base)` (reads `d[1..=3]`) times `eps`.
1082 #[inline]
1083 pub fn compose_unary(&self, d: [L; 5]) -> Self {
1084 let base = self.base.compose_unary([d[0], d[1], d[2]]);
1085 let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
1086 let eps = fprime.mul(&self.eps);
1087 OneSeedLane { base, eps }
1088 }
1089
1090 /// `e^self`, per-lane stack `[e, e, e, e, e]` (matches [`JetScalar::exp`]).
1091 #[inline]
1092 pub fn exp(&self) -> Self {
1093 let d = self.base.v.unary5(|u| {
1094 let e = u.exp();
1095 [e, e, e, e, e]
1096 });
1097 self.compose_unary(d)
1098 }
1099
1100 /// `ln(self)`; caller guarantees positivity (matches [`JetScalar::ln`]).
1101 #[inline]
1102 pub fn ln(&self) -> Self {
1103 let d = self.base.v.unary5(|u| {
1104 let r = 1.0 / u;
1105 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
1106 });
1107 self.compose_unary(d)
1108 }
1109
1110 /// `√self`; caller guarantees positivity (matches [`JetScalar::sqrt`]).
1111 #[inline]
1112 pub fn sqrt(&self) -> Self {
1113 let d = self.base.v.unary5(|u| {
1114 let s = u.sqrt();
1115 [
1116 s,
1117 0.5 / s,
1118 -0.25 / (u * s),
1119 0.375 / (u * u * s),
1120 -0.9375 / (u * u * u * s),
1121 ]
1122 });
1123 self.compose_unary(d)
1124 }
1125
1126 /// `1/self` (matches [`JetScalar::recip`]).
1127 #[inline]
1128 pub fn recip(&self) -> Self {
1129 let d = self.base.v.unary5(|u| {
1130 let r = 1.0 / u;
1131 let r2 = r * r;
1132 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
1133 });
1134 self.compose_unary(d)
1135 }
1136
1137 /// `self^a` for real `a`; caller guarantees a positive base (matches
1138 /// [`JetScalar::powf`]).
1139 #[inline]
1140 pub fn powf(&self, a: f64) -> Self {
1141 let d = self.base.v.unary5(|u| {
1142 [
1143 u.powf(a),
1144 a * u.powf(a - 1.0),
1145 a * (a - 1.0) * u.powf(a - 2.0),
1146 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
1147 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
1148 ]
1149 });
1150 self.compose_unary(d)
1151 }
1152
1153 /// `ln Γ(self)`; caller guarantees positivity (matches [`JetScalar::ln_gamma`],
1154 /// same hand-certified stack).
1155 #[inline]
1156 pub fn ln_gamma(&self) -> Self {
1157 let d = self
1158 .base
1159 .v
1160 .unary5(crate::jet_tower::ln_gamma_derivative_stack);
1161 self.compose_unary(d)
1162 }
1163
1164 /// `ψ(self)` digamma; caller guarantees positivity (matches
1165 /// [`JetScalar::digamma`], same hand-certified stack).
1166 #[inline]
1167 pub fn digamma(&self) -> Self {
1168 let d = self
1169 .base
1170 .v
1171 .unary5(crate::jet_tower::digamma_derivative_stack);
1172 self.compose_unary(d)
1173 }
1174}
1175
1176impl<const K: usize> OneSeedBatch<K> {
1177 /// Extract lane `i`'s parts as a production [`OneSeed<K>`]. Lane `i` is
1178 /// `to_bits`-identical to evaluating the same program at [`OneSeed<K>`] on
1179 /// row `i` (see `batch_tests`).
1180 #[inline]
1181 #[must_use]
1182 pub fn lane(&self, i: usize) -> OneSeed<K> {
1183 OneSeed {
1184 base: self.base.lane(i),
1185 eps: self.eps.lane(i),
1186 }
1187 }
1188}
1189
1190// ── TwoSeed<K>: two-seed, contracted fourth (doc §A.3) ──────────────────
1191
1192/// Two-seed scalar: an [`Order2`] base plus TWO nilpotents ε, δ
1193/// (`ε² = δ² = 0`, `εδ` retained) — four [`Order2`] parts
1194/// `s = base + ε·eps + δ·del + εδ·eps_del`.
1195///
1196/// Product truncates `ε² = δ² = 0` (doc §A.3): each part is built from
1197/// [`Order2`] products of the four input parts. Composition picks up
1198/// successively higher outer derivatives, the cross part carrying the second
1199/// Faà di Bruno term `f''·eps·del + f'·eps_del`.
1200///
1201/// Seed each primary with [`seed`](Self::seed): base = `variable(x, axis)`,
1202/// eps = `constant(u_axis)`, del = `constant(v_axis)`, eps_del = `constant(0)`.
1203/// Then the εδ-component of the evaluated Hessian channel is the contracted
1204/// fourth `[eps_del.h][a][b] = Σ_{cd} ℓ_{abcd} u_c v_d` — exactly
1205/// `row_fourth_contracted(u, v)`, without materialising `t4`.
1206#[derive(Clone, Copy, Debug)]
1207pub struct TwoSeed<const K: usize> {
1208 /// The `ε⁰δ⁰` part: value / grad / Hessian of `ℓ`.
1209 pub base: Order2<K>,
1210 /// The `ε¹δ⁰` part.
1211 pub eps: Order2<K>,
1212 /// The `ε⁰δ¹` part.
1213 pub del: Order2<K>,
1214 /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation,
1215 /// `eps_del.h[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`.
1216 pub eps_del: Order2<K>,
1217}
1218
1219impl<const K: usize> TwoSeed<K> {
1220 /// Seed primary `axis` at value `x` with ε-direction `u_axis` and
1221 /// δ-direction `v_axis`:
1222 /// `p_axis = p_axis⁰ + x-seed + ε·u_axis + δ·v_axis` (doc §A.3 "Seeding").
1223 pub fn seed(x: f64, axis: usize, u_axis: f64, v_axis: f64) -> Self {
1224 TwoSeed {
1225 base: Order2::variable(x, axis),
1226 eps: Order2::constant(u_axis),
1227 del: Order2::constant(v_axis),
1228 eps_del: Order2::constant(0.0),
1229 }
1230 }
1231
1232 /// The contracted-fourth channel after a `seed(u, v)` evaluation:
1233 /// `out[a][b] = Σ_{cd} ℓ_{abcd} u_c v_d`, i.e. the εδ-coefficient's Hessian.
1234 pub fn contracted_fourth(&self) -> [[f64; K]; K] {
1235 self.eps_del.h()
1236 }
1237}
1238
1239impl<const K: usize> JetScalar<K> for TwoSeed<K> {
1240 fn constant(c: f64) -> Self {
1241 TwoSeed {
1242 base: Order2::constant(c),
1243 eps: Order2::constant(0.0),
1244 del: Order2::constant(0.0),
1245 eps_del: Order2::constant(0.0),
1246 }
1247 }
1248 fn variable(x: f64, axis: usize) -> Self {
1249 TwoSeed {
1250 base: Order2::variable(x, axis),
1251 eps: Order2::constant(0.0),
1252 del: Order2::constant(0.0),
1253 eps_del: Order2::constant(0.0),
1254 }
1255 }
1256 fn value(&self) -> f64 {
1257 self.base.value()
1258 }
1259 fn add(&self, o: &Self) -> Self {
1260 TwoSeed {
1261 base: self.base.add(&o.base),
1262 eps: self.eps.add(&o.eps),
1263 del: self.del.add(&o.del),
1264 eps_del: self.eps_del.add(&o.eps_del),
1265 }
1266 }
1267 fn sub(&self, o: &Self) -> Self {
1268 TwoSeed {
1269 base: self.base.sub(&o.base),
1270 eps: self.eps.sub(&o.eps),
1271 del: self.del.sub(&o.del),
1272 eps_del: self.eps_del.sub(&o.eps_del),
1273 }
1274 }
1275 fn mul(&self, o: &Self) -> Self {
1276 let a = self;
1277 let b = o;
1278 // Truncate ε² = δ² = 0 (doc §A.3 product table).
1279 let base = a.base.mul(&b.base);
1280 let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
1281 let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
1282 let eps_del = a
1283 .base
1284 .mul(&b.eps_del)
1285 .add(&a.eps.mul(&b.del))
1286 .add(&a.del.mul(&b.eps))
1287 .add(&a.eps_del.mul(&b.base));
1288 TwoSeed {
1289 base,
1290 eps,
1291 del,
1292 eps_del,
1293 }
1294 }
1295 fn neg(&self) -> Self {
1296 TwoSeed {
1297 base: self.base.neg(),
1298 eps: self.eps.neg(),
1299 del: self.del.neg(),
1300 eps_del: self.eps_del.neg(),
1301 }
1302 }
1303 fn scale(&self, s: f64) -> Self {
1304 TwoSeed {
1305 base: self.base.scale(s),
1306 eps: self.eps.scale(s),
1307 del: self.del.scale(s),
1308 eps_del: self.eps_del.scale(s),
1309 }
1310 }
1311 fn compose_unary(&self, d: [f64; 5]) -> Self {
1312 // f(s) with s = base + ε eps + δ del + εδ eps_del, ε²=δ²=0:
1313 // f(s) = f(base)
1314 // + ε · f'(base)·eps
1315 // + δ · f'(base)·del
1316 // + εδ · ( f''(base)·eps·del + f'(base)·eps_del ).
1317 // Each f^{(r)}(base) is the Order2 composition of base with the stack
1318 // shifted r entries (doc §A.3 composition). Order2 reads only the
1319 // leading three entries of whatever stack it is handed, so the trailing
1320 // padding slots are unused (the fixed-length array makes this total).
1321 let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
1322 let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]); // f'(base) as Order2
1323 let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]); // f''(base) as Order2
1324 let eps = fprime.mul(&self.eps);
1325 let del = fprime.mul(&self.del);
1326 let eps_del = fsecond
1327 .mul(&self.eps)
1328 .mul(&self.del)
1329 .add(&fprime.mul(&self.eps_del));
1330 TwoSeed {
1331 base,
1332 eps,
1333 del,
1334 eps_del,
1335 }
1336 }
1337}
1338
1339// ── TwoSeedLane<L, K>: lane-batched two-seed, contracted fourth (doc §A.3) ─
1340
1341/// Lane-batched [`TwoSeed`]: the same two-seed scalar with its four [`Order2`]
1342/// parts re-typed to [`Order2Lane<L, K>`], so one `L = f64x4` instance carries
1343/// FOUR rows' contracted-fourth evaluations per vector pass.
1344///
1345/// Every operation is a term-for-term structural re-type of the scalar
1346/// [`TwoSeed`] ops onto the lane-implemented [`Order2Lane`] algebra. With
1347/// `L = f64`, `TwoSeedLane<f64, K>` is `to_bits`-identical to [`TwoSeed<K>`];
1348/// with `L = f64x4`, lane `i` is `to_bits`-identical to that (see `batch_tests`).
1349#[derive(Clone, Copy, Debug)]
1350pub struct TwoSeedLane<L: Lane, const K: usize> {
1351 /// The `ε⁰δ⁰` part.
1352 pub base: Order2Lane<L, K>,
1353 /// The `ε¹δ⁰` part.
1354 pub eps: Order2Lane<L, K>,
1355 /// The `ε⁰δ¹` part.
1356 pub del: Order2Lane<L, K>,
1357 /// The `ε¹δ¹` part. After a `seed(u, v)` evaluation, `eps_del.h[a][b]`
1358 /// lane `i` is row `i`'s `Σ_{cd} ℓ_{abcd} u_c v_d`.
1359 pub eps_del: Order2Lane<L, K>,
1360}
1361
1362/// The 4-rows-per-pass batched two-seed scalar (`wide::f64x4` lanes).
1363pub type TwoSeedBatch<const K: usize> = TwoSeedLane<wide::f64x4, K>;
1364
1365impl<L: Lane, const K: usize> TwoSeedLane<L, K> {
1366 /// A constant: base = `constant(c)`, all seed parts zero (mirrors
1367 /// [`TwoSeed::constant`]).
1368 #[inline]
1369 pub fn constant(c: L) -> Self {
1370 let z = Order2Lane::constant(L::splat(0.0));
1371 TwoSeedLane {
1372 base: Order2Lane::constant(c),
1373 eps: z,
1374 del: z,
1375 eps_del: z,
1376 }
1377 }
1378
1379 /// The seeded variable `p_axis` at (per-lane) value `value`, no ε/δ direction
1380 /// (mirrors [`TwoSeed::variable`]).
1381 #[inline]
1382 pub fn variable(value: L, axis: usize) -> Self {
1383 let z = Order2Lane::constant(L::splat(0.0));
1384 TwoSeedLane {
1385 base: Order2Lane::variable(value, axis),
1386 eps: z,
1387 del: z,
1388 eps_del: z,
1389 }
1390 }
1391
1392 /// Seed primary `axis` at (per-lane) value `value` with ε-direction `u_axis`
1393 /// and δ-direction `v_axis` (mirrors [`TwoSeed::seed`]). With `L = f64x4`,
1394 /// each argument packs the four rows' values for primary `axis`.
1395 #[inline]
1396 pub fn seed(value: L, axis: usize, u_axis: L, v_axis: L) -> Self {
1397 TwoSeedLane {
1398 base: Order2Lane::variable(value, axis),
1399 eps: Order2Lane::constant(u_axis),
1400 del: Order2Lane::constant(v_axis),
1401 eps_del: Order2Lane::constant(L::splat(0.0)),
1402 }
1403 }
1404
1405 /// The contracted-fourth channel after a `seed(u, v)` evaluation:
1406 /// `out[a][b]` lane `i` is row `i`'s `Σ_{cd} ℓ_{abcd} u_c v_d`
1407 /// (the εδ-part Hessian).
1408 #[inline]
1409 #[must_use]
1410 pub fn contracted_fourth(&self) -> [[L; K]; K] {
1411 self.eps_del.h
1412 }
1413
1414 /// Lane-wise `self + o` (mirrors [`TwoSeed::add`]).
1415 #[inline]
1416 pub fn add(&self, o: &Self) -> Self {
1417 TwoSeedLane {
1418 base: self.base.add(&o.base),
1419 eps: self.eps.add(&o.eps),
1420 del: self.del.add(&o.del),
1421 eps_del: self.eps_del.add(&o.eps_del),
1422 }
1423 }
1424
1425 /// Lane-wise `self - o` (mirrors [`TwoSeed::sub`]).
1426 #[inline]
1427 pub fn sub(&self, o: &Self) -> Self {
1428 TwoSeedLane {
1429 base: self.base.sub(&o.base),
1430 eps: self.eps.sub(&o.eps),
1431 del: self.del.sub(&o.del),
1432 eps_del: self.eps_del.sub(&o.eps_del),
1433 }
1434 }
1435
1436 /// Lane-wise `self · o`, ε² = δ² = 0 truncation (mirrors [`TwoSeed::mul`]).
1437 #[inline]
1438 pub fn mul(&self, o: &Self) -> Self {
1439 let a = self;
1440 let b = o;
1441 let base = a.base.mul(&b.base);
1442 let eps = a.base.mul(&b.eps).add(&a.eps.mul(&b.base));
1443 let del = a.base.mul(&b.del).add(&a.del.mul(&b.base));
1444 let eps_del = a
1445 .base
1446 .mul(&b.eps_del)
1447 .add(&a.eps.mul(&b.del))
1448 .add(&a.del.mul(&b.eps))
1449 .add(&a.eps_del.mul(&b.base));
1450 TwoSeedLane {
1451 base,
1452 eps,
1453 del,
1454 eps_del,
1455 }
1456 }
1457
1458 /// Negate every part (mirrors [`TwoSeed::neg`]).
1459 #[inline]
1460 pub fn neg(&self) -> Self {
1461 TwoSeedLane {
1462 base: self.base.neg(),
1463 eps: self.eps.neg(),
1464 del: self.del.neg(),
1465 eps_del: self.eps_del.neg(),
1466 }
1467 }
1468
1469 /// Multiply every part by the plain scalar `s` (mirrors [`TwoSeed::scale`]).
1470 #[inline]
1471 pub fn scale(&self, s: f64) -> Self {
1472 TwoSeedLane {
1473 base: self.base.scale(s),
1474 eps: self.eps.scale(s),
1475 del: self.del.scale(s),
1476 eps_del: self.eps_del.scale(s),
1477 }
1478 }
1479
1480 /// Exact composition `f ∘ self`, given the per-lane outer-derivative stack
1481 /// `d = [f, f′, f″, f‴, f⁗]`. Term-for-term identical to
1482 /// [`TwoSeed::compose_unary`]: base reads `d[0..=2]`, `f′(base)` reads
1483 /// `d[1..=3]`, `f″(base)` reads `d[2..=4]`, and the cross part carries
1484 /// `f″·eps·del + f′·eps_del`.
1485 #[inline]
1486 pub fn compose_unary(&self, d: [L; 5]) -> Self {
1487 let base = self.base.compose_unary([d[0], d[1], d[2]]);
1488 let fprime = self.base.compose_unary([d[1], d[2], d[3]]);
1489 let fsecond = self.base.compose_unary([d[2], d[3], d[4]]);
1490 let eps = fprime.mul(&self.eps);
1491 let del = fprime.mul(&self.del);
1492 let eps_del = fsecond
1493 .mul(&self.eps)
1494 .mul(&self.del)
1495 .add(&fprime.mul(&self.eps_del));
1496 TwoSeedLane {
1497 base,
1498 eps,
1499 del,
1500 eps_del,
1501 }
1502 }
1503
1504 /// `e^self`, per-lane stack `[e; 5]` (matches [`JetScalar::exp`]).
1505 #[inline]
1506 pub fn exp(&self) -> Self {
1507 let d = self.base.v.unary5(|u| {
1508 let e = u.exp();
1509 [e, e, e, e, e]
1510 });
1511 self.compose_unary(d)
1512 }
1513
1514 /// `ln(self)`; caller guarantees positivity (matches [`JetScalar::ln`]).
1515 #[inline]
1516 pub fn ln(&self) -> Self {
1517 let d = self.base.v.unary5(|u| {
1518 let r = 1.0 / u;
1519 [u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r]
1520 });
1521 self.compose_unary(d)
1522 }
1523
1524 /// `√self`; caller guarantees positivity (matches [`JetScalar::sqrt`]).
1525 #[inline]
1526 pub fn sqrt(&self) -> Self {
1527 let d = self.base.v.unary5(|u| {
1528 let s = u.sqrt();
1529 [
1530 s,
1531 0.5 / s,
1532 -0.25 / (u * s),
1533 0.375 / (u * u * s),
1534 -0.9375 / (u * u * u * s),
1535 ]
1536 });
1537 self.compose_unary(d)
1538 }
1539
1540 /// `1/self` (matches [`JetScalar::recip`]).
1541 #[inline]
1542 pub fn recip(&self) -> Self {
1543 let d = self.base.v.unary5(|u| {
1544 let r = 1.0 / u;
1545 let r2 = r * r;
1546 [r, -r2, 2.0 * r2 * r, -6.0 * r2 * r2, 24.0 * r2 * r2 * r]
1547 });
1548 self.compose_unary(d)
1549 }
1550
1551 /// `self^a` for real `a`; caller guarantees a positive base (matches
1552 /// [`JetScalar::powf`]).
1553 #[inline]
1554 pub fn powf(&self, a: f64) -> Self {
1555 let d = self.base.v.unary5(|u| {
1556 [
1557 u.powf(a),
1558 a * u.powf(a - 1.0),
1559 a * (a - 1.0) * u.powf(a - 2.0),
1560 a * (a - 1.0) * (a - 2.0) * u.powf(a - 3.0),
1561 a * (a - 1.0) * (a - 2.0) * (a - 3.0) * u.powf(a - 4.0),
1562 ]
1563 });
1564 self.compose_unary(d)
1565 }
1566
1567 /// `ln Γ(self)`; caller guarantees positivity (matches [`JetScalar::ln_gamma`]).
1568 #[inline]
1569 pub fn ln_gamma(&self) -> Self {
1570 let d = self
1571 .base
1572 .v
1573 .unary5(crate::jet_tower::ln_gamma_derivative_stack);
1574 self.compose_unary(d)
1575 }
1576
1577 /// `ψ(self)` digamma; caller guarantees positivity (matches
1578 /// [`JetScalar::digamma`]).
1579 #[inline]
1580 pub fn digamma(&self) -> Self {
1581 let d = self
1582 .base
1583 .v
1584 .unary5(crate::jet_tower::digamma_derivative_stack);
1585 self.compose_unary(d)
1586 }
1587}
1588
1589impl<const K: usize> TwoSeedBatch<K> {
1590 /// Extract lane `i`'s parts as a production [`TwoSeed<K>`]. Lane `i` is
1591 /// `to_bits`-identical to evaluating the same program at [`TwoSeed<K>`] on
1592 /// row `i` (see `batch_tests`).
1593 #[inline]
1594 #[must_use]
1595 pub fn lane(&self, i: usize) -> TwoSeed<K> {
1596 TwoSeed {
1597 base: self.base.lane(i),
1598 eps: self.eps.lane(i),
1599 del: self.del.lane(i),
1600 eps_del: self.eps_del.lane(i),
1601 }
1602 }
1603}
1604
1605// ── Tower3<K>: value / gradient / Hessian / third tensor ────────────────
1606
1607/// The order-≤3 [`crate::jet_tower::Tower3`] is also a [`JetScalar`]. It serves
1608/// consumers that read `.t3` but never `.t4`, avoiding the fourth-tensor
1609/// product/composition work while preserving the lower channels
1610/// bit-for-bit against [`crate::jet_tower::Tower4`].
1611impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower3<K> {
1612 fn constant(c: f64) -> Self {
1613 crate::jet_tower::Tower3::constant(c)
1614 }
1615 fn variable(x: f64, axis: usize) -> Self {
1616 crate::jet_tower::Tower3::variable(x, axis)
1617 }
1618 fn value(&self) -> f64 {
1619 self.v
1620 }
1621 fn add(&self, o: &Self) -> Self {
1622 *self + *o
1623 }
1624 fn sub(&self, o: &Self) -> Self {
1625 *self + o.scale(-1.0)
1626 }
1627 fn mul(&self, o: &Self) -> Self {
1628 crate::jet_tower::Tower3::mul(self, o)
1629 }
1630 fn neg(&self) -> Self {
1631 self.scale(-1.0)
1632 }
1633 fn scale(&self, s: f64) -> Self {
1634 crate::jet_tower::Tower3::scale(self, s)
1635 }
1636 fn compose_unary(&self, d: [f64; 5]) -> Self {
1637 crate::jet_tower::Tower3::compose_unary(self, [d[0], d[1], d[2], d[3]])
1638 }
1639}
1640
1641// ── Tower4<K>: full dense tower as a JetScalar (the all-channels scalar) ─
1642
1643/// The full dense [`crate::jet_tower::Tower4`] is itself a [`JetScalar`]: it
1644/// carries EVERY channel, so a row expression written ONCE against [`JetScalar`]
1645/// can be evaluated at `Tower4` to obtain the full `(v, g, H, t3, t4)` in one
1646/// pass. This is BOTH the #932 oracle ground truth the packed [`Order2`] /
1647/// [`OneSeed`] / [`TwoSeed`] scalars are pinned against, AND a production scalar:
1648/// a family whose uncontracted third / fourth derivative tensors are needed
1649/// (the BMS rigid `third_full` / `fourth_full` caches) evaluates the SAME
1650/// generic row-NLL expression at `Tower4` and reads `.t3` / `.t4` off the
1651/// result — so the dense tensors come from the single source of truth, not a
1652/// separately hand-written jet. The packed scalars serve the consumers that
1653/// need only `(v, g, H)` (`Order2`) or one / two contractions
1654/// (`OneSeed` / `TwoSeed`) without paying for the dense tensors.
1655impl<const K: usize> JetScalar<K> for crate::jet_tower::Tower4<K> {
1656 fn constant(c: f64) -> Self {
1657 crate::jet_tower::Tower4::constant(c)
1658 }
1659 fn variable(x: f64, axis: usize) -> Self {
1660 crate::jet_tower::Tower4::variable(x, axis)
1661 }
1662 fn value(&self) -> f64 {
1663 self.v
1664 }
1665 fn add(&self, o: &Self) -> Self {
1666 *self + *o
1667 }
1668 fn sub(&self, o: &Self) -> Self {
1669 *self - *o
1670 }
1671 fn mul(&self, o: &Self) -> Self {
1672 crate::jet_tower::Tower4::mul(self, o)
1673 }
1674 fn neg(&self) -> Self {
1675 self.scale(-1.0)
1676 }
1677 fn scale(&self, s: f64) -> Self {
1678 crate::jet_tower::Tower4::scale(self, s)
1679 }
1680 fn compose_unary(&self, d: [f64; 5]) -> Self {
1681 crate::jet_tower::Tower4::compose_unary(self, d)
1682 }
1683}
1684
1685#[cfg(test)]
1686mod tests {
1687 use super::*;
1688 use crate::jet_tower::{RowNllProgram, Tower4, evaluate_program};
1689
1690 /// A small polynomial-plus-unary row expression written ONCE, generically
1691 /// over `S: JetScalar<2>`, so it can be evaluated against every scalar:
1692 /// `ℓ = (e^{p0·p1} + 2) · √(p0·p0 + 1) − p1·p1·0.5`.
1693 /// Exercises mul, add/sub, scale, exp, sqrt — every algebra op.
1694 fn row_expr<S: JetScalar<2>>(p: &[S; 2]) -> S {
1695 let g = p[0].mul(&p[1]).exp();
1696 let inner = g.add(&S::constant(2.0));
1697 let radic = p[0].mul(&p[0]).add(&S::constant(1.0)).sqrt();
1698 inner.mul(&radic).sub(&p[1].mul(&p[1]).scale(0.5))
1699 }
1700
1701 /// The same expression as a Tower4 `RowNllProgram`, the ground-truth tower.
1702 struct ExprProgram {
1703 p: [f64; 2],
1704 }
1705 impl RowNllProgram<2> for ExprProgram {
1706 fn n_rows(&self) -> usize {
1707 1
1708 }
1709 fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
1710 if row >= self.n_rows() {
1711 return Err(format!("ExprProgram: row {row} out of range"));
1712 }
1713 Ok(self.p)
1714 }
1715 fn row_nll(&self, row: usize, p: &[Tower4<2>; 2]) -> Result<Tower4<2>, String> {
1716 if row >= self.n_rows() {
1717 return Err(format!("ExprProgram: row {row} out of range"));
1718 }
1719 Ok(row_expr(p))
1720 }
1721 }
1722
1723 const SEED: [f64; 2] = [0.37, -0.81];
1724 const U: [f64; 2] = [0.6, -0.2];
1725 const V: [f64; 2] = [-0.4, 1.1];
1726 const TOL: f64 = 1e-10;
1727
1728 fn close(a: f64, b: f64, label: &str) {
1729 let band = TOL + TOL * a.abs().max(b.abs());
1730 assert!(
1731 (a - b).abs() <= band,
1732 "{label}: {a:+.15e} vs {b:+.15e} (band {band:.3e})"
1733 );
1734 }
1735
1736 fn tower() -> Tower4<2> {
1737 evaluate_program(&ExprProgram { p: SEED }, 0).expect("tower")
1738 }
1739
1740 /// Order2 reproduces Tower4's value/grad/Hessian channels exactly.
1741 #[test]
1742 fn order2_matches_tower_value_grad_hessian() {
1743 let t = tower();
1744 let vars: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
1745 let s = row_expr(&vars);
1746 close(s.value(), t.v, "value");
1747 for a in 0..2 {
1748 close(s.0.g[a], t.g[a], &format!("grad[{a}]"));
1749 for b in 0..2 {
1750 close(s.h()[a][b], t.h[a][b], &format!("hess[{a}][{b}]"));
1751 }
1752 }
1753 }
1754
1755 /// The `compose_unary_with` seam on a scalar jet is `to_bits`-identical to
1756 /// the explicit `compose_unary(stack_fn(value))` — the contract the batch
1757 /// arm (`Tower{3,4}Lane::compose_unary_with`) lane-matches. Exercised on
1758 /// [`Order2`] across `K ∈ {2,3,4,9}`, ≥ 4000 random seeded inputs.
1759 #[test]
1760 fn compose_unary_with_scalar_seam_bit_identical() {
1761 fn rand_unit(state: &mut u64) -> f64 {
1762 let mut x = *state;
1763 x ^= x << 13;
1764 x ^= x >> 7;
1765 x ^= x << 17;
1766 *state = x;
1767 2.0 * ((x >> 11) as f64 / ((1u64 << 53) as f64)) - 1.0
1768 }
1769 // A base-value-dependent finite stack standing in for a family stack.
1770 fn stack(u: f64) -> [f64; 5] {
1771 [u.sin(), u.cos(), (2.0 * u).sin(), (0.5 * u).cos(), u * u - 0.3]
1772 }
1773 fn run<const K: usize>(state: &mut u64, n: usize) -> usize {
1774 for _ in 0..n {
1775 // A non-trivial Order2<K> jet: a seeded variable pushed through a
1776 // couple of algebra ops so g/h are dense, then exercise the seam.
1777 let base = rand_unit(state);
1778 let mut s = Order2::<K>::variable(base, 0);
1779 for a in 1..K {
1780 s = JetScalar::mul(&s, &Order2::<K>::variable(rand_unit(state), a));
1781 }
1782 let with = s.compose_unary_with(stack);
1783 let explicit = s.compose_unary(stack(s.value()));
1784 assert_eq!(with.value().to_bits(), explicit.value().to_bits(), "value");
1785 for a in 0..K {
1786 assert_eq!(with.g()[a].to_bits(), explicit.g()[a].to_bits(), "g[{a}]");
1787 for b in 0..K {
1788 assert_eq!(
1789 with.h()[a][b].to_bits(),
1790 explicit.h()[a][b].to_bits(),
1791 "h[{a}][{b}]"
1792 );
1793 }
1794 }
1795 }
1796 n
1797 }
1798 let mut st = 0x9e37_79b9_7f4a_7c15u64;
1799 let total =
1800 run::<2>(&mut st, 1100) + run::<3>(&mut st, 1100) + run::<4>(&mut st, 1100) + run::<9>(&mut st, 1100);
1801 assert_eq!(total, 4400);
1802 }
1803
1804 /// OneSeed's ε-Hessian is the contracted third Σ_c ℓ_{abc} u_c, matching
1805 /// `Tower4::third_contracted(u)`. Base channels also match the tower.
1806 #[test]
1807 fn one_seed_matches_tower_third_contracted() {
1808 let t = tower();
1809 let truth = t.third_contracted(&U);
1810 let vars: [OneSeed<2>; 2] =
1811 std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
1812 let s = row_expr(&vars);
1813 // Base channels are the plain (v, g, H).
1814 close(s.value(), t.v, "value");
1815 for a in 0..2 {
1816 for b in 0..2 {
1817 close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
1818 }
1819 }
1820 let third = s.contracted_third();
1821 for a in 0..2 {
1822 for b in 0..2 {
1823 close(third[a][b], truth[a][b], &format!("third[{a}][{b}]"));
1824 }
1825 }
1826 }
1827
1828 /// TwoSeed's εδ-Hessian is the contracted fourth Σ_{cd} ℓ_{abcd} u_c v_d,
1829 /// matching `Tower4::fourth_contracted(u, v)`. The ε / δ single-seed parts
1830 /// reproduce the two third contractions Σ_c ℓ_{abc} u_c and …v_d.
1831 #[test]
1832 fn two_seed_matches_tower_fourth_contracted() {
1833 let t = tower();
1834 let truth4 = t.fourth_contracted(&U, &V);
1835 let truth3_u = t.third_contracted(&U);
1836 let truth3_v = t.third_contracted(&V);
1837 let vars: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
1838 let s = row_expr(&vars);
1839 close(s.value(), t.v, "value");
1840 for a in 0..2 {
1841 close(s.base.0.g[a], t.g[a], &format!("grad[{a}]"));
1842 for b in 0..2 {
1843 close(s.base.h()[a][b], t.h[a][b], &format!("base hess[{a}][{b}]"));
1844 close(
1845 s.eps.h()[a][b],
1846 truth3_u[a][b],
1847 &format!("eps third_u[{a}][{b}]"),
1848 );
1849 close(
1850 s.del.h()[a][b],
1851 truth3_v[a][b],
1852 &format!("del third_v[{a}][{b}]"),
1853 );
1854 }
1855 }
1856 let fourth = s.contracted_fourth();
1857 for a in 0..2 {
1858 for b in 0..2 {
1859 close(fourth[a][b], truth4[a][b], &format!("fourth[{a}][{b}]"));
1860 }
1861 }
1862 }
1863
1864 /// The generic `row_nll_generic` seam (added to Tower4's program trait
1865 /// surface) evaluates the SAME expression on each scalar and extracts the
1866 /// channel a consumer asks for, agreeing with the direct Tower4 contraction.
1867 #[test]
1868 fn generic_program_seam_matches_tower_for_every_channel() {
1869 let t = tower();
1870 // Order2 via generic seam.
1871 let o2: [Order2<2>; 2] = std::array::from_fn(|a| Order2::variable(SEED[a], a));
1872 let so2 = row_expr(&o2);
1873 close(so2.value(), t.v, "seam order2 value");
1874 // OneSeed third.
1875 let os: [OneSeed<2>; 2] =
1876 std::array::from_fn(|a| OneSeed::seed_direction(SEED[a], a, U[a]));
1877 let third = row_expr(&os).contracted_third();
1878 let truth3 = t.third_contracted(&U);
1879 for a in 0..2 {
1880 for b in 0..2 {
1881 close(third[a][b], truth3[a][b], &format!("seam third[{a}][{b}]"));
1882 }
1883 }
1884 // TwoSeed fourth.
1885 let ts: [TwoSeed<2>; 2] = std::array::from_fn(|a| TwoSeed::seed(SEED[a], a, U[a], V[a]));
1886 let fourth = row_expr(&ts).contracted_fourth();
1887 let truth4 = t.fourth_contracted(&U, &V);
1888 for a in 0..2 {
1889 for b in 0..2 {
1890 close(
1891 fourth[a][b],
1892 truth4[a][b],
1893 &format!("seam fourth[{a}][{b}]"),
1894 );
1895 }
1896 }
1897 }
1898
1899 /// The (test-only) `Tower4: JetScalar` impl is the all-channels oracle scalar:
1900 /// evaluating the SAME generic `row_expr` at `S = Tower4` (through the
1901 /// `JetScalar` trait ops) must reproduce, channel-for-channel, the `Tower4`
1902 /// obtained from the `RowNllProgram` / inherent-operator path
1903 /// (`evaluate_program`). This pins that the trait impl delegates faithfully to
1904 /// the inherent `Tower4` arithmetic (so the contracted-scalar oracles above,
1905 /// which compare against `evaluate_program`'s tower, are comparing against the
1906 /// same algebra the `JetScalar` interface exposes).
1907 #[test]
1908 fn tower4_as_jetscalar_matches_program_tower_all_channels() {
1909 let t = tower();
1910 let vars: [Tower4<2>; 2] = std::array::from_fn(|a| Tower4::variable(SEED[a], a));
1911 let s = row_expr(&vars);
1912 close(s.v, t.v, "tower-jetscalar value");
1913 for a in 0..2 {
1914 close(s.g[a], t.g[a], &format!("tower-jetscalar grad[{a}]"));
1915 for b in 0..2 {
1916 close(
1917 s.h[a][b],
1918 t.h[a][b],
1919 &format!("tower-jetscalar hess[{a}][{b}]"),
1920 );
1921 for c in 0..2 {
1922 close(
1923 s.t3[a][b][c],
1924 t.t3[a][b][c],
1925 &format!("tower-jetscalar t3[{a}][{b}][{c}]"),
1926 );
1927 for d in 0..2 {
1928 close(
1929 s.t4[a][b][c][d],
1930 t.t4[a][b][c][d],
1931 &format!("tower-jetscalar t4[{a}][{b}][{c}][{d}]"),
1932 );
1933 }
1934 }
1935 }
1936 }
1937 }
1938}
1939
1940#[cfg(test)]
1941mod batch_tests {
1942 //! SIMD row-batching oracle: prove [`Order2Batch<K>`] (4 rows in
1943 //! `wide::f64x4` lanes) is `to_bits`-identical, on every value/gradient/
1944 //! Hessian channel, to the production [`Order2<K>`] evaluated per row — and
1945 //! that the new scalar field [`Order2Lane<f64, K>`] is too. Composing the two
1946 //! claims, batch lane `i` reproduces the production scalar for row `i` bit
1947 //! for bit, so the 4× throughput is a free lunch (no result change).
1948
1949 use super::{
1950 JetScalar, Lane, OneSeed, OneSeedBatch, OneSeedLane, Order2, Order2Batch, Order2Lane,
1951 TwoSeed, TwoSeedBatch, TwoSeedLane,
1952 };
1953
1954 /// The ops the witness row expression needs, so ONE generic body evaluates
1955 /// at the production [`Order2<K>`], the new scalar [`Order2Lane<f64, K>`],
1956 /// and the batched [`Order2Batch<K>`].
1957 trait RowAlg<const K: usize>: Copy {
1958 fn constant(c: f64) -> Self;
1959 fn add(&self, o: &Self) -> Self;
1960 fn sub(&self, o: &Self) -> Self;
1961 fn mul(&self, o: &Self) -> Self;
1962 fn scale(&self, s: f64) -> Self;
1963 fn exp(&self) -> Self;
1964 fn sqrt(&self) -> Self;
1965 fn recip(&self) -> Self;
1966 }
1967
1968 impl<const K: usize> RowAlg<K> for Order2<K> {
1969 fn constant(c: f64) -> Self {
1970 <Self as JetScalar<K>>::constant(c)
1971 }
1972 fn add(&self, o: &Self) -> Self {
1973 JetScalar::add(self, o)
1974 }
1975 fn sub(&self, o: &Self) -> Self {
1976 JetScalar::sub(self, o)
1977 }
1978 fn mul(&self, o: &Self) -> Self {
1979 JetScalar::mul(self, o)
1980 }
1981 fn scale(&self, s: f64) -> Self {
1982 JetScalar::scale(self, s)
1983 }
1984 fn exp(&self) -> Self {
1985 JetScalar::exp(self)
1986 }
1987 fn sqrt(&self) -> Self {
1988 JetScalar::sqrt(self)
1989 }
1990 fn recip(&self) -> Self {
1991 JetScalar::recip(self)
1992 }
1993 }
1994
1995 impl<L: Lane, const K: usize> RowAlg<K> for Order2Lane<L, K> {
1996 fn constant(c: f64) -> Self {
1997 Order2Lane::constant(L::splat(c))
1998 }
1999 fn add(&self, o: &Self) -> Self {
2000 Order2Lane::add(self, o)
2001 }
2002 fn sub(&self, o: &Self) -> Self {
2003 Order2Lane::sub(self, o)
2004 }
2005 fn mul(&self, o: &Self) -> Self {
2006 Order2Lane::mul(self, o)
2007 }
2008 fn scale(&self, s: f64) -> Self {
2009 Order2Lane::scale(self, s)
2010 }
2011 fn exp(&self) -> Self {
2012 Order2Lane::exp(self)
2013 }
2014 fn sqrt(&self) -> Self {
2015 Order2Lane::sqrt(self)
2016 }
2017 fn recip(&self) -> Self {
2018 Order2Lane::recip(self)
2019 }
2020 }
2021
2022 /// A dense witness row expression touching every algebra op (mul, add, sub,
2023 /// scale, exp, sqrt, recip) over ALL `K` primaries, so the gradient and the
2024 /// full `K×K` Hessian are dense (no trivially-zero channel). All transcend.
2025 /// arguments are kept finite/positive: `sqrt(s²+1) > 0`, `recip(exp+2) > 0`.
2026 fn row_expr<const K: usize, A: RowAlg<K>>(p: &[A; K]) -> A {
2027 let mut s = A::constant(0.3);
2028 for a in 0..K {
2029 let b = (a + 1) % K;
2030 s = s.add(&p[a].mul(&p[b]).scale(0.1 + 0.05 * a as f64));
2031 }
2032 let e = s.exp();
2033 let r = s.mul(&s).add(&A::constant(1.0)).sqrt();
2034 let denom = e.add(&A::constant(2.0));
2035 e.mul(&r).sub(&s.scale(0.5)).mul(&denom.recip())
2036 }
2037
2038 /// xorshift64 → `f64` in `[-1, 1)`.
2039 fn rand_unit(state: &mut u64) -> f64 {
2040 let mut x = *state;
2041 x ^= x << 13;
2042 x ^= x >> 7;
2043 x ^= x << 17;
2044 *state = x;
2045 let u = (x >> 11) as f64 / ((1u64 << 53) as f64); // [0, 1)
2046 2.0 * u - 1.0
2047 }
2048
2049 /// Returns the number of (batch, row) pairs whose every channel was
2050 /// verified bit-identical, so the caller can assert the expected total ran.
2051 fn check_k<const K: usize>(state: &mut u64, batches: usize) -> usize {
2052 let mut verified_rows = 0usize;
2053 for _ in 0..batches {
2054 // Four independent rows of K primary values.
2055 let rows: [[f64; K]; 4] =
2056 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2057
2058 // Production ground truth, evaluated per row at Order2<K>.
2059 let prod: [Order2<K>; 4] = std::array::from_fn(|r| {
2060 let p: [Order2<K>; K] = std::array::from_fn(|a| Order2::variable(rows[r][a], a));
2061 row_expr(&p)
2062 });
2063
2064 // New scalar field (Order2Lane<f64>), per row.
2065 let scal: [Order2Lane<f64, K>; 4] = std::array::from_fn(|r| {
2066 let p: [Order2Lane<f64, K>; K] =
2067 std::array::from_fn(|a| Order2Lane::variable(rows[r][a], a));
2068 row_expr(&p)
2069 });
2070
2071 // Batched: 4 rows packed into f64x4 lanes, ONE vector pass.
2072 let pbatch: [Order2Batch<K>; K] = std::array::from_fn(|a| {
2073 let packed =
2074 wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2075 Order2Batch::variable(packed, a)
2076 });
2077 let batch = row_expr(&pbatch);
2078
2079 for r in 0..4 {
2080 let g = prod[r].0;
2081 // Order2Lane<f64> == Order2<K> (bit-identical scalar field).
2082 assert_eq!(scal[r].v.to_bits(), g.v.to_bits(), "K={K} scalar v");
2083 // Batch lane r == Order2<K> for row r.
2084 let lr = batch.lane(r).0;
2085 assert_eq!(lr.v.to_bits(), g.v.to_bits(), "K={K} batch lane {r} v");
2086 for a in 0..K {
2087 assert_eq!(
2088 scal[r].g[a].to_bits(),
2089 g.g[a].to_bits(),
2090 "K={K} scalar g[{a}]"
2091 );
2092 assert_eq!(
2093 lr.g[a].to_bits(),
2094 g.g[a].to_bits(),
2095 "K={K} batch lane {r} g[{a}]"
2096 );
2097 for b in 0..K {
2098 assert_eq!(
2099 scal[r].h[a][b].to_bits(),
2100 g.h[a][b].to_bits(),
2101 "K={K} scalar h[{a}][{b}]"
2102 );
2103 assert_eq!(
2104 lr.h[a][b].to_bits(),
2105 g.h[a][b].to_bits(),
2106 "K={K} batch lane {r} h[{a}][{b}]"
2107 );
2108 }
2109 }
2110 verified_rows += 1;
2111 }
2112 }
2113 verified_rows
2114 }
2115
2116 /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: every channel of
2117 /// every lane is `to_bits`-identical to the production scalar per row.
2118 #[test]
2119 fn batch_lanes_bit_identical_to_scalar_per_row() {
2120 let mut state = 0x9E37_79B9_7F4A_7C15_u64;
2121 let mut verified = 0usize;
2122 verified += check_k::<2>(&mut state, 2000);
2123 verified += check_k::<3>(&mut state, 2000);
2124 verified += check_k::<4>(&mut state, 2000);
2125 verified += check_k::<9>(&mut state, 2000);
2126 // 4 K-values × 2000 batches × 4 packed rows each, all bit-identical.
2127 assert_eq!(verified, 4 * 2000 * 4, "every batch row must be verified");
2128 }
2129
2130 // ── One-/two-seed lane oracles ──────────────────────────────────────────
2131 //
2132 // The same dense `row_expr` witness program runs over the SEEDED directional
2133 // scalars: the scalar `OneSeed`/`TwoSeed` per row, the `f64`-lane re-type
2134 // (`*SeedLane<f64>`), and the 4-rows-per-pass batch (`*SeedBatch`). The
2135 // headline claim is that the contracted-third / contracted-fourth channel of
2136 // every lane is `to_bits`-identical to the production scalar's per row.
2137
2138 impl<const K: usize> RowAlg<K> for OneSeed<K> {
2139 fn constant(c: f64) -> Self {
2140 <Self as JetScalar<K>>::constant(c)
2141 }
2142 fn add(&self, o: &Self) -> Self {
2143 JetScalar::add(self, o)
2144 }
2145 fn sub(&self, o: &Self) -> Self {
2146 JetScalar::sub(self, o)
2147 }
2148 fn mul(&self, o: &Self) -> Self {
2149 JetScalar::mul(self, o)
2150 }
2151 fn scale(&self, s: f64) -> Self {
2152 JetScalar::scale(self, s)
2153 }
2154 fn exp(&self) -> Self {
2155 JetScalar::exp(self)
2156 }
2157 fn sqrt(&self) -> Self {
2158 JetScalar::sqrt(self)
2159 }
2160 fn recip(&self) -> Self {
2161 JetScalar::recip(self)
2162 }
2163 }
2164
2165 impl<L: Lane, const K: usize> RowAlg<K> for OneSeedLane<L, K> {
2166 fn constant(c: f64) -> Self {
2167 OneSeedLane::constant(L::splat(c))
2168 }
2169 fn add(&self, o: &Self) -> Self {
2170 OneSeedLane::add(self, o)
2171 }
2172 fn sub(&self, o: &Self) -> Self {
2173 OneSeedLane::sub(self, o)
2174 }
2175 fn mul(&self, o: &Self) -> Self {
2176 OneSeedLane::mul(self, o)
2177 }
2178 fn scale(&self, s: f64) -> Self {
2179 OneSeedLane::scale(self, s)
2180 }
2181 fn exp(&self) -> Self {
2182 OneSeedLane::exp(self)
2183 }
2184 fn sqrt(&self) -> Self {
2185 OneSeedLane::sqrt(self)
2186 }
2187 fn recip(&self) -> Self {
2188 OneSeedLane::recip(self)
2189 }
2190 }
2191
2192 impl<const K: usize> RowAlg<K> for TwoSeed<K> {
2193 fn constant(c: f64) -> Self {
2194 <Self as JetScalar<K>>::constant(c)
2195 }
2196 fn add(&self, o: &Self) -> Self {
2197 JetScalar::add(self, o)
2198 }
2199 fn sub(&self, o: &Self) -> Self {
2200 JetScalar::sub(self, o)
2201 }
2202 fn mul(&self, o: &Self) -> Self {
2203 JetScalar::mul(self, o)
2204 }
2205 fn scale(&self, s: f64) -> Self {
2206 JetScalar::scale(self, s)
2207 }
2208 fn exp(&self) -> Self {
2209 JetScalar::exp(self)
2210 }
2211 fn sqrt(&self) -> Self {
2212 JetScalar::sqrt(self)
2213 }
2214 fn recip(&self) -> Self {
2215 JetScalar::recip(self)
2216 }
2217 }
2218
2219 impl<L: Lane, const K: usize> RowAlg<K> for TwoSeedLane<L, K> {
2220 fn constant(c: f64) -> Self {
2221 TwoSeedLane::constant(L::splat(c))
2222 }
2223 fn add(&self, o: &Self) -> Self {
2224 TwoSeedLane::add(self, o)
2225 }
2226 fn sub(&self, o: &Self) -> Self {
2227 TwoSeedLane::sub(self, o)
2228 }
2229 fn mul(&self, o: &Self) -> Self {
2230 TwoSeedLane::mul(self, o)
2231 }
2232 fn scale(&self, s: f64) -> Self {
2233 TwoSeedLane::scale(self, s)
2234 }
2235 fn exp(&self) -> Self {
2236 TwoSeedLane::exp(self)
2237 }
2238 fn sqrt(&self) -> Self {
2239 TwoSeedLane::sqrt(self)
2240 }
2241 fn recip(&self) -> Self {
2242 TwoSeedLane::recip(self)
2243 }
2244 }
2245
2246 fn check_oneseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
2247 let mut rows_checked = 0;
2248 for _ in 0..batches {
2249 let rows: [[f64; K]; 4] =
2250 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2251 // Per-row ε-direction.
2252 let u: [[f64; K]; 4] =
2253 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2254
2255 // Production ground truth (scalar OneSeed per row).
2256 let prod: [OneSeed<K>; 4] = std::array::from_fn(|r| {
2257 let p: [OneSeed<K>; K] =
2258 std::array::from_fn(|a| OneSeed::seed_direction(rows[r][a], a, u[r][a]));
2259 row_expr(&p)
2260 });
2261
2262 // f64-lane re-type per row.
2263 let scal: [OneSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
2264 let p: [OneSeedLane<f64, K>; K] =
2265 std::array::from_fn(|a| OneSeedLane::seed_direction(rows[r][a], a, u[r][a]));
2266 row_expr(&p)
2267 });
2268
2269 // 4-rows-per-pass batch.
2270 let pbatch: [OneSeedBatch<K>; K] = std::array::from_fn(|a| {
2271 let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2272 let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
2273 OneSeedBatch::seed_direction(val, a, uu)
2274 });
2275 let batch = row_expr(&pbatch);
2276
2277 for r in 0..4 {
2278 let want = prod[r].contracted_third();
2279 let got_scal = scal[r].contracted_third();
2280 let got_batch = batch.lane(r).contracted_third();
2281 // Value channel too (sanity that the base program agrees).
2282 assert_eq!(
2283 scal[r].base.v.to_bits(),
2284 prod[r].base.value().to_bits(),
2285 "OneSeed K={K} scalar value"
2286 );
2287 assert_eq!(
2288 batch.lane(r).base.value().to_bits(),
2289 prod[r].base.value().to_bits(),
2290 "OneSeed K={K} batch lane {r} value"
2291 );
2292 for a in 0..K {
2293 for b in 0..K {
2294 assert_eq!(
2295 got_scal[a][b].to_bits(),
2296 want[a][b].to_bits(),
2297 "OneSeed K={K} scalar third[{a}][{b}]"
2298 );
2299 assert_eq!(
2300 got_batch[a][b].to_bits(),
2301 want[a][b].to_bits(),
2302 "OneSeed K={K} batch lane {r} third[{a}][{b}]"
2303 );
2304 }
2305 }
2306 rows_checked += 1;
2307 }
2308 }
2309 rows_checked
2310 }
2311
2312 fn check_twoseed<const K: usize>(state: &mut u64, batches: usize) -> usize {
2313 let mut rows_checked = 0;
2314 for _ in 0..batches {
2315 let rows: [[f64; K]; 4] =
2316 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2317 let u: [[f64; K]; 4] =
2318 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2319 let v: [[f64; K]; 4] =
2320 std::array::from_fn(|_| std::array::from_fn(|_| rand_unit(state)));
2321
2322 let prod: [TwoSeed<K>; 4] = std::array::from_fn(|r| {
2323 let p: [TwoSeed<K>; K] =
2324 std::array::from_fn(|a| TwoSeed::seed(rows[r][a], a, u[r][a], v[r][a]));
2325 row_expr(&p)
2326 });
2327
2328 let scal: [TwoSeedLane<f64, K>; 4] = std::array::from_fn(|r| {
2329 let p: [TwoSeedLane<f64, K>; K] =
2330 std::array::from_fn(|a| TwoSeedLane::seed(rows[r][a], a, u[r][a], v[r][a]));
2331 row_expr(&p)
2332 });
2333
2334 let pbatch: [TwoSeedBatch<K>; K] = std::array::from_fn(|a| {
2335 let val = wide::f64x4::new([rows[0][a], rows[1][a], rows[2][a], rows[3][a]]);
2336 let uu = wide::f64x4::new([u[0][a], u[1][a], u[2][a], u[3][a]]);
2337 let vv = wide::f64x4::new([v[0][a], v[1][a], v[2][a], v[3][a]]);
2338 TwoSeedBatch::seed(val, a, uu, vv)
2339 });
2340 let batch = row_expr(&pbatch);
2341
2342 for r in 0..4 {
2343 let want = prod[r].contracted_fourth();
2344 let got_scal = scal[r].contracted_fourth();
2345 let got_batch = batch.lane(r).contracted_fourth();
2346 assert_eq!(
2347 scal[r].base.v.to_bits(),
2348 prod[r].base.value().to_bits(),
2349 "TwoSeed K={K} scalar value"
2350 );
2351 assert_eq!(
2352 batch.lane(r).base.value().to_bits(),
2353 prod[r].base.value().to_bits(),
2354 "TwoSeed K={K} batch lane {r} value"
2355 );
2356 for a in 0..K {
2357 for b in 0..K {
2358 assert_eq!(
2359 got_scal[a][b].to_bits(),
2360 want[a][b].to_bits(),
2361 "TwoSeed K={K} scalar fourth[{a}][{b}]"
2362 );
2363 assert_eq!(
2364 got_batch[a][b].to_bits(),
2365 want[a][b].to_bits(),
2366 "TwoSeed K={K} batch lane {r} fourth[{a}][{b}]"
2367 );
2368 }
2369 }
2370 rows_checked += 1;
2371 }
2372 }
2373 rows_checked
2374 }
2375
2376 /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: the
2377 /// contracted-third channel of every `OneSeedLane` lane is `to_bits`-identical
2378 /// to the production [`OneSeed`] per row.
2379 #[test]
2380 fn oneseed_lanes_contracted_third_bit_identical() {
2381 let mut state = 0x1234_5678_9ABC_DEF0_u64;
2382 let batches = 2000;
2383 let rows_checked = check_oneseed::<2>(&mut state, batches)
2384 + check_oneseed::<3>(&mut state, batches)
2385 + check_oneseed::<4>(&mut state, batches)
2386 + check_oneseed::<9>(&mut state, batches);
2387 // 4 widths × `batches` batches × 4 rows each: a silently empty inner
2388 // loop would leave this at zero instead of passing as a no-op.
2389 assert_eq!(rows_checked, 4 * batches * 4);
2390 }
2391
2392 /// ≥2000 random 4-row batches per K, across K ∈ {2,3,4,9}: the
2393 /// contracted-fourth channel of every `TwoSeedLane` lane is `to_bits`-identical
2394 /// to the production [`TwoSeed`] per row.
2395 #[test]
2396 fn twoseed_lanes_contracted_fourth_bit_identical() {
2397 let mut state = 0x0FED_CBA9_8765_4321_u64;
2398 let batches = 2000;
2399 let rows_checked = check_twoseed::<2>(&mut state, batches)
2400 + check_twoseed::<3>(&mut state, batches)
2401 + check_twoseed::<4>(&mut state, batches)
2402 + check_twoseed::<9>(&mut state, batches);
2403 // 4 widths × `batches` batches × 4 rows each: a silently empty inner
2404 // loop would leave this at zero instead of passing as a no-op.
2405 assert_eq!(rows_checked, 4 * batches * 4);
2406 }
2407}
2408
2409#[cfg(test)]
2410mod unit_tests {
2411 use super::{JetScalar, Order1, Order2, filtered_implicit_solve_scalar};
2412
2413 // ── Order2 direct property tests ─────────────────────────────────────────
2414
2415 /// `Order2::constant(c)` carries value `c` and zero everywhere else.
2416 #[test]
2417 fn order2_constant_has_zero_derivatives() {
2418 let s = Order2::<3>::constant(7.5);
2419 assert_eq!(s.value(), 7.5);
2420 for a in 0..3 {
2421 assert_eq!(s.g()[a], 0.0, "grad[{a}] should be zero");
2422 for b in 0..3 {
2423 assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
2424 }
2425 }
2426 }
2427
2428 /// `Order2::variable(x, axis)` has unit gradient in slot `axis` and zero Hessian.
2429 #[test]
2430 fn order2_variable_has_unit_gradient_in_seeded_slot() {
2431 let x = -2.5_f64;
2432 let s = Order2::<4>::variable(x, 2);
2433 assert_eq!(s.value(), x);
2434 for a in 0..4 {
2435 let expected_g = if a == 2 { 1.0 } else { 0.0 };
2436 assert_eq!(s.g()[a], expected_g, "grad[{a}]");
2437 for b in 0..4 {
2438 assert_eq!(s.h()[a][b], 0.0, "hess[{a}][{b}] should be zero");
2439 }
2440 }
2441 }
2442
2443 /// `Order2::add` sums gradient channels; `sub` is the inverse on gradients.
2444 /// Uses integer-valued primaries so the value roundtrip is also exact.
2445 #[test]
2446 fn order2_add_sub_roundtrip() {
2447 let p = Order2::<2>::variable(3.0, 0);
2448 let q = Order2::<2>::variable(2.0, 1);
2449 let pq = JetScalar::add(&p, &q);
2450 // value = 3 + 2 = 5
2451 assert_eq!(pq.value(), 5.0, "add value");
2452 let back = JetScalar::sub(&pq, &q);
2453 // (p + q) - q gradient should equal p's gradient exactly
2454 for a in 0..2 {
2455 assert_eq!(back.g()[a], p.g()[a], "grad[{a}] roundtrip");
2456 }
2457 }
2458
2459 /// `Order2::mul` of two variables satisfies the Leibniz product rule:
2460 /// ∂(p·q)/∂p = q, ∂(p·q)/∂q = p, ∂²(p·q)/∂p∂q = 1.
2461 #[test]
2462 fn order2_mul_satisfies_leibniz_rule() {
2463 let pv = 3.0_f64;
2464 let qv = -2.0_f64;
2465 let p = Order2::<2>::variable(pv, 0);
2466 let q = Order2::<2>::variable(qv, 1);
2467 let pq = JetScalar::mul(&p, &q);
2468 assert_eq!(pq.value(), pv * qv, "value = p·q");
2469 assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
2470 assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
2471 assert_eq!(pq.h()[0][1], 1.0, "∂²(p·q)/∂p∂q = 1");
2472 assert_eq!(pq.h()[1][0], 1.0, "∂²(p·q)/∂q∂p = 1 (symmetric)");
2473 assert_eq!(pq.h()[0][0], 0.0, "∂²(p·q)/∂p² = 0");
2474 assert_eq!(pq.h()[1][1], 0.0, "∂²(p·q)/∂q² = 0");
2475 }
2476
2477 /// `Order2::scale(s)` multiplies every channel by `s`.
2478 #[test]
2479 fn order2_scale_multiplies_all_channels() {
2480 let p = Order2::<2>::variable(4.0, 0);
2481 let s = 2.5_f64;
2482 let ps = JetScalar::scale(&p, s);
2483 assert_eq!(ps.value(), 4.0 * s);
2484 assert_eq!(ps.g()[0], 1.0 * s);
2485 assert_eq!(ps.g()[1], 0.0);
2486 }
2487
2488 /// `Order2::exp` at a constant has value `e^c`, gradient `e^c * g`, Hessian `e^c * (g⊗g + H)`.
2489 /// At a seeded variable `p₀`, the first derivative is `e^{p₀}` and second is `e^{p₀}`.
2490 #[test]
2491 fn order2_exp_derivative_stack_correct() {
2492 let p0 = 1.0_f64;
2493 let p = Order2::<1>::variable(p0, 0);
2494 let ep = JetScalar::exp(&p);
2495 let e = p0.exp();
2496 assert!((ep.value() - e).abs() < 1e-15, "exp value");
2497 assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p) = exp(p)");
2498 assert!((ep.h()[0][0] - e).abs() < 1e-15, "d²/dp² exp(p) = exp(p)");
2499 }
2500
2501 /// `Order2::ln` at a seeded variable: d/dp ln(p) = 1/p, d²/dp² ln(p) = -1/p².
2502 #[test]
2503 fn order2_ln_derivative_stack_correct() {
2504 let p0 = 2.0_f64;
2505 let p = Order2::<1>::variable(p0, 0);
2506 let lnp = JetScalar::ln(&p);
2507 assert!((lnp.value() - p0.ln()).abs() < 1e-15, "ln value");
2508 assert!((lnp.g()[0] - 1.0 / p0).abs() < 1e-15, "d/dp ln(p) = 1/p");
2509 assert!((lnp.h()[0][0] - (-1.0 / (p0 * p0))).abs() < 1e-15, "d²/dp² ln(p) = -1/p²");
2510 }
2511
2512 /// `exp` and `ln` are mutual inverses: `ln(exp(p)).value() == p` at the scalar.
2513 #[test]
2514 fn order2_exp_ln_roundtrip_at_value() {
2515 let p0 = 0.8_f64;
2516 let p = Order2::<1>::variable(p0, 0);
2517 let roundtrip = JetScalar::ln(&JetScalar::exp(&p));
2518 assert!((roundtrip.value() - p0).abs() < 1e-14, "ln(exp(p)) ≈ p");
2519 }
2520
2521 // ── Order1 tests ─────────────────────────────────────────────────────────
2522
2523 /// `Order1::constant` carries the correct value with all-zero gradient.
2524 #[test]
2525 fn order1_constant_has_zero_gradient() {
2526 let s = Order1::<3>::constant(-5.0);
2527 assert_eq!(s.value(), -5.0);
2528 for a in 0..3 {
2529 assert_eq!(s.g()[a], 0.0, "g[{a}] should be zero");
2530 }
2531 }
2532
2533 /// `Order1::variable(x, axis)` has unit gradient only in `axis`.
2534 #[test]
2535 fn order1_variable_has_unit_gradient_in_seeded_slot() {
2536 let s = Order1::<3>::variable(2.0, 1);
2537 assert_eq!(s.value(), 2.0);
2538 assert_eq!(s.g()[0], 0.0);
2539 assert_eq!(s.g()[1], 1.0);
2540 assert_eq!(s.g()[2], 0.0);
2541 }
2542
2543 /// `Order1::mul` satisfies the product rule (value and gradient, no Hessian).
2544 #[test]
2545 fn order1_mul_satisfies_product_rule() {
2546 let pv = 3.0_f64;
2547 let qv = -2.0_f64;
2548 let p = Order1::<2>::variable(pv, 0);
2549 let q = Order1::<2>::variable(qv, 1);
2550 let pq = JetScalar::mul(&p, &q);
2551 assert_eq!(pq.value(), pv * qv);
2552 assert_eq!(pq.g()[0], qv, "∂(p·q)/∂p = q");
2553 assert_eq!(pq.g()[1], pv, "∂(p·q)/∂q = p");
2554 }
2555
2556 /// `Order1::exp` carries the correct value and gradient `e^{p₀}`.
2557 #[test]
2558 fn order1_exp_has_correct_value_and_gradient() {
2559 let p0 = 0.5_f64;
2560 let p = Order1::<2>::variable(p0, 0);
2561 let ep = JetScalar::exp(&p);
2562 let e = p0.exp();
2563 assert!((ep.value() - e).abs() < 1e-15, "exp value");
2564 assert!((ep.g()[0] - e).abs() < 1e-15, "d/dp exp(p)");
2565 assert_eq!(ep.g()[1], 0.0, "irrelevant gradient slot is zero");
2566 }
2567
2568 /// `Order1` and `Order2` agree on value and gradient for the same expression.
2569 #[test]
2570 fn order1_and_order2_agree_on_value_and_gradient() {
2571 let p0 = 1.3_f64;
2572 let q0 = -0.7_f64;
2573 // evaluate (p * q + p).exp() at (p0, q0)
2574 let p1 = Order1::<2>::variable(p0, 0);
2575 let q1 = Order1::<2>::variable(q0, 1);
2576 let expr1 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p1, &q1), &p1));
2577
2578 let p2 = Order2::<2>::variable(p0, 0);
2579 let q2 = Order2::<2>::variable(q0, 1);
2580 let expr2 = JetScalar::exp(&JetScalar::add(&JetScalar::mul(&p2, &q2), &p2));
2581
2582 assert!((expr1.value() - expr2.value()).abs() < 1e-14, "value mismatch");
2583 for a in 0..2 {
2584 assert!(
2585 (expr1.g()[a] - expr2.g()[a]).abs() < 1e-14,
2586 "gradient[{a}] mismatch"
2587 );
2588 }
2589 }
2590
2591 // ── filtered_implicit_solve_scalar ────────────────────────────────────────
2592
2593 /// Lift the trivial linear constraint F(a, θ) = a - θ = 0 through `Order2<1>`.
2594 /// The exact lifted jet is a(θ) = θ, so value=θ₀, gradient=1.
2595 #[test]
2596 fn filtered_implicit_solve_linear_constraint_gives_exact_jet() {
2597 let theta0 = 3.0_f64;
2598 let theta = Order2::<1>::variable(theta0, 0);
2599 // a0 = theta0, F_a = 1, inv_fa = 1; 2 iters suffice for Order2.
2600 let a = filtered_implicit_solve_scalar::<1, Order2<1>>(
2601 theta0,
2602 1.0,
2603 2,
2604 |a_jet| JetScalar::sub(a_jet, &theta),
2605 );
2606 assert!((a.value() - theta0).abs() < 1e-14, "value = theta0");
2607 // da/dtheta = 1 (identity)
2608 assert!((a.g()[0] - 1.0).abs() < 1e-14, "gradient = 1");
2609 // d²a/dtheta² = 0 (linear)
2610 assert!(a.h()[0][0].abs() < 1e-14, "hessian = 0");
2611 }
2612
2613 /// `filtered_implicit_solve_scalar` on a quadratic constraint F(a,θ)=a²-θ=0
2614 /// with primal root a₀=√θ₀, giving da/dθ = 1/(2√θ₀), d²a/dθ² = -1/(4θ₀^{3/2}).
2615 #[test]
2616 fn filtered_implicit_solve_quadratic_constraint_matches_analytic_derivatives() {
2617 let theta0 = 4.0_f64;
2618 let a0 = theta0.sqrt();
2619 let inv_fa = 1.0 / (2.0 * a0);
2620 let theta = Order2::<1>::variable(theta0, 0);
2621 // F(a,theta) = a*a - theta
2622 let a = filtered_implicit_solve_scalar::<1, Order2<1>>(a0, inv_fa, 2, |a_jet| {
2623 let aa = JetScalar::mul(a_jet, a_jet);
2624 JetScalar::sub(&aa, &theta)
2625 });
2626 let tol = 1e-12;
2627 assert!((a.value() - a0).abs() < tol, "value = sqrt(theta0)");
2628 let expected_g = 0.5 / a0;
2629 assert!((a.g()[0] - expected_g).abs() < tol, "da/dtheta = 1/(2*sqrt)");
2630 let expected_h = -0.25 / (theta0 * a0);
2631 assert!((a.h()[0][0] - expected_h).abs() < tol, "d2a/dtheta2 = -1/(4*theta^1.5)");
2632 }
2633}