use crate::{Graph, Vertex};
use crate::error::GraphError;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::dsu::DSU;
pub struct EdgeSpanningTreeEdge<'a, T> {
pub from : &'a Vertex<T>,
pub to : &'a Vertex<T>,
pub weight: T
}
struct KruskalEdge<U> where U : PartialOrd + Copy {
from: usize,
to: usize,
dist: U,
}
impl <U> std::cmp::PartialEq for KruskalEdge<U> where U : PartialOrd + Copy {
fn eq(&self, other: &KruskalEdge<U>) -> bool {
self.dist == other.dist
}
}
impl <U> Eq for KruskalEdge<U> where U : PartialOrd + Copy {}
impl <U> std::cmp::Ord for KruskalEdge<U> where U : PartialOrd + Copy {
fn cmp(&self, other: &Self) -> Ordering {
other.dist.partial_cmp(&self.dist).unwrap()
}
}
impl <U> std::cmp::PartialOrd for KruskalEdge<U> where U : PartialOrd + Copy {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(other.dist.partial_cmp(&self.dist).unwrap())
}
}
pub fn kruskal<T>(graph: &Graph<T>) -> Result<Vec<EdgeSpanningTreeEdge<T>>, GraphError> where T: PartialOrd + Copy {
let mut edges = vec![];
let mut heap = BinaryHeap::new();
let mut dsu = DSU::new(graph.size());
for (from, to) in graph.adj.iter().enumerate().skip(1) {
dsu.make_set(from).unwrap();
for edge in to.edges.iter() {
heap.push(KruskalEdge {
from,
to: edge.to,
dist: edge.weight
});
}
}
while let Some (value) = heap.pop() {
if dsu.find_set(value.from) != dsu.find_set(value.to) {
dsu.union_sets(value.from, value.to).unwrap();
edges.push(EdgeSpanningTreeEdge{
from: graph.get_vertex(value.from).unwrap(),
to: graph.get_vertex(value.to).unwrap(),
weight: value.dist
});
}
}
Ok(edges)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kruskal() {
let mut graph = Graph::new(20);
graph.add_edge(1, 2, 7.0).unwrap();
graph.add_edge(2, 1, 7.0).unwrap();
graph.add_edge(1, 4, 5.0).unwrap();
graph.add_edge(4, 1, 5.0).unwrap();
graph.add_edge(2, 3, 8.0).unwrap();
graph.add_edge(3, 2, 8.0).unwrap();
graph.add_edge(2, 5, 7.0).unwrap();
graph.add_edge(5, 2, 7.0).unwrap();
graph.add_edge(2, 4, 9.0).unwrap();
graph.add_edge(4, 2, 9.0).unwrap();
graph.add_edge(3, 5, 5.0).unwrap();
graph.add_edge(5, 3, 5.0).unwrap();
graph.add_edge(5, 7, 9.0).unwrap();
graph.add_edge(7, 5, 9.0).unwrap();
graph.add_edge(5, 6, 8.0).unwrap();
graph.add_edge(6, 5, 8.0).unwrap();
graph.add_edge(5, 4, 15.0).unwrap();
graph.add_edge(4, 5, 15.0).unwrap();
graph.add_edge(6, 7, 11.0).unwrap();
graph.add_edge(7, 6, 11.0).unwrap();
graph.add_edge(6, 4, 6.0).unwrap();
graph.add_edge(4, 6, 6.0).unwrap();
let edges = kruskal(&graph).unwrap();
let summary_weight: f64 = edges.iter().map(|value| value.weight).sum();
assert_eq!(39.0, summary_weight);
let res = edges.iter().map(|value| (value.from.id(), value.to.id())).collect::<Vec<(usize, usize)>>();
assert_eq!(res, vec![(1, 4), (5, 3), (4, 6), (2, 1), (5, 2), (7, 5)]);
}
}