use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::convert::From;
use std::ops::{Add, AddAssign, Sub};
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, i: usize) -> usize {
if self.parent[i] != i {
self.parent[i] = self.find(self.parent[i]);
}
self.parent[i]
}
fn union(&mut self, i: usize, j: usize) {
let i = self.find(i);
let j = self.find(j);
if i == j {
return;
}
match self.rank[i].cmp(&self.rank[j]) {
Ordering::Less => self.parent[i] = j,
Ordering::Greater => self.parent[j] = i,
Ordering::Equal => {
self.parent[j] = i;
self.rank[i] += 1;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MstEdge<W> {
pub u: NodeId,
pub v: NodeId,
pub weight: W,
}
pub fn boruvka_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + Sub<Output = W> + From<u8> + Send + Sync,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let n = graph.node_count();
let all_edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
let mut uf = UnionFind::new(n);
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let mut components = n;
while components > 1 {
let uf_snapshot = uf.parent.clone();
let cheapest: Vec<Option<(NodeId, NodeId, W)>> = (0..n)
.into_par_iter()
.map(|comp| {
let mut min_edge: Option<(NodeId, NodeId, W)> = None;
for &(u, v, w) in &all_edges {
let comp_u = uf_snapshot[u.index()];
let comp_v = uf_snapshot[v.index()];
if (comp_u == comp && comp_v != comp) || (comp_v == comp && comp_u != comp) {
match min_edge {
Some((_, _, current)) if w < current => min_edge = Some((u, v, w)),
None => min_edge = Some((u, v, w)),
_ => {}
}
}
}
min_edge
})
.collect();
let mut found = false;
for (u, v, w) in cheapest.into_iter().flatten() {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
components -= 1;
found = true;
}
}
if !found {
break;
}
}
Ok((mst_edges, total_weight))
}
pub fn kruskal_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let n = graph.node_count();
let mut edges: Vec<(NodeId, NodeId, W)> = graph.edges().map(|(u, v, w)| (u, v, *w)).collect();
edges.sort_by(|a, b| a.2.cmp(&b.2));
let mut uf = UnionFind::new(n);
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
for (u, v, w) in edges {
let ru = uf.find(u.index());
let rv = uf.find(v.index());
if ru != rv {
uf.union(ru, rv);
mst_edges.push(MstEdge { u, v, weight: w });
total_weight += w;
}
}
Ok((mst_edges, total_weight))
}
pub fn prim_mst<A, W, Ty>(graph: &BaseGraph<A, W, Ty>) -> Result<(Vec<MstEdge<W>>, W)>
where
W: Copy + PartialOrd + Add<Output = W> + AddAssign + From<u8> + Ord,
Ty: GraphConstructor<A, W>,
NodeId: Ord,
{
if graph.node_count() == 0 {
return Err(GraphinaError::invalid_graph(
"Graph is empty, cannot compute MST.",
));
}
let n = graph.node_count();
let mut mst_edges = Vec::new();
let mut total_weight = W::from(0u8);
let mut in_tree = vec![false; n];
for start in graph.nodes().map(|(node, _)| node) {
if in_tree[start.index()] {
continue;
}
in_tree[start.index()] = true;
let mut heap = std::collections::BinaryHeap::new();
for (_, v, weight) in graph
.edges()
.filter(|(u, _v, _w)| *u == start)
.map(|(u, v, w)| (u, v, *w))
{
heap.push(std::cmp::Reverse((weight, start, v)));
}
while let Some(std::cmp::Reverse((w, u, v))) = heap.pop() {
if in_tree[u.index()] && in_tree[v.index()] {
continue;
}
let (from, to) = if in_tree[u.index()] { (u, v) } else { (v, u) };
if !in_tree[to.index()] {
in_tree[to.index()] = true;
mst_edges.push(MstEdge {
u: from,
v: to,
weight: w,
});
total_weight += w;
for (_, neighbor, weight) in graph
.edges()
.filter(|(x, _y, _w)| *x == to)
.map(|(x, y, w)| (x, y, *w))
{
if !in_tree[neighbor.index()] {
heap.push(std::cmp::Reverse((weight, to, neighbor)));
}
}
for (_, neighbor, weight) in graph
.edges()
.filter(|(_x, y, _w)| *y == to)
.map(|(x, y, w)| (x, y, *w))
{
if !in_tree[neighbor.index()] {
heap.push(std::cmp::Reverse((weight, to, neighbor)));
}
}
}
}
}
Ok((mst_edges, total_weight))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::Graph;
use ordered_float::OrderedFloat;
#[test]
fn test_kruskal_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
graph.add_edge(n2, n4, OrderedFloat(4.0));
graph.add_edge(n3, n4, OrderedFloat(5.0));
let mst = kruskal_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 3);
let total_weight: f64 = mst.0.iter().map(|e| e.weight.0).sum();
assert!((total_weight - 7.0).abs() < 1e-6);
}
#[test]
fn test_prim_mst() {
let mut graph = Graph::<i32, OrderedFloat<f64>>::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
graph.add_edge(n1, n2, OrderedFloat(1.0));
graph.add_edge(n1, n3, OrderedFloat(3.0));
graph.add_edge(n2, n3, OrderedFloat(2.0));
let mst = prim_mst(&graph).expect("MST should exist");
assert_eq!(mst.0.len(), 2);
}
}