dagx/
runner.rs

1//! DAG runner for task orchestration and execution.
2//!
3//! Provides DagRunner for building and executing directed acyclic graphs of async tasks
4//! with compile-time type-safe dependencies.
5//!
6//! Uses Mutex for interior mutability to enable builder pattern (`&self` instead of `&mut self`).
7
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::panic::AssertUnwindSafe;
10use std::sync::{
11    atomic::{AtomicBool, Ordering},
12    Arc,
13};
14
15use futures::channel::mpsc;
16use futures::future::BoxFuture;
17use futures::{FutureExt, StreamExt};
18use parking_lot::Mutex;
19
20use crate::builder::TaskBuilder;
21use crate::error::{DagError, DagResult};
22use crate::extract::ExtractInput;
23use crate::node::{ExecutableNode, TypedNode};
24use crate::task::Task;
25use crate::types::{NodeId, Pending, TaskHandle};
26
27// Guard to ensure run_lock is released even on early return or panic
28struct RunGuard<'a> {
29    lock: &'a AtomicBool,
30}
31
32impl<'a> Drop for RunGuard<'a> {
33    fn drop(&mut self) {
34        self.lock.store(false, Ordering::SeqCst);
35    }
36}
37
38/// Build and execute a typed DAG of tasks.
39///
40/// A `DagRunner` is the main orchestrator for building and executing a directed acyclic graph
41/// of async tasks with compile-time type-safe dependencies.
42///
43/// # Workflow
44///
45/// 1. Create a new DAG with [`DagRunner::new`]
46/// 2. Add tasks with [`DagRunner::add_task`] to get [`TaskBuilder`] builders
47/// 3. Wire dependencies with [`TaskBuilder::depends_on`]
48/// 4. Execute all tasks with [`DagRunner::run`]
49/// 5. Optionally retrieve outputs with [`DagRunner::get`]
50///
51/// # Examples
52///
53/// ```no_run
54/// # use dagx::{task, DagRunner, Task};
55/// // Task with state constructed via ::new()
56/// struct LoadValue { value: i32 }
57///
58/// impl LoadValue {
59///     fn new(value: i32) -> Self { Self { value } }
60/// }
61///
62/// #[task]
63/// impl LoadValue {
64///     async fn run(&mut self) -> i32 { self.value }
65/// }
66///
67/// // Unit struct - no fields needed
68/// struct Add;
69///
70/// #[task]
71/// impl Add {
72///     async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
73/// }
74///
75/// # async {
76/// let dag = DagRunner::new();
77///
78/// // Construct instances using ::new() pattern
79/// let x = dag.add_task(LoadValue::new(2));
80/// let y = dag.add_task(LoadValue::new(3));
81/// let sum = dag.add_task(Add).depends_on((&x, &y));
82///
83/// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
84///
85/// assert_eq!(dag.get(sum).unwrap(), 5);
86/// # };
87/// ```
88///
89/// Uses Mutex for interior mutability to enable the builder pattern (`&self` not `&mut self`).
90/// This allows fluent chaining of `add_task()` calls.
91///
92/// Nodes use Option to allow taking ownership during execution.
93/// Outputs are Arc-wrapped and stored separately for retrieval via get().
94/// Arc enables efficient sharing during fanout without cloning data.
95pub struct DagRunner {
96    pub(crate) nodes: Mutex<Vec<Option<Box<dyn ExecutableNode>>>>,
97    pub(crate) outputs: Mutex<HashMap<NodeId, std::sync::Arc<dyn std::any::Any + Send + Sync>>>,
98    pub(crate) edges: Mutex<HashMap<NodeId, Vec<NodeId>>>, // node -> dependencies
99    pub(crate) dependents: Mutex<HashMap<NodeId, Vec<NodeId>>>, // node -> tasks that depend on it
100    pub(crate) next_id: Mutex<usize>,
101    pub(crate) run_lock: AtomicBool, // Ensures only one run() at a time
102}
103
104impl Default for DagRunner {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl DagRunner {
111    /// Create a new empty DAG.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use dagx::DagRunner;
117    ///
118    /// let dag = DagRunner::new();
119    /// ```
120    pub fn new() -> Self {
121        Self {
122            nodes: Mutex::new(Vec::new()),
123            outputs: Mutex::new(HashMap::new()),
124            edges: Mutex::new(HashMap::new()),
125            dependents: Mutex::new(HashMap::new()),
126            next_id: Mutex::new(0),
127            run_lock: AtomicBool::new(false),
128        }
129    }
130
131    pub(crate) fn alloc_id(&self) -> NodeId {
132        let mut next_id = self.next_id.lock();
133        let id = NodeId(*next_id);
134        *next_id += 1;
135        id
136    }
137
138    /// Add a task instance to the DAG, returning a node builder for wiring dependencies.
139    ///
140    /// The returned [`TaskBuilder<Tk, Pending>`](TaskBuilder) can be used to:
141    /// - Specify dependencies via [`TaskBuilder::depends_on`]
142    /// - Used directly as a [`TaskHandle`] to the task's output
143    ///
144    /// # Examples
145    ///
146    /// ```no_run
147    /// # use dagx::{task, DagRunner, Task};
148    /// // Task with state - shows you construct with specific value
149    /// struct LoadValue {
150    ///     initial: i32,
151    /// }
152    ///
153    /// impl LoadValue {
154    ///     fn new(initial: i32) -> Self {
155    ///         Self { initial }
156    ///     }
157    /// }
158    ///
159    /// #[task]
160    /// impl LoadValue {
161    ///     async fn run(&mut self) -> i32 { self.initial }
162    /// }
163    ///
164    /// // Task with configuration - shows you can parameterize behavior
165    /// struct AddOffset {
166    ///     offset: i32,
167    /// }
168    ///
169    /// impl AddOffset {
170    ///     fn new(offset: i32) -> Self {
171    ///         Self { offset }
172    ///     }
173    /// }
174    ///
175    /// #[task]
176    /// impl AddOffset {
177    ///     async fn run(&mut self, x: &i32) -> i32 { x + self.offset }
178    /// }
179    ///
180    /// # async {
181    /// let dag = DagRunner::new();
182    ///
183    /// // Construct task with initial value of 10
184    /// let base = dag.add_task(LoadValue::new(10));
185    ///
186    /// // Construct task with offset of 1
187    /// let inc = dag.add_task(AddOffset::new(1)).depends_on(&base);
188    ///
189    /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
190    /// assert_eq!(dag.get(&inc).unwrap(), 11);
191    /// # };
192    /// ```
193    pub fn add_task<Tk>(&self, task: Tk) -> TaskBuilder<'_, Tk, Pending>
194    where
195        Tk: Task + 'static,
196        Tk::Input: 'static + Clone + ExtractInput,
197        Tk::Output: 'static + Clone,
198    {
199        let id = self.alloc_id();
200        let node = TypedNode::new(id, task);
201        self.nodes.lock().push(Some(Box::new(node)));
202        self.edges.lock().insert(id, Vec::new());
203        self.dependents.lock().insert(id, Vec::new());
204
205        TaskBuilder {
206            id,
207            dag: self,
208            _phantom: std::marker::PhantomData,
209        }
210    }
211
212    /// Run the entire DAG to completion using the provided spawner.
213    ///
214    /// This method:
215    /// - Executes tasks in topological order (respecting dependencies)
216    /// - Runs ready tasks with maximum parallelism (executor-limited)
217    /// - Executes each task at most once
218    /// - Waits for **all sinks** (tasks with no dependents) to complete
219    /// - Is runtime-agnostic via the spawner function
220    ///
221    /// # Parameters
222    ///
223    /// - `spawner`: A function that spawns futures on the async runtime. Examples:
224    ///   - Tokio: `|fut| { tokio::spawn(fut); }`
225    ///   - Smol: `|fut| { smol::spawn(fut).detach(); }`
226    ///   - Async-std: `|fut| { async_std::task::spawn(fut); }`
227    ///
228    /// # Errors
229    ///
230    /// Returns `DagError::CycleDetected` if the DAG contains a cycle.
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// # use dagx::{task, DagRunner, Task};
236    /// // Tuple struct
237    /// struct Value(i32);
238    ///
239    /// #[task]
240    /// impl Value {
241    ///     async fn run(&mut self) -> i32 { self.0 }
242    /// }
243    ///
244    /// // Unit struct
245    /// struct Add;
246    ///
247    /// #[task]
248    /// impl Add {
249    ///     async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
250    /// }
251    ///
252    /// # async {
253    /// let dag = DagRunner::new();
254    ///
255    /// let a = dag.add_task(Value(1));
256    /// let b = dag.add_task(Value(2));
257    /// let sum = dag.add_task(Add).depends_on((&a, &b));
258    ///
259    /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap(); // Executes all tasks
260    /// # };
261    /// ```
262    ///
263    /// # Implementation Note
264    ///
265    /// Tasks communicate via oneshot channels created fresh during run().
266    /// This eliminates Mutex contention on outputs and enables true streaming execution.
267    /// Type erasure occurs only at the ExecutableNode trait boundary - channels are created
268    /// with full type information and type-erased just before passing to execute_with_channels.
269    pub async fn run<S>(&self, spawner: S) -> DagResult<()>
270    where
271        S: Fn(BoxFuture<'static, ()>),
272    {
273        // Acquire run lock to prevent concurrent executions
274        // Use atomic compare_exchange to check if already running
275        if self
276            .run_lock
277            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
278            .is_err()
279        {
280            return Err(DagError::CycleDetected {
281                nodes: vec![],
282                description: "DAG is already running - concurrent execution not supported"
283                    .to_string(),
284            });
285        }
286
287        // Guard ensures lock is released on drop (even on early return or panic)
288        let _run_guard = RunGuard {
289            lock: &self.run_lock,
290        };
291
292        // Build topological layers
293        let layers = self.compute_layers()?;
294
295        let edges = self.edges.lock().clone();
296
297        // Create all channels upfront
298        // Map: (producer_id, consumer_id) -> receiver index
299        // We store receivers in a Vec per consumer to maintain order
300        let mut consumer_receivers: HashMap<NodeId, Vec<Box<dyn std::any::Any + Send>>> =
301            HashMap::new();
302        let mut producer_senders: HashMap<NodeId, Vec<Box<dyn std::any::Any + Send>>> =
303            HashMap::new();
304
305        // Create channels for each producer-consumer relationship
306        // We need to lock nodes temporarily to call create_output_channels
307        {
308            let nodes_lock = self.nodes.lock();
309
310            for (consumer_id, producer_ids) in &edges {
311                for &producer_id in producer_ids {
312                    // Get the producer node (it's still in the vector at this point)
313                    if let Some(Some(producer_node)) = nodes_lock.get(producer_id.0) {
314                        // Ask the producer to create ONE channel for this consumer
315                        let (mut senders, mut receivers) = producer_node.create_output_channels(1);
316
317                        // Store the sender for the producer
318                        producer_senders
319                            .entry(producer_id)
320                            .or_default()
321                            .push(senders.pop().unwrap());
322
323                        // Store the receiver for the consumer (order matters!)
324                        consumer_receivers
325                            .entry(*consumer_id)
326                            .or_default()
327                            .push(receivers.pop().unwrap());
328                    }
329                }
330            }
331        }
332
333        // Create a channel to collect Arc-wrapped outputs from all tasks
334        let (output_tx, mut output_rx) = mpsc::unbounded::<(
335            NodeId,
336            Result<std::sync::Arc<dyn std::any::Any + Send + Sync>, DagError>,
337        )>();
338
339        // Execute layer by layer
340        for layer in layers {
341            // ============================================================================
342            // PERFORMANCE OPTIMIZATION: Inline execution for single-task layers
343            // ============================================================================
344            //
345            // When a layer contains exactly one task (common in deep chains, linear
346            // pipelines), we execute it inline rather than spawning it. This provides
347            // 10-100x performance improvements for sequential workloads by eliminating:
348            //   - Task spawning overhead
349            //   - Channel creation/destruction for layer coordination
350            //   - Context switching to/from the runtime
351            //
352            // CRITICAL: Panic handling is required to maintain behavioral consistency.
353            // All async runtimes (Tokio, async-std, smol, embassy-rs) catch panics in
354            // spawned tasks and convert them to errors. We must do the same for inline
355            // execution to ensure tasks behave identically regardless of execution path.
356            //
357            // Without panic catching:
358            //   - Spawned task panics → caught by runtime → becomes error (expected)
359            //   - Inline task panics → bubbles up → crashes program (surprising!)
360            //
361            // With panic catching (current implementation):
362            //   - Spawned task panics → caught by runtime → becomes error
363            //   - Inline task panics → caught by us → becomes error (consistent!)
364            //
365            // This ensures the optimization is transparent to users.
366            // ============================================================================
367            if layer.len() == 1 {
368                let node_id = layer[0];
369                let out_tx = output_tx.clone();
370
371                // Take ownership of the node
372                let node = {
373                    let mut nodes_lock = self.nodes.lock();
374                    nodes_lock[node_id.0].take()
375                };
376
377                if let Some(node) = node {
378                    // Take the receivers for this consumer
379                    let receivers = consumer_receivers.remove(&node_id).unwrap_or_default();
380
381                    // Take the senders for this producer
382                    let senders = producer_senders.remove(&node_id).unwrap_or_default();
383
384                    // Execute inline with panic handling.
385                    //
386                    // FutureExt::catch_unwind() ensures panics are caught and converted to
387                    // DagError::TaskPanicked, matching the behavior of async runtimes when
388                    // they spawn tasks. This guarantees consistent error handling whether
389                    // a task executes inline (single-task layer) or spawned (multi-task layer).
390                    let result = AssertUnwindSafe(node.execute_with_channels(receivers, senders))
391                        .catch_unwind()
392                        .await
393                        .unwrap_or_else(|panic_payload| {
394                            // Convert panic to error
395                            Err(DagError::TaskPanicked {
396                                task_id: node_id.0,
397                                panic_message: if let Some(s) = panic_payload.downcast_ref::<&str>()
398                                {
399                                    s.to_string()
400                                } else if let Some(s) = panic_payload.downcast_ref::<String>() {
401                                    s.clone()
402                                } else {
403                                    "unknown panic".to_string()
404                                },
405                            })
406                        });
407
408                    // Send the output (ignore send errors - receiver may be dropped)
409                    let _ = out_tx.unbounded_send((node_id, result));
410                }
411            } else {
412                // Slow path: Multiple tasks require spawning and coordination
413                // Create a channel to track task completion for this layer
414                let (tx, mut rx) = mpsc::channel::<()>(layer.len());
415
416                // Spawn each task in this layer
417                for &node_id in &layer {
418                    let mut task_tx = tx.clone();
419                    let out_tx = output_tx.clone();
420
421                    // Take ownership of the node
422                    let node = {
423                        let mut nodes_lock = self.nodes.lock();
424                        nodes_lock[node_id.0].take()
425                    };
426
427                    if let Some(node) = node {
428                        // Take the receivers for this consumer
429                        let receivers = consumer_receivers.remove(&node_id).unwrap_or_default();
430
431                        // Take the senders for this producer
432                        let senders = producer_senders.remove(&node_id).unwrap_or_default();
433
434                        // Create a 'static future by taking ownership
435                        let inner_future = async move {
436                            let task_future = node.execute_with_channels(receivers, senders);
437
438                            let result = task_future.await;
439
440                            // Send the output (ignore send errors - receiver may be dropped)
441                            let _ = out_tx.unbounded_send((node_id, result));
442
443                            // Signal completion (ignore send errors - receiver may be dropped)
444                            let _ = task_tx.try_send(());
445                        };
446
447                        // Spawn the task using the provided spawner
448                        spawner(Box::pin(inner_future));
449                    }
450                }
451
452                // Drop the original sender so the channel closes when all tasks complete
453                drop(tx);
454
455                // Wait for all tasks in this layer to complete
456                while rx.next().await.is_some() {}
457            }
458        }
459
460        // Drop the output sender so the channel closes
461        drop(output_tx);
462
463        // Collect all outputs and check for errors
464        let mut collected = Vec::new();
465        let mut first_error = None;
466        while let Some((node_id, result)) = output_rx.next().await {
467            match result {
468                Ok(output) => collected.push((node_id, output)),
469                Err(e) if first_error.is_none() => first_error = Some(e),
470                Err(_) => {} // Ignore subsequent errors
471            }
472        }
473
474        // Insert all successful outputs (avoiding holding lock across await)
475        let mut outputs = self.outputs.lock();
476        for (node_id, output) in collected {
477            outputs.insert(node_id, output);
478        }
479        drop(outputs);
480
481        // Return first error if any
482        if let Some(err) = first_error {
483            return Err(err);
484        }
485
486        Ok(())
487    }
488
489    /// Retrieve a task's output after [`DagRunner::run`].
490    ///
491    /// # Behavior
492    ///
493    /// All task outputs are stored after execution and can be retrieved via get().
494    ///
495    /// # Errors
496    ///
497    /// Returns `DagError::ResultNotFound` if:
498    /// - The task hasn't been executed yet
499    /// - The handle is invalid
500    ///
501    /// # Examples
502    ///
503    /// ```no_run
504    /// # use dagx::{task, DagRunner, Task};
505    /// struct Configuration {
506    ///     setting: i32,
507    /// }
508    ///
509    /// impl Configuration {
510    ///     fn new(setting: i32) -> Self {
511    ///         Self { setting }
512    ///     }
513    /// }
514    ///
515    /// #[task]
516    /// impl Configuration {
517    ///     async fn run(&mut self) -> i32 { self.setting }
518    /// }
519    ///
520    /// # async {
521    /// let dag = DagRunner::new();
522    ///
523    /// // Construct task with specific setting value
524    /// let task = dag.add_task(Configuration::new(42));
525    ///
526    /// dag.run(|fut| { tokio::spawn(fut); }).await.unwrap();
527    ///
528    /// assert_eq!(dag.get(task).unwrap(), 42);
529    /// # };
530    /// ```
531    pub fn get<T: 'static + Clone + Send + Sync, H>(&self, handle: H) -> DagResult<T>
532    where
533        H: Into<TaskHandle<T>>,
534    {
535        let handle: TaskHandle<T> = handle.into();
536        let outputs = self.outputs.lock();
537
538        let arc_output = outputs.get(&handle.id).ok_or(DagError::ResultNotFound {
539            task_id: handle.id.0,
540        })?;
541
542        // Downcast Arc<dyn Any> to Arc<T>, then clone the inner value
543        // The Arc itself is stored for efficient sharing, but get() clones the data
544        Arc::clone(arc_output)
545            .downcast::<T>()
546            .map(|arc| (*arc).clone())
547            .map_err(|_| DagError::TypeMismatch {
548                expected: std::any::type_name::<T>(),
549                found: "unknown",
550            })
551    }
552
553    fn compute_layers(&self) -> DagResult<Vec<Vec<NodeId>>> {
554        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
555        let mut layers = Vec::new();
556
557        // Calculate in-degrees: for each node, count how many dependencies it has
558        let edges = self.edges.lock();
559        let total_nodes = edges.len();
560
561        for (&node, deps) in edges.iter() {
562            let degree = deps.len();
563            in_degree.insert(node, degree);
564        }
565        drop(edges); // Release lock early
566
567        // Find all nodes with in-degree 0 (sources - nodes with no dependencies)
568        let mut queue: VecDeque<NodeId> = in_degree
569            .iter()
570            .filter(|&(_, deg)| *deg == 0)
571            .map(|(&node, _)| node)
572            .collect();
573
574        let mut visited = HashSet::new();
575
576        while !queue.is_empty() {
577            let mut current_layer = Vec::new();
578            let layer_size = queue.len();
579
580            for _ in 0..layer_size {
581                if let Some(node) = queue.pop_front() {
582                    if visited.contains(&node) {
583                        continue;
584                    }
585
586                    current_layer.push(node);
587                    visited.insert(node);
588
589                    // For each node that depends on the current node
590                    let dependents = self.dependents.lock();
591                    if let Some(deps) = dependents.get(&node) {
592                        for &dependent in deps {
593                            if let Some(degree) = in_degree.get_mut(&dependent) {
594                                *degree -= 1;
595                                if *degree == 0 {
596                                    queue.push_back(dependent);
597                                }
598                            }
599                        }
600                    }
601                    // Lock released here
602                }
603            }
604
605            if !current_layer.is_empty() {
606                current_layer.sort(); // Deterministic ordering
607                layers.push(current_layer);
608            }
609        }
610
611        // Cycle detection: if we haven't visited all nodes, there's a cycle
612        if visited.len() != total_nodes {
613            let unvisited: Vec<_> = in_degree
614                .iter()
615                .filter(|(node, _)| !visited.contains(node))
616                .map(|(node, _)| node.0)
617                .collect();
618            let description = format!(
619                "{} node(s) could not be processed: {:?}",
620                unvisited.len(),
621                unvisited
622            );
623            return Err(DagError::CycleDetected {
624                nodes: unvisited,
625                description,
626            });
627        }
628
629        Ok(layers)
630    }
631}
632
633// Note: We cannot implement Default for DagRunner anymore since new() returns Arc<Self>.
634// Users should call DagRunner::new() directly.
635
636#[cfg(test)]
637mod tests;