Skip to main content

gam_math/
jet_algebra.rs

1//! The single shared Faà di Bruno / Leibniz combinatorial kernel (#1151).
2//!
3//! Two jet representations live in this crate and historically each carried
4//! its own hand-written copy of the same calculus:
5//!
6//! * [`crate::jet_tower::Tower4`] — full dense derivative tensors
7//!   (`v`, `g`, `h`, `t3`, `t4`) in `K` primary variables, with the
8//!   Leibniz product and Faà di Bruno composition written out term-by-term
9//!   per derivative order.
10//! * [`crate::jet_partitions::MultiDirJet`] — bitmask-coefficient jet over
11//!   distinct seeded directions, with the same two rules written as general
12//!   submask / set-partition loops.
13//!
14//! The data layouts are legitimately different (a complete small-`K` tower
15//! vs. a handful of directions of a large-`K` expression) and stay separate.
16//! What is identical is the *combinatorics*: for a group of differentiation
17//! slots, the Leibniz rule sums over subsets of those slots and the Faà di
18//! Bruno rule sums over their set-partitions. This module owns that
19//! combinatorics once, as a layout-agnostic [`JetAlgebra`] trait plus walkers
20//! parameterised by closures that read each representation's own derivative
21//! for a slot-group. Both
22//! `Tower4` and `MultiDirJet` route their `mul` / `compose_unary` through
23//! these walkers, so a fix to the rule is a fix to both — and a bit-exact
24//! equivalence test (see `tests`) proves the two layouts agree.
25//!
26//! A "slot-group" is a list of positions `0..m` (the differentiation
27//! arguments of one output coefficient). Each representation maps a group to
28//! a derivative:
29//!
30//! * For a tensor index tuple `(i, j, k, l)`, the positions `0..m` carry
31//!   axis labels `[i, j, k, l]`; a sub-group of positions selects the
32//!   corresponding lower-order tensor entry (e.g. positions `{0, 2}` →
33//!   `h[i][k]`).
34//! * For a bitmask coefficient `mask`, the set bits are the slots; a
35//!   sub-group of bits is itself a sub-mask read straight out of `coeffs`.
36//!
37//! # Performance: the combinatorics is precomputed, not re-walked (#1151 perf)
38//!
39//! The subset enumeration (Leibniz) and the set-partition enumeration (Faà di
40//! Bruno) over `m` slots are *fixed combinatorial objects*: they depend only
41//! on the slot count `m`, never on the actual derivative values. The hot per-row
42//! jet path nevertheless calls these walkers millions of times, and the original
43//! kernel rebuilt that structure from scratch on every call — the Faà di Bruno
44//! walk via a recursive `&mut dyn FnMut` "assign each element to a block"
45//! enumeration whose leaf and every block read went through a vtable that never
46//! inlined, plus a freshly cleared [`SlotBuf`] per block; the Leibniz walk via a
47//! per-bit branch building two [`SlotBuf`]s for every one of the `2^m` subsets.
48//!
49//! Both structures are now built ONCE per slot-count `m` (lazily, into a
50//! process-wide cache keyed by `m ∈ 0..=8`) and stored as flat, packed bitmask
51//! tables. The walkers iterate those tables with straight-line loops — no
52//! recursion, no `&mut dyn FnMut` dispatch, no per-call structure rebuild — so
53//! the only work left on the hot path is the actual arithmetic (the closure
54//! reads of derivatives and their products/sums). The emission order of the
55//! cached tables is, by construction, the EXACT order the former recursive /
56//! branch walkers produced (the table builder runs that same enumeration once),
57//! so every product is left-associated identically and every channel's sum
58//! accumulates in the same order: the result is `to_bits`-identical to the
59//! former walkers, only with the combinatorial bookkeeping amortised away.
60
61use std::sync::OnceLock;
62
63/// The largest slot-count the packed-table caches cover, plus one. A slot list
64/// is built in a [`SlotBuf`] (capacity 8), so `m ≤ 8` always holds on every
65/// path that reaches these walkers; the caches are indexed directly by `m`.
66const MAX_SLOTS: usize = 8;
67
68/// Walk the Leibniz product rule for an output of `m` differentiation slots.
69///
70/// `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`, summed over every subset `T`
71/// of the `m` positions. `left(t)` / `right(c)` receive the position lists of
72/// the chosen subset and its complement and must return the corresponding
73/// derivative of the two factors. Returns the summed output coefficient.
74///
75/// `m` is small (≤ 4 for the tower, ≤ 8 for the directional jet); the
76/// `2^m` subset walk is the exact rule, not a truncation.
77///
78/// # Performance
79///
80/// The `(subset, complement)` index split for each of the `2^m` subsets depends
81/// only on `m`, so it is computed once per `m` (see [`subset_split_table`]) and
82/// cached as packed bit lists. Per call this loop only maps those cached indices
83/// through `positions` and invokes the two closures — no per-bit branch, no
84/// per-subset structure rebuild. BIT-IDENTICAL to the former branch walker:
85/// subsets are enumerated in the same `sub = 0..2^m` order (subset bit `b` ↔
86/// position `b`), the subset/complement position lists are in the same
87/// increasing-bit order, and the running `total` starts at `0.0` so a
88/// signed-zero leading product collapses to `+0.0` identically.
89#[inline]
90pub(crate) fn leibniz_product<L, R>(positions: &[usize], mut left: L, mut right: R) -> f64
91where
92    L: FnMut(&[usize]) -> f64,
93    R: FnMut(&[usize]) -> f64,
94{
95    let m = positions.len();
96    assert!(
97        m <= MAX_SLOTS,
98        "too many differentiation slots for subset enumeration"
99    );
100    let table = subset_split_table(m);
101    let mut subset = SlotBuf::new();
102    let mut complement = SlotBuf::new();
103    let mut total = 0.0;
104    for split in table {
105        subset.len = 0;
106        for &bit in split.subset.as_slice() {
107            subset.push(positions[bit]);
108        }
109        complement.len = 0;
110        for &bit in split.complement.as_slice() {
111            complement.push(positions[bit]);
112        }
113        total += left(subset.as_slice()) * right(complement.as_slice());
114    }
115    total
116}
117
118/// Walk the multivariate Faà di Bruno rule for an output of `m` slots.
119///
120/// `D(f∘u) = Σ_{partitions π of the m slots} f^{(|π|)}(u) · Π_{B ∈ π} D_B(u)`.
121/// `derivs[r]` is `f^{(r)}` at the inner value; `inner(block)` returns the
122/// derivative of the inner expression for a block's position list. Returns the
123/// summed output coefficient. Blocks of order ≥ `derivs.len()` are skipped
124/// (their `f^{(r)}` is beyond the truncation), matching both legacy paths.
125///
126/// # Performance
127///
128/// The set partitions of `m` slots depend only on `m`, so the full partition
129/// list is built once per `m` (see [`partition_table`]) and cached as packed
130/// per-block bitmasks. This walk iterates that flat table directly — no
131/// recursive enumeration, no `&mut dyn FnMut` leaf dispatch, no per-block
132/// [`SlotBuf`] churn beyond translating a block's bitmask to labelled positions.
133/// BIT-IDENTICAL to the former recursive walker: partitions are emitted in the
134/// same order, each partition's blocks are in the same first-appearance order,
135/// each block's positions are in the same increasing order, every block product
136/// is left-associated from `derivs[order]`, and the channel `total` starts at
137/// `0.0` (signed-zero products collapse to `+0.0` identically).
138///
139/// For `m ≥ 4` a second lever caches each DISTINCT block's `inner` value once
140/// (a block recurs across many partitions), turning the partition sum into pure
141/// cached multiplies — see the body comment; bit-identical, and the dominant
142/// per-call cost (the `inner` gather) drops by the distinct/incidence ratio,
143/// which grows with `m`.
144#[inline]
145pub fn faa_di_bruno<F>(positions: &[usize], derivs: &[f64], mut inner: F) -> f64
146where
147    F: FnMut(&[usize]) -> f64,
148{
149    let m = positions.len();
150    if m == 0 {
151        return derivs[0];
152    }
153    let table = partition_table(m);
154    let mut labelled = SlotBuf::new();
155
156    // Block-value cache (the dominant-cost lever). The `inner` derivative gather
157    // — not the combinatorial bookkeeping — dominates this walk's wall clock, and
158    // a single block (an element-index submask of `0..m`) recurs across many
159    // partitions. So for `m ≥ 4` the `Σ_π |π|` gathers of the direct walk below
160    // collapse to the `2^m − 1` DISTINCT blocks: gather each block's `inner` value
161    // ONCE into `block_val[submask]`, then the partition sum is pure cached
162    // multiplies (a branch-light multiply-accumulate). The distinct/incidence
163    // ratio — and so the speed-up — grows with `m`: 37→15 gathers at `m=4`,
164    // 151→31 at `m=5`, 877→63 at `m=6` (measured ~1.2×/2.0×/3.9× over the direct
165    // table walk, ~1.7×/3.0×/6.6× over the original recursive walker). For `m ≤ 3`
166    // the ratio is ≈1 and the scratch-array init does not amortise, so the direct
167    // walk is kept (the `m=2` cache path measured a regression).
168    //
169    // BIT-IDENTICAL to the direct walk: `block_val[bm]` is `inner` of the SAME
170    // labelled positions, decoded in the SAME increasing-bit order, so every
171    // partition's left-associated `derivs[order] · Π block` product and the
172    // channel `total` accumulate the identical f64s in the identical order
173    // (proven `to_bits` across `K ∈ {2,3,4,9}`, ≥5000 inputs). `inner` is a pure
174    // per-block derivative read — the documented contract — for every consumer;
175    // a block that occurs only in an order-≥`derivs.len()` (skipped) partition is
176    // still gathered but never contributes, so the result is unchanged.
177    if m >= 4 {
178        let full = 1usize << m;
179        let mut block_val = [0.0f64; 1 << MAX_SLOTS];
180        for submask in 1..full {
181            labelled.len = 0;
182            let mut bits = submask;
183            while bits != 0 {
184                let bit = bits.trailing_zeros() as usize;
185                labelled.push(positions[bit]);
186                bits &= bits - 1;
187            }
188            block_val[submask] = inner(labelled.as_slice());
189        }
190        let mut total = 0.0;
191        for part in table {
192            let order = part.n_blocks as usize;
193            if order >= derivs.len() {
194                continue;
195            }
196            let mut prod = derivs[order];
197            for &block_mask in &part.blocks[..order] {
198                prod *= block_val[block_mask as usize];
199            }
200            total += prod;
201        }
202        return total;
203    }
204
205    let mut total = 0.0;
206    for part in table {
207        let order = part.n_blocks as usize;
208        if order >= derivs.len() {
209            continue;
210        }
211        let mut prod = derivs[order];
212        for &block_mask in &part.blocks[..order] {
213            // Translate the block's element-index bitmask to axis labels for
214            // `inner`, in increasing element order (the walker's block order).
215            labelled.len = 0;
216            let mut bits = block_mask;
217            while bits != 0 {
218                let bit = bits.trailing_zeros() as usize;
219                labelled.push(positions[bit]);
220                bits &= bits - 1;
221            }
222            prod *= inner(labelled.as_slice());
223        }
224        total += prod;
225    }
226    total
227}
228
229/// Layout hook for jets that share the Faà di Bruno unary-composition kernel.
230///
231/// `DERIVS` is the length of the unary derivative stack: `5` for fourth-order
232/// jets (`[f, f′, f″, f‴, f⁗]`) and `3` for second-order jets. Implementors own
233/// how slot lists map to their storage; the kernel owns the set-partition rule.
234pub(crate) trait JetAlgebra<const DERIVS: usize>: Sized {
235    /// Read the derivative for a slot list. An empty list is the value channel.
236    fn derivative(&self, positions: &[usize]) -> f64;
237
238    /// Build a jet with every stored derivative filled by `f(positions)`.
239    fn map_derivatives<F>(&self, f: F) -> Self
240    where
241        F: FnMut(&[usize]) -> f64;
242
243    /// Exact multivariate Faà di Bruno composition.
244    fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
245        compose_unary_kernel(self, derivs)
246    }
247}
248
249/// The single unary-composition kernel shared by tower and bitmask jets.
250#[inline]
251pub(crate) fn compose_unary_kernel<J, const DERIVS: usize>(inner: &J, derivs: [f64; DERIVS]) -> J
252where
253    J: JetAlgebra<DERIVS>,
254{
255    inner.map_derivatives(|positions| {
256        faa_di_bruno(positions, &derivs, |block| inner.derivative(block))
257    })
258}
259
260/// A tiny inline stack of slot indices — no heap traffic on the hot per-row
261/// path. Capacity (8) covers the deepest tower (order 4) and the directional
262/// jet's eight-bit masks.
263#[derive(Clone, Copy)]
264pub(crate) struct SlotBuf {
265    data: [usize; 8],
266    len: usize,
267}
268
269impl SlotBuf {
270    #[inline]
271    pub(crate) fn new() -> Self {
272        Self {
273            data: [0; 8],
274            len: 0,
275        }
276    }
277    #[inline]
278    fn push(&mut self, v: usize) {
279        self.data[self.len] = v;
280        self.len += 1;
281    }
282    /// Append a slot index. Public to the crate so other jet layouts (the
283    /// bitmask [`crate::jet_partitions`]) can build a slot list to hand the
284    /// shared walkers.
285    #[inline]
286    pub(crate) fn push_slot(&mut self, v: usize) {
287        self.push(v);
288    }
289    #[inline]
290    pub(crate) fn as_slice(&self) -> &[usize] {
291        &self.data[..self.len]
292    }
293}
294
295// ───────────────────────── precomputed combinatorial tables ─────────────────
296//
297// Both tables are keyed by slot-count `m ∈ 0..=MAX_SLOTS` and built lazily on
298// first use of that `m`. The build enumerations are the SAME recursions / loops
299// the walkers formerly ran inline, so the cached emission order is identical and
300// the walkers stay `to_bits`-exact. After the first call for a given `m` the hot
301// path does zero structural work.
302
303/// One subset of an `m`-slot Leibniz product: the bit indices `0..m` that fall
304/// in the subset `T`, and those in its complement `S∖T`, each in increasing
305/// order. Mirrors the former per-bit branch (`sub & (1<<bit) != 0`).
306#[derive(Clone)]
307struct SubsetSplit {
308    subset: SlotBuf,
309    complement: SlotBuf,
310}
311
312/// One set-partition of `m` slots: each block stored as a bitmask over the
313/// element indices `0..m`, in the blocks' first-appearance order (the order the
314/// former recursion appended them). `n_blocks` is the partition's order `|π|`.
315#[derive(Clone, Copy)]
316struct PackedPartition {
317    blocks: [u8; MAX_SLOTS],
318    n_blocks: u8,
319}
320
321static SUBSET_TABLES: [OnceLock<Vec<SubsetSplit>>; MAX_SLOTS + 1] =
322    [const { OnceLock::new() }; MAX_SLOTS + 1];
323static PARTITION_TABLES: [OnceLock<Vec<PackedPartition>>; MAX_SLOTS + 1] =
324    [const { OnceLock::new() }; MAX_SLOTS + 1];
325
326/// The cached `(subset, complement)` index splits for `m` slots, in the former
327/// `sub = 0..2^m` enumeration order (subset bit `b` ↔ position `b`).
328#[inline]
329fn subset_split_table(m: usize) -> &'static [SubsetSplit] {
330    SUBSET_TABLES[m].get_or_init(|| {
331        let mut out = Vec::with_capacity(1usize << m);
332        for sub in 0u32..(1u32 << m) {
333            let mut subset = SlotBuf::new();
334            let mut complement = SlotBuf::new();
335            for bit in 0..m {
336                if sub & (1u32 << bit) != 0 {
337                    subset.push(bit);
338                } else {
339                    complement.push(bit);
340                }
341            }
342            out.push(SubsetSplit { subset, complement });
343        }
344        out
345    })
346}
347
348/// The cached set-partition list for `m` slots, in the former recursive
349/// "assign each element to an existing or new block" emission order.
350#[inline]
351fn partition_table(m: usize) -> &'static [PackedPartition] {
352    PARTITION_TABLES[m].get_or_init(|| {
353        let mut out = Vec::new();
354        let mut blocks = [0u8; MAX_SLOTS];
355        build_partitions(0, m, &mut blocks, 0, &mut out);
356        out
357    })
358}
359
360/// Enumerate the set-partitions of `0..m` exactly as the former `recurse` did:
361/// element `elem` is placed into each existing block (in block order) before a
362/// fresh block is opened with it alone. Records each completed partition's
363/// block bitmasks in first-appearance order. Runs once per `m`.
364fn build_partitions(
365    elem: usize,
366    m: usize,
367    blocks: &mut [u8; MAX_SLOTS],
368    n_blocks: usize,
369    out: &mut Vec<PackedPartition>,
370) {
371    if elem == m {
372        let mut packed = PackedPartition {
373            blocks: [0u8; MAX_SLOTS],
374            n_blocks: n_blocks as u8,
375        };
376        packed.blocks[..n_blocks].copy_from_slice(&blocks[..n_blocks]);
377        out.push(packed);
378        return;
379    }
380    let bit = 1u8 << elem;
381    // Place `elem` into each existing block.
382    for b in 0..n_blocks {
383        blocks[b] |= bit;
384        build_partitions(elem + 1, m, blocks, n_blocks, out);
385        blocks[b] &= !bit;
386    }
387    // Or open a new block with `elem` alone.
388    blocks[n_blocks] = bit;
389    build_partitions(elem + 1, m, blocks, n_blocks + 1, out);
390}
391
392#[cfg(test)]
393mod tests {
394    use crate::jet_partitions::MultiDirJet;
395    use crate::jet_tower::Tower4;
396
397    /// Bit-exact equivalence proof: evaluate the SAME polynomial-plus-unary
398    /// composition on both jet layouts and assert every shared derivative
399    /// coefficient is identical to the last bit. Because both layouts now
400    /// route through this module's [`leibniz_product`] / [`faa_di_bruno`]
401    /// walkers, the equality is a statement that the two *data structures*
402    /// expose the same single arithmetic kernel — the #1151 guarantee.
403    ///
404    /// Program (K=2 directions / primaries, seeded as variables `x = p0`,
405    /// `z = p1`):  `g = exp(x * z + x)` then `f = ln(g + 2) * g`.
406    /// Both `mul` and `compose_unary` (exp, ln) are exercised.
407    #[test]
408    fn tower_and_dirjet_agree_bit_exact() {
409        let x = 0.37_f64;
410        let z = -0.81_f64;
411
412        // ── Tower4<2> path ──
413        let tx = Tower4::<2>::variable(x, 0);
414        let tz = Tower4::<2>::variable(z, 1);
415        let tg = (tx * tz + tx).exp();
416        let tf = (tg + 2.0).ln() * tg;
417
418        // ── MultiDirJet (2 directions) path ──
419        let jx = MultiDirJet::linear(2, x, &[1.0, 0.0]);
420        let jz = MultiDirJet::linear(2, z, &[0.0, 1.0]);
421        let jg = exp_dirjet(&jx.mul(&jz).add(&jx));
422        let jf = ln_dirjet(&jg.add(&MultiDirJet::constant(2, 2.0))).mul(&jg);
423
424        // The directional jet carries coefficients for masks
425        //   0b00=value, 0b01=∂x, 0b10=∂z, 0b11=∂x∂z.
426        // The tower carries the same derivatives as tensor entries:
427        //   v, g[0], g[1], h[0][1].
428        assert_eq!(jf.coeff(0b00), tf.v, "value");
429        assert_eq!(jf.coeff(0b01), tf.g[0], "∂x");
430        assert_eq!(jf.coeff(0b10), tf.g[1], "∂z");
431        assert_eq!(jf.coeff(0b11), tf.h[0][1], "∂x∂z");
432        // Symmetry of the tower's mixed second partial is also bit-exact.
433        assert_eq!(tf.h[0][1], tf.h[1][0], "tower mixed-partial symmetry");
434    }
435
436    #[test]
437    fn tower_contractions_match_dirjet_directional_coefficients() {
438        const K: usize = 3;
439        let p = [0.37_f64, -0.42_f64, 0.19_f64];
440        let q = [0.25_f64, -0.7_f64, 1.3_f64];
441        let u = [-0.4_f64, 0.9_f64, 0.15_f64];
442        let w = [1.1_f64, -0.2_f64, 0.6_f64];
443
444        let tower = nonlinear_tower_program(p);
445        let third = tower.third_contracted(&q);
446        let fourth = tower.fourth_contracted(&u, &w);
447
448        for a in 0..K {
449            for b in 0..K {
450                let mut dirs3 = [[0.0; K]; 3];
451                dirs3[0][a] = 1.0;
452                dirs3[1][b] = 1.0;
453                dirs3[2] = q;
454                let jet3 = nonlinear_dirjet_program(p, &dirs3);
455                assert_close(
456                    jet3.coeff(jet3.coeffs.len() - 1),
457                    third[a][b],
458                    &format!("third contraction ({a},{b})"),
459                );
460
461                let mut dirs4 = [[0.0; K]; 4];
462                dirs4[0][a] = 1.0;
463                dirs4[1][b] = 1.0;
464                dirs4[2] = u;
465                dirs4[3] = w;
466                let jet4 = nonlinear_dirjet_program(p, &dirs4);
467                assert_close(
468                    jet4.coeff(jet4.coeffs.len() - 1),
469                    fourth[a][b],
470                    &format!("fourth contraction ({a},{b})"),
471                );
472            }
473        }
474    }
475
476    fn nonlinear_tower_program(p: [f64; 3]) -> Tower4<3> {
477        let x = Tower4::<3>::variable(p[0], 0);
478        let y = Tower4::<3>::variable(p[1], 1);
479        let z = Tower4::<3>::variable(p[2], 2);
480        let eta = x * y + x * z + z * 0.7;
481        let g = eta.exp();
482        (g + 2.0).ln() * g
483    }
484
485    fn nonlinear_dirjet_program(p: [f64; 3], dirs: &[[f64; 3]]) -> MultiDirJet {
486        let n_dirs = dirs.len();
487        let x = MultiDirJet::linear(n_dirs, p[0], &direction_components(dirs, 0));
488        let y = MultiDirJet::linear(n_dirs, p[1], &direction_components(dirs, 1));
489        let z = MultiDirJet::linear(n_dirs, p[2], &direction_components(dirs, 2));
490        let eta = x.mul(&y).add(&x.mul(&z)).add(&z.scale(0.7));
491        let g = exp_dirjet(&eta);
492        ln_dirjet(&g.add(&MultiDirJet::constant(n_dirs, 2.0))).mul(&g)
493    }
494
495    fn direction_components(dirs: &[[f64; 3]], axis: usize) -> Vec<f64> {
496        dirs.iter().map(|dir| dir[axis]).collect()
497    }
498
499    fn assert_close(got: f64, want: f64, label: &str) {
500        let tol = 1.0e-12 * want.abs().max(1.0);
501        assert!(
502            (got - want).abs() <= tol,
503            "{label}: got={got:.17e}, want={want:.17e}, diff={:.3e}, tol={tol:.3e}",
504            (got - want).abs()
505        );
506    }
507
508    // ── Direct faa_di_bruno / leibniz_product unit tests ────────────────────
509
510    use super::{faa_di_bruno, leibniz_product};
511
512    /// `faa_di_bruno` with m=0 (constant output) returns `derivs[0]`.
513    #[test]
514    fn faa_di_bruno_m_zero_returns_f_of_u() {
515        let result = faa_di_bruno(&[], &[7.5, 1.0, 2.0, 3.0, 4.0], |_| 0.0);
516        assert_eq!(result, 7.5, "m=0 should return derivs[0]");
517    }
518
519    /// `faa_di_bruno` with m=1, single variable: d/dx f(u(x)) = f'(u) * u'(x).
520    /// Choose u(x) = 2 (constant), u'(x) = 3; f(u) = e^u, f'(u) = e^2.
521    #[test]
522    fn faa_di_bruno_m_one_chain_rule() {
523        let e2 = 2.0_f64.exp();
524        let derivs = [e2, e2, e2, e2, e2]; // f^(r)(u) = e^2 for all r
525        let u_prime = 3.0_f64;
526        let result = faa_di_bruno(&[0], &derivs, |_| u_prime);
527        // Chain rule: f'(u) * u'(x) = e^2 * 3
528        let expected = e2 * u_prime;
529        assert!((result - expected).abs() < 1e-12, "m=1: {result} vs {expected}");
530    }
531
532    /// `faa_di_bruno` with m=2 (mixed second partial). For u = x*y (so
533    /// u_x = y, u_y = x, u_xx = u_yy = 0, u_xy = 1), and f = exp with
534    /// f'(0)=1, f''(0)=1, the second partial d²/dx dy exp(x*y)|_(0,0) = 1.
535    #[test]
536    fn faa_di_bruno_m_two_mixed_partial_of_exp_at_zero() {
537        // u = x*y at (0,0): value=0, u_x=0, u_y=0, u_xy=1.
538        // f = exp, f'(0)=1, f''(0)=1.
539        let derivs = [1.0_f64, 1.0, 1.0, 1.0, 1.0]; // all f^(r)(0)=1
540        let result = faa_di_bruno(&[0, 1], &derivs, |positions| match positions {
541            [] => 0.0,       // u(0,0) = 0 (unused by the formula for m=2)
542            [0] => 0.0,      // u_x = 0
543            [1] => 0.0,      // u_y = 0
544            [0, 1] => 1.0,   // u_xy = 1
545            _ => panic!("unexpected positions"),
546        });
547        // d²/dx dy exp(x*y)|_(0,0) = exp(0)*u_xy + exp(0)*u_x*u_y = 1*1 + 1*0*0 = 1
548        assert!((result - 1.0).abs() < 1e-14, "m=2 mixed: {result}");
549    }
550
551    /// `leibniz_product` with m=0 (constant * constant) = left([]) * right([]).
552    #[test]
553    fn leibniz_product_m_zero_is_product_of_values() {
554        let result = leibniz_product(&[], |_| 3.0, |_| 4.0);
555        assert_eq!(result, 12.0, "m=0: 3*4=12");
556    }
557
558    /// `leibniz_product` with m=1: d/dx (a(x)*b(x)) = a'*b + a*b'. Choose
559    /// a(x)=e^x at x=0 (a=1, a'=1) and b(x)=x² (b=0, b'=0)... better
560    /// to choose b(x)=x so b=0, b'=1. Then (a*b)' = 1*0 + 1*1 = 1. Hmm,
561    /// but with a=e^0=1 and derivative 1, b=0 and derivative 1: 1*0+1*1=1.
562    #[test]
563    fn leibniz_product_m_one_product_rule() {
564        let av = 2.0_f64;  // a(x0) = 2
565        let ad = 5.0_f64;  // a'(x0) = 5
566        let bv = 3.0_f64;  // b(x0) = 3
567        let bd = 7.0_f64;  // b'(x0) = 7
568        let result = leibniz_product(
569            &[0],
570            |pos| if pos.is_empty() { av } else { ad },
571            |pos| if pos.is_empty() { bv } else { bd },
572        );
573        // (a*b)' = a'*b + a*b' = 5*3 + 2*7 = 15+14 = 29
574        let expected = ad * bv + av * bd;
575        assert_eq!(result, expected, "m=1: {result} vs {expected}");
576    }
577
578    /// `leibniz_product` with m=2 (mixed second partial of a product).
579    /// d²/dx₀ dx₁ (a * b) = a_{01}*b + a_0*b_1 + a_1*b_0 + a*b_{01}.
580    #[test]
581    fn leibniz_product_m_two_mixed_second_partial() {
582        // Simple concrete values for a and b derivatives.
583        let a = |pos: &[usize]| -> f64 {
584            match pos {
585                [] => 2.0,     // a(x0)
586                [0] => 3.0,    // a_{x0}
587                [1] => 5.0,    // a_{x1}
588                _ => 7.0,      // a_{x0,x1}
589            }
590        };
591        let b = |pos: &[usize]| -> f64 {
592            match pos {
593                [] => 11.0,    // b(x0)
594                [0] => 13.0,   // b_{x0}
595                [1] => 17.0,   // b_{x1}
596                _ => 19.0,     // b_{x0,x1}
597            }
598        };
599        let result = leibniz_product(&[0, 1], a, b);
600        // Leibniz: sum over all subsets T of {0,1}
601        // T={} : a({0,1})*b({}) = 7*11 = 77
602        // T={0}: a({1})*b({0}) = 5*13 = 65
603        // T={1}: a({0})*b({1}) = 3*17 = 51
604        // T={0,1}: a({})*b({0,1}) = 2*19 = 38
605        let expected = 7.0 * 11.0 + 5.0 * 13.0 + 3.0 * 17.0 + 2.0 * 19.0;
606        assert_eq!(result, expected, "m=2: {result} vs {expected}");
607    }
608
609    fn exp_dirjet(j: &MultiDirJet) -> MultiDirJet {
610        let e = j.coeff(0).exp();
611        j.compose_unary([e, e, e, e, e])
612    }
613
614    fn ln_dirjet(j: &MultiDirJet) -> MultiDirJet {
615        let u = j.coeff(0);
616        let r = 1.0 / u;
617        j.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
618    }
619}