Skip to main content

react_compiler_utils/
disjoint_set.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! A generic disjoint-set (union-find) data structure.
7//!
8//! Ported from TypeScript `src/Utils/DisjointSet.ts`.
9
10use std::collections::HashSet;
11use std::hash::Hash;
12
13use indexmap::IndexMap;
14
15/// A Union-Find data structure for grouping items into disjoint sets.
16///
17/// Corresponds to TS `DisjointSet<T>` in `src/Utils/DisjointSet.ts`.
18/// Uses `IndexMap` to preserve insertion order (matching TS `Map` behavior).
19pub struct DisjointSet<K: Copy + Eq + Hash> {
20    entries: IndexMap<K, K>,
21}
22
23impl<K: Copy + Eq + Hash> DisjointSet<K> {
24    pub fn new() -> Self {
25        DisjointSet {
26            entries: IndexMap::new(),
27        }
28    }
29
30    /// Updates the graph to reflect that the given items form a set,
31    /// linking any previous sets that the items were part of into a single set.
32    ///
33    /// Corresponds to TS `union(items: Array<T>): void`.
34    pub fn union(&mut self, items: &[K]) {
35        if items.is_empty() {
36            return;
37        }
38        let root = self.find(items[0]);
39        for &item in &items[1..] {
40            let item_root = self.find(item);
41            if item_root != root {
42                self.entries.insert(item_root, root);
43            }
44        }
45    }
46
47    /// Find the root of the set containing `item`, with path compression.
48    /// If `item` is not in the set, it is inserted as its own root.
49    ///
50    /// Note: callers that need null/None semantics for missing items should
51    /// use `find_opt()` instead.
52    pub fn find(&mut self, item: K) -> K {
53        let parent = match self.entries.get(&item) {
54            Some(&p) => p,
55            None => {
56                self.entries.insert(item, item);
57                return item;
58            }
59        };
60        if parent == item {
61            return item;
62        }
63        let root = self.find(parent);
64        self.entries.insert(item, root);
65        root
66    }
67
68    /// Find the root of the set containing `item`, returning `None` if the item
69    /// was never added to the set.
70    ///
71    /// Corresponds to TS `find(item: T): T | null`.
72    pub fn find_opt(&mut self, item: K) -> Option<K> {
73        if !self.entries.contains_key(&item) {
74            return None;
75        }
76        Some(self.find(item))
77    }
78
79    /// Returns true if the item is present in the set.
80    ///
81    /// Corresponds to TS `has(item: T): boolean`.
82    pub fn has(&self, item: K) -> bool {
83        self.entries.contains_key(&item)
84    }
85
86    /// Forces the set into canonical form (all items pointing directly to their
87    /// root) and returns a map of items to their roots.
88    ///
89    /// Corresponds to TS `canonicalize(): Map<T, T>`.
90    pub fn canonicalize(&mut self) -> IndexMap<K, K> {
91        let mut result = IndexMap::new();
92        let keys: Vec<K> = self.entries.keys().copied().collect();
93        for item in keys {
94            let root = self.find(item);
95            result.insert(item, root);
96        }
97        result
98    }
99
100    /// Calls the provided callback once for each item in the disjoint set,
101    /// passing the item and the group root to which it belongs.
102    ///
103    /// Corresponds to TS `forEach(fn: (item: T, group: T) => void): void`.
104    pub fn for_each<F>(&mut self, mut f: F)
105    where
106        F: FnMut(K, K),
107    {
108        let keys: Vec<K> = self.entries.keys().copied().collect();
109        for item in keys {
110            let group = self.find(item);
111            f(item, group);
112        }
113    }
114
115    /// Groups all items by their root and returns the groups as a list of sets.
116    ///
117    /// Corresponds to TS `buildSets(): Array<Set<T>>`.
118    pub fn build_sets(&mut self) -> Vec<HashSet<K>> {
119        let mut group_to_index: IndexMap<K, usize> = IndexMap::new();
120        let mut sets: Vec<HashSet<K>> = Vec::new();
121        let keys: Vec<K> = self.entries.keys().copied().collect();
122        for item in keys {
123            let group = self.find(item);
124            let idx = match group_to_index.get(&group) {
125                Some(&idx) => idx,
126                None => {
127                    let idx = sets.len();
128                    group_to_index.insert(group, idx);
129                    sets.push(HashSet::new());
130                    idx
131                }
132            };
133            sets[idx].insert(item);
134        }
135        sets
136    }
137
138    /// Returns the number of items in the set.
139    ///
140    /// Corresponds to TS `get size(): number`.
141    pub fn len(&self) -> usize {
142        self.entries.len()
143    }
144
145    pub fn is_empty(&self) -> bool {
146        self.entries.is_empty()
147    }
148}