use std::cell::RefCell;
use std::sync::atomic::{AtomicU64, Ordering};
use wide::{CmpGe, f64x4};
pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
const DERIVS: usize = 5;
#[derive(Clone)]
pub struct MultiDirJet {
pub coeffs: Vec<f64>,
}
impl MultiDirJet {
pub fn zero(n_dirs: usize) -> Self {
Self {
coeffs: vec![0.0; 1usize << n_dirs],
}
}
pub fn constant(n_dirs: usize, value: f64) -> Self {
let mut out = Self::zero(n_dirs);
out.coeffs[0] = value;
out
}
pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
let mut out = Self::constant(n_dirs, base);
for (idx, &value) in first.iter().take(n_dirs).enumerate() {
out.coeffs[1usize << idx] = value;
}
out
}
pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
let mut out = Self::zero(n_dirs);
for &(mask, value) in coeffs {
if mask < out.coeffs.len() {
out.coeffs[mask] = value;
}
}
out
}
#[inline]
pub fn coeff(&self, mask: usize) -> f64 {
self.coeffs[mask]
}
pub fn add(&self, other: &Self) -> Self {
Self {
coeffs: self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(lhs, rhs)| lhs + rhs)
.collect(),
}
}
pub fn scale(&self, scalar: f64) -> Self {
Self {
coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
}
}
pub fn mul(&self, other: &Self) -> Self {
MUL_CALLS.fetch_add(1, Ordering::Relaxed);
let count = self.coeffs.len();
if count <= 1 {
return self.mul_reference(other);
}
let a = &self.coeffs;
let b = &other.coeffs;
let mut out = vec![0.0; count];
for (mask, slot) in out.iter_mut().enumerate() {
let mut acc = 0.0;
let mut sub = 0usize;
loop {
acc += a[sub] * b[mask ^ sub];
if sub == mask {
break;
}
sub = (sub | !mask).wrapping_add(1) & mask;
}
*slot = acc;
}
Self { coeffs: out }
}
fn mul_reference(&self, other: &Self) -> Self {
let count = self.coeffs.len();
let mut out = vec![0.0; count];
for (mask, slot) in out.iter_mut().enumerate() {
let bits = bit_positions(mask);
*slot = crate::jet_algebra::leibniz_product(
bits.as_slice(),
|t| self.coeffs[mask_of(t)],
|c| other.coeffs[mask_of(c)],
);
}
Self { coeffs: out }
}
pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
let count = self.coeffs.len();
if count <= 1 {
return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
}
let c1 = derivs[1];
let c2 = derivs[2] * 0.5;
let c3 = derivs[3] * (1.0 / 6.0);
let c4 = derivs[4] * (1.0 / 24.0);
let mut out = vec![0.0; count];
COMPOSE_SCRATCH.with(|cell| {
let mut buf = cell.borrow_mut();
buf.clear();
buf.resize(4 * count, 0.0);
let (vbuf, rest) = buf.split_at_mut(count);
let (p2, rest) = rest.split_at_mut(count);
let (p3, p4) = rest.split_at_mut(count);
vbuf.copy_from_slice(&self.coeffs);
vbuf[0] = 0.0;
subset_conv_into(vbuf, vbuf, p2, 2);
subset_conv_into(p2, vbuf, p3, 3);
subset_conv_into(p2, p2, p4, 4);
combine_powers(vbuf, p2, p3, p4, [c1, c2, c3, c4], &mut out);
out[0] = derivs[0];
});
Self { coeffs: out }
}
}
thread_local! {
static COMPOSE_SCRATCH: RefCell<Vec<f64>> = const { RefCell::new(Vec::new()) };
}
#[inline(always)]
fn two_sum(a: f64, b: f64) -> (f64, f64) {
let s = a + b;
let bb = s - a;
let e = (a - (s - bb)) + (b - bb);
(s, e)
}
#[inline]
fn subset_conv_into(a: &[f64], b: &[f64], out: &mut [f64], min_pop: u32) {
for (mask, slot) in out.iter_mut().enumerate() {
if (mask as u64).count_ones() < min_pop {
*slot = 0.0;
continue;
}
#[inline(always)]
fn dot2_step(s: &mut f64, c: &mut f64, x: f64, y: f64) {
let prod = x * y;
let prod_err = x.mul_add(y, -prod); let (t, sum_err) = two_sum(*s, prod);
*s = t;
*c += prod_err + sum_err;
}
let (mut s0, mut s1, mut s2, mut s3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
let (mut c0, mut c1, mut c2, mut c3) = (0.0f64, 0.0f64, 0.0f64, 0.0f64);
let mut sub = mask;
loop {
dot2_step(&mut s0, &mut c0, a[sub], b[mask ^ sub]);
sub = (sub - 1) & mask;
dot2_step(&mut s1, &mut c1, a[sub], b[mask ^ sub]);
sub = (sub - 1) & mask;
dot2_step(&mut s2, &mut c2, a[sub], b[mask ^ sub]);
sub = (sub - 1) & mask;
dot2_step(&mut s3, &mut c3, a[sub], b[mask ^ sub]);
if sub == 0 {
break;
}
sub = (sub - 1) & mask;
}
let (s01, e01) = two_sum(s0, s1);
let (s23, e23) = two_sum(s2, s3);
let (total, etot) = two_sum(s01, s23);
*slot = total + (etot + e01 + e23 + c0 + c1 + c2 + c3);
}
}
#[inline]
fn combine_powers(p1: &[f64], p2: &[f64], p3: &[f64], p4: &[f64], c: [f64; 4], out: &mut [f64]) {
let n = out.len();
let (c1, c2, c3, c4) = (c[0], c[1], c[2], c[3]);
let (v1, v2, v3, v4) = (
f64x4::splat(c1),
f64x4::splat(c2),
f64x4::splat(c3),
f64x4::splat(c4),
);
let mut mask = 0usize;
while mask + 4 <= n {
let load = |p: &[f64]| f64x4::new([p[mask], p[mask + 1], p[mask + 2], p[mask + 3]]);
let mut s = v1 * load(p1);
let mut comp = f64x4::splat(0.0);
for (cv, pv) in [(v2, p2), (v3, p3), (v4, p4)] {
let term = cv * load(pv);
let t = s + term;
let big_s = s.abs().cmp_ge(term.abs());
let lost = big_s.blend((s - t) + term, (term - t) + s);
comp += lost;
s = t;
}
let res = s + comp;
out[mask..mask + 4].copy_from_slice(&res.to_array());
mask += 4;
}
while mask < n {
let mut s = c1 * p1[mask];
let mut comp = 0.0f64;
for (cv, pv) in [(c2, p2), (c3, p3), (c4, p4)] {
let term = cv * pv[mask];
let (t, e) = two_sum(s, term);
comp += e;
s = t;
}
out[mask] = s + comp;
mask += 1;
}
}
impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
#[inline]
fn derivative(&self, slots: &[usize]) -> f64 {
self.coeffs[mask_of(slots)]
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = vec![0.0; self.coeffs.len()];
for (mask, value) in out.iter_mut().enumerate() {
let bits = bit_positions(mask);
*value = f(bits.as_slice());
}
Self { coeffs: out }
}
}
#[cfg(test)]
struct PartTable {
flat: Vec<u32>,
parts: Vec<(usize, u8)>,
}
#[cfg(test)]
thread_local! {
static PARTITION_TABLES: RefCell<Vec<std::rc::Rc<PartTable>>> =
const { RefCell::new(Vec::new()) };
}
#[cfg(test)]
fn partition_tables(n_dirs: usize) -> Vec<std::rc::Rc<PartTable>> {
PARTITION_TABLES.with(|cell| {
let mut tables = cell.borrow_mut();
while tables.len() <= n_dirs {
let m = tables.len();
tables.push(std::rc::Rc::new(build_partitions(m)));
}
(0..=n_dirs).map(|m| std::rc::Rc::clone(&tables[m])).collect()
})
}
#[cfg(test)]
fn compose_unary_partition_reference(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<f64> {
let count = coeffs.len();
let n_dirs = count.trailing_zeros() as usize;
let tables = partition_tables(n_dirs);
let mut out = vec![0.0; count];
let mut remap = vec![0usize; count];
let mut pos = [0usize; usize::BITS as usize];
for (mask, slot) in out.iter_mut().enumerate() {
if mask == 0 {
*slot = derivs[0];
continue;
}
let mut npos = 0usize;
let mut m = mask;
while m != 0 {
pos[npos] = m.trailing_zeros() as usize;
npos += 1;
m &= m - 1;
}
remap[0] = 0;
for cb in 1usize..(1usize << npos) {
let low = cb.trailing_zeros() as usize;
remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
}
let table = &tables[npos];
let flat = &table.flat;
let mut total = 0.0;
for &(off, order) in table.parts.iter() {
let order = order as usize;
let mut prod = derivs[order];
for &cb in &flat[off..off + order] {
prod *= coeffs[remap[cb as usize]];
}
total += prod;
}
*slot = total;
}
out
}
#[cfg(test)]
fn build_partitions(m: usize) -> PartTable {
fn recurse(elem: usize, m: usize, blocks: &mut [u32; 8], n_blocks: usize, out: &mut PartTable) {
if n_blocks >= DERIVS {
return;
}
if elem == m {
let off = out.flat.len();
out.flat.extend_from_slice(&blocks[..n_blocks]);
out.parts.push((off, n_blocks as u8));
return;
}
for b in 0..n_blocks {
blocks[b] |= 1u32 << elem;
recurse(elem + 1, m, blocks, n_blocks, out);
blocks[b] &= !(1u32 << elem);
}
blocks[n_blocks] = 1u32 << elem;
recurse(elem + 1, m, blocks, n_blocks + 1, out);
}
let mut out = PartTable {
flat: Vec::new(),
parts: Vec::new(),
};
let mut blocks = [0u32; 8];
recurse(0, m, &mut blocks, 0, &mut out);
out
}
fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
let mut out = crate::jet_algebra::SlotBuf::new();
let mut m = mask;
while m != 0 {
let bit = m.trailing_zeros() as usize;
out.push_slot(bit);
m &= m - 1;
}
out
}
fn mask_of(slots: &[usize]) -> usize {
slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
}
impl MultiDirJet {
pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
Self {
coeffs: vec![base, d1, d2, d12],
}
}
pub fn sub(&self, other: &Self) -> Self {
Self {
coeffs: self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(lhs, rhs)| lhs - rhs)
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_has_correct_length_and_all_zero_coefficients() {
let j = MultiDirJet::zero(3);
assert_eq!(j.coeffs.len(), 8);
assert!(j.coeffs.iter().all(|&v| v == 0.0));
}
#[test]
fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
let j = MultiDirJet::constant(2, 5.0);
assert_eq!(j.coeffs.len(), 4);
assert_eq!(j.coeff(0), 5.0);
assert_eq!(j.coeff(1), 0.0);
assert_eq!(j.coeff(2), 0.0);
assert_eq!(j.coeff(3), 0.0);
}
#[test]
fn linear_sets_base_and_per_direction_slots() {
let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
assert_eq!(j.coeff(0), 1.0); assert_eq!(j.coeff(1), 2.0); assert_eq!(j.coeff(2), 3.0); assert_eq!(j.coeff(3), 0.0); }
#[test]
fn bilinear_sets_all_four_slots() {
let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
assert_eq!(j.coeff(0), 1.0);
assert_eq!(j.coeff(1), 2.0);
assert_eq!(j.coeff(2), 3.0);
assert_eq!(j.coeff(3), 4.0);
}
#[test]
fn with_coeffs_sets_only_specified_entries() {
let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
assert_eq!(j.coeff(0), 9.0);
assert_eq!(j.coeff(1), 0.0);
assert_eq!(j.coeff(2), 0.0);
assert_eq!(j.coeff(3), -1.0);
}
#[test]
fn add_is_elementwise() {
let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
let c = a.add(&b);
assert_eq!(c.coeff(0), 5.0);
assert_eq!(c.coeff(1), 7.0);
assert_eq!(c.coeff(2), 9.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn scale_multiplies_all_coefficients() {
let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
let s = j.scale(2.0);
assert_eq!(s.coeff(0), 2.0);
assert_eq!(s.coeff(1), 4.0);
assert_eq!(s.coeff(2), 6.0);
assert_eq!(s.coeff(3), 0.0);
}
#[test]
fn sub_is_elementwise_difference() {
let a = MultiDirJet::constant(2, 5.0);
let b = MultiDirJet::constant(2, 3.0);
let c = a.sub(&b);
assert_eq!(c.coeff(0), 2.0);
assert_eq!(c.coeff(1), 0.0);
assert_eq!(c.coeff(2), 0.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn mul_of_constants_is_scalar_product() {
let a = MultiDirJet::constant(2, 2.0);
let b = MultiDirJet::constant(2, 3.0);
let c = a.mul(&b);
assert_eq!(c.coeff(0), 6.0);
assert_eq!(c.coeff(1), 0.0);
assert_eq!(c.coeff(2), 0.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn mul_satisfies_leibniz_rule_single_direction() {
let x = MultiDirJet::linear(1, 1.0, &[1.0]);
let y = MultiDirJet::linear(1, 1.0, &[1.0]);
let z = x.mul(&y);
assert_eq!(z.coeff(0), 1.0);
assert_eq!(z.coeff(1), 2.0);
}
#[test]
fn mul_cross_term_two_independent_directions() {
let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
let z = x.mul(&y);
assert_eq!(z.coeff(0), 1.0);
assert_eq!(z.coeff(1), 1.0);
assert_eq!(z.coeff(2), 1.0);
assert_eq!(z.coeff(3), 1.0);
}
struct Rng(u64);
impl Rng {
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.0 = x;
x.wrapping_mul(0x2545F4914F6CDD1D)
}
fn signed(&mut self, scale: f64) -> f64 {
let u = (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64; (2.0 * u - 1.0) * scale
}
}
#[inline]
fn two_prod(a: f64, b: f64) -> (f64, f64) {
let p = a * b;
(p, a.mul_add(b, -p))
}
#[inline]
fn dd_two_sum(a: f64, b: f64) -> (f64, f64) {
let s = a + b;
let bb = s - a;
(s, (a - (s - bb)) + (b - bb))
}
#[derive(Clone, Copy)]
struct Dd {
hi: f64,
lo: f64,
}
impl Dd {
fn from(x: f64) -> Self {
Self { hi: x, lo: 0.0 }
}
fn mul_f64(self, b: f64) -> Self {
let (p, e) = two_prod(self.hi, b);
let lo = self.lo.mul_add(b, e);
let s = p + lo;
Self {
hi: s,
lo: (p - s) + lo,
}
}
fn add(self, o: Self) -> Self {
let (s, e) = dd_two_sum(self.hi, o.hi);
let (s2, e2) = dd_two_sum(self.lo, o.lo);
let lo = e + s2;
let h1 = s + lo;
let l1 = (s - h1) + lo;
let lo2 = l1 + e2;
let h = h1 + lo2;
Self {
hi: h,
lo: (h1 - h) + lo2,
}
}
fn abs_err_to(self, x: f64) -> f64 {
((x - self.hi) - self.lo).abs()
}
}
fn compose_truth(coeffs: &[f64], derivs: [f64; DERIVS]) -> Vec<Dd> {
let count = coeffs.len();
let n_dirs = count.trailing_zeros() as usize;
let tables = partition_tables(n_dirs);
let mut out = vec![Dd::from(0.0); count];
let mut remap = vec![0usize; count];
let mut pos = [0usize; 64];
for (mask, slot) in out.iter_mut().enumerate() {
if mask == 0 {
*slot = Dd::from(derivs[0]);
continue;
}
let mut npos = 0usize;
let mut m = mask;
while m != 0 {
pos[npos] = m.trailing_zeros() as usize;
npos += 1;
m &= m - 1;
}
remap[0] = 0;
for cb in 1usize..(1usize << npos) {
let low = cb.trailing_zeros() as usize;
remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
}
let table = &tables[npos];
let mut total = Dd::from(0.0);
for &(off, order) in table.parts.iter() {
let order = order as usize;
let mut prod = Dd::from(derivs[order]);
for &cb in &table.flat[off..off + order] {
prod = prod.mul_f64(coeffs[remap[cb as usize]]);
}
total = total.add(prod);
}
*slot = total;
}
out
}
fn random_inner(n_dirs: usize, rng: &mut Rng) -> MultiDirJet {
let base = rng.signed(0.8);
let first: Vec<f64> = (0..n_dirs).map(|_| rng.signed(0.6)).collect();
let a = MultiDirJet::linear(n_dirs, base, &first);
let b = MultiDirJet::linear(
n_dirs,
rng.signed(0.7),
&(0..n_dirs).map(|_| rng.signed(0.5)).collect::<Vec<_>>(),
);
a.mul(&b).add(&a)
}
#[test]
fn compose_unary_matches_partition_reference_simple() {
let j = MultiDirJet::linear(2, 0.3, &[0.5, -0.4])
.mul(&MultiDirJet::linear(2, -0.2, &[0.1, 0.7]));
let d = [0.9_f64, 1.1, -0.7, 0.4, -0.25];
let got = j.compose_unary(d);
let want = compose_unary_partition_reference(&j.coeffs, d);
for (mask, (&g, &w)) in got.coeffs.iter().zip(want.iter()).enumerate() {
let tol = 1e-13 * w.abs().max(1.0);
assert!(
(g - w).abs() <= tol,
"mask {mask}: got={g:.17e} want={w:.17e}"
);
}
}
#[test]
fn compose_unary_accuracy_beats_partition_sum_vs_double_double() {
let mut rng = Rng(0x1234_5678_9abc_def0);
let mut sum_new = 0.0f64;
let mut sum_old = 0.0f64;
for &n_dirs in &[2usize, 3, 4, 6, 8] {
for _ in 0..200 {
let inner = random_inner(n_dirs, &mut rng);
let d = [
rng.signed(1.5),
rng.signed(1.5),
rng.signed(2.0),
rng.signed(3.0),
rng.signed(4.0),
];
let new = inner.compose_unary(d);
let old = compose_unary_partition_reference(&inner.coeffs, d);
let truth = compose_truth(&inner.coeffs, d);
for mask in 0..inner.coeffs.len() {
let en = truth[mask].abs_err_to(new.coeffs[mask]);
let eo = truth[mask].abs_err_to(old[mask]);
sum_new += en;
sum_old += eo;
let scale = truth[mask].hi.abs().max(1.0);
assert!(
en <= eo + 4.0 * f64::EPSILON * scale,
"K={n_dirs} mask={mask}: new_err={en:.3e} old_err={eo:.3e}"
);
}
}
}
assert!(
sum_new <= sum_old,
"aggregate error regressed: new={sum_new:.6e} old={sum_old:.6e}"
);
eprintln!(
"compose_unary accuracy: total |err| new={sum_new:.6e} old={sum_old:.6e} \
(improvement {:.2}x)",
sum_old / sum_new.max(f64::MIN_POSITIVE)
);
}
#[test]
fn compose_unary_speedup_over_partition_sum() {
use std::time::Instant;
let mut rng = Rng(0xfeed_face_dead_beef);
for &n_dirs in &[2usize, 4, 6, 8] {
let n_inputs = 256usize;
let inputs: Vec<(MultiDirJet, [f64; DERIVS])> = (0..n_inputs)
.map(|_| {
(
random_inner(n_dirs, &mut rng),
[
rng.signed(1.5),
rng.signed(1.5),
rng.signed(2.0),
rng.signed(3.0),
rng.signed(4.0),
],
)
})
.collect();
let iters = 200usize;
for (j, d) in &inputs {
std::hint::black_box(j.compose_unary(*d));
std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
}
let t0 = Instant::now();
for _ in 0..iters {
for (j, d) in &inputs {
std::hint::black_box(j.compose_unary(*d));
}
}
let new_ns = t0.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
let t1 = Instant::now();
for _ in 0..iters {
for (j, d) in &inputs {
std::hint::black_box(compose_unary_partition_reference(&j.coeffs, *d));
}
}
let old_ns = t1.elapsed().as_nanos() as f64 / (iters * inputs.len()) as f64;
eprintln!(
"compose_unary K={n_dirs}: new={new_ns:.1} ns/call old={old_ns:.1} ns/call \
speedup={:.2}x",
old_ns / new_ns
);
if !cfg!(debug_assertions) && n_dirs >= 6 {
assert!(
new_ns < old_ns,
"K={n_dirs} new path slower: new={new_ns:.1}ns old={old_ns:.1}ns"
);
}
}
}
}