1use crate::dag::{DagDefinition, DagRun, TaskRunStatus};
2use crate::operators::OperatorRegistry;
3use crate::store::Store;
4use anyhow::Result;
5use std::collections::{HashMap, HashSet};
6use tracing::{info, warn};
7
8pub struct DagExecutor {
9 store: std::sync::Arc<Store>,
10}
11
12impl DagExecutor {
13 pub fn new(store: std::sync::Arc<Store>) -> Self {
14 DagExecutor { store }
15 }
16
17 pub async fn execute(&self, dag: &DagDefinition, dag_run: &DagRun) -> Result<()> {
19 info!("Starting execution of DAG run: {}", dag_run.id);
20
21 self.store
23 .update_dag_run_status(&dag_run.id, crate::dag::DagRunStatus::Running)
24 .await?;
25
26 let mut task_runs = HashMap::new();
28 for task in &dag.tasks {
29 let task_run = self.store.create_task_run(&dag_run.id, &task.id).await?;
30 task_runs.insert(task.id.clone(), task_run);
31 }
32
33 let mut completed_tasks = HashSet::new();
35 let mut failed_tasks = HashSet::new();
36 let mut running_tasks = HashSet::new();
37 let mut join_set = tokio::task::JoinSet::new();
38
39 loop {
40 let runnable_tasks: Vec<String> = dag
42 .tasks
43 .iter()
44 .filter(|task| {
45 !completed_tasks.contains(&task.id)
46 && !failed_tasks.contains(&task.id)
47 && !running_tasks.contains(&task.id)
48 && dag.dependencies_satisfied(&task.id, &completed_tasks)
49 })
50 .map(|t| t.id.clone())
51 .collect();
52
53 for task_id in runnable_tasks {
55 running_tasks.insert(task_id.clone());
56
57 let task = dag.get_task(&task_id).unwrap();
58 let task_run = task_runs[&task_id].clone();
59 let store = std::sync::Arc::clone(&self.store);
60 let dag_def = dag.clone();
61 let task_def = task.clone();
62 let task_id_clone = task_id.clone();
63 let dag_run_id = dag_run.id.clone();
64
65 join_set.spawn(async move {
66 (
67 task_id_clone,
68 Self::execute_task(&store, &dag_def, &dag_run_id, &task_run, &task_def).await,
69 )
70 });
71 }
72
73 if join_set.is_empty() {
75 break;
76 }
77
78 if let Some(res) = join_set.join_next().await {
80 let (task_id, result) = res?;
81 running_tasks.remove(&task_id);
82
83 if result.is_ok() {
84 completed_tasks.insert(task_id);
85 } else {
86 failed_tasks.insert(task_id);
87 }
88 }
89 }
90
91 let dag_status = if failed_tasks.is_empty() {
93 crate::dag::DagRunStatus::Success
94 } else {
95 crate::dag::DagRunStatus::Failed
96 };
97
98 self.store
99 .update_dag_run_status(&dag_run.id, dag_status)
100 .await?;
101
102 info!("Completed execution of DAG run: {}", dag_run.id);
103 Ok(())
104 }
105
106 async fn execute_task(
107 store: &std::sync::Arc<Store>,
108 _dag: &DagDefinition,
109 dag_run_id: &str,
110 task_run: &crate::dag::TaskRun,
111 task_def: &crate::dag::TaskDefinition,
112 ) -> Result<()> {
113 let mut attempt = task_run.attempt_number;
114 let max_attempts = task_def.retries.unwrap_or(0) + 1;
115
116 loop {
117 info!("Executing task: {} (attempt {}/{})", task_def.id, attempt, max_attempts);
118
119 store
121 .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
122 .await?;
123
124 let mut task_config = task_def.config.clone();
126
127 for upstream_task_id in task_def.xcom_dependencies() {
129 if let Ok(Some(xcom_output)) = store.get_xcom(dag_run_id, &upstream_task_id).await {
130 if let Ok(xcom_json) = serde_json::from_str::<serde_json::Value>(&xcom_output) {
132 if !task_config.is_object() {
134 task_config = serde_json::json!({});
135 }
136 if let Some(obj) = task_config.as_object_mut() {
137 if !obj.contains_key("xcom") {
138 obj.insert("xcom".to_string(), serde_json::json!({}));
139 }
140 if let Some(xcom_obj) = obj.get_mut("xcom").and_then(|x| x.as_object_mut()) {
141 xcom_obj.insert(upstream_task_id.clone(), xcom_json);
142 }
143 }
144 }
145 }
146 }
147
148 let operator = OperatorRegistry::get_operator(&task_def.operator)
150 .ok_or_else(|| anyhow::anyhow!("Unknown operator: {}", task_def.operator))?;
151
152 let timeout_secs = task_def.timeout_secs.unwrap_or(3600); let execution_result = tokio::time::timeout(
155 tokio::time::Duration::from_secs(timeout_secs),
156 operator.execute(&task_config)
157 ).await;
158
159 let final_result = match execution_result {
160 Ok(res) => res,
161 Err(_) => Err(anyhow::anyhow!("Task execution timed out after {} seconds", timeout_secs)),
162 };
163
164 match final_result {
165 Ok(output) => {
166 info!("Task {} succeeded", task_def.id);
167 let output_clone = output.clone();
168 store
169 .update_task_run(
170 &task_run.id,
171 TaskRunStatus::Success,
172 Some(&output),
173 Some(output_clone),
174 )
175 .await?;
176 return Ok(());
177 }
178 Err(e) => {
179 warn!("Task {} failed (attempt {}/{}): {}", task_def.id, attempt, max_attempts, e);
180
181 if attempt < max_attempts {
182 store
183 .update_task_run(
184 &task_run.id,
185 TaskRunStatus::Retried,
186 Some(&e.to_string()),
187 None,
188 )
189 .await?;
190
191 let delay = task_def.retry_delay_secs.unwrap_or(60);
193 tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
194
195 attempt += 1;
197 store.increment_task_run_attempt(&task_run.id).await?;
198 continue;
199 } else {
200 store
201 .update_task_run(
202 &task_run.id,
203 TaskRunStatus::Failed,
204 Some(&e.to_string()),
205 None,
206 )
207 .await?;
208 return Err(e);
209 }
210 }
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::dag::{TaskDefinition, TriggerType};
220
221 #[tokio::test]
222 async fn test_executor_simple_dag() {
223 let store = std::sync::Arc::new(Store::new("sqlite::memory:").await.unwrap());
224
225 let dag = DagDefinition {
226 id: "test_dag".to_string(),
227 description: None,
228 schedule: None,
229 max_active_runs: None,
230 catchup: None,
231 tasks: vec![TaskDefinition {
232 id: "simple_task".to_string(),
233 operator: "bash".to_string(),
234 depends_on: None,
235 retries: None,
236 retry_delay_secs: None,
237 timeout_secs: None,
238 xcom_inputs: None,
239 config: serde_json::json!({
240 "command": "echo 'test'"
241 }),
242 }],
243 };
244
245 store.save_dag(&dag).await.unwrap();
246 let dag_run = store.create_dag_run(&dag.id, TriggerType::Manual).await.unwrap();
247
248 let executor = DagExecutor::new(std::sync::Arc::clone(&store));
249 let result = executor.execute(&dag, &dag_run).await;
250
251 assert!(result.is_ok());
252 }
253}