1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub handler: Option<String>,
200 pub status: TaskStatus,
201 pub kind: String,
202 pub input: Option<serde_json::Value>,
203 pub output: Option<serde_json::Value>,
204 pub error: Option<String>,
205 pub queue_name: Option<String>,
206 pub created_at: chrono::DateTime<chrono::FixedOffset>,
207 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
209}
210
211impl From<durable_db::entity::task::Model> for TaskSummary {
212 fn from(m: durable_db::entity::task::Model) -> Self {
213 Self {
214 id: m.id,
215 parent_id: m.parent_id,
216 name: m.name,
217 handler: m.handler,
218 status: m.status,
219 kind: m.kind,
220 input: m.input,
221 output: m.output,
222 error: m.error,
223 queue_name: m.queue_name,
224 created_at: m.created_at,
225 started_at: m.started_at,
226 completed_at: m.completed_at,
227 }
228 }
229}
230
231pub struct Ctx {
237 db: DatabaseConnection,
238 task_id: Uuid,
239 sequence: AtomicI32,
240 executor_id: Option<String>,
241}
242
243impl Ctx {
244 pub async fn start(
261 db: &DatabaseConnection,
262 name: &str,
263 input: Option<serde_json::Value>,
264 ) -> Result<Self, DurableError> {
265 Self::start_with_handler(db, name, input, None).await
266 }
267
268 pub async fn start_with_handler(
276 db: &DatabaseConnection,
277 name: &str,
278 input: Option<serde_json::Value>,
279 handler: Option<&str>,
280 ) -> Result<Self, DurableError> {
281 let existing_sql = format!(
283 "SELECT id FROM durable.task \
284 WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING' \
285 LIMIT 1",
286 name
287 );
288 if let Some(row) = db
289 .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
290 .await?
291 {
292 let existing_id: Uuid = row
293 .try_get_by_index(0)
294 .map_err(|e| DurableError::custom(e.to_string()))?;
295 tracing::info!(
296 workflow = name,
297 id = %existing_id,
298 "idempotent start: attaching to existing RUNNING task"
299 );
300 return Self::from_id(db, existing_id).await;
301 }
302
303 let task_id = Uuid::new_v4();
304 let input_json = match &input {
305 Some(v) => serde_json::to_string(v)?,
306 None => "null".to_string(),
307 };
308
309 let executor_id = crate::executor_id();
310
311 let mut extra_cols = String::new();
312 let mut extra_vals = String::new();
313
314 if let Some(eid) = &executor_id {
315 extra_cols.push_str(", executor_id");
316 extra_vals.push_str(&format!(", '{eid}'"));
317 }
318 if let Some(h) = handler {
319 extra_cols.push_str(", handler");
320 extra_vals.push_str(&format!(", '{h}'"));
321 }
322
323 let txn = db.begin().await?;
324 let sql = format!(
325 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
326 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){extra_vals})"
327 );
328 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
329 .await?;
330 txn.commit().await?;
331
332 Ok(Self {
333 db: db.clone(),
334 task_id,
335 sequence: AtomicI32::new(0),
336 executor_id,
337 })
338 }
339
340 pub async fn from_id(
347 db: &DatabaseConnection,
348 task_id: Uuid,
349 ) -> Result<Self, DurableError> {
350 let model = Task::find_by_id(task_id).one(db).await?;
352 let _model =
353 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
354
355 let executor_id = crate::executor_id();
357 if let Some(eid) = &executor_id {
358 db.execute(Statement::from_string(
359 DbBackend::Postgres,
360 format!(
361 "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
362 ),
363 ))
364 .await?;
365 }
366
367 Ok(Self {
372 db: db.clone(),
373 task_id,
374 sequence: AtomicI32::new(0),
375 executor_id,
376 })
377 }
378
379 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
388 where
389 T: Serialize + DeserializeOwned,
390 F: FnOnce() -> Fut,
391 Fut: std::future::Future<Output = Result<T, DurableError>>,
392 {
393 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
394
395 check_status(&self.db, self.task_id).await?;
397
398 check_deadline(&self.db, self.task_id).await?;
400
401 let txn = self.db.begin().await?;
403
404 let (step_id, saved_output) = find_or_create_task(
409 &txn,
410 Some(self.task_id),
411 Some(seq),
412 name,
413 "STEP",
414 None,
415 true,
416 Some(0),
417 )
418 .await?;
419
420 if let Some(output) = saved_output {
422 txn.commit().await?;
423 let val: T = serde_json::from_value(output)?;
424 tracing::debug!(step = name, seq, "replaying saved output");
425 return Ok(val);
426 }
427
428 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
430 match f().await {
431 Ok(val) => {
432 let json = serde_json::to_value(&val)?;
433 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
434 txn.commit().await?;
435 tracing::debug!(step = name, seq, "step completed");
436 Ok(val)
437 }
438 Err(e) => {
439 let err_msg = e.to_string();
440 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
441 txn.commit().await?;
442 Err(e)
443 }
444 }
445 }
446
447 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
459 where
460 T: Serialize + DeserializeOwned + Send,
461 F: for<'tx> FnOnce(
462 &'tx DatabaseTransaction,
463 ) -> Pin<
464 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
465 > + Send,
466 {
467 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
468
469 check_status(&self.db, self.task_id).await?;
471
472 let (step_id, saved_output) = find_or_create_task(
475 &self.db,
476 Some(self.task_id),
477 Some(seq),
478 name,
479 "TRANSACTION",
480 None,
481 false,
482 None,
483 )
484 .await?;
485
486 if let Some(output) = saved_output {
488 let val: T = serde_json::from_value(output)?;
489 tracing::debug!(step = name, seq, "replaying saved transaction output");
490 return Ok(val);
491 }
492
493 let tx = self.db.begin().await?;
495
496 set_status(&tx, step_id, TaskStatus::Running).await?;
497
498 match f(&tx).await {
499 Ok(val) => {
500 let json = serde_json::to_value(&val)?;
501 complete_task(&tx, step_id, json).await?;
502 tx.commit().await?;
503 tracing::debug!(step = name, seq, "transaction step committed");
504 Ok(val)
505 }
506 Err(e) => {
507 drop(tx);
510 fail_task(&self.db, step_id, &e.to_string()).await?;
511 Err(e)
512 }
513 }
514 }
515
516 pub async fn child(
524 &self,
525 name: &str,
526 input: Option<serde_json::Value>,
527 ) -> Result<Self, DurableError> {
528 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
529
530 check_status(&self.db, self.task_id).await?;
532
533 let txn = self.db.begin().await?;
534 let (child_id, _saved) = find_or_create_task(
536 &txn,
537 Some(self.task_id),
538 Some(seq),
539 name,
540 "WORKFLOW",
541 input,
542 false,
543 None,
544 )
545 .await?;
546
547 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
550 txn.commit().await?;
551
552 Ok(Self {
553 db: self.db.clone(),
554 task_id: child_id,
555 sequence: AtomicI32::new(0),
556 executor_id: self.executor_id.clone(),
557 })
558 }
559
560 pub async fn is_completed(&self) -> Result<bool, DurableError> {
562 let status = get_status(&self.db, self.task_id).await?;
563 Ok(status == Some(TaskStatus::Completed))
564 }
565
566 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
568 match get_output(&self.db, self.task_id).await? {
569 Some(val) => Ok(Some(serde_json::from_value(val)?)),
570 None => Ok(None),
571 }
572 }
573
574 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
576 let json = serde_json::to_value(output)?;
577 let db = &self.db;
578 let task_id = self.task_id;
579 retry_db_write(|| complete_task(db, task_id, json.clone())).await
580 }
581
582 pub async fn step_with_retry<T, F, Fut>(
596 &self,
597 name: &str,
598 policy: RetryPolicy,
599 f: F,
600 ) -> Result<T, DurableError>
601 where
602 T: Serialize + DeserializeOwned,
603 F: Fn() -> Fut,
604 Fut: std::future::Future<Output = Result<T, DurableError>>,
605 {
606 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
607
608 check_status(&self.db, self.task_id).await?;
610
611 let (step_id, saved_output) = find_or_create_task(
615 &self.db,
616 Some(self.task_id),
617 Some(seq),
618 name,
619 "STEP",
620 None,
621 false,
622 Some(policy.max_retries),
623 )
624 .await?;
625
626 if let Some(output) = saved_output {
628 let val: T = serde_json::from_value(output)?;
629 tracing::debug!(step = name, seq, "replaying saved output");
630 return Ok(val);
631 }
632
633 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
635
636 loop {
638 check_status(&self.db, self.task_id).await?;
640 set_status(&self.db, step_id, TaskStatus::Running).await?;
641 match f().await {
642 Ok(val) => {
643 let json = serde_json::to_value(&val)?;
644 complete_task(&self.db, step_id, json).await?;
645 tracing::debug!(step = name, seq, retry_count, "step completed");
646 return Ok(val);
647 }
648 Err(e) => {
649 if retry_count < max_retries {
650 retry_count = increment_retry_count(&self.db, step_id).await?;
652 tracing::debug!(
653 step = name,
654 seq,
655 retry_count,
656 max_retries,
657 "step failed, retrying"
658 );
659
660 let backoff = if policy.initial_backoff.is_zero() {
662 std::time::Duration::ZERO
663 } else {
664 let factor = policy
665 .backoff_multiplier
666 .powi((retry_count - 1) as i32)
667 .max(1.0);
668 let millis =
669 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
670 std::time::Duration::from_millis(millis)
671 };
672
673 if !backoff.is_zero() {
674 tokio::time::sleep(backoff).await;
675 }
676 } else {
677 fail_task(&self.db, step_id, &e.to_string()).await?;
679 tracing::debug!(
680 step = name,
681 seq,
682 retry_count,
683 "step exhausted retries, marked FAILED"
684 );
685 return Err(e);
686 }
687 }
688 }
689 }
690 }
691
692 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
694 let db = &self.db;
695 let task_id = self.task_id;
696 retry_db_write(|| fail_task(db, task_id, error)).await
697 }
698
699 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
707 let sql = format!(
708 "UPDATE durable.task \
709 SET timeout_ms = {timeout_ms}, \
710 deadline_epoch_ms = CASE \
711 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
712 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
713 ELSE deadline_epoch_ms \
714 END \
715 WHERE id = '{}'",
716 self.task_id
717 );
718 self.db
719 .execute(Statement::from_string(DbBackend::Postgres, sql))
720 .await?;
721 Ok(())
722 }
723
724 pub async fn start_with_timeout(
728 db: &DatabaseConnection,
729 name: &str,
730 input: Option<serde_json::Value>,
731 timeout_ms: i64,
732 ) -> Result<Self, DurableError> {
733 let ctx = Self::start(db, name, input).await?;
734 ctx.set_timeout(timeout_ms).await?;
735 Ok(ctx)
736 }
737
738 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
745 let model = Task::find_by_id(task_id).one(db).await?;
746 let model =
747 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
748
749 match model.status {
750 TaskStatus::Pending | TaskStatus::Running => {}
751 status => {
752 return Err(DurableError::custom(format!(
753 "cannot pause task in {status} status"
754 )));
755 }
756 }
757
758 let sql = format!(
760 "WITH RECURSIVE descendants AS ( \
761 SELECT id FROM durable.task WHERE id = '{task_id}' \
762 UNION ALL \
763 SELECT t.id FROM durable.task t \
764 INNER JOIN descendants d ON t.parent_id = d.id \
765 ) \
766 UPDATE durable.task SET status = 'PAUSED' \
767 WHERE id IN (SELECT id FROM descendants) \
768 AND status IN ('PENDING', 'RUNNING')"
769 );
770 db.execute(Statement::from_string(DbBackend::Postgres, sql))
771 .await?;
772
773 tracing::info!(%task_id, "workflow paused");
774 Ok(())
775 }
776
777 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
780 let model = Task::find_by_id(task_id).one(db).await?;
781 let model =
782 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
783
784 if model.status != TaskStatus::Paused {
785 return Err(DurableError::custom(format!(
786 "cannot resume task in {} status (must be PAUSED)",
787 model.status
788 )));
789 }
790
791 let mut active: TaskActiveModel = model.into();
793 active.status = Set(TaskStatus::Running);
794 active.update(db).await?;
795
796 let cascade_sql = format!(
798 "WITH RECURSIVE descendants AS ( \
799 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
800 UNION ALL \
801 SELECT t.id FROM durable.task t \
802 INNER JOIN descendants d ON t.parent_id = d.id \
803 ) \
804 UPDATE durable.task SET status = 'PENDING' \
805 WHERE id IN (SELECT id FROM descendants) \
806 AND status = 'PAUSED'"
807 );
808 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
809 .await?;
810
811 tracing::info!(%task_id, "workflow resumed");
812 Ok(())
813 }
814
815 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
820 let model = Task::find_by_id(task_id).one(db).await?;
821 let model =
822 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
823
824 match model.status {
825 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
826 return Err(DurableError::custom(format!(
827 "cannot cancel task in {} status",
828 model.status
829 )));
830 }
831 _ => {}
832 }
833
834 let sql = format!(
836 "WITH RECURSIVE descendants AS ( \
837 SELECT id FROM durable.task WHERE id = '{task_id}' \
838 UNION ALL \
839 SELECT t.id FROM durable.task t \
840 INNER JOIN descendants d ON t.parent_id = d.id \
841 ) \
842 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
843 WHERE id IN (SELECT id FROM descendants) \
844 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
845 );
846 db.execute(Statement::from_string(DbBackend::Postgres, sql))
847 .await?;
848
849 tracing::info!(%task_id, "workflow cancelled");
850 Ok(())
851 }
852
853 pub async fn list(
861 db: &DatabaseConnection,
862 query: TaskQuery,
863 ) -> Result<Vec<TaskSummary>, DurableError> {
864 let mut select = Task::find();
865
866 if let Some(status) = &query.status {
868 select = select.filter(TaskColumn::Status.eq(status.to_string()));
869 }
870 if let Some(kind) = &query.kind {
871 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
872 }
873 if let Some(parent_id) = query.parent_id {
874 select = select.filter(TaskColumn::ParentId.eq(parent_id));
875 }
876 if query.root_only {
877 select = select.filter(TaskColumn::ParentId.is_null());
878 }
879 if let Some(name) = &query.name {
880 select = select.filter(TaskColumn::Name.eq(name.as_str()));
881 }
882 if let Some(queue) = &query.queue_name {
883 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
884 }
885
886 let (col, order) = match query.sort {
888 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
889 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
890 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
891 TaskSort::Name(ord) => (TaskColumn::Name, ord),
892 TaskSort::Status(ord) => (TaskColumn::Status, ord),
893 };
894 select = select.order_by(col, order);
895
896 if let Some(offset) = query.offset {
898 select = select.offset(offset);
899 }
900 if let Some(limit) = query.limit {
901 select = select.limit(limit);
902 }
903
904 let models = select.all(db).await?;
905
906 Ok(models.into_iter().map(TaskSummary::from).collect())
907 }
908
909 pub async fn count(
911 db: &DatabaseConnection,
912 query: TaskQuery,
913 ) -> Result<u64, DurableError> {
914 let mut select = Task::find();
915
916 if let Some(status) = &query.status {
917 select = select.filter(TaskColumn::Status.eq(status.to_string()));
918 }
919 if let Some(kind) = &query.kind {
920 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
921 }
922 if let Some(parent_id) = query.parent_id {
923 select = select.filter(TaskColumn::ParentId.eq(parent_id));
924 }
925 if query.root_only {
926 select = select.filter(TaskColumn::ParentId.is_null());
927 }
928 if let Some(name) = &query.name {
929 select = select.filter(TaskColumn::Name.eq(name.as_str()));
930 }
931 if let Some(queue) = &query.queue_name {
932 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
933 }
934
935 let count = select.count(db).await?;
936 Ok(count)
937 }
938
939 pub fn db(&self) -> &DatabaseConnection {
942 &self.db
943 }
944
945 pub fn task_id(&self) -> Uuid {
946 self.task_id
947 }
948
949 pub fn next_sequence(&self) -> i32 {
950 self.sequence.fetch_add(1, Ordering::SeqCst)
951 }
952}
953
954#[allow(clippy::too_many_arguments)]
975async fn find_or_create_task(
976 db: &impl ConnectionTrait,
977 parent_id: Option<Uuid>,
978 sequence: Option<i32>,
979 name: &str,
980 kind: &str,
981 input: Option<serde_json::Value>,
982 lock: bool,
983 max_retries: Option<u32>,
984) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
985 let parent_eq = match parent_id {
986 Some(p) => format!("= '{p}'"),
987 None => "IS NULL".to_string(),
988 };
989 let parent_sql = match parent_id {
990 Some(p) => format!("'{p}'"),
991 None => "NULL".to_string(),
992 };
993
994 if lock {
995 let new_id = Uuid::new_v4();
1009 let seq_sql = match sequence {
1010 Some(s) => s.to_string(),
1011 None => "NULL".to_string(),
1012 };
1013 let input_sql = match &input {
1014 Some(v) => format!("'{}'", serde_json::to_string(v)?),
1015 None => "NULL".to_string(),
1016 };
1017
1018 let max_retries_sql = match max_retries {
1019 Some(r) => r.to_string(),
1020 None => "3".to_string(), };
1022
1023 let insert_sql = format!(
1025 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1026 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
1027 ON CONFLICT (parent_id, sequence) DO NOTHING"
1028 );
1029 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1030 .await?;
1031
1032 let lock_sql = format!(
1034 "SELECT id, status::text, output FROM durable.task \
1035 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1036 FOR UPDATE SKIP LOCKED"
1037 );
1038 let row = db
1039 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1040 .await?;
1041
1042 if let Some(row) = row {
1043 let id: Uuid = row
1044 .try_get_by_index(0)
1045 .map_err(|e| DurableError::custom(e.to_string()))?;
1046 let status: String = row
1047 .try_get_by_index(1)
1048 .map_err(|e| DurableError::custom(e.to_string()))?;
1049 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1050
1051 if status == TaskStatus::Completed.to_string() {
1052 return Ok((id, output));
1054 }
1055 return Ok((id, None));
1057 }
1058
1059 Err(DurableError::StepLocked(name.to_string()))
1061 } else {
1062 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1066 query = match parent_id {
1067 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1068 None => query.filter(TaskColumn::ParentId.is_null()),
1069 };
1070 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1073 let existing = query
1074 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1075 .one(db)
1076 .await?;
1077
1078 if let Some(model) = existing {
1079 if model.status == TaskStatus::Completed {
1080 return Ok((model.id, model.output));
1081 }
1082 return Ok((model.id, None));
1083 }
1084
1085 let id = Uuid::new_v4();
1087 let new_task = TaskActiveModel {
1088 id: Set(id),
1089 parent_id: Set(parent_id),
1090 sequence: Set(sequence),
1091 name: Set(name.to_string()),
1092 kind: Set(kind.to_string()),
1093 status: Set(TaskStatus::Pending),
1094 input: Set(input),
1095 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1096 ..Default::default()
1097 };
1098 new_task.insert(db).await?;
1099
1100 Ok((id, None))
1101 }
1102}
1103
1104async fn get_output(
1105 db: &impl ConnectionTrait,
1106 task_id: Uuid,
1107) -> Result<Option<serde_json::Value>, DurableError> {
1108 let model = Task::find_by_id(task_id)
1109 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1110 .one(db)
1111 .await?;
1112
1113 Ok(model.and_then(|m| m.output))
1114}
1115
1116async fn get_status(
1117 db: &impl ConnectionTrait,
1118 task_id: Uuid,
1119) -> Result<Option<TaskStatus>, DurableError> {
1120 let model = Task::find_by_id(task_id).one(db).await?;
1121
1122 Ok(model.map(|m| m.status))
1123}
1124
1125async fn get_retry_info(
1127 db: &DatabaseConnection,
1128 task_id: Uuid,
1129) -> Result<(u32, u32), DurableError> {
1130 let model = Task::find_by_id(task_id).one(db).await?;
1131
1132 match model {
1133 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1134 None => Err(DurableError::custom(format!(
1135 "task {task_id} not found when reading retry info"
1136 ))),
1137 }
1138}
1139
1140async fn increment_retry_count(
1142 db: &DatabaseConnection,
1143 task_id: Uuid,
1144) -> Result<u32, DurableError> {
1145 let model = Task::find_by_id(task_id).one(db).await?;
1146
1147 match model {
1148 Some(m) => {
1149 let new_count = m.retry_count + 1;
1150 let mut active: TaskActiveModel = m.into();
1151 active.retry_count = Set(new_count);
1152 active.status = Set(TaskStatus::Pending);
1153 active.error = Set(None);
1154 active.completed_at = Set(None);
1155 active.update(db).await?;
1156 Ok(new_count as u32)
1157 }
1158 None => Err(DurableError::custom(format!(
1159 "task {task_id} not found when incrementing retry count"
1160 ))),
1161 }
1162}
1163
1164async fn set_status(
1167 db: &impl ConnectionTrait,
1168 task_id: Uuid,
1169 status: TaskStatus,
1170) -> Result<(), DurableError> {
1171 let sql = format!(
1172 "UPDATE durable.task \
1173 SET status = '{status}', \
1174 started_at = COALESCE(started_at, now()), \
1175 deadline_epoch_ms = CASE \
1176 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1177 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1178 ELSE deadline_epoch_ms \
1179 END \
1180 WHERE id = '{task_id}'"
1181 );
1182 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1183 .await?;
1184 Ok(())
1185}
1186
1187async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1189 let status = get_status(db, task_id).await?;
1190 match status {
1191 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1192 Some(TaskStatus::Cancelled) => {
1193 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1194 }
1195 _ => Ok(()),
1196 }
1197}
1198
1199async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1201 let model = Task::find_by_id(task_id).one(db).await?;
1202
1203 if let Some(m) = model
1204 && let Some(deadline_ms) = m.deadline_epoch_ms
1205 {
1206 let now_ms = std::time::SystemTime::now()
1207 .duration_since(std::time::UNIX_EPOCH)
1208 .map(|d| d.as_millis() as i64)
1209 .unwrap_or(0);
1210 if now_ms > deadline_ms {
1211 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1212 }
1213 }
1214
1215 Ok(())
1216}
1217
1218async fn complete_task(
1219 db: &impl ConnectionTrait,
1220 task_id: Uuid,
1221 output: serde_json::Value,
1222) -> Result<(), DurableError> {
1223 let model = Task::find_by_id(task_id).one(db).await?;
1224
1225 if let Some(m) = model {
1226 let mut active: TaskActiveModel = m.into();
1227 active.status = Set(TaskStatus::Completed);
1228 active.output = Set(Some(output));
1229 active.completed_at = Set(Some(chrono::Utc::now().into()));
1230 active.update(db).await?;
1231 }
1232 Ok(())
1233}
1234
1235async fn fail_task(
1236 db: &impl ConnectionTrait,
1237 task_id: Uuid,
1238 error: &str,
1239) -> Result<(), DurableError> {
1240 let model = Task::find_by_id(task_id).one(db).await?;
1241
1242 if let Some(m) = model {
1243 let mut active: TaskActiveModel = m.into();
1244 active.status = Set(TaskStatus::Failed);
1245 active.error = Set(Some(error.to_string()));
1246 active.completed_at = Set(Some(chrono::Utc::now().into()));
1247 active.update(db).await?;
1248 }
1249 Ok(())
1250}
1251
1252#[cfg(test)]
1253mod tests {
1254 use super::*;
1255 use std::sync::Arc;
1256 use std::sync::atomic::{AtomicU32, Ordering};
1257
1258 #[tokio::test]
1261 async fn test_retry_db_write_succeeds_first_try() {
1262 let call_count = Arc::new(AtomicU32::new(0));
1263 let cc = call_count.clone();
1264 let result = retry_db_write(|| {
1265 let c = cc.clone();
1266 async move {
1267 c.fetch_add(1, Ordering::SeqCst);
1268 Ok::<(), DurableError>(())
1269 }
1270 })
1271 .await;
1272 assert!(result.is_ok());
1273 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1274 }
1275
1276 #[tokio::test]
1279 async fn test_retry_db_write_succeeds_after_transient_failure() {
1280 let call_count = Arc::new(AtomicU32::new(0));
1281 let cc = call_count.clone();
1282 let result = retry_db_write(|| {
1283 let c = cc.clone();
1284 async move {
1285 let n = c.fetch_add(1, Ordering::SeqCst);
1286 if n < 2 {
1287 Err(DurableError::Db(sea_orm::DbErr::Custom(
1288 "transient".to_string(),
1289 )))
1290 } else {
1291 Ok(())
1292 }
1293 }
1294 })
1295 .await;
1296 assert!(result.is_ok());
1297 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1298 }
1299
1300 #[tokio::test]
1303 async fn test_retry_db_write_exhausts_retries() {
1304 let call_count = Arc::new(AtomicU32::new(0));
1305 let cc = call_count.clone();
1306 let result = retry_db_write(|| {
1307 let c = cc.clone();
1308 async move {
1309 c.fetch_add(1, Ordering::SeqCst);
1310 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1311 "always fails".to_string(),
1312 )))
1313 }
1314 })
1315 .await;
1316 assert!(result.is_err());
1317 assert_eq!(
1319 call_count.load(Ordering::SeqCst),
1320 1 + MAX_CHECKPOINT_RETRIES
1321 );
1322 }
1323
1324 #[tokio::test]
1327 async fn test_retry_db_write_returns_original_error() {
1328 let call_count = Arc::new(AtomicU32::new(0));
1329 let cc = call_count.clone();
1330 let result = retry_db_write(|| {
1331 let c = cc.clone();
1332 async move {
1333 let n = c.fetch_add(1, Ordering::SeqCst);
1334 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1335 "error-{}",
1336 n
1337 ))))
1338 }
1339 })
1340 .await;
1341 let err = result.unwrap_err();
1342 assert!(
1344 err.to_string().contains("error-0"),
1345 "expected first error (error-0), got: {err}"
1346 );
1347 }
1348}