use std::prelude::v1::*;
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use serde::Serialize;
use crate::pipeline::{DatasetEvent, DatasetRef, RunnableStep, Steps};
use crate::error::PondError;
use crate::graph::build_pipeline_graph;
use crate::hooks::Hooks;
use super::Runner;
pub struct ParallelRunner {
pub num_threads: usize,
}
impl ParallelRunner {
pub fn new(num_threads: usize) -> Self {
Self { num_threads }
}
}
impl Default for ParallelRunner {
fn default() -> Self {
Self { num_threads: 0 }
}
}
fn collect_items<'a, E>(items: &mut Vec<&'a dyn RunnableStep<E>>, item: &'a dyn RunnableStep<E>) {
items.push(item);
if !item.is_leaf() {
item.for_each_child_step(&mut |child| {
collect_items(items, child);
});
}
}
impl Runner for ParallelRunner {
fn name(&self) -> &'static str {
"parallel"
}
fn run<E>(&self, pipe: &impl Steps<E>, catalog: &impl Serialize, params: &impl Serialize, hooks: &impl Hooks) -> Result<(), E>
where
E: From<PondError> + Send + Sync + core::fmt::Display + core::fmt::Debug + 'static,
{
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.build_global()
.ok();
let graph = build_pipeline_graph(pipe, catalog, params);
if graph.node_indices.is_empty() {
return Ok(());
}
let mut callable_items: Vec<&dyn RunnableStep<E>> = Vec::new();
pipe.for_each_item(&mut |item| {
collect_items(&mut callable_items, item);
});
let started: Vec<AtomicBool> = graph.node_indices.iter().map(|_| AtomicBool::new(false)).collect();
let pipe_indices: Vec<usize> = graph.nodes.iter().enumerate()
.filter(|(_, n)| n.is_pipe)
.map(|(i, _)| i)
.collect();
let pipe_started: Vec<AtomicBool> = pipe_indices.iter().map(|_| AtomicBool::new(false)).collect();
let pipe_completed: Vec<AtomicBool> = pipe_indices.iter().map(|_| AtomicBool::new(false)).collect();
let produced = Mutex::new(graph.source_datasets.clone());
let first_error: Mutex<Option<E>> = Mutex::new(None);
let has_error = AtomicBool::new(false);
rayon::scope(|s| {
loop {
if has_error.load(Ordering::Acquire) {
break;
}
let produced_snapshot: HashSet<_> = produced.lock().unwrap().clone();
for (pi, &pipe_idx) in pipe_indices.iter().enumerate() {
let pipe_node = &graph.nodes[pipe_idx];
if !pipe_started[pi].load(Ordering::Acquire)
&& pipe_node.inputs.iter().all(|d| produced_snapshot.contains(&d.id))
{
pipe_started[pi].store(true, Ordering::Release);
hooks.for_each_hook(&mut |h| h.before_pipeline_run(pipe_node.item));
}
if pipe_started[pi].load(Ordering::Acquire)
&& !pipe_completed[pi].load(Ordering::Acquire)
&& pipe_node.outputs.iter().all(|d| produced_snapshot.contains(&d.id))
{
pipe_completed[pi].store(true, Ordering::Release);
hooks.for_each_hook(&mut |h| h.after_pipeline_run(pipe_node.item));
}
}
let mut any_started = false;
for (si, &node_idx) in graph.node_indices.iter().enumerate() {
if started[si].load(Ordering::Acquire) {
continue;
}
let node = &graph.nodes[node_idx];
if node.inputs.iter().all(|d| produced_snapshot.contains(&d.id)) {
started[si].store(true, Ordering::Release);
any_started = true;
let produced = &produced;
let output_ids: Vec<usize> = node.outputs.iter().map(|d| d.id).collect();
let item = callable_items[node_idx];
let first_error = &first_error;
let has_error = &has_error;
let graph_nodes = &graph.nodes;
let names = &graph.dataset_names;
s.spawn(move |_| {
hooks.for_each_hook(&mut |h| h.before_node_run(item));
let mut on_event = |ds: &DatasetRef<'_>, event: DatasetEvent| {
super::dispatch_dataset_event(item, ds, event, names, hooks);
};
match item.call(&mut on_event) {
Ok(()) => {
hooks.for_each_hook(&mut |h| h.after_node_run(item));
produced.lock().unwrap().extend(output_ids);
}
Err(e) => {
let msg = e.to_string();
hooks.for_each_hook(&mut |h| h.on_node_error(item, &msg));
let mut parent = graph_nodes[node_idx].parent_pipe;
while let Some(pipe_idx) = parent {
let pipe = &graph_nodes[pipe_idx];
hooks.for_each_hook(&mut |h| h.on_pipeline_error(pipe.item, &msg));
parent = pipe.parent_pipe;
}
let mut guard = first_error.lock().unwrap();
if guard.is_none() {
*guard = Some(e);
}
has_error.store(true, Ordering::Release);
}
}
});
}
}
if started.iter().all(|s| s.load(Ordering::Acquire)) {
break;
}
if !any_started {
std::thread::yield_now();
}
}
});
{
let produced_snapshot = produced.lock().unwrap();
for (pi, &pipe_idx) in pipe_indices.iter().enumerate() {
let pipe_node = &graph.nodes[pipe_idx];
if pipe_started[pi].load(Ordering::Acquire)
&& !pipe_completed[pi].load(Ordering::Acquire)
&& pipe_node.outputs.iter().all(|d| produced_snapshot.contains(&d.id))
{
hooks.for_each_hook(&mut |h| h.after_pipeline_run(pipe_node.item));
}
}
}
match first_error.into_inner().unwrap() {
Some(e) => Err(e),
None => Ok(()),
}
}
}