use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::convert::From;
use std::ops::{Add, AddAssign, Sub};
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, i: usize) -> usize {
if self.parent[i] != i {
self.parent[i] = self.find(self.parent[i]);
}
self.parent[i]
}
fn union(&mut self, i: usize, j: usize) {
let i = self.find(i);
let j = self.find(j);
if i == j {
return;
}
match self.rank[i].cmp(&self.rank[j]) {
Ordering::Less => self.parent[i] = j,
Ordering::Greater => self.parent[j] = i,
Ordering::Equal => {
self.parent[j] = i;
self.rank[i] += 1;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MstEdge<W> {
pub u: NodeId,
pub v: NodeId,
pub weight: W,
}
pub fn boruvka_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + Sub<Output = W> + From<u8> + Send + Sync,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let n = graph.node_count();
let all_edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
let mut uf = UnionFind::new(n);
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let mut components = n;
while components > 1 {
let roots: Vec<usize> = (0..n).map(|i| uf.find(i)).collect();
let cheapest: Vec<Option<(NodeId, NodeId, W)>> = (0..n)
.into_par_iter()
.map(|comp| {
let mut min_edge: Option<(NodeId, NodeId, W)> = None;
for &(u, v, w) in &all_edges {
let comp_u = roots[u.index()];
let comp_v = roots[v.index()];
if (comp_u == comp && comp_v != comp) || (comp_v == comp && comp_u != comp) {
match min_edge {
Some((_, _, current)) if w < current => min_edge = Some((u, v, w)),
None => min_edge = Some((u, v, w)),
_ => {}
}
}
}
min_edge
})
.collect();
let mut found = false;
for (u, v, w) in cheapest.into_iter().flatten() {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
components -= 1;
found = true;
}
}
if !found {
break;
}
}
Ok((mst_edges, total_weight))
}
pub fn kruskal_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let n = graph.node_count();
let mut edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
edges.sort_by(|a, b| a.2.cmp(&b.2));
let mut uf = UnionFind::new(n);
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
for (u, v, w) in edges {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
}
}
Ok((mst_edges, total_weight))
}
pub fn prim_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
NodeId: Ord,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let mut in_tree: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
let mut adjacency: std::collections::HashMap<NodeId, Vec<(NodeId, W)>> =
std::collections::HashMap::new();
for (u, v, w) in graph.edges() {
adjacency.entry(u).or_default().push((v, *w));
adjacency.entry(v).or_default().push((u, *w));
}
let empty: Vec<(NodeId, W)> = Vec::new();
for start in graph.nodes().map(|(node, _)| node) {
if in_tree.contains(&start) {
continue;
}
in_tree.insert(start);
let mut heap = std::collections::BinaryHeap::new();
for &(neighbor, weight) in adjacency.get(&start).unwrap_or(&empty) {
heap.push(std::cmp::Reverse((weight, start, neighbor)));
}
while let Some(std::cmp::Reverse((w, u, v))) = heap.pop() {
if in_tree.contains(&u) && in_tree.contains(&v) {
continue;
}
let (from, to) = if in_tree.contains(&u) { (u, v) } else { (v, u) };
if !in_tree.contains(&to) {
in_tree.insert(to);
mst_edges.push(MstEdge {
u: from,
v: to,
weight: w,
});
total_weight += w;
for &(neighbor, weight) in adjacency.get(&to).unwrap_or(&empty) {
if !in_tree.contains(&neighbor) {
heap.push(std::cmp::Reverse((weight, to, neighbor)));
}
}
}
}
}
Ok((mst_edges, total_weight))
}
#[cfg(test)]
mod tests {
#[test]
fn test_prim_mst_undirected_target_edges() {
use crate::core::types::Graph;
use crate::mst::{kruskal_mst, prim_mst};
use ordered_float::OrderedFloat;
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..6).map(|i| g.add_node(i)).collect();
for (u, v, w) in [
(0, 4, 5.0),
(0, 5, 2.0),
(1, 5, 1.0),
(2, 4, 10.0),
(3, 4, 1.0),
] {
g.add_edge(nodes[u], nodes[v], OrderedFloat(w));
}
let (prim_edges, prim_weight) = prim_mst(&g).unwrap();
assert_eq!(prim_edges.len(), 5);
assert_eq!(prim_weight, OrderedFloat(19.0));
let (kruskal_edges, kruskal_weight) = kruskal_mst(&g).unwrap();
assert_eq!(prim_edges.len(), kruskal_edges.len());
assert_eq!(prim_weight, kruskal_weight);
}
#[test]
fn test_boruvka_mst_canonical_root_grouping() {
use crate::core::types::Graph;
use crate::mst::{boruvka_mst, kruskal_mst};
use ordered_float::OrderedFloat;
let edges = [
(0, 2, 4.0),
(0, 3, 1.0),
(0, 4, 4.0),
(0, 5, 4.0),
(0, 6, 3.0),
(1, 2, 8.0),
(1, 3, 6.0),
(1, 4, 5.0),
(1, 5, 4.0),
(1, 6, 10.0),
(1, 7, 1.0),
(1, 8, 7.0),
(2, 3, 7.0),
(2, 4, 7.0),
(2, 5, 9.0),
(2, 8, 8.0),
(2, 9, 1.0),
(2, 10, 3.0),
(3, 4, 9.0),
(3, 5, 10.0),
(3, 10, 5.0),
(4, 6, 5.0),
(4, 9, 7.0),
(4, 10, 5.0),
(5, 6, 7.0),
(5, 7, 7.0),
(5, 8, 5.0),
(5, 9, 4.0),
(5, 10, 5.0),
(6, 7, 6.0),
(6, 8, 2.0),
(6, 9, 5.0),
(6, 10, 4.0),
(7, 9, 2.0),
(7, 10, 9.0),
(8, 10, 9.0),
(9, 10, 10.0),
];
let mut g: Graph<i32, OrderedFloat<f64>> = Graph::new();
let nodes: Vec<_> = (0..11).map(|i| g.add_node(i)).collect();
for (u, v, w) in edges {
g.add_edge(nodes[u], nodes[v], OrderedFloat(w));
}
let (boruvka_edges, boruvka_weight) = boruvka_mst(&g).unwrap();
assert_eq!(boruvka_edges.len(), 10);
assert_eq!(boruvka_weight, OrderedFloat(25.0));
let (kruskal_edges, kruskal_weight) = kruskal_mst(&g).unwrap();
assert_eq!(boruvka_edges.len(), kruskal_edges.len());
assert_eq!(boruvka_weight, kruskal_weight);
}
use super::*;
use crate::core::types::Graph;
use ordered_float::OrderedFloat;
#[test]
fn test_kruskal_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
graph.add_edge(n2, n4, OrderedFloat(4.0));
graph.add_edge(n3, n4, OrderedFloat(5.0));
let mst = kruskal_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 3);
let total_weight: f64 = mst.0.iter().map(|e| e.weight.0).sum();
assert!((total_weight - 7.0).abs() < 1e-6);
}
#[test]
fn test_prim_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
let mst = prim_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 2);
}
}