use std::sync::Arc;
use clap::Parser;
use dag_executor::context::Context;
use dag_executor::prelude::*;
use dag_executor::utils::tracing as dag_tracing;
#[derive(Parser, Debug)]
#[command(
name = "dag-executor",
version,
about = "Run a demonstration DAG workflow"
)]
struct Cli {
#[arg(long)]
config: Option<String>,
#[arg(long, default_value_t = 64)]
concurrency: usize,
#[arg(long, default_value_t = 8)]
workers: usize,
#[arg(long)]
no_persist: bool,
#[arg(long, default_value = "info")]
log_level: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
dag_tracing::init(&cli.log_level);
let mut config = match &cli.config {
Some(path) => Config::from_file(path)?,
None => Config::default(),
}
.with_env_overrides();
config.max_concurrency = cli.concurrency;
config.persist_state = !cli.no_persist;
config.log_level = cli.log_level.clone();
let executor = DagExecutor::builder().config(config).build();
let dag = build_demo_dag(cli.workers)?;
let ctx = Arc::new(Context::new(Arc::new(Config::default())));
let cancel_ctx = ctx.clone();
tokio::spawn(async move {
if tokio::signal::ctrl_c().await.is_ok() {
tracing::warn!("received Ctrl-C, cancelling run");
cancel_ctx.cancel();
}
});
tracing::info!(workers = cli.workers, "starting demo workflow");
let report = executor.run_with_context(dag, ctx).await?;
let snap = &report.metrics;
tracing::info!(
run_id = %report.run_id,
completed = snap.tasks_completed,
failed = snap.tasks_failed,
retries = snap.retries,
skipped = snap.tasks_skipped,
success_rate = snap.success_rate(),
"run finished"
);
if report.is_success() {
println!(
"✓ workflow completed successfully ({} tasks)",
report.records.len()
);
Ok(())
} else {
let failed = report.failed_tasks();
println!("✗ workflow finished with failures: {failed:?}");
std::process::exit(1);
}
}
fn build_demo_dag(workers: usize) -> anyhow::Result<Dag> {
let mut dag = Dag::new();
dag.add_task(Arc::new(BasicTask::new(
"setup",
|ctx: Arc<Context>| async move {
ctx.set("config", serde_json::json!({ "ready": true }));
Ok(serde_json::json!("setup-done"))
},
)))?;
for task in patterns::fan_out_in(
"stage",
workers,
Some("setup"),
|_ctx, i| async move {
tokio::time::sleep(std::time::Duration::from_millis(5 * (i as u64 % 4 + 1))).await;
Ok(serde_json::json!({ "worker": i, "value": i * i }))
},
|_ctx, results| async move {
let sum: i64 = results
.iter()
.filter_map(|r| r.get("value").and_then(|v| v.as_i64()))
.sum();
Ok(serde_json::json!({ "sum_of_squares": sum }))
},
) {
dag.add_task(task)?;
}
dag.add_task(Arc::new(
BasicTask::new("summary", |ctx: Arc<Context>| async move {
let agg = ctx
.get("stage.aggregate")
.unwrap_or(serde_json::Value::Null);
tracing::info!(?agg, "aggregated result");
Ok(agg)
})
.with_deps(["stage.aggregate"]),
))?;
Ok(dag)
}