cauly_rust_leetcode_utils/
union_find.rs

1// #region UnionFind
2use std::collections::HashMap;
3
4pub struct UnionFind4Usize {
5    id: Vec<usize>,
6    size: Vec<usize>,
7    count: usize,
8    length: usize,
9}
10
11pub struct UnionFind<T>
12where
13    T: std::cmp::Eq,
14    T: std::hash::Hash,
15    T: std::fmt::Debug,
16{
17    map: HashMap<T, usize>,
18    uf: UnionFind4Usize,
19}
20
21impl UnionFind4Usize {
22    pub fn new(count: usize) -> Self {
23        UnionFind4Usize {
24            count,
25            length: count,
26            id: (0..count).collect(),
27            size: vec![1; count as usize],
28        }
29    }
30
31    pub fn add(&mut self) -> usize {
32        self.count += 1;
33        self.id.push(self.length);
34        self.size.push(1);
35        self.length += 1;
36        self.length - 1
37    }
38
39    pub fn is_connected(&mut self, p: usize, q: usize) -> bool {
40        self.find(p) == self.find(q)
41    }
42
43    pub fn find(&mut self, p: usize) -> usize {
44        let mut p = p;
45        while p != self.id[p] {
46            self.id[p] = self.id[self.id[p]];
47            p = self.id[p]
48        }
49        p
50    }
51
52    pub fn union(&mut self, p: usize, q: usize) {
53        let i = self.find(p);
54        let j = self.find(q);
55        if i == j {
56            return;
57        }
58        if self.size[i] < self.size[j] {
59            self.id[i] = j;
60            self.size[j] += self.size[i];
61        } else {
62            self.id[j] = i;
63            self.size[i] += self.size[j];
64        }
65        self.count -= 1;
66    }
67
68    pub fn union_count(&self) -> usize {
69        self.count
70    }
71
72    pub fn union_size(&mut self, p: usize) -> usize {
73        let root = self.find(p);
74        return self.size[root];
75    }
76
77    pub fn len(&self) -> usize {
78        self.length
79    }
80}
81
82impl<T> UnionFind<T>
83where
84    T: std::cmp::Eq,
85    T: std::hash::Hash,
86    T: std::fmt::Debug,
87{
88    pub fn new() -> Self {
89        UnionFind {
90            map: HashMap::new(),
91            uf: UnionFind4Usize {
92                count: 0,
93                length: 0,
94                id: Vec::new(),
95                size: Vec::new(),
96            },
97        }
98    }
99
100    pub fn from_iter<I>(iter: I) -> UnionFind<T>
101    where
102        I: IntoIterator<Item = T>,
103    {
104        let mut map = HashMap::new();
105        let mut index = 0;
106        for item in iter.into_iter() {
107            map.insert(item, index);
108            index += 1;
109        }
110        let len = map.len();
111        UnionFind {
112            map,
113            uf: UnionFind4Usize::new(len),
114        }
115    }
116
117    pub fn len(&self) -> usize {
118        self.uf.len()
119    }
120
121    pub fn union_count(&self) -> usize {
122        self.uf.union_count()
123    }
124
125    pub fn union_size(&mut self, p: T) -> Option<usize> {
126        if let Some(index) = self.map.get(&p) {
127            let root_index = self.uf.find(*index);
128            Some(self.uf.union_size(root_index))
129        } else {
130            None
131        }
132    }
133
134    pub fn find(&mut self, p: T) -> Option<&T> {
135        if let Some(index) = self.map.get(&p) {
136            let root_index = self.uf.find(*index);
137            self._find_by_index(root_index)
138        } else {
139            None
140        }
141    }
142
143    pub fn union(&mut self, p: T, q: T) -> Result<usize, String> {
144        if let Some(pindex) = self.map.get(&p) {
145            if let Some(qindex) = self.map.get(&q) {
146                self.uf.union(*pindex, *qindex);
147                return Ok(self.uf.union_size(*pindex));
148            } else {
149                return Err(format!("{:?} not found.", q));
150            }
151        } else {
152            return Err(format!("{:?} not found.", p));
153        }
154    }
155
156    pub fn is_connected(&mut self, p: T, q: T) -> Result<bool, String> {
157        if let Some(pindex) = self.map.get(&p) {
158            if let Some(qindex) = self.map.get(&q) {
159                return Ok(self.uf.find(*pindex) == self.uf.find(*qindex));
160            } else {
161                return Err(format!("{:?} not found.", q));
162            }
163        } else {
164            return Err(format!("{:?} not found.", p));
165        }
166    }
167
168    pub fn add(&mut self, p: T) {
169        let index = self.uf.add();
170        self.map.insert(p, index);
171    }
172
173    fn _find_by_index(&self, index: usize) -> Option<&T> {
174        for (k, v) in self.map.iter() {
175            if *v == index {
176                return Some(k);
177            }
178        }
179        None
180    }
181}
182
183// #endregion