gam_math/jet_algebra.rs
1//! The single shared Faà di Bruno / Leibniz combinatorial kernel (#1151).
2//!
3//! Two jet representations live in this crate and historically each carried
4//! its own hand-written copy of the same calculus:
5//!
6//! * [`crate::jet_tower::Tower4`] — full dense derivative tensors
7//! (`v`, `g`, `h`, `t3`, `t4`) in `K` primary variables, with the
8//! Leibniz product and Faà di Bruno composition written out term-by-term
9//! per derivative order.
10//! * [`crate::jet_partitions::MultiDirJet`] — bitmask-coefficient jet over
11//! distinct seeded directions, with the same two rules written as general
12//! submask / set-partition loops.
13//!
14//! The data layouts are legitimately different (a complete small-`K` tower
15//! vs. a handful of directions of a large-`K` expression) and stay separate.
16//! What is identical is the *combinatorics*: for a group of differentiation
17//! slots, the Leibniz rule sums over subsets of those slots and the Faà di
18//! Bruno rule sums over their set-partitions. This module owns that
19//! combinatorics once, as a layout-agnostic [`JetAlgebra`] trait plus walkers
20//! parameterised by closures that read each representation's own derivative
21//! for a slot-group. Both
22//! `Tower4` and `MultiDirJet` route their `mul` / `compose_unary` through
23//! these walkers, so a fix to the rule is a fix to both — and a bit-exact
24//! equivalence test (see `tests`) proves the two layouts agree.
25//!
26//! A "slot-group" is a list of positions `0..m` (the differentiation
27//! arguments of one output coefficient). Each representation maps a group to
28//! a derivative:
29//!
30//! * For a tensor index tuple `(i, j, k, l)`, the positions `0..m` carry
31//! axis labels `[i, j, k, l]`; a sub-group of positions selects the
32//! corresponding lower-order tensor entry (e.g. positions `{0, 2}` →
33//! `h[i][k]`).
34//! * For a bitmask coefficient `mask`, the set bits are the slots; a
35//! sub-group of bits is itself a sub-mask read straight out of `coeffs`.
36
37/// Walk the Leibniz product rule for an output of `m` differentiation slots.
38///
39/// `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`, summed over every subset `T`
40/// of the `m` positions. `left(t)` / `right(c)` receive the position lists of
41/// the chosen subset and its complement and must return the corresponding
42/// derivative of the two factors. Returns the summed output coefficient.
43///
44/// `m` is small (≤ 4 for the tower, ≤ 8 for the directional jet); the
45/// `2^m` subset walk is the exact rule, not a truncation.
46#[inline]
47pub(crate) fn leibniz_product<L, R>(positions: &[usize], mut left: L, mut right: R) -> f64
48where
49 L: FnMut(&[usize]) -> f64,
50 R: FnMut(&[usize]) -> f64,
51{
52 let m = positions.len();
53 assert!(
54 m <= usize::BITS as usize,
55 "too many differentiation slots for subset enumeration"
56 );
57 let mut subset = SlotBuf::new();
58 let mut complement = SlotBuf::new();
59 let mut total = 0.0;
60 for sub in 0u32..(1u32 << m) {
61 subset.clear();
62 complement.clear();
63 for (bit, &pos) in positions.iter().enumerate() {
64 if sub & (1u32 << bit) != 0 {
65 subset.push(pos);
66 } else {
67 complement.push(pos);
68 }
69 }
70 total += left(subset.as_slice()) * right(complement.as_slice());
71 }
72 total
73}
74
75/// Walk the multivariate Faà di Bruno rule for an output of `m` slots.
76///
77/// `D(f∘u) = Σ_{partitions π of the m slots} f^{(|π|)}(u) · Π_{B ∈ π} D_B(u)`.
78/// `derivs[r]` is `f^{(r)}` at the inner value; `inner(block)` returns the
79/// derivative of the inner expression for a block's position list. Returns the
80/// summed output coefficient. Blocks of order ≥ `derivs.len()` are skipped
81/// (their `f^{(r)}` is beyond the truncation), matching both legacy paths.
82#[inline]
83pub fn faa_di_bruno<F>(positions: &[usize], derivs: &[f64], mut inner: F) -> f64
84where
85 F: FnMut(&[usize]) -> f64,
86{
87 let m = positions.len();
88 if m == 0 {
89 return derivs[0];
90 }
91 let mut total = 0.0;
92 for_each_partition(m, &mut |blocks: &[SlotBuf]| {
93 let order = blocks.len();
94 if order >= derivs.len() {
95 return;
96 }
97 let mut prod = derivs[order];
98 for block in blocks {
99 // Translate block positions to their axis labels for `inner`.
100 let mut labelled = SlotBuf::new();
101 for &p in block.as_slice() {
102 labelled.push(positions[p]);
103 }
104 prod *= inner(labelled.as_slice());
105 }
106 total += prod;
107 });
108 total
109}
110
111/// Layout hook for jets that share the Faà di Bruno unary-composition kernel.
112///
113/// `DERIVS` is the length of the unary derivative stack: `5` for fourth-order
114/// jets (`[f, f′, f″, f‴, f⁗]`) and `3` for second-order jets. Implementors own
115/// how slot lists map to their storage; the kernel owns the set-partition rule.
116pub(crate) trait JetAlgebra<const DERIVS: usize>: Sized {
117 /// Read the derivative for a slot list. An empty list is the value channel.
118 fn derivative(&self, positions: &[usize]) -> f64;
119
120 /// Build a jet with every stored derivative filled by `f(positions)`.
121 fn map_derivatives<F>(&self, f: F) -> Self
122 where
123 F: FnMut(&[usize]) -> f64;
124
125 /// Exact multivariate Faà di Bruno composition.
126 fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
127 compose_unary_kernel(self, derivs)
128 }
129}
130
131/// The single unary-composition kernel shared by tower and bitmask jets.
132#[inline]
133pub(crate) fn compose_unary_kernel<J, const DERIVS: usize>(inner: &J, derivs: [f64; DERIVS]) -> J
134where
135 J: JetAlgebra<DERIVS>,
136{
137 inner.map_derivatives(|positions| {
138 faa_di_bruno(positions, &derivs, |block| inner.derivative(block))
139 })
140}
141
142/// A tiny inline stack of slot indices — no heap traffic on the hot per-row
143/// path. Capacity (8) covers the deepest tower (order 4) and the directional
144/// jet's eight-bit masks.
145#[derive(Clone, Copy)]
146pub(crate) struct SlotBuf {
147 data: [usize; 8],
148 len: usize,
149}
150
151impl SlotBuf {
152 #[inline]
153 pub(crate) fn new() -> Self {
154 Self {
155 data: [0; 8],
156 len: 0,
157 }
158 }
159 #[inline]
160 fn clear(&mut self) {
161 self.len = 0;
162 }
163 #[inline]
164 fn push(&mut self, v: usize) {
165 self.data[self.len] = v;
166 self.len += 1;
167 }
168 /// Append a slot index. Public to the crate so other jet layouts (the
169 /// bitmask [`crate::jet_partitions`]) can build a slot list to hand the
170 /// shared walkers.
171 #[inline]
172 pub(crate) fn push_slot(&mut self, v: usize) {
173 self.push(v);
174 }
175 #[inline]
176 pub(crate) fn as_slice(&self) -> &[usize] {
177 &self.data[..self.len]
178 }
179}
180
181/// Invoke `f` once per set-partition of positions `0..m`, passing the blocks
182/// as slot lists. Recursive "assign each element to an existing or new block"
183/// enumeration — allocation-free via the fixed-capacity [`SlotBuf`].
184fn for_each_partition(m: usize, f: &mut dyn FnMut(&[SlotBuf])) {
185 let mut blocks: [SlotBuf; 8] = [SlotBuf::new(); 8];
186 recurse(0, m, &mut blocks, 0, f);
187}
188
189fn recurse(
190 elem: usize,
191 m: usize,
192 blocks: &mut [SlotBuf; 8],
193 n_blocks: usize,
194 f: &mut dyn FnMut(&[SlotBuf]),
195) {
196 if elem == m {
197 f(&blocks[..n_blocks]);
198 return;
199 }
200 // Place `elem` into each existing block.
201 for b in 0..n_blocks {
202 blocks[b].push(elem);
203 recurse(elem + 1, m, blocks, n_blocks, f);
204 blocks[b].len -= 1;
205 }
206 // Or open a new block with `elem` alone.
207 blocks[n_blocks].clear();
208 blocks[n_blocks].push(elem);
209 recurse(elem + 1, m, blocks, n_blocks + 1, f);
210}
211
212#[cfg(test)]
213mod tests {
214 use crate::jet_partitions::MultiDirJet;
215 use crate::jet_tower::Tower4;
216
217 /// Bit-exact equivalence proof: evaluate the SAME polynomial-plus-unary
218 /// composition on both jet layouts and assert every shared derivative
219 /// coefficient is identical to the last bit. Because both layouts now
220 /// route through this module's [`leibniz_product`] / [`faa_di_bruno`]
221 /// walkers, the equality is a statement that the two *data structures*
222 /// expose the same single arithmetic kernel — the #1151 guarantee.
223 ///
224 /// Program (K=2 directions / primaries, seeded as variables `x = p0`,
225 /// `z = p1`): `g = exp(x * z + x)` then `f = ln(g + 2) * g`.
226 /// Both `mul` and `compose_unary` (exp, ln) are exercised.
227 #[test]
228 fn tower_and_dirjet_agree_bit_exact() {
229 let x = 0.37_f64;
230 let z = -0.81_f64;
231
232 // ── Tower4<2> path ──
233 let tx = Tower4::<2>::variable(x, 0);
234 let tz = Tower4::<2>::variable(z, 1);
235 let tg = (tx * tz + tx).exp();
236 let tf = (tg + 2.0).ln() * tg;
237
238 // ── MultiDirJet (2 directions) path ──
239 let jx = MultiDirJet::linear(2, x, &[1.0, 0.0]);
240 let jz = MultiDirJet::linear(2, z, &[0.0, 1.0]);
241 let jg = exp_dirjet(&jx.mul(&jz).add(&jx));
242 let jf = ln_dirjet(&jg.add(&MultiDirJet::constant(2, 2.0))).mul(&jg);
243
244 // The directional jet carries coefficients for masks
245 // 0b00=value, 0b01=∂x, 0b10=∂z, 0b11=∂x∂z.
246 // The tower carries the same derivatives as tensor entries:
247 // v, g[0], g[1], h[0][1].
248 assert_eq!(jf.coeff(0b00), tf.v, "value");
249 assert_eq!(jf.coeff(0b01), tf.g[0], "∂x");
250 assert_eq!(jf.coeff(0b10), tf.g[1], "∂z");
251 assert_eq!(jf.coeff(0b11), tf.h[0][1], "∂x∂z");
252 // Symmetry of the tower's mixed second partial is also bit-exact.
253 assert_eq!(tf.h[0][1], tf.h[1][0], "tower mixed-partial symmetry");
254 }
255
256 #[test]
257 fn tower_contractions_match_dirjet_directional_coefficients() {
258 const K: usize = 3;
259 let p = [0.37_f64, -0.42_f64, 0.19_f64];
260 let q = [0.25_f64, -0.7_f64, 1.3_f64];
261 let u = [-0.4_f64, 0.9_f64, 0.15_f64];
262 let w = [1.1_f64, -0.2_f64, 0.6_f64];
263
264 let tower = nonlinear_tower_program(p);
265 let third = tower.third_contracted(&q);
266 let fourth = tower.fourth_contracted(&u, &w);
267
268 for a in 0..K {
269 for b in 0..K {
270 let mut dirs3 = [[0.0; K]; 3];
271 dirs3[0][a] = 1.0;
272 dirs3[1][b] = 1.0;
273 dirs3[2] = q;
274 let jet3 = nonlinear_dirjet_program(p, &dirs3);
275 assert_close(
276 jet3.coeff(jet3.coeffs.len() - 1),
277 third[a][b],
278 &format!("third contraction ({a},{b})"),
279 );
280
281 let mut dirs4 = [[0.0; K]; 4];
282 dirs4[0][a] = 1.0;
283 dirs4[1][b] = 1.0;
284 dirs4[2] = u;
285 dirs4[3] = w;
286 let jet4 = nonlinear_dirjet_program(p, &dirs4);
287 assert_close(
288 jet4.coeff(jet4.coeffs.len() - 1),
289 fourth[a][b],
290 &format!("fourth contraction ({a},{b})"),
291 );
292 }
293 }
294 }
295
296 fn nonlinear_tower_program(p: [f64; 3]) -> Tower4<3> {
297 let x = Tower4::<3>::variable(p[0], 0);
298 let y = Tower4::<3>::variable(p[1], 1);
299 let z = Tower4::<3>::variable(p[2], 2);
300 let eta = x * y + x * z + z * 0.7;
301 let g = eta.exp();
302 (g + 2.0).ln() * g
303 }
304
305 fn nonlinear_dirjet_program(p: [f64; 3], dirs: &[[f64; 3]]) -> MultiDirJet {
306 let n_dirs = dirs.len();
307 let x = MultiDirJet::linear(n_dirs, p[0], &direction_components(dirs, 0));
308 let y = MultiDirJet::linear(n_dirs, p[1], &direction_components(dirs, 1));
309 let z = MultiDirJet::linear(n_dirs, p[2], &direction_components(dirs, 2));
310 let eta = x.mul(&y).add(&x.mul(&z)).add(&z.scale(0.7));
311 let g = exp_dirjet(&eta);
312 ln_dirjet(&g.add(&MultiDirJet::constant(n_dirs, 2.0))).mul(&g)
313 }
314
315 fn direction_components(dirs: &[[f64; 3]], axis: usize) -> Vec<f64> {
316 dirs.iter().map(|dir| dir[axis]).collect()
317 }
318
319 fn assert_close(got: f64, want: f64, label: &str) {
320 let tol = 1.0e-12 * want.abs().max(1.0);
321 assert!(
322 (got - want).abs() <= tol,
323 "{label}: got={got:.17e}, want={want:.17e}, diff={:.3e}, tol={tol:.3e}",
324 (got - want).abs()
325 );
326 }
327
328 fn exp_dirjet(j: &MultiDirJet) -> MultiDirJet {
329 let e = j.coeff(0).exp();
330 j.compose_unary([e, e, e, e, e])
331 }
332
333 fn ln_dirjet(j: &MultiDirJet) -> MultiDirJet {
334 let u = j.coeff(0);
335 let r = 1.0 / u;
336 j.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
337 }
338}