Skip to main content

libverify_core/
union_find.rs

1//! Union-Find (Disjoint Set Union) data structure for call graph connectivity.
2//!
3//! # Invariants (Creusot)
4//!
5//! ```text
6//! #[invariant(self.parent.len() == self.rank.len())]
7//! #[invariant(forall(|i| i < self.parent.len() ==> self.parent[i] < self.parent.len()))]
8//! ```
9//!
10//! # Properties
11//!
12//! - `find` is idempotent: `find(find(x)) == find(x)`
13//! - `merge` establishes equivalence: after `merge(x, y)`, `find(x) == find(y)`
14//! - `component_count` returns the number of distinct roots among file-kind nodes
15
16// Creusot struct-level #[invariant] requires Creusot nightly compiler.
17// The structural invariants are documented and tested but not yet machine-proved.
18// See doc comments on UnionFind for the intended invariants.
19
20/// The kind of node in the call graph.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum NodeKind {
23    File,
24    Function,
25}
26
27/// Descriptor for a node in the call graph.
28#[derive(Debug, Clone)]
29pub struct NodeDescriptor {
30    pub file_index: u16,
31    pub name: String,
32    pub kind: NodeKind,
33}
34
35/// Union-Find data structure for tracking connected components in a call graph.
36///
37/// Invariant: `parent.len() == rank.len() == nodes.len()`
38/// Invariant: `parent[i] < parent.len()` for all valid `i`
39#[derive(Debug)]
40pub struct UnionFind {
41    nodes: Vec<NodeDescriptor>,
42    parent: Vec<u32>,
43    rank: Vec<u8>,
44}
45
46impl UnionFind {
47    pub fn new() -> Self {
48        Self {
49            nodes: Vec::new(),
50            parent: Vec::new(),
51            rank: Vec::new(),
52        }
53    }
54
55    /// Add a node. Returns existing ID if `(file_index, name)` already present.
56    pub fn add_node(&mut self, file_index: u16, name: &str, kind: NodeKind) -> u32 {
57        // Deduplicate by (file_index, name)
58        for (i, n) in self.nodes.iter().enumerate() {
59            if n.file_index == file_index && n.name == name {
60                return i as u32;
61            }
62        }
63        let id = self.nodes.len() as u32;
64        self.nodes.push(NodeDescriptor {
65            file_index,
66            name: name.to_string(),
67            kind,
68        });
69        self.parent.push(id); // self-parent
70        self.rank.push(0);
71        id
72    }
73
74    /// Find root with path compression.
75    ///
76    /// # Postcondition
77    /// `find(find(x)) == find(x)` (idempotent)
78    pub fn find(&mut self, x: u32) -> u32 {
79        let mut current = x;
80        while self.parent[current as usize] != current {
81            // Path splitting (each node points to its grandparent)
82            let parent = self.parent[current as usize];
83            self.parent[current as usize] = self.parent[parent as usize];
84            current = self.parent[current as usize];
85        }
86        current
87    }
88
89    /// Union by rank.
90    ///
91    /// # Postcondition
92    /// After `merge(a, b)`: `find(a) == find(b)`
93    pub fn merge(&mut self, a: u32, b: u32) {
94        let ra = self.find(a);
95        let rb = self.find(b);
96        if ra == rb {
97            return;
98        }
99        if self.rank[ra as usize] < self.rank[rb as usize] {
100            self.parent[ra as usize] = rb;
101        } else if self.rank[ra as usize] > self.rank[rb as usize] {
102            self.parent[rb as usize] = ra;
103        } else {
104            self.parent[rb as usize] = ra;
105            self.rank[ra as usize] += 1;
106        }
107    }
108
109    /// Count distinct connected components among file-kind nodes only.
110    pub fn component_count(&mut self) -> usize {
111        // Collect file node indices first to avoid borrow conflict
112        let file_indices: Vec<usize> = self
113            .nodes
114            .iter()
115            .enumerate()
116            .filter(|(_, n)| n.kind == NodeKind::File)
117            .map(|(i, _)| i)
118            .collect();
119
120        let mut roots = std::collections::HashSet::new();
121        for i in file_indices {
122            let root = self.find(i as u32);
123            roots.insert(root);
124        }
125        roots.len()
126    }
127
128    /// Return file indices grouped by component.
129    pub fn get_components(&mut self) -> Vec<Vec<u16>> {
130        // Collect file node info first to avoid borrow conflict
131        let file_nodes: Vec<(usize, u16)> = self
132            .nodes
133            .iter()
134            .enumerate()
135            .filter(|(_, n)| n.kind == NodeKind::File)
136            .map(|(i, n)| (i, n.file_index))
137            .collect();
138
139        let mut comp_map: std::collections::HashMap<u32, Vec<u16>> =
140            std::collections::HashMap::new();
141
142        for (i, file_index) in file_nodes {
143            let root = self.find(i as u32);
144            comp_map.entry(root).or_default().push(file_index);
145        }
146
147        comp_map.into_values().collect()
148    }
149
150    /// Get the node descriptor for a given ID.
151    pub fn get_node(&self, id: u32) -> Option<&NodeDescriptor> {
152        self.nodes.get(id as usize)
153    }
154
155    /// Number of nodes in the graph.
156    pub fn len(&self) -> usize {
157        self.nodes.len()
158    }
159
160    /// Returns true if the graph has no nodes.
161    pub fn is_empty(&self) -> bool {
162        self.nodes.is_empty()
163    }
164}
165
166impl Default for UnionFind {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn empty_graph() {
178        let mut uf = UnionFind::new();
179        assert_eq!(uf.component_count(), 0);
180        assert!(uf.is_empty());
181    }
182
183    #[test]
184    fn single_node() {
185        let mut uf = UnionFind::new();
186        let a = uf.add_node(0, "main.rs", NodeKind::File);
187        assert_eq!(uf.find(a), a);
188        assert_eq!(uf.component_count(), 1);
189    }
190
191    #[test]
192    fn deduplication() {
193        let mut uf = UnionFind::new();
194        let a = uf.add_node(0, "main.rs", NodeKind::File);
195        let b = uf.add_node(0, "main.rs", NodeKind::File);
196        assert_eq!(a, b);
197        assert_eq!(uf.len(), 1);
198    }
199
200    #[test]
201    fn merge_reduces_components() {
202        let mut uf = UnionFind::new();
203        let a = uf.add_node(0, "a.rs", NodeKind::File);
204        let b = uf.add_node(1, "b.rs", NodeKind::File);
205        assert_eq!(uf.component_count(), 2);
206
207        uf.merge(a, b);
208        assert_eq!(uf.component_count(), 1);
209        assert_eq!(uf.find(a), uf.find(b));
210    }
211
212    #[test]
213    fn find_is_idempotent() {
214        let mut uf = UnionFind::new();
215        let a = uf.add_node(0, "a.rs", NodeKind::File);
216        let b = uf.add_node(1, "b.rs", NodeKind::File);
217        uf.merge(a, b);
218
219        let root1 = uf.find(a);
220        let root2 = uf.find(root1);
221        assert_eq!(root1, root2);
222    }
223
224    #[test]
225    fn merge_is_symmetric() {
226        let mut uf1 = UnionFind::new();
227        let a1 = uf1.add_node(0, "a.rs", NodeKind::File);
228        let b1 = uf1.add_node(1, "b.rs", NodeKind::File);
229        uf1.merge(a1, b1);
230
231        let mut uf2 = UnionFind::new();
232        let a2 = uf2.add_node(0, "a.rs", NodeKind::File);
233        let b2 = uf2.add_node(1, "b.rs", NodeKind::File);
234        uf2.merge(b2, a2);
235
236        // Both should have same component count
237        assert_eq!(uf1.component_count(), uf2.component_count());
238        assert_eq!(uf1.find(a1), uf1.find(b1));
239        assert_eq!(uf2.find(a2), uf2.find(b2));
240    }
241
242    #[test]
243    fn merge_is_transitive() {
244        let mut uf = UnionFind::new();
245        let a = uf.add_node(0, "a.rs", NodeKind::File);
246        let b = uf.add_node(1, "b.rs", NodeKind::File);
247        let c_node = uf.add_node(2, "c.rs", NodeKind::File);
248
249        uf.merge(a, b);
250        uf.merge(b, c_node);
251
252        assert_eq!(uf.find(a), uf.find(c_node));
253        assert_eq!(uf.component_count(), 1);
254    }
255
256    #[test]
257    fn function_nodes_dont_count_as_components() {
258        let mut uf = UnionFind::new();
259        let file_a = uf.add_node(0, "a.rs", NodeKind::File);
260        let _fn_a = uf.add_node(0, "foo", NodeKind::Function);
261        let file_b = uf.add_node(1, "b.rs", NodeKind::File);
262
263        // 2 file components (function node doesn't count)
264        assert_eq!(uf.component_count(), 2);
265
266        // Merging function with its file doesn't change file component count
267        uf.merge(file_a, _fn_a);
268        assert_eq!(uf.component_count(), 2);
269
270        // But merging files does
271        uf.merge(file_a, file_b);
272        assert_eq!(uf.component_count(), 1);
273    }
274
275    #[test]
276    fn get_components_returns_grouped_indices() {
277        let mut uf = UnionFind::new();
278        let _a = uf.add_node(0, "a.rs", NodeKind::File);
279        let _b = uf.add_node(1, "b.rs", NodeKind::File);
280        let _c = uf.add_node(2, "c.rs", NodeKind::File);
281
282        uf.merge(_a, _b);
283
284        let mut components = uf.get_components();
285        components.sort_by_key(|c| c[0]);
286
287        assert_eq!(components.len(), 2);
288        // One group has {0, 1}, the other has {2}
289        let group_with_a = components.iter().find(|g| g.contains(&0)).unwrap();
290        assert!(group_with_a.contains(&1));
291        let group_with_c = components.iter().find(|g| g.contains(&2)).unwrap();
292        assert_eq!(group_with_c.len(), 1);
293    }
294
295    /// Invariant: parent.len() == rank.len() == nodes.len()
296    #[test]
297    fn structural_invariant_maintained() {
298        let mut uf = UnionFind::new();
299        for i in 0..100 {
300            uf.add_node(i, &format!("file_{i}"), NodeKind::File);
301            assert_eq!(uf.nodes.len(), uf.parent.len());
302            assert_eq!(uf.nodes.len(), uf.rank.len());
303        }
304
305        // After merges, invariant still holds
306        for i in 0..99 {
307            uf.merge(i, i + 1);
308            assert_eq!(uf.nodes.len(), uf.parent.len());
309            assert_eq!(uf.nodes.len(), uf.rank.len());
310        }
311    }
312
313    /// Invariant: parent[i] < parent.len() for all i
314    #[test]
315    fn parent_bounds_invariant() {
316        let mut uf = UnionFind::new();
317        for i in 0..50u16 {
318            uf.add_node(i, &format!("f{i}"), NodeKind::File);
319        }
320        // Random-ish merges
321        for i in (0..49).step_by(2) {
322            uf.merge(i, i + 1);
323        }
324        for i in (0..48).step_by(4) {
325            uf.merge(i, i + 2);
326        }
327
328        for (i, &p) in uf.parent.iter().enumerate() {
329            assert!(
330                (p as usize) < uf.parent.len(),
331                "parent[{i}] = {p} >= len {}",
332                uf.parent.len()
333            );
334        }
335    }
336}