Skip to main content

oxicuda_graphalg/mst/
prim.rs

1//! Prim's MST algorithm with binary heap.
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::error::{GraphalgError, GraphalgResult};
7use crate::repr::weighted_graph::WeightedGraph;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10struct Item {
11    weight: f64,
12    node: usize,
13    from: usize,
14}
15impl Eq for Item {}
16impl PartialOrd for Item {
17    fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
18        Some(self.cmp(o))
19    }
20}
21impl Ord for Item {
22    fn cmp(&self, o: &Self) -> Ordering {
23        o.weight
24            .partial_cmp(&self.weight)
25            .unwrap_or(Ordering::Equal)
26            .then(o.node.cmp(&self.node))
27    }
28}
29
30/// Run Prim's MST starting from `source`. Returns the list of MST edges `(u, v, w)`.
31/// Assumes graph is undirected and connected. Negative weights are allowed for MST.
32pub fn prim_mst(g: &WeightedGraph, source: usize) -> GraphalgResult<Vec<(usize, usize, f64)>> {
33    if source >= g.n {
34        return Err(GraphalgError::SourceOutOfRange {
35            node: source,
36            n: g.n,
37        });
38    }
39    let mut in_tree = vec![false; g.n];
40    let mut edges: Vec<(usize, usize, f64)> = Vec::new();
41    let mut heap: BinaryHeap<Item> = BinaryHeap::new();
42    in_tree[source] = true;
43    for &(v, w) in g.neighbors(source)? {
44        heap.push(Item {
45            weight: w,
46            node: v,
47            from: source,
48        });
49    }
50    while let Some(Item { weight, node, from }) = heap.pop() {
51        if in_tree[node] {
52            continue;
53        }
54        in_tree[node] = true;
55        edges.push((from, node, weight));
56        for &(v, w) in g.neighbors(node)? {
57            if !in_tree[v] {
58                heap.push(Item {
59                    weight: w,
60                    node: v,
61                    from: node,
62                });
63            }
64        }
65    }
66    if edges.len() != g.n - 1 {
67        return Err(GraphalgError::DisconnectedGraph(
68            "Prim could not span the graph".to_string(),
69        ));
70    }
71    Ok(edges)
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    fn small() -> WeightedGraph {
79        let mut g = WeightedGraph::new(4);
80        g.add_undirected_edge(0, 1, 1.0).expect("ok");
81        g.add_undirected_edge(0, 2, 4.0).expect("ok");
82        g.add_undirected_edge(1, 2, 2.0).expect("ok");
83        g.add_undirected_edge(1, 3, 5.0).expect("ok");
84        g.add_undirected_edge(2, 3, 1.0).expect("ok");
85        g
86    }
87
88    #[test]
89    fn prim_total_weight() {
90        let g = small();
91        let mst = prim_mst(&g, 0).expect("ok");
92        let total: f64 = mst.iter().map(|e| e.2).sum();
93        assert!((total - 4.0).abs() < 1e-12);
94    }
95
96    #[test]
97    fn prim_disconnected_err() {
98        let mut g = WeightedGraph::new(4);
99        g.add_undirected_edge(0, 1, 1.0).expect("ok");
100        g.add_undirected_edge(2, 3, 1.0).expect("ok");
101        assert!(prim_mst(&g, 0).is_err());
102    }
103}