use std::cell::RefCell;
use std::rc::Rc;
use std::sync::atomic::{AtomicU64, Ordering};
pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
const DERIVS: usize = 5;
#[derive(Clone)]
pub struct MultiDirJet {
pub coeffs: Vec<f64>,
}
impl MultiDirJet {
pub fn zero(n_dirs: usize) -> Self {
Self {
coeffs: vec![0.0; 1usize << n_dirs],
}
}
pub fn constant(n_dirs: usize, value: f64) -> Self {
let mut out = Self::zero(n_dirs);
out.coeffs[0] = value;
out
}
pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
let mut out = Self::constant(n_dirs, base);
for (idx, &value) in first.iter().take(n_dirs).enumerate() {
out.coeffs[1usize << idx] = value;
}
out
}
pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
let mut out = Self::zero(n_dirs);
for &(mask, value) in coeffs {
if mask < out.coeffs.len() {
out.coeffs[mask] = value;
}
}
out
}
#[inline]
pub fn coeff(&self, mask: usize) -> f64 {
self.coeffs[mask]
}
pub fn add(&self, other: &Self) -> Self {
Self {
coeffs: self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(lhs, rhs)| lhs + rhs)
.collect(),
}
}
pub fn scale(&self, scalar: f64) -> Self {
Self {
coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
}
}
pub fn mul(&self, other: &Self) -> Self {
MUL_CALLS.fetch_add(1, Ordering::Relaxed);
let count = self.coeffs.len();
if count <= 1 {
return self.mul_reference(other);
}
let a = &self.coeffs;
let b = &other.coeffs;
let mut out = vec![0.0; count];
for (mask, slot) in out.iter_mut().enumerate() {
let mut acc = 0.0;
let mut sub = 0usize;
loop {
acc += a[sub] * b[mask ^ sub];
if sub == mask {
break;
}
sub = (sub | !mask).wrapping_add(1) & mask;
}
*slot = acc;
}
Self { coeffs: out }
}
fn mul_reference(&self, other: &Self) -> Self {
let count = self.coeffs.len();
let mut out = vec![0.0; count];
for (mask, slot) in out.iter_mut().enumerate() {
let bits = bit_positions(mask);
*slot = crate::jet_algebra::leibniz_product(
bits.as_slice(),
|t| self.coeffs[mask_of(t)],
|c| other.coeffs[mask_of(c)],
);
}
Self { coeffs: out }
}
pub fn compose_unary(&self, derivs: [f64; DERIVS]) -> Self {
COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
let count = self.coeffs.len();
if count <= 1 {
return <Self as crate::jet_algebra::JetAlgebra<DERIVS>>::compose_unary(self, derivs);
}
let n_dirs = count.trailing_zeros() as usize;
let tables = partition_tables(n_dirs);
let coeffs = &self.coeffs;
let mut out = vec![0.0; count];
let mut remap = vec![0usize; count];
let mut pos = [0usize; usize::BITS as usize];
for (mask, slot) in out.iter_mut().enumerate() {
if mask == 0 {
*slot = derivs[0];
continue;
}
let mut npos = 0usize;
let mut m = mask;
while m != 0 {
pos[npos] = m.trailing_zeros() as usize;
npos += 1;
m &= m - 1;
}
remap[0] = 0;
for cb in 1usize..(1usize << npos) {
let low = cb.trailing_zeros() as usize;
remap[cb] = remap[cb & (cb - 1)] | (1usize << pos[low]);
}
let table = &tables[npos];
let flat = &table.flat;
let mut total = 0.0;
for &(off, order) in table.parts.iter() {
let order = order as usize;
let mut prod = derivs[order];
for &cb in &flat[off..off + order] {
prod *= coeffs[remap[cb as usize]];
}
total += prod;
}
*slot = total;
}
Self { coeffs: out }
}
}
impl crate::jet_algebra::JetAlgebra<DERIVS> for MultiDirJet {
#[inline]
fn derivative(&self, slots: &[usize]) -> f64 {
self.coeffs[mask_of(slots)]
}
fn map_derivatives<F>(&self, mut f: F) -> Self
where
F: FnMut(&[usize]) -> f64,
{
let mut out = vec![0.0; self.coeffs.len()];
for (mask, value) in out.iter_mut().enumerate() {
let bits = bit_positions(mask);
*value = f(bits.as_slice());
}
Self { coeffs: out }
}
}
struct PartTable {
flat: Vec<u32>,
parts: Vec<(usize, u8)>,
}
thread_local! {
static PARTITION_TABLES: RefCell<Vec<Rc<PartTable>>> = const { RefCell::new(Vec::new()) };
}
fn partition_tables(n_dirs: usize) -> Vec<Rc<PartTable>> {
PARTITION_TABLES.with(|cell| {
let mut tables = cell.borrow_mut();
while tables.len() <= n_dirs {
let m = tables.len();
tables.push(Rc::new(build_partitions(m)));
}
(0..=n_dirs).map(|m| Rc::clone(&tables[m])).collect()
})
}
fn build_partitions(m: usize) -> PartTable {
fn recurse(elem: usize, m: usize, blocks: &mut [u32; 8], n_blocks: usize, out: &mut PartTable) {
if n_blocks >= DERIVS {
return;
}
if elem == m {
let off = out.flat.len();
out.flat.extend_from_slice(&blocks[..n_blocks]);
out.parts.push((off, n_blocks as u8));
return;
}
for b in 0..n_blocks {
blocks[b] |= 1u32 << elem;
recurse(elem + 1, m, blocks, n_blocks, out);
blocks[b] &= !(1u32 << elem);
}
blocks[n_blocks] = 1u32 << elem;
recurse(elem + 1, m, blocks, n_blocks + 1, out);
}
let mut out = PartTable {
flat: Vec::new(),
parts: Vec::new(),
};
let mut blocks = [0u32; 8];
recurse(0, m, &mut blocks, 0, &mut out);
out
}
fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
let mut out = crate::jet_algebra::SlotBuf::new();
let mut m = mask;
while m != 0 {
let bit = m.trailing_zeros() as usize;
out.push_slot(bit);
m &= m - 1;
}
out
}
fn mask_of(slots: &[usize]) -> usize {
slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
}
impl MultiDirJet {
pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
Self {
coeffs: vec![base, d1, d2, d12],
}
}
pub fn sub(&self, other: &Self) -> Self {
Self {
coeffs: self
.coeffs
.iter()
.zip(other.coeffs.iter())
.map(|(lhs, rhs)| lhs - rhs)
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_has_correct_length_and_all_zero_coefficients() {
let j = MultiDirJet::zero(3);
assert_eq!(j.coeffs.len(), 8);
assert!(j.coeffs.iter().all(|&v| v == 0.0));
}
#[test]
fn constant_has_value_at_mask_zero_and_zeros_elsewhere() {
let j = MultiDirJet::constant(2, 5.0);
assert_eq!(j.coeffs.len(), 4);
assert_eq!(j.coeff(0), 5.0);
assert_eq!(j.coeff(1), 0.0);
assert_eq!(j.coeff(2), 0.0);
assert_eq!(j.coeff(3), 0.0);
}
#[test]
fn linear_sets_base_and_per_direction_slots() {
let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
assert_eq!(j.coeff(0), 1.0); assert_eq!(j.coeff(1), 2.0); assert_eq!(j.coeff(2), 3.0); assert_eq!(j.coeff(3), 0.0); }
#[test]
fn bilinear_sets_all_four_slots() {
let j = MultiDirJet::bilinear(1.0, 2.0, 3.0, 4.0);
assert_eq!(j.coeff(0), 1.0);
assert_eq!(j.coeff(1), 2.0);
assert_eq!(j.coeff(2), 3.0);
assert_eq!(j.coeff(3), 4.0);
}
#[test]
fn with_coeffs_sets_only_specified_entries() {
let j = MultiDirJet::with_coeffs(2, &[(0, 9.0), (3, -1.0)]);
assert_eq!(j.coeff(0), 9.0);
assert_eq!(j.coeff(1), 0.0);
assert_eq!(j.coeff(2), 0.0);
assert_eq!(j.coeff(3), -1.0);
}
#[test]
fn add_is_elementwise() {
let a = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
let b = MultiDirJet::linear(2, 4.0, &[5.0, 6.0]);
let c = a.add(&b);
assert_eq!(c.coeff(0), 5.0);
assert_eq!(c.coeff(1), 7.0);
assert_eq!(c.coeff(2), 9.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn scale_multiplies_all_coefficients() {
let j = MultiDirJet::linear(2, 1.0, &[2.0, 3.0]);
let s = j.scale(2.0);
assert_eq!(s.coeff(0), 2.0);
assert_eq!(s.coeff(1), 4.0);
assert_eq!(s.coeff(2), 6.0);
assert_eq!(s.coeff(3), 0.0);
}
#[test]
fn sub_is_elementwise_difference() {
let a = MultiDirJet::constant(2, 5.0);
let b = MultiDirJet::constant(2, 3.0);
let c = a.sub(&b);
assert_eq!(c.coeff(0), 2.0);
assert_eq!(c.coeff(1), 0.0);
assert_eq!(c.coeff(2), 0.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn mul_of_constants_is_scalar_product() {
let a = MultiDirJet::constant(2, 2.0);
let b = MultiDirJet::constant(2, 3.0);
let c = a.mul(&b);
assert_eq!(c.coeff(0), 6.0);
assert_eq!(c.coeff(1), 0.0);
assert_eq!(c.coeff(2), 0.0);
assert_eq!(c.coeff(3), 0.0);
}
#[test]
fn mul_satisfies_leibniz_rule_single_direction() {
let x = MultiDirJet::linear(1, 1.0, &[1.0]);
let y = MultiDirJet::linear(1, 1.0, &[1.0]);
let z = x.mul(&y);
assert_eq!(z.coeff(0), 1.0);
assert_eq!(z.coeff(1), 2.0);
}
#[test]
fn mul_cross_term_two_independent_directions() {
let x = MultiDirJet::linear(2, 1.0, &[1.0, 0.0]);
let y = MultiDirJet::linear(2, 1.0, &[0.0, 1.0]);
let z = x.mul(&y);
assert_eq!(z.coeff(0), 1.0);
assert_eq!(z.coeff(1), 1.0);
assert_eq!(z.coeff(2), 1.0);
assert_eq!(z.coeff(3), 1.0);
}
}