use gdsl::ungraph::*;
use gdsl::*;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
type N = Node<usize, (), u64>;
type E = Edge<usize, (), u64>;
type Heap = BinaryHeap<Reverse<E>>;
fn prim_minimum_spanning_tree(s: &N) -> Vec<E> {
let mut mst: Vec<E> = vec![];
let mut in_mst: HashSet<usize> = HashSet::new();
let mut heap = Heap::new();
in_mst.insert(*s.key());
s.bfs()
.for_each(&mut |edge| {
heap.push(Reverse(edge.clone()));
})
.search();
let mut tmp: Vec<E> = vec![];
while let Some(Reverse(edge)) = heap.pop() {
let Edge(u, v, _) = edge.clone();
if in_mst.contains(u.key()) {
if in_mst.contains(v.key()) == false {
in_mst.insert(*v.key());
mst.push(edge.clone());
for tmp_edge in &tmp {
heap.push(Reverse(tmp_edge.clone()));
}
}
} else {
if in_mst.contains(v.key()) == false {
tmp.push(edge);
}
}
}
mst
}
fn main() {
let g1 = ungraph![
(usize) => [u64]
(0) => [ (1, 1), (3, 4), (4, 3) ]
(1) => [ (3, 4), (4, 2) ]
(2) => [ (4, 4), (5, 5) ]
(3) => [ (4, 4) ]
(4) => [ (5, 7) ]
(5) => []
];
let forest = prim_minimum_spanning_tree(&g1[0]);
let sum = forest.iter().fold(0, |acc, e| acc + e.2);
assert!(sum == 16);
let g2 = ungraph![
(usize) => [u64]
(0) => [ (1, 8), (2, 5) ]
(1) => [ (2, 10), (3, 2), (4, 18) ]
(2) => [ (3, 3), (5, 16) ]
(3) => [ (4, 12), (5, 30) ]
(4) => [ (6, 4) ]
(5) => [ (6, 26) ]
(6) => []
];
let forest = prim_minimum_spanning_tree(&g2[0]);
let sum = forest.iter().fold(0, |acc, e| acc + e.2);
assert!(sum == 42);
}