Skip to main content

graph_engine/algorithms/
mst.rs

1//! Minimum Spanning Tree using Kruskal's algorithm.
2//!
3//! Computes the minimum spanning tree (or forest) of a graph using edge weights.
4
5#![allow(clippy::cast_precision_loss)] // Acceptable for graph algorithm metrics
6
7use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11use crate::{GraphEngine, PropertyValue, Result};
12
13/// Configuration for MST computation.
14#[derive(Debug, Clone)]
15pub struct MstConfig {
16    /// Property name to use as edge weight.
17    pub weight_property: String,
18    /// Default weight for edges without the weight property.
19    pub default_weight: f64,
20    /// Whether to compute MST for each connected component (forest).
21    pub compute_forest: bool,
22}
23
24impl Default for MstConfig {
25    fn default() -> Self {
26        Self {
27            weight_property: "weight".to_string(),
28            default_weight: 1.0,
29            compute_forest: true,
30        }
31    }
32}
33
34impl MstConfig {
35    #[must_use]
36    pub fn new(weight_property: impl Into<String>) -> Self {
37        Self {
38            weight_property: weight_property.into(),
39            ..Self::default()
40        }
41    }
42
43    #[must_use]
44    pub const fn default_weight(mut self, weight: f64) -> Self {
45        self.default_weight = weight;
46        self
47    }
48
49    #[must_use]
50    pub const fn compute_forest(mut self, compute: bool) -> Self {
51        self.compute_forest = compute;
52        self
53    }
54}
55
56/// An edge in the MST result.
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
58pub struct MstEdge {
59    pub edge_id: u64,
60    pub from: u64,
61    pub to: u64,
62    pub weight: f64,
63}
64
65/// Result of MST computation.
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub struct MstResult {
68    /// Edges in the minimum spanning tree (or forest).
69    pub edges: Vec<MstEdge>,
70    /// Total weight of the MST.
71    pub total_weight: f64,
72    /// Number of trees in the forest (1 for connected graphs).
73    pub tree_count: usize,
74    /// Nodes included in the MST.
75    pub nodes: Vec<u64>,
76}
77
78impl MstResult {
79    #[must_use]
80    pub const fn empty() -> Self {
81        Self {
82            edges: Vec::new(),
83            total_weight: 0.0,
84            tree_count: 0,
85            nodes: Vec::new(),
86        }
87    }
88
89    #[must_use]
90    pub const fn is_connected(&self) -> bool {
91        self.tree_count == 1
92    }
93
94    #[must_use]
95    pub const fn edge_count(&self) -> usize {
96        self.edges.len()
97    }
98}
99
100impl Default for MstResult {
101    fn default() -> Self {
102        Self::empty()
103    }
104}
105
106/// Union-Find data structure for Kruskal's algorithm.
107struct UnionFind {
108    parent: HashMap<u64, u64>,
109    rank: HashMap<u64, usize>,
110}
111
112impl UnionFind {
113    fn new(nodes: &[u64]) -> Self {
114        let parent = nodes.iter().map(|&n| (n, n)).collect();
115        let rank = nodes.iter().map(|&n| (n, 0)).collect();
116        Self { parent, rank }
117    }
118
119    fn find(&mut self, x: u64) -> u64 {
120        let p = self.parent[&x];
121        if p == x {
122            x
123        } else {
124            let root = self.find(p);
125            self.parent.insert(x, root);
126            root
127        }
128    }
129
130    fn union(&mut self, x: u64, y: u64) -> bool {
131        let rx = self.find(x);
132        let ry = self.find(y);
133        if rx == ry {
134            return false; // Already in same set
135        }
136
137        let rank_x = self.rank[&rx];
138        let rank_y = self.rank[&ry];
139
140        match rank_x.cmp(&rank_y) {
141            std::cmp::Ordering::Less => {
142                self.parent.insert(rx, ry);
143            },
144            std::cmp::Ordering::Greater => {
145                self.parent.insert(ry, rx);
146            },
147            std::cmp::Ordering::Equal => {
148                self.parent.insert(ry, rx);
149                self.rank.insert(rx, rank_x + 1);
150            },
151        }
152        true
153    }
154}
155
156impl GraphEngine {
157    /// Compute the minimum spanning tree (or forest) using Kruskal's algorithm.
158    ///
159    /// Time complexity: O(E log E) for sorting edges.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if edge retrieval fails.
164    pub fn minimum_spanning_tree(&self, config: &MstConfig) -> Result<MstResult> {
165        let nodes = self.get_all_node_ids()?;
166        if nodes.is_empty() {
167            return Ok(MstResult::empty());
168        }
169
170        // Collect all edges with weights
171        let mut weighted_edges: Vec<(u64, u64, u64, f64)> = Vec::new(); // (from, to, edge_id, weight)
172
173        for key in self.store().scan("edge:") {
174            if let Some(id_str) = key.strip_prefix("edge:") {
175                if let Ok(edge_id) = id_str.parse::<u64>() {
176                    if let Ok(edge) = self.get_edge(edge_id) {
177                        let weight = match edge.properties.get(&config.weight_property) {
178                            Some(PropertyValue::Float(w)) => *w,
179                            Some(PropertyValue::Int(w)) => *w as f64,
180                            _ => config.default_weight,
181                        };
182                        weighted_edges.push((edge.from, edge.to, edge_id, weight));
183                    }
184                }
185            }
186        }
187
188        // Sort edges by weight
189        weighted_edges.sort_by(|a, b| a.3.partial_cmp(&b.3).unwrap_or(std::cmp::Ordering::Equal));
190
191        // Kruskal's algorithm
192        let mut uf = UnionFind::new(&nodes);
193        let mut mst_edges = Vec::new();
194        let mut total_weight = 0.0;
195
196        for (from, to, edge_id, weight) in weighted_edges {
197            if uf.union(from, to) {
198                mst_edges.push(MstEdge {
199                    edge_id,
200                    from,
201                    to,
202                    weight,
203                });
204                total_weight += weight;
205
206                // Early termination if not computing forest
207                if !config.compute_forest && mst_edges.len() == nodes.len() - 1 {
208                    break;
209                }
210            }
211        }
212
213        // Count number of trees (connected components)
214        let mut roots = std::collections::HashSet::new();
215        for &node in &nodes {
216            roots.insert(uf.find(node));
217        }
218        let tree_count = roots.len();
219
220        Ok(MstResult {
221            edges: mst_edges,
222            total_weight,
223            tree_count,
224            nodes,
225        })
226    }
227
228    /// Compute minimum spanning forest (MST for each connected component).
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if MST computation fails.
233    pub fn minimum_spanning_forest(&self, weight_property: &str) -> Result<Vec<MstResult>> {
234        let result =
235            self.minimum_spanning_tree(&MstConfig::new(weight_property).compute_forest(true))?;
236
237        if result.tree_count <= 1 {
238            return Ok(vec![result]);
239        }
240
241        // Group edges by component
242        let mut uf = UnionFind::new(&result.nodes);
243        for edge in &result.edges {
244            uf.union(edge.from, edge.to);
245        }
246
247        let mut components: HashMap<u64, Vec<MstEdge>> = HashMap::new();
248        let mut component_nodes: HashMap<u64, Vec<u64>> = HashMap::new();
249
250        for edge in result.edges {
251            let root = uf.find(edge.from);
252            components.entry(root).or_default().push(edge);
253        }
254
255        for &node in &result.nodes {
256            let root = uf.find(node);
257            component_nodes.entry(root).or_default().push(node);
258        }
259
260        let mut forests = Vec::new();
261        for (root, edges) in components {
262            let total_weight = edges.iter().map(|e| e.weight).sum();
263            let nodes = component_nodes.remove(&root).unwrap_or_default();
264            forests.push(MstResult {
265                edges,
266                total_weight,
267                tree_count: 1,
268                nodes,
269            });
270        }
271
272        // Add isolated nodes as separate trees
273        for (_, nodes) in component_nodes {
274            for node in nodes {
275                forests.push(MstResult {
276                    edges: Vec::new(),
277                    total_weight: 0.0,
278                    tree_count: 1,
279                    nodes: vec![node],
280                });
281            }
282        }
283
284        Ok(forests)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    fn create_weighted_edge(engine: &GraphEngine, from: u64, to: u64, weight: f64) -> u64 {
293        let mut props = HashMap::new();
294        props.insert("weight".to_string(), PropertyValue::Float(weight));
295        engine.create_edge(from, to, "EDGE", props, false).unwrap()
296    }
297
298    #[test]
299    fn test_mst_empty_graph() {
300        let engine = GraphEngine::new();
301        let result = engine.minimum_spanning_tree(&MstConfig::default()).unwrap();
302        assert!(result.edges.is_empty());
303        assert_eq!(result.tree_count, 0);
304    }
305
306    #[test]
307    fn test_mst_single_node() {
308        let engine = GraphEngine::new();
309        engine.create_node("A", HashMap::new()).unwrap();
310
311        let result = engine.minimum_spanning_tree(&MstConfig::default()).unwrap();
312        assert!(result.edges.is_empty());
313        assert_eq!(result.tree_count, 1);
314        assert_eq!(result.nodes.len(), 1);
315    }
316
317    #[test]
318    fn test_mst_simple_triangle() {
319        let engine = GraphEngine::new();
320        let a = engine.create_node("A", HashMap::new()).unwrap();
321        let b = engine.create_node("B", HashMap::new()).unwrap();
322        let c = engine.create_node("C", HashMap::new()).unwrap();
323
324        // Triangle with weights: A-B:1, B-C:2, A-C:3
325        create_weighted_edge(&engine, a, b, 1.0);
326        create_weighted_edge(&engine, b, c, 2.0);
327        create_weighted_edge(&engine, a, c, 3.0);
328
329        let result = engine
330            .minimum_spanning_tree(&MstConfig::new("weight"))
331            .unwrap();
332
333        assert_eq!(result.edge_count(), 2); // MST has n-1 edges
334        assert!((result.total_weight - 3.0).abs() < f64::EPSILON); // 1 + 2 = 3
335        assert!(result.is_connected());
336    }
337
338    #[test]
339    fn test_mst_selects_minimum_edges() {
340        let engine = GraphEngine::new();
341        let a = engine.create_node("A", HashMap::new()).unwrap();
342        let b = engine.create_node("B", HashMap::new()).unwrap();
343        let c = engine.create_node("C", HashMap::new()).unwrap();
344        let d = engine.create_node("D", HashMap::new()).unwrap();
345
346        // Create a graph where MST should pick specific edges
347        create_weighted_edge(&engine, a, b, 1.0);
348        create_weighted_edge(&engine, b, c, 2.0);
349        create_weighted_edge(&engine, c, d, 3.0);
350        create_weighted_edge(&engine, a, d, 10.0); // Should not be selected
351
352        let result = engine
353            .minimum_spanning_tree(&MstConfig::new("weight"))
354            .unwrap();
355
356        assert_eq!(result.edge_count(), 3);
357        assert!((result.total_weight - 6.0).abs() < f64::EPSILON); // 1 + 2 + 3
358    }
359
360    #[test]
361    fn test_mst_forest() {
362        let engine = GraphEngine::new();
363        let a = engine.create_node("A", HashMap::new()).unwrap();
364        let b = engine.create_node("B", HashMap::new()).unwrap();
365        let c = engine.create_node("C", HashMap::new()).unwrap();
366        let d = engine.create_node("D", HashMap::new()).unwrap();
367
368        // Two disconnected components
369        create_weighted_edge(&engine, a, b, 1.0);
370        create_weighted_edge(&engine, c, d, 2.0);
371
372        let result = engine
373            .minimum_spanning_tree(&MstConfig::new("weight"))
374            .unwrap();
375
376        assert_eq!(result.edge_count(), 2);
377        assert_eq!(result.tree_count, 2);
378        assert!(!result.is_connected());
379    }
380
381    #[test]
382    fn test_mst_forest_split() {
383        let engine = GraphEngine::new();
384        let a = engine.create_node("A", HashMap::new()).unwrap();
385        let b = engine.create_node("B", HashMap::new()).unwrap();
386        let c = engine.create_node("C", HashMap::new()).unwrap();
387        let d = engine.create_node("D", HashMap::new()).unwrap();
388
389        create_weighted_edge(&engine, a, b, 1.0);
390        create_weighted_edge(&engine, c, d, 2.0);
391
392        let forests = engine.minimum_spanning_forest("weight").unwrap();
393        assert_eq!(forests.len(), 2);
394    }
395
396    #[test]
397    fn test_mst_default_weight() {
398        let engine = GraphEngine::new();
399        let a = engine.create_node("A", HashMap::new()).unwrap();
400        let b = engine.create_node("B", HashMap::new()).unwrap();
401
402        // Edge without weight property
403        engine
404            .create_edge(a, b, "EDGE", HashMap::new(), false)
405            .unwrap();
406
407        let config = MstConfig::new("weight").default_weight(5.0);
408        let result = engine.minimum_spanning_tree(&config).unwrap();
409
410        assert_eq!(result.edge_count(), 1);
411        assert!((result.total_weight - 5.0).abs() < f64::EPSILON);
412    }
413
414    #[test]
415    fn test_mst_integer_weight() {
416        let engine = GraphEngine::new();
417        let a = engine.create_node("A", HashMap::new()).unwrap();
418        let b = engine.create_node("B", HashMap::new()).unwrap();
419
420        let mut props = HashMap::new();
421        props.insert("weight".to_string(), PropertyValue::Int(42));
422        engine.create_edge(a, b, "EDGE", props, false).unwrap();
423
424        let result = engine
425            .minimum_spanning_tree(&MstConfig::new("weight"))
426            .unwrap();
427        assert!((result.total_weight - 42.0).abs() < f64::EPSILON);
428    }
429}