1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3use chrono::{DateTime, Utc};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct DagDefinition {
8 pub id: String,
9 pub description: Option<String>,
10 pub schedule: Option<String>, pub max_active_runs: Option<u32>,
12 pub catchup: Option<bool>,
13 pub tasks: Vec<TaskDefinition>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct TaskDefinition {
19 pub id: String,
20 pub operator: String,
21 pub depends_on: Option<Vec<String>>,
22 pub retries: Option<u32>,
23 pub retry_delay_secs: Option<u64>,
24 pub timeout_secs: Option<u64>,
25 pub xcom_inputs: Option<Vec<String>>, #[serde(flatten)]
27 pub config: serde_json::Value, }
29
30impl TaskDefinition {
31 pub fn dependencies(&self) -> Vec<String> {
32 self.depends_on.clone().unwrap_or_default()
33 }
34
35 pub fn xcom_dependencies(&self) -> Vec<String> {
36 self.xcom_inputs.clone().unwrap_or_default()
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct DagRun {
43 pub id: String,
44 pub dag_id: String,
45 pub status: DagRunStatus,
46 pub started_at: DateTime<Utc>,
47 pub ended_at: Option<DateTime<Utc>>,
48 pub triggered_by: TriggerType,
49 pub run_number: u32,
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(rename_all = "lowercase")]
54pub enum DagRunStatus {
55 Queued,
56 Running,
57 Success,
58 Failed,
59}
60
61impl std::fmt::Display for DagRunStatus {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 DagRunStatus::Queued => write!(f, "queued"),
65 DagRunStatus::Running => write!(f, "running"),
66 DagRunStatus::Success => write!(f, "success"),
67 DagRunStatus::Failed => write!(f, "failed"),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(rename_all = "lowercase")]
74pub enum TriggerType {
75 Schedule,
76 Manual,
77}
78
79impl std::fmt::Display for TriggerType {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 TriggerType::Schedule => write!(f, "schedule"),
83 TriggerType::Manual => write!(f, "manual"),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskRun {
91 pub id: String,
92 pub dag_run_id: String,
93 pub task_id: String,
94 pub status: TaskRunStatus,
95 pub started_at: Option<DateTime<Utc>>,
96 pub ended_at: Option<DateTime<Utc>>,
97 pub attempt_number: u32,
98 pub log: String,
99 pub xcom_output: Option<String>, }
101
102#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
103#[serde(rename_all = "lowercase")]
104pub enum TaskRunStatus {
105 Pending,
106 Running,
107 Success,
108 Failed,
109 Retried,
110 Skipped,
111}
112
113impl std::fmt::Display for TaskRunStatus {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 TaskRunStatus::Pending => write!(f, "pending"),
117 TaskRunStatus::Running => write!(f, "running"),
118 TaskRunStatus::Success => write!(f, "success"),
119 TaskRunStatus::Failed => write!(f, "failed"),
120 TaskRunStatus::Retried => write!(f, "retried"),
121 TaskRunStatus::Skipped => write!(f, "skipped"),
122 }
123 }
124}
125
126impl DagDefinition {
127 pub fn task_execution_order(&self) -> Result<Vec<String>, String> {
129 let mut result = Vec::new();
130 let mut visited = HashSet::new();
131 let mut visiting = HashSet::new();
132
133 for task in &self.tasks {
134 self.dfs(&task.id, &mut result, &mut visited, &mut visiting)?;
135 }
136
137 Ok(result)
138 }
139
140 fn dfs(
141 &self,
142 task_id: &str,
143 result: &mut Vec<String>,
144 visited: &mut HashSet<String>,
145 visiting: &mut HashSet<String>,
146 ) -> Result<(), String> {
147 if visited.contains(task_id) {
148 return Ok(());
149 }
150
151 if visiting.contains(task_id) {
152 return Err(format!("Cycle detected involving task: {}", task_id));
153 }
154
155 visiting.insert(task_id.to_string());
156
157 if let Some(task) = self.tasks.iter().find(|t| t.id == task_id) {
158 for dep in &task.dependencies() {
159 self.dfs(dep, result, visited, visiting)?;
160 }
161 }
162
163 visiting.remove(task_id);
164 visited.insert(task_id.to_string());
165 result.push(task_id.to_string());
166
167 Ok(())
168 }
169
170 pub fn root_tasks(&self) -> Vec<String> {
172 self.tasks
173 .iter()
174 .filter(|t| t.dependencies().is_empty())
175 .map(|t| t.id.clone())
176 .collect()
177 }
178
179 pub fn dependents(&self, task_id: &str) -> Vec<String> {
181 self.tasks
182 .iter()
183 .filter(|t| t.dependencies().contains(&task_id.to_string()))
184 .map(|t| t.id.clone())
185 .collect()
186 }
187
188 pub fn get_task(&self, task_id: &str) -> Option<&TaskDefinition> {
190 self.tasks.iter().find(|t| t.id == task_id)
191 }
192
193 pub fn dependencies_satisfied(
195 &self,
196 task_id: &str,
197 completed_tasks: &HashSet<String>,
198 ) -> bool {
199 if let Some(task) = self.get_task(task_id) {
200 task.dependencies()
201 .iter()
202 .all(|dep| completed_tasks.contains(dep))
203 } else {
204 false
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_task_execution_order() {
215 let dag = DagDefinition {
216 id: "test_dag".to_string(),
217 description: None,
218 schedule: None,
219 max_active_runs: None,
220 catchup: None,
221 tasks: vec![
222 TaskDefinition {
223 id: "a".to_string(),
224 operator: "bash".to_string(),
225 depends_on: None,
226 retries: None,
227 retry_delay_secs: None,
228 timeout_secs: None,
229 xcom_inputs: None,
230 config: serde_json::json!({}),
231 },
232 TaskDefinition {
233 id: "b".to_string(),
234 operator: "bash".to_string(),
235 depends_on: Some(vec!["a".to_string()]),
236 retries: None,
237 retry_delay_secs: None,
238 timeout_secs: None,
239 xcom_inputs: None,
240 config: serde_json::json!({}),
241 },
242 TaskDefinition {
243 id: "c".to_string(),
244 operator: "bash".to_string(),
245 depends_on: Some(vec!["b".to_string()]),
246 retries: None,
247 retry_delay_secs: None,
248 timeout_secs: None,
249 xcom_inputs: None,
250 config: serde_json::json!({}),
251 },
252 ],
253 };
254
255 let order = dag.task_execution_order().unwrap();
256 assert_eq!(order, vec!["a", "b", "c"]);
257 }
258
259 #[test]
260 fn test_cycle_detection() {
261 let dag = DagDefinition {
262 id: "cyclic_dag".to_string(),
263 description: None,
264 schedule: None,
265 max_active_runs: None,
266 catchup: None,
267 tasks: vec![
268 TaskDefinition {
269 id: "a".to_string(),
270 operator: "bash".to_string(),
271 depends_on: Some(vec!["c".to_string()]),
272 retries: None,
273 retry_delay_secs: None,
274 timeout_secs: None,
275 xcom_inputs: None,
276 config: serde_json::json!({}),
277 },
278 TaskDefinition {
279 id: "b".to_string(),
280 operator: "bash".to_string(),
281 depends_on: Some(vec!["a".to_string()]),
282 retries: None,
283 retry_delay_secs: None,
284 timeout_secs: None,
285 xcom_inputs: None,
286 config: serde_json::json!({}),
287 },
288 TaskDefinition {
289 id: "c".to_string(),
290 operator: "bash".to_string(),
291 depends_on: Some(vec!["b".to_string()]),
292 retries: None,
293 retry_delay_secs: None,
294 timeout_secs: None,
295 xcom_inputs: None,
296 config: serde_json::json!({}),
297 },
298 ],
299 };
300
301 let result = dag.task_execution_order();
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn test_root_tasks() {
307 let dag = DagDefinition {
308 id: "test_dag".to_string(),
309 description: None,
310 schedule: None,
311 max_active_runs: None,
312 catchup: None,
313 tasks: vec![
314 TaskDefinition {
315 id: "a".to_string(),
316 operator: "bash".to_string(),
317 depends_on: None,
318 retries: None,
319 retry_delay_secs: None,
320 timeout_secs: None,
321 xcom_inputs: None,
322 config: serde_json::json!({}),
323 },
324 TaskDefinition {
325 id: "b".to_string(),
326 operator: "bash".to_string(),
327 depends_on: Some(vec!["a".to_string()]),
328 retries: None,
329 retry_delay_secs: None,
330 timeout_secs: None,
331 xcom_inputs: None,
332 config: serde_json::json!({}),
333 },
334 ],
335 };
336
337 let roots = dag.root_tasks();
338 assert_eq!(roots, vec!["a"]);
339 }
340}