algorithms-edu 0.2.7

Algorithms for pedagogical demonstration
Documentation
//! Implementation of finding the Lowest Common Ancestor (LCA) of a tree. This impl first finds an
//! Euler tour from the root node which visits all the nodes in the tree. The node height values
//! obtained from the Euler tour can then be used in combination with a sparse table to find the LCA
//! in O(1).
//!
//! # Resources
//!
//! - [W. Fiset's video](https://www.youtube.com/watch?v=sD1IoalFomA)

use super::Node;

pub struct LcaSolver {
    sparse_table: MinSparseTable,
    node_order: Vec<usize>,
    // The last occurrence mapping. This mapping keeps track of the last occurrence of a TreeNode in
    // the Euler tour for easy indexing.
    last: Vec<usize>,
}

impl LcaSolver {
    pub fn new(root: &Node, size: usize) -> Self {
        let mut node_depth = vec![0usize; size * 2 - 1]; // Vec::<usize>::new();
        let mut node_order = vec![0usize; size * 2 - 1]; // Vec::<usize>::new();
        let mut last = vec![0usize; size];
        let mut tour_index = 0;

        let mut visit = |node: usize, depth: usize| {
            node_order[tour_index] = node;
            node_depth[tour_index] = depth;
            last[node] = tour_index;
            tour_index += 1;
        };

        //dfs
        let mut stack = vec![(root, 0usize)];
        let mut visited = vec![false; size];
        while let Some((node, depth)) = stack.pop() {
            visit(node.id, depth);
            if !visited[node.id] {
                visited[node.id] = true;
                for child in &node.children {
                    stack.push((node, depth)); // revisit the current node after visiting each child
                    stack.push((child, depth + 1));
                }
            }
        }

        let sparse_table = MinSparseTable::new(&node_depth);
        Self {
            sparse_table,
            node_order,
            last,
        }
    }
    pub fn lca(&self, a: usize, b: usize) -> usize {
        let (a, b) = (self.last[a], self.last[b]);
        let (l, r) = if a < b { (a, b) } else { (b, a) };
        let idx = self.sparse_table.query_index(l, r);
        self.node_order[idx]
    }
}

pub struct MinSparseTable {
    // The sparse table values.
    min_depth: Vec<Vec<Option<usize>>>,
    // Index Table associated with the values in the sparse table.
    index: Vec<Vec<Option<usize>>>,
    log2: Vec<usize>,
}

impl MinSparseTable {
    pub fn new(node_depth: &[usize]) -> Self {
        let n = node_depth.len();
        let log2 = Self::build_log2(n);
        let m = log2[n];
        let mut min_depth = vec![vec![None; n]; m + 1];
        let mut index = vec![vec![None; n]; m + 1];
        for (i, &depth) in node_depth.iter().enumerate() {
            min_depth[0][i] = Some(depth);
            index[0][i] = Some(i);
        }
        // Build sparse table combining the values of the previous intervals.
        for i in 1..=m {
            for j in 0..=(n - (1 << i)) {
                let left_interval = min_depth[i - 1][j];
                let right_interval = min_depth[i - 1][j + (1 << (i - 1))];
                // Propagate the index of the best value
                if left_interval <= right_interval {
                    min_depth[i][j] = left_interval;
                    index[i][j] = index[i - 1][j];
                } else {
                    min_depth[i][j] = right_interval;
                    index[i][j] = index[i - 1][j + (1 << (i - 1))];
                }
            }
        }
        Self {
            min_depth,
            index,
            log2,
        }
    }
    fn build_log2(n: usize) -> Vec<usize> {
        let mut log2 = vec![0usize; n + 1];
        for i in 2..=n {
            log2[i] = log2[i / 2] + 1;
        }
        log2
    }
    fn query_index(&self, l: usize, r: usize) -> usize {
        let len = r - l + 1;
        let i = self.log2[len];
        let left_interval = self.min_depth[i][l];
        let right_interval = self.min_depth[i][r - (1 << i) + 1];
        if left_interval <= right_interval {
            self.index[i][l]
        } else {
            self.index[i][r - (i << i) + 1]
        }
        .unwrap()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::algo::graph::UnweightedAdjacencyList;
    #[test]
    fn test_tree_lowest_commmon_ancestor() {
        let tree = UnweightedAdjacencyList::new_undirected(
            17,
            &[
                [0, 1],
                [0, 2],
                [1, 3],
                [1, 4],
                [2, 5],
                [2, 6],
                [2, 7],
                [3, 8],
                [3, 9],
                [5, 10],
                [5, 11],
                [7, 12],
                [7, 13],
                [11, 14],
                [11, 15],
                [11, 16],
            ],
        );
        let tree = Node::from_adjacency_list(&tree, 0);

        let lca_solver = LcaSolver::new(&tree, 17);
        assert_eq!(lca_solver.lca(14, 13), 2);
        assert_eq!(lca_solver.lca(9, 11), 0);
        assert_eq!(lca_solver.lca(12, 12), 12);
    }
}

pub mod with_generic_sparse_table {

    use super::super::Node;
    use crate::data_structures::sparse_table::SparseTable;

    type IndexAndDepth = (usize, usize);

    pub struct LcaSolver {
        sparse_table:
            SparseTable<IndexAndDepth, Box<dyn Fn(IndexAndDepth, IndexAndDepth) -> IndexAndDepth>>,
        node_order: Vec<usize>,
        // The last occurrence mapping. This mapping keeps track of the last occurrence of a TreeNode in
        // the Euler tour for easy indexing.
        last: Vec<usize>,
    }

    impl LcaSolver {
        pub fn new(root: &Node, size: usize) -> Self {
            let mut node_depth = vec![0usize; size * 2 - 1];
            let mut node_order = vec![0usize; size * 2 - 1];
            let mut last = vec![0usize; size];
            let mut tour_index = 0;

            let mut visit = |node: usize, depth: usize| {
                node_order[tour_index] = node;
                node_depth[tour_index] = depth;
                last[node] = tour_index;
                tour_index += 1;
            };

            //dfs
            let mut stack = vec![(root, 0usize)];
            let mut visited = vec![false; size];
            while let Some((node, depth)) = stack.pop() {
                visit(node.id, depth);
                if !visited[node.id] {
                    visited[node.id] = true;
                    for child in &node.children {
                        stack.push((node, depth)); // revisit the current node after visiting each child
                        stack.push((child, depth + 1));
                    }
                }
            }
            let index_and_depth = node_depth.into_iter().enumerate().collect::<Vec<_>>();
            let f: Box<dyn Fn(IndexAndDepth, IndexAndDepth) -> IndexAndDepth> =
                Box::new(|a: IndexAndDepth, b: IndexAndDepth| if a.1 < b.1 { a } else { b });
            let sparse_table = SparseTable::new(&index_and_depth, f, true);
            Self {
                sparse_table,
                node_order,
                last,
            }
        }
        pub fn lca(&self, a: usize, b: usize) -> usize {
            let (a, b) = (self.last[a], self.last[b]);
            let (l, r) = if a < b { (a, b) } else { (b, a) };
            let idx = self.sparse_table.query(l, r).0;
            self.node_order[idx]
        }
    }

    #[cfg(test)]
    mod tests {
        use super::*;
        use crate::algo::graph::UnweightedAdjacencyList;
        #[test]
        fn test_tree_lowest_commmon_ancestor_with_generic_sparse_table() {
            let tree = UnweightedAdjacencyList::new_undirected(
                17,
                &[
                    [0, 1],
                    [0, 2],
                    [1, 3],
                    [1, 4],
                    [2, 5],
                    [2, 6],
                    [2, 7],
                    [3, 8],
                    [3, 9],
                    [5, 10],
                    [5, 11],
                    [7, 12],
                    [7, 13],
                    [11, 14],
                    [11, 15],
                    [11, 16],
                ],
            );
            let tree = Node::from_adjacency_list(&tree, 0);
            let lca_solver = LcaSolver::new(&tree, 17);
            assert_eq!(lca_solver.lca(14, 13), 2);
            assert_eq!(lca_solver.lca(9, 11), 0);
            assert_eq!(lca_solver.lca(12, 12), 12);
        }
    }
}