1use std::borrow::Borrow;
2use crate::threadpool::{ThreadExecute, ThreadPool, WorkerStatus};
3use std::collections::{HashSet, HashMap};
4use std::iter::{FromIterator};
5use std::sync::Arc;
6
7#[cfg(test)]
8mod tests {
9 use std::iter::FromIterator;
10 use std::collections::{HashSet, HashMap};
11 use super::Graph;
12 use crate::threadpool::ThreadExecute;
13 use crate::threadpool::tests::{Adder};
14 use std::sync::Arc;
15
16 const NUM_THREADS: usize = 8;
17
18 #[test]
19 fn can_construct_graph() {
20 let _graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
21 }
22
23 #[test]
24 fn can_add_node() {
25 let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
26 let node_id = graph.add(Adder::new(), &[]);
27 assert_eq!(node_id, 0);
28 }
29
30 #[test]
31 #[should_panic(expected="does not yet exist in the graph")]
32 fn cannot_set_node_input_to_itself() {
33 let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
34 graph.add(Adder::new(), vec![0]);
35 }
36
37 fn build_diamond_graph() -> (Graph<Adder, i32> ,usize, usize, usize, usize, usize) {
38 let mut graph = Graph::new(NUM_THREADS);
39 let input = graph.add(Adder::new(), &[]);
40 let hidden1 = graph.add(Adder::new(), &[input, input]);
42 let hidden2 = graph.add(Adder::new(), &[input, input]);
43 let output1 = graph.add(Adder::new(), &[hidden1, hidden2]);
44 let output2 = graph.add(Adder::new(), &[hidden1, hidden2]);
45 let _deadend = graph.add(Adder::new(), &[hidden1]);
46 return (graph, input, hidden1, hidden2, output1, output2);
47 }
48
49 #[test]
50 fn can_compile_graph() {
51 let (graph, input, hidden1, hidden2, output1, output2) = build_diamond_graph();
52 println!("Graph: {:?}", graph);
53 let out1_recipe = graph.compile(&[output1]);
55 println!("Output 1 Recipe: {:?}", out1_recipe);
56 assert_eq!(out1_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output1]));
58 assert_eq!(out1_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
59 let out2_recipe = graph.compile(vec![output2]);
61 println!("Output 2 Recipe: {:?}", out2_recipe);
62 assert_eq!(out2_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output2]));
64 assert_eq!(out2_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
65 }
66
67 #[test]
68 fn can_run_graph() {
69 let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
70 println!("Graph: {:?}", graph);
71 let recipe = graph.compile(vec![output1, output2, output2]);
72 println!("Recipe: {:?}", recipe);
73 let inputs_map = HashMap::from_iter(vec!(
74 (input, vec![1, 2, 3])
75 ));
76 let outputs = graph.run(&recipe, inputs_map);
77 println!("Outputs: {:?}", outputs);
78 assert_eq!(outputs.get(&output1), Some(&24));
79 assert_eq!(outputs.get(&output2), Some(&24));
80 }
81
82 #[test]
83 fn can_run_graph_all_nodes_outputs() {
84 let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
85 println!("Graph: {:?}", graph);
86 let recipe = graph.compile(vec![input, _hidden1, _hidden2, output1, output2, output2]);
87 println!("Recipe: {:?}", recipe);
88 let inputs_map = HashMap::from_iter(vec!(
89 (input, vec![1, 2, 3])
90 ));
91 println!("Input map: {:?}", inputs_map);
92 let outputs = graph.run(&recipe, inputs_map);
93 println!("Outputs: {:?}", outputs);
94 assert_eq!(outputs.get(&output1), Some(&24));
95 assert_eq!(outputs.get(&output2), Some(&24));
96 }
97
98 #[test]
99 fn can_run_graph_input_is_output() {
100 let (mut graph, input, _hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
101 println!("Graph: {:?}", graph);
102 let recipe = graph.compile(vec![input]);
103 println!("Recipe: {:?}", recipe);
104 let inputs_map = HashMap::from_iter(vec!(
105 (input, vec![1, 2, 3])
106 ));
107 println!("Input map: {:?}", inputs_map);
108 let outputs = graph.run(&recipe, inputs_map);
109 println!("Outputs: {:?}", outputs);
110 assert_eq!(outputs.get(&input), Some(&6));
111 }
112
113 #[test]
114 fn can_run_graph_input_and_node() {
115 let (mut graph, input, hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
116 println!("Graph: {:?}", graph);
117 let recipe = graph.compile(vec![input, hidden1]);
118 println!("Recipe: {:?}", recipe);
119 let inputs_map = HashMap::from_iter(vec!(
120 (input, vec![1, 2, 3])
121 ));
122 println!("Input map: {:?}", inputs_map);
123 let outputs = graph.run(&recipe, inputs_map);
124 println!("Outputs: {:?}", outputs);
125 assert_eq!(outputs.get(&hidden1), Some(&12));
126 }
127
128 struct FailNode;
129
130 impl ThreadExecute<i32> for FailNode {
131 fn execute(&mut self, _inputs: Vec<Arc<i32>>) -> Option<i32> {
132 return None;
133 }
134 }
135
136 #[test]
137 #[should_panic(expected="Graph failed to execute because node")]
138 fn node_failure_causes_panic() {
139 let mut graph = Graph::new(NUM_THREADS);
140 let input = graph.add(FailNode{}, &[]);
141 let recipe = graph.compile(&[input]);
142 let mut inputs_map = HashMap::new();
143 inputs_map.insert(input, vec!());
144 let _ = graph.run(&recipe, inputs_map);
145 }
146
147 #[test]
148 fn can_iterate_graph_nodes() {
149 let graph = build_diamond_graph().0;
150 let mut num_nodes = 0;
151 let expected_num_nodes = graph.len();
152 for node in graph {
153 assert!(node.valid);
154 num_nodes += 1;
155 }
156 assert_eq!(num_nodes, expected_num_nodes);
157 }
158
159 #[test]
160 fn can_iterate_ref_graph_nodes() {
161 let graph = build_diamond_graph().0;
162 let mut num_nodes = 0;
163 let expected_num_nodes = graph.len();
164 for node in &graph {
165 assert!(node.valid);
166 num_nodes += 1;
167 }
168 assert_eq!(num_nodes, expected_num_nodes);
169 }
170
171 #[test]
172 fn can_iterate_ref_mut_graph_nodes() {
173 let mut graph = build_diamond_graph().0;
174 let mut num_nodes = 0;
175 let expected_num_nodes = graph.len();
176 for node in &mut graph {
177 assert!(node.valid);
178 num_nodes += 1;
179 }
180 assert_eq!(num_nodes, expected_num_nodes);
181 }
182}
183
184#[derive(Debug)]
186pub struct Recipe {
187 runs: HashSet<usize>,
188 pub inputs: HashSet<usize>,
191 pub outputs: HashSet<usize>,
194 node_outputs: HashMap<usize, HashSet<usize>>,
196 node_inputs: HashMap<usize, HashSet<usize>>,
198}
199
200impl Recipe {
201 fn new(runs: HashSet<usize>, inputs: HashSet<usize>, outputs: HashSet<usize>, node_outputs: HashMap<usize, HashSet<usize>>, node_inputs: HashMap<usize, HashSet<usize>>) -> Recipe {
202 if inputs.len() == 0 {
203 panic!("Invalid Recipe: Found 0 inputs. Recipes must have at least one input node.");
204 }
205 return Recipe{runs: runs, inputs: inputs, outputs: outputs, node_outputs: node_outputs, node_inputs: node_inputs};
206 }
207}
208
209#[derive(Debug)]
214pub struct Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
215 nodes: Vec<Option<Node>>,
217 node_inputs: Vec<Vec<usize>>,
218 pool: ThreadPool<Node, Data>,
219}
220
221impl<Node, Data> IntoIterator for Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
222 type Item = Node;
223 type IntoIter = std::iter::Map<std::vec::IntoIter<std::option::Option<Node>>, fn(std::option::Option<Node>) -> Node>;
224
225
226 fn into_iter(self) -> Self::IntoIter {
227 fn expect_node<Node>(node: Option<Node>) -> Node {
228 return node.expect("Node has been moved out of the graph. Is the graph being executed?");
229 }
230 return self.nodes.into_iter().map(expect_node);
231 }
232}
233
234impl<'a, Node, Data> IntoIterator for &'a Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
235 type Item = &'a Node;
236 type IntoIter = std::iter::Map<std::slice::Iter<'a, std::option::Option<Node>>, fn(&std::option::Option<Node>) -> &Node>;
238
239 fn into_iter(self) -> Self::IntoIter {
240 fn expect_node<Node>(node: &Option<Node>) -> &Node {
241 return match node {
242 Some(n) => n,
243 None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
244 };
245 }
246 return self.nodes.iter().map(expect_node);
247 }
248}
249
250impl<'a, Node, Data> IntoIterator for &'a mut Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
251 type Item = &'a mut Node;
252 type IntoIter = std::iter::Map<std::slice::IterMut<'a, std::option::Option<Node>>, fn(&mut std::option::Option<Node>) -> &mut Node>;
254
255 fn into_iter(self) -> Self::IntoIter {
256 fn expect_node<Node>(node: &mut Option<Node>) -> &mut Node {
257 return match node {
258 Some(n) => n,
259 None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
260 };
261 }
262 return self.nodes.iter_mut().map(expect_node);
263 }
264}
265
266impl<Node: 'static, Data: 'static> Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
267
268 pub fn new(num_threads: usize) -> Graph<Node, Data> {
292 return Graph{nodes: Vec::new(), node_inputs: Vec::new(), pool: ThreadPool::new(num_threads)};
293 }
294
295 pub fn len(&self) -> usize {
297 return self.nodes.len();
298 }
299
300 pub fn get(&self, index: usize) -> Option<&Node> {
306 if let Some(node) = self.nodes.get(index) {
307 return node.as_ref();
308 }
309 return None;
310 }
311
312 pub fn get_mut(&mut self, index: usize) -> Option<&mut Node> {
318 if let Some(node) = self.nodes.get_mut(index) {
319 return node.as_mut();
320 }
321 return None;
322 }
323
324 pub fn add<Container, Elem>(&mut self, node: Node, inputs: Container) -> usize where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
351 let node_id = self.nodes.len();
352 self.nodes.push(Some(node));
353 let inputs = inputs.into_iter().map(|x| x.borrow().clone()).collect();
355 for &input in &inputs {
356 if input >= node_id {
357 panic!("Cannot add node {} as an input to node {} as it does not yet exist in the graph.", input, node_id);
358 }
359 }
360 self.node_inputs.push(inputs);
361 return node_id;
362 }
363
364 pub fn compile<Container, Elem>(&self, fetches: Container) -> Recipe
390 where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
391 let mut index = 0;
392 let mut recipe_inputs = HashSet::new();
393 let mut node_outputs: HashMap<usize, HashSet<usize>> = HashMap::new();
394 let mut node_inputs: HashMap<usize, HashSet<usize>> = HashMap::new();
395 let mut fetches: Vec<usize> = fetches.into_iter().map(|x| x.borrow().clone()).collect();
397 let recipe_outputs = HashSet::from_iter(fetches.iter().cloned());
398 while index < fetches.len() {
401 let node_id = match fetches.get(index) {
402 Some(id) => id,
403 None => panic!("Could not get index {} index in fetches ({:?}) during BFS", index, fetches),
404 };
405 let inputs = match self.node_inputs.get(*node_id) {
406 Some(id) => id,
407 None => panic!("Could not get node inputs for node {}", node_id),
408 };
409 if inputs.len() == 0 {
411 recipe_inputs.insert(*node_id);
412 }
413 node_inputs.insert(*node_id, inputs.iter().cloned().collect());
415 for input in inputs {
417 match node_outputs.get_mut(input) {
418 Some(outputs) => { outputs.insert(*node_id); },
420 None => {
422 node_outputs.insert(*input, HashSet::from_iter(vec![*node_id]));
423 },
424 };
425 }
426 fetches.extend(inputs);
427 index += 1;
428 }
429 return Recipe::new(HashSet::from_iter(fetches), recipe_inputs, recipe_outputs, node_outputs, node_inputs);
430 }
431
432 pub fn run(&mut self, recipe: &Recipe, mut inputs_map: HashMap<usize, Vec<Data>>) -> HashMap<usize, Data> {
465 fn execute_node<Node: 'static, Data: 'static>(graph: &mut Graph<Node, Data>, node_id: usize, inputs: Vec<Arc<Data>>) where Node: ThreadExecute<Data>, Data: Send + Sync {
466 let node_opt = match graph.nodes.get_mut(node_id) {
467 Some(id) => id,
468 None => panic!("While attempting to execute, could not retrieve node {}", node_id)
469 };
470 let node = match node_opt.take() {
471 Some(node) => node,
472 None => panic!("Could not retrieve node {} - is it currently being executed?", node_id)
473 };
474 graph.pool.execute(node, inputs, node_id);
475 }
476
477 fn assemble_inputs<Node, Data>(graph: &Graph<Node, Data>, intermediates: &HashMap<usize, Arc<Data>>, node_id: usize) -> Vec<Arc<Data>> where Node: ThreadExecute<Data>, Data: Send + Sync {
478 let mut inputs: Vec<Arc<Data>> = Vec::new();
479 let input_ids = match graph.node_inputs.get(node_id) {
480 Some(id) => id,
481 None => panic!("Could not find node {} in the graph", node_id)
482 };
483 for input_id in input_ids {
484 let intermediate = match intermediates.get(input_id) {
485 Some(intermediate) => intermediate,
486 None => panic!("Node {} attempted to execute, but input {} is missing", node_id, input_id)
487 };
488 inputs.push(Arc::clone(intermediate));
489 }
490 return inputs;
491 }
492
493 let mut num_nodes_remaining = recipe.runs.len();
495 let mut intermediates: HashMap<usize, Arc<Data>> = HashMap::with_capacity(num_nodes_remaining);
497 let mut remaining_inputs_map = recipe.node_inputs.clone();
500
501 for input_node in &recipe.inputs {
503 match inputs_map.remove(input_node) {
504 Some(inputs) => {
505 let arc_inputs: Vec<Arc<Data>> = inputs.into_iter().map(|input| Arc::new(input)).collect();
506 execute_node(self, *input_node, arc_inputs);
508 },
509 None => panic!("Input for {} is missing", input_node),
510 };
511 }
512
513 while num_nodes_remaining > 0 {
515 if let Ok(wstatus) = self.pool.wstatus_receiver.recv() {
516 match wstatus {
517 WorkerStatus::Complete(node, result, node_id) => {
518 num_nodes_remaining -= 1;
521 match self.nodes.get_mut(node_id) {
523 Some(node_option) => node_option.replace(node),
524 None => panic!("Received WorkerStatus for node {}, but this node is not in the graph", node_id),
525 };
526
527 if intermediates.insert(node_id, Arc::new(result)).is_some() {
530 panic!("Node {} was executed more than once, possibly due to a cycle", node_id);
531 }
532
533 if let Some(output_ids) = recipe.node_outputs.get(&node_id) {
536 for output_id in output_ids {
539 let remaining_inputs = match remaining_inputs_map.get_mut(output_id) {
540 Some(id) => id,
541 None => panic!("Node {} is not registered in the remaining_inputs_map", output_id)
542 };
543 remaining_inputs.remove(&node_id);
544 if remaining_inputs.len() == 0 {
546 let inputs = assemble_inputs(&self, &intermediates, output_id.clone());
548 execute_node(self, *output_id, inputs);
549 }
550 } } },
553 WorkerStatus::Fail(node_id) => panic!("Graph failed to execute because node {} failed", node_id),
554 } }
556 } let mut outputs_map = HashMap::new();
560 for output in &recipe.outputs {
561 match Arc::try_unwrap(intermediates.remove(output).unwrap()) {
563 Ok(data) => { outputs_map.insert(output.clone(), data); },
564 Err(_) => panic!("Could not retrieve output for node {}", output),
565 }
566 }
567 return outputs_map;
568 } }