Skip to main content

ironflow/executor/
mod.rs

1use crate::dag::{DagDefinition, DagRun, TaskRunStatus};
2use crate::operators::OperatorRegistry;
3use crate::store::Store;
4use anyhow::Result;
5use std::collections::{HashMap, HashSet};
6use tracing::{info, warn};
7
8pub struct DagExecutor {
9    store: std::sync::Arc<Store>,
10}
11
12impl DagExecutor {
13    pub fn new(store: std::sync::Arc<Store>) -> Self {
14        DagExecutor { store }
15    }
16
17    /// Execute a DAG run
18    pub async fn execute(&self, dag: &DagDefinition, dag_run: &DagRun) -> Result<()> {
19        info!("Starting execution of DAG run: {}", dag_run.id);
20
21        // Mark DAG run as running
22        self.store
23            .update_dag_run_status(&dag_run.id, crate::dag::DagRunStatus::Running)
24            .await?;
25
26        // Create task runs for all tasks
27        let mut task_runs = HashMap::new();
28        for task in &dag.tasks {
29            let task_run = self.store.create_task_run(&dag_run.id, &task.id).await?;
30            task_runs.insert(task.id.clone(), task_run);
31        }
32
33        // Execute tasks respecting dependencies
34        let mut completed_tasks = HashSet::new();
35        let mut failed_tasks = HashSet::new();
36        let mut running_tasks = HashSet::new();
37        let mut join_set = tokio::task::JoinSet::new();
38
39        loop {
40            // Find tasks that can run (all dependencies satisfied)
41            let runnable_tasks: Vec<String> = dag
42                .tasks
43                .iter()
44                .filter(|task| {
45                    !completed_tasks.contains(&task.id)
46                        && !failed_tasks.contains(&task.id)
47                        && !running_tasks.contains(&task.id)
48                        && dag.dependencies_satisfied(&task.id, &completed_tasks)
49                })
50                .map(|t| t.id.clone())
51                .collect();
52
53            // Spawn runnable tasks into the JoinSet
54            for task_id in runnable_tasks {
55                running_tasks.insert(task_id.clone());
56                
57                let task = dag.get_task(&task_id).unwrap();
58                let task_run = task_runs[&task_id].clone();
59                let store = std::sync::Arc::clone(&self.store);
60                let dag_def = dag.clone();
61                let task_def = task.clone();
62                let task_id_clone = task_id.clone();
63                let dag_run_id = dag_run.id.clone();
64
65                join_set.spawn(async move {
66                    (
67                        task_id_clone,
68                        Self::execute_task(&store, &dag_def, &dag_run_id, &task_run, &task_def).await,
69                    )
70                });
71            }
72
73            // If nothing is running and nothing is runnable, we are done
74            if join_set.is_empty() {
75                break;
76            }
77
78            // Wait for the next task to complete
79            if let Some(res) = join_set.join_next().await {
80                let (task_id, result) = res?;
81                running_tasks.remove(&task_id);
82                
83                if result.is_ok() {
84                    completed_tasks.insert(task_id);
85                } else {
86                    failed_tasks.insert(task_id);
87                }
88            }
89        }
90
91        // Determine overall DAG run status
92        let dag_status = if failed_tasks.is_empty() {
93            crate::dag::DagRunStatus::Success
94        } else {
95            crate::dag::DagRunStatus::Failed
96        };
97
98        self.store
99            .update_dag_run_status(&dag_run.id, dag_status)
100            .await?;
101
102        info!("Completed execution of DAG run: {}", dag_run.id);
103        Ok(())
104    }
105
106    async fn execute_task(
107        store: &std::sync::Arc<Store>,
108        _dag: &DagDefinition,
109        dag_run_id: &str,
110        task_run: &crate::dag::TaskRun,
111        task_def: &crate::dag::TaskDefinition,
112    ) -> Result<()> {
113        let mut attempt = task_run.attempt_number;
114        let max_attempts = task_def.retries.unwrap_or(0) + 1;
115
116        loop {
117            info!("Executing task: {} (attempt {}/{})", task_def.id, attempt, max_attempts);
118
119            // Mark task as running
120            store
121                .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
122                .await?;
123
124            // Prepare task config with XCom injections
125            let mut task_config = task_def.config.clone();
126            
127            // Inject XCom outputs from upstream tasks
128            for upstream_task_id in task_def.xcom_dependencies() {
129                if let Ok(Some(xcom_output)) = store.get_xcom(dag_run_id, &upstream_task_id).await {
130                    // Parse the XCom output as JSON
131                    if let Ok(xcom_json) = serde_json::from_str::<serde_json::Value>(&xcom_output) {
132                        // Inject under xcom.<task_id> key
133                        if !task_config.is_object() {
134                            task_config = serde_json::json!({});
135                        }
136                        if let Some(obj) = task_config.as_object_mut() {
137                            if !obj.contains_key("xcom") {
138                                obj.insert("xcom".to_string(), serde_json::json!({}));
139                            }
140                            if let Some(xcom_obj) = obj.get_mut("xcom").and_then(|x| x.as_object_mut()) {
141                                xcom_obj.insert(upstream_task_id.clone(), xcom_json);
142                            }
143                        }
144                    }
145                }
146            }
147
148            // Get the operator
149            let operator = OperatorRegistry::get_operator(&task_def.operator)
150                .ok_or_else(|| anyhow::anyhow!("Unknown operator: {}", task_def.operator))?;
151
152            // Execute the operator with timeout
153            let timeout_secs = task_def.timeout_secs.unwrap_or(3600); // 1 hour default
154            let execution_result = tokio::time::timeout(
155                tokio::time::Duration::from_secs(timeout_secs),
156                operator.execute(&task_config)
157            ).await;
158
159            let final_result = match execution_result {
160                Ok(res) => res,
161                Err(_) => Err(anyhow::anyhow!("Task execution timed out after {} seconds", timeout_secs)),
162            };
163
164            match final_result {
165                Ok(output) => {
166                    info!("Task {} succeeded", task_def.id);
167                    let output_clone = output.clone();
168                    store
169                        .update_task_run(
170                            &task_run.id,
171                            TaskRunStatus::Success,
172                            Some(&output),
173                            Some(output_clone),
174                        )
175                        .await?;
176                    return Ok(());
177                }
178                Err(e) => {
179                    warn!("Task {} failed (attempt {}/{}): {}", task_def.id, attempt, max_attempts, e);
180
181                    if attempt < max_attempts {
182                        store
183                            .update_task_run(
184                                &task_run.id,
185                                TaskRunStatus::Retried,
186                                Some(&e.to_string()),
187                                None,
188                            )
189                            .await?;
190
191                        // Wait before retrying
192                        let delay = task_def.retry_delay_secs.unwrap_or(60);
193                        tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
194
195                        // Increment attempt and retry
196                        attempt += 1;
197                        store.increment_task_run_attempt(&task_run.id).await?;
198                        continue;
199                    } else {
200                        store
201                            .update_task_run(
202                                &task_run.id,
203                                TaskRunStatus::Failed,
204                                Some(&e.to_string()),
205                                None,
206                            )
207                            .await?;
208                        return Err(e);
209                    }
210                }
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::dag::{TaskDefinition, TriggerType};
220
221    #[tokio::test]
222    async fn test_executor_simple_dag() {
223        let store = std::sync::Arc::new(Store::new("sqlite::memory:").await.unwrap());
224
225        let dag = DagDefinition {
226            id: "test_dag".to_string(),
227            description: None,
228            schedule: None,
229            max_active_runs: None,
230            catchup: None,
231            tasks: vec![TaskDefinition {
232                id: "simple_task".to_string(),
233                operator: "bash".to_string(),
234                depends_on: None,
235                retries: None,
236                retry_delay_secs: None,
237                timeout_secs: None,
238                xcom_inputs: None,
239                config: serde_json::json!({
240                    "command": "echo 'test'"
241                }),
242            }],
243        };
244
245        store.save_dag(&dag).await.unwrap();
246        let dag_run = store.create_dag_run(&dag.id, TriggerType::Manual).await.unwrap();
247
248        let executor = DagExecutor::new(std::sync::Arc::clone(&store));
249        let result = executor.execute(&dag, &dag_run).await;
250
251        assert!(result.is_ok());
252    }
253}