1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum NodeKind {
23 File,
24 Function,
25}
26
27#[derive(Debug, Clone)]
29pub struct NodeDescriptor {
30 pub file_index: u16,
31 pub name: String,
32 pub kind: NodeKind,
33}
34
35#[derive(Debug)]
40pub struct UnionFind {
41 nodes: Vec<NodeDescriptor>,
42 parent: Vec<u32>,
43 rank: Vec<u8>,
44}
45
46impl UnionFind {
47 pub fn new() -> Self {
48 Self {
49 nodes: Vec::new(),
50 parent: Vec::new(),
51 rank: Vec::new(),
52 }
53 }
54
55 pub fn add_node(&mut self, file_index: u16, name: &str, kind: NodeKind) -> u32 {
57 for (i, n) in self.nodes.iter().enumerate() {
59 if n.file_index == file_index && n.name == name {
60 return i as u32;
61 }
62 }
63 let id = self.nodes.len() as u32;
64 self.nodes.push(NodeDescriptor {
65 file_index,
66 name: name.to_string(),
67 kind,
68 });
69 self.parent.push(id); self.rank.push(0);
71 id
72 }
73
74 pub fn find(&mut self, x: u32) -> u32 {
79 let mut current = x;
80 while self.parent[current as usize] != current {
81 let parent = self.parent[current as usize];
83 self.parent[current as usize] = self.parent[parent as usize];
84 current = self.parent[current as usize];
85 }
86 current
87 }
88
89 pub fn merge(&mut self, a: u32, b: u32) {
94 let ra = self.find(a);
95 let rb = self.find(b);
96 if ra == rb {
97 return;
98 }
99 if self.rank[ra as usize] < self.rank[rb as usize] {
100 self.parent[ra as usize] = rb;
101 } else if self.rank[ra as usize] > self.rank[rb as usize] {
102 self.parent[rb as usize] = ra;
103 } else {
104 self.parent[rb as usize] = ra;
105 self.rank[ra as usize] += 1;
106 }
107 }
108
109 pub fn component_count(&mut self) -> usize {
111 let file_indices: Vec<usize> = self
113 .nodes
114 .iter()
115 .enumerate()
116 .filter(|(_, n)| n.kind == NodeKind::File)
117 .map(|(i, _)| i)
118 .collect();
119
120 let mut roots = std::collections::HashSet::new();
121 for i in file_indices {
122 let root = self.find(i as u32);
123 roots.insert(root);
124 }
125 roots.len()
126 }
127
128 pub fn get_components(&mut self) -> Vec<Vec<u16>> {
130 let file_nodes: Vec<(usize, u16)> = self
132 .nodes
133 .iter()
134 .enumerate()
135 .filter(|(_, n)| n.kind == NodeKind::File)
136 .map(|(i, n)| (i, n.file_index))
137 .collect();
138
139 let mut comp_map: std::collections::HashMap<u32, Vec<u16>> =
140 std::collections::HashMap::new();
141
142 for (i, file_index) in file_nodes {
143 let root = self.find(i as u32);
144 comp_map.entry(root).or_default().push(file_index);
145 }
146
147 comp_map.into_values().collect()
148 }
149
150 pub fn get_node(&self, id: u32) -> Option<&NodeDescriptor> {
152 self.nodes.get(id as usize)
153 }
154
155 pub fn len(&self) -> usize {
157 self.nodes.len()
158 }
159
160 pub fn is_empty(&self) -> bool {
162 self.nodes.is_empty()
163 }
164}
165
166impl Default for UnionFind {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn empty_graph() {
178 let mut uf = UnionFind::new();
179 assert_eq!(uf.component_count(), 0);
180 assert!(uf.is_empty());
181 }
182
183 #[test]
184 fn single_node() {
185 let mut uf = UnionFind::new();
186 let a = uf.add_node(0, "main.rs", NodeKind::File);
187 assert_eq!(uf.find(a), a);
188 assert_eq!(uf.component_count(), 1);
189 }
190
191 #[test]
192 fn deduplication() {
193 let mut uf = UnionFind::new();
194 let a = uf.add_node(0, "main.rs", NodeKind::File);
195 let b = uf.add_node(0, "main.rs", NodeKind::File);
196 assert_eq!(a, b);
197 assert_eq!(uf.len(), 1);
198 }
199
200 #[test]
201 fn merge_reduces_components() {
202 let mut uf = UnionFind::new();
203 let a = uf.add_node(0, "a.rs", NodeKind::File);
204 let b = uf.add_node(1, "b.rs", NodeKind::File);
205 assert_eq!(uf.component_count(), 2);
206
207 uf.merge(a, b);
208 assert_eq!(uf.component_count(), 1);
209 assert_eq!(uf.find(a), uf.find(b));
210 }
211
212 #[test]
213 fn find_is_idempotent() {
214 let mut uf = UnionFind::new();
215 let a = uf.add_node(0, "a.rs", NodeKind::File);
216 let b = uf.add_node(1, "b.rs", NodeKind::File);
217 uf.merge(a, b);
218
219 let root1 = uf.find(a);
220 let root2 = uf.find(root1);
221 assert_eq!(root1, root2);
222 }
223
224 #[test]
225 fn merge_is_symmetric() {
226 let mut uf1 = UnionFind::new();
227 let a1 = uf1.add_node(0, "a.rs", NodeKind::File);
228 let b1 = uf1.add_node(1, "b.rs", NodeKind::File);
229 uf1.merge(a1, b1);
230
231 let mut uf2 = UnionFind::new();
232 let a2 = uf2.add_node(0, "a.rs", NodeKind::File);
233 let b2 = uf2.add_node(1, "b.rs", NodeKind::File);
234 uf2.merge(b2, a2);
235
236 assert_eq!(uf1.component_count(), uf2.component_count());
238 assert_eq!(uf1.find(a1), uf1.find(b1));
239 assert_eq!(uf2.find(a2), uf2.find(b2));
240 }
241
242 #[test]
243 fn merge_is_transitive() {
244 let mut uf = UnionFind::new();
245 let a = uf.add_node(0, "a.rs", NodeKind::File);
246 let b = uf.add_node(1, "b.rs", NodeKind::File);
247 let c_node = uf.add_node(2, "c.rs", NodeKind::File);
248
249 uf.merge(a, b);
250 uf.merge(b, c_node);
251
252 assert_eq!(uf.find(a), uf.find(c_node));
253 assert_eq!(uf.component_count(), 1);
254 }
255
256 #[test]
257 fn function_nodes_dont_count_as_components() {
258 let mut uf = UnionFind::new();
259 let file_a = uf.add_node(0, "a.rs", NodeKind::File);
260 let _fn_a = uf.add_node(0, "foo", NodeKind::Function);
261 let file_b = uf.add_node(1, "b.rs", NodeKind::File);
262
263 assert_eq!(uf.component_count(), 2);
265
266 uf.merge(file_a, _fn_a);
268 assert_eq!(uf.component_count(), 2);
269
270 uf.merge(file_a, file_b);
272 assert_eq!(uf.component_count(), 1);
273 }
274
275 #[test]
276 fn get_components_returns_grouped_indices() {
277 let mut uf = UnionFind::new();
278 let _a = uf.add_node(0, "a.rs", NodeKind::File);
279 let _b = uf.add_node(1, "b.rs", NodeKind::File);
280 let _c = uf.add_node(2, "c.rs", NodeKind::File);
281
282 uf.merge(_a, _b);
283
284 let mut components = uf.get_components();
285 components.sort_by_key(|c| c[0]);
286
287 assert_eq!(components.len(), 2);
288 let group_with_a = components.iter().find(|g| g.contains(&0)).unwrap();
290 assert!(group_with_a.contains(&1));
291 let group_with_c = components.iter().find(|g| g.contains(&2)).unwrap();
292 assert_eq!(group_with_c.len(), 1);
293 }
294
295 #[test]
297 fn structural_invariant_maintained() {
298 let mut uf = UnionFind::new();
299 for i in 0..100 {
300 uf.add_node(i, &format!("file_{i}"), NodeKind::File);
301 assert_eq!(uf.nodes.len(), uf.parent.len());
302 assert_eq!(uf.nodes.len(), uf.rank.len());
303 }
304
305 for i in 0..99 {
307 uf.merge(i, i + 1);
308 assert_eq!(uf.nodes.len(), uf.parent.len());
309 assert_eq!(uf.nodes.len(), uf.rank.len());
310 }
311 }
312
313 #[test]
315 fn parent_bounds_invariant() {
316 let mut uf = UnionFind::new();
317 for i in 0..50u16 {
318 uf.add_node(i, &format!("f{i}"), NodeKind::File);
319 }
320 for i in (0..49).step_by(2) {
322 uf.merge(i, i + 1);
323 }
324 for i in (0..48).step_by(4) {
325 uf.merge(i, i + 2);
326 }
327
328 for (i, &p) in uf.parent.iter().enumerate() {
329 assert!(
330 (p as usize) < uf.parent.len(),
331 "parent[{i}] = {p} >= len {}",
332 uf.parent.len()
333 );
334 }
335 }
336}