Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── Retry constants ───────────────────────────────────────────────────────────
21
22const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25/// Retry a fallible DB write with exponential backoff.
26///
27/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
28/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
29/// error if all retries are exhausted.
30async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32    F: FnMut() -> Fut,
33    Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35    match f().await {
36        Ok(()) => Ok(()),
37        Err(first_err) => {
38            for i in 0..MAX_CHECKPOINT_RETRIES {
39                tokio::time::sleep(Duration::from_millis(
40                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41                ))
42                .await;
43                if f().await.is_ok() {
44                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45                    return Ok(());
46                }
47            }
48            Err(first_err)
49        }
50    }
51}
52
53/// Policy for retrying a step on failure.
54pub struct RetryPolicy {
55    pub max_retries: u32,
56    pub initial_backoff: std::time::Duration,
57    pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61    /// No retries — fails immediately on error.
62    pub fn none() -> Self {
63        Self {
64            max_retries: 0,
65            initial_backoff: std::time::Duration::from_secs(0),
66            backoff_multiplier: 1.0,
67        }
68    }
69
70    /// Exponential backoff: backoff doubles each retry.
71    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72        Self {
73            max_retries,
74            initial_backoff,
75            backoff_multiplier: 2.0,
76        }
77    }
78
79    /// Fixed backoff: same duration between all retries.
80    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff: backoff,
84            backoff_multiplier: 1.0,
85        }
86    }
87}
88
89/// Sort order for task listing.
90pub enum TaskSort {
91    CreatedAt(Order),
92    StartedAt(Order),
93    CompletedAt(Order),
94    Name(Order),
95    Status(Order),
96}
97
98/// Builder for filtering, sorting, and paginating task queries.
99///
100/// ```ignore
101/// let query = TaskQuery::default()
102///     .status(TaskStatus::Running)
103///     .kind("WORKFLOW")
104///     .root_only(true)
105///     .sort(TaskSort::CreatedAt(Order::Desc))
106///     .limit(20);
107/// let tasks = Ctx::list(&db, query).await?;
108/// ```
109pub struct TaskQuery {
110    pub status: Option<TaskStatus>,
111    pub kind: Option<String>,
112    pub parent_id: Option<Uuid>,
113    pub root_only: bool,
114    pub name: Option<String>,
115    pub queue_name: Option<String>,
116    pub sort: TaskSort,
117    pub limit: Option<u64>,
118    pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122    fn default() -> Self {
123        Self {
124            status: None,
125            kind: None,
126            parent_id: None,
127            root_only: false,
128            name: None,
129            queue_name: None,
130            sort: TaskSort::CreatedAt(Order::Desc),
131            limit: None,
132            offset: None,
133        }
134    }
135}
136
137impl TaskQuery {
138    /// Filter by status.
139    pub fn status(mut self, status: TaskStatus) -> Self {
140        self.status = Some(status);
141        self
142    }
143
144    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
145    pub fn kind(mut self, kind: &str) -> Self {
146        self.kind = Some(kind.to_string());
147        self
148    }
149
150    /// Filter by parent task ID (direct children only).
151    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152        self.parent_id = Some(parent_id);
153        self
154    }
155
156    /// Only return root tasks (no parent).
157    pub fn root_only(mut self, root_only: bool) -> Self {
158        self.root_only = root_only;
159        self
160    }
161
162    /// Filter by task name.
163    pub fn name(mut self, name: &str) -> Self {
164        self.name = Some(name.to_string());
165        self
166    }
167
168    /// Filter by queue name.
169    pub fn queue_name(mut self, queue: &str) -> Self {
170        self.queue_name = Some(queue.to_string());
171        self
172    }
173
174    /// Set the sort order.
175    pub fn sort(mut self, sort: TaskSort) -> Self {
176        self.sort = sort;
177        self
178    }
179
180    /// Limit the number of results.
181    pub fn limit(mut self, limit: u64) -> Self {
182        self.limit = Some(limit);
183        self
184    }
185
186    /// Skip the first N results.
187    pub fn offset(mut self, offset: u64) -> Self {
188        self.offset = Some(offset);
189        self
190    }
191}
192
193/// Summary of a task returned by `Ctx::list()`.
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196    pub id: Uuid,
197    pub parent_id: Option<Uuid>,
198    pub name: String,
199    pub handler: Option<String>,
200    pub status: TaskStatus,
201    pub kind: String,
202    pub input: Option<serde_json::Value>,
203    pub output: Option<serde_json::Value>,
204    pub error: Option<String>,
205    pub queue_name: Option<String>,
206    pub created_at: chrono::DateTime<chrono::FixedOffset>,
207    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
209}
210
211impl From<durable_db::entity::task::Model> for TaskSummary {
212    fn from(m: durable_db::entity::task::Model) -> Self {
213        Self {
214            id: m.id,
215            parent_id: m.parent_id,
216            name: m.name,
217            handler: m.handler,
218            status: m.status,
219            kind: m.kind,
220            input: m.input,
221            output: m.output,
222            error: m.error,
223            queue_name: m.queue_name,
224            created_at: m.created_at,
225            started_at: m.started_at,
226            completed_at: m.completed_at,
227        }
228    }
229}
230
231/// Context threaded through every workflow and step.
232///
233/// Users never create or manage task IDs. The SDK handles everything
234/// via `(parent_id, name)` lookups — the unique constraint in the schema
235/// guarantees exactly-once step creation.
236pub struct Ctx {
237    db: DatabaseConnection,
238    task_id: Uuid,
239    sequence: AtomicI32,
240    executor_id: Option<String>,
241}
242
243impl Ctx {
244    // ── Workflow lifecycle (user-facing) ──────────────────────────
245
246    /// Start a new root workflow (or attach to an existing one with the same name).
247    ///
248    /// If a RUNNING root task with the same `name` already exists (e.g. recovered
249    /// after a crash), returns a `Ctx` attached to that task (idempotent start).
250    /// Otherwise creates a fresh task.
251    ///
252    /// If [`init`](crate::init) was called, the task is automatically tagged
253    /// with the executor ID for crash recovery. Otherwise `executor_id` is NULL.
254    ///
255    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
256    ///
257    /// ```ignore
258    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
259    /// ```
260    pub async fn start(
261        db: &DatabaseConnection,
262        name: &str,
263        input: Option<serde_json::Value>,
264    ) -> Result<Self, DurableError> {
265        Self::start_with_handler(db, name, input, None).await
266    }
267
268    /// Like [`start`](Self::start) but records the handler function name for
269    /// crash recovery. The `handler` is stored in the `durable.task` row so
270    /// that [`dispatch_recovered`](crate::dispatch_recovered) can look up the
271    /// registered `#[durable::workflow]` function even when `name` differs.
272    ///
273    /// Typically called from macro-generated `start_<fn>` functions rather
274    /// than directly.
275    pub async fn start_with_handler(
276        db: &DatabaseConnection,
277        name: &str,
278        input: Option<serde_json::Value>,
279        handler: Option<&str>,
280    ) -> Result<Self, DurableError> {
281        // ── Idempotent: reuse existing RUNNING root task with same name ──
282        let existing_sql = format!(
283            "SELECT id FROM durable.task \
284             WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING' \
285             LIMIT 1",
286            name
287        );
288        if let Some(row) = db
289            .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
290            .await?
291        {
292            let existing_id: Uuid = row
293                .try_get_by_index(0)
294                .map_err(|e| DurableError::custom(e.to_string()))?;
295            tracing::info!(
296                workflow = name,
297                id = %existing_id,
298                "idempotent start: attaching to existing RUNNING task"
299            );
300            return Self::from_id(db, existing_id).await;
301        }
302
303        let task_id = Uuid::new_v4();
304        let input_json = match &input {
305            Some(v) => serde_json::to_string(v)?,
306            None => "null".to_string(),
307        };
308
309        let executor_id = crate::executor_id();
310
311        let mut extra_cols = String::new();
312        let mut extra_vals = String::new();
313
314        if let Some(eid) = &executor_id {
315            extra_cols.push_str(", executor_id");
316            extra_vals.push_str(&format!(", '{eid}'"));
317        }
318        if let Some(h) = handler {
319            extra_cols.push_str(", handler");
320            extra_vals.push_str(&format!(", '{h}'"));
321        }
322
323        let txn = db.begin().await?;
324        let sql = format!(
325            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
326             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){extra_vals})"
327        );
328        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
329            .await?;
330        txn.commit().await?;
331
332        Ok(Self {
333            db: db.clone(),
334            task_id,
335            sequence: AtomicI32::new(0),
336            executor_id,
337        })
338    }
339
340    /// Attach to an existing workflow by task ID. Used to resume a running or
341    /// recovered workflow without creating a new task row.
342    ///
343    /// ```ignore
344    /// let ctx = Ctx::from_id(&db, task_id).await?;
345    /// ```
346    pub async fn from_id(
347        db: &DatabaseConnection,
348        task_id: Uuid,
349    ) -> Result<Self, DurableError> {
350        // Verify the task exists
351        let model = Task::find_by_id(task_id).one(db).await?;
352        let _model =
353            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
354
355        // Claim the task with the current executor_id (if init() was called)
356        let executor_id = crate::executor_id();
357        if let Some(eid) = &executor_id {
358            db.execute(Statement::from_string(
359                DbBackend::Postgres,
360                format!(
361                    "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
362                ),
363            ))
364            .await?;
365        }
366
367        // Start sequence at 0 so that replaying steps from the beginning works
368        // correctly. Steps are looked up by (parent_id, sequence), so the caller
369        // will replay completed steps (getting saved outputs) before executing
370        // any new steps.
371        Ok(Self {
372            db: db.clone(),
373            task_id,
374            sequence: AtomicI32::new(0),
375            executor_id,
376        })
377    }
378
379    /// Run a step. If already completed, returns saved output. Otherwise executes
380    /// the closure, saves the result, and returns it.
381    ///
382    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
383    ///
384    /// ```ignore
385    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
386    /// ```
387    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
388    where
389        T: Serialize + DeserializeOwned,
390        F: FnOnce() -> Fut,
391        Fut: std::future::Future<Output = Result<T, DurableError>>,
392    {
393        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
394
395        // Check if workflow is paused or cancelled before executing
396        check_status(&self.db, self.task_id).await?;
397
398        // Check if parent task's deadline has passed before executing
399        check_deadline(&self.db, self.task_id).await?;
400
401        // TX1: Find or create the step row, set RUNNING, release connection.
402        // The RUNNING status acts as a mutex — recovery won't reclaim it
403        // unless this executor's heartbeat goes stale.
404        let txn = self.db.begin().await?;
405        let (step_id, saved_output) = find_or_create_task(
406            &txn,
407            Some(self.task_id),
408            Some(seq),
409            name,
410            "STEP",
411            None,
412            true,
413            Some(0),
414        )
415        .await?;
416
417        if let Some(output) = saved_output {
418            txn.commit().await?;
419            let val: T = serde_json::from_value(output)?;
420            tracing::debug!(step = name, seq, "replaying saved output");
421            return Ok(val);
422        }
423
424        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
425        txn.commit().await?;
426        // ── connection returned to pool ──
427
428        // Execute the closure with no DB connection held.
429        let result = f().await;
430
431        // TX2: Checkpoint the result (short transaction).
432        match result {
433            Ok(val) => {
434                let json = serde_json::to_value(&val)?;
435                retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
436                tracing::debug!(step = name, seq, "step completed");
437                Ok(val)
438            }
439            Err(e) => {
440                let err_msg = e.to_string();
441                retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
442                Err(e)
443            }
444        }
445    }
446
447    /// Run a DB-only step inside a single Postgres transaction.
448    ///
449    /// Both the user's DB work and the checkpoint save happen in the same transaction,
450    /// ensuring atomicity. If the closure returns an error, both the user writes and
451    /// the checkpoint are rolled back.
452    ///
453    /// ```ignore
454    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
455    ///     do_db_work(tx).await
456    /// })).await?;
457    /// ```
458    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
459    where
460        T: Serialize + DeserializeOwned + Send,
461        F: for<'tx> FnOnce(
462                &'tx DatabaseTransaction,
463            ) -> Pin<
464                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
465            > + Send,
466    {
467        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
468
469        // Check if workflow is paused or cancelled before executing
470        check_status(&self.db, self.task_id).await?;
471
472        // Find or create the step task record OUTSIDE the transaction.
473        // This is idempotent (UNIQUE constraint) and must exist before we begin.
474        let (step_id, saved_output) = find_or_create_task(
475            &self.db,
476            Some(self.task_id),
477            Some(seq),
478            name,
479            "TRANSACTION",
480            None,
481            false,
482            None,
483        )
484        .await?;
485
486        // If already completed, replay from saved output.
487        if let Some(output) = saved_output {
488            let val: T = serde_json::from_value(output)?;
489            tracing::debug!(step = name, seq, "replaying saved transaction output");
490            return Ok(val);
491        }
492
493        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
494        let tx = self.db.begin().await?;
495
496        set_status(&tx, step_id, TaskStatus::Running).await?;
497
498        match f(&tx).await {
499            Ok(val) => {
500                let json = serde_json::to_value(&val)?;
501                complete_task(&tx, step_id, json).await?;
502                tx.commit().await?;
503                tracing::debug!(step = name, seq, "transaction step committed");
504                Ok(val)
505            }
506            Err(e) => {
507                // Rollback happens automatically when tx is dropped.
508                // Record the failure on the main connection (outside the rolled-back tx).
509                drop(tx);
510                fail_task(&self.db, step_id, &e.to_string()).await?;
511                Err(e)
512            }
513        }
514    }
515
516    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
517    ///
518    /// ```ignore
519    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
520    /// // use child_ctx.step(...) for steps inside the child
521    /// child_ctx.complete(json!({"done": true})).await?;
522    /// ```
523    pub async fn child(
524        &self,
525        name: &str,
526        input: Option<serde_json::Value>,
527    ) -> Result<Self, DurableError> {
528        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
529
530        // Check if workflow is paused or cancelled before executing
531        check_status(&self.db, self.task_id).await?;
532
533        let txn = self.db.begin().await?;
534        // Child workflows also use plain find-or-create without locking.
535        let (child_id, _saved) = find_or_create_task(
536            &txn,
537            Some(self.task_id),
538            Some(seq),
539            name,
540            "WORKFLOW",
541            input,
542            false,
543            None,
544        )
545        .await?;
546
547        // If child already completed, return a Ctx that will replay
548        // (the caller should check is_completed() or just run steps which will replay)
549        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
550        txn.commit().await?;
551
552        Ok(Self {
553            db: self.db.clone(),
554            task_id: child_id,
555            sequence: AtomicI32::new(0),
556            executor_id: self.executor_id.clone(),
557        })
558    }
559
560    /// Check if this workflow/child was already completed (for skipping in parent).
561    pub async fn is_completed(&self) -> Result<bool, DurableError> {
562        let status = get_status(&self.db, self.task_id).await?;
563        Ok(status == Some(TaskStatus::Completed))
564    }
565
566    /// Get the saved output if this task is completed.
567    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
568        match get_output(&self.db, self.task_id).await? {
569            Some(val) => Ok(Some(serde_json::from_value(val)?)),
570            None => Ok(None),
571        }
572    }
573
574    /// Mark this workflow as completed with an output value.
575    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
576        let json = serde_json::to_value(output)?;
577        let db = &self.db;
578        let task_id = self.task_id;
579        retry_db_write(|| complete_task(db, task_id, json.clone())).await
580    }
581
582    /// Run a step with a configurable retry policy.
583    ///
584    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
585    /// it may be called multiple times on retry. Retries happen in-process with
586    /// configurable backoff between attempts.
587    ///
588    /// ```ignore
589    /// let result: u32 = ctx
590    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
591    ///         api.call().await
592    ///     })
593    ///     .await?;
594    /// ```
595    pub async fn step_with_retry<T, F, Fut>(
596        &self,
597        name: &str,
598        policy: RetryPolicy,
599        f: F,
600    ) -> Result<T, DurableError>
601    where
602        T: Serialize + DeserializeOwned,
603        F: Fn() -> Fut,
604        Fut: std::future::Future<Output = Result<T, DurableError>>,
605    {
606        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
607
608        // Check if workflow is paused or cancelled before executing
609        check_status(&self.db, self.task_id).await?;
610
611        // Find or create — idempotent via UNIQUE(parent_id, name)
612        // Set max_retries from policy when creating.
613        // No locking here — retry logic handles re-execution in-process.
614        let (step_id, saved_output) = find_or_create_task(
615            &self.db,
616            Some(self.task_id),
617            Some(seq),
618            name,
619            "STEP",
620            None,
621            false,
622            Some(policy.max_retries),
623        )
624        .await?;
625
626        // If already completed, replay from saved output
627        if let Some(output) = saved_output {
628            let val: T = serde_json::from_value(output)?;
629            tracing::debug!(step = name, seq, "replaying saved output");
630            return Ok(val);
631        }
632
633        // Get current retry state from DB (for resume across restarts)
634        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
635
636        // Retry loop
637        loop {
638            // Re-check status before each retry attempt
639            check_status(&self.db, self.task_id).await?;
640            set_status(&self.db, step_id, TaskStatus::Running).await?;
641            match f().await {
642                Ok(val) => {
643                    let json = serde_json::to_value(&val)?;
644                    complete_task(&self.db, step_id, json).await?;
645                    tracing::debug!(step = name, seq, retry_count, "step completed");
646                    return Ok(val);
647                }
648                Err(e) => {
649                    if retry_count < max_retries {
650                        // Increment retry count and reset to PENDING
651                        retry_count = increment_retry_count(&self.db, step_id).await?;
652                        tracing::debug!(
653                            step = name,
654                            seq,
655                            retry_count,
656                            max_retries,
657                            "step failed, retrying"
658                        );
659
660                        // Compute backoff duration
661                        let backoff = if policy.initial_backoff.is_zero() {
662                            std::time::Duration::ZERO
663                        } else {
664                            let factor = policy
665                                .backoff_multiplier
666                                .powi((retry_count - 1) as i32)
667                                .max(1.0);
668                            let millis =
669                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
670                            std::time::Duration::from_millis(millis)
671                        };
672
673                        if !backoff.is_zero() {
674                            tokio::time::sleep(backoff).await;
675                        }
676                    } else {
677                        // Exhausted retries — mark FAILED
678                        fail_task(&self.db, step_id, &e.to_string()).await?;
679                        tracing::debug!(
680                            step = name,
681                            seq,
682                            retry_count,
683                            "step exhausted retries, marked FAILED"
684                        );
685                        return Err(e);
686                    }
687                }
688            }
689        }
690    }
691
692    /// Mark this workflow as failed.
693    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
694        let db = &self.db;
695        let task_id = self.task_id;
696        retry_db_write(|| fail_task(db, task_id, error)).await
697    }
698
699    /// Mark a workflow as failed by task ID (static version for use outside a Ctx).
700    pub async fn fail_by_id(
701        db: &DatabaseConnection,
702        task_id: Uuid,
703        error: &str,
704    ) -> Result<(), DurableError> {
705        retry_db_write(|| fail_task(db, task_id, error)).await
706    }
707
708    /// Set the timeout for this task in milliseconds.
709    ///
710    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
711    /// for recovery by `Executor::recover()`.
712    ///
713    /// If the task is currently RUNNING and has a `started_at`, this also
714    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
715    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
716        let sql = format!(
717            "UPDATE durable.task \
718             SET timeout_ms = {timeout_ms}, \
719                 deadline_epoch_ms = CASE \
720                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
721                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
722                     ELSE deadline_epoch_ms \
723                 END \
724             WHERE id = '{}'",
725            self.task_id
726        );
727        self.db
728            .execute(Statement::from_string(DbBackend::Postgres, sql))
729            .await?;
730        Ok(())
731    }
732
733    /// Start or resume a root workflow with a timeout in milliseconds.
734    ///
735    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
736    pub async fn start_with_timeout(
737        db: &DatabaseConnection,
738        name: &str,
739        input: Option<serde_json::Value>,
740        timeout_ms: i64,
741    ) -> Result<Self, DurableError> {
742        let ctx = Self::start(db, name, input).await?;
743        ctx.set_timeout(timeout_ms).await?;
744        Ok(ctx)
745    }
746
747    // ── Workflow control (management API) ─────────────────────────
748
749    /// Pause a workflow by ID. Sets status to PAUSED and recursively
750    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
751    ///
752    /// Only workflows in PENDING or RUNNING status can be paused.
753    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
754        let model = Task::find_by_id(task_id).one(db).await?;
755        let model =
756            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
757
758        match model.status {
759            TaskStatus::Pending | TaskStatus::Running => {}
760            status => {
761                return Err(DurableError::custom(format!(
762                    "cannot pause task in {status} status"
763                )));
764            }
765        }
766
767        // Pause the task itself and all descendants in one recursive CTE
768        let sql = format!(
769            "WITH RECURSIVE descendants AS ( \
770                 SELECT id FROM durable.task WHERE id = '{task_id}' \
771                 UNION ALL \
772                 SELECT t.id FROM durable.task t \
773                 INNER JOIN descendants d ON t.parent_id = d.id \
774             ) \
775             UPDATE durable.task SET status = 'PAUSED' \
776             WHERE id IN (SELECT id FROM descendants) \
777               AND status IN ('PENDING', 'RUNNING')"
778        );
779        db.execute(Statement::from_string(DbBackend::Postgres, sql))
780            .await?;
781
782        tracing::info!(%task_id, "workflow paused");
783        Ok(())
784    }
785
786    /// Resume a paused workflow by ID. Sets status back to RUNNING and
787    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
788    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
789        let model = Task::find_by_id(task_id).one(db).await?;
790        let model =
791            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
792
793        if model.status != TaskStatus::Paused {
794            return Err(DurableError::custom(format!(
795                "cannot resume task in {} status (must be PAUSED)",
796                model.status
797            )));
798        }
799
800        // Resume the root task back to RUNNING
801        let mut active: TaskActiveModel = model.into();
802        active.status = Set(TaskStatus::Running);
803        active.update(db).await?;
804
805        // Recursively resume all PAUSED descendants back to PENDING
806        let cascade_sql = format!(
807            "WITH RECURSIVE descendants AS ( \
808                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
809                 UNION ALL \
810                 SELECT t.id FROM durable.task t \
811                 INNER JOIN descendants d ON t.parent_id = d.id \
812             ) \
813             UPDATE durable.task SET status = 'PENDING' \
814             WHERE id IN (SELECT id FROM descendants) \
815               AND status = 'PAUSED'"
816        );
817        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
818            .await?;
819
820        tracing::info!(%task_id, "workflow resumed");
821        Ok(())
822    }
823
824    /// Resume a failed workflow from its last completed step (DBOS-style).
825    ///
826    /// Resets the workflow from FAILED to RUNNING, resets all FAILED child
827    /// steps to PENDING (with `retry_count` zeroed), and increments
828    /// `recovery_count`. Completed steps are left untouched so they replay
829    /// from their checkpointed outputs.
830    ///
831    /// Returns the workflow's `handler` name (if set) so the caller can
832    /// look up the `WorkflowRegistration` and re-dispatch.
833    ///
834    /// Fails with `MaxRecoveryExceeded` if the workflow has already been
835    /// resumed `max_recovery_attempts` times.
836    pub async fn resume_failed(
837        db: &DatabaseConnection,
838        task_id: Uuid,
839    ) -> Result<Option<String>, DurableError> {
840        // Validate the workflow exists and is FAILED
841        let model = Task::find_by_id(task_id).one(db).await?;
842        let model =
843            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
844
845        if model.status != TaskStatus::Failed {
846            return Err(DurableError::custom(format!(
847                "cannot resume task in {} status (must be FAILED)",
848                model.status
849            )));
850        }
851
852        if model.kind != "WORKFLOW" {
853            return Err(DurableError::custom(format!(
854                "cannot resume a {}, only workflows can be resumed",
855                model.kind
856            )));
857        }
858
859        // Check recovery limit
860        if model.recovery_count >= model.max_recovery_attempts {
861            return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
862        }
863
864        let executor_id = crate::executor_id().unwrap_or_default();
865
866        // Atomically: reset workflow to RUNNING, increment recovery_count
867        let reset_workflow_sql = format!(
868            "UPDATE durable.task \
869             SET status = 'RUNNING', \
870                 error = NULL, \
871                 completed_at = NULL, \
872                 started_at = now(), \
873                 executor_id = '{executor_id}', \
874                 recovery_count = recovery_count + 1 \
875             WHERE id = '{task_id}' AND status = 'FAILED'"
876        );
877        db.execute(Statement::from_string(DbBackend::Postgres, reset_workflow_sql))
878            .await?;
879
880        // Reset FAILED descendant steps to PENDING with cleared retry state.
881        // Completed steps are preserved for replay.
882        let reset_steps_sql = format!(
883            "WITH RECURSIVE descendants AS ( \
884                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
885                 UNION ALL \
886                 SELECT t.id FROM durable.task t \
887                 INNER JOIN descendants d ON t.parent_id = d.id \
888             ) \
889             UPDATE durable.task \
890             SET status = 'PENDING', \
891                 error = NULL, \
892                 retry_count = 0, \
893                 completed_at = NULL, \
894                 started_at = NULL \
895             WHERE id IN (SELECT id FROM descendants) \
896               AND status IN ('FAILED', 'RUNNING')"
897        );
898        db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
899            .await?;
900
901        tracing::info!(
902            %task_id,
903            recovery_count = model.recovery_count + 1,
904            "failed workflow resumed"
905        );
906
907        Ok(model.handler.clone())
908    }
909
910    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
911    /// cascades to all non-terminal descendants.
912    ///
913    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
914    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
915        let model = Task::find_by_id(task_id).one(db).await?;
916        let model =
917            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
918
919        match model.status {
920            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
921                return Err(DurableError::custom(format!(
922                    "cannot cancel task in {} status",
923                    model.status
924                )));
925            }
926            _ => {}
927        }
928
929        // Cancel the task itself and all non-terminal descendants in one recursive CTE
930        let sql = format!(
931            "WITH RECURSIVE descendants AS ( \
932                 SELECT id FROM durable.task WHERE id = '{task_id}' \
933                 UNION ALL \
934                 SELECT t.id FROM durable.task t \
935                 INNER JOIN descendants d ON t.parent_id = d.id \
936             ) \
937             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
938             WHERE id IN (SELECT id FROM descendants) \
939               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
940        );
941        db.execute(Statement::from_string(DbBackend::Postgres, sql))
942            .await?;
943
944        tracing::info!(%task_id, "workflow cancelled");
945        Ok(())
946    }
947
948    // ── Query API ─────────────────────────────────────────────────
949
950    /// List tasks matching the given filter, with sorting and pagination.
951    ///
952    /// ```ignore
953    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
954    /// ```
955    pub async fn list(
956        db: &DatabaseConnection,
957        query: TaskQuery,
958    ) -> Result<Vec<TaskSummary>, DurableError> {
959        let mut select = Task::find();
960
961        // Filters
962        if let Some(status) = &query.status {
963            select = select.filter(TaskColumn::Status.eq(status.to_string()));
964        }
965        if let Some(kind) = &query.kind {
966            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
967        }
968        if let Some(parent_id) = query.parent_id {
969            select = select.filter(TaskColumn::ParentId.eq(parent_id));
970        }
971        if query.root_only {
972            select = select.filter(TaskColumn::ParentId.is_null());
973        }
974        if let Some(name) = &query.name {
975            select = select.filter(TaskColumn::Name.eq(name.as_str()));
976        }
977        if let Some(queue) = &query.queue_name {
978            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
979        }
980
981        // Sorting
982        let (col, order) = match query.sort {
983            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
984            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
985            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
986            TaskSort::Name(ord) => (TaskColumn::Name, ord),
987            TaskSort::Status(ord) => (TaskColumn::Status, ord),
988        };
989        select = select.order_by(col, order);
990
991        // Pagination
992        if let Some(offset) = query.offset {
993            select = select.offset(offset);
994        }
995        if let Some(limit) = query.limit {
996            select = select.limit(limit);
997        }
998
999        let models = select.all(db).await?;
1000
1001        Ok(models.into_iter().map(TaskSummary::from).collect())
1002    }
1003
1004    /// Count tasks matching the given filter.
1005    pub async fn count(
1006        db: &DatabaseConnection,
1007        query: TaskQuery,
1008    ) -> Result<u64, DurableError> {
1009        let mut select = Task::find();
1010
1011        if let Some(status) = &query.status {
1012            select = select.filter(TaskColumn::Status.eq(status.to_string()));
1013        }
1014        if let Some(kind) = &query.kind {
1015            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1016        }
1017        if let Some(parent_id) = query.parent_id {
1018            select = select.filter(TaskColumn::ParentId.eq(parent_id));
1019        }
1020        if query.root_only {
1021            select = select.filter(TaskColumn::ParentId.is_null());
1022        }
1023        if let Some(name) = &query.name {
1024            select = select.filter(TaskColumn::Name.eq(name.as_str()));
1025        }
1026        if let Some(queue) = &query.queue_name {
1027            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1028        }
1029
1030        let count = select.count(db).await?;
1031        Ok(count)
1032    }
1033
1034    // ── Accessors ────────────────────────────────────────────────
1035
1036    pub fn db(&self) -> &DatabaseConnection {
1037        &self.db
1038    }
1039
1040    pub fn task_id(&self) -> Uuid {
1041        self.task_id
1042    }
1043
1044    pub fn next_sequence(&self) -> i32 {
1045        self.sequence.fetch_add(1, Ordering::SeqCst)
1046    }
1047
1048    /// Deserialize the stored input for this task.
1049    ///
1050    /// Returns `Ok(value)` when the task has a non-null `input` column that
1051    /// deserializes to `T`. Returns an error if the input is missing or
1052    /// cannot be deserialized.
1053    ///
1054    /// ```ignore
1055    /// let args: IngestInput = ctx.input().await?;
1056    /// ```
1057    pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1058        let row = self
1059            .db
1060            .query_one(Statement::from_string(
1061                DbBackend::Postgres,
1062                format!(
1063                    "SELECT input FROM durable.task WHERE id = '{}'",
1064                    self.task_id
1065                ),
1066            ))
1067            .await?
1068            .ok_or_else(|| {
1069                DurableError::custom(format!("task {} not found", self.task_id))
1070            })?;
1071
1072        let input_json: Option<serde_json::Value> = row
1073            .try_get_by_index(0)
1074            .map_err(|e| DurableError::custom(e.to_string()))?;
1075
1076        let value = input_json.ok_or_else(|| {
1077            DurableError::custom(format!("task {} has no input", self.task_id))
1078        })?;
1079
1080        serde_json::from_value(value)
1081            .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1082    }
1083}
1084
1085// ── Internal SQL helpers ─────────────────────────────────────────────
1086
1087/// Find an existing task by (parent_id, name) or create a new one.
1088///
1089/// Returns `(task_id, Option<saved_output>)`:
1090/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
1091/// - `saved_output` is `None` when the task is new or in-progress.
1092///
1093/// When `lock` is `true` and an existing non-completed task is found, this
1094/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
1095/// another worker holds the lock, `DurableError::StepLocked` is returned so
1096/// the caller can skip execution rather than double-firing side effects.
1097///
1098/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
1099/// and child-workflow creation where concurrent start is safe).
1100///
1101/// When `lock` is `true`, the caller MUST call this within a transaction.
1102/// The lock prevents concurrent creation/claiming of the same step. The
1103/// caller should set the step to RUNNING and commit promptly — the RUNNING
1104/// status itself prevents other executors from re-claiming the step.
1105///
1106/// `max_retries`: if Some, overrides the schema default when creating a new task.
1107#[allow(clippy::too_many_arguments)]
1108async fn find_or_create_task(
1109    db: &impl ConnectionTrait,
1110    parent_id: Option<Uuid>,
1111    sequence: Option<i32>,
1112    name: &str,
1113    kind: &str,
1114    input: Option<serde_json::Value>,
1115    lock: bool,
1116    max_retries: Option<u32>,
1117) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1118    let parent_eq = match parent_id {
1119        Some(p) => format!("= '{p}'"),
1120        None => "IS NULL".to_string(),
1121    };
1122    let parent_sql = match parent_id {
1123        Some(p) => format!("'{p}'"),
1124        None => "NULL".to_string(),
1125    };
1126
1127    if lock {
1128        // Locking path (for steps): we need exactly-once execution.
1129        //
1130        // Strategy:
1131        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
1132        //    If another transaction is concurrently inserting the same row,
1133        //    Postgres will block here until that transaction commits or rolls
1134        //    back, ensuring we never see a phantom "not found" for a row being
1135        //    inserted.
1136        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
1137        //    should get it back; if the row existed and is locked by another
1138        //    worker we get nothing.
1139        // 3. If SKIP LOCKED returns empty, return StepLocked.
1140
1141        let new_id = Uuid::new_v4();
1142        let seq_sql = match sequence {
1143            Some(s) => s.to_string(),
1144            None => "NULL".to_string(),
1145        };
1146        let input_sql = match &input {
1147            Some(v) => format!("'{}'", serde_json::to_string(v)?),
1148            None => "NULL".to_string(),
1149        };
1150
1151        let max_retries_sql = match max_retries {
1152            Some(r) => r.to_string(),
1153            None => "3".to_string(), // schema default
1154        };
1155
1156        // Step 1: insert-or-skip
1157        let insert_sql = format!(
1158            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1159             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
1160             ON CONFLICT (parent_id, sequence) DO NOTHING"
1161        );
1162        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1163            .await?;
1164
1165        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
1166        let lock_sql = format!(
1167            "SELECT id, status::text, output FROM durable.task \
1168             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1169             FOR UPDATE SKIP LOCKED"
1170        );
1171        let row = db
1172            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1173            .await?;
1174
1175        if let Some(row) = row {
1176            let id: Uuid = row
1177                .try_get_by_index(0)
1178                .map_err(|e| DurableError::custom(e.to_string()))?;
1179            let status: String = row
1180                .try_get_by_index(1)
1181                .map_err(|e| DurableError::custom(e.to_string()))?;
1182            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1183
1184            if status == TaskStatus::Completed.to_string() {
1185                // Replay path — return saved output
1186                return Ok((id, output));
1187            }
1188            if status == TaskStatus::Running.to_string() {
1189                // Another worker already claimed this step (set RUNNING and committed).
1190                // We got the row lock (no longer held in a long txn) but must not
1191                // double-execute.
1192                return Err(DurableError::StepLocked(name.to_string()));
1193            }
1194            // Task is PENDING — we're the first to claim it.
1195            return Ok((id, None));
1196        }
1197
1198        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
1199        Err(DurableError::StepLocked(name.to_string()))
1200    } else {
1201        // Plain find without locking — safe for workflow-level operations.
1202        // Multiple workers resuming the same workflow is fine; individual
1203        // steps will be locked when executed.
1204        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1205        query = match parent_id {
1206            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1207            None => query.filter(TaskColumn::ParentId.is_null()),
1208        };
1209        // Skip cancelled/failed tasks — a fresh row will be created instead.
1210        // Completed tasks are still returned so child tasks can replay saved output.
1211        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1212        let existing = query
1213            .filter(TaskColumn::Status.is_not_in(status_exclusions))
1214            .one(db)
1215            .await?;
1216
1217        if let Some(model) = existing {
1218            if model.status == TaskStatus::Completed {
1219                return Ok((model.id, model.output));
1220            }
1221            return Ok((model.id, None));
1222        }
1223
1224        // Task does not exist — create it
1225        let id = Uuid::new_v4();
1226        let new_task = TaskActiveModel {
1227            id: Set(id),
1228            parent_id: Set(parent_id),
1229            sequence: Set(sequence),
1230            name: Set(name.to_string()),
1231            kind: Set(kind.to_string()),
1232            status: Set(TaskStatus::Pending),
1233            input: Set(input),
1234            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1235            ..Default::default()
1236        };
1237        new_task.insert(db).await?;
1238
1239        Ok((id, None))
1240    }
1241}
1242
1243async fn get_output(
1244    db: &impl ConnectionTrait,
1245    task_id: Uuid,
1246) -> Result<Option<serde_json::Value>, DurableError> {
1247    let model = Task::find_by_id(task_id)
1248        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1249        .one(db)
1250        .await?;
1251
1252    Ok(model.and_then(|m| m.output))
1253}
1254
1255async fn get_status(
1256    db: &impl ConnectionTrait,
1257    task_id: Uuid,
1258) -> Result<Option<TaskStatus>, DurableError> {
1259    let model = Task::find_by_id(task_id).one(db).await?;
1260
1261    Ok(model.map(|m| m.status))
1262}
1263
1264/// Returns (retry_count, max_retries) for a task.
1265async fn get_retry_info(
1266    db: &DatabaseConnection,
1267    task_id: Uuid,
1268) -> Result<(u32, u32), DurableError> {
1269    let model = Task::find_by_id(task_id).one(db).await?;
1270
1271    match model {
1272        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1273        None => Err(DurableError::custom(format!(
1274            "task {task_id} not found when reading retry info"
1275        ))),
1276    }
1277}
1278
1279/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1280async fn increment_retry_count(
1281    db: &DatabaseConnection,
1282    task_id: Uuid,
1283) -> Result<u32, DurableError> {
1284    let model = Task::find_by_id(task_id).one(db).await?;
1285
1286    match model {
1287        Some(m) => {
1288            let new_count = m.retry_count + 1;
1289            let mut active: TaskActiveModel = m.into();
1290            active.retry_count = Set(new_count);
1291            active.status = Set(TaskStatus::Pending);
1292            active.error = Set(None);
1293            active.completed_at = Set(None);
1294            active.update(db).await?;
1295            Ok(new_count as u32)
1296        }
1297        None => Err(DurableError::custom(format!(
1298            "task {task_id} not found when incrementing retry count"
1299        ))),
1300    }
1301}
1302
1303// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1304// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1305async fn set_status(
1306    db: &impl ConnectionTrait,
1307    task_id: Uuid,
1308    status: TaskStatus,
1309) -> Result<(), DurableError> {
1310    let sql = format!(
1311        "UPDATE durable.task \
1312         SET status = '{status}', \
1313             started_at = COALESCE(started_at, now()), \
1314             deadline_epoch_ms = CASE \
1315                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1316                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1317                 ELSE deadline_epoch_ms \
1318             END \
1319         WHERE id = '{task_id}'"
1320    );
1321    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1322        .await?;
1323    Ok(())
1324}
1325
1326/// Check if the task is paused or cancelled. Returns an error if so.
1327async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1328    let status = get_status(db, task_id).await?;
1329    match status {
1330        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1331        Some(TaskStatus::Cancelled) => {
1332            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1333        }
1334        _ => Ok(()),
1335    }
1336}
1337
1338/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1339async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1340    let model = Task::find_by_id(task_id).one(db).await?;
1341
1342    if let Some(m) = model
1343        && let Some(deadline_ms) = m.deadline_epoch_ms
1344    {
1345        let now_ms = std::time::SystemTime::now()
1346            .duration_since(std::time::UNIX_EPOCH)
1347            .map(|d| d.as_millis() as i64)
1348            .unwrap_or(0);
1349        if now_ms > deadline_ms {
1350            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1351        }
1352    }
1353
1354    Ok(())
1355}
1356
1357async fn complete_task(
1358    db: &impl ConnectionTrait,
1359    task_id: Uuid,
1360    output: serde_json::Value,
1361) -> Result<(), DurableError> {
1362    let model = Task::find_by_id(task_id).one(db).await?;
1363
1364    if let Some(m) = model {
1365        let mut active: TaskActiveModel = m.into();
1366        active.status = Set(TaskStatus::Completed);
1367        active.output = Set(Some(output));
1368        active.completed_at = Set(Some(chrono::Utc::now().into()));
1369        active.update(db).await?;
1370    }
1371    Ok(())
1372}
1373
1374async fn fail_task(
1375    db: &impl ConnectionTrait,
1376    task_id: Uuid,
1377    error: &str,
1378) -> Result<(), DurableError> {
1379    let model = Task::find_by_id(task_id).one(db).await?;
1380
1381    if let Some(m) = model {
1382        let mut active: TaskActiveModel = m.into();
1383        active.status = Set(TaskStatus::Failed);
1384        active.error = Set(Some(error.to_string()));
1385        active.completed_at = Set(Some(chrono::Utc::now().into()));
1386        active.update(db).await?;
1387    }
1388    Ok(())
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394    use std::sync::Arc;
1395    use std::sync::atomic::{AtomicU32, Ordering};
1396
1397    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1398    /// should be called exactly once and return Ok.
1399    #[tokio::test]
1400    async fn test_retry_db_write_succeeds_first_try() {
1401        let call_count = Arc::new(AtomicU32::new(0));
1402        let cc = call_count.clone();
1403        let result = retry_db_write(|| {
1404            let c = cc.clone();
1405            async move {
1406                c.fetch_add(1, Ordering::SeqCst);
1407                Ok::<(), DurableError>(())
1408            }
1409        })
1410        .await;
1411        assert!(result.is_ok());
1412        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1413    }
1414
1415    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1416    /// fails twice then succeeds should return Ok and be called 3 times.
1417    #[tokio::test]
1418    async fn test_retry_db_write_succeeds_after_transient_failure() {
1419        let call_count = Arc::new(AtomicU32::new(0));
1420        let cc = call_count.clone();
1421        let result = retry_db_write(|| {
1422            let c = cc.clone();
1423            async move {
1424                let n = c.fetch_add(1, Ordering::SeqCst);
1425                if n < 2 {
1426                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1427                        "transient".to_string(),
1428                    )))
1429                } else {
1430                    Ok(())
1431                }
1432            }
1433        })
1434        .await;
1435        assert!(result.is_ok());
1436        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1437    }
1438
1439    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1440    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1441    #[tokio::test]
1442    async fn test_retry_db_write_exhausts_retries() {
1443        let call_count = Arc::new(AtomicU32::new(0));
1444        let cc = call_count.clone();
1445        let result = retry_db_write(|| {
1446            let c = cc.clone();
1447            async move {
1448                c.fetch_add(1, Ordering::SeqCst);
1449                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1450                    "always fails".to_string(),
1451                )))
1452            }
1453        })
1454        .await;
1455        assert!(result.is_err());
1456        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1457        assert_eq!(
1458            call_count.load(Ordering::SeqCst),
1459            1 + MAX_CHECKPOINT_RETRIES
1460        );
1461    }
1462
1463    /// test_retry_db_write_returns_original_error: when all retries are
1464    /// exhausted the FIRST error is returned, not the last retry error.
1465    #[tokio::test]
1466    async fn test_retry_db_write_returns_original_error() {
1467        let call_count = Arc::new(AtomicU32::new(0));
1468        let cc = call_count.clone();
1469        let result = retry_db_write(|| {
1470            let c = cc.clone();
1471            async move {
1472                let n = c.fetch_add(1, Ordering::SeqCst);
1473                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1474                    "error-{}",
1475                    n
1476                ))))
1477            }
1478        })
1479        .await;
1480        let err = result.unwrap_err();
1481        // The message of the first error contains "error-0"
1482        assert!(
1483            err.to_string().contains("error-0"),
1484            "expected first error (error-0), got: {err}"
1485        );
1486    }
1487}