Skip to main content

durable/
ctx.rs

1use std::pin::Pin;
2
3use durable_db::entity::task::{
4    ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
5};
6use sea_orm::{
7    ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8    DbBackend, EntityTrait, QueryFilter, Set, Statement, TransactionTrait,
9};
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use std::sync::atomic::{AtomicI32, Ordering};
13use std::time::Duration;
14use uuid::Uuid;
15
16use crate::error::DurableError;
17
18// ── Retry constants ───────────────────────────────────────────────────────────
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
21const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
22
23/// Retry a fallible DB write with exponential backoff.
24///
25/// Calls `f` once immediately. On failure, retries up to `MAX_CHECKPOINT_RETRIES`
26/// times with exponential backoff (100ms, 200ms, 400ms). Returns the FIRST
27/// error if all retries are exhausted.
28async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
29where
30    F: FnMut() -> Fut,
31    Fut: std::future::Future<Output = Result<(), DurableError>>,
32{
33    match f().await {
34        Ok(()) => Ok(()),
35        Err(first_err) => {
36            for i in 0..MAX_CHECKPOINT_RETRIES {
37                tokio::time::sleep(Duration::from_millis(
38                    CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
39                ))
40                .await;
41                if f().await.is_ok() {
42                    tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
43                    return Ok(());
44                }
45            }
46            Err(first_err)
47        }
48    }
49}
50
51/// Policy for retrying a step on failure.
52pub struct RetryPolicy {
53    pub max_retries: u32,
54    pub initial_backoff: std::time::Duration,
55    pub backoff_multiplier: f64,
56}
57
58impl RetryPolicy {
59    /// No retries — fails immediately on error.
60    pub fn none() -> Self {
61        Self {
62            max_retries: 0,
63            initial_backoff: std::time::Duration::from_secs(0),
64            backoff_multiplier: 1.0,
65        }
66    }
67
68    /// Exponential backoff: backoff doubles each retry.
69    pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
70        Self {
71            max_retries,
72            initial_backoff,
73            backoff_multiplier: 2.0,
74        }
75    }
76
77    /// Fixed backoff: same duration between all retries.
78    pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
79        Self {
80            max_retries,
81            initial_backoff: backoff,
82            backoff_multiplier: 1.0,
83        }
84    }
85}
86
87/// Context threaded through every workflow and step.
88///
89/// Users never create or manage task IDs. The SDK handles everything
90/// via `(parent_id, name)` lookups — the unique constraint in the schema
91/// guarantees exactly-once step creation.
92pub struct Ctx {
93    db: DatabaseConnection,
94    task_id: Uuid,
95    sequence: AtomicI32,
96}
97
98impl Ctx {
99    // ── Workflow lifecycle (user-facing) ──────────────────────────
100
101    /// Start or resume a root workflow by name.
102    ///
103    /// ```ignore
104    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
105    /// ```
106    pub async fn start(
107        db: &DatabaseConnection,
108        name: &str,
109        input: Option<serde_json::Value>,
110    ) -> Result<Self, DurableError> {
111        let txn = db.begin().await?;
112        // Workflows use plain find-or-create without FOR UPDATE SKIP LOCKED.
113        // Multiple workers resuming the same workflow is fine — they share the
114        // same task row and each step will be individually locked.
115        let (task_id, _saved) =
116            find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
117        retry_db_write(|| set_status(&txn, task_id, "RUNNING")).await?;
118        txn.commit().await?;
119        Ok(Self {
120            db: db.clone(),
121            task_id,
122            sequence: AtomicI32::new(0),
123        })
124    }
125
126    /// Run a step. If already completed, returns saved output. Otherwise executes
127    /// the closure, saves the result, and returns it.
128    ///
129    /// This method uses no retries (max_retries=0). For retries, use `step_with_retry`.
130    ///
131    /// ```ignore
132    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
133    /// ```
134    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
135    where
136        T: Serialize + DeserializeOwned,
137        F: FnOnce() -> Fut,
138        Fut: std::future::Future<Output = Result<T, DurableError>>,
139    {
140        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
141
142        // Check if workflow is paused or cancelled before executing
143        check_status(&self.db, self.task_id).await?;
144
145        // Check if parent task's deadline has passed before executing
146        check_deadline(&self.db, self.task_id).await?;
147
148        // Begin transaction — the FOR UPDATE lock is held throughout step execution
149        let txn = self.db.begin().await?;
150
151        // Find or create — idempotent via UNIQUE(parent_id, name)
152        // Returns (step_id, Option<saved_output>) where saved_output is Some iff COMPLETED.
153        // Use FOR UPDATE SKIP LOCKED so only one worker can execute a given step.
154        // max_retries=0: step() does not retry; use step_with_retry() for retries.
155        let (step_id, saved_output) = find_or_create_task(
156            &txn,
157            Some(self.task_id),
158            Some(seq),
159            name,
160            "STEP",
161            None,
162            true,
163            Some(0),
164        )
165        .await?;
166
167        // If already completed, replay from saved output
168        if let Some(output) = saved_output {
169            txn.commit().await?;
170            let val: T = serde_json::from_value(output)?;
171            tracing::debug!(step = name, seq, "replaying saved output");
172            return Ok(val);
173        }
174
175        // Execute the step closure within the transaction (row lock is held)
176        retry_db_write(|| set_status(&txn, step_id, "RUNNING")).await?;
177        match f().await {
178            Ok(val) => {
179                let json = serde_json::to_value(&val)?;
180                retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
181                txn.commit().await?;
182                tracing::debug!(step = name, seq, "step completed");
183                Ok(val)
184            }
185            Err(e) => {
186                let err_msg = e.to_string();
187                retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
188                txn.commit().await?;
189                Err(e)
190            }
191        }
192    }
193
194    /// Run a DB-only step inside a single Postgres transaction.
195    ///
196    /// Both the user's DB work and the checkpoint save happen in the same transaction,
197    /// ensuring atomicity. If the closure returns an error, both the user writes and
198    /// the checkpoint are rolled back.
199    ///
200    /// ```ignore
201    /// let count: u32 = ctx.transaction("upsert_batch", |tx| Box::pin(async move {
202    ///     do_db_work(tx).await
203    /// })).await?;
204    /// ```
205    pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
206    where
207        T: Serialize + DeserializeOwned + Send,
208        F: for<'tx> FnOnce(
209                &'tx DatabaseTransaction,
210            ) -> Pin<
211                Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
212            > + Send,
213    {
214        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
215
216        // Check if workflow is paused or cancelled before executing
217        check_status(&self.db, self.task_id).await?;
218
219        // Find or create the step task record OUTSIDE the transaction.
220        // This is idempotent (UNIQUE constraint) and must exist before we begin.
221        let (step_id, saved_output) = find_or_create_task(
222            &self.db,
223            Some(self.task_id),
224            Some(seq),
225            name,
226            "TRANSACTION",
227            None,
228            false,
229            None,
230        )
231        .await?;
232
233        // If already completed, replay from saved output.
234        if let Some(output) = saved_output {
235            let val: T = serde_json::from_value(output)?;
236            tracing::debug!(step = name, seq, "replaying saved transaction output");
237            return Ok(val);
238        }
239
240        // Begin the transaction — set_status, user work, and complete_task all happen atomically.
241        let tx = self.db.begin().await?;
242
243        set_status(&tx, step_id, "RUNNING").await?;
244
245        match f(&tx).await {
246            Ok(val) => {
247                let json = serde_json::to_value(&val)?;
248                complete_task(&tx, step_id, json).await?;
249                tx.commit().await?;
250                tracing::debug!(step = name, seq, "transaction step committed");
251                Ok(val)
252            }
253            Err(e) => {
254                // Rollback happens automatically when tx is dropped.
255                // Record the failure on the main connection (outside the rolled-back tx).
256                drop(tx);
257                fail_task(&self.db, step_id, &e.to_string()).await?;
258                Err(e)
259            }
260        }
261    }
262
263    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
264    ///
265    /// ```ignore
266    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
267    /// // use child_ctx.step(...) for steps inside the child
268    /// child_ctx.complete(json!({"done": true})).await?;
269    /// ```
270    pub async fn child(
271        &self,
272        name: &str,
273        input: Option<serde_json::Value>,
274    ) -> Result<Self, DurableError> {
275        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
276
277        // Check if workflow is paused or cancelled before executing
278        check_status(&self.db, self.task_id).await?;
279
280        let txn = self.db.begin().await?;
281        // Child workflows also use plain find-or-create without locking.
282        let (child_id, _saved) = find_or_create_task(
283            &txn,
284            Some(self.task_id),
285            Some(seq),
286            name,
287            "WORKFLOW",
288            input,
289            false,
290            None,
291        )
292        .await?;
293
294        // If child already completed, return a Ctx that will replay
295        // (the caller should check is_completed() or just run steps which will replay)
296        retry_db_write(|| set_status(&txn, child_id, "RUNNING")).await?;
297        txn.commit().await?;
298
299        Ok(Self {
300            db: self.db.clone(),
301            task_id: child_id,
302            sequence: AtomicI32::new(0),
303        })
304    }
305
306    /// Check if this workflow/child was already completed (for skipping in parent).
307    pub async fn is_completed(&self) -> Result<bool, DurableError> {
308        let status = get_status(&self.db, self.task_id).await?;
309        Ok(status.as_deref() == Some("COMPLETED"))
310    }
311
312    /// Get the saved output if this task is completed.
313    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
314        match get_output(&self.db, self.task_id).await? {
315            Some(val) => Ok(Some(serde_json::from_value(val)?)),
316            None => Ok(None),
317        }
318    }
319
320    /// Mark this workflow as completed with an output value.
321    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
322        let json = serde_json::to_value(output)?;
323        let db = &self.db;
324        let task_id = self.task_id;
325        retry_db_write(|| complete_task(db, task_id, json.clone())).await
326    }
327
328    /// Run a step with a configurable retry policy.
329    ///
330    /// Unlike `step()`, the closure must implement `Fn` (not `FnOnce`) since
331    /// it may be called multiple times on retry. Retries happen in-process with
332    /// configurable backoff between attempts.
333    ///
334    /// ```ignore
335    /// let result: u32 = ctx
336    ///     .step_with_retry("call_api", RetryPolicy::exponential(3, Duration::from_secs(1)), || async {
337    ///         api.call().await
338    ///     })
339    ///     .await?;
340    /// ```
341    pub async fn step_with_retry<T, F, Fut>(
342        &self,
343        name: &str,
344        policy: RetryPolicy,
345        f: F,
346    ) -> Result<T, DurableError>
347    where
348        T: Serialize + DeserializeOwned,
349        F: Fn() -> Fut,
350        Fut: std::future::Future<Output = Result<T, DurableError>>,
351    {
352        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
353
354        // Check if workflow is paused or cancelled before executing
355        check_status(&self.db, self.task_id).await?;
356
357        // Find or create — idempotent via UNIQUE(parent_id, name)
358        // Set max_retries from policy when creating.
359        // No locking here — retry logic handles re-execution in-process.
360        let (step_id, saved_output) = find_or_create_task(
361            &self.db,
362            Some(self.task_id),
363            Some(seq),
364            name,
365            "STEP",
366            None,
367            false,
368            Some(policy.max_retries),
369        )
370        .await?;
371
372        // If already completed, replay from saved output
373        if let Some(output) = saved_output {
374            let val: T = serde_json::from_value(output)?;
375            tracing::debug!(step = name, seq, "replaying saved output");
376            return Ok(val);
377        }
378
379        // Get current retry state from DB (for resume across restarts)
380        let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
381
382        // Retry loop
383        loop {
384            // Re-check status before each retry attempt
385            check_status(&self.db, self.task_id).await?;
386            set_status(&self.db, step_id, "RUNNING").await?;
387            match f().await {
388                Ok(val) => {
389                    let json = serde_json::to_value(&val)?;
390                    complete_task(&self.db, step_id, json).await?;
391                    tracing::debug!(step = name, seq, retry_count, "step completed");
392                    return Ok(val);
393                }
394                Err(e) => {
395                    if retry_count < max_retries {
396                        // Increment retry count and reset to PENDING
397                        retry_count = increment_retry_count(&self.db, step_id).await?;
398                        tracing::debug!(
399                            step = name,
400                            seq,
401                            retry_count,
402                            max_retries,
403                            "step failed, retrying"
404                        );
405
406                        // Compute backoff duration
407                        let backoff = if policy.initial_backoff.is_zero() {
408                            std::time::Duration::ZERO
409                        } else {
410                            let factor = policy
411                                .backoff_multiplier
412                                .powi((retry_count - 1) as i32)
413                                .max(1.0);
414                            let millis =
415                                (policy.initial_backoff.as_millis() as f64 * factor) as u64;
416                            std::time::Duration::from_millis(millis)
417                        };
418
419                        if !backoff.is_zero() {
420                            tokio::time::sleep(backoff).await;
421                        }
422                    } else {
423                        // Exhausted retries — mark FAILED
424                        fail_task(&self.db, step_id, &e.to_string()).await?;
425                        tracing::debug!(
426                            step = name,
427                            seq,
428                            retry_count,
429                            "step exhausted retries, marked FAILED"
430                        );
431                        return Err(e);
432                    }
433                }
434            }
435        }
436    }
437
438    /// Mark this workflow as failed.
439    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
440        let db = &self.db;
441        let task_id = self.task_id;
442        retry_db_write(|| fail_task(db, task_id, error)).await
443    }
444
445    /// Set the timeout for this task in milliseconds.
446    ///
447    /// If a task stays RUNNING longer than timeout_ms, it will be eligible
448    /// for recovery by `Executor::recover()`.
449    ///
450    /// If the task is currently RUNNING and has a `started_at`, this also
451    /// computes and stores `deadline_epoch_ms = started_at_epoch_ms + timeout_ms`.
452    pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
453        let sql = format!(
454            "UPDATE durable.task \
455             SET timeout_ms = {timeout_ms}, \
456                 deadline_epoch_ms = CASE \
457                     WHEN status = 'RUNNING' AND started_at IS NOT NULL \
458                     THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
459                     ELSE deadline_epoch_ms \
460                 END \
461             WHERE id = '{}'",
462            self.task_id
463        );
464        self.db
465            .execute(Statement::from_string(DbBackend::Postgres, sql))
466            .await?;
467        Ok(())
468    }
469
470    /// Start or resume a root workflow with a timeout in milliseconds.
471    ///
472    /// Equivalent to calling `Ctx::start()` followed by `ctx.set_timeout(timeout_ms)`.
473    pub async fn start_with_timeout(
474        db: &DatabaseConnection,
475        name: &str,
476        input: Option<serde_json::Value>,
477        timeout_ms: i64,
478    ) -> Result<Self, DurableError> {
479        let ctx = Self::start(db, name, input).await?;
480        ctx.set_timeout(timeout_ms).await?;
481        Ok(ctx)
482    }
483
484    // ── Workflow control (management API) ─────────────────────────
485
486    /// Pause a workflow by ID. Sets status to PAUSED and recursively
487    /// cascades to all PENDING/RUNNING descendants (children, grandchildren, etc.).
488    ///
489    /// Only workflows in PENDING or RUNNING status can be paused.
490    pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
491        let model = Task::find_by_id(task_id).one(db).await?;
492        let model =
493            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
494
495        match model.status.as_str() {
496            "PENDING" | "RUNNING" => {}
497            status => {
498                return Err(DurableError::custom(format!(
499                    "cannot pause task in {status} status"
500                )));
501            }
502        }
503
504        // Pause the task itself and all descendants in one recursive CTE
505        let sql = format!(
506            "WITH RECURSIVE descendants AS ( \
507                 SELECT id FROM durable.task WHERE id = '{task_id}' \
508                 UNION ALL \
509                 SELECT t.id FROM durable.task t \
510                 INNER JOIN descendants d ON t.parent_id = d.id \
511             ) \
512             UPDATE durable.task SET status = 'PAUSED' \
513             WHERE id IN (SELECT id FROM descendants) \
514               AND status IN ('PENDING', 'RUNNING')"
515        );
516        db.execute(Statement::from_string(DbBackend::Postgres, sql))
517            .await?;
518
519        tracing::info!(%task_id, "workflow paused");
520        Ok(())
521    }
522
523    /// Resume a paused workflow by ID. Sets status back to RUNNING and
524    /// recursively cascades to all PAUSED descendants (resetting them to PENDING).
525    pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
526        let model = Task::find_by_id(task_id).one(db).await?;
527        let model =
528            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
529
530        if model.status != "PAUSED" {
531            return Err(DurableError::custom(format!(
532                "cannot resume task in {} status (must be PAUSED)",
533                model.status
534            )));
535        }
536
537        // Resume the root task back to RUNNING
538        let sql = format!(
539            "UPDATE durable.task SET status = 'RUNNING' WHERE id = '{task_id}'"
540        );
541        db.execute(Statement::from_string(DbBackend::Postgres, sql))
542            .await?;
543
544        // Recursively resume all PAUSED descendants back to PENDING
545        let cascade_sql = format!(
546            "WITH RECURSIVE descendants AS ( \
547                 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
548                 UNION ALL \
549                 SELECT t.id FROM durable.task t \
550                 INNER JOIN descendants d ON t.parent_id = d.id \
551             ) \
552             UPDATE durable.task SET status = 'PENDING' \
553             WHERE id IN (SELECT id FROM descendants) \
554               AND status = 'PAUSED'"
555        );
556        db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
557            .await?;
558
559        tracing::info!(%task_id, "workflow resumed");
560        Ok(())
561    }
562
563    /// Cancel a workflow by ID. Sets status to CANCELLED and recursively
564    /// cascades to all non-terminal descendants.
565    ///
566    /// Cancellation is terminal — a cancelled workflow cannot be resumed.
567    pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
568        let model = Task::find_by_id(task_id).one(db).await?;
569        let model =
570            model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
571
572        match model.status.as_str() {
573            "COMPLETED" | "FAILED" | "CANCELLED" => {
574                return Err(DurableError::custom(format!(
575                    "cannot cancel task in {} status",
576                    model.status
577                )));
578            }
579            _ => {}
580        }
581
582        // Cancel the task itself and all non-terminal descendants in one recursive CTE
583        let sql = format!(
584            "WITH RECURSIVE descendants AS ( \
585                 SELECT id FROM durable.task WHERE id = '{task_id}' \
586                 UNION ALL \
587                 SELECT t.id FROM durable.task t \
588                 INNER JOIN descendants d ON t.parent_id = d.id \
589             ) \
590             UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
591             WHERE id IN (SELECT id FROM descendants) \
592               AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
593        );
594        db.execute(Statement::from_string(DbBackend::Postgres, sql))
595            .await?;
596
597        tracing::info!(%task_id, "workflow cancelled");
598        Ok(())
599    }
600
601    // ── Accessors ────────────────────────────────────────────────
602
603    pub fn db(&self) -> &DatabaseConnection {
604        &self.db
605    }
606
607    pub fn task_id(&self) -> Uuid {
608        self.task_id
609    }
610
611    pub fn next_sequence(&self) -> i32 {
612        self.sequence.fetch_add(1, Ordering::SeqCst)
613    }
614}
615
616// ── Internal SQL helpers ─────────────────────────────────────────────
617
618/// Find an existing task by (parent_id, name) or create a new one.
619///
620/// Returns `(task_id, Option<saved_output>)`:
621/// - `saved_output` is `Some(json)` when the task is COMPLETED (replay path).
622/// - `saved_output` is `None` when the task is new or in-progress.
623///
624/// When `lock` is `true` and an existing non-completed task is found, this
625/// function attempts to acquire a `FOR UPDATE SKIP LOCKED` row lock. If
626/// another worker holds the lock, `DurableError::StepLocked` is returned so
627/// the caller can skip execution rather than double-firing side effects.
628///
629/// When `lock` is `false`, a plain SELECT is used (appropriate for workflow
630/// and child-workflow creation where concurrent start is safe).
631///
632/// When `lock` is `true`, the caller MUST call this within a transaction so
633/// the row lock is held throughout step execution.
634///
635/// `max_retries`: if Some, overrides the schema default when creating a new task.
636#[allow(clippy::too_many_arguments)]
637async fn find_or_create_task(
638    db: &impl ConnectionTrait,
639    parent_id: Option<Uuid>,
640    sequence: Option<i32>,
641    name: &str,
642    kind: &str,
643    input: Option<serde_json::Value>,
644    lock: bool,
645    max_retries: Option<u32>,
646) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
647    let parent_eq = match parent_id {
648        Some(p) => format!("= '{p}'"),
649        None => "IS NULL".to_string(),
650    };
651    let parent_sql = match parent_id {
652        Some(p) => format!("'{p}'"),
653        None => "NULL".to_string(),
654    };
655
656    if lock {
657        // Locking path (for steps): we need exactly-once execution.
658        //
659        // Strategy:
660        // 1. INSERT the row with ON CONFLICT DO NOTHING — idempotent creation.
661        //    If another transaction is concurrently inserting the same row,
662        //    Postgres will block here until that transaction commits or rolls
663        //    back, ensuring we never see a phantom "not found" for a row being
664        //    inserted.
665        // 2. Attempt FOR UPDATE SKIP LOCKED — if we just inserted the row we
666        //    should get it back; if the row existed and is locked by another
667        //    worker we get nothing.
668        // 3. If SKIP LOCKED returns empty, return StepLocked.
669
670        let new_id = Uuid::new_v4();
671        let seq_sql = match sequence {
672            Some(s) => s.to_string(),
673            None => "NULL".to_string(),
674        };
675        let input_sql = match &input {
676            Some(v) => format!("'{}'", serde_json::to_string(v)?),
677            None => "NULL".to_string(),
678        };
679
680        let max_retries_sql = match max_retries {
681            Some(r) => r.to_string(),
682            None => "3".to_string(), // schema default
683        };
684
685        // Step 1: insert-or-skip
686        let insert_sql = format!(
687            "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
688             VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
689             ON CONFLICT (parent_id, name) DO NOTHING"
690        );
691        db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
692            .await?;
693
694        // Step 2: lock the row (ours or pre-existing)
695        let lock_sql = format!(
696            "SELECT id, status, output FROM durable.task \
697             WHERE parent_id {parent_eq} AND name = '{name}' \
698             FOR UPDATE SKIP LOCKED"
699        );
700        let row = db
701            .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
702            .await?;
703
704        if let Some(row) = row {
705            let id: Uuid = row
706                .try_get_by_index(0)
707                .map_err(|e| DurableError::custom(e.to_string()))?;
708            let status: String = row
709                .try_get_by_index(1)
710                .map_err(|e| DurableError::custom(e.to_string()))?;
711            let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
712
713            if status == "COMPLETED" {
714                // Replay path — return saved output
715                return Ok((id, output));
716            }
717            // Task exists and we hold the lock — proceed to execute
718            return Ok((id, None));
719        }
720
721        // Step 3: SKIP LOCKED returned empty — another worker holds the lock
722        Err(DurableError::StepLocked(name.to_string()))
723    } else {
724        // Plain find without locking — safe for workflow-level operations.
725        // Multiple workers resuming the same workflow is fine; individual
726        // steps will be locked when executed.
727        let mut query = Task::find().filter(TaskColumn::Name.eq(name));
728        query = match parent_id {
729            Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
730            None => query.filter(TaskColumn::ParentId.is_null()),
731        };
732        let existing = query.one(db).await?;
733
734        if let Some(model) = existing {
735            if model.status == "COMPLETED" {
736                return Ok((model.id, model.output));
737            }
738            return Ok((model.id, None));
739        }
740
741        // Task does not exist — create it
742        let id = Uuid::new_v4();
743        let new_task = TaskActiveModel {
744            id: Set(id),
745            parent_id: Set(parent_id),
746            sequence: Set(sequence),
747            name: Set(name.to_string()),
748            kind: Set(kind.to_string()),
749            status: Set("PENDING".to_string()),
750            input: Set(input),
751            max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
752            ..Default::default()
753        };
754        new_task.insert(db).await?;
755
756        Ok((id, None))
757    }
758}
759
760async fn get_output(
761    db: &impl ConnectionTrait,
762    task_id: Uuid,
763) -> Result<Option<serde_json::Value>, DurableError> {
764    let model = Task::find_by_id(task_id)
765        .filter(TaskColumn::Status.eq("COMPLETED"))
766        .one(db)
767        .await?;
768
769    Ok(model.and_then(|m| m.output))
770}
771
772async fn get_status(
773    db: &impl ConnectionTrait,
774    task_id: Uuid,
775) -> Result<Option<String>, DurableError> {
776    let model = Task::find_by_id(task_id).one(db).await?;
777
778    Ok(model.map(|m| m.status))
779}
780
781/// Returns (retry_count, max_retries) for a task.
782async fn get_retry_info(
783    db: &DatabaseConnection,
784    task_id: Uuid,
785) -> Result<(u32, u32), DurableError> {
786    let model = Task::find_by_id(task_id).one(db).await?;
787
788    match model {
789        Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
790        None => Err(DurableError::custom(format!(
791            "task {task_id} not found when reading retry info"
792        ))),
793    }
794}
795
796/// Increment retry_count and reset status to PENDING. Returns the new retry_count.
797async fn increment_retry_count(
798    db: &DatabaseConnection,
799    task_id: Uuid,
800) -> Result<u32, DurableError> {
801    let model = Task::find_by_id(task_id).one(db).await?;
802
803    match model {
804        Some(m) => {
805            let new_count = m.retry_count + 1;
806            let mut active: TaskActiveModel = m.into();
807            active.retry_count = Set(new_count);
808            active.status = Set("PENDING".to_string());
809            active.error = Set(None);
810            active.completed_at = Set(None);
811            active.update(db).await?;
812            Ok(new_count as u32)
813        }
814        None => Err(DurableError::custom(format!(
815            "task {task_id} not found when incrementing retry count"
816        ))),
817    }
818}
819
820// NOTE: set_status uses raw SQL because SeaORM cannot express the CASE expression
821// for conditional deadline_epoch_ms computation or COALESCE on started_at.
822async fn set_status(
823    db: &impl ConnectionTrait,
824    task_id: Uuid,
825    status: &str,
826) -> Result<(), DurableError> {
827    let sql = format!(
828        "UPDATE durable.task \
829         SET status = '{status}', \
830             started_at = COALESCE(started_at, now()), \
831             deadline_epoch_ms = CASE \
832                 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
833                 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
834                 ELSE deadline_epoch_ms \
835             END \
836         WHERE id = '{task_id}'"
837    );
838    db.execute(Statement::from_string(DbBackend::Postgres, sql))
839        .await?;
840    Ok(())
841}
842
843/// Check if the task is paused or cancelled. Returns an error if so.
844async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
845    let status = get_status(db, task_id).await?;
846    match status.as_deref() {
847        Some("PAUSED") => Err(DurableError::Paused(format!("task {task_id} is paused"))),
848        Some("CANCELLED") => Err(DurableError::Cancelled(format!("task {task_id} is cancelled"))),
849        _ => Ok(()),
850    }
851}
852
853/// Check if the parent task's deadline has passed. Returns `DurableError::Timeout` if so.
854async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
855    let model = Task::find_by_id(task_id).one(db).await?;
856
857    if let Some(m) = model
858        && let Some(deadline_ms) = m.deadline_epoch_ms
859    {
860        let now_ms = std::time::SystemTime::now()
861            .duration_since(std::time::UNIX_EPOCH)
862            .map(|d| d.as_millis() as i64)
863            .unwrap_or(0);
864        if now_ms > deadline_ms {
865            return Err(DurableError::Timeout("task deadline exceeded".to_string()));
866        }
867    }
868
869    Ok(())
870}
871
872async fn complete_task(
873    db: &impl ConnectionTrait,
874    task_id: Uuid,
875    output: serde_json::Value,
876) -> Result<(), DurableError> {
877    let model = Task::find_by_id(task_id).one(db).await?;
878
879    if let Some(m) = model {
880        let mut active: TaskActiveModel = m.into();
881        active.status = Set("COMPLETED".to_string());
882        active.output = Set(Some(output));
883        active.completed_at = Set(Some(chrono::Utc::now().into()));
884        active.update(db).await?;
885    }
886    Ok(())
887}
888
889async fn fail_task(
890    db: &impl ConnectionTrait,
891    task_id: Uuid,
892    error: &str,
893) -> Result<(), DurableError> {
894    let model = Task::find_by_id(task_id).one(db).await?;
895
896    if let Some(m) = model {
897        let mut active: TaskActiveModel = m.into();
898        active.status = Set("FAILED".to_string());
899        active.error = Set(Some(error.to_string()));
900        active.completed_at = Set(Some(chrono::Utc::now().into()));
901        active.update(db).await?;
902    }
903    Ok(())
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909    use std::sync::Arc;
910    use std::sync::atomic::{AtomicU32, Ordering};
911
912    /// test_retry_db_write_succeeds_first_try: a closure that always succeeds
913    /// should be called exactly once and return Ok.
914    #[tokio::test]
915    async fn test_retry_db_write_succeeds_first_try() {
916        let call_count = Arc::new(AtomicU32::new(0));
917        let cc = call_count.clone();
918        let result = retry_db_write(|| {
919            let c = cc.clone();
920            async move {
921                c.fetch_add(1, Ordering::SeqCst);
922                Ok::<(), DurableError>(())
923            }
924        })
925        .await;
926        assert!(result.is_ok());
927        assert_eq!(call_count.load(Ordering::SeqCst), 1);
928    }
929
930    /// test_retry_db_write_succeeds_after_transient_failure: a closure that
931    /// fails twice then succeeds should return Ok and be called 3 times.
932    #[tokio::test]
933    async fn test_retry_db_write_succeeds_after_transient_failure() {
934        let call_count = Arc::new(AtomicU32::new(0));
935        let cc = call_count.clone();
936        let result = retry_db_write(|| {
937            let c = cc.clone();
938            async move {
939                let n = c.fetch_add(1, Ordering::SeqCst);
940                if n < 2 {
941                    Err(DurableError::Db(sea_orm::DbErr::Custom(
942                        "transient".to_string(),
943                    )))
944                } else {
945                    Ok(())
946                }
947            }
948        })
949        .await;
950        assert!(result.is_ok());
951        assert_eq!(call_count.load(Ordering::SeqCst), 3);
952    }
953
954    /// test_retry_db_write_exhausts_retries: a closure that always fails should
955    /// be called 1 + MAX_CHECKPOINT_RETRIES times total then return an error.
956    #[tokio::test]
957    async fn test_retry_db_write_exhausts_retries() {
958        let call_count = Arc::new(AtomicU32::new(0));
959        let cc = call_count.clone();
960        let result = retry_db_write(|| {
961            let c = cc.clone();
962            async move {
963                c.fetch_add(1, Ordering::SeqCst);
964                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
965                    "always fails".to_string(),
966                )))
967            }
968        })
969        .await;
970        assert!(result.is_err());
971        // 1 initial attempt + MAX_CHECKPOINT_RETRIES retry attempts
972        assert_eq!(
973            call_count.load(Ordering::SeqCst),
974            1 + MAX_CHECKPOINT_RETRIES
975        );
976    }
977
978    /// test_retry_db_write_returns_original_error: when all retries are
979    /// exhausted the FIRST error is returned, not the last retry error.
980    #[tokio::test]
981    async fn test_retry_db_write_returns_original_error() {
982        let call_count = Arc::new(AtomicU32::new(0));
983        let cc = call_count.clone();
984        let result = retry_db_write(|| {
985            let c = cc.clone();
986            async move {
987                let n = c.fetch_add(1, Ordering::SeqCst);
988                Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
989                    "error-{}",
990                    n
991                ))))
992            }
993        })
994        .await;
995        let err = result.unwrap_err();
996        // The message of the first error contains "error-0"
997        assert!(
998            err.to_string().contains("error-0"),
999            "expected first error (error-0), got: {err}"
1000        );
1001    }
1002}