use super::{Group, Item, Resolution, Strategy};
use crate::engine::{ArcId, Network, NodeId, SolveStatus};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::sync::Arc;
pub type ExtId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Allocation {
pub id: ExtId,
pub amount: i64,
}
type MatchKeysFn<E> = dyn Fn(&E, i64) -> Vec<u64>;
type CostFn<E> = dyn Fn(&E, i64, &E, i64) -> Option<f64>;
pub struct FlowSpec<E> {
penalty: Arc<dyn Fn(&E) -> f64>,
block_key: Arc<dyn Fn(&E) -> i64>,
window: i64,
match_keys: Arc<MatchKeysFn<E>>,
cost: Arc<CostFn<E>>,
}
impl<E> Clone for FlowSpec<E> {
fn clone(&self) -> Self {
FlowSpec {
penalty: self.penalty.clone(),
block_key: self.block_key.clone(),
window: self.window,
match_keys: self.match_keys.clone(),
cost: self.cost.clone(),
}
}
}
impl<E> Default for FlowSpec<E> {
fn default() -> Self {
FlowSpec {
penalty: Arc::new(|_| 0.0),
block_key: Arc::new(|_| 0),
window: -1,
match_keys: Arc::new(|_, _| Vec::new()),
cost: Arc::new(|_, _, _, _| None),
}
}
}
impl<E> FlowSpec<E> {
pub fn new() -> Self {
Self::default()
}
pub fn penalty(mut self, p: f64) -> Self {
self.penalty = Arc::new(move |_| p);
self
}
pub fn penalty_fn(mut self, f: impl Fn(&E) -> f64 + 'static) -> Self {
self.penalty = Arc::new(f);
self
}
pub fn window(mut self, w: i64) -> Self {
self.window = w;
self
}
pub fn block_key(mut self, f: impl Fn(&E) -> i64 + 'static) -> Self {
self.block_key = Arc::new(f);
self
}
pub fn match_keys(mut self, f: impl Fn(&E) -> Vec<u64> + 'static) -> Self {
self.match_keys = Arc::new(move |e, _amount| f(e));
self
}
pub fn match_keys_lot(mut self, f: impl Fn(&E, i64) -> Vec<u64> + 'static) -> Self {
self.match_keys = Arc::new(f);
self
}
pub fn cost(mut self, f: impl Fn(&E, &E) -> Option<f64> + 'static) -> Self {
self.cost = Arc::new(move |a, _aa, b, _bb| f(a, b));
self
}
pub fn cost_lot(mut self, f: impl Fn(&E, i64, &E, i64) -> Option<f64> + 'static) -> Self {
self.cost = Arc::new(f);
self
}
}
const MATCH_BUCKET_CAP: usize = 256;
struct Entry<E> {
node: NodeId,
tx: E,
key: i64,
base: i64,
keys: Vec<u64>,
arcs: Vec<(ExtId, ArcId)>,
}
#[derive(Clone, PartialEq, Eq)]
struct FlowSig {
amount: i64,
penalty_bits: u64,
key: i64,
keys: Vec<u64>,
}
struct Flow<E> {
spec: FlowSpec<E>,
net: Network,
entries: HashMap<ExtId, Entry<E>>,
by_key: BTreeMap<i64, Vec<ExtId>>,
by_match_key: HashMap<u64, Vec<ExtId>>,
loaded: HashMap<ExtId, FlowSig>,
}
impl<E> Flow<E> {
fn new(spec: FlowSpec<E>) -> Self {
Flow {
spec,
net: Network::new(),
entries: HashMap::new(),
by_key: BTreeMap::new(),
by_match_key: HashMap::new(),
loaded: HashMap::new(),
}
}
fn flow_sig(&self, item: &Item<E>) -> FlowSig {
let amount = item.amount;
let mut keys = (self.spec.match_keys)(&item.data, amount);
keys.sort_unstable();
FlowSig {
amount,
penalty_bits: (self.spec.penalty)(&item.data).to_bits(),
key: (self.spec.block_key)(&item.data),
keys,
}
}
fn upsert(&mut self, id: ExtId, tx: E, base: i64) {
let key = (self.spec.block_key)(&tx);
let keys = (self.spec.match_keys)(&tx, base);
if self.entries.contains_key(&id) {
self.detach_arcs(id);
let (old_node, old_key, old_base, old_keys) = {
let e = &self.entries[&id];
(e.node, e.key, e.base, e.keys.clone())
};
if old_key != key {
self.unindex_key(old_key, id);
self.by_key.entry(key).or_default().push(id);
}
if old_keys != keys {
self.unindex_match_keys(id, &old_keys);
self.index_match_keys(id, &keys);
}
if old_base != base {
self.net.set_supply(old_node, base);
}
self.net.set_penalty(old_node, (self.spec.penalty)(&tx));
{
let e = self.entries.get_mut(&id).unwrap();
e.tx = tx;
e.key = key;
e.base = base;
e.keys = keys;
}
self.generate_arcs(id);
} else {
let node = self.net.add_node(base, (self.spec.penalty)(&tx));
self.by_key.entry(key).or_default().push(id);
self.index_match_keys(id, &keys);
self.entries.insert(
id,
Entry {
node,
tx,
key,
base,
keys,
arcs: Vec::new(),
},
);
self.generate_arcs(id);
}
}
fn remove(&mut self, id: ExtId) {
if let Some(e) = self.entries.remove(&id) {
self.unindex_key(e.key, id);
self.unindex_match_keys(id, &e.keys);
for (other, _) in &e.arcs {
if let Some(oe) = self.entries.get_mut(other) {
oe.arcs.retain(|(x, _)| *x != id);
}
}
self.net.remove_node(e.node);
}
}
fn solve(&mut self) -> SolveStatus {
self.net.solve()
}
fn objective(&self) -> f64 {
self.net.total_cost()
}
fn arc_count(&self) -> usize {
self.entries.values().map(|e| e.arcs.len()).sum::<usize>() / 2
}
fn matched_arcs(&self) -> Vec<(ExtId, ExtId, i64)> {
let mut slot_to_ext: HashMap<NodeId, ExtId> = HashMap::new();
for (id, e) in &self.entries {
slot_to_ext.insert(e.node, *id);
}
let mut arcs: Vec<(ExtId, ExtId, i64)> = Vec::new();
for (from, to, f) in self.net.matches() {
if let (Some(&a), Some(&b)) = (slot_to_ext.get(&from), slot_to_ext.get(&to)) {
let ea = &self.entries[&a];
let eb = &self.entries[&b];
let (src, snk) = if ea.base > 0 && eb.base < 0 {
(a, b)
} else if eb.base > 0 && ea.base < 0 {
(b, a)
} else {
continue;
};
arcs.push((src, snk, f));
}
}
arcs.sort_unstable();
arcs
}
fn matched_by_id(&self) -> HashMap<ExtId, i64> {
let mut m: HashMap<ExtId, i64> = HashMap::new();
for (src, snk, f) in self.matched_arcs() {
*m.entry(src).or_insert(0) += f;
*m.entry(snk).or_insert(0) -= f;
}
m
}
fn unmatched_allocations(&self) -> Vec<Allocation> {
let matched_by_id = self.matched_by_id();
let mut out = Vec::new();
for (&id, e) in &self.entries {
let matched = *matched_by_id.get(&id).unwrap_or(&0);
let rem = e.base - matched;
if rem != 0 {
out.push(Allocation { id, amount: rem });
}
}
out.sort_by_key(|a| a.id);
out
}
fn generate_arcs(&mut self, id: ExtId) {
let window = self.spec.window;
let (key, base, node, keys) = {
let e = &self.entries[&id];
(e.key, e.base, e.node, e.keys.clone())
};
if base == 0 {
return;
}
let mut partners: HashSet<ExtId> = HashSet::new();
let consider = |this: &Self, other: ExtId, set: &mut HashSet<ExtId>| {
if other == id {
return;
}
let ob = this.entries[&other].base;
if (base > 0) == (ob > 0) {
return; }
set.insert(other);
};
if window >= 0 {
for (_k, ids) in self.by_key.range(key - window..=key + window) {
for &other in ids {
consider(self, other, &mut partners);
}
}
}
for k in &keys {
if let Some(ids) = self.by_match_key.get(k) {
if ids.len() > MATCH_BUCKET_CAP {
continue; }
for &other in ids {
consider(self, other, &mut partners);
}
}
}
let mut partners: Vec<ExtId> = partners.into_iter().collect();
partners.sort_unstable();
for other in partners {
let (src_id, snk_id) = if base > 0 { (id, other) } else { (other, id) };
let (src_node, snk_node) = if base > 0 {
(node, self.entries[&other].node)
} else {
(self.entries[&other].node, node)
};
let cost = {
let s = &self.entries[&src_id];
let t = &self.entries[&snk_id];
(self.spec.cost)(&s.tx, s.base, &t.tx, t.base)
};
if let Some(cost) = cost
&& let Some(arc) = self.net.add_arc(src_node, snk_node, cost)
{
self.entries.get_mut(&id).unwrap().arcs.push((other, arc));
self.entries.get_mut(&other).unwrap().arcs.push((id, arc));
}
}
}
fn detach_arcs(&mut self, id: ExtId) {
let arcs = std::mem::take(&mut self.entries.get_mut(&id).unwrap().arcs);
for (other, arc) in arcs {
self.net.remove_arc(arc);
if let Some(oe) = self.entries.get_mut(&other) {
oe.arcs.retain(|(x, _)| *x != id);
}
}
}
fn unindex_key(&mut self, key: i64, id: ExtId) {
if let Some(v) = self.by_key.get_mut(&key) {
v.retain(|x| *x != id);
if v.is_empty() {
self.by_key.remove(&key);
}
}
}
fn index_match_keys(&mut self, id: ExtId, keys: &[u64]) {
for &k in keys {
self.by_match_key.entry(k).or_default().push(id);
}
}
fn unindex_match_keys(&mut self, id: ExtId, keys: &[u64]) {
for &k in keys {
if let Some(v) = self.by_match_key.get_mut(&k) {
v.retain(|x| *x != id);
if v.is_empty() {
self.by_match_key.remove(&k);
}
}
}
}
}
impl<E> Strategy<E> for Flow<E>
where
E: Clone,
{
fn run(&mut self, bag: Vec<Item<E>>) -> Resolution<E> {
#[cfg(not(target_arch = "wasm32"))]
let timed = std::env::var_os("FLORECON_TIME").is_some();
#[cfg(target_arch = "wasm32")]
let timed = false;
let want: BTreeSet<ExtId> = bag.iter().map(|i| i.id).collect();
let data: HashMap<ExtId, &Item<E>> = bag.iter().map(|i| (i.id, i)).collect();
let sigs: HashMap<ExtId, FlowSig> = bag.iter().map(|i| (i.id, self.flow_sig(i))).collect();
let mut upserts: Vec<ExtId> = sigs
.iter()
.filter_map(|(&id, sig)| (self.loaded.get(&id) != Some(sig)).then_some(id))
.collect();
upserts.sort_by_key(|&id| flow_upsert_rank(id));
let drops: Vec<ExtId> = self
.loaded
.keys()
.copied()
.filter(|id| !want.contains(id))
.collect();
let tb = timed.then(std::time::Instant::now);
for id in upserts {
if let Some(item) = data.get(&id) {
self.upsert(id, item.data.clone(), item.amount);
}
}
for id in drops {
self.remove(id);
}
let build = tb.map(|t| t.elapsed().as_secs_f64() * 1000.0);
let ts = timed.then(std::time::Instant::now);
let status = self.solve(); if let (Some(build), Some(ts)) = (build, ts) {
eprintln!(
" flow: delta {build:>6.1} ms ({} arcs), solve {:>6.1} ms",
self.arc_count(),
ts.elapsed().as_secs_f64() * 1000.0,
);
}
debug_assert_eq!(status, SolveStatus::Optimal);
self.loaded = sigs;
if cfg!(debug_assertions) || std::env::var_os("FLORECON_VERIFY_WARM").is_some() {
let mut cold = Flow::new(self.spec.clone());
let mut ids: Vec<ExtId> = data.keys().copied().collect();
ids.sort_unstable();
for id in ids {
if let Some(item) = data.get(&id) {
cold.upsert(id, item.data.clone(), item.amount);
}
}
cold.solve();
let (warm_obj, cold_obj) = (self.objective(), cold.objective());
assert!(
(warm_obj - cold_obj).abs() < 1e-6,
"warm flow solve diverged from a fresh cold rebuild: \
warm objective {warm_obj} != cold objective {cold_obj}"
);
}
let groups = self
.matched_arcs()
.into_iter()
.map(|(src, snk, f)| Group {
members: vec![
Allocation { id: src, amount: f },
Allocation {
id: snk,
amount: -f,
},
],
origin: "flow".to_string(),
net: 0,
reason: Some("min-cost flow".to_string()),
})
.collect();
let unmatched: HashMap<ExtId, i64> = self
.unmatched_allocations()
.into_iter()
.map(|a| (a.id, a.amount))
.collect();
let residual = bag
.into_iter()
.filter_map(|mut i| {
unmatched.get(&i.id).map(|&amount| {
i.amount = amount;
i
})
})
.collect();
Resolution { groups, residual }
}
}
pub fn flow<E>(spec: FlowSpec<E>) -> Box<dyn Strategy<E>>
where
E: Clone + 'static,
{
Box::new(Flow::new(spec))
}
fn flow_upsert_rank(id: ExtId) -> u64 {
let mut z = id.wrapping_add(0x9E3779B97F4A7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct Tx {
date: i64,
}
fn demo() -> FlowSpec<Tx> {
FlowSpec::new()
.penalty(1_000_000.0)
.window(3)
.block_key(|tx: &Tx| tx.date)
.cost_lot(|a: &Tx, a_amt, b: &Tx, b_amt| {
Some(1.0 + (a_amt + b_amt).abs() as f64 * 0.1 + (a.date - b.date).abs() as f64)
})
}
fn item(id: ExtId, amount: i64, date: i64) -> Item<Tx> {
Item::new(id, amount, Tx { date })
}
fn ids(g: &Group) -> Vec<ExtId> {
g.member_ids()
}
#[test]
fn basic_recon() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, -100, 1)]);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 0); assert_eq!(ids(&r.groups[0]), vec![1, 2]);
assert!(r.residual.is_empty());
}
#[test]
fn bare_flow_emits_raw_arcs_not_settlements() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, 200, 1), item(3, -250, 0)]);
assert_eq!(r.groups.len(), 2, "one group per positive-flow arc");
assert!(r.groups.iter().all(|g| g.members.len() == 2 && g.net == 0));
assert!(r.groups.iter().all(|g| ids(g).contains(&3)));
let matched: i64 = r
.groups
.iter()
.flat_map(|g| &g.members)
.filter(|a| a.amount > 0)
.map(|a| a.amount)
.sum();
assert_eq!(matched, 250);
assert_eq!(r.residual.iter().map(|i| i.amount).sum::<i64>(), 50);
assert_eq!(r.residual.len(), 1);
}
#[test]
fn streaming_add_is_warm() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, -100, 0)]);
assert_eq!(r.groups.len(), 1);
let r = s.run(vec![
item(1, 100, 0),
item(2, -100, 0),
item(3, 70, 5),
item(4, -70, 5),
]);
assert_eq!(r.groups.len(), 2);
assert!(r.groups.iter().all(|g| g.net == 0));
}
#[test]
fn out_of_window_unmatched() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, -100, 100)]); assert_eq!(r.groups.len(), 0);
let mut rem: Vec<ExtId> = r.residual.iter().map(|i| i.id).collect();
rem.sort_unstable();
assert_eq!(rem, vec![1, 2]);
}
#[test]
fn correction_reprice_is_warm() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, -100, 0), item(3, -50, 0)]);
assert!(
r.groups
.iter()
.any(|g| ids(g).contains(&1) && ids(g).contains(&2))
);
let r = s.run(vec![item(1, 50, 0), item(2, -100, 0), item(3, -50, 0)]);
assert!(
r.groups
.iter()
.any(|g| g.net == 0 && ids(g).contains(&1) && ids(g).contains(&3))
);
}
#[test]
fn remove_is_warm() {
let mut s = flow(demo());
let r = s.run(vec![item(1, 100, 0), item(2, -100, 0)]);
assert_eq!(r.groups.len(), 1);
let r = s.run(vec![item(1, 100, 0)]);
assert_eq!(r.groups.len(), 0);
assert_eq!(r.residual.iter().map(|i| i.id).collect::<Vec<_>>(), vec![1]);
}
#[test]
fn lot_cost_sees_residual_amount() {
let spec = FlowSpec::new()
.penalty(1e9)
.window(5)
.block_key(|t: &Tx| t.date)
.cost_lot(|_a: &Tx, a_amt, _b: &Tx, b_amt| (a_amt.abs() == b_amt.abs()).then_some(1.0));
let mut s = flow(spec);
let r = s.run(vec![item(1, 100, 0), item(2, -100, 0)]);
assert_eq!(r.groups.len(), 1);
assert_eq!(r.groups[0].net, 0);
}
}