cauly_rust_leetcode_utils/
union_find.rs1use std::collections::HashMap;
3
4pub struct UnionFind4Usize {
5 id: Vec<usize>,
6 size: Vec<usize>,
7 count: usize,
8 length: usize,
9}
10
11pub struct UnionFind<T>
12where
13 T: std::cmp::Eq,
14 T: std::hash::Hash,
15 T: std::fmt::Debug,
16{
17 map: HashMap<T, usize>,
18 uf: UnionFind4Usize,
19}
20
21impl UnionFind4Usize {
22 pub fn new(count: usize) -> Self {
23 UnionFind4Usize {
24 count,
25 length: count,
26 id: (0..count).collect(),
27 size: vec![1; count as usize],
28 }
29 }
30
31 pub fn add(&mut self) -> usize {
32 self.count += 1;
33 self.id.push(self.length);
34 self.size.push(1);
35 self.length += 1;
36 self.length - 1
37 }
38
39 pub fn is_connected(&mut self, p: usize, q: usize) -> bool {
40 self.find(p) == self.find(q)
41 }
42
43 pub fn find(&mut self, p: usize) -> usize {
44 let mut p = p;
45 while p != self.id[p] {
46 self.id[p] = self.id[self.id[p]];
47 p = self.id[p]
48 }
49 p
50 }
51
52 pub fn union(&mut self, p: usize, q: usize) {
53 let i = self.find(p);
54 let j = self.find(q);
55 if i == j {
56 return;
57 }
58 if self.size[i] < self.size[j] {
59 self.id[i] = j;
60 self.size[j] += self.size[i];
61 } else {
62 self.id[j] = i;
63 self.size[i] += self.size[j];
64 }
65 self.count -= 1;
66 }
67
68 pub fn union_count(&self) -> usize {
69 self.count
70 }
71
72 pub fn union_size(&mut self, p: usize) -> usize {
73 let root = self.find(p);
74 return self.size[root];
75 }
76
77 pub fn len(&self) -> usize {
78 self.length
79 }
80}
81
82impl<T> UnionFind<T>
83where
84 T: std::cmp::Eq,
85 T: std::hash::Hash,
86 T: std::fmt::Debug,
87{
88 pub fn new() -> Self {
89 UnionFind {
90 map: HashMap::new(),
91 uf: UnionFind4Usize {
92 count: 0,
93 length: 0,
94 id: Vec::new(),
95 size: Vec::new(),
96 },
97 }
98 }
99
100 pub fn from_iter<I>(iter: I) -> UnionFind<T>
101 where
102 I: IntoIterator<Item = T>,
103 {
104 let mut map = HashMap::new();
105 let mut index = 0;
106 for item in iter.into_iter() {
107 map.insert(item, index);
108 index += 1;
109 }
110 let len = map.len();
111 UnionFind {
112 map,
113 uf: UnionFind4Usize::new(len),
114 }
115 }
116
117 pub fn len(&self) -> usize {
118 self.uf.len()
119 }
120
121 pub fn union_count(&self) -> usize {
122 self.uf.union_count()
123 }
124
125 pub fn union_size(&mut self, p: T) -> Option<usize> {
126 if let Some(index) = self.map.get(&p) {
127 let root_index = self.uf.find(*index);
128 Some(self.uf.union_size(root_index))
129 } else {
130 None
131 }
132 }
133
134 pub fn find(&mut self, p: T) -> Option<&T> {
135 if let Some(index) = self.map.get(&p) {
136 let root_index = self.uf.find(*index);
137 self._find_by_index(root_index)
138 } else {
139 None
140 }
141 }
142
143 pub fn union(&mut self, p: T, q: T) -> Result<usize, String> {
144 if let Some(pindex) = self.map.get(&p) {
145 if let Some(qindex) = self.map.get(&q) {
146 self.uf.union(*pindex, *qindex);
147 return Ok(self.uf.union_size(*pindex));
148 } else {
149 return Err(format!("{:?} not found.", q));
150 }
151 } else {
152 return Err(format!("{:?} not found.", p));
153 }
154 }
155
156 pub fn is_connected(&mut self, p: T, q: T) -> Result<bool, String> {
157 if let Some(pindex) = self.map.get(&p) {
158 if let Some(qindex) = self.map.get(&q) {
159 return Ok(self.uf.find(*pindex) == self.uf.find(*qindex));
160 } else {
161 return Err(format!("{:?} not found.", q));
162 }
163 } else {
164 return Err(format!("{:?} not found.", p));
165 }
166 }
167
168 pub fn add(&mut self, p: T) {
169 let index = self.uf.add();
170 self.map.insert(p, index);
171 }
172
173 fn _find_by_index(&self, index: usize) -> Option<&T> {
174 for (k, v) in self.map.iter() {
175 if *v == index {
176 return Some(k);
177 }
178 }
179 None
180 }
181}
182
183