oxicuda_graphalg/mst/
prim.rs1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::error::{GraphalgError, GraphalgResult};
7use crate::repr::weighted_graph::WeightedGraph;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10struct Item {
11 weight: f64,
12 node: usize,
13 from: usize,
14}
15impl Eq for Item {}
16impl PartialOrd for Item {
17 fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
18 Some(self.cmp(o))
19 }
20}
21impl Ord for Item {
22 fn cmp(&self, o: &Self) -> Ordering {
23 o.weight
24 .partial_cmp(&self.weight)
25 .unwrap_or(Ordering::Equal)
26 .then(o.node.cmp(&self.node))
27 }
28}
29
30pub fn prim_mst(g: &WeightedGraph, source: usize) -> GraphalgResult<Vec<(usize, usize, f64)>> {
33 if source >= g.n {
34 return Err(GraphalgError::SourceOutOfRange {
35 node: source,
36 n: g.n,
37 });
38 }
39 let mut in_tree = vec![false; g.n];
40 let mut edges: Vec<(usize, usize, f64)> = Vec::new();
41 let mut heap: BinaryHeap<Item> = BinaryHeap::new();
42 in_tree[source] = true;
43 for &(v, w) in g.neighbors(source)? {
44 heap.push(Item {
45 weight: w,
46 node: v,
47 from: source,
48 });
49 }
50 while let Some(Item { weight, node, from }) = heap.pop() {
51 if in_tree[node] {
52 continue;
53 }
54 in_tree[node] = true;
55 edges.push((from, node, weight));
56 for &(v, w) in g.neighbors(node)? {
57 if !in_tree[v] {
58 heap.push(Item {
59 weight: w,
60 node: v,
61 from: node,
62 });
63 }
64 }
65 }
66 if edges.len() != g.n - 1 {
67 return Err(GraphalgError::DisconnectedGraph(
68 "Prim could not span the graph".to_string(),
69 ));
70 }
71 Ok(edges)
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 fn small() -> WeightedGraph {
79 let mut g = WeightedGraph::new(4);
80 g.add_undirected_edge(0, 1, 1.0).expect("ok");
81 g.add_undirected_edge(0, 2, 4.0).expect("ok");
82 g.add_undirected_edge(1, 2, 2.0).expect("ok");
83 g.add_undirected_edge(1, 3, 5.0).expect("ok");
84 g.add_undirected_edge(2, 3, 1.0).expect("ok");
85 g
86 }
87
88 #[test]
89 fn prim_total_weight() {
90 let g = small();
91 let mst = prim_mst(&g, 0).expect("ok");
92 let total: f64 = mst.iter().map(|e| e.2).sum();
93 assert!((total - 4.0).abs() < 1e-12);
94 }
95
96 #[test]
97 fn prim_disconnected_err() {
98 let mut g = WeightedGraph::new(4);
99 g.add_undirected_edge(0, 1, 1.0).expect("ok");
100 g.add_undirected_edge(2, 3, 1.0).expect("ok");
101 assert!(prim_mst(&g, 0).is_err());
102 }
103}