#![forbid(unsafe_code)]
#![doc(html_root_url = "https://docs.rs/dag_compute/0.1.0")]
use slotmap::{SlotMap, SecondaryMap, new_key_type};
use slotmap::Key as KeyTrait;
use std::collections::{HashSet, HashMap, VecDeque};
use std::sync::Arc;
use std::ops::Deref;
use std::marker::PhantomData;
use std::fmt;
use log::{info, debug, trace};
new_key_type!{struct ComputeGraphKey;}
type BoxedEvalFn<T> = Box<dyn Fn(&[&T]) -> T + Send + Sync>;
pub(crate) struct Node<T> {
name: String,
func: BoxedEvalFn<T>,
input_nodes: Vec<ComputeGraphKey>,
output_cache: Option<Arc<T>>
}
impl<T> Node<T> {
fn new(name: String, func: BoxedEvalFn<T>) -> Node<T> {
Node {
name,
func,
input_nodes: Vec::default(),
output_cache: None
}
}
pub fn eval(&mut self, args: &[&T]) {
if self.output_cache.is_none() {
self.output_cache = Some(Arc::new((self.func)(args)));
} else {
panic!("Node is already evaluated");
}
}
pub fn computed_val(&self) -> Arc<T> {
if let Some(ref val) = self.output_cache {
val.clone()
} else {
panic!("Node has not yet been evaluated");
}
}
}
impl<T: fmt::Debug> fmt::Debug for Node<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "NodeHandle {{ ")?;
write!(f, "name: {:?}, ", self.name)?;
write!(f, "func: ..., ")?;
write!(f, "input_nodes: {:?}, ", self.input_nodes)?;
write!(f, "output_cache: {:?}", self.output_cache)?;
write!(f, " }}")
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct NodeHandle {
node_key: ComputeGraphKey,
graph_id: usize
}
#[derive(Debug)]
pub struct ComputationGraph<T> {
node_storage: SlotMap<ComputeGraphKey, Node<T>>,
node_refcount: SecondaryMap<ComputeGraphKey, u32>,
output_node: Option<ComputeGraphKey>,
graph_id: usize
}
impl<T> Default for ComputationGraph<T> {
fn default() -> Self {
let mut obj = ComputationGraph {
node_storage: SlotMap::default(),
node_refcount: SecondaryMap::default(),
output_node: None,
graph_id: 0
};
obj.graph_id = (&obj.node_storage as *const SlotMap<_,_>) as usize;
obj
}
}
impl<T> ComputationGraph<T> {
pub fn new() -> ComputationGraph<T>{
ComputationGraph::default()
}
pub fn insert_node(&mut self, name: String, func: BoxedEvalFn<T>) -> NodeHandle {
let node = Node::new(name, func);
let node_key = self.node_storage.insert(node);
self.node_refcount.insert(node_key, 0);
NodeHandle {
node_key,
graph_id: self.graph_id
}
}
pub fn node_name(&self, node: &NodeHandle) -> &str {
assert_eq!(node.graph_id, self.graph_id,
"Received NodeHandle for different graph");
&self.node_storage.get(node.node_key).unwrap().name
}
pub fn designate_output(&mut self, node: &NodeHandle) {
self.output_node.ok_or(()).expect_err("Output was already designated");
assert_eq!(node.graph_id, self.graph_id,
"Received NodeHandle for different graph");
let node_key = node.node_key;
assert!(self.node_storage.contains_key(node_key));
self.output_node = Some(node_key);
*self.node_refcount.get_mut(node_key).unwrap() += 1;
}
pub fn set_inputs(&mut self, node: &mut NodeHandle, inputs: &[&NodeHandle]) {
assert_eq!(node.graph_id, self.graph_id,
"Received NodeHandle for different graph");
let input_keys: Vec<_> = inputs.iter().map(|handle| handle.node_key).collect();
assert!(!input_keys.contains(&node.node_key), "Inputs would create self-loop");
for key in input_keys.iter() {
*self.node_refcount.get_mut(*key).unwrap() += 1;
}
self.node_storage.get_mut(node.node_key).unwrap().input_nodes = input_keys;
}
pub fn dot_graph(&self) -> impl fmt::Display + '_ {
DAGComputeDisplay::new(self)
}
fn computation_order(&mut self) -> impl IntoIterator<Item = ComputeGraphKey> {
debug!("Computing node evaluation order");
let out_node = self.output_node.expect("Output not yet designated");
let mut sort_list = VecDeque::new();
let mut temporary_set = HashSet::new();
self.toposort_helper(out_node, &mut sort_list, &mut temporary_set);
debug_assert!(temporary_set.is_empty());
self.node_storage.retain(|k, del_node| {
let keep = sort_list.contains(&k);
if !keep {
trace!("Sweeping node {}", del_node.name);
for input_key in &del_node.input_nodes {
*self.node_refcount.get_mut(*input_key).unwrap() -= 1;
}
self.node_refcount.remove(k);
} else {
trace!("Keeping node {}", del_node.name)
}
keep
});
sort_list.make_contiguous().reverse();
sort_list
}
fn toposort_helper(&self, node: ComputeGraphKey,
final_list: &mut VecDeque<ComputeGraphKey>,
temporary_set: &mut HashSet<ComputeGraphKey>) {
if final_list.contains(&node) {
return;
}
assert!(!temporary_set.contains(&node), "Computation graph contains cycle");
temporary_set.insert(node);
for input in self.node_storage.get(node).unwrap().input_nodes.iter() {
self.toposort_helper(*input, final_list, temporary_set);
}
temporary_set.remove(&node);
final_list.insert(0, node);
}
pub fn compute(mut self) -> T {
self.output_node.expect("Output not yet designated");
info!("Evaluating DAG");
let compute_order = self.computation_order();
debug!("Computing node values");
for node_key in compute_order {
let node = self.node_storage.get(node_key).unwrap();
trace!("Evaluating node {}", node.name);
let node_input_keyvec = node.input_nodes.clone();
let mut nodes_cleanup = Vec::with_capacity(node_input_keyvec.len());
let node_input_arcs: Vec<_> = node_input_keyvec.into_iter().map(|key| {
let in_refcnt = self.node_refcount.get_mut(key).unwrap();
assert!(*in_refcnt > 0);
*in_refcnt -= 1;
if *in_refcnt == 0 {
nodes_cleanup.push(key);
}
self.node_storage.get(key).unwrap().computed_val()
}).collect();
let mut node_inputs = Vec::with_capacity(node_input_arcs.len());
for arc in node_input_arcs.iter() {
node_inputs.push(arc.deref());
}
for old_key in nodes_cleanup {
self.node_storage.remove(old_key);
self.node_refcount.remove(old_key);
}
let node = self.node_storage.get_mut(node_key).unwrap();
node.eval(node_inputs.as_slice());
}
assert_eq!(self.node_storage.len(), 1);
let output_key = self.output_node.take().unwrap();
let output_node = self.node_storage.remove(output_key).unwrap();
let output_val_arc = output_node.computed_val();
drop(output_node);
Arc::try_unwrap(output_val_arc).ok().unwrap()
}
}
struct DAGComputeDisplay<'a, T> {
slotmap_ref: PhantomData<&'a SlotMap<ComputeGraphKey, Node<T>>>,
names: HashMap<ComputeGraphKey, &'a str>,
output_node: Option<ComputeGraphKey>,
edge_list: Vec<(ComputeGraphKey, ComputeGraphKey)>
}
impl<'a, T> DAGComputeDisplay<'a, T> {
fn new(map: &'a ComputationGraph<T>) -> DAGComputeDisplay<'a, T> {
let true_keyset: HashMap<ComputeGraphKey, &'a str> = map.node_storage
.keys()
.map(|key| (key, map.node_storage.get(key).unwrap().name.as_str()))
.collect();
let mut explored_keyset: HashSet<ComputeGraphKey> = HashSet::new();
let mut edge_list = Vec::new();
while true_keyset.len() > explored_keyset.len() {
debug_assert!(explored_keyset.is_subset(
&true_keyset.keys().copied().collect()));
let mut bfs_queue: VecDeque<ComputeGraphKey> = VecDeque::new();
let mut bfs_root: Option<ComputeGraphKey> = None;
for key in true_keyset.keys() {
if !explored_keyset.contains(key) {
bfs_root = Some(*key);
break;
}
}
let bfs_root = bfs_root.unwrap();
bfs_queue.push_back(bfs_root);
explored_keyset.insert(bfs_root);
while !bfs_queue.is_empty() {
let current = bfs_queue.pop_front().unwrap();
for input in map.node_storage.get(current).unwrap()
.input_nodes.iter() {
edge_list.push((*input, current));
if explored_keyset.insert(*input) {
bfs_queue.push_back(*input);
}
}
}
}
debug_assert_eq!(true_keyset.keys().copied().collect::<HashSet<_>>(),
explored_keyset);
DAGComputeDisplay {
slotmap_ref: PhantomData::default(),
names: true_keyset,
output_node: map.output_node,
edge_list
}
}
}
impl<'a, T> fmt::Display for DAGComputeDisplay<'a, T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(fmt, "strict digraph {{")?;
for (node, name) in self.names.iter() {
let node_id = node.data().as_ffi();
let escaped_name: String = name.chars().map(|c| {
match c {
'"' => r#"\""#.to_owned(),
c => c.to_string()
}
}).collect();
write!(fmt, "{} [label=\"{}\"", node_id, escaped_name)?;
if let Some(out) = self.output_node {
if out == *node {
write!(fmt, ", shape=box")?;
}
}
writeln!(fmt, "];")?;
}
for edge in self.edge_list.iter() {
let from_id = edge.0.data().as_ffi();
let to_id = edge.1.data().as_ffi();
writeln!(fmt, "{}->{};", from_id, to_id)?;
}
writeln!(fmt, "}}")
}
}