1mod graph;
10mod scheduler;
11mod worker_pool;
12
13pub use graph::Dag;
14pub use scheduler::Scheduler;
15pub use worker_pool::{TaskResult, WorkerPool};
16
17use crate::advanced::{CircuitBreaker, DeadLetterQueue, RetryPolicy};
18use crate::context::Context;
19use crate::error::{DagExecutorError, Result, TaskError};
20use crate::metrics::{MetricsCollector, MetricsSnapshot};
21use crate::state::{StateValidator, TaskRecord, TaskState};
22use crate::storage::{FileStorage, MemoryStorage, Storage};
23use crate::utils::Config;
24use futures::stream::{FuturesUnordered, StreamExt};
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::time::Duration;
28
29#[derive(Debug, Clone)]
31pub struct ExecutionReport {
32 pub run_id: String,
34 pub records: HashMap<String, TaskRecord>,
36 pub metrics: MetricsSnapshot,
38}
39
40impl ExecutionReport {
41 pub fn is_success(&self) -> bool {
43 !self.records.is_empty()
44 && self
45 .records
46 .values()
47 .all(|r| r.state == TaskState::Completed)
48 }
49
50 pub fn failed_tasks(&self) -> Vec<String> {
52 self.records
53 .values()
54 .filter(|r| r.state.is_failure())
55 .map(|r| r.id.clone())
56 .collect()
57 }
58
59 pub fn count_in(&self, state: TaskState) -> usize {
61 self.records.values().filter(|r| r.state == state).count()
62 }
63}
64
65pub struct DagExecutorBuilder {
67 config: Config,
68 storage: Option<Arc<dyn Storage>>,
69 retry: Option<RetryPolicy>,
70 breaker: Option<Arc<CircuitBreaker>>,
71}
72
73impl DagExecutorBuilder {
74 fn new() -> Self {
75 DagExecutorBuilder {
76 config: Config::default(),
77 storage: None,
78 retry: None,
79 breaker: None,
80 }
81 }
82
83 pub fn config(mut self, config: Config) -> Self {
85 self.config = config;
86 self
87 }
88
89 pub fn concurrency(mut self, n: usize) -> Self {
91 self.config.max_concurrency = n;
92 self
93 }
94
95 pub fn storage(mut self, storage: Arc<dyn Storage>) -> Self {
98 self.storage = Some(storage);
99 self
100 }
101
102 pub fn retry(mut self, retry: RetryPolicy) -> Self {
104 self.retry = Some(retry);
105 self
106 }
107
108 pub fn circuit_breaker(mut self, breaker: CircuitBreaker) -> Self {
110 self.breaker = Some(Arc::new(breaker));
111 self
112 }
113
114 pub fn persist(mut self, persist: bool) -> Self {
116 self.config.persist_state = persist;
117 self
118 }
119
120 pub fn build(self) -> DagExecutor {
122 let config = Arc::new(self.config);
123 let storage: Arc<dyn Storage> = self.storage.unwrap_or_else(|| {
124 if config.persist_state {
125 Arc::new(
126 FileStorage::open(&config.storage_dir)
127 .expect("failed to open storage directory"),
128 )
129 } else {
130 Arc::new(MemoryStorage::new())
131 }
132 });
133 let retry = self.retry.unwrap_or(RetryPolicy {
134 max_attempts: config.max_attempts,
135 ..RetryPolicy::default()
136 });
137
138 DagExecutor {
139 metrics: Arc::new(MetricsCollector::new()),
140 dead_letter: DeadLetterQueue::new(storage.clone()),
141 validator: StateValidator::new(),
142 breaker: self.breaker,
143 timeout: config.task_timeout,
144 retry,
145 storage,
146 config,
147 }
148 }
149}
150
151pub struct DagExecutor {
153 config: Arc<Config>,
154 storage: Arc<dyn Storage>,
155 metrics: Arc<MetricsCollector>,
156 dead_letter: DeadLetterQueue,
157 validator: StateValidator,
158 breaker: Option<Arc<CircuitBreaker>>,
159 retry: RetryPolicy,
160 timeout: Option<Duration>,
161}
162
163impl DagExecutor {
164 pub fn builder() -> DagExecutorBuilder {
166 DagExecutorBuilder::new()
167 }
168
169 pub fn new() -> Self {
171 DagExecutorBuilder::new().build()
172 }
173
174 pub fn metrics(&self) -> &Arc<MetricsCollector> {
176 &self.metrics
177 }
178
179 pub fn dead_letter(&self) -> &DeadLetterQueue {
181 &self.dead_letter
182 }
183
184 fn record_key(id: &str) -> String {
185 format!("record:{id}")
186 }
187
188 pub async fn run(&self, dag: Dag) -> Result<ExecutionReport> {
190 let ctx = Arc::new(Context::new(self.config.clone()));
191 self.run_with_context(dag, ctx).await
192 }
193
194 pub async fn run_with_context(&self, dag: Dag, ctx: Arc<Context>) -> Result<ExecutionReport> {
199 dag.validate()?;
200
201 let mut scheduler = self.recover_scheduler(&dag, &ctx).await?;
202
203 let mut remaining: HashMap<String, usize> = HashMap::new();
205 for id in dag.task_ids() {
206 let pending_deps = dag
207 .dependencies_of(&id)
208 .into_iter()
209 .filter(|d| scheduler.state(d) != Some(TaskState::Completed))
210 .count();
211 remaining.insert(id, pending_deps);
212 }
213
214 let pool = WorkerPool::new(
215 self.config.max_concurrency,
216 self.retry,
217 self.timeout,
218 self.metrics.clone(),
219 );
220
221 let initial_failures: Vec<String> = dag
224 .task_ids()
225 .into_iter()
226 .filter(|id| scheduler.state(id).map(|s| s.is_failure()).unwrap_or(false))
227 .collect();
228 for id in initial_failures {
229 self.cascade_skip(&dag, &mut scheduler, &id).await?;
230 }
231 for id in dag.task_ids() {
232 if scheduler.state(&id) == Some(TaskState::Pending)
233 && remaining.get(&id).copied().unwrap_or(0) == 0
234 {
235 let prio = dag.task(&id).map(|t| t.priority()).unwrap_or(0);
236 scheduler.mark_ready(&id, prio);
237 }
238 }
239
240 let mut in_flight: FuturesUnordered<_> = FuturesUnordered::new();
241
242 loop {
243 while let Some(id) = scheduler.next_ready() {
246 let task = match dag.task(&id) {
247 Some(t) => t,
248 None => continue,
249 };
250 scheduler.transition(&id, TaskState::Running);
251 self.persist(&scheduler, &id).await?;
252 in_flight.push(pool.spawn(task, ctx.clone(), self.breaker.clone()));
253 }
254
255 if in_flight.is_empty() {
256 break;
257 }
258
259 let joined = match in_flight.next().await {
260 Some(Ok(result)) => result,
261 Some(Err(join_err)) => {
262 return Err(DagExecutorError::Executor(format!(
263 "worker task panicked: {join_err}"
264 )))
265 }
266 None => break,
267 };
268
269 self.handle_result(&dag, &mut scheduler, &mut remaining, &ctx, joined)
270 .await?;
271 }
272
273 Ok(ExecutionReport {
274 run_id: ctx.run_id.clone(),
275 records: scheduler.records().clone(),
276 metrics: self.metrics.snapshot(),
277 })
278 }
279
280 async fn recover_scheduler(&self, dag: &Dag, ctx: &Arc<Context>) -> Result<Scheduler> {
282 let mut records: HashMap<String, TaskRecord> = HashMap::new();
283
284 if self.config.persist_state {
285 for id in dag.task_ids() {
286 let value = match self.storage.load(&Self::record_key(&id)).await {
290 Ok(v) => v,
291 Err(e) => {
292 tracing::warn!(task = %id, error = %e, "ignoring unreadable record during recovery");
293 None
294 }
295 };
296 if let Some(value) = value {
297 if let Ok(record) = serde_json::from_value::<TaskRecord>(value) {
298 if record.state == TaskState::Completed {
300 if let Some(out) = &record.output {
301 ctx.set(record.id.clone(), out.clone());
302 }
303 }
304 records.insert(id, record);
305 }
306 }
307 }
308 self.validator.repair(&mut records, self.retry.max_attempts);
309 }
310
311 let mut scheduler = Scheduler::with_records(records);
312 for id in dag.task_ids() {
314 scheduler.ensure_record(&id);
315 }
316 Ok(scheduler)
317 }
318
319 async fn handle_result(
321 &self,
322 dag: &Dag,
323 scheduler: &mut Scheduler,
324 remaining: &mut HashMap<String, usize>,
325 ctx: &Arc<Context>,
326 result: TaskResult,
327 ) -> Result<()> {
328 let TaskResult {
329 id,
330 attempts,
331 outcome,
332 } = result;
333
334 if let Some(record) = scheduler.records_mut().get_mut(&id) {
335 record.attempts = attempts;
336 }
337
338 match outcome {
339 Ok(output) => {
340 ctx.set(id.clone(), output.clone());
342 if let Some(record) = scheduler.records_mut().get_mut(&id) {
343 record.output = Some(output);
344 record.transition(TaskState::Completed);
345 }
346 let duration = scheduler
347 .record(&id)
348 .and_then(|r| r.duration_millis())
349 .unwrap_or(0);
350 self.metrics.task_completed(&id, duration);
351 self.persist(scheduler, &id).await?;
352
353 for dep in dag.dependents_of(&id) {
355 let count = remaining.entry(dep.clone()).or_insert(0);
356 *count = count.saturating_sub(1);
357 if *count == 0 && scheduler.state(&dep) == Some(TaskState::Pending) {
358 let prio = dag.task(&dep).map(|t| t.priority()).unwrap_or(0);
359 scheduler.mark_ready(&dep, prio);
360 }
361 }
362 }
363 Err(TaskError::Cancelled) => {
364 if let Some(record) = scheduler.records_mut().get_mut(&id) {
365 record.transition(TaskState::Cancelled);
366 }
367 self.persist(scheduler, &id).await?;
368 self.cascade_skip(dag, scheduler, &id).await?;
369 }
370 Err(err) => {
371 let msg = err.to_string();
372 if let Some(record) = scheduler.records_mut().get_mut(&id) {
373 record.error = Some(msg.clone());
374 record.transition(TaskState::Failed);
375 }
376 self.metrics.task_failed();
377
378 self.dead_letter.push(&id, attempts, msg).await?;
380 if let Some(record) = scheduler.records_mut().get_mut(&id) {
381 record.transition(TaskState::DeadLettered);
382 }
383 self.metrics.task_dead_lettered();
384 self.persist(scheduler, &id).await?;
385
386 self.cascade_skip(dag, scheduler, &id).await?;
387 }
388 }
389 Ok(())
390 }
391
392 async fn cascade_skip(
394 &self,
395 dag: &Dag,
396 scheduler: &mut Scheduler,
397 failed_id: &str,
398 ) -> Result<()> {
399 let mut stack: Vec<String> = dag.dependents_of(failed_id);
400 while let Some(id) = stack.pop() {
401 if scheduler.state(&id) == Some(TaskState::Pending)
402 && scheduler.transition(&id, TaskState::Skipped)
403 {
404 self.metrics.task_skipped();
405 self.persist(scheduler, &id).await?;
406 stack.extend(dag.dependents_of(&id));
407 }
408 }
409 Ok(())
410 }
411
412 async fn persist(&self, scheduler: &Scheduler, id: &str) -> Result<()> {
414 if !self.config.persist_state {
415 return Ok(());
416 }
417 if let Some(record) = scheduler.record(id) {
418 let value = serde_json::to_value(record).map_err(crate::error::StorageError::from)?;
419 self.storage.save(&Self::record_key(id), &value).await?;
420 }
421 Ok(())
422 }
423}
424
425impl Default for DagExecutor {
426 fn default() -> Self {
427 DagExecutor::new()
428 }
429}