use std::borrow::Borrow;
use crate::threadpool::{ThreadExecute, ThreadPool, WorkerStatus};
use std::collections::{HashSet, HashMap};
use std::iter::{FromIterator};
use std::sync::Arc;
#[cfg(test)]
mod tests {
use std::iter::FromIterator;
use std::collections::{HashSet, HashMap};
use super::Graph;
use crate::threadpool::ThreadExecute;
use crate::threadpool::tests::{Adder};
use std::sync::Arc;
const NUM_THREADS: usize = 8;
#[test]
fn can_construct_graph() {
let _graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
}
#[test]
fn can_add_node() {
let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
let node_id = graph.add(Adder::new(), &[]);
assert_eq!(node_id, 0);
}
#[test]
#[should_panic(expected="does not yet exist in the graph")]
fn cannot_set_node_input_to_itself() {
let mut graph: Graph<Adder, i32> = Graph::new(NUM_THREADS);
graph.add(Adder::new(), vec![0]);
}
fn build_diamond_graph() -> (Graph<Adder, i32> ,usize, usize, usize, usize, usize) {
let mut graph = Graph::new(NUM_THREADS);
let input = graph.add(Adder::new(), &[]);
let hidden1 = graph.add(Adder::new(), &[input, input]);
let hidden2 = graph.add(Adder::new(), &[input, input]);
let output1 = graph.add(Adder::new(), &[hidden1, hidden2]);
let output2 = graph.add(Adder::new(), &[hidden1, hidden2]);
let _deadend = graph.add(Adder::new(), &[hidden1]);
return (graph, input, hidden1, hidden2, output1, output2);
}
#[test]
fn can_compile_graph() {
let (graph, input, hidden1, hidden2, output1, output2) = build_diamond_graph();
println!("Graph: {:?}", graph);
let out1_recipe = graph.compile(&[output1]);
println!("Output 1 Recipe: {:?}", out1_recipe);
assert_eq!(out1_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output1]));
assert_eq!(out1_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
let out2_recipe = graph.compile(vec![output2]);
println!("Output 2 Recipe: {:?}", out2_recipe);
assert_eq!(out2_recipe.runs, HashSet::from_iter(vec![input, hidden1, hidden2, output2]));
assert_eq!(out2_recipe.inputs, HashSet::from_iter([input].iter().cloned()));
}
#[test]
fn can_run_graph() {
let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
println!("Graph: {:?}", graph);
let recipe = graph.compile(vec![output1, output2, output2]);
println!("Recipe: {:?}", recipe);
let inputs_map = HashMap::from_iter(vec!(
(input, vec![1, 2, 3])
));
let outputs = graph.run(&recipe, inputs_map);
println!("Outputs: {:?}", outputs);
assert_eq!(outputs.get(&output1), Some(&24));
assert_eq!(outputs.get(&output2), Some(&24));
}
#[test]
fn can_run_graph_all_nodes_outputs() {
let (mut graph, input, _hidden1, _hidden2, output1, output2) = build_diamond_graph();
println!("Graph: {:?}", graph);
let recipe = graph.compile(vec![input, _hidden1, _hidden2, output1, output2, output2]);
println!("Recipe: {:?}", recipe);
let inputs_map = HashMap::from_iter(vec!(
(input, vec![1, 2, 3])
));
println!("Input map: {:?}", inputs_map);
let outputs = graph.run(&recipe, inputs_map);
println!("Outputs: {:?}", outputs);
assert_eq!(outputs.get(&output1), Some(&24));
assert_eq!(outputs.get(&output2), Some(&24));
}
#[test]
fn can_run_graph_input_is_output() {
let (mut graph, input, _hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
println!("Graph: {:?}", graph);
let recipe = graph.compile(vec![input]);
println!("Recipe: {:?}", recipe);
let inputs_map = HashMap::from_iter(vec!(
(input, vec![1, 2, 3])
));
println!("Input map: {:?}", inputs_map);
let outputs = graph.run(&recipe, inputs_map);
println!("Outputs: {:?}", outputs);
assert_eq!(outputs.get(&input), Some(&6));
}
#[test]
fn can_run_graph_input_and_node() {
let (mut graph, input, hidden1, _hidden2, _output1, _output2) = build_diamond_graph();
println!("Graph: {:?}", graph);
let recipe = graph.compile(vec![input, hidden1]);
println!("Recipe: {:?}", recipe);
let inputs_map = HashMap::from_iter(vec!(
(input, vec![1, 2, 3])
));
println!("Input map: {:?}", inputs_map);
let outputs = graph.run(&recipe, inputs_map);
println!("Outputs: {:?}", outputs);
assert_eq!(outputs.get(&hidden1), Some(&12));
}
struct FailNode;
impl ThreadExecute<i32> for FailNode {
fn execute(&mut self, _inputs: Vec<Arc<i32>>) -> Option<i32> {
return None;
}
}
#[test]
#[should_panic(expected="Graph failed to execute because node")]
fn node_failure_causes_panic() {
let mut graph = Graph::new(NUM_THREADS);
let input = graph.add(FailNode{}, &[]);
let recipe = graph.compile(&[input]);
let mut inputs_map = HashMap::new();
inputs_map.insert(input, vec!());
let _ = graph.run(&recipe, inputs_map);
}
#[test]
fn can_iterate_graph_nodes() {
let graph = build_diamond_graph().0;
let mut num_nodes = 0;
let expected_num_nodes = graph.len();
for node in graph {
assert!(node.valid);
num_nodes += 1;
}
assert_eq!(num_nodes, expected_num_nodes);
}
#[test]
fn can_iterate_ref_graph_nodes() {
let graph = build_diamond_graph().0;
let mut num_nodes = 0;
let expected_num_nodes = graph.len();
for node in &graph {
assert!(node.valid);
num_nodes += 1;
}
assert_eq!(num_nodes, expected_num_nodes);
}
#[test]
fn can_iterate_ref_mut_graph_nodes() {
let mut graph = build_diamond_graph().0;
let mut num_nodes = 0;
let expected_num_nodes = graph.len();
for node in &mut graph {
assert!(node.valid);
num_nodes += 1;
}
assert_eq!(num_nodes, expected_num_nodes);
}
}
#[derive(Debug)]
pub struct Recipe {
runs: HashSet<usize>,
pub inputs: HashSet<usize>,
pub outputs: HashSet<usize>,
node_outputs: HashMap<usize, HashSet<usize>>,
node_inputs: HashMap<usize, HashSet<usize>>,
}
impl Recipe {
fn new(runs: HashSet<usize>, inputs: HashSet<usize>, outputs: HashSet<usize>, node_outputs: HashMap<usize, HashSet<usize>>, node_inputs: HashMap<usize, HashSet<usize>>) -> Recipe {
if inputs.len() == 0 {
panic!("Invalid Recipe: Found 0 inputs. Recipes must have at least one input node.");
}
return Recipe{runs: runs, inputs: inputs, outputs: outputs, node_outputs: node_outputs, node_inputs: node_inputs};
}
}
#[derive(Debug)]
pub struct Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
nodes: Vec<Option<Node>>,
node_inputs: Vec<Vec<usize>>,
pool: ThreadPool<Node, Data>,
}
impl<Node, Data> IntoIterator for Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
type Item = Node;
type IntoIter = std::iter::Map<std::vec::IntoIter<std::option::Option<Node>>, fn(std::option::Option<Node>) -> Node>;
fn into_iter(self) -> Self::IntoIter {
fn expect_node<Node>(node: Option<Node>) -> Node {
return node.expect("Node has been moved out of the graph. Is the graph being executed?");
}
return self.nodes.into_iter().map(expect_node);
}
}
impl<'a, Node, Data> IntoIterator for &'a Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
type Item = &'a Node;
type IntoIter = std::iter::Map<std::slice::Iter<'a, std::option::Option<Node>>, fn(&std::option::Option<Node>) -> &Node>;
fn into_iter(self) -> Self::IntoIter {
fn expect_node<Node>(node: &Option<Node>) -> &Node {
return match node {
Some(n) => n,
None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
};
}
return self.nodes.iter().map(expect_node);
}
}
impl<'a, Node, Data> IntoIterator for &'a mut Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
type Item = &'a mut Node;
type IntoIter = std::iter::Map<std::slice::IterMut<'a, std::option::Option<Node>>, fn(&mut std::option::Option<Node>) -> &mut Node>;
fn into_iter(self) -> Self::IntoIter {
fn expect_node<Node>(node: &mut Option<Node>) -> &mut Node {
return match node {
Some(n) => n,
None => panic!("Node has been moved out of the graph. Is the graph being executed?"),
};
}
return self.nodes.iter_mut().map(expect_node);
}
}
impl<Node: 'static, Data: 'static> Graph<Node, Data> where Node: ThreadExecute<Data>, Data: Send + Sync {
pub fn new(num_threads: usize) -> Graph<Node, Data> {
return Graph{nodes: Vec::new(), node_inputs: Vec::new(), pool: ThreadPool::new(num_threads)};
}
pub fn len(&self) -> usize {
return self.nodes.len();
}
pub fn get(&self, index: usize) -> Option<&Node> {
if let Some(node) = self.nodes.get(index) {
return node.as_ref();
}
return None;
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut Node> {
if let Some(node) = self.nodes.get_mut(index) {
return node.as_mut();
}
return None;
}
pub fn add<Container, Elem>(&mut self, node: Node, inputs: Container) -> usize where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
let node_id = self.nodes.len();
self.nodes.push(Some(node));
let inputs = inputs.into_iter().map(|x| x.borrow().clone()).collect();
for &input in &inputs {
if input >= node_id {
panic!("Cannot add node {} as an input to node {} as it does not yet exist in the graph.", input, node_id);
}
}
self.node_inputs.push(inputs);
return node_id;
}
pub fn compile<Container, Elem>(&self, fetches: Container) -> Recipe
where Container: IntoIterator<Item=Elem>, Elem: Borrow<usize> {
let mut index = 0;
let mut recipe_inputs = HashSet::new();
let mut node_outputs: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut node_inputs: HashMap<usize, HashSet<usize>> = HashMap::new();
let mut fetches: Vec<usize> = fetches.into_iter().map(|x| x.borrow().clone()).collect();
let recipe_outputs = HashSet::from_iter(fetches.iter().cloned());
while index < fetches.len() {
let node_id = fetches.get(index).expect(
&format!("Could not get index {} index in fetches ({:?}) during BFS", index, fetches));
let inputs = self.node_inputs.get(*node_id).expect(
&format!("Could not get node inputs for node {}", node_id));
if inputs.len() == 0 {
recipe_inputs.insert(*node_id);
}
node_inputs.insert(*node_id, inputs.iter().cloned().collect());
for input in inputs {
match node_outputs.get_mut(input) {
Some(outputs) => { outputs.insert(*node_id); },
None => {
node_outputs.insert(*input, HashSet::from_iter(vec![*node_id]));
},
};
}
fetches.extend(inputs);
index += 1;
}
return Recipe::new(HashSet::from_iter(fetches), recipe_inputs, recipe_outputs, node_outputs, node_inputs);
}
pub fn run(&mut self, recipe: &Recipe, mut inputs_map: HashMap<usize, Vec<Data>>) -> HashMap<usize, Data> {
fn execute_node<Node: 'static, Data: 'static>(graph: &mut Graph<Node, Data>, node_id: usize, inputs: Vec<Arc<Data>>) where Node: ThreadExecute<Data>, Data: Send + Sync {
let node = graph.nodes.get_mut(node_id).expect(
&format!("While attempting to execute, could not retrieve node {}", node_id)
).take().expect(
&format!("Could not retrieve node {} - is it currently being executed?", node_id));
graph.pool.execute(node, inputs, node_id);
}
fn assemble_inputs<Node, Data>(graph: &Graph<Node, Data>, intermediates: &HashMap<usize, Arc<Data>>, node_id: usize) -> Vec<Arc<Data>> where Node: ThreadExecute<Data>, Data: Send + Sync {
let mut inputs: Vec<Arc<Data>> = Vec::new();
let input_ids = graph.node_inputs.get(node_id).expect(
&format!("Could not find node {} in the graph", node_id));
for input_id in input_ids {
let intermediate = intermediates.get(input_id).expect(
&format!("Node {} attempted to execute, but input {} is missing", node_id, input_id));
inputs.push(Arc::clone(intermediate));
}
return inputs;
}
let mut num_nodes_remaining = recipe.runs.len();
let mut intermediates: HashMap<usize, Arc<Data>> = HashMap::with_capacity(num_nodes_remaining);
let mut remaining_inputs_map = recipe.node_inputs.clone();
for input_node in &recipe.inputs {
match inputs_map.remove(input_node) {
Some(inputs) => {
let arc_inputs: Vec<Arc<Data>> = inputs.into_iter().map(|input| Arc::new(input)).collect();
execute_node(self, *input_node, arc_inputs);
},
None => panic!("Input for {} is missing", input_node),
};
}
while num_nodes_remaining > 0 {
if let Ok(wstatus) = self.pool.wstatus_receiver.recv() {
match wstatus {
WorkerStatus::Complete(node, result, node_id) => {
num_nodes_remaining -= 1;
match self.nodes.get_mut(node_id) {
Some(node_option) => node_option.replace(node),
None => panic!("Received WorkerStatus for node {}, but this node is not in the graph", node_id),
};
if intermediates.insert(node_id, Arc::new(result)).is_some() {
panic!("Node {} was executed more than once, possibly due to a cycle", node_id);
}
if let Some(output_ids) = recipe.node_outputs.get(&node_id) {
for output_id in output_ids {
let remaining_inputs = remaining_inputs_map.get_mut(output_id).expect(
&format!("Node {} is not registered in the remaining_inputs_map", output_id));
remaining_inputs.remove(&node_id);
if remaining_inputs.len() == 0 {
let inputs = assemble_inputs(&self, &intermediates, output_id.clone());
execute_node(self, *output_id, inputs);
}
}
}
},
WorkerStatus::Fail(node_id) => panic!("Graph failed to execute because node {} failed", node_id),
}
}
}
let mut outputs_map = HashMap::new();
for output in &recipe.outputs {
match Arc::try_unwrap(intermediates.remove(output).unwrap()) {
Ok(data) => { outputs_map.insert(output.clone(), data); },
Err(_) => panic!("Could not retrieve output for node {}", output),
}
}
return outputs_map;
}
}