Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9    DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10    Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20// ── Retry constants ───────────────────────────────────────────────────────────
21
22const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25/// Retry a fallible DB write with exponential backoff.
26///
27/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
28/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
29/// error if all retries are exhausted.
30async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32    F: FnMut() -> Fut,
33    Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35    match f().await {
36        Ok(()) => Ok(()),
37        Err(first_err) => {
38            for i in 0..MAX_CHECKPOINT_RETRIES {
39                tokio::time::sleep(Duration::from_millis(
40                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41                ))
42                .await;
43                if f().await.is_ok() {
44                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45                    return Ok(());
46                }
47            }
48            Err(first_err)
49        }
50    }
51}
52
53/// Policy for retrying a step on failure.
54pub struct RetryPolicy {
55    pub max_retries: u32,
56    pub initial_backoff: std::time::Duration,
57    pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61    /// No retries — fails immediately on error.
62    pub fn none() -> Self {
63        Self {
64            max_retries: 0,
65            initial_backoff: std::time::Duration::from_secs(0),
66            backoff_multiplier: 1.0,
67        }
68    }
69
70    /// Exponential backoff: backoff doubles each retry.
71    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72        Self {
73            max_retries,
74            initial_backoff,
75            backoff_multiplier: 2.0,
76        }
77    }
78
79    /// Fixed backoff: same duration between all retries.
80    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81        Self {
82            max_retries,
83            initial_backoff: backoff,
84            backoff_multiplier: 1.0,
85        }
86    }
87}
88
89/// Sort order for task listing.
90pub enum TaskSort {
91    CreatedAt(Order),
92    StartedAt(Order),
93    CompletedAt(Order),
94    Name(Order),
95    Status(Order),
96}
97
98/// Builder for filtering, sorting, and paginating task queries.
99///
100/// ```ignore
101/// let query = TaskQuery::default()
102///     .status(TaskStatus::Running)
103///     .kind("WORKFLOW")
104///     .root_only(true)
105///     .sort(TaskSort::CreatedAt(Order::Desc))
106///     .limit(20);
107/// let tasks = Ctx::list(&db, query).await?;
108/// ```
109pub struct TaskQuery {
110    pub status: Option<TaskStatus>,
111    pub kind: Option<String>,
112    pub parent_id: Option<Uuid>,
113    pub root_only: bool,
114    pub name: Option<String>,
115    pub queue_name: Option<String>,
116    pub sort: TaskSort,
117    pub limit: Option<u64>,
118    pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122    fn default() -> Self {
123        Self {
124            status: None,
125            kind: None,
126            parent_id: None,
127            root_only: false,
128            name: None,
129            queue_name: None,
130            sort: TaskSort::CreatedAt(Order::Desc),
131            limit: None,
132            offset: None,
133        }
134    }
135}
136
137impl TaskQuery {
138    /// Filter by status.
139    pub fn status(mut self, status: TaskStatus) -> Self {
140        self.status = Some(status);
141        self
142    }
143
144    /// Filter by kind (e.g. "WORKFLOW", "STEP", "TRANSACTION").
145    pub fn kind(mut self, kind: &str) -> Self {
146        self.kind = Some(kind.to_string());
147        self
148    }
149
150    /// Filter by parent task ID (direct children only).
151    pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152        self.parent_id = Some(parent_id);
153        self
154    }
155
156    /// Only return root tasks (no parent).
157    pub fn root_only(mut self, root_only: bool) -> Self {
158        self.root_only = root_only;
159        self
160    }
161
162    /// Filter by task name.
163    pub fn name(mut self, name: &str) -> Self {
164        self.name = Some(name.to_string());
165        self
166    }
167
168    /// Filter by queue name.
169    pub fn queue_name(mut self, queue: &str) -> Self {
170        self.queue_name = Some(queue.to_string());
171        self
172    }
173
174    /// Set the sort order.
175    pub fn sort(mut self, sort: TaskSort) -> Self {
176        self.sort = sort;
177        self
178    }
179
180    /// Limit the number of results.
181    pub fn limit(mut self, limit: u64) -> Self {
182        self.limit = Some(limit);
183        self
184    }
185
186    /// Skip the first N results.
187    pub fn offset(mut self, offset: u64) -> Self {
188        self.offset = Some(offset);
189        self
190    }
191}
192
193/// Summary of a task returned by `Ctx::list()`.
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196    pub id: Uuid,
197    pub parent_id: Option<Uuid>,
198    pub name: String,
199    pub status: TaskStatus,
200    pub kind: String,
201    pub input: Option<serde_json::Value>,
202    pub output: Option<serde_json::Value>,
203    pub error: Option<String>,
204    pub queue_name: Option<String>,
205    pub created_at: chrono::DateTime<chrono::FixedOffset>,
206    pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207    pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211    fn from(m: durable_db::entity::task::Model) -> Self {
212        Self {
213            id: m.id,
214            parent_id: m.parent_id,
215            name: m.name,
216            status: m.status,
217            kind: m.kind,
218            input: m.input,
219            output: m.output,
220            error: m.error,
221            queue_name: m.queue_name,
222            created_at: m.created_at,
223            started_at: m.started_at,
224            completed_at: m.completed_at,
225        }
226    }
227}
228
229/// Context threaded through every workflow and step.
230///
231/// Users never create or manage task IDs. The SDK handles everything
232/// via `(parent_id, name)` lookups — the unique constraint in the schema
233/// guarantees exactly-once step creation.
234pub struct Ctx {
235    db: DatabaseConnection,
236    task_id: Uuid,
237    sequence: AtomicI32,
238}
239
240impl Ctx {
241    // ── Workflow lifecycle (user-facing) ──────────────────────────
242
243    /// Start or resume a root workflow by name.
244    ///
245    /// ```ignore
246    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
247    /// ```
248    pub async fn start(
249        db: &DatabaseConnection,
250        name: &str,
251        input: Option<serde_json::Value>,
252    ) -> Result<Self, DurableError> {
253        let txn = db.begin().await?;
254        // Workflows use plain find-or-create without FOR UPDATE SKIP LOCKED.
255        // Multiple workers resuming the same workflow is fine — they share the
256        // same task row and each step will be individually locked.
257        let (task_id, _saved) =
258            find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
259        retry_db_write(|| set_status(&txn, task_id, TaskStatus::Running)).await?;
260        txn.commit().await?;
261        Ok(Self {
262            db: db.clone(),
263            task_id,
264            sequence: AtomicI32::new(0),
265        })
266    }
267
268    /// Run a step. If already completed, returns saved output. Otherwise executes
269    /// the closure, saves the result, and returns it.
270    ///
271    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
272    ///
273    /// ```ignore
274    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
275    /// ```
276    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
277    where
278        T: Serialize + DeserializeOwned,
279        F: FnOnce() -> Fut,
280        Fut: std::future::Future<Output = Result<T, DurableError>>,
281    {
282        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
283
284        // Check if workflow is paused or cancelled before executing
285        check_status(&self.db, self.task_id).await?;
286
287        // Check if parent task's deadline has passed before executing
288        check_deadline(&self.db, self.task_id).await?;
289
290        // Begin transaction — the FOR UPDATE lock is held throughout step execution
291        let txn = self.db.begin().await?;
292
293        // Find or create — idempotent via UNIQUE(parent_id, name)
294        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
295        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
296        // max_retries=0: step() does not retry; use step_with_retry() for retries.
297        let (step_id, saved_output) = find_or_create_task(
298            &txn,
299            Some(self.task_id),
300            Some(seq),
301            name,
302            "STEP",
303            None,
304            true,
305            Some(0),
306        )
307        .await?;
308
309        // If already completed, replay from saved output
310        if let Some(output) = saved_output {
311            txn.commit().await?;
312            let val: T = serde_json::from_value(output)?;
313            tracing::debug!(step = name, seq, "replaying saved output");
314            return Ok(val);
315        }
316
317        // Execute the step closure within the transaction (row lock is held)
318        retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
319        match f().await {
320            Ok(val) => {
321                let json = serde_json::to_value(&val)?;
322                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
323                txn.commit().await?;
324                tracing::debug!(step = name, seq, "step completed");
325                Ok(val)
326            }
327            Err(e) => {
328                let err_msg = e.to_string();
329                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
330                txn.commit().await?;
331                Err(e)
332            }
333        }
334    }
335
336    /// Run a DB-only step inside a single Postgres transaction.
337    ///
338    /// Both the user's DB work and the checkpoint save happen in the same transaction,
339    /// ensuring atomicity. If the closure returns an error, both the user writes and
340    /// the checkpoint are rolled back.
341    ///
342    /// ```ignore
343    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
344    ///     do_db_work(tx).await
345    /// })).await?;
346    /// ```
347    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
348    where
349        T: Serialize + DeserializeOwned + Send,
350        F: for<'tx> FnOnce(
351                &'tx DatabaseTransaction,
352            ) -> Pin<
353                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
354            > + Send,
355    {
356        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
357
358        // Check if workflow is paused or cancelled before executing
359        check_status(&self.db, self.task_id).await?;
360
361        // Find or create the step task record OUTSIDE the transaction.
362        // This is idempotent (UNIQUE constraint) and must exist before we begin.
363        let (step_id, saved_output) = find_or_create_task(
364            &self.db,
365            Some(self.task_id),
366            Some(seq),
367            name,
368            "TRANSACTION",
369            None,
370            false,
371            None,
372        )
373        .await?;
374
375        // If already completed, replay from saved output.
376        if let Some(output) = saved_output {
377            let val: T = serde_json::from_value(output)?;
378            tracing::debug!(step = name, seq, "replaying saved transaction output");
379            return Ok(val);
380        }
381
382        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
383        let tx = self.db.begin().await?;
384
385        set_status(&tx, step_id, TaskStatus::Running).await?;
386
387        match f(&tx).await {
388            Ok(val) => {
389                let json = serde_json::to_value(&val)?;
390                complete_task(&tx, step_id, json).await?;
391                tx.commit().await?;
392                tracing::debug!(step = name, seq, "transaction step committed");
393                Ok(val)
394            }
395            Err(e) => {
396                // Rollback happens automatically when tx is dropped.
397                // Record the failure on the main connection (outside the rolled-back tx).
398                drop(tx);
399                fail_task(&self.db, step_id, &e.to_string()).await?;
400                Err(e)
401            }
402        }
403    }
404
405    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
406    ///
407    /// ```ignore
408    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
409    /// // use child_ctx.step(...) for steps inside the child
410    /// child_ctx.complete(json!({"done": true})).await?;
411    /// ```
412    pub async fn child(
413        &self,
414        name: &str,
415        input: Option<serde_json::Value>,
416    ) -> Result<Self, DurableError> {
417        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
418
419        // Check if workflow is paused or cancelled before executing
420        check_status(&self.db, self.task_id).await?;
421
422        let txn = self.db.begin().await?;
423        // Child workflows also use plain find-or-create without locking.
424        let (child_id, _saved) = find_or_create_task(
425            &txn,
426            Some(self.task_id),
427            Some(seq),
428            name,
429            "WORKFLOW",
430            input,
431            false,
432            None,
433        )
434        .await?;
435
436        // If child already completed, return a Ctx that will replay
437        // (the caller should check is_completed() or just run steps which will replay)
438        retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
439        txn.commit().await?;
440
441        Ok(Self {
442            db: self.db.clone(),
443            task_id: child_id,
444            sequence: AtomicI32::new(0),
445        })
446    }
447
448    /// Check if this workflow/child was already completed (for skipping in parent).
449    pub async fn is_completed(&self) -> Result<bool, DurableError> {
450        let status = get_status(&self.db, self.task_id).await?;
451        Ok(status == Some(TaskStatus::Completed))
452    }
453
454    /// Get the saved output if this task is completed.
455    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
456        match get_output(&self.db, self.task_id).await? {
457            Some(val) => Ok(Some(serde_json::from_value(val)?)),
458            None => Ok(None),
459        }
460    }
461
462    /// Mark this workflow as completed with an output value.
463    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
464        let json = serde_json::to_value(output)?;
465        let db = &self.db;
466        let task_id = self.task_id;
467        retry_db_write(|| complete_task(db, task_id, json.clone())).await
468    }
469
470    /// Run a step with a configurable retry policy.
471    ///
472    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
473    /// it may be called multiple times on retry. Retries happen in-process with
474    /// configurable backoff between attempts.
475    ///
476    /// ```ignore
477    /// let result: u32 = ctx
478    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
479    ///         api.call().await
480    ///     })
481    ///     .await?;
482    /// ```
483    pub async fn step_with_retry<T, F, Fut>(
484        &self,
485        name: &str,
486        policy: RetryPolicy,
487        f: F,
488    ) -> Result<T, DurableError>
489    where
490        T: Serialize + DeserializeOwned,
491        F: Fn() -> Fut,
492        Fut: std::future::Future<Output = Result<T, DurableError>>,
493    {
494        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
495
496        // Check if workflow is paused or cancelled before executing
497        check_status(&self.db, self.task_id).await?;
498
499        // Find or create — idempotent via UNIQUE(parent_id, name)
500        // Set max_retries from policy when creating.
501        // No locking here — retry logic handles re-execution in-process.
502        let (step_id, saved_output) = find_or_create_task(
503            &self.db,
504            Some(self.task_id),
505            Some(seq),
506            name,
507            "STEP",
508            None,
509            false,
510            Some(policy.max_retries),
511        )
512        .await?;
513
514        // If already completed, replay from saved output
515        if let Some(output) = saved_output {
516            let val: T = serde_json::from_value(output)?;
517            tracing::debug!(step = name, seq, "replaying saved output");
518            return Ok(val);
519        }
520
521        // Get current retry state from DB (for resume across restarts)
522        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
523
524        // Retry loop
525        loop {
526            // Re-check status before each retry attempt
527            check_status(&self.db, self.task_id).await?;
528            set_status(&self.db, step_id, TaskStatus::Running).await?;
529            match f().await {
530                Ok(val) => {
531                    let json = serde_json::to_value(&val)?;
532                    complete_task(&self.db, step_id, json).await?;
533                    tracing::debug!(step = name, seq, retry_count, "step completed");
534                    return Ok(val);
535                }
536                Err(e) => {
537                    if retry_count < max_retries {
538                        // Increment retry count and reset to PENDING
539                        retry_count = increment_retry_count(&self.db, step_id).await?;
540                        tracing::debug!(
541                            step = name,
542                            seq,
543                            retry_count,
544                            max_retries,
545                            "step failed, retrying"
546                        );
547
548                        // Compute backoff duration
549                        let backoff = if policy.initial_backoff.is_zero() {
550                            std::time::Duration::ZERO
551                        } else {
552                            let factor = policy
553                                .backoff_multiplier
554                                .powi((retry_count - 1) as i32)
555                                .max(1.0);
556                            let millis =
557                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
558                            std::time::Duration::from_millis(millis)
559                        };
560
561                        if !backoff.is_zero() {
562                            tokio::time::sleep(backoff).await;
563                        }
564                    } else {
565                        // Exhausted retries — mark FAILED
566                        fail_task(&self.db, step_id, &e.to_string()).await?;
567                        tracing::debug!(
568                            step = name,
569                            seq,
570                            retry_count,
571                            "step exhausted retries, marked FAILED"
572                        );
573                        return Err(e);
574                    }
575                }
576            }
577        }
578    }
579
580    /// Mark this workflow as failed.
581    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
582        let db = &self.db;
583        let task_id = self.task_id;
584        retry_db_write(|| fail_task(db, task_id, error)).await
585    }
586
587    /// Set the timeout for this task in milliseconds.
588    ///
589    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
590    /// for recovery by `Executor::recover()`.
591    ///
592    /// If the task is currently RUNNING and has a `started_at`, this also
593    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
594    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
595        let sql = format!(
596            "UPDATE durable.task \
597             SET timeout_ms = {timeout_ms}, \
598                 deadline_epoch_ms = CASE \
599                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
600                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
601                     ELSE deadline_epoch_ms \
602                 END \
603             WHERE id = '{}'",
604            self.task_id
605        );
606        self.db
607            .execute(Statement::from_string(DbBackend::Postgres, sql))
608            .await?;
609        Ok(())
610    }
611
612    /// Start or resume a root workflow with a timeout in milliseconds.
613    ///
614    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
615    pub async fn start_with_timeout(
616        db: &DatabaseConnection,
617        name: &str,
618        input: Option<serde_json::Value>,
619        timeout_ms: i64,
620    ) -> Result<Self, DurableError> {
621        let ctx = Self::start(db, name, input).await?;
622        ctx.set_timeout(timeout_ms).await?;
623        Ok(ctx)
624    }
625
626    // ── Workflow control (management API) ─────────────────────────
627
628    /// Pause a workflow by ID. Sets status to PAUSED and recursively
629    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
630    ///
631    /// Only workflows in PENDING or RUNNING status can be paused.
632    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
633        let model = Task::find_by_id(task_id).one(db).await?;
634        let model =
635            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
636
637        match model.status {
638            TaskStatus::Pending | TaskStatus::Running => {}
639            status => {
640                return Err(DurableError::custom(format!(
641                    "cannot pause task in {status} status"
642                )));
643            }
644        }
645
646        // Pause the task itself and all descendants in one recursive CTE
647        let sql = format!(
648            "WITH RECURSIVE descendants AS ( \
649                 SELECT id FROM durable.task WHERE id = '{task_id}' \
650                 UNION ALL \
651                 SELECT t.id FROM durable.task t \
652                 INNER JOIN descendants d ON t.parent_id = d.id \
653             ) \
654             UPDATE durable.task SET status = 'PAUSED' \
655             WHERE id IN (SELECT id FROM descendants) \
656               AND status IN ('PENDING', 'RUNNING')"
657        );
658        db.execute(Statement::from_string(DbBackend::Postgres, sql))
659            .await?;
660
661        tracing::info!(%task_id, "workflow paused");
662        Ok(())
663    }
664
665    /// Resume a paused workflow by ID. Sets status back to RUNNING and
666    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
667    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
668        let model = Task::find_by_id(task_id).one(db).await?;
669        let model =
670            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
671
672        if model.status != TaskStatus::Paused {
673            return Err(DurableError::custom(format!(
674                "cannot resume task in {} status (must be PAUSED)",
675                model.status
676            )));
677        }
678
679        // Resume the root task back to RUNNING
680        let mut active: TaskActiveModel = model.into();
681        active.status = Set(TaskStatus::Running);
682        active.update(db).await?;
683
684        // Recursively resume all PAUSED descendants back to PENDING
685        let cascade_sql = format!(
686            "WITH RECURSIVE descendants AS ( \
687                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
688                 UNION ALL \
689                 SELECT t.id FROM durable.task t \
690                 INNER JOIN descendants d ON t.parent_id = d.id \
691             ) \
692             UPDATE durable.task SET status = 'PENDING' \
693             WHERE id IN (SELECT id FROM descendants) \
694               AND status = 'PAUSED'"
695        );
696        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
697            .await?;
698
699        tracing::info!(%task_id, "workflow resumed");
700        Ok(())
701    }
702
703    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
704    /// cascades to all non-terminal descendants.
705    ///
706    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
707    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
708        let model = Task::find_by_id(task_id).one(db).await?;
709        let model =
710            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
711
712        match model.status {
713            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
714                return Err(DurableError::custom(format!(
715                    "cannot cancel task in {} status",
716                    model.status
717                )));
718            }
719            _ => {}
720        }
721
722        // Cancel the task itself and all non-terminal descendants in one recursive CTE
723        let sql = format!(
724            "WITH RECURSIVE descendants AS ( \
725                 SELECT id FROM durable.task WHERE id = '{task_id}' \
726                 UNION ALL \
727                 SELECT t.id FROM durable.task t \
728                 INNER JOIN descendants d ON t.parent_id = d.id \
729             ) \
730             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
731             WHERE id IN (SELECT id FROM descendants) \
732               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
733        );
734        db.execute(Statement::from_string(DbBackend::Postgres, sql))
735            .await?;
736
737        tracing::info!(%task_id, "workflow cancelled");
738        Ok(())
739    }
740
741    // ── Query API ─────────────────────────────────────────────────
742
743    /// List tasks matching the given filter, with sorting and pagination.
744    ///
745    /// ```ignore
746    /// let tasks = Ctx::list(&db, TaskQuery::default().status("RUNNING").limit(10)).await?;
747    /// ```
748    pub async fn list(
749        db: &DatabaseConnection,
750        query: TaskQuery,
751    ) -> Result<Vec<TaskSummary>, DurableError> {
752        let mut select = Task::find();
753
754        // Filters
755        if let Some(status) = &query.status {
756            select = select.filter(TaskColumn::Status.eq(status.to_string()));
757        }
758        if let Some(kind) = &query.kind {
759            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
760        }
761        if let Some(parent_id) = query.parent_id {
762            select = select.filter(TaskColumn::ParentId.eq(parent_id));
763        }
764        if query.root_only {
765            select = select.filter(TaskColumn::ParentId.is_null());
766        }
767        if let Some(name) = &query.name {
768            select = select.filter(TaskColumn::Name.eq(name.as_str()));
769        }
770        if let Some(queue) = &query.queue_name {
771            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
772        }
773
774        // Sorting
775        let (col, order) = match query.sort {
776            TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
777            TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
778            TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
779            TaskSort::Name(ord) => (TaskColumn::Name, ord),
780            TaskSort::Status(ord) => (TaskColumn::Status, ord),
781        };
782        select = select.order_by(col, order);
783
784        // Pagination
785        if let Some(offset) = query.offset {
786            select = select.offset(offset);
787        }
788        if let Some(limit) = query.limit {
789            select = select.limit(limit);
790        }
791
792        let models = select.all(db).await?;
793
794        Ok(models.into_iter().map(TaskSummary::from).collect())
795    }
796
797    /// Count tasks matching the given filter.
798    pub async fn count(
799        db: &DatabaseConnection,
800        query: TaskQuery,
801    ) -> Result<u64, DurableError> {
802        let mut select = Task::find();
803
804        if let Some(status) = &query.status {
805            select = select.filter(TaskColumn::Status.eq(status.to_string()));
806        }
807        if let Some(kind) = &query.kind {
808            select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
809        }
810        if let Some(parent_id) = query.parent_id {
811            select = select.filter(TaskColumn::ParentId.eq(parent_id));
812        }
813        if query.root_only {
814            select = select.filter(TaskColumn::ParentId.is_null());
815        }
816        if let Some(name) = &query.name {
817            select = select.filter(TaskColumn::Name.eq(name.as_str()));
818        }
819        if let Some(queue) = &query.queue_name {
820            select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
821        }
822
823        let count = select.count(db).await?;
824        Ok(count)
825    }
826
827    // ── Accessors ────────────────────────────────────────────────
828
829    pub fn db(&self) -> &DatabaseConnection {
830        &self.db
831    }
832
833    pub fn task_id(&self) -> Uuid {
834        self.task_id
835    }
836
837    pub fn next_sequence(&self) -> i32 {
838        self.sequence.fetch_add(1, Ordering::SeqCst)
839    }
840}
841
842// ── Internal SQL helpers ─────────────────────────────────────────────
843
844/// Find an existing task by (parent_id, name) or create a new one.
845///
846/// Returns `(task_id, Option<saved_output>)`:
847/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
848/// - `saved_output` is `None` when the task is new or in-progress.
849///
850/// When `lock` is `true` and an existing non-completed task is found, this
851/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
852/// another worker holds the lock, `DurableError::StepLocked` is returned so
853/// the caller can skip execution rather than double-firing side effects.
854///
855/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
856/// and child-workflow creation where concurrent start is safe).
857///
858/// When `lock` is `true`, the caller MUST call this within a transaction so
859/// the row lock is held throughout step execution.
860///
861/// `max_retries`: if Some, overrides the schema default when creating a new task.
862#[allow(clippy::too_many_arguments)]
863async fn find_or_create_task(
864    db: &impl ConnectionTrait,
865    parent_id: Option<Uuid>,
866    sequence: Option<i32>,
867    name: &str,
868    kind: &str,
869    input: Option<serde_json::Value>,
870    lock: bool,
871    max_retries: Option<u32>,
872) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
873    let parent_eq = match parent_id {
874        Some(p) => format!("= '{p}'"),
875        None => "IS NULL".to_string(),
876    };
877    let parent_sql = match parent_id {
878        Some(p) => format!("'{p}'"),
879        None => "NULL".to_string(),
880    };
881
882    if lock {
883        // Locking path (for steps): we need exactly-once execution.
884        //
885        // Strategy:
886        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
887        //    If another transaction is concurrently inserting the same row,
888        //    Postgres will block here until that transaction commits or rolls
889        //    back, ensuring we never see a phantom "not found" for a row being
890        //    inserted.
891        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
892        //    should get it back; if the row existed and is locked by another
893        //    worker we get nothing.
894        // 3. If SKIP LOCKED returns empty, return StepLocked.
895
896        let new_id = Uuid::new_v4();
897        let seq_sql = match sequence {
898            Some(s) => s.to_string(),
899            None => "NULL".to_string(),
900        };
901        let input_sql = match &input {
902            Some(v) => format!("'{}'", serde_json::to_string(v)?),
903            None => "NULL".to_string(),
904        };
905
906        let max_retries_sql = match max_retries {
907            Some(r) => r.to_string(),
908            None => "3".to_string(), // schema default
909        };
910
911        // Step 1: insert-or-skip
912        let insert_sql = format!(
913            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
914             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
915             ON CONFLICT (parent_id, sequence) DO NOTHING"
916        );
917        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
918            .await?;
919
920        // Step 2: lock the row (ours or pre-existing)
921        let lock_sql = format!(
922            "SELECT id, status::text, output FROM durable.task \
923             WHERE parent_id {parent_eq} AND name = '{name}' \
924             FOR UPDATE SKIP LOCKED"
925        );
926        let row = db
927            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
928            .await?;
929
930        if let Some(row) = row {
931            let id: Uuid = row
932                .try_get_by_index(0)
933                .map_err(|e| DurableError::custom(e.to_string()))?;
934            let status: String = row
935                .try_get_by_index(1)
936                .map_err(|e| DurableError::custom(e.to_string()))?;
937            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
938
939            if status == TaskStatus::Completed.to_string() {
940                // Replay path — return saved output
941                return Ok((id, output));
942            }
943            // Task exists and we hold the lock — proceed to execute
944            return Ok((id, None));
945        }
946
947        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
948        Err(DurableError::StepLocked(name.to_string()))
949    } else {
950        // Plain find without locking — safe for workflow-level operations.
951        // Multiple workers resuming the same workflow is fine; individual
952        // steps will be locked when executed.
953        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
954        query = match parent_id {
955            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
956            None => query.filter(TaskColumn::ParentId.is_null()),
957        };
958        // Skip cancelled/failed tasks — a fresh row will be created instead.
959        // Completed tasks are still returned so steps can replay saved output.
960        let existing = query
961            .filter(TaskColumn::Status.is_not_in([
962                TaskStatus::Cancelled,
963                TaskStatus::Failed,
964            ]))
965            .one(db)
966            .await?;
967
968        if let Some(model) = existing {
969            if model.status == TaskStatus::Completed {
970                return Ok((model.id, model.output));
971            }
972            return Ok((model.id, None));
973        }
974
975        // Task does not exist — create it
976        let id = Uuid::new_v4();
977        let new_task = TaskActiveModel {
978            id: Set(id),
979            parent_id: Set(parent_id),
980            sequence: Set(sequence),
981            name: Set(name.to_string()),
982            kind: Set(kind.to_string()),
983            status: Set(TaskStatus::Pending),
984            input: Set(input),
985            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
986            ..Default::default()
987        };
988        new_task.insert(db).await?;
989
990        Ok((id, None))
991    }
992}
993
994async fn get_output(
995    db: &impl ConnectionTrait,
996    task_id: Uuid,
997) -> Result<Option<serde_json::Value>, DurableError> {
998    let model = Task::find_by_id(task_id)
999        .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1000        .one(db)
1001        .await?;
1002
1003    Ok(model.and_then(|m| m.output))
1004}
1005
1006async fn get_status(
1007    db: &impl ConnectionTrait,
1008    task_id: Uuid,
1009) -> Result<Option<TaskStatus>, DurableError> {
1010    let model = Task::find_by_id(task_id).one(db).await?;
1011
1012    Ok(model.map(|m| m.status))
1013}
1014
1015/// Returns (retry_count, max_retries) for a task.
1016async fn get_retry_info(
1017    db: &DatabaseConnection,
1018    task_id: Uuid,
1019) -> Result<(u32, u32), DurableError> {
1020    let model = Task::find_by_id(task_id).one(db).await?;
1021
1022    match model {
1023        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1024        None => Err(DurableError::custom(format!(
1025            "task {task_id} not found when reading retry info"
1026        ))),
1027    }
1028}
1029
1030/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
1031async fn increment_retry_count(
1032    db: &DatabaseConnection,
1033    task_id: Uuid,
1034) -> Result<u32, DurableError> {
1035    let model = Task::find_by_id(task_id).one(db).await?;
1036
1037    match model {
1038        Some(m) => {
1039            let new_count = m.retry_count + 1;
1040            let mut active: TaskActiveModel = m.into();
1041            active.retry_count = Set(new_count);
1042            active.status = Set(TaskStatus::Pending);
1043            active.error = Set(None);
1044            active.completed_at = Set(None);
1045            active.update(db).await?;
1046            Ok(new_count as u32)
1047        }
1048        None => Err(DurableError::custom(format!(
1049            "task {task_id} not found when incrementing retry count"
1050        ))),
1051    }
1052}
1053
1054// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
1055// for conditional deadline_epoch_ms computation or COALESCE on started_at.
1056async fn set_status(
1057    db: &impl ConnectionTrait,
1058    task_id: Uuid,
1059    status: TaskStatus,
1060) -> Result<(), DurableError> {
1061    let sql = format!(
1062        "UPDATE durable.task \
1063         SET status = '{status}', \
1064             started_at = COALESCE(started_at, now()), \
1065             deadline_epoch_ms = CASE \
1066                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1067                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1068                 ELSE deadline_epoch_ms \
1069             END \
1070         WHERE id = '{task_id}'"
1071    );
1072    db.execute(Statement::from_string(DbBackend::Postgres, sql))
1073        .await?;
1074    Ok(())
1075}
1076
1077/// Check if the task is paused or cancelled. Returns an error if so.
1078async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1079    let status = get_status(db, task_id).await?;
1080    match status {
1081        Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1082        Some(TaskStatus::Cancelled) => {
1083            Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1084        }
1085        _ => Ok(()),
1086    }
1087}
1088
1089/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
1090async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1091    let model = Task::find_by_id(task_id).one(db).await?;
1092
1093    if let Some(m) = model
1094        && let Some(deadline_ms) = m.deadline_epoch_ms
1095    {
1096        let now_ms = std::time::SystemTime::now()
1097            .duration_since(std::time::UNIX_EPOCH)
1098            .map(|d| d.as_millis() as i64)
1099            .unwrap_or(0);
1100        if now_ms > deadline_ms {
1101            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1102        }
1103    }
1104
1105    Ok(())
1106}
1107
1108async fn complete_task(
1109    db: &impl ConnectionTrait,
1110    task_id: Uuid,
1111    output: serde_json::Value,
1112) -> Result<(), DurableError> {
1113    let model = Task::find_by_id(task_id).one(db).await?;
1114
1115    if let Some(m) = model {
1116        let mut active: TaskActiveModel = m.into();
1117        active.status = Set(TaskStatus::Completed);
1118        active.output = Set(Some(output));
1119        active.completed_at = Set(Some(chrono::Utc::now().into()));
1120        active.update(db).await?;
1121    }
1122    Ok(())
1123}
1124
1125async fn fail_task(
1126    db: &impl ConnectionTrait,
1127    task_id: Uuid,
1128    error: &str,
1129) -> Result<(), DurableError> {
1130    let model = Task::find_by_id(task_id).one(db).await?;
1131
1132    if let Some(m) = model {
1133        let mut active: TaskActiveModel = m.into();
1134        active.status = Set(TaskStatus::Failed);
1135        active.error = Set(Some(error.to_string()));
1136        active.completed_at = Set(Some(chrono::Utc::now().into()));
1137        active.update(db).await?;
1138    }
1139    Ok(())
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144    use super::*;
1145    use std::sync::Arc;
1146    use std::sync::atomic::{AtomicU32, Ordering};
1147
1148    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
1149    /// should be called exactly once and return Ok.
1150    #[tokio::test]
1151    async fn test_retry_db_write_succeeds_first_try() {
1152        let call_count = Arc::new(AtomicU32::new(0));
1153        let cc = call_count.clone();
1154        let result = retry_db_write(|| {
1155            let c = cc.clone();
1156            async move {
1157                c.fetch_add(1, Ordering::SeqCst);
1158                Ok::<(), DurableError>(())
1159            }
1160        })
1161        .await;
1162        assert!(result.is_ok());
1163        assert_eq!(call_count.load(Ordering::SeqCst), 1);
1164    }
1165
1166    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
1167    /// fails twice then succeeds should return Ok and be called 3 times.
1168    #[tokio::test]
1169    async fn test_retry_db_write_succeeds_after_transient_failure() {
1170        let call_count = Arc::new(AtomicU32::new(0));
1171        let cc = call_count.clone();
1172        let result = retry_db_write(|| {
1173            let c = cc.clone();
1174            async move {
1175                let n = c.fetch_add(1, Ordering::SeqCst);
1176                if n < 2 {
1177                    Err(DurableError::Db(sea_orm::DbErr::Custom(
1178                        "transient".to_string(),
1179                    )))
1180                } else {
1181                    Ok(())
1182                }
1183            }
1184        })
1185        .await;
1186        assert!(result.is_ok());
1187        assert_eq!(call_count.load(Ordering::SeqCst), 3);
1188    }
1189
1190    /// test_retry_db_write_exhausts_retries: a closure that always fails should
1191    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
1192    #[tokio::test]
1193    async fn test_retry_db_write_exhausts_retries() {
1194        let call_count = Arc::new(AtomicU32::new(0));
1195        let cc = call_count.clone();
1196        let result = retry_db_write(|| {
1197            let c = cc.clone();
1198            async move {
1199                c.fetch_add(1, Ordering::SeqCst);
1200                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1201                    "always fails".to_string(),
1202                )))
1203            }
1204        })
1205        .await;
1206        assert!(result.is_err());
1207        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
1208        assert_eq!(
1209            call_count.load(Ordering::SeqCst),
1210            1 + MAX_CHECKPOINT_RETRIES
1211        );
1212    }
1213
1214    /// test_retry_db_write_returns_original_error: when all retries are
1215    /// exhausted the FIRST error is returned, not the last retry error.
1216    #[tokio::test]
1217    async fn test_retry_db_write_returns_original_error() {
1218        let call_count = Arc::new(AtomicU32::new(0));
1219        let cc = call_count.clone();
1220        let result = retry_db_write(|| {
1221            let c = cc.clone();
1222            async move {
1223                let n = c.fetch_add(1, Ordering::SeqCst);
1224                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1225                    "error-{}",
1226                    n
1227                ))))
1228            }
1229        })
1230        .await;
1231        let err = result.unwrap_err();
1232        // The message of the first error contains "error-0"
1233        assert!(
1234            err.to_string().contains("error-0"),
1235            "expected first error (error-0), got: {err}"
1236        );
1237    }
1238}