dagx/runner.rs
1//! DAG runner for task orchestration and execution.
2//!
3//! Provides DagRunner for building and executing directed acyclic graphs of async tasks
4//! with compile-time type-safe dependencies.
5//!
6//! Uses Mutex for interior mutability to enable builder pattern (`&self` instead of `&mut self`).
7
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::panic::AssertUnwindSafe;
10use std::sync::{
11 atomic::{AtomicBool, Ordering},
12 Arc,
13};
14
15use futures::channel::mpsc;
16use futures::future::BoxFuture;
17use futures::{FutureExt, StreamExt};
18use parking_lot::Mutex;
19
20use crate::builder::TaskBuilder;
21use crate::error::{DagError, DagResult};
22use crate::extract::ExtractInput;
23use crate::node::{ExecutableNode, TypedNode};
24use crate::task::Task;
25use crate::types::{NodeId, Pending, TaskHandle};
26
27// Guard to ensure run_lock is released even on early return or panic
28struct RunGuard<'a> {
29 lock: &'a AtomicBool,
30}
31
32impl<'a> Drop for RunGuard<'a> {
33 fn drop(&mut self) {
34 self.lock.store(false, Ordering::SeqCst);
35 }
36}
37
38/// Build and execute a typed DAG of tasks.
39///
40/// A `DagRunner` is the main orchestrator for building and executing a directed acyclic graph
41/// of async tasks with compile-time type-safe dependencies.
42///
43/// # Workflow
44///
45/// 1. Create a new DAG with [`DagRunner::new`]
46/// 2. Add tasks with [`DagRunner::add_task`] to get [`TaskBuilder`] builders
47/// 3. Wire dependencies with [`TaskBuilder::depends_on`]
48/// 4. Execute all tasks with [`DagRunner::run`]
49/// 5. Optionally retrieve outputs with [`DagRunner::get`]
50///
51/// # Examples
52///
53/// ```no_run
54/// # use dagx::{task, DagRunner, Task};
55/// // Task with state constructed via ::new()
56/// struct LoadValue { value: i32 }
57///
58/// impl LoadValue {
59/// fn new(value: i32) -> Self { Self { value } }
60/// }
61///
62/// #[task]
63/// impl LoadValue {
64/// async fn run(&mut self) -> i32 { self.value }
65/// }
66///
67/// // Unit struct - no fields needed
68/// struct Add;
69///
70/// #[task]
71/// impl Add {
72/// async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
73/// }
74///
75/// # async {
76/// let dag = DagRunner::new();
77///
78/// // Construct instances using ::new() pattern
79/// let x = dag.add_task(LoadValue::new(2));
80/// let y = dag.add_task(LoadValue::new(3));
81/// let sum = dag.add_task(Add).depends_on((&x, &y));
82///
83/// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
84///
85/// assert_eq!(dag.get(sum).unwrap(), 5);
86/// # };
87/// ```
88///
89/// Uses Mutex for interior mutability to enable the builder pattern (`&self` not `&mut self`).
90/// This allows fluent chaining of `add_task()` calls.
91///
92/// Nodes use Option to allow taking ownership during execution.
93/// Outputs are Arc-wrapped and stored separately for retrieval via get().
94/// Arc enables efficient sharing during fanout without cloning data.
95pub struct DagRunner {
96 pub(crate) nodes: Mutex<Vec<Option<Box<dyn ExecutableNode>>>>,
97 pub(crate) outputs: Mutex<HashMap<NodeId, std::sync::Arc<dyn std::any::Any + Send + Sync>>>,
98 pub(crate) edges: Mutex<HashMap<NodeId, Vec<NodeId>>>, // node -> dependencies
99 pub(crate) dependents: Mutex<HashMap<NodeId, Vec<NodeId>>>, // node -> tasks that depend on it
100 pub(crate) next_id: Mutex<usize>,
101 pub(crate) run_lock: AtomicBool, // Ensures only one run() at a time
102}
103
104impl Default for DagRunner {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl DagRunner {
111 /// Create a new empty DAG.
112 ///
113 /// # Examples
114 ///
115 /// ```
116 /// use dagx::DagRunner;
117 ///
118 /// let dag = DagRunner::new();
119 /// ```
120 pub fn new() -> Self {
121 Self {
122 nodes: Mutex::new(Vec::new()),
123 outputs: Mutex::new(HashMap::new()),
124 edges: Mutex::new(HashMap::new()),
125 dependents: Mutex::new(HashMap::new()),
126 next_id: Mutex::new(0),
127 run_lock: AtomicBool::new(false),
128 }
129 }
130
131 pub(crate) fn alloc_id(&self) -> NodeId {
132 let mut next_id = self.next_id.lock();
133 let id = NodeId(*next_id);
134 *next_id += 1;
135 id
136 }
137
138 /// Add a task instance to the DAG, returning a node builder for wiring dependencies.
139 ///
140 /// The returned [`TaskBuilder<Tk, Pending>`](TaskBuilder) can be used to:
141 /// - Specify dependencies via [`TaskBuilder::depends_on`]
142 /// - Used directly as a [`TaskHandle`] to the task's output
143 ///
144 /// # Examples
145 ///
146 /// ```no_run
147 /// # use dagx::{task, DagRunner, Task};
148 /// // Task with state - shows you construct with specific value
149 /// struct LoadValue {
150 /// initial: i32,
151 /// }
152 ///
153 /// impl LoadValue {
154 /// fn new(initial: i32) -> Self {
155 /// Self { initial }
156 /// }
157 /// }
158 ///
159 /// #[task]
160 /// impl LoadValue {
161 /// async fn run(&mut self) -> i32 { self.initial }
162 /// }
163 ///
164 /// // Task with configuration - shows you can parameterize behavior
165 /// struct AddOffset {
166 /// offset: i32,
167 /// }
168 ///
169 /// impl AddOffset {
170 /// fn new(offset: i32) -> Self {
171 /// Self { offset }
172 /// }
173 /// }
174 ///
175 /// #[task]
176 /// impl AddOffset {
177 /// async fn run(&mut self, x: &i32) -> i32 { x + self.offset }
178 /// }
179 ///
180 /// # async {
181 /// let dag = DagRunner::new();
182 ///
183 /// // Construct task with initial value of 10
184 /// let base = dag.add_task(LoadValue::new(10));
185 ///
186 /// // Construct task with offset of 1
187 /// let inc = dag.add_task(AddOffset::new(1)).depends_on(&base);
188 ///
189 /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
190 /// assert_eq!(dag.get(&inc).unwrap(), 11);
191 /// # };
192 /// ```
193 pub fn add_task<Tk>(&self, task: Tk) -> TaskBuilder<'_, Tk, Pending>
194 where
195 Tk: Task + 'static,
196 Tk::Input: 'static + Clone + ExtractInput,
197 Tk::Output: 'static + Clone,
198 {
199 let id = self.alloc_id();
200 let node = TypedNode::new(id, task);
201 self.nodes.lock().push(Some(Box::new(node)));
202 self.edges.lock().insert(id, Vec::new());
203 self.dependents.lock().insert(id, Vec::new());
204
205 TaskBuilder {
206 id,
207 dag: self,
208 _phantom: std::marker::PhantomData,
209 }
210 }
211
212 /// Run the entire DAG to completion using the provided spawner.
213 ///
214 /// This method:
215 /// - Executes tasks in topological order (respecting dependencies)
216 /// - Runs ready tasks with maximum parallelism (executor-limited)
217 /// - Executes each task at most once
218 /// - Waits for **all sinks** (tasks with no dependents) to complete
219 /// - Is runtime-agnostic via the spawner function
220 ///
221 /// # Parameters
222 ///
223 /// - `spawner`: A function that spawns futures on the async runtime. Examples:
224 /// - Tokio: `|fut| { tokio::spawn(fut); }`
225 /// - Smol: `|fut| { smol::spawn(fut).detach(); }`
226 /// - Async-std: `|fut| { async_std::task::spawn(fut); }`
227 ///
228 /// # Errors
229 ///
230 /// Returns `DagError::CycleDetected` if the DAG contains a cycle.
231 ///
232 /// # Examples
233 ///
234 /// ```
235 /// # use dagx::{task, DagRunner, Task};
236 /// // Tuple struct
237 /// struct Value(i32);
238 ///
239 /// #[task]
240 /// impl Value {
241 /// async fn run(&mut self) -> i32 { self.0 }
242 /// }
243 ///
244 /// // Unit struct
245 /// struct Add;
246 ///
247 /// #[task]
248 /// impl Add {
249 /// async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
250 /// }
251 ///
252 /// # async {
253 /// let dag = DagRunner::new();
254 ///
255 /// let a = dag.add_task(Value(1));
256 /// let b = dag.add_task(Value(2));
257 /// let sum = dag.add_task(Add).depends_on((&a, &b));
258 ///
259 /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap(); // Executes all tasks
260 /// # };
261 /// ```
262 ///
263 /// # Implementation Note
264 ///
265 /// Tasks communicate via oneshot channels created fresh during run().
266 /// This eliminates Mutex contention on outputs and enables true streaming execution.
267 /// Type erasure occurs only at the ExecutableNode trait boundary - channels are created
268 /// with full type information and type-erased just before passing to execute_with_channels.
269 pub async fn run<S>(&self, spawner: S) -> DagResult<()>
270 where
271 S: Fn(BoxFuture<'static, ()>),
272 {
273 // Acquire run lock to prevent concurrent executions
274 // Use atomic compare_exchange to check if already running
275 if self
276 .run_lock
277 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
278 .is_err()
279 {
280 return Err(DagError::CycleDetected {
281 nodes: vec![],
282 description: "DAG is already running - concurrent execution not supported"
283 .to_string(),
284 });
285 }
286
287 // Guard ensures lock is released on drop (even on early return or panic)
288 let _run_guard = RunGuard {
289 lock: &self.run_lock,
290 };
291
292 // Build topological layers
293 let layers = self.compute_layers()?;
294
295 let edges = self.edges.lock().clone();
296
297 // Create all channels upfront
298 // Map: (producer_id, consumer_id) -> receiver index
299 // We store receivers in a Vec per consumer to maintain order
300 let mut consumer_receivers: HashMap<NodeId, Vec<Box<dyn std::any::Any + Send>>> =
301 HashMap::new();
302 let mut producer_senders: HashMap<NodeId, Vec<Box<dyn std::any::Any + Send>>> =
303 HashMap::new();
304
305 // Create channels for each producer-consumer relationship
306 // We need to lock nodes temporarily to call create_output_channels
307 {
308 let nodes_lock = self.nodes.lock();
309
310 for (consumer_id, producer_ids) in &edges {
311 for &producer_id in producer_ids {
312 // Get the producer node (it's still in the vector at this point)
313 if let Some(Some(producer_node)) = nodes_lock.get(producer_id.0) {
314 // Ask the producer to create ONE channel for this consumer
315 let (mut senders, mut receivers) = producer_node.create_output_channels(1);
316
317 // Store the sender for the producer
318 producer_senders
319 .entry(producer_id)
320 .or_default()
321 .push(senders.pop().unwrap());
322
323 // Store the receiver for the consumer (order matters!)
324 consumer_receivers
325 .entry(*consumer_id)
326 .or_default()
327 .push(receivers.pop().unwrap());
328 }
329 }
330 }
331 }
332
333 // Create a channel to collect Arc-wrapped outputs from all tasks
334 let (output_tx, mut output_rx) = mpsc::unbounded::<(
335 NodeId,
336 Result<std::sync::Arc<dyn std::any::Any + Send + Sync>, DagError>,
337 )>();
338
339 // Execute layer by layer
340 for layer in layers {
341 // ============================================================================
342 // PERFORMANCE OPTIMIZATION: Inline execution for single-task layers
343 // ============================================================================
344 //
345 // When a layer contains exactly one task (common in deep chains, linear
346 // pipelines), we execute it inline rather than spawning it. This provides
347 // 10-100x performance improvements for sequential workloads by eliminating:
348 // - Task spawning overhead
349 // - Channel creation/destruction for layer coordination
350 // - Context switching to/from the runtime
351 //
352 // CRITICAL: Panic handling is required to maintain behavioral consistency.
353 // All async runtimes (Tokio, async-std, smol, embassy-rs) catch panics in
354 // spawned tasks and convert them to errors. We must do the same for inline
355 // execution to ensure tasks behave identically regardless of execution path.
356 //
357 // Without panic catching:
358 // - Spawned task panics → caught by runtime → becomes error (expected)
359 // - Inline task panics → bubbles up → crashes program (surprising!)
360 //
361 // With panic catching (current implementation):
362 // - Spawned task panics → caught by runtime → becomes error
363 // - Inline task panics → caught by us → becomes error (consistent!)
364 //
365 // This ensures the optimization is transparent to users.
366 // ============================================================================
367 if layer.len() == 1 {
368 let node_id = layer[0];
369 let out_tx = output_tx.clone();
370
371 // Take ownership of the node
372 let node = {
373 let mut nodes_lock = self.nodes.lock();
374 nodes_lock[node_id.0].take()
375 };
376
377 if let Some(node) = node {
378 // Take the receivers for this consumer
379 let receivers = consumer_receivers.remove(&node_id).unwrap_or_default();
380
381 // Take the senders for this producer
382 let senders = producer_senders.remove(&node_id).unwrap_or_default();
383
384 // Execute inline with panic handling.
385 //
386 // FutureExt::catch_unwind() ensures panics are caught and converted to
387 // DagError::TaskPanicked, matching the behavior of async runtimes when
388 // they spawn tasks. This guarantees consistent error handling whether
389 // a task executes inline (single-task layer) or spawned (multi-task layer).
390 let result = AssertUnwindSafe(node.execute_with_channels(receivers, senders))
391 .catch_unwind()
392 .await
393 .unwrap_or_else(|panic_payload| {
394 // Convert panic to error
395 Err(DagError::TaskPanicked {
396 task_id: node_id.0,
397 panic_message: if let Some(s) = panic_payload.downcast_ref::<&str>()
398 {
399 s.to_string()
400 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
401 s.clone()
402 } else {
403 "unknown panic".to_string()
404 },
405 })
406 });
407
408 // Send the output (ignore send errors - receiver may be dropped)
409 let _ = out_tx.unbounded_send((node_id, result));
410 }
411 } else {
412 // Slow path: Multiple tasks require spawning and coordination
413 // Create a channel to track task completion for this layer
414 let (tx, mut rx) = mpsc::channel::<()>(layer.len());
415
416 // Spawn each task in this layer
417 for &node_id in &layer {
418 let mut task_tx = tx.clone();
419 let out_tx = output_tx.clone();
420
421 // Take ownership of the node
422 let node = {
423 let mut nodes_lock = self.nodes.lock();
424 nodes_lock[node_id.0].take()
425 };
426
427 if let Some(node) = node {
428 // Take the receivers for this consumer
429 let receivers = consumer_receivers.remove(&node_id).unwrap_or_default();
430
431 // Take the senders for this producer
432 let senders = producer_senders.remove(&node_id).unwrap_or_default();
433
434 // Create a 'static future by taking ownership
435 let inner_future = async move {
436 let task_future = node.execute_with_channels(receivers, senders);
437
438 let result = task_future.await;
439
440 // Send the output (ignore send errors - receiver may be dropped)
441 let _ = out_tx.unbounded_send((node_id, result));
442
443 // Signal completion (ignore send errors - receiver may be dropped)
444 let _ = task_tx.try_send(());
445 };
446
447 // Spawn the task using the provided spawner
448 spawner(Box::pin(inner_future));
449 }
450 }
451
452 // Drop the original sender so the channel closes when all tasks complete
453 drop(tx);
454
455 // Wait for all tasks in this layer to complete
456 while rx.next().await.is_some() {}
457 }
458 }
459
460 // Drop the output sender so the channel closes
461 drop(output_tx);
462
463 // Collect all outputs and check for errors
464 let mut collected = Vec::new();
465 let mut first_error = None;
466 while let Some((node_id, result)) = output_rx.next().await {
467 match result {
468 Ok(output) => collected.push((node_id, output)),
469 Err(e) if first_error.is_none() => first_error = Some(e),
470 Err(_) => {} // Ignore subsequent errors
471 }
472 }
473
474 // Insert all successful outputs (avoiding holding lock across await)
475 let mut outputs = self.outputs.lock();
476 for (node_id, output) in collected {
477 outputs.insert(node_id, output);
478 }
479 drop(outputs);
480
481 // Return first error if any
482 if let Some(err) = first_error {
483 return Err(err);
484 }
485
486 Ok(())
487 }
488
489 /// Retrieve a task's output after [`DagRunner::run`].
490 ///
491 /// # Behavior
492 ///
493 /// All task outputs are stored after execution and can be retrieved via get().
494 ///
495 /// # Errors
496 ///
497 /// Returns `DagError::ResultNotFound` if:
498 /// - The task hasn't been executed yet
499 /// - The handle is invalid
500 ///
501 /// # Examples
502 ///
503 /// ```no_run
504 /// # use dagx::{task, DagRunner, Task};
505 /// struct Configuration {
506 /// setting: i32,
507 /// }
508 ///
509 /// impl Configuration {
510 /// fn new(setting: i32) -> Self {
511 /// Self { setting }
512 /// }
513 /// }
514 ///
515 /// #[task]
516 /// impl Configuration {
517 /// async fn run(&mut self) -> i32 { self.setting }
518 /// }
519 ///
520 /// # async {
521 /// let dag = DagRunner::new();
522 ///
523 /// // Construct task with specific setting value
524 /// let task = dag.add_task(Configuration::new(42));
525 ///
526 /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
527 ///
528 /// assert_eq!(dag.get(task).unwrap(), 42);
529 /// # };
530 /// ```
531 pub fn get<T: 'static + Clone + Send + Sync, H>(&self, handle: H) -> DagResult<T>
532 where
533 H: Into<TaskHandle<T>>,
534 {
535 let handle: TaskHandle<T> = handle.into();
536 let outputs = self.outputs.lock();
537
538 let arc_output = outputs.get(&handle.id).ok_or(DagError::ResultNotFound {
539 task_id: handle.id.0,
540 })?;
541
542 // Downcast Arc<dyn Any> to Arc<T>, then clone the inner value
543 // The Arc itself is stored for efficient sharing, but get() clones the data
544 Arc::clone(arc_output)
545 .downcast::<T>()
546 .map(|arc| (*arc).clone())
547 .map_err(|_| DagError::TypeMismatch {
548 expected: std::any::type_name::<T>(),
549 found: "unknown",
550 })
551 }
552
553 fn compute_layers(&self) -> DagResult<Vec<Vec<NodeId>>> {
554 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
555 let mut layers = Vec::new();
556
557 // Calculate in-degrees: for each node, count how many dependencies it has
558 let edges = self.edges.lock();
559 let total_nodes = edges.len();
560
561 for (&node, deps) in edges.iter() {
562 let degree = deps.len();
563 in_degree.insert(node, degree);
564 }
565 drop(edges); // Release lock early
566
567 // Find all nodes with in-degree 0 (sources - nodes with no dependencies)
568 let mut queue: VecDeque<NodeId> = in_degree
569 .iter()
570 .filter(|&(_, deg)| *deg == 0)
571 .map(|(&node, _)| node)
572 .collect();
573
574 let mut visited = HashSet::new();
575
576 while !queue.is_empty() {
577 let mut current_layer = Vec::new();
578 let layer_size = queue.len();
579
580 for _ in 0..layer_size {
581 if let Some(node) = queue.pop_front() {
582 if visited.contains(&node) {
583 continue;
584 }
585
586 current_layer.push(node);
587 visited.insert(node);
588
589 // For each node that depends on the current node
590 let dependents = self.dependents.lock();
591 if let Some(deps) = dependents.get(&node) {
592 for &dependent in deps {
593 if let Some(degree) = in_degree.get_mut(&dependent) {
594 *degree -= 1;
595 if *degree == 0 {
596 queue.push_back(dependent);
597 }
598 }
599 }
600 }
601 // Lock released here
602 }
603 }
604
605 if !current_layer.is_empty() {
606 current_layer.sort(); // Deterministic ordering
607 layers.push(current_layer);
608 }
609 }
610
611 // Cycle detection: if we haven't visited all nodes, there's a cycle
612 if visited.len() != total_nodes {
613 let unvisited: Vec<_> = in_degree
614 .iter()
615 .filter(|(node, _)| !visited.contains(node))
616 .map(|(node, _)| node.0)
617 .collect();
618 let description = format!(
619 "{} node(s) could not be processed: {:?}",
620 unvisited.len(),
621 unvisited
622 );
623 return Err(DagError::CycleDetected {
624 nodes: unvisited,
625 description,
626 });
627 }
628
629 Ok(layers)
630 }
631}
632
633// Note: We cannot implement Default for DagRunner anymore since new() returns Arc<Self>.
634// Users should call DagRunner::new() directly.
635
636#[cfg(test)]
637mod tests;