use std::sync::OnceLock;
const MAX_SLOTS: usize = 8;
#[inline]
pub(crate) fn leibniz_product<L, R>(positions: &[usize], mut left: L, mut right: R) -> f64
where
L: FnMut(&[usize]) -> f64,
R: FnMut(&[usize]) -> f64,
{
let m = positions.len();
assert!(
m <= MAX_SLOTS,
"too many differentiation slots for subset enumeration"
);
let table = subset_split_table(m);
let mut subset = SlotBuf::new();
let mut complement = SlotBuf::new();
let mut total = 0.0;
for split in table {
subset.len = 0;
for &bit in split.subset.as_slice() {
subset.push(positions[bit]);
}
complement.len = 0;
for &bit in split.complement.as_slice() {
complement.push(positions[bit]);
}
total += left(subset.as_slice()) * right(complement.as_slice());
}
total
}
#[inline]
pub fn faa_di_bruno<F>(positions: &[usize], derivs: &[f64], mut inner: F) -> f64
where
F: FnMut(&[usize]) -> f64,
{
let m = positions.len();
if m == 0 {
return derivs[0];
}
let table = partition_table(m);
let mut labelled = SlotBuf::new();
if m >= 4 {
let full = 1usize << m;
let mut block_val = [0.0f64; 1 << MAX_SLOTS];
for submask in 1..full {
labelled.len = 0;
let mut bits = submask;
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
labelled.push(positions[bit]);
bits &= bits - 1;
}
block_val[submask] = inner(labelled.as_slice());
}
let mut total = 0.0;
for part in table {
let order = part.n_blocks as usize;
if order >= derivs.len() {
continue;
}
let mut prod = derivs[order];
for &block_mask in &part.blocks[..order] {
prod *= block_val[block_mask as usize];
}
total += prod;
}
return total;
}
let mut total = 0.0;
for part in table {
let order = part.n_blocks as usize;
if order >= derivs.len() {
continue;
}
let mut prod = derivs[order];
for &block_mask in &part.blocks[..order] {
labelled.len = 0;
let mut bits = block_mask;
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
labelled.push(positions[bit]);
bits &= bits - 1;
}
prod *= inner(labelled.as_slice());
}
total += prod;
}
total
}
pub(crate) trait JetAlgebra<const DERIVS: usize>: Sized {
fn derivative(&self, positions: &[usize]) -> f64;
fn map_derivatives<F>(&self, f: F) -> Self
where
F: FnMut(&[usize]) -> f64;
fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
compose_unary_kernel(self, derivs)
}
}
#[inline]
pub(crate) fn compose_unary_kernel<J, const DERIVS: usize>(inner: &J, derivs: [f64; DERIVS]) -> J
where
J: JetAlgebra<DERIVS>,
{
inner.map_derivatives(|positions| {
faa_di_bruno(positions, &derivs, |block| inner.derivative(block))
})
}
#[derive(Clone, Copy)]
pub(crate) struct SlotBuf {
data: [usize; 8],
len: usize,
}
impl SlotBuf {
#[inline]
pub(crate) fn new() -> Self {
Self {
data: [0; 8],
len: 0,
}
}
#[inline]
fn push(&mut self, v: usize) {
self.data[self.len] = v;
self.len += 1;
}
#[inline]
pub(crate) fn push_slot(&mut self, v: usize) {
self.push(v);
}
#[inline]
pub(crate) fn as_slice(&self) -> &[usize] {
&self.data[..self.len]
}
}
#[derive(Clone)]
struct SubsetSplit {
subset: SlotBuf,
complement: SlotBuf,
}
#[derive(Clone, Copy)]
struct PackedPartition {
blocks: [u8; MAX_SLOTS],
n_blocks: u8,
}
static SUBSET_TABLES: [OnceLock<Vec<SubsetSplit>>; MAX_SLOTS + 1] =
[const { OnceLock::new() }; MAX_SLOTS + 1];
static PARTITION_TABLES: [OnceLock<Vec<PackedPartition>>; MAX_SLOTS + 1] =
[const { OnceLock::new() }; MAX_SLOTS + 1];
#[inline]
fn subset_split_table(m: usize) -> &'static [SubsetSplit] {
SUBSET_TABLES[m].get_or_init(|| {
let mut out = Vec::with_capacity(1usize << m);
for sub in 0u32..(1u32 << m) {
let mut subset = SlotBuf::new();
let mut complement = SlotBuf::new();
for bit in 0..m {
if sub & (1u32 << bit) != 0 {
subset.push(bit);
} else {
complement.push(bit);
}
}
out.push(SubsetSplit { subset, complement });
}
out
})
}
#[inline]
fn partition_table(m: usize) -> &'static [PackedPartition] {
PARTITION_TABLES[m].get_or_init(|| {
let mut out = Vec::new();
let mut blocks = [0u8; MAX_SLOTS];
build_partitions(0, m, &mut blocks, 0, &mut out);
out
})
}
fn build_partitions(
elem: usize,
m: usize,
blocks: &mut [u8; MAX_SLOTS],
n_blocks: usize,
out: &mut Vec<PackedPartition>,
) {
if elem == m {
let mut packed = PackedPartition {
blocks: [0u8; MAX_SLOTS],
n_blocks: n_blocks as u8,
};
packed.blocks[..n_blocks].copy_from_slice(&blocks[..n_blocks]);
out.push(packed);
return;
}
let bit = 1u8 << elem;
for b in 0..n_blocks {
blocks[b] |= bit;
build_partitions(elem + 1, m, blocks, n_blocks, out);
blocks[b] &= !bit;
}
blocks[n_blocks] = bit;
build_partitions(elem + 1, m, blocks, n_blocks + 1, out);
}
#[cfg(test)]
mod tests {
use crate::jet_partitions::MultiDirJet;
use crate::jet_tower::Tower4;
#[test]
fn tower_and_dirjet_agree_bit_exact() {
let x = 0.37_f64;
let z = -0.81_f64;
let tx = Tower4::<2>::variable(x, 0);
let tz = Tower4::<2>::variable(z, 1);
let tg = (tx * tz + tx).exp();
let tf = (tg + 2.0).ln() * tg;
let jx = MultiDirJet::linear(2, x, &[1.0, 0.0]);
let jz = MultiDirJet::linear(2, z, &[0.0, 1.0]);
let jg = exp_dirjet(&jx.mul(&jz).add(&jx));
let jf = ln_dirjet(&jg.add(&MultiDirJet::constant(2, 2.0))).mul(&jg);
assert_eq!(jf.coeff(0b00), tf.v, "value");
assert_eq!(jf.coeff(0b01), tf.g[0], "∂x");
assert_eq!(jf.coeff(0b10), tf.g[1], "∂z");
assert_eq!(jf.coeff(0b11), tf.h[0][1], "∂x∂z");
assert_eq!(tf.h[0][1], tf.h[1][0], "tower mixed-partial symmetry");
}
#[test]
fn tower_contractions_match_dirjet_directional_coefficients() {
const K: usize = 3;
let p = [0.37_f64, -0.42_f64, 0.19_f64];
let q = [0.25_f64, -0.7_f64, 1.3_f64];
let u = [-0.4_f64, 0.9_f64, 0.15_f64];
let w = [1.1_f64, -0.2_f64, 0.6_f64];
let tower = nonlinear_tower_program(p);
let third = tower.third_contracted(&q);
let fourth = tower.fourth_contracted(&u, &w);
for a in 0..K {
for b in 0..K {
let mut dirs3 = [[0.0; K]; 3];
dirs3[0][a] = 1.0;
dirs3[1][b] = 1.0;
dirs3[2] = q;
let jet3 = nonlinear_dirjet_program(p, &dirs3);
assert_close(
jet3.coeff(jet3.coeffs.len() - 1),
third[a][b],
&format!("third contraction ({a},{b})"),
);
let mut dirs4 = [[0.0; K]; 4];
dirs4[0][a] = 1.0;
dirs4[1][b] = 1.0;
dirs4[2] = u;
dirs4[3] = w;
let jet4 = nonlinear_dirjet_program(p, &dirs4);
assert_close(
jet4.coeff(jet4.coeffs.len() - 1),
fourth[a][b],
&format!("fourth contraction ({a},{b})"),
);
}
}
}
fn nonlinear_tower_program(p: [f64; 3]) -> Tower4<3> {
let x = Tower4::<3>::variable(p[0], 0);
let y = Tower4::<3>::variable(p[1], 1);
let z = Tower4::<3>::variable(p[2], 2);
let eta = x * y + x * z + z * 0.7;
let g = eta.exp();
(g + 2.0).ln() * g
}
fn nonlinear_dirjet_program(p: [f64; 3], dirs: &[[f64; 3]]) -> MultiDirJet {
let n_dirs = dirs.len();
let x = MultiDirJet::linear(n_dirs, p[0], &direction_components(dirs, 0));
let y = MultiDirJet::linear(n_dirs, p[1], &direction_components(dirs, 1));
let z = MultiDirJet::linear(n_dirs, p[2], &direction_components(dirs, 2));
let eta = x.mul(&y).add(&x.mul(&z)).add(&z.scale(0.7));
let g = exp_dirjet(&eta);
ln_dirjet(&g.add(&MultiDirJet::constant(n_dirs, 2.0))).mul(&g)
}
fn direction_components(dirs: &[[f64; 3]], axis: usize) -> Vec<f64> {
dirs.iter().map(|dir| dir[axis]).collect()
}
fn assert_close(got: f64, want: f64, label: &str) {
let tol = 1.0e-12 * want.abs().max(1.0);
assert!(
(got - want).abs() <= tol,
"{label}: got={got:.17e}, want={want:.17e}, diff={:.3e}, tol={tol:.3e}",
(got - want).abs()
);
}
fn exp_dirjet(j: &MultiDirJet) -> MultiDirJet {
let e = j.coeff(0).exp();
j.compose_unary([e, e, e, e, e])
}
fn ln_dirjet(j: &MultiDirJet) -> MultiDirJet {
let u = j.coeff(0);
let r = 1.0 / u;
j.compose_unary([u.ln(), r, -r * r, 2.0 * r * r * r, -6.0 * r * r * r * r])
}
}