Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── Retry constants ───────────────────────────────────────────────────────────
21
22const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25/// Retry a fallible DB write with exponential backoff.
26///
27/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
28/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
29/// error if all retries are exhausted.
30async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32    F: FnMut() -> Fut,
33    Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35    match f().await {
36        Ok(()) => Ok(()),
37        Err(first_err) => {
38            for i in 0..MAX_CHECKPOINT_RETRIES {
39                tokio::time::sleep(Duration::from_millis(
40                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41                ))
42                .await;
43                if f().await.is_ok() {
44                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45                    return Ok(());
46                }
47            }
48            Err(first_err)
49        }
50    }
51}
52
53/// Policy for retrying a step on failure.
54pub struct RetryPolicy {
55    pub max_retries: u32,
56    pub initial_backoff: std::time::Duration,
57    pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61    /// No retries — fails immediately on error.
62    pub fn none() -> Self {
63        Self {
64            max_retries: 0,
65            initial_backoff: std::time::Duration::from_secs(0),
66            backoff_multiplier: 1.0,
67        }
68    }
69
70    /// Exponential backoff: backoff doubles each retry.
71    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72        Self {
73            max_retries,
74            initial_backoff,
75            backoff_multiplier: 2.0,
76        }
77    }
78
79    /// Fixed backoff: same duration between all retries.
80    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff: backoff,
84            backoff_multiplier: 1.0,
85        }
86    }
87}
88
89/// Sort order for task listing.
90pub enum TaskSort {
91    CreatedAt(Order),
92    StartedAt(Order),
93    CompletedAt(Order),
94    Name(Order),
95    Status(Order),
96}
97
98/// Builder for filtering, sorting, and paginating task queries.
99///
100/// ```ignore
101/// let query = TaskQuery::default()
102///     .status(TaskStatus::Running)
103///     .kind("WORKFLOW")
104///     .root_only(true)
105///     .sort(TaskSort::CreatedAt(Order::Desc))
106///     .limit(20);
107/// let tasks = Ctx::list(&db, query).await?;
108/// ```
109pub struct TaskQuery {
110    pub status: Option<TaskStatus>,
111    pub kind: Option<String>,
112    pub parent_id: Option<Uuid>,
113    pub root_only: bool,
114    pub name: Option<String>,
115    pub queue_name: Option<String>,
116    pub sort: TaskSort,
117    pub limit: Option<u64>,
118    pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122    fn default() -> Self {
123        Self {
124            status: None,
125            kind: None,
126            parent_id: None,
127            root_only: false,
128            name: None,
129            queue_name: None,
130            sort: TaskSort::CreatedAt(Order::Desc),
131            limit: None,
132            offset: None,
133        }
134    }
135}
136
137impl TaskQuery {
138    /// Filter by status.
139    pub fn status(mut self, status: TaskStatus) -> Self {
140        self.status = Some(status);
141        self
142    }
143
144    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
145    pub fn kind(mut self, kind: &str) -> Self {
146        self.kind = Some(kind.to_string());
147        self
148    }
149
150    /// Filter by parent task ID (direct children only).
151    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152        self.parent_id = Some(parent_id);
153        self
154    }
155
156    /// Only return root tasks (no parent).
157    pub fn root_only(mut self, root_only: bool) -> Self {
158        self.root_only = root_only;
159        self
160    }
161
162    /// Filter by task name.
163    pub fn name(mut self, name: &str) -> Self {
164        self.name = Some(name.to_string());
165        self
166    }
167
168    /// Filter by queue name.
169    pub fn queue_name(mut self, queue: &str) -> Self {
170        self.queue_name = Some(queue.to_string());
171        self
172    }
173
174    /// Set the sort order.
175    pub fn sort(mut self, sort: TaskSort) -> Self {
176        self.sort = sort;
177        self
178    }
179
180    /// Limit the number of results.
181    pub fn limit(mut self, limit: u64) -> Self {
182        self.limit = Some(limit);
183        self
184    }
185
186    /// Skip the first N results.
187    pub fn offset(mut self, offset: u64) -> Self {
188        self.offset = Some(offset);
189        self
190    }
191}
192
193/// Summary of a task returned by `Ctx::list()`.
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196    pub id: Uuid,
197    pub parent_id: Option<Uuid>,
198    pub name: String,
199    pub status: TaskStatus,
200    pub kind: String,
201    pub input: Option<serde_json::Value>,
202    pub output: Option<serde_json::Value>,
203    pub error: Option<String>,
204    pub queue_name: Option<String>,
205    pub created_at: chrono::DateTime<chrono::FixedOffset>,
206    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211    fn from(m: durable_db::entity::task::Model) -> Self {
212        Self {
213            id: m.id,
214            parent_id: m.parent_id,
215            name: m.name,
216            status: m.status,
217            kind: m.kind,
218            input: m.input,
219            output: m.output,
220            error: m.error,
221            queue_name: m.queue_name,
222            created_at: m.created_at,
223            started_at: m.started_at,
224            completed_at: m.completed_at,
225        }
226    }
227}
228
229/// Context threaded through every workflow and step.
230///
231/// Users never create or manage task IDs. The SDK handles everything
232/// via `(parent_id, name)` lookups — the unique constraint in the schema
233/// guarantees exactly-once step creation.
234pub struct Ctx {
235    db: DatabaseConnection,
236    task_id: Uuid,
237    sequence: AtomicI32,
238}
239
240impl Ctx {
241    // ── Workflow lifecycle (user-facing) ──────────────────────────
242
243    /// Start a new root workflow. Always creates a fresh task with a new ID.
244    ///
245    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
246    ///
247    /// ```ignore
248    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
249    /// ```
250    pub async fn start(
251        db: &DatabaseConnection,
252        name: &str,
253        input: Option<serde_json::Value>,
254    ) -> Result<Self, DurableError> {
255        let task_id = Uuid::new_v4();
256        let input_json = match &input {
257            Some(v) => serde_json::to_string(v)?,
258            None => "null".to_string(),
259        };
260
261        let txn = db.begin().await?;
262        let sql = format!(
263            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at) \
264             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now())"
265        );
266        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
267            .await?;
268        txn.commit().await?;
269
270        Ok(Self {
271            db: db.clone(),
272            task_id,
273            sequence: AtomicI32::new(0),
274        })
275    }
276
277    /// Attach to an existing workflow by task ID. Used to resume a running or
278    /// recovered workflow without creating a new task row.
279    ///
280    /// ```ignore
281    /// let ctx = Ctx::from_id(&db, task_id).await?;
282    /// ```
283    pub async fn from_id(
284        db: &DatabaseConnection,
285        task_id: Uuid,
286    ) -> Result<Self, DurableError> {
287        // Verify the task exists
288        let model = Task::find_by_id(task_id).one(db).await?;
289        let _model =
290            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
291
292        // Start sequence at 0 so that replaying steps from the beginning works
293        // correctly. Steps are looked up by (parent_id, sequence), so the caller
294        // will replay completed steps (getting saved outputs) before executing
295        // any new steps.
296        Ok(Self {
297            db: db.clone(),
298            task_id,
299            sequence: AtomicI32::new(0),
300        })
301    }
302
303    /// Run a step. If already completed, returns saved output. Otherwise executes
304    /// the closure, saves the result, and returns it.
305    ///
306    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
307    ///
308    /// ```ignore
309    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
310    /// ```
311    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
312    where
313        T: Serialize + DeserializeOwned,
314        F: FnOnce() -> Fut,
315        Fut: std::future::Future<Output = Result<T, DurableError>>,
316    {
317        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
318
319        // Check if workflow is paused or cancelled before executing
320        check_status(&self.db, self.task_id).await?;
321
322        // Check if parent task's deadline has passed before executing
323        check_deadline(&self.db, self.task_id).await?;
324
325        // Begin transaction — the FOR UPDATE lock is held throughout step execution
326        let txn = self.db.begin().await?;
327
328        // Find or create — idempotent via UNIQUE(parent_id, name)
329        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
330        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
331        // max_retries=0: step() does not retry; use step_with_retry() for retries.
332        let (step_id, saved_output) = find_or_create_task(
333            &txn,
334            Some(self.task_id),
335            Some(seq),
336            name,
337            "STEP",
338            None,
339            true,
340            Some(0),
341        )
342        .await?;
343
344        // If already completed, replay from saved output
345        if let Some(output) = saved_output {
346            txn.commit().await?;
347            let val: T = serde_json::from_value(output)?;
348            tracing::debug!(step = name, seq, "replaying saved output");
349            return Ok(val);
350        }
351
352        // Execute the step closure within the transaction (row lock is held)
353        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
354        match f().await {
355            Ok(val) => {
356                let json = serde_json::to_value(&val)?;
357                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
358                txn.commit().await?;
359                tracing::debug!(step = name, seq, "step completed");
360                Ok(val)
361            }
362            Err(e) => {
363                let err_msg = e.to_string();
364                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
365                txn.commit().await?;
366                Err(e)
367            }
368        }
369    }
370
371    /// Run a DB-only step inside a single Postgres transaction.
372    ///
373    /// Both the user's DB work and the checkpoint save happen in the same transaction,
374    /// ensuring atomicity. If the closure returns an error, both the user writes and
375    /// the checkpoint are rolled back.
376    ///
377    /// ```ignore
378    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
379    ///     do_db_work(tx).await
380    /// })).await?;
381    /// ```
382    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
383    where
384        T: Serialize + DeserializeOwned + Send,
385        F: for<'tx> FnOnce(
386                &'tx DatabaseTransaction,
387            ) -> Pin<
388                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
389            > + Send,
390    {
391        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
392
393        // Check if workflow is paused or cancelled before executing
394        check_status(&self.db, self.task_id).await?;
395
396        // Find or create the step task record OUTSIDE the transaction.
397        // This is idempotent (UNIQUE constraint) and must exist before we begin.
398        let (step_id, saved_output) = find_or_create_task(
399            &self.db,
400            Some(self.task_id),
401            Some(seq),
402            name,
403            "TRANSACTION",
404            None,
405            false,
406            None,
407        )
408        .await?;
409
410        // If already completed, replay from saved output.
411        if let Some(output) = saved_output {
412            let val: T = serde_json::from_value(output)?;
413            tracing::debug!(step = name, seq, "replaying saved transaction output");
414            return Ok(val);
415        }
416
417        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
418        let tx = self.db.begin().await?;
419
420        set_status(&tx, step_id, TaskStatus::Running).await?;
421
422        match f(&tx).await {
423            Ok(val) => {
424                let json = serde_json::to_value(&val)?;
425                complete_task(&tx, step_id, json).await?;
426                tx.commit().await?;
427                tracing::debug!(step = name, seq, "transaction step committed");
428                Ok(val)
429            }
430            Err(e) => {
431                // Rollback happens automatically when tx is dropped.
432                // Record the failure on the main connection (outside the rolled-back tx).
433                drop(tx);
434                fail_task(&self.db, step_id, &e.to_string()).await?;
435                Err(e)
436            }
437        }
438    }
439
440    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
441    ///
442    /// ```ignore
443    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
444    /// // use child_ctx.step(...) for steps inside the child
445    /// child_ctx.complete(json!({"done": true})).await?;
446    /// ```
447    pub async fn child(
448        &self,
449        name: &str,
450        input: Option<serde_json::Value>,
451    ) -> Result<Self, DurableError> {
452        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
453
454        // Check if workflow is paused or cancelled before executing
455        check_status(&self.db, self.task_id).await?;
456
457        let txn = self.db.begin().await?;
458        // Child workflows also use plain find-or-create without locking.
459        let (child_id, _saved) = find_or_create_task(
460            &txn,
461            Some(self.task_id),
462            Some(seq),
463            name,
464            "WORKFLOW",
465            input,
466            false,
467            None,
468        )
469        .await?;
470
471        // If child already completed, return a Ctx that will replay
472        // (the caller should check is_completed() or just run steps which will replay)
473        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
474        txn.commit().await?;
475
476        Ok(Self {
477            db: self.db.clone(),
478            task_id: child_id,
479            sequence: AtomicI32::new(0),
480        })
481    }
482
483    /// Check if this workflow/child was already completed (for skipping in parent).
484    pub async fn is_completed(&self) -> Result<bool, DurableError> {
485        let status = get_status(&self.db, self.task_id).await?;
486        Ok(status == Some(TaskStatus::Completed))
487    }
488
489    /// Get the saved output if this task is completed.
490    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
491        match get_output(&self.db, self.task_id).await? {
492            Some(val) => Ok(Some(serde_json::from_value(val)?)),
493            None => Ok(None),
494        }
495    }
496
497    /// Mark this workflow as completed with an output value.
498    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
499        let json = serde_json::to_value(output)?;
500        let db = &self.db;
501        let task_id = self.task_id;
502        retry_db_write(|| complete_task(db, task_id, json.clone())).await
503    }
504
505    /// Run a step with a configurable retry policy.
506    ///
507    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
508    /// it may be called multiple times on retry. Retries happen in-process with
509    /// configurable backoff between attempts.
510    ///
511    /// ```ignore
512    /// let result: u32 = ctx
513    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
514    ///         api.call().await
515    ///     })
516    ///     .await?;
517    /// ```
518    pub async fn step_with_retry<T, F, Fut>(
519        &self,
520        name: &str,
521        policy: RetryPolicy,
522        f: F,
523    ) -> Result<T, DurableError>
524    where
525        T: Serialize + DeserializeOwned,
526        F: Fn() -> Fut,
527        Fut: std::future::Future<Output = Result<T, DurableError>>,
528    {
529        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
530
531        // Check if workflow is paused or cancelled before executing
532        check_status(&self.db, self.task_id).await?;
533
534        // Find or create — idempotent via UNIQUE(parent_id, name)
535        // Set max_retries from policy when creating.
536        // No locking here — retry logic handles re-execution in-process.
537        let (step_id, saved_output) = find_or_create_task(
538            &self.db,
539            Some(self.task_id),
540            Some(seq),
541            name,
542            "STEP",
543            None,
544            false,
545            Some(policy.max_retries),
546        )
547        .await?;
548
549        // If already completed, replay from saved output
550        if let Some(output) = saved_output {
551            let val: T = serde_json::from_value(output)?;
552            tracing::debug!(step = name, seq, "replaying saved output");
553            return Ok(val);
554        }
555
556        // Get current retry state from DB (for resume across restarts)
557        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
558
559        // Retry loop
560        loop {
561            // Re-check status before each retry attempt
562            check_status(&self.db, self.task_id).await?;
563            set_status(&self.db, step_id, TaskStatus::Running).await?;
564            match f().await {
565                Ok(val) => {
566                    let json = serde_json::to_value(&val)?;
567                    complete_task(&self.db, step_id, json).await?;
568                    tracing::debug!(step = name, seq, retry_count, "step completed");
569                    return Ok(val);
570                }
571                Err(e) => {
572                    if retry_count < max_retries {
573                        // Increment retry count and reset to PENDING
574                        retry_count = increment_retry_count(&self.db, step_id).await?;
575                        tracing::debug!(
576                            step = name,
577                            seq,
578                            retry_count,
579                            max_retries,
580                            "step failed, retrying"
581                        );
582
583                        // Compute backoff duration
584                        let backoff = if policy.initial_backoff.is_zero() {
585                            std::time::Duration::ZERO
586                        } else {
587                            let factor = policy
588                                .backoff_multiplier
589                                .powi((retry_count - 1) as i32)
590                                .max(1.0);
591                            let millis =
592                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
593                            std::time::Duration::from_millis(millis)
594                        };
595
596                        if !backoff.is_zero() {
597                            tokio::time::sleep(backoff).await;
598                        }
599                    } else {
600                        // Exhausted retries — mark FAILED
601                        fail_task(&self.db, step_id, &e.to_string()).await?;
602                        tracing::debug!(
603                            step = name,
604                            seq,
605                            retry_count,
606                            "step exhausted retries, marked FAILED"
607                        );
608                        return Err(e);
609                    }
610                }
611            }
612        }
613    }
614
615    /// Mark this workflow as failed.
616    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
617        let db = &self.db;
618        let task_id = self.task_id;
619        retry_db_write(|| fail_task(db, task_id, error)).await
620    }
621
622    /// Set the timeout for this task in milliseconds.
623    ///
624    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
625    /// for recovery by `Executor::recover()`.
626    ///
627    /// If the task is currently RUNNING and has a `started_at`, this also
628    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
629    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
630        let sql = format!(
631            "UPDATE durable.task \
632             SET timeout_ms = {timeout_ms}, \
633                 deadline_epoch_ms = CASE \
634                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
635                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
636                     ELSE deadline_epoch_ms \
637                 END \
638             WHERE id = '{}'",
639            self.task_id
640        );
641        self.db
642            .execute(Statement::from_string(DbBackend::Postgres, sql))
643            .await?;
644        Ok(())
645    }
646
647    /// Start or resume a root workflow with a timeout in milliseconds.
648    ///
649    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
650    pub async fn start_with_timeout(
651        db: &DatabaseConnection,
652        name: &str,
653        input: Option<serde_json::Value>,
654        timeout_ms: i64,
655    ) -> Result<Self, DurableError> {
656        let ctx = Self::start(db, name, input).await?;
657        ctx.set_timeout(timeout_ms).await?;
658        Ok(ctx)
659    }
660
661    // ── Workflow control (management API) ─────────────────────────
662
663    /// Pause a workflow by ID. Sets status to PAUSED and recursively
664    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
665    ///
666    /// Only workflows in PENDING or RUNNING status can be paused.
667    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
668        let model = Task::find_by_id(task_id).one(db).await?;
669        let model =
670            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
671
672        match model.status {
673            TaskStatus::Pending | TaskStatus::Running => {}
674            status => {
675                return Err(DurableError::custom(format!(
676                    "cannot pause task in {status} status"
677                )));
678            }
679        }
680
681        // Pause the task itself and all descendants in one recursive CTE
682        let sql = format!(
683            "WITH RECURSIVE descendants AS ( \
684                 SELECT id FROM durable.task WHERE id = '{task_id}' \
685                 UNION ALL \
686                 SELECT t.id FROM durable.task t \
687                 INNER JOIN descendants d ON t.parent_id = d.id \
688             ) \
689             UPDATE durable.task SET status = 'PAUSED' \
690             WHERE id IN (SELECT id FROM descendants) \
691               AND status IN ('PENDING', 'RUNNING')"
692        );
693        db.execute(Statement::from_string(DbBackend::Postgres, sql))
694            .await?;
695
696        tracing::info!(%task_id, "workflow paused");
697        Ok(())
698    }
699
700    /// Resume a paused workflow by ID. Sets status back to RUNNING and
701    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
702    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
703        let model = Task::find_by_id(task_id).one(db).await?;
704        let model =
705            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
706
707        if model.status != TaskStatus::Paused {
708            return Err(DurableError::custom(format!(
709                "cannot resume task in {} status (must be PAUSED)",
710                model.status
711            )));
712        }
713
714        // Resume the root task back to RUNNING
715        let mut active: TaskActiveModel = model.into();
716        active.status = Set(TaskStatus::Running);
717        active.update(db).await?;
718
719        // Recursively resume all PAUSED descendants back to PENDING
720        let cascade_sql = format!(
721            "WITH RECURSIVE descendants AS ( \
722                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
723                 UNION ALL \
724                 SELECT t.id FROM durable.task t \
725                 INNER JOIN descendants d ON t.parent_id = d.id \
726             ) \
727             UPDATE durable.task SET status = 'PENDING' \
728             WHERE id IN (SELECT id FROM descendants) \
729               AND status = 'PAUSED'"
730        );
731        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
732            .await?;
733
734        tracing::info!(%task_id, "workflow resumed");
735        Ok(())
736    }
737
738    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
739    /// cascades to all non-terminal descendants.
740    ///
741    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
742    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
743        let model = Task::find_by_id(task_id).one(db).await?;
744        let model =
745            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
746
747        match model.status {
748            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
749                return Err(DurableError::custom(format!(
750                    "cannot cancel task in {} status",
751                    model.status
752                )));
753            }
754            _ => {}
755        }
756
757        // Cancel the task itself and all non-terminal descendants in one recursive CTE
758        let sql = format!(
759            "WITH RECURSIVE descendants AS ( \
760                 SELECT id FROM durable.task WHERE id = '{task_id}' \
761                 UNION ALL \
762                 SELECT t.id FROM durable.task t \
763                 INNER JOIN descendants d ON t.parent_id = d.id \
764             ) \
765             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
766             WHERE id IN (SELECT id FROM descendants) \
767               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
768        );
769        db.execute(Statement::from_string(DbBackend::Postgres, sql))
770            .await?;
771
772        tracing::info!(%task_id, "workflow cancelled");
773        Ok(())
774    }
775
776    // ── Query API ─────────────────────────────────────────────────
777
778    /// List tasks matching the given filter, with sorting and pagination.
779    ///
780    /// ```ignore
781    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
782    /// ```
783    pub async fn list(
784        db: &DatabaseConnection,
785        query: TaskQuery,
786    ) -> Result<Vec<TaskSummary>, DurableError> {
787        let mut select = Task::find();
788
789        // Filters
790        if let Some(status) = &query.status {
791            select = select.filter(TaskColumn::Status.eq(status.to_string()));
792        }
793        if let Some(kind) = &query.kind {
794            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
795        }
796        if let Some(parent_id) = query.parent_id {
797            select = select.filter(TaskColumn::ParentId.eq(parent_id));
798        }
799        if query.root_only {
800            select = select.filter(TaskColumn::ParentId.is_null());
801        }
802        if let Some(name) = &query.name {
803            select = select.filter(TaskColumn::Name.eq(name.as_str()));
804        }
805        if let Some(queue) = &query.queue_name {
806            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
807        }
808
809        // Sorting
810        let (col, order) = match query.sort {
811            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
812            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
813            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
814            TaskSort::Name(ord) => (TaskColumn::Name, ord),
815            TaskSort::Status(ord) => (TaskColumn::Status, ord),
816        };
817        select = select.order_by(col, order);
818
819        // Pagination
820        if let Some(offset) = query.offset {
821            select = select.offset(offset);
822        }
823        if let Some(limit) = query.limit {
824            select = select.limit(limit);
825        }
826
827        let models = select.all(db).await?;
828
829        Ok(models.into_iter().map(TaskSummary::from).collect())
830    }
831
832    /// Count tasks matching the given filter.
833    pub async fn count(
834        db: &DatabaseConnection,
835        query: TaskQuery,
836    ) -> Result<u64, DurableError> {
837        let mut select = Task::find();
838
839        if let Some(status) = &query.status {
840            select = select.filter(TaskColumn::Status.eq(status.to_string()));
841        }
842        if let Some(kind) = &query.kind {
843            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
844        }
845        if let Some(parent_id) = query.parent_id {
846            select = select.filter(TaskColumn::ParentId.eq(parent_id));
847        }
848        if query.root_only {
849            select = select.filter(TaskColumn::ParentId.is_null());
850        }
851        if let Some(name) = &query.name {
852            select = select.filter(TaskColumn::Name.eq(name.as_str()));
853        }
854        if let Some(queue) = &query.queue_name {
855            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
856        }
857
858        let count = select.count(db).await?;
859        Ok(count)
860    }
861
862    // ── Accessors ────────────────────────────────────────────────
863
864    pub fn db(&self) -> &DatabaseConnection {
865        &self.db
866    }
867
868    pub fn task_id(&self) -> Uuid {
869        self.task_id
870    }
871
872    pub fn next_sequence(&self) -> i32 {
873        self.sequence.fetch_add(1, Ordering::SeqCst)
874    }
875}
876
877// ── Internal SQL helpers ─────────────────────────────────────────────
878
879/// Find an existing task by (parent_id, name) or create a new one.
880///
881/// Returns `(task_id, Option<saved_output>)`:
882/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
883/// - `saved_output` is `None` when the task is new or in-progress.
884///
885/// When `lock` is `true` and an existing non-completed task is found, this
886/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
887/// another worker holds the lock, `DurableError::StepLocked` is returned so
888/// the caller can skip execution rather than double-firing side effects.
889///
890/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
891/// and child-workflow creation where concurrent start is safe).
892///
893/// When `lock` is `true`, the caller MUST call this within a transaction so
894/// the row lock is held throughout step execution.
895///
896/// `max_retries`: if Some, overrides the schema default when creating a new task.
897#[allow(clippy::too_many_arguments)]
898async fn find_or_create_task(
899    db: &impl ConnectionTrait,
900    parent_id: Option<Uuid>,
901    sequence: Option<i32>,
902    name: &str,
903    kind: &str,
904    input: Option<serde_json::Value>,
905    lock: bool,
906    max_retries: Option<u32>,
907) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
908    let parent_eq = match parent_id {
909        Some(p) => format!("= '{p}'"),
910        None => "IS NULL".to_string(),
911    };
912    let parent_sql = match parent_id {
913        Some(p) => format!("'{p}'"),
914        None => "NULL".to_string(),
915    };
916
917    if lock {
918        // Locking path (for steps): we need exactly-once execution.
919        //
920        // Strategy:
921        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
922        //    If another transaction is concurrently inserting the same row,
923        //    Postgres will block here until that transaction commits or rolls
924        //    back, ensuring we never see a phantom "not found" for a row being
925        //    inserted.
926        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
927        //    should get it back; if the row existed and is locked by another
928        //    worker we get nothing.
929        // 3. If SKIP LOCKED returns empty, return StepLocked.
930
931        let new_id = Uuid::new_v4();
932        let seq_sql = match sequence {
933            Some(s) => s.to_string(),
934            None => "NULL".to_string(),
935        };
936        let input_sql = match &input {
937            Some(v) => format!("'{}'", serde_json::to_string(v)?),
938            None => "NULL".to_string(),
939        };
940
941        let max_retries_sql = match max_retries {
942            Some(r) => r.to_string(),
943            None => "3".to_string(), // schema default
944        };
945
946        // Step 1: insert-or-skip
947        let insert_sql = format!(
948            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
949             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
950             ON CONFLICT (parent_id, sequence) DO NOTHING"
951        );
952        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
953            .await?;
954
955        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
956        let lock_sql = format!(
957            "SELECT id, status::text, output FROM durable.task \
958             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
959             FOR UPDATE SKIP LOCKED"
960        );
961        let row = db
962            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
963            .await?;
964
965        if let Some(row) = row {
966            let id: Uuid = row
967                .try_get_by_index(0)
968                .map_err(|e| DurableError::custom(e.to_string()))?;
969            let status: String = row
970                .try_get_by_index(1)
971                .map_err(|e| DurableError::custom(e.to_string()))?;
972            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
973
974            if status == TaskStatus::Completed.to_string() {
975                // Replay path — return saved output
976                return Ok((id, output));
977            }
978            // Task exists and we hold the lock — proceed to execute
979            return Ok((id, None));
980        }
981
982        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
983        Err(DurableError::StepLocked(name.to_string()))
984    } else {
985        // Plain find without locking — safe for workflow-level operations.
986        // Multiple workers resuming the same workflow is fine; individual
987        // steps will be locked when executed.
988        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
989        query = match parent_id {
990            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
991            None => query.filter(TaskColumn::ParentId.is_null()),
992        };
993        // Skip cancelled/failed tasks — a fresh row will be created instead.
994        // Completed tasks are still returned so child tasks can replay saved output.
995        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
996        let existing = query
997            .filter(TaskColumn::Status.is_not_in(status_exclusions))
998            .one(db)
999            .await?;
1000
1001        if let Some(model) = existing {
1002            if model.status == TaskStatus::Completed {
1003                return Ok((model.id, model.output));
1004            }
1005            return Ok((model.id, None));
1006        }
1007
1008        // Task does not exist — create it
1009        let id = Uuid::new_v4();
1010        let new_task = TaskActiveModel {
1011            id: Set(id),
1012            parent_id: Set(parent_id),
1013            sequence: Set(sequence),
1014            name: Set(name.to_string()),
1015            kind: Set(kind.to_string()),
1016            status: Set(TaskStatus::Pending),
1017            input: Set(input),
1018            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1019            ..Default::default()
1020        };
1021        new_task.insert(db).await?;
1022
1023        Ok((id, None))
1024    }
1025}
1026
1027async fn get_output(
1028    db: &impl ConnectionTrait,
1029    task_id: Uuid,
1030) -> Result<Option<serde_json::Value>, DurableError> {
1031    let model = Task::find_by_id(task_id)
1032        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1033        .one(db)
1034        .await?;
1035
1036    Ok(model.and_then(|m| m.output))
1037}
1038
1039async fn get_status(
1040    db: &impl ConnectionTrait,
1041    task_id: Uuid,
1042) -> Result<Option<TaskStatus>, DurableError> {
1043    let model = Task::find_by_id(task_id).one(db).await?;
1044
1045    Ok(model.map(|m| m.status))
1046}
1047
1048/// Returns (retry_count, max_retries) for a task.
1049async fn get_retry_info(
1050    db: &DatabaseConnection,
1051    task_id: Uuid,
1052) -> Result<(u32, u32), DurableError> {
1053    let model = Task::find_by_id(task_id).one(db).await?;
1054
1055    match model {
1056        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1057        None => Err(DurableError::custom(format!(
1058            "task {task_id} not found when reading retry info"
1059        ))),
1060    }
1061}
1062
1063/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1064async fn increment_retry_count(
1065    db: &DatabaseConnection,
1066    task_id: Uuid,
1067) -> Result<u32, DurableError> {
1068    let model = Task::find_by_id(task_id).one(db).await?;
1069
1070    match model {
1071        Some(m) => {
1072            let new_count = m.retry_count + 1;
1073            let mut active: TaskActiveModel = m.into();
1074            active.retry_count = Set(new_count);
1075            active.status = Set(TaskStatus::Pending);
1076            active.error = Set(None);
1077            active.completed_at = Set(None);
1078            active.update(db).await?;
1079            Ok(new_count as u32)
1080        }
1081        None => Err(DurableError::custom(format!(
1082            "task {task_id} not found when incrementing retry count"
1083        ))),
1084    }
1085}
1086
1087// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1088// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1089async fn set_status(
1090    db: &impl ConnectionTrait,
1091    task_id: Uuid,
1092    status: TaskStatus,
1093) -> Result<(), DurableError> {
1094    let sql = format!(
1095        "UPDATE durable.task \
1096         SET status = '{status}', \
1097             started_at = COALESCE(started_at, now()), \
1098             deadline_epoch_ms = CASE \
1099                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1100                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1101                 ELSE deadline_epoch_ms \
1102             END \
1103         WHERE id = '{task_id}'"
1104    );
1105    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1106        .await?;
1107    Ok(())
1108}
1109
1110/// Check if the task is paused or cancelled. Returns an error if so.
1111async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1112    let status = get_status(db, task_id).await?;
1113    match status {
1114        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1115        Some(TaskStatus::Cancelled) => {
1116            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1117        }
1118        _ => Ok(()),
1119    }
1120}
1121
1122/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1123async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1124    let model = Task::find_by_id(task_id).one(db).await?;
1125
1126    if let Some(m) = model
1127        && let Some(deadline_ms) = m.deadline_epoch_ms
1128    {
1129        let now_ms = std::time::SystemTime::now()
1130            .duration_since(std::time::UNIX_EPOCH)
1131            .map(|d| d.as_millis() as i64)
1132            .unwrap_or(0);
1133        if now_ms > deadline_ms {
1134            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1135        }
1136    }
1137
1138    Ok(())
1139}
1140
1141async fn complete_task(
1142    db: &impl ConnectionTrait,
1143    task_id: Uuid,
1144    output: serde_json::Value,
1145) -> Result<(), DurableError> {
1146    let model = Task::find_by_id(task_id).one(db).await?;
1147
1148    if let Some(m) = model {
1149        let mut active: TaskActiveModel = m.into();
1150        active.status = Set(TaskStatus::Completed);
1151        active.output = Set(Some(output));
1152        active.completed_at = Set(Some(chrono::Utc::now().into()));
1153        active.update(db).await?;
1154    }
1155    Ok(())
1156}
1157
1158async fn fail_task(
1159    db: &impl ConnectionTrait,
1160    task_id: Uuid,
1161    error: &str,
1162) -> Result<(), DurableError> {
1163    let model = Task::find_by_id(task_id).one(db).await?;
1164
1165    if let Some(m) = model {
1166        let mut active: TaskActiveModel = m.into();
1167        active.status = Set(TaskStatus::Failed);
1168        active.error = Set(Some(error.to_string()));
1169        active.completed_at = Set(Some(chrono::Utc::now().into()));
1170        active.update(db).await?;
1171    }
1172    Ok(())
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177    use super::*;
1178    use std::sync::Arc;
1179    use std::sync::atomic::{AtomicU32, Ordering};
1180
1181    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1182    /// should be called exactly once and return Ok.
1183    #[tokio::test]
1184    async fn test_retry_db_write_succeeds_first_try() {
1185        let call_count = Arc::new(AtomicU32::new(0));
1186        let cc = call_count.clone();
1187        let result = retry_db_write(|| {
1188            let c = cc.clone();
1189            async move {
1190                c.fetch_add(1, Ordering::SeqCst);
1191                Ok::<(), DurableError>(())
1192            }
1193        })
1194        .await;
1195        assert!(result.is_ok());
1196        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1197    }
1198
1199    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1200    /// fails twice then succeeds should return Ok and be called 3 times.
1201    #[tokio::test]
1202    async fn test_retry_db_write_succeeds_after_transient_failure() {
1203        let call_count = Arc::new(AtomicU32::new(0));
1204        let cc = call_count.clone();
1205        let result = retry_db_write(|| {
1206            let c = cc.clone();
1207            async move {
1208                let n = c.fetch_add(1, Ordering::SeqCst);
1209                if n < 2 {
1210                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1211                        "transient".to_string(),
1212                    )))
1213                } else {
1214                    Ok(())
1215                }
1216            }
1217        })
1218        .await;
1219        assert!(result.is_ok());
1220        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1221    }
1222
1223    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1224    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1225    #[tokio::test]
1226    async fn test_retry_db_write_exhausts_retries() {
1227        let call_count = Arc::new(AtomicU32::new(0));
1228        let cc = call_count.clone();
1229        let result = retry_db_write(|| {
1230            let c = cc.clone();
1231            async move {
1232                c.fetch_add(1, Ordering::SeqCst);
1233                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1234                    "always fails".to_string(),
1235                )))
1236            }
1237        })
1238        .await;
1239        assert!(result.is_err());
1240        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1241        assert_eq!(
1242            call_count.load(Ordering::SeqCst),
1243            1 + MAX_CHECKPOINT_RETRIES
1244        );
1245    }
1246
1247    /// test_retry_db_write_returns_original_error: when all retries are
1248    /// exhausted the FIRST error is returned, not the last retry error.
1249    #[tokio::test]
1250    async fn test_retry_db_write_returns_original_error() {
1251        let call_count = Arc::new(AtomicU32::new(0));
1252        let cc = call_count.clone();
1253        let result = retry_db_write(|| {
1254            let c = cc.clone();
1255            async move {
1256                let n = c.fetch_add(1, Ordering::SeqCst);
1257                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1258                    "error-{}",
1259                    n
1260                ))))
1261            }
1262        })
1263        .await;
1264        let err = result.unwrap_err();
1265        // The message of the first error contains "error-0"
1266        assert!(
1267            err.to_string().contains("error-0"),
1268            "expected first error (error-0), got: {err}"
1269        );
1270    }
1271}