1#![forbid(unsafe_code)]
2#![doc(html_root_url = "https://docs.rs/dag_compute/0.1.0")]
3
4use slotmap::{SlotMap, SecondaryMap, new_key_type};
5use slotmap::Key as KeyTrait;
6
7use std::collections::{HashSet, HashMap, VecDeque};
8use std::sync::Arc;
9use std::ops::Deref;
10use std::marker::PhantomData;
11use std::fmt;
12
13use log::{info, debug, trace};
14
15new_key_type!{struct ComputeGraphKey;}
16
17type BoxedEvalFn<T> = Box<dyn Fn(&[&T]) -> T + Send + Sync>;
18
19pub(crate) struct Node<T> {
20 name: String,
21 func: BoxedEvalFn<T>,
22 input_nodes: Vec<ComputeGraphKey>,
23 output_cache: Option<Arc<T>>
24}
25impl<T> Node<T> {
26 fn new(name: String, func: BoxedEvalFn<T>) -> Node<T> {
27 Node {
28 name,
29 func,
30 input_nodes: Vec::default(),
31 output_cache: None
32 }
33 }
34 pub fn eval(&mut self, args: &[&T]) {
37 if self.output_cache.is_none() {
38 self.output_cache = Some(Arc::new((self.func)(args)));
39 } else {
40 panic!("Node is already evaluated");
41 }
42 }
43 pub fn computed_val(&self) -> Arc<T> {
44 if let Some(ref val) = self.output_cache {
45 val.clone()
46 } else {
47 panic!("Node has not yet been evaluated");
48 }
49 }
50}
51impl<T: fmt::Debug> fmt::Debug for Node<T> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 write!(f, "NodeHandle {{ ")?;
54 write!(f, "name: {:?}, ", self.name)?;
55 write!(f, "func: ..., ")?;
56 write!(f, "input_nodes: {:?}, ", self.input_nodes)?;
57 write!(f, "output_cache: {:?}", self.output_cache)?;
58 write!(f, " }}")
59 }
60}
61
62#[derive(Debug, PartialEq, Eq, Hash)]
64pub struct NodeHandle {
66 node_key: ComputeGraphKey,
67 graph_id: usize
68}
69
70#[derive(Debug)]
72pub struct ComputationGraph<T> {
73 node_storage: SlotMap<ComputeGraphKey, Node<T>>,
74 node_refcount: SecondaryMap<ComputeGraphKey, u32>,
75 output_node: Option<ComputeGraphKey>,
76 graph_id: usize
77}
78impl<T> Default for ComputationGraph<T> {
79 fn default() -> Self {
80 let mut obj = ComputationGraph {
81 node_storage: SlotMap::default(),
82 node_refcount: SecondaryMap::default(),
83 output_node: None,
84 graph_id: 0
85 };
86 obj.graph_id = (&obj.node_storage as *const SlotMap<_,_>) as usize;
89 obj
90 }
91}
92impl<T> ComputationGraph<T> {
93 pub fn new() -> ComputationGraph<T>{
94 ComputationGraph::default()
95 }
96 pub fn insert_node(&mut self, name: String, func: BoxedEvalFn<T>) -> NodeHandle {
101 let node = Node::new(name, func);
102 let node_key = self.node_storage.insert(node);
103 self.node_refcount.insert(node_key, 0);
104 NodeHandle {
105 node_key,
106 graph_id: self.graph_id
107 }
108 }
109 pub fn node_name(&self, node: &NodeHandle) -> &str {
111 assert_eq!(node.graph_id, self.graph_id,
112 "Received NodeHandle for different graph");
113 &self.node_storage.get(node.node_key).unwrap().name
114 }
115 pub fn designate_output(&mut self, node: &NodeHandle) {
117 self.output_node.ok_or(()).expect_err("Output was already designated");
118 assert_eq!(node.graph_id, self.graph_id,
119 "Received NodeHandle for different graph");
120 let node_key = node.node_key;
121 assert!(self.node_storage.contains_key(node_key));
122 self.output_node = Some(node_key);
123 *self.node_refcount.get_mut(node_key).unwrap() += 1;
124 }
125 pub fn set_inputs(&mut self, node: &mut NodeHandle, inputs: &[&NodeHandle]) {
130 assert_eq!(node.graph_id, self.graph_id,
131 "Received NodeHandle for different graph");
132 let input_keys: Vec<_> = inputs.iter().map(|handle| handle.node_key).collect();
133 assert!(!input_keys.contains(&node.node_key), "Inputs would create self-loop");
136 for key in input_keys.iter() {
139 *self.node_refcount.get_mut(*key).unwrap() += 1;
140 }
141 self.node_storage.get_mut(node.node_key).unwrap().input_nodes = input_keys;
142 }
143 pub fn dot_graph(&self) -> impl fmt::Display + '_ {
147 DAGComputeDisplay::new(self)
148 }
149
150 fn computation_order(&mut self) -> impl IntoIterator<Item = ComputeGraphKey> {
152 debug!("Computing node evaluation order");
153 let out_node = self.output_node.expect("Output not yet designated");
154
155 let mut sort_list = VecDeque::new();
157 let mut temporary_set = HashSet::new();
158 self.toposort_helper(out_node, &mut sort_list, &mut temporary_set);
159 debug_assert!(temporary_set.is_empty());
160
161 self.node_storage.retain(|k, del_node| {
163 let keep = sort_list.contains(&k);
164 if !keep {
165 trace!("Sweeping node {}", del_node.name);
166 for input_key in &del_node.input_nodes {
167 *self.node_refcount.get_mut(*input_key).unwrap() -= 1;
168 }
169 self.node_refcount.remove(k);
170 } else {
171 trace!("Keeping node {}", del_node.name)
172 }
173 keep
174 });
175 sort_list.make_contiguous().reverse();
181 sort_list
182 }
183 fn toposort_helper(&self, node: ComputeGraphKey,
185 final_list: &mut VecDeque<ComputeGraphKey>,
186 temporary_set: &mut HashSet<ComputeGraphKey>) {
187 if final_list.contains(&node) {
188 return;
189 }
190 assert!(!temporary_set.contains(&node), "Computation graph contains cycle");
191 temporary_set.insert(node);
192 for input in self.node_storage.get(node).unwrap().input_nodes.iter() {
193 self.toposort_helper(*input, final_list, temporary_set);
194 }
195 temporary_set.remove(&node);
196 final_list.insert(0, node);
197 }
198
199 pub fn compute(mut self) -> T {
201 self.output_node.expect("Output not yet designated");
202 info!("Evaluating DAG");
203 let compute_order = self.computation_order();
204 debug!("Computing node values");
205 for node_key in compute_order {
206 let node = self.node_storage.get(node_key).unwrap();
207 trace!("Evaluating node {}", node.name);
208
209 let node_input_keyvec = node.input_nodes.clone();
210 let mut nodes_cleanup = Vec::with_capacity(node_input_keyvec.len());
211 let node_input_arcs: Vec<_> = node_input_keyvec.into_iter().map(|key| {
212 let in_refcnt = self.node_refcount.get_mut(key).unwrap();
213 assert!(*in_refcnt > 0);
214 *in_refcnt -= 1;
215 if *in_refcnt == 0 {
216 nodes_cleanup.push(key);
217 }
218 self.node_storage.get(key).unwrap().computed_val()
220 }).collect();
221 let mut node_inputs = Vec::with_capacity(node_input_arcs.len());
223 for arc in node_input_arcs.iter() {
224 node_inputs.push(arc.deref());
225 }
226
227 for old_key in nodes_cleanup {
228 self.node_storage.remove(old_key);
229 self.node_refcount.remove(old_key);
230 }
231 let node = self.node_storage.get_mut(node_key).unwrap();
233 node.eval(node_inputs.as_slice());
234 }
235 assert_eq!(self.node_storage.len(), 1);
237 let output_key = self.output_node.take().unwrap();
238 let output_node = self.node_storage.remove(output_key).unwrap();
240 let output_val_arc = output_node.computed_val();
241 drop(output_node);
242 Arc::try_unwrap(output_val_arc).ok().unwrap()
248 }
249}
250
251struct DAGComputeDisplay<'a, T> {
252 slotmap_ref: PhantomData<&'a SlotMap<ComputeGraphKey, Node<T>>>,
258 names: HashMap<ComputeGraphKey, &'a str>,
259 output_node: Option<ComputeGraphKey>,
260 edge_list: Vec<(ComputeGraphKey, ComputeGraphKey)>
261}
262impl<'a, T> DAGComputeDisplay<'a, T> {
263 fn new(map: &'a ComputationGraph<T>) -> DAGComputeDisplay<'a, T> {
264 let true_keyset: HashMap<ComputeGraphKey, &'a str> = map.node_storage
265 .keys()
266 .map(|key| (key, map.node_storage.get(key).unwrap().name.as_str()))
267 .collect();
268 let mut explored_keyset: HashSet<ComputeGraphKey> = HashSet::new();
269 let mut edge_list = Vec::new();
270 while true_keyset.len() > explored_keyset.len() {
273 debug_assert!(explored_keyset.is_subset(
274 &true_keyset.keys().copied().collect()));
275 let mut bfs_queue: VecDeque<ComputeGraphKey> = VecDeque::new();
277 let mut bfs_root: Option<ComputeGraphKey> = None;
278 for key in true_keyset.keys() {
279 if !explored_keyset.contains(key) {
280 bfs_root = Some(*key);
281 break;
282 }
283 }
284 let bfs_root = bfs_root.unwrap(); bfs_queue.push_back(bfs_root);
287 explored_keyset.insert(bfs_root);
288 while !bfs_queue.is_empty() {
289 let current = bfs_queue.pop_front().unwrap();
290 for input in map.node_storage.get(current).unwrap()
291 .input_nodes.iter() {
292 edge_list.push((*input, current));
293 if explored_keyset.insert(*input) {
295 bfs_queue.push_back(*input);
296 }
297 }
298 }
299 }
300 debug_assert_eq!(true_keyset.keys().copied().collect::<HashSet<_>>(),
301 explored_keyset);
302 DAGComputeDisplay {
303 slotmap_ref: PhantomData::default(),
304 names: true_keyset,
305 output_node: map.output_node,
306 edge_list
307 }
308 }
309}
310impl<'a, T> fmt::Display for DAGComputeDisplay<'a, T> {
311 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
312 writeln!(fmt, "strict digraph {{")?;
313 for (node, name) in self.names.iter() {
314 let node_id = node.data().as_ffi();
315 let escaped_name: String = name.chars().map(|c| {
316 match c {
317 '"' => r#"\""#.to_owned(),
318 c => c.to_string()
319 }
320 }).collect();
321 write!(fmt, "{} [label=\"{}\"", node_id, escaped_name)?;
322 if let Some(out) = self.output_node {
323 if out == *node {
324 write!(fmt, ", shape=box")?;
325 }
326 }
327 writeln!(fmt, "];")?;
328 }
329 for edge in self.edge_list.iter() {
330 let from_id = edge.0.data().as_ffi();
332 let to_id = edge.1.data().as_ffi();
333 writeln!(fmt, "{}->{};", from_id, to_id)?;
334 }
335 writeln!(fmt, "}}")
336 }
337}