Skip to main content

grafeo_adapters/plugins/algorithms/
mst.rs

1//! Minimum Spanning Tree algorithms: Kruskal, Prim.
2//!
3//! These algorithms find a tree that connects all nodes in an undirected
4//! graph with minimum total edge weight.
5
6use std::collections::BinaryHeap;
7use std::sync::OnceLock;
8
9use grafeo_common::types::{EdgeId, NodeId, Value};
10use grafeo_common::utils::error::Result;
11use grafeo_common::utils::hash::FxHashMap;
12use grafeo_core::graph::Direction;
13use grafeo_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::components::UnionFind;
17use super::traits::{GraphAlgorithm, MinScored};
18
19// ============================================================================
20// Edge Weight Extraction
21// ============================================================================
22
23/// Extracts edge weight from a property value.
24fn extract_weight(store: &LpgStore, edge_id: EdgeId, weight_prop: Option<&str>) -> f64 {
25    if let Some(prop_name) = weight_prop
26        && let Some(edge) = store.get_edge(edge_id)
27        && let Some(value) = edge.get_property(prop_name)
28    {
29        return match value {
30            Value::Int64(i) => *i as f64,
31            Value::Float64(f) => *f,
32            _ => 1.0,
33        };
34    }
35    1.0
36}
37
38// ============================================================================
39// MST Result
40// ============================================================================
41
42/// Result of MST algorithms.
43#[derive(Debug, Clone)]
44pub struct MstResult {
45    /// Edges in the MST: (source, target, edge_id, weight)
46    pub edges: Vec<(NodeId, NodeId, EdgeId, f64)>,
47    /// Total weight of the MST.
48    pub total_weight: f64,
49}
50
51impl MstResult {
52    /// Returns the number of edges in the MST.
53    pub fn edge_count(&self) -> usize {
54        self.edges.len()
55    }
56
57    /// Returns true if this is a valid spanning tree for n nodes.
58    pub fn is_spanning_tree(&self, node_count: usize) -> bool {
59        if node_count == 0 {
60            return self.edges.is_empty();
61        }
62        self.edges.len() == node_count - 1
63    }
64}
65
66// ============================================================================
67// Kruskal's Algorithm
68// ============================================================================
69
70/// Computes the Minimum Spanning Tree using Kruskal's algorithm.
71///
72/// Kruskal's algorithm sorts all edges by weight and greedily adds
73/// edges that don't create a cycle (using Union-Find).
74///
75/// # Arguments
76///
77/// * `store` - The graph store
78/// * `weight_property` - Optional property name for edge weights (defaults to 1.0)
79///
80/// # Returns
81///
82/// The MST edges and total weight.
83///
84/// # Complexity
85///
86/// O(E log E) for sorting edges
87pub fn kruskal(store: &LpgStore, weight_property: Option<&str>) -> MstResult {
88    let nodes = store.node_ids();
89    let n = nodes.len();
90
91    if n == 0 {
92        return MstResult {
93            edges: Vec::new(),
94            total_weight: 0.0,
95        };
96    }
97
98    // Build node index mapping
99    let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
100    for (idx, &node) in nodes.iter().enumerate() {
101        node_to_idx.insert(node, idx);
102    }
103
104    // Collect all edges with weights (treating as undirected)
105    let mut edges: Vec<(f64, NodeId, NodeId, EdgeId)> = Vec::new();
106    let mut seen_edges: std::collections::HashSet<(usize, usize)> =
107        std::collections::HashSet::new();
108
109    for &node in &nodes {
110        let i = *node_to_idx.get(&node).unwrap();
111        for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
112            if let Some(&j) = node_to_idx.get(&neighbor) {
113                // For undirected: only add each edge once
114                let key = if i < j { (i, j) } else { (j, i) };
115                if !seen_edges.contains(&key) {
116                    seen_edges.insert(key);
117                    let weight = extract_weight(store, edge_id, weight_property);
118                    edges.push((weight, node, neighbor, edge_id));
119                }
120            }
121        }
122    }
123
124    // Sort edges by weight
125    edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
126
127    // Initialize Union-Find
128    let mut uf = UnionFind::new(n);
129
130    let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
131    let mut total_weight = 0.0;
132
133    for (weight, src, dst, edge_id) in edges {
134        let i = *node_to_idx.get(&src).unwrap();
135        let j = *node_to_idx.get(&dst).unwrap();
136
137        if uf.find(i) != uf.find(j) {
138            uf.union(i, j);
139            mst_edges.push((src, dst, edge_id, weight));
140            total_weight += weight;
141
142            // MST has n-1 edges
143            if mst_edges.len() == n - 1 {
144                break;
145            }
146        }
147    }
148
149    MstResult {
150        edges: mst_edges,
151        total_weight,
152    }
153}
154
155// ============================================================================
156// Prim's Algorithm
157// ============================================================================
158
159/// Computes the Minimum Spanning Tree using Prim's algorithm.
160///
161/// Prim's algorithm grows the MST from a starting node, always adding
162/// the minimum weight edge that connects a tree node to a non-tree node.
163///
164/// # Arguments
165///
166/// * `store` - The graph store
167/// * `weight_property` - Optional property name for edge weights (defaults to 1.0)
168/// * `start` - Optional starting node (defaults to first node)
169///
170/// # Returns
171///
172/// The MST edges and total weight.
173///
174/// # Complexity
175///
176/// O(E log V) using a binary heap
177pub fn prim(store: &LpgStore, weight_property: Option<&str>, start: Option<NodeId>) -> MstResult {
178    let nodes = store.node_ids();
179    let n = nodes.len();
180
181    if n == 0 {
182        return MstResult {
183            edges: Vec::new(),
184            total_weight: 0.0,
185        };
186    }
187
188    // Start from the first node or specified start
189    let start_node = start.unwrap_or(nodes[0]);
190
191    // Verify start node exists
192    if store.get_node(start_node).is_none() {
193        return MstResult {
194            edges: Vec::new(),
195            total_weight: 0.0,
196        };
197    }
198
199    let mut in_tree: FxHashMap<NodeId, bool> = FxHashMap::default();
200    let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
201    let mut total_weight = 0.0;
202
203    // Priority queue: (weight, source, target, edge_id)
204    let mut heap: BinaryHeap<MinScored<f64, (NodeId, NodeId, EdgeId)>> = BinaryHeap::new();
205
206    // Start with the first node
207    in_tree.insert(start_node, true);
208
209    // Add edges from start node
210    for (neighbor, edge_id) in store.edges_from(start_node, Direction::Outgoing) {
211        let weight = extract_weight(store, edge_id, weight_property);
212        heap.push(MinScored::new(weight, (start_node, neighbor, edge_id)));
213    }
214
215    // Also consider incoming edges (for undirected behavior)
216    for &other in &nodes {
217        for (neighbor, edge_id) in store.edges_from(other, Direction::Outgoing) {
218            if neighbor == start_node {
219                let weight = extract_weight(store, edge_id, weight_property);
220                heap.push(MinScored::new(weight, (other, start_node, edge_id)));
221            }
222        }
223    }
224
225    while let Some(MinScored(weight, (src, dst, edge_id))) = heap.pop() {
226        // Skip if target already in tree
227        if *in_tree.get(&dst).unwrap_or(&false) {
228            continue;
229        }
230
231        // Add edge to MST
232        in_tree.insert(dst, true);
233        mst_edges.push((src, dst, edge_id, weight));
234        total_weight += weight;
235
236        // Add edges from new node
237        for (neighbor, new_edge_id) in store.edges_from(dst, Direction::Outgoing) {
238            if !*in_tree.get(&neighbor).unwrap_or(&false) {
239                let new_weight = extract_weight(store, new_edge_id, weight_property);
240                heap.push(MinScored::new(new_weight, (dst, neighbor, new_edge_id)));
241            }
242        }
243
244        // Also consider incoming edges
245        for &other in &nodes {
246            if !*in_tree.get(&other).unwrap_or(&false) {
247                for (neighbor, new_edge_id) in store.edges_from(other, Direction::Outgoing) {
248                    if neighbor == dst {
249                        let new_weight = extract_weight(store, new_edge_id, weight_property);
250                        heap.push(MinScored::new(new_weight, (other, dst, new_edge_id)));
251                    }
252                }
253            }
254        }
255
256        // MST has n-1 edges
257        if mst_edges.len() == n - 1 {
258            break;
259        }
260    }
261
262    MstResult {
263        edges: mst_edges,
264        total_weight,
265    }
266}
267
268// ============================================================================
269// Algorithm Wrappers for Plugin Registry
270// ============================================================================
271
272/// Static parameter definitions for Kruskal algorithm.
273static KRUSKAL_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
274
275fn kruskal_params() -> &'static [ParameterDef] {
276    KRUSKAL_PARAMS.get_or_init(|| {
277        vec![ParameterDef {
278            name: "weight".to_string(),
279            description: "Edge property name for weights (default: 1.0)".to_string(),
280            param_type: ParameterType::String,
281            required: false,
282            default: None,
283        }]
284    })
285}
286
287/// Kruskal's MST algorithm wrapper.
288pub struct KruskalAlgorithm;
289
290impl GraphAlgorithm for KruskalAlgorithm {
291    fn name(&self) -> &str {
292        "kruskal"
293    }
294
295    fn description(&self) -> &str {
296        "Kruskal's Minimum Spanning Tree algorithm"
297    }
298
299    fn parameters(&self) -> &[ParameterDef] {
300        kruskal_params()
301    }
302
303    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
304        let weight_prop = params.get_string("weight");
305
306        let result = kruskal(store, weight_prop);
307
308        let mut output = AlgorithmResult::new(vec![
309            "source".to_string(),
310            "target".to_string(),
311            "weight".to_string(),
312            "total_weight".to_string(),
313        ]);
314
315        for (src, dst, _edge_id, weight) in result.edges {
316            output.add_row(vec![
317                Value::Int64(src.0 as i64),
318                Value::Int64(dst.0 as i64),
319                Value::Float64(weight),
320                Value::Float64(result.total_weight),
321            ]);
322        }
323
324        Ok(output)
325    }
326}
327
328/// Static parameter definitions for Prim algorithm.
329static PRIM_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
330
331fn prim_params() -> &'static [ParameterDef] {
332    PRIM_PARAMS.get_or_init(|| {
333        vec![
334            ParameterDef {
335                name: "weight".to_string(),
336                description: "Edge property name for weights (default: 1.0)".to_string(),
337                param_type: ParameterType::String,
338                required: false,
339                default: None,
340            },
341            ParameterDef {
342                name: "start".to_string(),
343                description: "Starting node ID (optional)".to_string(),
344                param_type: ParameterType::NodeId,
345                required: false,
346                default: None,
347            },
348        ]
349    })
350}
351
352/// Prim's MST algorithm wrapper.
353pub struct PrimAlgorithm;
354
355impl GraphAlgorithm for PrimAlgorithm {
356    fn name(&self) -> &str {
357        "prim"
358    }
359
360    fn description(&self) -> &str {
361        "Prim's Minimum Spanning Tree algorithm"
362    }
363
364    fn parameters(&self) -> &[ParameterDef] {
365        prim_params()
366    }
367
368    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
369        let weight_prop = params.get_string("weight");
370        let start = params.get_int("start").map(|id| NodeId::new(id as u64));
371
372        let result = prim(store, weight_prop, start);
373
374        let mut output = AlgorithmResult::new(vec![
375            "source".to_string(),
376            "target".to_string(),
377            "weight".to_string(),
378            "total_weight".to_string(),
379        ]);
380
381        for (src, dst, _edge_id, weight) in result.edges {
382            output.add_row(vec![
383                Value::Int64(src.0 as i64),
384                Value::Int64(dst.0 as i64),
385                Value::Float64(weight),
386                Value::Float64(result.total_weight),
387            ]);
388        }
389
390        Ok(output)
391    }
392}
393
394// ============================================================================
395// Tests
396// ============================================================================
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    fn create_weighted_triangle() -> LpgStore {
403        // Triangle: 0-1-2 with edges
404        // 0-1: weight 1
405        // 1-2: weight 2
406        // 0-2: weight 3
407        let store = LpgStore::new();
408
409        let n0 = store.create_node(&["Node"]);
410        let n1 = store.create_node(&["Node"]);
411        let n2 = store.create_node(&["Node"]);
412
413        store.create_edge_with_props(n0, n1, "EDGE", [("weight", Value::Float64(1.0))]);
414        store.create_edge_with_props(n1, n0, "EDGE", [("weight", Value::Float64(1.0))]);
415        store.create_edge_with_props(n1, n2, "EDGE", [("weight", Value::Float64(2.0))]);
416        store.create_edge_with_props(n2, n1, "EDGE", [("weight", Value::Float64(2.0))]);
417        store.create_edge_with_props(n0, n2, "EDGE", [("weight", Value::Float64(3.0))]);
418        store.create_edge_with_props(n2, n0, "EDGE", [("weight", Value::Float64(3.0))]);
419
420        store
421    }
422
423    fn create_simple_chain() -> LpgStore {
424        // Chain: 0 - 1 - 2 - 3
425        let store = LpgStore::new();
426
427        let n0 = store.create_node(&["Node"]);
428        let n1 = store.create_node(&["Node"]);
429        let n2 = store.create_node(&["Node"]);
430        let n3 = store.create_node(&["Node"]);
431
432        store.create_edge(n0, n1, "EDGE");
433        store.create_edge(n1, n0, "EDGE");
434        store.create_edge(n1, n2, "EDGE");
435        store.create_edge(n2, n1, "EDGE");
436        store.create_edge(n2, n3, "EDGE");
437        store.create_edge(n3, n2, "EDGE");
438
439        store
440    }
441
442    #[test]
443    fn test_kruskal_triangle() {
444        let store = create_weighted_triangle();
445        let result = kruskal(&store, Some("weight"));
446
447        // MST should have 2 edges for 3 nodes
448        assert_eq!(result.edges.len(), 2);
449
450        // Total weight should be 1 + 2 = 3 (not including 0-2 with weight 3)
451        assert!((result.total_weight - 3.0).abs() < 0.001);
452    }
453
454    #[test]
455    fn test_kruskal_chain() {
456        let store = create_simple_chain();
457        let result = kruskal(&store, None);
458
459        // MST should have 3 edges for 4 nodes
460        assert_eq!(result.edges.len(), 3);
461
462        // All edges have default weight 1.0
463        assert!((result.total_weight - 3.0).abs() < 0.001);
464    }
465
466    #[test]
467    fn test_kruskal_empty() {
468        let store = LpgStore::new();
469        let result = kruskal(&store, None);
470
471        assert!(result.edges.is_empty());
472        assert_eq!(result.total_weight, 0.0);
473    }
474
475    #[test]
476    fn test_kruskal_single_node() {
477        let store = LpgStore::new();
478        store.create_node(&["Node"]);
479
480        let result = kruskal(&store, None);
481
482        assert!(result.edges.is_empty());
483        assert!(result.is_spanning_tree(1));
484    }
485
486    #[test]
487    fn test_prim_triangle() {
488        let store = create_weighted_triangle();
489        let result = prim(&store, Some("weight"), None);
490
491        // MST should have 2 edges for 3 nodes
492        assert_eq!(result.edges.len(), 2);
493
494        // Total weight should be 1 + 2 = 3
495        assert!((result.total_weight - 3.0).abs() < 0.001);
496    }
497
498    #[test]
499    fn test_prim_chain() {
500        let store = create_simple_chain();
501        let result = prim(&store, None, None);
502
503        // MST should have 3 edges for 4 nodes
504        assert_eq!(result.edges.len(), 3);
505    }
506
507    #[test]
508    fn test_prim_with_start() {
509        let store = create_simple_chain();
510        let result = prim(&store, None, Some(NodeId::new(2)));
511
512        // Should still find valid MST starting from node 2
513        assert_eq!(result.edges.len(), 3);
514    }
515
516    #[test]
517    fn test_prim_empty() {
518        let store = LpgStore::new();
519        let result = prim(&store, None, None);
520
521        assert!(result.edges.is_empty());
522    }
523
524    #[test]
525    fn test_kruskal_prim_same_weight() {
526        let store = create_weighted_triangle();
527
528        let kruskal_result = kruskal(&store, Some("weight"));
529        let prim_result = prim(&store, Some("weight"), None);
530
531        // Both should have the same total weight
532        assert!((kruskal_result.total_weight - prim_result.total_weight).abs() < 0.001);
533    }
534
535    #[test]
536    fn test_mst_is_spanning_tree() {
537        let store = create_simple_chain();
538        let result = kruskal(&store, None);
539
540        assert!(result.is_spanning_tree(4));
541    }
542}