use crate::{Graph, Vertex, algorithms};
use crate::algorithms::{VisitorDFS, VisitorDFSAction, dfs_with_visitor};
use crate::error::{GraphError, ErrorKind};
struct CustomVisitor {
orders: Vec<usize>
}
impl<T> VisitorDFS<T> for CustomVisitor {
fn entry_to_vertex_event(&mut self, vertex: &Vertex<T>) -> Result<VisitorDFSAction, GraphError> {
self.orders.push(vertex.id);
Ok(VisitorDFSAction::Nothing)
}
fn exit_from_white_vertex_event(&mut self, _vertex: &Vertex<T>, parent: &Vertex<T>, _grand_parent: Option<&Vertex<T>>) -> Result<VisitorDFSAction, GraphError> {
self.orders.push(parent.id);
Ok(VisitorDFSAction::Nothing)
}
}
pub struct LCA<'a, T> {
borders: Vec<usize>,
rmq: RMQLCA,
tree: &'a Graph<T>
}
impl <'a, T> LCA <'a, T> {
pub fn build(tree: &'a Graph<T>, root: &'a Vertex<T>) -> Result<LCA<'a, T>, GraphError> where T: std::cmp::PartialOrd + Copy + Default {
let mut depths = vec![0; tree.size()];
let vertices_depths = algorithms::find_vertices_depths(tree, root)?;
for (idx, item) in depths.iter_mut().enumerate().take(tree.size()).skip(1) {
let value = vertices_depths.get_vertex_depth(tree.get_vertex(idx).unwrap());
if value.is_none() {
return Err(GraphError::Regular(ErrorKind::GraphNotConnected));
}
*item = value.unwrap();
}
let mut visitor = CustomVisitor{
orders: vec![0],
};
dfs_with_visitor(tree, root, &mut visitor)?;
let mut borders = vec![0; tree.size()];
for (idx, value) in visitor.orders.iter().enumerate().rev() {
borders[*value] = idx;
}
let rmq = RMQLCA::build(&visitor.orders, depths);
Ok(LCA{borders, rmq, tree })
}
pub fn query(&self, first: &Vertex<T>, second: &Vertex<T>) -> &Vertex<T> {
self.tree.get_vertex(self.rmq.query(self.borders[first.id], self.borders[second.id])).unwrap()
}
}
#[allow(clippy::upper_case_acronyms)]
struct RMQLCA {
data: Vec<usize>,
depths: Vec<usize>
}
impl RMQLCA {
fn build(src: &[usize], depths: Vec<usize>) -> Self {
if src.is_empty() {
return RMQLCA { data: vec![], depths };
}
let n = calculate_size_array(src.len());
let mut dst = vec![usize::MAX; n];
for (i, value) in src.iter().enumerate() {
dst[n / 2 + i] = *value;
}
for i in (1..n / 2).rev() {
if depths[dst[2 * i ]] < depths[dst[2 * i + 1]] {
dst[i] = dst[2 * i];
} else {
dst[i] = dst[2 * i + 1];
}
}
RMQLCA{
data: dst,
depths
}
}
fn query(&self, l: usize, r: usize) -> usize {
let mut l = l + self.data.len() / 2;
let mut r = r + self.data.len() / 2;
if l > r {
std::mem::swap(&mut l, &mut r);
}
let mut value = usize::MAX;
let mut res = 0;
while l <= r {
if l % 2 != 0 && self.depths[self.data[l]] < value {
value = self.depths[self.data[l]];
res = self.data[l];
}
l = (l + 1) >> 1;
if r % 2 == 0 && self.depths[self.data[r]] < value {
value = self.depths[self.data[r]];
res = self.data[r];
}
r = (r - 1 ) >> 1;
}
res
}
}
fn calculate_size_array(n: usize) -> usize {
let mut cnt = 1;
while cnt < n {
cnt <<= 1;
}
cnt << 1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lca() {
let mut graph = Graph::new(8);
graph.add_edge(1, 2, 0).unwrap();
graph.add_edge(1, 3, 0).unwrap();
graph.add_edge(2, 4, 0).unwrap();
graph.add_edge(2, 5, 0).unwrap();
graph.add_edge(3, 6, 0).unwrap();
graph.add_edge(3, 7, 0).unwrap();
graph.add_edge(7, 8, 0).unwrap();
let lca = LCA::build(&graph, graph.get_vertex(1).unwrap()).unwrap();
assert_eq!(2, lca.query(graph.get_vertex(4).unwrap(), graph.get_vertex(5).unwrap()).id);
assert_eq!(1, lca.query(graph.get_vertex(4).unwrap(), graph.get_vertex(8).unwrap()).id);
assert_eq!(2, lca.query(graph.get_vertex(4).unwrap(), graph.get_vertex(2).unwrap()).id);
assert_eq!(1, lca.query(graph.get_vertex(2).unwrap(), graph.get_vertex(8).unwrap()).id);
assert_eq!(3, lca.query(graph.get_vertex(6).unwrap(), graph.get_vertex(8).unwrap()).id);
}
#[test]
#[should_panic]
fn test_lca_not_connected_graph() {
let mut graph = Graph::new(9);
graph.add_edge(1, 2, 0).unwrap();
graph.add_edge(1, 3, 0).unwrap();
graph.add_edge(2, 4, 0).unwrap();
graph.add_edge(2, 5, 0).unwrap();
graph.add_edge(3, 6, 0).unwrap();
graph.add_edge(3, 7, 0).unwrap();
graph.add_edge(7, 8, 0).unwrap();
LCA::build(&graph, graph.get_vertex(1).unwrap()).unwrap();
}
}