1use std::pin::Pin;
2
3use durable_db::entity::task::{
4 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
5};
6use sea_orm::{
7 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8 DbBackend, EntityTrait, QueryFilter, Set, Statement, TransactionTrait,
9};
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use std::sync::atomic::{AtomicI32, Ordering};
13use std::time::Duration;
14use uuid::Uuid;
15
16use crate::error::DurableError;
17
18const MAX_CHECKPOINT_RETRIES: u32 = 3;
21const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
22
23async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
29where
30 F: FnMut() -> Fut,
31 Fut: std::future::Future<Output = Result<(), DurableError>>,
32{
33 match f().await {
34 Ok(()) => Ok(()),
35 Err(first_err) => {
36 for i in 0..MAX_CHECKPOINT_RETRIES {
37 tokio::time::sleep(Duration::from_millis(
38 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
39 ))
40 .await;
41 if f().await.is_ok() {
42 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
43 return Ok(());
44 }
45 }
46 Err(first_err)
47 }
48 }
49}
50
51pub struct RetryPolicy {
53 pub max_retries: u32,
54 pub initial_backoff: std::time::Duration,
55 pub backoff_multiplier: f64,
56}
57
58impl RetryPolicy {
59 pub fn none() -> Self {
61 Self {
62 max_retries: 0,
63 initial_backoff: std::time::Duration::from_secs(0),
64 backoff_multiplier: 1.0,
65 }
66 }
67
68 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
70 Self {
71 max_retries,
72 initial_backoff,
73 backoff_multiplier: 2.0,
74 }
75 }
76
77 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
79 Self {
80 max_retries,
81 initial_backoff: backoff,
82 backoff_multiplier: 1.0,
83 }
84 }
85}
86
87pub struct Ctx {
93 db: DatabaseConnection,
94 task_id: Uuid,
95 sequence: AtomicI32,
96}
97
98impl Ctx {
99 pub async fn start(
107 db: &DatabaseConnection,
108 name: &str,
109 input: Option<serde_json::Value>,
110 ) -> Result<Self, DurableError> {
111 let txn = db.begin().await?;
112 let (task_id, _saved) =
116 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
117 retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
118 txn.commit().await?;
119 Ok(Self {
120 db: db.clone(),
121 task_id,
122 sequence: AtomicI32::new(0),
123 })
124 }
125
126 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
135 where
136 T: Serialize + DeserializeOwned,
137 F: FnOnce() -> Fut,
138 Fut: std::future::Future<Output = Result<T, DurableError>>,
139 {
140 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
141
142 check_status(&self.db, self.task_id).await?;
144
145 check_deadline(&self.db, self.task_id).await?;
147
148 let txn = self.db.begin().await?;
150
151 let (step_id, saved_output) = find_or_create_task(
156 &txn,
157 Some(self.task_id),
158 Some(seq),
159 name,
160 "STEP",
161 None,
162 true,
163 Some(0),
164 )
165 .await?;
166
167 if let Some(output) = saved_output {
169 txn.commit().await?;
170 let val: T = serde_json::from_value(output)?;
171 tracing::debug!(step = name, seq, "replaying saved output");
172 return Ok(val);
173 }
174
175 retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
177 match f().await {
178 Ok(val) => {
179 let json = serde_json::to_value(&val)?;
180 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
181 txn.commit().await?;
182 tracing::debug!(step = name, seq, "step completed");
183 Ok(val)
184 }
185 Err(e) => {
186 let err_msg = e.to_string();
187 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
188 txn.commit().await?;
189 Err(e)
190 }
191 }
192 }
193
194 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
206 where
207 T: Serialize + DeserializeOwned + Send,
208 F: for<'tx> FnOnce(
209 &'tx DatabaseTransaction,
210 ) -> Pin<
211 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
212 > + Send,
213 {
214 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
215
216 check_status(&self.db, self.task_id).await?;
218
219 let (step_id, saved_output) = find_or_create_task(
222 &self.db,
223 Some(self.task_id),
224 Some(seq),
225 name,
226 "TRANSACTION",
227 None,
228 false,
229 None,
230 )
231 .await?;
232
233 if let Some(output) = saved_output {
235 let val: T = serde_json::from_value(output)?;
236 tracing::debug!(step = name, seq, "replaying saved transaction output");
237 return Ok(val);
238 }
239
240 let tx = self.db.begin().await?;
242
243 set_status(&tx, step_id, "RUNNING").await?;
244
245 match f(&tx).await {
246 Ok(val) => {
247 let json = serde_json::to_value(&val)?;
248 complete_task(&tx, step_id, json).await?;
249 tx.commit().await?;
250 tracing::debug!(step = name, seq, "transaction step committed");
251 Ok(val)
252 }
253 Err(e) => {
254 drop(tx);
257 fail_task(&self.db, step_id, &e.to_string()).await?;
258 Err(e)
259 }
260 }
261 }
262
263 pub async fn child(
271 &self,
272 name: &str,
273 input: Option<serde_json::Value>,
274 ) -> Result<Self, DurableError> {
275 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
276
277 check_status(&self.db, self.task_id).await?;
279
280 let txn = self.db.begin().await?;
281 let (child_id, _saved) = find_or_create_task(
283 &txn,
284 Some(self.task_id),
285 Some(seq),
286 name,
287 "WORKFLOW",
288 input,
289 false,
290 None,
291 )
292 .await?;
293
294 retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
297 txn.commit().await?;
298
299 Ok(Self {
300 db: self.db.clone(),
301 task_id: child_id,
302 sequence: AtomicI32::new(0),
303 })
304 }
305
306 pub async fn is_completed(&self) -> Result<bool, DurableError> {
308 let status = get_status(&self.db, self.task_id).await?;
309 Ok(status.as_deref() == Some("COMPLETED"))
310 }
311
312 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
314 match get_output(&self.db, self.task_id).await? {
315 Some(val) => Ok(Some(serde_json::from_value(val)?)),
316 None => Ok(None),
317 }
318 }
319
320 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
322 let json = serde_json::to_value(output)?;
323 let db = &self.db;
324 let task_id = self.task_id;
325 retry_db_write(|| complete_task(db, task_id, json.clone())).await
326 }
327
328 pub async fn step_with_retry<T, F, Fut>(
342 &self,
343 name: &str,
344 policy: RetryPolicy,
345 f: F,
346 ) -> Result<T, DurableError>
347 where
348 T: Serialize + DeserializeOwned,
349 F: Fn() -> Fut,
350 Fut: std::future::Future<Output = Result<T, DurableError>>,
351 {
352 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
353
354 check_status(&self.db, self.task_id).await?;
356
357 let (step_id, saved_output) = find_or_create_task(
361 &self.db,
362 Some(self.task_id),
363 Some(seq),
364 name,
365 "STEP",
366 None,
367 false,
368 Some(policy.max_retries),
369 )
370 .await?;
371
372 if let Some(output) = saved_output {
374 let val: T = serde_json::from_value(output)?;
375 tracing::debug!(step = name, seq, "replaying saved output");
376 return Ok(val);
377 }
378
379 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
381
382 loop {
384 check_status(&self.db, self.task_id).await?;
386 set_status(&self.db, step_id, "RUNNING").await?;
387 match f().await {
388 Ok(val) => {
389 let json = serde_json::to_value(&val)?;
390 complete_task(&self.db, step_id, json).await?;
391 tracing::debug!(step = name, seq, retry_count, "step completed");
392 return Ok(val);
393 }
394 Err(e) => {
395 if retry_count < max_retries {
396 retry_count = increment_retry_count(&self.db, step_id).await?;
398 tracing::debug!(
399 step = name,
400 seq,
401 retry_count,
402 max_retries,
403 "step failed, retrying"
404 );
405
406 let backoff = if policy.initial_backoff.is_zero() {
408 std::time::Duration::ZERO
409 } else {
410 let factor = policy
411 .backoff_multiplier
412 .powi((retry_count - 1) as i32)
413 .max(1.0);
414 let millis =
415 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
416 std::time::Duration::from_millis(millis)
417 };
418
419 if !backoff.is_zero() {
420 tokio::time::sleep(backoff).await;
421 }
422 } else {
423 fail_task(&self.db, step_id, &e.to_string()).await?;
425 tracing::debug!(
426 step = name,
427 seq,
428 retry_count,
429 "step exhausted retries, marked FAILED"
430 );
431 return Err(e);
432 }
433 }
434 }
435 }
436 }
437
438 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
440 let db = &self.db;
441 let task_id = self.task_id;
442 retry_db_write(|| fail_task(db, task_id, error)).await
443 }
444
445 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
453 let sql = format!(
454 "UPDATE durable.task \
455 SET timeout_ms = {timeout_ms}, \
456 deadline_epoch_ms = CASE \
457 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
458 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
459 ELSE deadline_epoch_ms \
460 END \
461 WHERE id = '{}'",
462 self.task_id
463 );
464 self.db
465 .execute(Statement::from_string(DbBackend::Postgres, sql))
466 .await?;
467 Ok(())
468 }
469
470 pub async fn start_with_timeout(
474 db: &DatabaseConnection,
475 name: &str,
476 input: Option<serde_json::Value>,
477 timeout_ms: i64,
478 ) -> Result<Self, DurableError> {
479 let ctx = Self::start(db, name, input).await?;
480 ctx.set_timeout(timeout_ms).await?;
481 Ok(ctx)
482 }
483
484 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
491 let model = Task::find_by_id(task_id).one(db).await?;
492 let model =
493 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
494
495 match model.status.as_str() {
496 "PENDING" | "RUNNING" => {}
497 status => {
498 return Err(DurableError::custom(format!(
499 "cannot pause task in {status} status"
500 )));
501 }
502 }
503
504 let sql = format!(
506 "WITH RECURSIVE descendants AS ( \
507 SELECT id FROM durable.task WHERE id = '{task_id}' \
508 UNION ALL \
509 SELECT t.id FROM durable.task t \
510 INNER JOIN descendants d ON t.parent_id = d.id \
511 ) \
512 UPDATE durable.task SET status = 'PAUSED' \
513 WHERE id IN (SELECT id FROM descendants) \
514 AND status IN ('PENDING', 'RUNNING')"
515 );
516 db.execute(Statement::from_string(DbBackend::Postgres, sql))
517 .await?;
518
519 tracing::info!(%task_id, "workflow paused");
520 Ok(())
521 }
522
523 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
526 let model = Task::find_by_id(task_id).one(db).await?;
527 let model =
528 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
529
530 if model.status != "PAUSED" {
531 return Err(DurableError::custom(format!(
532 "cannot resume task in {} status (must be PAUSED)",
533 model.status
534 )));
535 }
536
537 let sql = format!(
539 "UPDATE durable.task SET status = 'RUNNING' WHERE id = '{task_id}'"
540 );
541 db.execute(Statement::from_string(DbBackend::Postgres, sql))
542 .await?;
543
544 let cascade_sql = format!(
546 "WITH RECURSIVE descendants AS ( \
547 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
548 UNION ALL \
549 SELECT t.id FROM durable.task t \
550 INNER JOIN descendants d ON t.parent_id = d.id \
551 ) \
552 UPDATE durable.task SET status = 'PENDING' \
553 WHERE id IN (SELECT id FROM descendants) \
554 AND status = 'PAUSED'"
555 );
556 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
557 .await?;
558
559 tracing::info!(%task_id, "workflow resumed");
560 Ok(())
561 }
562
563 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
568 let model = Task::find_by_id(task_id).one(db).await?;
569 let model =
570 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
571
572 match model.status.as_str() {
573 "COMPLETED" | "FAILED" | "CANCELLED" => {
574 return Err(DurableError::custom(format!(
575 "cannot cancel task in {} status",
576 model.status
577 )));
578 }
579 _ => {}
580 }
581
582 let sql = format!(
584 "WITH RECURSIVE descendants AS ( \
585 SELECT id FROM durable.task WHERE id = '{task_id}' \
586 UNION ALL \
587 SELECT t.id FROM durable.task t \
588 INNER JOIN descendants d ON t.parent_id = d.id \
589 ) \
590 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
591 WHERE id IN (SELECT id FROM descendants) \
592 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
593 );
594 db.execute(Statement::from_string(DbBackend::Postgres, sql))
595 .await?;
596
597 tracing::info!(%task_id, "workflow cancelled");
598 Ok(())
599 }
600
601 pub fn db(&self) -> &DatabaseConnection {
604 &self.db
605 }
606
607 pub fn task_id(&self) -> Uuid {
608 self.task_id
609 }
610
611 pub fn next_sequence(&self) -> i32 {
612 self.sequence.fetch_add(1, Ordering::SeqCst)
613 }
614}
615
616#[allow(clippy::too_many_arguments)]
637async fn find_or_create_task(
638 db: &impl ConnectionTrait,
639 parent_id: Option<Uuid>,
640 sequence: Option<i32>,
641 name: &str,
642 kind: &str,
643 input: Option<serde_json::Value>,
644 lock: bool,
645 max_retries: Option<u32>,
646) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
647 let parent_eq = match parent_id {
648 Some(p) => format!("= '{p}'"),
649 None => "IS NULL".to_string(),
650 };
651 let parent_sql = match parent_id {
652 Some(p) => format!("'{p}'"),
653 None => "NULL".to_string(),
654 };
655
656 if lock {
657 let new_id = Uuid::new_v4();
671 let seq_sql = match sequence {
672 Some(s) => s.to_string(),
673 None => "NULL".to_string(),
674 };
675 let input_sql = match &input {
676 Some(v) => format!("'{}'", serde_json::to_string(v)?),
677 None => "NULL".to_string(),
678 };
679
680 let max_retries_sql = match max_retries {
681 Some(r) => r.to_string(),
682 None => "3".to_string(), };
684
685 let insert_sql = format!(
687 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
688 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
689 ON CONFLICT (parent_id, name) DO NOTHING"
690 );
691 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
692 .await?;
693
694 let lock_sql = format!(
696 "SELECT id, status, output FROM durable.task \
697 WHERE parent_id {parent_eq} AND name = '{name}' \
698 FOR UPDATE SKIP LOCKED"
699 );
700 let row = db
701 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
702 .await?;
703
704 if let Some(row) = row {
705 let id: Uuid = row
706 .try_get_by_index(0)
707 .map_err(|e| DurableError::custom(e.to_string()))?;
708 let status: String = row
709 .try_get_by_index(1)
710 .map_err(|e| DurableError::custom(e.to_string()))?;
711 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
712
713 if status == "COMPLETED" {
714 return Ok((id, output));
716 }
717 return Ok((id, None));
719 }
720
721 Err(DurableError::StepLocked(name.to_string()))
723 } else {
724 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
728 query = match parent_id {
729 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
730 None => query.filter(TaskColumn::ParentId.is_null()),
731 };
732 let existing = query.one(db).await?;
733
734 if let Some(model) = existing {
735 if model.status == "COMPLETED" {
736 return Ok((model.id, model.output));
737 }
738 return Ok((model.id, None));
739 }
740
741 let id = Uuid::new_v4();
743 let new_task = TaskActiveModel {
744 id: Set(id),
745 parent_id: Set(parent_id),
746 sequence: Set(sequence),
747 name: Set(name.to_string()),
748 kind: Set(kind.to_string()),
749 status: Set("PENDING".to_string()),
750 input: Set(input),
751 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
752 ..Default::default()
753 };
754 new_task.insert(db).await?;
755
756 Ok((id, None))
757 }
758}
759
760async fn get_output(
761 db: &impl ConnectionTrait,
762 task_id: Uuid,
763) -> Result<Option<serde_json::Value>, DurableError> {
764 let model = Task::find_by_id(task_id)
765 .filter(TaskColumn::Status.eq("COMPLETED"))
766 .one(db)
767 .await?;
768
769 Ok(model.and_then(|m| m.output))
770}
771
772async fn get_status(
773 db: &impl ConnectionTrait,
774 task_id: Uuid,
775) -> Result<Option<String>, DurableError> {
776 let model = Task::find_by_id(task_id).one(db).await?;
777
778 Ok(model.map(|m| m.status))
779}
780
781async fn get_retry_info(
783 db: &DatabaseConnection,
784 task_id: Uuid,
785) -> Result<(u32, u32), DurableError> {
786 let model = Task::find_by_id(task_id).one(db).await?;
787
788 match model {
789 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
790 None => Err(DurableError::custom(format!(
791 "task {task_id} not found when reading retry info"
792 ))),
793 }
794}
795
796async fn increment_retry_count(
798 db: &DatabaseConnection,
799 task_id: Uuid,
800) -> Result<u32, DurableError> {
801 let model = Task::find_by_id(task_id).one(db).await?;
802
803 match model {
804 Some(m) => {
805 let new_count = m.retry_count + 1;
806 let mut active: TaskActiveModel = m.into();
807 active.retry_count = Set(new_count);
808 active.status = Set("PENDING".to_string());
809 active.error = Set(None);
810 active.completed_at = Set(None);
811 active.update(db).await?;
812 Ok(new_count as u32)
813 }
814 None => Err(DurableError::custom(format!(
815 "task {task_id} not found when incrementing retry count"
816 ))),
817 }
818}
819
820async fn set_status(
823 db: &impl ConnectionTrait,
824 task_id: Uuid,
825 status: &str,
826) -> Result<(), DurableError> {
827 let sql = format!(
828 "UPDATE durable.task \
829 SET status = '{status}', \
830 started_at = COALESCE(started_at, now()), \
831 deadline_epoch_ms = CASE \
832 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
833 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
834 ELSE deadline_epoch_ms \
835 END \
836 WHERE id = '{task_id}'"
837 );
838 db.execute(Statement::from_string(DbBackend::Postgres, sql))
839 .await?;
840 Ok(())
841}
842
843async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
845 let status = get_status(db, task_id).await?;
846 match status.as_deref() {
847 Some("PAUSED") => Err(DurableError::Paused(format!("task {task_id} is paused"))),
848 Some("CANCELLED") => Err(DurableError::Cancelled(format!("task {task_id} is cancelled"))),
849 _ => Ok(()),
850 }
851}
852
853async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
855 let model = Task::find_by_id(task_id).one(db).await?;
856
857 if let Some(m) = model
858 && let Some(deadline_ms) = m.deadline_epoch_ms
859 {
860 let now_ms = std::time::SystemTime::now()
861 .duration_since(std::time::UNIX_EPOCH)
862 .map(|d| d.as_millis() as i64)
863 .unwrap_or(0);
864 if now_ms > deadline_ms {
865 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
866 }
867 }
868
869 Ok(())
870}
871
872async fn complete_task(
873 db: &impl ConnectionTrait,
874 task_id: Uuid,
875 output: serde_json::Value,
876) -> Result<(), DurableError> {
877 let model = Task::find_by_id(task_id).one(db).await?;
878
879 if let Some(m) = model {
880 let mut active: TaskActiveModel = m.into();
881 active.status = Set("COMPLETED".to_string());
882 active.output = Set(Some(output));
883 active.completed_at = Set(Some(chrono::Utc::now().into()));
884 active.update(db).await?;
885 }
886 Ok(())
887}
888
889async fn fail_task(
890 db: &impl ConnectionTrait,
891 task_id: Uuid,
892 error: &str,
893) -> Result<(), DurableError> {
894 let model = Task::find_by_id(task_id).one(db).await?;
895
896 if let Some(m) = model {
897 let mut active: TaskActiveModel = m.into();
898 active.status = Set("FAILED".to_string());
899 active.error = Set(Some(error.to_string()));
900 active.completed_at = Set(Some(chrono::Utc::now().into()));
901 active.update(db).await?;
902 }
903 Ok(())
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909 use std::sync::Arc;
910 use std::sync::atomic::{AtomicU32, Ordering};
911
912 #[tokio::test]
915 async fn test_retry_db_write_succeeds_first_try() {
916 let call_count = Arc::new(AtomicU32::new(0));
917 let cc = call_count.clone();
918 let result = retry_db_write(|| {
919 let c = cc.clone();
920 async move {
921 c.fetch_add(1, Ordering::SeqCst);
922 Ok::<(), DurableError>(())
923 }
924 })
925 .await;
926 assert!(result.is_ok());
927 assert_eq!(call_count.load(Ordering::SeqCst), 1);
928 }
929
930 #[tokio::test]
933 async fn test_retry_db_write_succeeds_after_transient_failure() {
934 let call_count = Arc::new(AtomicU32::new(0));
935 let cc = call_count.clone();
936 let result = retry_db_write(|| {
937 let c = cc.clone();
938 async move {
939 let n = c.fetch_add(1, Ordering::SeqCst);
940 if n < 2 {
941 Err(DurableError::Db(sea_orm::DbErr::Custom(
942 "transient".to_string(),
943 )))
944 } else {
945 Ok(())
946 }
947 }
948 })
949 .await;
950 assert!(result.is_ok());
951 assert_eq!(call_count.load(Ordering::SeqCst), 3);
952 }
953
954 #[tokio::test]
957 async fn test_retry_db_write_exhausts_retries() {
958 let call_count = Arc::new(AtomicU32::new(0));
959 let cc = call_count.clone();
960 let result = retry_db_write(|| {
961 let c = cc.clone();
962 async move {
963 c.fetch_add(1, Ordering::SeqCst);
964 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
965 "always fails".to_string(),
966 )))
967 }
968 })
969 .await;
970 assert!(result.is_err());
971 assert_eq!(
973 call_count.load(Ordering::SeqCst),
974 1 + MAX_CHECKPOINT_RETRIES
975 );
976 }
977
978 #[tokio::test]
981 async fn test_retry_db_write_returns_original_error() {
982 let call_count = Arc::new(AtomicU32::new(0));
983 let cc = call_count.clone();
984 let result = retry_db_write(|| {
985 let c = cc.clone();
986 async move {
987 let n = c.fetch_add(1, Ordering::SeqCst);
988 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
989 "error-{}",
990 n
991 ))))
992 }
993 })
994 .await;
995 let err = result.unwrap_err();
996 assert!(
998 err.to_string().contains("error-0"),
999 "expected first error (error-0), got: {err}"
1000 );
1001 }
1002}