use std::prelude::v1::*;
use std::collections::{HashMap, HashSet};
use serde::Serialize;
use crate::catalog_indexer::index_catalog_with_params;
use crate::pipeline::{StepInfo, PipelineInfo, ptr_to_id};
use super::types::{Edge, GraphNode, PipelineGraph};
pub fn build_pipeline_graph<'a>(
pipe: &'a impl PipelineInfo,
catalog: &impl Serialize,
params: &impl Serialize,
) -> PipelineGraph<'a> {
let catalog_index = index_catalog_with_params(catalog, params);
let dataset_names = catalog_index.into_inner();
let mut nodes: Vec<GraphNode<'a>> = Vec::new();
pipe.for_each_info(&mut |item| {
collect_node(item, None, &mut nodes);
});
let edges = build_edges(&nodes);
let node_indices: Vec<usize> = nodes
.iter()
.enumerate()
.filter(|(_, n)| !n.is_pipe)
.map(|(i, _)| i)
.collect();
let all_outputs: HashSet<usize> = nodes
.iter()
.filter(|n| !n.is_pipe)
.flat_map(|n| n.outputs.iter().map(|d| d.id))
.collect();
let all_inputs: HashSet<usize> = nodes
.iter()
.filter(|n| !n.is_pipe)
.flat_map(|n| n.inputs.iter().map(|d| d.id))
.collect();
let source_datasets = all_inputs.difference(&all_outputs).copied().collect();
PipelineGraph {
nodes,
edges,
node_indices,
source_datasets,
dataset_names,
}
}
fn collect_node<'a>(
item: &'a dyn StepInfo,
parent: Option<usize>,
nodes: &mut Vec<GraphNode<'a>>,
) {
let index = nodes.len();
let is_pipe = !item.is_leaf();
let mut inputs = Vec::new();
item.for_each_input(&mut |d| inputs.push(d.clone()));
let mut outputs = Vec::new();
item.for_each_output(&mut |d| outputs.push(d.clone()));
nodes.push(GraphNode {
id: ptr_to_id(item),
name: item.name(),
is_pipe,
inputs,
outputs,
pipe_children: Vec::new(),
parent_pipe: parent,
item,
});
if let Some(parent_idx) = parent {
nodes[parent_idx].pipe_children.push(index);
}
if is_pipe {
item.for_each_child(&mut |child| {
collect_node(child, Some(index), nodes);
});
}
}
fn build_edges<'a>(nodes: &'_ [GraphNode<'a>]) -> Vec<Edge<'a>> {
let mut producers: HashMap<usize, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if !node.is_pipe {
for output in &node.outputs {
producers.insert(output.id, i);
}
}
}
let mut edges = Vec::new();
for (i, node) in nodes.iter().enumerate() {
if !node.is_pipe {
for input in &node.inputs {
if let Some(&producer_idx) = producers.get(&input.id) {
edges.push(Edge {
from_node: producer_idx,
to_node: i,
dataset: input.clone(),
});
}
}
}
}
edges
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasets::{MemoryDataset, Param};
use crate::pipeline::{Node, Pipeline};
use serde::Serialize;
#[derive(Serialize)]
struct TestCatalog {
a: MemoryDataset<i32>,
b: MemoryDataset<i32>,
c: MemoryDataset<i32>,
}
#[derive(Serialize)]
struct TestParams {
initial_value: Param<i32>,
}
#[test]
fn test_linear_pipeline() {
let catalog = TestCatalog {
a: MemoryDataset::new(),
b: MemoryDataset::new(),
c: MemoryDataset::new(),
};
let params = TestParams {
initial_value: Param(1),
};
let pipe = (
Node {
name: "n1",
func: |v| (v,),
input: (¶ms.initial_value,),
output: (&catalog.a,),
},
Node {
name: "n2",
func: |v| (v,),
input: (&catalog.a,),
output: (&catalog.b,),
},
Node {
name: "n3",
func: |v| (v,),
input: (&catalog.b,),
output: (&catalog.c,),
},
);
let graph = build_pipeline_graph(&pipe, &catalog, ¶ms);
assert_eq!(graph.nodes.len(), 3);
assert_eq!(graph.node_indices.len(), 3);
assert!(graph.nodes.iter().all(|n| !n.is_pipe));
assert_eq!(graph.nodes[0].name, "n1");
assert_eq!(graph.nodes[1].name, "n2");
assert_eq!(graph.nodes[2].name, "n3");
assert_eq!(graph.edges.len(), 2);
assert_eq!(graph.edges[0].from_node, 0);
assert_eq!(graph.edges[0].to_node, 1);
assert_eq!(graph.edges[1].from_node, 1);
assert_eq!(graph.edges[1].to_node, 2);
assert_eq!(graph.source_datasets.len(), 1);
let source_id = *graph.source_datasets.iter().next().unwrap();
assert_eq!(
graph.dataset_names.get(&source_id).map(|s| s.as_str()),
Some("params.initial_value")
);
assert!(graph.dataset_names.values().any(|n| n == "catalog.a"));
assert!(graph.dataset_names.values().any(|n| n == "catalog.b"));
assert!(graph.dataset_names.values().any(|n| n == "catalog.c"));
assert!(graph.nodes[0].inputs[0].meta.is_param());
assert!(!graph.nodes[1].inputs[0].meta.is_param());
}
#[test]
fn test_diamond_pipeline() {
let catalog = TestCatalog {
a: MemoryDataset::new(),
b: MemoryDataset::new(),
c: MemoryDataset::new(),
};
let params = TestParams {
initial_value: Param(1),
};
let pipe = (
Node {
name: "n1",
func: |v| (v,),
input: (¶ms.initial_value,),
output: (&catalog.a,),
},
Node {
name: "n2",
func: |v| (v,),
input: (¶ms.initial_value,),
output: (&catalog.b,),
},
Node {
name: "n3",
func: |a, b| (a + b,),
input: (&catalog.a, &catalog.b),
output: (&catalog.c,),
},
);
let graph = build_pipeline_graph(&pipe, &catalog, ¶ms);
assert_eq!(graph.nodes.len(), 3);
assert_eq!(graph.edges.len(), 2);
assert!(graph.edges.iter().all(|e| e.to_node == 2));
let from_nodes: HashSet<_> = graph.edges.iter().map(|e| e.from_node).collect();
assert!(from_nodes.contains(&0));
assert!(from_nodes.contains(&1));
}
#[test]
fn test_nested_pipeline() {
let catalog = TestCatalog {
a: MemoryDataset::new(),
b: MemoryDataset::new(),
c: MemoryDataset::new(),
};
let params = TestParams {
initial_value: Param(1),
};
let pipe = (
Node {
name: "n1",
func: |v| (v,),
input: (¶ms.initial_value,),
output: (&catalog.a,),
},
Pipeline {
name: "inner",
steps: (
Node {
name: "n2",
func: |v| (v,),
input: (&catalog.a,),
output: (&catalog.b,),
},
Node {
name: "n3",
func: |v| (v,),
input: (&catalog.b,),
output: (&catalog.c,),
},
),
input: (&catalog.a,),
output: (&catalog.c,),
},
);
let graph = build_pipeline_graph(&pipe, &catalog, ¶ms);
assert_eq!(graph.nodes.len(), 4);
assert_eq!(graph.node_indices.len(), 3);
let inner = &graph.nodes[1];
assert_eq!(inner.name, "inner");
assert!(inner.is_pipe);
assert_eq!(inner.pipe_children.len(), 2); assert!(inner.parent_pipe.is_none());
let n2 = &graph.nodes[2];
let n3 = &graph.nodes[3];
assert_eq!(n2.parent_pipe, Some(1));
assert_eq!(n3.parent_pipe, Some(1));
assert_eq!(graph.edges.len(), 2);
}
#[test]
fn test_no_output_node() {
let catalog = TestCatalog {
a: MemoryDataset::new(),
b: MemoryDataset::new(),
c: MemoryDataset::new(),
};
let params = TestParams {
initial_value: Param(1),
};
let pipe = (
Node {
name: "n1",
func: |v| (v,),
input: (¶ms.initial_value,),
output: (&catalog.a,),
},
Node {
name: "n2",
func: |v| println!("{v}"),
input: (&catalog.a,),
output: (),
},
);
let graph = build_pipeline_graph(&pipe, &catalog, ¶ms);
assert_eq!(graph.nodes.len(), 2);
assert_eq!(graph.edges.len(), 1);
assert_eq!(graph.nodes[1].outputs.len(), 0);
}
}