jj_lib/
union_find.rs

1// Copyright 2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This module implements a [`UnionFind<T>`] type which can be used to
16//! efficiently calculate disjoint sets for any data type.
17
18use std::collections::HashMap;
19use std::hash::Hash;
20
21#[derive(Clone, Copy)]
22struct Node<T> {
23    root: T,
24    size: u32,
25}
26
27/// Implementation of the union-find algorithm:
28/// <https://en.wikipedia.org/wiki/Disjoint-set_data_structure>
29///
30/// Joins disjoint sets by size to amortize cost.
31#[derive(Clone)]
32pub struct UnionFind<T> {
33    roots: HashMap<T, Node<T>>,
34}
35
36impl<T> Default for UnionFind<T>
37where
38    T: Copy + Eq + Hash,
39{
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<T> UnionFind<T>
46where
47    T: Copy + Eq + Hash,
48{
49    /// Creates a new empty UnionFind data structure.
50    pub fn new() -> Self {
51        Self {
52            roots: HashMap::new(),
53        }
54    }
55
56    /// Returns the root identifying the union this item is a part of.
57    pub fn find(&mut self, item: T) -> T {
58        self.find_node(item).root
59    }
60
61    fn find_node(&mut self, item: T) -> Node<T> {
62        match self.roots.get(&item) {
63            Some(node) => {
64                if node.root != item {
65                    let new_root = self.find_node(node.root);
66                    self.roots.insert(item, new_root);
67                    new_root
68                } else {
69                    *node
70                }
71            }
72            None => {
73                let node = Node::<T> {
74                    root: item,
75                    size: 1,
76                };
77                self.roots.insert(item, node);
78                node
79            }
80        }
81    }
82
83    /// Unions the disjoint sets connected to `a` and `b`.
84    pub fn union(&mut self, a: T, b: T) {
85        let a = self.find_node(a);
86        let b = self.find_node(b);
87        if a.root == b.root {
88            return;
89        }
90
91        let new_node = Node::<T> {
92            root: if a.size < b.size { b.root } else { a.root },
93            size: a.size + b.size,
94        };
95        self.roots.insert(a.root, new_node);
96        self.roots.insert(b.root, new_node);
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use itertools::Itertools as _;
103
104    use super::*;
105
106    #[test]
107    fn test_basic() {
108        let mut union_find = UnionFind::<i32>::new();
109
110        // Everything starts as a singleton.
111        assert_eq!(union_find.find(1), 1);
112        assert_eq!(union_find.find(2), 2);
113        assert_eq!(union_find.find(3), 3);
114
115        // Make two pair sets. This implicitly adds node 4.
116        union_find.union(1, 2);
117        union_find.union(3, 4);
118        assert_eq!(union_find.find(1), union_find.find(2));
119        assert_eq!(union_find.find(3), union_find.find(4));
120        assert_ne!(union_find.find(1), union_find.find(3));
121
122        // Unioning the pairs gives everything the same root.
123        union_find.union(1, 3);
124        assert!(
125            [
126                union_find.find(1),
127                union_find.find(2),
128                union_find.find(3),
129                union_find.find(4),
130            ]
131            .iter()
132            .all_equal()
133        );
134    }
135
136    #[test]
137    fn test_union_by_size() {
138        let mut union_find = UnionFind::<i32>::new();
139
140        // Create a set of 3 and a set of 2.
141        union_find.union(1, 2);
142        union_find.union(2, 3);
143        union_find.union(4, 5);
144        let set3 = union_find.find(1);
145        let set2 = union_find.find(4);
146        assert_ne!(set3, set2);
147
148        // Merging them always chooses the larger set.
149        let mut large_first = union_find.clone();
150        large_first.union(1, 4);
151        assert_eq!(large_first.find(1), set3);
152        assert_eq!(large_first.find(4), set3);
153
154        let mut small_first = union_find.clone();
155        small_first.union(4, 1);
156        assert_eq!(small_first.find(1), set3);
157        assert_eq!(small_first.find(4), set3);
158    }
159}