1use std::pin::Pin;
2
3use durable_db::entity::task::{
4 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
5};
6use sea_orm::{
7 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8 DbBackend, EntityTrait, QueryFilter, Set, Statement, TransactionTrait,
9};
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use std::sync::atomic::{AtomicI32, Ordering};
13use std::time::Duration;
14use uuid::Uuid;
15
16use crate::error::DurableError;
17
18const MAX_CHECKPOINT_RETRIES: u32 = 3;
21const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
22
23async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
29where
30 F: FnMut() -> Fut,
31 Fut: std::future::Future<Output = Result<(), DurableError>>,
32{
33 match f().await {
34 Ok(()) => Ok(()),
35 Err(first_err) => {
36 for i in 0..MAX_CHECKPOINT_RETRIES {
37 tokio::time::sleep(Duration::from_millis(
38 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
39 ))
40 .await;
41 if f().await.is_ok() {
42 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
43 return Ok(());
44 }
45 }
46 Err(first_err)
47 }
48 }
49}
50
51pub struct RetryPolicy {
53 pub max_retries: u32,
54 pub initial_backoff: std::time::Duration,
55 pub backoff_multiplier: f64,
56}
57
58impl RetryPolicy {
59 pub fn none() -> Self {
61 Self {
62 max_retries: 0,
63 initial_backoff: std::time::Duration::from_secs(0),
64 backoff_multiplier: 1.0,
65 }
66 }
67
68 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
70 Self {
71 max_retries,
72 initial_backoff,
73 backoff_multiplier: 2.0,
74 }
75 }
76
77 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
79 Self {
80 max_retries,
81 initial_backoff: backoff,
82 backoff_multiplier: 1.0,
83 }
84 }
85}
86
87pub struct Ctx {
93 db: DatabaseConnection,
94 task_id: Uuid,
95 sequence: AtomicI32,
96}
97
98impl Ctx {
99 pub async fn start(
107 db: &DatabaseConnection,
108 name: &str,
109 input: Option<serde_json::Value>,
110 ) -> Result<Self, DurableError> {
111 let txn = db.begin().await?;
112 let (task_id, _saved) =
116 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
117 retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
118 txn.commit().await?;
119 Ok(Self {
120 db: db.clone(),
121 task_id,
122 sequence: AtomicI32::new(0),
123 })
124 }
125
126 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
135 where
136 T: Serialize + DeserializeOwned,
137 F: FnOnce() -> Fut,
138 Fut: std::future::Future<Output = Result<T, DurableError>>,
139 {
140 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
141
142 check_deadline(&self.db, self.task_id).await?;
144
145 let txn = self.db.begin().await?;
147
148 let (step_id, saved_output) = find_or_create_task(
153 &txn,
154 Some(self.task_id),
155 Some(seq),
156 name,
157 "STEP",
158 None,
159 true,
160 Some(0),
161 )
162 .await?;
163
164 if let Some(output) = saved_output {
166 txn.commit().await?;
167 let val: T = serde_json::from_value(output)?;
168 tracing::debug!(step = name, seq, "replaying saved output");
169 return Ok(val);
170 }
171
172 retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
174 match f().await {
175 Ok(val) => {
176 let json = serde_json::to_value(&val)?;
177 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
178 txn.commit().await?;
179 tracing::debug!(step = name, seq, "step completed");
180 Ok(val)
181 }
182 Err(e) => {
183 let err_msg = e.to_string();
184 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
185 txn.commit().await?;
186 Err(e)
187 }
188 }
189 }
190
191 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
203 where
204 T: Serialize + DeserializeOwned + Send,
205 F: for<'tx> FnOnce(
206 &'tx DatabaseTransaction,
207 ) -> Pin<
208 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
209 > + Send,
210 {
211 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
212
213 let (step_id, saved_output) = find_or_create_task(
216 &self.db,
217 Some(self.task_id),
218 Some(seq),
219 name,
220 "TRANSACTION",
221 None,
222 false,
223 None,
224 )
225 .await?;
226
227 if let Some(output) = saved_output {
229 let val: T = serde_json::from_value(output)?;
230 tracing::debug!(step = name, seq, "replaying saved transaction output");
231 return Ok(val);
232 }
233
234 let tx = self.db.begin().await?;
236
237 set_status(&tx, step_id, "RUNNING").await?;
238
239 match f(&tx).await {
240 Ok(val) => {
241 let json = serde_json::to_value(&val)?;
242 complete_task(&tx, step_id, json).await?;
243 tx.commit().await?;
244 tracing::debug!(step = name, seq, "transaction step committed");
245 Ok(val)
246 }
247 Err(e) => {
248 drop(tx);
251 fail_task(&self.db, step_id, &e.to_string()).await?;
252 Err(e)
253 }
254 }
255 }
256
257 pub async fn child(
265 &self,
266 name: &str,
267 input: Option<serde_json::Value>,
268 ) -> Result<Self, DurableError> {
269 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
270
271 let txn = self.db.begin().await?;
272 let (child_id, _saved) = find_or_create_task(
274 &txn,
275 Some(self.task_id),
276 Some(seq),
277 name,
278 "WORKFLOW",
279 input,
280 false,
281 None,
282 )
283 .await?;
284
285 retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
288 txn.commit().await?;
289
290 Ok(Self {
291 db: self.db.clone(),
292 task_id: child_id,
293 sequence: AtomicI32::new(0),
294 })
295 }
296
297 pub async fn is_completed(&self) -> Result<bool, DurableError> {
299 let status = get_status(&self.db, self.task_id).await?;
300 Ok(status.as_deref() == Some("COMPLETED"))
301 }
302
303 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
305 match get_output(&self.db, self.task_id).await? {
306 Some(val) => Ok(Some(serde_json::from_value(val)?)),
307 None => Ok(None),
308 }
309 }
310
311 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
313 let json = serde_json::to_value(output)?;
314 let db = &self.db;
315 let task_id = self.task_id;
316 retry_db_write(|| complete_task(db, task_id, json.clone())).await
317 }
318
319 pub async fn step_with_retry<T, F, Fut>(
333 &self,
334 name: &str,
335 policy: RetryPolicy,
336 f: F,
337 ) -> Result<T, DurableError>
338 where
339 T: Serialize + DeserializeOwned,
340 F: Fn() -> Fut,
341 Fut: std::future::Future<Output = Result<T, DurableError>>,
342 {
343 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
344
345 let (step_id, saved_output) = find_or_create_task(
349 &self.db,
350 Some(self.task_id),
351 Some(seq),
352 name,
353 "STEP",
354 None,
355 false,
356 Some(policy.max_retries),
357 )
358 .await?;
359
360 if let Some(output) = saved_output {
362 let val: T = serde_json::from_value(output)?;
363 tracing::debug!(step = name, seq, "replaying saved output");
364 return Ok(val);
365 }
366
367 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
369
370 loop {
372 set_status(&self.db, step_id, "RUNNING").await?;
373 match f().await {
374 Ok(val) => {
375 let json = serde_json::to_value(&val)?;
376 complete_task(&self.db, step_id, json).await?;
377 tracing::debug!(step = name, seq, retry_count, "step completed");
378 return Ok(val);
379 }
380 Err(e) => {
381 if retry_count < max_retries {
382 retry_count = increment_retry_count(&self.db, step_id).await?;
384 tracing::debug!(
385 step = name,
386 seq,
387 retry_count,
388 max_retries,
389 "step failed, retrying"
390 );
391
392 let backoff = if policy.initial_backoff.is_zero() {
394 std::time::Duration::ZERO
395 } else {
396 let factor = policy
397 .backoff_multiplier
398 .powi((retry_count - 1) as i32)
399 .max(1.0);
400 let millis =
401 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
402 std::time::Duration::from_millis(millis)
403 };
404
405 if !backoff.is_zero() {
406 tokio::time::sleep(backoff).await;
407 }
408 } else {
409 fail_task(&self.db, step_id, &e.to_string()).await?;
411 tracing::debug!(
412 step = name,
413 seq,
414 retry_count,
415 "step exhausted retries, marked FAILED"
416 );
417 return Err(e);
418 }
419 }
420 }
421 }
422 }
423
424 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
426 let db = &self.db;
427 let task_id = self.task_id;
428 retry_db_write(|| fail_task(db, task_id, error)).await
429 }
430
431 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
439 let sql = format!(
440 "UPDATE durable.task \
441 SET timeout_ms = {timeout_ms}, \
442 deadline_epoch_ms = CASE \
443 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
444 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
445 ELSE deadline_epoch_ms \
446 END \
447 WHERE id = '{}'",
448 self.task_id
449 );
450 self.db
451 .execute(Statement::from_string(DbBackend::Postgres, sql))
452 .await?;
453 Ok(())
454 }
455
456 pub async fn start_with_timeout(
460 db: &DatabaseConnection,
461 name: &str,
462 input: Option<serde_json::Value>,
463 timeout_ms: i64,
464 ) -> Result<Self, DurableError> {
465 let ctx = Self::start(db, name, input).await?;
466 ctx.set_timeout(timeout_ms).await?;
467 Ok(ctx)
468 }
469
470 pub fn db(&self) -> &DatabaseConnection {
473 &self.db
474 }
475
476 pub fn task_id(&self) -> Uuid {
477 self.task_id
478 }
479
480 pub fn next_sequence(&self) -> i32 {
481 self.sequence.fetch_add(1, Ordering::SeqCst)
482 }
483}
484
485#[allow(clippy::too_many_arguments)]
506async fn find_or_create_task(
507 db: &impl ConnectionTrait,
508 parent_id: Option<Uuid>,
509 sequence: Option<i32>,
510 name: &str,
511 kind: &str,
512 input: Option<serde_json::Value>,
513 lock: bool,
514 max_retries: Option<u32>,
515) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
516 let parent_eq = match parent_id {
517 Some(p) => format!("= '{p}'"),
518 None => "IS NULL".to_string(),
519 };
520 let parent_sql = match parent_id {
521 Some(p) => format!("'{p}'"),
522 None => "NULL".to_string(),
523 };
524
525 if lock {
526 let new_id = Uuid::new_v4();
540 let seq_sql = match sequence {
541 Some(s) => s.to_string(),
542 None => "NULL".to_string(),
543 };
544 let input_sql = match &input {
545 Some(v) => format!("'{}'", serde_json::to_string(v)?),
546 None => "NULL".to_string(),
547 };
548
549 let max_retries_sql = match max_retries {
550 Some(r) => r.to_string(),
551 None => "3".to_string(), };
553
554 let insert_sql = format!(
556 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
557 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
558 ON CONFLICT (parent_id, name) DO NOTHING"
559 );
560 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
561 .await?;
562
563 let lock_sql = format!(
565 "SELECT id, status, output FROM durable.task \
566 WHERE parent_id {parent_eq} AND name = '{name}' \
567 FOR UPDATE SKIP LOCKED"
568 );
569 let row = db
570 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
571 .await?;
572
573 if let Some(row) = row {
574 let id: Uuid = row
575 .try_get_by_index(0)
576 .map_err(|e| DurableError::custom(e.to_string()))?;
577 let status: String = row
578 .try_get_by_index(1)
579 .map_err(|e| DurableError::custom(e.to_string()))?;
580 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
581
582 if status == "COMPLETED" {
583 return Ok((id, output));
585 }
586 return Ok((id, None));
588 }
589
590 Err(DurableError::StepLocked(name.to_string()))
592 } else {
593 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
597 query = match parent_id {
598 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
599 None => query.filter(TaskColumn::ParentId.is_null()),
600 };
601 let existing = query.one(db).await?;
602
603 if let Some(model) = existing {
604 if model.status == "COMPLETED" {
605 return Ok((model.id, model.output));
606 }
607 return Ok((model.id, None));
608 }
609
610 let id = Uuid::new_v4();
612 let new_task = TaskActiveModel {
613 id: Set(id),
614 parent_id: Set(parent_id),
615 sequence: Set(sequence),
616 name: Set(name.to_string()),
617 kind: Set(kind.to_string()),
618 status: Set("PENDING".to_string()),
619 input: Set(input),
620 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
621 ..Default::default()
622 };
623 new_task.insert(db).await?;
624
625 Ok((id, None))
626 }
627}
628
629async fn get_output(
630 db: &impl ConnectionTrait,
631 task_id: Uuid,
632) -> Result<Option<serde_json::Value>, DurableError> {
633 let model = Task::find_by_id(task_id)
634 .filter(TaskColumn::Status.eq("COMPLETED"))
635 .one(db)
636 .await?;
637
638 Ok(model.and_then(|m| m.output))
639}
640
641async fn get_status(
642 db: &impl ConnectionTrait,
643 task_id: Uuid,
644) -> Result<Option<String>, DurableError> {
645 let model = Task::find_by_id(task_id).one(db).await?;
646
647 Ok(model.map(|m| m.status))
648}
649
650async fn get_retry_info(
652 db: &DatabaseConnection,
653 task_id: Uuid,
654) -> Result<(u32, u32), DurableError> {
655 let model = Task::find_by_id(task_id).one(db).await?;
656
657 match model {
658 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
659 None => Err(DurableError::custom(format!(
660 "task {task_id} not found when reading retry info"
661 ))),
662 }
663}
664
665async fn increment_retry_count(
667 db: &DatabaseConnection,
668 task_id: Uuid,
669) -> Result<u32, DurableError> {
670 let model = Task::find_by_id(task_id).one(db).await?;
671
672 match model {
673 Some(m) => {
674 let new_count = m.retry_count + 1;
675 let mut active: TaskActiveModel = m.into();
676 active.retry_count = Set(new_count);
677 active.status = Set("PENDING".to_string());
678 active.error = Set(None);
679 active.completed_at = Set(None);
680 active.update(db).await?;
681 Ok(new_count as u32)
682 }
683 None => Err(DurableError::custom(format!(
684 "task {task_id} not found when incrementing retry count"
685 ))),
686 }
687}
688
689async fn set_status(
692 db: &impl ConnectionTrait,
693 task_id: Uuid,
694 status: &str,
695) -> Result<(), DurableError> {
696 let sql = format!(
697 "UPDATE durable.task \
698 SET status = '{status}', \
699 started_at = COALESCE(started_at, now()), \
700 deadline_epoch_ms = CASE \
701 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
702 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
703 ELSE deadline_epoch_ms \
704 END \
705 WHERE id = '{task_id}'"
706 );
707 db.execute(Statement::from_string(DbBackend::Postgres, sql))
708 .await?;
709 Ok(())
710}
711
712async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
714 let model = Task::find_by_id(task_id).one(db).await?;
715
716 if let Some(m) = model
717 && let Some(deadline_ms) = m.deadline_epoch_ms
718 {
719 let now_ms = std::time::SystemTime::now()
720 .duration_since(std::time::UNIX_EPOCH)
721 .map(|d| d.as_millis() as i64)
722 .unwrap_or(0);
723 if now_ms > deadline_ms {
724 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
725 }
726 }
727
728 Ok(())
729}
730
731async fn complete_task(
732 db: &impl ConnectionTrait,
733 task_id: Uuid,
734 output: serde_json::Value,
735) -> Result<(), DurableError> {
736 let model = Task::find_by_id(task_id).one(db).await?;
737
738 if let Some(m) = model {
739 let mut active: TaskActiveModel = m.into();
740 active.status = Set("COMPLETED".to_string());
741 active.output = Set(Some(output));
742 active.completed_at = Set(Some(chrono::Utc::now().into()));
743 active.update(db).await?;
744 }
745 Ok(())
746}
747
748async fn fail_task(
749 db: &impl ConnectionTrait,
750 task_id: Uuid,
751 error: &str,
752) -> Result<(), DurableError> {
753 let model = Task::find_by_id(task_id).one(db).await?;
754
755 if let Some(m) = model {
756 let mut active: TaskActiveModel = m.into();
757 active.status = Set("FAILED".to_string());
758 active.error = Set(Some(error.to_string()));
759 active.completed_at = Set(Some(chrono::Utc::now().into()));
760 active.update(db).await?;
761 }
762 Ok(())
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use std::sync::Arc;
769 use std::sync::atomic::{AtomicU32, Ordering};
770
771 #[tokio::test]
774 async fn test_retry_db_write_succeeds_first_try() {
775 let call_count = Arc::new(AtomicU32::new(0));
776 let cc = call_count.clone();
777 let result = retry_db_write(|| {
778 let c = cc.clone();
779 async move {
780 c.fetch_add(1, Ordering::SeqCst);
781 Ok::<(), DurableError>(())
782 }
783 })
784 .await;
785 assert!(result.is_ok());
786 assert_eq!(call_count.load(Ordering::SeqCst), 1);
787 }
788
789 #[tokio::test]
792 async fn test_retry_db_write_succeeds_after_transient_failure() {
793 let call_count = Arc::new(AtomicU32::new(0));
794 let cc = call_count.clone();
795 let result = retry_db_write(|| {
796 let c = cc.clone();
797 async move {
798 let n = c.fetch_add(1, Ordering::SeqCst);
799 if n < 2 {
800 Err(DurableError::Db(sea_orm::DbErr::Custom(
801 "transient".to_string(),
802 )))
803 } else {
804 Ok(())
805 }
806 }
807 })
808 .await;
809 assert!(result.is_ok());
810 assert_eq!(call_count.load(Ordering::SeqCst), 3);
811 }
812
813 #[tokio::test]
816 async fn test_retry_db_write_exhausts_retries() {
817 let call_count = Arc::new(AtomicU32::new(0));
818 let cc = call_count.clone();
819 let result = retry_db_write(|| {
820 let c = cc.clone();
821 async move {
822 c.fetch_add(1, Ordering::SeqCst);
823 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
824 "always fails".to_string(),
825 )))
826 }
827 })
828 .await;
829 assert!(result.is_err());
830 assert_eq!(
832 call_count.load(Ordering::SeqCst),
833 1 + MAX_CHECKPOINT_RETRIES
834 );
835 }
836
837 #[tokio::test]
840 async fn test_retry_db_write_returns_original_error() {
841 let call_count = Arc::new(AtomicU32::new(0));
842 let cc = call_count.clone();
843 let result = retry_db_write(|| {
844 let c = cc.clone();
845 async move {
846 let n = c.fetch_add(1, Ordering::SeqCst);
847 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
848 "error-{}",
849 n
850 ))))
851 }
852 })
853 .await;
854 let err = result.unwrap_err();
855 assert!(
857 err.to_string().contains("error-0"),
858 "expected first error (error-0), got: {err}"
859 );
860 }
861}