operese-dagx 0.4.1

A minimal, type-safe, runtime-agnostic async DAG (Directed Acyclic Graph) executor with compile-time cycle prevention and true parallel execution
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
//! DAG runner for task orchestration and execution.
//!
//! Provides DagRunner for building and executing directed acyclic graphs of async tasks
//! with compile-time type-safe dependencies.

use std::any::Any;
use std::collections::{HashMap, HashSet, VecDeque};
use std::future::Future;
use std::hash::{BuildHasher, Hasher};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;

use futures_util::future::BoxFuture;
use futures_util::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};

#[cfg(feature = "tracing")]
use tracing::{debug, error, info, trace};

use crate::builder::{NodeId, TaskWire};
use crate::error::{DagError, DagResult};
use crate::node::{ExecutableNode, TypedNode};
use crate::DagOutput;

/// Fast hasher using values as hashes
#[derive(Default, Clone)]
pub(crate) struct PassThroughHasher {
    hash: u64,
}

impl Hasher for PassThroughHasher {
    fn finish(&self) -> u64 {
        self.hash
    }

    fn write_u32(&mut self, i: u32) {
        self.hash = i as u64;
    }

    fn write(&mut self, _bytes: &[u8]) {
        panic!("PassThroughHasher used on invalid type");
    }
}

impl BuildHasher for PassThroughHasher {
    type Hasher = PassThroughHasher;

    fn build_hasher(&self) -> Self::Hasher {
        PassThroughHasher::default()
    }
}

pub(crate) type PassThroughHashMap<K, V> = HashMap<K, V, PassThroughHasher>;

/// Build and execute a typed DAG of tasks.
///
/// A `DagRunner` is the main orchestrator for building and executing a directed acyclic graph
/// of async tasks with compile-time type-safe dependencies.
///
/// # Workflow
///
/// 1. Create a new DAG with [`DagRunner::new`]
/// 2. Add tasks with [`DagRunner::add_task`] to get builders
/// 3. Wire dependencies with [`crate::TaskBuilder::depends_on`]
/// 4. Execute all tasks with [`DagRunner::run`]
/// 5. Optionally retrieve outputs with [`DagOutput::get`]
///
/// # Examples
///
/// ```no_run
/// # use operese_dagx::{task, DagRunner, Task};
/// #
/// // Task with state constructed via ::new()
/// struct LoadValue { value: i32 }
///
/// impl LoadValue {
///     fn new(value: i32) -> Self { Self { value } }
/// }
///
/// #[task]
/// impl LoadValue {
///     async fn run(&mut self) -> i32 { self.value }
/// }
///
/// struct Add;
///
/// #[task]
/// impl Add {
///     async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
/// }
///
/// # async {
/// let mut dag = DagRunner::new();
///
/// // Construct instances using ::new() pattern
/// let x = dag.add_task(LoadValue::new(2));
/// let y = dag.add_task(LoadValue::new(3));
/// let sum = dag.add_task(Add).depends_on((&x, &y));
///
///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap();
///
/// assert_eq!(output.get(sum), 5);
/// # };
/// ```
pub struct DagRunner {
    pub(crate) nodes: Vec<Option<Box<dyn ExecutableNode + Sync>>>,
    /// node -> dependencies
    pub(crate) edges: PassThroughHashMap<NodeId, Vec<NodeId>>,
    /// node -> tasks that depend on it
    pub(crate) dependents: PassThroughHashMap<NodeId, Vec<NodeId>>,
}

impl Default for DagRunner {
    fn default() -> Self {
        Self::new()
    }
}

impl DagRunner {
    /// Create a new empty DAG.
    ///
    /// # Examples
    ///
    /// ```
    /// use operese_dagx::DagRunner;
    ///
    /// let mut dag = DagRunner::new();
    /// ```
    pub fn new() -> Self {
        Self {
            nodes: Vec::new(),
            edges: HashMap::default(),
            dependents: HashMap::default(),
        }
    }

    /// Add a task instance to the DAG, returning a node builder for wiring dependencies.
    ///
    /// If the task has no dependencies, a [`crate::TaskHandle`] will be returned.
    /// If not, dependencies should be specified for the returned [`crate::TaskBuilder`]
    /// using [`crate::TaskBuilder::depends_on`].
    ///
    /// # Examples
    ///
    /// ```no_run
    /// # use operese_dagx::{task, DagRunner, Task};
    /// #
    /// // Task with state - shows you construct with specific value
    /// struct LoadValue {
    ///     initial: i32,
    /// }
    ///
    /// impl LoadValue {
    ///     fn new(initial: i32) -> Self {
    ///         Self { initial }
    ///     }
    /// }
    ///
    /// #[task]
    /// impl LoadValue {
    ///     async fn run(&mut self) -> i32 { self.initial }
    /// }
    ///
    /// // Task with configuration - shows you can parameterize behavior
    /// struct AddOffset {
    ///     offset: i32,
    /// }
    ///
    /// impl AddOffset {
    ///     fn new(offset: i32) -> Self {
    ///         Self { offset }
    ///     }
    /// }
    ///
    /// #[task]
    /// impl AddOffset {
    ///     async fn run(&mut self, x: &i32) -> i32 { x + self.offset }
    /// }
    ///
    /// # async {
    /// let mut dag = DagRunner::new();
    ///
    /// // Construct task with initial value of 10
    /// let base = dag.add_task(LoadValue::new(10));
    ///
    /// // Construct task with offset of 1
    /// let inc = dag.add_task(AddOffset::new(1)).depends_on(&base);
    ///
    ///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap();
    /// assert_eq!(output.get(inc), 11);
    /// # };
    /// ```
    pub fn add_task<'dag, Input, Tk>(&'dag mut self, task: Tk) -> Tk::Retval<'dag>
    where
        Tk: TaskWire<Input>,
        Input: Send + Sync + 'static,
    {
        let id = NodeId(self.nodes.len() as u32);

        #[cfg(feature = "tracing")]
        debug!(
            task_id = id.0,
            task_type = std::any::type_name::<Tk>(),
            "adding task to DAG"
        );

        let node = TypedNode::new(task);
        self.nodes.push(Some(Box::new(node)));
        self.edges.insert(id, Vec::new());
        self.dependents.insert(id, Vec::new());

        Tk::new_from_dag(id, self)
    }

    /// Run the entire DAG to completion using the provided spawner.
    ///
    /// - Executes tasks in topological order (respecting dependencies)
    /// - Runs ready tasks with maximum parallelism (executor-limited)
    /// - Executes each task at most once
    /// - Is runtime-agnostic via the spawner function
    ///
    /// # Parameters
    ///
    /// - `spawner`: A function that spawns futures on the async runtime
    ///   and returns a handle to the task. This is the only way to run tasks on separate threads. Examples:
    ///   - Tokio: `|fut| { tokio::spawn(fut).await.unwrap() }`
    ///   - Smol: `|fut| { smol::spawn(fut) }`
    ///   - Single-threaded on invoking runtime: `|fut| fut`
    ///     - Can be faster in situations where waiting time dominates
    ///
    /// # Errors
    ///
    /// - Returns `DagError::TaskPanicked` if any task panics during execution
    ///
    /// # Examples
    ///
    /// ```
    /// # use operese_dagx::{task, DagRunner, Task};
    /// #
    /// // Tuple struct
    /// struct Value(i32);
    ///
    /// #[task]
    /// impl Value {
    ///     async fn run(&mut self) -> i32 { self.0 }
    /// }
    ///
    /// // Unit struct
    /// struct Add;
    ///
    /// #[task]
    /// impl Add {
    ///     async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
    /// }
    ///
    /// # async {
    /// let mut dag = DagRunner::new();
    ///
    /// let a = dag.add_task(Value(1));
    /// let b = dag.add_task(Value(2));
    /// let sum = dag.add_task(Add).depends_on((&a, &b));
    ///
    ///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap(); // Executes all tasks
    /// # };
    /// ```
    #[inline]
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, spawner)))]
    pub async fn run<S, F>(mut self, spawner: S) -> DagResult<DagOutput>
    where
        S: Fn(BoxFuture<'static, DagResult<Arc<dyn Any + Send + Sync>>>) -> F,
        F: Future<Output = DagResult<Arc<dyn Any + Send + Sync>>>,
    {
        #[cfg(feature = "tracing")]
        info!("starting DAG execution");

        // Build topological layers
        let layers = self.compute_layers()?;

        let total_tasks = layers.iter().map(|l| l.len()).sum::<usize>();

        #[cfg(feature = "tracing")]
        debug!(
            layer_count = layers.len(),
            total_tasks, "computed topological layers"
        );

        let mut outputs: PassThroughHashMap<NodeId, Arc<dyn Any + Send + Sync>> =
            HashMap::with_capacity_and_hasher(total_tasks, PassThroughHasher::default());
        let mut first_error = None;

        // Panic handling is required to maintain behavioral consistency, as
        // different async runtimes (Tokio, async-std, smol, embassy-rs) handle panics in
        // spawned tasks differently.
        //
        // This ensures tasks behave in the same way whether executed inline or on the spawner
        // across every runtime.
        for layer in layers {
            #[cfg(feature = "tracing")]
            {
                debug!(task_count = layer.len(), "executing layer");
            }
            // Performance optimization: Inline execution for single-task layers
            //
            // When a layer contains exactly one task (common in deep chains, linear
            // pipelines), we execute it inline rather than spawning it. This provides
            // 10-100x performance improvements for sequential workloads by eliminating:
            //   - Task spawning overhead
            //   - Context switching to/from the runtime
            if layer.len() == 1 {
                let node_id = layer[0];

                #[cfg(feature = "tracing")]
                trace!(
                    task_id = node_id.0,
                    "executing task inline (single-task layer optimization)"
                );

                // Take ownership of the node
                let node = self.nodes[node_id.0 as usize].take();

                if let Some(node) = node {
                    let dependencies: Vec<_> = self.edges[&node_id]
                        .iter()
                        .flat_map(|dep| outputs.get(dep))
                        .cloned()
                        .collect();

                    // Execute inline with panic handling.
                    let result = AssertUnwindSafe(node.execute_with_deps(dependencies))
                        .catch_unwind()
                        .await
                        .unwrap_or_else(|panic_payload| {
                            // Convert panic to error
                            let panic_message =
                                if let Some(s) = panic_payload.downcast_ref::<&str>() {
                                    s.to_string()
                                } else if let Some(s) = panic_payload.downcast_ref::<String>() {
                                    s.clone()
                                } else {
                                    "unknown panic".to_string()
                                };

                            #[cfg(feature = "tracing")]
                            error!(
                                task_id = node_id.0,
                                panic_message = %panic_message,
                                "task panicked during inline execution"
                            );

                            Err(DagError::TaskPanicked {
                                task_id: node_id.0,
                                panic_message,
                            })
                        });

                    match result {
                        Ok(output) => {
                            outputs.insert(node_id, output);
                        }
                        Err(e) => {
                            first_error.get_or_insert(e);
                        }
                    }
                }
            } else {
                // Slow path: Multiple tasks require spawning and coordination
                // Spawn each task in this layer
                let mut futures: FuturesUnordered<_> = layer
                    .into_iter()
                    .filter_map(|node_id| {
                        #[cfg(feature = "tracing")]
                        trace!(task_id = node_id.0, "spawning task");

                        // Take ownership of the node
                        let node = self.nodes[node_id.0 as usize].take();
                        if let Some(node) = node {
                            let dependencies: Vec<_> = self.edges[&node_id]
                                .iter()
                                .flat_map(|dep| outputs.get(dep))
                                .cloned()
                                .collect();

                            let inner_future = spawner(node.execute_with_deps(dependencies));
                            // Spawn the task using the provided spawner
                            let inner_future = async move {
                                let result = inner_future.await?;
                                Ok((node_id, result))
                            };

                            Some(
                                AssertUnwindSafe(inner_future)
                                    .catch_unwind()
                                    .unwrap_or_else(move |panic_payload| {
                                        // Convert panic to error
                                        let panic_message =
                                            if let Some(s) = panic_payload.downcast_ref::<&str>() {
                                                s.to_string()
                                            } else if let Some(s) =
                                                panic_payload.downcast_ref::<String>()
                                            {
                                                s.clone()
                                            } else {
                                                "unknown panic".to_string()
                                            };

                                        #[cfg(feature = "tracing")]
                                        error!(
                                            task_id = node_id.0,
                                            panic_message = %panic_message,
                                            "task panicked during inline execution"
                                        );

                                        Err(DagError::TaskPanicked {
                                            task_id: node_id.0,
                                            panic_message,
                                        })
                                    }),
                            )
                        } else {
                            None
                        }
                    })
                    .collect();

                while let Some(out) = futures.next().await {
                    match out {
                        Ok(output) => {
                            outputs.insert(output.0, output.1);
                        }
                        Err(e) => {
                            first_error.get_or_insert(e);
                        }
                    }
                }
            }

            // Return first error if any, aborting execution after this layer
            if let Some(err) = first_error {
                #[cfg(feature = "tracing")]
                error!(?err, "DAG execution failed");
                return Err(err);
            }
        }

        #[cfg(feature = "tracing")]
        info!("DAG execution completed successfully");

        Ok(DagOutput::new(outputs))
    }

    fn compute_layers(&self) -> DagResult<Vec<Vec<NodeId>>> {
        #[cfg(feature = "tracing")]
        debug!("computing topological layers");

        let mut in_degree: PassThroughHashMap<NodeId, usize> = HashMap::default();
        let mut layers = Vec::new();

        // Calculate in-degrees: for each node, count how many dependencies it has

        for (&node, deps) in self.edges.iter() {
            let degree = deps.len();
            in_degree.insert(node, degree);
        }

        // Find all nodes with in-degree 0 (sources - nodes with no dependencies)
        let mut queue: VecDeque<NodeId> = in_degree
            .iter()
            .filter(|&(_, deg)| *deg == 0)
            .map(|(&node, _)| node)
            .collect();

        let mut visited = HashSet::new();

        while !queue.is_empty() {
            let mut current_layer = Vec::new();
            let layer_size = queue.len();

            for _ in 0..layer_size {
                if let Some(node) = queue.pop_front() {
                    if visited.contains(&node) {
                        continue;
                    }

                    current_layer.push(node);
                    visited.insert(node);

                    // For each node that depends on the current node
                    if let Some(deps) = self.dependents.get(&node) {
                        for &dependent in deps {
                            if let Some(degree) = in_degree.get_mut(&dependent) {
                                *degree -= 1;
                                if *degree == 0 {
                                    queue.push_back(dependent);
                                }
                            }
                        }
                    }
                    // Lock released here
                }
            }

            if !current_layer.is_empty() {
                layers.push(current_layer);
            }
        }

        // Cycle detection removed: cycles are impossible via the public API.
        //
        // The type-state pattern enforces acyclic structure at compile time through:
        // 1. TaskBuilder::depends_on() consumes the builder (move semantics)
        // 2. TaskHandle has no methods to add dependencies (immutability)
        // 3. Strict topological ordering requirement for dependency wiring
        //
        // This eliminates the need for runtime cycle detection, providing both
        // safety guarantees and performance benefits (no validation overhead).
        debug_assert!(!visited.is_empty() || layers.is_empty());

        #[cfg(feature = "tracing")]
        debug!(layer_count = layers.len(), "topological layers computed");

        Ok(layers)
    }
}

#[cfg(test)]
mod tests;