use crate::bytecode_tape::BytecodeTape;
use crate::taylor_dyn::{TaylorArenaLocal, TaylorDyn, TaylorDynGuard};
use crate::Float;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MultiIndex {
orders: Vec<u8>,
}
impl MultiIndex {
#[must_use]
pub fn new(orders: &[u8]) -> Self {
assert!(
!orders.is_empty(),
"multi-index must have at least one variable"
);
MultiIndex {
orders: orders.to_vec(),
}
}
#[must_use]
pub fn diagonal(num_vars: usize, var: usize, order: u8) -> Self {
assert!(var < num_vars, "var ({}) >= num_vars ({})", var, num_vars);
assert!(order > 0, "order must be > 0");
let mut orders = vec![0u8; num_vars];
orders[var] = order;
MultiIndex { orders }
}
#[must_use]
pub fn partial(num_vars: usize, var: usize) -> Self {
Self::diagonal(num_vars, var, 1)
}
#[must_use]
pub fn total_order(&self) -> usize {
self.orders.iter().map(|&o| o as usize).sum()
}
#[must_use]
pub fn active_vars(&self) -> Vec<(usize, u8)> {
self.orders
.iter()
.enumerate()
.filter(|(_, &o)| o > 0)
.map(|(i, &o)| (i, o))
.collect()
}
#[must_use]
pub fn num_vars(&self) -> usize {
self.orders.len()
}
#[must_use]
pub fn orders(&self) -> &[u8] {
&self.orders
}
fn active_var_set(&self) -> Vec<usize> {
self.orders
.iter()
.enumerate()
.filter(|(_, &o)| o > 0)
.map(|(i, _)| i)
.collect()
}
}
fn partitions_with_support(k: usize, slots: &[usize]) -> Vec<Vec<(usize, usize)>> {
let mut results = Vec::new();
let mut current = Vec::new();
partitions_recurse(k, slots, 0, &mut current, &mut results);
results
}
fn partitions_recurse(
remaining: usize,
slots: &[usize],
start_idx: usize,
current: &mut Vec<(usize, usize)>,
results: &mut Vec<Vec<(usize, usize)>>,
) {
if remaining == 0 {
results.push(current.clone());
return;
}
for idx in start_idx..slots.len() {
let s = slots[idx];
if s > remaining {
continue;
}
let max_mult = remaining / s;
for mult in 1..=max_mult {
current.push((s, mult));
partitions_recurse(remaining - s * mult, slots, idx + 1, current, results);
current.pop();
}
}
}
fn extraction_prefactor<F: Float>(slot_assignments: &[(usize, u8)]) -> F {
let mut prefactor = F::one();
for &(slot, order) in slot_assignments {
let mut q_fact = F::one();
for i in 2..=(order as usize) {
q_fact = q_fact * F::from(i).unwrap();
}
let mut j_fact = F::one();
for i in 2..=slot {
j_fact = j_fact * F::from(i).unwrap();
}
let mut j_fact_pow = F::one();
for _ in 0..order {
j_fact_pow = j_fact_pow * j_fact;
}
prefactor = prefactor * q_fact * j_fact_pow;
}
if prefactor.is_finite() {
return prefactor;
}
let mut log_pref = F::zero();
for &(slot, order) in slot_assignments {
for i in 2..=(order as usize) {
log_pref = log_pref + F::from(i).unwrap().ln();
}
let mut log_j_fact = F::zero();
for i in 2..=slot {
log_j_fact = log_j_fact + F::from(i).unwrap().ln();
}
log_pref = log_pref + F::from(order as usize).unwrap() * log_j_fact;
}
log_pref.exp()
}
#[derive(Clone, Debug)]
struct Extraction<F> {
result_index: usize,
output_coeff_index: usize,
prefactor: F,
}
#[derive(Clone, Debug)]
struct PushforwardGroup<F> {
jet_order: usize,
input_coeffs: Vec<(usize, usize, F)>,
extractions: Vec<Extraction<F>>,
}
#[derive(Clone, Debug)]
pub struct JetPlan<F> {
max_jet_order: usize,
groups: Vec<PushforwardGroup<F>>,
multi_indices: Vec<MultiIndex>,
}
const PRIMES: [usize; 20] = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
];
fn try_slots<F: Float>(
var_slot: &[(usize, usize)],
multi_indices_with_idx: &[(usize, &MultiIndex)],
) -> Result<(Vec<Extraction<F>>, usize), ()> {
let group_slots: Vec<usize> = var_slot.iter().map(|&(_, s)| s).collect();
let mut extractions = Vec::new();
let mut max_k = 0usize;
for &(result_index, mi) in multi_indices_with_idx {
let active = mi.active_vars();
if active.is_empty() {
extractions.push(Extraction {
result_index,
output_coeff_index: 0,
prefactor: F::one(),
});
continue;
}
let slot_orders: Vec<(usize, u8)> = active
.iter()
.map(|&(var, order)| {
let slot = var_slot.iter().find(|(v, _)| *v == var).unwrap().1;
(slot, order)
})
.collect();
let k: usize = slot_orders.iter().map(|&(s, q)| s * q as usize).sum();
let partitions = partitions_with_support(k, &group_slots);
let mut target_partition: Vec<(usize, usize)> = slot_orders
.iter()
.map(|&(slot, order)| (slot, order as usize))
.collect();
target_partition.sort_by_key(|&(s, _)| s);
let collision = partitions.iter().any(|p| {
let mut sorted = p.clone();
sorted.sort_by_key(|&(s, _)| s);
sorted != target_partition
});
if collision {
return Err(());
}
let prefactor = extraction_prefactor::<F>(&slot_orders);
max_k = max_k.max(k);
extractions.push(Extraction {
result_index,
output_coeff_index: k,
prefactor,
});
}
Ok((extractions, max_k))
}
fn plan_group<F: Float>(
active_var_set: &[usize],
multi_indices_with_idx: &[(usize, &MultiIndex)],
) -> PushforwardGroup<F> {
let t = active_var_set.len();
assert!(
t <= PRIMES.len(),
"too many active variables ({}) — max supported is {}",
t,
PRIMES.len()
);
let mut var_max_order: Vec<(usize, u8)> = active_var_set
.iter()
.map(|&var| {
let max_ord = multi_indices_with_idx
.iter()
.map(|(_, mi)| mi.orders()[var])
.max()
.unwrap_or(0);
(var, max_ord)
})
.collect();
var_max_order.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0)));
let max_offset = PRIMES.len() - t;
for offset in 0..=max_offset {
let var_slot: Vec<(usize, usize)> = var_max_order
.iter()
.enumerate()
.map(|(i, &(var, _))| (var, PRIMES[offset + i]))
.collect();
if let Ok((extractions, max_k)) = try_slots::<F>(&var_slot, multi_indices_with_idx) {
let input_coeffs: Vec<(usize, usize, F)> = var_slot
.iter()
.map(|&(var, slot)| {
let mut factorial = F::one();
for i in 2..=slot {
factorial = factorial * F::from(i).unwrap();
}
(var, slot, F::one() / factorial)
})
.collect();
return PushforwardGroup {
jet_order: max_k + 1,
input_coeffs,
extractions,
};
}
}
panic!(
"failed to find collision-free slot assignment for active vars {:?}",
active_var_set
);
}
impl<F: Float> JetPlan<F> {
#[must_use]
pub fn plan(num_vars: usize, multi_indices: &[MultiIndex]) -> Self {
assert!(
!multi_indices.is_empty(),
"must provide at least one multi-index"
);
for mi in multi_indices {
assert_eq!(
mi.num_vars(),
num_vars,
"multi-index num_vars ({}) != expected ({})",
mi.num_vars(),
num_vars
);
}
type GroupEntry<'a> = (Vec<usize>, Vec<(usize, &'a MultiIndex)>);
let mut group_map: Vec<GroupEntry<'_>> = Vec::new();
for (i, mi) in multi_indices.iter().enumerate() {
let active_set = mi.active_var_set();
if let Some(entry) = group_map.iter_mut().find(|(set, _)| *set == active_set) {
entry.1.push((i, mi));
} else {
group_map.push((active_set, vec![(i, mi)]));
}
}
let mut groups = Vec::with_capacity(group_map.len());
let mut max_jet_order = 1;
for (active_set, members) in &group_map {
let group = plan_group::<F>(active_set, members);
max_jet_order = max_jet_order.max(group.jet_order);
groups.push(group);
}
JetPlan {
max_jet_order,
groups,
multi_indices: multi_indices.to_vec(),
}
}
#[must_use]
pub fn jet_order(&self) -> usize {
self.max_jet_order
}
#[must_use]
pub fn multi_indices(&self) -> Vec<MultiIndex> {
self.multi_indices.clone()
}
}
#[derive(Clone, Debug)]
pub struct DiffOpResult<F> {
pub value: F,
pub derivatives: Vec<F>,
pub multi_indices: Vec<MultiIndex>,
}
pub fn eval_dyn<F: Float + TaylorArenaLocal>(
plan: &JetPlan<F>,
tape: &BytecodeTape<F>,
x: &[F],
) -> DiffOpResult<F> {
let n = tape.num_inputs();
assert_eq!(
x.len(),
n,
"x.len() ({}) must match tape.num_inputs() ({})",
x.len(),
n
);
let num_results = plan.multi_indices.len();
let mut derivatives = vec![F::zero(); num_results];
let mut value = F::zero();
for group in &plan.groups {
let _guard = TaylorDynGuard::<F>::new(group.jet_order);
let inputs: Vec<TaylorDyn<F>> = (0..n)
.map(|i| {
let mut coeffs = vec![F::zero(); group.jet_order];
coeffs[0] = x[i];
for &(var, slot, inv_fact) in &group.input_coeffs {
if var == i && slot < group.jet_order {
coeffs[slot] = inv_fact;
}
}
TaylorDyn::from_coeffs(&coeffs)
})
.collect();
let mut buf = Vec::new();
tape.forward_tangent(&inputs, &mut buf);
let out_coeffs = buf[tape.output_index()].coeffs();
value = out_coeffs[0];
for extraction in &group.extractions {
derivatives[extraction.result_index] =
out_coeffs[extraction.output_coeff_index] * extraction.prefactor;
}
}
DiffOpResult {
value,
derivatives,
multi_indices: plan.multi_indices.clone(),
}
}
pub fn mixed_partial<F: Float + TaylorArenaLocal>(
tape: &BytecodeTape<F>,
x: &[F],
orders: &[u8],
) -> (F, F) {
assert_eq!(
orders.len(),
tape.num_inputs(),
"mixed_partial: orders.len() must equal tape.num_inputs() \
(got orders.len()={}, tape.num_inputs()={})",
orders.len(),
tape.num_inputs(),
);
let mi = MultiIndex::new(orders);
let plan = JetPlan::plan(orders.len(), &[mi]);
let result = eval_dyn(&plan, tape, x);
(result.value, result.derivatives[0])
}
#[allow(clippy::needless_range_loop)]
pub fn hessian<F: Float + TaylorArenaLocal>(
tape: &BytecodeTape<F>,
x: &[F],
) -> (F, Vec<F>, Vec<Vec<F>>) {
let n = tape.num_inputs();
assert_eq!(x.len(), n, "x.len() must match tape.num_inputs()");
let mut indices = Vec::with_capacity(n + n * (n + 1) / 2);
for i in 0..n {
indices.push(MultiIndex::partial(n, i));
}
for i in 0..n {
for j in i..n {
let mut orders = vec![0u8; n];
if i == j {
orders[i] = 2;
} else {
orders[i] = 1;
orders[j] = 1;
}
indices.push(MultiIndex::new(&orders));
}
}
let plan = JetPlan::plan(n, &indices);
let result = eval_dyn(&plan, tape, x);
let gradient: Vec<F> = result.derivatives[..n].to_vec();
let mut hess = vec![vec![F::zero(); n]; n];
let mut idx = n;
for i in 0..n {
for j in i..n {
let val = result.derivatives[idx];
hess[i][j] = val;
hess[j][i] = val;
idx += 1;
}
}
(result.value, gradient, hess)
}
#[derive(Clone, Debug)]
pub struct DiffOp<F> {
terms: Vec<(F, MultiIndex)>,
num_vars: usize,
}
impl<F: Float> DiffOp<F> {
#[must_use]
pub fn new(num_vars: usize, terms: Vec<(F, MultiIndex)>) -> Self {
assert!(!terms.is_empty(), "DiffOp must have at least one term");
for (_, mi) in &terms {
assert_eq!(
mi.num_vars(),
num_vars,
"multi-index num_vars ({}) != expected ({})",
mi.num_vars(),
num_vars
);
}
DiffOp { terms, num_vars }
}
pub fn from_orders(num_vars: usize, terms: &[(F, &[u8])]) -> Self {
let terms: Vec<(F, MultiIndex)> = terms
.iter()
.map(|&(c, orders)| (c, MultiIndex::new(orders)))
.collect();
Self::new(num_vars, terms)
}
#[must_use]
pub fn laplacian(n: usize) -> Self {
let terms = (0..n)
.map(|j| (F::one(), MultiIndex::diagonal(n, j, 2)))
.collect();
DiffOp { terms, num_vars: n }
}
#[must_use]
pub fn biharmonic(n: usize) -> Self {
let two = F::one() + F::one();
let mut terms: Vec<(F, MultiIndex)> = (0..n)
.map(|j| (F::one(), MultiIndex::diagonal(n, j, 4)))
.collect();
for j in 0..n {
for k in (j + 1)..n {
let mut orders = vec![0u8; n];
orders[j] = 2;
orders[k] = 2;
terms.push((two, MultiIndex::new(&orders)));
}
}
DiffOp { terms, num_vars: n }
}
#[must_use]
pub fn diagonal(n: usize, k: u8) -> Self {
assert!(k >= 1, "diagonal order must be >= 1");
let terms = (0..n)
.map(|j| (F::one(), MultiIndex::diagonal(n, j, k)))
.collect();
DiffOp { terms, num_vars: n }
}
#[must_use]
pub fn terms(&self) -> &[(F, MultiIndex)] {
&self.terms
}
#[must_use]
pub fn num_vars(&self) -> usize {
self.num_vars
}
#[must_use]
pub fn order(&self) -> usize {
self.terms
.iter()
.map(|(_, mi)| mi.total_order())
.max()
.unwrap_or(0)
}
#[must_use]
pub fn is_diagonal(&self) -> bool {
self.terms.iter().all(|(_, mi)| mi.active_vars().len() <= 1)
}
#[must_use]
pub fn split_by_order(&self) -> Vec<DiffOp<F>> {
let mut order_map: Vec<(usize, Vec<(F, MultiIndex)>)> = Vec::new();
for (c, mi) in &self.terms {
let ord = mi.total_order();
if let Some(entry) = order_map.iter_mut().find(|(o, _)| *o == ord) {
entry.1.push((*c, mi.clone()));
} else {
order_map.push((ord, vec![(*c, mi.clone())]));
}
}
order_map.sort_by_key(|(o, _)| *o);
order_map
.into_iter()
.map(|(_, terms)| DiffOp {
terms,
num_vars: self.num_vars,
})
.collect()
}
}
impl<F: Float + TaylorArenaLocal> DiffOp<F> {
pub fn eval(&self, tape: &BytecodeTape<F>, x: &[F]) -> (F, F) {
let multi_indices: Vec<MultiIndex> = self.terms.iter().map(|(_, mi)| mi.clone()).collect();
let plan = JetPlan::plan(self.num_vars, &multi_indices);
let result = eval_dyn(&plan, tape, x);
let mut op_value = F::zero();
for (i, (c, _)) in self.terms.iter().enumerate() {
op_value = op_value + *c * result.derivatives[i];
}
(result.value, op_value)
}
#[must_use]
pub fn sparse_distribution(&self) -> SparseSamplingDistribution<F> {
let k = self.terms[0].1.total_order();
for (_, mi) in &self.terms {
assert_eq!(
mi.total_order(),
k,
"sparse_distribution requires homogeneous operator: \
found order {} and order {}",
k,
mi.total_order()
);
}
let mut entries = Vec::with_capacity(self.terms.len());
let mut cumulative = F::zero();
for (coeff, mi) in &self.terms {
let abs_c = coeff.abs();
cumulative = cumulative + abs_c;
let active_set = mi.active_vars().iter().map(|&(v, _)| v).collect::<Vec<_>>();
let group = plan_group::<F>(&active_set, &[(0, mi)]);
let extraction = &group.extractions[0];
entries.push(SparseJetEntry {
cumulative_weight: cumulative,
input_coeffs: group.input_coeffs.clone(),
output_coeff_index: extraction.output_coeff_index,
extraction_prefactor: extraction.prefactor,
sign: coeff.signum(),
});
}
SparseSamplingDistribution {
jet_order: entries
.iter()
.map(|e| e.output_coeff_index)
.max()
.unwrap_or(1),
entries,
total_weight: cumulative,
}
}
}
#[derive(Clone, Debug)]
pub struct SparseSamplingDistribution<F> {
jet_order: usize,
entries: Vec<SparseJetEntry<F>>,
total_weight: F,
}
#[derive(Clone, Debug)]
struct SparseJetEntry<F> {
cumulative_weight: F,
input_coeffs: Vec<(usize, usize, F)>,
output_coeff_index: usize,
extraction_prefactor: F,
sign: F,
}
pub struct SparseJetEntryRef<'a, F> {
entry: &'a SparseJetEntry<F>,
}
impl<'a, F: Float> SparseJetEntryRef<'a, F> {
#[must_use]
pub fn input_coeffs(&self) -> &[(usize, usize, F)] {
&self.entry.input_coeffs
}
#[must_use]
pub fn output_coeff_index(&self) -> usize {
self.entry.output_coeff_index
}
#[must_use]
pub fn extraction_prefactor(&self) -> F {
self.entry.extraction_prefactor
}
#[must_use]
pub fn sign(&self) -> F {
self.entry.sign
}
}
impl<F: Float> SparseSamplingDistribution<F> {
pub fn sample_index(&self, uniform_01: F) -> usize {
let target = uniform_01 * self.total_weight;
let mut lo = 0;
let mut hi = self.entries.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.entries[mid].cumulative_weight <= target {
lo = mid + 1;
} else {
hi = mid;
}
}
lo.min(self.entries.len() - 1)
}
pub fn normalization(&self) -> F {
self.total_weight
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn jet_order(&self) -> usize {
self.jet_order
}
pub fn entry(&self, index: usize) -> SparseJetEntryRef<'_, F> {
SparseJetEntryRef {
entry: &self.entries[index],
}
}
}