algorithms_edu/algo/graph/minimum_spanning_tree/
prim.rs

1//! An implementation of the eager version of Prim's algorithm which relies on using an indexed
2//! priority queue data structure to query the next best edge.
3//!
4//! # Resources
5//!
6//! - [W. Fiset's video 1](https://www.youtube.com/watch?v=jsmMtJpPnhU&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=30)
7//! - [W. Fiset's video 2](https://www.youtube.com/watch?v=xq3ABa-px_g&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=31)
8//! - [W. Fiset's video 3](https://www.youtube.com/watch?v=CI5Fvk-dGVs&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=32)
9//! - [Wikipedia](https://www.wikiwand.com/en/Prim%27s_algorithm)
10
11use crate::algo::graph::{Edge, WeightedAdjacencyList};
12use ordered_float::OrderedFloat;
13use priority_queue::PriorityQueue;
14
15impl WeightedAdjacencyList {
16    pub fn prim(&self) -> Option<(f64, WeightedAdjacencyList)> {
17        let n = self.node_count();
18        // the number of edges in the MST (a tree with `n` vertices has `n - 1` edges)
19        let m = n - 1;
20
21        let mut visited = vec![false; n];
22        let mut pq = PriorityQueue::new();
23
24        let add_edges = |from, visited: &mut [bool], pq: &mut PriorityQueue<_, _>| {
25            visited[from] = true;
26            // iterate over all edges going outwards from the current node.
27            // Add edges to the PQ which point to unvisited nodes.
28            for &Edge { to, weight } in &self[from] {
29                if !visited[to] {
30                    // `push_increase` queues an element if it's not already present.
31                    // Otherwise, it updates the element's priority if the new priority is higher.
32                    pq.push_increase((from, to), OrderedFloat(-weight));
33                }
34            }
35        };
36
37        let mut min_mst_cost = f64::INFINITY;
38        let mut best_mst_edges = Vec::new();
39        for i in 0..n {
40            let mut mst_cost = 0.;
41            let mut mst_edges = Vec::with_capacity(m);
42            add_edges(i, &mut visited, &mut pq);
43
44            while let Some(((from, to), cost)) = pq.pop() {
45                if mst_edges.len() == m {
46                    break;
47                };
48                if visited[to] {
49                    continue;
50                }
51                mst_edges.push((from, to, -cost.into_inner()));
52                mst_cost += -cost.into_inner();
53
54                add_edges(to, &mut visited, &mut pq);
55            }
56            if mst_edges.len() != m {
57                continue;
58            }
59            if mst_cost < min_mst_cost {
60                min_mst_cost = mst_cost;
61                best_mst_edges = mst_edges
62            }
63        }
64        if min_mst_cost == f64::INFINITY {
65            None
66        } else {
67            Some((
68                min_mst_cost,
69                WeightedAdjacencyList::new_directed(n, &best_mst_edges),
70            ))
71        }
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    #[test]
79    fn test_prim1() {
80        // from https://www.youtube.com/watch?v=jsmMtJpPnhU&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=30
81        // at 10:05
82        let g = WeightedAdjacencyList::new_directed(
83            8,
84            &[
85                (0, 1, 10.),
86                (0, 2, 1.),
87                (0, 3, 4.),
88                (2, 1, 3.),
89                (2, 5, 8.),
90                (2, 3, 2.),
91                (2, 0, 1.),
92                (3, 2, 2.),
93                (3, 5, 2.),
94                (3, 6, 7.),
95                (3, 0, 4.),
96                (5, 2, 8.),
97                (5, 4, 1.),
98                (5, 7, 9.),
99                (5, 6, 6.),
100                (5, 3, 2.),
101                (4, 1, 0.),
102                (4, 5, 1.),
103                (4, 7, 8.),
104                (1, 0, 10.),
105                (1, 2, 3.),
106                (1, 4, 0.),
107                (6, 3, 7.),
108                (6, 5, 6.),
109                (6, 7, 12.),
110                (7, 4, 8.),
111                (7, 5, 9.),
112                (7, 6, 12.),
113            ],
114        );
115        let (cost, mst) = g.prim().unwrap();
116        println!("{}", mst);
117        assert_eq!(cost, 20.);
118    }
119    #[test]
120    fn test_prim2() {
121        // from https://www.youtube.com/watch?v=xq3ABa-px_g&list=PLDV1Zeh2NRsDGO4--qE8yH72HFL1Km93P&index=31
122        // at 08:31
123        let g = WeightedAdjacencyList::new_directed(
124            7,
125            &[
126                (0, 2, 0.),
127                (0, 5, 7.),
128                (0, 3, 5.),
129                (0, 1, 9.),
130                (2, 0, 0.),
131                (2, 5, 6.),
132                (3, 0, 5.),
133                (3, 1, -2.),
134                (3, 6, 3.),
135                (3, 5, 2.),
136                (1, 0, 9.),
137                (1, 3, -2.),
138                (1, 6, 4.),
139                (1, 4, 3.),
140                (5, 2, 6.),
141                (5, 0, 7.),
142                (5, 3, 2.),
143                (5, 6, 1.),
144                (6, 5, 1.),
145                (6, 3, 3.),
146                (6, 1, 4.),
147                (6, 4, 6.),
148                (4, 1, 3.),
149                (4, 6, 6.),
150            ],
151        );
152        let (cost, mst) = g.prim().unwrap();
153        println!("{}", mst);
154        assert_eq!(cost, 9.);
155    }
156}