algorithms_edu/algo/graph/tree/
lca.rs

1//! Implementation of finding the Lowest Common Ancestor (LCA) of a tree. This impl first finds an
2//! Euler tour from the root node which visits all the nodes in the tree. The node height values
3//! obtained from the Euler tour can then be used in combination with a sparse table to find the LCA
4//! in O(1).
5//!
6//! # Resources
7//!
8//! - [W. Fiset's video](https://www.youtube.com/watch?v=sD1IoalFomA)
9
10use super::Node;
11
12pub struct LcaSolver {
13    sparse_table: MinSparseTable,
14    node_order: Vec<usize>,
15    // The last occurrence mapping. This mapping keeps track of the last occurrence of a TreeNode in
16    // the Euler tour for easy indexing.
17    last: Vec<usize>,
18}
19
20impl LcaSolver {
21    pub fn new(root: &Node, size: usize) -> Self {
22        let mut node_depth = vec![0usize; size * 2 - 1]; // Vec::<usize>::new();
23        let mut node_order = vec![0usize; size * 2 - 1]; // Vec::<usize>::new();
24        let mut last = vec![0usize; size];
25        let mut tour_index = 0;
26
27        let mut visit = |node: usize, depth: usize| {
28            node_order[tour_index] = node;
29            node_depth[tour_index] = depth;
30            last[node] = tour_index;
31            tour_index += 1;
32        };
33
34        //dfs
35        let mut stack = vec![(root, 0usize)];
36        let mut visited = vec![false; size];
37        while let Some((node, depth)) = stack.pop() {
38            visit(node.id, depth);
39            if !visited[node.id] {
40                visited[node.id] = true;
41                for child in &node.children {
42                    stack.push((node, depth)); // revisit the current node after visiting each child
43                    stack.push((child, depth + 1));
44                }
45            }
46        }
47
48        let sparse_table = MinSparseTable::new(&node_depth);
49        Self {
50            sparse_table,
51            node_order,
52            last,
53        }
54    }
55    pub fn lca(&self, a: usize, b: usize) -> usize {
56        let (a, b) = (self.last[a], self.last[b]);
57        let (l, r) = if a < b { (a, b) } else { (b, a) };
58        let idx = self.sparse_table.query_index(l, r);
59        self.node_order[idx]
60    }
61}
62
63pub struct MinSparseTable {
64    // The sparse table values.
65    min_depth: Vec<Vec<Option<usize>>>,
66    // Index Table associated with the values in the sparse table.
67    index: Vec<Vec<Option<usize>>>,
68    log2: Vec<usize>,
69}
70
71impl MinSparseTable {
72    pub fn new(node_depth: &[usize]) -> Self {
73        let n = node_depth.len();
74        let log2 = Self::build_log2(n);
75        let m = log2[n];
76        let mut min_depth = vec![vec![None; n]; m + 1];
77        let mut index = vec![vec![None; n]; m + 1];
78        for (i, &depth) in node_depth.iter().enumerate() {
79            min_depth[0][i] = Some(depth);
80            index[0][i] = Some(i);
81        }
82        // Build sparse table combining the values of the previous intervals.
83        for i in 1..=m {
84            for j in 0..=(n - (1 << i)) {
85                let left_interval = min_depth[i - 1][j];
86                let right_interval = min_depth[i - 1][j + (1 << (i - 1))];
87                // Propagate the index of the best value
88                if left_interval <= right_interval {
89                    min_depth[i][j] = left_interval;
90                    index[i][j] = index[i - 1][j];
91                } else {
92                    min_depth[i][j] = right_interval;
93                    index[i][j] = index[i - 1][j + (1 << (i - 1))];
94                }
95            }
96        }
97        Self {
98            min_depth,
99            index,
100            log2,
101        }
102    }
103    fn build_log2(n: usize) -> Vec<usize> {
104        let mut log2 = vec![0usize; n + 1];
105        for i in 2..=n {
106            log2[i] = log2[i / 2] + 1;
107        }
108        log2
109    }
110    fn query_index(&self, l: usize, r: usize) -> usize {
111        let len = r - l + 1;
112        let i = self.log2[len];
113        let left_interval = self.min_depth[i][l];
114        let right_interval = self.min_depth[i][r - (1 << i) + 1];
115        if left_interval <= right_interval {
116            self.index[i][l]
117        } else {
118            self.index[i][r - (i << i) + 1]
119        }
120        .unwrap()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::algo::graph::UnweightedAdjacencyList;
128    #[test]
129    fn test_tree_lowest_commmon_ancestor() {
130        let tree = UnweightedAdjacencyList::new_undirected(
131            17,
132            &[
133                [0, 1],
134                [0, 2],
135                [1, 3],
136                [1, 4],
137                [2, 5],
138                [2, 6],
139                [2, 7],
140                [3, 8],
141                [3, 9],
142                [5, 10],
143                [5, 11],
144                [7, 12],
145                [7, 13],
146                [11, 14],
147                [11, 15],
148                [11, 16],
149            ],
150        );
151        let tree = Node::from_adjacency_list(&tree, 0);
152
153        let lca_solver = LcaSolver::new(&tree, 17);
154        assert_eq!(lca_solver.lca(14, 13), 2);
155        assert_eq!(lca_solver.lca(9, 11), 0);
156        assert_eq!(lca_solver.lca(12, 12), 12);
157    }
158}
159
160pub mod with_generic_sparse_table {
161
162    use super::super::Node;
163    use crate::data_structures::sparse_table::SparseTable;
164
165    type IndexAndDepth = (usize, usize);
166
167    pub struct LcaSolver {
168        sparse_table:
169            SparseTable<IndexAndDepth, Box<dyn Fn(IndexAndDepth, IndexAndDepth) -> IndexAndDepth>>,
170        node_order: Vec<usize>,
171        // The last occurrence mapping. This mapping keeps track of the last occurrence of a TreeNode in
172        // the Euler tour for easy indexing.
173        last: Vec<usize>,
174    }
175
176    impl LcaSolver {
177        pub fn new(root: &Node, size: usize) -> Self {
178            let mut node_depth = vec![0usize; size * 2 - 1];
179            let mut node_order = vec![0usize; size * 2 - 1];
180            let mut last = vec![0usize; size];
181            let mut tour_index = 0;
182
183            let mut visit = |node: usize, depth: usize| {
184                node_order[tour_index] = node;
185                node_depth[tour_index] = depth;
186                last[node] = tour_index;
187                tour_index += 1;
188            };
189
190            //dfs
191            let mut stack = vec![(root, 0usize)];
192            let mut visited = vec![false; size];
193            while let Some((node, depth)) = stack.pop() {
194                visit(node.id, depth);
195                if !visited[node.id] {
196                    visited[node.id] = true;
197                    for child in &node.children {
198                        stack.push((node, depth)); // revisit the current node after visiting each child
199                        stack.push((child, depth + 1));
200                    }
201                }
202            }
203            let index_and_depth = node_depth.into_iter().enumerate().collect::<Vec<_>>();
204            let f: Box<dyn Fn(IndexAndDepth, IndexAndDepth) -> IndexAndDepth> =
205                Box::new(|a: IndexAndDepth, b: IndexAndDepth| if a.1 < b.1 { a } else { b });
206            let sparse_table = SparseTable::new(&index_and_depth, f, true);
207            Self {
208                sparse_table,
209                node_order,
210                last,
211            }
212        }
213        pub fn lca(&self, a: usize, b: usize) -> usize {
214            let (a, b) = (self.last[a], self.last[b]);
215            let (l, r) = if a < b { (a, b) } else { (b, a) };
216            let idx = self.sparse_table.query(l, r).0;
217            self.node_order[idx]
218        }
219    }
220
221    #[cfg(test)]
222    mod tests {
223        use super::*;
224        use crate::algo::graph::UnweightedAdjacencyList;
225        #[test]
226        fn test_tree_lowest_commmon_ancestor_with_generic_sparse_table() {
227            let tree = UnweightedAdjacencyList::new_undirected(
228                17,
229                &[
230                    [0, 1],
231                    [0, 2],
232                    [1, 3],
233                    [1, 4],
234                    [2, 5],
235                    [2, 6],
236                    [2, 7],
237                    [3, 8],
238                    [3, 9],
239                    [5, 10],
240                    [5, 11],
241                    [7, 12],
242                    [7, 13],
243                    [11, 14],
244                    [11, 15],
245                    [11, 16],
246                ],
247            );
248            let tree = Node::from_adjacency_list(&tree, 0);
249            let lca_solver = LcaSolver::new(&tree, 17);
250            assert_eq!(lca_solver.lca(14, 13), 2);
251            assert_eq!(lca_solver.lca(9, 11), 0);
252            assert_eq!(lca_solver.lca(12, 12), 12);
253        }
254    }
255}