Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use sea_orm::{
4    ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, Statement,
5    TransactionTrait,
6};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use std::sync::atomic::{AtomicI32, Ordering};
10use std::time::Duration;
11use uuid::Uuid;
12
13use crate::error::DurableError;
14
15// ── Retry constants ───────────────────────────────────────────────────────────
16
17const MAX_CHECKPOINT_RETRIES: u32 = 3;
18const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
19
20/// Retry a fallible DB write with exponential backoff.
21///
22/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
23/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
24/// error if all retries are exhausted.
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
26where
27    F: FnMut() -> Fut,
28    Fut: std::future::Future<Output = Result<(), DurableError>>,
29{
30    match f().await {
31        Ok(()) => Ok(()),
32        Err(first_err) => {
33            for i in 0..MAX_CHECKPOINT_RETRIES {
34                tokio::time::sleep(Duration::from_millis(
35                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
36                ))
37                .await;
38                if f().await.is_ok() {
39                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
40                    return Ok(());
41                }
42            }
43            Err(first_err)
44        }
45    }
46}
47
48/// Policy for retrying a step on failure.
49pub struct RetryPolicy {
50    pub max_retries: u32,
51    pub initial_backoff: std::time::Duration,
52    pub backoff_multiplier: f64,
53}
54
55impl RetryPolicy {
56    /// No retries — fails immediately on error.
57    pub fn none() -> Self {
58        Self {
59            max_retries: 0,
60            initial_backoff: std::time::Duration::from_secs(0),
61            backoff_multiplier: 1.0,
62        }
63    }
64
65    /// Exponential backoff: backoff doubles each retry.
66    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
67        Self {
68            max_retries,
69            initial_backoff,
70            backoff_multiplier: 2.0,
71        }
72    }
73
74    /// Fixed backoff: same duration between all retries.
75    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
76        Self {
77            max_retries,
78            initial_backoff: backoff,
79            backoff_multiplier: 1.0,
80        }
81    }
82}
83
84/// Context threaded through every workflow and step.
85///
86/// Users never create or manage task IDs. The SDK handles everything
87/// via `(parent_id, name)` lookups — the unique constraint in the schema
88/// guarantees exactly-once step creation.
89pub struct Ctx {
90    db: DatabaseConnection,
91    task_id: Uuid,
92    sequence: AtomicI32,
93}
94
95impl Ctx {
96    // ── Workflow lifecycle (user-facing) ──────────────────────────
97
98    /// Start or resume a root workflow by name.
99    ///
100    /// ```ignore
101    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
102    /// ```
103    pub async fn start(
104        db: &DatabaseConnection,
105        name: &str,
106        input: Option<serde_json::Value>,
107    ) -> Result<Self, DurableError> {
108        let txn = db.begin().await?;
109        // Workflows use plain find-or-create without FOR UPDATE SKIP LOCKED.
110        // Multiple workers resuming the same workflow is fine — they share the
111        // same task row and each step will be individually locked.
112        let (task_id, _saved) =
113            find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
114        retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
115        txn.commit().await?;
116        Ok(Self {
117            db: db.clone(),
118            task_id,
119            sequence: AtomicI32::new(0),
120        })
121    }
122
123    /// Run a step. If already completed, returns saved output. Otherwise executes
124    /// the closure, saves the result, and returns it.
125    ///
126    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
127    ///
128    /// ```ignore
129    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
130    /// ```
131    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
132    where
133        T: Serialize + DeserializeOwned,
134        F: FnOnce() -> Fut,
135        Fut: std::future::Future<Output = Result<T, DurableError>>,
136    {
137        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
138
139        // Check if parent task's deadline has passed before executing
140        check_deadline(&self.db, self.task_id).await?;
141
142        // Begin transaction — the FOR UPDATE lock is held throughout step execution
143        let txn = self.db.begin().await?;
144
145        // Find or create — idempotent via UNIQUE(parent_id, name)
146        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
147        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
148        // max_retries=0: step() does not retry; use step_with_retry() for retries.
149        let (step_id, saved_output) = find_or_create_task(
150            &txn,
151            Some(self.task_id),
152            Some(seq),
153            name,
154            "STEP",
155            None,
156            true,
157            Some(0),
158        )
159        .await?;
160
161        // If already completed, replay from saved output
162        if let Some(output) = saved_output {
163            txn.commit().await?;
164            let val: T = serde_json::from_value(output)?;
165            tracing::debug!(step = name, seq, "replaying saved output");
166            return Ok(val);
167        }
168
169        // Execute the step closure within the transaction (row lock is held)
170        retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
171        match f().await {
172            Ok(val) => {
173                let json = serde_json::to_value(&val)?;
174                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
175                txn.commit().await?;
176                tracing::debug!(step = name, seq, "step completed");
177                Ok(val)
178            }
179            Err(e) => {
180                let err_msg = e.to_string();
181                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
182                txn.commit().await?;
183                Err(e)
184            }
185        }
186    }
187
188    /// Run a DB-only step inside a single Postgres transaction.
189    ///
190    /// Both the user's DB work and the checkpoint save happen in the same transaction,
191    /// ensuring atomicity. If the closure returns an error, both the user writes and
192    /// the checkpoint are rolled back.
193    ///
194    /// ```ignore
195    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
196    ///     do_db_work(tx).await
197    /// })).await?;
198    /// ```
199    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
200    where
201        T: Serialize + DeserializeOwned + Send,
202        F: for<'tx> FnOnce(
203                &'tx DatabaseTransaction,
204            ) -> Pin<
205                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
206            > + Send,
207    {
208        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
209
210        // Find or create the step task record OUTSIDE the transaction.
211        // This is idempotent (UNIQUE constraint) and must exist before we begin.
212        let (step_id, saved_output) = find_or_create_task(
213            &self.db,
214            Some(self.task_id),
215            Some(seq),
216            name,
217            "TRANSACTION",
218            None,
219            false,
220            None,
221        )
222        .await?;
223
224        // If already completed, replay from saved output.
225        if let Some(output) = saved_output {
226            let val: T = serde_json::from_value(output)?;
227            tracing::debug!(step = name, seq, "replaying saved transaction output");
228            return Ok(val);
229        }
230
231        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
232        let tx = self.db.begin().await?;
233
234        set_status(&tx, step_id, "RUNNING").await?;
235
236        match f(&tx).await {
237            Ok(val) => {
238                let json = serde_json::to_value(&val)?;
239                complete_task(&tx, step_id, json).await?;
240                tx.commit().await?;
241                tracing::debug!(step = name, seq, "transaction step committed");
242                Ok(val)
243            }
244            Err(e) => {
245                // Rollback happens automatically when tx is dropped.
246                // Record the failure on the main connection (outside the rolled-back tx).
247                drop(tx);
248                fail_task(&self.db, step_id, &e.to_string()).await?;
249                Err(e)
250            }
251        }
252    }
253
254    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
255    ///
256    /// ```ignore
257    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
258    /// // use child_ctx.step(...) for steps inside the child
259    /// child_ctx.complete(json!({"done": true})).await?;
260    /// ```
261    pub async fn child(
262        &self,
263        name: &str,
264        input: Option<serde_json::Value>,
265    ) -> Result<Self, DurableError> {
266        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
267
268        let txn = self.db.begin().await?;
269        // Child workflows also use plain find-or-create without locking.
270        let (child_id, _saved) = find_or_create_task(
271            &txn,
272            Some(self.task_id),
273            Some(seq),
274            name,
275            "WORKFLOW",
276            input,
277            false,
278            None,
279        )
280        .await?;
281
282        // If child already completed, return a Ctx that will replay
283        // (the caller should check is_completed() or just run steps which will replay)
284        retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
285        txn.commit().await?;
286
287        Ok(Self {
288            db: self.db.clone(),
289            task_id: child_id,
290            sequence: AtomicI32::new(0),
291        })
292    }
293
294    /// Check if this workflow/child was already completed (for skipping in parent).
295    pub async fn is_completed(&self) -> Result<bool, DurableError> {
296        let status = get_status(&self.db, self.task_id).await?;
297        Ok(status.as_deref() == Some("COMPLETED"))
298    }
299
300    /// Get the saved output if this task is completed.
301    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
302        match get_output(&self.db, self.task_id).await? {
303            Some(val) => Ok(Some(serde_json::from_value(val)?)),
304            None => Ok(None),
305        }
306    }
307
308    /// Mark this workflow as completed with an output value.
309    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
310        let json = serde_json::to_value(output)?;
311        let db = &self.db;
312        let task_id = self.task_id;
313        retry_db_write(|| complete_task(db, task_id, json.clone())).await
314    }
315
316    /// Run a step with a configurable retry policy.
317    ///
318    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
319    /// it may be called multiple times on retry. Retries happen in-process with
320    /// configurable backoff between attempts.
321    ///
322    /// ```ignore
323    /// let result: u32 = ctx
324    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
325    ///         api.call().await
326    ///     })
327    ///     .await?;
328    /// ```
329    pub async fn step_with_retry<T, F, Fut>(
330        &self,
331        name: &str,
332        policy: RetryPolicy,
333        f: F,
334    ) -> Result<T, DurableError>
335    where
336        T: Serialize + DeserializeOwned,
337        F: Fn() -> Fut,
338        Fut: std::future::Future<Output = Result<T, DurableError>>,
339    {
340        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
341
342        // Find or create — idempotent via UNIQUE(parent_id, name)
343        // Set max_retries from policy when creating.
344        // No locking here — retry logic handles re-execution in-process.
345        let (step_id, saved_output) = find_or_create_task(
346            &self.db,
347            Some(self.task_id),
348            Some(seq),
349            name,
350            "STEP",
351            None,
352            false,
353            Some(policy.max_retries),
354        )
355        .await?;
356
357        // If already completed, replay from saved output
358        if let Some(output) = saved_output {
359            let val: T = serde_json::from_value(output)?;
360            tracing::debug!(step = name, seq, "replaying saved output");
361            return Ok(val);
362        }
363
364        // Get current retry state from DB (for resume across restarts)
365        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
366
367        // Retry loop
368        loop {
369            set_status(&self.db, step_id, "RUNNING").await?;
370            match f().await {
371                Ok(val) => {
372                    let json = serde_json::to_value(&val)?;
373                    complete_task(&self.db, step_id, json).await?;
374                    tracing::debug!(step = name, seq, retry_count, "step completed");
375                    return Ok(val);
376                }
377                Err(e) => {
378                    if retry_count < max_retries {
379                        // Increment retry count and reset to PENDING
380                        retry_count = increment_retry_count(&self.db, step_id).await?;
381                        tracing::debug!(
382                            step = name,
383                            seq,
384                            retry_count,
385                            max_retries,
386                            "step failed, retrying"
387                        );
388
389                        // Compute backoff duration
390                        let backoff = if policy.initial_backoff.is_zero() {
391                            std::time::Duration::ZERO
392                        } else {
393                            let factor = policy
394                                .backoff_multiplier
395                                .powi((retry_count - 1) as i32)
396                                .max(1.0);
397                            let millis =
398                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
399                            std::time::Duration::from_millis(millis)
400                        };
401
402                        if !backoff.is_zero() {
403                            tokio::time::sleep(backoff).await;
404                        }
405                    } else {
406                        // Exhausted retries — mark FAILED
407                        fail_task(&self.db, step_id, &e.to_string()).await?;
408                        tracing::debug!(
409                            step = name,
410                            seq,
411                            retry_count,
412                            "step exhausted retries, marked FAILED"
413                        );
414                        return Err(e);
415                    }
416                }
417            }
418        }
419    }
420
421    /// Mark this workflow as failed.
422    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
423        let db = &self.db;
424        let task_id = self.task_id;
425        retry_db_write(|| fail_task(db, task_id, error)).await
426    }
427
428    /// Set the timeout for this task in milliseconds.
429    ///
430    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
431    /// for recovery by `Executor::recover()`.
432    ///
433    /// If the task is currently RUNNING and has a `started_at`, this also
434    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
435    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
436        let sql = format!(
437            "UPDATE durable.task \
438             SET timeout_ms = {timeout_ms}, \
439                 deadline_epoch_ms = CASE \
440                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
441                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
442                     ELSE deadline_epoch_ms \
443                 END \
444             WHERE id = '{}'",
445            self.task_id
446        );
447        self.db
448            .execute(Statement::from_string(DbBackend::Postgres, sql))
449            .await?;
450        Ok(())
451    }
452
453    /// Start or resume a root workflow with a timeout in milliseconds.
454    ///
455    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
456    pub async fn start_with_timeout(
457        db: &DatabaseConnection,
458        name: &str,
459        input: Option<serde_json::Value>,
460        timeout_ms: i64,
461    ) -> Result<Self, DurableError> {
462        let ctx = Self::start(db, name, input).await?;
463        ctx.set_timeout(timeout_ms).await?;
464        Ok(ctx)
465    }
466
467    // ── Accessors ────────────────────────────────────────────────
468
469    pub fn db(&self) -> &DatabaseConnection {
470        &self.db
471    }
472
473    pub fn task_id(&self) -> Uuid {
474        self.task_id
475    }
476
477    pub fn next_sequence(&self) -> i32 {
478        self.sequence.fetch_add(1, Ordering::SeqCst)
479    }
480}
481
482// ── Internal SQL helpers ─────────────────────────────────────────────
483
484/// Find an existing task by (parent_id, name) or create a new one.
485///
486/// Returns `(task_id, Option<saved_output>)`:
487/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
488/// - `saved_output` is `None` when the task is new or in-progress.
489///
490/// When `lock` is `true` and an existing non-completed task is found, this
491/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
492/// another worker holds the lock, `DurableError::StepLocked` is returned so
493/// the caller can skip execution rather than double-firing side effects.
494///
495/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
496/// and child-workflow creation where concurrent start is safe).
497///
498/// When `lock` is `true`, the caller MUST call this within a transaction so
499/// the row lock is held throughout step execution.
500///
501/// `max_retries`: if Some, overrides the schema default when creating a new task.
502#[allow(clippy::too_many_arguments)]
503async fn find_or_create_task(
504    db: &impl ConnectionTrait,
505    parent_id: Option<Uuid>,
506    sequence: Option<i32>,
507    name: &str,
508    kind: &str,
509    input: Option<serde_json::Value>,
510    lock: bool,
511    max_retries: Option<u32>,
512) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
513    let parent_eq = match parent_id {
514        Some(p) => format!("= '{p}'"),
515        None => "IS NULL".to_string(),
516    };
517    let parent_sql = match parent_id {
518        Some(p) => format!("'{p}'"),
519        None => "NULL".to_string(),
520    };
521
522    if lock {
523        // Locking path (for steps): we need exactly-once execution.
524        //
525        // Strategy:
526        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
527        //    If another transaction is concurrently inserting the same row,
528        //    Postgres will block here until that transaction commits or rolls
529        //    back, ensuring we never see a phantom "not found" for a row being
530        //    inserted.
531        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
532        //    should get it back; if the row existed and is locked by another
533        //    worker we get nothing.
534        // 3. If SKIP LOCKED returns empty, return StepLocked.
535
536        let new_id = Uuid::new_v4();
537        let seq_sql = match sequence {
538            Some(s) => s.to_string(),
539            None => "NULL".to_string(),
540        };
541        let input_sql = match &input {
542            Some(v) => format!("'{}'", serde_json::to_string(v)?),
543            None => "NULL".to_string(),
544        };
545
546        let max_retries_sql = match max_retries {
547            Some(r) => r.to_string(),
548            None => "3".to_string(), // schema default
549        };
550
551        // Step 1: insert-or-skip
552        let insert_sql = format!(
553            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
554             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
555             ON CONFLICT (parent_id, name) DO NOTHING"
556        );
557        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
558            .await?;
559
560        // Step 2: lock the row (ours or pre-existing)
561        let lock_sql = format!(
562            "SELECT id, status, output FROM durable.task \
563             WHERE parent_id {parent_eq} AND name = '{name}' \
564             FOR UPDATE SKIP LOCKED"
565        );
566        let row = db
567            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
568            .await?;
569
570        if let Some(row) = row {
571            let id: Uuid = row
572                .try_get_by_index(0)
573                .map_err(|e| DurableError::custom(e.to_string()))?;
574            let status: String = row
575                .try_get_by_index(1)
576                .map_err(|e| DurableError::custom(e.to_string()))?;
577            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
578
579            if status == "COMPLETED" {
580                // Replay path — return saved output
581                return Ok((id, output));
582            }
583            // Task exists and we hold the lock — proceed to execute
584            return Ok((id, None));
585        }
586
587        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
588        Err(DurableError::StepLocked(name.to_string()))
589    } else {
590        // Plain find without locking — safe for workflow-level operations.
591        // Multiple workers resuming the same workflow is fine; individual
592        // steps will be locked when executed.
593        let find_sql = format!(
594            "SELECT id, status, output FROM durable.task \
595             WHERE parent_id {parent_eq} AND name = '{name}'"
596        );
597        let row = db
598            .query_one(Statement::from_string(DbBackend::Postgres, find_sql))
599            .await?;
600
601        if let Some(row) = row {
602            let id: Uuid = row
603                .try_get_by_index(0)
604                .map_err(|e| DurableError::custom(e.to_string()))?;
605            let status: String = row
606                .try_get_by_index(1)
607                .map_err(|e| DurableError::custom(e.to_string()))?;
608            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
609
610            if status == "COMPLETED" {
611                return Ok((id, output));
612            }
613            return Ok((id, None));
614        }
615
616        // Task does not exist — create it
617        let id = Uuid::new_v4();
618        let seq_sql = match sequence {
619            Some(s) => s.to_string(),
620            None => "NULL".to_string(),
621        };
622        let input_sql = match &input {
623            Some(v) => format!("'{}'", serde_json::to_string(v)?),
624            None => "NULL".to_string(),
625        };
626        let max_retries_sql = match max_retries {
627            Some(r) => r.to_string(),
628            None => "3".to_string(), // schema default
629        };
630
631        let sql = format!(
632            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
633             VALUES ('{id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql})"
634        );
635        db.execute(Statement::from_string(DbBackend::Postgres, sql))
636            .await?;
637
638        Ok((id, None))
639    }
640}
641
642async fn get_output(
643    db: &impl ConnectionTrait,
644    task_id: Uuid,
645) -> Result<Option<serde_json::Value>, DurableError> {
646    let sql =
647        format!("SELECT output FROM durable.task WHERE id = '{task_id}' AND status = 'COMPLETED'");
648    let row = db
649        .query_one(Statement::from_string(DbBackend::Postgres, sql))
650        .await?;
651
652    match row {
653        Some(r) => Ok(r.try_get_by_index(0).ok()),
654        None => Ok(None),
655    }
656}
657
658async fn get_status(
659    db: &impl ConnectionTrait,
660    task_id: Uuid,
661) -> Result<Option<String>, DurableError> {
662    let sql = format!("SELECT status FROM durable.task WHERE id = '{task_id}'");
663    let row = db
664        .query_one(Statement::from_string(DbBackend::Postgres, sql))
665        .await?;
666
667    match row {
668        Some(r) => Ok(r.try_get_by_index(0).ok()),
669        None => Ok(None),
670    }
671}
672
673/// Returns (retry_count, max_retries) for a task.
674async fn get_retry_info(
675    db: &DatabaseConnection,
676    task_id: Uuid,
677) -> Result<(u32, u32), DurableError> {
678    let sql = format!("SELECT retry_count, max_retries FROM durable.task WHERE id = '{task_id}'");
679    let row = db
680        .query_one(Statement::from_string(DbBackend::Postgres, sql))
681        .await?;
682
683    match row {
684        Some(r) => {
685            let retry_count: i32 = r
686                .try_get_by_index(0)
687                .map_err(|e| DurableError::custom(e.to_string()))?;
688            let max_retries: i32 = r
689                .try_get_by_index(1)
690                .map_err(|e| DurableError::custom(e.to_string()))?;
691            Ok((retry_count as u32, max_retries as u32))
692        }
693        None => Err(DurableError::custom(format!(
694            "task {task_id} not found when reading retry info"
695        ))),
696    }
697}
698
699/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
700async fn increment_retry_count(
701    db: &DatabaseConnection,
702    task_id: Uuid,
703) -> Result<u32, DurableError> {
704    let sql = format!(
705        "UPDATE durable.task \
706         SET retry_count = retry_count + 1, status = 'PENDING', error = NULL, completed_at = NULL \
707         WHERE id = '{task_id}' \
708         RETURNING retry_count"
709    );
710    let row = db
711        .query_one(Statement::from_string(DbBackend::Postgres, sql))
712        .await?;
713
714    match row {
715        Some(r) => {
716            let count: i32 = r
717                .try_get_by_index(0)
718                .map_err(|e| DurableError::custom(e.to_string()))?;
719            Ok(count as u32)
720        }
721        None => Err(DurableError::custom(format!(
722            "task {task_id} not found when incrementing retry count"
723        ))),
724    }
725}
726
727async fn set_status(
728    db: &impl ConnectionTrait,
729    task_id: Uuid,
730    status: &str,
731) -> Result<(), DurableError> {
732    let sql = format!(
733        "UPDATE durable.task \
734         SET status = '{status}', \
735             started_at = COALESCE(started_at, now()), \
736             deadline_epoch_ms = CASE \
737                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
738                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
739                 ELSE deadline_epoch_ms \
740             END \
741         WHERE id = '{task_id}'"
742    );
743    db.execute(Statement::from_string(DbBackend::Postgres, sql))
744        .await?;
745    Ok(())
746}
747
748/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
749async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
750    let sql = format!("SELECT deadline_epoch_ms FROM durable.task WHERE id = '{task_id}'");
751    let row = db
752        .query_one(Statement::from_string(DbBackend::Postgres, sql))
753        .await?;
754
755    if let Some(row) = row {
756        let deadline: Option<i64> = row.try_get_by_index(0).ok().flatten();
757        if let Some(deadline_ms) = deadline {
758            let now_ms = std::time::SystemTime::now()
759                .duration_since(std::time::UNIX_EPOCH)
760                .map(|d| d.as_millis() as i64)
761                .unwrap_or(0);
762            if now_ms > deadline_ms {
763                return Err(DurableError::Timeout("task deadline exceeded".to_string()));
764            }
765        }
766    }
767
768    Ok(())
769}
770
771async fn complete_task(
772    db: &impl ConnectionTrait,
773    task_id: Uuid,
774    output: serde_json::Value,
775) -> Result<(), DurableError> {
776    let sql = format!(
777        "UPDATE durable.task SET status = 'COMPLETED', output = '{}', completed_at = now() \
778         WHERE id = '{task_id}'",
779        output
780    );
781    db.execute(Statement::from_string(DbBackend::Postgres, sql))
782        .await?;
783    Ok(())
784}
785
786async fn fail_task(
787    db: &impl ConnectionTrait,
788    task_id: Uuid,
789    error: &str,
790) -> Result<(), DurableError> {
791    let sql = format!(
792        "UPDATE durable.task SET status = 'FAILED', error = '{}', completed_at = now() \
793         WHERE id = '{task_id}'",
794        error.replace('\'', "''")
795    );
796    db.execute(Statement::from_string(DbBackend::Postgres, sql))
797        .await?;
798    Ok(())
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use std::sync::Arc;
805    use std::sync::atomic::{AtomicU32, Ordering};
806
807    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
808    /// should be called exactly once and return Ok.
809    #[tokio::test]
810    async fn test_retry_db_write_succeeds_first_try() {
811        let call_count = Arc::new(AtomicU32::new(0));
812        let cc = call_count.clone();
813        let result = retry_db_write(|| {
814            let c = cc.clone();
815            async move {
816                c.fetch_add(1, Ordering::SeqCst);
817                Ok::<(), DurableError>(())
818            }
819        })
820        .await;
821        assert!(result.is_ok());
822        assert_eq!(call_count.load(Ordering::SeqCst), 1);
823    }
824
825    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
826    /// fails twice then succeeds should return Ok and be called 3 times.
827    #[tokio::test]
828    async fn test_retry_db_write_succeeds_after_transient_failure() {
829        let call_count = Arc::new(AtomicU32::new(0));
830        let cc = call_count.clone();
831        let result = retry_db_write(|| {
832            let c = cc.clone();
833            async move {
834                let n = c.fetch_add(1, Ordering::SeqCst);
835                if n < 2 {
836                    Err(DurableError::Db(sea_orm::DbErr::Custom(
837                        "transient".to_string(),
838                    )))
839                } else {
840                    Ok(())
841                }
842            }
843        })
844        .await;
845        assert!(result.is_ok());
846        assert_eq!(call_count.load(Ordering::SeqCst), 3);
847    }
848
849    /// test_retry_db_write_exhausts_retries: a closure that always fails should
850    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
851    #[tokio::test]
852    async fn test_retry_db_write_exhausts_retries() {
853        let call_count = Arc::new(AtomicU32::new(0));
854        let cc = call_count.clone();
855        let result = retry_db_write(|| {
856            let c = cc.clone();
857            async move {
858                c.fetch_add(1, Ordering::SeqCst);
859                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
860                    "always fails".to_string(),
861                )))
862            }
863        })
864        .await;
865        assert!(result.is_err());
866        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
867        assert_eq!(
868            call_count.load(Ordering::SeqCst),
869            1 + MAX_CHECKPOINT_RETRIES
870        );
871    }
872
873    /// test_retry_db_write_returns_original_error: when all retries are
874    /// exhausted the FIRST error is returned, not the last retry error.
875    #[tokio::test]
876    async fn test_retry_db_write_returns_original_error() {
877        let call_count = Arc::new(AtomicU32::new(0));
878        let cc = call_count.clone();
879        let result = retry_db_write(|| {
880            let c = cc.clone();
881            async move {
882                let n = c.fetch_add(1, Ordering::SeqCst);
883                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
884                    "error-{}",
885                    n
886                ))))
887            }
888        })
889        .await;
890        let err = result.unwrap_err();
891        // The message of the first error contains "error-0"
892        assert!(
893            err.to_string().contains("error-0"),
894            "expected first error (error-0), got: {err}"
895        );
896    }
897}