react_compiler_utils/disjoint_set.rs
1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! A generic disjoint-set (union-find) data structure.
7//!
8//! Ported from TypeScript `src/Utils/DisjointSet.ts`.
9
10use std::collections::HashSet;
11use std::hash::Hash;
12
13use indexmap::IndexMap;
14
15/// A Union-Find data structure for grouping items into disjoint sets.
16///
17/// Corresponds to TS `DisjointSet<T>` in `src/Utils/DisjointSet.ts`.
18/// Uses `IndexMap` to preserve insertion order (matching TS `Map` behavior).
19pub struct DisjointSet<K: Copy + Eq + Hash> {
20 entries: IndexMap<K, K>,
21}
22
23impl<K: Copy + Eq + Hash> DisjointSet<K> {
24 pub fn new() -> Self {
25 DisjointSet {
26 entries: IndexMap::new(),
27 }
28 }
29
30 /// Updates the graph to reflect that the given items form a set,
31 /// linking any previous sets that the items were part of into a single set.
32 ///
33 /// Corresponds to TS `union(items: Array<T>): void`.
34 pub fn union(&mut self, items: &[K]) {
35 if items.is_empty() {
36 return;
37 }
38 let root = self.find(items[0]);
39 for &item in &items[1..] {
40 let item_root = self.find(item);
41 if item_root != root {
42 self.entries.insert(item_root, root);
43 }
44 }
45 }
46
47 /// Find the root of the set containing `item`, with path compression.
48 /// If `item` is not in the set, it is inserted as its own root.
49 ///
50 /// Note: callers that need null/None semantics for missing items should
51 /// use `find_opt()` instead.
52 pub fn find(&mut self, item: K) -> K {
53 let parent = match self.entries.get(&item) {
54 Some(&p) => p,
55 None => {
56 self.entries.insert(item, item);
57 return item;
58 }
59 };
60 if parent == item {
61 return item;
62 }
63 let root = self.find(parent);
64 self.entries.insert(item, root);
65 root
66 }
67
68 /// Find the root of the set containing `item`, returning `None` if the item
69 /// was never added to the set.
70 ///
71 /// Corresponds to TS `find(item: T): T | null`.
72 pub fn find_opt(&mut self, item: K) -> Option<K> {
73 if !self.entries.contains_key(&item) {
74 return None;
75 }
76 Some(self.find(item))
77 }
78
79 /// Returns true if the item is present in the set.
80 ///
81 /// Corresponds to TS `has(item: T): boolean`.
82 pub fn has(&self, item: K) -> bool {
83 self.entries.contains_key(&item)
84 }
85
86 /// Forces the set into canonical form (all items pointing directly to their
87 /// root) and returns a map of items to their roots.
88 ///
89 /// Corresponds to TS `canonicalize(): Map<T, T>`.
90 pub fn canonicalize(&mut self) -> IndexMap<K, K> {
91 let mut result = IndexMap::new();
92 let keys: Vec<K> = self.entries.keys().copied().collect();
93 for item in keys {
94 let root = self.find(item);
95 result.insert(item, root);
96 }
97 result
98 }
99
100 /// Calls the provided callback once for each item in the disjoint set,
101 /// passing the item and the group root to which it belongs.
102 ///
103 /// Corresponds to TS `forEach(fn: (item: T, group: T) => void): void`.
104 pub fn for_each<F>(&mut self, mut f: F)
105 where
106 F: FnMut(K, K),
107 {
108 let keys: Vec<K> = self.entries.keys().copied().collect();
109 for item in keys {
110 let group = self.find(item);
111 f(item, group);
112 }
113 }
114
115 /// Groups all items by their root and returns the groups as a list of sets.
116 ///
117 /// Corresponds to TS `buildSets(): Array<Set<T>>`.
118 pub fn build_sets(&mut self) -> Vec<HashSet<K>> {
119 let mut group_to_index: IndexMap<K, usize> = IndexMap::new();
120 let mut sets: Vec<HashSet<K>> = Vec::new();
121 let keys: Vec<K> = self.entries.keys().copied().collect();
122 for item in keys {
123 let group = self.find(item);
124 let idx = match group_to_index.get(&group) {
125 Some(&idx) => idx,
126 None => {
127 let idx = sets.len();
128 group_to_index.insert(group, idx);
129 sets.push(HashSet::new());
130 idx
131 }
132 };
133 sets[idx].insert(item);
134 }
135 sets
136 }
137
138 /// Returns the number of items in the set.
139 ///
140 /// Corresponds to TS `get size(): number`.
141 pub fn len(&self) -> usize {
142 self.entries.len()
143 }
144
145 pub fn is_empty(&self) -> bool {
146 self.entries.is_empty()
147 }
148}