use std::collections::HashMap;
use std::hash::Hash;
use petgraph::algo::toposort;
use petgraph::Direction;
use petgraph::dot::{Config, Dot};
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::StableDiGraph;
use petgraph::visit::{Control, depth_first_search, DfsEvent, NodeRef, Visitable, VisitMap};
use crate::dep::state::State;
use crate::task::TaskId;
pub(crate) struct Tree<K> {
inner: StableDiGraph<State, (), TaskId>,
to_id: HashMap<K, TaskId>,
}
impl<K> Tree<K> where K: Clone + Eq + Hash {
pub(crate) fn new() -> Self {
Self { inner: Default::default(), to_id: Default::default() }
}
pub(crate) fn find_or_add_node(&mut self, k: K) -> TaskId {
*self.to_id.entry(k).or_insert_with(|| self.inner.add_node(State::Waiting).index())
}
pub(crate) fn find_or_add_edge(&mut self, k0: K, k1: K) -> TaskId {
let node0 = NodeIndex::from(self.find_or_add_node(k0));
let node1 = NodeIndex::from(self.find_or_add_node(k1));
self.inner.find_edge(node0, node1).unwrap_or_else(|| self.inner.add_edge(node0, node1, ())).index()
}
pub(crate) fn retain_dependencies(&mut self, ks: Vec<K>) {
let mut vm = self.inner.visit_map();
for w in ks {
let node = NodeIndex::from(self.find_or_add_node(w));
depth_first_search(&self.inner, Some(node), |event| {
if let DfsEvent::Discover(n, _time) = event {
if vm.is_visited(&n) { return Control::<NodeIndex>::Prune; }
vm.visit(n);
}
Control::<NodeIndex>::Continue
});
}
self.inner.retain_nodes(|_fg, n| vm.is_visited(&n));
}
pub(crate) fn find_next_node(&self) -> Result<Option<TaskId>, ()> {
if self.inner.node_count() == 0 { return Err(()); }
let waiting_map = |_, s: &State| (*s == State::Waiting).then_some(*s);
let waiting_nodes = self.inner.filter_map(waiting_map, |_, e| Some(e));
if waiting_nodes.node_count() == 0 { return Err(()); }
let ready_map = |i: NodeIndex<TaskId>, s: &State| {
let mut neighbors = self.inner.neighbors_directed(i, Direction::Outgoing);
neighbors.all(|i| self.inner[i] == State::Success).then_some(*s)
};
let ready_nodes = waiting_nodes.filter_map(ready_map, |_, e| Some(e));
if ready_nodes.node_count() == 0 { return Ok(None); }
Ok(toposort(&ready_nodes, None).map_err(|_| ())?.last().map(|ni| TaskId::from(ni.index())))
}
pub(crate) fn children(&self, id: TaskId) -> Vec<TaskId> {
self.inner.neighbors_directed(NodeIndex::from(id), Direction::Outgoing).map(|n| n.index()).collect()
}
pub(crate) fn parents(&self, id: TaskId) -> Vec<TaskId> {
self.inner.neighbors_directed(NodeIndex::from(id), Direction::Incoming).map(|n| n.index()).collect()
}
pub(crate) fn update_node(&mut self, id: TaskId, state: State) {
let node = NodeIndex::from(id);
self.inner[node] = state;
if state == State::Failure {
for id in self.parents(node.index()) {
self.update_node(id, state);
}
}
}
}
impl<K> ToString for Tree<K> {
fn to_string(&self) -> String {
let dot = Dot::with_attr_getters(
&self.inner,
&[Config::NodeIndexLabel, Config::EdgeNoLabel],
&|_g, _e| "".to_string(),
&|_g, n| format!("color={}", n.weight().color()),
);
format!("{:?}", dot)
}
}