1use sea_orm::{ConnectionTrait, DatabaseConnection, DbBackend, Statement};
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use std::sync::atomic::{AtomicI32, Ordering};
5use uuid::Uuid;
6
7use crate::error::DurableError;
8
9pub struct Ctx {
15 db: DatabaseConnection,
16 task_id: Uuid,
17 sequence: AtomicI32,
18}
19
20impl Ctx {
21 pub async fn start(
29 db: &DatabaseConnection,
30 name: &str,
31 input: Option<serde_json::Value>,
32 ) -> Result<Self, DurableError> {
33 let task_id = find_or_create_task(db, None, None, name, "WORKFLOW", input).await?;
34 set_status(db, task_id, "RUNNING").await?;
35 Ok(Self {
36 db: db.clone(),
37 task_id,
38 sequence: AtomicI32::new(0),
39 })
40 }
41
42 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
49 where
50 T: Serialize + DeserializeOwned,
51 F: FnOnce() -> Fut,
52 Fut: std::future::Future<Output = Result<T, DurableError>>,
53 {
54 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
55
56 let step_id =
58 find_or_create_task(&self.db, Some(self.task_id), Some(seq), name, "STEP", None)
59 .await?;
60
61 if let Some(output) = get_output(&self.db, step_id).await? {
63 let val: T = serde_json::from_value(output)?;
64 tracing::debug!(step = name, seq, "replaying saved output");
65 return Ok(val);
66 }
67
68 set_status(&self.db, step_id, "RUNNING").await?;
70 match f().await {
71 Ok(val) => {
72 let json = serde_json::to_value(&val)?;
73 complete_task(&self.db, step_id, json).await?;
74 tracing::debug!(step = name, seq, "step completed");
75 Ok(val)
76 }
77 Err(e) => {
78 fail_task(&self.db, step_id, &e.to_string()).await?;
79 Err(e)
80 }
81 }
82 }
83
84 pub async fn child(
92 &self,
93 name: &str,
94 input: Option<serde_json::Value>,
95 ) -> Result<Self, DurableError> {
96 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
97 let child_id = find_or_create_task(
98 &self.db,
99 Some(self.task_id),
100 Some(seq),
101 name,
102 "WORKFLOW",
103 input,
104 )
105 .await?;
106
107 set_status(&self.db, child_id, "RUNNING").await?;
110
111 Ok(Self {
112 db: self.db.clone(),
113 task_id: child_id,
114 sequence: AtomicI32::new(0),
115 })
116 }
117
118 pub async fn is_completed(&self) -> Result<bool, DurableError> {
120 let status = get_status(&self.db, self.task_id).await?;
121 Ok(status.as_deref() == Some("COMPLETED"))
122 }
123
124 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
126 match get_output(&self.db, self.task_id).await? {
127 Some(val) => Ok(Some(serde_json::from_value(val)?)),
128 None => Ok(None),
129 }
130 }
131
132 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
134 let json = serde_json::to_value(output)?;
135 complete_task(&self.db, self.task_id, json).await
136 }
137
138 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
140 fail_task(&self.db, self.task_id, error).await
141 }
142
143 pub fn db(&self) -> &DatabaseConnection {
146 &self.db
147 }
148
149 pub fn task_id(&self) -> Uuid {
150 self.task_id
151 }
152
153 pub fn next_sequence(&self) -> i32 {
154 self.sequence.fetch_add(1, Ordering::SeqCst)
155 }
156}
157
158async fn find_or_create_task(
164 db: &DatabaseConnection,
165 parent_id: Option<Uuid>,
166 sequence: Option<i32>,
167 name: &str,
168 kind: &str,
169 input: Option<serde_json::Value>,
170) -> Result<Uuid, DurableError> {
171 let parent_sql = match parent_id {
172 Some(p) => format!("'{p}'"),
173 None => "NULL".to_string(),
174 };
175
176 let find_sql = format!(
178 "SELECT id FROM durable.task WHERE parent_id {eq} AND name = '{name}'",
179 eq = match parent_id {
180 Some(p) => format!("= '{p}'"),
181 None => "IS NULL".to_string(),
182 }
183 );
184 let row = db
185 .query_one(Statement::from_string(DbBackend::Postgres, find_sql))
186 .await?;
187
188 if let Some(row) = row {
189 let id: Uuid = row
190 .try_get_by_index(0)
191 .map_err(|e| DurableError::custom(e.to_string()))?;
192 return Ok(id);
193 }
194
195 let id = Uuid::new_v4();
197 let seq_sql = match sequence {
198 Some(s) => s.to_string(),
199 None => "NULL".to_string(),
200 };
201 let input_sql = match &input {
202 Some(v) => format!("'{}'", serde_json::to_string(v)?),
203 None => "NULL".to_string(),
204 };
205
206 let sql = format!(
207 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input) \
208 VALUES ('{id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql})"
209 );
210 db.execute(Statement::from_string(DbBackend::Postgres, sql))
211 .await?;
212
213 Ok(id)
214}
215
216async fn get_output(
217 db: &DatabaseConnection,
218 task_id: Uuid,
219) -> Result<Option<serde_json::Value>, DurableError> {
220 let sql =
221 format!("SELECT output FROM durable.task WHERE id = '{task_id}' AND status = 'COMPLETED'");
222 let row = db
223 .query_one(Statement::from_string(DbBackend::Postgres, sql))
224 .await?;
225
226 match row {
227 Some(r) => Ok(r.try_get_by_index(0).ok()),
228 None => Ok(None),
229 }
230}
231
232async fn get_status(
233 db: &DatabaseConnection,
234 task_id: Uuid,
235) -> Result<Option<String>, DurableError> {
236 let sql = format!("SELECT status FROM durable.task WHERE id = '{task_id}'");
237 let row = db
238 .query_one(Statement::from_string(DbBackend::Postgres, sql))
239 .await?;
240
241 match row {
242 Some(r) => Ok(r.try_get_by_index(0).ok()),
243 None => Ok(None),
244 }
245}
246
247async fn set_status(
248 db: &DatabaseConnection,
249 task_id: Uuid,
250 status: &str,
251) -> Result<(), DurableError> {
252 let sql = format!(
253 "UPDATE durable.task SET status = '{status}', started_at = COALESCE(started_at, now()) \
254 WHERE id = '{task_id}'"
255 );
256 db.execute(Statement::from_string(DbBackend::Postgres, sql))
257 .await?;
258 Ok(())
259}
260
261async fn complete_task(
262 db: &DatabaseConnection,
263 task_id: Uuid,
264 output: serde_json::Value,
265) -> Result<(), DurableError> {
266 let sql = format!(
267 "UPDATE durable.task SET status = 'COMPLETED', output = '{}', completed_at = now() \
268 WHERE id = '{task_id}'",
269 output
270 );
271 db.execute(Statement::from_string(DbBackend::Postgres, sql))
272 .await?;
273 Ok(())
274}
275
276async fn fail_task(
277 db: &DatabaseConnection,
278 task_id: Uuid,
279 error: &str,
280) -> Result<(), DurableError> {
281 let sql = format!(
282 "UPDATE durable.task SET status = 'FAILED', error = '{}', completed_at = now() \
283 WHERE id = '{task_id}'",
284 error.replace('\'', "''")
285 );
286 db.execute(Statement::from_string(DbBackend::Postgres, sql))
287 .await?;
288 Ok(())
289}