use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::convert::From;
use std::ops::{Add, AddAssign, Sub};
const BORUVKA_PARALLEL_MIN_EDGES: usize = 10_000;
fn index_bound<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> usize
where
Ty: GraphConstructor<A, W>,
{
graph
.node_ids()
.map(|node| node.index())
.max()
.map_or(0, |m| m + 1)
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, i: usize) -> usize {
if self.parent[i] != i {
self.parent[i] = self.find(self.parent[i]);
}
self.parent[i]
}
fn union(&mut self, i: usize, j: usize) {
let i = self.find(i);
let j = self.find(j);
if i == j {
return;
}
match self.rank[i].cmp(&self.rank[j]) {
Ordering::Less => self.parent[i] = j,
Ordering::Greater => self.parent[j] = i,
Ordering::Equal => {
self.parent[j] = i;
self.rank[i] += 1;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MstEdge<W> {
pub u: NodeId,
pub v: NodeId,
pub weight: W,
}
pub fn boruvka_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + Sub<Output = W> + From<u8> + Send + Sync,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let bound = index_bound(graph);
let all_edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
let mut uf = UnionFind::new(bound);
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let mut components = graph.node_count();
while components > 1 {
let roots: Vec<usize> = (0..bound).map(|i| uf.find(i)).collect();
let keep_lighter =
|slot: &mut Option<(NodeId, NodeId, W)>, cand: (NodeId, NodeId, W)| match slot {
Some((_, _, current)) if cand.2 < *current => *slot = Some(cand),
None => *slot = Some(cand),
_ => {}
};
let cheapest: Vec<Option<(NodeId, NodeId, W)>> =
if all_edges.len() >= BORUVKA_PARALLEL_MIN_EDGES {
all_edges
.par_iter()
.fold(
|| vec![None::<(NodeId, NodeId, W)>; bound],
|mut acc, &(u, v, w)| {
let ru = roots[u.index()];
let rv = roots[v.index()];
if ru != rv {
keep_lighter(&mut acc[ru], (u, v, w));
keep_lighter(&mut acc[rv], (u, v, w));
}
acc
},
)
.reduce(
|| vec![None::<(NodeId, NodeId, W)>; bound],
|mut a, b| {
for (slot, other) in a.iter_mut().zip(b) {
if let Some(cand) = other {
keep_lighter(slot, cand);
}
}
a
},
)
} else {
let mut acc = vec![None::<(NodeId, NodeId, W)>; bound];
for &(u, v, w) in &all_edges {
let ru = roots[u.index()];
let rv = roots[v.index()];
if ru != rv {
keep_lighter(&mut acc[ru], (u, v, w));
keep_lighter(&mut acc[rv], (u, v, w));
}
}
acc
};
let mut found = false;
for (u, v, w) in cheapest.into_iter().flatten() {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
components -= 1;
found = true;
}
}
if !found {
break;
}
}
Ok((mst_edges, total_weight))
}
pub fn kruskal_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let mut edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
edges.sort_by(|a, b| a.2.cmp(&b.2));
let mut uf = UnionFind::new(index_bound(graph));
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
for (u, v, w) in edges {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
}
}
Ok((mst_edges, total_weight))
}
pub fn prim_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
NodeId: Ord,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let bound = index_bound(graph);
let mut in_tree = vec![false; bound];
let mut adjacency: Vec<Vec<(NodeId, W)>> = vec![Vec::new(); bound];
for (u, v, w) in graph.edges() {
adjacency[u.index()].push((v, *w));
adjacency[v.index()].push((u, *w));
}
for start in graph.node_ids() {
if in_tree[start.index()] {
continue;
}
in_tree[start.index()] = true;
let mut heap = std::collections::BinaryHeap::new();
for &(neighbor, weight) in &adjacency[start.index()] {
heap.push(std::cmp::Reverse((weight, start, neighbor)));
}
while let Some(std::cmp::Reverse((w, u, v))) = heap.pop() {
if in_tree[u.index()] && in_tree[v.index()] {
continue;
}
let (from, to) = if in_tree[u.index()] { (u, v) } else { (v, u) };
if !in_tree[to.index()] {
in_tree[to.index()] = true;
mst_edges.push(MstEdge {
u: from,
v: to,
weight: w,
});
total_weight += w;
for &(neighbor, weight) in &adjacency[to.index()] {
if !in_tree[neighbor.index()] {
heap.push(std::cmp::Reverse((weight, to, neighbor)));
}
}
}
}
}
Ok((mst_edges, total_weight))
}
#[cfg(test)]
mod tests {
#[test]
fn test_prim_mst_undirected_target_edges() {
use crate::core::types::Graph;
use crate::mst::{kruskal_mst, prim_mst};
use ordered_float::OrderedFloat;
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..6).map(|i| g.add_node(i)).collect();
for (u, v, w) in [
(0, 4, 5.0),
(0, 5, 2.0),
(1, 5, 1.0),
(2, 4, 10.0),
(3, 4, 1.0),
] {
g.add_edge(nodes[u], nodes[v], OrderedFloat(w));
}
let (prim_edges, prim_weight) = prim_mst(&g).unwrap();
assert_eq!(prim_edges.len(), 5);
assert_eq!(prim_weight, OrderedFloat(19.0));
let (kruskal_edges, kruskal_weight) = kruskal_mst(&g).unwrap();
assert_eq!(prim_edges.len(), kruskal_edges.len());
assert_eq!(prim_weight, kruskal_weight);
}
#[test]
fn test_boruvka_mst_canonical_root_grouping() {
use crate::core::types::Graph;
use crate::mst::{boruvka_mst, kruskal_mst};
use ordered_float::OrderedFloat;
let edges = [
(0, 2, 4.0),
(0, 3, 1.0),
(0, 4, 4.0),
(0, 5, 4.0),
(0, 6, 3.0),
(1, 2, 8.0),
(1, 3, 6.0),
(1, 4, 5.0),
(1, 5, 4.0),
(1, 6, 10.0),
(1, 7, 1.0),
(1, 8, 7.0),
(2, 3, 7.0),
(2, 4, 7.0),
(2, 5, 9.0),
(2, 8, 8.0),
(2, 9, 1.0),
(2, 10, 3.0),
(3, 4, 9.0),
(3, 5, 10.0),
(3, 10, 5.0),
(4, 6, 5.0),
(4, 9, 7.0),
(4, 10, 5.0),
(5, 6, 7.0),
(5, 7, 7.0),
(5, 8, 5.0),
(5, 9, 4.0),
(5, 10, 5.0),
(6, 7, 6.0),
(6, 8, 2.0),
(6, 9, 5.0),
(6, 10, 4.0),
(7, 9, 2.0),
(7, 10, 9.0),
(8, 10, 9.0),
(9, 10, 10.0),
];
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..11).map(|i| g.add_node(i)).collect();
for (u, v, w) in edges {
g.add_edge(nodes[u], nodes[v], OrderedFloat(w));
}
let (boruvka_edges, boruvka_weight) = boruvka_mst(&g).unwrap();
assert_eq!(boruvka_edges.len(), 10);
assert_eq!(boruvka_weight, OrderedFloat(25.0));
let (kruskal_edges, kruskal_weight) = kruskal_mst(&g).unwrap();
assert_eq!(boruvka_edges.len(), kruskal_edges.len());
assert_eq!(boruvka_weight, kruskal_weight);
}
use super::*;
use crate::core::types::Graph;
use ordered_float::OrderedFloat;
#[test]
fn test_kruskal_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
graph.add_edge(n2, n4, OrderedFloat(4.0));
graph.add_edge(n3, n4, OrderedFloat(5.0));
let mst = kruskal_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 3);
let total_weight: f64 = mst.0.iter().map(|e| e.weight.0).sum();
assert!((total_weight - 7.0).abs() < 1e-6);
}
#[test]
fn test_mst_disconnected_forest() {
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..5).map(|i| g.add_node(i)).collect();
for (u, v, w) in [(0, 1, 1.0), (1, 2, 2.0), (0, 2, 5.0), (3, 4, 3.0)] {
g.add_edge(nodes[u], nodes[v], OrderedFloat(w));
}
let (k_edges, k_weight) = kruskal_mst(&g).unwrap();
let (p_edges, p_weight) = prim_mst(&g).unwrap();
let (b_edges, b_weight) = boruvka_mst(&g).unwrap();
assert_eq!(k_edges.len(), 3);
assert_eq!(k_weight, OrderedFloat(6.0));
assert_eq!(p_edges.len(), 3);
assert_eq!(p_weight, k_weight);
assert_eq!(b_edges.len(), 3);
assert_eq!(b_weight, k_weight);
}
#[test]
fn test_mst_sparse_indices_after_removal() {
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..4).map(|i| g.add_node(i)).collect();
g.remove_node(nodes[1]);
g.add_edge(nodes[0], nodes[2], OrderedFloat(1.0));
g.add_edge(nodes[2], nodes[3], OrderedFloat(2.0));
let (k_edges, k_weight) = kruskal_mst(&g).unwrap();
let (p_edges, p_weight) = prim_mst(&g).unwrap();
let (b_edges, b_weight) = boruvka_mst(&g).unwrap();
assert_eq!(k_edges.len(), 2);
assert_eq!(k_weight, OrderedFloat(3.0));
assert_eq!(p_edges.len(), 2);
assert_eq!(p_weight, k_weight);
assert_eq!(b_edges.len(), 2);
assert_eq!(b_weight, k_weight);
}
#[test]
fn test_prim_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
let mst = prim_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 2);
}
}