1use std::collections::HashMap;
19use std::hash::Hash;
20
21#[derive(Clone, Copy)]
22struct Node<T> {
23 root: T,
24 size: u32,
25}
26
27#[derive(Clone)]
32pub struct UnionFind<T> {
33 roots: HashMap<T, Node<T>>,
34}
35
36impl<T> Default for UnionFind<T>
37where
38 T: Copy + Eq + Hash,
39{
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<T> UnionFind<T>
46where
47 T: Copy + Eq + Hash,
48{
49 pub fn new() -> Self {
51 Self {
52 roots: HashMap::new(),
53 }
54 }
55
56 pub fn find(&mut self, item: T) -> T {
58 self.find_node(item).root
59 }
60
61 fn find_node(&mut self, item: T) -> Node<T> {
62 match self.roots.get(&item) {
63 Some(node) => {
64 if node.root != item {
65 let new_root = self.find_node(node.root);
66 self.roots.insert(item, new_root);
67 new_root
68 } else {
69 *node
70 }
71 }
72 None => {
73 let node = Node::<T> {
74 root: item,
75 size: 1,
76 };
77 self.roots.insert(item, node);
78 node
79 }
80 }
81 }
82
83 pub fn union(&mut self, a: T, b: T) {
85 let a = self.find_node(a);
86 let b = self.find_node(b);
87 if a.root == b.root {
88 return;
89 }
90
91 let new_node = Node::<T> {
92 root: if a.size < b.size { b.root } else { a.root },
93 size: a.size + b.size,
94 };
95 self.roots.insert(a.root, new_node);
96 self.roots.insert(b.root, new_node);
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use itertools::Itertools as _;
103
104 use super::*;
105
106 #[test]
107 fn test_basic() {
108 let mut union_find = UnionFind::<i32>::new();
109
110 assert_eq!(union_find.find(1), 1);
112 assert_eq!(union_find.find(2), 2);
113 assert_eq!(union_find.find(3), 3);
114
115 union_find.union(1, 2);
117 union_find.union(3, 4);
118 assert_eq!(union_find.find(1), union_find.find(2));
119 assert_eq!(union_find.find(3), union_find.find(4));
120 assert_ne!(union_find.find(1), union_find.find(3));
121
122 union_find.union(1, 3);
124 assert!(
125 [
126 union_find.find(1),
127 union_find.find(2),
128 union_find.find(3),
129 union_find.find(4),
130 ]
131 .iter()
132 .all_equal()
133 );
134 }
135
136 #[test]
137 fn test_union_by_size() {
138 let mut union_find = UnionFind::<i32>::new();
139
140 union_find.union(1, 2);
142 union_find.union(2, 3);
143 union_find.union(4, 5);
144 let set3 = union_find.find(1);
145 let set2 = union_find.find(4);
146 assert_ne!(set3, set2);
147
148 let mut large_first = union_find.clone();
150 large_first.union(1, 4);
151 assert_eq!(large_first.find(1), set3);
152 assert_eq!(large_first.find(4), set3);
153
154 let mut small_first = union_find.clone();
155 small_first.union(4, 1);
156 assert_eq!(small_first.find(1), set3);
157 assert_eq!(small_first.find(4), set3);
158 }
159}