1#![allow(clippy::cast_precision_loss)] use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11use crate::{GraphEngine, PropertyValue, Result};
12
13#[derive(Debug, Clone)]
15pub struct MstConfig {
16 pub weight_property: String,
18 pub default_weight: f64,
20 pub compute_forest: bool,
22}
23
24impl Default for MstConfig {
25 fn default() -> Self {
26 Self {
27 weight_property: "weight".to_string(),
28 default_weight: 1.0,
29 compute_forest: true,
30 }
31 }
32}
33
34impl MstConfig {
35 #[must_use]
36 pub fn new(weight_property: impl Into<String>) -> Self {
37 Self {
38 weight_property: weight_property.into(),
39 ..Self::default()
40 }
41 }
42
43 #[must_use]
44 pub const fn default_weight(mut self, weight: f64) -> Self {
45 self.default_weight = weight;
46 self
47 }
48
49 #[must_use]
50 pub const fn compute_forest(mut self, compute: bool) -> Self {
51 self.compute_forest = compute;
52 self
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58pub struct MstEdge {
59 pub edge_id: u64,
60 pub from: u64,
61 pub to: u64,
62 pub weight: f64,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub struct MstResult {
68 pub edges: Vec<MstEdge>,
70 pub total_weight: f64,
72 pub tree_count: usize,
74 pub nodes: Vec<u64>,
76}
77
78impl MstResult {
79 #[must_use]
80 pub const fn empty() -> Self {
81 Self {
82 edges: Vec::new(),
83 total_weight: 0.0,
84 tree_count: 0,
85 nodes: Vec::new(),
86 }
87 }
88
89 #[must_use]
90 pub const fn is_connected(&self) -> bool {
91 self.tree_count == 1
92 }
93
94 #[must_use]
95 pub const fn edge_count(&self) -> usize {
96 self.edges.len()
97 }
98}
99
100impl Default for MstResult {
101 fn default() -> Self {
102 Self::empty()
103 }
104}
105
106struct UnionFind {
108 parent: HashMap<u64, u64>,
109 rank: HashMap<u64, usize>,
110}
111
112impl UnionFind {
113 fn new(nodes: &[u64]) -> Self {
114 let parent = nodes.iter().map(|&n| (n, n)).collect();
115 let rank = nodes.iter().map(|&n| (n, 0)).collect();
116 Self { parent, rank }
117 }
118
119 fn find(&mut self, x: u64) -> u64 {
120 let p = self.parent[&x];
121 if p == x {
122 x
123 } else {
124 let root = self.find(p);
125 self.parent.insert(x, root);
126 root
127 }
128 }
129
130 fn union(&mut self, x: u64, y: u64) -> bool {
131 let rx = self.find(x);
132 let ry = self.find(y);
133 if rx == ry {
134 return false; }
136
137 let rank_x = self.rank[&rx];
138 let rank_y = self.rank[&ry];
139
140 match rank_x.cmp(&rank_y) {
141 std::cmp::Ordering::Less => {
142 self.parent.insert(rx, ry);
143 },
144 std::cmp::Ordering::Greater => {
145 self.parent.insert(ry, rx);
146 },
147 std::cmp::Ordering::Equal => {
148 self.parent.insert(ry, rx);
149 self.rank.insert(rx, rank_x + 1);
150 },
151 }
152 true
153 }
154}
155
156impl GraphEngine {
157 pub fn minimum_spanning_tree(&self, config: &MstConfig) -> Result<MstResult> {
165 let nodes = self.get_all_node_ids()?;
166 if nodes.is_empty() {
167 return Ok(MstResult::empty());
168 }
169
170 let mut weighted_edges: Vec<(u64, u64, u64, f64)> = Vec::new(); for key in self.store().scan("edge:") {
174 if let Some(id_str) = key.strip_prefix("edge:") {
175 if let Ok(edge_id) = id_str.parse::<u64>() {
176 if let Ok(edge) = self.get_edge(edge_id) {
177 let weight = match edge.properties.get(&config.weight_property) {
178 Some(PropertyValue::Float(w)) => *w,
179 Some(PropertyValue::Int(w)) => *w as f64,
180 _ => config.default_weight,
181 };
182 weighted_edges.push((edge.from, edge.to, edge_id, weight));
183 }
184 }
185 }
186 }
187
188 weighted_edges.sort_by(|a, b| a.3.partial_cmp(&b.3).unwrap_or(std::cmp::Ordering::Equal));
190
191 let mut uf = UnionFind::new(&nodes);
193 let mut mst_edges = Vec::new();
194 let mut total_weight = 0.0;
195
196 for (from, to, edge_id, weight) in weighted_edges {
197 if uf.union(from, to) {
198 mst_edges.push(MstEdge {
199 edge_id,
200 from,
201 to,
202 weight,
203 });
204 total_weight += weight;
205
206 if !config.compute_forest && mst_edges.len() == nodes.len() - 1 {
208 break;
209 }
210 }
211 }
212
213 let mut roots = std::collections::HashSet::new();
215 for &node in &nodes {
216 roots.insert(uf.find(node));
217 }
218 let tree_count = roots.len();
219
220 Ok(MstResult {
221 edges: mst_edges,
222 total_weight,
223 tree_count,
224 nodes,
225 })
226 }
227
228 pub fn minimum_spanning_forest(&self, weight_property: &str) -> Result<Vec<MstResult>> {
234 let result =
235 self.minimum_spanning_tree(&MstConfig::new(weight_property).compute_forest(true))?;
236
237 if result.tree_count <= 1 {
238 return Ok(vec![result]);
239 }
240
241 let mut uf = UnionFind::new(&result.nodes);
243 for edge in &result.edges {
244 uf.union(edge.from, edge.to);
245 }
246
247 let mut components: HashMap<u64, Vec<MstEdge>> = HashMap::new();
248 let mut component_nodes: HashMap<u64, Vec<u64>> = HashMap::new();
249
250 for edge in result.edges {
251 let root = uf.find(edge.from);
252 components.entry(root).or_default().push(edge);
253 }
254
255 for &node in &result.nodes {
256 let root = uf.find(node);
257 component_nodes.entry(root).or_default().push(node);
258 }
259
260 let mut forests = Vec::new();
261 for (root, edges) in components {
262 let total_weight = edges.iter().map(|e| e.weight).sum();
263 let nodes = component_nodes.remove(&root).unwrap_or_default();
264 forests.push(MstResult {
265 edges,
266 total_weight,
267 tree_count: 1,
268 nodes,
269 });
270 }
271
272 for (_, nodes) in component_nodes {
274 for node in nodes {
275 forests.push(MstResult {
276 edges: Vec::new(),
277 total_weight: 0.0,
278 tree_count: 1,
279 nodes: vec![node],
280 });
281 }
282 }
283
284 Ok(forests)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn create_weighted_edge(engine: &GraphEngine, from: u64, to: u64, weight: f64) -> u64 {
293 let mut props = HashMap::new();
294 props.insert("weight".to_string(), PropertyValue::Float(weight));
295 engine.create_edge(from, to, "EDGE", props, false).unwrap()
296 }
297
298 #[test]
299 fn test_mst_empty_graph() {
300 let engine = GraphEngine::new();
301 let result = engine.minimum_spanning_tree(&MstConfig::default()).unwrap();
302 assert!(result.edges.is_empty());
303 assert_eq!(result.tree_count, 0);
304 }
305
306 #[test]
307 fn test_mst_single_node() {
308 let engine = GraphEngine::new();
309 engine.create_node("A", HashMap::new()).unwrap();
310
311 let result = engine.minimum_spanning_tree(&MstConfig::default()).unwrap();
312 assert!(result.edges.is_empty());
313 assert_eq!(result.tree_count, 1);
314 assert_eq!(result.nodes.len(), 1);
315 }
316
317 #[test]
318 fn test_mst_simple_triangle() {
319 let engine = GraphEngine::new();
320 let a = engine.create_node("A", HashMap::new()).unwrap();
321 let b = engine.create_node("B", HashMap::new()).unwrap();
322 let c = engine.create_node("C", HashMap::new()).unwrap();
323
324 create_weighted_edge(&engine, a, b, 1.0);
326 create_weighted_edge(&engine, b, c, 2.0);
327 create_weighted_edge(&engine, a, c, 3.0);
328
329 let result = engine
330 .minimum_spanning_tree(&MstConfig::new("weight"))
331 .unwrap();
332
333 assert_eq!(result.edge_count(), 2); assert!((result.total_weight - 3.0).abs() < f64::EPSILON); assert!(result.is_connected());
336 }
337
338 #[test]
339 fn test_mst_selects_minimum_edges() {
340 let engine = GraphEngine::new();
341 let a = engine.create_node("A", HashMap::new()).unwrap();
342 let b = engine.create_node("B", HashMap::new()).unwrap();
343 let c = engine.create_node("C", HashMap::new()).unwrap();
344 let d = engine.create_node("D", HashMap::new()).unwrap();
345
346 create_weighted_edge(&engine, a, b, 1.0);
348 create_weighted_edge(&engine, b, c, 2.0);
349 create_weighted_edge(&engine, c, d, 3.0);
350 create_weighted_edge(&engine, a, d, 10.0); let result = engine
353 .minimum_spanning_tree(&MstConfig::new("weight"))
354 .unwrap();
355
356 assert_eq!(result.edge_count(), 3);
357 assert!((result.total_weight - 6.0).abs() < f64::EPSILON); }
359
360 #[test]
361 fn test_mst_forest() {
362 let engine = GraphEngine::new();
363 let a = engine.create_node("A", HashMap::new()).unwrap();
364 let b = engine.create_node("B", HashMap::new()).unwrap();
365 let c = engine.create_node("C", HashMap::new()).unwrap();
366 let d = engine.create_node("D", HashMap::new()).unwrap();
367
368 create_weighted_edge(&engine, a, b, 1.0);
370 create_weighted_edge(&engine, c, d, 2.0);
371
372 let result = engine
373 .minimum_spanning_tree(&MstConfig::new("weight"))
374 .unwrap();
375
376 assert_eq!(result.edge_count(), 2);
377 assert_eq!(result.tree_count, 2);
378 assert!(!result.is_connected());
379 }
380
381 #[test]
382 fn test_mst_forest_split() {
383 let engine = GraphEngine::new();
384 let a = engine.create_node("A", HashMap::new()).unwrap();
385 let b = engine.create_node("B", HashMap::new()).unwrap();
386 let c = engine.create_node("C", HashMap::new()).unwrap();
387 let d = engine.create_node("D", HashMap::new()).unwrap();
388
389 create_weighted_edge(&engine, a, b, 1.0);
390 create_weighted_edge(&engine, c, d, 2.0);
391
392 let forests = engine.minimum_spanning_forest("weight").unwrap();
393 assert_eq!(forests.len(), 2);
394 }
395
396 #[test]
397 fn test_mst_default_weight() {
398 let engine = GraphEngine::new();
399 let a = engine.create_node("A", HashMap::new()).unwrap();
400 let b = engine.create_node("B", HashMap::new()).unwrap();
401
402 engine
404 .create_edge(a, b, "EDGE", HashMap::new(), false)
405 .unwrap();
406
407 let config = MstConfig::new("weight").default_weight(5.0);
408 let result = engine.minimum_spanning_tree(&config).unwrap();
409
410 assert_eq!(result.edge_count(), 1);
411 assert!((result.total_weight - 5.0).abs() < f64::EPSILON);
412 }
413
414 #[test]
415 fn test_mst_integer_weight() {
416 let engine = GraphEngine::new();
417 let a = engine.create_node("A", HashMap::new()).unwrap();
418 let b = engine.create_node("B", HashMap::new()).unwrap();
419
420 let mut props = HashMap::new();
421 props.insert("weight".to_string(), PropertyValue::Int(42));
422 engine.create_edge(a, b, "EDGE", props, false).unwrap();
423
424 let result = engine
425 .minimum_spanning_tree(&MstConfig::new("weight"))
426 .unwrap();
427 assert!((result.total_weight - 42.0).abs() < f64::EPSILON);
428 }
429}