use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
pub mod flow;
pub use flow::{Allocation, ExtId, FlowSpec, flow};
use std::hash::Hash;
use std::marker::PhantomData;
#[derive(Clone)]
pub struct Item<E> {
pub id: ExtId,
pub original: i64,
pub amount: i64,
pub data: E,
}
impl<E> Item<E> {
pub fn new(id: ExtId, amount: i64, data: E) -> Self {
Item {
id,
original: amount,
amount,
data,
}
}
}
#[derive(Debug, Clone)]
pub struct Group {
pub members: Vec<Allocation>,
pub origin: String,
pub net: i64,
pub reason: Option<String>,
}
impl Group {
pub fn member_ids(&self) -> Vec<ExtId> {
self.members.iter().map(|a| a.id).collect()
}
pub fn size(&self) -> usize {
self.members.len()
}
pub fn abs_net(&self) -> i64 {
self.net.abs()
}
pub fn max_abs(&self) -> i64 {
self.members
.iter()
.map(|a| a.amount.abs())
.max()
.unwrap_or(0)
}
pub fn min_abs(&self) -> i64 {
self.members
.iter()
.map(|a| a.amount.abs())
.filter(|&v| v > 0)
.min()
.unwrap_or(0)
}
pub fn min_side(&self) -> usize {
let pos = self.members.iter().filter(|a| a.amount > 0).count();
let neg = self.members.iter().filter(|a| a.amount < 0).count();
pos.min(neg)
}
}
pub struct GroupView<'a, E> {
members: Vec<MemberView<'a, E>>,
}
pub struct MemberView<'a, E> {
pub id: ExtId,
pub amount: i64,
pub original: i64,
pub data: &'a E,
}
impl<'a, E> GroupView<'a, E> {
fn from_refs(items: impl IntoIterator<Item = &'a Item<E>>) -> Self {
GroupView {
members: items
.into_iter()
.map(|i| MemberView {
id: i.id,
amount: i.amount,
original: i.original,
data: &i.data,
})
.collect(),
}
}
fn from_group(g: &Group, src: &'a HashMap<ExtId, Item<E>>) -> Self {
GroupView {
members: g
.members
.iter()
.filter_map(|a| {
src.get(&a.id).map(|it| MemberView {
id: a.id,
amount: a.amount,
original: it.original,
data: &it.data,
})
})
.collect(),
}
}
pub fn net(&self) -> i64 {
self.members.iter().map(|m| m.amount).sum()
}
pub fn gross(&self) -> i64 {
self.members.iter().map(|m| m.amount.abs()).sum()
}
pub fn max_leg(&self) -> i64 {
self.members.iter().map(|m| m.amount.abs()).max().unwrap_or(0)
}
pub fn min_leg(&self) -> i64 {
self.members
.iter()
.map(|m| m.amount.abs())
.filter(|&v| v > 0)
.min()
.unwrap_or(0)
}
pub fn original_total(&self) -> i64 {
let mut seen: HashSet<ExtId> = HashSet::new();
self.members
.iter()
.filter(|m| seen.insert(m.id))
.map(|m| m.original.abs())
.sum()
}
pub fn size(&self) -> usize {
self.members.len()
}
pub fn min_side(&self) -> usize {
let pos = self.members.iter().filter(|m| m.amount > 0).count();
let neg = self.members.iter().filter(|m| m.amount < 0).count();
pos.min(neg)
}
pub fn members(&self) -> impl Iterator<Item = &MemberView<'a, E>> {
self.members.iter()
}
}
pub struct Resolution<E> {
pub groups: Vec<Group>,
pub residual: Vec<Item<E>>,
}
pub trait Strategy<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E>;
}
impl<E> Strategy<E> for Box<dyn Strategy<E>> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
(**self).run(bag)
}
}
struct Seq<E> {
steps: Vec<Box<dyn Strategy<E>>>,
}
impl<E> Strategy<E> for Seq<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
#[cfg(not(target_arch = "wasm32"))]
let timed = std::env::var_os("FLORECON_TIME").is_some();
#[cfg(target_arch = "wasm32")]
let timed = false;
let mut groups = Vec::new();
let mut residual = bag;
for (i, step) in self.steps.iter().enumerate() {
let n_in = residual.len();
let t = timed.then(std::time::Instant::now);
let r = step.run(residual);
if let Some(t) = t {
eprintln!(
" seq step {i}: {n_in:>7} in -> {:>7} grouped, {:>7} residual [{:>6.1} ms]",
r.groups.iter().map(|g| g.members.len()).sum::<usize>(),
r.residual.len(),
t.elapsed().as_secs_f64() * 1000.0,
);
}
groups.extend(r.groups);
residual = r.residual;
}
Resolution { groups, residual }
}
}
pub fn seq<E: 'static>(steps: Vec<Box<dyn Strategy<E>>>) -> Box<dyn Strategy<E>> {
Box::new(Seq { steps })
}
struct Labeled<E> {
tag: String,
inner: Box<dyn Strategy<E>>,
}
impl<E> Strategy<E> for Labeled<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut r = self.inner.run(bag);
for g in &mut r.groups {
g.reason = Some(match g.reason.take() {
Some(detail) => format!("{}: {}", self.tag, detail),
None => self.tag.clone(),
});
}
r
}
}
pub fn labeled<E: 'static>(
tag: impl Into<String>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(Labeled {
tag: tag.into(),
inner,
})
}
struct Filter<E, FP> {
pred: FP,
inner: Box<dyn Strategy<E>>,
}
impl<E, FP> Strategy<E> for Filter<E, FP>
where
E: Clone,
FP: Fn(&GroupView<E>) -> bool,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let src: HashMap<ExtId, Item<E>> = bag.iter().map(|i| (i.id, i.clone())).collect();
let mut r = self.inner.run(bag);
let mut kept = Vec::with_capacity(r.groups.len());
let mut residual_ix: HashMap<ExtId, usize> = r
.residual
.iter()
.enumerate()
.map(|(ix, item)| (item.id, ix))
.collect();
for g in r.groups.drain(..) {
let accept = (self.pred)(&GroupView::from_group(&g, &src));
if accept {
kept.push(g);
continue;
}
for a in &g.members {
match residual_ix.get(&a.id) {
Some(&ix) => r.residual[ix].amount += a.amount,
None => {
if let Some(orig) = src.get(&a.id) {
residual_ix.insert(a.id, r.residual.len());
r.residual.push(Item {
id: a.id,
original: orig.original,
amount: a.amount,
data: orig.data.clone(),
});
}
}
}
}
}
Resolution {
groups: kept,
residual: r.residual,
}
}
}
pub fn accept_if<E: Clone + 'static, FP>(
pred: FP,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FP: Fn(&GroupView<E>) -> bool + 'static,
{
Box::new(Filter { pred, inner })
}
struct Reclaim<E> {
origin: String,
inner: Box<dyn Strategy<E>>,
}
impl<E: Clone> Strategy<E> for Reclaim<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let src: HashMap<ExtId, Item<E>> = bag.iter().map(|i| (i.id, i.clone())).collect();
let r = self.inner.run(bag);
let mut resid: HashMap<ExtId, Item<E>> = HashMap::new();
for it in r.residual {
resid
.entry(it.id)
.and_modify(|e| e.amount += it.amount)
.or_insert(it);
}
let comps = group_components(&r.groups);
let mut out_groups: Vec<Group> = Vec::new();
for comp in comps {
let mut member_ids: Vec<ExtId> = Vec::new();
let mut seen: HashSet<ExtId> = HashSet::new();
for &gi in &comp {
for a in &r.groups[gi].members {
if seen.insert(a.id) {
member_ids.push(a.id);
}
}
}
member_ids.sort_unstable();
let wholes: Vec<(ExtId, i64)> = member_ids
.iter()
.filter_map(|&id| src.get(&id).map(|i| (id, i.original)))
.collect();
let net: i64 = wholes.iter().map(|&(_, o)| o).sum();
for &(id, _) in &wholes {
resid.remove(&id);
}
let (origin, reason) = if comp.len() == 1 {
let g = &r.groups[comp[0]];
(g.origin.clone(), g.reason.clone())
} else {
(self.origin.clone(), None)
};
out_groups.push(Group {
members: wholes
.iter()
.map(|&(id, o)| Allocation { id, amount: o })
.collect(),
origin,
net,
reason,
});
}
let mut out_residual: Vec<Item<E>> = resid.into_values().collect();
out_residual.sort_by_key(|i| i.id);
Resolution {
groups: out_groups,
residual: out_residual,
}
}
}
pub fn reclaim<E: Clone + 'static>(
origin: impl Into<String>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(Reclaim {
origin: origin.into(),
inner,
})
}
struct Coalesce<E> {
origin: String,
inner: Box<dyn Strategy<E>>,
}
fn group_components(groups: &[Group]) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..groups.len()).collect();
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]]; x = parent[x];
}
x
}
let mut first: HashMap<ExtId, usize> = HashMap::new();
for (gi, g) in groups.iter().enumerate() {
for a in &g.members {
match first.get(&a.id) {
Some(&fj) => {
let (ra, rb) = (find(&mut parent, gi), find(&mut parent, fj));
if ra != rb {
parent[ra] = rb;
}
}
None => {
first.insert(a.id, gi);
}
}
}
}
let mut order: Vec<usize> = Vec::new();
let mut buckets: HashMap<usize, Vec<usize>> = HashMap::new();
for gi in 0..groups.len() {
let r = find(&mut parent, gi);
if !buckets.contains_key(&r) {
order.push(r);
}
buckets.entry(r).or_default().push(gi);
}
order
.into_iter()
.map(|r| buckets.remove(&r).unwrap())
.collect()
}
fn coalesce_groups(groups: &[Group], origin: &str) -> Vec<Group> {
let comps = group_components(groups);
let mut out = Vec::with_capacity(comps.len());
for comp in &comps {
let mut by_id: BTreeMap<ExtId, i64> = BTreeMap::new();
for &gi in comp {
for a in &groups[gi].members {
*by_id.entry(a.id).or_insert(0) += a.amount;
}
}
let members: Vec<Allocation> = by_id
.into_iter()
.filter(|&(_, amount)| amount != 0)
.map(|(id, amount)| Allocation { id, amount })
.collect();
if members.is_empty() {
continue;
}
let net = members.iter().map(|a| a.amount).sum();
let reason = if comp.len() == 1 {
groups[comp[0]].reason.clone()
} else {
Some(format!("coalesced {} groups", comp.len()))
};
out.push(Group {
members,
origin: origin.to_string(),
net,
reason,
});
}
out
}
impl<E> Strategy<E> for Coalesce<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let r = self.inner.run(bag);
Resolution {
groups: coalesce_groups(&r.groups, &self.origin),
residual: r.residual,
}
}
}
pub fn coalesce<E: 'static>(
origin: impl Into<String>,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>> {
Box::new(Coalesce {
origin: origin.into(),
inner,
})
}
struct FixedPoint<E> {
inner: Box<dyn Strategy<E>>,
max_passes: usize,
}
fn residual_fingerprint<E>(items: &[Item<E>]) -> Vec<(ExtId, i64)> {
let mut v: Vec<(ExtId, i64)> = items.iter().map(|i| (i.id, i.amount)).collect();
v.sort_unstable();
v
}
impl<E> Strategy<E> for FixedPoint<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut groups = Vec::new();
let mut residual = bag;
let mut fp = residual_fingerprint(&residual);
for _ in 0..self.max_passes {
if residual.is_empty() {
break;
}
let r = self.inner.run(std::mem::take(&mut residual));
groups.extend(r.groups);
residual = r.residual;
let next = residual_fingerprint(&residual);
if next == fp {
break;
}
fp = next;
}
Resolution { groups, residual }
}
}
pub fn fixed_point<E: 'static>(
inner: Box<dyn Strategy<E>>,
max_passes: usize,
) -> Box<dyn Strategy<E>> {
Box::new(FixedPoint {
inner,
max_passes: max_passes.max(1),
})
}
type ShardFactory<E, K> = dyn Fn(&K) -> Box<dyn Strategy<E>>;
struct PartitionBy<E, K, FK> {
key: FK,
factory: Box<ShardFactory<E, K>>,
}
impl<E, K, FK> Strategy<E> for PartitionBy<E, K, FK>
where
K: Hash + Eq + Clone,
FK: Fn(&Item<E>) -> K,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut shards: HashMap<K, Vec<Item<E>>> = HashMap::new();
for item in bag {
shards.entry((self.key)(&item)).or_default().push(item);
}
let mut groups = Vec::new();
let mut residual = Vec::new();
for (k, items) in shards {
let r = (self.factory)(&k).run(items);
groups.extend(r.groups);
residual.extend(r.residual);
}
Resolution { groups, residual }
}
}
pub fn partition_by<E: 'static, K, FK, FF>(key: FK, factory: FF) -> Box<dyn Strategy<E>>
where
K: Hash + Eq + Clone + 'static,
FK: Fn(&Item<E>) -> K + 'static,
FF: Fn(&K) -> Box<dyn Strategy<E>> + 'static,
{
Box::new(PartitionBy {
key,
factory: Box::new(factory),
})
}
struct When<E, FP> {
pred: FP,
inner: Box<dyn Strategy<E>>,
}
impl<E, FP> Strategy<E> for When<E, FP>
where
FP: Fn(&Item<E>) -> bool,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut yes = Vec::new();
let mut no = Vec::new();
for item in bag {
if (self.pred)(&item) {
yes.push(item);
} else {
no.push(item);
}
}
let mut r = self.inner.run(yes);
r.residual.extend(no);
r
}
}
pub fn when<E: 'static, FP>(pred: FP, inner: Box<dyn Strategy<E>>) -> Box<dyn Strategy<E>>
where
FP: Fn(&Item<E>) -> bool + 'static,
{
Box::new(When { pred, inner })
}
struct Identity;
impl<E> Strategy<E> for Identity {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
Resolution {
groups: Vec::new(),
residual: bag,
}
}
}
pub fn identity<E: 'static>() -> Box<dyn Strategy<E>> {
Box::new(Identity)
}
struct Windowed<E, FO> {
order: FO,
width: i64,
inner: Box<dyn Strategy<E>>,
_e: PhantomData<E>,
}
impl<E, FO> Strategy<E> for Windowed<E, FO>
where
FO: Fn(&Item<E>) -> i64,
{
fn run(&self, mut bag: Vec<Item<E>>) -> Resolution<E> {
let w = self.width.max(1);
bag.sort_by_key(|i| (self.order)(i));
let mut groups = Vec::new();
let mut residual = Vec::new();
let mut carry: Vec<Item<E>> = Vec::new();
let mut it = bag.into_iter().peekable();
while let Some(first) = it.peek() {
let band_bottom = (self.order)(first);
let mut band = Vec::new();
while let Some(item) = it.peek() {
if (self.order)(item) < band_bottom + w {
band.push(it.next().unwrap());
} else {
break;
}
}
let mut keep = Vec::new();
for item in carry.drain(..) {
if (self.order)(&item) + w >= band_bottom {
keep.push(item);
} else {
residual.push(item);
}
}
keep.extend(band);
let r = self.inner.run(keep);
groups.extend(r.groups);
carry = r.residual; }
residual.extend(carry);
Resolution { groups, residual }
}
}
pub fn windowed<E: 'static, FO>(
order: FO,
width: i64,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FO: Fn(&Item<E>) -> i64 + 'static,
{
Box::new(Windowed {
order,
width,
inner,
_e: PhantomData,
})
}
struct ExactOneToOne<E, FK> {
key: FK,
_e: PhantomData<E>,
}
impl<E, FK> Strategy<E> for ExactOneToOne<E, FK>
where
FK: Fn(&Item<E>) -> Option<u64>,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut buckets: HashMap<u64, Vec<Item<E>>> = HashMap::new();
let mut residual = Vec::new();
for item in bag {
match (self.key)(&item) {
Some(k) if item.amount != 0 => buckets.entry(k).or_default().push(item),
_ => residual.push(item),
}
}
let mut groups = Vec::new();
for (_k, items) in buckets {
type Signed<E> = (Vec<Item<E>>, Vec<Item<E>>);
let mut by_mag: HashMap<i64, Signed<E>> = HashMap::new();
for item in items {
let a = item.amount;
let slot = by_mag.entry(a.abs()).or_default();
if a > 0 {
slot.0.push(item);
} else {
slot.1.push(item);
}
}
for (_mag, (mut pos, mut neg)) in by_mag {
pos.sort_unstable_by_key(|i| i.id);
neg.sort_unstable_by_key(|i| i.id);
let pairs = pos.len().min(neg.len());
for _ in 0..pairs {
let p = pos.pop().unwrap();
let n = neg.pop().unwrap();
groups.push(Group {
members: vec![
Allocation {
id: p.id,
amount: p.amount,
},
Allocation {
id: n.id,
amount: n.amount,
},
],
origin: "exact_1to1".to_string(),
net: 0,
reason: Some("exact 1:1 pair".to_string()),
});
}
residual.extend(pos);
residual.extend(neg);
}
}
Resolution { groups, residual }
}
}
pub fn exact_1to1<E: 'static, FK>(key: FK) -> Box<dyn Strategy<E>>
where
FK: Fn(&Item<E>) -> Option<u64> + 'static,
{
Box::new(ExactOneToOne {
key,
_e: PhantomData,
})
}
struct AggNet<E, FK, FP> {
key: FK,
accept: FP,
_e: PhantomData<E>,
}
impl<E, FK, FP> Strategy<E> for AggNet<E, FK, FP>
where
FK: Fn(&Item<E>) -> u64,
FP: Fn(&GroupView<E>) -> bool,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut buckets: HashMap<u64, Vec<Item<E>>> = HashMap::new();
for item in bag {
buckets.entry((self.key)(&item)).or_default().push(item);
}
let mut groups = Vec::new();
let mut residual = Vec::new();
for (_k, items) in buckets {
let sum: i64 = items.iter().map(|i| i.amount).sum();
let signs = items.iter().fold((false, false), |(p, n), i| {
let a = i.amount;
(p || a > 0, n || a < 0)
});
let accept = items.len() >= 2
&& signs.0
&& signs.1
&& (self.accept)(&GroupView::from_refs(&items));
if accept {
groups.push(Group {
members: items
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect(),
origin: "agg_net".to_string(),
net: sum,
reason: Some("aggregate net".to_string()),
});
} else {
residual.extend(items);
}
}
Resolution { groups, residual }
}
}
pub fn agg_net<E: 'static, FK, FP>(key: FK, accept: FP) -> Box<dyn Strategy<E>>
where
FK: Fn(&Item<E>) -> u64 + 'static,
FP: Fn(&GroupView<E>) -> bool + 'static,
{
Box::new(AggNet {
key,
accept,
_e: PhantomData,
})
}
struct Cumulative<E, FO, FP> {
order: FO,
accept: FP,
_e: PhantomData<E>,
}
impl<E, FO, FP> Strategy<E> for Cumulative<E, FO, FP>
where
FO: Fn(&Item<E>) -> i64,
FP: Fn(&GroupView<E>) -> bool,
{
fn run(&self, mut bag: Vec<Item<E>>) -> Resolution<E> {
bag.sort_by_key(|i| (self.order)(i));
let mut groups = Vec::new();
let mut seg: Vec<Item<E>> = Vec::new();
for item in bag {
seg.push(item);
let close = seg.len() >= 2 && (self.accept)(&GroupView::from_refs(&seg));
if close {
let net: i64 = seg.iter().map(|i| i.amount).sum();
groups.push(Group {
members: seg
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect(),
origin: "cumulative".to_string(),
net,
reason: Some("cumulative segment".to_string()),
});
seg.clear();
}
}
Resolution {
groups,
residual: seg, }
}
}
pub fn cumulative<E: 'static, FO, FP>(order: FO, accept: FP) -> Box<dyn Strategy<E>>
where
FO: Fn(&Item<E>) -> i64 + 'static,
FP: Fn(&GroupView<E>) -> bool + 'static,
{
Box::new(Cumulative {
order,
accept,
_e: PhantomData,
})
}
struct SignalGroup<E, FS, FP> {
signals: FS,
accept: FP,
cap: usize,
_e: PhantomData<E>,
}
impl<E, FS, FP> Strategy<E> for SignalGroup<E, FS, FP>
where
FS: Fn(&Item<E>) -> Vec<u64>,
FP: Fn(&GroupView<E>) -> bool,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let n = bag.len();
let amt: Vec<i64> = bag.iter().map(|i| i.amount).collect();
let sigs: Vec<Vec<u64>> = bag.iter().map(|i| (self.signals)(i)).collect();
let mut index: HashMap<u64, Vec<usize>> = HashMap::new();
for (i, s) in sigs.iter().enumerate() {
for &k in s {
index.entry(k).or_default().push(i);
}
}
let mut order: Vec<(usize, u64)> = index.iter().map(|(k, v)| (v.len(), *k)).collect();
order.sort_unstable();
let mut used = vec![false; n];
let mut groups = Vec::new();
for (_len, k) in order {
let members: Vec<usize> = index[&k].iter().copied().filter(|&i| !used[i]).collect();
if members.len() < 2 || members.len() > self.cap {
continue;
}
let sum: i64 = members.iter().map(|&i| amt[i]).sum();
let has_pos = members.iter().any(|&i| amt[i] > 0);
let has_neg = members.iter().any(|&i| amt[i] < 0);
let view = GroupView::from_refs(members.iter().map(|&i| &bag[i]));
if has_pos && has_neg && (self.accept)(&view) {
for &i in &members {
used[i] = true;
}
groups.push(Group {
members: members
.iter()
.map(|&i| Allocation {
id: bag[i].id,
amount: amt[i],
})
.collect(),
origin: "signal_group".to_string(),
net: sum,
reason: Some("shared reference".to_string()),
});
}
}
let residual = bag
.into_iter()
.enumerate()
.filter(|(i, _)| !used[*i])
.map(|(_, item)| item)
.collect();
Resolution { groups, residual }
}
}
pub fn signal_group<E: 'static, FS, FP>(
signals: FS,
accept: FP,
cap: usize,
) -> Box<dyn Strategy<E>>
where
FS: Fn(&Item<E>) -> Vec<u64> + 'static,
FP: Fn(&GroupView<E>) -> bool + 'static,
{
Box::new(SignalGroup {
signals,
accept,
cap,
_e: PhantomData,
})
}
fn splitmix64(mut z: u64) -> u64 {
z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn seed_mix(a: u64, b: u64) -> u64 {
splitmix64(a ^ splitmix64(b))
}
const SUBSET_CAND_CAP: usize = 32;
struct SubsetSum<E> {
band: i64,
max_group: usize,
seed: u64,
_e: PhantomData<E>,
}
impl<E> Strategy<E> for SubsetSum<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let n = bag.len();
let mut consumed = vec![false; n];
let max_partners = self.max_group.saturating_sub(1);
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| {
(
std::cmp::Reverse(bag[i].amount.unsigned_abs()),
seed_mix(bag[i].id, self.seed),
)
});
let mut groups = Vec::new();
for &ai in &order {
if consumed[ai] || bag[ai].amount == 0 || max_partners == 0 {
continue;
}
let target = bag[ai].amount.unsigned_abs() as i64;
let want_sign = -bag[ai].amount.signum();
let mut cands: Vec<(usize, ExtId, i64)> = (0..n)
.filter(|&j| {
!consumed[j]
&& j != ai
&& bag[j].amount.signum() == want_sign
&& bag[j].amount.unsigned_abs() as i64 <= target + self.band
})
.map(|j| (j, bag[j].id, bag[j].amount.unsigned_abs() as i64))
.collect();
if cands.is_empty() {
continue;
}
if cands.len() > SUBSET_CAND_CAP {
cands.sort_by_key(|&(_, _, mag)| std::cmp::Reverse(mag));
cands.truncate(SUBSET_CAND_CAP);
}
if let Some(chosen) = best_subset(target, self.band, max_partners, &cands, self.seed) {
consumed[ai] = true;
let mut members = vec![Allocation {
id: bag[ai].id,
amount: bag[ai].amount,
}];
for &ci in &chosen {
let j = cands[ci].0;
consumed[j] = true;
members.push(Allocation {
id: bag[j].id,
amount: bag[j].amount,
});
}
let net = members.iter().map(|m| m.amount).sum();
let size = members.len();
groups.push(Group {
members,
origin: "subset-sum".to_string(),
net,
reason: Some(format!("subset sum of {size} lots")),
});
}
}
let residual = bag
.into_iter()
.zip(consumed)
.filter_map(|(item, used)| (!used).then_some(item))
.collect();
Resolution { groups, residual }
}
}
fn best_subset(
target: i64,
band: i64,
k_max: usize,
cands: &[(usize, ExtId, i64)],
seed: u64,
) -> Option<Vec<usize>> {
if k_max == 0 || cands.is_empty() {
return None;
}
let m = cands.len();
let mid = m / 2;
let enumerate = |lo: usize, len: usize| -> Vec<(i64, u32, u32, u64)> {
let mut out = Vec::new();
for mask in 0u32..(1u32 << len) {
let pc = mask.count_ones();
if pc as usize > k_max {
continue;
}
let mut sum = 0i64;
let mut key = 0u64;
for b in 0..len {
if mask & (1u32 << b) != 0 {
sum += cands[lo + b].2;
key ^= seed_mix(cands[lo + b].1, seed);
}
}
out.push((sum, pc, mask, key));
}
out
};
let left = enumerate(0, mid);
let mut right = enumerate(mid, m - mid);
right.sort_by_key(|&(s, _, _, _)| s);
let rsums: Vec<i64> = right.iter().map(|&(s, _, _, _)| s).collect();
let mut best: Option<(i64, u32, u64, u32, u32)> = None;
for &(sl, cl, ml, kl) in &left {
let lo = target - band - sl;
let hi = target + band - sl;
let start = rsums.partition_point(|&s| s < lo);
for &(sr, cr, mr, kr) in &right[start..] {
if sr > hi {
break;
}
let card = cl + cr;
if card == 0 || card as usize > k_max {
continue;
}
let cand = ((sl + sr - target).abs(), card, kl ^ kr, ml, mr);
if best.is_none_or(|b| (cand.0, cand.1, cand.2) < (b.0, b.1, b.2)) {
best = Some(cand);
}
}
}
best.map(|(_, _, _, ml, mr)| {
let mut picks = Vec::new();
for b in 0..mid {
if ml & (1u32 << b) != 0 {
picks.push(b);
}
}
for b in 0..(m - mid) {
if mr & (1u32 << b) != 0 {
picks.push(mid + b);
}
}
picks
})
}
pub fn subset_sum<E: 'static>(band: i64, max_group: usize, seed: u64) -> Box<dyn Strategy<E>> {
Box::new(SubsetSum {
band,
max_group,
seed,
_e: PhantomData,
})
}
struct Restart<E, F> {
n: usize,
seed: u64,
factory: F,
_e: PhantomData<E>,
}
impl<E, F> Strategy<E> for Restart<E, F>
where
E: Clone,
F: Fn(u64) -> Box<dyn Strategy<E>>,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let runs = self.n.max(1);
let mut best: Option<Resolution<E>> = None;
let mut best_score = (i64::MIN, i64::MIN);
for i in 0..runs {
let s = seed_mix(self.seed, i as u64);
let r = (self.factory)(s).run(bag.clone());
let matched: i64 = r
.groups
.iter()
.flat_map(|g| &g.members)
.map(|a| a.amount.abs())
.sum();
let score = (matched, -(r.residual.len() as i64));
if best.is_none() || score > best_score {
best_score = score;
best = Some(r);
}
}
best.expect("restart runs at least once")
}
}
pub fn restart<E, F>(n: usize, seed: u64, factory: F) -> Box<dyn Strategy<E>>
where
E: Clone + 'static,
F: Fn(u64) -> Box<dyn Strategy<E>> + 'static,
{
Box::new(Restart {
n,
seed,
factory,
_e: PhantomData,
})
}
#[derive(Clone)]
struct PivotMeta<E> {
outer: Item<E>,
alt_original: i64,
}
struct Pivot<E, FA> {
amount: FA,
inner: Box<dyn Strategy<E>>,
}
fn prorate(total: i64, part: i64, denom: i64) -> i64 {
if denom == 0 || total == 0 || part == 0 {
return 0;
}
let num = part as i128 * total as i128;
let den = denom as i128;
(num / den) as i64
}
impl<E, FA> Strategy<E> for Pivot<E, FA>
where
E: Clone,
FA: Fn(&Item<E>) -> i64,
{
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let mut meta: BTreeMap<ExtId, PivotMeta<E>> = BTreeMap::new();
let inner_bag: Vec<Item<E>> = bag
.into_iter()
.map(|outer| {
let alt_original = (self.amount)(&outer);
let alt_amount = prorate(alt_original, outer.amount, outer.original);
let id = outer.id;
let data = outer.data.clone();
meta.insert(
id,
PivotMeta {
outer,
alt_original,
},
);
Item {
id,
original: alt_original,
amount: alt_amount,
data,
}
})
.collect();
let mut res = self.inner.run(inner_bag);
{
let mut group_pivot: BTreeMap<ExtId, i64> = BTreeMap::new();
for g in &res.groups {
for a in &g.members {
*group_pivot.entry(a.id).or_insert(0) += a.amount;
}
}
let mut dissolve: BTreeSet<ExtId> = BTreeSet::new();
for (id, &gp) in &group_pivot {
if gp == 0 {
continue;
}
let Some(m) = meta.get(id) else { continue };
if prorate(m.outer.amount, gp, m.alt_original) == 0 {
dissolve.insert(*id);
}
}
if !dissolve.is_empty() {
let mut moved: BTreeMap<ExtId, i64> = BTreeMap::new();
for g in &mut res.groups {
g.members.retain(|a| {
if dissolve.contains(&a.id) {
*moved.entry(a.id).or_insert(0) += a.amount;
false
} else {
true
}
});
g.net = g.members.iter().map(|a| a.amount).sum();
}
res.groups.retain(|g| !g.members.is_empty());
for (id, amt) in moved {
if amt == 0 {
continue;
}
if let Some(item) = res.residual.iter_mut().find(|i| i.id == id) {
item.amount += amt;
} else if let Some(m) = meta.get(&id) {
res.residual.push(Item {
id,
original: m.alt_original,
amount: amt,
data: m.outer.data.clone(),
});
}
}
}
}
let mut parts: BTreeMap<ExtId, Vec<(usize, Option<usize>, i64)>> = BTreeMap::new();
for (gi, g) in res.groups.iter().enumerate() {
for (mi, a) in g.members.iter().enumerate() {
parts
.entry(a.id)
.or_default()
.push((gi, Some(mi), a.amount));
}
}
for (ri, item) in res.residual.iter().enumerate() {
parts
.entry(item.id)
.or_default()
.push((ri, None, item.amount));
}
let mut group_amounts: Vec<Vec<i64>> = res
.groups
.iter()
.map(|g| vec![0; g.members.len()])
.collect();
let mut residual_amounts: Vec<i64> = vec![0; res.residual.len()];
let accounted: BTreeSet<ExtId> = parts.keys().copied().collect();
for (id, ps) in parts {
let Some(m) = meta.get(&id) else { continue };
let mut converted = Vec::with_capacity(ps.len());
let mut sum = 0i64;
for (_, _, amt) in &ps {
let v = prorate(m.outer.amount, *amt, m.alt_original);
converted.push(v);
sum += v;
}
if let Some(last) = converted.last_mut() {
*last += m.outer.amount - sum;
}
for ((idx, mi, _), v) in ps.into_iter().zip(converted) {
if let Some(mi) = mi {
group_amounts[idx][mi] = v;
} else {
residual_amounts[idx] = v;
}
}
}
let groups = res
.groups
.into_iter()
.enumerate()
.filter_map(|(gi, mut g)| {
for (mi, a) in g.members.iter_mut().enumerate() {
a.amount = group_amounts[gi][mi];
}
g.members.retain(|a| a.amount != 0);
if g.members.is_empty() {
return None;
}
g.net = g.members.iter().map(|a| a.amount).sum();
Some(g)
})
.collect();
let mut residual: Vec<Item<E>> = res
.residual
.into_iter()
.enumerate()
.filter_map(|(ri, mut i)| {
let m = meta.remove(&i.id)?;
i.original = m.outer.original;
i.amount = residual_amounts[ri];
(i.amount != 0).then_some(i)
})
.collect();
for (id, m) in meta {
if !accounted.contains(&id) && m.outer.amount != 0 {
residual.push(Item {
id,
original: m.outer.original,
amount: m.outer.amount,
data: m.outer.data,
});
}
}
Resolution { groups, residual }
}
}
pub fn pivot<E: Clone + 'static, FA>(
amount: FA,
inner: Box<dyn Strategy<E>>,
) -> Box<dyn Strategy<E>>
where
FA: Fn(&Item<E>) -> i64 + 'static,
{
Box::new(Pivot { amount, inner })
}
struct Soak<E> {
origin: String,
_e: PhantomData<E>,
}
impl<E> Strategy<E> for Soak<E> {
fn run(&self, bag: Vec<Item<E>>) -> Resolution<E> {
let members: Vec<Allocation> = bag
.iter()
.filter(|i| i.amount != 0)
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect();
let groups = if members.is_empty() {
Vec::new()
} else {
let net = members.iter().map(|a| a.amount).sum();
vec![Group {
members,
origin: self.origin.clone(),
net,
reason: None,
}]
};
Resolution {
groups,
residual: Vec::new(),
}
}
}
pub fn soak<E: 'static>(origin: impl Into<String>) -> Box<dyn Strategy<E>> {
Box::new(Soak {
origin: origin.into(),
_e: PhantomData,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn bag(items: &[(ExtId, i64)]) -> Vec<Item<i64>> {
items.iter().map(|&(id, a)| Item::new(id, a, a)).collect()
}
fn ids(g: &Group) -> Vec<ExtId> {
let mut m = g.member_ids();
m.sort();
m
}
fn conserves<E>(input: usize, r: &Resolution<E>) {
let g: usize = r.groups.iter().map(|g| g.members.len()).sum();
assert_eq!(g + r.residual.len(), input, "conservation violated");
}
#[test]
fn agg_net_relative_tolerance_scales_with_smallest_leg() {
let b = bag(&[(1, 10_000), (2, -9_991)]);
let s = agg_net(|_a: &Item<i64>| 0u64, |g| {
g.net().abs() <= 10 * g.min_leg() / 10_000
});
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1, "9 <= 10bps of the smallest leg");
let b = bag(&[(1, 10_000), (2, -9_991)]);
let s = agg_net(|_a: &Item<i64>| 0u64, |g| g.net().abs() <= 5 * g.min_leg() / 10_000);
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 0, "9 > 5bps of the smallest leg");
}
#[test]
fn agg_net_relative_floor_applies_to_tiny_buckets() {
let b = bag(&[(1, 100), (2, -98)]);
let s = agg_net(|_a: &Item<i64>| 0u64, |g| {
g.net().abs() <= (10 * g.min_leg() / 10_000).max(3)
});
let r = s.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
}
#[test]
fn labeled_stamps_reason_on_groups_but_not_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 7)]);
let s = labeled("S3a exact", exact_1to1(|_| Some(0)));
let r = s.run(b);
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(
r.groups[0].reason.as_deref(),
Some("S3a exact: exact 1:1 pair")
);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn labeled_prepends_to_inner_detail() {
let b = bag(&[(1, 5), (2, -5)]);
let s = labeled("outer", labeled("inner", exact_1to1(|_| Some(0))));
let r = s.run(b);
assert_eq!(
r.groups[0].reason.as_deref(),
Some("outer: inner: exact 1:1 pair")
);
}
#[test]
fn exact_pairs_and_leaves_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 5), (4, 3)]);
let s = exact_1to1(|_| Some(0));
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert!(r.groups[0].member_ids().contains(&2));
assert_eq!(r.residual.len(), 2);
}
#[test]
fn agg_accepts_netting_bucket() {
let b = bag(&[(1, 100), (2, -60), (3, -40), (4, 7)]);
let s = agg_net(|_a: &Item<i64>| 0u64, |g| g.net() == 0);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 0);
let b = bag(&[(1, 100), (2, -60), (3, -40), (4, 7)]);
let s = agg_net(|_a: &Item<i64>| 0u64, |g| g.net().abs() <= 10);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members.len(), 4);
}
#[test]
fn signal_groups_net_and_cascade() {
let b = bag(&[(1, 50), (2, -50), (3, 9)]);
let s = signal_group(
|a: &Item<i64>| if a.amount == 9 { vec![] } else { vec![10] },
|g| g.net() == 0,
16,
);
let r = s.run(b);
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
assert_eq!(r.residual.len(), 1);
}
#[test]
fn signal_groups_accept_relative_tol() {
let b = bag(&[(1, 1000), (2, -995)]);
let tight = signal_group(
|_: &Item<i64>| vec![7u64],
|g| g.net().abs() <= 10 * g.min_leg() / 10_000,
16,
);
let r = tight.run(b.clone());
assert_eq!(r.groups.len(), 0);
assert_eq!(r.residual.len(), 2);
let loose = signal_group(
|_: &Item<i64>| vec![7u64],
|g| g.net().abs() <= 60 * g.min_leg() / 10_000,
16,
);
let r = loose.run(b);
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
}
struct OnePair;
impl Strategy<i64> for OnePair {
fn run(&self, bag: Vec<Item<i64>>) -> Resolution<i64> {
for i in 0..bag.len() {
for j in (i + 1)..bag.len() {
if bag[i].amount == -bag[j].amount && bag[i].amount != 0 {
let mut residual = Vec::new();
let mut members = Vec::new();
for (k, item) in bag.into_iter().enumerate() {
if k == i || k == j {
members.push(Allocation {
id: item.id,
amount: item.amount,
});
} else {
residual.push(item);
}
}
let g = Group {
members,
origin: "onepair".into(),
net: 0,
reason: None,
};
return Resolution {
groups: vec![g],
residual,
};
}
}
}
Resolution {
groups: vec![],
residual: bag,
}
}
}
#[test]
fn fixed_point_drives_a_non_maximal_leaf_to_completion() {
let once = OnePair;
let r = once.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 2);
let fp = fixed_point(Box::new(OnePair), 16);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 2, "both pairs found across passes");
assert_eq!(r.residual.len(), 0);
}
#[test]
fn fixed_point_leaves_unmatchable_residual_and_terminates() {
let fp = fixed_point(Box::new(OnePair), 16);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, 3)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
let mut left: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
left.sort();
assert_eq!(left, vec![3, 4]);
}
#[test]
fn fixed_point_respects_the_pass_cap() {
let fp = fixed_point(Box::new(OnePair), 1);
let r = fp.run(bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 1, "cap of 1 means one pass");
assert_eq!(r.residual.len(), 2);
}
#[test]
fn when_cascade_routes_to_different_children_and_conserves() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let s = seq(vec![
when(
|a: &Item<i64>| a.amount.unsigned_abs() == 5,
agg_net(|_a: &Item<i64>| 1u64, |g| g.net() == 0),
),
agg_net(|_a: &Item<i64>| 2u64, |g| g.net() == 0),
]);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
assert_eq!(r.residual.len(), 0);
}
#[test]
fn partition_by_picks_a_per_key_subtree() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let s = partition_by(
|a: &Item<i64>| (a.amount.unsigned_abs() == 5) as u8,
|k: &u8| {
if *k == 1 {
agg_net(|_a: &Item<i64>| 0u64, |g| g.net() == 0)
} else {
identity()
}
},
);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
let mut rem: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
rem.sort_unstable();
assert_eq!(rem, vec![3, 4]);
}
#[test]
fn windowed_blocks_far_matches() {
let b = vec![Item::new(1, 5, (1i64, 5i64)), Item::new(2, -5, (100, -5))];
let inner = exact_1to1(|_| Some(0));
let r = {
let w = windowed(|d: &Item<(i64, i64)>| d.data.0, 3, inner);
w.run(b)
};
assert_eq!(r.groups.len(), 0);
assert_eq!(r.residual.len(), 2);
}
#[test]
fn windowed_finds_near_match_across_band_boundary() {
let b = vec![Item::new(1, 5, (4i64, 5i64)), Item::new(2, -5, (7, -5))];
let inner = exact_1to1(|_| Some(0));
let r = {
let w = windowed(|d: &Item<(i64, i64)>| d.data.0, 3, inner);
w.run(b)
};
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 0);
}
#[test]
fn cumulative_segments_at_balance_clears() {
let b = vec![
Item::new(1, 100, (1i64, 100i64)),
Item::new(2, -100, (2, -100)),
Item::new(3, 50, (3, 50)),
Item::new(4, -30, (4, -30)),
Item::new(5, -20, (5, -20)),
];
let s = cumulative(|d: &Item<(i64, i64)>| d.data.0, |g| g.net() == 0);
let r = s.run(b);
conserves(5, &r);
assert_eq!(r.groups.len(), 2);
assert_eq!(r.groups[0].member_ids(), vec![1, 2]);
assert_eq!(r.groups[1].member_ids(), vec![3, 4, 5]);
}
#[test]
fn cumulative_leaves_uncleared_tail() {
let b = vec![
Item::new(1, 100, (1i64, 100i64)),
Item::new(2, -100, (2, -100)),
Item::new(3, 7, (3, 7)),
];
let s = cumulative(|d: &Item<(i64, i64)>| d.data.0, |g| g.net() == 0);
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn seq_then_partition_compose() {
let pipeline = partition_by(
|a: &Item<i64>| a.amount.signum().unsigned_abs(),
|_| seq(vec![exact_1to1(|_| Some(0))]),
);
let b = bag(&[(1, 4), (2, -4), (3, 4), (4, -4)]);
let r = pipeline.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
}
#[test]
fn accept_if_rejects_groups_back_to_residual() {
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let s = accept_if(
|g: &GroupView<i64>| g.members().all(|m| m.amount.abs() == 5),
exact_1to1(|_| Some(0)),
);
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
let mut left: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
left.sort();
assert_eq!(left, vec![3, 4]);
for i in &r.residual {
assert_eq!(i.amount.abs(), 7);
}
}
struct Slice {
moved: i64,
}
impl Strategy<i64> for Slice {
fn run(&self, bag: Vec<Item<i64>>) -> Resolution<i64> {
let m: HashMap<ExtId, Item<i64>> = bag.into_iter().map(|i| (i.id, i)).collect();
let groups = vec![Group {
members: vec![
Allocation {
id: 1,
amount: self.moved,
},
Allocation {
id: 2,
amount: -self.moved,
},
],
origin: "flow".into(),
net: 0,
reason: None,
}];
let mut t1 = m[&1].clone();
t1.amount = m[&1].original - self.moved;
let mut t2 = m[&2].clone();
t2.amount = m[&2].original + self.moved;
let residual = [t1, t2].into_iter().filter(|i| i.amount != 0).collect();
Resolution { groups, residual }
}
}
#[test]
fn material_rel_prunes_a_sliver_of_a_big_row() {
let s = accept_if(
|g: &GroupView<i64>| g.gross() > 500 * g.original_total() / 10_000,
Box::new(Slice { moved: 30 }),
);
let r = s.run(bag(&[(1, 1000), (2, -1000)]));
assert!(r.groups.is_empty(), "60 <= 5% of 2000");
let mut left: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
left.sort();
assert_eq!(left, vec![(1, 1000), (2, -1000)], "rows return whole");
}
#[test]
fn material_rel_keeps_a_substantial_match() {
let s = accept_if(
|g: &GroupView<i64>| g.gross() > 500 * g.original_total() / 10_000,
Box::new(Slice { moved: 300 }),
);
let r = s.run(bag(&[(1, 1000), (2, -1000)]));
assert_eq!(r.groups.len(), 1, "600 > 5% of 2000");
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
let by_id = |id: ExtId| {
r.residual
.iter()
.filter(|i| i.id == id)
.map(|i| i.amount)
.sum::<i64>()
};
assert_eq!(by_id(1), 700);
assert_eq!(by_id(2), -700);
}
#[test]
fn material_abs_prunes_below_fixed_floor_ignoring_original() {
let s = accept_if(|g: &GroupView<i64>| g.gross() > 100, Box::new(Slice { moved: 30 }));
let r = s.run(bag(&[(1, 1000), (2, -1000)]));
assert!(r.groups.is_empty(), "60 <= 100");
let s = accept_if(|g: &GroupView<i64>| g.gross() > 50, Box::new(Slice { moved: 30 }));
let r = s.run(bag(&[(1, 1000), (2, -1000)]));
assert_eq!(r.groups.len(), 1, "60 > 50");
}
#[test]
fn material_dissolve_merges_onto_existing_residual_and_conserves() {
let s = accept_if(|g: &GroupView<i64>| g.gross() > 1000, Box::new(Slice { moved: 30 }));
let r = s.run(bag(&[(1, 1000), (2, -1000)]));
assert!(r.groups.is_empty());
assert_eq!(r.residual.len(), 2, "no duplicate lots");
let mut left: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
left.sort();
assert_eq!(left, vec![(1, 1000), (2, -1000)]);
}
#[test]
fn subset_sum_clears_one_against_many_whole_lots() {
let s = subset_sum(0, 8, 0);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -40), (4, -25)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].origin, "subset-sum");
assert_eq!(r.groups[0].net, 0);
assert_eq!(ids(&r.groups[0]), vec![1, 2, 3]);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 4);
assert_eq!(r.residual[0].amount, -25, "unmatched lot stays whole");
}
#[test]
fn subset_sum_keeps_break_inside_within_tol() {
let s = subset_sum(2, 8, 0);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -38)]));
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 2);
assert_eq!(ids(&r.groups[0]), vec![1, 2, 3]);
assert!(r.residual.is_empty());
let s = subset_sum(1, 8, 0);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -38)]));
conserves(3, &r);
assert!(r.groups.is_empty());
assert_eq!(r.residual.len(), 3);
}
#[test]
fn subset_sum_respects_the_group_size_cap() {
let s = subset_sum(0, 2, 0);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -40)]));
conserves(3, &r);
assert!(r.groups.is_empty());
assert_eq!(r.residual.len(), 3);
let s = subset_sum(0, 3, 0);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -40)]));
assert_eq!(r.groups.len(), 1);
assert_eq!(ids(&r.groups[0]), vec![1, 2, 3]);
}
#[test]
fn subset_sum_is_reproducible_across_runs() {
let run = || {
let s = subset_sum(0, 8, 7);
let r = s.run(bag(&[(1, 100), (2, -50), (3, -50), (4, -100), (5, 100)]));
r.groups.iter().map(|g| (ids(g), g.net)).collect::<Vec<_>>()
};
assert_eq!(run(), run());
}
#[test]
fn restart_keeps_the_attempt_that_matches_most() {
struct Toy {
seed: u64,
}
impl Strategy<i64> for Toy {
fn run(&self, bag: Vec<Item<i64>>) -> Resolution<i64> {
if self.seed % 2 == 1 {
let members = bag
.iter()
.map(|i| Allocation {
id: i.id,
amount: i.amount,
})
.collect();
Resolution {
groups: vec![Group {
members,
origin: "toy".into(),
net: bag.iter().map(|i| i.amount).sum(),
reason: None,
}],
residual: vec![],
}
} else {
Resolution {
groups: vec![],
residual: bag,
}
}
}
}
let s = restart(6, 0, |seed| Box::new(Toy { seed }));
let r = s.run(bag(&[(1, 5), (2, -5)]));
conserves(2, &r);
assert_eq!(r.groups.len(), 1, "the matching seed wins");
assert!(r.residual.is_empty());
}
#[test]
fn restart_drives_subset_sum_to_a_full_clear() {
let factory = |seed: u64| accept_if(|g: &GroupView<i64>| g.net() == 0, subset_sum(0, 8, seed));
let s = restart(8, 42, factory);
let r = s.run(bag(&[(1, 100), (2, -50), (3, -50), (4, -50), (5, 50)]));
conserves(5, &r);
assert!(r.groups.iter().all(|g| g.net == 0));
let matched: i64 = r
.groups
.iter()
.flat_map(|g| &g.members)
.map(|a| a.amount.abs())
.sum();
let residual: i64 = r.residual.iter().map(|i| i.amount.abs()).sum();
assert_eq!(matched + residual, 300, "conservation in amount");
}
#[test]
fn reclaim_keeps_break_within_tol_via_accept_if() {
struct Partial;
impl Strategy<i64> for Partial {
fn run(&self, bag: Vec<Item<i64>>) -> Resolution<i64> {
let m: HashMap<ExtId, Item<i64>> = bag.into_iter().map(|i| (i.id, i)).collect();
let groups = vec![Group {
members: vec![
Allocation { id: 1, amount: 97 },
Allocation { id: 2, amount: -97 },
],
origin: "flow".into(),
net: 0,
reason: None,
}];
let mut tail = m[&1].clone();
tail.amount = 3;
Resolution {
groups,
residual: vec![tail],
}
}
}
let s = accept_if(
|g: &GroupView<i64>| g.net().abs() <= 5,
reclaim("settlement", Box::new(Partial)),
);
let r = s.run(bag(&[(1, 100), (2, -97)]));
conserves(2, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 3);
assert_eq!(ids(&r.groups[0]), vec![1, 2]);
assert!(r.residual.is_empty());
let s = accept_if(
|g: &GroupView<i64>| g.net().abs() <= 2,
reclaim("settlement", Box::new(Partial)),
);
let r = s.run(bag(&[(1, 100), (2, -97)]));
conserves(2, &r);
assert!(r.groups.is_empty());
let mut left: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
left.sort();
assert_eq!(left, vec![(1, 100), (2, -97)]); }
#[test]
fn reclaim_collapses_groups_sharing_a_line() {
struct Split;
impl Strategy<i64> for Split {
fn run(&self, _bag: Vec<Item<i64>>) -> Resolution<i64> {
let groups = vec![
Group {
members: vec![
Allocation { id: 1, amount: 60 },
Allocation { id: 2, amount: -60 },
],
origin: "a".into(),
net: 0,
reason: None,
},
Group {
members: vec![
Allocation { id: 1, amount: 40 },
Allocation { id: 3, amount: -40 },
],
origin: "b".into(),
net: 0,
reason: None,
},
];
Resolution {
groups,
residual: vec![],
}
}
}
let s = accept_if(
|g: &GroupView<i64>| g.net() == 0,
reclaim("settlement", Box::new(Split)),
);
let r = s.run(bag(&[(1, 100), (2, -60), (3, -40)]));
conserves(3, &r);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].origin, "settlement");
assert_eq!(r.groups[0].net, 0);
let mut mem: Vec<(ExtId, i64)> = r.groups[0]
.members
.iter()
.map(|a| (a.id, a.amount))
.collect();
mem.sort();
assert_eq!(mem, vec![(1, 100), (2, -60), (3, -40)]);
assert!(r.residual.is_empty());
}
#[test]
fn accept_if_size_cap_and_minority_side() {
let b = bag(&[
(1, 40),
(2, -10),
(3, -10),
(4, -10),
(5, -10),
(6, 8),
(7, -8),
]);
let s = accept_if(
|g: &GroupView<i64>| g.size() <= 3,
agg_net(
|a: &Item<i64>| if a.amount.unsigned_abs() == 8 { 1u64 } else { 0u64 },
|g| g.net() == 0,
),
);
let r = s.run(b);
conserves(7, &r);
assert_eq!(r.groups.len(), 1, "only the small pair is accepted");
assert_eq!(ids(&r.groups[0]), vec![6, 7]);
assert_eq!(r.residual.len(), 5, "the over-large group is dissolved");
}
struct EmitGroups(Vec<Vec<(ExtId, i64)>>);
impl Strategy<i64> for EmitGroups {
fn run(&self, bag: Vec<Item<i64>>) -> Resolution<i64> {
let claimed: BTreeSet<ExtId> = self.0.iter().flatten().map(|&(id, _)| id).collect();
let groups = self
.0
.iter()
.map(|m| Group {
members: m
.iter()
.map(|&(id, amount)| Allocation { id, amount })
.collect(),
origin: "emit".into(),
net: m.iter().map(|&(_, a)| a).sum(),
reason: None,
})
.collect();
let residual = bag
.into_iter()
.filter(|i| !claimed.contains(&i.id))
.collect();
Resolution { groups, residual }
}
}
#[test]
fn coalesce_merges_groups_that_share_a_row() {
let inner = EmitGroups(vec![
vec![(1, 100), (2, -60)],
vec![(2, -40), (3, 100), (4, -100)],
]);
let b = bag(&[(1, 100), (2, -100), (3, 100), (4, -100), (9, 7)]);
let s = coalesce("settlement", Box::new(inner));
let r = s.run(b);
conserves(5, &r);
assert_eq!(r.groups.len(), 1, "the two interlocking groups merge");
let g = &r.groups[0];
assert_eq!(g.origin, "settlement");
assert_eq!(ids(g), vec![1, 2, 3, 4]);
let two = g.members.iter().find(|a| a.id == 2).unwrap();
assert_eq!(two.amount, -100);
assert_eq!(g.net, 0);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 9);
}
#[test]
fn coalesce_keeps_disjoint_groups_separate() {
let inner = EmitGroups(vec![vec![(1, 5), (2, -5)], vec![(3, 7), (4, -7)]]);
let b = bag(&[(1, 5), (2, -5), (3, 7), (4, -7)]);
let s = coalesce("settlement", Box::new(inner));
let r = s.run(b);
conserves(4, &r);
assert_eq!(r.groups.len(), 2, "disjoint components stay separate");
assert!(r.groups.iter().all(|g| g.origin == "settlement"));
}
#[test]
fn coalesce_folds_flow_arcs_into_one_settlement() {
#[derive(Clone)]
struct Tx {
date: i64,
}
let spec = FlowSpec::<Tx>::new()
.penalty(1_000_000.0)
.window(3)
.block_key(|t: &Tx| t.date)
.cost_lot(|a: &Tx, a_amt, b: &Tx, b_amt| {
Some(1.0 + (a_amt + b_amt).abs() as f64 * 0.1 + (a.date - b.date).abs() as f64)
});
let s = coalesce("flow", flow(spec));
let r = s.run(vec![
Item::new(1, 100, Tx { date: 0 }),
Item::new(2, 200, Tx { date: 1 }),
Item::new(3, -250, Tx { date: 0 }),
]);
assert_eq!(r.groups.len(), 1, "arcs fold into one settlement");
let g = &r.groups[0];
assert_eq!(g.origin, "flow");
assert_eq!(g.net, 0);
assert_eq!(ids(g), vec![1, 2, 3]);
let three = g.members.iter().find(|a| a.id == 3).unwrap();
assert_eq!(three.amount, -250, "id 3's two arcs sum to one clean edge");
assert_eq!(r.residual.iter().map(|i| i.amount).sum::<i64>(), 50);
}
#[test]
fn pivot_converts_back_to_outer_amount() {
let b = vec![Item::new(1, 110, (100i64,)), Item::new(2, -110, (-100i64,))];
let s = pivot(|d: &Item<(i64,)>| d.data.0, exact_1to1(|_| Some(0)));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 0);
assert_eq!(
r.groups[0].members,
vec![
Allocation { id: 1, amount: 110 },
Allocation {
id: 2,
amount: -110
}
]
);
}
struct HalfMatch;
impl Strategy<(i64,)> for HalfMatch {
fn run(&self, bag: Vec<Item<(i64,)>>) -> Resolution<(i64,)> {
let mut members = Vec::new();
let mut residual = Vec::new();
for it in bag {
let half = it.amount / 2;
members.push(Allocation {
id: it.id,
amount: half,
});
let mut r = it.clone();
r.amount = it.amount - half;
residual.push(r);
}
let net = members.iter().map(|a| a.amount).sum();
let groups = vec![Group {
members,
origin: "half".into(),
net,
reason: None,
}];
Resolution { groups, residual }
}
}
#[test]
fn pivot_dissolves_rows_that_round_to_zero_parent() {
let b = vec![Item::new(1, 1, (4i64,)), Item::new(2, 100, (4i64,))];
let s = pivot(|d: &Item<(i64,)>| d.data.0, Box::new(HalfMatch));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members, vec![Allocation { id: 2, amount: 50 }]);
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 1), (2, 50)]);
}
#[test]
fn pivot_zero_target_is_safe() {
let b = vec![Item::new(1, 5, (0i64,)), Item::new(2, 100, (4i64,))];
let s = pivot(|d: &Item<(i64,)>| d.data.0, Box::new(HalfMatch));
let r = s.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].members, vec![Allocation { id: 2, amount: 50 }]);
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 5), (2, 50)]);
}
struct DropZeros;
impl Strategy<(i64,)> for DropZeros {
fn run(&self, bag: Vec<Item<(i64,)>>) -> Resolution<(i64,)> {
Resolution {
groups: Vec::new(),
residual: bag.into_iter().filter(|i| i.amount != 0).collect(),
}
}
}
#[test]
fn pivot_reemits_forward_floored_rows() {
let b = vec![
Item {
id: 1,
original: 100,
amount: 1,
data: (3i64,),
},
Item {
id: 2,
original: 50,
amount: 50,
data: (50i64,),
},
];
let s = pivot(|d: &Item<(i64,)>| d.data.0, Box::new(DropZeros));
let r = s.run(b);
assert!(r.groups.is_empty());
let mut res: Vec<(ExtId, i64)> = r.residual.iter().map(|i| (i.id, i.amount)).collect();
res.sort();
assert_eq!(res, vec![(1, 1), (2, 50)]);
}
#[test]
fn soak_immaterial_singletons_via_when_and_partition() {
let s = when(
|i: &Item<i64>| i.amount != 0 && i.amount.abs() <= 5,
partition_by(|i: &Item<i64>| i.id, |_| soak("rounding")),
);
let r = s.run(bag(&[(1, 3), (2, -2), (3, 100)]));
conserves(3, &r);
assert_eq!(r.groups.len(), 2);
assert!(
r.groups
.iter()
.all(|g| g.members.len() == 1 && g.origin == "rounding")
);
assert_eq!(r.residual.len(), 1);
assert_eq!(r.residual[0].id, 3);
}
#[test]
fn soak_immaterial_bps_against_original() {
let items = vec![
Item {
id: 1,
original: 1000,
amount: 10,
data: 0,
}, Item {
id: 2,
original: 1000,
amount: 50,
data: 0,
}, ];
let s = when(
|i: &Item<i64>| i.amount != 0 && i.amount.abs() <= 200 * i.original / 10_000,
partition_by(|i: &Item<i64>| i.id, |_| soak("var")),
);
let r = s.run(items);
conserves(2, &r);
assert_eq!(ids(&r.groups[0]), vec![1]);
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![2]);
}
#[test]
fn soak_buckets_by_key_via_partition() {
let s = when(
|i: &Item<i64>| i.amount != 0 && i.amount.abs() <= 5,
partition_by(
|i: &Item<i64>| if i.amount > 0 { "pos".to_string() } else { "neg".to_string() },
|k: &String| soak(format!("tail:{k}")),
),
);
let r = s.run(bag(&[(1, 3), (2, 4), (3, -2), (4, 100)]));
conserves(4, &r);
assert_eq!(r.groups.len(), 2);
assert!(r.groups.iter().all(|g| g.origin.starts_with("tail:")));
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![4]);
}
#[test]
fn soak_predicate_selects_via_when() {
let s = when(
|i: &Item<i64>| i.amount < 0,
partition_by(|i: &Item<i64>| i.id, |_| soak("shorts")),
);
let r = s.run(bag(&[(1, 50), (2, -30), (3, -10)]));
conserves(3, &r);
let mut soaked: Vec<ExtId> = r.groups.iter().flat_map(ids).collect();
soaked.sort();
assert_eq!(soaked, vec![2, 3]);
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![1]);
}
#[test]
fn soak_terminates_residual_one_group_per_lot() {
let s = partition_by(|i: &Item<i64>| i.id, |_| soak("unmatched"));
let r = s.run(bag(&[(1, 50), (2, -30), (3, 0)]));
assert!(r.residual.is_empty());
let mut soaked: Vec<ExtId> = r.groups.iter().flat_map(ids).collect();
soaked.sort();
assert_eq!(soaked, vec![1, 2]);
assert!(r.groups.iter().all(|g| g.net != 0));
}
#[test]
fn soak_bucket_nets_per_key() {
let s = partition_by(
|i: &Item<i64>| if i.amount > 0 { 1u64 } else { 2u64 },
|_| soak("class"),
);
let r = s.run(bag(&[(1, 50), (2, 30), (3, -20)]));
assert!(r.residual.is_empty());
let mut nets: Vec<i64> = r.groups.iter().map(|g| g.net).collect();
nets.sort();
assert_eq!(nets, vec![-20, 80]);
}
#[test]
fn soak_collapses_all_into_one_group() {
let s = soak("variance");
let r = s.run(bag(&[(1, 50), (2, 30), (3, -20), (4, 0)]));
assert!(r.residual.is_empty());
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].origin, "variance");
assert_eq!(ids(&r.groups[0]), vec![1, 2, 3]); assert_eq!(r.groups[0].net, 60);
}
fn group(members: &[(ExtId, i64)]) -> Group {
let members: Vec<Allocation> = members
.iter()
.map(|&(id, amount)| Allocation { id, amount })
.collect();
let net = members.iter().map(|a| a.amount).sum();
Group {
members,
origin: "test".into(),
net,
reason: None,
}
}
#[test]
fn group_metrics() {
let g = group(&[(1, 1_000_000), (2, -999_000), (3, -1_200)]);
assert_eq!(g.size(), 3);
assert_eq!(g.abs_net(), 200);
assert_eq!(g.max_abs(), 1_000_000);
assert_eq!(g.min_abs(), 1_200);
assert_eq!(g.min_side(), 1); }
#[test]
fn agg_net_largest_leg_accepts_what_smallest_leg_rejects() {
let b = bag(&[(1, 1_000_000), (2, -999_000), (3, -1_200)]);
let smallest = agg_net(
|_: &Item<i64>| 0u64,
|g| g.net().abs() <= (5 * g.min_leg() / 10_000).max(100),
);
assert_eq!(smallest.run(b.clone()).groups.len(), 0);
let largest = agg_net(
|_: &Item<i64>| 0u64,
|g| g.net().abs() <= (5 * g.max_leg() / 10_000).max(100),
);
let r = largest.run(b);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, -200);
}
}