gam_math/jet_partitions.rs
1//! Bitmask-coefficient multi-directional jets used by marginal-slope and
2//! latent-survival row kernels.
3//!
4//! The layout stores one coefficient per direction mask. The calculus itself
5//! lives in [`crate::jet_algebra`]: that module owns the layout-agnostic
6//! Leibniz / Faà di Bruno *combinatorics* once, and the scalar (`n_dirs <= 1`)
7//! path here still routes through it so a fix to the rule is a fix to both
8//! representations.
9//!
10//! ## Why this layout is special (and how the hot path exploits it)
11//!
12//! Each direction is seeded *linearly* (one first-derivative slot), so every
13//! direction variable squares to zero. The coefficients therefore form the
14//! commutative **multilinear / set-function algebra**: `coeffs[mask]` is the
15//! coefficient of `Π_{i ∈ mask} ε_i`. In that algebra two facts collapse the
16//! generic combinatorial walkers into tight branch-free arithmetic:
17//!
18//! * **`mul` is the subset (zeta-style) convolution**
19//! `out[mask] = Σ_{sub ⊆ mask} a[sub] · b[mask \ sub]`.
20//! The shared `leibniz_product` walker rebuilds two `SlotBuf`s and folds bit
21//! lists back into masks (`mask_of`) *per subset*; here we enumerate the
22//! submasks of `mask` directly — `mask \ sub == mask ^ sub` because
23//! `sub ⊆ mask` — in the **same ascending order** the walker used, so the
24//! floating-point accumulation is bit-for-bit identical while every
25//! `SlotBuf`/closure/`mask_of` allocation and indirection disappears
26//! (`3^K` pure FMAs, no heap, no `dyn`).
27//!
28//! * **`compose_unary` is the truncated Faà di Bruno composition**, computed
29//! here by the exact **truncated-Taylor reassociation** rather than a direct
30//! set-partition sum. Let `v` be the non-constant part of `self`
31//! (`v[0] = 0`, `v[mask] = self[mask]`) and let `v^{⊛k}` be the `k`-fold
32//! *subset convolution* (the multilinear power). The ordered-tuple identity
33//! `v^{⊛k}[mask] = k! · Σ_{π ⊢ mask, |π| = k} Π_{B ∈ π} v[B]` turns the
34//! set-partition sum into a degree-4 polynomial in `v`:
35//!
36//! ```text
37//! f(self)[mask] = Σ_{k=0}^{4} (f^{(k)} / k!) · v^{⊛k}[mask] (mask ≠ 0)
38//! f(self)[0] = f^{(0)}
39//! ```
40//!
41//! so a composition is just **three subset convolutions** (`v²`, `v³=v²⊛v`,
42//! `v⁴=v²⊛v²` — the Motzkin floor for a quartic) plus a five-term combine.
43//! That is ~3× fewer FLOPs than the per-mask partition gather; each
44//! convolution is a four-lane compensated dot product (Ogita–Rump–Oishi
45//! Dot2, FMA-split products + TwoSum carry) so the result is computed in
46//! ~double the working precision and the rounding of `v²` cannot compound
47//! through `v³`/`v⁴`; the final per-mask combine is Neumaier-compensated and
48//! `wide::f64x4`-vectorised; and the whole call runs on reused thread-local
49//! scratch with no per-call heap traffic. The reassociation is algebraically
50//! exact; accuracy-vs-truth (a double-double oracle) is the test gate and is
51//! strictly ≤ the old partition sum's error (see `tests`).
52use std::cell::RefCell;
53use std::sync::atomic::{AtomicU64, Ordering};
54use wide::{CmpGe, f64x4};
55
56pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
57pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
58
59/// Length of the unary derivative stack `[f, f', f'', f''', f'''']`: composition
60/// is exact through order 4, partitions into `>= 5` blocks are truncated.
61const DERIVS: usize = 5;
62
63#[derive(Clone)]
64pub struct MultiDirJet {
65 pub coeffs: Vec<f64>,
66}
67
68impl MultiDirJet {
69 pub fn zero(n_dirs: usize) -> Self {
70 Self {
71 coeffs: vec![0.0; 1usize << n_dirs],
72 }
73 }
74
75 pub fn constant(n_dirs: usize, value: f64) -> Self {
76 let mut out = Self::zero(n_dirs);
77 out.coeffs[0] = value;
78 out
79 }
80
81 pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
82 let mut out = Self::constant(n_dirs, base);
83 for (idx, &value) in first.iter().take(n_dirs).enumerate() {
84 out.coeffs[1usize << idx] = value;
85 }
86 out
87 }
88
89 pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
90 let mut out = Self::zero(n_dirs);
91 for &(mask, value) in coeffs {
92 if mask < out.coeffs.len() {
93 out.coeffs[mask] = value;
94 }
95 }
96 out
97 }
98
99 #[inline]
100 pub fn coeff(&self, mask: usize) -> f64 {
101 self.coeffs[mask]
102 }
103
104 pub fn add(&self, other: &Self) -> Self {
105 Self {
106 coeffs: self
107 .coeffs
108 .iter()
109 .zip(other.coeffs.iter())
110 .map(|(lhs, rhs)| lhs + rhs)
111 .collect(),
112 }
113 }
114
115 pub fn scale(&self, scalar: f64) -> Self {
116 Self {
117 coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
118 }
119 }
120
121 /// Subset-convolution product `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`.
122 ///
123 /// Bit-identical to the shared [`crate::jet_algebra::leibniz_product`] walker
124 /// (the submasks are enumerated in the same ascending order — the walker's
125 /// compacted subset index is a monotone bit-deposit of the submask) while
126 /// dropping its per-subset `SlotBuf`/closure/`mask_of` overhead. The scalar
127 /// `n_dirs == 0` case keeps the shared walker live as its reference.
128 pub fn mul(&self, other: &Self) -> Self {
129 MUL_CALLS.fetch_add(1, Ordering::Relaxed);
130 let count = self.coeffs.len();
131 if count <= 1 {
132 return self.mul_reference(other);
133 }
134 let a = &self.coeffs;
135 let b = &other.coeffs;
136 // Both operands carry the same direction set, so `b` is `count` long too.
137 // With that established once, every `a[sub]`/`b[mask ^ sub]` below is
138 // provably in bounds (`sub, mask ^ sub ⊆ mask < count`), so the inner
139 // submask walk can drop its per-load bounds checks.
140 assert_eq!(b.len(), count, "MultiDirJet::mul operands must share n_dirs");
141 let mut out = vec![0.0; count];
142 for (mask, slot) in out.iter_mut().enumerate() {
143 // Walk every submask of `mask` in ascending numeric order — the same
144 // order `leibniz_product` accumulates — via the classic gap-fill
145 // increment `next = ((sub | !mask) + 1) & mask`.
146 let mut acc = 0.0;
147 let mut sub = 0usize;
148 // SAFETY: `sub ⊆ mask < count` and `mask ^ sub ⊆ mask < count`, and
149 // both `a` and `b` are `count` long (asserted above).
150 unsafe {
151 loop {
152 acc += *a.get_unchecked(sub) * *b.get_unchecked(mask ^ sub);
153 if sub == mask {
154 break;
155 }
156 sub = (sub | !mask).wrapping_add(1) & mask;
157 }
158 }
159 *slot = acc;
160 }
161 Self { coeffs: out }
162 }
163
164 /// The pre-#perf shared-walker product, retained verbatim as the scalar-case
165 /// implementation and as the bit-exact reference for `mul`.
166 fn mul_reference(&self, other: &Self) -> Self {
167 let count = self.coeffs.len();
168 let mut out = vec![0.0; count];
169 for (mask, slot) in out.iter_mut().enumerate() {
170 let bits = bit_positions(mask);
171 *slot = crate::jet_algebra::leibniz_product(
172 bits.as_slice(),
173 |t| self.coeffs[mask_of(t)],
174 |c| other.coeffs[mask_of(c)],
175 );
176 }
177 Self { coeffs: out }
178 }
179
180 /// Exact (order-4 truncated) unary composition `f(self)` from the Taylor
181 /// stack `[f, f', f'', f''', f'''']` at `self.coeff(0)`.
182 ///
183 /// Computed by the truncated-Taylor reassociation (see the module note):
184 /// `f(self) = Σ_{k=0}^{4} (f^{(k)}/k!)·v^{⊛k}` with `v` the non-constant
185 /// part of `self`. The three subset-convolution powers `v²`, `v³`, `v⁴`
186 /// are compensated (Dot2) and the per-mask combine is Neumaier-compensated
187 /// and vectorised, so the result is *more* accurate vs. the true
188 /// real-arithmetic value than the prior naive partition sum (proven against
189 /// a double-double oracle in `tests`). The scalar `n_dirs == 0` case keeps
190 /// the shared Faà di Bruno walker live as its reference.
191 pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
192 COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
193 let count = self.coeffs.len();
194 if count <= 1 {
195 return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
196 }
197 // Per-block Taylor coefficients c_k = f^{(k)} / k! (k = 1..=4): the
198 // `1/k!` undoes the ordered-tuple overcount of the subset-convolution
199 // power v^{⊛k} relative to the unordered set-partition sum.
200 let c1 = derivs[1];
201 let c2 = derivs[2] * 0.5;
202 let c3 = derivs[3] * (1.0 / 6.0);
203 let c4 = derivs[4] * (1.0 / 24.0);
204
205 let mut out = vec![0.0; count];
206 COMPOSE_SCRATCH.with(|cell| {
207 let mut buf = cell.borrow_mut();
208 // Four contiguous scratch lanes: v, p2 = v², p3 = v³, p4 = v⁴.
209 buf.clear();
210 buf.resize(4 * count, 0.0);
211 let (vbuf, rest) = buf.split_at_mut(count);
212 let (p2, rest) = rest.split_at_mut(count);
213 let (p3, p4) = rest.split_at_mut(count);
214
215 // v = non-constant part of self (the constant channel squares to a
216 // 0-block, which the k = 0 term carries separately).
217 vbuf.copy_from_slice(&self.coeffs);
218 vbuf[0] = 0.0;
219
220 // Powers via compensated subset convolution, pruned by output
221 // popcount: v^{⊛k}[mask] = 0 whenever popcount(mask) < k.
222 subset_conv_into(vbuf, vbuf, p2, 2);
223 subset_conv_into(p2, vbuf, p3, 3);
224 subset_conv_into(p2, p2, p4, 4);
225
226 // out[mask] = c1·v + c2·v² + c3·v³ + c4·v⁴ (mask ≠ 0), Neumaier-
227 // compensated and f64x4-vectorised over masks. out[0] = f^{(0)}.
228 combine_powers(vbuf, p2, p3, p4, [c1, c2, c3, c4], &mut out);
229 out[0] = derivs[0];
230 });
231 Self { coeffs: out }
232 }
233}
234
235thread_local! {
236 /// Reused composition scratch (`4·count` f64s: v, v², v³, v⁴). Sized up on
237 /// demand and never freed, so a steady-state `compose_unary` does zero heap
238 /// work beyond the owned output `Vec`.
239 static COMPOSE_SCRATCH: RefCell<Vec<f64>> = const { RefCell::new(Vec::new()) };
240}
241
242/// Branchless TwoSum: returns `(s, e)` with `s = fl(a+b)` and `a+b = s+e`
243/// exactly (Knuth/Møller). Used by the compensated convolution and combine.
244#[inline(always)]
245fn two_sum(a: f64, b: f64) -> (f64, f64) {
246 let s = a + b;
247 let bb = s - a;
248 let e = (a - (s - bb)) + (b - bb);
249 (s, e)
250}
251
252/// Subset (zeta-style) convolution `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`,
253/// evaluated as a **compensated dot product** (Ogita–Rump–Oishi Dot2): each
254/// product is split into head + FMA error (`mul_add`) and the running sum
255/// carries a TwoSum error term, so the result is accurate as if computed in
256/// ~twice the working precision. This stops the rounding of `v²` from
257/// compounding through `v³`/`v⁴`, which a single-rounding accumulation does
258/// not. Output masks with `popcount < min_pop` are left at zero: the
259/// multilinear power `v^{⊛k}` vanishes below popcount `k`, so the prune is exact
260/// and skips the low-order masks entirely.
261#[inline]
262fn subset_conv_into(a: &[f64], b: &[f64], out: &mut [f64], min_pop: u32) {
263 // SAFETY invariant for the `get_unchecked` loads below: every index this
264 // kernel reads is `< out.len()`, and the caller passes `a`/`b` at least as
265 // long as `out` (in `compose_unary` all three are the same `count`-length
266 // slices carved from one scratch buffer). Concretely `mask < out.len()`
267 // (loop bound), and each submask satisfies `sub ⊆ mask` so `sub ≤ mask` and
268 // `mask ^ sub ⊆ mask` so `mask ^ sub ≤ mask` — both `< out.len() ≤ a.len(),
269 // b.len()`. The `assert!` below pins the length precondition (one check per
270 // call, negligible next to the walk) so the `get_unchecked` below is sound;
271 // the bounds checks LLVM cannot elide (the indices are data-dependent) are a
272 // real per-step cost across the `3^K` submask walk (×3 convolutions per
273 // compose), so eliding them via get_unchecked is a measured ~20% at the
274 // marginal-slope direction counts.
275 assert!(a.len() >= out.len() && b.len() >= out.len());
276 for (mask, slot) in out.iter_mut().enumerate() {
277 if (mask as u64).count_ones() < min_pop {
278 *slot = 0.0;
279 continue;
280 }
281 // Descending submask enumeration `sub = (sub-1) & mask`, terminating
282 // after `sub == 0` (the classic Gosper-style submask walk). The Dot2 is
283 // spread across FOUR independent named accumulators (a 4-way unroll) so
284 // the FMA/TwoSum latency chains overlap — the loop becomes throughput-
285 // rather than latency-bound — then the lanes are merged with a final
286 // compensated reduction. Every non-pruned mask has popcount ≥ 2, so its
287 // `2^popcount` submask count is a multiple of 4 and the unroll is exact
288 // (the all-zero submask always lands in the fourth lane). Reassociation
289 // only; the value is the same real sum, in ~double the working precision.
290 #[inline(always)]
291 fn dot2_step(s: &mut f64, c: &mut f64, x: f64, y: f64) {
292 let prod = x * y;
293 let prod_err = x.mul_add(y, -prod); // exact: prod + prod_err == x*y
294 let (t, sum_err) = two_sum(*s, prod);
295 *s = t;
296 *c += prod_err + sum_err;
297 }
298 let (mut s0, mut s1, mut s2, mut s3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
299 let (mut c0, mut c1, mut c2, mut c3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
300 let mut sub = mask;
301 // SAFETY: see the invariant comment at the top of the function — `sub`
302 // and `mask ^ sub` are both submasks of `mask < out.len()`, hence in
303 // bounds for `a`/`b` (each ≥ `out.len()` long).
304 unsafe {
305 loop {
306 dot2_step(&mut s0, &mut c0, *a.get_unchecked(sub), *b.get_unchecked(mask ^ sub));
307 sub = (sub - 1) & mask;
308 dot2_step(&mut s1, &mut c1, *a.get_unchecked(sub), *b.get_unchecked(mask ^ sub));
309 sub = (sub - 1) & mask;
310 dot2_step(&mut s2, &mut c2, *a.get_unchecked(sub), *b.get_unchecked(mask ^ sub));
311 sub = (sub - 1) & mask;
312 dot2_step(&mut s3, &mut c3, *a.get_unchecked(sub), *b.get_unchecked(mask ^ sub));
313 if sub == 0 {
314 break;
315 }
316 sub = (sub - 1) & mask;
317 }
318 }
319 // Merge the four lanes, compensated.
320 let (s01, e01) = two_sum(s0, s1);
321 let (s23, e23) = two_sum(s2, s3);
322 let (total, etot) = two_sum(s01, s23);
323 *slot = total + (etot + e01 + e23 + c0 + c1 + c2 + c3);
324 }
325}
326
327/// `out[mask] = c[0]·p1 + c[1]·p2 + c[2]·p3 + c[3]·p4` for `mask ≥ 1`, with a
328/// Neumaier-compensated four-term accumulation (the powers span growing
329/// magnitudes, so the compensation recovers the bits a naive `+=` would drop)
330/// and a `wide::f64x4` body over four masks at a time. `out[0]` is overwritten
331/// by the caller with the value channel.
332#[inline]
333fn combine_powers(p1: &[f64], p2: &[f64], p3: &[f64], p4: &[f64], c: [f64; 4], out: &mut [f64]) {
334 let n = out.len();
335 let (c1, c2, c3, c4) = (c[0], c[1], c[2], c[3]);
336 let (v1, v2, v3, v4) = (
337 f64x4::splat(c1),
338 f64x4::splat(c2),
339 f64x4::splat(c3),
340 f64x4::splat(c4),
341 );
342 let mut mask = 0usize;
343 // Vector body: four contiguous masks per step. Neumaier compensation is
344 // applied lane-wise; pick the larger magnitude to subtract first.
345 while mask + 4 <= n {
346 let load = |p: &[f64]| f64x4::new([p[mask], p[mask + 1], p[mask + 2], p[mask + 3]]);
347 let mut s = v1 * load(p1);
348 let mut comp = f64x4::splat(0.0);
349 for (cv, pv) in [(v2, p2), (v3, p3), (v4, p4)] {
350 let term = cv * load(pv);
351 let t = s + term;
352 let big_s = s.abs().cmp_ge(term.abs());
353 let lost = big_s.blend((s - t) + term, (term - t) + s);
354 comp += lost;
355 s = t;
356 }
357 let res = s + comp;
358 out[mask..mask + 4].copy_from_slice(&res.to_array());
359 mask += 4;
360 }
361 // Scalar tail (and the small-K path where `n < 4`).
362 while mask < n {
363 let mut s = c1 * p1[mask];
364 let mut comp = 0.0f64;
365 for (cv, pv) in [(c2, p2), (c3, p3), (c4, p4)] {
366 let term = cv * pv[mask];
367 let (t, e) = two_sum(s, term);
368 comp += e;
369 s = t;
370 }
371 out[mask] = s + comp;
372 mask += 1;
373 }
374}
375
376impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
377 #[inline]
378 fn derivative(&self, slots: &[usize]) -> f64 {
379 self.coeffs[mask_of(slots)]
380 }
381
382 fn map_derivatives<F>(&self, mut f: F) -> Self
383 where
384 F: FnMut(&[usize]) -> f64,
385 {
386 let mut out = vec![0.0; self.coeffs.len()];
387 for (mask, value) in out.iter_mut().enumerate() {
388 let bits = bit_positions(mask);
389 *value = f(bits.as_slice());
390 }
391 Self { coeffs: out }
392 }
393}
394
395/// The set-bit positions of `mask`, low to high — the differentiation slots of
396/// that coefficient.
397fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
398 let mut out = crate::jet_algebra::SlotBuf::new();
399 let mut m = mask;
400 while m != 0 {
401 let bit = m.trailing_zeros() as usize;
402 out.push_slot(bit);
403 m &= m - 1;
404 }
405 out
406}
407
408/// Combine a slot-group (list of bit positions) back into a sub-mask.
409fn mask_of(slots: &[usize]) -> usize {
410 slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
411}
412
413// #932-2 cutover: `MultiDirJet::bilinear` (the 4-coeff `[base, d1, d2, d12]`
414// constructor) and `MultiDirJet::sub` are consumed ONLY by the now test-only hand
415// survival directional/bidirectional oracle (the production flex jet path uses the
416// `flex_jet` runtime jet algebra, not `MultiDirJet`). After the #1521 crate split
417// moved `MultiDirJet` into `gam-math`, those oracle tests live in the dependent
418// `gam` crate, where a `#[cfg(test)]` gate in *this* crate is inactive — so the
419// methods must be plain `pub` inherent methods to be reachable cross-crate. They
420// carry no dead-code cost because `pub` items are part of the crate's public API.
421// Bodies are byte-identical to their former gated form.
422impl MultiDirJet {
423 pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
424 Self {
425 coeffs: vec![base, d1, d2, d12],
426 }
427 }
428
429 pub fn sub(&self, other: &Self) -> Self {
430 Self {
431 coeffs: self
432 .coeffs
433 .iter()
434 .zip(other.coeffs.iter())
435 .map(|(lhs, rhs)| lhs - rhs)
436 .collect(),
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 /// A flattened set-partition table for a fixed slot count. `parts[i] = (off,
446 /// order)` describes one partition: its `order` block submasks (compacted) are
447 /// `flat[off .. off + order]`.
448 ///
449 /// This direct set-partition sum is the previous production `compose_unary`
450 /// implementation, retained as the **accuracy reference** the new
451 /// truncated-Taylor path is graded against: a double-double oracle is the
452 /// truth, and the test asserts the new path's error-vs-truth is `≤` this naive
453 /// partition sum's error-vs-truth on every randomised program.
454 struct PartTable {
455 flat: Vec<u32>,
456 parts: Vec<(usize, u8)>,
457 }
458
459 thread_local! {
460 /// Cached set-partition tables, indexed by slot count `m`. Entry `m` holds
461 /// every partition of `{0..m}` into `< DERIVS` blocks, in the shared
462 /// walker's recursion order, each block a compacted submask. Pure function
463 /// of `m`, so caching is sound and deterministic.
464 static PARTITION_TABLES: RefCell<Vec<std::rc::Rc<PartTable>>> =
465 const { RefCell::new(Vec::new()) };
466 }
467
468 /// Return cached partition tables for slot counts `0..=n_dirs`.
469 fn partition_tables(n_dirs: usize) -> Vec<std::rc::Rc<PartTable>> {
470 PARTITION_TABLES.with(|cell| {
471 let mut tables = cell.borrow_mut();
472 while tables.len() <= n_dirs {
473 let m = tables.len();
474 tables.push(std::rc::Rc::new(build_partitions(m)));
475 }
476 (0..=n_dirs).map(|m| std::rc::Rc::clone(&tables[m])).collect()
477 })
478 }
479
480 /// The previous production `compose_unary`: a direct set-partition (Faà di
481 /// Bruno) sum per output mask, retained as the accuracy reference.
482 fn compose_unary_partition_reference(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<f64> {
483 let count = coeffs.len();
484 let n_dirs = count.trailing_zeros() as usize;
485 let tables = partition_tables(n_dirs);
486 let mut out = vec![0.0; count];
487 let mut remap = vec![0usize; count];
488 let mut pos = [0usize; usize::BITS as usize];
489 for (mask, slot) in out.iter_mut().enumerate() {
490 if mask == 0 {
491 *slot = derivs[0];
492 continue;
493 }
494 let mut npos = 0usize;
495 let mut m = mask;
496 while m != 0 {
497 pos[npos] = m.trailing_zeros() as usize;
498 npos += 1;
499 m &= m - 1;
500 }
501 remap[0] = 0;
502 for cb in 1usize..(1usize << npos) {
503 let low = cb.trailing_zeros() as usize;
504 remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
505 }
506 let table = &tables[npos];
507 let flat = &table.flat;
508 let mut total = 0.0;
509 for &(off, order) in table.parts.iter() {
510 let order = order as usize;
511 let mut prod = derivs[order];
512 for &cb in &flat[off..off + order] {
513 prod *= coeffs[remap[cb as usize]];
514 }
515 total += prod;
516 }
517 *slot = total;
518 }
519 out
520 }
521
522 /// Enumerate the set-partitions of `{0..m}` with fewer than `DERIVS` blocks, in
523 /// the exact DFS order of [`crate::jet_algebra`]'s `for_each_partition`
524 /// recursion ("place each element into an existing block, else open a new one"),
525 /// each block recorded as a compacted submask of `{0..m}`, flattened.
526 fn build_partitions(m: usize) -> PartTable {
527 fn recurse(elem: usize, m: usize, blocks: &mut [u32; 8], n_blocks: usize, out: &mut PartTable) {
528 // Partitions with `>= DERIVS` blocks are truncated (their `f^{(order)}`
529 // is beyond the stack); the block count never decreases, so the whole
530 // subtree contributes nothing and is pruned — matching the walker's
531 // per-partition `order >= derivs.len()` skip.
532 if n_blocks >= DERIVS {
533 return;
534 }
535 if elem == m {
536 let off = out.flat.len();
537 out.flat.extend_from_slice(&blocks[..n_blocks]);
538 out.parts.push((off, n_blocks as u8));
539 return;
540 }
541 for b in 0..n_blocks {
542 blocks[b] |= 1u32 << elem;
543 recurse(elem + 1, m, blocks, n_blocks, out);
544 blocks[b] &= !(1u32 << elem);
545 }
546 blocks[n_blocks] = 1u32 << elem;
547 recurse(elem + 1, m, blocks, n_blocks + 1, out);
548 }
549 let mut out = PartTable {
550 flat: Vec::new(),
551 parts: Vec::new(),
552 };
553 let mut blocks = [0u32; 8];
554 recurse(0, m, &mut blocks, 0, &mut out);
555 out
556 }
557
558 // ── constructors ─────────────────────────────────────────────────────────
559
560 #[test]
561 fn zero_has_correct_length_and_all_zero_coefficients() {
562 let j = MultiDirJet::zero(3);
563 assert_eq!(j.coeffs.len(), 8);
564 assert!(j.coeffs.iter().all(|&v| v == 0.0));
565 }
566
567 #[test]
568 fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
569 let j = MultiDirJet::constant(2, 5.0);
570 assert_eq!(j.coeffs.len(), 4);
571 assert_eq!(j.coeff(0), 5.0);
572 assert_eq!(j.coeff(1), 0.0);
573 assert_eq!(j.coeff(2), 0.0);
574 assert_eq!(j.coeff(3), 0.0);
575 }
576
577 #[test]
578 fn linear_sets_base_and_per_direction_slots() {
579 let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
580 assert_eq!(j.coeff(0), 1.0); // constant
581 assert_eq!(j.coeff(1), 2.0); // mask 0b01 — direction 0
582 assert_eq!(j.coeff(2), 3.0); // mask 0b10 — direction 1
583 assert_eq!(j.coeff(3), 0.0); // cross term is zero
584 }
585
586 #[test]
587 fn bilinear_sets_all_four_slots() {
588 let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
589 assert_eq!(j.coeff(0), 1.0);
590 assert_eq!(j.coeff(1), 2.0);
591 assert_eq!(j.coeff(2), 3.0);
592 assert_eq!(j.coeff(3), 4.0);
593 }
594
595 #[test]
596 fn with_coeffs_sets_only_specified_entries() {
597 let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
598 assert_eq!(j.coeff(0), 9.0);
599 assert_eq!(j.coeff(1), 0.0);
600 assert_eq!(j.coeff(2), 0.0);
601 assert_eq!(j.coeff(3), -1.0);
602 }
603
604 // ── elementwise arithmetic ────────────────────────────────────────────────
605
606 #[test]
607 fn add_is_elementwise() {
608 let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
609 let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
610 let c = a.add(&b);
611 assert_eq!(c.coeff(0), 5.0);
612 assert_eq!(c.coeff(1), 7.0);
613 assert_eq!(c.coeff(2), 9.0);
614 assert_eq!(c.coeff(3), 0.0);
615 }
616
617 #[test]
618 fn scale_multiplies_all_coefficients() {
619 let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
620 let s = j.scale(2.0);
621 assert_eq!(s.coeff(0), 2.0);
622 assert_eq!(s.coeff(1), 4.0);
623 assert_eq!(s.coeff(2), 6.0);
624 assert_eq!(s.coeff(3), 0.0);
625 }
626
627 #[test]
628 fn sub_is_elementwise_difference() {
629 let a = MultiDirJet::constant(2, 5.0);
630 let b = MultiDirJet::constant(2, 3.0);
631 let c = a.sub(&b);
632 assert_eq!(c.coeff(0), 2.0);
633 assert_eq!(c.coeff(1), 0.0);
634 assert_eq!(c.coeff(2), 0.0);
635 assert_eq!(c.coeff(3), 0.0);
636 }
637
638 // ── mul (subset-convolution) ──────────────────────────────────────────────
639
640 #[test]
641 fn mul_of_constants_is_scalar_product() {
642 let a = MultiDirJet::constant(2, 2.0);
643 let b = MultiDirJet::constant(2, 3.0);
644 let c = a.mul(&b);
645 assert_eq!(c.coeff(0), 6.0);
646 assert_eq!(c.coeff(1), 0.0);
647 assert_eq!(c.coeff(2), 0.0);
648 assert_eq!(c.coeff(3), 0.0);
649 }
650
651 #[test]
652 fn mul_satisfies_leibniz_rule_single_direction() {
653 // (1 + ε) * (1 + ε) = 1 + 2ε
654 let x = MultiDirJet::linear(1, 1.0, &[1.0]);
655 let y = MultiDirJet::linear(1, 1.0, &[1.0]);
656 let z = x.mul(&y);
657 assert_eq!(z.coeff(0), 1.0);
658 assert_eq!(z.coeff(1), 2.0);
659 }
660
661 #[test]
662 fn mul_cross_term_two_independent_directions() {
663 // (1 + ε₁)(1 + ε₂) = 1 + ε₁ + ε₂ + ε₁ε₂
664 let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
665 let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
666 let z = x.mul(&y);
667 assert_eq!(z.coeff(0), 1.0);
668 assert_eq!(z.coeff(1), 1.0);
669 assert_eq!(z.coeff(2), 1.0);
670 assert_eq!(z.coeff(3), 1.0);
671 }
672
673 // ── compose_unary: truncated-Taylor reassociation ─────────────────────────
674 //
675 // The new `compose_unary` reassociates the per-mask Faà di Bruno set-partition
676 // sum into a degree-4 polynomial in the subset-convolution power of the
677 // non-constant part. These tests are the accuracy gate: a double-double
678 // oracle is the truth, and the new path's error-vs-truth must be `≤` the old
679 // naive partition sum's error-vs-truth on every randomised program.
680
681 /// Deterministic xorshift64* — no `rand` dependency in the test.
682 struct Rng(u64);
683 impl Rng {
684 fn next_u64(&mut self) -> u64 {
685 let mut x = self.0;
686 x ^= x >> 12;
687 x ^= x << 25;
688 x ^= x >> 27;
689 self.0 = x;
690 x.wrapping_mul(0x2545F4914F6CDD1D)
691 }
692 /// Uniform in `[-scale, scale]`.
693 fn signed(&mut self, scale: f64) -> f64 {
694 let u = (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64; // [0,1)
695 (2.0 * u - 1.0) * scale
696 }
697 }
698
699 // ── A double-double oracle for the exact (order-4 truncated) composition ──
700
701 #[inline]
702 fn two_prod(a: f64, b: f64) -> (f64, f64) {
703 let p = a * b;
704 (p, a.mul_add(b, -p))
705 }
706 #[inline]
707 fn dd_two_sum(a: f64, b: f64) -> (f64, f64) {
708 let s = a + b;
709 let bb = s - a;
710 (s, (a - (s - bb)) + (b - bb))
711 }
712 #[derive(Clone, Copy)]
713 struct Dd {
714 hi: f64,
715 lo: f64,
716 }
717 impl Dd {
718 fn from(x: f64) -> Self {
719 Self { hi: x, lo: 0.0 }
720 }
721 fn mul_f64(self, b: f64) -> Self {
722 let (p, e) = two_prod(self.hi, b);
723 let lo = self.lo.mul_add(b, e);
724 let s = p + lo;
725 Self {
726 hi: s,
727 lo: (p - s) + lo,
728 }
729 }
730 fn add(self, o: Self) -> Self {
731 let (s, e) = dd_two_sum(self.hi, o.hi);
732 let (s2, e2) = dd_two_sum(self.lo, o.lo);
733 let lo = e + s2;
734 let h1 = s + lo;
735 let l1 = (s - h1) + lo;
736 let lo2 = l1 + e2;
737 let h = h1 + lo2;
738 Self {
739 hi: h,
740 lo: (h1 - h) + lo2,
741 }
742 }
743 /// `|self - x|` to ~double precision in the residual (Sterbenz: `x` and
744 /// `hi` agree to ~53 bits, so `x - hi` is essentially exact).
745 fn abs_err_to(self, x: f64) -> f64 {
746 ((x - self.hi) - self.lo).abs()
747 }
748 }
749
750 /// High-precision truth for `compose_unary` via the set-partition reference,
751 /// every product and sum carried in double-double.
752 fn compose_truth(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<Dd> {
753 let count = coeffs.len();
754 let n_dirs = count.trailing_zeros() as usize;
755 let tables = partition_tables(n_dirs);
756 let mut out = vec![Dd::from(0.0); count];
757 let mut remap = vec![0usize; count];
758 let mut pos = [0usize; 64];
759 for (mask, slot) in out.iter_mut().enumerate() {
760 if mask == 0 {
761 *slot = Dd::from(derivs[0]);
762 continue;
763 }
764 let mut npos = 0usize;
765 let mut m = mask;
766 while m != 0 {
767 pos[npos] = m.trailing_zeros() as usize;
768 npos += 1;
769 m &= m - 1;
770 }
771 remap[0] = 0;
772 for cb in 1usize..(1usize << npos) {
773 let low = cb.trailing_zeros() as usize;
774 remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
775 }
776 let table = &tables[npos];
777 let mut total = Dd::from(0.0);
778 for &(off, order) in table.parts.iter() {
779 let order = order as usize;
780 let mut prod = Dd::from(derivs[order]);
781 for &cb in &table.flat[off..off + order] {
782 prod = prod.mul_f64(coeffs[remap[cb as usize]]);
783 }
784 total = total.add(prod);
785 }
786 *slot = total;
787 }
788 out
789 }
790
791 /// Build a random composite jet so the composition input is a realistic
792 /// non-trivial multilinear element (not just seeded directions).
793 fn random_inner(n_dirs: usize, rng: &mut Rng) -> MultiDirJet {
794 let base = rng.signed(0.8);
795 let first: Vec<f64> = (0..n_dirs).map(|_| rng.signed(0.6)).collect();
796 let a = MultiDirJet::linear(n_dirs, base, &first);
797 let b = MultiDirJet::linear(
798 n_dirs,
799 rng.signed(0.7),
800 &(0..n_dirs).map(|_| rng.signed(0.5)).collect::<Vec<_>>(),
801 );
802 // a*b + a populates the full cross-mask spectrum.
803 a.mul(&b).add(&a)
804 }
805
806 #[test]
807 fn compose_unary_matches_partition_reference_simple() {
808 // exp-like stack on a 2-direction cross jet: every coeff agrees with the
809 // direct set-partition reference to a tight tolerance.
810 let j = MultiDirJet::linear(2, 0.3, &[0.5, -0.4])
811 .mul(&MultiDirJet::linear(2, -0.2, &[0.1, 0.7]));
812 let d = [0.9_f64, 1.1, -0.7, 0.4, -0.25];
813 let got = j.compose_unary(d);
814 let want = compose_unary_partition_reference(&j.coeffs, d);
815 for (mask, (&g, &w)) in got.coeffs.iter().zip(want.iter()).enumerate() {
816 let tol = 1e-13 * w.abs().max(1.0);
817 assert!(
818 (g - w).abs() <= tol,
819 "mask {mask}: got={g:.17e} want={w:.17e}"
820 );
821 }
822 }
823
824 #[test]
825 fn compose_unary_accuracy_beats_partition_sum_vs_double_double() {
826 // The accuracy gate. Over many random programs at every K used in
827 // production, the new path's error-vs-truth is never worse than the old
828 // naive partition sum's, and is a strict improvement in aggregate.
829 let mut rng = Rng(0x1234_5678_9abc_def0);
830 let mut sum_new = 0.0f64;
831 let mut sum_old = 0.0f64;
832 for &n_dirs in &[2usize, 3, 4, 6, 8] {
833 for _ in 0..200 {
834 let inner = random_inner(n_dirs, &mut rng);
835 let d = [
836 rng.signed(1.5),
837 rng.signed(1.5),
838 rng.signed(2.0),
839 rng.signed(3.0),
840 rng.signed(4.0),
841 ];
842 let new = inner.compose_unary(d);
843 let old = compose_unary_partition_reference(&inner.coeffs, d);
844 let truth = compose_truth(&inner.coeffs, d);
845 for mask in 0..inner.coeffs.len() {
846 let en = truth[mask].abs_err_to(new.coeffs[mask]);
847 let eo = truth[mask].abs_err_to(old[mask]);
848 sum_new += en;
849 sum_old += eo;
850 // Per-coefficient: new is never materially worse. The 4 ULP
851 // slack absorbs the rare tie where a differently-grouped but
852 // equally-valid rounding lands one ULP either way.
853 let scale = truth[mask].hi.abs().max(1.0);
854 assert!(
855 en <= eo + 4.0 * f64::EPSILON * scale,
856 "K={n_dirs} mask={mask}: new_err={en:.3e} old_err={eo:.3e}"
857 );
858 }
859 }
860 }
861 // Aggregate: the compensated reassociation is a real improvement.
862 assert!(
863 sum_new <= sum_old,
864 "aggregate error regressed: new={sum_new:.6e} old={sum_old:.6e}"
865 );
866 eprintln!(
867 "compose_unary accuracy: total |err| new={sum_new:.6e} old={sum_old:.6e} \
868 (improvement {:.2}x)",
869 sum_old / sum_new.max(f64::MIN_POSITIVE)
870 );
871 }
872
873 #[test]
874 fn compose_unary_speedup_over_partition_sum() {
875 // Measure ns/call new vs. the previous partition-sum implementation
876 // across the production K range. Prints the multiple; asserts a
877 // conservative floor so CI noise can't make it flaky.
878 use std::time::Instant;
879 let mut rng = Rng(0xfeed_face_dead_beef);
880 for &n_dirs in &[2usize, 4, 6, 8] {
881 let n_inputs = 256usize;
882 let inputs: Vec<(MultiDirJet, [f64; DERIVS])> = (0..n_inputs)
883 .map(|_| {
884 (
885 random_inner(n_dirs, &mut rng),
886 [
887 rng.signed(1.5),
888 rng.signed(1.5),
889 rng.signed(2.0),
890 rng.signed(3.0),
891 rng.signed(4.0),
892 ],
893 )
894 })
895 .collect();
896 let iters = 200usize;
897 // Warm the scratch / partition tables.
898 for (j, d) in &inputs {
899 std::hint::black_box(j.compose_unary(*d));
900 std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
901 }
902 let t0 = Instant::now();
903 for _ in 0..iters {
904 for (j, d) in &inputs {
905 std::hint::black_box(j.compose_unary(*d));
906 }
907 }
908 let new_ns = t0.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
909 let t1 = Instant::now();
910 for _ in 0..iters {
911 for (j, d) in &inputs {
912 std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
913 }
914 }
915 let old_ns = t1.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
916 eprintln!(
917 "compose_unary K={n_dirs}: new={new_ns:.1} ns/call old={old_ns:.1} ns/call \
918 speedup={:.2}x",
919 old_ns / new_ns
920 );
921 // Guard only where the algorithmic win is robust: an optimised build
922 // at the production-dominant K (the partition sum's `Σ_π |π|` work
923 // grows steeply with K, while the new path is three convolutions).
924 // Debug builds and tiny K are dominated by fixed per-call overhead
925 // and the ratio there is not a meaningful guard, so it is printed
926 // but not asserted (and timing asserts must not flake on CI).
927 if !cfg!(debug_assertions) && n_dirs >= 6 {
928 assert!(
929 new_ns < old_ns,
930 "K={n_dirs} new path slower: new={new_ns:.1}ns old={old_ns:.1}ns"
931 );
932 }
933 }
934 }
935}