Skip to main content

dag_executor/dag/
mod.rs

1//! The DAG engine: graph, scheduler, worker pool, and the executor that drives
2//! them.
3//!
4//! [`DagExecutor`] is the entry point. It validates a [`Dag`], optionally
5//! recovers persisted state, then runs tasks concurrently — respecting
6//! dependencies, priorities, retries, the circuit breaker, and the dead-letter
7//! queue — and returns an [`ExecutionReport`].
8
9mod graph;
10mod scheduler;
11mod worker_pool;
12
13pub use graph::Dag;
14pub use scheduler::Scheduler;
15pub use worker_pool::{TaskResult, WorkerPool};
16
17use crate::advanced::{CircuitBreaker, DeadLetterQueue, RetryPolicy};
18use crate::context::Context;
19use crate::error::{DagExecutorError, Result, TaskError};
20use crate::metrics::{MetricsCollector, MetricsSnapshot};
21use crate::state::{StateValidator, TaskRecord, TaskState};
22use crate::storage::{FileStorage, MemoryStorage, Storage};
23use crate::utils::Config;
24use futures::stream::{FuturesUnordered, StreamExt};
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::time::Duration;
28
29/// Summary of a completed (or cancelled) run.
30#[derive(Debug, Clone)]
31pub struct ExecutionReport {
32    /// Unique id of the run.
33    pub run_id: String,
34    /// Final record for every task.
35    pub records: HashMap<String, TaskRecord>,
36    /// Metrics snapshot taken at the end of the run.
37    pub metrics: MetricsSnapshot,
38}
39
40impl ExecutionReport {
41    /// Whether every task completed successfully.
42    pub fn is_success(&self) -> bool {
43        !self.records.is_empty()
44            && self
45                .records
46                .values()
47                .all(|r| r.state == TaskState::Completed)
48    }
49
50    /// Ids of tasks that ended in a failure state.
51    pub fn failed_tasks(&self) -> Vec<String> {
52        self.records
53            .values()
54            .filter(|r| r.state.is_failure())
55            .map(|r| r.id.clone())
56            .collect()
57    }
58
59    /// Count of tasks in a given state.
60    pub fn count_in(&self, state: TaskState) -> usize {
61        self.records.values().filter(|r| r.state == state).count()
62    }
63}
64
65/// Fluent builder for [`DagExecutor`].
66pub struct DagExecutorBuilder {
67    config: Config,
68    storage: Option<Arc<dyn Storage>>,
69    retry: Option<RetryPolicy>,
70    breaker: Option<Arc<CircuitBreaker>>,
71}
72
73impl DagExecutorBuilder {
74    fn new() -> Self {
75        DagExecutorBuilder {
76            config: Config::default(),
77            storage: None,
78            retry: None,
79            breaker: None,
80        }
81    }
82
83    /// Use a fully custom configuration.
84    pub fn config(mut self, config: Config) -> Self {
85        self.config = config;
86        self
87    }
88
89    /// Set the maximum number of concurrently executing tasks.
90    pub fn concurrency(mut self, n: usize) -> Self {
91        self.config.max_concurrency = n;
92        self
93    }
94
95    /// Provide a custom storage backend (defaults to [`FileStorage`] at the
96    /// configured `storage_dir`, or [`MemoryStorage`] when persistence is off).
97    pub fn storage(mut self, storage: Arc<dyn Storage>) -> Self {
98        self.storage = Some(storage);
99        self
100    }
101
102    /// Set the retry policy applied to every task.
103    pub fn retry(mut self, retry: RetryPolicy) -> Self {
104        self.retry = Some(retry);
105        self
106    }
107
108    /// Enable a shared circuit breaker guarding task execution.
109    pub fn circuit_breaker(mut self, breaker: CircuitBreaker) -> Self {
110        self.breaker = Some(Arc::new(breaker));
111        self
112    }
113
114    /// Toggle state persistence.
115    pub fn persist(mut self, persist: bool) -> Self {
116        self.config.persist_state = persist;
117        self
118    }
119
120    /// Finalize the executor.
121    pub fn build(self) -> DagExecutor {
122        let config = Arc::new(self.config);
123        let storage: Arc<dyn Storage> = self.storage.unwrap_or_else(|| {
124            if config.persist_state {
125                Arc::new(
126                    FileStorage::open(&config.storage_dir)
127                        .expect("failed to open storage directory"),
128                )
129            } else {
130                Arc::new(MemoryStorage::new())
131            }
132        });
133        let retry = self.retry.unwrap_or(RetryPolicy {
134            max_attempts: config.max_attempts,
135            ..RetryPolicy::default()
136        });
137
138        DagExecutor {
139            metrics: Arc::new(MetricsCollector::new()),
140            dead_letter: DeadLetterQueue::new(storage.clone()),
141            validator: StateValidator::new(),
142            breaker: self.breaker,
143            timeout: config.task_timeout,
144            retry,
145            storage,
146            config,
147        }
148    }
149}
150
151/// Executes DAGs of tasks with persistence, fault-tolerance and observability.
152pub struct DagExecutor {
153    config: Arc<Config>,
154    storage: Arc<dyn Storage>,
155    metrics: Arc<MetricsCollector>,
156    dead_letter: DeadLetterQueue,
157    validator: StateValidator,
158    breaker: Option<Arc<CircuitBreaker>>,
159    retry: RetryPolicy,
160    timeout: Option<Duration>,
161}
162
163impl DagExecutor {
164    /// Start building an executor.
165    pub fn builder() -> DagExecutorBuilder {
166        DagExecutorBuilder::new()
167    }
168
169    /// Build an executor with all defaults.
170    pub fn new() -> Self {
171        DagExecutorBuilder::new().build()
172    }
173
174    /// Access the metrics collector (live during a run).
175    pub fn metrics(&self) -> &Arc<MetricsCollector> {
176        &self.metrics
177    }
178
179    /// Access the dead-letter queue.
180    pub fn dead_letter(&self) -> &DeadLetterQueue {
181        &self.dead_letter
182    }
183
184    fn record_key(id: &str) -> String {
185        format!("record:{id}")
186    }
187
188    /// Run `dag` to completion with a fresh context.
189    pub async fn run(&self, dag: Dag) -> Result<ExecutionReport> {
190        let ctx = Arc::new(Context::new(self.config.clone()));
191        self.run_with_context(dag, ctx).await
192    }
193
194    /// Run `dag` using a caller-supplied [`Context`].
195    ///
196    /// This is the hook for graceful shutdown: hold a clone of `ctx` and call
197    /// [`Context::cancel`] (e.g. on `SIGINT`) to stop scheduling new work.
198    pub async fn run_with_context(&self, dag: Dag, ctx: Arc<Context>) -> Result<ExecutionReport> {
199        dag.validate()?;
200
201        let mut scheduler = self.recover_scheduler(&dag, &ctx).await?;
202
203        // remaining[id] = number of dependencies not yet completed.
204        let mut remaining: HashMap<String, usize> = HashMap::new();
205        for id in dag.task_ids() {
206            let pending_deps = dag
207                .dependencies_of(&id)
208                .into_iter()
209                .filter(|d| scheduler.state(d) != Some(TaskState::Completed))
210                .count();
211            remaining.insert(id, pending_deps);
212        }
213
214        let pool = WorkerPool::new(
215            self.config.max_concurrency,
216            self.retry,
217            self.timeout,
218            self.metrics.clone(),
219        );
220
221        // Seed the schedule: cascade-skip anything blocked by an already-failed
222        // dependency, then enqueue everything whose deps are all satisfied.
223        let initial_failures: Vec<String> = dag
224            .task_ids()
225            .into_iter()
226            .filter(|id| scheduler.state(id).map(|s| s.is_failure()).unwrap_or(false))
227            .collect();
228        for id in initial_failures {
229            self.cascade_skip(&dag, &mut scheduler, &id).await?;
230        }
231        for id in dag.task_ids() {
232            if scheduler.state(&id) == Some(TaskState::Pending)
233                && remaining.get(&id).copied().unwrap_or(0) == 0
234            {
235                let prio = dag.task(&id).map(|t| t.priority()).unwrap_or(0);
236                scheduler.mark_ready(&id, prio);
237            }
238        }
239
240        let mut in_flight: FuturesUnordered<_> = FuturesUnordered::new();
241
242        loop {
243            // Launch every currently-ready task; the worker pool's semaphore
244            // bounds how many actually run at once.
245            while let Some(id) = scheduler.next_ready() {
246                let task = match dag.task(&id) {
247                    Some(t) => t,
248                    None => continue,
249                };
250                scheduler.transition(&id, TaskState::Running);
251                self.persist(&scheduler, &id).await?;
252                in_flight.push(pool.spawn(task, ctx.clone(), self.breaker.clone()));
253            }
254
255            if in_flight.is_empty() {
256                break;
257            }
258
259            let joined = match in_flight.next().await {
260                Some(Ok(result)) => result,
261                Some(Err(join_err)) => {
262                    return Err(DagExecutorError::Executor(format!(
263                        "worker task panicked: {join_err}"
264                    )))
265                }
266                None => break,
267            };
268
269            self.handle_result(&dag, &mut scheduler, &mut remaining, &ctx, joined)
270                .await?;
271        }
272
273        Ok(ExecutionReport {
274            run_id: ctx.run_id.clone(),
275            records: scheduler.records().clone(),
276            metrics: self.metrics.snapshot(),
277        })
278    }
279
280    /// Build a scheduler, recovering and repairing any persisted records.
281    async fn recover_scheduler(&self, dag: &Dag, ctx: &Arc<Context>) -> Result<Scheduler> {
282        let mut records: HashMap<String, TaskRecord> = HashMap::new();
283
284        if self.config.persist_state {
285            for id in dag.task_ids() {
286                // A record that fails to load (e.g. a checksum mismatch from a
287                // crash during a Fast in-place write) is treated as absent, so
288                // the task simply re-runs rather than aborting recovery.
289                let value = match self.storage.load(&Self::record_key(&id)).await {
290                    Ok(v) => v,
291                    Err(e) => {
292                        tracing::warn!(task = %id, error = %e, "ignoring unreadable record during recovery");
293                        None
294                    }
295                };
296                if let Some(value) = value {
297                    if let Ok(record) = serde_json::from_value::<TaskRecord>(value) {
298                        // Re-publish recovered outputs so dependents can read them.
299                        if record.state == TaskState::Completed {
300                            if let Some(out) = &record.output {
301                                ctx.set(record.id.clone(), out.clone());
302                            }
303                        }
304                        records.insert(id, record);
305                    }
306                }
307            }
308            self.validator.repair(&mut records, self.retry.max_attempts);
309        }
310
311        let mut scheduler = Scheduler::with_records(records);
312        // Make sure every task has a record (recovered runs may add new tasks).
313        for id in dag.task_ids() {
314            scheduler.ensure_record(&id);
315        }
316        Ok(scheduler)
317    }
318
319    /// Apply the outcome of a finished task and unblock/skip dependents.
320    async fn handle_result(
321        &self,
322        dag: &Dag,
323        scheduler: &mut Scheduler,
324        remaining: &mut HashMap<String, usize>,
325        ctx: &Arc<Context>,
326        result: TaskResult,
327    ) -> Result<()> {
328        let TaskResult {
329            id,
330            attempts,
331            outcome,
332        } = result;
333
334        if let Some(record) = scheduler.records_mut().get_mut(&id) {
335            record.attempts = attempts;
336        }
337
338        match outcome {
339            Ok(output) => {
340                // Publish output for downstream consumers.
341                ctx.set(id.clone(), output.clone());
342                if let Some(record) = scheduler.records_mut().get_mut(&id) {
343                    record.output = Some(output);
344                    record.transition(TaskState::Completed);
345                }
346                let duration = scheduler
347                    .record(&id)
348                    .and_then(|r| r.duration_millis())
349                    .unwrap_or(0);
350                self.metrics.task_completed(&id, duration);
351                self.persist(scheduler, &id).await?;
352
353                // Unblock dependents whose last dependency just completed.
354                for dep in dag.dependents_of(&id) {
355                    let count = remaining.entry(dep.clone()).or_insert(0);
356                    *count = count.saturating_sub(1);
357                    if *count == 0 && scheduler.state(&dep) == Some(TaskState::Pending) {
358                        let prio = dag.task(&dep).map(|t| t.priority()).unwrap_or(0);
359                        scheduler.mark_ready(&dep, prio);
360                    }
361                }
362            }
363            Err(TaskError::Cancelled) => {
364                if let Some(record) = scheduler.records_mut().get_mut(&id) {
365                    record.transition(TaskState::Cancelled);
366                }
367                self.persist(scheduler, &id).await?;
368                self.cascade_skip(dag, scheduler, &id).await?;
369            }
370            Err(err) => {
371                let msg = err.to_string();
372                if let Some(record) = scheduler.records_mut().get_mut(&id) {
373                    record.error = Some(msg.clone());
374                    record.transition(TaskState::Failed);
375                }
376                self.metrics.task_failed();
377
378                // Retries are exhausted by the time we get here: dead-letter it.
379                self.dead_letter.push(&id, attempts, msg).await?;
380                if let Some(record) = scheduler.records_mut().get_mut(&id) {
381                    record.transition(TaskState::DeadLettered);
382                }
383                self.metrics.task_dead_lettered();
384                self.persist(scheduler, &id).await?;
385
386                self.cascade_skip(dag, scheduler, &id).await?;
387            }
388        }
389        Ok(())
390    }
391
392    /// Mark every (transitive) dependent of a failed/cancelled task as skipped.
393    async fn cascade_skip(
394        &self,
395        dag: &Dag,
396        scheduler: &mut Scheduler,
397        failed_id: &str,
398    ) -> Result<()> {
399        let mut stack: Vec<String> = dag.dependents_of(failed_id);
400        while let Some(id) = stack.pop() {
401            if scheduler.state(&id) == Some(TaskState::Pending)
402                && scheduler.transition(&id, TaskState::Skipped)
403            {
404                self.metrics.task_skipped();
405                self.persist(scheduler, &id).await?;
406                stack.extend(dag.dependents_of(&id));
407            }
408        }
409        Ok(())
410    }
411
412    /// Persist a single task's record if persistence is enabled.
413    async fn persist(&self, scheduler: &Scheduler, id: &str) -> Result<()> {
414        if !self.config.persist_state {
415            return Ok(());
416        }
417        if let Some(record) = scheduler.record(id) {
418            let value = serde_json::to_value(record).map_err(crate::error::StorageError::from)?;
419            self.storage.save(&Self::record_key(id), &value).await?;
420        }
421        Ok(())
422    }
423}
424
425impl Default for DagExecutor {
426    fn default() -> Self {
427        DagExecutor::new()
428    }
429}