gam_math/jet_partitions.rs
1//! Bitmask-coefficient multi-directional jets used by marginal-slope and
2//! latent-survival row kernels.
3//!
4//! The layout stores one coefficient per direction mask. The calculus itself
5//! lives in [`crate::jet_algebra`]: that module owns the layout-agnostic
6//! Leibniz / Faà di Bruno *combinatorics* once, and the scalar (`n_dirs <= 1`)
7//! path here still routes through it so a fix to the rule is a fix to both
8//! representations.
9//!
10//! ## Why this layout is special (and how the hot path exploits it)
11//!
12//! Each direction is seeded *linearly* (one first-derivative slot), so every
13//! direction variable squares to zero. The coefficients therefore form the
14//! commutative **multilinear / set-function algebra**: `coeffs[mask]` is the
15//! coefficient of `Π_{i ∈ mask} ε_i`. In that algebra two facts collapse the
16//! generic combinatorial walkers into tight branch-free arithmetic:
17//!
18//! * **`mul` is the subset (zeta-style) convolution**
19//! `out[mask] = Σ_{sub ⊆ mask} a[sub] · b[mask \ sub]`.
20//! The shared `leibniz_product` walker rebuilds two `SlotBuf`s and folds bit
21//! lists back into masks (`mask_of`) *per subset*; here we enumerate the
22//! submasks of `mask` directly — `mask \ sub == mask ^ sub` because
23//! `sub ⊆ mask` — in the **same ascending order** the walker used, so the
24//! floating-point accumulation is bit-for-bit identical while every
25//! `SlotBuf`/closure/`mask_of` allocation and indirection disappears
26//! (`3^K` pure FMAs, no heap, no `dyn`).
27//!
28//! * **`compose_unary` is the truncated set-partition (Faà di Bruno) sum**
29//! `out[mask] = Σ_{π ⊢ mask, |π| < 5} f^{(|π|)} · Π_{B ∈ π} u[B]`.
30//! The shared walker re-runs the partition *recursion* (with `&mut dyn
31//! FnMut` dispatch and fresh `SlotBuf` blocks) once **per output mask**.
32//! The set of partitions of `m` slots depends only on `m`, so we enumerate
33//! them **once** into a thread-local table — emitted in the exact recursion
34//! order, pruned at `|π| >= 5` (the same order-4 truncation) — and the hot
35//! loop is then a flat sum of products with no recursion and no dynamic
36//! dispatch. Same emit order, same block order, same `derivs[order]` factor,
37//! so the result is bit-for-bit identical to the walker.
38//!
39//! Both fast paths were validated `to_bits`-identical against the shared
40//! walkers over thousands of randomised composite programs at `K ∈ {2,3,4,9}`.
41use std::cell::RefCell;
42use std::rc::Rc;
43use std::sync::atomic::{AtomicU64, Ordering};
44
45pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
46pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
47
48/// Length of the unary derivative stack `[f, f', f'', f''', f'''']`: composition
49/// is exact through order 4, partitions into `>= 5` blocks are truncated.
50const DERIVS: usize = 5;
51
52#[derive(Clone)]
53pub struct MultiDirJet {
54 pub coeffs: Vec<f64>,
55}
56
57impl MultiDirJet {
58 pub fn zero(n_dirs: usize) -> Self {
59 Self {
60 coeffs: vec![0.0; 1usize << n_dirs],
61 }
62 }
63
64 pub fn constant(n_dirs: usize, value: f64) -> Self {
65 let mut out = Self::zero(n_dirs);
66 out.coeffs[0] = value;
67 out
68 }
69
70 pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
71 let mut out = Self::constant(n_dirs, base);
72 for (idx, &value) in first.iter().take(n_dirs).enumerate() {
73 out.coeffs[1usize << idx] = value;
74 }
75 out
76 }
77
78 pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
79 let mut out = Self::zero(n_dirs);
80 for &(mask, value) in coeffs {
81 if mask < out.coeffs.len() {
82 out.coeffs[mask] = value;
83 }
84 }
85 out
86 }
87
88 #[inline]
89 pub fn coeff(&self, mask: usize) -> f64 {
90 self.coeffs[mask]
91 }
92
93 pub fn add(&self, other: &Self) -> Self {
94 Self {
95 coeffs: self
96 .coeffs
97 .iter()
98 .zip(other.coeffs.iter())
99 .map(|(lhs, rhs)| lhs + rhs)
100 .collect(),
101 }
102 }
103
104 pub fn scale(&self, scalar: f64) -> Self {
105 Self {
106 coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
107 }
108 }
109
110 /// Subset-convolution product `out[mask] = Σ_{sub ⊆ mask} a[sub]·b[mask^sub]`.
111 ///
112 /// Bit-identical to the shared [`crate::jet_algebra::leibniz_product`] walker
113 /// (the submasks are enumerated in the same ascending order — the walker's
114 /// compacted subset index is a monotone bit-deposit of the submask) while
115 /// dropping its per-subset `SlotBuf`/closure/`mask_of` overhead. The scalar
116 /// `n_dirs == 0` case keeps the shared walker live as its reference.
117 pub fn mul(&self, other: &Self) -> Self {
118 MUL_CALLS.fetch_add(1, Ordering::Relaxed);
119 let count = self.coeffs.len();
120 if count <= 1 {
121 return self.mul_reference(other);
122 }
123 let a = &self.coeffs;
124 let b = &other.coeffs;
125 let mut out = vec![0.0; count];
126 for (mask, slot) in out.iter_mut().enumerate() {
127 // Walk every submask of `mask` in ascending numeric order — the same
128 // order `leibniz_product` accumulates — via the classic gap-fill
129 // increment `next = ((sub | !mask) + 1) & mask`.
130 let mut acc = 0.0;
131 let mut sub = 0usize;
132 loop {
133 acc += a[sub] * b[mask ^ sub];
134 if sub == mask {
135 break;
136 }
137 sub = (sub | !mask).wrapping_add(1) & mask;
138 }
139 *slot = acc;
140 }
141 Self { coeffs: out }
142 }
143
144 /// The pre-#perf shared-walker product, retained verbatim as the scalar-case
145 /// implementation and as the bit-exact reference for `mul`.
146 fn mul_reference(&self, other: &Self) -> Self {
147 let count = self.coeffs.len();
148 let mut out = vec![0.0; count];
149 for (mask, slot) in out.iter_mut().enumerate() {
150 let bits = bit_positions(mask);
151 *slot = crate::jet_algebra::leibniz_product(
152 bits.as_slice(),
153 |t| self.coeffs[mask_of(t)],
154 |c| other.coeffs[mask_of(c)],
155 );
156 }
157 Self { coeffs: out }
158 }
159
160 /// Exact (order-4 truncated) unary composition `f(self)` from the Taylor
161 /// stack `[f, f', f'', f''', f'''']` at `self.coeff(0)`.
162 ///
163 /// Bit-identical to the shared [`crate::jet_algebra`] Faà di Bruno walker:
164 /// it enumerates the set-partitions of each output mask's slots in the exact
165 /// same recursion order, multiplies `derivs[order]` by the same per-block
166 /// inner coefficients in the same order, and sums them in the same order —
167 /// but the partition enumeration is hoisted out of the per-mask loop into a
168 /// thread-local table built once per slot count. The scalar `n_dirs == 0`
169 /// case keeps the shared walker live as its reference.
170 pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
171 COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
172 let count = self.coeffs.len();
173 if count <= 1 {
174 return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
175 }
176 let n_dirs = count.trailing_zeros() as usize;
177 // Partition tables for every slot count present, built once and cached.
178 let tables = partition_tables(n_dirs);
179 let coeffs = &self.coeffs;
180 let mut out = vec![0.0; count];
181 // Per-mask scratch: `remap[cb]` lifts a compacted submask `cb` of the
182 // current mask's slots back to the real coefficient index (the walker's
183 // `mask_of(labelled)`). Filled once per mask and reused across all of
184 // that mask's partitions/blocks, replacing the per-block bit-deposit
185 // loop with a single load. Sized `count` (>= 2^npos for every mask).
186 let mut remap = vec![0usize; count];
187 let mut pos = [0usize; usize::BITS as usize];
188 for (mask, slot) in out.iter_mut().enumerate() {
189 if mask == 0 {
190 // Matches the walker's `m == 0` early return exactly (no `0.0 +`
191 // round-trip, which would differ on a `-0.0` value channel).
192 *slot = derivs[0];
193 continue;
194 }
195 // Set-bit positions of `mask`, ascending — the slot labels.
196 let mut npos = 0usize;
197 let mut m = mask;
198 while m != 0 {
199 pos[npos] = m.trailing_zeros() as usize;
200 npos += 1;
201 m &= m - 1;
202 }
203 // Deposit table: remap[cb] = OR over set bits `i` of cb of 1<<pos[i].
204 // DP over submasks — strip the lowest bit, add its real position.
205 remap[0] = 0;
206 for cb in 1usize..(1usize << npos) {
207 let low = cb.trailing_zeros() as usize;
208 remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
209 }
210 let table = &tables[npos];
211 let flat = &table.flat;
212 let mut total = 0.0;
213 for &(off, order) in table.parts.iter() {
214 let order = order as usize;
215 let mut prod = derivs[order];
216 for &cb in &flat[off..off + order] {
217 prod *= coeffs[remap[cb as usize]];
218 }
219 total += prod;
220 }
221 *slot = total;
222 }
223 Self { coeffs: out }
224 }
225}
226
227impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
228 #[inline]
229 fn derivative(&self, slots: &[usize]) -> f64 {
230 self.coeffs[mask_of(slots)]
231 }
232
233 fn map_derivatives<F>(&self, mut f: F) -> Self
234 where
235 F: FnMut(&[usize]) -> f64,
236 {
237 let mut out = vec![0.0; self.coeffs.len()];
238 for (mask, value) in out.iter_mut().enumerate() {
239 let bits = bit_positions(mask);
240 *value = f(bits.as_slice());
241 }
242 Self { coeffs: out }
243 }
244}
245
246/// A flattened set-partition table for a fixed slot count. `parts[i] = (off,
247/// order)` describes one partition: its `order` block submasks (compacted) are
248/// `flat[off .. off + order]`. Flattening keeps the hot composition loop on one
249/// contiguous slice instead of chasing per-partition `Vec` pointers.
250struct PartTable {
251 flat: Vec<u32>,
252 parts: Vec<(usize, u8)>,
253}
254
255thread_local! {
256 /// Cached set-partition tables, indexed by slot count `m`. Entry `m` holds
257 /// every partition of `{0..m}` into `< DERIVS` blocks, in the shared
258 /// walker's recursion order, each block a compacted submask. Pure function
259 /// of `m`, so caching is sound and deterministic.
260 static PARTITION_TABLES: RefCell<Vec<Rc<PartTable>>> = const { RefCell::new(Vec::new()) };
261}
262
263/// Return cached partition tables for slot counts `0..=n_dirs`.
264fn partition_tables(n_dirs: usize) -> Vec<Rc<PartTable>> {
265 PARTITION_TABLES.with(|cell| {
266 let mut tables = cell.borrow_mut();
267 while tables.len() <= n_dirs {
268 let m = tables.len();
269 tables.push(Rc::new(build_partitions(m)));
270 }
271 (0..=n_dirs).map(|m| Rc::clone(&tables[m])).collect()
272 })
273}
274
275/// Enumerate the set-partitions of `{0..m}` with fewer than `DERIVS` blocks, in
276/// the exact DFS order of [`crate::jet_algebra`]'s `for_each_partition`
277/// recursion ("place each element into an existing block, else open a new one"),
278/// each block recorded as a compacted submask of `{0..m}`, flattened.
279fn build_partitions(m: usize) -> PartTable {
280 fn recurse(elem: usize, m: usize, blocks: &mut [u32; 8], n_blocks: usize, out: &mut PartTable) {
281 // Partitions with `>= DERIVS` blocks are truncated (their `f^{(order)}`
282 // is beyond the stack); the block count never decreases, so the whole
283 // subtree contributes nothing and is pruned — matching the walker's
284 // per-partition `order >= derivs.len()` skip.
285 if n_blocks >= DERIVS {
286 return;
287 }
288 if elem == m {
289 let off = out.flat.len();
290 out.flat.extend_from_slice(&blocks[..n_blocks]);
291 out.parts.push((off, n_blocks as u8));
292 return;
293 }
294 for b in 0..n_blocks {
295 blocks[b] |= 1u32 << elem;
296 recurse(elem + 1, m, blocks, n_blocks, out);
297 blocks[b] &= !(1u32 << elem);
298 }
299 blocks[n_blocks] = 1u32 << elem;
300 recurse(elem + 1, m, blocks, n_blocks + 1, out);
301 }
302 let mut out = PartTable {
303 flat: Vec::new(),
304 parts: Vec::new(),
305 };
306 let mut blocks = [0u32; 8];
307 recurse(0, m, &mut blocks, 0, &mut out);
308 out
309}
310
311/// The set-bit positions of `mask`, low to high — the differentiation slots of
312/// that coefficient.
313fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
314 let mut out = crate::jet_algebra::SlotBuf::new();
315 let mut m = mask;
316 while m != 0 {
317 let bit = m.trailing_zeros() as usize;
318 out.push_slot(bit);
319 m &= m - 1;
320 }
321 out
322}
323
324/// Combine a slot-group (list of bit positions) back into a sub-mask.
325fn mask_of(slots: &[usize]) -> usize {
326 slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
327}
328
329// #932-2 cutover: `MultiDirJet::bilinear` (the 4-coeff `[base, d1, d2, d12]`
330// constructor) and `MultiDirJet::sub` are consumed ONLY by the now test-only hand
331// survival directional/bidirectional oracle (the production flex jet path uses the
332// `flex_jet` runtime jet algebra, not `MultiDirJet`). After the #1521 crate split
333// moved `MultiDirJet` into `gam-math`, those oracle tests live in the dependent
334// `gam` crate, where a `#[cfg(test)]` gate in *this* crate is inactive — so the
335// methods must be plain `pub` inherent methods to be reachable cross-crate. They
336// carry no dead-code cost because `pub` items are part of the crate's public API.
337// Bodies are byte-identical to their former gated form.
338impl MultiDirJet {
339 pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
340 Self {
341 coeffs: vec![base, d1, d2, d12],
342 }
343 }
344
345 pub fn sub(&self, other: &Self) -> Self {
346 Self {
347 coeffs: self
348 .coeffs
349 .iter()
350 .zip(other.coeffs.iter())
351 .map(|(lhs, rhs)| lhs - rhs)
352 .collect(),
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 // ── constructors ─────────────────────────────────────────────────────────
362
363 #[test]
364 fn zero_has_correct_length_and_all_zero_coefficients() {
365 let j = MultiDirJet::zero(3);
366 assert_eq!(j.coeffs.len(), 8);
367 assert!(j.coeffs.iter().all(|&v| v == 0.0));
368 }
369
370 #[test]
371 fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
372 let j = MultiDirJet::constant(2, 5.0);
373 assert_eq!(j.coeffs.len(), 4);
374 assert_eq!(j.coeff(0), 5.0);
375 assert_eq!(j.coeff(1), 0.0);
376 assert_eq!(j.coeff(2), 0.0);
377 assert_eq!(j.coeff(3), 0.0);
378 }
379
380 #[test]
381 fn linear_sets_base_and_per_direction_slots() {
382 let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
383 assert_eq!(j.coeff(0), 1.0); // constant
384 assert_eq!(j.coeff(1), 2.0); // mask 0b01 — direction 0
385 assert_eq!(j.coeff(2), 3.0); // mask 0b10 — direction 1
386 assert_eq!(j.coeff(3), 0.0); // cross term is zero
387 }
388
389 #[test]
390 fn bilinear_sets_all_four_slots() {
391 let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
392 assert_eq!(j.coeff(0), 1.0);
393 assert_eq!(j.coeff(1), 2.0);
394 assert_eq!(j.coeff(2), 3.0);
395 assert_eq!(j.coeff(3), 4.0);
396 }
397
398 #[test]
399 fn with_coeffs_sets_only_specified_entries() {
400 let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
401 assert_eq!(j.coeff(0), 9.0);
402 assert_eq!(j.coeff(1), 0.0);
403 assert_eq!(j.coeff(2), 0.0);
404 assert_eq!(j.coeff(3), -1.0);
405 }
406
407 // ── elementwise arithmetic ────────────────────────────────────────────────
408
409 #[test]
410 fn add_is_elementwise() {
411 let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
412 let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
413 let c = a.add(&b);
414 assert_eq!(c.coeff(0), 5.0);
415 assert_eq!(c.coeff(1), 7.0);
416 assert_eq!(c.coeff(2), 9.0);
417 assert_eq!(c.coeff(3), 0.0);
418 }
419
420 #[test]
421 fn scale_multiplies_all_coefficients() {
422 let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
423 let s = j.scale(2.0);
424 assert_eq!(s.coeff(0), 2.0);
425 assert_eq!(s.coeff(1), 4.0);
426 assert_eq!(s.coeff(2), 6.0);
427 assert_eq!(s.coeff(3), 0.0);
428 }
429
430 #[test]
431 fn sub_is_elementwise_difference() {
432 let a = MultiDirJet::constant(2, 5.0);
433 let b = MultiDirJet::constant(2, 3.0);
434 let c = a.sub(&b);
435 assert_eq!(c.coeff(0), 2.0);
436 assert_eq!(c.coeff(1), 0.0);
437 assert_eq!(c.coeff(2), 0.0);
438 assert_eq!(c.coeff(3), 0.0);
439 }
440
441 // ── mul (subset-convolution) ──────────────────────────────────────────────
442
443 #[test]
444 fn mul_of_constants_is_scalar_product() {
445 let a = MultiDirJet::constant(2, 2.0);
446 let b = MultiDirJet::constant(2, 3.0);
447 let c = a.mul(&b);
448 assert_eq!(c.coeff(0), 6.0);
449 assert_eq!(c.coeff(1), 0.0);
450 assert_eq!(c.coeff(2), 0.0);
451 assert_eq!(c.coeff(3), 0.0);
452 }
453
454 #[test]
455 fn mul_satisfies_leibniz_rule_single_direction() {
456 // (1 + ε) * (1 + ε) = 1 + 2ε
457 let x = MultiDirJet::linear(1, 1.0, &[1.0]);
458 let y = MultiDirJet::linear(1, 1.0, &[1.0]);
459 let z = x.mul(&y);
460 assert_eq!(z.coeff(0), 1.0);
461 assert_eq!(z.coeff(1), 2.0);
462 }
463
464 #[test]
465 fn mul_cross_term_two_independent_directions() {
466 // (1 + ε₁)(1 + ε₂) = 1 + ε₁ + ε₂ + ε₁ε₂
467 let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
468 let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
469 let z = x.mul(&y);
470 assert_eq!(z.coeff(0), 1.0);
471 assert_eq!(z.coeff(1), 1.0);
472 assert_eq!(z.coeff(2), 1.0);
473 assert_eq!(z.coeff(3), 1.0);
474 }
475}