egg/
unionfind.rs

1use crate::Id;
2use std::fmt::Debug;
3
4#[derive(Debug, Clone, Default)]
5#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
6pub struct UnionFind {
7    parents: Vec<Id>,
8}
9
10impl UnionFind {
11    pub fn make_set(&mut self) -> Id {
12        let id = Id::from(self.parents.len());
13        self.parents.push(id);
14        id
15    }
16
17    pub fn size(&self) -> usize {
18        self.parents.len()
19    }
20
21    fn parent(&self, query: Id) -> Id {
22        self.parents[usize::from(query)]
23    }
24
25    fn parent_mut(&mut self, query: Id) -> &mut Id {
26        &mut self.parents[usize::from(query)]
27    }
28
29    pub fn find(&self, mut current: Id) -> Id {
30        while current != self.parent(current) {
31            current = self.parent(current)
32        }
33        current
34    }
35
36    pub fn find_mut(&mut self, mut current: Id) -> Id {
37        while current != self.parent(current) {
38            let grandparent = self.parent(self.parent(current));
39            *self.parent_mut(current) = grandparent;
40            current = grandparent;
41        }
42        current
43    }
44
45    /// Given two leader ids, unions the two eclasses making root1 the leader.
46    pub fn union(&mut self, root1: Id, root2: Id) -> Id {
47        *self.parent_mut(root2) = root1;
48        root1
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    fn ids(us: impl IntoIterator<Item = usize>) -> Vec<Id> {
57        us.into_iter().map(|u| u.into()).collect()
58    }
59
60    #[test]
61    fn union_find() {
62        let n = 10;
63        let id = Id::from;
64
65        let mut uf = UnionFind::default();
66        for _ in 0..n {
67            uf.make_set();
68        }
69
70        // test the initial condition of everyone in their own set
71        assert_eq!(uf.parents, ids(0..n));
72
73        // build up one set
74        uf.union(id(0), id(1));
75        uf.union(id(0), id(2));
76        uf.union(id(0), id(3));
77
78        // build up another set
79        uf.union(id(6), id(7));
80        uf.union(id(6), id(8));
81        uf.union(id(6), id(9));
82
83        // this should compress all paths
84        for i in 0..n {
85            uf.find_mut(id(i));
86        }
87
88        // indexes:         0, 1, 2, 3, 4, 5, 6, 7, 8, 9
89        let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6];
90        assert_eq!(uf.parents, ids(expected));
91    }
92}