Skip to main content

gam_math/
jet_partitions.rs

1//! Bitmask-coefficient multi-directional jets used by marginal-slope and
2//! latent-survival row kernels.
3//!
4//! The layout stores one coefficient per direction mask. The calculus itself
5//! lives in [`crate::jet_algebra`]: that module owns the layout-agnostic
6//! Leibniz / Faà di Bruno *combinatorics* once, and the scalar (`n_dirs <= 1`)
7//! path here still routes through it so a fix to the rule is a fix to both
8//! representations.
9//!
10//! ## Why this layout is special (and how the hot path exploits it)
11//!
12//! Each direction is seeded *linearly* (one first-derivative slot), so every
13//! direction variable squares to zero. The coefficients therefore form the
14//! commutative **multilinear / set-function algebra**: `coeffs[mask]` is the
15//! coefficient of `Π_{i ∈ mask} ε_i`. In that algebra two facts collapse the
16//! generic combinatorial walkers into tight branch-free arithmetic:
17//!
18//! * **`mul` is the subset (zeta-style) convolution**
19//!   `out[mask] = Σ_{sub ⊆ mask} a[sub] · b[mask \ sub]`.
20//!   The shared `leibniz_product` walker rebuilds two `SlotBuf`s and folds bit
21//!   lists back into masks (`mask_of`) *per subset*; here we enumerate the
22//!   submasks of `mask` directly — `mask \ sub == mask ^ sub` because
23//!   `sub ⊆ mask` — in the **same ascending order** the walker used, so the
24//!   floating-point accumulation is bit-for-bit identical while every
25//!   `SlotBuf`/closure/`mask_of` allocation and indirection disappears
26//!   (`3^K` pure FMAs, no heap, no `dyn`).
27//!
28//! * **`compose_unary` is the truncated set-partition (Faà di Bruno) sum**
29//!   `out[mask] = Σ_{π ⊢ mask, |π| < 5} f^{(|π|)} · Π_{B ∈ π} u[B]`.
30//!   The shared walker re-runs the partition *recursion* (with `&mut dyn
31//!   FnMut` dispatch and fresh `SlotBuf` blocks) once **per output mask**.
32//!   The set of partitions of `m` slots depends only on `m`, so we enumerate
33//!   them **once** into a thread-local table — emitted in the exact recursion
34//!   order, pruned at `|π| >= 5` (the same order-4 truncation) — and the hot
35//!   loop is then a flat sum of products with no recursion and no dynamic
36//!   dispatch. Same emit order, same block order, same `derivs[order]` factor,
37//!   so the result is bit-for-bit identical to the walker.
38//!
39//! Both fast paths were validated `to_bits`-identical against the shared
40//! walkers over thousands of randomised composite programs at `K ∈ {2,3,4,9}`.
41use std::cell::RefCell;
42use std::rc::Rc;
43use std::sync::atomic::{AtomicU64, Ordering};
44
45pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
46pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
47
48/// Length of the unary derivative stack `[f, f', f'', f''', f'''']`: composition
49/// is exact through order 4, partitions into `>= 5` blocks are truncated.
50const DERIVS: usize = 5;
51
52#[derive(Clone)]
53pub struct MultiDirJet {
54    pub coeffs: Vec<f64>,
55}
56
57impl MultiDirJet {
58    pub fn zero(n_dirs: usize) -> Self {
59        Self {
60            coeffs: vec![0.0; 1usize << n_dirs],
61        }
62    }
63
64    pub fn constant(n_dirs: usize, value: f64) -> Self {
65        let mut out = Self::zero(n_dirs);
66        out.coeffs[0] = value;
67        out
68    }
69
70    pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
71        let mut out = Self::constant(n_dirs, base);
72        for (idx, &value) in first.iter().take(n_dirs).enumerate() {
73            out.coeffs[1usize << idx] = value;
74        }
75        out
76    }
77
78    pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
79        let mut out = Self::zero(n_dirs);
80        for &(mask, value) in coeffs {
81            if mask < out.coeffs.len() {
82                out.coeffs[mask] = value;
83            }
84        }
85        out
86    }
87
88    #[inline]
89    pub fn coeff(&self, mask: usize) -> f64 {
90        self.coeffs[mask]
91    }
92
93    pub fn add(&self, other: &Self) -> Self {
94        Self {
95            coeffs: self
96                .coeffs
97                .iter()
98                .zip(other.coeffs.iter())
99                .map(|(lhs, rhs)| lhs + rhs)
100                .collect(),
101        }
102    }
103
104    pub fn scale(&self, scalar: f64) -> Self {
105        Self {
106            coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
107        }
108    }
109
110    /// Subset-convolution product `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`.
111    ///
112    /// Bit-identical to the shared [`crate::jet_algebra::leibniz_product`] walker
113    /// (the submasks are enumerated in the same ascending order — the walker's
114    /// compacted subset index is a monotone bit-deposit of the submask) while
115    /// dropping its per-subset `SlotBuf`/closure/`mask_of` overhead. The scalar
116    /// `n_dirs == 0` case keeps the shared walker live as its reference.
117    pub fn mul(&self, other: &Self) -> Self {
118        MUL_CALLS.fetch_add(1, Ordering::Relaxed);
119        let count = self.coeffs.len();
120        if count <= 1 {
121            return self.mul_reference(other);
122        }
123        let a = &self.coeffs;
124        let b = &other.coeffs;
125        let mut out = vec![0.0; count];
126        for (mask, slot) in out.iter_mut().enumerate() {
127            // Walk every submask of `mask` in ascending numeric order — the same
128            // order `leibniz_product` accumulates — via the classic gap-fill
129            // increment `next = ((sub | !mask) + 1) & mask`.
130            let mut acc = 0.0;
131            let mut sub = 0usize;
132            loop {
133                acc += a[sub] * b[mask ^ sub];
134                if sub == mask {
135                    break;
136                }
137                sub = (sub | !mask).wrapping_add(1) & mask;
138            }
139            *slot = acc;
140        }
141        Self { coeffs: out }
142    }
143
144    /// The pre-#perf shared-walker product, retained verbatim as the scalar-case
145    /// implementation and as the bit-exact reference for `mul`.
146    fn mul_reference(&self, other: &Self) -> Self {
147        let count = self.coeffs.len();
148        let mut out = vec![0.0; count];
149        for (mask, slot) in out.iter_mut().enumerate() {
150            let bits = bit_positions(mask);
151            *slot = crate::jet_algebra::leibniz_product(
152                bits.as_slice(),
153                |t| self.coeffs[mask_of(t)],
154                |c| other.coeffs[mask_of(c)],
155            );
156        }
157        Self { coeffs: out }
158    }
159
160    /// Exact (order-4 truncated) unary composition `f(self)` from the Taylor
161    /// stack `[f, f', f'', f''', f'''']` at `self.coeff(0)`.
162    ///
163    /// Bit-identical to the shared [`crate::jet_algebra`] Faà di Bruno walker:
164    /// it enumerates the set-partitions of each output mask's slots in the exact
165    /// same recursion order, multiplies `derivs[order]` by the same per-block
166    /// inner coefficients in the same order, and sums them in the same order —
167    /// but the partition enumeration is hoisted out of the per-mask loop into a
168    /// thread-local table built once per slot count. The scalar `n_dirs == 0`
169    /// case keeps the shared walker live as its reference.
170    pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
171        COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
172        let count = self.coeffs.len();
173        if count <= 1 {
174            return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
175        }
176        let n_dirs = count.trailing_zeros() as usize;
177        // Partition tables for every slot count present, built once and cached.
178        let tables = partition_tables(n_dirs);
179        let coeffs = &self.coeffs;
180        let mut out = vec![0.0; count];
181        // Per-mask scratch: `remap[cb]` lifts a compacted submask `cb` of the
182        // current mask's slots back to the real coefficient index (the walker's
183        // `mask_of(labelled)`). Filled once per mask and reused across all of
184        // that mask's partitions/blocks, replacing the per-block bit-deposit
185        // loop with a single load. Sized `count` (>= 2^npos for every mask).
186        let mut remap = vec![0usize; count];
187        let mut pos = [0usize; usize::BITS as usize];
188        for (mask, slot) in out.iter_mut().enumerate() {
189            if mask == 0 {
190                // Matches the walker's `m == 0` early return exactly (no `0.0 +`
191                // round-trip, which would differ on a `-0.0` value channel).
192                *slot = derivs[0];
193                continue;
194            }
195            // Set-bit positions of `mask`, ascending — the slot labels.
196            let mut npos = 0usize;
197            let mut m = mask;
198            while m != 0 {
199                pos[npos] = m.trailing_zeros() as usize;
200                npos += 1;
201                m &= m - 1;
202            }
203            // Deposit table: remap[cb] = OR over set bits `i` of cb of 1<<pos[i].
204            // DP over submasks — strip the lowest bit, add its real position.
205            remap[0] = 0;
206            for cb in 1usize..(1usize << npos) {
207                let low = cb.trailing_zeros() as usize;
208                remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
209            }
210            let table = &tables[npos];
211            let flat = &table.flat;
212            let mut total = 0.0;
213            for &(off, order) in table.parts.iter() {
214                let order = order as usize;
215                let mut prod = derivs[order];
216                for &cb in &flat[off..off + order] {
217                    prod *= coeffs[remap[cb as usize]];
218                }
219                total += prod;
220            }
221            *slot = total;
222        }
223        Self { coeffs: out }
224    }
225}
226
227impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
228    #[inline]
229    fn derivative(&self, slots: &[usize]) -> f64 {
230        self.coeffs[mask_of(slots)]
231    }
232
233    fn map_derivatives<F>(&self, mut f: F) -> Self
234    where
235        F: FnMut(&[usize]) -> f64,
236    {
237        let mut out = vec![0.0; self.coeffs.len()];
238        for (mask, value) in out.iter_mut().enumerate() {
239            let bits = bit_positions(mask);
240            *value = f(bits.as_slice());
241        }
242        Self { coeffs: out }
243    }
244}
245
246/// A flattened set-partition table for a fixed slot count. `parts[i] = (off,
247/// order)` describes one partition: its `order` block submasks (compacted) are
248/// `flat[off .. off + order]`. Flattening keeps the hot composition loop on one
249/// contiguous slice instead of chasing per-partition `Vec` pointers.
250struct PartTable {
251    flat: Vec<u32>,
252    parts: Vec<(usize, u8)>,
253}
254
255thread_local! {
256    /// Cached set-partition tables, indexed by slot count `m`. Entry `m` holds
257    /// every partition of `{0..m}` into `< DERIVS` blocks, in the shared
258    /// walker's recursion order, each block a compacted submask. Pure function
259    /// of `m`, so caching is sound and deterministic.
260    static PARTITION_TABLES: RefCell<Vec<Rc<PartTable>>> = const { RefCell::new(Vec::new()) };
261}
262
263/// Return cached partition tables for slot counts `0..=n_dirs`.
264fn partition_tables(n_dirs: usize) -> Vec<Rc<PartTable>> {
265    PARTITION_TABLES.with(|cell| {
266        let mut tables = cell.borrow_mut();
267        while tables.len() <= n_dirs {
268            let m = tables.len();
269            tables.push(Rc::new(build_partitions(m)));
270        }
271        (0..=n_dirs).map(|m| Rc::clone(&tables[m])).collect()
272    })
273}
274
275/// Enumerate the set-partitions of `{0..m}` with fewer than `DERIVS` blocks, in
276/// the exact DFS order of [`crate::jet_algebra`]'s `for_each_partition`
277/// recursion ("place each element into an existing block, else open a new one"),
278/// each block recorded as a compacted submask of `{0..m}`, flattened.
279fn build_partitions(m: usize) -> PartTable {
280    fn recurse(elem: usize, m: usize, blocks: &mut [u32; 8], n_blocks: usize, out: &mut PartTable) {
281        // Partitions with `>= DERIVS` blocks are truncated (their `f^{(order)}`
282        // is beyond the stack); the block count never decreases, so the whole
283        // subtree contributes nothing and is pruned — matching the walker's
284        // per-partition `order >= derivs.len()` skip.
285        if n_blocks >= DERIVS {
286            return;
287        }
288        if elem == m {
289            let off = out.flat.len();
290            out.flat.extend_from_slice(&blocks[..n_blocks]);
291            out.parts.push((off, n_blocks as u8));
292            return;
293        }
294        for b in 0..n_blocks {
295            blocks[b] |= 1u32 << elem;
296            recurse(elem + 1, m, blocks, n_blocks, out);
297            blocks[b] &= !(1u32 << elem);
298        }
299        blocks[n_blocks] = 1u32 << elem;
300        recurse(elem + 1, m, blocks, n_blocks + 1, out);
301    }
302    let mut out = PartTable {
303        flat: Vec::new(),
304        parts: Vec::new(),
305    };
306    let mut blocks = [0u32; 8];
307    recurse(0, m, &mut blocks, 0, &mut out);
308    out
309}
310
311/// The set-bit positions of `mask`, low to high — the differentiation slots of
312/// that coefficient.
313fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
314    let mut out = crate::jet_algebra::SlotBuf::new();
315    let mut m = mask;
316    while m != 0 {
317        let bit = m.trailing_zeros() as usize;
318        out.push_slot(bit);
319        m &= m - 1;
320    }
321    out
322}
323
324/// Combine a slot-group (list of bit positions) back into a sub-mask.
325fn mask_of(slots: &[usize]) -> usize {
326    slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
327}
328
329// #932-2 cutover: `MultiDirJet::bilinear` (the 4-coeff `[base, d1, d2, d12]`
330// constructor) and `MultiDirJet::sub` are consumed ONLY by the now test-only hand
331// survival directional/bidirectional oracle (the production flex jet path uses the
332// `flex_jet` runtime jet algebra, not `MultiDirJet`). After the #1521 crate split
333// moved `MultiDirJet` into `gam-math`, those oracle tests live in the dependent
334// `gam` crate, where a `#[cfg(test)]` gate in *this* crate is inactive — so the
335// methods must be plain `pub` inherent methods to be reachable cross-crate. They
336// carry no dead-code cost because `pub` items are part of the crate's public API.
337// Bodies are byte-identical to their former gated form.
338impl MultiDirJet {
339    pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
340        Self {
341            coeffs: vec![base, d1, d2, d12],
342        }
343    }
344
345    pub fn sub(&self, other: &Self) -> Self {
346        Self {
347            coeffs: self
348                .coeffs
349                .iter()
350                .zip(other.coeffs.iter())
351                .map(|(lhs, rhs)| lhs - rhs)
352                .collect(),
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    // ── constructors ─────────────────────────────────────────────────────────
362
363    #[test]
364    fn zero_has_correct_length_and_all_zero_coefficients() {
365        let j = MultiDirJet::zero(3);
366        assert_eq!(j.coeffs.len(), 8);
367        assert!(j.coeffs.iter().all(|&v| v == 0.0));
368    }
369
370    #[test]
371    fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
372        let j = MultiDirJet::constant(2, 5.0);
373        assert_eq!(j.coeffs.len(), 4);
374        assert_eq!(j.coeff(0), 5.0);
375        assert_eq!(j.coeff(1), 0.0);
376        assert_eq!(j.coeff(2), 0.0);
377        assert_eq!(j.coeff(3), 0.0);
378    }
379
380    #[test]
381    fn linear_sets_base_and_per_direction_slots() {
382        let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
383        assert_eq!(j.coeff(0), 1.0); // constant
384        assert_eq!(j.coeff(1), 2.0); // mask 0b01 — direction 0
385        assert_eq!(j.coeff(2), 3.0); // mask 0b10 — direction 1
386        assert_eq!(j.coeff(3), 0.0); // cross term is zero
387    }
388
389    #[test]
390    fn bilinear_sets_all_four_slots() {
391        let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
392        assert_eq!(j.coeff(0), 1.0);
393        assert_eq!(j.coeff(1), 2.0);
394        assert_eq!(j.coeff(2), 3.0);
395        assert_eq!(j.coeff(3), 4.0);
396    }
397
398    #[test]
399    fn with_coeffs_sets_only_specified_entries() {
400        let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
401        assert_eq!(j.coeff(0), 9.0);
402        assert_eq!(j.coeff(1), 0.0);
403        assert_eq!(j.coeff(2), 0.0);
404        assert_eq!(j.coeff(3), -1.0);
405    }
406
407    // ── elementwise arithmetic ────────────────────────────────────────────────
408
409    #[test]
410    fn add_is_elementwise() {
411        let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
412        let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
413        let c = a.add(&b);
414        assert_eq!(c.coeff(0), 5.0);
415        assert_eq!(c.coeff(1), 7.0);
416        assert_eq!(c.coeff(2), 9.0);
417        assert_eq!(c.coeff(3), 0.0);
418    }
419
420    #[test]
421    fn scale_multiplies_all_coefficients() {
422        let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
423        let s = j.scale(2.0);
424        assert_eq!(s.coeff(0), 2.0);
425        assert_eq!(s.coeff(1), 4.0);
426        assert_eq!(s.coeff(2), 6.0);
427        assert_eq!(s.coeff(3), 0.0);
428    }
429
430    #[test]
431    fn sub_is_elementwise_difference() {
432        let a = MultiDirJet::constant(2, 5.0);
433        let b = MultiDirJet::constant(2, 3.0);
434        let c = a.sub(&b);
435        assert_eq!(c.coeff(0), 2.0);
436        assert_eq!(c.coeff(1), 0.0);
437        assert_eq!(c.coeff(2), 0.0);
438        assert_eq!(c.coeff(3), 0.0);
439    }
440
441    // ── mul (subset-convolution) ──────────────────────────────────────────────
442
443    #[test]
444    fn mul_of_constants_is_scalar_product() {
445        let a = MultiDirJet::constant(2, 2.0);
446        let b = MultiDirJet::constant(2, 3.0);
447        let c = a.mul(&b);
448        assert_eq!(c.coeff(0), 6.0);
449        assert_eq!(c.coeff(1), 0.0);
450        assert_eq!(c.coeff(2), 0.0);
451        assert_eq!(c.coeff(3), 0.0);
452    }
453
454    #[test]
455    fn mul_satisfies_leibniz_rule_single_direction() {
456        // (1 + ε) * (1 + ε) = 1 + 2ε
457        let x = MultiDirJet::linear(1, 1.0, &[1.0]);
458        let y = MultiDirJet::linear(1, 1.0, &[1.0]);
459        let z = x.mul(&y);
460        assert_eq!(z.coeff(0), 1.0);
461        assert_eq!(z.coeff(1), 2.0);
462    }
463
464    #[test]
465    fn mul_cross_term_two_independent_directions() {
466        // (1 + ε₁)(1 + ε₂) = 1 + ε₁ + ε₂ + ε₁ε₂
467        let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
468        let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
469        let z = x.mul(&y);
470        assert_eq!(z.coeff(0), 1.0);
471        assert_eq!(z.coeff(1), 1.0);
472        assert_eq!(z.coeff(2), 1.0);
473        assert_eq!(z.coeff(3), 1.0);
474    }
475}