gam_math/
jet_partitions.rs1use std::sync::atomic::{AtomicU64, Ordering};
7
8pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
9pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
10
11#[derive(Clone)]
12pub struct MultiDirJet {
13 pub coeffs: Vec<f64>,
14}
15
16impl MultiDirJet {
17 pub fn zero(n_dirs: usize) -> Self {
18 Self {
19 coeffs: vec![0.0; 1usize << n_dirs],
20 }
21 }
22
23 pub fn constant(n_dirs: usize, value: f64) -> Self {
24 let mut out = Self::zero(n_dirs);
25 out.coeffs[0] = value;
26 out
27 }
28
29 pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
30 let mut out = Self::constant(n_dirs, base);
31 for (idx, &value) in first.iter().take(n_dirs).enumerate() {
32 out.coeffs[1usize << idx] = value;
33 }
34 out
35 }
36
37 pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
38 let mut out = Self::zero(n_dirs);
39 for &(mask, value) in coeffs {
40 if mask < out.coeffs.len() {
41 out.coeffs[mask] = value;
42 }
43 }
44 out
45 }
46
47 #[inline]
48 pub fn coeff(&self, mask: usize) -> f64 {
49 self.coeffs[mask]
50 }
51
52 pub fn add(&self, other: &Self) -> Self {
53 Self {
54 coeffs: self
55 .coeffs
56 .iter()
57 .zip(other.coeffs.iter())
58 .map(|(lhs, rhs)| lhs + rhs)
59 .collect(),
60 }
61 }
62
63 pub fn scale(&self, scalar: f64) -> Self {
64 Self {
65 coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
66 }
67 }
68
69 pub fn mul(&self, other: &Self) -> Self {
70 MUL_CALLS.fetch_add(1, Ordering::Relaxed);
71 let count = self.coeffs.len();
72 let mut out = vec![0.0; count];
73 for (mask, slot) in out.iter_mut().enumerate() {
74 let bits = bit_positions(mask);
80 *slot = crate::jet_algebra::leibniz_product(
81 bits.as_slice(),
82 |t| self.coeffs[mask_of(t)],
83 |c| other.coeffs[mask_of(c)],
84 );
85 }
86 Self { coeffs: out }
87 }
88
89 pub fn compose_unary(&self, derivs: [f64; 5]) -> Self {
90 COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
91 <Self as crate::jet_algebra::JetAlgebra<5>>::compose_unary(self, derivs)
92 }
93}
94
95impl crate::jet_algebra::JetAlgebra<5> for MultiDirJet {
96 #[inline]
97 fn derivative(&self, slots: &[usize]) -> f64 {
98 self.coeffs[mask_of(slots)]
99 }
100
101 fn map_derivatives<F>(&self, mut f: F) -> Self
102 where
103 F: FnMut(&[usize]) -> f64,
104 {
105 let mut out = vec![0.0; self.coeffs.len()];
106 for (mask, value) in out.iter_mut().enumerate() {
107 let bits = bit_positions(mask);
108 *value = f(bits.as_slice());
109 }
110 Self { coeffs: out }
111 }
112}
113
114fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
117 let mut out = crate::jet_algebra::SlotBuf::new();
118 let mut m = mask;
119 while m != 0 {
120 let bit = m.trailing_zeros() as usize;
121 out.push_slot(bit);
122 m &= m - 1;
123 }
124 out
125}
126
127fn mask_of(slots: &[usize]) -> usize {
129 slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
130}
131
132impl MultiDirJet {
142 pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
143 Self {
144 coeffs: vec![base, d1, d2, d12],
145 }
146 }
147
148 pub fn sub(&self, other: &Self) -> Self {
149 Self {
150 coeffs: self
151 .coeffs
152 .iter()
153 .zip(other.coeffs.iter())
154 .map(|(lhs, rhs)| lhs - rhs)
155 .collect(),
156 }
157 }
158}