Skip to main content

ironflow/dag/
definition.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3use chrono::{DateTime, Utc};
4
5/// Represents a complete DAG (Directed Acyclic Graph)
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DagDefinition {
8    pub id: String,
9    pub description: Option<String>,
10    pub schedule: Option<String>, // cron expression
11    pub max_active_runs: Option<u32>,
12    pub catchup: Option<bool>,
13    pub tasks: Vec<TaskDefinition>,
14}
15
16/// Represents a single task within a DAG
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct TaskDefinition {
19    pub id: String,
20    pub operator: String,
21    pub depends_on: Option<Vec<String>>,
22    pub retries: Option<u32>,
23    pub retry_delay_secs: Option<u64>,
24    pub timeout_secs: Option<u64>,
25    pub xcom_inputs: Option<Vec<String>>, // Task IDs whose outputs to inject
26    #[serde(flatten)]
27    pub config: serde_json::Value, // Operator-specific config
28}
29
30impl TaskDefinition {
31    pub fn dependencies(&self) -> Vec<String> {
32        self.depends_on.clone().unwrap_or_default()
33    }
34
35    pub fn xcom_dependencies(&self) -> Vec<String> {
36        self.xcom_inputs.clone().unwrap_or_default()
37    }
38}
39
40/// Represents a DAG run (one execution instance)
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DagRun {
43    pub id: String,
44    pub dag_id: String,
45    pub status: DagRunStatus,
46    pub started_at: DateTime<Utc>,
47    pub ended_at: Option<DateTime<Utc>>,
48    pub triggered_by: TriggerType,
49    pub run_number: u32,
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(rename_all = "lowercase")]
54pub enum DagRunStatus {
55    Queued,
56    Running,
57    Success,
58    Failed,
59}
60
61impl std::fmt::Display for DagRunStatus {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            DagRunStatus::Queued => write!(f, "queued"),
65            DagRunStatus::Running => write!(f, "running"),
66            DagRunStatus::Success => write!(f, "success"),
67            DagRunStatus::Failed => write!(f, "failed"),
68        }
69    }
70}
71
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(rename_all = "lowercase")]
74pub enum TriggerType {
75    Schedule,
76    Manual,
77}
78
79impl std::fmt::Display for TriggerType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            TriggerType::Schedule => write!(f, "schedule"),
83            TriggerType::Manual => write!(f, "manual"),
84        }
85    }
86}
87
88/// Represents a task run (one task execution in a DAG run)
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskRun {
91    pub id: String,
92    pub dag_run_id: String,
93    pub task_id: String,
94    pub status: TaskRunStatus,
95    pub started_at: Option<DateTime<Utc>>,
96    pub ended_at: Option<DateTime<Utc>>,
97    pub attempt_number: u32,
98    pub log: String,
99    pub xcom_output: Option<String>, // JSON
100}
101
102#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
103#[serde(rename_all = "lowercase")]
104pub enum TaskRunStatus {
105    Pending,
106    Running,
107    Success,
108    Failed,
109    Retried,
110    Skipped,
111}
112
113impl std::fmt::Display for TaskRunStatus {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            TaskRunStatus::Pending => write!(f, "pending"),
117            TaskRunStatus::Running => write!(f, "running"),
118            TaskRunStatus::Success => write!(f, "success"),
119            TaskRunStatus::Failed => write!(f, "failed"),
120            TaskRunStatus::Retried => write!(f, "retried"),
121            TaskRunStatus::Skipped => write!(f, "skipped"),
122        }
123    }
124}
125
126impl DagDefinition {
127    /// Get all task IDs in topological order for execution
128    pub fn task_execution_order(&self) -> Result<Vec<String>, String> {
129        let mut result = Vec::new();
130        let mut visited = HashSet::new();
131        let mut visiting = HashSet::new();
132
133        for task in &self.tasks {
134            self.dfs(&task.id, &mut result, &mut visited, &mut visiting)?;
135        }
136
137        Ok(result)
138    }
139
140    fn dfs(
141        &self,
142        task_id: &str,
143        result: &mut Vec<String>,
144        visited: &mut HashSet<String>,
145        visiting: &mut HashSet<String>,
146    ) -> Result<(), String> {
147        if visited.contains(task_id) {
148            return Ok(());
149        }
150
151        if visiting.contains(task_id) {
152            return Err(format!("Cycle detected involving task: {}", task_id));
153        }
154
155        visiting.insert(task_id.to_string());
156
157        if let Some(task) = self.tasks.iter().find(|t| t.id == task_id) {
158            for dep in &task.dependencies() {
159                self.dfs(dep, result, visited, visiting)?;
160            }
161        }
162
163        visiting.remove(task_id);
164        visited.insert(task_id.to_string());
165        result.push(task_id.to_string());
166
167        Ok(())
168    }
169
170    /// Get tasks that have no dependencies (can run immediately)
171    pub fn root_tasks(&self) -> Vec<String> {
172        self.tasks
173            .iter()
174            .filter(|t| t.dependencies().is_empty())
175            .map(|t| t.id.clone())
176            .collect()
177    }
178
179    /// Get tasks that depend on a given task
180    pub fn dependents(&self, task_id: &str) -> Vec<String> {
181        self.tasks
182            .iter()
183            .filter(|t| t.dependencies().contains(&task_id.to_string()))
184            .map(|t| t.id.clone())
185            .collect()
186    }
187
188    /// Get a task by ID
189    pub fn get_task(&self, task_id: &str) -> Option<&TaskDefinition> {
190        self.tasks.iter().find(|t| t.id == task_id)
191    }
192
193    /// Check if all dependencies of a task are satisfied
194    pub fn dependencies_satisfied(
195        &self,
196        task_id: &str,
197        completed_tasks: &HashSet<String>,
198    ) -> bool {
199        if let Some(task) = self.get_task(task_id) {
200            task.dependencies()
201                .iter()
202                .all(|dep| completed_tasks.contains(dep))
203        } else {
204            false
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_task_execution_order() {
215        let dag = DagDefinition {
216            id: "test_dag".to_string(),
217            description: None,
218            schedule: None,
219            max_active_runs: None,
220            catchup: None,
221            tasks: vec![
222                TaskDefinition {
223                    id: "a".to_string(),
224                    operator: "bash".to_string(),
225                    depends_on: None,
226                    retries: None,
227                    retry_delay_secs: None,
228                    timeout_secs: None,
229                    xcom_inputs: None,
230                    config: serde_json::json!({}),
231                },
232                TaskDefinition {
233                    id: "b".to_string(),
234                    operator: "bash".to_string(),
235                    depends_on: Some(vec!["a".to_string()]),
236                    retries: None,
237                    retry_delay_secs: None,
238                    timeout_secs: None,
239                    xcom_inputs: None,
240                    config: serde_json::json!({}),
241                },
242                TaskDefinition {
243                    id: "c".to_string(),
244                    operator: "bash".to_string(),
245                    depends_on: Some(vec!["b".to_string()]),
246                    retries: None,
247                    retry_delay_secs: None,
248                    timeout_secs: None,
249                    xcom_inputs: None,
250                    config: serde_json::json!({}),
251                },
252            ],
253        };
254
255        let order = dag.task_execution_order().unwrap();
256        assert_eq!(order, vec!["a", "b", "c"]);
257    }
258
259    #[test]
260    fn test_cycle_detection() {
261        let dag = DagDefinition {
262            id: "cyclic_dag".to_string(),
263            description: None,
264            schedule: None,
265            max_active_runs: None,
266            catchup: None,
267            tasks: vec![
268                TaskDefinition {
269                    id: "a".to_string(),
270                    operator: "bash".to_string(),
271                    depends_on: Some(vec!["c".to_string()]),
272                    retries: None,
273                    retry_delay_secs: None,
274                    timeout_secs: None,
275                    xcom_inputs: None,
276                    config: serde_json::json!({}),
277                },
278                TaskDefinition {
279                    id: "b".to_string(),
280                    operator: "bash".to_string(),
281                    depends_on: Some(vec!["a".to_string()]),
282                    retries: None,
283                    retry_delay_secs: None,
284                    timeout_secs: None,
285                    xcom_inputs: None,
286                    config: serde_json::json!({}),
287                },
288                TaskDefinition {
289                    id: "c".to_string(),
290                    operator: "bash".to_string(),
291                    depends_on: Some(vec!["b".to_string()]),
292                    retries: None,
293                    retry_delay_secs: None,
294                    timeout_secs: None,
295                    xcom_inputs: None,
296                    config: serde_json::json!({}),
297                },
298            ],
299        };
300
301        let result = dag.task_execution_order();
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn test_root_tasks() {
307        let dag = DagDefinition {
308            id: "test_dag".to_string(),
309            description: None,
310            schedule: None,
311            max_active_runs: None,
312            catchup: None,
313            tasks: vec![
314                TaskDefinition {
315                    id: "a".to_string(),
316                    operator: "bash".to_string(),
317                    depends_on: None,
318                    retries: None,
319                    retry_delay_secs: None,
320                    timeout_secs: None,
321                    xcom_inputs: None,
322                    config: serde_json::json!({}),
323                },
324                TaskDefinition {
325                    id: "b".to_string(),
326                    operator: "bash".to_string(),
327                    depends_on: Some(vec!["a".to_string()]),
328                    retries: None,
329                    retry_delay_secs: None,
330                    timeout_secs: None,
331                    xcom_inputs: None,
332                    config: serde_json::json!({}),
333                },
334            ],
335        };
336
337        let roots = dag.root_tasks();
338        assert_eq!(roots, vec!["a"]);
339    }
340}