Skip to main content

aoc_core/
graph.rs

1/// Disjoint-set union-find data structure.
2pub struct Dsu {
3    /// Parent pointer for each element.
4    ///
5    /// A root element is its own parent.
6    parent: Vec<usize>,
7
8    /// Size of each component (valid only for roots).
9    size: Vec<usize>,
10
11    /// Current number of disjoint components.
12    components: usize,
13}
14
15impl Dsu {
16    /// Creates a new DSU with `n` singleton components.
17    #[must_use]
18    pub fn new(n: usize) -> Self {
19        Self {
20            parent: (0..n).collect(),
21            size: vec![1; n],
22            components: n,
23        }
24    }
25
26    /// Parent point for each element.
27    ///
28    /// A root element is its own parent.
29    #[inline]
30    #[must_use]
31    pub fn parent(&self) -> &[usize] {
32        &self.parent
33    }
34
35    /// Size of each component (valid only for roots).
36    #[inline]
37    #[must_use]
38    pub fn size(&self) -> &[usize] {
39        &self.size
40    }
41
42    /// Current number of disjoint components.
43    #[inline]
44    #[must_use]
45    pub const fn components(&self) -> usize {
46        self.components
47    }
48
49    /// Returns the root of the component containing `x`.
50    ///
51    /// Applies path compression, making future queries faster.
52    #[inline]
53    pub fn find(&mut self, mut x: usize) -> usize {
54        let mut root = x;
55
56        while self.parent[root] != root {
57            root = self.parent[root];
58        }
59
60        while self.parent[x] != x {
61            let next = self.parent[x];
62            self.parent[x] = root;
63            x = next;
64        }
65
66        root
67    }
68
69    /// Unites the components containing `x` and `y`,
70    /// merging the smaller component into the bigger one.
71    ///
72    /// Returns `true` if a merge occurred,
73    /// or `false` if they were already in the same component.
74    #[inline]
75    pub fn union(&mut self, x: usize, y: usize) -> bool {
76        let rx = self.find(x);
77        let ry = self.find(y);
78
79        if rx == ry {
80            return false;
81        }
82
83        if self.size[rx] < self.size[ry] {
84            self.parent[rx] = ry;
85            self.size[ry] += self.size[rx];
86        } else {
87            self.parent[ry] = rx;
88            self.size[rx] += self.size[ry];
89        }
90
91        self.components -= 1;
92
93        true
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::Dsu;
100
101    #[test]
102    fn new() {
103        let input = 8;
104        let expected = (vec![0, 1, 2, 3, 4, 5, 6, 7], vec![1; 8], 8);
105        let dsu = Dsu::new(input);
106        let output = (dsu.parent, dsu.size, dsu.components);
107        assert_eq!(expected, output, "\n input: {input:?}");
108    }
109
110    // ------------------------------------------------------------------------------------------------
111    // Find
112
113    #[test]
114    fn find() {
115        let input = [0, 3, 1, 2];
116        let expected = [0, 3, 0, 0];
117        let mut dsu = Dsu {
118            parent: vec![0, 0, 1, 3],
119            size: vec![2, 1, 1, 1],
120            components: 2,
121        };
122        let output = input.map(|n| dsu.find(n));
123        assert_eq!(expected, output, "\n input: {input:?}");
124    }
125
126    #[test]
127    fn find_compression() {
128        let input = 2;
129        let expected = 0;
130        let mut dsu = Dsu {
131            parent: vec![0, 0, 1, 3],
132            size: vec![2, 1, 1, 1],
133            components: 2,
134        };
135        let _ = dsu.find(input);
136        let output = dsu.parent()[2];
137        assert_eq!(expected, output, "\n input: {input:?}");
138    }
139
140    // ------------------------------------------------------------------------------------------------
141    // Union
142
143    #[test]
144    fn union_size() {
145        let input = [((0, 1), 0), ((2, 0), 0), ((3, 0), 0), ((3, 1), 0)];
146        let expected = [(true, 2), (true, 3), (true, 4), (false, 4)];
147        let mut dsu = Dsu::new(4);
148        let output = input.map(|((x, y), r)| {
149            let union = dsu.union(x, y);
150            let root = dsu.find(r);
151            (union, dsu.size()[root])
152        });
153        assert_eq!(expected, output, "\n input: {input:?}");
154    }
155
156    #[test]
157    fn union_components() {
158        let input = [(0, 1), (2, 3), (1, 3), (0, 2)];
159        let expected = [(true, 3), (true, 2), (true, 1), (false, 1)];
160        let mut dsu = Dsu::new(4);
161        let output = input.map(|(x, y)| (dsu.union(x, y), dsu.components()));
162        assert_eq!(expected, output, "\n input: {input:?}");
163    }
164}