mod graph;
mod scheduler;
mod worker_pool;
pub use graph::Dag;
pub use scheduler::Scheduler;
pub use worker_pool::{TaskResult, WorkerPool};
use crate::advanced::{CircuitBreaker, DeadLetterQueue, RetryPolicy};
use crate::context::Context;
use crate::error::{DagExecutorError, Result, TaskError};
use crate::metrics::{MetricsCollector, MetricsSnapshot};
use crate::state::{StateValidator, TaskRecord, TaskState};
use crate::storage::{FileStorage, MemoryStorage, Storage};
use crate::utils::Config;
use futures::stream::{FuturesUnordered, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ExecutionReport {
pub run_id: String,
pub records: HashMap<String, TaskRecord>,
pub metrics: MetricsSnapshot,
}
impl ExecutionReport {
pub fn is_success(&self) -> bool {
!self.records.is_empty()
&& self
.records
.values()
.all(|r| r.state == TaskState::Completed)
}
pub fn failed_tasks(&self) -> Vec<String> {
self.records
.values()
.filter(|r| r.state.is_failure())
.map(|r| r.id.clone())
.collect()
}
pub fn count_in(&self, state: TaskState) -> usize {
self.records.values().filter(|r| r.state == state).count()
}
}
pub struct DagExecutorBuilder {
config: Config,
storage: Option<Arc<dyn Storage>>,
retry: Option<RetryPolicy>,
breaker: Option<Arc<CircuitBreaker>>,
}
impl DagExecutorBuilder {
fn new() -> Self {
DagExecutorBuilder {
config: Config::default(),
storage: None,
retry: None,
breaker: None,
}
}
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
pub fn concurrency(mut self, n: usize) -> Self {
self.config.max_concurrency = n;
self
}
pub fn storage(mut self, storage: Arc<dyn Storage>) -> Self {
self.storage = Some(storage);
self
}
pub fn retry(mut self, retry: RetryPolicy) -> Self {
self.retry = Some(retry);
self
}
pub fn circuit_breaker(mut self, breaker: CircuitBreaker) -> Self {
self.breaker = Some(Arc::new(breaker));
self
}
pub fn persist(mut self, persist: bool) -> Self {
self.config.persist_state = persist;
self
}
pub fn build(self) -> DagExecutor {
let config = Arc::new(self.config);
let storage: Arc<dyn Storage> = self.storage.unwrap_or_else(|| {
if config.persist_state {
Arc::new(
FileStorage::open(&config.storage_dir)
.expect("failed to open storage directory"),
)
} else {
Arc::new(MemoryStorage::new())
}
});
let retry = self.retry.unwrap_or(RetryPolicy {
max_attempts: config.max_attempts,
..RetryPolicy::default()
});
DagExecutor {
metrics: Arc::new(MetricsCollector::new()),
dead_letter: DeadLetterQueue::new(storage.clone()),
validator: StateValidator::new(),
breaker: self.breaker,
timeout: config.task_timeout,
retry,
storage,
config,
}
}
}
pub struct DagExecutor {
config: Arc<Config>,
storage: Arc<dyn Storage>,
metrics: Arc<MetricsCollector>,
dead_letter: DeadLetterQueue,
validator: StateValidator,
breaker: Option<Arc<CircuitBreaker>>,
retry: RetryPolicy,
timeout: Option<Duration>,
}
impl DagExecutor {
pub fn builder() -> DagExecutorBuilder {
DagExecutorBuilder::new()
}
pub fn new() -> Self {
DagExecutorBuilder::new().build()
}
pub fn metrics(&self) -> &Arc<MetricsCollector> {
&self.metrics
}
pub fn dead_letter(&self) -> &DeadLetterQueue {
&self.dead_letter
}
fn record_key(id: &str) -> String {
format!("record:{id}")
}
pub async fn run(&self, dag: Dag) -> Result<ExecutionReport> {
let ctx = Arc::new(Context::new(self.config.clone()));
self.run_with_context(dag, ctx).await
}
pub async fn run_with_context(&self, dag: Dag, ctx: Arc<Context>) -> Result<ExecutionReport> {
dag.validate()?;
let mut scheduler = self.recover_scheduler(&dag, &ctx).await?;
let mut remaining: HashMap<String, usize> = HashMap::new();
for id in dag.task_ids() {
let pending_deps = dag
.dependencies_of(&id)
.into_iter()
.filter(|d| scheduler.state(d) != Some(TaskState::Completed))
.count();
remaining.insert(id, pending_deps);
}
let pool = WorkerPool::new(
self.config.max_concurrency,
self.retry,
self.timeout,
self.metrics.clone(),
);
let initial_failures: Vec<String> = dag
.task_ids()
.into_iter()
.filter(|id| scheduler.state(id).map(|s| s.is_failure()).unwrap_or(false))
.collect();
for id in initial_failures {
self.cascade_skip(&dag, &mut scheduler, &id).await?;
}
for id in dag.task_ids() {
if scheduler.state(&id) == Some(TaskState::Pending)
&& remaining.get(&id).copied().unwrap_or(0) == 0
{
let prio = dag.task(&id).map(|t| t.priority()).unwrap_or(0);
scheduler.mark_ready(&id, prio);
}
}
let mut in_flight: FuturesUnordered<_> = FuturesUnordered::new();
loop {
while let Some(id) = scheduler.next_ready() {
let task = match dag.task(&id) {
Some(t) => t,
None => continue,
};
scheduler.transition(&id, TaskState::Running);
self.persist(&scheduler, &id).await?;
in_flight.push(pool.spawn(task, ctx.clone(), self.breaker.clone()));
}
if in_flight.is_empty() {
break;
}
let joined = match in_flight.next().await {
Some(Ok(result)) => result,
Some(Err(join_err)) => {
return Err(DagExecutorError::Executor(format!(
"worker task panicked: {join_err}"
)))
}
None => break,
};
self.handle_result(&dag, &mut scheduler, &mut remaining, &ctx, joined)
.await?;
}
Ok(ExecutionReport {
run_id: ctx.run_id.clone(),
records: scheduler.records().clone(),
metrics: self.metrics.snapshot(),
})
}
async fn recover_scheduler(&self, dag: &Dag, ctx: &Arc<Context>) -> Result<Scheduler> {
let mut records: HashMap<String, TaskRecord> = HashMap::new();
if self.config.persist_state {
for id in dag.task_ids() {
let value = match self.storage.load(&Self::record_key(&id)).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(task = %id, error = %e, "ignoring unreadable record during recovery");
None
}
};
if let Some(value) = value {
if let Ok(record) = serde_json::from_value::<TaskRecord>(value) {
if record.state == TaskState::Completed {
if let Some(out) = &record.output {
ctx.set(record.id.clone(), out.clone());
}
}
records.insert(id, record);
}
}
}
self.validator.repair(&mut records, self.retry.max_attempts);
}
let mut scheduler = Scheduler::with_records(records);
for id in dag.task_ids() {
scheduler.ensure_record(&id);
}
Ok(scheduler)
}
async fn handle_result(
&self,
dag: &Dag,
scheduler: &mut Scheduler,
remaining: &mut HashMap<String, usize>,
ctx: &Arc<Context>,
result: TaskResult,
) -> Result<()> {
let TaskResult {
id,
attempts,
outcome,
} = result;
if let Some(record) = scheduler.records_mut().get_mut(&id) {
record.attempts = attempts;
}
match outcome {
Ok(output) => {
ctx.set(id.clone(), output.clone());
if let Some(record) = scheduler.records_mut().get_mut(&id) {
record.output = Some(output);
record.transition(TaskState::Completed);
}
let duration = scheduler
.record(&id)
.and_then(|r| r.duration_millis())
.unwrap_or(0);
self.metrics.task_completed(&id, duration);
self.persist(scheduler, &id).await?;
for dep in dag.dependents_of(&id) {
let count = remaining.entry(dep.clone()).or_insert(0);
*count = count.saturating_sub(1);
if *count == 0 && scheduler.state(&dep) == Some(TaskState::Pending) {
let prio = dag.task(&dep).map(|t| t.priority()).unwrap_or(0);
scheduler.mark_ready(&dep, prio);
}
}
}
Err(TaskError::Cancelled) => {
if let Some(record) = scheduler.records_mut().get_mut(&id) {
record.transition(TaskState::Cancelled);
}
self.persist(scheduler, &id).await?;
self.cascade_skip(dag, scheduler, &id).await?;
}
Err(err) => {
let msg = err.to_string();
if let Some(record) = scheduler.records_mut().get_mut(&id) {
record.error = Some(msg.clone());
record.transition(TaskState::Failed);
}
self.metrics.task_failed();
self.dead_letter.push(&id, attempts, msg).await?;
if let Some(record) = scheduler.records_mut().get_mut(&id) {
record.transition(TaskState::DeadLettered);
}
self.metrics.task_dead_lettered();
self.persist(scheduler, &id).await?;
self.cascade_skip(dag, scheduler, &id).await?;
}
}
Ok(())
}
async fn cascade_skip(
&self,
dag: &Dag,
scheduler: &mut Scheduler,
failed_id: &str,
) -> Result<()> {
let mut stack: Vec<String> = dag.dependents_of(failed_id);
while let Some(id) = stack.pop() {
if scheduler.state(&id) == Some(TaskState::Pending)
&& scheduler.transition(&id, TaskState::Skipped)
{
self.metrics.task_skipped();
self.persist(scheduler, &id).await?;
stack.extend(dag.dependents_of(&id));
}
}
Ok(())
}
async fn persist(&self, scheduler: &Scheduler, id: &str) -> Result<()> {
if !self.config.persist_state {
return Ok(());
}
if let Some(record) = scheduler.record(id) {
let value = serde_json::to_value(record).map_err(crate::error::StorageError::from)?;
self.storage.save(&Self::record_key(id), &value).await?;
}
Ok(())
}
}
impl Default for DagExecutor {
fn default() -> Self {
DagExecutor::new()
}
}