Skip to main content

gam_math/
jet_algebra.rs

1//! The single shared Faà di Bruno / Leibniz combinatorial kernel (#1151).
2//!
3//! Two jet representations live in this crate and historically each carried
4//! its own hand-written copy of the same calculus:
5//!
6//! * [`crate::jet_tower::Tower4`] — full dense derivative tensors
7//!   (`v`, `g`, `h`, `t3`, `t4`) in `K` primary variables, with the
8//!   Leibniz product and Faà di Bruno composition written out term-by-term
9//!   per derivative order.
10//! * [`crate::jet_partitions::MultiDirJet`] — bitmask-coefficient jet over
11//!   distinct seeded directions, with the same two rules written as general
12//!   submask / set-partition loops.
13//!
14//! The data layouts are legitimately different (a complete small-`K` tower
15//! vs. a handful of directions of a large-`K` expression) and stay separate.
16//! What is identical is the *combinatorics*: for a group of differentiation
17//! slots, the Leibniz rule sums over subsets of those slots and the Faà di
18//! Bruno rule sums over their set-partitions. This module owns that
19//! combinatorics once, as a layout-agnostic [`JetAlgebra`] trait plus walkers
20//! parameterised by closures that read each representation's own derivative
21//! for a slot-group. Both
22//! `Tower4` and `MultiDirJet` route their `mul` / `compose_unary` through
23//! these walkers, so a fix to the rule is a fix to both — and a bit-exact
24//! equivalence test (see `tests`) proves the two layouts agree.
25//!
26//! A "slot-group" is a list of positions `0..m` (the differentiation
27//! arguments of one output coefficient). Each representation maps a group to
28//! a derivative:
29//!
30//! * For a tensor index tuple `(i, j, k, l)`, the positions `0..m` carry
31//!   axis labels `[i, j, k, l]`; a sub-group of positions selects the
32//!   corresponding lower-order tensor entry (e.g. positions `{0, 2}` →
33//!   `h[i][k]`).
34//! * For a bitmask coefficient `mask`, the set bits are the slots; a
35//!   sub-group of bits is itself a sub-mask read straight out of `coeffs`.
36
37/// Walk the Leibniz product rule for an output of `m` differentiation slots.
38///
39/// `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`, summed over every subset `T`
40/// of the `m` positions. `left(t)` / `right(c)` receive the position lists of
41/// the chosen subset and its complement and must return the corresponding
42/// derivative of the two factors. Returns the summed output coefficient.
43///
44/// `m` is small (≤ 4 for the tower, ≤ 8 for the directional jet); the
45/// `2^m` subset walk is the exact rule, not a truncation.
46#[inline]
47pub(crate) fn leibniz_product<L, R>(positions: &[usize], mut left: L, mut right: R) -> f64
48where
49    L: FnMut(&[usize]) -> f64,
50    R: FnMut(&[usize]) -> f64,
51{
52    let m = positions.len();
53    assert!(
54        m <= usize::BITS as usize,
55        "too many differentiation slots for subset enumeration"
56    );
57    let mut subset = SlotBuf::new();
58    let mut complement = SlotBuf::new();
59    let mut total = 0.0;
60    for sub in 0u32..(1u32 << m) {
61        subset.clear();
62        complement.clear();
63        for (bit, &pos) in positions.iter().enumerate() {
64            if sub & (1u32 << bit) != 0 {
65                subset.push(pos);
66            } else {
67                complement.push(pos);
68            }
69        }
70        total += left(subset.as_slice()) * right(complement.as_slice());
71    }
72    total
73}
74
75/// Walk the multivariate Faà di Bruno rule for an output of `m` slots.
76///
77/// `D(f∘u) = Σ_{partitions π of the m slots} f^{(|π|)}(u) · Π_{B ∈ π} D_B(u)`.
78/// `derivs[r]` is `f^{(r)}` at the inner value; `inner(block)` returns the
79/// derivative of the inner expression for a block's position list. Returns the
80/// summed output coefficient. Blocks of order ≥ `derivs.len()` are skipped
81/// (their `f^{(r)}` is beyond the truncation), matching both legacy paths.
82#[inline]
83pub fn faa_di_bruno<F>(positions: &[usize], derivs: &[f64], mut inner: F) -> f64
84where
85    F: FnMut(&[usize]) -> f64,
86{
87    let m = positions.len();
88    if m == 0 {
89        return derivs[0];
90    }
91    let mut total = 0.0;
92    for_each_partition(m, &mut |blocks: &[SlotBuf]| {
93        let order = blocks.len();
94        if order >= derivs.len() {
95            return;
96        }
97        let mut prod = derivs[order];
98        for block in blocks {
99            // Translate block positions to their axis labels for `inner`.
100            let mut labelled = SlotBuf::new();
101            for &p in block.as_slice() {
102                labelled.push(positions[p]);
103            }
104            prod *= inner(labelled.as_slice());
105        }
106        total += prod;
107    });
108    total
109}
110
111/// Layout hook for jets that share the Faà di Bruno unary-composition kernel.
112///
113/// `DERIVS` is the length of the unary derivative stack: `5` for fourth-order
114/// jets (`[f, f′, f″, f‴, f⁗]`) and `3` for second-order jets. Implementors own
115/// how slot lists map to their storage; the kernel owns the set-partition rule.
116pub(crate) trait JetAlgebra<const DERIVS: usize>: Sized {
117    /// Read the derivative for a slot list. An empty list is the value channel.
118    fn derivative(&self, positions: &[usize]) -> f64;
119
120    /// Build a jet with every stored derivative filled by `f(positions)`.
121    fn map_derivatives<F>(&self, f: F) -> Self
122    where
123        F: FnMut(&[usize]) -> f64;
124
125    /// Exact multivariate Faà di Bruno composition.
126    fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
127        compose_unary_kernel(self, derivs)
128    }
129}
130
131/// The single unary-composition kernel shared by tower and bitmask jets.
132#[inline]
133pub(crate) fn compose_unary_kernel<J, const DERIVS: usize>(inner: &J, derivs: [f64; DERIVS]) -> J
134where
135    J: JetAlgebra<DERIVS>,
136{
137    inner.map_derivatives(|positions| {
138        faa_di_bruno(positions, &derivs, |block| inner.derivative(block))
139    })
140}
141
142/// A tiny inline stack of slot indices — no heap traffic on the hot per-row
143/// path. Capacity (8) covers the deepest tower (order 4) and the directional
144/// jet's eight-bit masks.
145#[derive(Clone, Copy)]
146pub(crate) struct SlotBuf {
147    data: [usize; 8],
148    len: usize,
149}
150
151impl SlotBuf {
152    #[inline]
153    pub(crate) fn new() -> Self {
154        Self {
155            data: [0; 8],
156            len: 0,
157        }
158    }
159    #[inline]
160    fn clear(&mut self) {
161        self.len = 0;
162    }
163    #[inline]
164    fn push(&mut self, v: usize) {
165        self.data[self.len] = v;
166        self.len += 1;
167    }
168    /// Append a slot index. Public to the crate so other jet layouts (the
169    /// bitmask [`crate::jet_partitions`]) can build a slot list to hand the
170    /// shared walkers.
171    #[inline]
172    pub(crate) fn push_slot(&mut self, v: usize) {
173        self.push(v);
174    }
175    #[inline]
176    pub(crate) fn as_slice(&self) -> &[usize] {
177        &self.data[..self.len]
178    }
179}
180
181/// Invoke `f` once per set-partition of positions `0..m`, passing the blocks
182/// as slot lists. Recursive "assign each element to an existing or new block"
183/// enumeration — allocation-free via the fixed-capacity [`SlotBuf`].
184fn for_each_partition(m: usize, f: &mut dyn FnMut(&[SlotBuf])) {
185    let mut blocks: [SlotBuf; 8] = [SlotBuf::new(); 8];
186    recurse(0, m, &mut blocks, 0, f);
187}
188
189fn recurse(
190    elem: usize,
191    m: usize,
192    blocks: &mut [SlotBuf; 8],
193    n_blocks: usize,
194    f: &mut dyn FnMut(&[SlotBuf]),
195) {
196    if elem == m {
197        f(&blocks[..n_blocks]);
198        return;
199    }
200    // Place `elem` into each existing block.
201    for b in 0..n_blocks {
202        blocks[b].push(elem);
203        recurse(elem + 1, m, blocks, n_blocks, f);
204        blocks[b].len -= 1;
205    }
206    // Or open a new block with `elem` alone.
207    blocks[n_blocks].clear();
208    blocks[n_blocks].push(elem);
209    recurse(elem + 1, m, blocks, n_blocks + 1, f);
210}
211
212#[cfg(test)]
213mod tests {
214    use crate::jet_partitions::MultiDirJet;
215    use crate::jet_tower::Tower4;
216
217    /// Bit-exact equivalence proof: evaluate the SAME polynomial-plus-unary
218    /// composition on both jet layouts and assert every shared derivative
219    /// coefficient is identical to the last bit. Because both layouts now
220    /// route through this module's [`leibniz_product`] / [`faa_di_bruno`]
221    /// walkers, the equality is a statement that the two *data structures*
222    /// expose the same single arithmetic kernel — the #1151 guarantee.
223    ///
224    /// Program (K=2 directions / primaries, seeded as variables `x = p0`,
225    /// `z = p1`):  `g = exp(x * z + x)` then `f = ln(g + 2) * g`.
226    /// Both `mul` and `compose_unary` (exp, ln) are exercised.
227    #[test]
228    fn tower_and_dirjet_agree_bit_exact() {
229        let x = 0.37_f64;
230        let z = -0.81_f64;
231
232        // ── Tower4<2> path ──
233        let tx = Tower4::<2>::variable(x, 0);
234        let tz = Tower4::<2>::variable(z, 1);
235        let tg = (tx * tz + tx).exp();
236        let tf = (tg + 2.0).ln() * tg;
237
238        // ── MultiDirJet (2 directions) path ──
239        let jx = MultiDirJet::linear(2, x, &[1.0, 0.0]);
240        let jz = MultiDirJet::linear(2, z, &[0.0, 1.0]);
241        let jg = exp_dirjet(&jx.mul(&jz).add(&jx));
242        let jf = ln_dirjet(&jg.add(&MultiDirJet::constant(2, 2.0))).mul(&jg);
243
244        // The directional jet carries coefficients for masks
245        //   0b00=value, 0b01=∂x, 0b10=∂z, 0b11=∂x∂z.
246        // The tower carries the same derivatives as tensor entries:
247        //   v, g[0], g[1], h[0][1].
248        assert_eq!(jf.coeff(0b00), tf.v, "value");
249        assert_eq!(jf.coeff(0b01), tf.g[0], "∂x");
250        assert_eq!(jf.coeff(0b10), tf.g[1], "∂z");
251        assert_eq!(jf.coeff(0b11), tf.h[0][1], "∂x∂z");
252        // Symmetry of the tower's mixed second partial is also bit-exact.
253        assert_eq!(tf.h[0][1], tf.h[1][0], "tower mixed-partial symmetry");
254    }
255
256    #[test]
257    fn tower_contractions_match_dirjet_directional_coefficients() {
258        const K: usize = 3;
259        let p = [0.37_f64, -0.42_f64, 0.19_f64];
260        let q = [0.25_f64, -0.7_f64, 1.3_f64];
261        let u = [-0.4_f64, 0.9_f64, 0.15_f64];
262        let w = [1.1_f64, -0.2_f64, 0.6_f64];
263
264        let tower = nonlinear_tower_program(p);
265        let third = tower.third_contracted(&q);
266        let fourth = tower.fourth_contracted(&u, &w);
267
268        for a in 0..K {
269            for b in 0..K {
270                let mut dirs3 = [[0.0; K]; 3];
271                dirs3[0][a] = 1.0;
272                dirs3[1][b] = 1.0;
273                dirs3[2] = q;
274                let jet3 = nonlinear_dirjet_program(p, &dirs3);
275                assert_close(
276                    jet3.coeff(jet3.coeffs.len() - 1),
277                    third[a][b],
278                    &format!("third contraction ({a},{b})"),
279                );
280
281                let mut dirs4 = [[0.0; K]; 4];
282                dirs4[0][a] = 1.0;
283                dirs4[1][b] = 1.0;
284                dirs4[2] = u;
285                dirs4[3] = w;
286                let jet4 = nonlinear_dirjet_program(p, &dirs4);
287                assert_close(
288                    jet4.coeff(jet4.coeffs.len() - 1),
289                    fourth[a][b],
290                    &format!("fourth contraction ({a},{b})"),
291                );
292            }
293        }
294    }
295
296    fn nonlinear_tower_program(p: [f64; 3]) -> Tower4<3> {
297        let x = Tower4::<3>::variable(p[0], 0);
298        let y = Tower4::<3>::variable(p[1], 1);
299        let z = Tower4::<3>::variable(p[2], 2);
300        let eta = x * y + x * z + z * 0.7;
301        let g = eta.exp();
302        (g + 2.0).ln() * g
303    }
304
305    fn nonlinear_dirjet_program(p: [f64; 3], dirs: &[[f64; 3]]) -> MultiDirJet {
306        let n_dirs = dirs.len();
307        let x = MultiDirJet::linear(n_dirs, p[0], &direction_components(dirs, 0));
308        let y = MultiDirJet::linear(n_dirs, p[1], &direction_components(dirs, 1));
309        let z = MultiDirJet::linear(n_dirs, p[2], &direction_components(dirs, 2));
310        let eta = x.mul(&y).add(&x.mul(&z)).add(&z.scale(0.7));
311        let g = exp_dirjet(&eta);
312        ln_dirjet(&g.add(&MultiDirJet::constant(n_dirs, 2.0))).mul(&g)
313    }
314
315    fn direction_components(dirs: &[[f64; 3]], axis: usize) -> Vec<f64> {
316        dirs.iter().map(|dir| dir[axis]).collect()
317    }
318
319    fn assert_close(got: f64, want: f64, label: &str) {
320        let tol = 1.0e-12 * want.abs().max(1.0);
321        assert!(
322            (got - want).abs() <= tol,
323            "{label}: got={got:.17e}, want={want:.17e}, diff={:.3e}, tol={tol:.3e}",
324            (got - want).abs()
325        );
326    }
327
328    fn exp_dirjet(j: &MultiDirJet) -> MultiDirJet {
329        let e = j.coeff(0).exp();
330        j.compose_unary([e, e, e, e, e])
331    }
332
333    fn ln_dirjet(j: &MultiDirJet) -> MultiDirJet {
334        let u = j.coeff(0);
335        let r = 1.0 / u;
336        j.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
337    }
338}