1use std::pin::Pin;
2
3use sea_orm::{
4 ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, Statement,
5 TransactionTrait,
6};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use std::sync::atomic::{AtomicI32, Ordering};
10use std::time::Duration;
11use uuid::Uuid;
12
13use crate::error::DurableError;
14
15const MAX_CHECKPOINT_RETRIES: u32 = 3;
18const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
19
20async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
26where
27 F: FnMut() -> Fut,
28 Fut: std::future::Future<Output = Result<(), DurableError>>,
29{
30 match f().await {
31 Ok(()) => Ok(()),
32 Err(first_err) => {
33 for i in 0..MAX_CHECKPOINT_RETRIES {
34 tokio::time::sleep(Duration::from_millis(
35 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
36 ))
37 .await;
38 if f().await.is_ok() {
39 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
40 return Ok(());
41 }
42 }
43 Err(first_err)
44 }
45 }
46}
47
48pub struct RetryPolicy {
50 pub max_retries: u32,
51 pub initial_backoff: std::time::Duration,
52 pub backoff_multiplier: f64,
53}
54
55impl RetryPolicy {
56 pub fn none() -> Self {
58 Self {
59 max_retries: 0,
60 initial_backoff: std::time::Duration::from_secs(0),
61 backoff_multiplier: 1.0,
62 }
63 }
64
65 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
67 Self {
68 max_retries,
69 initial_backoff,
70 backoff_multiplier: 2.0,
71 }
72 }
73
74 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
76 Self {
77 max_retries,
78 initial_backoff: backoff,
79 backoff_multiplier: 1.0,
80 }
81 }
82}
83
84pub struct Ctx {
90 db: DatabaseConnection,
91 task_id: Uuid,
92 sequence: AtomicI32,
93}
94
95impl Ctx {
96 pub async fn start(
104 db: &DatabaseConnection,
105 name: &str,
106 input: Option<serde_json::Value>,
107 ) -> Result<Self, DurableError> {
108 let txn = db.begin().await?;
109 let (task_id, _saved) =
113 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
114 retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
115 txn.commit().await?;
116 Ok(Self {
117 db: db.clone(),
118 task_id,
119 sequence: AtomicI32::new(0),
120 })
121 }
122
123 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
132 where
133 T: Serialize + DeserializeOwned,
134 F: FnOnce() -> Fut,
135 Fut: std::future::Future<Output = Result<T, DurableError>>,
136 {
137 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
138
139 check_deadline(&self.db, self.task_id).await?;
141
142 let txn = self.db.begin().await?;
144
145 let (step_id, saved_output) = find_or_create_task(
150 &txn,
151 Some(self.task_id),
152 Some(seq),
153 name,
154 "STEP",
155 None,
156 true,
157 Some(0),
158 )
159 .await?;
160
161 if let Some(output) = saved_output {
163 txn.commit().await?;
164 let val: T = serde_json::from_value(output)?;
165 tracing::debug!(step = name, seq, "replaying saved output");
166 return Ok(val);
167 }
168
169 retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
171 match f().await {
172 Ok(val) => {
173 let json = serde_json::to_value(&val)?;
174 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
175 txn.commit().await?;
176 tracing::debug!(step = name, seq, "step completed");
177 Ok(val)
178 }
179 Err(e) => {
180 let err_msg = e.to_string();
181 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
182 txn.commit().await?;
183 Err(e)
184 }
185 }
186 }
187
188 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
200 where
201 T: Serialize + DeserializeOwned + Send,
202 F: for<'tx> FnOnce(
203 &'tx DatabaseTransaction,
204 ) -> Pin<
205 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
206 > + Send,
207 {
208 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
209
210 let (step_id, saved_output) = find_or_create_task(
213 &self.db,
214 Some(self.task_id),
215 Some(seq),
216 name,
217 "TRANSACTION",
218 None,
219 false,
220 None,
221 )
222 .await?;
223
224 if let Some(output) = saved_output {
226 let val: T = serde_json::from_value(output)?;
227 tracing::debug!(step = name, seq, "replaying saved transaction output");
228 return Ok(val);
229 }
230
231 let tx = self.db.begin().await?;
233
234 set_status(&tx, step_id, "RUNNING").await?;
235
236 match f(&tx).await {
237 Ok(val) => {
238 let json = serde_json::to_value(&val)?;
239 complete_task(&tx, step_id, json).await?;
240 tx.commit().await?;
241 tracing::debug!(step = name, seq, "transaction step committed");
242 Ok(val)
243 }
244 Err(e) => {
245 drop(tx);
248 fail_task(&self.db, step_id, &e.to_string()).await?;
249 Err(e)
250 }
251 }
252 }
253
254 pub async fn child(
262 &self,
263 name: &str,
264 input: Option<serde_json::Value>,
265 ) -> Result<Self, DurableError> {
266 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
267
268 let txn = self.db.begin().await?;
269 let (child_id, _saved) = find_or_create_task(
271 &txn,
272 Some(self.task_id),
273 Some(seq),
274 name,
275 "WORKFLOW",
276 input,
277 false,
278 None,
279 )
280 .await?;
281
282 retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
285 txn.commit().await?;
286
287 Ok(Self {
288 db: self.db.clone(),
289 task_id: child_id,
290 sequence: AtomicI32::new(0),
291 })
292 }
293
294 pub async fn is_completed(&self) -> Result<bool, DurableError> {
296 let status = get_status(&self.db, self.task_id).await?;
297 Ok(status.as_deref() == Some("COMPLETED"))
298 }
299
300 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
302 match get_output(&self.db, self.task_id).await? {
303 Some(val) => Ok(Some(serde_json::from_value(val)?)),
304 None => Ok(None),
305 }
306 }
307
308 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
310 let json = serde_json::to_value(output)?;
311 let db = &self.db;
312 let task_id = self.task_id;
313 retry_db_write(|| complete_task(db, task_id, json.clone())).await
314 }
315
316 pub async fn step_with_retry<T, F, Fut>(
330 &self,
331 name: &str,
332 policy: RetryPolicy,
333 f: F,
334 ) -> Result<T, DurableError>
335 where
336 T: Serialize + DeserializeOwned,
337 F: Fn() -> Fut,
338 Fut: std::future::Future<Output = Result<T, DurableError>>,
339 {
340 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
341
342 let (step_id, saved_output) = find_or_create_task(
346 &self.db,
347 Some(self.task_id),
348 Some(seq),
349 name,
350 "STEP",
351 None,
352 false,
353 Some(policy.max_retries),
354 )
355 .await?;
356
357 if let Some(output) = saved_output {
359 let val: T = serde_json::from_value(output)?;
360 tracing::debug!(step = name, seq, "replaying saved output");
361 return Ok(val);
362 }
363
364 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
366
367 loop {
369 set_status(&self.db, step_id, "RUNNING").await?;
370 match f().await {
371 Ok(val) => {
372 let json = serde_json::to_value(&val)?;
373 complete_task(&self.db, step_id, json).await?;
374 tracing::debug!(step = name, seq, retry_count, "step completed");
375 return Ok(val);
376 }
377 Err(e) => {
378 if retry_count < max_retries {
379 retry_count = increment_retry_count(&self.db, step_id).await?;
381 tracing::debug!(
382 step = name,
383 seq,
384 retry_count,
385 max_retries,
386 "step failed, retrying"
387 );
388
389 let backoff = if policy.initial_backoff.is_zero() {
391 std::time::Duration::ZERO
392 } else {
393 let factor = policy
394 .backoff_multiplier
395 .powi((retry_count - 1) as i32)
396 .max(1.0);
397 let millis =
398 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
399 std::time::Duration::from_millis(millis)
400 };
401
402 if !backoff.is_zero() {
403 tokio::time::sleep(backoff).await;
404 }
405 } else {
406 fail_task(&self.db, step_id, &e.to_string()).await?;
408 tracing::debug!(
409 step = name,
410 seq,
411 retry_count,
412 "step exhausted retries, marked FAILED"
413 );
414 return Err(e);
415 }
416 }
417 }
418 }
419 }
420
421 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
423 let db = &self.db;
424 let task_id = self.task_id;
425 retry_db_write(|| fail_task(db, task_id, error)).await
426 }
427
428 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
436 let sql = format!(
437 "UPDATE durable.task \
438 SET timeout_ms = {timeout_ms}, \
439 deadline_epoch_ms = CASE \
440 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
441 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
442 ELSE deadline_epoch_ms \
443 END \
444 WHERE id = '{}'",
445 self.task_id
446 );
447 self.db
448 .execute(Statement::from_string(DbBackend::Postgres, sql))
449 .await?;
450 Ok(())
451 }
452
453 pub async fn start_with_timeout(
457 db: &DatabaseConnection,
458 name: &str,
459 input: Option<serde_json::Value>,
460 timeout_ms: i64,
461 ) -> Result<Self, DurableError> {
462 let ctx = Self::start(db, name, input).await?;
463 ctx.set_timeout(timeout_ms).await?;
464 Ok(ctx)
465 }
466
467 pub fn db(&self) -> &DatabaseConnection {
470 &self.db
471 }
472
473 pub fn task_id(&self) -> Uuid {
474 self.task_id
475 }
476
477 pub fn next_sequence(&self) -> i32 {
478 self.sequence.fetch_add(1, Ordering::SeqCst)
479 }
480}
481
482#[allow(clippy::too_many_arguments)]
503async fn find_or_create_task(
504 db: &impl ConnectionTrait,
505 parent_id: Option<Uuid>,
506 sequence: Option<i32>,
507 name: &str,
508 kind: &str,
509 input: Option<serde_json::Value>,
510 lock: bool,
511 max_retries: Option<u32>,
512) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
513 let parent_eq = match parent_id {
514 Some(p) => format!("= '{p}'"),
515 None => "IS NULL".to_string(),
516 };
517 let parent_sql = match parent_id {
518 Some(p) => format!("'{p}'"),
519 None => "NULL".to_string(),
520 };
521
522 if lock {
523 let new_id = Uuid::new_v4();
537 let seq_sql = match sequence {
538 Some(s) => s.to_string(),
539 None => "NULL".to_string(),
540 };
541 let input_sql = match &input {
542 Some(v) => format!("'{}'", serde_json::to_string(v)?),
543 None => "NULL".to_string(),
544 };
545
546 let max_retries_sql = match max_retries {
547 Some(r) => r.to_string(),
548 None => "3".to_string(), };
550
551 let insert_sql = format!(
553 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
554 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
555 ON CONFLICT (parent_id, name) DO NOTHING"
556 );
557 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
558 .await?;
559
560 let lock_sql = format!(
562 "SELECT id, status, output FROM durable.task \
563 WHERE parent_id {parent_eq} AND name = '{name}' \
564 FOR UPDATE SKIP LOCKED"
565 );
566 let row = db
567 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
568 .await?;
569
570 if let Some(row) = row {
571 let id: Uuid = row
572 .try_get_by_index(0)
573 .map_err(|e| DurableError::custom(e.to_string()))?;
574 let status: String = row
575 .try_get_by_index(1)
576 .map_err(|e| DurableError::custom(e.to_string()))?;
577 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
578
579 if status == "COMPLETED" {
580 return Ok((id, output));
582 }
583 return Ok((id, None));
585 }
586
587 Err(DurableError::StepLocked(name.to_string()))
589 } else {
590 let find_sql = format!(
594 "SELECT id, status, output FROM durable.task \
595 WHERE parent_id {parent_eq} AND name = '{name}'"
596 );
597 let row = db
598 .query_one(Statement::from_string(DbBackend::Postgres, find_sql))
599 .await?;
600
601 if let Some(row) = row {
602 let id: Uuid = row
603 .try_get_by_index(0)
604 .map_err(|e| DurableError::custom(e.to_string()))?;
605 let status: String = row
606 .try_get_by_index(1)
607 .map_err(|e| DurableError::custom(e.to_string()))?;
608 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
609
610 if status == "COMPLETED" {
611 return Ok((id, output));
612 }
613 return Ok((id, None));
614 }
615
616 let id = Uuid::new_v4();
618 let seq_sql = match sequence {
619 Some(s) => s.to_string(),
620 None => "NULL".to_string(),
621 };
622 let input_sql = match &input {
623 Some(v) => format!("'{}'", serde_json::to_string(v)?),
624 None => "NULL".to_string(),
625 };
626 let max_retries_sql = match max_retries {
627 Some(r) => r.to_string(),
628 None => "3".to_string(), };
630
631 let sql = format!(
632 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
633 VALUES ('{id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql})"
634 );
635 db.execute(Statement::from_string(DbBackend::Postgres, sql))
636 .await?;
637
638 Ok((id, None))
639 }
640}
641
642async fn get_output(
643 db: &impl ConnectionTrait,
644 task_id: Uuid,
645) -> Result<Option<serde_json::Value>, DurableError> {
646 let sql =
647 format!("SELECT output FROM durable.task WHERE id = '{task_id}' AND status = 'COMPLETED'");
648 let row = db
649 .query_one(Statement::from_string(DbBackend::Postgres, sql))
650 .await?;
651
652 match row {
653 Some(r) => Ok(r.try_get_by_index(0).ok()),
654 None => Ok(None),
655 }
656}
657
658async fn get_status(
659 db: &impl ConnectionTrait,
660 task_id: Uuid,
661) -> Result<Option<String>, DurableError> {
662 let sql = format!("SELECT status FROM durable.task WHERE id = '{task_id}'");
663 let row = db
664 .query_one(Statement::from_string(DbBackend::Postgres, sql))
665 .await?;
666
667 match row {
668 Some(r) => Ok(r.try_get_by_index(0).ok()),
669 None => Ok(None),
670 }
671}
672
673async fn get_retry_info(
675 db: &DatabaseConnection,
676 task_id: Uuid,
677) -> Result<(u32, u32), DurableError> {
678 let sql = format!("SELECT retry_count, max_retries FROM durable.task WHERE id = '{task_id}'");
679 let row = db
680 .query_one(Statement::from_string(DbBackend::Postgres, sql))
681 .await?;
682
683 match row {
684 Some(r) => {
685 let retry_count: i32 = r
686 .try_get_by_index(0)
687 .map_err(|e| DurableError::custom(e.to_string()))?;
688 let max_retries: i32 = r
689 .try_get_by_index(1)
690 .map_err(|e| DurableError::custom(e.to_string()))?;
691 Ok((retry_count as u32, max_retries as u32))
692 }
693 None => Err(DurableError::custom(format!(
694 "task {task_id} not found when reading retry info"
695 ))),
696 }
697}
698
699async fn increment_retry_count(
701 db: &DatabaseConnection,
702 task_id: Uuid,
703) -> Result<u32, DurableError> {
704 let sql = format!(
705 "UPDATE durable.task \
706 SET retry_count = retry_count + 1, status = 'PENDING', error = NULL, completed_at = NULL \
707 WHERE id = '{task_id}' \
708 RETURNING retry_count"
709 );
710 let row = db
711 .query_one(Statement::from_string(DbBackend::Postgres, sql))
712 .await?;
713
714 match row {
715 Some(r) => {
716 let count: i32 = r
717 .try_get_by_index(0)
718 .map_err(|e| DurableError::custom(e.to_string()))?;
719 Ok(count as u32)
720 }
721 None => Err(DurableError::custom(format!(
722 "task {task_id} not found when incrementing retry count"
723 ))),
724 }
725}
726
727async fn set_status(
728 db: &impl ConnectionTrait,
729 task_id: Uuid,
730 status: &str,
731) -> Result<(), DurableError> {
732 let sql = format!(
733 "UPDATE durable.task \
734 SET status = '{status}', \
735 started_at = COALESCE(started_at, now()), \
736 deadline_epoch_ms = CASE \
737 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
738 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
739 ELSE deadline_epoch_ms \
740 END \
741 WHERE id = '{task_id}'"
742 );
743 db.execute(Statement::from_string(DbBackend::Postgres, sql))
744 .await?;
745 Ok(())
746}
747
748async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
750 let sql = format!("SELECT deadline_epoch_ms FROM durable.task WHERE id = '{task_id}'");
751 let row = db
752 .query_one(Statement::from_string(DbBackend::Postgres, sql))
753 .await?;
754
755 if let Some(row) = row {
756 let deadline: Option<i64> = row.try_get_by_index(0).ok().flatten();
757 if let Some(deadline_ms) = deadline {
758 let now_ms = std::time::SystemTime::now()
759 .duration_since(std::time::UNIX_EPOCH)
760 .map(|d| d.as_millis() as i64)
761 .unwrap_or(0);
762 if now_ms > deadline_ms {
763 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
764 }
765 }
766 }
767
768 Ok(())
769}
770
771async fn complete_task(
772 db: &impl ConnectionTrait,
773 task_id: Uuid,
774 output: serde_json::Value,
775) -> Result<(), DurableError> {
776 let sql = format!(
777 "UPDATE durable.task SET status = 'COMPLETED', output = '{}', completed_at = now() \
778 WHERE id = '{task_id}'",
779 output
780 );
781 db.execute(Statement::from_string(DbBackend::Postgres, sql))
782 .await?;
783 Ok(())
784}
785
786async fn fail_task(
787 db: &impl ConnectionTrait,
788 task_id: Uuid,
789 error: &str,
790) -> Result<(), DurableError> {
791 let sql = format!(
792 "UPDATE durable.task SET status = 'FAILED', error = '{}', completed_at = now() \
793 WHERE id = '{task_id}'",
794 error.replace('\'', "''")
795 );
796 db.execute(Statement::from_string(DbBackend::Postgres, sql))
797 .await?;
798 Ok(())
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use std::sync::Arc;
805 use std::sync::atomic::{AtomicU32, Ordering};
806
807 #[tokio::test]
810 async fn test_retry_db_write_succeeds_first_try() {
811 let call_count = Arc::new(AtomicU32::new(0));
812 let cc = call_count.clone();
813 let result = retry_db_write(|| {
814 let c = cc.clone();
815 async move {
816 c.fetch_add(1, Ordering::SeqCst);
817 Ok::<(), DurableError>(())
818 }
819 })
820 .await;
821 assert!(result.is_ok());
822 assert_eq!(call_count.load(Ordering::SeqCst), 1);
823 }
824
825 #[tokio::test]
828 async fn test_retry_db_write_succeeds_after_transient_failure() {
829 let call_count = Arc::new(AtomicU32::new(0));
830 let cc = call_count.clone();
831 let result = retry_db_write(|| {
832 let c = cc.clone();
833 async move {
834 let n = c.fetch_add(1, Ordering::SeqCst);
835 if n < 2 {
836 Err(DurableError::Db(sea_orm::DbErr::Custom(
837 "transient".to_string(),
838 )))
839 } else {
840 Ok(())
841 }
842 }
843 })
844 .await;
845 assert!(result.is_ok());
846 assert_eq!(call_count.load(Ordering::SeqCst), 3);
847 }
848
849 #[tokio::test]
852 async fn test_retry_db_write_exhausts_retries() {
853 let call_count = Arc::new(AtomicU32::new(0));
854 let cc = call_count.clone();
855 let result = retry_db_write(|| {
856 let c = cc.clone();
857 async move {
858 c.fetch_add(1, Ordering::SeqCst);
859 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
860 "always fails".to_string(),
861 )))
862 }
863 })
864 .await;
865 assert!(result.is_err());
866 assert_eq!(
868 call_count.load(Ordering::SeqCst),
869 1 + MAX_CHECKPOINT_RETRIES
870 );
871 }
872
873 #[tokio::test]
876 async fn test_retry_db_write_returns_original_error() {
877 let call_count = Arc::new(AtomicU32::new(0));
878 let cc = call_count.clone();
879 let result = retry_db_write(|| {
880 let c = cc.clone();
881 async move {
882 let n = c.fetch_add(1, Ordering::SeqCst);
883 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
884 "error-{}",
885 n
886 ))))
887 }
888 })
889 .await;
890 let err = result.unwrap_err();
891 assert!(
893 err.to_string().contains("error-0"),
894 "expected first error (error-0), got: {err}"
895 );
896 }
897}