1use anyhow::Result;
2use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
3use sqlx::Row;
4use crate::dag::{DagDefinition, DagRun, DagRunStatus, TaskRun, TaskRunStatus, TriggerType};
5use chrono::{DateTime, Utc};
6use uuid::Uuid;
7
8pub struct Store {
9 pool: SqlitePool,
10}
11
12impl Store {
13 pub async fn new(database_url: &str) -> Result<Self> {
15 let db_url = if database_url.contains("mode=") {
17 database_url.to_string()
18 } else {
19 format!("{}?mode=rwc", database_url)
20 };
21
22 let pool = SqlitePoolOptions::new()
23 .max_connections(5)
24 .connect(&db_url)
25 .await?;
26
27 sqlx::query("PRAGMA journal_mode = WAL;")
29 .execute(&pool)
30 .await?;
31 sqlx::query("PRAGMA synchronous = NORMAL;")
32 .execute(&pool)
33 .await?;
34
35 sqlx::query(
36 "CREATE TABLE IF NOT EXISTS dags (
37 id TEXT PRIMARY KEY,
38 definition TEXT NOT NULL,
39 is_paused BOOLEAN NOT NULL DEFAULT 0,
40 created_at DATETIME NOT NULL,
41 updated_at DATETIME NOT NULL
42 )",
43 )
44 .execute(&pool)
45 .await?;
46
47 sqlx::query(
48 "CREATE TABLE IF NOT EXISTS dag_runs (
49 id TEXT PRIMARY KEY,
50 dag_id TEXT NOT NULL,
51 status TEXT NOT NULL,
52 started_at DATETIME NOT NULL,
53 ended_at DATETIME,
54 triggered_by TEXT NOT NULL,
55 run_number INTEGER NOT NULL,
56 FOREIGN KEY (dag_id) REFERENCES dags(id)
57 )",
58 )
59 .execute(&pool)
60 .await?;
61
62 sqlx::query(
63 "CREATE TABLE IF NOT EXISTS task_runs (
64 id TEXT PRIMARY KEY,
65 dag_run_id TEXT NOT NULL,
66 task_id TEXT NOT NULL,
67 status TEXT NOT NULL,
68 started_at DATETIME,
69 ended_at DATETIME,
70 attempt_number INTEGER NOT NULL,
71 log TEXT NOT NULL DEFAULT '',
72 xcom_output TEXT,
73 FOREIGN KEY (dag_run_id) REFERENCES dag_runs(id)
74 )",
75 )
76 .execute(&pool)
77 .await?;
78
79 Ok(Store { pool })
80 }
81
82 pub async fn save_dag(&self, dag: &DagDefinition) -> Result<()> {
85 let definition = serde_json::to_string(dag)?;
86 let now = Utc::now();
87 let now_str = now.to_rfc3339();
88
89 sqlx::query(
90 "INSERT OR REPLACE INTO dags (id, definition, is_paused, created_at, updated_at)
91 VALUES (?, ?, 0, ?, ?)",
92 )
93 .bind(&dag.id)
94 .bind(&definition)
95 .bind(&now_str)
96 .bind(&now_str)
97 .execute(&self.pool)
98 .await?;
99
100 Ok(())
101 }
102
103 pub async fn get_dag(&self, dag_id: &str) -> Result<Option<DagDefinition>> {
104 let row = sqlx::query("SELECT definition FROM dags WHERE id = ?")
105 .bind(dag_id)
106 .fetch_optional(&self.pool)
107 .await?;
108
109 Ok(row.and_then(|r| {
110 let definition_str: String = r.get("definition");
111 serde_json::from_str(&definition_str).ok()
112 }))
113 }
114
115 pub async fn get_all_dags(&self) -> Result<Vec<DagDefinition>> {
116 let rows = sqlx::query("SELECT definition FROM dags")
117 .fetch_all(&self.pool)
118 .await?;
119
120 let dags = rows
121 .into_iter()
122 .filter_map(|row| {
123 let definition_str: String = row.get("definition");
124 serde_json::from_str(&definition_str).ok()
125 })
126 .collect();
127
128 Ok(dags)
129 }
130
131 pub async fn pause_dag(&self, dag_id: &str) -> Result<()> {
132 sqlx::query("UPDATE dags SET is_paused = 1 WHERE id = ?")
133 .bind(dag_id)
134 .execute(&self.pool)
135 .await?;
136
137 Ok(())
138 }
139
140 pub async fn unpause_dag(&self, dag_id: &str) -> Result<()> {
141 sqlx::query("UPDATE dags SET is_paused = 0 WHERE id = ?")
142 .bind(dag_id)
143 .execute(&self.pool)
144 .await?;
145
146 Ok(())
147 }
148
149 pub async fn is_dag_paused(&self, dag_id: &str) -> Result<bool> {
150 let row = sqlx::query("SELECT is_paused FROM dags WHERE id = ?")
151 .bind(dag_id)
152 .fetch_optional(&self.pool)
153 .await?;
154
155 Ok(row.map(|r| r.get::<bool, _>("is_paused")).unwrap_or(false))
156 }
157
158 pub async fn recover_orphaned_runs(&self) -> Result<()> {
161 let now = Utc::now().to_rfc3339();
162
163 let orphaned_tasks = sqlx::query(
165 "SELECT id FROM task_runs WHERE status = ?",
166 )
167 .bind(TaskRunStatus::Running.to_string())
168 .fetch_all(&self.pool)
169 .await?;
170
171 for task_row in orphaned_tasks {
172 let task_run_id: String = task_row.get("id");
173 let recovery_msg = "Orphaned by executor crash — marked failed on restart";
174
175 sqlx::query(
176 "UPDATE task_runs SET status = ?, ended_at = ?, log = log || '\n' || ? WHERE id = ?",
177 )
178 .bind(TaskRunStatus::Failed.to_string())
179 .bind(&now)
180 .bind(recovery_msg)
181 .bind(&task_run_id)
182 .execute(&self.pool)
183 .await?;
184
185 tracing::info!("Recovered orphaned task run: {}", task_run_id);
186 }
187
188 let orphaned_runs = sqlx::query(
190 "SELECT id FROM dag_runs WHERE status = ?",
191 )
192 .bind(DagRunStatus::Running.to_string())
193 .fetch_all(&self.pool)
194 .await?;
195
196 for run_row in orphaned_runs {
197 let dag_run_id: String = run_row.get("id");
198
199 sqlx::query(
200 "UPDATE dag_runs SET status = ?, ended_at = ? WHERE id = ?",
201 )
202 .bind(DagRunStatus::Failed.to_string())
203 .bind(&now)
204 .bind(&dag_run_id)
205 .execute(&self.pool)
206 .await?;
207
208 tracing::info!("Recovered orphaned DAG run: {}", dag_run_id);
209 }
210
211 Ok(())
212 }
213
214 pub async fn create_dag_run(
217 &self,
218 dag_id: &str,
219 triggered_by: TriggerType,
220 ) -> Result<DagRun> {
221 let run_id = Uuid::new_v4().to_string();
222 let now = Utc::now();
223 let now_str = now.to_rfc3339();
224
225 let run_number: i64 = sqlx::query_scalar(
227 "SELECT COALESCE(MAX(run_number), 0) + 1 FROM dag_runs WHERE dag_id = ?",
228 )
229 .bind(dag_id)
230 .fetch_one(&self.pool)
231 .await?;
232
233 sqlx::query(
234 "INSERT INTO dag_runs (id, dag_id, status, started_at, triggered_by, run_number)
235 VALUES (?, ?, ?, ?, ?, ?)",
236 )
237 .bind(&run_id)
238 .bind(dag_id)
239 .bind(DagRunStatus::Queued.to_string())
240 .bind(&now_str)
241 .bind(triggered_by.to_string())
242 .bind(run_number)
243 .execute(&self.pool)
244 .await?;
245
246 Ok(DagRun {
247 id: run_id,
248 dag_id: dag_id.to_string(),
249 status: DagRunStatus::Queued,
250 started_at: now,
251 ended_at: None,
252 triggered_by,
253 run_number: run_number as u32,
254 })
255 }
256
257 pub async fn get_dag_run(&self, run_id: &str) -> Result<Option<DagRun>> {
258 let row = sqlx::query(
259 "SELECT id, dag_id, status, started_at, ended_at, triggered_by, run_number
260 FROM dag_runs WHERE id = ?",
261 )
262 .bind(run_id)
263 .fetch_optional(&self.pool)
264 .await?;
265
266 Ok(row.and_then(|r| {
267 let status_str: String = r.get("status");
268 let triggered_str: String = r.get("triggered_by");
269 let started_at_str: String = r.get("started_at");
270 let started_at = DateTime::parse_from_rfc3339(&started_at_str)
271 .ok()?
272 .with_timezone(&Utc);
273
274 Some(DagRun {
275 id: r.get("id"),
276 dag_id: r.get("dag_id"),
277 status: match status_str.as_str() {
278 "queued" => DagRunStatus::Queued,
279 "running" => DagRunStatus::Running,
280 "success" => DagRunStatus::Success,
281 "failed" => DagRunStatus::Failed,
282 _ => DagRunStatus::Queued,
283 },
284 started_at,
285 ended_at: {
286 let ended_at_str: Option<String> = r.get("ended_at");
287 ended_at_str.and_then(|s| {
288 DateTime::parse_from_rfc3339(&s)
289 .ok()
290 .map(|dt| dt.with_timezone(&Utc))
291 })
292 },
293 triggered_by: match triggered_str.as_str() {
294 "schedule" => TriggerType::Schedule,
295 "manual" => TriggerType::Manual,
296 _ => TriggerType::Manual,
297 },
298 run_number: r.get::<i64, _>("run_number") as u32,
299 })
300 }))
301 }
302
303 pub async fn get_dag_runs(&self, dag_id: &str, limit: i64) -> Result<Vec<DagRun>> {
304 let rows = sqlx::query(
305 "SELECT id, dag_id, status, started_at, ended_at, triggered_by, run_number
306 FROM dag_runs WHERE dag_id = ? ORDER BY started_at DESC LIMIT ?",
307 )
308 .bind(dag_id)
309 .bind(limit)
310 .fetch_all(&self.pool)
311 .await?;
312
313 let runs = rows
314 .into_iter()
315 .filter_map(|r| {
316 let status_str: String = r.get("status");
317 let triggered_str: String = r.get("triggered_by");
318 let started_at_str: String = r.get("started_at");
319 let started_at = DateTime::parse_from_rfc3339(&started_at_str)
320 .ok()?
321 .with_timezone(&Utc);
322
323 Some(DagRun {
324 id: r.get("id"),
325 dag_id: r.get("dag_id"),
326 status: match status_str.as_str() {
327 "queued" => DagRunStatus::Queued,
328 "running" => DagRunStatus::Running,
329 "success" => DagRunStatus::Success,
330 "failed" => DagRunStatus::Failed,
331 _ => DagRunStatus::Queued,
332 },
333 started_at,
334 ended_at: {
335 let ended_at_str: Option<String> = r.get("ended_at");
336 ended_at_str.and_then(|s| {
337 DateTime::parse_from_rfc3339(&s)
338 .ok()
339 .map(|dt| dt.with_timezone(&Utc))
340 })
341 },
342 triggered_by: match triggered_str.as_str() {
343 "schedule" => TriggerType::Schedule,
344 "manual" => TriggerType::Manual,
345 _ => TriggerType::Manual,
346 },
347 run_number: r.get::<i64, _>("run_number") as u32,
348 })
349 })
350 .collect();
351
352 Ok(runs)
353 }
354
355 pub async fn update_dag_run_status(&self, run_id: &str, status: DagRunStatus) -> Result<()> {
356 let ended_at = if matches!(status, DagRunStatus::Success | DagRunStatus::Failed) {
357 Some(Utc::now().to_rfc3339())
358 } else {
359 None
360 };
361
362 sqlx::query(
363 "UPDATE dag_runs SET status = ?, ended_at = ? WHERE id = ?",
364 )
365 .bind(status.to_string())
366 .bind(ended_at)
367 .bind(run_id)
368 .execute(&self.pool)
369 .await?;
370
371 Ok(())
372 }
373
374 pub async fn create_task_run(
377 &self,
378 dag_run_id: &str,
379 task_id: &str,
380 ) -> Result<TaskRun> {
381 let task_run_id = Uuid::new_v4().to_string();
382
383 sqlx::query(
384 "INSERT INTO task_runs (id, dag_run_id, task_id, status, attempt_number, log)
385 VALUES (?, ?, ?, ?, 1, '')",
386 )
387 .bind(&task_run_id)
388 .bind(dag_run_id)
389 .bind(task_id)
390 .bind(TaskRunStatus::Pending.to_string())
391 .execute(&self.pool)
392 .await?;
393
394 Ok(TaskRun {
395 id: task_run_id,
396 dag_run_id: dag_run_id.to_string(),
397 task_id: task_id.to_string(),
398 status: TaskRunStatus::Pending,
399 started_at: None,
400 ended_at: None,
401 attempt_number: 1,
402 log: String::new(),
403 xcom_output: None,
404 })
405 }
406
407 pub async fn get_task_run(&self, task_run_id: &str) -> Result<Option<TaskRun>> {
408 let row = sqlx::query(
409 "SELECT id, dag_run_id, task_id, status, started_at, ended_at, attempt_number, log, xcom_output
410 FROM task_runs WHERE id = ?",
411 )
412 .bind(task_run_id)
413 .fetch_optional(&self.pool)
414 .await?;
415
416 Ok(row.map(|r| {
417 let status_str: String = r.get("status");
418 let started_at_str: Option<String> = r.get("started_at");
419 let ended_at_str: Option<String> = r.get("ended_at");
420
421 TaskRun {
422 id: r.get("id"),
423 dag_run_id: r.get("dag_run_id"),
424 task_id: r.get("task_id"),
425 status: match status_str.as_str() {
426 "pending" => TaskRunStatus::Pending,
427 "running" => TaskRunStatus::Running,
428 "success" => TaskRunStatus::Success,
429 "failed" => TaskRunStatus::Failed,
430 "retried" => TaskRunStatus::Retried,
431 "skipped" => TaskRunStatus::Skipped,
432 _ => TaskRunStatus::Pending,
433 },
434 started_at: started_at_str.and_then(|s| {
435 DateTime::parse_from_rfc3339(&s)
436 .ok()
437 .map(|dt| dt.with_timezone(&Utc))
438 }),
439 ended_at: ended_at_str.and_then(|s| {
440 DateTime::parse_from_rfc3339(&s)
441 .ok()
442 .map(|dt| dt.with_timezone(&Utc))
443 }),
444 attempt_number: r.get::<i64, _>("attempt_number") as u32,
445 log: r.get("log"),
446 xcom_output: r.get("xcom_output"),
447 }
448 }))
449 }
450
451 pub async fn get_task_runs_for_dag_run(&self, dag_run_id: &str) -> Result<Vec<TaskRun>> {
452 let rows = sqlx::query(
453 "SELECT id, dag_run_id, task_id, status, started_at, ended_at, attempt_number, log, xcom_output
454 FROM task_runs WHERE dag_run_id = ?",
455 )
456 .bind(dag_run_id)
457 .fetch_all(&self.pool)
458 .await?;
459
460 let task_runs = rows
461 .into_iter()
462 .map(|r| {
463 let status_str: String = r.get("status");
464 let started_at_str: Option<String> = r.get("started_at");
465 let ended_at_str: Option<String> = r.get("ended_at");
466
467 TaskRun {
468 id: r.get("id"),
469 dag_run_id: r.get("dag_run_id"),
470 task_id: r.get("task_id"),
471 status: match status_str.as_str() {
472 "pending" => TaskRunStatus::Pending,
473 "running" => TaskRunStatus::Running,
474 "success" => TaskRunStatus::Success,
475 "failed" => TaskRunStatus::Failed,
476 "retried" => TaskRunStatus::Retried,
477 "skipped" => TaskRunStatus::Skipped,
478 _ => TaskRunStatus::Pending,
479 },
480 started_at: started_at_str.and_then(|s| {
481 DateTime::parse_from_rfc3339(&s)
482 .ok()
483 .map(|dt| dt.with_timezone(&Utc))
484 }),
485 ended_at: ended_at_str.and_then(|s| {
486 DateTime::parse_from_rfc3339(&s)
487 .ok()
488 .map(|dt| dt.with_timezone(&Utc))
489 }),
490 attempt_number: r.get::<i64, _>("attempt_number") as u32,
491 log: r.get("log"),
492 xcom_output: r.get("xcom_output"),
493 }
494 })
495 .collect();
496
497 Ok(task_runs)
498 }
499
500 pub async fn update_task_run(
501 &self,
502 task_run_id: &str,
503 status: TaskRunStatus,
504 log_append: Option<&str>,
505 xcom_output: Option<String>,
506 ) -> Result<()> {
507 let started_at = if matches!(status, TaskRunStatus::Running) {
508 Some(Utc::now().to_rfc3339())
509 } else {
510 None
511 };
512
513 let ended_at = if matches!(
514 status,
515 TaskRunStatus::Success | TaskRunStatus::Failed | TaskRunStatus::Skipped
516 ) {
517 Some(Utc::now().to_rfc3339())
518 } else {
519 None
520 };
521
522 let mut new_log = String::new();
524 if let Some(append) = log_append {
525 if let Ok(Some(task_run)) = self.get_task_run(task_run_id).await {
526 new_log = format!("{}\n{}", task_run.log, append);
527 } else {
528 new_log = append.to_string();
529 }
530 }
531
532 sqlx::query(
533 "UPDATE task_runs SET status = ?, started_at = COALESCE(started_at, ?),
534 ended_at = ?, log = CASE WHEN ? THEN ? ELSE log END, xcom_output = COALESCE(?, xcom_output)
535 WHERE id = ?",
536 )
537 .bind(status.to_string())
538 .bind(&started_at)
539 .bind(&ended_at)
540 .bind(!new_log.is_empty())
541 .bind(&new_log)
542 .bind(&xcom_output)
543 .bind(task_run_id)
544 .execute(&self.pool)
545 .await?;
546
547 Ok(())
548 }
549
550 pub async fn increment_task_run_attempt(&self, task_run_id: &str) -> Result<u32> {
551 let new_attempt: i64 = sqlx::query_scalar(
552 "UPDATE task_runs SET attempt_number = attempt_number + 1 WHERE id = ? RETURNING attempt_number",
553 )
554 .bind(task_run_id)
555 .fetch_one(&self.pool)
556 .await?;
557
558 Ok(new_attempt as u32)
559 }
560
561 pub async fn append_task_log(&self, task_run_id: &str, log_line: &str) -> Result<()> {
562 if let Ok(Some(task_run)) = self.get_task_run(task_run_id).await {
563 let new_log = format!("{}\n{}", task_run.log, log_line);
564 sqlx::query("UPDATE task_runs SET log = ? WHERE id = ?")
565 .bind(&new_log)
566 .bind(task_run_id)
567 .execute(&self.pool)
568 .await?;
569 }
570
571 Ok(())
572 }
573
574 pub async fn get_xcom(&self, run_id: &str, task_id: &str) -> Result<Option<String>> {
578 let row = sqlx::query(
579 "SELECT xcom_output FROM task_runs WHERE dag_run_id = ? AND task_id = ? AND status = ?",
580 )
581 .bind(run_id)
582 .bind(task_id)
583 .bind(TaskRunStatus::Success.to_string())
584 .fetch_optional(&self.pool)
585 .await?;
586
587 Ok(row.map(|r| r.get::<Option<String>, _>("xcom_output")).flatten())
588 }
589
590 pub async fn get_all_xcoms_for_run(&self, run_id: &str) -> Result<serde_json::Map<String, serde_json::Value>> {
592 let rows = sqlx::query(
593 "SELECT task_id, xcom_output FROM task_runs WHERE dag_run_id = ? AND status = ? AND xcom_output IS NOT NULL",
594 )
595 .bind(run_id)
596 .bind(TaskRunStatus::Success.to_string())
597 .fetch_all(&self.pool)
598 .await?;
599
600 let mut xcoms = serde_json::Map::new();
601 for row in rows {
602 let task_id: String = row.get("task_id");
603 let xcom_output: Option<String> = row.get("xcom_output");
604 if let Some(output) = xcom_output {
605 xcoms.insert(task_id, serde_json::json!(output));
606 }
607 }
608
609 Ok(xcoms)
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[tokio::test]
618 async fn test_store_creation() {
619 let store = Store::new("sqlite::memory:").await;
620 assert!(store.is_ok());
621 }
622
623 #[tokio::test]
624 async fn test_save_and_get_dag() {
625 let store = Store::new("sqlite::memory:").await.unwrap();
626 let dag = DagDefinition {
627 id: "test_dag".to_string(),
628 description: Some("Test".to_string()),
629 schedule: None,
630 max_active_runs: None,
631 catchup: None,
632 tasks: vec![],
633 };
634
635 store.save_dag(&dag).await.unwrap();
636 let retrieved = store.get_dag("test_dag").await.unwrap();
637 assert!(retrieved.is_some());
638 assert_eq!(retrieved.unwrap().id, "test_dag");
639 }
640
641 #[tokio::test]
642 async fn test_dag_run_creation_and_retrieval() {
643 let store = Store::new("sqlite::memory:").await.unwrap();
644 let dag = DagDefinition {
645 id: "test_dag".to_string(),
646 description: None,
647 schedule: None,
648 max_active_runs: None,
649 catchup: None,
650 tasks: vec![],
651 };
652
653 store.save_dag(&dag).await.unwrap();
654
655 let dag_run = store
656 .create_dag_run("test_dag", TriggerType::Manual)
657 .await
658 .unwrap();
659
660 assert_eq!(dag_run.dag_id, "test_dag");
661 assert_eq!(dag_run.status, DagRunStatus::Queued);
662
663 let retrieved = store.get_dag_run(&dag_run.id).await.unwrap();
664 assert!(retrieved.is_some());
665 }
666
667 #[tokio::test]
668 async fn test_task_run_creation_and_update() {
669 let store = Store::new("sqlite::memory:").await.unwrap();
670 let dag = DagDefinition {
671 id: "test_dag".to_string(),
672 description: None,
673 schedule: None,
674 max_active_runs: None,
675 catchup: None,
676 tasks: vec![],
677 };
678
679 store.save_dag(&dag).await.unwrap();
680
681 let dag_run = store
682 .create_dag_run("test_dag", TriggerType::Manual)
683 .await
684 .unwrap();
685
686 let task_run = store
687 .create_task_run(&dag_run.id, "task_1")
688 .await
689 .unwrap();
690
691 assert_eq!(task_run.status, TaskRunStatus::Pending);
692
693 store
694 .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
695 .await
696 .unwrap();
697
698 let retrieved = store.get_task_run(&task_run.id).await.unwrap().unwrap();
699 assert_eq!(retrieved.status, TaskRunStatus::Running);
700 }
701
702 #[tokio::test]
703 async fn test_crash_recovery_orphaned_runs() {
704 let store = Store::new("sqlite::memory:").await.unwrap();
705 let dag = DagDefinition {
706 id: "test_dag".to_string(),
707 description: None,
708 schedule: None,
709 max_active_runs: None,
710 catchup: None,
711 tasks: vec![],
712 };
713
714 store.save_dag(&dag).await.unwrap();
715
716 let dag_run = store
718 .create_dag_run("test_dag", TriggerType::Manual)
719 .await
720 .unwrap();
721
722 store
723 .update_dag_run_status(&dag_run.id, DagRunStatus::Running)
724 .await
725 .unwrap();
726
727 let task_run = store
728 .create_task_run(&dag_run.id, "task_1")
729 .await
730 .unwrap();
731
732 store
733 .update_task_run(&task_run.id, TaskRunStatus::Running, None, None)
734 .await
735 .unwrap();
736
737 let dag_run_before = store.get_dag_run(&dag_run.id).await.unwrap().unwrap();
739 assert_eq!(dag_run_before.status, DagRunStatus::Running);
740
741 let task_run_before = store.get_task_run(&task_run.id).await.unwrap().unwrap();
742 assert_eq!(task_run_before.status, TaskRunStatus::Running);
743
744 store.recover_orphaned_runs().await.unwrap();
746
747 let dag_run_after = store.get_dag_run(&dag_run.id).await.unwrap().unwrap();
749 assert_eq!(dag_run_after.status, DagRunStatus::Failed);
750 assert!(dag_run_after.ended_at.is_some());
751
752 let task_run_after = store.get_task_run(&task_run.id).await.unwrap().unwrap();
753 assert_eq!(task_run_after.status, TaskRunStatus::Failed);
754 assert!(task_run_after.ended_at.is_some());
755 assert!(task_run_after.log.contains("Orphaned by executor crash"));
756 }
757
758 #[tokio::test]
759 async fn test_xcom_get_and_retrieve() {
760 let store = Store::new("sqlite::memory:").await.unwrap();
761 let dag = DagDefinition {
762 id: "test_dag".to_string(),
763 description: None,
764 schedule: None,
765 max_active_runs: None,
766 catchup: None,
767 tasks: vec![],
768 };
769
770 store.save_dag(&dag).await.unwrap();
771 let dag_run = store
772 .create_dag_run("test_dag", TriggerType::Manual)
773 .await
774 .unwrap();
775
776 let task_run = store
777 .create_task_run(&dag_run.id, "task_1")
778 .await
779 .unwrap();
780
781 let xcom_output = r#"{"result": "success", "count": 42}"#.to_string();
783 store
784 .update_task_run(&task_run.id, TaskRunStatus::Success, Some("Task completed"), Some(xcom_output.clone()))
785 .await
786 .unwrap();
787
788 let retrieved_xcom = store.get_xcom(&dag_run.id, "task_1").await.unwrap();
790 assert!(retrieved_xcom.is_some());
791 assert_eq!(retrieved_xcom.unwrap(), xcom_output);
792 }
793
794 #[tokio::test]
795 async fn test_get_all_xcoms_for_run() {
796 let store = Store::new("sqlite::memory:").await.unwrap();
797 let dag = DagDefinition {
798 id: "test_dag".to_string(),
799 description: None,
800 schedule: None,
801 max_active_runs: None,
802 catchup: None,
803 tasks: vec![],
804 };
805
806 store.save_dag(&dag).await.unwrap();
807 let dag_run = store
808 .create_dag_run("test_dag", TriggerType::Manual)
809 .await
810 .unwrap();
811
812 let task1 = store.create_task_run(&dag_run.id, "task_1").await.unwrap();
814 let task2 = store.create_task_run(&dag_run.id, "task_2").await.unwrap();
815
816 store
817 .update_task_run(&task1.id, TaskRunStatus::Success, None, Some(r#"{"value": 1}"#.to_string()))
818 .await
819 .unwrap();
820
821 store
822 .update_task_run(&task2.id, TaskRunStatus::Success, None, Some(r#"{"value": 2}"#.to_string()))
823 .await
824 .unwrap();
825
826 let all_xcoms = store.get_all_xcoms_for_run(&dag_run.id).await.unwrap();
828 assert_eq!(all_xcoms.len(), 2);
829 assert!(all_xcoms.contains_key("task_1"));
830 assert!(all_xcoms.contains_key("task_2"));
831 }
832}