gam_math/jet_algebra.rs
1//! The single shared Faà di Bruno / Leibniz combinatorial kernel (#1151).
2//!
3//! Two jet representations live in this crate and historically each carried
4//! its own hand-written copy of the same calculus:
5//!
6//! * [`crate::jet_tower::Tower4`] — full dense derivative tensors
7//! (`v`, `g`, `h`, `t3`, `t4`) in `K` primary variables, with the
8//! Leibniz product and Faà di Bruno composition written out term-by-term
9//! per derivative order.
10//! * [`crate::jet_partitions::MultiDirJet`] — bitmask-coefficient jet over
11//! distinct seeded directions, with the same two rules written as general
12//! submask / set-partition loops.
13//!
14//! The data layouts are legitimately different (a complete small-`K` tower
15//! vs. a handful of directions of a large-`K` expression) and stay separate.
16//! What is identical is the *combinatorics*: for a group of differentiation
17//! slots, the Leibniz rule sums over subsets of those slots and the Faà di
18//! Bruno rule sums over their set-partitions. This module owns that
19//! combinatorics once, as a layout-agnostic [`JetAlgebra`] trait plus walkers
20//! parameterised by closures that read each representation's own derivative
21//! for a slot-group. Both
22//! `Tower4` and `MultiDirJet` route their `mul` / `compose_unary` through
23//! these walkers, so a fix to the rule is a fix to both — and a bit-exact
24//! equivalence test (see `tests`) proves the two layouts agree.
25//!
26//! A "slot-group" is a list of positions `0..m` (the differentiation
27//! arguments of one output coefficient). Each representation maps a group to
28//! a derivative:
29//!
30//! * For a tensor index tuple `(i, j, k, l)`, the positions `0..m` carry
31//! axis labels `[i, j, k, l]`; a sub-group of positions selects the
32//! corresponding lower-order tensor entry (e.g. positions `{0, 2}` →
33//! `h[i][k]`).
34//! * For a bitmask coefficient `mask`, the set bits are the slots; a
35//! sub-group of bits is itself a sub-mask read straight out of `coeffs`.
36//!
37//! # Performance: the combinatorics is precomputed, not re-walked (#1151 perf)
38//!
39//! The subset enumeration (Leibniz) and the set-partition enumeration (Faà di
40//! Bruno) over `m` slots are *fixed combinatorial objects*: they depend only
41//! on the slot count `m`, never on the actual derivative values. The hot per-row
42//! jet path nevertheless calls these walkers millions of times, and the original
43//! kernel rebuilt that structure from scratch on every call — the Faà di Bruno
44//! walk via a recursive `&mut dyn FnMut` "assign each element to a block"
45//! enumeration whose leaf and every block read went through a vtable that never
46//! inlined, plus a freshly cleared [`SlotBuf`] per block; the Leibniz walk via a
47//! per-bit branch building two [`SlotBuf`]s for every one of the `2^m` subsets.
48//!
49//! Both structures are now built ONCE per slot-count `m` (lazily, into a
50//! process-wide cache keyed by `m ∈ 0..=8`) and stored as flat, packed bitmask
51//! tables. The walkers iterate those tables with straight-line loops — no
52//! recursion, no `&mut dyn FnMut` dispatch, no per-call structure rebuild — so
53//! the only work left on the hot path is the actual arithmetic (the closure
54//! reads of derivatives and their products/sums). The emission order of the
55//! cached tables is, by construction, the EXACT order the former recursive /
56//! branch walkers produced (the table builder runs that same enumeration once),
57//! so every product is left-associated identically and every channel's sum
58//! accumulates in the same order: the result is `to_bits`-identical to the
59//! former walkers, only with the combinatorial bookkeeping amortised away.
60
61use std::sync::OnceLock;
62
63/// The largest slot-count the packed-table caches cover, plus one. A slot list
64/// is built in a [`SlotBuf`] (capacity 8), so `m ≤ 8` always holds on every
65/// path that reaches these walkers; the caches are indexed directly by `m`.
66const MAX_SLOTS: usize = 8;
67
68/// Walk the Leibniz product rule for an output of `m` differentiation slots.
69///
70/// `D_S(ab) = Σ_{T ⊆ S} D_T(a) · D_{S∖T}(b)`, summed over every subset `T`
71/// of the `m` positions. `left(t)` / `right(c)` receive the position lists of
72/// the chosen subset and its complement and must return the corresponding
73/// derivative of the two factors. Returns the summed output coefficient.
74///
75/// `m` is small (≤ 4 for the tower, ≤ 8 for the directional jet); the
76/// `2^m` subset walk is the exact rule, not a truncation.
77///
78/// # Performance
79///
80/// The `(subset, complement)` index split for each of the `2^m` subsets depends
81/// only on `m`, so it is computed once per `m` (see [`subset_split_table`]) and
82/// cached as packed bit lists. Per call this loop only maps those cached indices
83/// through `positions` and invokes the two closures — no per-bit branch, no
84/// per-subset structure rebuild. BIT-IDENTICAL to the former branch walker:
85/// subsets are enumerated in the same `sub = 0..2^m` order (subset bit `b` ↔
86/// position `b`), the subset/complement position lists are in the same
87/// increasing-bit order, and the running `total` starts at `0.0` so a
88/// signed-zero leading product collapses to `+0.0` identically.
89#[inline]
90pub(crate) fn leibniz_product<L, R>(positions: &[usize], mut left: L, mut right: R) -> f64
91where
92 L: FnMut(&[usize]) -> f64,
93 R: FnMut(&[usize]) -> f64,
94{
95 let m = positions.len();
96 assert!(
97 m <= MAX_SLOTS,
98 "too many differentiation slots for subset enumeration"
99 );
100 let table = subset_split_table(m);
101 let mut subset = SlotBuf::new();
102 let mut complement = SlotBuf::new();
103 let mut total = 0.0;
104 for split in table {
105 subset.len = 0;
106 for &bit in split.subset.as_slice() {
107 subset.push(positions[bit]);
108 }
109 complement.len = 0;
110 for &bit in split.complement.as_slice() {
111 complement.push(positions[bit]);
112 }
113 total += left(subset.as_slice()) * right(complement.as_slice());
114 }
115 total
116}
117
118/// Walk the multivariate Faà di Bruno rule for an output of `m` slots.
119///
120/// `D(f∘u) = Σ_{partitions π of the m slots} f^{(|π|)}(u) · Π_{B ∈ π} D_B(u)`.
121/// `derivs[r]` is `f^{(r)}` at the inner value; `inner(block)` returns the
122/// derivative of the inner expression for a block's position list. Returns the
123/// summed output coefficient. Blocks of order ≥ `derivs.len()` are skipped
124/// (their `f^{(r)}` is beyond the truncation), matching both legacy paths.
125///
126/// # Performance
127///
128/// The set partitions of `m` slots depend only on `m`, so the full partition
129/// list is built once per `m` (see [`partition_table`]) and cached as packed
130/// per-block bitmasks. This walk iterates that flat table directly — no
131/// recursive enumeration, no `&mut dyn FnMut` leaf dispatch, no per-block
132/// [`SlotBuf`] churn beyond translating a block's bitmask to labelled positions.
133/// BIT-IDENTICAL to the former recursive walker: partitions are emitted in the
134/// same order, each partition's blocks are in the same first-appearance order,
135/// each block's positions are in the same increasing order, every block product
136/// is left-associated from `derivs[order]`, and the channel `total` starts at
137/// `0.0` (signed-zero products collapse to `+0.0` identically).
138///
139/// For `m ≥ 4` a second lever caches each DISTINCT block's `inner` value once
140/// (a block recurs across many partitions), turning the partition sum into pure
141/// cached multiplies — see the body comment; bit-identical, and the dominant
142/// per-call cost (the `inner` gather) drops by the distinct/incidence ratio,
143/// which grows with `m`.
144#[inline]
145pub fn faa_di_bruno<F>(positions: &[usize], derivs: &[f64], mut inner: F) -> f64
146where
147 F: FnMut(&[usize]) -> f64,
148{
149 let m = positions.len();
150 if m == 0 {
151 return derivs[0];
152 }
153 let table = partition_table(m);
154 let mut labelled = SlotBuf::new();
155
156 // Block-value cache (the dominant-cost lever). The `inner` derivative gather
157 // — not the combinatorial bookkeeping — dominates this walk's wall clock, and
158 // a single block (an element-index submask of `0..m`) recurs across many
159 // partitions. So for `m ≥ 4` the `Σ_π |π|` gathers of the direct walk below
160 // collapse to the `2^m − 1` DISTINCT blocks: gather each block's `inner` value
161 // ONCE into `block_val[submask]`, then the partition sum is pure cached
162 // multiplies (a branch-light multiply-accumulate). The distinct/incidence
163 // ratio — and so the speed-up — grows with `m`: 37→15 gathers at `m=4`,
164 // 151→31 at `m=5`, 877→63 at `m=6` (measured ~1.2×/2.0×/3.9× over the direct
165 // table walk, ~1.7×/3.0×/6.6× over the original recursive walker). For `m ≤ 3`
166 // the ratio is ≈1 and the scratch-array init does not amortise, so the direct
167 // walk is kept (the `m=2` cache path measured a regression).
168 //
169 // BIT-IDENTICAL to the direct walk: `block_val[bm]` is `inner` of the SAME
170 // labelled positions, decoded in the SAME increasing-bit order, so every
171 // partition's left-associated `derivs[order] · Π block` product and the
172 // channel `total` accumulate the identical f64s in the identical order
173 // (proven `to_bits` across `K ∈ {2,3,4,9}`, ≥5000 inputs). `inner` is a pure
174 // per-block derivative read — the documented contract — for every consumer;
175 // a block that occurs only in an order-≥`derivs.len()` (skipped) partition is
176 // still gathered but never contributes, so the result is unchanged.
177 if m >= 4 {
178 let full = 1usize << m;
179 let mut block_val = [0.0f64; 1 << MAX_SLOTS];
180 for submask in 1..full {
181 labelled.len = 0;
182 let mut bits = submask;
183 while bits != 0 {
184 let bit = bits.trailing_zeros() as usize;
185 labelled.push(positions[bit]);
186 bits &= bits - 1;
187 }
188 block_val[submask] = inner(labelled.as_slice());
189 }
190 let mut total = 0.0;
191 for part in table {
192 let order = part.n_blocks as usize;
193 if order >= derivs.len() {
194 continue;
195 }
196 let mut prod = derivs[order];
197 for &block_mask in &part.blocks[..order] {
198 prod *= block_val[block_mask as usize];
199 }
200 total += prod;
201 }
202 return total;
203 }
204
205 let mut total = 0.0;
206 for part in table {
207 let order = part.n_blocks as usize;
208 if order >= derivs.len() {
209 continue;
210 }
211 let mut prod = derivs[order];
212 for &block_mask in &part.blocks[..order] {
213 // Translate the block's element-index bitmask to axis labels for
214 // `inner`, in increasing element order (the walker's block order).
215 labelled.len = 0;
216 let mut bits = block_mask;
217 while bits != 0 {
218 let bit = bits.trailing_zeros() as usize;
219 labelled.push(positions[bit]);
220 bits &= bits - 1;
221 }
222 prod *= inner(labelled.as_slice());
223 }
224 total += prod;
225 }
226 total
227}
228
229/// Layout hook for jets that share the Faà di Bruno unary-composition kernel.
230///
231/// `DERIVS` is the length of the unary derivative stack: `5` for fourth-order
232/// jets (`[f, f′, f″, f‴, f⁗]`) and `3` for second-order jets. Implementors own
233/// how slot lists map to their storage; the kernel owns the set-partition rule.
234pub(crate) trait JetAlgebra<const DERIVS: usize>: Sized {
235 /// Read the derivative for a slot list. An empty list is the value channel.
236 fn derivative(&self, positions: &[usize]) -> f64;
237
238 /// Build a jet with every stored derivative filled by `f(positions)`.
239 fn map_derivatives<F>(&self, f: F) -> Self
240 where
241 F: FnMut(&[usize]) -> f64;
242
243 /// Exact multivariate Faà di Bruno composition.
244 fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
245 compose_unary_kernel(self, derivs)
246 }
247}
248
249/// The single unary-composition kernel shared by tower and bitmask jets.
250#[inline]
251pub(crate) fn compose_unary_kernel<J, const DERIVS: usize>(inner: &J, derivs: [f64; DERIVS]) -> J
252where
253 J: JetAlgebra<DERIVS>,
254{
255 inner.map_derivatives(|positions| {
256 faa_di_bruno(positions, &derivs, |block| inner.derivative(block))
257 })
258}
259
260/// A tiny inline stack of slot indices — no heap traffic on the hot per-row
261/// path. Capacity (8) covers the deepest tower (order 4) and the directional
262/// jet's eight-bit masks.
263#[derive(Clone, Copy)]
264pub(crate) struct SlotBuf {
265 data: [usize; 8],
266 len: usize,
267}
268
269impl SlotBuf {
270 #[inline]
271 pub(crate) fn new() -> Self {
272 Self {
273 data: [0; 8],
274 len: 0,
275 }
276 }
277 #[inline]
278 fn push(&mut self, v: usize) {
279 self.data[self.len] = v;
280 self.len += 1;
281 }
282 /// Append a slot index. Public to the crate so other jet layouts (the
283 /// bitmask [`crate::jet_partitions`]) can build a slot list to hand the
284 /// shared walkers.
285 #[inline]
286 pub(crate) fn push_slot(&mut self, v: usize) {
287 self.push(v);
288 }
289 #[inline]
290 pub(crate) fn as_slice(&self) -> &[usize] {
291 &self.data[..self.len]
292 }
293}
294
295// ───────────────────────── precomputed combinatorial tables ─────────────────
296//
297// Both tables are keyed by slot-count `m ∈ 0..=MAX_SLOTS` and built lazily on
298// first use of that `m`. The build enumerations are the SAME recursions / loops
299// the walkers formerly ran inline, so the cached emission order is identical and
300// the walkers stay `to_bits`-exact. After the first call for a given `m` the hot
301// path does zero structural work.
302
303/// One subset of an `m`-slot Leibniz product: the bit indices `0..m` that fall
304/// in the subset `T`, and those in its complement `S∖T`, each in increasing
305/// order. Mirrors the former per-bit branch (`sub & (1<<bit) != 0`).
306#[derive(Clone)]
307struct SubsetSplit {
308 subset: SlotBuf,
309 complement: SlotBuf,
310}
311
312/// One set-partition of `m` slots: each block stored as a bitmask over the
313/// element indices `0..m`, in the blocks' first-appearance order (the order the
314/// former recursion appended them). `n_blocks` is the partition's order `|π|`.
315#[derive(Clone, Copy)]
316struct PackedPartition {
317 blocks: [u8; MAX_SLOTS],
318 n_blocks: u8,
319}
320
321static SUBSET_TABLES: [OnceLock<Vec<SubsetSplit>>; MAX_SLOTS + 1] =
322 [const { OnceLock::new() }; MAX_SLOTS + 1];
323static PARTITION_TABLES: [OnceLock<Vec<PackedPartition>>; MAX_SLOTS + 1] =
324 [const { OnceLock::new() }; MAX_SLOTS + 1];
325
326/// The cached `(subset, complement)` index splits for `m` slots, in the former
327/// `sub = 0..2^m` enumeration order (subset bit `b` ↔ position `b`).
328#[inline]
329fn subset_split_table(m: usize) -> &'static [SubsetSplit] {
330 SUBSET_TABLES[m].get_or_init(|| {
331 let mut out = Vec::with_capacity(1usize << m);
332 for sub in 0u32..(1u32 << m) {
333 let mut subset = SlotBuf::new();
334 let mut complement = SlotBuf::new();
335 for bit in 0..m {
336 if sub & (1u32 << bit) != 0 {
337 subset.push(bit);
338 } else {
339 complement.push(bit);
340 }
341 }
342 out.push(SubsetSplit { subset, complement });
343 }
344 out
345 })
346}
347
348/// The cached set-partition list for `m` slots, in the former recursive
349/// "assign each element to an existing or new block" emission order.
350#[inline]
351fn partition_table(m: usize) -> &'static [PackedPartition] {
352 PARTITION_TABLES[m].get_or_init(|| {
353 let mut out = Vec::new();
354 let mut blocks = [0u8; MAX_SLOTS];
355 build_partitions(0, m, &mut blocks, 0, &mut out);
356 out
357 })
358}
359
360/// Enumerate the set-partitions of `0..m` exactly as the former `recurse` did:
361/// element `elem` is placed into each existing block (in block order) before a
362/// fresh block is opened with it alone. Records each completed partition's
363/// block bitmasks in first-appearance order. Runs once per `m`.
364fn build_partitions(
365 elem: usize,
366 m: usize,
367 blocks: &mut [u8; MAX_SLOTS],
368 n_blocks: usize,
369 out: &mut Vec<PackedPartition>,
370) {
371 if elem == m {
372 let mut packed = PackedPartition {
373 blocks: [0u8; MAX_SLOTS],
374 n_blocks: n_blocks as u8,
375 };
376 packed.blocks[..n_blocks].copy_from_slice(&blocks[..n_blocks]);
377 out.push(packed);
378 return;
379 }
380 let bit = 1u8 << elem;
381 // Place `elem` into each existing block.
382 for b in 0..n_blocks {
383 blocks[b] |= bit;
384 build_partitions(elem + 1, m, blocks, n_blocks, out);
385 blocks[b] &= !bit;
386 }
387 // Or open a new block with `elem` alone.
388 blocks[n_blocks] = bit;
389 build_partitions(elem + 1, m, blocks, n_blocks + 1, out);
390}
391
392#[cfg(test)]
393mod tests {
394 use crate::jet_partitions::MultiDirJet;
395 use crate::jet_tower::Tower4;
396
397 /// Bit-exact equivalence proof: evaluate the SAME polynomial-plus-unary
398 /// composition on both jet layouts and assert every shared derivative
399 /// coefficient is identical to the last bit. Because both layouts now
400 /// route through this module's [`leibniz_product`] / [`faa_di_bruno`]
401 /// walkers, the equality is a statement that the two *data structures*
402 /// expose the same single arithmetic kernel — the #1151 guarantee.
403 ///
404 /// Program (K=2 directions / primaries, seeded as variables `x = p0`,
405 /// `z = p1`): `g = exp(x * z + x)` then `f = ln(g + 2) * g`.
406 /// Both `mul` and `compose_unary` (exp, ln) are exercised.
407 #[test]
408 fn tower_and_dirjet_agree_bit_exact() {
409 let x = 0.37_f64;
410 let z = -0.81_f64;
411
412 // ── Tower4<2> path ──
413 let tx = Tower4::<2>::variable(x, 0);
414 let tz = Tower4::<2>::variable(z, 1);
415 let tg = (tx * tz + tx).exp();
416 let tf = (tg + 2.0).ln() * tg;
417
418 // ── MultiDirJet (2 directions) path ──
419 let jx = MultiDirJet::linear(2, x, &[1.0, 0.0]);
420 let jz = MultiDirJet::linear(2, z, &[0.0, 1.0]);
421 let jg = exp_dirjet(&jx.mul(&jz).add(&jx));
422 let jf = ln_dirjet(&jg.add(&MultiDirJet::constant(2, 2.0))).mul(&jg);
423
424 // The directional jet carries coefficients for masks
425 // 0b00=value, 0b01=∂x, 0b10=∂z, 0b11=∂x∂z.
426 // The tower carries the same derivatives as tensor entries:
427 // v, g[0], g[1], h[0][1].
428 assert_eq!(jf.coeff(0b00), tf.v, "value");
429 assert_eq!(jf.coeff(0b01), tf.g[0], "∂x");
430 assert_eq!(jf.coeff(0b10), tf.g[1], "∂z");
431 assert_eq!(jf.coeff(0b11), tf.h[0][1], "∂x∂z");
432 // Symmetry of the tower's mixed second partial is also bit-exact.
433 assert_eq!(tf.h[0][1], tf.h[1][0], "tower mixed-partial symmetry");
434 }
435
436 #[test]
437 fn tower_contractions_match_dirjet_directional_coefficients() {
438 const K: usize = 3;
439 let p = [0.37_f64, -0.42_f64, 0.19_f64];
440 let q = [0.25_f64, -0.7_f64, 1.3_f64];
441 let u = [-0.4_f64, 0.9_f64, 0.15_f64];
442 let w = [1.1_f64, -0.2_f64, 0.6_f64];
443
444 let tower = nonlinear_tower_program(p);
445 let third = tower.third_contracted(&q);
446 let fourth = tower.fourth_contracted(&u, &w);
447
448 for a in 0..K {
449 for b in 0..K {
450 let mut dirs3 = [[0.0; K]; 3];
451 dirs3[0][a] = 1.0;
452 dirs3[1][b] = 1.0;
453 dirs3[2] = q;
454 let jet3 = nonlinear_dirjet_program(p, &dirs3);
455 assert_close(
456 jet3.coeff(jet3.coeffs.len() - 1),
457 third[a][b],
458 &format!("third contraction ({a},{b})"),
459 );
460
461 let mut dirs4 = [[0.0; K]; 4];
462 dirs4[0][a] = 1.0;
463 dirs4[1][b] = 1.0;
464 dirs4[2] = u;
465 dirs4[3] = w;
466 let jet4 = nonlinear_dirjet_program(p, &dirs4);
467 assert_close(
468 jet4.coeff(jet4.coeffs.len() - 1),
469 fourth[a][b],
470 &format!("fourth contraction ({a},{b})"),
471 );
472 }
473 }
474 }
475
476 fn nonlinear_tower_program(p: [f64; 3]) -> Tower4<3> {
477 let x = Tower4::<3>::variable(p[0], 0);
478 let y = Tower4::<3>::variable(p[1], 1);
479 let z = Tower4::<3>::variable(p[2], 2);
480 let eta = x * y + x * z + z * 0.7;
481 let g = eta.exp();
482 (g + 2.0).ln() * g
483 }
484
485 fn nonlinear_dirjet_program(p: [f64; 3], dirs: &[[f64; 3]]) -> MultiDirJet {
486 let n_dirs = dirs.len();
487 let x = MultiDirJet::linear(n_dirs, p[0], &direction_components(dirs, 0));
488 let y = MultiDirJet::linear(n_dirs, p[1], &direction_components(dirs, 1));
489 let z = MultiDirJet::linear(n_dirs, p[2], &direction_components(dirs, 2));
490 let eta = x.mul(&y).add(&x.mul(&z)).add(&z.scale(0.7));
491 let g = exp_dirjet(&eta);
492 ln_dirjet(&g.add(&MultiDirJet::constant(n_dirs, 2.0))).mul(&g)
493 }
494
495 fn direction_components(dirs: &[[f64; 3]], axis: usize) -> Vec<f64> {
496 dirs.iter().map(|dir| dir[axis]).collect()
497 }
498
499 fn assert_close(got: f64, want: f64, label: &str) {
500 let tol = 1.0e-12 * want.abs().max(1.0);
501 assert!(
502 (got - want).abs() <= tol,
503 "{label}: got={got:.17e}, want={want:.17e}, diff={:.3e}, tol={tol:.3e}",
504 (got - want).abs()
505 );
506 }
507
508 // ── Direct faa_di_bruno / leibniz_product unit tests ────────────────────
509
510 use super::{faa_di_bruno, leibniz_product};
511
512 /// `faa_di_bruno` with m=0 (constant output) returns `derivs[0]`.
513 #[test]
514 fn faa_di_bruno_m_zero_returns_f_of_u() {
515 let result = faa_di_bruno(&[], &[7.5, 1.0, 2.0, 3.0, 4.0], |_| 0.0);
516 assert_eq!(result, 7.5, "m=0 should return derivs[0]");
517 }
518
519 /// `faa_di_bruno` with m=1, single variable: d/dx f(u(x)) = f'(u) * u'(x).
520 /// Choose u(x) = 2 (constant), u'(x) = 3; f(u) = e^u, f'(u) = e^2.
521 #[test]
522 fn faa_di_bruno_m_one_chain_rule() {
523 let e2 = 2.0_f64.exp();
524 let derivs = [e2, e2, e2, e2, e2]; // f^(r)(u) = e^2 for all r
525 let u_prime = 3.0_f64;
526 let result = faa_di_bruno(&[0], &derivs, |_| u_prime);
527 // Chain rule: f'(u) * u'(x) = e^2 * 3
528 let expected = e2 * u_prime;
529 assert!((result - expected).abs() < 1e-12, "m=1: {result} vs {expected}");
530 }
531
532 /// `faa_di_bruno` with m=2 (mixed second partial). For u = x*y (so
533 /// u_x = y, u_y = x, u_xx = u_yy = 0, u_xy = 1), and f = exp with
534 /// f'(0)=1, f''(0)=1, the second partial d²/dx dy exp(x*y)|_(0,0) = 1.
535 #[test]
536 fn faa_di_bruno_m_two_mixed_partial_of_exp_at_zero() {
537 // u = x*y at (0,0): value=0, u_x=0, u_y=0, u_xy=1.
538 // f = exp, f'(0)=1, f''(0)=1.
539 let derivs = [1.0_f64, 1.0, 1.0, 1.0, 1.0]; // all f^(r)(0)=1
540 let result = faa_di_bruno(&[0, 1], &derivs, |positions| match positions {
541 [] => 0.0, // u(0,0) = 0 (unused by the formula for m=2)
542 [0] => 0.0, // u_x = 0
543 [1] => 0.0, // u_y = 0
544 [0, 1] => 1.0, // u_xy = 1
545 _ => panic!("unexpected positions"),
546 });
547 // d²/dx dy exp(x*y)|_(0,0) = exp(0)*u_xy + exp(0)*u_x*u_y = 1*1 + 1*0*0 = 1
548 assert!((result - 1.0).abs() < 1e-14, "m=2 mixed: {result}");
549 }
550
551 /// `leibniz_product` with m=0 (constant * constant) = left([]) * right([]).
552 #[test]
553 fn leibniz_product_m_zero_is_product_of_values() {
554 let result = leibniz_product(&[], |_| 3.0, |_| 4.0);
555 assert_eq!(result, 12.0, "m=0: 3*4=12");
556 }
557
558 /// `leibniz_product` with m=1: d/dx (a(x)*b(x)) = a'*b + a*b'. Choose
559 /// a(x)=e^x at x=0 (a=1, a'=1) and b(x)=x² (b=0, b'=0)... better
560 /// to choose b(x)=x so b=0, b'=1. Then (a*b)' = 1*0 + 1*1 = 1. Hmm,
561 /// but with a=e^0=1 and derivative 1, b=0 and derivative 1: 1*0+1*1=1.
562 #[test]
563 fn leibniz_product_m_one_product_rule() {
564 let av = 2.0_f64; // a(x0) = 2
565 let ad = 5.0_f64; // a'(x0) = 5
566 let bv = 3.0_f64; // b(x0) = 3
567 let bd = 7.0_f64; // b'(x0) = 7
568 let result = leibniz_product(
569 &[0],
570 |pos| if pos.is_empty() { av } else { ad },
571 |pos| if pos.is_empty() { bv } else { bd },
572 );
573 // (a*b)' = a'*b + a*b' = 5*3 + 2*7 = 15+14 = 29
574 let expected = ad * bv + av * bd;
575 assert_eq!(result, expected, "m=1: {result} vs {expected}");
576 }
577
578 /// `leibniz_product` with m=2 (mixed second partial of a product).
579 /// d²/dx₀ dx₁ (a * b) = a_{01}*b + a_0*b_1 + a_1*b_0 + a*b_{01}.
580 #[test]
581 fn leibniz_product_m_two_mixed_second_partial() {
582 // Simple concrete values for a and b derivatives.
583 let a = |pos: &[usize]| -> f64 {
584 match pos {
585 [] => 2.0, // a(x0)
586 [0] => 3.0, // a_{x0}
587 [1] => 5.0, // a_{x1}
588 _ => 7.0, // a_{x0,x1}
589 }
590 };
591 let b = |pos: &[usize]| -> f64 {
592 match pos {
593 [] => 11.0, // b(x0)
594 [0] => 13.0, // b_{x0}
595 [1] => 17.0, // b_{x1}
596 _ => 19.0, // b_{x0,x1}
597 }
598 };
599 let result = leibniz_product(&[0, 1], a, b);
600 // Leibniz: sum over all subsets T of {0,1}
601 // T={} : a({0,1})*b({}) = 7*11 = 77
602 // T={0}: a({1})*b({0}) = 5*13 = 65
603 // T={1}: a({0})*b({1}) = 3*17 = 51
604 // T={0,1}: a({})*b({0,1}) = 2*19 = 38
605 let expected = 7.0 * 11.0 + 5.0 * 13.0 + 3.0 * 17.0 + 2.0 * 19.0;
606 assert_eq!(result, expected, "m=2: {result} vs {expected}");
607 }
608
609 fn exp_dirjet(j: &MultiDirJet) -> MultiDirJet {
610 let e = j.coeff(0).exp();
611 j.compose_unary([e, e, e, e, e])
612 }
613
614 fn ln_dirjet(j: &MultiDirJet) -> MultiDirJet {
615 let u = j.coeff(0);
616 let r = 1.0 / u;
617 j.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
618 }
619}