use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
pub mod flow;
pub use flow::{Allocation, ExtId, FlowSpec, flow};
use std::hash::Hash;
use std::marker::PhantomData;
#[derive(Clone)]
pub struct Item<E> {
pub id: ExtId,
pub original: i64,
pub amount: i64,
pub data: E,
}
impl<E> Item<E> {
pub fn new(id: ExtId, amount: i64, data: E) -> Self {
Item {
id,
original: amount,
amount,
data,
}
}
}
#[derive(Debug, Clone)]
pub struct Group {
pub members: Vec<Allocation>,
pub origin: String,
pub net: i64,
pub reason: Option<String>,
}
impl Group {
pub fn member_ids(&self) -> Vec<ExtId> {
self.members.iter().map(|a| a.id).collect()
}
pub fn size(&self) -> usize {
self.members.len()
}
pub fn abs_net(&self) -> i64 {
self.net.abs()
}
pub fn max_abs(&self) -> i64 {
self.members
.iter()
.map(|a| a.amount.abs())
.max()
.unwrap_or(0)
}
pub fn min_abs(&self) -> i64 {
self.members
.iter()
.map(|a| a.amount.abs())
.filter(|&v| v > 0)
.min()
.unwrap_or(0)
}
pub fn min_side(&self) -> usize {
let pos = self.members.iter().filter(|a| a.amount > 0).count();
let neg = self.members.iter().filter(|a| a.amount < 0).count();
pos.min(neg)
}
pub fn clean(&self, tol: Tol) -> bool {
self.abs_net() <= tol.slack_for(self.members.iter().map(|a| a.amount))
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(untagged))]
pub enum Tol {
Abs(i64),
Rel { bps: i64, floor: i64 },
RelMax { bps: i64, floor: i64 },
}
impl Tol {
pub fn slack(&self, scale: i64) -> i64 {
match *self {
Tol::Abs(t) => t,
Tol::Rel { bps, floor } | Tol::RelMax { bps, floor } => {
let rel = (scale.unsigned_abs() as i128 * bps.max(0) as i128 / 10_000) as i64;
rel.max(floor.max(0))
}
}
}
pub fn slack_for(&self, legs: impl Iterator<Item = i64>) -> i64 {
let scale = match self {
Tol::Abs(_) => 0,
Tol::Rel { .. } => legs.map(i64::abs).filter(|&v| v > 0).min().unwrap_or(0),
Tol::RelMax { .. } => legs.map(i64::abs).max().unwrap_or(0),
};
self.slack(scale)
}
}
impl From<i64> for Tol {
fn from(t: i64) -> Self {
Tol::Abs(t)
}
}
pub struct Resolution<E> {
pub groups: Vec<Group>,
pub residual: Vec<Item<E>>,
}
pub trait Strategy<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E>;
}
impl<E> Strategy<E> for Box<dyn Strategy<E>> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
(**self).run(bag)
}
}
struct Seq<E> {
steps: Vec<Box<dyn Strategy<E>>>,
}
impl<E> Strategy<E> for Seq<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
#[cfg(not(target_arch = "wasm32"))]
let timed = std::env::var_os("FLORECON_TIME").is_some();
#[cfg(target_arch = "wasm32")]
let timed = false;
let mut groups = Vec::new();
let mut residual = bag;
for (i, step) in self.steps.iter_mut().enumerate() {
let n_in = residual.len();
let t = timed.then(std::time::Instant::now);
let r = step.run(residual);
if let Some(t) = t {
eprintln!(
" seq step {i}: {n_in:>7} in -> {:>7} grouped, {:>7} residual [{:>6.1} ms]",
r.groups.iter().map(|g| g.members.len()).sum::<usize>(),
r.residual.len(),
t.elapsed().as_secs_f64() * 1000.0,
);
}
groups.extend(r.groups);
residual = r.residual;
}
Resolution { groups, residual }
}
}
pub fn seq<E: 'static>(steps: Vec<Box<dyn Strategy<E>>>) -> Box<dyn Strategy<E>> {
Box::new(Seq { steps })
}
struct Labeled<E> {
tag: String,
inner: Box<dyn Strategy<E>>,
}
impl<E> Strategy<E> for Labeled<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut r = self.inner.run(bag);
for g in &mut r.groups {
g.reason = Some(match g.reason.take() {
Some(detail) => format!("{}: {}", self.tag, detail),
None => self.tag.clone(),
});
}
r
}
}
pub fn labeled<E: 'static>(
tag: impl Into<String>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(Labeled {
tag: tag.into(),
inner,
})
}
struct Filter<E, FP> {
pred: FP,
inner: Box<dyn Strategy<E>>,
}
impl<E, FP> Strategy<E> for Filter<E, FP>
where
E: Clone,
FP: Fn(&Group) -> bool,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let src: HashMap<ExtId, Item<E>> = bag.iter().map(|i| (i.id, i.clone())).collect();
let mut r = self.inner.run(bag);
let mut kept = Vec::with_capacity(r.groups.len());
let mut residual_ix: HashMap<ExtId, usize> = r
.residual
.iter()
.enumerate()
.map(|(ix, item)| (item.id, ix))
.collect();
for g in r.groups.drain(..) {
if (self.pred)(&g) {
kept.push(g);
continue;
}
for a in &g.members {
match residual_ix.get(&a.id) {
Some(&ix) => r.residual[ix].amount += a.amount,
None => {
if let Some(orig) = src.get(&a.id) {
residual_ix.insert(a.id, r.residual.len());
r.residual.push(Item {
id: a.id,
original: orig.original,
amount: a.amount,
data: orig.data.clone(),
});
}
}
}
}
}
Resolution {
groups: kept,
residual: r.residual,
}
}
}
pub fn accept_if<E: Clone + 'static, FP>(
pred: FP,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FP: Fn(&Group) -> bool + 'static,
{
Box::new(Filter { pred, inner })
}
struct WholeNet<E> {
tol: Tol,
inner: Box<dyn Strategy<E>>,
}
impl<E: Clone> Strategy<E> for WholeNet<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let src: HashMap<ExtId, Item<E>> = bag.iter().map(|i| (i.id, i.clone())).collect();
let r = self.inner.run(bag);
let mut resid: HashMap<ExtId, Item<E>> = HashMap::new();
for it in r.residual {
resid
.entry(it.id)
.and_modify(|e| e.amount += it.amount)
.or_insert(it);
}
let comps = group_components(&r.groups);
let mut out_groups: Vec<Group> = Vec::new();
let mut out_residual: Vec<Item<E>> = Vec::new();
for comp in comps {
let mut member_ids: Vec<ExtId> = Vec::new();
let mut seen: HashSet<ExtId> = HashSet::new();
for &gi in &comp {
for a in &r.groups[gi].members {
if seen.insert(a.id) {
member_ids.push(a.id);
}
}
}
member_ids.sort_unstable();
let wholes: Vec<(ExtId, i64)> = member_ids
.iter()
.filter_map(|&id| src.get(&id).map(|i| (id, i.original)))
.collect();
let net: i64 = wholes.iter().map(|&(_, o)| o).sum();
let tol = self.tol.slack_for(wholes.iter().map(|&(_, o)| o));
if net.abs() <= tol {
for &(id, _) in &wholes {
resid.remove(&id);
}
let (origin, reason) = if comp.len() == 1 {
let g = &r.groups[comp[0]];
(g.origin.clone(), g.reason.clone())
} else {
("settlement".to_string(), None)
};
out_groups.push(Group {
members: wholes
.iter()
.map(|&(id, o)| Allocation { id, amount: o })
.collect(),
origin,
net,
reason,
});
} else {
for &(id, o) in &wholes {
resid.remove(&id);
if let Some(it) = src.get(&id) {
out_residual.push(Item {
id,
original: it.original,
amount: o,
data: it.data.clone(),
});
}
}
}
}
out_residual.extend(resid.into_values());
out_residual.sort_by_key(|i| i.id);
Resolution {
groups: out_groups,
residual: out_residual,
}
}
}
pub fn whole_net<E: Clone + 'static>(
tol: impl Into<Tol>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(WholeNet {
tol: tol.into(),
inner,
})
}
struct Coalesce<E> {
origin: String,
inner: Box<dyn Strategy<E>>,
}
fn group_components(groups: &[Group]) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..groups.len()).collect();
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]]; x = parent[x];
}
x
}
let mut first: HashMap<ExtId, usize> = HashMap::new();
for (gi, g) in groups.iter().enumerate() {
for a in &g.members {
match first.get(&a.id) {
Some(&fj) => {
let (ra, rb) = (find(&mut parent, gi), find(&mut parent, fj));
if ra != rb {
parent[ra] = rb;
}
}
None => {
first.insert(a.id, gi);
}
}
}
}
let mut order: Vec<usize> = Vec::new();
let mut buckets: HashMap<usize, Vec<usize>> = HashMap::new();
for gi in 0..groups.len() {
let r = find(&mut parent, gi);
if !buckets.contains_key(&r) {
order.push(r);
}
buckets.entry(r).or_default().push(gi);
}
order
.into_iter()
.map(|r| buckets.remove(&r).unwrap())
.collect()
}
fn coalesce_groups(groups: &[Group], origin: &str) -> Vec<Group> {
let comps = group_components(groups);
let mut out = Vec::with_capacity(comps.len());
for comp in &comps {
let mut by_id: BTreeMap<ExtId, i64> = BTreeMap::new();
for &gi in comp {
for a in &groups[gi].members {
*by_id.entry(a.id).or_insert(0) += a.amount;
}
}
let members: Vec<Allocation> = by_id
.into_iter()
.filter(|&(_, amount)| amount != 0)
.map(|(id, amount)| Allocation { id, amount })
.collect();
if members.is_empty() {
continue;
}
let net = members.iter().map(|a| a.amount).sum();
let reason = if comp.len() == 1 {
groups[comp[0]].reason.clone()
} else {
Some(format!("coalesced {} groups", comp.len()))
};
out.push(Group {
members,
origin: origin.to_string(),
net,
reason,
});
}
out
}
impl<E> Strategy<E> for Coalesce<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let r = self.inner.run(bag);
Resolution {
groups: coalesce_groups(&r.groups, &self.origin),
residual: r.residual,
}
}
}
pub fn coalesce<E: 'static>(
origin: impl Into<String>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(Coalesce {
origin: origin.into(),
inner,
})
}
pub fn settle<E: Clone + 'static>(spec: FlowSpec<E>) -> Box<dyn Strategy<E>> {
coalesce("flow", flow(spec))
}
struct FixedPoint<E> {
inner: Box<dyn Strategy<E>>,
max_passes: usize,
}
fn residual_fingerprint<E>(items: &[Item<E>]) -> Vec<(ExtId, i64)> {
let mut v: Vec<(ExtId, i64)> = items.iter().map(|i| (i.id, i.amount)).collect();
v.sort_unstable();
v
}
impl<E> Strategy<E> for FixedPoint<E> {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut groups = Vec::new();
let mut residual = bag;
let mut fp = residual_fingerprint(&residual);
for _ in 0..self.max_passes {
if residual.is_empty() {
break;
}
let r = self.inner.run(std::mem::take(&mut residual));
groups.extend(r.groups);
residual = r.residual;
let next = residual_fingerprint(&residual);
if next == fp {
break;
}
fp = next;
}
Resolution { groups, residual }
}
}
pub fn fixed_point<E: 'static>(
inner: Box<dyn Strategy<E>>,
max_passes: usize,
) -> Box<dyn Strategy<E>> {
Box::new(FixedPoint {
inner,
max_passes: max_passes.max(1),
})
}
type ShardFactory<E, K> = dyn Fn(&K) -> Box<dyn Strategy<E>>;
struct PartitionBy<E, K, FK> {
key: FK,
factory: Box<ShardFactory<E, K>>,
children: HashMap<K, Box<dyn Strategy<E>>>,
}
impl<E, K, FK> Strategy<E> for PartitionBy<E, K, FK>
where
K: Hash + Eq + Clone,
FK: Fn(&E) -> K,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut shards: HashMap<K, Vec<Item<E>>> = HashMap::new();
for item in bag {
shards.entry((self.key)(&item.data)).or_default().push(item);
}
for k in self.children.keys() {
shards.entry(k.clone()).or_default();
}
let factory = &self.factory;
let children = &mut self.children;
let mut groups = Vec::new();
let mut residual = Vec::new();
for (k, items) in shards {
if !children.contains_key(&k) {
children.insert(k.clone(), factory(&k));
}
let r = children.get_mut(&k).unwrap().run(items);
groups.extend(r.groups);
residual.extend(r.residual);
}
Resolution { groups, residual }
}
}
pub fn partition_by<E: 'static, K, FK, FF>(key: FK, factory: FF) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + 'static,
FK: Fn(&E) -> K + 'static,
FF: Fn() -> Box<dyn Strategy<E>> + 'static,
{
Box::new(PartitionBy {
key,
factory: Box::new(move |_k| factory()),
children: HashMap::new(),
})
}
pub fn partition_by_with<E: 'static, K, FK, FF>(key: FK, factory: FF) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + 'static,
FK: Fn(&E) -> K + 'static,
FF: Fn(&K) -> Box<dyn Strategy<E>> + 'static,
{
Box::new(PartitionBy {
key,
factory: Box::new(factory),
children: HashMap::new(),
})
}
struct When<E, FP> {
pred: FP,
inner: Box<dyn Strategy<E>>,
}
impl<E, FP> Strategy<E> for When<E, FP>
where
FP: Fn(&E) -> bool,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut yes = Vec::new();
let mut no = Vec::new();
for item in bag {
if (self.pred)(&item.data) {
yes.push(item);
} else {
no.push(item);
}
}
let mut r = self.inner.run(yes);
r.residual.extend(no);
r
}
}
pub fn when<E: 'static, FP>(pred: FP, inner: Box<dyn Strategy<E>>) -> Box<dyn Strategy<E>>
where
FP: Fn(&E) -> bool + 'static,
{
Box::new(When { pred, inner })
}
struct Identity;
impl<E> Strategy<E> for Identity {
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
Resolution {
groups: Vec::new(),
residual: bag,
}
}
}
pub fn identity<E: 'static>() -> Box<dyn Strategy<E>> {
Box::new(Identity)
}
struct Windowed<E, FO> {
order: FO,
width: i64,
inner: Box<dyn Strategy<E>>,
_e: PhantomData<E>,
}
impl<E, FO> Strategy<E> for Windowed<E, FO>
where
FO: Fn(&E) -> i64,
{
fn run(&mut self, mut bag: Vec<Item<E>>) -> Resolution<E> {
let w = self.width.max(1);
bag.sort_by_key(|i| (self.order)(&i.data));
let mut groups = Vec::new();
let mut residual = Vec::new();
let mut carry: Vec<Item<E>> = Vec::new();
let mut it = bag.into_iter().peekable();
while let Some(first) = it.peek() {
let band_bottom = (self.order)(&first.data);
let mut band = Vec::new();
while let Some(item) = it.peek() {
if (self.order)(&item.data) < band_bottom + w {
band.push(it.next().unwrap());
} else {
break;
}
}
let mut keep = Vec::new();
for item in carry.drain(..) {
if (self.order)(&item.data) + w >= band_bottom {
keep.push(item);
} else {
residual.push(item);
}
}
keep.extend(band);
let r = self.inner.run(keep);
groups.extend(r.groups);
carry = r.residual; }
residual.extend(carry);
Resolution { groups, residual }
}
}
pub fn windowed<E: 'static, FO>(
order: FO,
width: i64,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FO: Fn(&E) -> i64 + 'static,
{
Box::new(Windowed {
order,
width,
inner,
_e: PhantomData,
})
}
struct ExactOneToOne<E, FK> {
key: FK,
_e: PhantomData<E>,
}
impl<E, FK> Strategy<E> for ExactOneToOne<E, FK>
where
FK: Fn(&E) -> Option<u64>,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut buckets: HashMap<u64, Vec<Item<E>>> = HashMap::new();
let mut residual = Vec::new();
for item in bag {
match (self.key)(&item.data) {
Some(k) if item.amount != 0 => buckets.entry(k).or_default().push(item),
_ => residual.push(item),
}
}
let mut groups = Vec::new();
for (_k, items) in buckets {
type Signed<E> = (Vec<Item<E>>, Vec<Item<E>>);
let mut by_mag: HashMap<i64, Signed<E>> = HashMap::new();
for item in items {
let a = item.amount;
let slot = by_mag.entry(a.abs()).or_default();
if a > 0 {
slot.0.push(item);
} else {
slot.1.push(item);
}
}
for (_mag, (mut pos, mut neg)) in by_mag {
pos.sort_unstable_by_key(|i| i.id);
neg.sort_unstable_by_key(|i| i.id);
let pairs = pos.len().min(neg.len());
for _ in 0..pairs {
let p = pos.pop().unwrap();
let n = neg.pop().unwrap();
groups.push(Group {
members: vec![
Allocation {
id: p.id,
amount: p.amount,
},
Allocation {
id: n.id,
amount: n.amount,
},
],
origin: "exact_1to1".to_string(),
net: 0,
reason: Some("exact 1:1 pair".to_string()),
});
}
residual.extend(pos);
residual.extend(neg);
}
}
Resolution { groups, residual }
}
}
pub fn exact_1to1<E: 'static, FK>(key: FK) -> Box<dyn Strategy<E>>
where
FK: Fn(&E) -> Option<u64> + 'static,
{
Box::new(ExactOneToOne {
key,
_e: PhantomData,
})
}
struct AggNet<E, FK> {
key: FK,
tol: Tol,
_e: PhantomData<E>,
}
impl<E, FK> Strategy<E> for AggNet<E, FK>
where
FK: Fn(&E) -> u64,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut buckets: HashMap<u64, Vec<Item<E>>> = HashMap::new();
for item in bag {
buckets
.entry((self.key)(&item.data))
.or_default()
.push(item);
}
let mut groups = Vec::new();
let mut residual = Vec::new();
for (_k, items) in buckets {
let sum: i64 = items.iter().map(|i| i.amount).sum();
let signs = items.iter().fold((false, false), |(p, n), i| {
let a = i.amount;
(p || a > 0, n || a < 0)
});
let tol = self.tol.slack_for(items.iter().map(|i| i.amount));
if items.len() >= 2 && sum.abs() <= tol && signs.0 && signs.1 {
groups.push(Group {
members: items
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect(),
origin: "agg_net".to_string(),
net: sum,
reason: Some("aggregate net".to_string()),
});
} else {
residual.extend(items);
}
}
Resolution { groups, residual }
}
}
pub fn agg_net<E: 'static, FK>(key: FK, tol: impl Into<Tol>) -> Box<dyn Strategy<E>>
where
FK: Fn(&E) -> u64 + 'static,
{
Box::new(AggNet {
key,
tol: tol.into(),
_e: PhantomData,
})
}
struct RunningZero<E, FO> {
order: FO,
tol: i64,
_e: PhantomData<E>,
}
impl<E, FO> Strategy<E> for RunningZero<E, FO>
where
FO: Fn(&E) -> i64,
{
fn run(&mut self, mut bag: Vec<Item<E>>) -> Resolution<E> {
bag.sort_by_key(|i| (self.order)(&i.data));
let mut groups = Vec::new();
let mut seg: Vec<Item<E>> = Vec::new();
let mut acc: i64 = 0;
for item in bag {
acc += item.amount;
seg.push(item);
if acc.abs() <= self.tol && seg.len() >= 2 {
groups.push(Group {
members: seg
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect(),
origin: "running_zero".to_string(),
net: acc,
reason: Some("running-balance zero".to_string()),
});
seg.clear();
acc = 0;
}
}
Resolution {
groups,
residual: seg, }
}
}
pub fn running_zero<E: 'static, FO>(order: FO, tol: i64) -> Box<dyn Strategy<E>>
where
FO: Fn(&E) -> i64 + 'static,
{
Box::new(RunningZero {
order,
tol,
_e: PhantomData,
})
}
struct SignalGroup<E, FS> {
signals: FS,
tol: Tol,
cap: usize,
_e: PhantomData<E>,
}
impl<E, FS> Strategy<E> for SignalGroup<E, FS>
where
FS: Fn(&E) -> Vec<u64>,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let n = bag.len();
let amt: Vec<i64> = bag.iter().map(|i| i.amount).collect();
let sigs: Vec<Vec<u64>> = bag.iter().map(|i| (self.signals)(&i.data)).collect();
let mut index: HashMap<u64, Vec<usize>> = HashMap::new();
for (i, s) in sigs.iter().enumerate() {
for &k in s {
index.entry(k).or_default().push(i);
}
}
let mut order: Vec<(usize, u64)> = index.iter().map(|(k, v)| (v.len(), *k)).collect();
order.sort_unstable();
let mut used = vec![false; n];
let mut groups = Vec::new();
for (_len, k) in order {
let members: Vec<usize> = index[&k].iter().copied().filter(|&i| !used[i]).collect();
if members.len() < 2 || members.len() > self.cap {
continue;
}
let sum: i64 = members.iter().map(|&i| amt[i]).sum();
let has_pos = members.iter().any(|&i| amt[i] > 0);
let has_neg = members.iter().any(|&i| amt[i] < 0);
let slack = self.tol.slack_for(members.iter().map(|&i| amt[i]));
if sum.abs() <= slack && has_pos && has_neg {
for &i in &members {
used[i] = true;
}
groups.push(Group {
members: members
.iter()
.map(|&i| Allocation {
id: bag[i].id,
amount: amt[i],
})
.collect(),
origin: "signal_group".to_string(),
net: sum,
reason: Some("shared reference".to_string()),
});
}
}
let residual = bag
.into_iter()
.enumerate()
.filter(|(i, _)| !used[*i])
.map(|(_, item)| item)
.collect();
Resolution { groups, residual }
}
}
pub fn signal_group<E: 'static, FS>(
signals: FS,
tol: impl Into<Tol>,
cap: usize,
) -> Box<dyn Strategy<E>>
where
FS: Fn(&E) -> Vec<u64> + 'static,
{
Box::new(SignalGroup {
signals,
tol: tol.into(),
cap,
_e: PhantomData,
})
}
#[derive(Clone)]
struct PivotMeta<E> {
outer: Item<E>,
alt_original: i64,
}
struct Pivot<E, FA> {
amount: FA,
inner: Box<dyn Strategy<E>>,
}
fn prorate(total: i64, part: i64, denom: i64) -> i64 {
if denom == 0 || total == 0 || part == 0 {
return 0;
}
let num = part as i128 * total as i128;
let den = denom as i128;
(num / den) as i64
}
impl<E, FA> Strategy<E> for Pivot<E, FA>
where
E: Clone,
FA: Fn(&E) -> i64,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut meta: BTreeMap<ExtId, PivotMeta<E>> = BTreeMap::new();
let inner_bag: Vec<Item<E>> = bag
.into_iter()
.map(|outer| {
let alt_original = (self.amount)(&outer.data);
let alt_amount = prorate(alt_original, outer.amount, outer.original);
let id = outer.id;
let data = outer.data.clone();
meta.insert(
id,
PivotMeta {
outer,
alt_original,
},
);
Item {
id,
original: alt_original,
amount: alt_amount,
data,
}
})
.collect();
let mut res = self.inner.run(inner_bag);
{
let mut group_pivot: BTreeMap<ExtId, i64> = BTreeMap::new();
for g in &res.groups {
for a in &g.members {
*group_pivot.entry(a.id).or_insert(0) += a.amount;
}
}
let mut dissolve: BTreeSet<ExtId> = BTreeSet::new();
for (id, &gp) in &group_pivot {
if gp == 0 {
continue;
}
let Some(m) = meta.get(id) else { continue };
if prorate(m.outer.amount, gp, m.alt_original) == 0 {
dissolve.insert(*id);
}
}
if !dissolve.is_empty() {
let mut moved: BTreeMap<ExtId, i64> = BTreeMap::new();
for g in &mut res.groups {
g.members.retain(|a| {
if dissolve.contains(&a.id) {
*moved.entry(a.id).or_insert(0) += a.amount;
false
} else {
true
}
});
g.net = g.members.iter().map(|a| a.amount).sum();
}
res.groups.retain(|g| !g.members.is_empty());
for (id, amt) in moved {
if amt == 0 {
continue;
}
if let Some(item) = res.residual.iter_mut().find(|i| i.id == id) {
item.amount += amt;
} else if let Some(m) = meta.get(&id) {
res.residual.push(Item {
id,
original: m.alt_original,
amount: amt,
data: m.outer.data.clone(),
});
}
}
}
}
let mut parts: BTreeMap<ExtId, Vec<(usize, Option<usize>, i64)>> = BTreeMap::new();
for (gi, g) in res.groups.iter().enumerate() {
for (mi, a) in g.members.iter().enumerate() {
parts
.entry(a.id)
.or_default()
.push((gi, Some(mi), a.amount));
}
}
for (ri, item) in res.residual.iter().enumerate() {
parts
.entry(item.id)
.or_default()
.push((ri, None, item.amount));
}
let mut group_amounts: Vec<Vec<i64>> = res
.groups
.iter()
.map(|g| vec![0; g.members.len()])
.collect();
let mut residual_amounts: Vec<i64> = vec![0; res.residual.len()];
let accounted: BTreeSet<ExtId> = parts.keys().copied().collect();
for (id, ps) in parts {
let Some(m) = meta.get(&id) else { continue };
let mut converted = Vec::with_capacity(ps.len());
let mut sum = 0i64;
for (_, _, amt) in &ps {
let v = prorate(m.outer.amount, *amt, m.alt_original);
converted.push(v);
sum += v;
}
if let Some(last) = converted.last_mut() {
*last += m.outer.amount - sum;
}
for ((idx, mi, _), v) in ps.into_iter().zip(converted) {
if let Some(mi) = mi {
group_amounts[idx][mi] = v;
} else {
residual_amounts[idx] = v;
}
}
}
let groups = res
.groups
.into_iter()
.enumerate()
.filter_map(|(gi, mut g)| {
for (mi, a) in g.members.iter_mut().enumerate() {
a.amount = group_amounts[gi][mi];
}
g.members.retain(|a| a.amount != 0);
if g.members.is_empty() {
return None;
}
g.net = g.members.iter().map(|a| a.amount).sum();
Some(g)
})
.collect();
let mut residual: Vec<Item<E>> = res
.residual
.into_iter()
.enumerate()
.filter_map(|(ri, mut i)| {
let m = meta.remove(&i.id)?;
i.original = m.outer.original;
i.amount = residual_amounts[ri];
(i.amount != 0).then_some(i)
})
.collect();
for (id, m) in meta {
if !accounted.contains(&id) && m.outer.amount != 0 {
residual.push(Item {
id,
original: m.outer.original,
amount: m.outer.amount,
data: m.outer.data,
});
}
}
Resolution { groups, residual }
}
}
pub fn pivot<E: Clone + 'static, FA>(
amount: FA,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FA: Fn(&E) -> i64 + 'static,
{
Box::new(Pivot { amount, inner })
}
#[derive(Clone, Copy)]
pub enum SoakMode {
Singleton,
Bucket,
}
fn soak_emit<E, K>(
groups: &mut Vec<Group>,
buckets: &mut HashMap<K, Vec<Item<E>>>,
mode: SoakMode,
origin: &str,
key: Option<K>,
item: Item<E>,
) where
K: Hash + Eq,
{
match mode {
SoakMode::Singleton => groups.push(Group {
members: vec![Allocation {
id: item.id,
amount: item.amount,
}],
origin: origin.to_string(),
net: item.amount,
reason: None,
}),
SoakMode::Bucket => buckets.entry(key.unwrap()).or_default().push(item),
}
}
fn soak_flush<E, K: ToString>(
groups: &mut Vec<Group>,
buckets: HashMap<K, Vec<Item<E>>>,
origin: &str,
) {
for (k, items) in buckets {
let net: i64 = items.iter().map(|i| i.amount).sum();
groups.push(Group {
members: items
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect(),
origin: format!("{}:{}", origin, k.to_string()),
net,
reason: None,
});
}
}
struct SoakSmall<E, FK> {
tol: Tol,
key: FK,
mode: SoakMode,
origin: String,
_e: PhantomData<E>,
}
impl<E, K, FK> Strategy<E> for SoakSmall<E, FK>
where
K: Hash + Eq + Clone + ToString,
FK: Fn(&Item<E>) -> K,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut groups = Vec::new();
let mut residual = Vec::new();
let mut buckets: HashMap<K, Vec<Item<E>>> = HashMap::new();
for item in bag {
let immaterial = item.amount != 0 && item.amount.abs() <= self.tol.slack(item.original);
if immaterial {
let k = matches!(self.mode, SoakMode::Bucket).then(|| (self.key)(&item));
soak_emit(&mut groups, &mut buckets, self.mode, &self.origin, k, item);
} else {
residual.push(item);
}
}
soak_flush(&mut groups, buckets, &self.origin);
Resolution { groups, residual }
}
}
pub fn soak_small<E: 'static, K, FK>(
tol: impl Into<Tol>,
mode: SoakMode,
origin: impl Into<String>,
key: FK,
) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + ToString + 'static,
FK: Fn(&Item<E>) -> K + 'static,
{
Box::new(SoakSmall {
tol: tol.into(),
key,
mode,
origin: origin.into(),
_e: PhantomData,
})
}
struct SoakIf<E, FP, FK> {
pred: FP,
key: FK,
mode: SoakMode,
origin: String,
_e: PhantomData<E>,
}
impl<E, K, FP, FK> Strategy<E> for SoakIf<E, FP, FK>
where
K: Hash + Eq + Clone + ToString,
FP: Fn(&Item<E>) -> bool,
FK: Fn(&Item<E>) -> K,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut groups = Vec::new();
let mut residual = Vec::new();
let mut buckets: HashMap<K, Vec<Item<E>>> = HashMap::new();
for item in bag {
if item.amount != 0 && (self.pred)(&item) {
let k = matches!(self.mode, SoakMode::Bucket).then(|| (self.key)(&item));
soak_emit(&mut groups, &mut buckets, self.mode, &self.origin, k, item);
} else {
residual.push(item);
}
}
soak_flush(&mut groups, buckets, &self.origin);
Resolution { groups, residual }
}
}
pub fn soak_if<E: 'static, K, FP, FK>(
pred: FP,
mode: SoakMode,
origin: impl Into<String>,
key: FK,
) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + ToString + 'static,
FP: Fn(&Item<E>) -> bool + 'static,
FK: Fn(&Item<E>) -> K + 'static,
{
Box::new(SoakIf {
pred,
key,
mode,
origin: origin.into(),
_e: PhantomData,
})
}
struct SoakAll<E, FK> {
key: FK,
mode: SoakMode,
origin: String,
_e: PhantomData<E>,
}
impl<E, K, FK> Strategy<E> for SoakAll<E, FK>
where
K: Hash + Eq + Clone + ToString,
FK: Fn(&Item<E>) -> K,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut groups = Vec::new();
let mut buckets: HashMap<K, Vec<Item<E>>> = HashMap::new();
for item in bag {
if item.amount == 0 {
continue;
}
let k = matches!(self.mode, SoakMode::Bucket).then(|| (self.key)(&item));
soak_emit(&mut groups, &mut buckets, self.mode, &self.origin, k, item);
}
soak_flush(&mut groups, buckets, &self.origin);
Resolution {
groups,
residual: Vec::new(),
}
}
}
pub fn soak_all<E: 'static, K, FK>(
mode: SoakMode,
origin: impl Into<String>,
key: FK,
) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + ToString + 'static,
FK: Fn(&Item<E>) -> K + 'static,
{
Box::new(SoakAll {
key,
mode,
origin: origin.into(),
_e: PhantomData,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn bag(items: &[(ExtId, i64)]) -> Vec<Item<i64>> {
items.iter().map(|&(id, a)| Item::new(id, a, a)).collect()
}
fn ids(g: &Group) -> Vec<ExtId> {
let mut m = g.member_ids();
m.sort();
m
}
fn conserves<E>(input: usize, r: &Resolution<E>) {
let g: usize = r.groups.iter().map(|g| g.members.len()).sum();
assert_eq!(g + r.residual.len(), input, "conservation violated");
}
#[test]
fn agg_net_relative_tolerance_scales_with_smallest_leg() {
let b = bag(&[(1, 10_000), (2, -9_991)]);
let mut s = agg_net(|_a: &i64| 0u64, Tol::Rel { bps: 10, floor: 0 });
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1, "9 <= 10 (10bps of 10_000)");
let b = bag(&[(1, 10_000), (2, -9_991)]);
let mut s = agg_net(|_a: &i64| 0u64, Tol::Rel { bps: 5, floor: 0 });
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 0, "9 > 5 (5bps of 10_000)");
}
#[test]
fn agg_net_relative_floor_applies_to_tiny_buckets() {
let b = bag(&[(1, 100), (2, -98)]);
let mut s = agg_net(|_a: &i64| 0u64, Tol::Rel { bps: 10, floor: 3 });
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
}
#[test]
fn labeled_stamps_reason_on_groups_but_not_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 7)]);
let mut s = labeled("S3a exact", exact_1to1(|_| Some(0)));
let r = s.run(b);
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(
r.groups[0].reason.as_deref(),
Some("S3a exact: exact 1:1 pair")
);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn labeled_prepends_to_inner_detail() {
let b = bag(&[(1, 5), (2, -5)]);
let mut s = labeled("outer", labeled("inner", exact_1to1(|_| Some(0))));
let r = s.run(b);
assert_eq!(
r.groups[0].reason.as_deref(),
Some("outer: inner: exact 1:1 pair")
);
}
#[test]
fn exact_pairs_and_leaves_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 5), (4, 3)]);
let mut s = exact_1to1(|_| Some(0));
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert!(r.groups[0].member_ids().contains(&2));
assert_eq!(r.residual.len(), 2);
}
#[test]
fn agg_accepts_netting_bucket() {
let b = bag(&[(1, 100), (2, -60), (3, -40), (4, 7)]);
let mut s = agg_net(|_a: &i64| 0u64, 0);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 0);
let b = bag(&[(1, 100), (2, -60), (3, -40), (4, 7)]);
let mut s = agg_net(|_a: &i64| 0u64, 10);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members.len(), 4);
}
#[test]
fn signal_groups_net_and_cascade() {
let b = bag(&[(1, 50), (2, -50), (3, 9)]);
let mut s = signal_group(
|a: &i64| if *a == 9 { vec![] } else { vec![10] },
Tol::Abs(0),
16,
);
let r = s.run(b);
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
assert_eq!(r.residual.len(), 1);
}
#[test]
fn signal_groups_accept_relative_tol() {
let b = bag(&[(1, 1000), (2, -995)]);
let mut tight = signal_group(|_: &i64| vec![7u64], Tol::Rel { bps: 10, floor: 0 }, 16);
let r = tight.run(b.clone());
assert_eq!(r.groups.len(), 0);
assert_eq!(r.residual.len(), 2);
let mut loose = signal_group(|_: &i64| vec![7u64], Tol::Rel { bps: 60, floor: 0 }, 16);
let r = loose.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
}
struct OnePair;
impl Strategy<i64> for OnePair {
fn run(&mut self, bag: Vec<Item<i64>>) -> Resolution<i64> {
for i in 0..bag.len() {
for j in (i + 1)..bag.len() {
if bag[i].amount == -bag[j].amount && bag[i].amount != 0 {
let mut residual = Vec::new();
let mut members = Vec::new();
for (k, item) in bag.into_iter().enumerate() {
if k == i || k == j {
members.push(Allocation {
id: item.id,
amount: item.amount,
});
} else {
residual.push(item);
}
}
let g = Group {
members,
origin: "onepair".into(),
net: 0,
reason: None,
};
return Resolution {
groups: vec![g],
residual,
};
}
}
}
Resolution {
groups: vec![],
residual: bag,
}
}
}
#[test]
fn fixed_point_drives_a_non_maximal_leaf_to_completion() {
let mut once = OnePair;
let r = once.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 2);
let mut fp = fixed_point(Box::new(OnePair), 16);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 2, "both pairs found across passes");
assert_eq!(r.residual.len(), 0);
}
#[test]
fn fixed_point_leaves_unmatchable_residual_and_terminates() {
let mut fp = fixed_point(Box::new(OnePair), 16);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, 3)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
let mut left: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
left.sort();
assert_eq!(left, vec![3, 4]);
}
#[test]
fn fixed_point_respects_the_pass_cap() {
let mut fp = fixed_point(Box::new(OnePair), 1);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 1, "cap of 1 means one pass");
assert_eq!(r.residual.len(), 2);
}
#[test]
fn when_cascade_routes_to_different_children_and_conserves() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let mut s = seq(vec![
when(|a: &i64| a.unsigned_abs() == 5, agg_net(|_a: &i64| 1u64, 0)),
agg_net(|_a: &i64| 2u64, 0),
]);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
assert_eq!(r.residual.len(), 0);
}
#[test]
fn partition_by_with_picks_a_per_key_subtree() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let mut s = partition_by_with(
|a: &i64| (a.unsigned_abs() == 5) as u8,
|k: &u8| {
if *k == 1 {
agg_net(|_a: &i64| 0u64, 0)
} else {
identity()
}
},
);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
let mut rem: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
rem.sort_unstable();
assert_eq!(rem, vec![3, 4]);
}
#[test]
fn windowed_blocks_far_matches() {
let b = vec![Item::new(1, 5, (1i64, 5i64)), Item::new(2, -5, (100, -5))];
let inner = exact_1to1(|_| Some(0));
let r = {
let mut w = windowed(|d: &(i64, i64)| d.0, 3, inner);
w.run(b)
};
assert_eq!(r.groups.len(), 0);
assert_eq!(r.residual.len(), 2);
}
#[test]
fn windowed_finds_near_match_across_band_boundary() {
let b = vec![Item::new(1, 5, (4i64, 5i64)), Item::new(2, -5, (7, -5))];
let inner = exact_1to1(|_| Some(0));
let r = {
let mut w = windowed(|d: &(i64, i64)| d.0, 3, inner);
w.run(b)
};
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 0);
}
#[test]
fn running_zero_segments_at_balance_clears() {
let b = vec![
Item::new(1, 100, (1i64, 100i64)),
Item::new(2, -100, (2, -100)),
Item::new(3, 50, (3, 50)),
Item::new(4, -30, (4, -30)),
Item::new(5, -20, (5, -20)),
];
let mut s = running_zero(|d: &(i64, i64)| d.0, 0);
let r = s.run(b);
conserves(5, &r);
assert_eq!(r.groups.len(), 2);
assert_eq!(r.groups[0].member_ids(), vec![1, 2]);
assert_eq!(r.groups[1].member_ids(), vec![3, 4, 5]);
}
#[test]
fn running_zero_leaves_uncleared_tail() {
let b = vec![
Item::new(1, 100, (1i64, 100i64)),
Item::new(2, -100, (2, -100)),
Item::new(3, 7, (3, 7)),
];
let mut s = running_zero(|d: &(i64, i64)| d.0, 0);
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn seq_then_partition_compose() {
let mut pipeline = partition_by(
|a: &i64| a.signum().unsigned_abs(),
|| seq(vec![exact_1to1(|_| Some(0))]),
);
let b = bag(&[(1, 4), (2, -4), (3, 4), (4, -4)]);
let r = pipeline.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
}
#[test]
fn accept_if_rejects_groups_back_to_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let mut s = accept_if(
|g: &Group| g.members.iter().all(|a| a.amount.abs() == 5),
exact_1to1(|_| Some(0)),
);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
let mut left: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
left.sort();
assert_eq!(left, vec![3, 4]);
for i in &r.residual {
assert_eq!(i.amount.abs(), 7);
}
}
#[test]
fn whole_net_reclaims_tail_and_keeps_break_within_tol() {
struct Partial;
impl Strategy<i64> for Partial {
fn run(&mut self, bag: Vec<Item<i64>>) -> Resolution<i64> {
let m: HashMap<ExtId, Item<i64>> = bag.into_iter().map(|i| (i.id, i)).collect();
let groups = vec![Group {
members: vec![
Allocation { id: 1, amount: 97 },
Allocation { id: 2, amount: -97 },
],
origin: "flow".into(),
net: 0,
reason: None,
}];
let mut tail = m[&1].clone();
tail.amount = 3;
Resolution {
groups,
residual: vec![tail],
}
}
}
let mut s = whole_net(Tol::Abs(5), Box::new(Partial));
let r = s.run(bag(&[(1, 100), (2, -97)]));
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 3);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
assert!(r.residual.is_empty());
let mut s = whole_net(Tol::Abs(2), Box::new(Partial));
let r = s.run(bag(&[(1, 100), (2, -97)]));
conserves(2, &r);
assert!(r.groups.is_empty());
let left: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
assert_eq!(left, vec![(1, 100), (2, -97)]); }
#[test]
fn whole_net_collapses_groups_sharing_a_line() {
struct Split;
impl Strategy<i64> for Split {
fn run(&mut self, _bag: Vec<Item<i64>>) -> Resolution<i64> {
let groups = vec![
Group {
members: vec![
Allocation { id: 1, amount: 60 },
Allocation { id: 2, amount: -60 },
],
origin: "a".into(),
net: 0,
reason: None,
},
Group {
members: vec![
Allocation { id: 1, amount: 40 },
Allocation { id: 3, amount: -40 },
],
origin: "b".into(),
net: 0,
reason: None,
},
];
Resolution {
groups,
residual: vec![],
}
}
}
let mut s = whole_net(Tol::Abs(0), Box::new(Split));
let r = s.run(bag(&[(1, 100), (2, -60), (3, -40)]));
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].origin, "settlement");
assert_eq!(r.groups[0].net, 0);
let mut mem: Vec<(ExtId, i64)> = r.groups[0]
.members
.iter()
.map(|a| (a.id, a.amount))
.collect();
mem.sort();
assert_eq!(mem, vec![(1, 100), (2, -60), (3, -40)]);
assert!(r.residual.is_empty());
}
#[test]
fn accept_if_size_cap_and_minority_side() {
let b = bag(&[
(1, 40),
(2, -10),
(3, -10),
(4, -10),
(5, -10),
(6, 8),
(7, -8),
]);
let mut s = accept_if(
|g: &Group| g.members.len() <= 3,
agg_net(|a: &i64| if a.unsigned_abs() == 8 { 1u64 } else { 0u64 }, 0),
);
let r = s.run(b);
conserves(7, &r);
assert_eq!(r.groups.len(), 1, "only the small pair is accepted");
assert_eq!(ids(&r.groups[0]), vec![6, 7]);
assert_eq!(r.residual.len(), 5, "the over-large group is dissolved");
}
struct EmitGroups(Vec<Vec<(ExtId, i64)>>);
impl Strategy<i64> for EmitGroups {
fn run(&mut self, bag: Vec<Item<i64>>) -> Resolution<i64> {
let claimed: BTreeSet<ExtId> = self.0.iter().flatten().map(|&(id, _)| id).collect();
let groups = self
.0
.iter()
.map(|m| Group {
members: m
.iter()
.map(|&(id, amount)| Allocation { id, amount })
.collect(),
origin: "emit".into(),
net: m.iter().map(|&(_, a)| a).sum(),
reason: None,
})
.collect();
let residual = bag
.into_iter()
.filter(|i| !claimed.contains(&i.id))
.collect();
Resolution { groups, residual }
}
}
#[test]
fn coalesce_merges_groups_that_share_a_row() {
let inner = EmitGroups(vec![
vec![(1, 100), (2, -60)],
vec![(2, -40), (3, 100), (4, -100)],
]);
let b = bag(&[(1, 100), (2, -100), (3, 100), (4, -100), (9, 7)]);
let mut s = coalesce("settlement", Box::new(inner));
let r = s.run(b);
conserves(5, &r);
assert_eq!(r.groups.len(), 1, "the two interlocking groups merge");
let g = &r.groups[0];
assert_eq!(g.origin, "settlement");
assert_eq!(ids(g), vec![1, 2, 3, 4]);
let two = g.members.iter().find(|a| a.id == 2).unwrap();
assert_eq!(two.amount, -100);
assert_eq!(g.net, 0);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 9);
}
#[test]
fn coalesce_keeps_disjoint_groups_separate() {
let inner = EmitGroups(vec![vec![(1, 5), (2, -5)], vec![(3, 7), (4, -7)]]);
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let mut s = coalesce("settlement", Box::new(inner));
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2, "disjoint components stay separate");
assert!(r.groups.iter().all(|g| g.origin == "settlement"));
}
#[test]
fn settle_folds_flow_arcs_into_one_settlement() {
#[derive(Clone)]
struct Tx {
date: i64,
}
let spec = FlowSpec::<Tx>::new()
.penalty(1_000_000.0)
.window(3)
.block_key(|t: &Tx| t.date)
.cost_lot(|a: &Tx, a_amt, b: &Tx, b_amt| {
Some(1.0 + (a_amt + b_amt).abs() as f64 * 0.1 + (a.date - b.date).abs() as f64)
});
let mut s = settle(spec);
let r = s.run(vec![
Item::new(1, 100, Tx { date: 0 }),
Item::new(2, 200, Tx { date: 1 }),
Item::new(3, -250, Tx { date: 0 }),
]);
assert_eq!(r.groups.len(), 1, "arcs fold into one settlement");
let g = &r.groups[0];
assert_eq!(g.origin, "flow");
assert_eq!(g.net, 0);
assert_eq!(ids(g), vec![1, 2, 3]);
let three = g.members.iter().find(|a| a.id == 3).unwrap();
assert_eq!(three.amount, -250, "id 3's two arcs sum to one clean edge");
assert_eq!(r.residual.iter().map(|i| i.amount).sum::<i64>(), 50);
}
#[test]
fn pivot_converts_back_to_outer_amount() {
let b = vec![Item::new(1, 110, (100i64,)), Item::new(2, -110, (-100i64,))];
let mut s = pivot(|d: &(i64,)| d.0, exact_1to1(|_| Some(0)));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 0);
assert_eq!(
r.groups[0].members,
vec![
Allocation { id: 1, amount: 110 },
Allocation {
id: 2,
amount: -110
}
]
);
}
struct HalfMatch;
impl Strategy<(i64,)> for HalfMatch {
fn run(&mut self, bag: Vec<Item<(i64,)>>) -> Resolution<(i64,)> {
let mut members = Vec::new();
let mut residual = Vec::new();
for it in bag {
let half = it.amount / 2;
members.push(Allocation {
id: it.id,
amount: half,
});
let mut r = it.clone();
r.amount = it.amount - half;
residual.push(r);
}
let net = members.iter().map(|a| a.amount).sum();
let groups = vec![Group {
members,
origin: "half".into(),
net,
reason: None,
}];
Resolution { groups, residual }
}
}
#[test]
fn pivot_dissolves_rows_that_round_to_zero_parent() {
let b = vec![Item::new(1, 1, (4i64,)), Item::new(2, 100, (4i64,))];
let mut s = pivot(|d: &(i64,)| d.0, Box::new(HalfMatch));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members, vec![Allocation { id: 2, amount: 50 }]);
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 1), (2, 50)]);
}
#[test]
fn pivot_zero_target_is_safe() {
let b = vec![Item::new(1, 5, (0i64,)), Item::new(2, 100, (4i64,))];
let mut s = pivot(|d: &(i64,)| d.0, Box::new(HalfMatch));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members, vec![Allocation { id: 2, amount: 50 }]);
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 5), (2, 50)]);
}
struct DropZeros;
impl Strategy<(i64,)> for DropZeros {
fn run(&mut self, bag: Vec<Item<(i64,)>>) -> Resolution<(i64,)> {
Resolution {
groups: Vec::new(),
residual: bag.into_iter().filter(|i| i.amount != 0).collect(),
}
}
}
#[test]
fn pivot_reemits_forward_floored_rows() {
let b = vec![
Item {
id: 1,
original: 100,
amount: 1,
data: (3i64,),
},
Item {
id: 2,
original: 50,
amount: 50,
data: (50i64,),
},
];
let mut s = pivot(|d: &(i64,)| d.0, Box::new(DropZeros));
let r = s.run(b);
assert!(r.groups.is_empty());
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 1), (2, 50)]);
}
#[test]
fn soak_small_abs_threshold_singletons() {
let mut s = soak_small(5, SoakMode::Singleton, "rounding", |_: &Item<i64>| 0u64);
let r = s.run(bag(&[(1, 3), (2, -2), (3, 100)]));
conserves(3, &r);
assert_eq!(r.groups.len(), 2);
assert!(
r.groups
.iter()
.all(|g| g.members.len() == 1 && g.origin == "rounding")
);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn soak_small_bps_against_original() {
let items = vec![
Item {
id: 1,
original: 1000,
amount: 10,
data: 0,
}, Item {
id: 2,
original: 1000,
amount: 50,
data: 0,
}, ];
let mut s = soak_small(
Tol::Rel { bps: 200, floor: 0 },
SoakMode::Singleton,
"var",
|_: &Item<i64>| 0u64,
);
let r = s.run(items);
conserves(2, &r);
assert_eq!(ids(&r.groups[0]), vec![1]);
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![2]);
}
#[test]
fn soak_small_bucket_groups_by_key() {
let key = |i: &Item<i64>| if i.amount > 0 { "pos" } else { "neg" };
let mut s = soak_small(5, SoakMode::Bucket, "tail", key);
let r = s.run(bag(&[(1, 3), (2, 4), (3, -2), (4, 100)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
assert!(r.groups.iter().all(|g| g.origin.starts_with("tail:")));
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![4]);
}
#[test]
fn soak_if_predicate_selects() {
let mut s = soak_if(
|i: &Item<i64>| i.amount < 0,
SoakMode::Singleton,
"shorts",
|_: &Item<i64>| 0u64,
);
let r = s.run(bag(&[(1, 50), (2, -30), (3, -10)]));
conserves(3, &r);
let mut soaked: Vec<ExtId> = r.groups.iter().flat_map(ids).collect();
soaked.sort();
assert_eq!(soaked, vec![2, 3]);
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![1]);
}
#[test]
fn soak_all_terminates_residual() {
let mut s = soak_all(SoakMode::Singleton, "unmatched", |_: &Item<i64>| 0u64);
let r = s.run(bag(&[(1, 50), (2, -30), (3, 0)]));
assert!(r.residual.is_empty());
let mut soaked: Vec<ExtId> = r.groups.iter().flat_map(ids).collect();
soaked.sort();
assert_eq!(soaked, vec![1, 2]);
assert!(r.groups.iter().all(|g| g.net != 0));
}
#[test]
fn soak_all_bucket_nets_per_key() {
let key = |i: &Item<i64>| if i.amount > 0 { 1u64 } else { 2u64 };
let mut s = soak_all(SoakMode::Bucket, "class", key);
let r = s.run(bag(&[(1, 50), (2, 30), (3, -20)]));
assert!(r.residual.is_empty());
let mut nets: Vec<i64> = r.groups.iter().map(|g| g.net).collect();
nets.sort();
assert_eq!(nets, vec![-20, 80]);
}
fn group(members: &[(ExtId, i64)]) -> Group {
let members: Vec<Allocation> = members
.iter()
.map(|&(id, amount)| Allocation { id, amount })
.collect();
let net = members.iter().map(|a| a.amount).sum();
Group {
members,
origin: "test".into(),
net,
reason: None,
}
}
#[test]
fn group_metrics() {
let g = group(&[(1, 1_000_000), (2, -999_000), (3, -1_200)]);
assert_eq!(g.size(), 3);
assert_eq!(g.abs_net(), 200);
assert_eq!(g.max_abs(), 1_000_000);
assert_eq!(g.min_abs(), 1_200);
assert_eq!(g.min_side(), 1); }
#[test]
fn clean_rel_vs_relmax_pick_different_scale_legs() {
let g = group(&[(1, 1_000_000), (2, -999_000), (3, -1_200)]);
assert!(
!g.clean(Tol::Rel { bps: 5, floor: 100 }),
"200 > 100 -> dirty"
);
assert!(
g.clean(Tol::RelMax { bps: 5, floor: 100 }),
"200 <= 500 -> clean"
);
assert!(g.clean(Tol::Abs(200)));
assert!(!g.clean(Tol::Abs(199)));
}
#[test]
fn agg_net_relmax_accepts_what_rel_rejects() {
let b = bag(&[(1, 1_000_000), (2, -999_000), (3, -1_200)]);
let mut rel = agg_net(|_: &i64| 0u64, Tol::Rel { bps: 5, floor: 100 });
assert_eq!(rel.run(b.clone()).groups.len(), 0);
let mut relmax = agg_net(|_: &i64| 0u64, Tol::RelMax { bps: 5, floor: 100 });
let r = relmax.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, -200);
}
}