Skip to main content

graphos_adapters/plugins/algorithms/
mst.rs

1//! Minimum Spanning Tree algorithms: Kruskal, Prim.
2//!
3//! These algorithms find a tree that connects all nodes in an undirected
4//! graph with minimum total edge weight.
5
6use std::collections::BinaryHeap;
7use std::sync::OnceLock;
8
9use graphos_common::types::{EdgeId, NodeId, Value};
10use graphos_common::utils::error::Result;
11use graphos_common::utils::hash::FxHashMap;
12use graphos_core::graph::Direction;
13use graphos_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::components::UnionFind;
17use super::traits::{GraphAlgorithm, MinScored};
18
19// ============================================================================
20// Edge Weight Extraction
21// ============================================================================
22
23/// Extracts edge weight from a property value.
24fn extract_weight(store: &LpgStore, edge_id: EdgeId, weight_prop: Option<&str>) -> f64 {
25    if let Some(prop_name) = weight_prop {
26        if let Some(edge) = store.get_edge(edge_id) {
27            if let Some(value) = edge.get_property(prop_name) {
28                return match value {
29                    Value::Int64(i) => *i as f64,
30                    Value::Float64(f) => *f,
31                    _ => 1.0,
32                };
33            }
34        }
35    }
36    1.0
37}
38
39// ============================================================================
40// MST Result
41// ============================================================================
42
43/// Result of MST algorithms.
44#[derive(Debug, Clone)]
45pub struct MstResult {
46    /// Edges in the MST: (source, target, edge_id, weight)
47    pub edges: Vec<(NodeId, NodeId, EdgeId, f64)>,
48    /// Total weight of the MST.
49    pub total_weight: f64,
50}
51
52impl MstResult {
53    /// Returns the number of edges in the MST.
54    pub fn edge_count(&self) -> usize {
55        self.edges.len()
56    }
57
58    /// Returns true if this is a valid spanning tree for n nodes.
59    pub fn is_spanning_tree(&self, node_count: usize) -> bool {
60        if node_count == 0 {
61            return self.edges.is_empty();
62        }
63        self.edges.len() == node_count - 1
64    }
65}
66
67// ============================================================================
68// Kruskal's Algorithm
69// ============================================================================
70
71/// Computes the Minimum Spanning Tree using Kruskal's algorithm.
72///
73/// Kruskal's algorithm sorts all edges by weight and greedily adds
74/// edges that don't create a cycle (using Union-Find).
75///
76/// # Arguments
77///
78/// * `store` - The graph store
79/// * `weight_property` - Optional property name for edge weights (defaults to 1.0)
80///
81/// # Returns
82///
83/// The MST edges and total weight.
84///
85/// # Complexity
86///
87/// O(E log E) for sorting edges
88pub fn kruskal(store: &LpgStore, weight_property: Option<&str>) -> MstResult {
89    let nodes = store.node_ids();
90    let n = nodes.len();
91
92    if n == 0 {
93        return MstResult {
94            edges: Vec::new(),
95            total_weight: 0.0,
96        };
97    }
98
99    // Build node index mapping
100    let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
101    for (idx, &node) in nodes.iter().enumerate() {
102        node_to_idx.insert(node, idx);
103    }
104
105    // Collect all edges with weights (treating as undirected)
106    let mut edges: Vec<(f64, NodeId, NodeId, EdgeId)> = Vec::new();
107    let mut seen_edges: std::collections::HashSet<(usize, usize)> =
108        std::collections::HashSet::new();
109
110    for &node in &nodes {
111        let i = *node_to_idx.get(&node).unwrap();
112        for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
113            if let Some(&j) = node_to_idx.get(&neighbor) {
114                // For undirected: only add each edge once
115                let key = if i < j { (i, j) } else { (j, i) };
116                if !seen_edges.contains(&key) {
117                    seen_edges.insert(key);
118                    let weight = extract_weight(store, edge_id, weight_property);
119                    edges.push((weight, node, neighbor, edge_id));
120                }
121            }
122        }
123    }
124
125    // Sort edges by weight
126    edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
127
128    // Initialize Union-Find
129    let mut uf = UnionFind::new(n);
130
131    let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
132    let mut total_weight = 0.0;
133
134    for (weight, src, dst, edge_id) in edges {
135        let i = *node_to_idx.get(&src).unwrap();
136        let j = *node_to_idx.get(&dst).unwrap();
137
138        if uf.find(i) != uf.find(j) {
139            uf.union(i, j);
140            mst_edges.push((src, dst, edge_id, weight));
141            total_weight += weight;
142
143            // MST has n-1 edges
144            if mst_edges.len() == n - 1 {
145                break;
146            }
147        }
148    }
149
150    MstResult {
151        edges: mst_edges,
152        total_weight,
153    }
154}
155
156// ============================================================================
157// Prim's Algorithm
158// ============================================================================
159
160/// Computes the Minimum Spanning Tree using Prim's algorithm.
161///
162/// Prim's algorithm grows the MST from a starting node, always adding
163/// the minimum weight edge that connects a tree node to a non-tree node.
164///
165/// # Arguments
166///
167/// * `store` - The graph store
168/// * `weight_property` - Optional property name for edge weights (defaults to 1.0)
169/// * `start` - Optional starting node (defaults to first node)
170///
171/// # Returns
172///
173/// The MST edges and total weight.
174///
175/// # Complexity
176///
177/// O(E log V) using a binary heap
178pub fn prim(store: &LpgStore, weight_property: Option<&str>, start: Option<NodeId>) -> MstResult {
179    let nodes = store.node_ids();
180    let n = nodes.len();
181
182    if n == 0 {
183        return MstResult {
184            edges: Vec::new(),
185            total_weight: 0.0,
186        };
187    }
188
189    // Start from the first node or specified start
190    let start_node = start.unwrap_or(nodes[0]);
191
192    // Verify start node exists
193    if store.get_node(start_node).is_none() {
194        return MstResult {
195            edges: Vec::new(),
196            total_weight: 0.0,
197        };
198    }
199
200    let mut in_tree: FxHashMap<NodeId, bool> = FxHashMap::default();
201    let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
202    let mut total_weight = 0.0;
203
204    // Priority queue: (weight, source, target, edge_id)
205    let mut heap: BinaryHeap<MinScored<f64, (NodeId, NodeId, EdgeId)>> = BinaryHeap::new();
206
207    // Start with the first node
208    in_tree.insert(start_node, true);
209
210    // Add edges from start node
211    for (neighbor, edge_id) in store.edges_from(start_node, Direction::Outgoing) {
212        let weight = extract_weight(store, edge_id, weight_property);
213        heap.push(MinScored::new(weight, (start_node, neighbor, edge_id)));
214    }
215
216    // Also consider incoming edges (for undirected behavior)
217    for &other in &nodes {
218        for (neighbor, edge_id) in store.edges_from(other, Direction::Outgoing) {
219            if neighbor == start_node {
220                let weight = extract_weight(store, edge_id, weight_property);
221                heap.push(MinScored::new(weight, (other, start_node, edge_id)));
222            }
223        }
224    }
225
226    while let Some(MinScored(weight, (src, dst, edge_id))) = heap.pop() {
227        // Skip if target already in tree
228        if *in_tree.get(&dst).unwrap_or(&false) {
229            continue;
230        }
231
232        // Add edge to MST
233        in_tree.insert(dst, true);
234        mst_edges.push((src, dst, edge_id, weight));
235        total_weight += weight;
236
237        // Add edges from new node
238        for (neighbor, new_edge_id) in store.edges_from(dst, Direction::Outgoing) {
239            if !*in_tree.get(&neighbor).unwrap_or(&false) {
240                let new_weight = extract_weight(store, new_edge_id, weight_property);
241                heap.push(MinScored::new(new_weight, (dst, neighbor, new_edge_id)));
242            }
243        }
244
245        // Also consider incoming edges
246        for &other in &nodes {
247            if !*in_tree.get(&other).unwrap_or(&false) {
248                for (neighbor, new_edge_id) in store.edges_from(other, Direction::Outgoing) {
249                    if neighbor == dst {
250                        let new_weight = extract_weight(store, new_edge_id, weight_property);
251                        heap.push(MinScored::new(new_weight, (other, dst, new_edge_id)));
252                    }
253                }
254            }
255        }
256
257        // MST has n-1 edges
258        if mst_edges.len() == n - 1 {
259            break;
260        }
261    }
262
263    MstResult {
264        edges: mst_edges,
265        total_weight,
266    }
267}
268
269// ============================================================================
270// Algorithm Wrappers for Plugin Registry
271// ============================================================================
272
273/// Static parameter definitions for Kruskal algorithm.
274static KRUSKAL_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
275
276fn kruskal_params() -> &'static [ParameterDef] {
277    KRUSKAL_PARAMS.get_or_init(|| {
278        vec![ParameterDef {
279            name: "weight".to_string(),
280            description: "Edge property name for weights (default: 1.0)".to_string(),
281            param_type: ParameterType::String,
282            required: false,
283            default: None,
284        }]
285    })
286}
287
288/// Kruskal's MST algorithm wrapper.
289pub struct KruskalAlgorithm;
290
291impl GraphAlgorithm for KruskalAlgorithm {
292    fn name(&self) -> &str {
293        "kruskal"
294    }
295
296    fn description(&self) -> &str {
297        "Kruskal's Minimum Spanning Tree algorithm"
298    }
299
300    fn parameters(&self) -> &[ParameterDef] {
301        kruskal_params()
302    }
303
304    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
305        let weight_prop = params.get_string("weight");
306
307        let result = kruskal(store, weight_prop);
308
309        let mut output = AlgorithmResult::new(vec![
310            "source".to_string(),
311            "target".to_string(),
312            "weight".to_string(),
313            "total_weight".to_string(),
314        ]);
315
316        for (src, dst, _edge_id, weight) in result.edges {
317            output.add_row(vec![
318                Value::Int64(src.0 as i64),
319                Value::Int64(dst.0 as i64),
320                Value::Float64(weight),
321                Value::Float64(result.total_weight),
322            ]);
323        }
324
325        Ok(output)
326    }
327}
328
329/// Static parameter definitions for Prim algorithm.
330static PRIM_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
331
332fn prim_params() -> &'static [ParameterDef] {
333    PRIM_PARAMS.get_or_init(|| {
334        vec![
335            ParameterDef {
336                name: "weight".to_string(),
337                description: "Edge property name for weights (default: 1.0)".to_string(),
338                param_type: ParameterType::String,
339                required: false,
340                default: None,
341            },
342            ParameterDef {
343                name: "start".to_string(),
344                description: "Starting node ID (optional)".to_string(),
345                param_type: ParameterType::NodeId,
346                required: false,
347                default: None,
348            },
349        ]
350    })
351}
352
353/// Prim's MST algorithm wrapper.
354pub struct PrimAlgorithm;
355
356impl GraphAlgorithm for PrimAlgorithm {
357    fn name(&self) -> &str {
358        "prim"
359    }
360
361    fn description(&self) -> &str {
362        "Prim's Minimum Spanning Tree algorithm"
363    }
364
365    fn parameters(&self) -> &[ParameterDef] {
366        prim_params()
367    }
368
369    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
370        let weight_prop = params.get_string("weight");
371        let start = params.get_int("start").map(|id| NodeId::new(id as u64));
372
373        let result = prim(store, weight_prop, start);
374
375        let mut output = AlgorithmResult::new(vec![
376            "source".to_string(),
377            "target".to_string(),
378            "weight".to_string(),
379            "total_weight".to_string(),
380        ]);
381
382        for (src, dst, _edge_id, weight) in result.edges {
383            output.add_row(vec![
384                Value::Int64(src.0 as i64),
385                Value::Int64(dst.0 as i64),
386                Value::Float64(weight),
387                Value::Float64(result.total_weight),
388            ]);
389        }
390
391        Ok(output)
392    }
393}
394
395// ============================================================================
396// Tests
397// ============================================================================
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    fn create_weighted_triangle() -> LpgStore {
404        // Triangle: 0-1-2 with edges
405        // 0-1: weight 1
406        // 1-2: weight 2
407        // 0-2: weight 3
408        let store = LpgStore::new();
409
410        let n0 = store.create_node(&["Node"]);
411        let n1 = store.create_node(&["Node"]);
412        let n2 = store.create_node(&["Node"]);
413
414        store.create_edge_with_props(n0, n1, "EDGE", [("weight", Value::Float64(1.0))]);
415        store.create_edge_with_props(n1, n0, "EDGE", [("weight", Value::Float64(1.0))]);
416        store.create_edge_with_props(n1, n2, "EDGE", [("weight", Value::Float64(2.0))]);
417        store.create_edge_with_props(n2, n1, "EDGE", [("weight", Value::Float64(2.0))]);
418        store.create_edge_with_props(n0, n2, "EDGE", [("weight", Value::Float64(3.0))]);
419        store.create_edge_with_props(n2, n0, "EDGE", [("weight", Value::Float64(3.0))]);
420
421        store
422    }
423
424    fn create_simple_chain() -> LpgStore {
425        // Chain: 0 - 1 - 2 - 3
426        let store = LpgStore::new();
427
428        let n0 = store.create_node(&["Node"]);
429        let n1 = store.create_node(&["Node"]);
430        let n2 = store.create_node(&["Node"]);
431        let n3 = store.create_node(&["Node"]);
432
433        store.create_edge(n0, n1, "EDGE");
434        store.create_edge(n1, n0, "EDGE");
435        store.create_edge(n1, n2, "EDGE");
436        store.create_edge(n2, n1, "EDGE");
437        store.create_edge(n2, n3, "EDGE");
438        store.create_edge(n3, n2, "EDGE");
439
440        store
441    }
442
443    #[test]
444    fn test_kruskal_triangle() {
445        let store = create_weighted_triangle();
446        let result = kruskal(&store, Some("weight"));
447
448        // MST should have 2 edges for 3 nodes
449        assert_eq!(result.edges.len(), 2);
450
451        // Total weight should be 1 + 2 = 3 (not including 0-2 with weight 3)
452        assert!((result.total_weight - 3.0).abs() < 0.001);
453    }
454
455    #[test]
456    fn test_kruskal_chain() {
457        let store = create_simple_chain();
458        let result = kruskal(&store, None);
459
460        // MST should have 3 edges for 4 nodes
461        assert_eq!(result.edges.len(), 3);
462
463        // All edges have default weight 1.0
464        assert!((result.total_weight - 3.0).abs() < 0.001);
465    }
466
467    #[test]
468    fn test_kruskal_empty() {
469        let store = LpgStore::new();
470        let result = kruskal(&store, None);
471
472        assert!(result.edges.is_empty());
473        assert_eq!(result.total_weight, 0.0);
474    }
475
476    #[test]
477    fn test_kruskal_single_node() {
478        let store = LpgStore::new();
479        store.create_node(&["Node"]);
480
481        let result = kruskal(&store, None);
482
483        assert!(result.edges.is_empty());
484        assert!(result.is_spanning_tree(1));
485    }
486
487    #[test]
488    fn test_prim_triangle() {
489        let store = create_weighted_triangle();
490        let result = prim(&store, Some("weight"), None);
491
492        // MST should have 2 edges for 3 nodes
493        assert_eq!(result.edges.len(), 2);
494
495        // Total weight should be 1 + 2 = 3
496        assert!((result.total_weight - 3.0).abs() < 0.001);
497    }
498
499    #[test]
500    fn test_prim_chain() {
501        let store = create_simple_chain();
502        let result = prim(&store, None, None);
503
504        // MST should have 3 edges for 4 nodes
505        assert_eq!(result.edges.len(), 3);
506    }
507
508    #[test]
509    fn test_prim_with_start() {
510        let store = create_simple_chain();
511        let result = prim(&store, None, Some(NodeId::new(2)));
512
513        // Should still find valid MST starting from node 2
514        assert_eq!(result.edges.len(), 3);
515    }
516
517    #[test]
518    fn test_prim_empty() {
519        let store = LpgStore::new();
520        let result = prim(&store, None, None);
521
522        assert!(result.edges.is_empty());
523    }
524
525    #[test]
526    fn test_kruskal_prim_same_weight() {
527        let store = create_weighted_triangle();
528
529        let kruskal_result = kruskal(&store, Some("weight"));
530        let prim_result = prim(&store, Some("weight"), None);
531
532        // Both should have the same total weight
533        assert!((kruskal_result.total_weight - prim_result.total_weight).abs() < 0.001);
534    }
535
536    #[test]
537    fn test_mst_is_spanning_tree() {
538        let store = create_simple_chain();
539        let result = kruskal(&store, None);
540
541        assert!(result.is_spanning_tree(4));
542    }
543}