use hashbrown::HashMap;
use num_traits::Zero;
use std::{hash::Hash, ops::AddAssign};
use priority_queue::PriorityQueue;
use petgraph::{
stable_graph::StableUnGraph,
visit::{Bfs, EdgeCount, EdgeRef, GraphProp, IntoEdges, IntoNodeIdentifiers, NodeCount},
Undirected,
};
type StCut<K, T> = Option<((T, T), K)>;
type MinCut<K, T, E> = Result<Option<(K, Vec<T>)>, E>;
fn zip<T, U>(a: Option<T>, b: Option<U>) -> Option<(T, U)> {
match (a, b) {
(Some(a), Some(b)) => Some((a, b)),
_ => None,
}
}
fn stoer_wagner_phase<G, F, K>(graph: G, mut edge_cost: F) -> StCut<K, G::NodeId>
where
G: GraphProp<EdgeType = Undirected> + IntoEdges + IntoNodeIdentifiers,
G::NodeId: Hash + Eq,
F: FnMut(G::EdgeRef) -> K,
K: Copy + Ord + Zero + AddAssign,
{
let mut pq = PriorityQueue::<G::NodeId, K, foldhash::fast::RandomState>::from(
graph
.node_identifiers()
.map(|nx| (nx, K::zero()))
.collect::<Vec<(G::NodeId, K)>>(),
);
let mut cut_w = None;
let (mut s, mut t) = (None, None);
while let Some((nx, nx_val)) = pq.pop() {
s = t;
t = Some(nx);
cut_w = Some(nx_val);
for edge in graph.edges(nx) {
pq.change_priority_by(&edge.target(), |x| {
*x += edge_cost(edge);
});
}
}
zip(zip(s, t), cut_w)
}
pub fn stoer_wagner_min_cut<G, F, K, E>(graph: G, mut edge_cost: F) -> MinCut<K, G::NodeId, E>
where
G: GraphProp<EdgeType = Undirected> + IntoEdges + IntoNodeIdentifiers + NodeCount + EdgeCount,
G::NodeId: Hash + Eq,
F: FnMut(G::EdgeRef) -> Result<K, E>,
K: Copy + Ord + Zero + AddAssign,
{
let mut graph_with_super_nodes =
StableUnGraph::with_capacity(graph.node_count(), graph.edge_count());
let mut node_map = HashMap::with_capacity(graph.node_count());
let mut rev_node_map = HashMap::with_capacity(graph.node_count());
for node in graph.node_identifiers() {
let index = graph_with_super_nodes.add_node(());
node_map.insert(node, index);
rev_node_map.insert(index, node);
}
for edge in graph.edge_references() {
let cost = edge_cost(edge)?;
let source = node_map[&edge.source()];
let target = node_map[&edge.target()];
graph_with_super_nodes.add_edge(source, target, cost);
}
if graph_with_super_nodes.node_count() == 0 {
return Ok(None);
}
let (mut best_phase, mut min_cut_val) = (None, None);
let mut contractions = Vec::new();
for phase in 0..(graph_with_super_nodes.node_count() - 1) {
if let Some(((s, t), cut_w)) =
stoer_wagner_phase(&graph_with_super_nodes, |edge| *edge.weight())
{
if min_cut_val.is_none() || Some(cut_w) < min_cut_val {
best_phase = Some(phase);
min_cut_val = Some(cut_w);
}
contractions.push((s, t));
let edges = graph_with_super_nodes
.edges(t)
.map(|edge| (s, edge.target(), *edge.weight()))
.collect::<Vec<_>>();
for (source, target, cost) in edges {
if let Some(edge_index) = graph_with_super_nodes.find_edge(source, target) {
graph_with_super_nodes[edge_index] += cost;
} else {
graph_with_super_nodes.add_edge(source, target, cost);
}
}
graph_with_super_nodes.remove_node(t);
}
}
let min_cut = best_phase.map(|phase| {
let mut clustered_graph = StableUnGraph::<(), ()>::default();
clustered_graph.extend_with_edges(&contractions[..phase]);
let node = contractions[phase].1;
if clustered_graph.contains_node(node) {
let mut cluster = Vec::new();
let mut bfs = Bfs::new(&clustered_graph, node);
while let Some(nx) = bfs.next(&clustered_graph) {
cluster.push(rev_node_map[&nx])
}
cluster
} else {
vec![rev_node_map[&node]]
}
});
Ok(zip(min_cut_val, min_cut))
}