Skip to main content

durable_rust/
ctx.rs

1use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use std::sync::atomic::{AtomicI32, Ordering};
5use uuid::Uuid;
6
7use crate::error::DurableError;
8
9/// Context threaded through every workflow and step.
10///
11/// Users never create or manage task IDs. The SDK handles everything
12/// via `(parent_id, name)` lookups — the unique constraint in the schema
13/// guarantees exactly-once step creation.
14pub struct Ctx {
15    db: DatabaseConnection,
16    task_id: Uuid,
17    sequence: AtomicI32,
18}
19
20impl Ctx {
21    // ── Workflow lifecycle (user-facing) ──────────────────────────
22
23    /// Start or resume a root workflow by name.
24    ///
25    /// ```ignore
26    /// let ctx = Ctx::start(&db, "ingest", json!({"crawl": "CC-2026"})).await?;
27    /// ```
28    pub async fn start(
29        db: &DatabaseConnection,
30        name: &str,
31        input: Option<serde_json::Value>,
32    ) -> Result<Self, DurableError> {
33        let task_id = find_or_create_task(db, None, None, name, "WORKFLOW", input).await?;
34        set_status(db, task_id, "RUNNING").await?;
35        Ok(Self {
36            db: db.clone(),
37            task_id,
38            sequence: AtomicI32::new(0),
39        })
40    }
41
42    /// Run a step. If already completed, returns saved output. Otherwise executes
43    /// the closure, saves the result, and returns it.
44    ///
45    /// ```ignore
46    /// let count: u32 = ctx.step("fetch_count", || async { api.get_count().await }).await?;
47    /// ```
48    pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
49    where
50        T: Serialize + DeserializeOwned,
51        F: FnOnce() -> Fut,
52        Fut: std::future::Future<Output = Result<T, DurableError>>,
53    {
54        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
55
56        // Find or create — idempotent via UNIQUE(parent_id, name)
57        let step_id =
58            find_or_create_task(&self.db, Some(self.task_id), Some(seq), name, "STEP", None)
59                .await?;
60
61        // If already completed, replay
62        if let Some(output) = get_output(&self.db, step_id).await? {
63            let val: T = serde_json::from_value(output)?;
64            tracing::debug!(step = name, seq, "replaying saved output");
65            return Ok(val);
66        }
67
68        // Execute
69        set_status(&self.db, step_id, "RUNNING").await?;
70        match f().await {
71            Ok(val) => {
72                let json = serde_json::to_value(&val)?;
73                complete_task(&self.db, step_id, json).await?;
74                tracing::debug!(step = name, seq, "step completed");
75                Ok(val)
76            }
77            Err(e) => {
78                fail_task(&self.db, step_id, &e.to_string()).await?;
79                Err(e)
80            }
81        }
82    }
83
84    /// Start or resume a child workflow. Returns a new `Ctx` scoped to the child.
85    ///
86    /// ```ignore
87    /// let child_ctx = ctx.child("embed_batch", json!({"vectors": 1000})).await?;
88    /// // use child_ctx.step(...) for steps inside the child
89    /// child_ctx.complete(json!({"done": true})).await?;
90    /// ```
91    pub async fn child(
92        &self,
93        name: &str,
94        input: Option<serde_json::Value>,
95    ) -> Result<Self, DurableError> {
96        let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
97        let child_id = find_or_create_task(
98            &self.db,
99            Some(self.task_id),
100            Some(seq),
101            name,
102            "WORKFLOW",
103            input,
104        )
105        .await?;
106
107        // If child already completed, return a Ctx that will replay
108        // (the caller should check is_completed() or just run steps which will replay)
109        set_status(&self.db, child_id, "RUNNING").await?;
110
111        Ok(Self {
112            db: self.db.clone(),
113            task_id: child_id,
114            sequence: AtomicI32::new(0),
115        })
116    }
117
118    /// Check if this workflow/child was already completed (for skipping in parent).
119    pub async fn is_completed(&self) -> Result<bool, DurableError> {
120        let status = get_status(&self.db, self.task_id).await?;
121        Ok(status.as_deref() == Some("COMPLETED"))
122    }
123
124    /// Get the saved output if this task is completed.
125    pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
126        match get_output(&self.db, self.task_id).await? {
127            Some(val) => Ok(Some(serde_json::from_value(val)?)),
128            None => Ok(None),
129        }
130    }
131
132    /// Mark this workflow as completed with an output value.
133    pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
134        let json = serde_json::to_value(output)?;
135        complete_task(&self.db, self.task_id, json).await
136    }
137
138    /// Mark this workflow as failed.
139    pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
140        fail_task(&self.db, self.task_id, error).await
141    }
142
143    // ── Accessors ────────────────────────────────────────────────
144
145    pub fn db(&self) -> &DatabaseConnection {
146        &self.db
147    }
148
149    pub fn task_id(&self) -> Uuid {
150        self.task_id
151    }
152
153    pub fn next_sequence(&self) -> i32 {
154        self.sequence.fetch_add(1, Ordering::SeqCst)
155    }
156}
157
158// ── Internal SQL helpers ─────────────────────────────────────────────
159
160/// Find an existing task by (parent_id, name) or create a new one.
161/// This is the core idempotency mechanism — the UNIQUE(parent_id, name)
162/// constraint in the schema makes this safe.
163async fn find_or_create_task(
164    db: &DatabaseConnection,
165    parent_id: Option<Uuid>,
166    sequence: Option<i32>,
167    name: &str,
168    kind: &str,
169    input: Option<serde_json::Value>,
170) -> Result<Uuid, DurableError> {
171    let parent_sql = match parent_id {
172        Some(p) => format!("'{p}'"),
173        None => "NULL".to_string(),
174    };
175
176    // Try to find existing
177    let find_sql = format!(
178        "SELECT id FROM durable.task WHERE parent_id {eq} AND name = '{name}'",
179        eq = match parent_id {
180            Some(p) => format!("= '{p}'"),
181            None => "IS NULL".to_string(),
182        }
183    );
184    let row = db
185        .query_one(Statement::from_string(DbBackend::Postgres, find_sql))
186        .await?;
187
188    if let Some(row) = row {
189        let id: Uuid = row
190            .try_get_by_index(0)
191            .map_err(|e| DurableError::custom(e.to_string()))?;
192        return Ok(id);
193    }
194
195    // Create new
196    let id = Uuid::new_v4();
197    let seq_sql = match sequence {
198        Some(s) => s.to_string(),
199        None => "NULL".to_string(),
200    };
201    let input_sql = match &input {
202        Some(v) => format!("'{}'", serde_json::to_string(v)?),
203        None => "NULL".to_string(),
204    };
205
206    let sql = format!(
207        "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input) \
208         VALUES ('{id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql})"
209    );
210    db.execute(Statement::from_string(DbBackend::Postgres, sql))
211        .await?;
212
213    Ok(id)
214}
215
216async fn get_output(
217    db: &DatabaseConnection,
218    task_id: Uuid,
219) -> Result<Option<serde_json::Value>, DurableError> {
220    let sql =
221        format!("SELECT output FROM durable.task WHERE id = '{task_id}' AND status = 'COMPLETED'");
222    let row = db
223        .query_one(Statement::from_string(DbBackend::Postgres, sql))
224        .await?;
225
226    match row {
227        Some(r) => Ok(r.try_get_by_index(0).ok()),
228        None => Ok(None),
229    }
230}
231
232async fn get_status(
233    db: &DatabaseConnection,
234    task_id: Uuid,
235) -> Result<Option<String>, DurableError> {
236    let sql = format!("SELECT status FROM durable.task WHERE id = '{task_id}'");
237    let row = db
238        .query_one(Statement::from_string(DbBackend::Postgres, sql))
239        .await?;
240
241    match row {
242        Some(r) => Ok(r.try_get_by_index(0).ok()),
243        None => Ok(None),
244    }
245}
246
247async fn set_status(
248    db: &DatabaseConnection,
249    task_id: Uuid,
250    status: &str,
251) -> Result<(), DurableError> {
252    let sql = format!(
253        "UPDATE durable.task SET status = '{status}', started_at = COALESCE(started_at, now()) \
254         WHERE id = '{task_id}'"
255    );
256    db.execute(Statement::from_string(DbBackend::Postgres, sql))
257        .await?;
258    Ok(())
259}
260
261async fn complete_task(
262    db: &DatabaseConnection,
263    task_id: Uuid,
264    output: serde_json::Value,
265) -> Result<(), DurableError> {
266    let sql = format!(
267        "UPDATE durable.task SET status = 'COMPLETED', output = '{}', completed_at = now() \
268         WHERE id = '{task_id}'",
269        output
270    );
271    db.execute(Statement::from_string(DbBackend::Postgres, sql))
272        .await?;
273    Ok(())
274}
275
276async fn fail_task(
277    db: &DatabaseConnection,
278    task_id: Uuid,
279    error: &str,
280) -> Result<(), DurableError> {
281    let sql = format!(
282        "UPDATE durable.task SET status = 'FAILED', error = '{}', completed_at = now() \
283         WHERE id = '{task_id}'",
284        error.replace('\'', "''")
285    );
286    db.execute(Statement::from_string(DbBackend::Postgres, sql))
287        .await?;
288    Ok(())
289}