use std::marker::PhantomData;
use crate::{Vertex, Graph};
use crate::algorithms::{VisitorDFSAction, VisitorDFS, dfs_with_visitor};
use crate::error::{GraphError, ErrorKind};
pub struct SubTreesSize<'a, T> {
values: Vec<Option<usize>>,
phantom: PhantomData<&'a T>,
}
impl <'a, T> SubTreesSize<'a, T> where T: Copy {
pub fn get_subtree_size(&self, target: &Vertex<T>) -> Option<usize> {
self.values[target.id()]
}
}
struct CustomVisitor {
values: Vec<Option<usize>>,
cycle: bool
}
impl<T> VisitorDFS<T> for CustomVisitor {
fn entry_to_vertex_event(&mut self, vertex: &Vertex<T>) -> Result<VisitorDFSAction, GraphError> {
self.values[vertex.id] = Some(1);
Ok(VisitorDFSAction::Nothing)
}
fn exit_from_white_vertex_event(&mut self, vertex: &Vertex<T>, parent: &Vertex<T>, _grand_parent: Option<&Vertex<T>>) -> Result<VisitorDFSAction, GraphError> {
self.values[parent.id] = Some(self.values[parent.id].unwrap() + self.values[vertex.id].unwrap());
Ok(VisitorDFSAction::Nothing)
}
fn entry_to_grey_vertex_event(&mut self, vertex: &Vertex<T>, _parent: &Vertex<T>, grand_parent: Option<&Vertex<T>>) -> Result<VisitorDFSAction, GraphError> {
if !self.cycle && vertex.id != grand_parent.unwrap().id {
self.cycle = true;
return Ok(VisitorDFSAction::Break);
}
Ok(VisitorDFSAction::Nothing)
}
}
pub fn find_subtrees_size<'a, T>(graph: &'a Graph<T>, from: &'a Vertex<T>) -> Result<SubTreesSize<'a, T>, GraphError> where T: Default + Copy {
let mut visitor = CustomVisitor{
values: vec![None; graph.size()],
cycle: false
};
dfs_with_visitor(graph, from , &mut visitor)?;
if visitor.cycle {
return Err(GraphError::Regular(ErrorKind::TreeContainsCycle));
}
Ok(SubTreesSize{ values: visitor.values, phantom: PhantomData })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_subtrees_size() {
let mut graph = Graph::new(12);
graph.add_edge(1, 4, 0).unwrap();
graph.add_edge(1, 2, 0).unwrap();
graph.add_edge(4, 11, 0).unwrap();
graph.add_edge(4, 12, 0).unwrap();
graph.add_edge(12, 3, 0).unwrap();
graph.add_edge(2, 5, 0).unwrap();
graph.add_edge(2, 6, 0).unwrap();
graph.add_edge(5, 9, 0).unwrap();
graph.add_edge(5, 10, 0).unwrap();
graph.add_edge(6, 7, 0).unwrap();
graph.add_edge(7, 8, 0).unwrap();
let subtrees_size = find_subtrees_size(&graph, graph.get_vertex(1).unwrap()).unwrap();
assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(1).unwrap()), Some(12));
assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(2).unwrap()), Some(7));
assert_eq!(subtrees_size.get_subtree_size(graph.get_vertex(4).unwrap()), Some(4));
}
}