Skip to main content

gam_math/
jet_partitions.rs

1//! Bitmask-coefficient multi-directional jets used by marginal-slope and
2//! latent-survival row kernels.
3//!
4//! The layout stores one coefficient per direction mask. The calculus itself
5//! lives in [`crate::jet_algebra`]: this module only maps slot lists to masks.
6use std::sync::atomic::{AtomicU64, Ordering};
7
8pub static COMPOSE_UNARY_CALLS: AtomicU64 = AtomicU64::new(0);
9pub static MUL_CALLS: AtomicU64 = AtomicU64::new(0);
10
11#[derive(Clone)]
12pub struct MultiDirJet {
13    pub coeffs: Vec<f64>,
14}
15
16impl MultiDirJet {
17    pub fn zero(n_dirs: usize) -> Self {
18        Self {
19            coeffs: vec![0.0; 1usize << n_dirs],
20        }
21    }
22
23    pub fn constant(n_dirs: usize, value: f64) -> Self {
24        let mut out = Self::zero(n_dirs);
25        out.coeffs[0] = value;
26        out
27    }
28
29    pub fn linear(n_dirs: usize, base: f64, first: &[f64]) -> Self {
30        let mut out = Self::constant(n_dirs, base);
31        for (idx, &value) in first.iter().take(n_dirs).enumerate() {
32            out.coeffs[1usize << idx] = value;
33        }
34        out
35    }
36
37    pub fn with_coeffs(n_dirs: usize, coeffs: &[(usize, f64)]) -> Self {
38        let mut out = Self::zero(n_dirs);
39        for &(mask, value) in coeffs {
40            if mask < out.coeffs.len() {
41                out.coeffs[mask] = value;
42            }
43        }
44        out
45    }
46
47    #[inline]
48    pub fn coeff(&self, mask: usize) -> f64 {
49        self.coeffs[mask]
50    }
51
52    pub fn add(&self, other: &Self) -> Self {
53        Self {
54            coeffs: self
55                .coeffs
56                .iter()
57                .zip(other.coeffs.iter())
58                .map(|(lhs, rhs)| lhs + rhs)
59                .collect(),
60        }
61    }
62
63    pub fn scale(&self, scalar: f64) -> Self {
64        Self {
65            coeffs: self.coeffs.iter().map(|value| scalar * value).collect(),
66        }
67    }
68
69    pub fn mul(&self, other: &Self) -> Self {
70        MUL_CALLS.fetch_add(1, Ordering::Relaxed);
71        let count = self.coeffs.len();
72        let mut out = vec![0.0; count];
73        for (mask, slot) in out.iter_mut().enumerate() {
74            // The differentiation slots of coefficient `mask` are its set bits;
75            // the shared Leibniz walker sums over subsets of those bits. A
76            // slot-group (list of bit positions) maps back to a sub-mask, the
77            // same submask enumeration the hand loop used — now one kernel
78            // shared with `Tower4::mul` (#1151).
79            let bits = bit_positions(mask);
80            *slot = crate::jet_algebra::leibniz_product(
81                bits.as_slice(),
82                |t| self.coeffs[mask_of(t)],
83                |c| other.coeffs[mask_of(c)],
84            );
85        }
86        Self { coeffs: out }
87    }
88
89    pub fn compose_unary(&self, derivs: [f64; 5]) -> Self {
90        COMPOSE_UNARY_CALLS.fetch_add(1, Ordering::Relaxed);
91        <Self as crate::jet_algebra::JetAlgebra<5>>::compose_unary(self, derivs)
92    }
93}
94
95impl crate::jet_algebra::JetAlgebra<5> for MultiDirJet {
96    #[inline]
97    fn derivative(&self, slots: &[usize]) -> f64 {
98        self.coeffs[mask_of(slots)]
99    }
100
101    fn map_derivatives<F>(&self, mut f: F) -> Self
102    where
103        F: FnMut(&[usize]) -> f64,
104    {
105        let mut out = vec![0.0; self.coeffs.len()];
106        for (mask, value) in out.iter_mut().enumerate() {
107            let bits = bit_positions(mask);
108            *value = f(bits.as_slice());
109        }
110        Self { coeffs: out }
111    }
112}
113
114/// The set-bit positions of `mask`, low to high — the differentiation slots of
115/// that coefficient.
116fn bit_positions(mask: usize) -> crate::jet_algebra::SlotBuf {
117    let mut out = crate::jet_algebra::SlotBuf::new();
118    let mut m = mask;
119    while m != 0 {
120        let bit = m.trailing_zeros() as usize;
121        out.push_slot(bit);
122        m &= m - 1;
123    }
124    out
125}
126
127/// Combine a slot-group (list of bit positions) back into a sub-mask.
128fn mask_of(slots: &[usize]) -> usize {
129    slots.iter().fold(0usize, |acc, &b| acc | (1usize << b))
130}
131
132// #932-2 cutover: `MultiDirJet::bilinear` (the 4-coeff `[base, d1, d2, d12]`
133// constructor) and `MultiDirJet::sub` are consumed ONLY by the now test-only hand
134// survival directional/bidirectional oracle (the production flex jet path uses the
135// `flex_jet` runtime jet algebra, not `MultiDirJet`). After the #1521 crate split
136// moved `MultiDirJet` into `gam-math`, those oracle tests live in the dependent
137// `gam` crate, where a `#[cfg(test)]` gate in *this* crate is inactive — so the
138// methods must be plain `pub` inherent methods to be reachable cross-crate. They
139// carry no dead-code cost because `pub` items are part of the crate's public API.
140// Bodies are byte-identical to their former gated form.
141impl MultiDirJet {
142    pub fn bilinear(base: f64, d1: f64, d2: f64, d12: f64) -> Self {
143        Self {
144            coeffs: vec![base, d1, d2, d12],
145        }
146    }
147
148    pub fn sub(&self, other: &Self) -> Self {
149        Self {
150            coeffs: self
151                .coeffs
152                .iter()
153                .zip(other.coeffs.iter())
154                .map(|(lhs, rhs)| lhs - rhs)
155                .collect(),
156        }
157    }
158}