Skip to main content

graphmind_graph_algorithms/
mst.rs

1//! Minimum Spanning Tree algorithms
2//!
3//! Implements Prim's algorithm for MST.
4
5use super::common::{GraphView, NodeId};
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashSet};
8
9pub struct MSTResult {
10    pub total_weight: f64,
11    pub edges: Vec<(NodeId, NodeId, f64)>, // (source, target, weight)
12}
13
14#[derive(Copy, Clone, PartialEq)]
15struct EdgeState {
16    weight: f64,
17    source: usize,
18    target: usize,
19}
20
21impl Eq for EdgeState {}
22
23impl Ord for EdgeState {
24    fn cmp(&self, other: &Self) -> Ordering {
25        // Reverse for min-heap
26        other
27            .weight
28            .partial_cmp(&self.weight)
29            .unwrap_or(Ordering::Equal)
30    }
31}
32
33impl PartialOrd for EdgeState {
34    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39/// Prim's Algorithm for Minimum Spanning Tree
40///
41/// Treats graph as undirected (ignores edge direction).
42/// If graph is disconnected, returns MST of the component containing `start_node`
43/// (or arbitrary node if not specified).
44pub fn prim_mst(view: &GraphView) -> MSTResult {
45    if view.node_count == 0 {
46        return MSTResult {
47            total_weight: 0.0,
48            edges: Vec::new(),
49        };
50    }
51
52    let start_idx = 0; // Start from first node
53    let mut visited = HashSet::new();
54    let mut heap = BinaryHeap::new();
55    let mut mst_edges = Vec::new();
56    let mut total_weight = 0.0;
57
58    visited.insert(start_idx);
59
60    // Add initial edges
61    add_edges(view, start_idx, &mut heap, &visited);
62
63    while let Some(EdgeState {
64        weight,
65        source,
66        target,
67    }) = heap.pop()
68    {
69        if visited.contains(&target) {
70            continue;
71        }
72
73        visited.insert(target);
74        mst_edges.push((
75            view.index_to_node[source],
76            view.index_to_node[target],
77            weight,
78        ));
79        total_weight += weight;
80
81        add_edges(view, target, &mut heap, &visited);
82    }
83
84    MSTResult {
85        total_weight,
86        edges: mst_edges,
87    }
88}
89
90fn add_edges(
91    view: &GraphView,
92    u: usize,
93    heap: &mut BinaryHeap<EdgeState>,
94    visited: &HashSet<usize>,
95) {
96    // Check outgoing edges
97    let u_out = view.successors(u);
98    for (i, &v) in u_out.iter().enumerate() {
99        if !visited.contains(&v) {
100            let weight = view.weights(u).map(|w| w[i]).unwrap_or(1.0);
101            heap.push(EdgeState {
102                weight,
103                source: u,
104                target: v,
105            });
106        }
107    }
108
109    // Check incoming edges (treat as undirected)
110    let u_in = view.predecessors(u);
111    for &_v in u_in.iter() {
112        let v = _v; // explicit copy
113        if !visited.contains(&v) {
114            // Need to find weight in incoming list?
115            // GraphView structure: incoming[u] contains v implies edge v->u exists.
116
117            let v_out = view.successors(v);
118            if let Some(idx) = v_out.iter().position(|&x| x == u) {
119                let weight = view.weights(v).map(|w| w[idx]).unwrap_or(1.0);
120                heap.push(EdgeState {
121                    weight,
122                    source: u,
123                    target: v,
124                }); // "source" here is just the connection point in MST
125            }
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::collections::HashMap;
134
135    #[test]
136    fn test_prim_mst() {
137        // Triangle: 1-2 (1), 2-3 (2), 1-3 (10)
138        // MST should be 1-2, 2-3. Total 3.
139
140        let node_count = 3;
141        let index_to_node = vec![1, 2, 3];
142        let mut node_to_index = HashMap::new();
143        node_to_index.insert(1, 0);
144        node_to_index.insert(2, 1);
145        node_to_index.insert(3, 2);
146
147        let mut outgoing = vec![vec![]; 3];
148        let mut incoming = vec![vec![]; 3];
149        let mut weights = vec![vec![]; 3];
150
151        // 1->2 (1)
152        outgoing[0].push(1);
153        incoming[1].push(0);
154        weights[0].push(1.0);
155        // 2->1 (1) - Undirected explicitly stored?
156        outgoing[1].push(0);
157        incoming[0].push(1);
158        weights[1].push(1.0);
159
160        // 2->3 (2)
161        outgoing[1].push(2);
162        incoming[2].push(1);
163        weights[1].push(2.0);
164        // 3->2 (2)
165        outgoing[2].push(1);
166        incoming[1].push(2);
167        weights[2].push(2.0);
168
169        // 1->3 (10)
170        outgoing[0].push(2);
171        incoming[2].push(0);
172        weights[0].push(10.0);
173        // 3->1 (10)
174        outgoing[2].push(0);
175        incoming[0].push(2);
176        weights[2].push(10.0);
177
178        let view = GraphView::from_adjacency_list(
179            node_count,
180            index_to_node,
181            node_to_index,
182            outgoing,
183            incoming,
184            Some(weights),
185        );
186
187        let result = prim_mst(&view);
188        assert_eq!(result.total_weight, 3.0);
189        assert_eq!(result.edges.len(), 2);
190    }
191}