Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── SQL constants ────────────────────────────────────────────────────────────
21
22/// Explicit cast suffix for the `task_status` enum in raw SQL.
23///
24/// PostgreSQL cannot resolve `task_status` if the `durable` schema is not in
25/// the connection's `search_path`. Appending this cast to every status literal
26/// makes queries work regardless of the caller's search_path configuration.
27pub(crate) const TS: &str = "::durable.task_status";
28
29// ── Retry constants ───────────────────────────────────────────────────────────
30
31const MAX_CHECKPOINT_RETRIES: u32 = 3;
32const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
33
34/// Retry a fallible DB write with exponential backoff.
35///
36/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
37/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
38/// error if all retries are exhausted.
39async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
40where
41    F: FnMut() -> Fut,
42    Fut: std::future::Future<Output = Result<(), DurableError>>,
43{
44    match f().await {
45        Ok(()) => Ok(()),
46        Err(first_err) => {
47            for i in 0..MAX_CHECKPOINT_RETRIES {
48                tokio::time::sleep(Duration::from_millis(
49                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
50                ))
51                .await;
52                if f().await.is_ok() {
53                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
54                    return Ok(());
55                }
56            }
57            Err(first_err)
58        }
59    }
60}
61
62/// Policy for retrying a step on failure.
63pub struct RetryPolicy {
64    pub max_retries: u32,
65    pub initial_backoff: std::time::Duration,
66    pub backoff_multiplier: f64,
67}
68
69impl RetryPolicy {
70    /// No retries — fails immediately on error.
71    pub fn none() -> Self {
72        Self {
73            max_retries: 0,
74            initial_backoff: std::time::Duration::from_secs(0),
75            backoff_multiplier: 1.0,
76        }
77    }
78
79    /// Exponential backoff: backoff doubles each retry.
80    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff,
84            backoff_multiplier: 2.0,
85        }
86    }
87
88    /// Fixed backoff: same duration between all retries.
89    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
90        Self {
91            max_retries,
92            initial_backoff: backoff,
93            backoff_multiplier: 1.0,
94        }
95    }
96}
97
98/// Sort order for task listing.
99pub enum TaskSort {
100    CreatedAt(Order),
101    StartedAt(Order),
102    CompletedAt(Order),
103    Name(Order),
104    Status(Order),
105}
106
107/// Builder for filtering, sorting, and paginating task queries.
108///
109/// ```ignore
110/// let query = TaskQuery::default()
111///     .status(TaskStatus::Running)
112///     .kind("WORKFLOW")
113///     .root_only(true)
114///     .sort(TaskSort::CreatedAt(Order::Desc))
115///     .limit(20);
116/// let tasks = Ctx::list(&db, query).await?;
117/// ```
118pub struct TaskQuery {
119    pub status: Option<TaskStatus>,
120    pub kind: Option<String>,
121    pub parent_id: Option<Uuid>,
122    pub root_only: bool,
123    pub name: Option<String>,
124    pub queue_name: Option<String>,
125    pub sort: TaskSort,
126    pub limit: Option<u64>,
127    pub offset: Option<u64>,
128}
129
130impl Default for TaskQuery {
131    fn default() -> Self {
132        Self {
133            status: None,
134            kind: None,
135            parent_id: None,
136            root_only: false,
137            name: None,
138            queue_name: None,
139            sort: TaskSort::CreatedAt(Order::Desc),
140            limit: None,
141            offset: None,
142        }
143    }
144}
145
146impl TaskQuery {
147    /// Filter by status.
148    pub fn status(mut self, status: TaskStatus) -> Self {
149        self.status = Some(status);
150        self
151    }
152
153    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
154    pub fn kind(mut self, kind: &str) -> Self {
155        self.kind = Some(kind.to_string());
156        self
157    }
158
159    /// Filter by parent task ID (direct children only).
160    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
161        self.parent_id = Some(parent_id);
162        self
163    }
164
165    /// Only return root tasks (no parent).
166    pub fn root_only(mut self, root_only: bool) -> Self {
167        self.root_only = root_only;
168        self
169    }
170
171    /// Filter by task name.
172    pub fn name(mut self, name: &str) -> Self {
173        self.name = Some(name.to_string());
174        self
175    }
176
177    /// Filter by queue name.
178    pub fn queue_name(mut self, queue: &str) -> Self {
179        self.queue_name = Some(queue.to_string());
180        self
181    }
182
183    /// Set the sort order.
184    pub fn sort(mut self, sort: TaskSort) -> Self {
185        self.sort = sort;
186        self
187    }
188
189    /// Limit the number of results.
190    pub fn limit(mut self, limit: u64) -> Self {
191        self.limit = Some(limit);
192        self
193    }
194
195    /// Skip the first N results.
196    pub fn offset(mut self, offset: u64) -> Self {
197        self.offset = Some(offset);
198        self
199    }
200}
201
202/// Summary of a task returned by `Ctx::list()`.
203#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
204pub struct TaskSummary {
205    pub id: Uuid,
206    pub parent_id: Option<Uuid>,
207    pub name: String,
208    pub handler: Option<String>,
209    pub status: TaskStatus,
210    pub kind: String,
211    pub input: Option<serde_json::Value>,
212    pub output: Option<serde_json::Value>,
213    pub error: Option<String>,
214    pub queue_name: Option<String>,
215    pub created_at: chrono::DateTime<chrono::FixedOffset>,
216    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
217    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
218}
219
220impl From<durable_db::entity::task::Model> for TaskSummary {
221    fn from(m: durable_db::entity::task::Model) -> Self {
222        Self {
223            id: m.id,
224            parent_id: m.parent_id,
225            name: m.name,
226            handler: m.handler,
227            status: m.status,
228            kind: m.kind,
229            input: m.input,
230            output: m.output,
231            error: m.error,
232            queue_name: m.queue_name,
233            created_at: m.created_at,
234            started_at: m.started_at,
235            completed_at: m.completed_at,
236        }
237    }
238}
239
240/// Context threaded through every workflow and step.
241///
242/// Users never create or manage task IDs. The SDK handles everything
243/// via `(parent_id, name)` lookups — the unique constraint in the schema
244/// guarantees exactly-once step creation.
245pub struct Ctx {
246    db: DatabaseConnection,
247    task_id: Uuid,
248    sequence: AtomicI32,
249    executor_id: Option<String>,
250}
251
252impl Ctx {
253    // ── Workflow lifecycle (user-facing) ──────────────────────────
254
255    /// Start a new root workflow (or attach to an existing one with the same name).
256    ///
257    /// If a RUNNING root task with the same `name` already exists (e.g. recovered
258    /// after a crash), returns a `Ctx` attached to that task (idempotent start).
259    /// Otherwise creates a fresh task.
260    ///
261    /// If [`init`](crate::init) was called, the task is automatically tagged
262    /// with the executor ID for crash recovery. Otherwise `executor_id` is NULL.
263    ///
264    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
265    ///
266    /// ```ignore
267    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
268    /// ```
269    pub async fn start(
270        db: &DatabaseConnection,
271        name: &str,
272        input: Option<serde_json::Value>,
273    ) -> Result<Self, DurableError> {
274        Self::start_with_handler(db, name, input, None).await
275    }
276
277    /// Like [`start`](Self::start) but records the handler function name for
278    /// crash recovery. The `handler` is stored in the `durable.task` row so
279    /// that [`dispatch_recovered`](crate::dispatch_recovered) can look up the
280    /// registered `#[durable::workflow]` function even when `name` differs.
281    ///
282    /// Typically called from macro-generated `start_<fn>` functions rather
283    /// than directly.
284    pub async fn start_with_handler(
285        db: &DatabaseConnection,
286        name: &str,
287        input: Option<serde_json::Value>,
288        handler: Option<&str>,
289    ) -> Result<Self, DurableError> {
290        // ── Idempotent: reuse existing RUNNING root task with same name ──
291        let existing_sql = format!(
292            "SELECT id FROM durable.task \
293             WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING'{TS} \
294             LIMIT 1",
295            name
296        );
297        if let Some(row) = db
298            .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
299            .await?
300        {
301            let existing_id: Uuid = row
302                .try_get_by_index(0)
303                .map_err(|e| DurableError::custom(e.to_string()))?;
304            tracing::info!(
305                workflow = name,
306                id = %existing_id,
307                "idempotent start: attaching to existing RUNNING task"
308            );
309            return Self::from_id(db, existing_id).await;
310        }
311
312        let task_id = Uuid::new_v4();
313        let input_json = match &input {
314            Some(v) => serde_json::to_string(v)?,
315            None => "null".to_string(),
316        };
317
318        let executor_id = crate::executor_id();
319
320        let mut extra_cols = String::new();
321        let mut extra_vals = String::new();
322
323        if let Some(eid) = &executor_id {
324            extra_cols.push_str(", executor_id");
325            extra_vals.push_str(&format!(", '{eid}'"));
326        }
327        if let Some(h) = handler {
328            extra_cols.push_str(", handler");
329            extra_vals.push_str(&format!(", '{h}'"));
330        }
331
332        let txn = db.begin().await?;
333        let sql = format!(
334            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
335             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING'{TS}, '{input_json}', now(){extra_vals})"
336        );
337        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
338            .await?;
339        txn.commit().await?;
340
341        Ok(Self {
342            db: db.clone(),
343            task_id,
344            sequence: AtomicI32::new(0),
345            executor_id,
346        })
347    }
348
349    /// Attach to an existing workflow by task ID. Used to resume a running or
350    /// recovered workflow without creating a new task row.
351    ///
352    /// ```ignore
353    /// let ctx = Ctx::from_id(&db, task_id).await?;
354    /// ```
355    pub async fn from_id(
356        db: &DatabaseConnection,
357        task_id: Uuid,
358    ) -> Result<Self, DurableError> {
359        // Verify the task exists
360        let model = Task::find_by_id(task_id).one(db).await?;
361        let _model =
362            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
363
364        // Claim the task with the current executor_id (if init() was called)
365        let executor_id = crate::executor_id();
366        if let Some(eid) = &executor_id {
367            db.execute(Statement::from_string(
368                DbBackend::Postgres,
369                format!(
370                    "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
371                ),
372            ))
373            .await?;
374        }
375
376        // Start sequence at 0 so that replaying steps from the beginning works
377        // correctly. Steps are looked up by (parent_id, sequence), so the caller
378        // will replay completed steps (getting saved outputs) before executing
379        // any new steps.
380        Ok(Self {
381            db: db.clone(),
382            task_id,
383            sequence: AtomicI32::new(0),
384            executor_id,
385        })
386    }
387
388    /// Run a step. If already completed, returns saved output. Otherwise executes
389    /// the closure, saves the result, and returns it.
390    ///
391    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
392    ///
393    /// ```ignore
394    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
395    /// ```
396    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
397    where
398        T: Serialize + DeserializeOwned,
399        F: FnOnce() -> Fut,
400        Fut: std::future::Future<Output = Result<T, DurableError>>,
401    {
402        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
403
404        // Check if workflow is paused or cancelled before executing
405        check_status(&self.db, self.task_id).await?;
406
407        // Check if parent task's deadline has passed before executing
408        check_deadline(&self.db, self.task_id).await?;
409
410        // TX1: Find or create the step row, set RUNNING, release connection.
411        // The RUNNING status acts as a mutex — recovery won't reclaim it
412        // unless this executor's heartbeat goes stale.
413        let txn = self.db.begin().await?;
414        let (step_id, saved_output) = find_or_create_task(
415            &txn,
416            Some(self.task_id),
417            Some(seq),
418            name,
419            "STEP",
420            None,
421            true,
422            Some(0),
423        )
424        .await?;
425
426        if let Some(output) = saved_output {
427            txn.commit().await?;
428            let val: T = serde_json::from_value(output)?;
429            tracing::debug!(step = name, seq, "replaying saved output");
430            return Ok(val);
431        }
432
433        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
434        txn.commit().await?;
435        // ── connection returned to pool ──
436
437        // Execute the closure with no DB connection held.
438        let result = f().await;
439
440        // TX2: Checkpoint the result (short transaction).
441        match result {
442            Ok(val) => {
443                let json = serde_json::to_value(&val)?;
444                retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
445                tracing::debug!(step = name, seq, "step completed");
446                Ok(val)
447            }
448            Err(e) => {
449                let err_msg = e.to_string();
450                retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
451                Err(e)
452            }
453        }
454    }
455
456    /// Run a DB-only step inside a single Postgres transaction.
457    ///
458    /// Both the user's DB work and the checkpoint save happen in the same transaction,
459    /// ensuring atomicity. If the closure returns an error, both the user writes and
460    /// the checkpoint are rolled back.
461    ///
462    /// ```ignore
463    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
464    ///     do_db_work(tx).await
465    /// })).await?;
466    /// ```
467    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
468    where
469        T: Serialize + DeserializeOwned + Send,
470        F: for<'tx> FnOnce(
471                &'tx DatabaseTransaction,
472            ) -> Pin<
473                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
474            > + Send,
475    {
476        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
477
478        // Check if workflow is paused or cancelled before executing
479        check_status(&self.db, self.task_id).await?;
480
481        // Find or create the step task record OUTSIDE the transaction.
482        // This is idempotent (UNIQUE constraint) and must exist before we begin.
483        let (step_id, saved_output) = find_or_create_task(
484            &self.db,
485            Some(self.task_id),
486            Some(seq),
487            name,
488            "TRANSACTION",
489            None,
490            false,
491            None,
492        )
493        .await?;
494
495        // If already completed, replay from saved output.
496        if let Some(output) = saved_output {
497            let val: T = serde_json::from_value(output)?;
498            tracing::debug!(step = name, seq, "replaying saved transaction output");
499            return Ok(val);
500        }
501
502        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
503        let tx = self.db.begin().await?;
504
505        set_status(&tx, step_id, TaskStatus::Running).await?;
506
507        match f(&tx).await {
508            Ok(val) => {
509                let json = serde_json::to_value(&val)?;
510                complete_task(&tx, step_id, json).await?;
511                tx.commit().await?;
512                tracing::debug!(step = name, seq, "transaction step committed");
513                Ok(val)
514            }
515            Err(e) => {
516                // Rollback happens automatically when tx is dropped.
517                // Record the failure on the main connection (outside the rolled-back tx).
518                drop(tx);
519                fail_task(&self.db, step_id, &e.to_string()).await?;
520                Err(e)
521            }
522        }
523    }
524
525    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
526    ///
527    /// ```ignore
528    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
529    /// // use child_ctx.step(...) for steps inside the child
530    /// child_ctx.complete(json!({"done": true})).await?;
531    /// ```
532    pub async fn child(
533        &self,
534        name: &str,
535        input: Option<serde_json::Value>,
536    ) -> Result<Self, DurableError> {
537        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
538
539        // Check if workflow is paused or cancelled before executing
540        check_status(&self.db, self.task_id).await?;
541
542        let txn = self.db.begin().await?;
543        // Child workflows also use plain find-or-create without locking.
544        let (child_id, _saved) = find_or_create_task(
545            &txn,
546            Some(self.task_id),
547            Some(seq),
548            name,
549            "WORKFLOW",
550            input,
551            false,
552            None,
553        )
554        .await?;
555
556        // If child already completed, return a Ctx that will replay
557        // (the caller should check is_completed() or just run steps which will replay)
558        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
559        txn.commit().await?;
560
561        Ok(Self {
562            db: self.db.clone(),
563            task_id: child_id,
564            sequence: AtomicI32::new(0),
565            executor_id: self.executor_id.clone(),
566        })
567    }
568
569    /// Check if this workflow/child was already completed (for skipping in parent).
570    pub async fn is_completed(&self) -> Result<bool, DurableError> {
571        let status = get_status(&self.db, self.task_id).await?;
572        Ok(status == Some(TaskStatus::Completed))
573    }
574
575    /// Get the saved output if this task is completed.
576    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
577        match get_output(&self.db, self.task_id).await? {
578            Some(val) => Ok(Some(serde_json::from_value(val)?)),
579            None => Ok(None),
580        }
581    }
582
583    /// Mark this workflow as completed with an output value.
584    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
585        let json = serde_json::to_value(output)?;
586        let db = &self.db;
587        let task_id = self.task_id;
588        retry_db_write(|| complete_task(db, task_id, json.clone())).await
589    }
590
591    /// Run a step with a configurable retry policy.
592    ///
593    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
594    /// it may be called multiple times on retry. Retries happen in-process with
595    /// configurable backoff between attempts.
596    ///
597    /// ```ignore
598    /// let result: u32 = ctx
599    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
600    ///         api.call().await
601    ///     })
602    ///     .await?;
603    /// ```
604    pub async fn step_with_retry<T, F, Fut>(
605        &self,
606        name: &str,
607        policy: RetryPolicy,
608        f: F,
609    ) -> Result<T, DurableError>
610    where
611        T: Serialize + DeserializeOwned,
612        F: Fn() -> Fut,
613        Fut: std::future::Future<Output = Result<T, DurableError>>,
614    {
615        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
616
617        // Check if workflow is paused or cancelled before executing
618        check_status(&self.db, self.task_id).await?;
619
620        // Find or create — idempotent via UNIQUE(parent_id, name)
621        // Set max_retries from policy when creating.
622        // No locking here — retry logic handles re-execution in-process.
623        let (step_id, saved_output) = find_or_create_task(
624            &self.db,
625            Some(self.task_id),
626            Some(seq),
627            name,
628            "STEP",
629            None,
630            false,
631            Some(policy.max_retries),
632        )
633        .await?;
634
635        // If already completed, replay from saved output
636        if let Some(output) = saved_output {
637            let val: T = serde_json::from_value(output)?;
638            tracing::debug!(step = name, seq, "replaying saved output");
639            return Ok(val);
640        }
641
642        // Get current retry state from DB (for resume across restarts)
643        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
644
645        // Retry loop
646        loop {
647            // Re-check status before each retry attempt
648            check_status(&self.db, self.task_id).await?;
649            set_status(&self.db, step_id, TaskStatus::Running).await?;
650            match f().await {
651                Ok(val) => {
652                    let json = serde_json::to_value(&val)?;
653                    complete_task(&self.db, step_id, json).await?;
654                    tracing::debug!(step = name, seq, retry_count, "step completed");
655                    return Ok(val);
656                }
657                Err(e) => {
658                    if retry_count < max_retries {
659                        // Increment retry count and reset to PENDING
660                        retry_count = increment_retry_count(&self.db, step_id).await?;
661                        tracing::debug!(
662                            step = name,
663                            seq,
664                            retry_count,
665                            max_retries,
666                            "step failed, retrying"
667                        );
668
669                        // Compute backoff duration
670                        let backoff = if policy.initial_backoff.is_zero() {
671                            std::time::Duration::ZERO
672                        } else {
673                            let factor = policy
674                                .backoff_multiplier
675                                .powi((retry_count - 1) as i32)
676                                .max(1.0);
677                            let millis =
678                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
679                            std::time::Duration::from_millis(millis)
680                        };
681
682                        if !backoff.is_zero() {
683                            tokio::time::sleep(backoff).await;
684                        }
685                    } else {
686                        // Exhausted retries — mark FAILED
687                        fail_task(&self.db, step_id, &e.to_string()).await?;
688                        tracing::debug!(
689                            step = name,
690                            seq,
691                            retry_count,
692                            "step exhausted retries, marked FAILED"
693                        );
694                        return Err(e);
695                    }
696                }
697            }
698        }
699    }
700
701    /// Mark this workflow as failed.
702    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
703        let db = &self.db;
704        let task_id = self.task_id;
705        retry_db_write(|| fail_task(db, task_id, error)).await
706    }
707
708    /// Mark a workflow as failed by task ID (static version for use outside a Ctx).
709    pub async fn fail_by_id(
710        db: &DatabaseConnection,
711        task_id: Uuid,
712        error: &str,
713    ) -> Result<(), DurableError> {
714        retry_db_write(|| fail_task(db, task_id, error)).await
715    }
716
717    /// Set the timeout for this task in milliseconds.
718    ///
719    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
720    /// for recovery by `Executor::recover()`.
721    ///
722    /// If the task is currently RUNNING and has a `started_at`, this also
723    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
724    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
725        let sql = format!(
726            "UPDATE durable.task \
727             SET timeout_ms = {timeout_ms}, \
728                 deadline_epoch_ms = CASE \
729                     WHEN status = 'RUNNING'{TS} AND started_at IS NOT NULL \
730                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
731                     ELSE deadline_epoch_ms \
732                 END \
733             WHERE id = '{}'",
734            self.task_id
735        );
736        self.db
737            .execute(Statement::from_string(DbBackend::Postgres, sql))
738            .await?;
739        Ok(())
740    }
741
742    /// Start or resume a root workflow with a timeout in milliseconds.
743    ///
744    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
745    pub async fn start_with_timeout(
746        db: &DatabaseConnection,
747        name: &str,
748        input: Option<serde_json::Value>,
749        timeout_ms: i64,
750    ) -> Result<Self, DurableError> {
751        let ctx = Self::start(db, name, input).await?;
752        ctx.set_timeout(timeout_ms).await?;
753        Ok(ctx)
754    }
755
756    // ── Workflow control (management API) ─────────────────────────
757
758    /// Pause a workflow by ID. Sets status to PAUSED and recursively
759    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
760    ///
761    /// Only workflows in PENDING or RUNNING status can be paused.
762    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
763        let model = Task::find_by_id(task_id).one(db).await?;
764        let model =
765            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
766
767        match model.status {
768            TaskStatus::Pending | TaskStatus::Running => {}
769            status => {
770                return Err(DurableError::custom(format!(
771                    "cannot pause task in {status} status"
772                )));
773            }
774        }
775
776        // Pause the task itself and all descendants in one recursive CTE
777        let sql = format!(
778            "WITH RECURSIVE descendants AS ( \
779                 SELECT id FROM durable.task WHERE id = '{task_id}' \
780                 UNION ALL \
781                 SELECT t.id FROM durable.task t \
782                 INNER JOIN descendants d ON t.parent_id = d.id \
783             ) \
784             UPDATE durable.task SET status = 'PAUSED'{TS} \
785             WHERE id IN (SELECT id FROM descendants) \
786               AND status IN ('PENDING'{TS}, 'RUNNING'{TS})"
787        );
788        db.execute(Statement::from_string(DbBackend::Postgres, sql))
789            .await?;
790
791        tracing::info!(%task_id, "workflow paused");
792        Ok(())
793    }
794
795    /// Resume a paused workflow by ID. Sets status back to RUNNING and
796    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
797    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
798        let model = Task::find_by_id(task_id).one(db).await?;
799        let model =
800            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
801
802        if model.status != TaskStatus::Paused {
803            return Err(DurableError::custom(format!(
804                "cannot resume task in {} status (must be PAUSED)",
805                model.status
806            )));
807        }
808
809        // Resume the root task back to RUNNING
810        let mut active: TaskActiveModel = model.into();
811        active.status = Set(TaskStatus::Running);
812        active.update(db).await?;
813
814        // Recursively resume all PAUSED descendants back to PENDING
815        let cascade_sql = format!(
816            "WITH RECURSIVE descendants AS ( \
817                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
818                 UNION ALL \
819                 SELECT t.id FROM durable.task t \
820                 INNER JOIN descendants d ON t.parent_id = d.id \
821             ) \
822             UPDATE durable.task SET status = 'PENDING'{TS} \
823             WHERE id IN (SELECT id FROM descendants) \
824               AND status = 'PAUSED'{TS}"
825        );
826        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
827            .await?;
828
829        tracing::info!(%task_id, "workflow resumed");
830        Ok(())
831    }
832
833    /// Resume a failed workflow from its last completed step (DBOS-style).
834    ///
835    /// Resets the workflow from FAILED to RUNNING, resets all FAILED child
836    /// steps to PENDING (with `retry_count` zeroed), and increments
837    /// `recovery_count`. Completed steps are left untouched so they replay
838    /// from their checkpointed outputs.
839    ///
840    /// Returns the workflow's `handler` name (if set) so the caller can
841    /// look up the `WorkflowRegistration` and re-dispatch.
842    ///
843    /// Fails with `MaxRecoveryExceeded` if the workflow has already been
844    /// resumed `max_recovery_attempts` times.
845    pub async fn resume_failed(
846        db: &DatabaseConnection,
847        task_id: Uuid,
848    ) -> Result<Option<String>, DurableError> {
849        // Validate the workflow exists and is FAILED
850        let model = Task::find_by_id(task_id).one(db).await?;
851        let model =
852            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
853
854        if model.status != TaskStatus::Failed {
855            return Err(DurableError::custom(format!(
856                "cannot resume task in {} status (must be FAILED)",
857                model.status
858            )));
859        }
860
861        if model.kind != "WORKFLOW" {
862            return Err(DurableError::custom(format!(
863                "cannot resume a {}, only workflows can be resumed",
864                model.kind
865            )));
866        }
867
868        // Check recovery limit
869        if model.recovery_count >= model.max_recovery_attempts {
870            return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
871        }
872
873        let executor_id = crate::executor_id().unwrap_or_default();
874
875        // Atomically: reset workflow to RUNNING, increment recovery_count
876        let reset_workflow_sql = format!(
877            "UPDATE durable.task \
878             SET status = 'RUNNING'{TS}, \
879                 error = NULL, \
880                 completed_at = NULL, \
881                 started_at = now(), \
882                 executor_id = '{executor_id}', \
883                 recovery_count = recovery_count + 1 \
884             WHERE id = '{task_id}' AND status = 'FAILED'{TS}"
885        );
886        db.execute(Statement::from_string(DbBackend::Postgres, reset_workflow_sql))
887            .await?;
888
889        // Reset FAILED descendant steps to PENDING with cleared retry state.
890        // Completed steps are preserved for replay.
891        let reset_steps_sql = format!(
892            "WITH RECURSIVE descendants AS ( \
893                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
894                 UNION ALL \
895                 SELECT t.id FROM durable.task t \
896                 INNER JOIN descendants d ON t.parent_id = d.id \
897             ) \
898             UPDATE durable.task \
899             SET status = 'PENDING'{TS}, \
900                 error = NULL, \
901                 retry_count = 0, \
902                 completed_at = NULL, \
903                 started_at = NULL \
904             WHERE id IN (SELECT id FROM descendants) \
905               AND status IN ('FAILED'{TS}, 'RUNNING'{TS})"
906        );
907        db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
908            .await?;
909
910        tracing::info!(
911            %task_id,
912            recovery_count = model.recovery_count + 1,
913            "failed workflow resumed"
914        );
915
916        Ok(model.handler.clone())
917    }
918
919    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
920    /// cascades to all non-terminal descendants.
921    ///
922    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
923    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
924        let model = Task::find_by_id(task_id).one(db).await?;
925        let model =
926            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
927
928        match model.status {
929            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
930                return Err(DurableError::custom(format!(
931                    "cannot cancel task in {} status",
932                    model.status
933                )));
934            }
935            _ => {}
936        }
937
938        // Cancel the task itself and all non-terminal descendants in one recursive CTE
939        let sql = format!(
940            "WITH RECURSIVE descendants AS ( \
941                 SELECT id FROM durable.task WHERE id = '{task_id}' \
942                 UNION ALL \
943                 SELECT t.id FROM durable.task t \
944                 INNER JOIN descendants d ON t.parent_id = d.id \
945             ) \
946             UPDATE durable.task SET status = 'CANCELLED'{TS}, completed_at = now() \
947             WHERE id IN (SELECT id FROM descendants) \
948               AND status NOT IN ('COMPLETED'{TS}, 'FAILED'{TS}, 'CANCELLED'{TS})"
949        );
950        db.execute(Statement::from_string(DbBackend::Postgres, sql))
951            .await?;
952
953        tracing::info!(%task_id, "workflow cancelled");
954        Ok(())
955    }
956
957    // ── Query API ─────────────────────────────────────────────────
958
959    /// List tasks matching the given filter, with sorting and pagination.
960    ///
961    /// ```ignore
962    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
963    /// ```
964    pub async fn list(
965        db: &DatabaseConnection,
966        query: TaskQuery,
967    ) -> Result<Vec<TaskSummary>, DurableError> {
968        let mut select = Task::find();
969
970        // Filters
971        if let Some(status) = &query.status {
972            select = select.filter(TaskColumn::Status.eq(status.to_string()));
973        }
974        if let Some(kind) = &query.kind {
975            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
976        }
977        if let Some(parent_id) = query.parent_id {
978            select = select.filter(TaskColumn::ParentId.eq(parent_id));
979        }
980        if query.root_only {
981            select = select.filter(TaskColumn::ParentId.is_null());
982        }
983        if let Some(name) = &query.name {
984            select = select.filter(TaskColumn::Name.eq(name.as_str()));
985        }
986        if let Some(queue) = &query.queue_name {
987            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
988        }
989
990        // Sorting
991        let (col, order) = match query.sort {
992            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
993            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
994            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
995            TaskSort::Name(ord) => (TaskColumn::Name, ord),
996            TaskSort::Status(ord) => (TaskColumn::Status, ord),
997        };
998        select = select.order_by(col, order);
999
1000        // Pagination
1001        if let Some(offset) = query.offset {
1002            select = select.offset(offset);
1003        }
1004        if let Some(limit) = query.limit {
1005            select = select.limit(limit);
1006        }
1007
1008        let models = select.all(db).await?;
1009
1010        Ok(models.into_iter().map(TaskSummary::from).collect())
1011    }
1012
1013    /// Count tasks matching the given filter.
1014    pub async fn count(
1015        db: &DatabaseConnection,
1016        query: TaskQuery,
1017    ) -> Result<u64, DurableError> {
1018        let mut select = Task::find();
1019
1020        if let Some(status) = &query.status {
1021            select = select.filter(TaskColumn::Status.eq(status.to_string()));
1022        }
1023        if let Some(kind) = &query.kind {
1024            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1025        }
1026        if let Some(parent_id) = query.parent_id {
1027            select = select.filter(TaskColumn::ParentId.eq(parent_id));
1028        }
1029        if query.root_only {
1030            select = select.filter(TaskColumn::ParentId.is_null());
1031        }
1032        if let Some(name) = &query.name {
1033            select = select.filter(TaskColumn::Name.eq(name.as_str()));
1034        }
1035        if let Some(queue) = &query.queue_name {
1036            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1037        }
1038
1039        let count = select.count(db).await?;
1040        Ok(count)
1041    }
1042
1043    // ── Accessors ────────────────────────────────────────────────
1044
1045    pub fn db(&self) -> &DatabaseConnection {
1046        &self.db
1047    }
1048
1049    pub fn task_id(&self) -> Uuid {
1050        self.task_id
1051    }
1052
1053    pub fn next_sequence(&self) -> i32 {
1054        self.sequence.fetch_add(1, Ordering::SeqCst)
1055    }
1056
1057    /// Deserialize the stored input for this task.
1058    ///
1059    /// Returns `Ok(value)` when the task has a non-null `input` column that
1060    /// deserializes to `T`. Returns an error if the input is missing or
1061    /// cannot be deserialized.
1062    ///
1063    /// ```ignore
1064    /// let args: IngestInput = ctx.input().await?;
1065    /// ```
1066    pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1067        let row = self
1068            .db
1069            .query_one(Statement::from_string(
1070                DbBackend::Postgres,
1071                format!(
1072                    "SELECT input FROM durable.task WHERE id = '{}'",
1073                    self.task_id
1074                ),
1075            ))
1076            .await?
1077            .ok_or_else(|| {
1078                DurableError::custom(format!("task {} not found", self.task_id))
1079            })?;
1080
1081        let input_json: Option<serde_json::Value> = row
1082            .try_get_by_index(0)
1083            .map_err(|e| DurableError::custom(e.to_string()))?;
1084
1085        let value = input_json.ok_or_else(|| {
1086            DurableError::custom(format!("task {} has no input", self.task_id))
1087        })?;
1088
1089        serde_json::from_value(value)
1090            .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1091    }
1092}
1093
1094// ── Internal SQL helpers ─────────────────────────────────────────────
1095
1096/// Find an existing task by (parent_id, name) or create a new one.
1097///
1098/// Returns `(task_id, Option<saved_output>)`:
1099/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
1100/// - `saved_output` is `None` when the task is new or in-progress.
1101///
1102/// When `lock` is `true` and an existing non-completed task is found, this
1103/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
1104/// another worker holds the lock, `DurableError::StepLocked` is returned so
1105/// the caller can skip execution rather than double-firing side effects.
1106///
1107/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
1108/// and child-workflow creation where concurrent start is safe).
1109///
1110/// When `lock` is `true`, the caller MUST call this within a transaction.
1111/// The lock prevents concurrent creation/claiming of the same step. The
1112/// caller should set the step to RUNNING and commit promptly — the RUNNING
1113/// status itself prevents other executors from re-claiming the step.
1114///
1115/// `max_retries`: if Some, overrides the schema default when creating a new task.
1116#[allow(clippy::too_many_arguments)]
1117async fn find_or_create_task(
1118    db: &impl ConnectionTrait,
1119    parent_id: Option<Uuid>,
1120    sequence: Option<i32>,
1121    name: &str,
1122    kind: &str,
1123    input: Option<serde_json::Value>,
1124    lock: bool,
1125    max_retries: Option<u32>,
1126) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1127    let parent_eq = match parent_id {
1128        Some(p) => format!("= '{p}'"),
1129        None => "IS NULL".to_string(),
1130    };
1131    let parent_sql = match parent_id {
1132        Some(p) => format!("'{p}'"),
1133        None => "NULL".to_string(),
1134    };
1135
1136    if lock {
1137        // Locking path (for steps): we need exactly-once execution.
1138        //
1139        // Strategy:
1140        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
1141        //    If another transaction is concurrently inserting the same row,
1142        //    Postgres will block here until that transaction commits or rolls
1143        //    back, ensuring we never see a phantom "not found" for a row being
1144        //    inserted.
1145        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
1146        //    should get it back; if the row existed and is locked by another
1147        //    worker we get nothing.
1148        // 3. If SKIP LOCKED returns empty, return StepLocked.
1149
1150        let new_id = Uuid::new_v4();
1151        let seq_sql = match sequence {
1152            Some(s) => s.to_string(),
1153            None => "NULL".to_string(),
1154        };
1155        let input_sql = match &input {
1156            Some(v) => format!("'{}'", serde_json::to_string(v)?),
1157            None => "NULL".to_string(),
1158        };
1159
1160        let max_retries_sql = match max_retries {
1161            Some(r) => r.to_string(),
1162            None => "3".to_string(), // schema default
1163        };
1164
1165        // Step 1: insert-or-skip
1166        let insert_sql = format!(
1167            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1168             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING'{TS}, {input_sql}, {max_retries_sql}) \
1169             ON CONFLICT (parent_id, sequence) DO NOTHING"
1170        );
1171        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1172            .await?;
1173
1174        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
1175        let lock_sql = format!(
1176            "SELECT id, status::text, output FROM durable.task \
1177             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1178             FOR UPDATE SKIP LOCKED"
1179        );
1180        let row = db
1181            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1182            .await?;
1183
1184        if let Some(row) = row {
1185            let id: Uuid = row
1186                .try_get_by_index(0)
1187                .map_err(|e| DurableError::custom(e.to_string()))?;
1188            let status: String = row
1189                .try_get_by_index(1)
1190                .map_err(|e| DurableError::custom(e.to_string()))?;
1191            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1192
1193            if status == TaskStatus::Completed.to_string() {
1194                // Replay path — return saved output
1195                return Ok((id, output));
1196            }
1197            if status == TaskStatus::Running.to_string() {
1198                // Another worker already claimed this step (set RUNNING and committed).
1199                // We got the row lock (no longer held in a long txn) but must not
1200                // double-execute.
1201                return Err(DurableError::StepLocked(name.to_string()));
1202            }
1203            // Task is PENDING — we're the first to claim it.
1204            return Ok((id, None));
1205        }
1206
1207        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
1208        Err(DurableError::StepLocked(name.to_string()))
1209    } else {
1210        // Plain find without locking — safe for workflow-level operations.
1211        // Multiple workers resuming the same workflow is fine; individual
1212        // steps will be locked when executed.
1213        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1214        query = match parent_id {
1215            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1216            None => query.filter(TaskColumn::ParentId.is_null()),
1217        };
1218        // Skip cancelled/failed tasks — a fresh row will be created instead.
1219        // Completed tasks are still returned so child tasks can replay saved output.
1220        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1221        let existing = query
1222            .filter(TaskColumn::Status.is_not_in(status_exclusions))
1223            .one(db)
1224            .await?;
1225
1226        if let Some(model) = existing {
1227            if model.status == TaskStatus::Completed {
1228                return Ok((model.id, model.output));
1229            }
1230            return Ok((model.id, None));
1231        }
1232
1233        // Task does not exist — create it
1234        let id = Uuid::new_v4();
1235        let new_task = TaskActiveModel {
1236            id: Set(id),
1237            parent_id: Set(parent_id),
1238            sequence: Set(sequence),
1239            name: Set(name.to_string()),
1240            kind: Set(kind.to_string()),
1241            status: Set(TaskStatus::Pending),
1242            input: Set(input),
1243            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1244            ..Default::default()
1245        };
1246        new_task.insert(db).await?;
1247
1248        Ok((id, None))
1249    }
1250}
1251
1252async fn get_output(
1253    db: &impl ConnectionTrait,
1254    task_id: Uuid,
1255) -> Result<Option<serde_json::Value>, DurableError> {
1256    let model = Task::find_by_id(task_id)
1257        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1258        .one(db)
1259        .await?;
1260
1261    Ok(model.and_then(|m| m.output))
1262}
1263
1264async fn get_status(
1265    db: &impl ConnectionTrait,
1266    task_id: Uuid,
1267) -> Result<Option<TaskStatus>, DurableError> {
1268    let model = Task::find_by_id(task_id).one(db).await?;
1269
1270    Ok(model.map(|m| m.status))
1271}
1272
1273/// Returns (retry_count, max_retries) for a task.
1274async fn get_retry_info(
1275    db: &DatabaseConnection,
1276    task_id: Uuid,
1277) -> Result<(u32, u32), DurableError> {
1278    let model = Task::find_by_id(task_id).one(db).await?;
1279
1280    match model {
1281        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1282        None => Err(DurableError::custom(format!(
1283            "task {task_id} not found when reading retry info"
1284        ))),
1285    }
1286}
1287
1288/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1289async fn increment_retry_count(
1290    db: &DatabaseConnection,
1291    task_id: Uuid,
1292) -> Result<u32, DurableError> {
1293    let model = Task::find_by_id(task_id).one(db).await?;
1294
1295    match model {
1296        Some(m) => {
1297            let new_count = m.retry_count + 1;
1298            let mut active: TaskActiveModel = m.into();
1299            active.retry_count = Set(new_count);
1300            active.status = Set(TaskStatus::Pending);
1301            active.error = Set(None);
1302            active.completed_at = Set(None);
1303            active.update(db).await?;
1304            Ok(new_count as u32)
1305        }
1306        None => Err(DurableError::custom(format!(
1307            "task {task_id} not found when incrementing retry count"
1308        ))),
1309    }
1310}
1311
1312// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1313// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1314async fn set_status(
1315    db: &impl ConnectionTrait,
1316    task_id: Uuid,
1317    status: TaskStatus,
1318) -> Result<(), DurableError> {
1319    let sql = format!(
1320        "UPDATE durable.task \
1321         SET status = '{status}'{TS}, \
1322             started_at = COALESCE(started_at, now()), \
1323             deadline_epoch_ms = CASE \
1324                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1325                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1326                 ELSE deadline_epoch_ms \
1327             END \
1328         WHERE id = '{task_id}'"
1329    );
1330    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1331        .await?;
1332    Ok(())
1333}
1334
1335/// Check if the task is paused or cancelled. Returns an error if so.
1336async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1337    let status = get_status(db, task_id).await?;
1338    match status {
1339        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1340        Some(TaskStatus::Cancelled) => {
1341            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1342        }
1343        _ => Ok(()),
1344    }
1345}
1346
1347/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1348async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1349    let model = Task::find_by_id(task_id).one(db).await?;
1350
1351    if let Some(m) = model
1352        && let Some(deadline_ms) = m.deadline_epoch_ms
1353    {
1354        let now_ms = std::time::SystemTime::now()
1355            .duration_since(std::time::UNIX_EPOCH)
1356            .map(|d| d.as_millis() as i64)
1357            .unwrap_or(0);
1358        if now_ms > deadline_ms {
1359            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1360        }
1361    }
1362
1363    Ok(())
1364}
1365
1366async fn complete_task(
1367    db: &impl ConnectionTrait,
1368    task_id: Uuid,
1369    output: serde_json::Value,
1370) -> Result<(), DurableError> {
1371    let model = Task::find_by_id(task_id).one(db).await?;
1372
1373    if let Some(m) = model {
1374        let mut active: TaskActiveModel = m.into();
1375        active.status = Set(TaskStatus::Completed);
1376        active.output = Set(Some(output));
1377        active.completed_at = Set(Some(chrono::Utc::now().into()));
1378        active.update(db).await?;
1379    }
1380    Ok(())
1381}
1382
1383async fn fail_task(
1384    db: &impl ConnectionTrait,
1385    task_id: Uuid,
1386    error: &str,
1387) -> Result<(), DurableError> {
1388    let model = Task::find_by_id(task_id).one(db).await?;
1389
1390    if let Some(m) = model {
1391        let mut active: TaskActiveModel = m.into();
1392        active.status = Set(TaskStatus::Failed);
1393        active.error = Set(Some(error.to_string()));
1394        active.completed_at = Set(Some(chrono::Utc::now().into()));
1395        active.update(db).await?;
1396    }
1397    Ok(())
1398}
1399
1400#[cfg(test)]
1401mod tests {
1402    use super::*;
1403    use std::sync::Arc;
1404    use std::sync::atomic::{AtomicU32, Ordering};
1405
1406    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1407    /// should be called exactly once and return Ok.
1408    #[tokio::test]
1409    async fn test_retry_db_write_succeeds_first_try() {
1410        let call_count = Arc::new(AtomicU32::new(0));
1411        let cc = call_count.clone();
1412        let result = retry_db_write(|| {
1413            let c = cc.clone();
1414            async move {
1415                c.fetch_add(1, Ordering::SeqCst);
1416                Ok::<(), DurableError>(())
1417            }
1418        })
1419        .await;
1420        assert!(result.is_ok());
1421        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1422    }
1423
1424    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1425    /// fails twice then succeeds should return Ok and be called 3 times.
1426    #[tokio::test]
1427    async fn test_retry_db_write_succeeds_after_transient_failure() {
1428        let call_count = Arc::new(AtomicU32::new(0));
1429        let cc = call_count.clone();
1430        let result = retry_db_write(|| {
1431            let c = cc.clone();
1432            async move {
1433                let n = c.fetch_add(1, Ordering::SeqCst);
1434                if n < 2 {
1435                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1436                        "transient".to_string(),
1437                    )))
1438                } else {
1439                    Ok(())
1440                }
1441            }
1442        })
1443        .await;
1444        assert!(result.is_ok());
1445        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1446    }
1447
1448    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1449    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1450    #[tokio::test]
1451    async fn test_retry_db_write_exhausts_retries() {
1452        let call_count = Arc::new(AtomicU32::new(0));
1453        let cc = call_count.clone();
1454        let result = retry_db_write(|| {
1455            let c = cc.clone();
1456            async move {
1457                c.fetch_add(1, Ordering::SeqCst);
1458                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1459                    "always fails".to_string(),
1460                )))
1461            }
1462        })
1463        .await;
1464        assert!(result.is_err());
1465        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1466        assert_eq!(
1467            call_count.load(Ordering::SeqCst),
1468            1 + MAX_CHECKPOINT_RETRIES
1469        );
1470    }
1471
1472    /// test_retry_db_write_returns_original_error: when all retries are
1473    /// exhausted the FIRST error is returned, not the last retry error.
1474    #[tokio::test]
1475    async fn test_retry_db_write_returns_original_error() {
1476        let call_count = Arc::new(AtomicU32::new(0));
1477        let cc = call_count.clone();
1478        let result = retry_db_write(|| {
1479            let c = cc.clone();
1480            async move {
1481                let n = c.fetch_add(1, Ordering::SeqCst);
1482                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1483                    "error-{}",
1484                    n
1485                ))))
1486            }
1487        })
1488        .await;
1489        let err = result.unwrap_err();
1490        // The message of the first error contains "error-0"
1491        assert!(
1492            err.to_string().contains("error-0"),
1493            "expected first error (error-0), got: {err}"
1494        );
1495    }
1496}