use std::cmp::Reverse;
use std::collections::BinaryHeap;
pub const INF_COST: i64 = 1_000_000_000_000_000_000;
#[derive(Debug, Clone)]
pub enum McfError {
InsufficientFlow { sent: i64, requested: i64 },
}
impl std::fmt::Display for McfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McfError::InsufficientFlow { sent, requested } => {
write!(
f,
"Could not send full flow: sent={sent} requested={requested}"
)
}
}
}
}
impl std::error::Error for McfError {}
#[derive(Debug, Clone)]
struct Edge {
to: usize,
rev: usize,
cap: i64,
cost: i64,
}
pub struct MinCostFlowGraph {
n: usize,
graph: Vec<Vec<Edge>>,
}
impl MinCostFlowGraph {
pub fn new(n: usize) -> Self {
Self {
n,
graph: vec![Vec::new(); n],
}
}
pub fn add_edge(&mut self, from: usize, to: usize, cap: i64, cost: i64) {
let rev_from = self.graph[to].len();
let rev_to = self.graph[from].len();
self.graph[from].push(Edge {
to,
rev: rev_from,
cap,
cost,
});
self.graph[to].push(Edge {
to: from,
rev: rev_to,
cap: 0,
cost: -cost,
});
}
pub fn min_cost_flow(
&mut self,
s: usize,
t: usize,
max_flow: i64,
) -> Result<(i64, i64), McfError> {
let n = self.n;
let inf: i128 = INF_COST as i128;
let mut pot = vec![0i128; n];
let mut prevv = vec![0usize; n];
let mut preve = vec![0usize; n];
let mut flow: i64 = 0;
let mut cost: i128 = 0;
while flow < max_flow {
let mut dist = vec![inf; n];
dist[s] = 0;
let mut pq: BinaryHeap<Reverse<(i128, usize)>> = BinaryHeap::new();
pq.push(Reverse((0, s)));
while let Some(Reverse((d, v))) = pq.pop() {
if d != dist[v] {
continue;
}
for (i, e) in self.graph[v].iter().enumerate() {
if e.cap <= 0 {
continue;
}
let nd = d + e.cost as i128 + pot[v] - pot[e.to];
if nd < dist[e.to] {
dist[e.to] = nd;
prevv[e.to] = v;
preve[e.to] = i;
pq.push(Reverse((nd, e.to)));
}
}
}
if dist[t] >= inf {
break; }
for v in 0..n {
if dist[v] < inf {
pot[v] += dist[v];
}
}
let mut add_flow = max_flow - flow;
let mut v = t;
while v != s {
let e = &self.graph[prevv[v]][preve[v]];
add_flow = add_flow.min(e.cap);
v = prevv[v];
}
v = t;
while v != s {
let pv = prevv[v];
let pe = preve[v];
self.graph[pv][pe].cap -= add_flow;
let rev = self.graph[pv][pe].rev;
self.graph[v][rev].cap += add_flow;
v = pv;
}
flow += add_flow;
cost += add_flow as i128 * pot[t];
}
if flow < max_flow {
Err(McfError::InsufficientFlow {
sent: flow,
requested: max_flow,
})
} else {
let cost_i64 = cost.clamp(i64::MIN as i128, i64::MAX as i128) as i64;
Ok((flow, cost_i64))
}
}
pub fn flow_on_edge(&self, from: usize, edge_idx: usize) -> i64 {
let e = &self.graph[from][edge_idx];
let rev_e = &self.graph[e.to][e.rev];
rev_e.cap
}
pub fn edge_count(&self, node: usize) -> usize {
self.graph[node].len()
}
pub fn edge_to(&self, node: usize, edge_idx: usize) -> usize {
self.graph[node][edge_idx].to
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_flow() {
let mut g = MinCostFlowGraph::new(2);
g.add_edge(0, 1, 2, 1);
let (flow, cost) = g.min_cost_flow(0, 1, 2).unwrap();
assert_eq!(flow, 2);
assert_eq!(cost, 2); }
#[test]
fn test_multi_path_chooses_cheapest() {
let mut g = MinCostFlowGraph::new(4);
g.add_edge(0, 1, 1, 10); g.add_edge(1, 3, 1, 0); g.add_edge(0, 2, 1, 100); g.add_edge(2, 3, 1, 0);
let (flow, cost) = g.min_cost_flow(0, 3, 1).unwrap();
assert_eq!(flow, 1);
assert_eq!(cost, 10); }
#[test]
fn test_multi_path_both_needed() {
let mut g = MinCostFlowGraph::new(4);
g.add_edge(0, 1, 1, 10);
g.add_edge(1, 3, 1, 0);
g.add_edge(0, 2, 1, 100);
g.add_edge(2, 3, 1, 0);
let (flow, cost) = g.min_cost_flow(0, 3, 2).unwrap();
assert_eq!(flow, 2);
assert_eq!(cost, 110); }
#[test]
fn test_insufficient_flow() {
let mut g = MinCostFlowGraph::new(2);
g.add_edge(0, 1, 1, 1);
let result = g.min_cost_flow(0, 1, 5);
assert!(result.is_err());
match result {
Err(McfError::InsufficientFlow { sent, requested }) => {
assert_eq!(sent, 1);
assert_eq!(requested, 5);
}
_ => panic!("expected InsufficientFlow"),
}
}
#[test]
fn test_zero_flow() {
let mut g = MinCostFlowGraph::new(2);
g.add_edge(0, 1, 10, 5);
let (flow, cost) = g.min_cost_flow(0, 1, 0).unwrap();
assert_eq!(flow, 0);
assert_eq!(cost, 0);
}
#[test]
fn test_bipartite_assignment() {
let mut g = MinCostFlowGraph::new(6);
g.add_edge(0, 1, 1, 0); g.add_edge(0, 2, 1, 0); g.add_edge(1, 3, 1, 5); g.add_edge(1, 4, 1, 10); g.add_edge(2, 3, 1, 10); g.add_edge(2, 4, 1, 5); g.add_edge(3, 5, 2, 0); g.add_edge(4, 5, 2, 0);
let (flow, cost) = g.min_cost_flow(0, 5, 2).unwrap();
assert_eq!(flow, 2);
assert_eq!(cost, 10); }
#[test]
fn test_flow_on_edge() {
let mut g = MinCostFlowGraph::new(3);
g.add_edge(0, 1, 1, 5); g.add_edge(0, 2, 1, 10); g.add_edge(1, 2, 1, 0);
g.add_edge(2, 2, 0, 0); let mut g = MinCostFlowGraph::new(4);
g.add_edge(0, 1, 1, 5); g.add_edge(0, 2, 1, 10); g.add_edge(1, 3, 1, 0); g.add_edge(2, 3, 1, 0);
let (flow, _cost) = g.min_cost_flow(0, 3, 1).unwrap();
assert_eq!(flow, 1);
assert_eq!(g.flow_on_edge(0, 0), 1); assert_eq!(g.flow_on_edge(0, 1), 0); }
#[test]
fn test_negative_costs_keep_reward() {
let mut g = MinCostFlowGraph::new(5);
g.add_edge(0, 1, 1, 0); g.add_edge(1, 2, 1, -100); g.add_edge(1, 3, 1, 50); g.add_edge(2, 4, 1, 0); g.add_edge(3, 4, 1, 0);
let (flow, cost) = g.min_cost_flow(0, 4, 1).unwrap();
assert_eq!(flow, 1);
assert_eq!(cost, -100); }
}