Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── Retry constants ───────────────────────────────────────────────────────────
21
22const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25/// Retry a fallible DB write with exponential backoff.
26///
27/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
28/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
29/// error if all retries are exhausted.
30async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32    F: FnMut() -> Fut,
33    Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35    match f().await {
36        Ok(()) => Ok(()),
37        Err(first_err) => {
38            for i in 0..MAX_CHECKPOINT_RETRIES {
39                tokio::time::sleep(Duration::from_millis(
40                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41                ))
42                .await;
43                if f().await.is_ok() {
44                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45                    return Ok(());
46                }
47            }
48            Err(first_err)
49        }
50    }
51}
52
53/// Policy for retrying a step on failure.
54pub struct RetryPolicy {
55    pub max_retries: u32,
56    pub initial_backoff: std::time::Duration,
57    pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61    /// No retries — fails immediately on error.
62    pub fn none() -> Self {
63        Self {
64            max_retries: 0,
65            initial_backoff: std::time::Duration::from_secs(0),
66            backoff_multiplier: 1.0,
67        }
68    }
69
70    /// Exponential backoff: backoff doubles each retry.
71    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72        Self {
73            max_retries,
74            initial_backoff,
75            backoff_multiplier: 2.0,
76        }
77    }
78
79    /// Fixed backoff: same duration between all retries.
80    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff: backoff,
84            backoff_multiplier: 1.0,
85        }
86    }
87}
88
89/// Sort order for task listing.
90pub enum TaskSort {
91    CreatedAt(Order),
92    StartedAt(Order),
93    CompletedAt(Order),
94    Name(Order),
95    Status(Order),
96}
97
98/// Builder for filtering, sorting, and paginating task queries.
99///
100/// ```ignore
101/// let query = TaskQuery::default()
102///     .status(TaskStatus::Running)
103///     .kind("WORKFLOW")
104///     .root_only(true)
105///     .sort(TaskSort::CreatedAt(Order::Desc))
106///     .limit(20);
107/// let tasks = Ctx::list(&db, query).await?;
108/// ```
109pub struct TaskQuery {
110    pub status: Option<TaskStatus>,
111    pub kind: Option<String>,
112    pub parent_id: Option<Uuid>,
113    pub root_only: bool,
114    pub name: Option<String>,
115    pub queue_name: Option<String>,
116    pub sort: TaskSort,
117    pub limit: Option<u64>,
118    pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122    fn default() -> Self {
123        Self {
124            status: None,
125            kind: None,
126            parent_id: None,
127            root_only: false,
128            name: None,
129            queue_name: None,
130            sort: TaskSort::CreatedAt(Order::Desc),
131            limit: None,
132            offset: None,
133        }
134    }
135}
136
137impl TaskQuery {
138    /// Filter by status.
139    pub fn status(mut self, status: TaskStatus) -> Self {
140        self.status = Some(status);
141        self
142    }
143
144    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
145    pub fn kind(mut self, kind: &str) -> Self {
146        self.kind = Some(kind.to_string());
147        self
148    }
149
150    /// Filter by parent task ID (direct children only).
151    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152        self.parent_id = Some(parent_id);
153        self
154    }
155
156    /// Only return root tasks (no parent).
157    pub fn root_only(mut self, root_only: bool) -> Self {
158        self.root_only = root_only;
159        self
160    }
161
162    /// Filter by task name.
163    pub fn name(mut self, name: &str) -> Self {
164        self.name = Some(name.to_string());
165        self
166    }
167
168    /// Filter by queue name.
169    pub fn queue_name(mut self, queue: &str) -> Self {
170        self.queue_name = Some(queue.to_string());
171        self
172    }
173
174    /// Set the sort order.
175    pub fn sort(mut self, sort: TaskSort) -> Self {
176        self.sort = sort;
177        self
178    }
179
180    /// Limit the number of results.
181    pub fn limit(mut self, limit: u64) -> Self {
182        self.limit = Some(limit);
183        self
184    }
185
186    /// Skip the first N results.
187    pub fn offset(mut self, offset: u64) -> Self {
188        self.offset = Some(offset);
189        self
190    }
191}
192
193/// Summary of a task returned by `Ctx::list()`.
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196    pub id: Uuid,
197    pub parent_id: Option<Uuid>,
198    pub name: String,
199    pub status: TaskStatus,
200    pub kind: String,
201    pub input: Option<serde_json::Value>,
202    pub output: Option<serde_json::Value>,
203    pub error: Option<String>,
204    pub queue_name: Option<String>,
205    pub created_at: chrono::DateTime<chrono::FixedOffset>,
206    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211    fn from(m: durable_db::entity::task::Model) -> Self {
212        Self {
213            id: m.id,
214            parent_id: m.parent_id,
215            name: m.name,
216            status: m.status,
217            kind: m.kind,
218            input: m.input,
219            output: m.output,
220            error: m.error,
221            queue_name: m.queue_name,
222            created_at: m.created_at,
223            started_at: m.started_at,
224            completed_at: m.completed_at,
225        }
226    }
227}
228
229/// Context threaded through every workflow and step.
230///
231/// Users never create or manage task IDs. The SDK handles everything
232/// via `(parent_id, name)` lookups — the unique constraint in the schema
233/// guarantees exactly-once step creation.
234pub struct Ctx {
235    db: DatabaseConnection,
236    task_id: Uuid,
237    sequence: AtomicI32,
238    executor_id: Option<String>,
239}
240
241impl Ctx {
242    // ── Workflow lifecycle (user-facing) ──────────────────────────
243
244    /// Start a new root workflow. Always creates a fresh task with a new ID.
245    ///
246    /// If [`init`](crate::init) was called, the task is automatically tagged
247    /// with the executor ID for crash recovery. Otherwise `executor_id` is NULL.
248    ///
249    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
250    ///
251    /// ```ignore
252    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
253    /// ```
254    pub async fn start(
255        db: &DatabaseConnection,
256        name: &str,
257        input: Option<serde_json::Value>,
258    ) -> Result<Self, DurableError> {
259        let task_id = Uuid::new_v4();
260        let input_json = match &input {
261            Some(v) => serde_json::to_string(v)?,
262            None => "null".to_string(),
263        };
264
265        // Read executor_id from the global set by init()
266        let executor_id = crate::executor_id();
267
268        let executor_col = if executor_id.is_some() { ", executor_id" } else { "" };
269        let executor_val = match &executor_id {
270            Some(eid) => format!(", '{eid}'"),
271            None => String::new(),
272        };
273
274        let txn = db.begin().await?;
275        let sql = format!(
276            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{executor_col}) \
277             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){executor_val})"
278        );
279        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
280            .await?;
281        txn.commit().await?;
282
283        Ok(Self {
284            db: db.clone(),
285            task_id,
286            sequence: AtomicI32::new(0),
287            executor_id,
288        })
289    }
290
291    /// Attach to an existing workflow by task ID. Used to resume a running or
292    /// recovered workflow without creating a new task row.
293    ///
294    /// ```ignore
295    /// let ctx = Ctx::from_id(&db, task_id).await?;
296    /// ```
297    pub async fn from_id(
298        db: &DatabaseConnection,
299        task_id: Uuid,
300    ) -> Result<Self, DurableError> {
301        // Verify the task exists
302        let model = Task::find_by_id(task_id).one(db).await?;
303        let _model =
304            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
305
306        // Claim the task with the current executor_id (if init() was called)
307        let executor_id = crate::executor_id();
308        if let Some(eid) = &executor_id {
309            db.execute(Statement::from_string(
310                DbBackend::Postgres,
311                format!(
312                    "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
313                ),
314            ))
315            .await?;
316        }
317
318        // Start sequence at 0 so that replaying steps from the beginning works
319        // correctly. Steps are looked up by (parent_id, sequence), so the caller
320        // will replay completed steps (getting saved outputs) before executing
321        // any new steps.
322        Ok(Self {
323            db: db.clone(),
324            task_id,
325            sequence: AtomicI32::new(0),
326            executor_id,
327        })
328    }
329
330    /// Run a step. If already completed, returns saved output. Otherwise executes
331    /// the closure, saves the result, and returns it.
332    ///
333    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
334    ///
335    /// ```ignore
336    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
337    /// ```
338    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
339    where
340        T: Serialize + DeserializeOwned,
341        F: FnOnce() -> Fut,
342        Fut: std::future::Future<Output = Result<T, DurableError>>,
343    {
344        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
345
346        // Check if workflow is paused or cancelled before executing
347        check_status(&self.db, self.task_id).await?;
348
349        // Check if parent task's deadline has passed before executing
350        check_deadline(&self.db, self.task_id).await?;
351
352        // Begin transaction — the FOR UPDATE lock is held throughout step execution
353        let txn = self.db.begin().await?;
354
355        // Find or create — idempotent via UNIQUE(parent_id, name)
356        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
357        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
358        // max_retries=0: step() does not retry; use step_with_retry() for retries.
359        let (step_id, saved_output) = find_or_create_task(
360            &txn,
361            Some(self.task_id),
362            Some(seq),
363            name,
364            "STEP",
365            None,
366            true,
367            Some(0),
368        )
369        .await?;
370
371        // If already completed, replay from saved output
372        if let Some(output) = saved_output {
373            txn.commit().await?;
374            let val: T = serde_json::from_value(output)?;
375            tracing::debug!(step = name, seq, "replaying saved output");
376            return Ok(val);
377        }
378
379        // Execute the step closure within the transaction (row lock is held)
380        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
381        match f().await {
382            Ok(val) => {
383                let json = serde_json::to_value(&val)?;
384                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
385                txn.commit().await?;
386                tracing::debug!(step = name, seq, "step completed");
387                Ok(val)
388            }
389            Err(e) => {
390                let err_msg = e.to_string();
391                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
392                txn.commit().await?;
393                Err(e)
394            }
395        }
396    }
397
398    /// Run a DB-only step inside a single Postgres transaction.
399    ///
400    /// Both the user's DB work and the checkpoint save happen in the same transaction,
401    /// ensuring atomicity. If the closure returns an error, both the user writes and
402    /// the checkpoint are rolled back.
403    ///
404    /// ```ignore
405    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
406    ///     do_db_work(tx).await
407    /// })).await?;
408    /// ```
409    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
410    where
411        T: Serialize + DeserializeOwned + Send,
412        F: for<'tx> FnOnce(
413                &'tx DatabaseTransaction,
414            ) -> Pin<
415                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
416            > + Send,
417    {
418        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
419
420        // Check if workflow is paused or cancelled before executing
421        check_status(&self.db, self.task_id).await?;
422
423        // Find or create the step task record OUTSIDE the transaction.
424        // This is idempotent (UNIQUE constraint) and must exist before we begin.
425        let (step_id, saved_output) = find_or_create_task(
426            &self.db,
427            Some(self.task_id),
428            Some(seq),
429            name,
430            "TRANSACTION",
431            None,
432            false,
433            None,
434        )
435        .await?;
436
437        // If already completed, replay from saved output.
438        if let Some(output) = saved_output {
439            let val: T = serde_json::from_value(output)?;
440            tracing::debug!(step = name, seq, "replaying saved transaction output");
441            return Ok(val);
442        }
443
444        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
445        let tx = self.db.begin().await?;
446
447        set_status(&tx, step_id, TaskStatus::Running).await?;
448
449        match f(&tx).await {
450            Ok(val) => {
451                let json = serde_json::to_value(&val)?;
452                complete_task(&tx, step_id, json).await?;
453                tx.commit().await?;
454                tracing::debug!(step = name, seq, "transaction step committed");
455                Ok(val)
456            }
457            Err(e) => {
458                // Rollback happens automatically when tx is dropped.
459                // Record the failure on the main connection (outside the rolled-back tx).
460                drop(tx);
461                fail_task(&self.db, step_id, &e.to_string()).await?;
462                Err(e)
463            }
464        }
465    }
466
467    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
468    ///
469    /// ```ignore
470    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
471    /// // use child_ctx.step(...) for steps inside the child
472    /// child_ctx.complete(json!({"done": true})).await?;
473    /// ```
474    pub async fn child(
475        &self,
476        name: &str,
477        input: Option<serde_json::Value>,
478    ) -> Result<Self, DurableError> {
479        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
480
481        // Check if workflow is paused or cancelled before executing
482        check_status(&self.db, self.task_id).await?;
483
484        let txn = self.db.begin().await?;
485        // Child workflows also use plain find-or-create without locking.
486        let (child_id, _saved) = find_or_create_task(
487            &txn,
488            Some(self.task_id),
489            Some(seq),
490            name,
491            "WORKFLOW",
492            input,
493            false,
494            None,
495        )
496        .await?;
497
498        // If child already completed, return a Ctx that will replay
499        // (the caller should check is_completed() or just run steps which will replay)
500        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
501        txn.commit().await?;
502
503        Ok(Self {
504            db: self.db.clone(),
505            task_id: child_id,
506            sequence: AtomicI32::new(0),
507            executor_id: self.executor_id.clone(),
508        })
509    }
510
511    /// Check if this workflow/child was already completed (for skipping in parent).
512    pub async fn is_completed(&self) -> Result<bool, DurableError> {
513        let status = get_status(&self.db, self.task_id).await?;
514        Ok(status == Some(TaskStatus::Completed))
515    }
516
517    /// Get the saved output if this task is completed.
518    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
519        match get_output(&self.db, self.task_id).await? {
520            Some(val) => Ok(Some(serde_json::from_value(val)?)),
521            None => Ok(None),
522        }
523    }
524
525    /// Mark this workflow as completed with an output value.
526    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
527        let json = serde_json::to_value(output)?;
528        let db = &self.db;
529        let task_id = self.task_id;
530        retry_db_write(|| complete_task(db, task_id, json.clone())).await
531    }
532
533    /// Run a step with a configurable retry policy.
534    ///
535    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
536    /// it may be called multiple times on retry. Retries happen in-process with
537    /// configurable backoff between attempts.
538    ///
539    /// ```ignore
540    /// let result: u32 = ctx
541    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
542    ///         api.call().await
543    ///     })
544    ///     .await?;
545    /// ```
546    pub async fn step_with_retry<T, F, Fut>(
547        &self,
548        name: &str,
549        policy: RetryPolicy,
550        f: F,
551    ) -> Result<T, DurableError>
552    where
553        T: Serialize + DeserializeOwned,
554        F: Fn() -> Fut,
555        Fut: std::future::Future<Output = Result<T, DurableError>>,
556    {
557        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
558
559        // Check if workflow is paused or cancelled before executing
560        check_status(&self.db, self.task_id).await?;
561
562        // Find or create — idempotent via UNIQUE(parent_id, name)
563        // Set max_retries from policy when creating.
564        // No locking here — retry logic handles re-execution in-process.
565        let (step_id, saved_output) = find_or_create_task(
566            &self.db,
567            Some(self.task_id),
568            Some(seq),
569            name,
570            "STEP",
571            None,
572            false,
573            Some(policy.max_retries),
574        )
575        .await?;
576
577        // If already completed, replay from saved output
578        if let Some(output) = saved_output {
579            let val: T = serde_json::from_value(output)?;
580            tracing::debug!(step = name, seq, "replaying saved output");
581            return Ok(val);
582        }
583
584        // Get current retry state from DB (for resume across restarts)
585        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
586
587        // Retry loop
588        loop {
589            // Re-check status before each retry attempt
590            check_status(&self.db, self.task_id).await?;
591            set_status(&self.db, step_id, TaskStatus::Running).await?;
592            match f().await {
593                Ok(val) => {
594                    let json = serde_json::to_value(&val)?;
595                    complete_task(&self.db, step_id, json).await?;
596                    tracing::debug!(step = name, seq, retry_count, "step completed");
597                    return Ok(val);
598                }
599                Err(e) => {
600                    if retry_count < max_retries {
601                        // Increment retry count and reset to PENDING
602                        retry_count = increment_retry_count(&self.db, step_id).await?;
603                        tracing::debug!(
604                            step = name,
605                            seq,
606                            retry_count,
607                            max_retries,
608                            "step failed, retrying"
609                        );
610
611                        // Compute backoff duration
612                        let backoff = if policy.initial_backoff.is_zero() {
613                            std::time::Duration::ZERO
614                        } else {
615                            let factor = policy
616                                .backoff_multiplier
617                                .powi((retry_count - 1) as i32)
618                                .max(1.0);
619                            let millis =
620                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
621                            std::time::Duration::from_millis(millis)
622                        };
623
624                        if !backoff.is_zero() {
625                            tokio::time::sleep(backoff).await;
626                        }
627                    } else {
628                        // Exhausted retries — mark FAILED
629                        fail_task(&self.db, step_id, &e.to_string()).await?;
630                        tracing::debug!(
631                            step = name,
632                            seq,
633                            retry_count,
634                            "step exhausted retries, marked FAILED"
635                        );
636                        return Err(e);
637                    }
638                }
639            }
640        }
641    }
642
643    /// Mark this workflow as failed.
644    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
645        let db = &self.db;
646        let task_id = self.task_id;
647        retry_db_write(|| fail_task(db, task_id, error)).await
648    }
649
650    /// Set the timeout for this task in milliseconds.
651    ///
652    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
653    /// for recovery by `Executor::recover()`.
654    ///
655    /// If the task is currently RUNNING and has a `started_at`, this also
656    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
657    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
658        let sql = format!(
659            "UPDATE durable.task \
660             SET timeout_ms = {timeout_ms}, \
661                 deadline_epoch_ms = CASE \
662                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
663                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
664                     ELSE deadline_epoch_ms \
665                 END \
666             WHERE id = '{}'",
667            self.task_id
668        );
669        self.db
670            .execute(Statement::from_string(DbBackend::Postgres, sql))
671            .await?;
672        Ok(())
673    }
674
675    /// Start or resume a root workflow with a timeout in milliseconds.
676    ///
677    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
678    pub async fn start_with_timeout(
679        db: &DatabaseConnection,
680        name: &str,
681        input: Option<serde_json::Value>,
682        timeout_ms: i64,
683    ) -> Result<Self, DurableError> {
684        let ctx = Self::start(db, name, input).await?;
685        ctx.set_timeout(timeout_ms).await?;
686        Ok(ctx)
687    }
688
689    // ── Workflow control (management API) ─────────────────────────
690
691    /// Pause a workflow by ID. Sets status to PAUSED and recursively
692    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
693    ///
694    /// Only workflows in PENDING or RUNNING status can be paused.
695    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
696        let model = Task::find_by_id(task_id).one(db).await?;
697        let model =
698            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
699
700        match model.status {
701            TaskStatus::Pending | TaskStatus::Running => {}
702            status => {
703                return Err(DurableError::custom(format!(
704                    "cannot pause task in {status} status"
705                )));
706            }
707        }
708
709        // Pause the task itself and all descendants in one recursive CTE
710        let sql = format!(
711            "WITH RECURSIVE descendants AS ( \
712                 SELECT id FROM durable.task WHERE id = '{task_id}' \
713                 UNION ALL \
714                 SELECT t.id FROM durable.task t \
715                 INNER JOIN descendants d ON t.parent_id = d.id \
716             ) \
717             UPDATE durable.task SET status = 'PAUSED' \
718             WHERE id IN (SELECT id FROM descendants) \
719               AND status IN ('PENDING', 'RUNNING')"
720        );
721        db.execute(Statement::from_string(DbBackend::Postgres, sql))
722            .await?;
723
724        tracing::info!(%task_id, "workflow paused");
725        Ok(())
726    }
727
728    /// Resume a paused workflow by ID. Sets status back to RUNNING and
729    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
730    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
731        let model = Task::find_by_id(task_id).one(db).await?;
732        let model =
733            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
734
735        if model.status != TaskStatus::Paused {
736            return Err(DurableError::custom(format!(
737                "cannot resume task in {} status (must be PAUSED)",
738                model.status
739            )));
740        }
741
742        // Resume the root task back to RUNNING
743        let mut active: TaskActiveModel = model.into();
744        active.status = Set(TaskStatus::Running);
745        active.update(db).await?;
746
747        // Recursively resume all PAUSED descendants back to PENDING
748        let cascade_sql = format!(
749            "WITH RECURSIVE descendants AS ( \
750                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
751                 UNION ALL \
752                 SELECT t.id FROM durable.task t \
753                 INNER JOIN descendants d ON t.parent_id = d.id \
754             ) \
755             UPDATE durable.task SET status = 'PENDING' \
756             WHERE id IN (SELECT id FROM descendants) \
757               AND status = 'PAUSED'"
758        );
759        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
760            .await?;
761
762        tracing::info!(%task_id, "workflow resumed");
763        Ok(())
764    }
765
766    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
767    /// cascades to all non-terminal descendants.
768    ///
769    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
770    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
771        let model = Task::find_by_id(task_id).one(db).await?;
772        let model =
773            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
774
775        match model.status {
776            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
777                return Err(DurableError::custom(format!(
778                    "cannot cancel task in {} status",
779                    model.status
780                )));
781            }
782            _ => {}
783        }
784
785        // Cancel the task itself and all non-terminal descendants in one recursive CTE
786        let sql = format!(
787            "WITH RECURSIVE descendants AS ( \
788                 SELECT id FROM durable.task WHERE id = '{task_id}' \
789                 UNION ALL \
790                 SELECT t.id FROM durable.task t \
791                 INNER JOIN descendants d ON t.parent_id = d.id \
792             ) \
793             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
794             WHERE id IN (SELECT id FROM descendants) \
795               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
796        );
797        db.execute(Statement::from_string(DbBackend::Postgres, sql))
798            .await?;
799
800        tracing::info!(%task_id, "workflow cancelled");
801        Ok(())
802    }
803
804    // ── Query API ─────────────────────────────────────────────────
805
806    /// List tasks matching the given filter, with sorting and pagination.
807    ///
808    /// ```ignore
809    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
810    /// ```
811    pub async fn list(
812        db: &DatabaseConnection,
813        query: TaskQuery,
814    ) -> Result<Vec<TaskSummary>, DurableError> {
815        let mut select = Task::find();
816
817        // Filters
818        if let Some(status) = &query.status {
819            select = select.filter(TaskColumn::Status.eq(status.to_string()));
820        }
821        if let Some(kind) = &query.kind {
822            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
823        }
824        if let Some(parent_id) = query.parent_id {
825            select = select.filter(TaskColumn::ParentId.eq(parent_id));
826        }
827        if query.root_only {
828            select = select.filter(TaskColumn::ParentId.is_null());
829        }
830        if let Some(name) = &query.name {
831            select = select.filter(TaskColumn::Name.eq(name.as_str()));
832        }
833        if let Some(queue) = &query.queue_name {
834            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
835        }
836
837        // Sorting
838        let (col, order) = match query.sort {
839            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
840            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
841            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
842            TaskSort::Name(ord) => (TaskColumn::Name, ord),
843            TaskSort::Status(ord) => (TaskColumn::Status, ord),
844        };
845        select = select.order_by(col, order);
846
847        // Pagination
848        if let Some(offset) = query.offset {
849            select = select.offset(offset);
850        }
851        if let Some(limit) = query.limit {
852            select = select.limit(limit);
853        }
854
855        let models = select.all(db).await?;
856
857        Ok(models.into_iter().map(TaskSummary::from).collect())
858    }
859
860    /// Count tasks matching the given filter.
861    pub async fn count(
862        db: &DatabaseConnection,
863        query: TaskQuery,
864    ) -> Result<u64, DurableError> {
865        let mut select = Task::find();
866
867        if let Some(status) = &query.status {
868            select = select.filter(TaskColumn::Status.eq(status.to_string()));
869        }
870        if let Some(kind) = &query.kind {
871            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
872        }
873        if let Some(parent_id) = query.parent_id {
874            select = select.filter(TaskColumn::ParentId.eq(parent_id));
875        }
876        if query.root_only {
877            select = select.filter(TaskColumn::ParentId.is_null());
878        }
879        if let Some(name) = &query.name {
880            select = select.filter(TaskColumn::Name.eq(name.as_str()));
881        }
882        if let Some(queue) = &query.queue_name {
883            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
884        }
885
886        let count = select.count(db).await?;
887        Ok(count)
888    }
889
890    // ── Accessors ────────────────────────────────────────────────
891
892    pub fn db(&self) -> &DatabaseConnection {
893        &self.db
894    }
895
896    pub fn task_id(&self) -> Uuid {
897        self.task_id
898    }
899
900    pub fn next_sequence(&self) -> i32 {
901        self.sequence.fetch_add(1, Ordering::SeqCst)
902    }
903}
904
905// ── Internal SQL helpers ─────────────────────────────────────────────
906
907/// Find an existing task by (parent_id, name) or create a new one.
908///
909/// Returns `(task_id, Option<saved_output>)`:
910/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
911/// - `saved_output` is `None` when the task is new or in-progress.
912///
913/// When `lock` is `true` and an existing non-completed task is found, this
914/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
915/// another worker holds the lock, `DurableError::StepLocked` is returned so
916/// the caller can skip execution rather than double-firing side effects.
917///
918/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
919/// and child-workflow creation where concurrent start is safe).
920///
921/// When `lock` is `true`, the caller MUST call this within a transaction so
922/// the row lock is held throughout step execution.
923///
924/// `max_retries`: if Some, overrides the schema default when creating a new task.
925#[allow(clippy::too_many_arguments)]
926async fn find_or_create_task(
927    db: &impl ConnectionTrait,
928    parent_id: Option<Uuid>,
929    sequence: Option<i32>,
930    name: &str,
931    kind: &str,
932    input: Option<serde_json::Value>,
933    lock: bool,
934    max_retries: Option<u32>,
935) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
936    let parent_eq = match parent_id {
937        Some(p) => format!("= '{p}'"),
938        None => "IS NULL".to_string(),
939    };
940    let parent_sql = match parent_id {
941        Some(p) => format!("'{p}'"),
942        None => "NULL".to_string(),
943    };
944
945    if lock {
946        // Locking path (for steps): we need exactly-once execution.
947        //
948        // Strategy:
949        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
950        //    If another transaction is concurrently inserting the same row,
951        //    Postgres will block here until that transaction commits or rolls
952        //    back, ensuring we never see a phantom "not found" for a row being
953        //    inserted.
954        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
955        //    should get it back; if the row existed and is locked by another
956        //    worker we get nothing.
957        // 3. If SKIP LOCKED returns empty, return StepLocked.
958
959        let new_id = Uuid::new_v4();
960        let seq_sql = match sequence {
961            Some(s) => s.to_string(),
962            None => "NULL".to_string(),
963        };
964        let input_sql = match &input {
965            Some(v) => format!("'{}'", serde_json::to_string(v)?),
966            None => "NULL".to_string(),
967        };
968
969        let max_retries_sql = match max_retries {
970            Some(r) => r.to_string(),
971            None => "3".to_string(), // schema default
972        };
973
974        // Step 1: insert-or-skip
975        let insert_sql = format!(
976            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
977             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
978             ON CONFLICT (parent_id, sequence) DO NOTHING"
979        );
980        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
981            .await?;
982
983        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
984        let lock_sql = format!(
985            "SELECT id, status::text, output FROM durable.task \
986             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
987             FOR UPDATE SKIP LOCKED"
988        );
989        let row = db
990            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
991            .await?;
992
993        if let Some(row) = row {
994            let id: Uuid = row
995                .try_get_by_index(0)
996                .map_err(|e| DurableError::custom(e.to_string()))?;
997            let status: String = row
998                .try_get_by_index(1)
999                .map_err(|e| DurableError::custom(e.to_string()))?;
1000            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1001
1002            if status == TaskStatus::Completed.to_string() {
1003                // Replay path — return saved output
1004                return Ok((id, output));
1005            }
1006            // Task exists and we hold the lock — proceed to execute
1007            return Ok((id, None));
1008        }
1009
1010        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
1011        Err(DurableError::StepLocked(name.to_string()))
1012    } else {
1013        // Plain find without locking — safe for workflow-level operations.
1014        // Multiple workers resuming the same workflow is fine; individual
1015        // steps will be locked when executed.
1016        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1017        query = match parent_id {
1018            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1019            None => query.filter(TaskColumn::ParentId.is_null()),
1020        };
1021        // Skip cancelled/failed tasks — a fresh row will be created instead.
1022        // Completed tasks are still returned so child tasks can replay saved output.
1023        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1024        let existing = query
1025            .filter(TaskColumn::Status.is_not_in(status_exclusions))
1026            .one(db)
1027            .await?;
1028
1029        if let Some(model) = existing {
1030            if model.status == TaskStatus::Completed {
1031                return Ok((model.id, model.output));
1032            }
1033            return Ok((model.id, None));
1034        }
1035
1036        // Task does not exist — create it
1037        let id = Uuid::new_v4();
1038        let new_task = TaskActiveModel {
1039            id: Set(id),
1040            parent_id: Set(parent_id),
1041            sequence: Set(sequence),
1042            name: Set(name.to_string()),
1043            kind: Set(kind.to_string()),
1044            status: Set(TaskStatus::Pending),
1045            input: Set(input),
1046            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1047            ..Default::default()
1048        };
1049        new_task.insert(db).await?;
1050
1051        Ok((id, None))
1052    }
1053}
1054
1055async fn get_output(
1056    db: &impl ConnectionTrait,
1057    task_id: Uuid,
1058) -> Result<Option<serde_json::Value>, DurableError> {
1059    let model = Task::find_by_id(task_id)
1060        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1061        .one(db)
1062        .await?;
1063
1064    Ok(model.and_then(|m| m.output))
1065}
1066
1067async fn get_status(
1068    db: &impl ConnectionTrait,
1069    task_id: Uuid,
1070) -> Result<Option<TaskStatus>, DurableError> {
1071    let model = Task::find_by_id(task_id).one(db).await?;
1072
1073    Ok(model.map(|m| m.status))
1074}
1075
1076/// Returns (retry_count, max_retries) for a task.
1077async fn get_retry_info(
1078    db: &DatabaseConnection,
1079    task_id: Uuid,
1080) -> Result<(u32, u32), DurableError> {
1081    let model = Task::find_by_id(task_id).one(db).await?;
1082
1083    match model {
1084        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1085        None => Err(DurableError::custom(format!(
1086            "task {task_id} not found when reading retry info"
1087        ))),
1088    }
1089}
1090
1091/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1092async fn increment_retry_count(
1093    db: &DatabaseConnection,
1094    task_id: Uuid,
1095) -> Result<u32, DurableError> {
1096    let model = Task::find_by_id(task_id).one(db).await?;
1097
1098    match model {
1099        Some(m) => {
1100            let new_count = m.retry_count + 1;
1101            let mut active: TaskActiveModel = m.into();
1102            active.retry_count = Set(new_count);
1103            active.status = Set(TaskStatus::Pending);
1104            active.error = Set(None);
1105            active.completed_at = Set(None);
1106            active.update(db).await?;
1107            Ok(new_count as u32)
1108        }
1109        None => Err(DurableError::custom(format!(
1110            "task {task_id} not found when incrementing retry count"
1111        ))),
1112    }
1113}
1114
1115// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1116// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1117async fn set_status(
1118    db: &impl ConnectionTrait,
1119    task_id: Uuid,
1120    status: TaskStatus,
1121) -> Result<(), DurableError> {
1122    let sql = format!(
1123        "UPDATE durable.task \
1124         SET status = '{status}', \
1125             started_at = COALESCE(started_at, now()), \
1126             deadline_epoch_ms = CASE \
1127                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1128                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1129                 ELSE deadline_epoch_ms \
1130             END \
1131         WHERE id = '{task_id}'"
1132    );
1133    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1134        .await?;
1135    Ok(())
1136}
1137
1138/// Check if the task is paused or cancelled. Returns an error if so.
1139async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1140    let status = get_status(db, task_id).await?;
1141    match status {
1142        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1143        Some(TaskStatus::Cancelled) => {
1144            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1145        }
1146        _ => Ok(()),
1147    }
1148}
1149
1150/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1151async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1152    let model = Task::find_by_id(task_id).one(db).await?;
1153
1154    if let Some(m) = model
1155        && let Some(deadline_ms) = m.deadline_epoch_ms
1156    {
1157        let now_ms = std::time::SystemTime::now()
1158            .duration_since(std::time::UNIX_EPOCH)
1159            .map(|d| d.as_millis() as i64)
1160            .unwrap_or(0);
1161        if now_ms > deadline_ms {
1162            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1163        }
1164    }
1165
1166    Ok(())
1167}
1168
1169async fn complete_task(
1170    db: &impl ConnectionTrait,
1171    task_id: Uuid,
1172    output: serde_json::Value,
1173) -> Result<(), DurableError> {
1174    let model = Task::find_by_id(task_id).one(db).await?;
1175
1176    if let Some(m) = model {
1177        let mut active: TaskActiveModel = m.into();
1178        active.status = Set(TaskStatus::Completed);
1179        active.output = Set(Some(output));
1180        active.completed_at = Set(Some(chrono::Utc::now().into()));
1181        active.update(db).await?;
1182    }
1183    Ok(())
1184}
1185
1186async fn fail_task(
1187    db: &impl ConnectionTrait,
1188    task_id: Uuid,
1189    error: &str,
1190) -> Result<(), DurableError> {
1191    let model = Task::find_by_id(task_id).one(db).await?;
1192
1193    if let Some(m) = model {
1194        let mut active: TaskActiveModel = m.into();
1195        active.status = Set(TaskStatus::Failed);
1196        active.error = Set(Some(error.to_string()));
1197        active.completed_at = Set(Some(chrono::Utc::now().into()));
1198        active.update(db).await?;
1199    }
1200    Ok(())
1201}
1202
1203#[cfg(test)]
1204mod tests {
1205    use super::*;
1206    use std::sync::Arc;
1207    use std::sync::atomic::{AtomicU32, Ordering};
1208
1209    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1210    /// should be called exactly once and return Ok.
1211    #[tokio::test]
1212    async fn test_retry_db_write_succeeds_first_try() {
1213        let call_count = Arc::new(AtomicU32::new(0));
1214        let cc = call_count.clone();
1215        let result = retry_db_write(|| {
1216            let c = cc.clone();
1217            async move {
1218                c.fetch_add(1, Ordering::SeqCst);
1219                Ok::<(), DurableError>(())
1220            }
1221        })
1222        .await;
1223        assert!(result.is_ok());
1224        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1225    }
1226
1227    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1228    /// fails twice then succeeds should return Ok and be called 3 times.
1229    #[tokio::test]
1230    async fn test_retry_db_write_succeeds_after_transient_failure() {
1231        let call_count = Arc::new(AtomicU32::new(0));
1232        let cc = call_count.clone();
1233        let result = retry_db_write(|| {
1234            let c = cc.clone();
1235            async move {
1236                let n = c.fetch_add(1, Ordering::SeqCst);
1237                if n < 2 {
1238                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1239                        "transient".to_string(),
1240                    )))
1241                } else {
1242                    Ok(())
1243                }
1244            }
1245        })
1246        .await;
1247        assert!(result.is_ok());
1248        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1249    }
1250
1251    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1252    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1253    #[tokio::test]
1254    async fn test_retry_db_write_exhausts_retries() {
1255        let call_count = Arc::new(AtomicU32::new(0));
1256        let cc = call_count.clone();
1257        let result = retry_db_write(|| {
1258            let c = cc.clone();
1259            async move {
1260                c.fetch_add(1, Ordering::SeqCst);
1261                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1262                    "always fails".to_string(),
1263                )))
1264            }
1265        })
1266        .await;
1267        assert!(result.is_err());
1268        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1269        assert_eq!(
1270            call_count.load(Ordering::SeqCst),
1271            1 + MAX_CHECKPOINT_RETRIES
1272        );
1273    }
1274
1275    /// test_retry_db_write_returns_original_error: when all retries are
1276    /// exhausted the FIRST error is returned, not the last retry error.
1277    #[tokio::test]
1278    async fn test_retry_db_write_returns_original_error() {
1279        let call_count = Arc::new(AtomicU32::new(0));
1280        let cc = call_count.clone();
1281        let result = retry_db_write(|| {
1282            let c = cc.clone();
1283            async move {
1284                let n = c.fetch_add(1, Ordering::SeqCst);
1285                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1286                    "error-{}",
1287                    n
1288                ))))
1289            }
1290        })
1291        .await;
1292        let err = result.unwrap_err();
1293        // The message of the first error contains "error-0"
1294        assert!(
1295            err.to_string().contains("error-0"),
1296            "expected first error (error-0), got: {err}"
1297        );
1298    }
1299}