Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── SQL constants ────────────────────────────────────────────────────────────
21
22/// Explicit cast suffix for the `task_status` enum in raw SQL.
23///
24/// PostgreSQL cannot resolve `task_status` if the `durable` schema is not in
25/// the connection's `search_path`. Appending this cast to every status literal
26/// makes queries work regardless of the caller's search_path configuration.
27pub(crate) const TS: &str = "::durable.task_status";
28
29// ── Retry constants ───────────────────────────────────────────────────────────
30
31const MAX_CHECKPOINT_RETRIES: u32 = 3;
32const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
33
34/// Retry a fallible DB write with exponential backoff.
35///
36/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
37/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
38/// error if all retries are exhausted.
39async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
40where
41    F: FnMut() -> Fut,
42    Fut: std::future::Future<Output = Result<(), DurableError>>,
43{
44    match f().await {
45        Ok(()) => Ok(()),
46        Err(first_err) => {
47            for i in 0..MAX_CHECKPOINT_RETRIES {
48                tokio::time::sleep(Duration::from_millis(
49                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
50                ))
51                .await;
52                if f().await.is_ok() {
53                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
54                    return Ok(());
55                }
56            }
57            Err(first_err)
58        }
59    }
60}
61
62/// Policy for retrying a step on failure.
63pub struct RetryPolicy {
64    pub max_retries: u32,
65    pub initial_backoff: std::time::Duration,
66    pub backoff_multiplier: f64,
67}
68
69impl RetryPolicy {
70    /// No retries — fails immediately on error.
71    pub fn none() -> Self {
72        Self {
73            max_retries: 0,
74            initial_backoff: std::time::Duration::from_secs(0),
75            backoff_multiplier: 1.0,
76        }
77    }
78
79    /// Exponential backoff: backoff doubles each retry.
80    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff,
84            backoff_multiplier: 2.0,
85        }
86    }
87
88    /// Fixed backoff: same duration between all retries.
89    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
90        Self {
91            max_retries,
92            initial_backoff: backoff,
93            backoff_multiplier: 1.0,
94        }
95    }
96}
97
98/// Sort order for task listing.
99pub enum TaskSort {
100    CreatedAt(Order),
101    StartedAt(Order),
102    CompletedAt(Order),
103    Name(Order),
104    Status(Order),
105}
106
107/// Builder for filtering, sorting, and paginating task queries.
108///
109/// ```ignore
110/// let query = TaskQuery::default()
111///     .status(TaskStatus::Running)
112///     .kind("WORKFLOW")
113///     .root_only(true)
114///     .sort(TaskSort::CreatedAt(Order::Desc))
115///     .limit(20);
116/// let tasks = Ctx::list(&db, query).await?;
117/// ```
118pub struct TaskQuery {
119    pub status: Option<TaskStatus>,
120    pub kind: Option<String>,
121    pub parent_id: Option<Uuid>,
122    pub root_only: bool,
123    pub name: Option<String>,
124    pub queue_name: Option<String>,
125    pub sort: TaskSort,
126    pub limit: Option<u64>,
127    pub offset: Option<u64>,
128}
129
130impl Default for TaskQuery {
131    fn default() -> Self {
132        Self {
133            status: None,
134            kind: None,
135            parent_id: None,
136            root_only: false,
137            name: None,
138            queue_name: None,
139            sort: TaskSort::CreatedAt(Order::Desc),
140            limit: None,
141            offset: None,
142        }
143    }
144}
145
146impl TaskQuery {
147    /// Filter by status.
148    pub fn status(mut self, status: TaskStatus) -> Self {
149        self.status = Some(status);
150        self
151    }
152
153    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
154    pub fn kind(mut self, kind: &str) -> Self {
155        self.kind = Some(kind.to_string());
156        self
157    }
158
159    /// Filter by parent task ID (direct children only).
160    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
161        self.parent_id = Some(parent_id);
162        self
163    }
164
165    /// Only return root tasks (no parent).
166    pub fn root_only(mut self, root_only: bool) -> Self {
167        self.root_only = root_only;
168        self
169    }
170
171    /// Filter by task name.
172    pub fn name(mut self, name: &str) -> Self {
173        self.name = Some(name.to_string());
174        self
175    }
176
177    /// Filter by queue name.
178    pub fn queue_name(mut self, queue: &str) -> Self {
179        self.queue_name = Some(queue.to_string());
180        self
181    }
182
183    /// Set the sort order.
184    pub fn sort(mut self, sort: TaskSort) -> Self {
185        self.sort = sort;
186        self
187    }
188
189    /// Limit the number of results.
190    pub fn limit(mut self, limit: u64) -> Self {
191        self.limit = Some(limit);
192        self
193    }
194
195    /// Skip the first N results.
196    pub fn offset(mut self, offset: u64) -> Self {
197        self.offset = Some(offset);
198        self
199    }
200}
201
202/// Summary of a task returned by `Ctx::list()`.
203#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
204pub struct TaskSummary {
205    pub id: Uuid,
206    pub parent_id: Option<Uuid>,
207    pub name: String,
208    pub handler: Option<String>,
209    pub status: TaskStatus,
210    pub kind: String,
211    pub input: Option<serde_json::Value>,
212    pub output: Option<serde_json::Value>,
213    pub error: Option<String>,
214    pub queue_name: Option<String>,
215    pub created_at: chrono::DateTime<chrono::FixedOffset>,
216    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
217    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
218}
219
220impl From<durable_db::entity::task::Model> for TaskSummary {
221    fn from(m: durable_db::entity::task::Model) -> Self {
222        Self {
223            id: m.id,
224            parent_id: m.parent_id,
225            name: m.name,
226            handler: m.handler,
227            status: m.status,
228            kind: m.kind,
229            input: m.input,
230            output: m.output,
231            error: m.error,
232            queue_name: m.queue_name,
233            created_at: m.created_at,
234            started_at: m.started_at,
235            completed_at: m.completed_at,
236        }
237    }
238}
239
240/// Result of [`Ctx::start`] / [`Ctx::start_with_handler`] / [`Ctx::start_with_timeout`].
241///
242/// Distinguishes whether a new workflow was **created** from scratch or an
243/// existing RUNNING workflow was **attached to** (idempotent start).
244///
245/// ```ignore
246/// match Ctx::start(&db, "my-workflow", None).await? {
247///     StartResult::Created(ctx) => { /* fresh run */ },
248///     StartResult::Attached(ctx) => { /* resumed existing RUNNING task */ },
249/// }
250///
251/// // Or just unwrap the Ctx when you don't care which variant:
252/// let ctx = Ctx::start(&db, "my-workflow", None).await?.into_ctx();
253/// ```
254pub enum StartResult {
255    /// A fresh workflow task was inserted and started.
256    Created(Ctx),
257    /// An existing RUNNING workflow with the same name was found and reused.
258    Attached(Ctx),
259}
260
261impl StartResult {
262    /// Unwrap the inner [`Ctx`], discarding the variant information.
263    pub fn into_ctx(self) -> Ctx {
264        match self {
265            StartResult::Created(ctx) | StartResult::Attached(ctx) => ctx,
266        }
267    }
268
269    /// Borrow the inner [`Ctx`] without consuming `self`.
270    pub fn ctx(&self) -> &Ctx {
271        match self {
272            StartResult::Created(ctx) | StartResult::Attached(ctx) => ctx,
273        }
274    }
275
276    /// Returns `true` if a new workflow task was created.
277    pub fn is_created(&self) -> bool {
278        matches!(self, StartResult::Created(_))
279    }
280
281    /// Returns `true` if an existing RUNNING workflow was reused.
282    pub fn is_attached(&self) -> bool {
283        matches!(self, StartResult::Attached(_))
284    }
285}
286
287/// Context threaded through every workflow and step.
288///
289/// Users never create or manage task IDs. The SDK handles everything
290/// via `(parent_id, name)` lookups — the unique constraint in the schema
291/// guarantees exactly-once step creation.
292pub struct Ctx {
293    db: DatabaseConnection,
294    task_id: Uuid,
295    sequence: AtomicI32,
296    executor_id: Option<String>,
297}
298
299impl Ctx {
300    // ── Workflow lifecycle (user-facing) ──────────────────────────
301
302    /// Start a new root workflow (or attach to an existing one with the same name).
303    ///
304    /// If a RUNNING root task with the same `name` already exists (e.g. recovered
305    /// after a crash), returns a `Ctx` attached to that task (idempotent start).
306    /// Otherwise creates a fresh task.
307    ///
308    /// If [`init`](crate::init) was called, the task is automatically tagged
309    /// with the executor ID for crash recovery. Otherwise `executor_id` is NULL.
310    ///
311    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
312    ///
313    /// ```ignore
314    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
315    /// ```
316    pub async fn start(
317        db: &DatabaseConnection,
318        name: &str,
319        input: Option<serde_json::Value>,
320    ) -> Result<StartResult, DurableError> {
321        Self::start_with_handler(db, name, input, None).await
322    }
323
324    /// Like [`start`](Self::start) but records the handler function name for
325    /// crash recovery. The `handler` is stored in the `durable.task` row so
326    /// that [`dispatch_recovered`](crate::dispatch_recovered) can look up the
327    /// registered `#[durable::workflow]` function even when `name` differs.
328    ///
329    /// Typically called from macro-generated `start_<fn>` functions rather
330    /// than directly.
331    pub async fn start_with_handler(
332        db: &DatabaseConnection,
333        name: &str,
334        input: Option<serde_json::Value>,
335        handler: Option<&str>,
336    ) -> Result<StartResult, DurableError> {
337        // ── Idempotent: reuse existing RUNNING root task with same name ──
338        let existing_sql = format!(
339            "SELECT id FROM durable.task \
340             WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING'{TS} \
341             LIMIT 1",
342            name
343        );
344        if let Some(row) = db
345            .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
346            .await?
347        {
348            let existing_id: Uuid = row
349                .try_get_by_index(0)
350                .map_err(|e| DurableError::custom(e.to_string()))?;
351            tracing::info!(
352                workflow = name,
353                id = %existing_id,
354                "idempotent start: attaching to existing RUNNING task"
355            );
356            return Self::from_id(db, existing_id)
357                .await
358                .map(StartResult::Attached);
359        }
360
361        let task_id = Uuid::new_v4();
362        let input_json = match &input {
363            Some(v) => serde_json::to_string(v)?,
364            None => "null".to_string(),
365        };
366
367        let executor_id = crate::executor_id();
368
369        let mut extra_cols = String::new();
370        let mut extra_vals = String::new();
371
372        if let Some(eid) = &executor_id {
373            extra_cols.push_str(", executor_id");
374            extra_vals.push_str(&format!(", '{eid}'"));
375        }
376        if let Some(h) = handler {
377            extra_cols.push_str(", handler");
378            extra_vals.push_str(&format!(", '{h}'"));
379        }
380
381        let txn = db.begin().await?;
382        let sql = format!(
383            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
384             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING'{TS}, '{input_json}', now(){extra_vals})"
385        );
386        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
387            .await?;
388        txn.commit().await?;
389
390        Ok(StartResult::Created(Self {
391            db: db.clone(),
392            task_id,
393            sequence: AtomicI32::new(0),
394            executor_id,
395        }))
396    }
397
398    /// Attach to an existing workflow by task ID. Used to resume a running or
399    /// recovered workflow without creating a new task row.
400    ///
401    /// ```ignore
402    /// let ctx = Ctx::from_id(&db, task_id).await?;
403    /// ```
404    pub async fn from_id(db: &DatabaseConnection, task_id: Uuid) -> Result<Self, DurableError> {
405        // Verify the task exists
406        let model = Task::find_by_id(task_id).one(db).await?;
407        let _model =
408            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
409
410        // Claim the task with the current executor_id (if init() was called)
411        let executor_id = crate::executor_id();
412        if let Some(eid) = &executor_id {
413            db.execute(Statement::from_string(
414                DbBackend::Postgres,
415                format!("UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"),
416            ))
417            .await?;
418        }
419
420        // Start sequence at 0 so that replaying steps from the beginning works
421        // correctly. Steps are looked up by (parent_id, sequence), so the caller
422        // will replay completed steps (getting saved outputs) before executing
423        // any new steps.
424        Ok(Self {
425            db: db.clone(),
426            task_id,
427            sequence: AtomicI32::new(0),
428            executor_id,
429        })
430    }
431
432    /// Run a step. If already completed, returns saved output. Otherwise executes
433    /// the closure, saves the result, and returns it.
434    ///
435    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
436    ///
437    /// ```ignore
438    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
439    /// ```
440    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
441    where
442        T: Serialize + DeserializeOwned,
443        F: FnOnce() -> Fut,
444        Fut: std::future::Future<Output = Result<T, DurableError>>,
445    {
446        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
447
448        // Check if workflow is paused or cancelled before executing
449        check_status(&self.db, self.task_id).await?;
450
451        // Check if parent task's deadline has passed before executing
452        check_deadline(&self.db, self.task_id).await?;
453
454        // TX1: Find or create the step row, set RUNNING, release connection.
455        // The RUNNING status acts as a mutex — recovery won't reclaim it
456        // unless this executor's heartbeat goes stale.
457        let txn = self.db.begin().await?;
458        let (step_id, saved_output) = find_or_create_task(
459            &txn,
460            Some(self.task_id),
461            Some(seq),
462            name,
463            "STEP",
464            None,
465            true,
466            Some(0),
467        )
468        .await?;
469
470        if let Some(output) = saved_output {
471            txn.commit().await?;
472            let val: T = serde_json::from_value(output)?;
473            tracing::debug!(step = name, seq, "replaying saved output");
474            return Ok(val);
475        }
476
477        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
478        txn.commit().await?;
479        // ── connection returned to pool ──
480
481        // Execute the closure with no DB connection held.
482        let result = f().await;
483
484        // TX2: Checkpoint the result (short transaction).
485        match result {
486            Ok(val) => {
487                let json = serde_json::to_value(&val)?;
488                retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
489                tracing::debug!(step = name, seq, "step completed");
490                Ok(val)
491            }
492            Err(e) => {
493                let err_msg = e.to_string();
494                retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
495                Err(e)
496            }
497        }
498    }
499
500    /// Run a DB-only step inside a single Postgres transaction.
501    ///
502    /// Both the user's DB work and the checkpoint save happen in the same transaction,
503    /// ensuring atomicity. If the closure returns an error, both the user writes and
504    /// the checkpoint are rolled back.
505    ///
506    /// ```ignore
507    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
508    ///     do_db_work(tx).await
509    /// })).await?;
510    /// ```
511    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
512    where
513        T: Serialize + DeserializeOwned + Send,
514        F: for<'tx> FnOnce(
515                &'tx DatabaseTransaction,
516            ) -> Pin<
517                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
518            > + Send,
519    {
520        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
521
522        // Check if workflow is paused or cancelled before executing
523        check_status(&self.db, self.task_id).await?;
524
525        // Find or create the step task record OUTSIDE the transaction.
526        // This is idempotent (UNIQUE constraint) and must exist before we begin.
527        let (step_id, saved_output) = find_or_create_task(
528            &self.db,
529            Some(self.task_id),
530            Some(seq),
531            name,
532            "TRANSACTION",
533            None,
534            false,
535            None,
536        )
537        .await?;
538
539        // If already completed, replay from saved output.
540        if let Some(output) = saved_output {
541            let val: T = serde_json::from_value(output)?;
542            tracing::debug!(step = name, seq, "replaying saved transaction output");
543            return Ok(val);
544        }
545
546        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
547        let tx = self.db.begin().await?;
548
549        set_status(&tx, step_id, TaskStatus::Running).await?;
550
551        match f(&tx).await {
552            Ok(val) => {
553                let json = serde_json::to_value(&val)?;
554                complete_task(&tx, step_id, json).await?;
555                tx.commit().await?;
556                tracing::debug!(step = name, seq, "transaction step committed");
557                Ok(val)
558            }
559            Err(e) => {
560                // Rollback happens automatically when tx is dropped.
561                // Record the failure on the main connection (outside the rolled-back tx).
562                drop(tx);
563                fail_task(&self.db, step_id, &e.to_string()).await?;
564                Err(e)
565            }
566        }
567    }
568
569    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
570    ///
571    /// ```ignore
572    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
573    /// // use child_ctx.step(...) for steps inside the child
574    /// child_ctx.complete(json!({"done": true})).await?;
575    /// ```
576    pub async fn child(
577        &self,
578        name: &str,
579        input: Option<serde_json::Value>,
580    ) -> Result<Self, DurableError> {
581        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
582
583        // Check if workflow is paused or cancelled before executing
584        check_status(&self.db, self.task_id).await?;
585
586        let txn = self.db.begin().await?;
587        // Child workflows also use plain find-or-create without locking.
588        let (child_id, _saved) = find_or_create_task(
589            &txn,
590            Some(self.task_id),
591            Some(seq),
592            name,
593            "WORKFLOW",
594            input,
595            false,
596            None,
597        )
598        .await?;
599
600        // If child already completed, return a Ctx that will replay
601        // (the caller should check is_completed() or just run steps which will replay)
602        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
603        txn.commit().await?;
604
605        Ok(Self {
606            db: self.db.clone(),
607            task_id: child_id,
608            sequence: AtomicI32::new(0),
609            executor_id: self.executor_id.clone(),
610        })
611    }
612
613    /// Run a step that can execute concurrently with other concurrent steps.
614    ///
615    /// Same semantics as `step()` — idempotent, checkpointed, replayable — but
616    /// designed to be composed with `tokio::try_join!` for parallel execution.
617    /// The child ctx is an internal implementation detail, not exposed to the caller.
618    ///
619    /// ```ignore
620    /// let (a, b) = tokio::try_join!(
621    ///     ctx.concurrent_step("fetch_a", || async { fetch_a().await }),
622    ///     ctx.concurrent_step("fetch_b", || async { fetch_b().await }),
623    /// )?;
624    /// ```
625    pub async fn concurrent_step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
626    where
627        T: Serialize + DeserializeOwned,
628        F: FnOnce() -> Fut,
629        Fut: std::future::Future<Output = Result<T, DurableError>>,
630    {
631        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
632
633        // Check if the parent workflow is paused or cancelled.
634        check_status(&self.db, self.task_id).await?;
635
636        // Find or create the child task without holding a row lock (same as child()).
637        let (child_id, saved_output) = find_or_create_task(
638            &self.db,
639            Some(self.task_id),
640            Some(seq),
641            name,
642            "WORKFLOW",
643            None,
644            false,
645            None,
646        )
647        .await?;
648
649        // If already completed, replay saved output without touching the task row.
650        if let Some(output) = saved_output {
651            let val: T = serde_json::from_value(output)?;
652            tracing::debug!(step = name, seq, "concurrent_step: replaying saved output");
653            return Ok(val);
654        }
655
656        // Not yet completed — set to RUNNING, execute, checkpoint result.
657        retry_db_write(|| set_status(&self.db, child_id, TaskStatus::Running)).await?;
658
659        let result = f().await?;
660
661        let json = serde_json::to_value(&result)?;
662        retry_db_write(|| complete_task(&self.db, child_id, json.clone())).await?;
663        tracing::debug!(step = name, seq, "concurrent_step: completed");
664        Ok(result)
665    }
666
667    /// Check if this workflow/child was already completed (for skipping in parent).
668    pub async fn is_completed(&self) -> Result<bool, DurableError> {
669        let status = get_status(&self.db, self.task_id).await?;
670        Ok(status == Some(TaskStatus::Completed))
671    }
672
673    /// Get the saved output if this task is completed.
674    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
675        match get_output(&self.db, self.task_id).await? {
676            Some(val) => Ok(Some(serde_json::from_value(val)?)),
677            None => Ok(None),
678        }
679    }
680
681    /// Mark this workflow as completed with an output value.
682    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
683        let json = serde_json::to_value(output)?;
684        let db = &self.db;
685        let task_id = self.task_id;
686        retry_db_write(|| complete_task(db, task_id, json.clone())).await
687    }
688
689    /// Run a step with a configurable retry policy.
690    ///
691    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
692    /// it may be called multiple times on retry. Retries happen in-process with
693    /// configurable backoff between attempts.
694    ///
695    /// ```ignore
696    /// let result: u32 = ctx
697    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
698    ///         api.call().await
699    ///     })
700    ///     .await?;
701    /// ```
702    pub async fn step_with_retry<T, F, Fut>(
703        &self,
704        name: &str,
705        policy: RetryPolicy,
706        f: F,
707    ) -> Result<T, DurableError>
708    where
709        T: Serialize + DeserializeOwned,
710        F: Fn() -> Fut,
711        Fut: std::future::Future<Output = Result<T, DurableError>>,
712    {
713        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
714
715        // Check if workflow is paused or cancelled before executing
716        check_status(&self.db, self.task_id).await?;
717
718        // Find or create — idempotent via UNIQUE(parent_id, name)
719        // Set max_retries from policy when creating.
720        // No locking here — retry logic handles re-execution in-process.
721        let (step_id, saved_output) = find_or_create_task(
722            &self.db,
723            Some(self.task_id),
724            Some(seq),
725            name,
726            "STEP",
727            None,
728            false,
729            Some(policy.max_retries),
730        )
731        .await?;
732
733        // If already completed, replay from saved output
734        if let Some(output) = saved_output {
735            let val: T = serde_json::from_value(output)?;
736            tracing::debug!(step = name, seq, "replaying saved output");
737            return Ok(val);
738        }
739
740        // Get current retry state from DB (for resume across restarts)
741        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
742
743        // Retry loop
744        loop {
745            // Re-check status before each retry attempt
746            check_status(&self.db, self.task_id).await?;
747            set_status(&self.db, step_id, TaskStatus::Running).await?;
748            match f().await {
749                Ok(val) => {
750                    let json = serde_json::to_value(&val)?;
751                    complete_task(&self.db, step_id, json).await?;
752                    tracing::debug!(step = name, seq, retry_count, "step completed");
753                    return Ok(val);
754                }
755                Err(e) => {
756                    if retry_count < max_retries {
757                        // Increment retry count and reset to PENDING
758                        retry_count = increment_retry_count(&self.db, step_id).await?;
759                        tracing::debug!(
760                            step = name,
761                            seq,
762                            retry_count,
763                            max_retries,
764                            "step failed, retrying"
765                        );
766
767                        // Compute backoff duration
768                        let backoff = if policy.initial_backoff.is_zero() {
769                            std::time::Duration::ZERO
770                        } else {
771                            let factor = policy
772                                .backoff_multiplier
773                                .powi((retry_count - 1) as i32)
774                                .max(1.0);
775                            let millis =
776                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
777                            std::time::Duration::from_millis(millis)
778                        };
779
780                        if !backoff.is_zero() {
781                            tokio::time::sleep(backoff).await;
782                        }
783                    } else {
784                        // Exhausted retries — mark FAILED
785                        fail_task(&self.db, step_id, &e.to_string()).await?;
786                        tracing::debug!(
787                            step = name,
788                            seq,
789                            retry_count,
790                            "step exhausted retries, marked FAILED"
791                        );
792                        return Err(e);
793                    }
794                }
795            }
796        }
797    }
798
799    /// Mark this workflow as failed.
800    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
801        let db = &self.db;
802        let task_id = self.task_id;
803        retry_db_write(|| fail_task(db, task_id, error)).await
804    }
805
806    /// Mark a workflow as failed by task ID (static version for use outside a Ctx).
807    pub async fn fail_by_id(
808        db: &DatabaseConnection,
809        task_id: Uuid,
810        error: &str,
811    ) -> Result<(), DurableError> {
812        retry_db_write(|| fail_task(db, task_id, error)).await
813    }
814
815    /// Set the timeout for this task in milliseconds.
816    ///
817    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
818    /// for recovery by `Executor::recover()`.
819    ///
820    /// If the task is currently RUNNING and has a `started_at`, this also
821    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
822    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
823        let sql = format!(
824            "UPDATE durable.task \
825             SET timeout_ms = {timeout_ms}, \
826                 deadline_epoch_ms = CASE \
827                     WHEN status = 'RUNNING'{TS} AND started_at IS NOT NULL \
828                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
829                     ELSE deadline_epoch_ms \
830                 END \
831             WHERE id = '{}'",
832            self.task_id
833        );
834        self.db
835            .execute(Statement::from_string(DbBackend::Postgres, sql))
836            .await?;
837        Ok(())
838    }
839
840    /// Start or resume a root workflow with a timeout in milliseconds.
841    ///
842    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
843    pub async fn start_with_timeout(
844        db: &DatabaseConnection,
845        name: &str,
846        input: Option<serde_json::Value>,
847        timeout_ms: i64,
848    ) -> Result<StartResult, DurableError> {
849        let result = Self::start(db, name, input).await?;
850        result.ctx().set_timeout(timeout_ms).await?;
851        Ok(result)
852    }
853
854    // ── Workflow control (management API) ─────────────────────────
855
856    /// Pause a workflow by ID. Sets status to PAUSED and recursively
857    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
858    ///
859    /// Only workflows in PENDING or RUNNING status can be paused.
860    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
861        let model = Task::find_by_id(task_id).one(db).await?;
862        let model =
863            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
864
865        match model.status {
866            TaskStatus::Pending | TaskStatus::Running => {}
867            status => {
868                return Err(DurableError::custom(format!(
869                    "cannot pause task in {status} status"
870                )));
871            }
872        }
873
874        // Pause the task itself and all descendants in one recursive CTE
875        let sql = format!(
876            "WITH RECURSIVE descendants AS ( \
877                 SELECT id FROM durable.task WHERE id = '{task_id}' \
878                 UNION ALL \
879                 SELECT t.id FROM durable.task t \
880                 INNER JOIN descendants d ON t.parent_id = d.id \
881             ) \
882             UPDATE durable.task SET status = 'PAUSED'{TS} \
883             WHERE id IN (SELECT id FROM descendants) \
884               AND status IN ('PENDING'{TS}, 'RUNNING'{TS})"
885        );
886        db.execute(Statement::from_string(DbBackend::Postgres, sql))
887            .await?;
888
889        tracing::info!(%task_id, "workflow paused");
890        Ok(())
891    }
892
893    /// Resume a paused workflow by ID. Sets status back to RUNNING and
894    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
895    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
896        let model = Task::find_by_id(task_id).one(db).await?;
897        let model =
898            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
899
900        if model.status != TaskStatus::Paused {
901            return Err(DurableError::custom(format!(
902                "cannot resume task in {} status (must be PAUSED)",
903                model.status
904            )));
905        }
906
907        // Resume the root task back to RUNNING
908        let mut active: TaskActiveModel = model.into();
909        active.status = Set(TaskStatus::Running);
910        active.update(db).await?;
911
912        // Recursively resume all PAUSED descendants back to PENDING
913        let cascade_sql = format!(
914            "WITH RECURSIVE descendants AS ( \
915                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
916                 UNION ALL \
917                 SELECT t.id FROM durable.task t \
918                 INNER JOIN descendants d ON t.parent_id = d.id \
919             ) \
920             UPDATE durable.task SET status = 'PENDING'{TS} \
921             WHERE id IN (SELECT id FROM descendants) \
922               AND status = 'PAUSED'{TS}"
923        );
924        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
925            .await?;
926
927        tracing::info!(%task_id, "workflow resumed");
928        Ok(())
929    }
930
931    /// Resume a failed workflow from its last completed step (DBOS-style).
932    ///
933    /// Resets the workflow from FAILED to RUNNING, resets all FAILED child
934    /// steps to PENDING (with `retry_count` zeroed), and increments
935    /// `recovery_count`. Completed steps are left untouched so they replay
936    /// from their checkpointed outputs.
937    ///
938    /// Returns the workflow's `handler` name (if set) so the caller can
939    /// look up the `WorkflowRegistration` and re-dispatch.
940    ///
941    /// Fails with `MaxRecoveryExceeded` if the workflow has already been
942    /// resumed `max_recovery_attempts` times.
943    pub async fn resume_failed(
944        db: &DatabaseConnection,
945        task_id: Uuid,
946    ) -> Result<Option<String>, DurableError> {
947        // Validate the workflow exists and is FAILED
948        let model = Task::find_by_id(task_id).one(db).await?;
949        let model =
950            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
951
952        if model.status != TaskStatus::Failed {
953            return Err(DurableError::custom(format!(
954                "cannot resume task in {} status (must be FAILED)",
955                model.status
956            )));
957        }
958
959        if model.kind != "WORKFLOW" {
960            return Err(DurableError::custom(format!(
961                "cannot resume a {}, only workflows can be resumed",
962                model.kind
963            )));
964        }
965
966        // Check recovery limit
967        if model.recovery_count >= model.max_recovery_attempts {
968            return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
969        }
970
971        let executor_id = crate::executor_id().unwrap_or_default();
972
973        // Atomically: reset workflow to RUNNING, increment recovery_count
974        let reset_workflow_sql = format!(
975            "UPDATE durable.task \
976             SET status = 'RUNNING'{TS}, \
977                 error = NULL, \
978                 completed_at = NULL, \
979                 started_at = now(), \
980                 executor_id = '{executor_id}', \
981                 recovery_count = recovery_count + 1 \
982             WHERE id = '{task_id}' AND status = 'FAILED'{TS}"
983        );
984        db.execute(Statement::from_string(
985            DbBackend::Postgres,
986            reset_workflow_sql,
987        ))
988        .await?;
989
990        // Reset FAILED descendant steps to PENDING with cleared retry state.
991        // Completed steps are preserved for replay.
992        let reset_steps_sql = format!(
993            "WITH RECURSIVE descendants AS ( \
994                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
995                 UNION ALL \
996                 SELECT t.id FROM durable.task t \
997                 INNER JOIN descendants d ON t.parent_id = d.id \
998             ) \
999             UPDATE durable.task \
1000             SET status = 'PENDING'{TS}, \
1001                 error = NULL, \
1002                 retry_count = 0, \
1003                 completed_at = NULL, \
1004                 started_at = NULL \
1005             WHERE id IN (SELECT id FROM descendants) \
1006               AND status IN ('FAILED'{TS}, 'RUNNING'{TS})"
1007        );
1008        db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
1009            .await?;
1010
1011        tracing::info!(
1012            %task_id,
1013            recovery_count = model.recovery_count + 1,
1014            "failed workflow resumed"
1015        );
1016
1017        Ok(model.handler.clone())
1018    }
1019
1020    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
1021    /// cascades to all non-terminal descendants.
1022    ///
1023    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
1024    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1025        let model = Task::find_by_id(task_id).one(db).await?;
1026        let model =
1027            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
1028
1029        match model.status {
1030            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
1031                return Err(DurableError::custom(format!(
1032                    "cannot cancel task in {} status",
1033                    model.status
1034                )));
1035            }
1036            _ => {}
1037        }
1038
1039        // Cancel the task itself and all non-terminal descendants in one recursive CTE
1040        let sql = format!(
1041            "WITH RECURSIVE descendants AS ( \
1042                 SELECT id FROM durable.task WHERE id = '{task_id}' \
1043                 UNION ALL \
1044                 SELECT t.id FROM durable.task t \
1045                 INNER JOIN descendants d ON t.parent_id = d.id \
1046             ) \
1047             UPDATE durable.task SET status = 'CANCELLED'{TS}, completed_at = now() \
1048             WHERE id IN (SELECT id FROM descendants) \
1049               AND status NOT IN ('COMPLETED'{TS}, 'FAILED'{TS}, 'CANCELLED'{TS})"
1050        );
1051        db.execute(Statement::from_string(DbBackend::Postgres, sql))
1052            .await?;
1053
1054        tracing::info!(%task_id, "workflow cancelled");
1055        Ok(())
1056    }
1057
1058    // ── Query API ─────────────────────────────────────────────────
1059
1060    /// List tasks matching the given filter, with sorting and pagination.
1061    ///
1062    /// ```ignore
1063    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
1064    /// ```
1065    pub async fn list(
1066        db: &DatabaseConnection,
1067        query: TaskQuery,
1068    ) -> Result<Vec<TaskSummary>, DurableError> {
1069        let mut select = Task::find();
1070
1071        // Filters
1072        if let Some(status) = &query.status {
1073            select = select.filter(TaskColumn::Status.eq(status.to_string()));
1074        }
1075        if let Some(kind) = &query.kind {
1076            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1077        }
1078        if let Some(parent_id) = query.parent_id {
1079            select = select.filter(TaskColumn::ParentId.eq(parent_id));
1080        }
1081        if query.root_only {
1082            select = select.filter(TaskColumn::ParentId.is_null());
1083        }
1084        if let Some(name) = &query.name {
1085            select = select.filter(TaskColumn::Name.eq(name.as_str()));
1086        }
1087        if let Some(queue) = &query.queue_name {
1088            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1089        }
1090
1091        // Sorting
1092        let (col, order) = match query.sort {
1093            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
1094            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
1095            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
1096            TaskSort::Name(ord) => (TaskColumn::Name, ord),
1097            TaskSort::Status(ord) => (TaskColumn::Status, ord),
1098        };
1099        select = select.order_by(col, order);
1100
1101        // Pagination
1102        if let Some(offset) = query.offset {
1103            select = select.offset(offset);
1104        }
1105        if let Some(limit) = query.limit {
1106            select = select.limit(limit);
1107        }
1108
1109        let models = select.all(db).await?;
1110
1111        Ok(models.into_iter().map(TaskSummary::from).collect())
1112    }
1113
1114    /// Count tasks matching the given filter.
1115    pub async fn count(db: &DatabaseConnection, query: TaskQuery) -> Result<u64, DurableError> {
1116        let mut select = Task::find();
1117
1118        if let Some(status) = &query.status {
1119            select = select.filter(TaskColumn::Status.eq(status.to_string()));
1120        }
1121        if let Some(kind) = &query.kind {
1122            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1123        }
1124        if let Some(parent_id) = query.parent_id {
1125            select = select.filter(TaskColumn::ParentId.eq(parent_id));
1126        }
1127        if query.root_only {
1128            select = select.filter(TaskColumn::ParentId.is_null());
1129        }
1130        if let Some(name) = &query.name {
1131            select = select.filter(TaskColumn::Name.eq(name.as_str()));
1132        }
1133        if let Some(queue) = &query.queue_name {
1134            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1135        }
1136
1137        let count = select.count(db).await?;
1138        Ok(count)
1139    }
1140
1141    // ── Accessors ────────────────────────────────────────────────
1142
1143    pub fn db(&self) -> &DatabaseConnection {
1144        &self.db
1145    }
1146
1147    pub fn task_id(&self) -> Uuid {
1148        self.task_id
1149    }
1150
1151    pub fn next_sequence(&self) -> i32 {
1152        self.sequence.fetch_add(1, Ordering::SeqCst)
1153    }
1154
1155    /// Deserialize the stored input for this task.
1156    ///
1157    /// Returns `Ok(value)` when the task has a non-null `input` column that
1158    /// deserializes to `T`. Returns an error if the input is missing or
1159    /// cannot be deserialized.
1160    ///
1161    /// ```ignore
1162    /// let args: IngestInput = ctx.input().await?;
1163    /// ```
1164    pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1165        let row = self
1166            .db
1167            .query_one(Statement::from_string(
1168                DbBackend::Postgres,
1169                format!(
1170                    "SELECT input FROM durable.task WHERE id = '{}'",
1171                    self.task_id
1172                ),
1173            ))
1174            .await?
1175            .ok_or_else(|| DurableError::custom(format!("task {} not found", self.task_id)))?;
1176
1177        let input_json: Option<serde_json::Value> = row
1178            .try_get_by_index(0)
1179            .map_err(|e| DurableError::custom(e.to_string()))?;
1180
1181        let value = input_json
1182            .ok_or_else(|| DurableError::custom(format!("task {} has no input", self.task_id)))?;
1183
1184        serde_json::from_value(value)
1185            .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1186    }
1187}
1188
1189// ── Internal SQL helpers ─────────────────────────────────────────────
1190
1191/// Find an existing task by (parent_id, name) or create a new one.
1192///
1193/// Returns `(task_id, Option<saved_output>)`:
1194/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
1195/// - `saved_output` is `None` when the task is new or in-progress.
1196///
1197/// When `lock` is `true` and an existing non-completed task is found, this
1198/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
1199/// another worker holds the lock, `DurableError::StepLocked` is returned so
1200/// the caller can skip execution rather than double-firing side effects.
1201///
1202/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
1203/// and child-workflow creation where concurrent start is safe).
1204///
1205/// When `lock` is `true`, the caller MUST call this within a transaction.
1206/// The lock prevents concurrent creation/claiming of the same step. The
1207/// caller should set the step to RUNNING and commit promptly — the RUNNING
1208/// status itself prevents other executors from re-claiming the step.
1209///
1210/// `max_retries`: if Some, overrides the schema default when creating a new task.
1211#[allow(clippy::too_many_arguments)]
1212async fn find_or_create_task(
1213    db: &impl ConnectionTrait,
1214    parent_id: Option<Uuid>,
1215    sequence: Option<i32>,
1216    name: &str,
1217    kind: &str,
1218    input: Option<serde_json::Value>,
1219    lock: bool,
1220    max_retries: Option<u32>,
1221) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1222    let parent_eq = match parent_id {
1223        Some(p) => format!("= '{p}'"),
1224        None => "IS NULL".to_string(),
1225    };
1226    let parent_sql = match parent_id {
1227        Some(p) => format!("'{p}'"),
1228        None => "NULL".to_string(),
1229    };
1230
1231    if lock {
1232        // Locking path (for steps): we need exactly-once execution.
1233        //
1234        // Strategy:
1235        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
1236        //    If another transaction is concurrently inserting the same row,
1237        //    Postgres will block here until that transaction commits or rolls
1238        //    back, ensuring we never see a phantom "not found" for a row being
1239        //    inserted.
1240        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
1241        //    should get it back; if the row existed and is locked by another
1242        //    worker we get nothing.
1243        // 3. If SKIP LOCKED returns empty, return StepLocked.
1244
1245        let new_id = Uuid::new_v4();
1246        let seq_sql = match sequence {
1247            Some(s) => s.to_string(),
1248            None => "NULL".to_string(),
1249        };
1250        let input_sql = match &input {
1251            Some(v) => format!("'{}'", serde_json::to_string(v)?),
1252            None => "NULL".to_string(),
1253        };
1254
1255        let max_retries_sql = match max_retries {
1256            Some(r) => r.to_string(),
1257            None => "3".to_string(), // schema default
1258        };
1259
1260        // Step 1: insert-or-skip
1261        let insert_sql = format!(
1262            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1263             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING'{TS}, {input_sql}, {max_retries_sql}) \
1264             ON CONFLICT (parent_id, sequence) DO NOTHING"
1265        );
1266        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1267            .await?;
1268
1269        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
1270        let lock_sql = format!(
1271            "SELECT id, status::text, output FROM durable.task \
1272             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1273             FOR UPDATE SKIP LOCKED"
1274        );
1275        let row = db
1276            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1277            .await?;
1278
1279        if let Some(row) = row {
1280            let id: Uuid = row
1281                .try_get_by_index(0)
1282                .map_err(|e| DurableError::custom(e.to_string()))?;
1283            let status: String = row
1284                .try_get_by_index(1)
1285                .map_err(|e| DurableError::custom(e.to_string()))?;
1286            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1287
1288            if status == TaskStatus::Completed.to_string() {
1289                // Replay path — return saved output
1290                return Ok((id, output));
1291            }
1292            if status == TaskStatus::Running.to_string() {
1293                // Another worker already claimed this step (set RUNNING and committed).
1294                // We got the row lock (no longer held in a long txn) but must not
1295                // double-execute.
1296                return Err(DurableError::StepLocked(name.to_string()));
1297            }
1298            // Task is PENDING — we're the first to claim it.
1299            return Ok((id, None));
1300        }
1301
1302        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
1303        Err(DurableError::StepLocked(name.to_string()))
1304    } else {
1305        // Plain find without locking — safe for workflow-level operations.
1306        // Multiple workers resuming the same workflow is fine; individual
1307        // steps will be locked when executed.
1308        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1309        query = match parent_id {
1310            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1311            None => query.filter(TaskColumn::ParentId.is_null()),
1312        };
1313        // Skip cancelled/failed tasks — a fresh row will be created instead.
1314        // Completed tasks are still returned so child tasks can replay saved output.
1315        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1316        let existing = query
1317            .filter(TaskColumn::Status.is_not_in(status_exclusions))
1318            .one(db)
1319            .await?;
1320
1321        if let Some(model) = existing {
1322            if model.status == TaskStatus::Completed {
1323                return Ok((model.id, model.output));
1324            }
1325            return Ok((model.id, None));
1326        }
1327
1328        // Task does not exist — create it
1329        let id = Uuid::new_v4();
1330        let new_task = TaskActiveModel {
1331            id: Set(id),
1332            parent_id: Set(parent_id),
1333            sequence: Set(sequence),
1334            name: Set(name.to_string()),
1335            kind: Set(kind.to_string()),
1336            status: Set(TaskStatus::Pending),
1337            input: Set(input),
1338            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1339            ..Default::default()
1340        };
1341        new_task.insert(db).await?;
1342
1343        Ok((id, None))
1344    }
1345}
1346
1347async fn get_output(
1348    db: &impl ConnectionTrait,
1349    task_id: Uuid,
1350) -> Result<Option<serde_json::Value>, DurableError> {
1351    let model = Task::find_by_id(task_id)
1352        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1353        .one(db)
1354        .await?;
1355
1356    Ok(model.and_then(|m| m.output))
1357}
1358
1359async fn get_status(
1360    db: &impl ConnectionTrait,
1361    task_id: Uuid,
1362) -> Result<Option<TaskStatus>, DurableError> {
1363    let model = Task::find_by_id(task_id).one(db).await?;
1364
1365    Ok(model.map(|m| m.status))
1366}
1367
1368/// Returns (retry_count, max_retries) for a task.
1369async fn get_retry_info(
1370    db: &DatabaseConnection,
1371    task_id: Uuid,
1372) -> Result<(u32, u32), DurableError> {
1373    let model = Task::find_by_id(task_id).one(db).await?;
1374
1375    match model {
1376        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1377        None => Err(DurableError::custom(format!(
1378            "task {task_id} not found when reading retry info"
1379        ))),
1380    }
1381}
1382
1383/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1384async fn increment_retry_count(
1385    db: &DatabaseConnection,
1386    task_id: Uuid,
1387) -> Result<u32, DurableError> {
1388    let model = Task::find_by_id(task_id).one(db).await?;
1389
1390    match model {
1391        Some(m) => {
1392            let new_count = m.retry_count + 1;
1393            let mut active: TaskActiveModel = m.into();
1394            active.retry_count = Set(new_count);
1395            active.status = Set(TaskStatus::Pending);
1396            active.error = Set(None);
1397            active.completed_at = Set(None);
1398            active.update(db).await?;
1399            Ok(new_count as u32)
1400        }
1401        None => Err(DurableError::custom(format!(
1402            "task {task_id} not found when incrementing retry count"
1403        ))),
1404    }
1405}
1406
1407// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1408// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1409async fn set_status(
1410    db: &impl ConnectionTrait,
1411    task_id: Uuid,
1412    status: TaskStatus,
1413) -> Result<(), DurableError> {
1414    let sql = format!(
1415        "UPDATE durable.task \
1416         SET status = '{status}'{TS}, \
1417             started_at = COALESCE(started_at, now()), \
1418             deadline_epoch_ms = CASE \
1419                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1420                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1421                 ELSE deadline_epoch_ms \
1422             END \
1423         WHERE id = '{task_id}'"
1424    );
1425    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1426        .await?;
1427    Ok(())
1428}
1429
1430/// Check if the task is paused or cancelled. Returns an error if so.
1431async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1432    let status = get_status(db, task_id).await?;
1433    match status {
1434        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1435        Some(TaskStatus::Cancelled) => Err(DurableError::Cancelled(format!(
1436            "task {task_id} is cancelled"
1437        ))),
1438        _ => Ok(()),
1439    }
1440}
1441
1442/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1443async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1444    let model = Task::find_by_id(task_id).one(db).await?;
1445
1446    if let Some(m) = model
1447        && let Some(deadline_ms) = m.deadline_epoch_ms
1448    {
1449        let now_ms = std::time::SystemTime::now()
1450            .duration_since(std::time::UNIX_EPOCH)
1451            .map(|d| d.as_millis() as i64)
1452            .unwrap_or(0);
1453        if now_ms > deadline_ms {
1454            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1455        }
1456    }
1457
1458    Ok(())
1459}
1460
1461async fn complete_task(
1462    db: &impl ConnectionTrait,
1463    task_id: Uuid,
1464    output: serde_json::Value,
1465) -> Result<(), DurableError> {
1466    let model = Task::find_by_id(task_id).one(db).await?;
1467
1468    if let Some(m) = model {
1469        let mut active: TaskActiveModel = m.into();
1470        active.status = Set(TaskStatus::Completed);
1471        active.output = Set(Some(output));
1472        active.completed_at = Set(Some(chrono::Utc::now().into()));
1473        active.update(db).await?;
1474    }
1475    Ok(())
1476}
1477
1478async fn fail_task(
1479    db: &impl ConnectionTrait,
1480    task_id: Uuid,
1481    error: &str,
1482) -> Result<(), DurableError> {
1483    let model = Task::find_by_id(task_id).one(db).await?;
1484
1485    if let Some(m) = model {
1486        let mut active: TaskActiveModel = m.into();
1487        active.status = Set(TaskStatus::Failed);
1488        active.error = Set(Some(error.to_string()));
1489        active.completed_at = Set(Some(chrono::Utc::now().into()));
1490        active.update(db).await?;
1491    }
1492    Ok(())
1493}
1494
1495#[cfg(test)]
1496mod tests {
1497    use super::*;
1498    use std::sync::Arc;
1499    use std::sync::atomic::{AtomicU32, Ordering};
1500
1501    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1502    /// should be called exactly once and return Ok.
1503    #[tokio::test]
1504    async fn test_retry_db_write_succeeds_first_try() {
1505        let call_count = Arc::new(AtomicU32::new(0));
1506        let cc = call_count.clone();
1507        let result = retry_db_write(|| {
1508            let c = cc.clone();
1509            async move {
1510                c.fetch_add(1, Ordering::SeqCst);
1511                Ok::<(), DurableError>(())
1512            }
1513        })
1514        .await;
1515        assert!(result.is_ok());
1516        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1517    }
1518
1519    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1520    /// fails twice then succeeds should return Ok and be called 3 times.
1521    #[tokio::test]
1522    async fn test_retry_db_write_succeeds_after_transient_failure() {
1523        let call_count = Arc::new(AtomicU32::new(0));
1524        let cc = call_count.clone();
1525        let result = retry_db_write(|| {
1526            let c = cc.clone();
1527            async move {
1528                let n = c.fetch_add(1, Ordering::SeqCst);
1529                if n < 2 {
1530                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1531                        "transient".to_string(),
1532                    )))
1533                } else {
1534                    Ok(())
1535                }
1536            }
1537        })
1538        .await;
1539        assert!(result.is_ok());
1540        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1541    }
1542
1543    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1544    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1545    #[tokio::test]
1546    async fn test_retry_db_write_exhausts_retries() {
1547        let call_count = Arc::new(AtomicU32::new(0));
1548        let cc = call_count.clone();
1549        let result = retry_db_write(|| {
1550            let c = cc.clone();
1551            async move {
1552                c.fetch_add(1, Ordering::SeqCst);
1553                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1554                    "always fails".to_string(),
1555                )))
1556            }
1557        })
1558        .await;
1559        assert!(result.is_err());
1560        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1561        assert_eq!(
1562            call_count.load(Ordering::SeqCst),
1563            1 + MAX_CHECKPOINT_RETRIES
1564        );
1565    }
1566
1567    /// test_retry_db_write_returns_original_error: when all retries are
1568    /// exhausted the FIRST error is returned, not the last retry error.
1569    #[tokio::test]
1570    async fn test_retry_db_write_returns_original_error() {
1571        let call_count = Arc::new(AtomicU32::new(0));
1572        let cc = call_count.clone();
1573        let result = retry_db_write(|| {
1574            let c = cc.clone();
1575            async move {
1576                let n = c.fetch_add(1, Ordering::SeqCst);
1577                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1578                    "error-{}",
1579                    n
1580                ))))
1581            }
1582        })
1583        .await;
1584        let err = result.unwrap_err();
1585        // The message of the first error contains "error-0"
1586        assert!(
1587            err.to_string().contains("error-0"),
1588            "expected first error (error-0), got: {err}"
1589        );
1590    }
1591}