use std::collections::{BTreeMap, BTreeSet, HashMap};
pub type Axis = &'static str;
pub type Coord = u64;
pub type Money = i128;
pub const ANY: Coord = u64::MAX;
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
pub struct Key(Vec<(Axis, Coord)>);
impl Key {
pub fn of(pairs: &[(Axis, Coord)]) -> Key {
let mut v = pairs.to_vec();
v.sort_unstable();
v.dedup();
Key(v)
}
pub fn coords(&self) -> &[(Axis, Coord)] {
&self.0
}
pub fn get(&self, axis: Axis) -> Option<Coord> {
self.0.iter().find(|(a, _)| *a == axis).map(|(_, c)| *c)
}
pub fn project(&self, axes: &[Axis]) -> Key {
self.project_set(&axes.iter().copied().collect())
}
pub fn with(&self, axis: Axis, coord: Coord) -> Key {
let mut m: BTreeMap<Axis, Coord> = self.0.iter().copied().collect();
m.insert(axis, coord);
Key(m.into_iter().collect())
}
pub fn without(&self, axis: Axis) -> Key {
Key(self.0.iter().copied().filter(|(a, _)| *a != axis).collect())
}
fn project_set(&self, axes: &BTreeSet<Axis>) -> Key {
Key(self
.0
.iter()
.copied()
.filter(|(a, _)| axes.contains(a))
.collect())
}
fn merge(&self, other: &Key) -> Key {
let mut m: BTreeMap<Axis, Coord> = self.0.iter().copied().collect();
for (a, c) in &other.0 {
m.entry(a).or_insert(*c);
}
Key(m.into_iter().collect())
}
fn with_any(&self, axes: &BTreeSet<Axis>) -> Key {
let mut m: BTreeMap<Axis, Coord> = self.0.iter().copied().collect();
for &a in axes {
m.insert(a, ANY);
}
Key(m.into_iter().collect())
}
fn has_any(&self) -> bool {
self.0.iter().any(|&(_, c)| c == ANY)
}
}
#[derive(Clone, Debug, Default)]
pub struct Measure {
axes: BTreeSet<Axis>,
cells: HashMap<Key, Money>,
}
impl Measure {
pub fn with_axes(axes: BTreeSet<Axis>) -> Measure {
Measure {
axes,
cells: HashMap::new(),
}
}
pub fn build(axes: &[Axis], cells: &[(&[(Axis, Coord)], Money)]) -> Measure {
let mut m = Measure::with_axes(axes.iter().copied().collect());
for (pairs, v) in cells {
m.add(Key::of(pairs), *v);
}
m
}
pub fn axes(&self) -> &BTreeSet<Axis> {
&self.axes
}
pub fn get(&self, k: &Key) -> Money {
self.cells.get(k).copied().unwrap_or(0)
}
pub fn cells(&self) -> impl Iterator<Item = (&Key, Money)> {
self.cells.iter().map(|(k, v)| (k, *v))
}
pub fn len(&self) -> usize {
self.cells.len()
}
pub fn is_empty(&self) -> bool {
self.cells.is_empty()
}
pub fn total(&self) -> Money {
self.cells.values().sum()
}
fn add(&mut self, k: Key, v: Money) {
if v == 0 {
return;
}
for &(a, _) in &k.0 {
self.axes.insert(a);
}
use std::collections::hash_map::Entry;
match self.cells.entry(k) {
Entry::Occupied(mut o) => {
*o.get_mut() += v;
if *o.get() == 0 {
o.remove();
}
}
Entry::Vacant(va) => {
va.insert(v);
}
}
}
pub fn rekey(&self, f: impl Fn(&Key) -> Key) -> Measure {
let mut out = Measure::default();
for (k, v) in &self.cells {
out.add(f(k), *v);
}
out
}
pub fn marginalize(&self, keep: &[Axis]) -> Measure {
let keep: BTreeSet<Axis> = keep.iter().copied().collect();
self.rekey(|k| k.project_set(&keep))
}
pub fn combine(&self, rhs: &Measure, f: impl Fn(Money, Money) -> Money) -> Measure {
let shared: BTreeSet<Axis> = self.axes.intersection(&rhs.axes).copied().collect();
let out_axes: BTreeSet<Axis> = self.axes.union(&rhs.axes).copied().collect();
let lb = self.bucket(&shared);
let rb = rhs.bucket(&shared);
let mut out = Measure::with_axes(out_axes);
let shared_keys: BTreeSet<&Key> = lb.keys().chain(rb.keys()).collect();
for s in shared_keys {
match (lb.get(s), rb.get(s)) {
(Some(ls), Some(rs)) => {
for (lk, a) in ls {
for (rk, b) in rs {
out.add(lk.merge(rk), f(*a, *b));
}
}
}
(Some(ls), None) if rhs.axes.is_subset(&shared) => {
for (lk, a) in ls {
out.add(lk.clone(), f(*a, 0));
}
}
(None, Some(rs)) if self.axes.is_subset(&shared) => {
for (rk, b) in rs {
out.add(rk.clone(), f(0, *b));
}
}
_ => {}
}
}
out
}
pub fn allocate(&self, driver: &Measure) -> Measure {
let shared: BTreeSet<Axis> = self.axes.intersection(&driver.axes).copied().collect();
let out_axes: BTreeSet<Axis> = self.axes.union(&driver.axes).copied().collect();
let new_axes: BTreeSet<Axis> = driver.axes.difference(&self.axes).copied().collect();
let pool_b = self.bucket(&shared);
let drv_b = driver.bucket(&shared);
let mut out = Measure::with_axes(out_axes);
for (s, prows) in &pool_b {
let drows = drv_b
.get(s)
.filter(|d| d.iter().map(|(_, w)| *w).sum::<i128>() > 0);
let Some(drows) = drows else {
for (pk, amt) in prows {
out.add(pk.with_any(&new_axes), *amt); }
continue;
};
let weights: Vec<Money> = drows.iter().map(|(_, w)| *w).collect();
let wtotal: i128 = weights.iter().sum();
let dkeys: Vec<&Key> = drows.iter().map(|(dk, _)| dk).collect();
for (pk, amt) in prows {
let shares = largest_remainder(*amt, &weights, wtotal, &dkeys);
for ((dk, _), share) in drows.iter().zip(&shares) {
out.add(pk.merge(dk), *share);
}
}
}
out
}
pub fn select(&self, keep: impl Fn(&Key, Money) -> bool) -> Measure {
let mut out = Measure::with_axes(self.axes.clone());
for (k, v) in &self.cells {
if keep(k, *v) {
out.add(k.clone(), *v);
}
}
out
}
pub fn partition(&self, pred: impl Fn(&Key, Money) -> bool) -> (Measure, Measure) {
let mut yes = Measure::with_axes(self.axes.clone());
let mut no = Measure::with_axes(self.axes.clone());
for (k, v) in &self.cells {
if pred(k, *v) {
yes.add(k.clone(), *v);
} else {
no.add(k.clone(), *v);
}
}
(yes, no)
}
pub fn slice(&self, at: &[(Axis, Coord)]) -> Measure {
let at = at.to_vec();
let drop: BTreeSet<Axis> = at.iter().map(|(a, _)| *a).collect();
let keep: Vec<Axis> = self
.axes
.iter()
.copied()
.filter(|a| !drop.contains(a))
.collect();
self.select(move |k, _| at.iter().all(|&(a, c)| k.get(a) == Some(c)))
.marginalize(&keep)
}
pub fn group_by(&self, axis: Axis, f: impl Fn(Coord, &Measure) -> Measure) -> Measure {
let mut groups: BTreeMap<Coord, Measure> = BTreeMap::new();
for (k, v) in &self.cells {
let c = k.get(axis).unwrap_or(ANY);
groups
.entry(c)
.or_insert_with(|| Measure::with_axes(self.axes.clone()))
.add(k.clone(), *v);
}
let mut out = Measure::default();
for (c, sub) in &groups {
out = out.combine(&f(*c, sub), |a, b| a + b);
}
out
}
pub fn pending(&self) -> Measure {
self.select(|k, _| k.has_any())
}
pub fn rake(&self, dim: Axis, driver: &Measure) -> Measure {
let (pending, resolved) = self.partition(|k, _| k.get(dim) == Some(ANY));
if pending.is_empty() {
return resolved;
}
let raked = pending.rekey(|k| k.without(dim)).allocate(driver);
resolved.combine(&raked, |a, b| a + b)
}
pub fn vacuum(&self, eps: Money, order: &[Axis]) -> Measure {
let mut current = self.clone();
for &dim in order {
if !current.axes.contains(&dim) {
continue;
}
let (small, big) = current.partition(|_, v| v.abs() < eps);
let coarsened = small.rekey(|k| k.with(dim, ANY));
current = big.combine(&coarsened, |a, b| a + b);
}
current
}
fn bucket(&self, shared: &BTreeSet<Axis>) -> HashMap<Key, Vec<(Key, Money)>> {
let mut m: HashMap<Key, Vec<(Key, Money)>> = HashMap::new();
for (k, v) in &self.cells {
m.entry(k.project_set(shared))
.or_default()
.push((k.clone(), *v));
}
m
}
}
fn largest_remainder(amt: Money, weights: &[Money], wtotal: i128, keys: &[&Key]) -> Vec<Money> {
let n = weights.len();
let mut base = vec![0i128; n];
let mut rem = vec![0i128; n];
let mut sum_base = 0i128;
for i in 0..n {
let num = amt
.checked_mul(weights[i])
.expect("allocate: weight*amount overflow");
base[i] = num.div_euclid(wtotal);
rem[i] = num.rem_euclid(wtotal);
sum_base += base[i];
}
let deficit = (amt - sum_base) as usize; let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| rem[b].cmp(&rem[a]).then_with(|| keys[a].cmp(keys[b])));
for &i in idx.iter().take(deficit) {
base[i] += 1;
}
base
}
pub fn fixed_point<S>(
init: S,
step: impl Fn(&S) -> S,
converged: impl Fn(&S, &S) -> bool,
max_passes: usize,
) -> S {
let mut state = init;
for _ in 0..max_passes {
let next = step(&state);
if converged(&next, &state) {
return next;
}
state = next;
}
state
}
#[cfg(test)]
mod tests;