Skip to main content

gam_math/
jet_partitions.rs

1//! Bitmask-coefficient multi-directional jets used by marginal-slope and
2//! latent-survival row kernels.
3//!
4//! The layout stores one coefficient per direction mask. The calculus itself
5//! lives in [`crate::jet_algebra`]: that module owns the layout-agnostic
6//! Leibniz / Faà di Bruno *combinatorics* once, and the scalar (`n_dirs <= 1`)
7//! path here still routes through it so a fix to the rule is a fix to both
8//! representations.
9//!
10//! ## Why this layout is special (and how the hot path exploits it)
11//!
12//! Each direction is seeded *linearly* (one first-derivative slot), so every
13//! direction variable squares to zero. The coefficients therefore form the
14//! commutative **multilinear / set-function algebra**: `coeffs[mask]` is the
15//! coefficient of `Π_{i ∈ mask} ε_i`. In that algebra two facts collapse the
16//! generic combinatorial walkers into tight branch-free arithmetic:
17//!
18//! * **`mul` is the subset (zeta-style) convolution**
19//!   `out[mask] = Σ_{sub ⊆ mask} a[sub] · b[mask \ sub]`.
20//!   The shared `leibniz_product` walker rebuilds two `SlotBuf`s and folds bit
21//!   lists back into masks (`mask_of`) *per subset*; here we enumerate the
22//!   submasks of `mask` directly — `mask \ sub == mask ^ sub` because
23//!   `sub ⊆ mask` — in the **same ascending order** the walker used, so the
24//!   floating-point accumulation is bit-for-bit identical while every
25//!   `SlotBuf`/closure/`mask_of` allocation and indirection disappears
26//!   (`3^K` pure FMAs, no heap, no `dyn`).
27//!
28//! * **`compose_unary` is the truncated Faà di Bruno composition**, computed
29//!   here by the exact **truncated-Taylor reassociation** rather than a direct
30//!   set-partition sum. Let `v` be the non-constant part of `self`
31//!   (`v[0] = 0`, `v[mask] = self[mask]`) and let `v^{⊛k}` be the `k`-fold
32//!   *subset convolution* (the multilinear power). The ordered-tuple identity
33//!   `v^{⊛k}[mask] = k! · Σ_{π ⊢ mask, |π| = k} Π_{B ∈ π} v[B]` turns the
34//!   set-partition sum into a degree-4 polynomial in `v`:
35//!
36//!   ```text
37//!   f(self)[mask] = Σ_{k=0}^{4} (f^{(k)} / k!) · v^{⊛k}[mask]      (mask ≠ 0)
38//!   f(self)[0]    = f^{(0)}
39//!   ```
40//!
41//!   so a composition is just **three subset convolutions** (`v²`, `v³=v²⊛v`,
42//!   `v⁴=v²⊛v²` — the Motzkin floor for a quartic) plus a five-term combine.
43//!   That is ~3× fewer FLOPs than the per-mask partition gather; each
44//!   convolution is a four-lane compensated dot product (Ogita–Rump–Oishi
45//!   Dot2, FMA-split products + TwoSum carry) so the result is computed in
46//!   ~double the working precision and the rounding of `v²` cannot compound
47//!   through `v³`/`v⁴`; the final per-mask combine is Neumaier-compensated and
48//!   `wide::f64x4`-vectorised; and the whole call runs on reused thread-local
49//!   scratch with no per-call heap traffic. The reassociation is algebraically
50//!   exact; accuracy-vs-truth (a double-double oracle) is the test gate and is
51//!   strictly ≤ the old partition sum's error (see `tests`).
52use std::cell::RefCell;
53use std::sync::atomic::{AtomicU64, Ordering};
54use wide::{CmpGe, f64x4};
55
56pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
57pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
58
59/// Length of the unary derivative stack `[f, f', f'', f''', f'''']`: composition
60/// is exact through order 4, partitions into `>= 5` blocks are truncated.
61const DERIVS: usize = 5;
62
63#[derive(Clone)]
64pub struct MultiDirJet {
65    pub coeffs: Vec<f64>,
66}
67
68impl MultiDirJet {
69    pub fn zero(n_dirs: usize) -> Self {
70        Self {
71            coeffs: vec![0.0; 1usize << n_dirs],
72        }
73    }
74
75    pub fn constant(n_dirs: usize, value: f64) -> Self {
76        let mut out = Self::zero(n_dirs);
77        out.coeffs[0] = value;
78        out
79    }
80
81    pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
82        let mut out = Self::constant(n_dirs, base);
83        for (idx, &value) in first.iter().take(n_dirs).enumerate() {
84            out.coeffs[1usize << idx] = value;
85        }
86        out
87    }
88
89    pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
90        let mut out = Self::zero(n_dirs);
91        for &(mask, value) in coeffs {
92            if mask < out.coeffs.len() {
93                out.coeffs[mask] = value;
94            }
95        }
96        out
97    }
98
99    #[inline]
100    pub fn coeff(&self, mask: usize) -> f64 {
101        self.coeffs[mask]
102    }
103
104    pub fn add(&self, other: &Self) -> Self {
105        Self {
106            coeffs: self
107                .coeffs
108                .iter()
109                .zip(other.coeffs.iter())
110                .map(|(lhs, rhs)| lhs + rhs)
111                .collect(),
112        }
113    }
114
115    pub fn scale(&self, scalar: f64) -> Self {
116        Self {
117            coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
118        }
119    }
120
121    /// Subset-convolution product `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`.
122    ///
123    /// Bit-identical to the shared [`crate::jet_algebra::leibniz_product`] walker
124    /// (the submasks are enumerated in the same ascending order — the walker's
125    /// compacted subset index is a monotone bit-deposit of the submask) while
126    /// dropping its per-subset `SlotBuf`/closure/`mask_of` overhead. The scalar
127    /// `n_dirs == 0` case keeps the shared walker live as its reference.
128    pub fn mul(&self, other: &Self) -> Self {
129        MUL_CALLS.fetch_add(1, Ordering::Relaxed);
130        let count = self.coeffs.len();
131        if count <= 1 {
132            return self.mul_reference(other);
133        }
134        let a = &self.coeffs;
135        let b = &other.coeffs;
136        // Both operands carry the same direction set, so `b` is `count` long too.
137        // With that established once, every `a[sub]`/`b[mask ^ sub]` below is
138        // provably in bounds (`sub, mask ^ sub ⊆ mask < count`), so the inner
139        // submask walk can drop its per-load bounds checks.
140        assert_eq!(
141            b.len(),
142            count,
143            "MultiDirJet::mul operands must share n_dirs"
144        );
145        let mut out = vec![0.0; count];
146        for (mask, slot) in out.iter_mut().enumerate() {
147            // Walk every submask of `mask` in ascending numeric order — the same
148            // order `leibniz_product` accumulates — via the classic gap-fill
149            // increment `next = ((sub | !mask) + 1) & mask`.
150            let mut acc = 0.0;
151            let mut sub = 0usize;
152            // SAFETY: `sub ⊆ mask < count` and `mask ^ sub ⊆ mask < count`, and
153            // both `a` and `b` are `count` long (asserted above).
154            unsafe {
155                loop {
156                    acc += *a.get_unchecked(sub) * *b.get_unchecked(mask ^ sub);
157                    if sub == mask {
158                        break;
159                    }
160                    sub = (sub | !mask).wrapping_add(1) & mask;
161                }
162            }
163            *slot = acc;
164        }
165        Self { coeffs: out }
166    }
167
168    /// The pre-#perf shared-walker product, retained verbatim as the scalar-case
169    /// implementation and as the bit-exact reference for `mul`.
170    fn mul_reference(&self, other: &Self) -> Self {
171        let count = self.coeffs.len();
172        let mut out = vec![0.0; count];
173        for (mask, slot) in out.iter_mut().enumerate() {
174            let bits = bit_positions(mask);
175            *slot = crate::jet_algebra::leibniz_product(
176                bits.as_slice(),
177                |t| self.coeffs[mask_of(t)],
178                |c| other.coeffs[mask_of(c)],
179            );
180        }
181        Self { coeffs: out }
182    }
183
184    /// Exact (order-4 truncated) unary composition `f(self)` from the Taylor
185    /// stack `[f, f', f'', f''', f'''']` at `self.coeff(0)`.
186    ///
187    /// Computed by the truncated-Taylor reassociation (see the module note):
188    /// `f(self) = Σ_{k=0}^{4} (f^{(k)}/k!)·v^{⊛k}` with `v` the non-constant
189    /// part of `self`. The three subset-convolution powers `v²`, `v³`, `v⁴`
190    /// are compensated (Dot2) and the per-mask combine is Neumaier-compensated
191    /// and vectorised, so the result is *more* accurate vs. the true
192    /// real-arithmetic value than the prior naive partition sum (proven against
193    /// a double-double oracle in `tests`). The scalar `n_dirs == 0` case keeps
194    /// the shared Faà di Bruno walker live as its reference.
195    pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
196        COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
197        let count = self.coeffs.len();
198        if count <= 1 {
199            return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
200        }
201        // Per-block Taylor coefficients c_k = f^{(k)} / k!  (k = 1..=4): the
202        // `1/k!` undoes the ordered-tuple overcount of the subset-convolution
203        // power v^{⊛k} relative to the unordered set-partition sum.
204        let c1 = derivs[1];
205        let c2 = derivs[2] * 0.5;
206        let c3 = derivs[3] * (1.0 / 6.0);
207        let c4 = derivs[4] * (1.0 / 24.0);
208
209        let mut out = vec![0.0; count];
210        COMPOSE_SCRATCH.with(|cell| {
211            let mut buf = cell.borrow_mut();
212            // Four contiguous scratch lanes: v, p2 = v², p3 = v³, p4 = v⁴.
213            buf.clear();
214            buf.resize(4 * count, 0.0);
215            let (vbuf, rest) = buf.split_at_mut(count);
216            let (p2, rest) = rest.split_at_mut(count);
217            let (p3, p4) = rest.split_at_mut(count);
218
219            // v = non-constant part of self (the constant channel squares to a
220            // 0-block, which the k = 0 term carries separately).
221            vbuf.copy_from_slice(&self.coeffs);
222            vbuf[0] = 0.0;
223
224            // Powers via compensated subset convolution, pruned by output
225            // popcount: v^{⊛k}[mask] = 0 whenever popcount(mask) < k.
226            subset_conv_into(vbuf, vbuf, p2, 2);
227            subset_conv_into(p2, vbuf, p3, 3);
228            subset_conv_into(p2, p2, p4, 4);
229
230            // out[mask] = c1·v + c2·v² + c3·v³ + c4·v⁴ (mask ≠ 0), Neumaier-
231            // compensated and f64x4-vectorised over masks. out[0] = f^{(0)}.
232            combine_powers(vbuf, p2, p3, p4, [c1, c2, c3, c4], &mut out);
233            out[0] = derivs[0];
234        });
235        Self { coeffs: out }
236    }
237}
238
239thread_local! {
240    /// Reused composition scratch (`4·count` f64s: v, v², v³, v⁴). Sized up on
241    /// demand and never freed, so a steady-state `compose_unary` does zero heap
242    /// work beyond the owned output `Vec`.
243    static COMPOSE_SCRATCH: RefCell<Vec<f64>> = const { RefCell::new(Vec::new()) };
244}
245
246/// Branchless TwoSum: returns `(s, e)` with `s = fl(a+b)` and `a+b = s+e`
247/// exactly (Knuth/Møller). Used by the compensated convolution and combine.
248#[inline(always)]
249fn two_sum(a: f64, b: f64) -> (f64, f64) {
250    let s = a + b;
251    let bb = s - a;
252    let e = (a - (s - bb)) + (b - bb);
253    (s, e)
254}
255
256/// Subset (zeta-style) convolution `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`,
257/// evaluated as a **compensated dot product** (Ogita–Rump–Oishi Dot2): each
258/// product is split into head + FMA error (`mul_add`) and the running sum
259/// carries a TwoSum error term, so the result is accurate as if computed in
260/// ~twice the working precision. This stops the rounding of `v²` from
261/// compounding through `v³`/`v⁴`, which a single-rounding accumulation does
262/// not. Output masks with `popcount < min_pop` are left at zero: the
263/// multilinear power `v^{⊛k}` vanishes below popcount `k`, so the prune is exact
264/// and skips the low-order masks entirely.
265#[inline]
266fn subset_conv_into(a: &[f64], b: &[f64], out: &mut [f64], min_pop: u32) {
267    // SAFETY invariant for the `get_unchecked` loads below: every index this
268    // kernel reads is `< out.len()`, and the caller passes `a`/`b` at least as
269    // long as `out` (in `compose_unary` all three are the same `count`-length
270    // slices carved from one scratch buffer). Concretely `mask < out.len()`
271    // (loop bound), and each submask satisfies `sub ⊆ mask` so `sub ≤ mask` and
272    // `mask ^ sub ⊆ mask` so `mask ^ sub ≤ mask` — both `< out.len() ≤ a.len(),
273    // b.len()`. The `assert!` below pins the length precondition (one check per
274    // call, negligible next to the walk) so the `get_unchecked` below is sound;
275    // the bounds checks LLVM cannot elide (the indices are data-dependent) are a
276    // real per-step cost across the `3^K` submask walk (×3 convolutions per
277    // compose), so eliding them via get_unchecked is a measured ~20% at the
278    // marginal-slope direction counts.
279    assert!(a.len() >= out.len() && b.len() >= out.len());
280    for (mask, slot) in out.iter_mut().enumerate() {
281        if (mask as u64).count_ones() < min_pop {
282            *slot = 0.0;
283            continue;
284        }
285        // Descending submask enumeration `sub = (sub-1) & mask`, terminating
286        // after `sub == 0` (the classic Gosper-style submask walk). The Dot2 is
287        // spread across FOUR independent named accumulators (a 4-way unroll) so
288        // the FMA/TwoSum latency chains overlap — the loop becomes throughput-
289        // rather than latency-bound — then the lanes are merged with a final
290        // compensated reduction. Every non-pruned mask has popcount ≥ 2, so its
291        // `2^popcount` submask count is a multiple of 4 and the unroll is exact
292        // (the all-zero submask always lands in the fourth lane). Reassociation
293        // only; the value is the same real sum, in ~double the working precision.
294        #[inline(always)]
295        fn dot2_step(s: &mut f64, c: &mut f64, x: f64, y: f64) {
296            let prod = x * y;
297            let prod_err = x.mul_add(y, -prod); // exact: prod + prod_err == x*y
298            let (t, sum_err) = two_sum(*s, prod);
299            *s = t;
300            *c += prod_err + sum_err;
301        }
302        let (mut s0, mut s1, mut s2, mut s3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
303        let (mut c0, mut c1, mut c2, mut c3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
304        let mut sub = mask;
305        // SAFETY: see the invariant comment at the top of the function — `sub`
306        // and `mask ^ sub` are both submasks of `mask < out.len()`, hence in
307        // bounds for `a`/`b` (each ≥ `out.len()` long).
308        unsafe {
309            loop {
310                dot2_step(
311                    &mut s0,
312                    &mut c0,
313                    *a.get_unchecked(sub),
314                    *b.get_unchecked(mask ^ sub),
315                );
316                sub = (sub - 1) & mask;
317                dot2_step(
318                    &mut s1,
319                    &mut c1,
320                    *a.get_unchecked(sub),
321                    *b.get_unchecked(mask ^ sub),
322                );
323                sub = (sub - 1) & mask;
324                dot2_step(
325                    &mut s2,
326                    &mut c2,
327                    *a.get_unchecked(sub),
328                    *b.get_unchecked(mask ^ sub),
329                );
330                sub = (sub - 1) & mask;
331                dot2_step(
332                    &mut s3,
333                    &mut c3,
334                    *a.get_unchecked(sub),
335                    *b.get_unchecked(mask ^ sub),
336                );
337                if sub == 0 {
338                    break;
339                }
340                sub = (sub - 1) & mask;
341            }
342        }
343        // Merge the four lanes, compensated.
344        let (s01, e01) = two_sum(s0, s1);
345        let (s23, e23) = two_sum(s2, s3);
346        let (total, etot) = two_sum(s01, s23);
347        *slot = total + (etot + e01 + e23 + c0 + c1 + c2 + c3);
348    }
349}
350
351/// `out[mask] = c[0]·p1 + c[1]·p2 + c[2]·p3 + c[3]·p4` for `mask ≥ 1`, with a
352/// Neumaier-compensated four-term accumulation (the powers span growing
353/// magnitudes, so the compensation recovers the bits a naive `+=` would drop)
354/// and a `wide::f64x4` body over four masks at a time. `out[0]` is overwritten
355/// by the caller with the value channel.
356#[inline]
357fn combine_powers(p1: &[f64], p2: &[f64], p3: &[f64], p4: &[f64], c: [f64; 4], out: &mut [f64]) {
358    let n = out.len();
359    let (c1, c2, c3, c4) = (c[0], c[1], c[2], c[3]);
360    let (v1, v2, v3, v4) = (
361        f64x4::splat(c1),
362        f64x4::splat(c2),
363        f64x4::splat(c3),
364        f64x4::splat(c4),
365    );
366    let mut mask = 0usize;
367    // Vector body: four contiguous masks per step. Neumaier compensation is
368    // applied lane-wise; pick the larger magnitude to subtract first.
369    while mask + 4 <= n {
370        let load = |p: &[f64]| f64x4::new([p[mask], p[mask + 1], p[mask + 2], p[mask + 3]]);
371        let mut s = v1 * load(p1);
372        let mut comp = f64x4::splat(0.0);
373        for (cv, pv) in [(v2, p2), (v3, p3), (v4, p4)] {
374            let term = cv * load(pv);
375            let t = s + term;
376            let big_s = s.abs().cmp_ge(term.abs());
377            let lost = big_s.blend((s - t) + term, (term - t) + s);
378            comp += lost;
379            s = t;
380        }
381        let res = s + comp;
382        out[mask..mask + 4].copy_from_slice(&res.to_array());
383        mask += 4;
384    }
385    // Scalar tail (and the small-K path where `n < 4`).
386    while mask < n {
387        let mut s = c1 * p1[mask];
388        let mut comp = 0.0f64;
389        for (cv, pv) in [(c2, p2), (c3, p3), (c4, p4)] {
390            let term = cv * pv[mask];
391            let (t, e) = two_sum(s, term);
392            comp += e;
393            s = t;
394        }
395        out[mask] = s + comp;
396        mask += 1;
397    }
398}
399
400impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
401    #[inline]
402    fn derivative(&self, slots: &[usize]) -> f64 {
403        self.coeffs[mask_of(slots)]
404    }
405
406    fn map_derivatives<F>(&self, mut f: F) -> Self
407    where
408        F: FnMut(&[usize]) -> f64,
409    {
410        let mut out = vec![0.0; self.coeffs.len()];
411        for (mask, value) in out.iter_mut().enumerate() {
412            let bits = bit_positions(mask);
413            *value = f(bits.as_slice());
414        }
415        Self { coeffs: out }
416    }
417}
418
419/// The set-bit positions of `mask`, low to high — the differentiation slots of
420/// that coefficient.
421fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
422    let mut out = crate::jet_algebra::SlotBuf::new();
423    let mut m = mask;
424    while m != 0 {
425        let bit = m.trailing_zeros() as usize;
426        out.push_slot(bit);
427        m &= m - 1;
428    }
429    out
430}
431
432/// Combine a slot-group (list of bit positions) back into a sub-mask.
433fn mask_of(slots: &[usize]) -> usize {
434    slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
435}
436
437// #932-2 cutover: `MultiDirJet::bilinear` (the 4-coeff `[base, d1, d2, d12]`
438// constructor) and `MultiDirJet::sub` are consumed ONLY by the now test-only hand
439// survival directional/bidirectional oracle (the production flex jet path uses the
440// `flex_jet` runtime jet algebra, not `MultiDirJet`). After the #1521 crate split
441// moved `MultiDirJet` into `gam-math`, those oracle tests live in the dependent
442// `gam` crate, where a `#[cfg(test)]` gate in *this* crate is inactive — so the
443// methods must be plain `pub` inherent methods to be reachable cross-crate. They
444// carry no dead-code cost because `pub` items are part of the crate's public API.
445// Bodies are byte-identical to their former gated form.
446impl MultiDirJet {
447    pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
448        Self {
449            coeffs: vec![base, d1, d2, d12],
450        }
451    }
452
453    pub fn sub(&self, other: &Self) -> Self {
454        Self {
455            coeffs: self
456                .coeffs
457                .iter()
458                .zip(other.coeffs.iter())
459                .map(|(lhs, rhs)| lhs - rhs)
460                .collect(),
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    /// A flattened set-partition table for a fixed slot count. `parts[i] = (off,
470    /// order)` describes one partition: its `order` block submasks (compacted) are
471    /// `flat[off .. off + order]`.
472    ///
473    /// This direct set-partition sum is the previous production `compose_unary`
474    /// implementation, retained as the **accuracy reference** the new
475    /// truncated-Taylor path is graded against: a double-double oracle is the
476    /// truth, and the test asserts the new path's error-vs-truth is `≤` this naive
477    /// partition sum's error-vs-truth on every randomised program.
478    struct PartTable {
479        flat: Vec<u32>,
480        parts: Vec<(usize, u8)>,
481    }
482
483    thread_local! {
484        /// Cached set-partition tables, indexed by slot count `m`. Entry `m` holds
485        /// every partition of `{0..m}` into `< DERIVS` blocks, in the shared
486        /// walker's recursion order, each block a compacted submask. Pure function
487        /// of `m`, so caching is sound and deterministic.
488        static PARTITION_TABLES: RefCell<Vec<std::rc::Rc<PartTable>>> =
489            const { RefCell::new(Vec::new()) };
490    }
491
492    /// Return cached partition tables for slot counts `0..=n_dirs`.
493    fn partition_tables(n_dirs: usize) -> Vec<std::rc::Rc<PartTable>> {
494        PARTITION_TABLES.with(|cell| {
495            let mut tables = cell.borrow_mut();
496            while tables.len() <= n_dirs {
497                let m = tables.len();
498                tables.push(std::rc::Rc::new(build_partitions(m)));
499            }
500            (0..=n_dirs)
501                .map(|m| std::rc::Rc::clone(&tables[m]))
502                .collect()
503        })
504    }
505
506    /// The previous production `compose_unary`: a direct set-partition (Faà di
507    /// Bruno) sum per output mask, retained as the accuracy reference.
508    fn compose_unary_partition_reference(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<f64> {
509        let count = coeffs.len();
510        let n_dirs = count.trailing_zeros() as usize;
511        let tables = partition_tables(n_dirs);
512        let mut out = vec![0.0; count];
513        let mut remap = vec![0usize; count];
514        let mut pos = [0usize; usize::BITS as usize];
515        for (mask, slot) in out.iter_mut().enumerate() {
516            if mask == 0 {
517                *slot = derivs[0];
518                continue;
519            }
520            let mut npos = 0usize;
521            let mut m = mask;
522            while m != 0 {
523                pos[npos] = m.trailing_zeros() as usize;
524                npos += 1;
525                m &= m - 1;
526            }
527            remap[0] = 0;
528            for cb in 1usize..(1usize << npos) {
529                let low = cb.trailing_zeros() as usize;
530                remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
531            }
532            let table = &tables[npos];
533            let flat = &table.flat;
534            let mut total = 0.0;
535            for &(off, order) in table.parts.iter() {
536                let order = order as usize;
537                let mut prod = derivs[order];
538                for &cb in &flat[off..off + order] {
539                    prod *= coeffs[remap[cb as usize]];
540                }
541                total += prod;
542            }
543            *slot = total;
544        }
545        out
546    }
547
548    /// Enumerate the set-partitions of `{0..m}` with fewer than `DERIVS` blocks, in
549    /// the exact DFS order of [`crate::jet_algebra`]'s `for_each_partition`
550    /// recursion ("place each element into an existing block, else open a new one"),
551    /// each block recorded as a compacted submask of `{0..m}`, flattened.
552    fn build_partitions(m: usize) -> PartTable {
553        fn recurse(
554            elem: usize,
555            m: usize,
556            blocks: &mut [u32; 8],
557            n_blocks: usize,
558            out: &mut PartTable,
559        ) {
560            // Partitions with `>= DERIVS` blocks are truncated (their `f^{(order)}`
561            // is beyond the stack); the block count never decreases, so the whole
562            // subtree contributes nothing and is pruned — matching the walker's
563            // per-partition `order >= derivs.len()` skip.
564            if n_blocks >= DERIVS {
565                return;
566            }
567            if elem == m {
568                let off = out.flat.len();
569                out.flat.extend_from_slice(&blocks[..n_blocks]);
570                out.parts.push((off, n_blocks as u8));
571                return;
572            }
573            for b in 0..n_blocks {
574                blocks[b] |= 1u32 << elem;
575                recurse(elem + 1, m, blocks, n_blocks, out);
576                blocks[b] &= !(1u32 << elem);
577            }
578            blocks[n_blocks] = 1u32 << elem;
579            recurse(elem + 1, m, blocks, n_blocks + 1, out);
580        }
581        let mut out = PartTable {
582            flat: Vec::new(),
583            parts: Vec::new(),
584        };
585        let mut blocks = [0u32; 8];
586        recurse(0, m, &mut blocks, 0, &mut out);
587        out
588    }
589
590    // ── constructors ─────────────────────────────────────────────────────────
591
592    #[test]
593    fn zero_has_correct_length_and_all_zero_coefficients() {
594        let j = MultiDirJet::zero(3);
595        assert_eq!(j.coeffs.len(), 8);
596        assert!(j.coeffs.iter().all(|&v| v == 0.0));
597    }
598
599    #[test]
600    fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
601        let j = MultiDirJet::constant(2, 5.0);
602        assert_eq!(j.coeffs.len(), 4);
603        assert_eq!(j.coeff(0), 5.0);
604        assert_eq!(j.coeff(1), 0.0);
605        assert_eq!(j.coeff(2), 0.0);
606        assert_eq!(j.coeff(3), 0.0);
607    }
608
609    #[test]
610    fn linear_sets_base_and_per_direction_slots() {
611        let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
612        assert_eq!(j.coeff(0), 1.0); // constant
613        assert_eq!(j.coeff(1), 2.0); // mask 0b01 — direction 0
614        assert_eq!(j.coeff(2), 3.0); // mask 0b10 — direction 1
615        assert_eq!(j.coeff(3), 0.0); // cross term is zero
616    }
617
618    #[test]
619    fn bilinear_sets_all_four_slots() {
620        let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
621        assert_eq!(j.coeff(0), 1.0);
622        assert_eq!(j.coeff(1), 2.0);
623        assert_eq!(j.coeff(2), 3.0);
624        assert_eq!(j.coeff(3), 4.0);
625    }
626
627    #[test]
628    fn with_coeffs_sets_only_specified_entries() {
629        let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
630        assert_eq!(j.coeff(0), 9.0);
631        assert_eq!(j.coeff(1), 0.0);
632        assert_eq!(j.coeff(2), 0.0);
633        assert_eq!(j.coeff(3), -1.0);
634    }
635
636    // ── elementwise arithmetic ────────────────────────────────────────────────
637
638    #[test]
639    fn add_is_elementwise() {
640        let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
641        let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
642        let c = a.add(&b);
643        assert_eq!(c.coeff(0), 5.0);
644        assert_eq!(c.coeff(1), 7.0);
645        assert_eq!(c.coeff(2), 9.0);
646        assert_eq!(c.coeff(3), 0.0);
647    }
648
649    #[test]
650    fn scale_multiplies_all_coefficients() {
651        let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
652        let s = j.scale(2.0);
653        assert_eq!(s.coeff(0), 2.0);
654        assert_eq!(s.coeff(1), 4.0);
655        assert_eq!(s.coeff(2), 6.0);
656        assert_eq!(s.coeff(3), 0.0);
657    }
658
659    #[test]
660    fn sub_is_elementwise_difference() {
661        let a = MultiDirJet::constant(2, 5.0);
662        let b = MultiDirJet::constant(2, 3.0);
663        let c = a.sub(&b);
664        assert_eq!(c.coeff(0), 2.0);
665        assert_eq!(c.coeff(1), 0.0);
666        assert_eq!(c.coeff(2), 0.0);
667        assert_eq!(c.coeff(3), 0.0);
668    }
669
670    // ── mul (subset-convolution) ──────────────────────────────────────────────
671
672    #[test]
673    fn mul_of_constants_is_scalar_product() {
674        let a = MultiDirJet::constant(2, 2.0);
675        let b = MultiDirJet::constant(2, 3.0);
676        let c = a.mul(&b);
677        assert_eq!(c.coeff(0), 6.0);
678        assert_eq!(c.coeff(1), 0.0);
679        assert_eq!(c.coeff(2), 0.0);
680        assert_eq!(c.coeff(3), 0.0);
681    }
682
683    #[test]
684    fn mul_satisfies_leibniz_rule_single_direction() {
685        // (1 + ε) * (1 + ε) = 1 + 2ε
686        let x = MultiDirJet::linear(1, 1.0, &[1.0]);
687        let y = MultiDirJet::linear(1, 1.0, &[1.0]);
688        let z = x.mul(&y);
689        assert_eq!(z.coeff(0), 1.0);
690        assert_eq!(z.coeff(1), 2.0);
691    }
692
693    #[test]
694    fn mul_cross_term_two_independent_directions() {
695        // (1 + ε₁)(1 + ε₂) = 1 + ε₁ + ε₂ + ε₁ε₂
696        let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
697        let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
698        let z = x.mul(&y);
699        assert_eq!(z.coeff(0), 1.0);
700        assert_eq!(z.coeff(1), 1.0);
701        assert_eq!(z.coeff(2), 1.0);
702        assert_eq!(z.coeff(3), 1.0);
703    }
704
705    // ── compose_unary: truncated-Taylor reassociation ─────────────────────────
706    //
707    // The new `compose_unary` reassociates the per-mask Faà di Bruno set-partition
708    // sum into a degree-4 polynomial in the subset-convolution power of the
709    // non-constant part. These tests are the accuracy gate: a double-double
710    // oracle is the truth, and the new path's error-vs-truth must be `≤` the old
711    // naive partition sum's error-vs-truth on every randomised program.
712
713    /// Deterministic xorshift64* — no `rand` dependency in the test.
714    struct Rng(u64);
715    impl Rng {
716        fn next_u64(&mut self) -> u64 {
717            let mut x = self.0;
718            x ^= x >> 12;
719            x ^= x << 25;
720            x ^= x >> 27;
721            self.0 = x;
722            x.wrapping_mul(0x2545F4914F6CDD1D)
723        }
724        /// Uniform in `[-scale, scale]`.
725        fn signed(&mut self, scale: f64) -> f64 {
726            let u = (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64; // [0,1)
727            (2.0 * u - 1.0) * scale
728        }
729    }
730
731    // ── A double-double oracle for the exact (order-4 truncated) composition ──
732
733    #[inline]
734    fn two_prod(a: f64, b: f64) -> (f64, f64) {
735        let p = a * b;
736        (p, a.mul_add(b, -p))
737    }
738    #[inline]
739    fn dd_two_sum(a: f64, b: f64) -> (f64, f64) {
740        let s = a + b;
741        let bb = s - a;
742        (s, (a - (s - bb)) + (b - bb))
743    }
744    #[derive(Clone, Copy)]
745    struct Dd {
746        hi: f64,
747        lo: f64,
748    }
749    impl Dd {
750        fn from(x: f64) -> Self {
751            Self { hi: x, lo: 0.0 }
752        }
753        fn mul_f64(self, b: f64) -> Self {
754            let (p, e) = two_prod(self.hi, b);
755            let lo = self.lo.mul_add(b, e);
756            let s = p + lo;
757            Self {
758                hi: s,
759                lo: (p - s) + lo,
760            }
761        }
762        fn add(self, o: Self) -> Self {
763            let (s, e) = dd_two_sum(self.hi, o.hi);
764            let (s2, e2) = dd_two_sum(self.lo, o.lo);
765            let lo = e + s2;
766            let h1 = s + lo;
767            let l1 = (s - h1) + lo;
768            let lo2 = l1 + e2;
769            let h = h1 + lo2;
770            Self {
771                hi: h,
772                lo: (h1 - h) + lo2,
773            }
774        }
775        /// `|self - x|` to ~double precision in the residual (Sterbenz: `x` and
776        /// `hi` agree to ~53 bits, so `x - hi` is essentially exact).
777        fn abs_err_to(self, x: f64) -> f64 {
778            ((x - self.hi) - self.lo).abs()
779        }
780    }
781
782    /// High-precision truth for `compose_unary` via the set-partition reference,
783    /// every product and sum carried in double-double.
784    fn compose_truth(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<Dd> {
785        let count = coeffs.len();
786        let n_dirs = count.trailing_zeros() as usize;
787        let tables = partition_tables(n_dirs);
788        let mut out = vec![Dd::from(0.0); count];
789        let mut remap = vec![0usize; count];
790        let mut pos = [0usize; 64];
791        for (mask, slot) in out.iter_mut().enumerate() {
792            if mask == 0 {
793                *slot = Dd::from(derivs[0]);
794                continue;
795            }
796            let mut npos = 0usize;
797            let mut m = mask;
798            while m != 0 {
799                pos[npos] = m.trailing_zeros() as usize;
800                npos += 1;
801                m &= m - 1;
802            }
803            remap[0] = 0;
804            for cb in 1usize..(1usize << npos) {
805                let low = cb.trailing_zeros() as usize;
806                remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
807            }
808            let table = &tables[npos];
809            let mut total = Dd::from(0.0);
810            for &(off, order) in table.parts.iter() {
811                let order = order as usize;
812                let mut prod = Dd::from(derivs[order]);
813                for &cb in &table.flat[off..off + order] {
814                    prod = prod.mul_f64(coeffs[remap[cb as usize]]);
815                }
816                total = total.add(prod);
817            }
818            *slot = total;
819        }
820        out
821    }
822
823    /// Build a random composite jet so the composition input is a realistic
824    /// non-trivial multilinear element (not just seeded directions).
825    fn random_inner(n_dirs: usize, rng: &mut Rng) -> MultiDirJet {
826        let base = rng.signed(0.8);
827        let first: Vec<f64> = (0..n_dirs).map(|_| rng.signed(0.6)).collect();
828        let a = MultiDirJet::linear(n_dirs, base, &first);
829        let b = MultiDirJet::linear(
830            n_dirs,
831            rng.signed(0.7),
832            &(0..n_dirs).map(|_| rng.signed(0.5)).collect::<Vec<_>>(),
833        );
834        // a*b + a populates the full cross-mask spectrum.
835        a.mul(&b).add(&a)
836    }
837
838    #[test]
839    fn compose_unary_matches_partition_reference_simple() {
840        // exp-like stack on a 2-direction cross jet: every coeff agrees with the
841        // direct set-partition reference to a tight tolerance.
842        let j = MultiDirJet::linear(2, 0.3, &[0.5, -0.4]).mul(&MultiDirJet::linear(
843            2,
844            -0.2,
845            &[0.1, 0.7],
846        ));
847        let d = [0.9_f64, 1.1, -0.7, 0.4, -0.25];
848        let got = j.compose_unary(d);
849        let want = compose_unary_partition_reference(&j.coeffs, d);
850        for (mask, (&g, &w)) in got.coeffs.iter().zip(want.iter()).enumerate() {
851            let tol = 1e-13 * w.abs().max(1.0);
852            assert!(
853                (g - w).abs() <= tol,
854                "mask {mask}: got={g:.17e} want={w:.17e}"
855            );
856        }
857    }
858
859    #[test]
860    fn compose_unary_accuracy_beats_partition_sum_vs_double_double() {
861        // The accuracy gate. Over many random programs at every K used in
862        // production, the new path's error-vs-truth is never worse than the old
863        // naive partition sum's, and is a strict improvement in aggregate.
864        let mut rng = Rng(0x1234_5678_9abc_def0);
865        let mut sum_new = 0.0f64;
866        let mut sum_old = 0.0f64;
867        for &n_dirs in &[2usize, 3, 4, 6, 8] {
868            for _ in 0..200 {
869                let inner = random_inner(n_dirs, &mut rng);
870                let d = [
871                    rng.signed(1.5),
872                    rng.signed(1.5),
873                    rng.signed(2.0),
874                    rng.signed(3.0),
875                    rng.signed(4.0),
876                ];
877                let new = inner.compose_unary(d);
878                let old = compose_unary_partition_reference(&inner.coeffs, d);
879                let truth = compose_truth(&inner.coeffs, d);
880                for mask in 0..inner.coeffs.len() {
881                    let en = truth[mask].abs_err_to(new.coeffs[mask]);
882                    let eo = truth[mask].abs_err_to(old[mask]);
883                    sum_new += en;
884                    sum_old += eo;
885                    // Per-coefficient: new is never materially worse. The 4 ULP
886                    // slack absorbs the rare tie where a differently-grouped but
887                    // equally-valid rounding lands one ULP either way.
888                    let scale = truth[mask].hi.abs().max(1.0);
889                    assert!(
890                        en <= eo + 4.0 * f64::EPSILON * scale,
891                        "K={n_dirs} mask={mask}: new_err={en:.3e} old_err={eo:.3e}"
892                    );
893                }
894            }
895        }
896        // Aggregate: the compensated reassociation is a real improvement.
897        assert!(
898            sum_new <= sum_old,
899            "aggregate error regressed: new={sum_new:.6e} old={sum_old:.6e}"
900        );
901        eprintln!(
902            "compose_unary accuracy: total |err| new={sum_new:.6e} old={sum_old:.6e} \
903             (improvement {:.2}x)",
904            sum_old / sum_new.max(f64::MIN_POSITIVE)
905        );
906    }
907
908    #[test]
909    fn compose_unary_speedup_over_partition_sum() {
910        // Measure ns/call new vs. the previous partition-sum implementation
911        // across the production K range. Prints the multiple; asserts a
912        // conservative floor so CI noise can't make it flaky.
913        use std::time::Instant;
914        let mut rng = Rng(0xfeed_face_dead_beef);
915        for &n_dirs in &[2usize, 4, 6, 8] {
916            let n_inputs = 256usize;
917            let inputs: Vec<(MultiDirJet, [f64; DERIVS])> = (0..n_inputs)
918                .map(|_| {
919                    (
920                        random_inner(n_dirs, &mut rng),
921                        [
922                            rng.signed(1.5),
923                            rng.signed(1.5),
924                            rng.signed(2.0),
925                            rng.signed(3.0),
926                            rng.signed(4.0),
927                        ],
928                    )
929                })
930                .collect();
931            let iters = 200usize;
932            // Warm the scratch / partition tables.
933            for (j, d) in &inputs {
934                std::hint::black_box(j.compose_unary(*d));
935                std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
936            }
937            let t0 = Instant::now();
938            for _ in 0..iters {
939                for (j, d) in &inputs {
940                    std::hint::black_box(j.compose_unary(*d));
941                }
942            }
943            let new_ns = t0.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
944            let t1 = Instant::now();
945            for _ in 0..iters {
946                for (j, d) in &inputs {
947                    std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
948                }
949            }
950            let old_ns = t1.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
951            eprintln!(
952                "compose_unary K={n_dirs}: new={new_ns:.1} ns/call  old={old_ns:.1} ns/call  \
953                 speedup={:.2}x",
954                old_ns / new_ns
955            );
956            // Guard only where the algorithmic win is robust: an optimised build
957            // at the production-dominant K (the partition sum's `Σ_π |π|` work
958            // grows steeply with K, while the new path is three convolutions).
959            // Debug builds and tiny K are dominated by fixed per-call overhead
960            // and the ratio there is not a meaningful guard, so it is printed
961            // but not asserted (and timing asserts must not flake on CI).
962            if !cfg!(debug_assertions) && n_dirs >= 6 {
963                assert!(
964                    new_ns < old_ns,
965                    "K={n_dirs} new path slower: new={new_ns:.1}ns old={old_ns:.1}ns"
966                );
967            }
968        }
969    }
970}