paragraphs/
graph.rs

1use std::borrow::Borrow;
2use crate::threadpool::{ThreadExecute, ThreadPool, WorkerStatus};
3use std::collections::{HashSet, HashMap};
4use std::iter::{FromIterator};
5use std::sync::Arc;
6
7#[cfg(test)]
8mod tests {
9    use std::iter::FromIterator;
10    use std::collections::{HashSet, HashMap};
11    use super::Graph;
12    use crate::threadpool::ThreadExecute;
13    use crate::threadpool::tests::{Adder};
14    use std::sync::Arc;
15
16    const NUM_THREADS: usize = 8;
17
18    #[test]
19    fn can_construct_graph() {
20        let _graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
21    }
22
23    #[test]
24    fn can_add_node() {
25        let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
26        let node_id = graph.add(Adder::new(), &[]);
27        assert_eq!(node_id, 0);
28    }
29
30    #[test]
31    #[should_panic(expected="does not yet exist in the graph")]
32    fn cannot_set_node_input_to_itself() {
33        let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
34        graph.add(Adder::new(), vec![0]);
35    }
36
37    fn build_diamond_graph() -> (Graph<Adder, i32> ,usize, usize, usize, usize, usize) {
38        let mut graph = Graph::new(NUM_THREADS);
39        let input = graph.add(Adder::new(), &[]);
40        // Diamond graph.
41        let hidden1 = graph.add(Adder::new(), &[input, input]);
42        let hidden2 = graph.add(Adder::new(), &[input, input]);
43        let output1 = graph.add(Adder::new(), &[hidden1, hidden2]);
44        let output2 = graph.add(Adder::new(), &[hidden1, hidden2]);
45        let _deadend = graph.add(Adder::new(), &[hidden1]);
46        return (graph, input, hidden1, hidden2, output1, output2);
47    }
48
49    #[test]
50    fn can_compile_graph() {
51        let (graph, input, hidden1, hidden2, output1, output2) = build_diamond_graph();
52        println!("Graph: {:?}", graph);
53        // Get recipe for the first output.
54        let out1_recipe = graph.compile(&[output1]);
55        println!("Output 1 Recipe: {:?}", out1_recipe);
56        // Check
57        assert_eq!(out1_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output1]));
58        assert_eq!(out1_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
59        // Get recipe for the second output.
60        let out2_recipe = graph.compile(vec![output2]);
61        println!("Output 2 Recipe: {:?}", out2_recipe);
62        // Check correctness.
63        assert_eq!(out2_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output2]));
64        assert_eq!(out2_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
65    }
66
67    #[test]
68    fn can_run_graph() {
69        let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
70        println!("Graph: {:?}", graph);
71        let recipe = graph.compile(vec![output1, output2, output2]);
72        println!("Recipe: {:?}", recipe);
73        let inputs_map = HashMap::from_iter(vec!(
74            (input, vec![1, 2, 3])
75        ));
76        let outputs = graph.run(&recipe, inputs_map);
77        println!("Outputs: {:?}", outputs);
78        assert_eq!(outputs.get(&output1), Some(&24));
79        assert_eq!(outputs.get(&output2), Some(&24));
80    }
81
82    #[test]
83    fn can_run_graph_all_nodes_outputs() {
84        let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
85        println!("Graph: {:?}", graph);
86        let recipe = graph.compile(vec![input, _hidden1, _hidden2, output1, output2, output2]);
87        println!("Recipe: {:?}", recipe);
88        let inputs_map = HashMap::from_iter(vec!(
89            (input, vec![1, 2, 3])
90        ));
91        println!("Input map: {:?}", inputs_map);
92        let outputs = graph.run(&recipe, inputs_map);
93        println!("Outputs: {:?}", outputs);
94        assert_eq!(outputs.get(&output1), Some(&24));
95        assert_eq!(outputs.get(&output2), Some(&24));
96    }
97
98    #[test]
99    fn can_run_graph_input_is_output() {
100        let (mut graph, input, _hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
101        println!("Graph: {:?}", graph);
102        let recipe = graph.compile(vec![input]);
103        println!("Recipe: {:?}", recipe);
104        let inputs_map = HashMap::from_iter(vec!(
105            (input, vec![1, 2, 3])
106        ));
107        println!("Input map: {:?}", inputs_map);
108        let outputs = graph.run(&recipe, inputs_map);
109        println!("Outputs: {:?}", outputs);
110        assert_eq!(outputs.get(&input), Some(&6));
111    }
112
113    #[test]
114    fn can_run_graph_input_and_node() {
115        let (mut graph, input, hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
116        println!("Graph: {:?}", graph);
117        let recipe = graph.compile(vec![input, hidden1]);
118        println!("Recipe: {:?}", recipe);
119        let inputs_map = HashMap::from_iter(vec!(
120            (input, vec![1, 2, 3])
121        ));
122        println!("Input map: {:?}", inputs_map);
123        let outputs = graph.run(&recipe, inputs_map);
124        println!("Outputs: {:?}", outputs);
125        assert_eq!(outputs.get(&hidden1), Some(&12));
126    }
127
128    struct FailNode;
129
130    impl ThreadExecute<i32> for FailNode {
131        fn execute(&mut self, _inputs: Vec<Arc<i32>>) -> Option<i32> {
132            return None;
133        }
134    }
135
136    #[test]
137    #[should_panic(expected="Graph failed to execute because node")]
138    fn node_failure_causes_panic() {
139        let mut graph = Graph::new(NUM_THREADS);
140        let input = graph.add(FailNode{}, &[]);
141        let recipe = graph.compile(&[input]);
142        let mut inputs_map = HashMap::new();
143        inputs_map.insert(input, vec!());
144        let _ = graph.run(&recipe, inputs_map);
145    }
146
147    #[test]
148    fn can_iterate_graph_nodes() {
149        let graph = build_diamond_graph().0;
150        let mut num_nodes = 0;
151        let expected_num_nodes = graph.len();
152        for node in graph {
153            assert!(node.valid);
154            num_nodes += 1;
155        }
156        assert_eq!(num_nodes, expected_num_nodes);
157    }
158
159    #[test]
160    fn can_iterate_ref_graph_nodes() {
161        let graph = build_diamond_graph().0;
162        let mut num_nodes = 0;
163        let expected_num_nodes = graph.len();
164        for node in &graph {
165            assert!(node.valid);
166            num_nodes += 1;
167        }
168        assert_eq!(num_nodes, expected_num_nodes);
169    }
170
171    #[test]
172    fn can_iterate_ref_mut_graph_nodes() {
173        let mut graph = build_diamond_graph().0;
174        let mut num_nodes = 0;
175        let expected_num_nodes = graph.len();
176        for node in &mut graph {
177            assert!(node.valid);
178            num_nodes += 1;
179        }
180        assert_eq!(num_nodes, expected_num_nodes);
181    }
182}
183
184/// Describes a recipe for retrieving a particular set of outputs.
185#[derive(Debug)]
186pub struct Recipe {
187    runs: HashSet<usize>,
188    /// The graph indices of the inputs of this recipe.
189    /// Values for each of these must be provided at graph execution time.
190    pub inputs: HashSet<usize>,
191    /// The graph indices of the outputs of this recipe.
192    /// Values for each of these will be returned after graph execution.
193    pub outputs: HashSet<usize>,
194    // Maps every node in runs to any outputs in the Recipe.
195    node_outputs: HashMap<usize, HashSet<usize>>,
196    // Maps every node in runs to its inputs, which must be in the Recipe.
197    node_inputs: HashMap<usize, HashSet<usize>>,
198}
199
200impl Recipe {
201    fn new(runs: HashSet<usize>, inputs: HashSet<usize>, outputs: HashSet<usize>, node_outputs: HashMap<usize, HashSet<usize>>, node_inputs: HashMap<usize, HashSet<usize>>) -> Recipe {
202        if inputs.len() == 0 {
203            panic!("Invalid Recipe: Found 0 inputs. Recipes must have at least one input node.");
204        }
205        return Recipe{runs: runs, inputs: inputs, outputs: outputs, node_outputs: node_outputs, node_inputs: node_inputs};
206    }
207}
208
209/// A computation graph.
210/// Nodes are executed concurrently wherever possible.
211/// Each graph manages its own threadpool, so although it may be possible to creat
212/// higher-order graphs, it is generally not advisable.
213#[derive(Debug)]
214pub struct Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
215    // This needs to be an option so that we can take() from it.
216    nodes: Vec<Option<Node>>,
217    node_inputs: Vec<Vec<usize>>,
218    pool: ThreadPool<Node, Data>,
219}
220
221impl<Node, Data> IntoIterator for Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
222    type Item = Node;
223    type IntoIter = std::iter::Map<std::vec::IntoIter<std::option::Option<Node>>, fn(std::option::Option<Node>) -> Node>;
224
225
226    fn into_iter(self) -> Self::IntoIter {
227        fn expect_node<Node>(node: Option<Node>) -> Node {
228            return node.expect("Node has been moved out of the graph. Is the graph being executed?");
229        }
230        return self.nodes.into_iter().map(expect_node);
231    }
232}
233
234impl<'a, Node, Data> IntoIterator for &'a Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
235    type Item = &'a Node;
236    // type IntoIter = GraphIterator<Node>;
237    type IntoIter = std::iter::Map<std::slice::Iter<'a, std::option::Option<Node>>, fn(&std::option::Option<Node>) -> &Node>;
238
239    fn into_iter(self) -> Self::IntoIter {
240        fn expect_node<Node>(node: &Option<Node>) -> &Node {
241            return match node {
242                Some(n) => n,
243                None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
244            };
245        }
246        return self.nodes.iter().map(expect_node);
247    }
248}
249
250impl<'a, Node, Data> IntoIterator for &'a mut Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
251    type Item = &'a mut Node;
252    // type IntoIter = GraphIterator<Node>;
253    type IntoIter = std::iter::Map<std::slice::IterMut<'a, std::option::Option<Node>>, fn(&mut std::option::Option<Node>) -> &mut Node>;
254
255    fn into_iter(self) -> Self::IntoIter {
256        fn expect_node<Node>(node: &mut Option<Node>) -> &mut Node {
257            return match node {
258                Some(n) => n,
259                None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
260            };
261        }
262        return self.nodes.iter_mut().map(expect_node);
263    }
264}
265
266impl<Node: 'static, Data: 'static> Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
267
268    /// Creates a new graph.
269    ///
270    /// # Arguments
271    ///
272    /// * `num_threads` - The number of threads available to the graph.
273    ///
274    /// # Example
275    ///
276    /// ```
277    /// use paragraphs::{Graph, ThreadExecute};
278    /// use std::sync::Arc;
279    ///
280    /// struct MyNode;
281    /// struct MyData;
282    ///
283    /// impl ThreadExecute<MyData> for MyNode {
284    ///     fn execute(&mut self, inputs: Vec<Arc<MyData>>) -> Option<MyData> {
285    ///         return Some(MyData{});
286    ///     }
287    /// }
288    ///
289    /// let graph: Graph<MyNode, MyData> = Graph::new(8);
290    /// ```
291    pub fn new(num_threads: usize) -> Graph<Node, Data> {
292        return Graph{nodes: Vec::new(), node_inputs: Vec::new(), pool: ThreadPool::new(num_threads)};
293    }
294
295    /// Returns the number of nodes in the graph.
296    pub fn len(&self) -> usize {
297        return self.nodes.len();
298    }
299
300    /// Gets an Option containing a reference to the node at the specified index.
301    ///
302    /// # Arguments
303    ///
304    /// * `index` - The index of the node.
305    pub fn get(&self, index: usize) -> Option<&Node> {
306        if let Some(node) = self.nodes.get(index) {
307            return node.as_ref();
308        }
309        return None;
310    }
311
312    /// Gets an Option containing a mutable referenece to the node at the specified index.
313    ///
314    /// # Arguments
315    ///
316    /// * `index` - The index of the node.
317    pub fn get_mut(&mut self, index: usize) -> Option<&mut Node> {
318        if let Some(node) = self.nodes.get_mut(index) {
319            return node.as_mut();
320        }
321        return None;
322    }
323
324    /// Adds a new node to the graph.
325    ///
326    /// # Arguments
327    ///
328    /// * `node` - The node to add to the graph.
329    /// * `inputs` - The indices of the inputs to the node. All inputs specified must
330    ///             already be present in the graph, or the function panics.
331    ///             Inputs may be specified more than once.
332    ///
333    /// # Example
334    ///
335    /// ```
336    /// use paragraphs::Graph;
337    /// # use paragraphs::ThreadExecute;
338    /// # use std::sync::Arc;
339    /// # struct MyNode;
340    /// # struct MyData;
341    /// # impl ThreadExecute<MyData> for MyNode {
342    /// #     fn execute(&mut self, inputs: Vec<Arc<MyData>>) -> Option<MyData> {
343    /// #         return Some(MyData{});
344    /// #     }
345    /// # }
346    /// let mut graph = Graph::new(8);
347    /// let node0 = graph.add(MyNode{}, &[]);
348    /// let node1 = graph.add(MyNode{}, &[node0, node0]);
349    /// ```
350    pub fn add<Container, Elem>(&mut self, node: Node, inputs: Container) -> usize where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
351        let node_id = self.nodes.len();
352        self.nodes.push(Some(node));
353        // Push inputs at the end, so that the above will fail if this node is an input to itself.
354        let inputs = inputs.into_iter().map(|x| x.borrow().clone()).collect();
355        for &input in &inputs {
356            if input >= node_id {
357                panic!("Cannot add node {} as an input to node {} as it does not yet exist in the graph.", input, node_id);
358            }
359        }
360        self.node_inputs.push(inputs);
361        return node_id;
362    }
363
364    /// Generates a Recipe that can be used to compute the specified outputs.
365    ///
366    /// # Arguments
367    ///
368    /// * `fetches` - The indices of the nodes to fetch. May contain duplicates.
369    ///             Panics if no nodes are specified.
370    ///
371    /// # Example
372    ///
373    /// ```
374    /// # use paragraphs::{Graph, ThreadExecute};
375    /// # use std::sync::Arc;
376    /// # struct MyNode;
377    /// # struct MyData;
378    /// # impl ThreadExecute<MyData> for MyNode {
379    /// #     fn execute(&mut self, inputs: Vec<Arc<MyData>>) -> Option<MyData> {
380    /// #         return Some(MyData{});
381    /// #     }
382    /// # }
383    /// # let mut graph = Graph::new(8);
384    /// # let node0 = graph.add(MyNode{}, &[]);
385    /// # let node1 = graph.add(MyNode{}, &[node0, node0]);
386    /// // This recipe can be used to fetch the result of node1.
387    /// let node1_recipe = graph.compile(&[node1]);
388    /// ```
389    pub fn compile<Container, Elem>(&self, fetches: Container) -> Recipe
390        where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
391        let mut index = 0;
392        let mut recipe_inputs = HashSet::new();
393        let mut node_outputs: HashMap<usize, HashSet<usize>> = HashMap::new();
394        let mut node_inputs: HashMap<usize, HashSet<usize>> = HashMap::new();
395        // Remove unecessary duplicates, and then store as recipe_outputs.
396        let mut fetches: Vec<usize> = fetches.into_iter().map(|x| x.borrow().clone()).collect();
397        let recipe_outputs = HashSet::from_iter(fetches.iter().cloned());
398        // Walk over fetches, and append the inputs of each node in it to the end of the vector.
399        // This is a BFS for finding all nodes that need to be executed.
400        while index < fetches.len() {
401            let node_id = match fetches.get(index) {
402                Some(id) => id,
403                None => panic!("Could not get index {} index in fetches ({:?}) during BFS", index, fetches),
404            };
405            let inputs = match self.node_inputs.get(*node_id) {
406                Some(id) => id,
407                None => panic!("Could not get node inputs for node {}", node_id),
408            };
409            // Nodes with no inputs ARE inputs.
410            if inputs.len() == 0 {
411                recipe_inputs.insert(*node_id);
412            }
413            // Add node inputs.
414            node_inputs.insert(*node_id, inputs.iter().cloned().collect());
415            // Add node outputs.
416            for input in inputs {
417                match node_outputs.get_mut(input) {
418                    // If this node is already in the map, then append the output to it.
419                    Some(outputs) => { outputs.insert(*node_id); },
420                    // Otherwise insert it into the map.
421                    None => {
422                        node_outputs.insert(*input, HashSet::from_iter(vec![*node_id]));
423                    },
424                };
425            }
426            fetches.extend(inputs);
427            index += 1;
428        }
429        return Recipe::new(HashSet::from_iter(fetches), recipe_inputs, recipe_outputs, node_outputs, node_inputs);
430    }
431
432    // Runs the provided Recipe with the provded inputs (map of {node: inputs}).
433    // If inputs are missing, panics.
434    // TODO: Document all panic conditions.
435    /// Executes the nodes specified by the recipe using the provided inputs.
436    ///
437    /// # Arguments
438    ///
439    /// * `recipe` - The recipe to execute. All outputs of the recipe are fetched.
440    /// * `inputs_map` - Inputs for each input node. Panics if inputs are missing.
441    /// # Example
442    ///
443    /// ```
444    /// # use paragraphs::{Graph, ThreadExecute};
445    /// # use std::sync::Arc;
446    /// use std::collections::HashMap;
447    /// use std::iter::FromIterator;
448    /// # struct MyNode;
449    /// # struct MyData;
450    /// # impl ThreadExecute<MyData> for MyNode {
451    /// #     fn execute(&mut self, inputs: Vec<Arc<MyData>>) -> Option<MyData> {
452    /// #         return Some(MyData{});
453    /// #     }
454    /// # }
455    /// # let mut graph = Graph::new(8);
456    /// # let node0 = graph.add(MyNode{}, &[]);
457    /// # let node1 = graph.add(MyNode{}, &[node0, node0]);
458    /// # let node1_recipe = graph.compile(&[node1]);
459    /// let inputs_map = HashMap::from_iter(vec!(
460    ///     (node0, vec![MyData{}]),
461    /// ));
462    /// let outputs = graph.run(&node1_recipe, inputs_map);
463    /// ```
464    pub fn run(&mut self, recipe: &Recipe, mut inputs_map: HashMap<usize, Vec<Data>>) -> HashMap<usize, Data> {
465        fn execute_node<Node: 'static, Data: 'static>(graph: &mut Graph<Node, Data>, node_id: usize, inputs: Vec<Arc<Data>>) where Node: ThreadExecute<Data>, Data: Send + Sync {
466            let node_opt = match graph.nodes.get_mut(node_id) {
467                Some(id) => id,
468                None => panic!("While attempting to execute, could not retrieve node {}", node_id)
469            };
470            let node = match node_opt.take() {
471                Some(node) => node,
472                None => panic!("Could not retrieve node {} - is it currently being executed?", node_id)
473            };
474            graph.pool.execute(node, inputs, node_id);
475        }
476
477        fn assemble_inputs<Node, Data>(graph: &Graph<Node, Data>, intermediates: &HashMap<usize, Arc<Data>>, node_id: usize) -> Vec<Arc<Data>> where Node: ThreadExecute<Data>, Data: Send + Sync {
478            let mut inputs: Vec<Arc<Data>> = Vec::new();
479            let input_ids = match graph.node_inputs.get(node_id) {
480                Some(id) => id,
481                None => panic!("Could not find node {} in the graph", node_id)
482            };
483            for input_id in input_ids {
484                let intermediate = match intermediates.get(input_id) {
485                    Some(intermediate) => intermediate,
486                    None => panic!("Node {} attempted to execute, but input {} is missing", node_id, input_id)
487                };
488                inputs.push(Arc::clone(intermediate));
489            }
490            return inputs;
491        }
492
493        // Each time we receive a node back, we will decrement the nodes remaining count.
494        let mut num_nodes_remaining = recipe.runs.len();
495        // We also store intermediate outputs of nodes.
496        let mut intermediates: HashMap<usize, Arc<Data>> = HashMap::with_capacity(num_nodes_remaining);
497        // Maps each node in recipe.runs to its inputs. When there are no inputs remaining,
498        // it means the node can be executed.
499        let mut remaining_inputs_map = recipe.node_inputs.clone();
500
501        // First, launch all input nodes.
502        for input_node in &recipe.inputs {
503            match inputs_map.remove(input_node) {
504                Some(inputs) => {
505                    let arc_inputs: Vec<Arc<Data>> = inputs.into_iter().map(|input| Arc::new(input)).collect();
506                    // Queue up every input node.
507                    execute_node(self, *input_node, arc_inputs);
508                },
509                None => panic!("Input for {} is missing", input_node),
510            };
511        }
512
513        // Keep going until everything has finished executing.
514        while num_nodes_remaining > 0 {
515            if let Ok(wstatus) = self.pool.wstatus_receiver.recv() {
516                match wstatus {
517                    WorkerStatus::Complete(node, result, node_id) => {
518                        // Each time we receive a node back, we decrement the count of nodes
519                        // that haven't yet finnished executing.
520                        num_nodes_remaining -= 1;
521                        // Place the node back into the graph.
522                        match self.nodes.get_mut(node_id) {
523                            Some(node_option) => node_option.replace(node),
524                            None => panic!("Received WorkerStatus for node {}, but this node is not in the graph", node_id),
525                        };
526
527                        // Store the intermediate output. If it is already present,
528                        // it means the node was executed more than once.
529                        if intermediates.insert(node_id, Arc::new(result)).is_some() {
530                            panic!("Node {} was executed more than once, possibly due to a cycle", node_id);
531                        }
532
533                        // See if it is possible to queue other nodes by checking on this node's outputs.
534                        // If a node is not found in the output map, then it means it IS an output.
535                        if let Some(output_ids) = recipe.node_outputs.get(&node_id) {
536                            // Next, walk over all the output_ids of this node, decrementing
537                            // their input counts.
538                            for output_id in output_ids {
539                                let remaining_inputs = match remaining_inputs_map.get_mut(output_id) {
540                                    Some(id) => id,
541                                    None => panic!("Node {} is not registered in the remaining_inputs_map", output_id)
542                                };
543                                remaining_inputs.remove(&node_id);
544                                // If any hit 0, execute them.
545                                if remaining_inputs.len() == 0 {
546                                    // Assemble the required inputs and dispatch.
547                                    let inputs = assemble_inputs(&self, &intermediates, output_id.clone());
548                                    execute_node(self, *output_id, inputs);
549                                }
550                            } // for output_id in output_ids
551                        } // if let Some(output_ids)
552                    },
553                    WorkerStatus::Fail(node_id) => panic!("Graph failed to execute because node {} failed", node_id),
554                } // match wstatus
555            }
556        } // while num_nodes_remaining > 0
557
558        // Return the outputs.
559        let mut outputs_map = HashMap::new();
560        for output in &recipe.outputs {
561            // Unwrap the Arc to get the underlying Data.
562            match Arc::try_unwrap(intermediates.remove(output).unwrap()) {
563                Ok(data) => { outputs_map.insert(output.clone(), data); },
564                Err(_) => panic!("Could not retrieve output for node {}", output),
565            }
566        }
567        return outputs_map;
568    } // fn run
569} // impl