use std::any::Any;
use std::collections::{HashMap, HashSet, VecDeque};
use std::future::Future;
use std::hash::{BuildHasher, Hasher};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use futures_util::future::BoxFuture;
use futures_util::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
#[cfg(feature = "tracing")]
use tracing::{debug, error, info, trace};
use crate::builder::{NodeId, TaskWire};
use crate::error::{DagError, DagResult};
use crate::node::{ExecutableNode, TypedNode};
use crate::DagOutput;
#[derive(Default, Clone)]
pub(crate) struct PassThroughHasher {
hash: u64,
}
impl Hasher for PassThroughHasher {
fn finish(&self) -> u64 {
self.hash
}
fn write_u32(&mut self, i: u32) {
self.hash = i as u64;
}
fn write(&mut self, _bytes: &[u8]) {
panic!("PassThroughHasher used on invalid type");
}
}
impl BuildHasher for PassThroughHasher {
type Hasher = PassThroughHasher;
fn build_hasher(&self) -> Self::Hasher {
PassThroughHasher::default()
}
}
pub(crate) type PassThroughHashMap<K, V> = HashMap<K, V, PassThroughHasher>;
pub struct DagRunner {
pub(crate) nodes: Vec<Option<Box<dyn ExecutableNode + Sync>>>,
pub(crate) edges: PassThroughHashMap<NodeId, Vec<NodeId>>,
pub(crate) dependents: PassThroughHashMap<NodeId, Vec<NodeId>>,
}
impl Default for DagRunner {
fn default() -> Self {
Self::new()
}
}
impl DagRunner {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: HashMap::default(),
dependents: HashMap::default(),
}
}
pub fn add_task<'dag, Input, Tk>(&'dag mut self, task: Tk) -> Tk::Retval<'dag>
where
Tk: TaskWire<Input>,
Input: Send + Sync + 'static,
{
let id = NodeId(self.nodes.len() as u32);
#[cfg(feature = "tracing")]
debug!(
task_id = id.0,
task_type = std::any::type_name::<Tk>(),
"adding task to DAG"
);
let node = TypedNode::new(task);
self.nodes.push(Some(Box::new(node)));
self.edges.insert(id, Vec::new());
self.dependents.insert(id, Vec::new());
Tk::new_from_dag(id, self)
}
#[inline]
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self, spawner)))]
pub async fn run<S, F>(mut self, spawner: S) -> DagResult<DagOutput>
where
S: Fn(BoxFuture<'static, DagResult<Arc<dyn Any + Send + Sync>>>) -> F,
F: Future<Output = DagResult<Arc<dyn Any + Send + Sync>>>,
{
#[cfg(feature = "tracing")]
info!("starting DAG execution");
let layers = self.compute_layers()?;
let total_tasks = layers.iter().map(|l| l.len()).sum::<usize>();
#[cfg(feature = "tracing")]
debug!(
layer_count = layers.len(),
total_tasks, "computed topological layers"
);
let mut outputs: PassThroughHashMap<NodeId, Arc<dyn Any + Send + Sync>> =
HashMap::with_capacity_and_hasher(total_tasks, PassThroughHasher::default());
let mut first_error = None;
for layer in layers {
#[cfg(feature = "tracing")]
{
debug!(task_count = layer.len(), "executing layer");
}
if layer.len() == 1 {
let node_id = layer[0];
#[cfg(feature = "tracing")]
trace!(
task_id = node_id.0,
"executing task inline (single-task layer optimization)"
);
let node = self.nodes[node_id.0 as usize].take();
if let Some(node) = node {
let dependencies: Vec<_> = self.edges[&node_id]
.iter()
.flat_map(|dep| outputs.get(dep))
.cloned()
.collect();
let result = AssertUnwindSafe(node.execute_with_deps(dependencies))
.catch_unwind()
.await
.unwrap_or_else(|panic_payload| {
let panic_message =
if let Some(s) = panic_payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
#[cfg(feature = "tracing")]
error!(
task_id = node_id.0,
panic_message = %panic_message,
"task panicked during inline execution"
);
Err(DagError::TaskPanicked {
task_id: node_id.0,
panic_message,
})
});
match result {
Ok(output) => {
outputs.insert(node_id, output);
}
Err(e) => {
first_error.get_or_insert(e);
}
}
}
} else {
let mut futures: FuturesUnordered<_> = layer
.into_iter()
.filter_map(|node_id| {
#[cfg(feature = "tracing")]
trace!(task_id = node_id.0, "spawning task");
let node = self.nodes[node_id.0 as usize].take();
if let Some(node) = node {
let dependencies: Vec<_> = self.edges[&node_id]
.iter()
.flat_map(|dep| outputs.get(dep))
.cloned()
.collect();
let inner_future = spawner(node.execute_with_deps(dependencies));
let inner_future = async move {
let result = inner_future.await?;
Ok((node_id, result))
};
Some(
AssertUnwindSafe(inner_future)
.catch_unwind()
.unwrap_or_else(move |panic_payload| {
let panic_message =
if let Some(s) = panic_payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) =
panic_payload.downcast_ref::<String>()
{
s.clone()
} else {
"unknown panic".to_string()
};
#[cfg(feature = "tracing")]
error!(
task_id = node_id.0,
panic_message = %panic_message,
"task panicked during inline execution"
);
Err(DagError::TaskPanicked {
task_id: node_id.0,
panic_message,
})
}),
)
} else {
None
}
})
.collect();
while let Some(out) = futures.next().await {
match out {
Ok(output) => {
outputs.insert(output.0, output.1);
}
Err(e) => {
first_error.get_or_insert(e);
}
}
}
}
if let Some(err) = first_error {
#[cfg(feature = "tracing")]
error!(?err, "DAG execution failed");
return Err(err);
}
}
#[cfg(feature = "tracing")]
info!("DAG execution completed successfully");
Ok(DagOutput::new(outputs))
}
fn compute_layers(&self) -> DagResult<Vec<Vec<NodeId>>> {
#[cfg(feature = "tracing")]
debug!("computing topological layers");
let mut in_degree: PassThroughHashMap<NodeId, usize> = HashMap::default();
let mut layers = Vec::new();
for (&node, deps) in self.edges.iter() {
let degree = deps.len();
in_degree.insert(node, degree);
}
let mut queue: VecDeque<NodeId> = in_degree
.iter()
.filter(|&(_, deg)| *deg == 0)
.map(|(&node, _)| node)
.collect();
let mut visited = HashSet::new();
while !queue.is_empty() {
let mut current_layer = Vec::new();
let layer_size = queue.len();
for _ in 0..layer_size {
if let Some(node) = queue.pop_front() {
if visited.contains(&node) {
continue;
}
current_layer.push(node);
visited.insert(node);
if let Some(deps) = self.dependents.get(&node) {
for &dependent in deps {
if let Some(degree) = in_degree.get_mut(&dependent) {
*degree -= 1;
if *degree == 0 {
queue.push_back(dependent);
}
}
}
}
}
}
if !current_layer.is_empty() {
layers.push(current_layer);
}
}
debug_assert!(!visited.is_empty() || layers.is_empty());
#[cfg(feature = "tracing")]
debug!(layer_count = layers.len(), "topological layers computed");
Ok(layers)
}
}
#[cfg(test)]
mod tests;