Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::task::{
4    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
5};
6use sea_orm::{
7    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8    DbBackend, EntityTrait, QueryFilter, Set, Statement, TransactionTrait,
9};
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use std::sync::atomic::{AtomicI32, Ordering};
13use std::time::Duration;
14use uuid::Uuid;
15
16use crate::error::DurableError;
17
18// ── Retry constants ───────────────────────────────────────────────────────────
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
21const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
22
23/// Retry a fallible DB write with exponential backoff.
24///
25/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
26/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
27/// error if all retries are exhausted.
28async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
29where
30    F: FnMut() -> Fut,
31    Fut: std::future::Future<Output = Result<(), DurableError>>,
32{
33    match f().await {
34        Ok(()) => Ok(()),
35        Err(first_err) => {
36            for i in 0..MAX_CHECKPOINT_RETRIES {
37                tokio::time::sleep(Duration::from_millis(
38                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
39                ))
40                .await;
41                if f().await.is_ok() {
42                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
43                    return Ok(());
44                }
45            }
46            Err(first_err)
47        }
48    }
49}
50
51/// Policy for retrying a step on failure.
52pub struct RetryPolicy {
53    pub max_retries: u32,
54    pub initial_backoff: std::time::Duration,
55    pub backoff_multiplier: f64,
56}
57
58impl RetryPolicy {
59    /// No retries — fails immediately on error.
60    pub fn none() -> Self {
61        Self {
62            max_retries: 0,
63            initial_backoff: std::time::Duration::from_secs(0),
64            backoff_multiplier: 1.0,
65        }
66    }
67
68    /// Exponential backoff: backoff doubles each retry.
69    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
70        Self {
71            max_retries,
72            initial_backoff,
73            backoff_multiplier: 2.0,
74        }
75    }
76
77    /// Fixed backoff: same duration between all retries.
78    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
79        Self {
80            max_retries,
81            initial_backoff: backoff,
82            backoff_multiplier: 1.0,
83        }
84    }
85}
86
87/// Context threaded through every workflow and step.
88///
89/// Users never create or manage task IDs. The SDK handles everything
90/// via `(parent_id, name)` lookups — the unique constraint in the schema
91/// guarantees exactly-once step creation.
92pub struct Ctx {
93    db: DatabaseConnection,
94    task_id: Uuid,
95    sequence: AtomicI32,
96}
97
98impl Ctx {
99    // ── Workflow lifecycle (user-facing) ──────────────────────────
100
101    /// Start or resume a root workflow by name.
102    ///
103    /// ```ignore
104    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
105    /// ```
106    pub async fn start(
107        db: &DatabaseConnection,
108        name: &str,
109        input: Option<serde_json::Value>,
110    ) -> Result<Self, DurableError> {
111        let txn = db.begin().await?;
112        // Workflows use plain find-or-create without FOR UPDATE SKIP LOCKED.
113        // Multiple workers resuming the same workflow is fine — they share the
114        // same task row and each step will be individually locked.
115        let (task_id, _saved) =
116            find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
117        retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
118        txn.commit().await?;
119        Ok(Self {
120            db: db.clone(),
121            task_id,
122            sequence: AtomicI32::new(0),
123        })
124    }
125
126    /// Run a step. If already completed, returns saved output. Otherwise executes
127    /// the closure, saves the result, and returns it.
128    ///
129    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
130    ///
131    /// ```ignore
132    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
133    /// ```
134    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
135    where
136        T: Serialize + DeserializeOwned,
137        F: FnOnce() -> Fut,
138        Fut: std::future::Future<Output = Result<T, DurableError>>,
139    {
140        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
141
142        // Check if parent task's deadline has passed before executing
143        check_deadline(&self.db, self.task_id).await?;
144
145        // Begin transaction — the FOR UPDATE lock is held throughout step execution
146        let txn = self.db.begin().await?;
147
148        // Find or create — idempotent via UNIQUE(parent_id, name)
149        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
150        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
151        // max_retries=0: step() does not retry; use step_with_retry() for retries.
152        let (step_id, saved_output) = find_or_create_task(
153            &txn,
154            Some(self.task_id),
155            Some(seq),
156            name,
157            "STEP",
158            None,
159            true,
160            Some(0),
161        )
162        .await?;
163
164        // If already completed, replay from saved output
165        if let Some(output) = saved_output {
166            txn.commit().await?;
167            let val: T = serde_json::from_value(output)?;
168            tracing::debug!(step = name, seq, "replaying saved output");
169            return Ok(val);
170        }
171
172        // Execute the step closure within the transaction (row lock is held)
173        retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
174        match f().await {
175            Ok(val) => {
176                let json = serde_json::to_value(&val)?;
177                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
178                txn.commit().await?;
179                tracing::debug!(step = name, seq, "step completed");
180                Ok(val)
181            }
182            Err(e) => {
183                let err_msg = e.to_string();
184                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
185                txn.commit().await?;
186                Err(e)
187            }
188        }
189    }
190
191    /// Run a DB-only step inside a single Postgres transaction.
192    ///
193    /// Both the user's DB work and the checkpoint save happen in the same transaction,
194    /// ensuring atomicity. If the closure returns an error, both the user writes and
195    /// the checkpoint are rolled back.
196    ///
197    /// ```ignore
198    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
199    ///     do_db_work(tx).await
200    /// })).await?;
201    /// ```
202    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
203    where
204        T: Serialize + DeserializeOwned + Send,
205        F: for<'tx> FnOnce(
206                &'tx DatabaseTransaction,
207            ) -> Pin<
208                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
209            > + Send,
210    {
211        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
212
213        // Find or create the step task record OUTSIDE the transaction.
214        // This is idempotent (UNIQUE constraint) and must exist before we begin.
215        let (step_id, saved_output) = find_or_create_task(
216            &self.db,
217            Some(self.task_id),
218            Some(seq),
219            name,
220            "TRANSACTION",
221            None,
222            false,
223            None,
224        )
225        .await?;
226
227        // If already completed, replay from saved output.
228        if let Some(output) = saved_output {
229            let val: T = serde_json::from_value(output)?;
230            tracing::debug!(step = name, seq, "replaying saved transaction output");
231            return Ok(val);
232        }
233
234        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
235        let tx = self.db.begin().await?;
236
237        set_status(&tx, step_id, "RUNNING").await?;
238
239        match f(&tx).await {
240            Ok(val) => {
241                let json = serde_json::to_value(&val)?;
242                complete_task(&tx, step_id, json).await?;
243                tx.commit().await?;
244                tracing::debug!(step = name, seq, "transaction step committed");
245                Ok(val)
246            }
247            Err(e) => {
248                // Rollback happens automatically when tx is dropped.
249                // Record the failure on the main connection (outside the rolled-back tx).
250                drop(tx);
251                fail_task(&self.db, step_id, &e.to_string()).await?;
252                Err(e)
253            }
254        }
255    }
256
257    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
258    ///
259    /// ```ignore
260    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
261    /// // use child_ctx.step(...) for steps inside the child
262    /// child_ctx.complete(json!({"done": true})).await?;
263    /// ```
264    pub async fn child(
265        &self,
266        name: &str,
267        input: Option<serde_json::Value>,
268    ) -> Result<Self, DurableError> {
269        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
270
271        let txn = self.db.begin().await?;
272        // Child workflows also use plain find-or-create without locking.
273        let (child_id, _saved) = find_or_create_task(
274            &txn,
275            Some(self.task_id),
276            Some(seq),
277            name,
278            "WORKFLOW",
279            input,
280            false,
281            None,
282        )
283        .await?;
284
285        // If child already completed, return a Ctx that will replay
286        // (the caller should check is_completed() or just run steps which will replay)
287        retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
288        txn.commit().await?;
289
290        Ok(Self {
291            db: self.db.clone(),
292            task_id: child_id,
293            sequence: AtomicI32::new(0),
294        })
295    }
296
297    /// Check if this workflow/child was already completed (for skipping in parent).
298    pub async fn is_completed(&self) -> Result<bool, DurableError> {
299        let status = get_status(&self.db, self.task_id).await?;
300        Ok(status.as_deref() == Some("COMPLETED"))
301    }
302
303    /// Get the saved output if this task is completed.
304    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
305        match get_output(&self.db, self.task_id).await? {
306            Some(val) => Ok(Some(serde_json::from_value(val)?)),
307            None => Ok(None),
308        }
309    }
310
311    /// Mark this workflow as completed with an output value.
312    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
313        let json = serde_json::to_value(output)?;
314        let db = &self.db;
315        let task_id = self.task_id;
316        retry_db_write(|| complete_task(db, task_id, json.clone())).await
317    }
318
319    /// Run a step with a configurable retry policy.
320    ///
321    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
322    /// it may be called multiple times on retry. Retries happen in-process with
323    /// configurable backoff between attempts.
324    ///
325    /// ```ignore
326    /// let result: u32 = ctx
327    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
328    ///         api.call().await
329    ///     })
330    ///     .await?;
331    /// ```
332    pub async fn step_with_retry<T, F, Fut>(
333        &self,
334        name: &str,
335        policy: RetryPolicy,
336        f: F,
337    ) -> Result<T, DurableError>
338    where
339        T: Serialize + DeserializeOwned,
340        F: Fn() -> Fut,
341        Fut: std::future::Future<Output = Result<T, DurableError>>,
342    {
343        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
344
345        // Find or create — idempotent via UNIQUE(parent_id, name)
346        // Set max_retries from policy when creating.
347        // No locking here — retry logic handles re-execution in-process.
348        let (step_id, saved_output) = find_or_create_task(
349            &self.db,
350            Some(self.task_id),
351            Some(seq),
352            name,
353            "STEP",
354            None,
355            false,
356            Some(policy.max_retries),
357        )
358        .await?;
359
360        // If already completed, replay from saved output
361        if let Some(output) = saved_output {
362            let val: T = serde_json::from_value(output)?;
363            tracing::debug!(step = name, seq, "replaying saved output");
364            return Ok(val);
365        }
366
367        // Get current retry state from DB (for resume across restarts)
368        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
369
370        // Retry loop
371        loop {
372            set_status(&self.db, step_id, "RUNNING").await?;
373            match f().await {
374                Ok(val) => {
375                    let json = serde_json::to_value(&val)?;
376                    complete_task(&self.db, step_id, json).await?;
377                    tracing::debug!(step = name, seq, retry_count, "step completed");
378                    return Ok(val);
379                }
380                Err(e) => {
381                    if retry_count < max_retries {
382                        // Increment retry count and reset to PENDING
383                        retry_count = increment_retry_count(&self.db, step_id).await?;
384                        tracing::debug!(
385                            step = name,
386                            seq,
387                            retry_count,
388                            max_retries,
389                            "step failed, retrying"
390                        );
391
392                        // Compute backoff duration
393                        let backoff = if policy.initial_backoff.is_zero() {
394                            std::time::Duration::ZERO
395                        } else {
396                            let factor = policy
397                                .backoff_multiplier
398                                .powi((retry_count - 1) as i32)
399                                .max(1.0);
400                            let millis =
401                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
402                            std::time::Duration::from_millis(millis)
403                        };
404
405                        if !backoff.is_zero() {
406                            tokio::time::sleep(backoff).await;
407                        }
408                    } else {
409                        // Exhausted retries — mark FAILED
410                        fail_task(&self.db, step_id, &e.to_string()).await?;
411                        tracing::debug!(
412                            step = name,
413                            seq,
414                            retry_count,
415                            "step exhausted retries, marked FAILED"
416                        );
417                        return Err(e);
418                    }
419                }
420            }
421        }
422    }
423
424    /// Mark this workflow as failed.
425    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
426        let db = &self.db;
427        let task_id = self.task_id;
428        retry_db_write(|| fail_task(db, task_id, error)).await
429    }
430
431    /// Set the timeout for this task in milliseconds.
432    ///
433    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
434    /// for recovery by `Executor::recover()`.
435    ///
436    /// If the task is currently RUNNING and has a `started_at`, this also
437    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
438    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
439        let sql = format!(
440            "UPDATE durable.task \
441             SET timeout_ms = {timeout_ms}, \
442                 deadline_epoch_ms = CASE \
443                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
444                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
445                     ELSE deadline_epoch_ms \
446                 END \
447             WHERE id = '{}'",
448            self.task_id
449        );
450        self.db
451            .execute(Statement::from_string(DbBackend::Postgres, sql))
452            .await?;
453        Ok(())
454    }
455
456    /// Start or resume a root workflow with a timeout in milliseconds.
457    ///
458    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
459    pub async fn start_with_timeout(
460        db: &DatabaseConnection,
461        name: &str,
462        input: Option<serde_json::Value>,
463        timeout_ms: i64,
464    ) -> Result<Self, DurableError> {
465        let ctx = Self::start(db, name, input).await?;
466        ctx.set_timeout(timeout_ms).await?;
467        Ok(ctx)
468    }
469
470    // ── Accessors ────────────────────────────────────────────────
471
472    pub fn db(&self) -> &DatabaseConnection {
473        &self.db
474    }
475
476    pub fn task_id(&self) -> Uuid {
477        self.task_id
478    }
479
480    pub fn next_sequence(&self) -> i32 {
481        self.sequence.fetch_add(1, Ordering::SeqCst)
482    }
483}
484
485// ── Internal SQL helpers ─────────────────────────────────────────────
486
487/// Find an existing task by (parent_id, name) or create a new one.
488///
489/// Returns `(task_id, Option<saved_output>)`:
490/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
491/// - `saved_output` is `None` when the task is new or in-progress.
492///
493/// When `lock` is `true` and an existing non-completed task is found, this
494/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
495/// another worker holds the lock, `DurableError::StepLocked` is returned so
496/// the caller can skip execution rather than double-firing side effects.
497///
498/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
499/// and child-workflow creation where concurrent start is safe).
500///
501/// When `lock` is `true`, the caller MUST call this within a transaction so
502/// the row lock is held throughout step execution.
503///
504/// `max_retries`: if Some, overrides the schema default when creating a new task.
505#[allow(clippy::too_many_arguments)]
506async fn find_or_create_task(
507    db: &impl ConnectionTrait,
508    parent_id: Option<Uuid>,
509    sequence: Option<i32>,
510    name: &str,
511    kind: &str,
512    input: Option<serde_json::Value>,
513    lock: bool,
514    max_retries: Option<u32>,
515) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
516    let parent_eq = match parent_id {
517        Some(p) => format!("= '{p}'"),
518        None => "IS NULL".to_string(),
519    };
520    let parent_sql = match parent_id {
521        Some(p) => format!("'{p}'"),
522        None => "NULL".to_string(),
523    };
524
525    if lock {
526        // Locking path (for steps): we need exactly-once execution.
527        //
528        // Strategy:
529        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
530        //    If another transaction is concurrently inserting the same row,
531        //    Postgres will block here until that transaction commits or rolls
532        //    back, ensuring we never see a phantom "not found" for a row being
533        //    inserted.
534        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
535        //    should get it back; if the row existed and is locked by another
536        //    worker we get nothing.
537        // 3. If SKIP LOCKED returns empty, return StepLocked.
538
539        let new_id = Uuid::new_v4();
540        let seq_sql = match sequence {
541            Some(s) => s.to_string(),
542            None => "NULL".to_string(),
543        };
544        let input_sql = match &input {
545            Some(v) => format!("'{}'", serde_json::to_string(v)?),
546            None => "NULL".to_string(),
547        };
548
549        let max_retries_sql = match max_retries {
550            Some(r) => r.to_string(),
551            None => "3".to_string(), // schema default
552        };
553
554        // Step 1: insert-or-skip
555        let insert_sql = format!(
556            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
557             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
558             ON CONFLICT (parent_id, name) DO NOTHING"
559        );
560        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
561            .await?;
562
563        // Step 2: lock the row (ours or pre-existing)
564        let lock_sql = format!(
565            "SELECT id, status, output FROM durable.task \
566             WHERE parent_id {parent_eq} AND name = '{name}' \
567             FOR UPDATE SKIP LOCKED"
568        );
569        let row = db
570            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
571            .await?;
572
573        if let Some(row) = row {
574            let id: Uuid = row
575                .try_get_by_index(0)
576                .map_err(|e| DurableError::custom(e.to_string()))?;
577            let status: String = row
578                .try_get_by_index(1)
579                .map_err(|e| DurableError::custom(e.to_string()))?;
580            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
581
582            if status == "COMPLETED" {
583                // Replay path — return saved output
584                return Ok((id, output));
585            }
586            // Task exists and we hold the lock — proceed to execute
587            return Ok((id, None));
588        }
589
590        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
591        Err(DurableError::StepLocked(name.to_string()))
592    } else {
593        // Plain find without locking — safe for workflow-level operations.
594        // Multiple workers resuming the same workflow is fine; individual
595        // steps will be locked when executed.
596        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
597        query = match parent_id {
598            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
599            None => query.filter(TaskColumn::ParentId.is_null()),
600        };
601        let existing = query.one(db).await?;
602
603        if let Some(model) = existing {
604            if model.status == "COMPLETED" {
605                return Ok((model.id, model.output));
606            }
607            return Ok((model.id, None));
608        }
609
610        // Task does not exist — create it
611        let id = Uuid::new_v4();
612        let new_task = TaskActiveModel {
613            id: Set(id),
614            parent_id: Set(parent_id),
615            sequence: Set(sequence),
616            name: Set(name.to_string()),
617            kind: Set(kind.to_string()),
618            status: Set("PENDING".to_string()),
619            input: Set(input),
620            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
621            ..Default::default()
622        };
623        new_task.insert(db).await?;
624
625        Ok((id, None))
626    }
627}
628
629async fn get_output(
630    db: &impl ConnectionTrait,
631    task_id: Uuid,
632) -> Result<Option<serde_json::Value>, DurableError> {
633    let model = Task::find_by_id(task_id)
634        .filter(TaskColumn::Status.eq("COMPLETED"))
635        .one(db)
636        .await?;
637
638    Ok(model.and_then(|m| m.output))
639}
640
641async fn get_status(
642    db: &impl ConnectionTrait,
643    task_id: Uuid,
644) -> Result<Option<String>, DurableError> {
645    let model = Task::find_by_id(task_id).one(db).await?;
646
647    Ok(model.map(|m| m.status))
648}
649
650/// Returns (retry_count, max_retries) for a task.
651async fn get_retry_info(
652    db: &DatabaseConnection,
653    task_id: Uuid,
654) -> Result<(u32, u32), DurableError> {
655    let model = Task::find_by_id(task_id).one(db).await?;
656
657    match model {
658        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
659        None => Err(DurableError::custom(format!(
660            "task {task_id} not found when reading retry info"
661        ))),
662    }
663}
664
665/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
666async fn increment_retry_count(
667    db: &DatabaseConnection,
668    task_id: Uuid,
669) -> Result<u32, DurableError> {
670    let model = Task::find_by_id(task_id).one(db).await?;
671
672    match model {
673        Some(m) => {
674            let new_count = m.retry_count + 1;
675            let mut active: TaskActiveModel = m.into();
676            active.retry_count = Set(new_count);
677            active.status = Set("PENDING".to_string());
678            active.error = Set(None);
679            active.completed_at = Set(None);
680            active.update(db).await?;
681            Ok(new_count as u32)
682        }
683        None => Err(DurableError::custom(format!(
684            "task {task_id} not found when incrementing retry count"
685        ))),
686    }
687}
688
689// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
690// for conditional deadline_epoch_ms computation or COALESCE on started_at.
691async fn set_status(
692    db: &impl ConnectionTrait,
693    task_id: Uuid,
694    status: &str,
695) -> Result<(), DurableError> {
696    let sql = format!(
697        "UPDATE durable.task \
698         SET status = '{status}', \
699             started_at = COALESCE(started_at, now()), \
700             deadline_epoch_ms = CASE \
701                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
702                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
703                 ELSE deadline_epoch_ms \
704             END \
705         WHERE id = '{task_id}'"
706    );
707    db.execute(Statement::from_string(DbBackend::Postgres, sql))
708        .await?;
709    Ok(())
710}
711
712/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
713async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
714    let model = Task::find_by_id(task_id).one(db).await?;
715
716    if let Some(m) = model
717        && let Some(deadline_ms) = m.deadline_epoch_ms
718    {
719        let now_ms = std::time::SystemTime::now()
720            .duration_since(std::time::UNIX_EPOCH)
721            .map(|d| d.as_millis() as i64)
722            .unwrap_or(0);
723        if now_ms > deadline_ms {
724            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
725        }
726    }
727
728    Ok(())
729}
730
731async fn complete_task(
732    db: &impl ConnectionTrait,
733    task_id: Uuid,
734    output: serde_json::Value,
735) -> Result<(), DurableError> {
736    let model = Task::find_by_id(task_id).one(db).await?;
737
738    if let Some(m) = model {
739        let mut active: TaskActiveModel = m.into();
740        active.status = Set("COMPLETED".to_string());
741        active.output = Set(Some(output));
742        active.completed_at = Set(Some(chrono::Utc::now().into()));
743        active.update(db).await?;
744    }
745    Ok(())
746}
747
748async fn fail_task(
749    db: &impl ConnectionTrait,
750    task_id: Uuid,
751    error: &str,
752) -> Result<(), DurableError> {
753    let model = Task::find_by_id(task_id).one(db).await?;
754
755    if let Some(m) = model {
756        let mut active: TaskActiveModel = m.into();
757        active.status = Set("FAILED".to_string());
758        active.error = Set(Some(error.to_string()));
759        active.completed_at = Set(Some(chrono::Utc::now().into()));
760        active.update(db).await?;
761    }
762    Ok(())
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768    use std::sync::Arc;
769    use std::sync::atomic::{AtomicU32, Ordering};
770
771    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
772    /// should be called exactly once and return Ok.
773    #[tokio::test]
774    async fn test_retry_db_write_succeeds_first_try() {
775        let call_count = Arc::new(AtomicU32::new(0));
776        let cc = call_count.clone();
777        let result = retry_db_write(|| {
778            let c = cc.clone();
779            async move {
780                c.fetch_add(1, Ordering::SeqCst);
781                Ok::<(), DurableError>(())
782            }
783        })
784        .await;
785        assert!(result.is_ok());
786        assert_eq!(call_count.load(Ordering::SeqCst), 1);
787    }
788
789    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
790    /// fails twice then succeeds should return Ok and be called 3 times.
791    #[tokio::test]
792    async fn test_retry_db_write_succeeds_after_transient_failure() {
793        let call_count = Arc::new(AtomicU32::new(0));
794        let cc = call_count.clone();
795        let result = retry_db_write(|| {
796            let c = cc.clone();
797            async move {
798                let n = c.fetch_add(1, Ordering::SeqCst);
799                if n < 2 {
800                    Err(DurableError::Db(sea_orm::DbErr::Custom(
801                        "transient".to_string(),
802                    )))
803                } else {
804                    Ok(())
805                }
806            }
807        })
808        .await;
809        assert!(result.is_ok());
810        assert_eq!(call_count.load(Ordering::SeqCst), 3);
811    }
812
813    /// test_retry_db_write_exhausts_retries: a closure that always fails should
814    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
815    #[tokio::test]
816    async fn test_retry_db_write_exhausts_retries() {
817        let call_count = Arc::new(AtomicU32::new(0));
818        let cc = call_count.clone();
819        let result = retry_db_write(|| {
820            let c = cc.clone();
821            async move {
822                c.fetch_add(1, Ordering::SeqCst);
823                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
824                    "always fails".to_string(),
825                )))
826            }
827        })
828        .await;
829        assert!(result.is_err());
830        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
831        assert_eq!(
832            call_count.load(Ordering::SeqCst),
833            1 + MAX_CHECKPOINT_RETRIES
834        );
835    }
836
837    /// test_retry_db_write_returns_original_error: when all retries are
838    /// exhausted the FIRST error is returned, not the last retry error.
839    #[tokio::test]
840    async fn test_retry_db_write_returns_original_error() {
841        let call_count = Arc::new(AtomicU32::new(0));
842        let cc = call_count.clone();
843        let result = retry_db_write(|| {
844            let c = cc.clone();
845            async move {
846                let n = c.fetch_add(1, Ordering::SeqCst);
847                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
848                    "error-{}",
849                    n
850                ))))
851            }
852        })
853        .await;
854        let err = result.unwrap_err();
855        // The message of the first error contains "error-0"
856        assert!(
857            err.to_string().contains("error-0"),
858            "expected first error (error-0), got: {err}"
859        );
860    }
861}