Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::task::{
4    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task, TaskStatus,
5};
6use sea_orm::{
7    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
9    Statement, TransactionTrait,
10};
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use std::sync::atomic::{AtomicI32, Ordering};
14use std::time::Duration;
15use uuid::Uuid;
16
17use crate::error::DurableError;
18
19// ── Retry constants ───────────────────────────────────────────────────────────
20
21const MAX_CHECKPOINT_RETRIES: u32 = 3;
22const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
23
24/// Retry a fallible DB write with exponential backoff.
25///
26/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
27/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
28/// error if all retries are exhausted.
29async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
30where
31    F: FnMut() -> Fut,
32    Fut: std::future::Future<Output = Result<(), DurableError>>,
33{
34    match f().await {
35        Ok(()) => Ok(()),
36        Err(first_err) => {
37            for i in 0..MAX_CHECKPOINT_RETRIES {
38                tokio::time::sleep(Duration::from_millis(
39                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
40                ))
41                .await;
42                if f().await.is_ok() {
43                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
44                    return Ok(());
45                }
46            }
47            Err(first_err)
48        }
49    }
50}
51
52/// Policy for retrying a step on failure.
53pub struct RetryPolicy {
54    pub max_retries: u32,
55    pub initial_backoff: std::time::Duration,
56    pub backoff_multiplier: f64,
57}
58
59impl RetryPolicy {
60    /// No retries — fails immediately on error.
61    pub fn none() -> Self {
62        Self {
63            max_retries: 0,
64            initial_backoff: std::time::Duration::from_secs(0),
65            backoff_multiplier: 1.0,
66        }
67    }
68
69    /// Exponential backoff: backoff doubles each retry.
70    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
71        Self {
72            max_retries,
73            initial_backoff,
74            backoff_multiplier: 2.0,
75        }
76    }
77
78    /// Fixed backoff: same duration between all retries.
79    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
80        Self {
81            max_retries,
82            initial_backoff: backoff,
83            backoff_multiplier: 1.0,
84        }
85    }
86}
87
88/// Sort order for task listing.
89pub enum TaskSort {
90    CreatedAt(Order),
91    StartedAt(Order),
92    CompletedAt(Order),
93    Name(Order),
94    Status(Order),
95}
96
97/// Builder for filtering, sorting, and paginating task queries.
98///
99/// ```ignore
100/// let query = TaskQuery::default()
101///     .status(TaskStatus::Running)
102///     .kind("WORKFLOW")
103///     .root_only(true)
104///     .sort(TaskSort::CreatedAt(Order::Desc))
105///     .limit(20);
106/// let tasks = Ctx::list(&db, query).await?;
107/// ```
108pub struct TaskQuery {
109    pub status: Option<TaskStatus>,
110    pub kind: Option<String>,
111    pub parent_id: Option<Uuid>,
112    pub root_only: bool,
113    pub name: Option<String>,
114    pub queue_name: Option<String>,
115    pub sort: TaskSort,
116    pub limit: Option<u64>,
117    pub offset: Option<u64>,
118}
119
120impl Default for TaskQuery {
121    fn default() -> Self {
122        Self {
123            status: None,
124            kind: None,
125            parent_id: None,
126            root_only: false,
127            name: None,
128            queue_name: None,
129            sort: TaskSort::CreatedAt(Order::Desc),
130            limit: None,
131            offset: None,
132        }
133    }
134}
135
136impl TaskQuery {
137    /// Filter by status.
138    pub fn status(mut self, status: TaskStatus) -> Self {
139        self.status = Some(status);
140        self
141    }
142
143    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
144    pub fn kind(mut self, kind: &str) -> Self {
145        self.kind = Some(kind.to_string());
146        self
147    }
148
149    /// Filter by parent task ID (direct children only).
150    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
151        self.parent_id = Some(parent_id);
152        self
153    }
154
155    /// Only return root tasks (no parent).
156    pub fn root_only(mut self, root_only: bool) -> Self {
157        self.root_only = root_only;
158        self
159    }
160
161    /// Filter by task name.
162    pub fn name(mut self, name: &str) -> Self {
163        self.name = Some(name.to_string());
164        self
165    }
166
167    /// Filter by queue name.
168    pub fn queue_name(mut self, queue: &str) -> Self {
169        self.queue_name = Some(queue.to_string());
170        self
171    }
172
173    /// Set the sort order.
174    pub fn sort(mut self, sort: TaskSort) -> Self {
175        self.sort = sort;
176        self
177    }
178
179    /// Limit the number of results.
180    pub fn limit(mut self, limit: u64) -> Self {
181        self.limit = Some(limit);
182        self
183    }
184
185    /// Skip the first N results.
186    pub fn offset(mut self, offset: u64) -> Self {
187        self.offset = Some(offset);
188        self
189    }
190}
191
192/// Summary of a task returned by `Ctx::list()`.
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
194pub struct TaskSummary {
195    pub id: Uuid,
196    pub parent_id: Option<Uuid>,
197    pub name: String,
198    pub status: TaskStatus,
199    pub kind: String,
200    pub input: Option<serde_json::Value>,
201    pub output: Option<serde_json::Value>,
202    pub error: Option<String>,
203    pub queue_name: Option<String>,
204    pub created_at: chrono::DateTime<chrono::FixedOffset>,
205    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
206    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207}
208
209impl From<durable_db::entity::task::Model> for TaskSummary {
210    fn from(m: durable_db::entity::task::Model) -> Self {
211        Self {
212            id: m.id,
213            parent_id: m.parent_id,
214            name: m.name,
215            status: m.status,
216            kind: m.kind,
217            input: m.input,
218            output: m.output,
219            error: m.error,
220            queue_name: m.queue_name,
221            created_at: m.created_at,
222            started_at: m.started_at,
223            completed_at: m.completed_at,
224        }
225    }
226}
227
228/// Context threaded through every workflow and step.
229///
230/// Users never create or manage task IDs. The SDK handles everything
231/// via `(parent_id, name)` lookups — the unique constraint in the schema
232/// guarantees exactly-once step creation.
233pub struct Ctx {
234    db: DatabaseConnection,
235    task_id: Uuid,
236    sequence: AtomicI32,
237}
238
239impl Ctx {
240    // ── Workflow lifecycle (user-facing) ──────────────────────────
241
242    /// Start or resume a root workflow by name.
243    ///
244    /// ```ignore
245    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
246    /// ```
247    pub async fn start(
248        db: &DatabaseConnection,
249        name: &str,
250        input: Option<serde_json::Value>,
251    ) -> Result<Self, DurableError> {
252        let txn = db.begin().await?;
253        // Workflows use plain find-or-create without FOR UPDATE SKIP LOCKED.
254        // Multiple workers resuming the same workflow is fine — they share the
255        // same task row and each step will be individually locked.
256        let (task_id, _saved) =
257            find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
258        retry_db_write(|| set_status(&txn, task_id, TaskStatus::Running)).await?;
259        txn.commit().await?;
260        Ok(Self {
261            db: db.clone(),
262            task_id,
263            sequence: AtomicI32::new(0),
264        })
265    }
266
267    /// Run a step. If already completed, returns saved output. Otherwise executes
268    /// the closure, saves the result, and returns it.
269    ///
270    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
271    ///
272    /// ```ignore
273    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
274    /// ```
275    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
276    where
277        T: Serialize + DeserializeOwned,
278        F: FnOnce() -> Fut,
279        Fut: std::future::Future<Output = Result<T, DurableError>>,
280    {
281        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
282
283        // Check if workflow is paused or cancelled before executing
284        check_status(&self.db, self.task_id).await?;
285
286        // Check if parent task's deadline has passed before executing
287        check_deadline(&self.db, self.task_id).await?;
288
289        // Begin transaction — the FOR UPDATE lock is held throughout step execution
290        let txn = self.db.begin().await?;
291
292        // Find or create — idempotent via UNIQUE(parent_id, name)
293        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
294        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
295        // max_retries=0: step() does not retry; use step_with_retry() for retries.
296        let (step_id, saved_output) = find_or_create_task(
297            &txn,
298            Some(self.task_id),
299            Some(seq),
300            name,
301            "STEP",
302            None,
303            true,
304            Some(0),
305        )
306        .await?;
307
308        // If already completed, replay from saved output
309        if let Some(output) = saved_output {
310            txn.commit().await?;
311            let val: T = serde_json::from_value(output)?;
312            tracing::debug!(step = name, seq, "replaying saved output");
313            return Ok(val);
314        }
315
316        // Execute the step closure within the transaction (row lock is held)
317        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
318        match f().await {
319            Ok(val) => {
320                let json = serde_json::to_value(&val)?;
321                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
322                txn.commit().await?;
323                tracing::debug!(step = name, seq, "step completed");
324                Ok(val)
325            }
326            Err(e) => {
327                let err_msg = e.to_string();
328                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
329                txn.commit().await?;
330                Err(e)
331            }
332        }
333    }
334
335    /// Run a DB-only step inside a single Postgres transaction.
336    ///
337    /// Both the user's DB work and the checkpoint save happen in the same transaction,
338    /// ensuring atomicity. If the closure returns an error, both the user writes and
339    /// the checkpoint are rolled back.
340    ///
341    /// ```ignore
342    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
343    ///     do_db_work(tx).await
344    /// })).await?;
345    /// ```
346    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
347    where
348        T: Serialize + DeserializeOwned + Send,
349        F: for<'tx> FnOnce(
350                &'tx DatabaseTransaction,
351            ) -> Pin<
352                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
353            > + Send,
354    {
355        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
356
357        // Check if workflow is paused or cancelled before executing
358        check_status(&self.db, self.task_id).await?;
359
360        // Find or create the step task record OUTSIDE the transaction.
361        // This is idempotent (UNIQUE constraint) and must exist before we begin.
362        let (step_id, saved_output) = find_or_create_task(
363            &self.db,
364            Some(self.task_id),
365            Some(seq),
366            name,
367            "TRANSACTION",
368            None,
369            false,
370            None,
371        )
372        .await?;
373
374        // If already completed, replay from saved output.
375        if let Some(output) = saved_output {
376            let val: T = serde_json::from_value(output)?;
377            tracing::debug!(step = name, seq, "replaying saved transaction output");
378            return Ok(val);
379        }
380
381        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
382        let tx = self.db.begin().await?;
383
384        set_status(&tx, step_id, TaskStatus::Running).await?;
385
386        match f(&tx).await {
387            Ok(val) => {
388                let json = serde_json::to_value(&val)?;
389                complete_task(&tx, step_id, json).await?;
390                tx.commit().await?;
391                tracing::debug!(step = name, seq, "transaction step committed");
392                Ok(val)
393            }
394            Err(e) => {
395                // Rollback happens automatically when tx is dropped.
396                // Record the failure on the main connection (outside the rolled-back tx).
397                drop(tx);
398                fail_task(&self.db, step_id, &e.to_string()).await?;
399                Err(e)
400            }
401        }
402    }
403
404    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
405    ///
406    /// ```ignore
407    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
408    /// // use child_ctx.step(...) for steps inside the child
409    /// child_ctx.complete(json!({"done": true})).await?;
410    /// ```
411    pub async fn child(
412        &self,
413        name: &str,
414        input: Option<serde_json::Value>,
415    ) -> Result<Self, DurableError> {
416        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
417
418        // Check if workflow is paused or cancelled before executing
419        check_status(&self.db, self.task_id).await?;
420
421        let txn = self.db.begin().await?;
422        // Child workflows also use plain find-or-create without locking.
423        let (child_id, _saved) = find_or_create_task(
424            &txn,
425            Some(self.task_id),
426            Some(seq),
427            name,
428            "WORKFLOW",
429            input,
430            false,
431            None,
432        )
433        .await?;
434
435        // If child already completed, return a Ctx that will replay
436        // (the caller should check is_completed() or just run steps which will replay)
437        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
438        txn.commit().await?;
439
440        Ok(Self {
441            db: self.db.clone(),
442            task_id: child_id,
443            sequence: AtomicI32::new(0),
444        })
445    }
446
447    /// Check if this workflow/child was already completed (for skipping in parent).
448    pub async fn is_completed(&self) -> Result<bool, DurableError> {
449        let status = get_status(&self.db, self.task_id).await?;
450        Ok(status == Some(TaskStatus::Completed))
451    }
452
453    /// Get the saved output if this task is completed.
454    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
455        match get_output(&self.db, self.task_id).await? {
456            Some(val) => Ok(Some(serde_json::from_value(val)?)),
457            None => Ok(None),
458        }
459    }
460
461    /// Mark this workflow as completed with an output value.
462    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
463        let json = serde_json::to_value(output)?;
464        let db = &self.db;
465        let task_id = self.task_id;
466        retry_db_write(|| complete_task(db, task_id, json.clone())).await
467    }
468
469    /// Run a step with a configurable retry policy.
470    ///
471    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
472    /// it may be called multiple times on retry. Retries happen in-process with
473    /// configurable backoff between attempts.
474    ///
475    /// ```ignore
476    /// let result: u32 = ctx
477    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
478    ///         api.call().await
479    ///     })
480    ///     .await?;
481    /// ```
482    pub async fn step_with_retry<T, F, Fut>(
483        &self,
484        name: &str,
485        policy: RetryPolicy,
486        f: F,
487    ) -> Result<T, DurableError>
488    where
489        T: Serialize + DeserializeOwned,
490        F: Fn() -> Fut,
491        Fut: std::future::Future<Output = Result<T, DurableError>>,
492    {
493        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
494
495        // Check if workflow is paused or cancelled before executing
496        check_status(&self.db, self.task_id).await?;
497
498        // Find or create — idempotent via UNIQUE(parent_id, name)
499        // Set max_retries from policy when creating.
500        // No locking here — retry logic handles re-execution in-process.
501        let (step_id, saved_output) = find_or_create_task(
502            &self.db,
503            Some(self.task_id),
504            Some(seq),
505            name,
506            "STEP",
507            None,
508            false,
509            Some(policy.max_retries),
510        )
511        .await?;
512
513        // If already completed, replay from saved output
514        if let Some(output) = saved_output {
515            let val: T = serde_json::from_value(output)?;
516            tracing::debug!(step = name, seq, "replaying saved output");
517            return Ok(val);
518        }
519
520        // Get current retry state from DB (for resume across restarts)
521        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
522
523        // Retry loop
524        loop {
525            // Re-check status before each retry attempt
526            check_status(&self.db, self.task_id).await?;
527            set_status(&self.db, step_id, TaskStatus::Running).await?;
528            match f().await {
529                Ok(val) => {
530                    let json = serde_json::to_value(&val)?;
531                    complete_task(&self.db, step_id, json).await?;
532                    tracing::debug!(step = name, seq, retry_count, "step completed");
533                    return Ok(val);
534                }
535                Err(e) => {
536                    if retry_count < max_retries {
537                        // Increment retry count and reset to PENDING
538                        retry_count = increment_retry_count(&self.db, step_id).await?;
539                        tracing::debug!(
540                            step = name,
541                            seq,
542                            retry_count,
543                            max_retries,
544                            "step failed, retrying"
545                        );
546
547                        // Compute backoff duration
548                        let backoff = if policy.initial_backoff.is_zero() {
549                            std::time::Duration::ZERO
550                        } else {
551                            let factor = policy
552                                .backoff_multiplier
553                                .powi((retry_count - 1) as i32)
554                                .max(1.0);
555                            let millis =
556                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
557                            std::time::Duration::from_millis(millis)
558                        };
559
560                        if !backoff.is_zero() {
561                            tokio::time::sleep(backoff).await;
562                        }
563                    } else {
564                        // Exhausted retries — mark FAILED
565                        fail_task(&self.db, step_id, &e.to_string()).await?;
566                        tracing::debug!(
567                            step = name,
568                            seq,
569                            retry_count,
570                            "step exhausted retries, marked FAILED"
571                        );
572                        return Err(e);
573                    }
574                }
575            }
576        }
577    }
578
579    /// Mark this workflow as failed.
580    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
581        let db = &self.db;
582        let task_id = self.task_id;
583        retry_db_write(|| fail_task(db, task_id, error)).await
584    }
585
586    /// Set the timeout for this task in milliseconds.
587    ///
588    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
589    /// for recovery by `Executor::recover()`.
590    ///
591    /// If the task is currently RUNNING and has a `started_at`, this also
592    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
593    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
594        let sql = format!(
595            "UPDATE durable.task \
596             SET timeout_ms = {timeout_ms}, \
597                 deadline_epoch_ms = CASE \
598                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
599                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
600                     ELSE deadline_epoch_ms \
601                 END \
602             WHERE id = '{}'",
603            self.task_id
604        );
605        self.db
606            .execute(Statement::from_string(DbBackend::Postgres, sql))
607            .await?;
608        Ok(())
609    }
610
611    /// Start or resume a root workflow with a timeout in milliseconds.
612    ///
613    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
614    pub async fn start_with_timeout(
615        db: &DatabaseConnection,
616        name: &str,
617        input: Option<serde_json::Value>,
618        timeout_ms: i64,
619    ) -> Result<Self, DurableError> {
620        let ctx = Self::start(db, name, input).await?;
621        ctx.set_timeout(timeout_ms).await?;
622        Ok(ctx)
623    }
624
625    // ── Workflow control (management API) ─────────────────────────
626
627    /// Pause a workflow by ID. Sets status to PAUSED and recursively
628    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
629    ///
630    /// Only workflows in PENDING or RUNNING status can be paused.
631    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
632        let model = Task::find_by_id(task_id).one(db).await?;
633        let model =
634            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
635
636        match model.status {
637            TaskStatus::Pending | TaskStatus::Running => {}
638            status => {
639                return Err(DurableError::custom(format!(
640                    "cannot pause task in {status} status"
641                )));
642            }
643        }
644
645        // Pause the task itself and all descendants in one recursive CTE
646        let sql = format!(
647            "WITH RECURSIVE descendants AS ( \
648                 SELECT id FROM durable.task WHERE id = '{task_id}' \
649                 UNION ALL \
650                 SELECT t.id FROM durable.task t \
651                 INNER JOIN descendants d ON t.parent_id = d.id \
652             ) \
653             UPDATE durable.task SET status = 'PAUSED' \
654             WHERE id IN (SELECT id FROM descendants) \
655               AND status IN ('PENDING', 'RUNNING')"
656        );
657        db.execute(Statement::from_string(DbBackend::Postgres, sql))
658            .await?;
659
660        tracing::info!(%task_id, "workflow paused");
661        Ok(())
662    }
663
664    /// Resume a paused workflow by ID. Sets status back to RUNNING and
665    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
666    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
667        let model = Task::find_by_id(task_id).one(db).await?;
668        let model =
669            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
670
671        if model.status != TaskStatus::Paused {
672            return Err(DurableError::custom(format!(
673                "cannot resume task in {} status (must be PAUSED)",
674                model.status
675            )));
676        }
677
678        // Resume the root task back to RUNNING
679        let sql = format!(
680            "UPDATE durable.task SET status = 'RUNNING' WHERE id = '{task_id}'"
681        );
682        db.execute(Statement::from_string(DbBackend::Postgres, sql))
683            .await?;
684
685        // Recursively resume all PAUSED descendants back to PENDING
686        let cascade_sql = format!(
687            "WITH RECURSIVE descendants AS ( \
688                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
689                 UNION ALL \
690                 SELECT t.id FROM durable.task t \
691                 INNER JOIN descendants d ON t.parent_id = d.id \
692             ) \
693             UPDATE durable.task SET status = 'PENDING' \
694             WHERE id IN (SELECT id FROM descendants) \
695               AND status = 'PAUSED'"
696        );
697        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
698            .await?;
699
700        tracing::info!(%task_id, "workflow resumed");
701        Ok(())
702    }
703
704    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
705    /// cascades to all non-terminal descendants.
706    ///
707    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
708    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
709        let model = Task::find_by_id(task_id).one(db).await?;
710        let model =
711            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
712
713        match model.status {
714            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
715                return Err(DurableError::custom(format!(
716                    "cannot cancel task in {} status",
717                    model.status
718                )));
719            }
720            _ => {}
721        }
722
723        // Cancel the task itself and all non-terminal descendants in one recursive CTE
724        let sql = format!(
725            "WITH RECURSIVE descendants AS ( \
726                 SELECT id FROM durable.task WHERE id = '{task_id}' \
727                 UNION ALL \
728                 SELECT t.id FROM durable.task t \
729                 INNER JOIN descendants d ON t.parent_id = d.id \
730             ) \
731             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
732             WHERE id IN (SELECT id FROM descendants) \
733               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
734        );
735        db.execute(Statement::from_string(DbBackend::Postgres, sql))
736            .await?;
737
738        tracing::info!(%task_id, "workflow cancelled");
739        Ok(())
740    }
741
742    // ── Query API ─────────────────────────────────────────────────
743
744    /// List tasks matching the given filter, with sorting and pagination.
745    ///
746    /// ```ignore
747    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
748    /// ```
749    pub async fn list(
750        db: &DatabaseConnection,
751        query: TaskQuery,
752    ) -> Result<Vec<TaskSummary>, DurableError> {
753        let mut select = Task::find();
754
755        // Filters
756        if let Some(status) = &query.status {
757            select = select.filter(TaskColumn::Status.eq(status.to_string()));
758        }
759        if let Some(kind) = &query.kind {
760            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
761        }
762        if let Some(parent_id) = query.parent_id {
763            select = select.filter(TaskColumn::ParentId.eq(parent_id));
764        }
765        if query.root_only {
766            select = select.filter(TaskColumn::ParentId.is_null());
767        }
768        if let Some(name) = &query.name {
769            select = select.filter(TaskColumn::Name.eq(name.as_str()));
770        }
771        if let Some(queue) = &query.queue_name {
772            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
773        }
774
775        // Sorting
776        let (col, order) = match query.sort {
777            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
778            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
779            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
780            TaskSort::Name(ord) => (TaskColumn::Name, ord),
781            TaskSort::Status(ord) => (TaskColumn::Status, ord),
782        };
783        select = select.order_by(col, order);
784
785        // Pagination
786        if let Some(offset) = query.offset {
787            select = select.offset(offset);
788        }
789        if let Some(limit) = query.limit {
790            select = select.limit(limit);
791        }
792
793        let models = select.all(db).await?;
794
795        Ok(models.into_iter().map(TaskSummary::from).collect())
796    }
797
798    /// Count tasks matching the given filter.
799    pub async fn count(
800        db: &DatabaseConnection,
801        query: TaskQuery,
802    ) -> Result<u64, DurableError> {
803        let mut select = Task::find();
804
805        if let Some(status) = &query.status {
806            select = select.filter(TaskColumn::Status.eq(status.to_string()));
807        }
808        if let Some(kind) = &query.kind {
809            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
810        }
811        if let Some(parent_id) = query.parent_id {
812            select = select.filter(TaskColumn::ParentId.eq(parent_id));
813        }
814        if query.root_only {
815            select = select.filter(TaskColumn::ParentId.is_null());
816        }
817        if let Some(name) = &query.name {
818            select = select.filter(TaskColumn::Name.eq(name.as_str()));
819        }
820        if let Some(queue) = &query.queue_name {
821            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
822        }
823
824        let count = select.count(db).await?;
825        Ok(count)
826    }
827
828    // ── Accessors ────────────────────────────────────────────────
829
830    pub fn db(&self) -> &DatabaseConnection {
831        &self.db
832    }
833
834    pub fn task_id(&self) -> Uuid {
835        self.task_id
836    }
837
838    pub fn next_sequence(&self) -> i32 {
839        self.sequence.fetch_add(1, Ordering::SeqCst)
840    }
841}
842
843// ── Internal SQL helpers ─────────────────────────────────────────────
844
845/// Find an existing task by (parent_id, name) or create a new one.
846///
847/// Returns `(task_id, Option<saved_output>)`:
848/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
849/// - `saved_output` is `None` when the task is new or in-progress.
850///
851/// When `lock` is `true` and an existing non-completed task is found, this
852/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
853/// another worker holds the lock, `DurableError::StepLocked` is returned so
854/// the caller can skip execution rather than double-firing side effects.
855///
856/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
857/// and child-workflow creation where concurrent start is safe).
858///
859/// When `lock` is `true`, the caller MUST call this within a transaction so
860/// the row lock is held throughout step execution.
861///
862/// `max_retries`: if Some, overrides the schema default when creating a new task.
863#[allow(clippy::too_many_arguments)]
864async fn find_or_create_task(
865    db: &impl ConnectionTrait,
866    parent_id: Option<Uuid>,
867    sequence: Option<i32>,
868    name: &str,
869    kind: &str,
870    input: Option<serde_json::Value>,
871    lock: bool,
872    max_retries: Option<u32>,
873) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
874    let parent_eq = match parent_id {
875        Some(p) => format!("= '{p}'"),
876        None => "IS NULL".to_string(),
877    };
878    let parent_sql = match parent_id {
879        Some(p) => format!("'{p}'"),
880        None => "NULL".to_string(),
881    };
882
883    if lock {
884        // Locking path (for steps): we need exactly-once execution.
885        //
886        // Strategy:
887        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
888        //    If another transaction is concurrently inserting the same row,
889        //    Postgres will block here until that transaction commits or rolls
890        //    back, ensuring we never see a phantom "not found" for a row being
891        //    inserted.
892        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
893        //    should get it back; if the row existed and is locked by another
894        //    worker we get nothing.
895        // 3. If SKIP LOCKED returns empty, return StepLocked.
896
897        let new_id = Uuid::new_v4();
898        let seq_sql = match sequence {
899            Some(s) => s.to_string(),
900            None => "NULL".to_string(),
901        };
902        let input_sql = match &input {
903            Some(v) => format!("'{}'", serde_json::to_string(v)?),
904            None => "NULL".to_string(),
905        };
906
907        let max_retries_sql = match max_retries {
908            Some(r) => r.to_string(),
909            None => "3".to_string(), // schema default
910        };
911
912        // Step 1: insert-or-skip
913        let insert_sql = format!(
914            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
915             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
916             ON CONFLICT (parent_id, name) DO NOTHING"
917        );
918        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
919            .await?;
920
921        // Step 2: lock the row (ours or pre-existing)
922        let lock_sql = format!(
923            "SELECT id, status, output FROM durable.task \
924             WHERE parent_id {parent_eq} AND name = '{name}' \
925             FOR UPDATE SKIP LOCKED"
926        );
927        let row = db
928            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
929            .await?;
930
931        if let Some(row) = row {
932            let id: Uuid = row
933                .try_get_by_index(0)
934                .map_err(|e| DurableError::custom(e.to_string()))?;
935            let status: String = row
936                .try_get_by_index(1)
937                .map_err(|e| DurableError::custom(e.to_string()))?;
938            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
939
940            if status == TaskStatus::Completed.to_string() {
941                // Replay path — return saved output
942                return Ok((id, output));
943            }
944            // Task exists and we hold the lock — proceed to execute
945            return Ok((id, None));
946        }
947
948        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
949        Err(DurableError::StepLocked(name.to_string()))
950    } else {
951        // Plain find without locking — safe for workflow-level operations.
952        // Multiple workers resuming the same workflow is fine; individual
953        // steps will be locked when executed.
954        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
955        query = match parent_id {
956            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
957            None => query.filter(TaskColumn::ParentId.is_null()),
958        };
959        let existing = query.one(db).await?;
960
961        if let Some(model) = existing {
962            if model.status == TaskStatus::Completed {
963                return Ok((model.id, model.output));
964            }
965            return Ok((model.id, None));
966        }
967
968        // Task does not exist — create it
969        let id = Uuid::new_v4();
970        let new_task = TaskActiveModel {
971            id: Set(id),
972            parent_id: Set(parent_id),
973            sequence: Set(sequence),
974            name: Set(name.to_string()),
975            kind: Set(kind.to_string()),
976            status: Set(TaskStatus::Pending),
977            input: Set(input),
978            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
979            ..Default::default()
980        };
981        new_task.insert(db).await?;
982
983        Ok((id, None))
984    }
985}
986
987async fn get_output(
988    db: &impl ConnectionTrait,
989    task_id: Uuid,
990) -> Result<Option<serde_json::Value>, DurableError> {
991    let model = Task::find_by_id(task_id)
992        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
993        .one(db)
994        .await?;
995
996    Ok(model.and_then(|m| m.output))
997}
998
999async fn get_status(
1000    db: &impl ConnectionTrait,
1001    task_id: Uuid,
1002) -> Result<Option<TaskStatus>, DurableError> {
1003    let model = Task::find_by_id(task_id).one(db).await?;
1004
1005    Ok(model.map(|m| m.status))
1006}
1007
1008/// Returns (retry_count, max_retries) for a task.
1009async fn get_retry_info(
1010    db: &DatabaseConnection,
1011    task_id: Uuid,
1012) -> Result<(u32, u32), DurableError> {
1013    let model = Task::find_by_id(task_id).one(db).await?;
1014
1015    match model {
1016        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1017        None => Err(DurableError::custom(format!(
1018            "task {task_id} not found when reading retry info"
1019        ))),
1020    }
1021}
1022
1023/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1024async fn increment_retry_count(
1025    db: &DatabaseConnection,
1026    task_id: Uuid,
1027) -> Result<u32, DurableError> {
1028    let model = Task::find_by_id(task_id).one(db).await?;
1029
1030    match model {
1031        Some(m) => {
1032            let new_count = m.retry_count + 1;
1033            let mut active: TaskActiveModel = m.into();
1034            active.retry_count = Set(new_count);
1035            active.status = Set(TaskStatus::Pending);
1036            active.error = Set(None);
1037            active.completed_at = Set(None);
1038            active.update(db).await?;
1039            Ok(new_count as u32)
1040        }
1041        None => Err(DurableError::custom(format!(
1042            "task {task_id} not found when incrementing retry count"
1043        ))),
1044    }
1045}
1046
1047// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1048// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1049async fn set_status(
1050    db: &impl ConnectionTrait,
1051    task_id: Uuid,
1052    status: TaskStatus,
1053) -> Result<(), DurableError> {
1054    let sql = format!(
1055        "UPDATE durable.task \
1056         SET status = '{status}', \
1057             started_at = COALESCE(started_at, now()), \
1058             deadline_epoch_ms = CASE \
1059                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1060                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1061                 ELSE deadline_epoch_ms \
1062             END \
1063         WHERE id = '{task_id}'"
1064    );
1065    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1066        .await?;
1067    Ok(())
1068}
1069
1070/// Check if the task is paused or cancelled. Returns an error if so.
1071async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1072    let status = get_status(db, task_id).await?;
1073    match status {
1074        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1075        Some(TaskStatus::Cancelled) => {
1076            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1077        }
1078        _ => Ok(()),
1079    }
1080}
1081
1082/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1083async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1084    let model = Task::find_by_id(task_id).one(db).await?;
1085
1086    if let Some(m) = model
1087        && let Some(deadline_ms) = m.deadline_epoch_ms
1088    {
1089        let now_ms = std::time::SystemTime::now()
1090            .duration_since(std::time::UNIX_EPOCH)
1091            .map(|d| d.as_millis() as i64)
1092            .unwrap_or(0);
1093        if now_ms > deadline_ms {
1094            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1095        }
1096    }
1097
1098    Ok(())
1099}
1100
1101async fn complete_task(
1102    db: &impl ConnectionTrait,
1103    task_id: Uuid,
1104    output: serde_json::Value,
1105) -> Result<(), DurableError> {
1106    let model = Task::find_by_id(task_id).one(db).await?;
1107
1108    if let Some(m) = model {
1109        let mut active: TaskActiveModel = m.into();
1110        active.status = Set(TaskStatus::Completed);
1111        active.output = Set(Some(output));
1112        active.completed_at = Set(Some(chrono::Utc::now().into()));
1113        active.update(db).await?;
1114    }
1115    Ok(())
1116}
1117
1118async fn fail_task(
1119    db: &impl ConnectionTrait,
1120    task_id: Uuid,
1121    error: &str,
1122) -> Result<(), DurableError> {
1123    let model = Task::find_by_id(task_id).one(db).await?;
1124
1125    if let Some(m) = model {
1126        let mut active: TaskActiveModel = m.into();
1127        active.status = Set(TaskStatus::Failed);
1128        active.error = Set(Some(error.to_string()));
1129        active.completed_at = Set(Some(chrono::Utc::now().into()));
1130        active.update(db).await?;
1131    }
1132    Ok(())
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137    use super::*;
1138    use std::sync::Arc;
1139    use std::sync::atomic::{AtomicU32, Ordering};
1140
1141    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1142    /// should be called exactly once and return Ok.
1143    #[tokio::test]
1144    async fn test_retry_db_write_succeeds_first_try() {
1145        let call_count = Arc::new(AtomicU32::new(0));
1146        let cc = call_count.clone();
1147        let result = retry_db_write(|| {
1148            let c = cc.clone();
1149            async move {
1150                c.fetch_add(1, Ordering::SeqCst);
1151                Ok::<(), DurableError>(())
1152            }
1153        })
1154        .await;
1155        assert!(result.is_ok());
1156        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1157    }
1158
1159    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1160    /// fails twice then succeeds should return Ok and be called 3 times.
1161    #[tokio::test]
1162    async fn test_retry_db_write_succeeds_after_transient_failure() {
1163        let call_count = Arc::new(AtomicU32::new(0));
1164        let cc = call_count.clone();
1165        let result = retry_db_write(|| {
1166            let c = cc.clone();
1167            async move {
1168                let n = c.fetch_add(1, Ordering::SeqCst);
1169                if n < 2 {
1170                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1171                        "transient".to_string(),
1172                    )))
1173                } else {
1174                    Ok(())
1175                }
1176            }
1177        })
1178        .await;
1179        assert!(result.is_ok());
1180        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1181    }
1182
1183    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1184    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1185    #[tokio::test]
1186    async fn test_retry_db_write_exhausts_retries() {
1187        let call_count = Arc::new(AtomicU32::new(0));
1188        let cc = call_count.clone();
1189        let result = retry_db_write(|| {
1190            let c = cc.clone();
1191            async move {
1192                c.fetch_add(1, Ordering::SeqCst);
1193                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1194                    "always fails".to_string(),
1195                )))
1196            }
1197        })
1198        .await;
1199        assert!(result.is_err());
1200        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1201        assert_eq!(
1202            call_count.load(Ordering::SeqCst),
1203            1 + MAX_CHECKPOINT_RETRIES
1204        );
1205    }
1206
1207    /// test_retry_db_write_returns_original_error: when all retries are
1208    /// exhausted the FIRST error is returned, not the last retry error.
1209    #[tokio::test]
1210    async fn test_retry_db_write_returns_original_error() {
1211        let call_count = Arc::new(AtomicU32::new(0));
1212        let cc = call_count.clone();
1213        let result = retry_db_write(|| {
1214            let c = cc.clone();
1215            async move {
1216                let n = c.fetch_add(1, Ordering::SeqCst);
1217                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1218                    "error-{}",
1219                    n
1220                ))))
1221            }
1222        })
1223        .await;
1224        let err = result.unwrap_err();
1225        // The message of the first error contains "error-0"
1226        assert!(
1227            err.to_string().contains("error-0"),
1228            "expected first error (error-0), got: {err}"
1229        );
1230    }
1231}