graphmind_graph_algorithms/
mst.rs1use super::common::{GraphView, NodeId};
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashSet};
8
9pub struct MSTResult {
10 pub total_weight: f64,
11 pub edges: Vec<(NodeId, NodeId, f64)>, }
13
14#[derive(Copy, Clone, PartialEq)]
15struct EdgeState {
16 weight: f64,
17 source: usize,
18 target: usize,
19}
20
21impl Eq for EdgeState {}
22
23impl Ord for EdgeState {
24 fn cmp(&self, other: &Self) -> Ordering {
25 other
27 .weight
28 .partial_cmp(&self.weight)
29 .unwrap_or(Ordering::Equal)
30 }
31}
32
33impl PartialOrd for EdgeState {
34 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35 Some(self.cmp(other))
36 }
37}
38
39pub fn prim_mst(view: &GraphView) -> MSTResult {
45 if view.node_count == 0 {
46 return MSTResult {
47 total_weight: 0.0,
48 edges: Vec::new(),
49 };
50 }
51
52 let start_idx = 0; let mut visited = HashSet::new();
54 let mut heap = BinaryHeap::new();
55 let mut mst_edges = Vec::new();
56 let mut total_weight = 0.0;
57
58 visited.insert(start_idx);
59
60 add_edges(view, start_idx, &mut heap, &visited);
62
63 while let Some(EdgeState {
64 weight,
65 source,
66 target,
67 }) = heap.pop()
68 {
69 if visited.contains(&target) {
70 continue;
71 }
72
73 visited.insert(target);
74 mst_edges.push((
75 view.index_to_node[source],
76 view.index_to_node[target],
77 weight,
78 ));
79 total_weight += weight;
80
81 add_edges(view, target, &mut heap, &visited);
82 }
83
84 MSTResult {
85 total_weight,
86 edges: mst_edges,
87 }
88}
89
90fn add_edges(
91 view: &GraphView,
92 u: usize,
93 heap: &mut BinaryHeap<EdgeState>,
94 visited: &HashSet<usize>,
95) {
96 let u_out = view.successors(u);
98 for (i, &v) in u_out.iter().enumerate() {
99 if !visited.contains(&v) {
100 let weight = view.weights(u).map(|w| w[i]).unwrap_or(1.0);
101 heap.push(EdgeState {
102 weight,
103 source: u,
104 target: v,
105 });
106 }
107 }
108
109 let u_in = view.predecessors(u);
111 for &_v in u_in.iter() {
112 let v = _v; if !visited.contains(&v) {
114 let v_out = view.successors(v);
118 if let Some(idx) = v_out.iter().position(|&x| x == u) {
119 let weight = view.weights(v).map(|w| w[idx]).unwrap_or(1.0);
120 heap.push(EdgeState {
121 weight,
122 source: u,
123 target: v,
124 }); }
126 }
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::collections::HashMap;
134
135 #[test]
136 fn test_prim_mst() {
137 let node_count = 3;
141 let index_to_node = vec![1, 2, 3];
142 let mut node_to_index = HashMap::new();
143 node_to_index.insert(1, 0);
144 node_to_index.insert(2, 1);
145 node_to_index.insert(3, 2);
146
147 let mut outgoing = vec![vec![]; 3];
148 let mut incoming = vec![vec![]; 3];
149 let mut weights = vec![vec![]; 3];
150
151 outgoing[0].push(1);
153 incoming[1].push(0);
154 weights[0].push(1.0);
155 outgoing[1].push(0);
157 incoming[0].push(1);
158 weights[1].push(1.0);
159
160 outgoing[1].push(2);
162 incoming[2].push(1);
163 weights[1].push(2.0);
164 outgoing[2].push(1);
166 incoming[1].push(2);
167 weights[2].push(2.0);
168
169 outgoing[0].push(2);
171 incoming[2].push(0);
172 weights[0].push(10.0);
173 outgoing[2].push(0);
175 incoming[0].push(2);
176 weights[2].push(10.0);
177
178 let view = GraphView::from_adjacency_list(
179 node_count,
180 index_to_node,
181 node_to_index,
182 outgoing,
183 incoming,
184 Some(weights),
185 );
186
187 let result = prim_mst(&view);
188 assert_eq!(result.total_weight, 3.0);
189 assert_eq!(result.edges.len(), 2);
190 }
191}