Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── Retry constants ───────────────────────────────────────────────────────────
21
22const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25/// Retry a fallible DB write with exponential backoff.
26///
27/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
28/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
29/// error if all retries are exhausted.
30async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32    F: FnMut() -> Fut,
33    Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35    match f().await {
36        Ok(()) => Ok(()),
37        Err(first_err) => {
38            for i in 0..MAX_CHECKPOINT_RETRIES {
39                tokio::time::sleep(Duration::from_millis(
40                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41                ))
42                .await;
43                if f().await.is_ok() {
44                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45                    return Ok(());
46                }
47            }
48            Err(first_err)
49        }
50    }
51}
52
53/// Policy for retrying a step on failure.
54pub struct RetryPolicy {
55    pub max_retries: u32,
56    pub initial_backoff: std::time::Duration,
57    pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61    /// No retries — fails immediately on error.
62    pub fn none() -> Self {
63        Self {
64            max_retries: 0,
65            initial_backoff: std::time::Duration::from_secs(0),
66            backoff_multiplier: 1.0,
67        }
68    }
69
70    /// Exponential backoff: backoff doubles each retry.
71    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72        Self {
73            max_retries,
74            initial_backoff,
75            backoff_multiplier: 2.0,
76        }
77    }
78
79    /// Fixed backoff: same duration between all retries.
80    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff: backoff,
84            backoff_multiplier: 1.0,
85        }
86    }
87}
88
89/// Sort order for task listing.
90pub enum TaskSort {
91    CreatedAt(Order),
92    StartedAt(Order),
93    CompletedAt(Order),
94    Name(Order),
95    Status(Order),
96}
97
98/// Builder for filtering, sorting, and paginating task queries.
99///
100/// ```ignore
101/// let query = TaskQuery::default()
102///     .status(TaskStatus::Running)
103///     .kind("WORKFLOW")
104///     .root_only(true)
105///     .sort(TaskSort::CreatedAt(Order::Desc))
106///     .limit(20);
107/// let tasks = Ctx::list(&db, query).await?;
108/// ```
109pub struct TaskQuery {
110    pub status: Option<TaskStatus>,
111    pub kind: Option<String>,
112    pub parent_id: Option<Uuid>,
113    pub root_only: bool,
114    pub name: Option<String>,
115    pub queue_name: Option<String>,
116    pub sort: TaskSort,
117    pub limit: Option<u64>,
118    pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122    fn default() -> Self {
123        Self {
124            status: None,
125            kind: None,
126            parent_id: None,
127            root_only: false,
128            name: None,
129            queue_name: None,
130            sort: TaskSort::CreatedAt(Order::Desc),
131            limit: None,
132            offset: None,
133        }
134    }
135}
136
137impl TaskQuery {
138    /// Filter by status.
139    pub fn status(mut self, status: TaskStatus) -> Self {
140        self.status = Some(status);
141        self
142    }
143
144    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
145    pub fn kind(mut self, kind: &str) -> Self {
146        self.kind = Some(kind.to_string());
147        self
148    }
149
150    /// Filter by parent task ID (direct children only).
151    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152        self.parent_id = Some(parent_id);
153        self
154    }
155
156    /// Only return root tasks (no parent).
157    pub fn root_only(mut self, root_only: bool) -> Self {
158        self.root_only = root_only;
159        self
160    }
161
162    /// Filter by task name.
163    pub fn name(mut self, name: &str) -> Self {
164        self.name = Some(name.to_string());
165        self
166    }
167
168    /// Filter by queue name.
169    pub fn queue_name(mut self, queue: &str) -> Self {
170        self.queue_name = Some(queue.to_string());
171        self
172    }
173
174    /// Set the sort order.
175    pub fn sort(mut self, sort: TaskSort) -> Self {
176        self.sort = sort;
177        self
178    }
179
180    /// Limit the number of results.
181    pub fn limit(mut self, limit: u64) -> Self {
182        self.limit = Some(limit);
183        self
184    }
185
186    /// Skip the first N results.
187    pub fn offset(mut self, offset: u64) -> Self {
188        self.offset = Some(offset);
189        self
190    }
191}
192
193/// Summary of a task returned by `Ctx::list()`.
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196    pub id: Uuid,
197    pub parent_id: Option<Uuid>,
198    pub name: String,
199    pub status: TaskStatus,
200    pub kind: String,
201    pub input: Option<serde_json::Value>,
202    pub output: Option<serde_json::Value>,
203    pub error: Option<String>,
204    pub queue_name: Option<String>,
205    pub created_at: chrono::DateTime<chrono::FixedOffset>,
206    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211    fn from(m: durable_db::entity::task::Model) -> Self {
212        Self {
213            id: m.id,
214            parent_id: m.parent_id,
215            name: m.name,
216            status: m.status,
217            kind: m.kind,
218            input: m.input,
219            output: m.output,
220            error: m.error,
221            queue_name: m.queue_name,
222            created_at: m.created_at,
223            started_at: m.started_at,
224            completed_at: m.completed_at,
225        }
226    }
227}
228
229/// Context threaded through every workflow and step.
230///
231/// Users never create or manage task IDs. The SDK handles everything
232/// via `(parent_id, name)` lookups — the unique constraint in the schema
233/// guarantees exactly-once step creation.
234pub struct Ctx {
235    db: DatabaseConnection,
236    task_id: Uuid,
237    sequence: AtomicI32,
238    executor_id: Option<String>,
239}
240
241impl Ctx {
242    // ── Workflow lifecycle (user-facing) ──────────────────────────
243
244    /// Start a new root workflow. Always creates a fresh task with a new ID.
245    ///
246    /// To reactivate a paused workflow, use [`Ctx::resume`] instead.
247    ///
248    /// ```ignore
249    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
250    /// ```
251    pub async fn start(
252        db: &DatabaseConnection,
253        name: &str,
254        input: Option<serde_json::Value>,
255    ) -> Result<Self, DurableError> {
256        Self::start_with_executor(db, name, input, None).await
257    }
258
259    /// Start a new root workflow, tagging the task with an `executor_id` so
260    /// heartbeat-based recovery can find it if this process crashes.
261    ///
262    /// Prefer [`DurableInstance::start_workflow`] which passes the executor_id
263    /// automatically.
264    pub async fn start_with_executor(
265        db: &DatabaseConnection,
266        name: &str,
267        input: Option<serde_json::Value>,
268        executor_id: Option<String>,
269    ) -> Result<Self, DurableError> {
270        let task_id = Uuid::new_v4();
271        let input_json = match &input {
272            Some(v) => serde_json::to_string(v)?,
273            None => "null".to_string(),
274        };
275
276        let executor_col = if executor_id.is_some() { ", executor_id" } else { "" };
277        let executor_val = match &executor_id {
278            Some(eid) => format!(", '{eid}'"),
279            None => String::new(),
280        };
281
282        let txn = db.begin().await?;
283        let sql = format!(
284            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{executor_col}) \
285             VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){executor_val})"
286        );
287        txn.execute(Statement::from_string(DbBackend::Postgres, sql))
288            .await?;
289        txn.commit().await?;
290
291        Ok(Self {
292            db: db.clone(),
293            task_id,
294            sequence: AtomicI32::new(0),
295            executor_id,
296        })
297    }
298
299    /// Attach to an existing workflow by task ID. Used to resume a running or
300    /// recovered workflow without creating a new task row.
301    ///
302    /// ```ignore
303    /// let ctx = Ctx::from_id(&db, task_id).await?;
304    /// ```
305    pub async fn from_id(
306        db: &DatabaseConnection,
307        task_id: Uuid,
308    ) -> Result<Self, DurableError> {
309        // Verify the task exists
310        let model = Task::find_by_id(task_id).one(db).await?;
311        let model =
312            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
313
314        // Start sequence at 0 so that replaying steps from the beginning works
315        // correctly. Steps are looked up by (parent_id, sequence), so the caller
316        // will replay completed steps (getting saved outputs) before executing
317        // any new steps.
318        Ok(Self {
319            db: db.clone(),
320            task_id,
321            sequence: AtomicI32::new(0),
322            executor_id: model.executor_id,
323        })
324    }
325
326    /// Run a step. If already completed, returns saved output. Otherwise executes
327    /// the closure, saves the result, and returns it.
328    ///
329    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
330    ///
331    /// ```ignore
332    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
333    /// ```
334    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
335    where
336        T: Serialize + DeserializeOwned,
337        F: FnOnce() -> Fut,
338        Fut: std::future::Future<Output = Result<T, DurableError>>,
339    {
340        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
341
342        // Check if workflow is paused or cancelled before executing
343        check_status(&self.db, self.task_id).await?;
344
345        // Check if parent task's deadline has passed before executing
346        check_deadline(&self.db, self.task_id).await?;
347
348        // Begin transaction — the FOR UPDATE lock is held throughout step execution
349        let txn = self.db.begin().await?;
350
351        // Find or create — idempotent via UNIQUE(parent_id, name)
352        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
353        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
354        // max_retries=0: step() does not retry; use step_with_retry() for retries.
355        let (step_id, saved_output) = find_or_create_task(
356            &txn,
357            Some(self.task_id),
358            Some(seq),
359            name,
360            "STEP",
361            None,
362            true,
363            Some(0),
364        )
365        .await?;
366
367        // If already completed, replay from saved output
368        if let Some(output) = saved_output {
369            txn.commit().await?;
370            let val: T = serde_json::from_value(output)?;
371            tracing::debug!(step = name, seq, "replaying saved output");
372            return Ok(val);
373        }
374
375        // Execute the step closure within the transaction (row lock is held)
376        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
377        match f().await {
378            Ok(val) => {
379                let json = serde_json::to_value(&val)?;
380                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
381                txn.commit().await?;
382                tracing::debug!(step = name, seq, "step completed");
383                Ok(val)
384            }
385            Err(e) => {
386                let err_msg = e.to_string();
387                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
388                txn.commit().await?;
389                Err(e)
390            }
391        }
392    }
393
394    /// Run a DB-only step inside a single Postgres transaction.
395    ///
396    /// Both the user's DB work and the checkpoint save happen in the same transaction,
397    /// ensuring atomicity. If the closure returns an error, both the user writes and
398    /// the checkpoint are rolled back.
399    ///
400    /// ```ignore
401    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
402    ///     do_db_work(tx).await
403    /// })).await?;
404    /// ```
405    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
406    where
407        T: Serialize + DeserializeOwned + Send,
408        F: for<'tx> FnOnce(
409                &'tx DatabaseTransaction,
410            ) -> Pin<
411                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
412            > + Send,
413    {
414        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
415
416        // Check if workflow is paused or cancelled before executing
417        check_status(&self.db, self.task_id).await?;
418
419        // Find or create the step task record OUTSIDE the transaction.
420        // This is idempotent (UNIQUE constraint) and must exist before we begin.
421        let (step_id, saved_output) = find_or_create_task(
422            &self.db,
423            Some(self.task_id),
424            Some(seq),
425            name,
426            "TRANSACTION",
427            None,
428            false,
429            None,
430        )
431        .await?;
432
433        // If already completed, replay from saved output.
434        if let Some(output) = saved_output {
435            let val: T = serde_json::from_value(output)?;
436            tracing::debug!(step = name, seq, "replaying saved transaction output");
437            return Ok(val);
438        }
439
440        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
441        let tx = self.db.begin().await?;
442
443        set_status(&tx, step_id, TaskStatus::Running).await?;
444
445        match f(&tx).await {
446            Ok(val) => {
447                let json = serde_json::to_value(&val)?;
448                complete_task(&tx, step_id, json).await?;
449                tx.commit().await?;
450                tracing::debug!(step = name, seq, "transaction step committed");
451                Ok(val)
452            }
453            Err(e) => {
454                // Rollback happens automatically when tx is dropped.
455                // Record the failure on the main connection (outside the rolled-back tx).
456                drop(tx);
457                fail_task(&self.db, step_id, &e.to_string()).await?;
458                Err(e)
459            }
460        }
461    }
462
463    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
464    ///
465    /// ```ignore
466    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
467    /// // use child_ctx.step(...) for steps inside the child
468    /// child_ctx.complete(json!({"done": true})).await?;
469    /// ```
470    pub async fn child(
471        &self,
472        name: &str,
473        input: Option<serde_json::Value>,
474    ) -> Result<Self, DurableError> {
475        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
476
477        // Check if workflow is paused or cancelled before executing
478        check_status(&self.db, self.task_id).await?;
479
480        let txn = self.db.begin().await?;
481        // Child workflows also use plain find-or-create without locking.
482        let (child_id, _saved) = find_or_create_task(
483            &txn,
484            Some(self.task_id),
485            Some(seq),
486            name,
487            "WORKFLOW",
488            input,
489            false,
490            None,
491        )
492        .await?;
493
494        // If child already completed, return a Ctx that will replay
495        // (the caller should check is_completed() or just run steps which will replay)
496        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
497        txn.commit().await?;
498
499        Ok(Self {
500            db: self.db.clone(),
501            task_id: child_id,
502            sequence: AtomicI32::new(0),
503            executor_id: self.executor_id.clone(),
504        })
505    }
506
507    /// Check if this workflow/child was already completed (for skipping in parent).
508    pub async fn is_completed(&self) -> Result<bool, DurableError> {
509        let status = get_status(&self.db, self.task_id).await?;
510        Ok(status == Some(TaskStatus::Completed))
511    }
512
513    /// Get the saved output if this task is completed.
514    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
515        match get_output(&self.db, self.task_id).await? {
516            Some(val) => Ok(Some(serde_json::from_value(val)?)),
517            None => Ok(None),
518        }
519    }
520
521    /// Mark this workflow as completed with an output value.
522    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
523        let json = serde_json::to_value(output)?;
524        let db = &self.db;
525        let task_id = self.task_id;
526        retry_db_write(|| complete_task(db, task_id, json.clone())).await
527    }
528
529    /// Run a step with a configurable retry policy.
530    ///
531    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
532    /// it may be called multiple times on retry. Retries happen in-process with
533    /// configurable backoff between attempts.
534    ///
535    /// ```ignore
536    /// let result: u32 = ctx
537    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
538    ///         api.call().await
539    ///     })
540    ///     .await?;
541    /// ```
542    pub async fn step_with_retry<T, F, Fut>(
543        &self,
544        name: &str,
545        policy: RetryPolicy,
546        f: F,
547    ) -> Result<T, DurableError>
548    where
549        T: Serialize + DeserializeOwned,
550        F: Fn() -> Fut,
551        Fut: std::future::Future<Output = Result<T, DurableError>>,
552    {
553        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
554
555        // Check if workflow is paused or cancelled before executing
556        check_status(&self.db, self.task_id).await?;
557
558        // Find or create — idempotent via UNIQUE(parent_id, name)
559        // Set max_retries from policy when creating.
560        // No locking here — retry logic handles re-execution in-process.
561        let (step_id, saved_output) = find_or_create_task(
562            &self.db,
563            Some(self.task_id),
564            Some(seq),
565            name,
566            "STEP",
567            None,
568            false,
569            Some(policy.max_retries),
570        )
571        .await?;
572
573        // If already completed, replay from saved output
574        if let Some(output) = saved_output {
575            let val: T = serde_json::from_value(output)?;
576            tracing::debug!(step = name, seq, "replaying saved output");
577            return Ok(val);
578        }
579
580        // Get current retry state from DB (for resume across restarts)
581        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
582
583        // Retry loop
584        loop {
585            // Re-check status before each retry attempt
586            check_status(&self.db, self.task_id).await?;
587            set_status(&self.db, step_id, TaskStatus::Running).await?;
588            match f().await {
589                Ok(val) => {
590                    let json = serde_json::to_value(&val)?;
591                    complete_task(&self.db, step_id, json).await?;
592                    tracing::debug!(step = name, seq, retry_count, "step completed");
593                    return Ok(val);
594                }
595                Err(e) => {
596                    if retry_count < max_retries {
597                        // Increment retry count and reset to PENDING
598                        retry_count = increment_retry_count(&self.db, step_id).await?;
599                        tracing::debug!(
600                            step = name,
601                            seq,
602                            retry_count,
603                            max_retries,
604                            "step failed, retrying"
605                        );
606
607                        // Compute backoff duration
608                        let backoff = if policy.initial_backoff.is_zero() {
609                            std::time::Duration::ZERO
610                        } else {
611                            let factor = policy
612                                .backoff_multiplier
613                                .powi((retry_count - 1) as i32)
614                                .max(1.0);
615                            let millis =
616                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
617                            std::time::Duration::from_millis(millis)
618                        };
619
620                        if !backoff.is_zero() {
621                            tokio::time::sleep(backoff).await;
622                        }
623                    } else {
624                        // Exhausted retries — mark FAILED
625                        fail_task(&self.db, step_id, &e.to_string()).await?;
626                        tracing::debug!(
627                            step = name,
628                            seq,
629                            retry_count,
630                            "step exhausted retries, marked FAILED"
631                        );
632                        return Err(e);
633                    }
634                }
635            }
636        }
637    }
638
639    /// Mark this workflow as failed.
640    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
641        let db = &self.db;
642        let task_id = self.task_id;
643        retry_db_write(|| fail_task(db, task_id, error)).await
644    }
645
646    /// Set the timeout for this task in milliseconds.
647    ///
648    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
649    /// for recovery by `Executor::recover()`.
650    ///
651    /// If the task is currently RUNNING and has a `started_at`, this also
652    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
653    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
654        let sql = format!(
655            "UPDATE durable.task \
656             SET timeout_ms = {timeout_ms}, \
657                 deadline_epoch_ms = CASE \
658                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
659                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
660                     ELSE deadline_epoch_ms \
661                 END \
662             WHERE id = '{}'",
663            self.task_id
664        );
665        self.db
666            .execute(Statement::from_string(DbBackend::Postgres, sql))
667            .await?;
668        Ok(())
669    }
670
671    /// Start or resume a root workflow with a timeout in milliseconds.
672    ///
673    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
674    pub async fn start_with_timeout(
675        db: &DatabaseConnection,
676        name: &str,
677        input: Option<serde_json::Value>,
678        timeout_ms: i64,
679    ) -> Result<Self, DurableError> {
680        let ctx = Self::start(db, name, input).await?;
681        ctx.set_timeout(timeout_ms).await?;
682        Ok(ctx)
683    }
684
685    /// Start or resume a root workflow with a timeout, tagging it with an executor_id.
686    pub async fn start_with_timeout_and_executor(
687        db: &DatabaseConnection,
688        name: &str,
689        input: Option<serde_json::Value>,
690        timeout_ms: i64,
691        executor_id: Option<String>,
692    ) -> Result<Self, DurableError> {
693        let ctx = Self::start_with_executor(db, name, input, executor_id).await?;
694        ctx.set_timeout(timeout_ms).await?;
695        Ok(ctx)
696    }
697
698    // ── Workflow control (management API) ─────────────────────────
699
700    /// Pause a workflow by ID. Sets status to PAUSED and recursively
701    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
702    ///
703    /// Only workflows in PENDING or RUNNING status can be paused.
704    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
705        let model = Task::find_by_id(task_id).one(db).await?;
706        let model =
707            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
708
709        match model.status {
710            TaskStatus::Pending | TaskStatus::Running => {}
711            status => {
712                return Err(DurableError::custom(format!(
713                    "cannot pause task in {status} status"
714                )));
715            }
716        }
717
718        // Pause the task itself and all descendants in one recursive CTE
719        let sql = format!(
720            "WITH RECURSIVE descendants AS ( \
721                 SELECT id FROM durable.task WHERE id = '{task_id}' \
722                 UNION ALL \
723                 SELECT t.id FROM durable.task t \
724                 INNER JOIN descendants d ON t.parent_id = d.id \
725             ) \
726             UPDATE durable.task SET status = 'PAUSED' \
727             WHERE id IN (SELECT id FROM descendants) \
728               AND status IN ('PENDING', 'RUNNING')"
729        );
730        db.execute(Statement::from_string(DbBackend::Postgres, sql))
731            .await?;
732
733        tracing::info!(%task_id, "workflow paused");
734        Ok(())
735    }
736
737    /// Resume a paused workflow by ID. Sets status back to RUNNING and
738    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
739    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
740        let model = Task::find_by_id(task_id).one(db).await?;
741        let model =
742            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
743
744        if model.status != TaskStatus::Paused {
745            return Err(DurableError::custom(format!(
746                "cannot resume task in {} status (must be PAUSED)",
747                model.status
748            )));
749        }
750
751        // Resume the root task back to RUNNING
752        let mut active: TaskActiveModel = model.into();
753        active.status = Set(TaskStatus::Running);
754        active.update(db).await?;
755
756        // Recursively resume all PAUSED descendants back to PENDING
757        let cascade_sql = format!(
758            "WITH RECURSIVE descendants AS ( \
759                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
760                 UNION ALL \
761                 SELECT t.id FROM durable.task t \
762                 INNER JOIN descendants d ON t.parent_id = d.id \
763             ) \
764             UPDATE durable.task SET status = 'PENDING' \
765             WHERE id IN (SELECT id FROM descendants) \
766               AND status = 'PAUSED'"
767        );
768        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
769            .await?;
770
771        tracing::info!(%task_id, "workflow resumed");
772        Ok(())
773    }
774
775    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
776    /// cascades to all non-terminal descendants.
777    ///
778    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
779    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
780        let model = Task::find_by_id(task_id).one(db).await?;
781        let model =
782            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
783
784        match model.status {
785            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
786                return Err(DurableError::custom(format!(
787                    "cannot cancel task in {} status",
788                    model.status
789                )));
790            }
791            _ => {}
792        }
793
794        // Cancel the task itself and all non-terminal descendants in one recursive CTE
795        let sql = format!(
796            "WITH RECURSIVE descendants AS ( \
797                 SELECT id FROM durable.task WHERE id = '{task_id}' \
798                 UNION ALL \
799                 SELECT t.id FROM durable.task t \
800                 INNER JOIN descendants d ON t.parent_id = d.id \
801             ) \
802             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
803             WHERE id IN (SELECT id FROM descendants) \
804               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
805        );
806        db.execute(Statement::from_string(DbBackend::Postgres, sql))
807            .await?;
808
809        tracing::info!(%task_id, "workflow cancelled");
810        Ok(())
811    }
812
813    // ── Query API ─────────────────────────────────────────────────
814
815    /// List tasks matching the given filter, with sorting and pagination.
816    ///
817    /// ```ignore
818    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
819    /// ```
820    pub async fn list(
821        db: &DatabaseConnection,
822        query: TaskQuery,
823    ) -> Result<Vec<TaskSummary>, DurableError> {
824        let mut select = Task::find();
825
826        // Filters
827        if let Some(status) = &query.status {
828            select = select.filter(TaskColumn::Status.eq(status.to_string()));
829        }
830        if let Some(kind) = &query.kind {
831            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
832        }
833        if let Some(parent_id) = query.parent_id {
834            select = select.filter(TaskColumn::ParentId.eq(parent_id));
835        }
836        if query.root_only {
837            select = select.filter(TaskColumn::ParentId.is_null());
838        }
839        if let Some(name) = &query.name {
840            select = select.filter(TaskColumn::Name.eq(name.as_str()));
841        }
842        if let Some(queue) = &query.queue_name {
843            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
844        }
845
846        // Sorting
847        let (col, order) = match query.sort {
848            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
849            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
850            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
851            TaskSort::Name(ord) => (TaskColumn::Name, ord),
852            TaskSort::Status(ord) => (TaskColumn::Status, ord),
853        };
854        select = select.order_by(col, order);
855
856        // Pagination
857        if let Some(offset) = query.offset {
858            select = select.offset(offset);
859        }
860        if let Some(limit) = query.limit {
861            select = select.limit(limit);
862        }
863
864        let models = select.all(db).await?;
865
866        Ok(models.into_iter().map(TaskSummary::from).collect())
867    }
868
869    /// Count tasks matching the given filter.
870    pub async fn count(
871        db: &DatabaseConnection,
872        query: TaskQuery,
873    ) -> Result<u64, DurableError> {
874        let mut select = Task::find();
875
876        if let Some(status) = &query.status {
877            select = select.filter(TaskColumn::Status.eq(status.to_string()));
878        }
879        if let Some(kind) = &query.kind {
880            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
881        }
882        if let Some(parent_id) = query.parent_id {
883            select = select.filter(TaskColumn::ParentId.eq(parent_id));
884        }
885        if query.root_only {
886            select = select.filter(TaskColumn::ParentId.is_null());
887        }
888        if let Some(name) = &query.name {
889            select = select.filter(TaskColumn::Name.eq(name.as_str()));
890        }
891        if let Some(queue) = &query.queue_name {
892            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
893        }
894
895        let count = select.count(db).await?;
896        Ok(count)
897    }
898
899    // ── Accessors ────────────────────────────────────────────────
900
901    pub fn db(&self) -> &DatabaseConnection {
902        &self.db
903    }
904
905    pub fn task_id(&self) -> Uuid {
906        self.task_id
907    }
908
909    pub fn next_sequence(&self) -> i32 {
910        self.sequence.fetch_add(1, Ordering::SeqCst)
911    }
912}
913
914// ── Internal SQL helpers ─────────────────────────────────────────────
915
916/// Find an existing task by (parent_id, name) or create a new one.
917///
918/// Returns `(task_id, Option<saved_output>)`:
919/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
920/// - `saved_output` is `None` when the task is new or in-progress.
921///
922/// When `lock` is `true` and an existing non-completed task is found, this
923/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
924/// another worker holds the lock, `DurableError::StepLocked` is returned so
925/// the caller can skip execution rather than double-firing side effects.
926///
927/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
928/// and child-workflow creation where concurrent start is safe).
929///
930/// When `lock` is `true`, the caller MUST call this within a transaction so
931/// the row lock is held throughout step execution.
932///
933/// `max_retries`: if Some, overrides the schema default when creating a new task.
934#[allow(clippy::too_many_arguments)]
935async fn find_or_create_task(
936    db: &impl ConnectionTrait,
937    parent_id: Option<Uuid>,
938    sequence: Option<i32>,
939    name: &str,
940    kind: &str,
941    input: Option<serde_json::Value>,
942    lock: bool,
943    max_retries: Option<u32>,
944) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
945    let parent_eq = match parent_id {
946        Some(p) => format!("= '{p}'"),
947        None => "IS NULL".to_string(),
948    };
949    let parent_sql = match parent_id {
950        Some(p) => format!("'{p}'"),
951        None => "NULL".to_string(),
952    };
953
954    if lock {
955        // Locking path (for steps): we need exactly-once execution.
956        //
957        // Strategy:
958        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
959        //    If another transaction is concurrently inserting the same row,
960        //    Postgres will block here until that transaction commits or rolls
961        //    back, ensuring we never see a phantom "not found" for a row being
962        //    inserted.
963        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
964        //    should get it back; if the row existed and is locked by another
965        //    worker we get nothing.
966        // 3. If SKIP LOCKED returns empty, return StepLocked.
967
968        let new_id = Uuid::new_v4();
969        let seq_sql = match sequence {
970            Some(s) => s.to_string(),
971            None => "NULL".to_string(),
972        };
973        let input_sql = match &input {
974            Some(v) => format!("'{}'", serde_json::to_string(v)?),
975            None => "NULL".to_string(),
976        };
977
978        let max_retries_sql = match max_retries {
979            Some(r) => r.to_string(),
980            None => "3".to_string(), // schema default
981        };
982
983        // Step 1: insert-or-skip
984        let insert_sql = format!(
985            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
986             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
987             ON CONFLICT (parent_id, sequence) DO NOTHING"
988        );
989        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
990            .await?;
991
992        // Step 2: lock the row (ours or pre-existing) by (parent_id, sequence)
993        let lock_sql = format!(
994            "SELECT id, status::text, output FROM durable.task \
995             WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
996             FOR UPDATE SKIP LOCKED"
997        );
998        let row = db
999            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1000            .await?;
1001
1002        if let Some(row) = row {
1003            let id: Uuid = row
1004                .try_get_by_index(0)
1005                .map_err(|e| DurableError::custom(e.to_string()))?;
1006            let status: String = row
1007                .try_get_by_index(1)
1008                .map_err(|e| DurableError::custom(e.to_string()))?;
1009            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1010
1011            if status == TaskStatus::Completed.to_string() {
1012                // Replay path — return saved output
1013                return Ok((id, output));
1014            }
1015            // Task exists and we hold the lock — proceed to execute
1016            return Ok((id, None));
1017        }
1018
1019        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
1020        Err(DurableError::StepLocked(name.to_string()))
1021    } else {
1022        // Plain find without locking — safe for workflow-level operations.
1023        // Multiple workers resuming the same workflow is fine; individual
1024        // steps will be locked when executed.
1025        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1026        query = match parent_id {
1027            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1028            None => query.filter(TaskColumn::ParentId.is_null()),
1029        };
1030        // Skip cancelled/failed tasks — a fresh row will be created instead.
1031        // Completed tasks are still returned so child tasks can replay saved output.
1032        let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1033        let existing = query
1034            .filter(TaskColumn::Status.is_not_in(status_exclusions))
1035            .one(db)
1036            .await?;
1037
1038        if let Some(model) = existing {
1039            if model.status == TaskStatus::Completed {
1040                return Ok((model.id, model.output));
1041            }
1042            return Ok((model.id, None));
1043        }
1044
1045        // Task does not exist — create it
1046        let id = Uuid::new_v4();
1047        let new_task = TaskActiveModel {
1048            id: Set(id),
1049            parent_id: Set(parent_id),
1050            sequence: Set(sequence),
1051            name: Set(name.to_string()),
1052            kind: Set(kind.to_string()),
1053            status: Set(TaskStatus::Pending),
1054            input: Set(input),
1055            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1056            ..Default::default()
1057        };
1058        new_task.insert(db).await?;
1059
1060        Ok((id, None))
1061    }
1062}
1063
1064async fn get_output(
1065    db: &impl ConnectionTrait,
1066    task_id: Uuid,
1067) -> Result<Option<serde_json::Value>, DurableError> {
1068    let model = Task::find_by_id(task_id)
1069        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1070        .one(db)
1071        .await?;
1072
1073    Ok(model.and_then(|m| m.output))
1074}
1075
1076async fn get_status(
1077    db: &impl ConnectionTrait,
1078    task_id: Uuid,
1079) -> Result<Option<TaskStatus>, DurableError> {
1080    let model = Task::find_by_id(task_id).one(db).await?;
1081
1082    Ok(model.map(|m| m.status))
1083}
1084
1085/// Returns (retry_count, max_retries) for a task.
1086async fn get_retry_info(
1087    db: &DatabaseConnection,
1088    task_id: Uuid,
1089) -> Result<(u32, u32), DurableError> {
1090    let model = Task::find_by_id(task_id).one(db).await?;
1091
1092    match model {
1093        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1094        None => Err(DurableError::custom(format!(
1095            "task {task_id} not found when reading retry info"
1096        ))),
1097    }
1098}
1099
1100/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1101async fn increment_retry_count(
1102    db: &DatabaseConnection,
1103    task_id: Uuid,
1104) -> Result<u32, DurableError> {
1105    let model = Task::find_by_id(task_id).one(db).await?;
1106
1107    match model {
1108        Some(m) => {
1109            let new_count = m.retry_count + 1;
1110            let mut active: TaskActiveModel = m.into();
1111            active.retry_count = Set(new_count);
1112            active.status = Set(TaskStatus::Pending);
1113            active.error = Set(None);
1114            active.completed_at = Set(None);
1115            active.update(db).await?;
1116            Ok(new_count as u32)
1117        }
1118        None => Err(DurableError::custom(format!(
1119            "task {task_id} not found when incrementing retry count"
1120        ))),
1121    }
1122}
1123
1124// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1125// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1126async fn set_status(
1127    db: &impl ConnectionTrait,
1128    task_id: Uuid,
1129    status: TaskStatus,
1130) -> Result<(), DurableError> {
1131    let sql = format!(
1132        "UPDATE durable.task \
1133         SET status = '{status}', \
1134             started_at = COALESCE(started_at, now()), \
1135             deadline_epoch_ms = CASE \
1136                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1137                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1138                 ELSE deadline_epoch_ms \
1139             END \
1140         WHERE id = '{task_id}'"
1141    );
1142    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1143        .await?;
1144    Ok(())
1145}
1146
1147/// Check if the task is paused or cancelled. Returns an error if so.
1148async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1149    let status = get_status(db, task_id).await?;
1150    match status {
1151        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1152        Some(TaskStatus::Cancelled) => {
1153            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1154        }
1155        _ => Ok(()),
1156    }
1157}
1158
1159/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1160async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1161    let model = Task::find_by_id(task_id).one(db).await?;
1162
1163    if let Some(m) = model
1164        && let Some(deadline_ms) = m.deadline_epoch_ms
1165    {
1166        let now_ms = std::time::SystemTime::now()
1167            .duration_since(std::time::UNIX_EPOCH)
1168            .map(|d| d.as_millis() as i64)
1169            .unwrap_or(0);
1170        if now_ms > deadline_ms {
1171            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1172        }
1173    }
1174
1175    Ok(())
1176}
1177
1178async fn complete_task(
1179    db: &impl ConnectionTrait,
1180    task_id: Uuid,
1181    output: serde_json::Value,
1182) -> Result<(), DurableError> {
1183    let model = Task::find_by_id(task_id).one(db).await?;
1184
1185    if let Some(m) = model {
1186        let mut active: TaskActiveModel = m.into();
1187        active.status = Set(TaskStatus::Completed);
1188        active.output = Set(Some(output));
1189        active.completed_at = Set(Some(chrono::Utc::now().into()));
1190        active.update(db).await?;
1191    }
1192    Ok(())
1193}
1194
1195async fn fail_task(
1196    db: &impl ConnectionTrait,
1197    task_id: Uuid,
1198    error: &str,
1199) -> Result<(), DurableError> {
1200    let model = Task::find_by_id(task_id).one(db).await?;
1201
1202    if let Some(m) = model {
1203        let mut active: TaskActiveModel = m.into();
1204        active.status = Set(TaskStatus::Failed);
1205        active.error = Set(Some(error.to_string()));
1206        active.completed_at = Set(Some(chrono::Utc::now().into()));
1207        active.update(db).await?;
1208    }
1209    Ok(())
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214    use super::*;
1215    use std::sync::Arc;
1216    use std::sync::atomic::{AtomicU32, Ordering};
1217
1218    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1219    /// should be called exactly once and return Ok.
1220    #[tokio::test]
1221    async fn test_retry_db_write_succeeds_first_try() {
1222        let call_count = Arc::new(AtomicU32::new(0));
1223        let cc = call_count.clone();
1224        let result = retry_db_write(|| {
1225            let c = cc.clone();
1226            async move {
1227                c.fetch_add(1, Ordering::SeqCst);
1228                Ok::<(), DurableError>(())
1229            }
1230        })
1231        .await;
1232        assert!(result.is_ok());
1233        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1234    }
1235
1236    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1237    /// fails twice then succeeds should return Ok and be called 3 times.
1238    #[tokio::test]
1239    async fn test_retry_db_write_succeeds_after_transient_failure() {
1240        let call_count = Arc::new(AtomicU32::new(0));
1241        let cc = call_count.clone();
1242        let result = retry_db_write(|| {
1243            let c = cc.clone();
1244            async move {
1245                let n = c.fetch_add(1, Ordering::SeqCst);
1246                if n < 2 {
1247                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1248                        "transient".to_string(),
1249                    )))
1250                } else {
1251                    Ok(())
1252                }
1253            }
1254        })
1255        .await;
1256        assert!(result.is_ok());
1257        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1258    }
1259
1260    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1261    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1262    #[tokio::test]
1263    async fn test_retry_db_write_exhausts_retries() {
1264        let call_count = Arc::new(AtomicU32::new(0));
1265        let cc = call_count.clone();
1266        let result = retry_db_write(|| {
1267            let c = cc.clone();
1268            async move {
1269                c.fetch_add(1, Ordering::SeqCst);
1270                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1271                    "always fails".to_string(),
1272                )))
1273            }
1274        })
1275        .await;
1276        assert!(result.is_err());
1277        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1278        assert_eq!(
1279            call_count.load(Ordering::SeqCst),
1280            1 + MAX_CHECKPOINT_RETRIES
1281        );
1282    }
1283
1284    /// test_retry_db_write_returns_original_error: when all retries are
1285    /// exhausted the FIRST error is returned, not the last retry error.
1286    #[tokio::test]
1287    async fn test_retry_db_write_returns_original_error() {
1288        let call_count = Arc::new(AtomicU32::new(0));
1289        let cc = call_count.clone();
1290        let result = retry_db_write(|| {
1291            let c = cc.clone();
1292            async move {
1293                let n = c.fetch_add(1, Ordering::SeqCst);
1294                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1295                    "error-{}",
1296                    n
1297                ))))
1298            }
1299        })
1300        .await;
1301        let err = result.unwrap_err();
1302        // The message of the first error contains "error-0"
1303        assert!(
1304            err.to_string().contains("error-0"),
1305            "expected first error (error-0), got: {err}"
1306        );
1307    }
1308}