use super::{Graph, GraphError, GraphResult, NodeId, Weight};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub struct FlowNetwork {
capacity: HashMap<NodeId, HashMap<NodeId, Weight>>,
flow: HashMap<NodeId, HashMap<NodeId, Weight>>,
nodes: HashSet<NodeId>,
}
impl FlowNetwork {
pub fn from_graph(graph: &Graph) -> GraphResult<Self> {
let mut capacity = HashMap::new();
let mut flow = HashMap::new();
let mut nodes = HashSet::new();
for &node in &graph.nodes() {
nodes.insert(node);
capacity.insert(node, HashMap::new());
flow.insert(node, HashMap::new());
}
for edge in graph.edges() {
capacity
.get_mut(&edge.from)
.ok_or(GraphError::NodeNotFound(edge.from))?
.insert(edge.to, edge.weight);
flow.get_mut(&edge.from)
.ok_or(GraphError::NodeNotFound(edge.from))?
.insert(edge.to, 0.0);
if !flow
.get(&edge.to)
.ok_or(GraphError::NodeNotFound(edge.to))?
.contains_key(&edge.from)
{
flow.get_mut(&edge.to)
.ok_or(GraphError::NodeNotFound(edge.to))?
.insert(edge.from, 0.0);
}
}
Ok(Self {
capacity,
flow,
nodes,
})
}
fn residual_capacity(&self, u: NodeId, v: NodeId) -> Weight {
let capacity = self
.capacity
.get(&u)
.and_then(|m| m.get(&v))
.copied()
.unwrap_or(0.0);
let current_flow = self
.flow
.get(&u)
.and_then(|m| m.get(&v))
.copied()
.unwrap_or(0.0);
capacity - current_flow
}
fn bfs_residual(&self, source: NodeId, sink: NodeId) -> Option<HashMap<NodeId, NodeId>> {
let mut visited = HashSet::new();
let mut parent = HashMap::new();
let mut queue = VecDeque::new();
visited.insert(source);
queue.push_back(source);
while let Some(u) = queue.pop_front() {
if u == sink {
return Some(parent);
}
for &v in &self.nodes {
if !visited.contains(&v) && self.residual_capacity(u, v) > 0.0 {
visited.insert(v);
parent.insert(v, u);
queue.push_back(v);
}
}
}
None
}
pub fn total_flow(&self, source: NodeId) -> Weight {
self.flow
.get(&source)
.map(|outgoing| outgoing.values().sum::<Weight>())
.unwrap_or(0.0)
}
pub fn get_flow(&self, u: NodeId, v: NodeId) -> Weight {
self.flow
.get(&u)
.and_then(|m| m.get(&v))
.copied()
.unwrap_or(0.0)
}
}
#[derive(Debug, Clone)]
pub struct MaxFlowResult {
pub max_flow: Weight,
pub flows: HashMap<(NodeId, NodeId), Weight>,
pub source_side: HashSet<NodeId>,
}
pub fn max_flow(graph: &Graph, source: NodeId, sink: NodeId) -> GraphResult<MaxFlowResult> {
if !graph.has_node(source) {
return Err(GraphError::NodeNotFound(source));
}
if !graph.has_node(sink) {
return Err(GraphError::NodeNotFound(sink));
}
if source == sink {
return Err(GraphError::InvalidOperation(
"Source and sink must be different nodes".to_string(),
));
}
for edge in graph.edges() {
if edge.weight < 0.0 {
return Err(GraphError::InvalidWeight(
"Flow network cannot have negative capacities".to_string(),
));
}
}
let mut network = FlowNetwork::from_graph(graph)?;
while let Some(parent) = network.bfs_residual(source, sink) {
let mut path_flow = Weight::INFINITY;
let mut v = sink;
while let Some(&u) = parent.get(&v) {
let residual = network.residual_capacity(u, v);
path_flow = path_flow.min(residual);
v = u;
}
v = sink;
while let Some(&u) = parent.get(&v) {
let current = network
.flow
.get(&u)
.and_then(|m| m.get(&v))
.copied()
.unwrap_or(0.0);
network
.flow
.get_mut(&u)
.ok_or(GraphError::NodeNotFound(u))?
.insert(v, current + path_flow);
let reverse = network
.flow
.get(&v)
.and_then(|m| m.get(&u))
.copied()
.unwrap_or(0.0);
network
.flow
.get_mut(&v)
.ok_or(GraphError::NodeNotFound(v))?
.insert(u, reverse - path_flow);
v = u;
}
}
let mut flows = HashMap::new();
for (&from, outgoing) in &network.flow {
for (&to, &flow_val) in outgoing {
if flow_val > 0.0 {
flows.insert((from, to), flow_val);
}
}
}
let mut source_side = HashSet::new();
let mut queue = VecDeque::new();
source_side.insert(source);
queue.push_back(source);
while let Some(u) = queue.pop_front() {
for &v in &network.nodes {
if !source_side.contains(&v) && network.residual_capacity(u, v) > 0.0 {
source_side.insert(v);
queue.push_back(v);
}
}
}
Ok(MaxFlowResult {
max_flow: network.total_flow(source),
flows,
source_side,
})
}
pub fn min_cut(
graph: &Graph,
max_flow_result: &MaxFlowResult,
) -> GraphResult<Vec<(NodeId, NodeId, Weight)>> {
let mut cut_edges = Vec::new();
for edge in graph.edges() {
let in_source = max_flow_result.source_side.contains(&edge.from);
let in_sink = !max_flow_result.source_side.contains(&edge.to);
if in_source && in_sink {
cut_edges.push((edge.from, edge.to, edge.weight));
}
}
Ok(cut_edges)
}
#[derive(Debug, Clone)]
pub struct BipartiteGraph {
pub left: HashSet<NodeId>,
pub right: HashSet<NodeId>,
pub edges: HashMap<NodeId, HashSet<NodeId>>,
}
impl BipartiteGraph {
pub fn new() -> Self {
Self {
left: HashSet::new(),
right: HashSet::new(),
edges: HashMap::new(),
}
}
pub fn add_left(&mut self, node: NodeId) {
self.left.insert(node);
self.edges.entry(node).or_default();
}
pub fn add_right(&mut self, node: NodeId) {
self.right.insert(node);
}
pub fn add_edge(&mut self, left: NodeId, right: NodeId) -> GraphResult<()> {
if !self.left.contains(&left) {
return Err(GraphError::NodeNotFound(left));
}
if !self.right.contains(&right) {
return Err(GraphError::NodeNotFound(right));
}
self.edges
.get_mut(&left)
.ok_or(GraphError::NodeNotFound(left))?
.insert(right);
Ok(())
}
fn to_flow_network(
&self,
) -> (
Graph,
NodeId,
NodeId,
HashMap<NodeId, NodeId>,
HashMap<NodeId, NodeId>,
) {
let mut graph = Graph::new(true);
let mut reverse_left = HashMap::new();
let mut reverse_right = HashMap::new();
let source = graph.add_node();
let sink = graph.add_node();
let mut left_sorted: Vec<_> = self.left.iter().copied().collect();
left_sorted.sort();
let mut left_id_map = HashMap::new();
for left_node in &left_sorted {
let id = graph.add_node();
left_id_map.insert(*left_node, id);
reverse_left.insert(id, *left_node);
let _ = graph.add_edge(source, id, 1.0);
}
let mut right_sorted: Vec<_> = self.right.iter().copied().collect();
right_sorted.sort();
let mut right_id_map = HashMap::new();
for right_node in &right_sorted {
let id = graph.add_node();
right_id_map.insert(*right_node, id);
reverse_right.insert(id, *right_node);
let _ = graph.add_edge(id, sink, 1.0);
}
for (&left_node, right_neighbors) in &self.edges {
if let Some(&left_id) = left_id_map.get(&left_node) {
for &right_node in right_neighbors {
if let Some(&right_id) = right_id_map.get(&right_node) {
let _ = graph.add_edge(left_id, right_id, 1.0);
}
}
}
}
(graph, source, sink, reverse_left, reverse_right)
}
}
impl Default for BipartiteGraph {
fn default() -> Self {
Self::new()
}
}
pub fn max_bipartite_matching(bipartite: &BipartiteGraph) -> GraphResult<Vec<(NodeId, NodeId)>> {
let (flow_graph, source, sink, reverse_left, reverse_right) = bipartite.to_flow_network();
if flow_graph.node_count() < 2 {
return Ok(Vec::new());
}
let flow_result = max_flow(&flow_graph, source, sink)?;
let mut matching = Vec::new();
for ((from, to), flow_val) in &flow_result.flows {
if *flow_val > 0.0 {
if let (Some(&left), Some(&right)) = (reverse_left.get(from), reverse_right.get(to)) {
matching.push((left, right));
}
}
}
Ok(matching)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_max_flow_simple() {
let mut graph = Graph::new(true);
let source = graph.add_node();
let n1 = graph.add_node();
let n2 = graph.add_node();
let sink = graph.add_node();
graph
.add_edge(source, n1, 10.0)
.expect("test: valid edge addition");
graph
.add_edge(source, n2, 10.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, sink, 10.0)
.expect("test: valid edge addition");
graph
.add_edge(n2, sink, 10.0)
.expect("test: valid edge addition");
let result = max_flow(&graph, source, sink).expect("test: valid max flow");
assert_eq!(result.max_flow, 20.0);
}
#[test]
fn test_max_flow_bottleneck() {
let mut graph = Graph::new(true);
let source = graph.add_node();
let n1 = graph.add_node();
let sink = graph.add_node();
graph
.add_edge(source, n1, 10.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, sink, 5.0)
.expect("test: valid edge addition");
let result = max_flow(&graph, source, sink).expect("test: valid max flow");
assert_eq!(result.max_flow, 5.0);
}
#[test]
fn test_max_flow_negative_capacity() {
let mut graph = Graph::new(true);
let source = graph.add_node();
let sink = graph.add_node();
graph
.add_edge(source, sink, -1.0)
.expect("test: valid edge addition");
let result = max_flow(&graph, source, sink);
assert!(result.is_err());
}
#[test]
fn test_min_cut() {
let mut graph = Graph::new(true);
let source = graph.add_node();
let n1 = graph.add_node();
let sink = graph.add_node();
graph
.add_edge(source, n1, 10.0)
.expect("test: valid edge addition");
graph
.add_edge(n1, sink, 5.0)
.expect("test: valid edge addition");
graph
.add_edge(source, sink, 15.0)
.expect("test: valid edge addition");
let flow_result = max_flow(&graph, source, sink).expect("test: valid max flow");
let cut = min_cut(&graph, &flow_result).expect("test: valid min cut");
let cut_capacity: Weight = cut.iter().map(|(_, _, w)| w).sum();
assert_eq!(cut_capacity, flow_result.max_flow);
}
#[test]
fn test_bipartite_matching() {
let mut bipartite = BipartiteGraph::new();
for i in 0..3 {
bipartite.add_left(i);
}
for i in 3..6 {
bipartite.add_right(i);
}
bipartite
.add_edge(0, 3)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(0, 4)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(1, 4)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(2, 5)
.expect("test: valid bipartite edge addition");
let matching = max_bipartite_matching(&bipartite).expect("test: valid bipartite matching");
assert_eq!(matching.len(), 3);
let left_matched: HashSet<_> = matching.iter().map(|(l, _)| *l).collect();
let right_matched: HashSet<_> = matching.iter().map(|(_, r)| *r).collect();
assert_eq!(left_matched.len(), 3);
assert_eq!(right_matched.len(), 3);
}
#[test]
fn test_bipartite_matching_incomplete() {
let mut bipartite = BipartiteGraph::new();
bipartite.add_left(0);
bipartite.add_left(1);
bipartite.add_left(2);
bipartite.add_right(3);
bipartite.add_right(4);
bipartite
.add_edge(0, 3)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(1, 3)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(1, 4)
.expect("test: valid bipartite edge addition");
bipartite
.add_edge(2, 4)
.expect("test: valid bipartite edge addition");
let matching = max_bipartite_matching(&bipartite).expect("test: valid bipartite matching");
assert_eq!(matching.len(), 2);
}
#[test]
fn test_flow_network_creation() {
let mut graph = Graph::new(true);
let n0 = graph.add_node();
let n1 = graph.add_node();
graph
.add_edge(n0, n1, 10.0)
.expect("test: valid edge addition");
let network = FlowNetwork::from_graph(&graph).expect("test: valid flow network creation");
assert_eq!(network.residual_capacity(n0, n1), 10.0);
assert_eq!(network.get_flow(n0, n1), 0.0);
}
}