path_finding/
union_find.rs

1pub struct UnionFind {
2    sizes: Vec<usize>,
3    ids: Vec<usize>,
4    components: usize,
5}
6
7impl UnionFind {
8    pub fn from(node_count: usize) -> UnionFind {
9        let sizes = vec![1; node_count];
10        let mut ids = Vec::with_capacity(node_count);
11        let components = node_count;
12
13        for i in 0..node_count {
14            ids.push(i);
15        }
16
17        return UnionFind {
18            sizes,
19            ids,
20            components,
21        };
22    }
23
24    fn find(&mut self, mut p: usize) -> usize {
25        let mut root = p;
26
27        while root != self.ids[root] {
28            root = self.ids[root];
29        }
30
31        while p != root {
32            let next = self.ids[p];
33            self.ids[p] = root;
34            p = next;
35        }
36
37        return root;
38    }
39
40    pub fn connected(&mut self, p: usize, q: usize) -> bool {
41        return self.find(p) == self.find(q);
42    }
43
44    pub fn unify(&mut self, p: usize, q: usize) {
45        let p_root = self.find(p);
46        let q_root = self.find(q);
47
48        if p_root == q_root {
49            return;
50        }
51
52        if self.sizes[p_root] < self.sizes[q_root] {
53            self.sizes[q_root] += self.sizes[p_root];
54            self.ids[p_root] = self.ids[q_root];
55        } else {
56            self.sizes[p_root] += self.sizes[q_root];
57            self.ids[q_root] = self.ids[p_root];
58        }
59
60        self.components -= 1;
61    }
62
63    pub fn size(&self, id: usize) -> usize {
64        return self.sizes[id];
65    }
66
67    pub fn parent(&self, id: usize) -> usize {
68        return self.ids[id];
69    }
70}
71
72#[test]
73fn union_find_with_zero_edges_should_succeed() {
74    let union_find = UnionFind::from(0);
75
76    assert_eq!(0, union_find.components)
77}
78
79#[test]
80fn unify_should_decrease_components() {
81    let mut union_find = UnionFind::from(2);
82
83    assert_eq!(2, union_find.components);
84
85    union_find.unify(0, 1);
86    assert_eq!(1, union_find.components);
87    assert_eq!(2, union_find.size(0));
88    assert_eq!(0, union_find.parent(1));
89}
90
91#[test]
92fn test_find() {
93    let mut union_find = UnionFind::from(5);
94    union_find.unify(0, 1);
95    union_find.unify(1, 2);
96
97    assert_eq!(0, union_find.find(2));
98}
99
100#[test]
101fn test_connected() {
102    let mut union_find = UnionFind::from(5);
103    union_find.unify(0, 1);
104    union_find.unify(3, 4);
105
106    assert!(union_find.connected(0, 1));
107    assert!(union_find.connected(3, 4));
108    assert!(!union_find.connected(0, 4));
109}
110
111#[test]
112fn test_unify() {
113    let mut union_find = UnionFind::from(4);
114    union_find.unify(0, 1);
115
116    assert_eq!(2, union_find.size(0));
117    assert_eq!(1, union_find.size(1));
118    assert_eq!(1, union_find.size(2));
119    assert_eq!(1, union_find.size(3));
120}
121
122#[test]
123fn test_unify_multiple_groups() {
124    let mut union_find = UnionFind::from(6);
125    union_find.unify(0, 1);
126    union_find.unify(1, 2);
127    union_find.unify(3, 4);
128
129    assert_eq!(3, union_find.size(0));
130    assert_eq!(1, union_find.size(1));
131    assert_eq!(1, union_find.size(2));
132    assert_eq!(2, union_find.size(3));
133    assert_eq!(1, union_find.size(4));
134    assert_eq!(1, union_find.size(5));
135}
136
137#[test]
138fn test_components() {
139    let mut union_find = UnionFind::from(5);
140    assert_eq!(5, union_find.components);
141
142    union_find.unify(0, 1);
143    assert_eq!(4, union_find.components);
144
145    union_find.unify(1, 2);
146    assert_eq!(3, union_find.components);
147
148    union_find.unify(3, 4);
149    assert_eq!(2, union_find.components);
150}