dag_compute/
lib.rs

1#![forbid(unsafe_code)]
2#![doc(html_root_url = "https://docs.rs/dag_compute/0.1.0")]
3
4use slotmap::{SlotMap, SecondaryMap, new_key_type};
5use slotmap::Key as KeyTrait;
6
7use std::collections::{HashSet, HashMap, VecDeque};
8use std::sync::Arc;
9use std::ops::Deref;
10use std::marker::PhantomData;
11use std::fmt;
12
13use log::{info, debug, trace};
14
15new_key_type!{struct ComputeGraphKey;}
16
17type BoxedEvalFn<T> = Box<dyn Fn(&[&T]) -> T + Send + Sync>;
18
19pub(crate) struct Node<T> {
20    name: String,
21    func: BoxedEvalFn<T>,
22    input_nodes: Vec<ComputeGraphKey>,
23    output_cache: Option<Arc<T>>
24}
25impl<T> Node<T> {
26    fn new(name: String, func: BoxedEvalFn<T>) -> Node<T> {
27        Node {
28            name,
29            func,
30            input_nodes: Vec::default(),
31            output_cache: None
32        }
33    }
34    // Passing arg slice instead of node handles is a leaky encapsulation
35    // Doesn't seem to be possible to remove leakiness safely though?
36    pub fn eval(&mut self, args: &[&T]) {
37        if self.output_cache.is_none() {
38            self.output_cache = Some(Arc::new((self.func)(args)));
39        } else {
40            panic!("Node is already evaluated");
41        }
42    }
43    pub fn computed_val(&self) -> Arc<T> {
44        if let Some(ref val) = self.output_cache {
45            val.clone()
46        } else {
47            panic!("Node has not yet been evaluated");
48        }
49    }
50}
51impl<T: fmt::Debug> fmt::Debug for Node<T> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "NodeHandle {{ ")?;
54        write!(f, "name: {:?}, ", self.name)?;
55        write!(f, "func: ..., ")?;
56        write!(f, "input_nodes: {:?}, ", self.input_nodes)?;
57        write!(f, "output_cache: {:?}", self.output_cache)?;
58        write!(f, " }}")
59    }
60}
61
62// DO NOT DERIVE Copy OR Clone: HANDLE MUST BE NON-FUNGIBLE
63#[derive(Debug, PartialEq, Eq, Hash)]
64/// An opaque handle to a node in a [`ComputationGraph`].
65pub struct NodeHandle {
66    node_key: ComputeGraphKey,
67    graph_id: usize
68}
69
70/// A DAG that expresses a computation flow between nodes.
71#[derive(Debug)]
72pub struct ComputationGraph<T> {
73    node_storage: SlotMap<ComputeGraphKey, Node<T>>,
74    node_refcount: SecondaryMap<ComputeGraphKey, u32>,
75    output_node: Option<ComputeGraphKey>,
76    graph_id: usize
77}
78impl<T> Default for ComputationGraph<T> {
79    fn default() -> Self {
80        let mut obj = ComputationGraph {
81            node_storage: SlotMap::default(),
82            node_refcount: SecondaryMap::default(),
83            output_node: None,
84            graph_id: 0
85        };
86        // Use pointer numerical value to tie NodeHandles to ComputationGraphs
87        // No potential risks here as we only need this as an opaque token
88        obj.graph_id = (&obj.node_storage as *const SlotMap<_,_>) as usize;
89        obj
90    }
91}
92impl<T> ComputationGraph<T> {
93    pub fn new() -> ComputationGraph<T>{
94        ComputationGraph::default()
95    }
96    /// Inserts a new node, returning an opaque node handle.
97    /// 
98    /// While the library does not enforce name uniqueness, this is
99    /// highly recommended to make debugging easier.
100    pub fn insert_node(&mut self, name: String, func: BoxedEvalFn<T>) -> NodeHandle {
101        let node = Node::new(name, func);
102        let node_key = self.node_storage.insert(node);
103        self.node_refcount.insert(node_key, 0);
104        NodeHandle {
105            node_key,
106            graph_id: self.graph_id
107        }
108    }
109    /// Returns a reference to a node's name.
110    pub fn node_name(&self, node: &NodeHandle) -> &str {
111        assert_eq!(node.graph_id, self.graph_id,
112            "Received NodeHandle for different graph");
113        &self.node_storage.get(node.node_key).unwrap().name
114    }
115    /// Designates the given node as the output node.
116    pub fn designate_output(&mut self, node: &NodeHandle) {
117        self.output_node.ok_or(()).expect_err("Output was already designated");
118        assert_eq!(node.graph_id, self.graph_id,
119            "Received NodeHandle for different graph");
120        let node_key = node.node_key;
121        assert!(self.node_storage.contains_key(node_key));
122        self.output_node = Some(node_key);
123        *self.node_refcount.get_mut(node_key).unwrap() += 1;
124    }
125    /// Sets the given node's inputs.
126    /// 
127    /// It is the caller's responsibility to avoid creating loops,
128    /// which are only detected at computation time.
129    pub fn set_inputs(&mut self, node: &mut NodeHandle, inputs: &[&NodeHandle]) {
130        assert_eq!(node.graph_id, self.graph_id,
131            "Received NodeHandle for different graph");
132        let input_keys: Vec<_> = inputs.iter().map(|handle| handle.node_key).collect();
133        // Mutability rules actually enforce the non-circular-loop case
134        // Keep assert in case duplication happens elsewhere
135        assert!(!input_keys.contains(&node.node_key), "Inputs would create self-loop");
136        // Other cycles would be caught at computation time
137
138        for key in input_keys.iter() {
139            *self.node_refcount.get_mut(*key).unwrap() += 1;
140        }
141        self.node_storage.get_mut(node.node_key).unwrap().input_nodes = input_keys;
142    }
143    /// Emits a DOT graph of the computation graph.
144    /// 
145    /// Nodes are labeled with names, and the output node is rectangular.
146    pub fn dot_graph(&self) -> impl fmt::Display + '_ {
147        DAGComputeDisplay::new(self)
148    }
149
150    /// Determines a valid order for node evaluation.
151    fn computation_order(&mut self) -> impl IntoIterator<Item = ComputeGraphKey> {
152        debug!("Computing node evaluation order");
153        let out_node = self.output_node.expect("Output not yet designated");
154
155        // Toposort the graph, marking used nodes
156        let mut sort_list = VecDeque::new();
157        let mut temporary_set = HashSet::new();
158        self.toposort_helper(out_node, &mut sort_list, &mut temporary_set);
159        debug_assert!(temporary_set.is_empty());
160
161        // Sweep phase of mark-and-sweep GC
162        self.node_storage.retain(|k, del_node| {
163            let keep = sort_list.contains(&k);
164            if !keep {
165                trace!("Sweeping node {}", del_node.name);
166                for input_key in &del_node.input_nodes {
167                    *self.node_refcount.get_mut(*input_key).unwrap() -= 1;
168                }
169                self.node_refcount.remove(k);
170            } else {
171                trace!("Keeping node {}", del_node.name)
172            }
173            keep
174        });
175        /*
176         * We traversed the edge in the opposite direction of the dataflow
177         * Reverse now to get the correct directions
178         * WARNING: this is valid for DFS-obtained toposort but not in general
179         */
180        sort_list.make_contiguous().reverse();
181        sort_list
182    }
183    // Adapted from the DFS-based toposort of https://en.wikipedia.org/wiki/Topological_sorting
184    fn toposort_helper(&self, node: ComputeGraphKey,
185            final_list: &mut VecDeque<ComputeGraphKey>,
186            temporary_set: &mut HashSet<ComputeGraphKey>) {
187        if final_list.contains(&node) {
188            return;
189        }
190        assert!(!temporary_set.contains(&node), "Computation graph contains cycle");
191        temporary_set.insert(node);
192        for input in self.node_storage.get(node).unwrap().input_nodes.iter() {
193            self.toposort_helper(*input, final_list, temporary_set);
194        }
195        temporary_set.remove(&node);
196        final_list.insert(0, node);
197    }
198
199    /// Computes and returns the value of the output node.
200    pub fn compute(mut self) -> T {
201        self.output_node.expect("Output not yet designated");
202        info!("Evaluating DAG");
203        let compute_order = self.computation_order();
204        debug!("Computing node values");
205        for node_key in compute_order {
206            let node = self.node_storage.get(node_key).unwrap();
207            trace!("Evaluating node {}", node.name);
208
209            let node_input_keyvec = node.input_nodes.clone();
210            let mut nodes_cleanup = Vec::with_capacity(node_input_keyvec.len());
211            let node_input_arcs: Vec<_> = node_input_keyvec.into_iter().map(|key| {
212                let in_refcnt = self.node_refcount.get_mut(key).unwrap();
213                assert!(*in_refcnt > 0);
214                *in_refcnt -= 1;
215                if *in_refcnt == 0 {
216                    nodes_cleanup.push(key);
217                }
218                // Toposort guarantees that inputs will be ready when needed
219                self.node_storage.get(key).unwrap().computed_val()
220            }).collect();
221            // The refs in node_inputs are live as long as node_input_arcs is
222            let mut node_inputs = Vec::with_capacity(node_input_arcs.len());
223            for arc in node_input_arcs.iter() {
224                node_inputs.push(arc.deref());
225            }
226
227            for old_key in nodes_cleanup {
228                self.node_storage.remove(old_key);
229                self.node_refcount.remove(old_key);
230            }
231            // Rebind node as &mut to perform calculation
232            let node = self.node_storage.get_mut(node_key).unwrap();
233            node.eval(node_inputs.as_slice());
234        }
235        // Assert checks that only the output node is left
236        assert_eq!(self.node_storage.len(), 1);
237        let output_key = self.output_node.take().unwrap();
238        // Remove instead of get because we want an owned Node
239        let output_node = self.node_storage.remove(output_key).unwrap();
240        let output_val_arc = output_node.computed_val();
241        drop(output_node);
242        /*
243         * We just computed the output value and didn't hand it to anyone else
244         * We dropped the output node, which would have held the only other copy
245         * There is exactly one copy of the Arc, so try_unwrap must succeed
246         */
247        Arc::try_unwrap(output_val_arc).ok().unwrap()
248    }
249}
250
251struct DAGComputeDisplay<'a, T> {
252    /*
253     * We only really need edge_list, but hold a PhantomData to slotmap_ref
254     * This prevents changes to the DAG so we only need to compute stuff once
255     */
256    // TODO: make this actual ref?
257    slotmap_ref: PhantomData<&'a SlotMap<ComputeGraphKey, Node<T>>>,
258    names: HashMap<ComputeGraphKey, &'a str>,
259    output_node: Option<ComputeGraphKey>,
260    edge_list: Vec<(ComputeGraphKey, ComputeGraphKey)>
261}
262impl<'a, T> DAGComputeDisplay<'a, T> {
263    fn new(map: &'a ComputationGraph<T>) -> DAGComputeDisplay<'a, T> {
264        let true_keyset: HashMap<ComputeGraphKey, &'a str> = map.node_storage
265            .keys()
266            .map(|key| (key, map.node_storage.get(key).unwrap().name.as_str()))
267            .collect();
268        let mut explored_keyset: HashSet<ComputeGraphKey> = HashSet::new();
269        let mut edge_list = Vec::new();
270        // len is more efficient than full equality
271        // We need this to account for ill-formed graphs (don't reject here)
272        while true_keyset.len() > explored_keyset.len() {
273            debug_assert!(explored_keyset.is_subset(
274                &true_keyset.keys().copied().collect()));
275            // Do BFS to make the final dot file more human-readable
276            let mut bfs_queue: VecDeque<ComputeGraphKey> = VecDeque::new();
277            let mut bfs_root: Option<ComputeGraphKey> = None;
278            for key in true_keyset.keys() {
279                if !explored_keyset.contains(key) {
280                    bfs_root = Some(*key);
281                    break;
282                }
283            }
284            let bfs_root = bfs_root.unwrap(); // Rebind and assert
285
286            bfs_queue.push_back(bfs_root);
287            explored_keyset.insert(bfs_root);
288            while !bfs_queue.is_empty() {
289                let current = bfs_queue.pop_front().unwrap();
290                for input in map.node_storage.get(current).unwrap()
291                        .input_nodes.iter() {
292                    edge_list.push((*input, current));
293                    // Insert returns true if new element was added
294                    if explored_keyset.insert(*input) {
295                        bfs_queue.push_back(*input);
296                    }
297                }
298            }
299        }
300        debug_assert_eq!(true_keyset.keys().copied().collect::<HashSet<_>>(),
301                explored_keyset);
302        DAGComputeDisplay {
303            slotmap_ref: PhantomData::default(),
304            names: true_keyset,
305            output_node: map.output_node,
306            edge_list
307        }
308    }
309}
310impl<'a, T> fmt::Display for DAGComputeDisplay<'a, T> {
311    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
312        writeln!(fmt, "strict digraph {{")?;
313        for (node, name) in self.names.iter() {
314            let node_id = node.data().as_ffi();
315            let escaped_name: String = name.chars().map(|c| {
316                match c {
317                    '"' => r#"\""#.to_owned(),
318                    c => c.to_string()
319                }
320            }).collect();
321            write!(fmt, "{} [label=\"{}\"", node_id, escaped_name)?;
322            if let Some(out) = self.output_node {
323                if out == *node {
324                    write!(fmt, ", shape=box")?;
325                }
326            }
327            writeln!(fmt, "];")?;
328        }
329        for edge in self.edge_list.iter() {
330            // Use the u64 as_ffi to handle duplicate names
331            let from_id = edge.0.data().as_ffi();
332            let to_id = edge.1.data().as_ffi();
333            writeln!(fmt, "{}->{};", from_id, to_id)?;
334        }
335        writeln!(fmt, "}}")
336    }
337}