use std::collections::{HashMap, HashSet};
use tokio::{select, task::JoinSet};
use tracing::info;
use crate::{
error::Error,
job::Job,
mermaid::mermaid,
node::{Node, NodeBuilder, Payload},
outputs::Outputs,
validation::{validate_job, validate_nodes},
};
pub trait Ctx: Clone + Send + 'static {}
impl<T> Ctx for T where T: Clone + Send + 'static {}
pub trait Er: Clone + Send + 'static + std::error::Error {}
impl<T> Er for T where T: Clone + Send + 'static + std::error::Error {}
#[macro_export]
macro_rules! build_graph {
( $( $ty:ty ),* $(,)? ) => {{
let mut graph = $crate::graph::Graph::builder();
$(
graph.add_node::<$ty>();
)*
graph.build()
}};
}
pub struct Builder<C: Ctx, E: Er> {
nodes: Vec<Node<C, E>>,
}
impl<C: Ctx, E: Er> Builder<C, E> {
pub fn add_node<T: NodeBuilder<C, E>>(&mut self) {
self.nodes.push(T::node());
}
#[must_use]
pub fn with_node<T: NodeBuilder<C, E>>(mut self) -> Self {
self.add_node::<T>();
self
}
pub fn build(self) -> Result<Graph<C, E>, Error<E>> {
let mut nodes = self.nodes;
nodes.sort_by_key(|n| n.id);
let mut seen = HashSet::new();
nodes.retain(|node| seen.insert(node.id));
let adj = validate_nodes(&nodes)?;
Ok(Graph { nodes, adj })
}
}
pub struct Graph<C: Ctx, E: Er> {
pub(crate) nodes: Vec<Node<C, E>>,
pub(crate) adj: Vec<Vec<usize>>,
}
impl<C: Ctx, E: Er> Graph<C, E> {
#[must_use]
pub fn builder() -> Builder<C, E> {
Builder { nodes: vec![] }
}
#[must_use]
pub fn node_name(&self, i: usize) -> &'static str {
self.nodes[i].name
}
pub fn validate_job(&self, job: &Job<C, E>) -> Result<(), Error<E>> {
validate_job(&self.nodes, job)
}
fn outputs(&self, results: HashMap<usize, Payload>) -> Outputs {
let id_to_payload = results
.into_iter()
.map(|(i, payload)| (self.nodes[i].id, payload))
.collect();
Outputs::new(id_to_payload)
}
#[must_use]
pub fn mermaid(&self, job: &Job<C, E>) -> String {
mermaid(self, job)
}
pub async fn execute(&self, job: Job<C, E>, ctx: C) -> Result<Outputs, Error<E>> {
let mut handles = JoinSet::new();
let mut pending = job.pending(self);
let mut results = job
.inputs
.into_iter()
.filter_map(|(id, payload)| {
let i = self.nodes.binary_search_by_key(&id, |n| n.id).ok()?;
Some((i, payload))
})
.collect::<HashMap<_, _>>();
info!(count = pending.len(), "Job start");
loop {
let is_done = |i| results.contains_key(i);
let ready = pending.extract_if(|i| self.adj[*i].iter().all(is_done));
for i in ready {
info!(node = self.node_name(i), "Node start");
let node = self.nodes[i].clone();
let payloads = self.adj[i].iter().map(|i| &results[i]).collect();
let payload = (node.prepare)(payloads);
let ctx = ctx.clone();
handles.spawn(async move { (i, (node.execute)(ctx, payload).await) });
}
select! {
res = handles.join_next() => {
match res {
Some(Ok((i, Ok(r)))) => {
info!(node = self.node_name(i), "Node done");
results.insert(i, r);
}
Some(Ok((i, Err(error)))) => {
info!(node = self.node_name(i), ?error, "Node failed");
let outputs = self.outputs(results);
return Err(Error::NodeFailed { outputs, i, error });
}
Some(Err(error)) => {
info!(?error, "Node panicked");
let outputs = self.outputs(results);
return Err(Error::NodePanic { outputs, error });
}
None => {
info!("Job done");
let outputs = self.outputs(results);
return Ok(outputs);
}
}
}
() = job.cancellation_token.cancelled() => {
info!("Job cancelled");
let outputs = self.outputs(results);
return Err(Error::Cancelled { outputs });
}
}
}
}
}