use crate::index::IndexBase;
use crate::{Hierarchy, NodeIndex, PortView, UnmanagedDenseMap};
pub fn lca<N: IndexBase>(
graph: impl PortView<NodeIndexBase = N>,
hierarchy: &Hierarchy<N>,
) -> LCA<N> {
LCA::new(graph, hierarchy)
}
#[derive(Debug, Default, Clone)]
pub struct LCA<N: IndexBase = u32> {
first_visit: UnmanagedDenseMap<NodeIndex<N>, usize>,
last_visit: UnmanagedDenseMap<NodeIndex<N>, usize>,
climb_nodes: UnmanagedDenseMap<NodeIndex<N>, Vec<NodeIndex<N>>>,
}
impl<N: IndexBase> LCA<N> {
pub fn new(graph: impl PortView<NodeIndexBase = N>, hierarchy: &Hierarchy<N>) -> Self {
let capacity = graph.node_capacity();
let mut lca = LCA {
first_visit: UnmanagedDenseMap::with_capacity(capacity),
last_visit: UnmanagedDenseMap::with_capacity(capacity),
climb_nodes: UnmanagedDenseMap::with_capacity(capacity),
};
let mut timestamp = 0;
let mut stack = vec![];
for root in graph.nodes_iter() {
debug_assert!(stack.is_empty());
if !hierarchy.is_root(root) {
continue;
}
stack.push(DFSState::<N>::Visit {
node: root,
parent: None,
});
while let Some(state) = stack.pop() {
match state {
DFSState::<N>::Visit { node, parent } => {
lca.first_visit[node] = timestamp;
timestamp += 1;
let climb: Vec<NodeIndex<N>> = (0..)
.scan(parent, |prev, i| {
let ith_parent = (*prev)?;
*prev = lca.climb_nodes[ith_parent].get(i).copied();
Some(ith_parent)
})
.collect();
if !climb.is_empty() {
lca.climb_nodes[node] = climb;
}
stack.push(DFSState::<N>::Finish { node });
for child in hierarchy.children(node) {
stack.push(DFSState::<N>::Visit {
node: child,
parent: Some(node),
});
}
}
DFSState::<N>::Finish { node } => {
lca.last_visit[node] = timestamp;
timestamp += 1;
}
}
}
}
lca
}
pub fn is_ancestor(&self, a: NodeIndex<N>, b: NodeIndex<N>) -> bool {
self.first_visit[a] <= self.first_visit[b] && self.last_visit[a] >= self.last_visit[b]
}
pub fn root(&self, node: NodeIndex<N>) -> NodeIndex<N> {
let mut u = node;
while let Some(&v) = self.climb_nodes[u].last() {
u = v;
}
u
}
pub fn lca(&self, a: NodeIndex<N>, b: NodeIndex<N>) -> Option<NodeIndex<N>> {
if self.is_ancestor(a, b) {
return Some(a);
}
if self.is_ancestor(b, a) {
return Some(b);
}
if self.root(a) != self.root(b) {
return None;
}
let mut u = itertools::iterate(Some(a), |u| {
u.and_then(|u| self.climb_nodes[u].last().copied())
})
.take_while(|u| u.is_some_and(|u| !self.is_ancestor(u, b)))
.last()??;
let mut i = self.climb_nodes[u].len() - 1;
while i > 0 {
i -= 1;
let v = self.climb_nodes[u][i];
if !self.is_ancestor(v, b) {
u = v;
i = i.max(self.climb_nodes[u].len() - 1);
}
}
Some(self.climb_nodes[u][0])
}
}
#[derive(Debug, Clone, Copy, Hash)]
enum DFSState<N: IndexBase> {
Visit {
node: NodeIndex<N>,
parent: Option<NodeIndex<N>>,
},
Finish { node: NodeIndex<N> },
}
#[cfg(test)]
mod test {
use crate::PortMut;
use rstest::{fixture, rstest};
type PortGraph = crate::PortGraph<u32, u32, u16>;
type Hierarchy = crate::Hierarchy<u32>;
type NodeIndex = crate::NodeIndex<u32>;
#[allow(clippy::upper_case_acronyms)]
type LCA = super::LCA<u32>;
#[fixture]
fn test_hierarchy() -> (PortGraph, Hierarchy) {
let mut graph = PortGraph::with_capacity(16, 0);
for _ in 0..16 {
graph.add_node(0, 0);
}
let mut hier = Hierarchy::with_capacity(16);
let edges = [
(0, 1),
(0, 2),
(1, 3),
(3, 4),
(4, 5),
(5, 6),
(1, 7),
(2, 8),
(8, 9),
(8, 10),
(11, 12),
(11, 13),
];
for (parent, node) in edges {
hier.push_child(NodeIndex::new(node), NodeIndex::new(parent))
.unwrap();
}
(graph, hier)
}
#[rstest]
fn lca(test_hierarchy: (PortGraph, Hierarchy)) {
let lca = LCA::new(&test_hierarchy.0, &test_hierarchy.1);
let n = NodeIndex::new;
assert_eq!(lca.lca(n(5), n(10)), Some(n(0)));
assert_eq!(lca.lca(n(10), n(5)), Some(n(0)));
assert_eq!(lca.lca(n(6), n(10)), Some(n(0)));
assert_eq!(lca.lca(n(10), n(6)), Some(n(0)));
for node in 0..=10 {
assert_eq!(lca.root(n(node)), n(0));
}
for node in 11..=13 {
assert_eq!(lca.root(n(node)), n(11));
}
for node in 14..=15 {
assert_eq!(lca.root(n(node)), n(node));
}
assert_eq!(lca.lca(n(0), n(0)), Some(n(0)));
assert_eq!(lca.lca(n(0), n(1)), Some(n(0)));
assert_eq!(lca.lca(n(0), n(9)), Some(n(0)));
assert_eq!(lca.lca(n(1), n(0)), Some(n(0)));
assert_eq!(lca.lca(n(9), n(0)), Some(n(0)));
assert_eq!(lca.lca(n(0), n(11)), None);
assert_eq!(lca.lca(n(0), n(12)), None);
assert_eq!(lca.lca(n(0), n(14)), None);
assert_eq!(lca.lca(n(11), n(0)), None);
assert_eq!(lca.lca(n(12), n(0)), None);
assert_eq!(lca.lca(n(14), n(0)), None);
assert_eq!(lca.lca(n(14), n(14)), Some(n(14)));
assert_eq!(lca.lca(n(14), n(15)), None);
assert_eq!(lca.lca(n(1), n(2)), Some(n(0)));
assert_eq!(lca.lca(n(7), n(8)), Some(n(0)));
assert_eq!(lca.lca(n(7), n(10)), Some(n(0)));
assert_eq!(lca.lca(n(10), n(7)), Some(n(0)));
assert_eq!(lca.lca(n(5), n(9)), Some(n(0)));
assert_eq!(lca.lca(n(9), n(5)), Some(n(0)));
assert_eq!(lca.lca(n(6), n(9)), Some(n(0)));
assert_eq!(lca.lca(n(9), n(6)), Some(n(0)));
assert_eq!(lca.lca(n(2), n(10)), Some(n(2)));
assert_eq!(lca.lca(n(10), n(2)), Some(n(2)));
assert_eq!(lca.lca(n(9), n(12)), None);
}
}