use rand::seq::SliceRandom;
use rand::thread_rng;
pub type Edge = (usize, usize, f64);
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize) -> bool {
let x_root = self.find(x);
let y_root = self.find(y);
if x_root == y_root {
return false;
}
match self.rank[x_root].cmp(&self.rank[y_root]) {
std::cmp::Ordering::Less => self.parent[x_root] = y_root,
std::cmp::Ordering::Greater => self.parent[y_root] = x_root,
std::cmp::Ordering::Equal => {
self.parent[y_root] = x_root;
self.rank[x_root] += 1;
}
}
true
}
}
pub fn randomized_kruskal(num_vertices: usize, mut edges: Vec<Edge>) -> (Vec<Edge>, f64) {
let mut rng = thread_rng();
edges.shuffle(&mut rng);
edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
let mut uf = UnionFind::new(num_vertices);
let mut mst_edges = Vec::new();
let mut total_weight = 0.0;
for edge in edges {
if uf.union(edge.0, edge.1) {
mst_edges.push(edge);
total_weight += edge.2;
}
if mst_edges.len() == num_vertices - 1 {
break;
}
}
(mst_edges, total_weight)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_randomized_kruskal() {
let num_vertices = 4;
let edges = vec![
(0, 1, 1.0),
(0, 2, 4.0),
(0, 3, 3.0),
(1, 2, 2.0),
(1, 3, 5.0),
(2, 3, 1.5),
];
let (mst, total_weight) = randomized_kruskal(num_vertices, edges);
assert_eq!(mst.len(), num_vertices - 1);
assert!(total_weight - 4.5 < 1e-6 || total_weight - 5.0 < 1e-6);
}
}