Skip to main content

ironflow/store/
mod.rs

1use anyhow::Result;
2use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
3use sqlx::Row;
4use crate::dag::{DagDefinition, DagRun, DagRunStatus, TaskRun, TaskRunStatus, TriggerType};
5use chrono::{DateTime, Utc};
6use uuid::Uuid;
7
8pub struct Store {
9    pool: SqlitePool,
10}
11
12impl Store {
13    /// Create a new store and initialize the database
14    pub async fn new(database_url: &str) -> Result<Self> {
15        // Add mode=rwc if not already present to allow creating new databases
16        let db_url = if database_url.contains("mode=") {
17            database_url.to_string()
18        } else {
19            format!("{}?mode=rwc", database_url)
20        };
21        
22        let pool = SqlitePoolOptions::new()
23            .max_connections(5)
24            .connect(&db_url)
25            .await?;
26
27        // Enable Write-Ahead Logging for better concurrency
28        sqlx::query("PRAGMA journal_mode = WAL;")
29            .execute(&pool)
30            .await?;
31        sqlx::query("PRAGMA synchronous = NORMAL;")
32            .execute(&pool)
33            .await?;
34
35        sqlx::query(
36            "CREATE TABLE IF NOT EXISTS dags (
37                id TEXT PRIMARY KEY,
38                definition TEXT NOT NULL,
39                is_paused BOOLEAN NOT NULL DEFAULT 0,
40                created_at DATETIME NOT NULL,
41                updated_at DATETIME NOT NULL
42            )",
43        )
44        .execute(&pool)
45        .await?;
46
47        sqlx::query(
48            "CREATE TABLE IF NOT EXISTS dag_runs (
49                id TEXT PRIMARY KEY,
50                dag_id TEXT NOT NULL,
51                status TEXT NOT NULL,
52                started_at DATETIME NOT NULL,
53                ended_at DATETIME,
54                triggered_by TEXT NOT NULL,
55                run_number INTEGER NOT NULL,
56                FOREIGN KEY (dag_id) REFERENCES dags(id)
57            )",
58        )
59        .execute(&pool)
60        .await?;
61
62        sqlx::query(
63            "CREATE TABLE IF NOT EXISTS task_runs (
64                id TEXT PRIMARY KEY,
65                dag_run_id TEXT NOT NULL,
66                task_id TEXT NOT NULL,
67                status TEXT NOT NULL,
68                started_at DATETIME,
69                ended_at DATETIME,
70                attempt_number INTEGER NOT NULL,
71                log TEXT NOT NULL DEFAULT '',
72                xcom_output TEXT,
73                FOREIGN KEY (dag_run_id) REFERENCES dag_runs(id)
74            )",
75        )
76        .execute(&pool)
77        .await?;
78
79        Ok(Store { pool })
80    }
81
82    // ===== DAG Operations =====
83
84    pub async fn save_dag(&self, dag: &DagDefinition) -> Result<()> {
85        let definition = serde_json::to_string(dag)?;
86        let now = Utc::now();
87        let now_str = now.to_rfc3339();
88
89        sqlx::query(
90            "INSERT OR REPLACE INTO dags (id, definition, is_paused, created_at, updated_at)
91             VALUES (?, ?, 0, ?, ?)",
92        )
93        .bind(&dag.id)
94        .bind(&definition)
95        .bind(&now_str)
96        .bind(&now_str)
97        .execute(&self.pool)
98        .await?;
99
100        Ok(())
101    }
102
103    pub async fn get_dag(&self, dag_id: &str) -> Result<Option<DagDefinition>> {
104        let row = sqlx::query("SELECT definition FROM dags WHERE id = ?")
105            .bind(dag_id)
106            .fetch_optional(&self.pool)
107            .await?;
108
109        Ok(row.and_then(|r| {
110            let definition_str: String = r.get("definition");
111            serde_json::from_str(&definition_str).ok()
112        }))
113    }
114
115    pub async fn get_all_dags(&self) -> Result<Vec<DagDefinition>> {
116        let rows = sqlx::query("SELECT definition FROM dags")
117            .fetch_all(&self.pool)
118            .await?;
119
120        let dags = rows
121            .into_iter()
122            .filter_map(|row| {
123                let definition_str: String = row.get("definition");
124                serde_json::from_str(&definition_str).ok()
125            })
126            .collect();
127
128        Ok(dags)
129    }
130
131    pub async fn pause_dag(&self, dag_id: &str) -> Result<()> {
132        sqlx::query("UPDATE dags SET is_paused = 1 WHERE id = ?")
133            .bind(dag_id)
134            .execute(&self.pool)
135            .await?;
136
137        Ok(())
138    }
139
140    pub async fn unpause_dag(&self, dag_id: &str) -> Result<()> {
141        sqlx::query("UPDATE dags SET is_paused = 0 WHERE id = ?")
142            .bind(dag_id)
143            .execute(&self.pool)
144            .await?;
145
146        Ok(())
147    }
148
149    pub async fn is_dag_paused(&self, dag_id: &str) -> Result<bool> {
150        let row = sqlx::query("SELECT is_paused FROM dags WHERE id = ?")
151            .bind(dag_id)
152            .fetch_optional(&self.pool)
153            .await?;
154
155        Ok(row.map(|r| r.get::<bool, _>("is_paused")).unwrap_or(false))
156    }
157
158    /// Recover orphaned DAG and task runs from a previous crash
159    /// Marks any Running tasks/runs as Failed with a system message
160    pub async fn recover_orphaned_runs(&self) -> Result<()> {
161        let now = Utc::now().to_rfc3339();
162        
163        // Find and mark orphaned task runs as Failed
164        let orphaned_tasks = sqlx::query(
165            "SELECT id FROM task_runs WHERE status = ?",
166        )
167        .bind(TaskRunStatus::Running.to_string())
168        .fetch_all(&self.pool)
169        .await?;
170
171        for task_row in orphaned_tasks {
172            let task_run_id: String = task_row.get("id");
173            let recovery_msg = "Orphaned by executor crash — marked failed on restart";
174            
175            sqlx::query(
176                "UPDATE task_runs SET status = ?, ended_at = ?, log = log || '\n' || ? WHERE id = ?",
177            )
178            .bind(TaskRunStatus::Failed.to_string())
179            .bind(&now)
180            .bind(recovery_msg)
181            .bind(&task_run_id)
182            .execute(&self.pool)
183            .await?;
184            
185            tracing::info!("Recovered orphaned task run: {}", task_run_id);
186        }
187
188        // Find and mark orphaned DAG runs as Failed
189        let orphaned_runs = sqlx::query(
190            "SELECT id FROM dag_runs WHERE status = ?",
191        )
192        .bind(DagRunStatus::Running.to_string())
193        .fetch_all(&self.pool)
194        .await?;
195
196        for run_row in orphaned_runs {
197            let dag_run_id: String = run_row.get("id");
198            
199            sqlx::query(
200                "UPDATE dag_runs SET status = ?, ended_at = ? WHERE id = ?",
201            )
202            .bind(DagRunStatus::Failed.to_string())
203            .bind(&now)
204            .bind(&dag_run_id)
205            .execute(&self.pool)
206            .await?;
207            
208            tracing::info!("Recovered orphaned DAG run: {}", dag_run_id);
209        }
210
211        Ok(())
212    }
213
214    // ===== DAG Run Operations =====
215
216    pub async fn create_dag_run(
217        &self,
218        dag_id: &str,
219        triggered_by: TriggerType,
220    ) -> Result<DagRun> {
221        let run_id = Uuid::new_v4().to_string();
222        let now = Utc::now();
223        let now_str = now.to_rfc3339();
224
225        // Get the next run number for this DAG
226        let run_number: i64 = sqlx::query_scalar(
227            "SELECT COALESCE(MAX(run_number), 0) + 1 FROM dag_runs WHERE dag_id = ?",
228        )
229        .bind(dag_id)
230        .fetch_one(&self.pool)
231        .await?;
232
233        sqlx::query(
234            "INSERT INTO dag_runs (id, dag_id, status, started_at, triggered_by, run_number)
235             VALUES (?, ?, ?, ?, ?, ?)",
236        )
237        .bind(&run_id)
238        .bind(dag_id)
239        .bind(DagRunStatus::Queued.to_string())
240        .bind(&now_str)
241        .bind(triggered_by.to_string())
242        .bind(run_number)
243        .execute(&self.pool)
244        .await?;
245
246        Ok(DagRun {
247            id: run_id,
248            dag_id: dag_id.to_string(),
249            status: DagRunStatus::Queued,
250            started_at: now,
251            ended_at: None,
252            triggered_by,
253            run_number: run_number as u32,
254        })
255    }
256
257    pub async fn get_dag_run(&self, run_id: &str) -> Result<Option<DagRun>> {
258        let row = sqlx::query(
259            "SELECT id, dag_id, status, started_at, ended_at, triggered_by, run_number
260             FROM dag_runs WHERE id = ?",
261        )
262        .bind(run_id)
263        .fetch_optional(&self.pool)
264        .await?;
265
266        Ok(row.and_then(|r| {
267            let status_str: String = r.get("status");
268            let triggered_str: String = r.get("triggered_by");
269            let started_at_str: String = r.get("started_at");
270            let started_at = DateTime::parse_from_rfc3339(&started_at_str)
271                .ok()?
272                .with_timezone(&Utc);
273
274            Some(DagRun {
275                id: r.get("id"),
276                dag_id: r.get("dag_id"),
277                status: match status_str.as_str() {
278                    "queued" => DagRunStatus::Queued,
279                    "running" => DagRunStatus::Running,
280                    "success" => DagRunStatus::Success,
281                    "failed" => DagRunStatus::Failed,
282                    _ => DagRunStatus::Queued,
283                },
284                started_at,
285                ended_at: {
286                    let ended_at_str: Option<String> = r.get("ended_at");
287                    ended_at_str.and_then(|s| {
288                        DateTime::parse_from_rfc3339(&s)
289                            .ok()
290                            .map(|dt| dt.with_timezone(&Utc))
291                    })
292                },
293                triggered_by: match triggered_str.as_str() {
294                    "schedule" => TriggerType::Schedule,
295                    "manual" => TriggerType::Manual,
296                    _ => TriggerType::Manual,
297                },
298                run_number: r.get::<i64, _>("run_number") as u32,
299            })
300        }))
301    }
302
303    pub async fn get_dag_runs(&self, dag_id: &str, limit: i64) -> Result<Vec<DagRun>> {
304        let rows = sqlx::query(
305            "SELECT id, dag_id, status, started_at, ended_at, triggered_by, run_number
306             FROM dag_runs WHERE dag_id = ? ORDER BY started_at DESC LIMIT ?",
307        )
308        .bind(dag_id)
309        .bind(limit)
310        .fetch_all(&self.pool)
311        .await?;
312
313        let runs = rows
314            .into_iter()
315            .filter_map(|r| {
316                let status_str: String = r.get("status");
317                let triggered_str: String = r.get("triggered_by");
318                let started_at_str: String = r.get("started_at");
319                let started_at = DateTime::parse_from_rfc3339(&started_at_str)
320                    .ok()?
321                    .with_timezone(&Utc);
322
323                Some(DagRun {
324                    id: r.get("id"),
325                    dag_id: r.get("dag_id"),
326                    status: match status_str.as_str() {
327                        "queued" => DagRunStatus::Queued,
328                        "running" => DagRunStatus::Running,
329                        "success" => DagRunStatus::Success,
330                        "failed" => DagRunStatus::Failed,
331                        _ => DagRunStatus::Queued,
332                    },
333                    started_at,
334                    ended_at: {
335                        let ended_at_str: Option<String> = r.get("ended_at");
336                        ended_at_str.and_then(|s| {
337                            DateTime::parse_from_rfc3339(&s)
338                                .ok()
339                                .map(|dt| dt.with_timezone(&Utc))
340                        })
341                    },
342                    triggered_by: match triggered_str.as_str() {
343                        "schedule" => TriggerType::Schedule,
344                        "manual" => TriggerType::Manual,
345                        _ => TriggerType::Manual,
346                    },
347                    run_number: r.get::<i64, _>("run_number") as u32,
348                })
349            })
350            .collect();
351
352        Ok(runs)
353    }
354
355    pub async fn update_dag_run_status(&self, run_id: &str, status: DagRunStatus) -> Result<()> {
356        let ended_at = if matches!(status, DagRunStatus::Success | DagRunStatus::Failed) {
357            Some(Utc::now().to_rfc3339())
358        } else {
359            None
360        };
361
362        sqlx::query(
363            "UPDATE dag_runs SET status = ?, ended_at = ? WHERE id = ?",
364        )
365        .bind(status.to_string())
366        .bind(ended_at)
367        .bind(run_id)
368        .execute(&self.pool)
369        .await?;
370
371        Ok(())
372    }
373
374    // ===== Task Run Operations =====
375
376    pub async fn create_task_run(
377        &self,
378        dag_run_id: &str,
379        task_id: &str,
380    ) -> Result<TaskRun> {
381        let task_run_id = Uuid::new_v4().to_string();
382
383        sqlx::query(
384            "INSERT INTO task_runs (id, dag_run_id, task_id, status, attempt_number, log)
385             VALUES (?, ?, ?, ?, 1, '')",
386        )
387        .bind(&task_run_id)
388        .bind(dag_run_id)
389        .bind(task_id)
390        .bind(TaskRunStatus::Pending.to_string())
391        .execute(&self.pool)
392        .await?;
393
394        Ok(TaskRun {
395            id: task_run_id,
396            dag_run_id: dag_run_id.to_string(),
397            task_id: task_id.to_string(),
398            status: TaskRunStatus::Pending,
399            started_at: None,
400            ended_at: None,
401            attempt_number: 1,
402            log: String::new(),
403            xcom_output: None,
404        })
405    }
406
407    pub async fn get_task_run(&self, task_run_id: &str) -> Result<Option<TaskRun>> {
408        let row = sqlx::query(
409            "SELECT id, dag_run_id, task_id, status, started_at, ended_at, attempt_number, log, xcom_output
410             FROM task_runs WHERE id = ?",
411        )
412        .bind(task_run_id)
413        .fetch_optional(&self.pool)
414        .await?;
415
416        Ok(row.map(|r| {
417            let status_str: String = r.get("status");
418            let started_at_str: Option<String> = r.get("started_at");
419            let ended_at_str: Option<String> = r.get("ended_at");
420
421            TaskRun {
422                id: r.get("id"),
423                dag_run_id: r.get("dag_run_id"),
424                task_id: r.get("task_id"),
425                status: match status_str.as_str() {
426                    "pending" => TaskRunStatus::Pending,
427                    "running" => TaskRunStatus::Running,
428                    "success" => TaskRunStatus::Success,
429                    "failed" => TaskRunStatus::Failed,
430                    "retried" => TaskRunStatus::Retried,
431                    "skipped" => TaskRunStatus::Skipped,
432                    _ => TaskRunStatus::Pending,
433                },
434                started_at: started_at_str.and_then(|s| {
435                    DateTime::parse_from_rfc3339(&s)
436                        .ok()
437                        .map(|dt| dt.with_timezone(&Utc))
438                }),
439                ended_at: ended_at_str.and_then(|s| {
440                    DateTime::parse_from_rfc3339(&s)
441                        .ok()
442                        .map(|dt| dt.with_timezone(&Utc))
443                }),
444                attempt_number: r.get::<i64, _>("attempt_number") as u32,
445                log: r.get("log"),
446                xcom_output: r.get("xcom_output"),
447            }
448        }))
449    }
450
451    pub async fn get_task_runs_for_dag_run(&self, dag_run_id: &str) -> Result<Vec<TaskRun>> {
452        let rows = sqlx::query(
453            "SELECT id, dag_run_id, task_id, status, started_at, ended_at, attempt_number, log, xcom_output
454             FROM task_runs WHERE dag_run_id = ?",
455        )
456        .bind(dag_run_id)
457        .fetch_all(&self.pool)
458        .await?;
459
460        let task_runs = rows
461            .into_iter()
462            .map(|r| {
463                let status_str: String = r.get("status");
464                let started_at_str: Option<String> = r.get("started_at");
465                let ended_at_str: Option<String> = r.get("ended_at");
466
467                TaskRun {
468                    id: r.get("id"),
469                    dag_run_id: r.get("dag_run_id"),
470                    task_id: r.get("task_id"),
471                    status: match status_str.as_str() {
472                        "pending" => TaskRunStatus::Pending,
473                        "running" => TaskRunStatus::Running,
474                        "success" => TaskRunStatus::Success,
475                        "failed" => TaskRunStatus::Failed,
476                        "retried" => TaskRunStatus::Retried,
477                        "skipped" => TaskRunStatus::Skipped,
478                        _ => TaskRunStatus::Pending,
479                    },
480                    started_at: started_at_str.and_then(|s| {
481                        DateTime::parse_from_rfc3339(&s)
482                            .ok()
483                            .map(|dt| dt.with_timezone(&Utc))
484                    }),
485                    ended_at: ended_at_str.and_then(|s| {
486                        DateTime::parse_from_rfc3339(&s)
487                            .ok()
488                            .map(|dt| dt.with_timezone(&Utc))
489                    }),
490                    attempt_number: r.get::<i64, _>("attempt_number") as u32,
491                    log: r.get("log"),
492                    xcom_output: r.get("xcom_output"),
493                }
494            })
495            .collect();
496
497        Ok(task_runs)
498    }
499
500    pub async fn update_task_run(
501        &self,
502        task_run_id: &str,
503        status: TaskRunStatus,
504        log_append: Option<&str>,
505        xcom_output: Option<String>,
506    ) -> Result<()> {
507        let started_at = if matches!(status, TaskRunStatus::Running) {
508            Some(Utc::now().to_rfc3339())
509        } else {
510            None
511        };
512
513        let ended_at = if matches!(
514            status,
515            TaskRunStatus::Success | TaskRunStatus::Failed | TaskRunStatus::Skipped
516        ) {
517            Some(Utc::now().to_rfc3339())
518        } else {
519            None
520        };
521
522        // Get current log and append
523        let mut new_log = String::new();
524        if let Some(append) = log_append {
525            if let Ok(Some(task_run)) = self.get_task_run(task_run_id).await {
526                new_log = format!("{}\n{}", task_run.log, append);
527            } else {
528                new_log = append.to_string();
529            }
530        }
531
532        sqlx::query(
533            "UPDATE task_runs SET status = ?, started_at = COALESCE(started_at, ?), 
534             ended_at = ?, log = CASE WHEN ? THEN ? ELSE log END, xcom_output = COALESCE(?, xcom_output)
535             WHERE id = ?",
536        )
537        .bind(status.to_string())
538        .bind(&started_at)
539        .bind(&ended_at)
540        .bind(!new_log.is_empty())
541        .bind(&new_log)
542        .bind(&xcom_output)
543        .bind(task_run_id)
544        .execute(&self.pool)
545        .await?;
546
547        Ok(())
548    }
549
550    pub async fn increment_task_run_attempt(&self, task_run_id: &str) -> Result<u32> {
551        let new_attempt: i64 = sqlx::query_scalar(
552            "UPDATE task_runs SET attempt_number = attempt_number + 1 WHERE id = ? RETURNING attempt_number",
553        )
554        .bind(task_run_id)
555        .fetch_one(&self.pool)
556        .await?;
557
558        Ok(new_attempt as u32)
559    }
560
561    pub async fn append_task_log(&self, task_run_id: &str, log_line: &str) -> Result<()> {
562        if let Ok(Some(task_run)) = self.get_task_run(task_run_id).await {
563            let new_log = format!("{}\n{}", task_run.log, log_line);
564            sqlx::query("UPDATE task_runs SET log = ? WHERE id = ?")
565                .bind(&new_log)
566                .bind(task_run_id)
567                .execute(&self.pool)
568                .await?;
569        }
570
571        Ok(())
572    }
573
574    // ===== XCom Operations =====
575
576    /// Get XCom output for a task in a specific run
577    pub async fn get_xcom(&self, run_id: &str, task_id: &str) -> Result<Option<String>> {
578        let row = sqlx::query(
579            "SELECT xcom_output FROM task_runs WHERE dag_run_id = ? AND task_id = ? AND status = ?",
580        )
581        .bind(run_id)
582        .bind(task_id)
583        .bind(TaskRunStatus::Success.to_string())
584        .fetch_optional(&self.pool)
585        .await?;
586
587        Ok(row.map(|r| r.get::<Option<String>, _>("xcom_output")).flatten())
588    }
589
590    /// Get all XCom outputs for a DAG run, organized by task_id
591    pub async fn get_all_xcoms_for_run(&self, run_id: &str) -> Result<serde_json::Map<String, serde_json::Value>> {
592        let rows = sqlx::query(
593            "SELECT task_id, xcom_output FROM task_runs WHERE dag_run_id = ? AND status = ? AND xcom_output IS NOT NULL",
594        )
595        .bind(run_id)
596        .bind(TaskRunStatus::Success.to_string())
597        .fetch_all(&self.pool)
598        .await?;
599
600        let mut xcoms = serde_json::Map::new();
601        for row in rows {
602            let task_id: String = row.get("task_id");
603            let xcom_output: Option<String> = row.get("xcom_output");
604            if let Some(output) = xcom_output {
605                xcoms.insert(task_id, serde_json::json!(output));
606            }
607        }
608
609        Ok(xcoms)
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[tokio::test]
618    async fn test_store_creation() {
619        let store = Store::new("sqlite::memory:").await;
620        assert!(store.is_ok());
621    }
622
623    #[tokio::test]
624    async fn test_save_and_get_dag() {
625        let store = Store::new("sqlite::memory:").await.unwrap();
626        let dag = DagDefinition {
627            id: "test_dag".to_string(),
628            description: Some("Test".to_string()),
629            schedule: None,
630            max_active_runs: None,
631            catchup: None,
632            tasks: vec![],
633        };
634
635        store.save_dag(&dag).await.unwrap();
636        let retrieved = store.get_dag("test_dag").await.unwrap();
637        assert!(retrieved.is_some());
638        assert_eq!(retrieved.unwrap().id, "test_dag");
639    }
640
641    #[tokio::test]
642    async fn test_dag_run_creation_and_retrieval() {
643        let store = Store::new("sqlite::memory:").await.unwrap();
644        let dag = DagDefinition {
645            id: "test_dag".to_string(),
646            description: None,
647            schedule: None,
648            max_active_runs: None,
649            catchup: None,
650            tasks: vec![],
651        };
652
653        store.save_dag(&dag).await.unwrap();
654        
655        let dag_run = store
656            .create_dag_run("test_dag", TriggerType::Manual)
657            .await
658            .unwrap();
659
660        assert_eq!(dag_run.dag_id, "test_dag");
661        assert_eq!(dag_run.status, DagRunStatus::Queued);
662
663        let retrieved = store.get_dag_run(&dag_run.id).await.unwrap();
664        assert!(retrieved.is_some());
665    }
666
667    #[tokio::test]
668    async fn test_task_run_creation_and_update() {
669        let store = Store::new("sqlite::memory:").await.unwrap();
670        let dag = DagDefinition {
671            id: "test_dag".to_string(),
672            description: None,
673            schedule: None,
674            max_active_runs: None,
675            catchup: None,
676            tasks: vec![],
677        };
678
679        store.save_dag(&dag).await.unwrap();
680        
681        let dag_run = store
682            .create_dag_run("test_dag", TriggerType::Manual)
683            .await
684            .unwrap();
685
686        let task_run = store
687            .create_task_run(&dag_run.id, "task_1")
688            .await
689            .unwrap();
690
691        assert_eq!(task_run.status, TaskRunStatus::Pending);
692
693        store
694            .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
695            .await
696            .unwrap();
697
698        let retrieved = store.get_task_run(&task_run.id).await.unwrap().unwrap();
699        assert_eq!(retrieved.status, TaskRunStatus::Running);
700    }
701
702    #[tokio::test]
703    async fn test_crash_recovery_orphaned_runs() {
704        let store = Store::new("sqlite::memory:").await.unwrap();
705        let dag = DagDefinition {
706            id: "test_dag".to_string(),
707            description: None,
708            schedule: None,
709            max_active_runs: None,
710            catchup: None,
711            tasks: vec![],
712        };
713
714        store.save_dag(&dag).await.unwrap();
715        
716        // Create a DAG run and mark it as Running (simulating a crash)
717        let dag_run = store
718            .create_dag_run("test_dag", TriggerType::Manual)
719            .await
720            .unwrap();
721        
722        store
723            .update_dag_run_status(&dag_run.id, DagRunStatus::Running)
724            .await
725            .unwrap();
726
727        let task_run = store
728            .create_task_run(&dag_run.id, "task_1")
729            .await
730            .unwrap();
731
732        store
733            .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
734            .await
735            .unwrap();
736
737        // Verify they are Running
738        let dag_run_before = store.get_dag_run(&dag_run.id).await.unwrap().unwrap();
739        assert_eq!(dag_run_before.status, DagRunStatus::Running);
740        
741        let task_run_before = store.get_task_run(&task_run.id).await.unwrap().unwrap();
742        assert_eq!(task_run_before.status, TaskRunStatus::Running);
743
744        // Now recover from crash
745        store.recover_orphaned_runs().await.unwrap();
746
747        // Verify they are now Failed
748        let dag_run_after = store.get_dag_run(&dag_run.id).await.unwrap().unwrap();
749        assert_eq!(dag_run_after.status, DagRunStatus::Failed);
750        assert!(dag_run_after.ended_at.is_some());
751        
752        let task_run_after = store.get_task_run(&task_run.id).await.unwrap().unwrap();
753        assert_eq!(task_run_after.status, TaskRunStatus::Failed);
754        assert!(task_run_after.ended_at.is_some());
755        assert!(task_run_after.log.contains("Orphaned by executor crash"));
756    }
757
758    #[tokio::test]
759    async fn test_xcom_get_and_retrieve() {
760        let store = Store::new("sqlite::memory:").await.unwrap();
761        let dag = DagDefinition {
762            id: "test_dag".to_string(),
763            description: None,
764            schedule: None,
765            max_active_runs: None,
766            catchup: None,
767            tasks: vec![],
768        };
769
770        store.save_dag(&dag).await.unwrap();
771        let dag_run = store
772            .create_dag_run("test_dag", TriggerType::Manual)
773            .await
774            .unwrap();
775
776        let task_run = store
777            .create_task_run(&dag_run.id, "task_1")
778            .await
779            .unwrap();
780
781        // Simulate task execution with XCom output
782        let xcom_output = r#"{"result": "success", "count": 42}"#.to_string();
783        store
784            .update_task_run(&task_run.id, TaskRunStatus::Success, Some("Task completed"), Some(xcom_output.clone()))
785            .await
786            .unwrap();
787
788        // Retrieve XCom output
789        let retrieved_xcom = store.get_xcom(&dag_run.id, "task_1").await.unwrap();
790        assert!(retrieved_xcom.is_some());
791        assert_eq!(retrieved_xcom.unwrap(), xcom_output);
792    }
793
794    #[tokio::test]
795    async fn test_get_all_xcoms_for_run() {
796        let store = Store::new("sqlite::memory:").await.unwrap();
797        let dag = DagDefinition {
798            id: "test_dag".to_string(),
799            description: None,
800            schedule: None,
801            max_active_runs: None,
802            catchup: None,
803            tasks: vec![],
804        };
805
806        store.save_dag(&dag).await.unwrap();
807        let dag_run = store
808            .create_dag_run("test_dag", TriggerType::Manual)
809            .await
810            .unwrap();
811
812        // Create multiple tasks with XCom outputs
813        let task1 = store.create_task_run(&dag_run.id, "task_1").await.unwrap();
814        let task2 = store.create_task_run(&dag_run.id, "task_2").await.unwrap();
815
816        store
817            .update_task_run(&task1.id, TaskRunStatus::Success, None, Some(r#"{"value": 1}"#.to_string()))
818            .await
819            .unwrap();
820
821        store
822            .update_task_run(&task2.id, TaskRunStatus::Success, None, Some(r#"{"value": 2}"#.to_string()))
823            .await
824            .unwrap();
825
826        // Retrieve all XComs
827        let all_xcoms = store.get_all_xcoms_for_run(&dag_run.id).await.unwrap();
828        assert_eq!(all_xcoms.len(), 2);
829        assert!(all_xcoms.contains_key("task_1"));
830        assert!(all_xcoms.contains_key("task_2"));
831    }
832}