1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20pub(crate) const TS: &str = "::durable.task_status";
28
29const MAX_CHECKPOINT_RETRIES: u32 = 3;
32const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
33
34async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
40where
41 F: FnMut() -> Fut,
42 Fut: std::future::Future<Output = Result<(), DurableError>>,
43{
44 match f().await {
45 Ok(()) => Ok(()),
46 Err(first_err) => {
47 for i in 0..MAX_CHECKPOINT_RETRIES {
48 tokio::time::sleep(Duration::from_millis(
49 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
50 ))
51 .await;
52 if f().await.is_ok() {
53 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
54 return Ok(());
55 }
56 }
57 Err(first_err)
58 }
59 }
60}
61
62pub struct RetryPolicy {
64 pub max_retries: u32,
65 pub initial_backoff: std::time::Duration,
66 pub backoff_multiplier: f64,
67}
68
69impl RetryPolicy {
70 pub fn none() -> Self {
72 Self {
73 max_retries: 0,
74 initial_backoff: std::time::Duration::from_secs(0),
75 backoff_multiplier: 1.0,
76 }
77 }
78
79 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff,
84 backoff_multiplier: 2.0,
85 }
86 }
87
88 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
90 Self {
91 max_retries,
92 initial_backoff: backoff,
93 backoff_multiplier: 1.0,
94 }
95 }
96}
97
98pub enum TaskSort {
100 CreatedAt(Order),
101 StartedAt(Order),
102 CompletedAt(Order),
103 Name(Order),
104 Status(Order),
105}
106
107pub struct TaskQuery {
119 pub status: Option<TaskStatus>,
120 pub kind: Option<String>,
121 pub parent_id: Option<Uuid>,
122 pub root_only: bool,
123 pub name: Option<String>,
124 pub queue_name: Option<String>,
125 pub sort: TaskSort,
126 pub limit: Option<u64>,
127 pub offset: Option<u64>,
128}
129
130impl Default for TaskQuery {
131 fn default() -> Self {
132 Self {
133 status: None,
134 kind: None,
135 parent_id: None,
136 root_only: false,
137 name: None,
138 queue_name: None,
139 sort: TaskSort::CreatedAt(Order::Desc),
140 limit: None,
141 offset: None,
142 }
143 }
144}
145
146impl TaskQuery {
147 pub fn status(mut self, status: TaskStatus) -> Self {
149 self.status = Some(status);
150 self
151 }
152
153 pub fn kind(mut self, kind: &str) -> Self {
155 self.kind = Some(kind.to_string());
156 self
157 }
158
159 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
161 self.parent_id = Some(parent_id);
162 self
163 }
164
165 pub fn root_only(mut self, root_only: bool) -> Self {
167 self.root_only = root_only;
168 self
169 }
170
171 pub fn name(mut self, name: &str) -> Self {
173 self.name = Some(name.to_string());
174 self
175 }
176
177 pub fn queue_name(mut self, queue: &str) -> Self {
179 self.queue_name = Some(queue.to_string());
180 self
181 }
182
183 pub fn sort(mut self, sort: TaskSort) -> Self {
185 self.sort = sort;
186 self
187 }
188
189 pub fn limit(mut self, limit: u64) -> Self {
191 self.limit = Some(limit);
192 self
193 }
194
195 pub fn offset(mut self, offset: u64) -> Self {
197 self.offset = Some(offset);
198 self
199 }
200}
201
202#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
204pub struct TaskSummary {
205 pub id: Uuid,
206 pub parent_id: Option<Uuid>,
207 pub name: String,
208 pub handler: Option<String>,
209 pub status: TaskStatus,
210 pub kind: String,
211 pub input: Option<serde_json::Value>,
212 pub output: Option<serde_json::Value>,
213 pub error: Option<String>,
214 pub queue_name: Option<String>,
215 pub created_at: chrono::DateTime<chrono::FixedOffset>,
216 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
217 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
218}
219
220impl From<durable_db::entity::task::Model> for TaskSummary {
221 fn from(m: durable_db::entity::task::Model) -> Self {
222 Self {
223 id: m.id,
224 parent_id: m.parent_id,
225 name: m.name,
226 handler: m.handler,
227 status: m.status,
228 kind: m.kind,
229 input: m.input,
230 output: m.output,
231 error: m.error,
232 queue_name: m.queue_name,
233 created_at: m.created_at,
234 started_at: m.started_at,
235 completed_at: m.completed_at,
236 }
237 }
238}
239
240pub enum StartResult {
255 Created(Ctx),
257 Attached(Ctx),
259}
260
261impl StartResult {
262 pub fn into_ctx(self) -> Ctx {
264 match self {
265 StartResult::Created(ctx) | StartResult::Attached(ctx) => ctx,
266 }
267 }
268
269 pub fn ctx(&self) -> &Ctx {
271 match self {
272 StartResult::Created(ctx) | StartResult::Attached(ctx) => ctx,
273 }
274 }
275
276 pub fn is_created(&self) -> bool {
278 matches!(self, StartResult::Created(_))
279 }
280
281 pub fn is_attached(&self) -> bool {
283 matches!(self, StartResult::Attached(_))
284 }
285}
286
287pub struct Ctx {
293 db: DatabaseConnection,
294 task_id: Uuid,
295 sequence: AtomicI32,
296 executor_id: Option<String>,
297}
298
299impl Ctx {
300 pub async fn start(
317 db: &DatabaseConnection,
318 name: &str,
319 input: Option<serde_json::Value>,
320 ) -> Result<StartResult, DurableError> {
321 Self::start_with_handler(db, name, input, None).await
322 }
323
324 pub async fn start_with_handler(
332 db: &DatabaseConnection,
333 name: &str,
334 input: Option<serde_json::Value>,
335 handler: Option<&str>,
336 ) -> Result<StartResult, DurableError> {
337 let existing_sql = format!(
339 "SELECT id FROM durable.task \
340 WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING'{TS} \
341 LIMIT 1",
342 name
343 );
344 if let Some(row) = db
345 .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
346 .await?
347 {
348 let existing_id: Uuid = row
349 .try_get_by_index(0)
350 .map_err(|e| DurableError::custom(e.to_string()))?;
351 tracing::info!(
352 workflow = name,
353 id = %existing_id,
354 "idempotent start: attaching to existing RUNNING task"
355 );
356 return Self::from_id(db, existing_id)
357 .await
358 .map(StartResult::Attached);
359 }
360
361 let task_id = Uuid::new_v4();
362 let input_json = match &input {
363 Some(v) => serde_json::to_string(v)?,
364 None => "null".to_string(),
365 };
366
367 let executor_id = crate::executor_id();
368
369 let mut extra_cols = String::new();
370 let mut extra_vals = String::new();
371
372 if let Some(eid) = &executor_id {
373 extra_cols.push_str(", executor_id");
374 extra_vals.push_str(&format!(", '{eid}'"));
375 }
376 if let Some(h) = handler {
377 extra_cols.push_str(", handler");
378 extra_vals.push_str(&format!(", '{h}'"));
379 }
380
381 let txn = db.begin().await?;
382 let sql = format!(
383 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
384 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING'{TS}, '{input_json}', now(){extra_vals})"
385 );
386 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
387 .await?;
388 txn.commit().await?;
389
390 Ok(StartResult::Created(Self {
391 db: db.clone(),
392 task_id,
393 sequence: AtomicI32::new(0),
394 executor_id,
395 }))
396 }
397
398 pub async fn from_id(db: &DatabaseConnection, task_id: Uuid) -> Result<Self, DurableError> {
405 let model = Task::find_by_id(task_id).one(db).await?;
407 let _model =
408 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
409
410 let executor_id = crate::executor_id();
412 if let Some(eid) = &executor_id {
413 db.execute(Statement::from_string(
414 DbBackend::Postgres,
415 format!("UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"),
416 ))
417 .await?;
418 }
419
420 Ok(Self {
425 db: db.clone(),
426 task_id,
427 sequence: AtomicI32::new(0),
428 executor_id,
429 })
430 }
431
432 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
441 where
442 T: Serialize + DeserializeOwned,
443 F: FnOnce() -> Fut,
444 Fut: std::future::Future<Output = Result<T, DurableError>>,
445 {
446 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
447
448 check_status(&self.db, self.task_id).await?;
450
451 check_deadline(&self.db, self.task_id).await?;
453
454 let txn = self.db.begin().await?;
458 let (step_id, saved_output) = find_or_create_task(
459 &txn,
460 Some(self.task_id),
461 Some(seq),
462 name,
463 "STEP",
464 None,
465 true,
466 Some(0),
467 )
468 .await?;
469
470 if let Some(output) = saved_output {
471 txn.commit().await?;
472 let val: T = serde_json::from_value(output)?;
473 tracing::debug!(step = name, seq, "replaying saved output");
474 return Ok(val);
475 }
476
477 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
478 txn.commit().await?;
479 let result = f().await;
483
484 match result {
486 Ok(val) => {
487 let json = serde_json::to_value(&val)?;
488 retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
489 tracing::debug!(step = name, seq, "step completed");
490 Ok(val)
491 }
492 Err(e) => {
493 let err_msg = e.to_string();
494 retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
495 Err(e)
496 }
497 }
498 }
499
500 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
512 where
513 T: Serialize + DeserializeOwned + Send,
514 F: for<'tx> FnOnce(
515 &'tx DatabaseTransaction,
516 ) -> Pin<
517 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
518 > + Send,
519 {
520 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
521
522 check_status(&self.db, self.task_id).await?;
524
525 let (step_id, saved_output) = find_or_create_task(
528 &self.db,
529 Some(self.task_id),
530 Some(seq),
531 name,
532 "TRANSACTION",
533 None,
534 false,
535 None,
536 )
537 .await?;
538
539 if let Some(output) = saved_output {
541 let val: T = serde_json::from_value(output)?;
542 tracing::debug!(step = name, seq, "replaying saved transaction output");
543 return Ok(val);
544 }
545
546 let tx = self.db.begin().await?;
548
549 set_status(&tx, step_id, TaskStatus::Running).await?;
550
551 match f(&tx).await {
552 Ok(val) => {
553 let json = serde_json::to_value(&val)?;
554 complete_task(&tx, step_id, json).await?;
555 tx.commit().await?;
556 tracing::debug!(step = name, seq, "transaction step committed");
557 Ok(val)
558 }
559 Err(e) => {
560 drop(tx);
563 fail_task(&self.db, step_id, &e.to_string()).await?;
564 Err(e)
565 }
566 }
567 }
568
569 pub async fn child(
577 &self,
578 name: &str,
579 input: Option<serde_json::Value>,
580 ) -> Result<Self, DurableError> {
581 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
582
583 check_status(&self.db, self.task_id).await?;
585
586 let txn = self.db.begin().await?;
587 let (child_id, _saved) = find_or_create_task(
589 &txn,
590 Some(self.task_id),
591 Some(seq),
592 name,
593 "WORKFLOW",
594 input,
595 false,
596 None,
597 )
598 .await?;
599
600 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
603 txn.commit().await?;
604
605 Ok(Self {
606 db: self.db.clone(),
607 task_id: child_id,
608 sequence: AtomicI32::new(0),
609 executor_id: self.executor_id.clone(),
610 })
611 }
612
613 pub async fn concurrent_step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
626 where
627 T: Serialize + DeserializeOwned,
628 F: FnOnce() -> Fut,
629 Fut: std::future::Future<Output = Result<T, DurableError>>,
630 {
631 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
632
633 check_status(&self.db, self.task_id).await?;
635
636 let (child_id, saved_output) = find_or_create_task(
638 &self.db,
639 Some(self.task_id),
640 Some(seq),
641 name,
642 "WORKFLOW",
643 None,
644 false,
645 None,
646 )
647 .await?;
648
649 if let Some(output) = saved_output {
651 let val: T = serde_json::from_value(output)?;
652 tracing::debug!(step = name, seq, "concurrent_step: replaying saved output");
653 return Ok(val);
654 }
655
656 retry_db_write(|| set_status(&self.db, child_id, TaskStatus::Running)).await?;
658
659 let result = f().await?;
660
661 let json = serde_json::to_value(&result)?;
662 retry_db_write(|| complete_task(&self.db, child_id, json.clone())).await?;
663 tracing::debug!(step = name, seq, "concurrent_step: completed");
664 Ok(result)
665 }
666
667 pub async fn is_completed(&self) -> Result<bool, DurableError> {
669 let status = get_status(&self.db, self.task_id).await?;
670 Ok(status == Some(TaskStatus::Completed))
671 }
672
673 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
675 match get_output(&self.db, self.task_id).await? {
676 Some(val) => Ok(Some(serde_json::from_value(val)?)),
677 None => Ok(None),
678 }
679 }
680
681 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
683 let json = serde_json::to_value(output)?;
684 let db = &self.db;
685 let task_id = self.task_id;
686 retry_db_write(|| complete_task(db, task_id, json.clone())).await
687 }
688
689 pub async fn step_with_retry<T, F, Fut>(
703 &self,
704 name: &str,
705 policy: RetryPolicy,
706 f: F,
707 ) -> Result<T, DurableError>
708 where
709 T: Serialize + DeserializeOwned,
710 F: Fn() -> Fut,
711 Fut: std::future::Future<Output = Result<T, DurableError>>,
712 {
713 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
714
715 check_status(&self.db, self.task_id).await?;
717
718 let (step_id, saved_output) = find_or_create_task(
722 &self.db,
723 Some(self.task_id),
724 Some(seq),
725 name,
726 "STEP",
727 None,
728 false,
729 Some(policy.max_retries),
730 )
731 .await?;
732
733 if let Some(output) = saved_output {
735 let val: T = serde_json::from_value(output)?;
736 tracing::debug!(step = name, seq, "replaying saved output");
737 return Ok(val);
738 }
739
740 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
742
743 loop {
745 check_status(&self.db, self.task_id).await?;
747 set_status(&self.db, step_id, TaskStatus::Running).await?;
748 match f().await {
749 Ok(val) => {
750 let json = serde_json::to_value(&val)?;
751 complete_task(&self.db, step_id, json).await?;
752 tracing::debug!(step = name, seq, retry_count, "step completed");
753 return Ok(val);
754 }
755 Err(e) => {
756 if retry_count < max_retries {
757 retry_count = increment_retry_count(&self.db, step_id).await?;
759 tracing::debug!(
760 step = name,
761 seq,
762 retry_count,
763 max_retries,
764 "step failed, retrying"
765 );
766
767 let backoff = if policy.initial_backoff.is_zero() {
769 std::time::Duration::ZERO
770 } else {
771 let factor = policy
772 .backoff_multiplier
773 .powi((retry_count - 1) as i32)
774 .max(1.0);
775 let millis =
776 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
777 std::time::Duration::from_millis(millis)
778 };
779
780 if !backoff.is_zero() {
781 tokio::time::sleep(backoff).await;
782 }
783 } else {
784 fail_task(&self.db, step_id, &e.to_string()).await?;
786 tracing::debug!(
787 step = name,
788 seq,
789 retry_count,
790 "step exhausted retries, marked FAILED"
791 );
792 return Err(e);
793 }
794 }
795 }
796 }
797 }
798
799 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
801 let db = &self.db;
802 let task_id = self.task_id;
803 retry_db_write(|| fail_task(db, task_id, error)).await
804 }
805
806 pub async fn fail_by_id(
808 db: &DatabaseConnection,
809 task_id: Uuid,
810 error: &str,
811 ) -> Result<(), DurableError> {
812 retry_db_write(|| fail_task(db, task_id, error)).await
813 }
814
815 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
823 let sql = format!(
824 "UPDATE durable.task \
825 SET timeout_ms = {timeout_ms}, \
826 deadline_epoch_ms = CASE \
827 WHEN status = 'RUNNING'{TS} AND started_at IS NOT NULL \
828 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
829 ELSE deadline_epoch_ms \
830 END \
831 WHERE id = '{}'",
832 self.task_id
833 );
834 self.db
835 .execute(Statement::from_string(DbBackend::Postgres, sql))
836 .await?;
837 Ok(())
838 }
839
840 pub async fn start_with_timeout(
844 db: &DatabaseConnection,
845 name: &str,
846 input: Option<serde_json::Value>,
847 timeout_ms: i64,
848 ) -> Result<StartResult, DurableError> {
849 let result = Self::start(db, name, input).await?;
850 result.ctx().set_timeout(timeout_ms).await?;
851 Ok(result)
852 }
853
854 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
861 let model = Task::find_by_id(task_id).one(db).await?;
862 let model =
863 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
864
865 match model.status {
866 TaskStatus::Pending | TaskStatus::Running => {}
867 status => {
868 return Err(DurableError::custom(format!(
869 "cannot pause task in {status} status"
870 )));
871 }
872 }
873
874 let sql = format!(
876 "WITH RECURSIVE descendants AS ( \
877 SELECT id FROM durable.task WHERE id = '{task_id}' \
878 UNION ALL \
879 SELECT t.id FROM durable.task t \
880 INNER JOIN descendants d ON t.parent_id = d.id \
881 ) \
882 UPDATE durable.task SET status = 'PAUSED'{TS} \
883 WHERE id IN (SELECT id FROM descendants) \
884 AND status IN ('PENDING'{TS}, 'RUNNING'{TS})"
885 );
886 db.execute(Statement::from_string(DbBackend::Postgres, sql))
887 .await?;
888
889 tracing::info!(%task_id, "workflow paused");
890 Ok(())
891 }
892
893 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
896 let model = Task::find_by_id(task_id).one(db).await?;
897 let model =
898 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
899
900 if model.status != TaskStatus::Paused {
901 return Err(DurableError::custom(format!(
902 "cannot resume task in {} status (must be PAUSED)",
903 model.status
904 )));
905 }
906
907 let mut active: TaskActiveModel = model.into();
909 active.status = Set(TaskStatus::Running);
910 active.update(db).await?;
911
912 let cascade_sql = format!(
914 "WITH RECURSIVE descendants AS ( \
915 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
916 UNION ALL \
917 SELECT t.id FROM durable.task t \
918 INNER JOIN descendants d ON t.parent_id = d.id \
919 ) \
920 UPDATE durable.task SET status = 'PENDING'{TS} \
921 WHERE id IN (SELECT id FROM descendants) \
922 AND status = 'PAUSED'{TS}"
923 );
924 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
925 .await?;
926
927 tracing::info!(%task_id, "workflow resumed");
928 Ok(())
929 }
930
931 pub async fn resume_failed(
944 db: &DatabaseConnection,
945 task_id: Uuid,
946 ) -> Result<Option<String>, DurableError> {
947 let model = Task::find_by_id(task_id).one(db).await?;
949 let model =
950 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
951
952 if model.status != TaskStatus::Failed {
953 return Err(DurableError::custom(format!(
954 "cannot resume task in {} status (must be FAILED)",
955 model.status
956 )));
957 }
958
959 if model.kind != "WORKFLOW" {
960 return Err(DurableError::custom(format!(
961 "cannot resume a {}, only workflows can be resumed",
962 model.kind
963 )));
964 }
965
966 if model.recovery_count >= model.max_recovery_attempts {
968 return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
969 }
970
971 let executor_id = crate::executor_id().unwrap_or_default();
972
973 let reset_workflow_sql = format!(
975 "UPDATE durable.task \
976 SET status = 'RUNNING'{TS}, \
977 error = NULL, \
978 completed_at = NULL, \
979 started_at = now(), \
980 executor_id = '{executor_id}', \
981 recovery_count = recovery_count + 1 \
982 WHERE id = '{task_id}' AND status = 'FAILED'{TS}"
983 );
984 db.execute(Statement::from_string(
985 DbBackend::Postgres,
986 reset_workflow_sql,
987 ))
988 .await?;
989
990 let reset_steps_sql = format!(
993 "WITH RECURSIVE descendants AS ( \
994 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
995 UNION ALL \
996 SELECT t.id FROM durable.task t \
997 INNER JOIN descendants d ON t.parent_id = d.id \
998 ) \
999 UPDATE durable.task \
1000 SET status = 'PENDING'{TS}, \
1001 error = NULL, \
1002 retry_count = 0, \
1003 completed_at = NULL, \
1004 started_at = NULL \
1005 WHERE id IN (SELECT id FROM descendants) \
1006 AND status IN ('FAILED'{TS}, 'RUNNING'{TS})"
1007 );
1008 db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
1009 .await?;
1010
1011 tracing::info!(
1012 %task_id,
1013 recovery_count = model.recovery_count + 1,
1014 "failed workflow resumed"
1015 );
1016
1017 Ok(model.handler.clone())
1018 }
1019
1020 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1025 let model = Task::find_by_id(task_id).one(db).await?;
1026 let model =
1027 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
1028
1029 match model.status {
1030 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
1031 return Err(DurableError::custom(format!(
1032 "cannot cancel task in {} status",
1033 model.status
1034 )));
1035 }
1036 _ => {}
1037 }
1038
1039 let sql = format!(
1041 "WITH RECURSIVE descendants AS ( \
1042 SELECT id FROM durable.task WHERE id = '{task_id}' \
1043 UNION ALL \
1044 SELECT t.id FROM durable.task t \
1045 INNER JOIN descendants d ON t.parent_id = d.id \
1046 ) \
1047 UPDATE durable.task SET status = 'CANCELLED'{TS}, completed_at = now() \
1048 WHERE id IN (SELECT id FROM descendants) \
1049 AND status NOT IN ('COMPLETED'{TS}, 'FAILED'{TS}, 'CANCELLED'{TS})"
1050 );
1051 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1052 .await?;
1053
1054 tracing::info!(%task_id, "workflow cancelled");
1055 Ok(())
1056 }
1057
1058 pub async fn list(
1066 db: &DatabaseConnection,
1067 query: TaskQuery,
1068 ) -> Result<Vec<TaskSummary>, DurableError> {
1069 let mut select = Task::find();
1070
1071 if let Some(status) = &query.status {
1073 select = select.filter(TaskColumn::Status.eq(status.to_string()));
1074 }
1075 if let Some(kind) = &query.kind {
1076 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1077 }
1078 if let Some(parent_id) = query.parent_id {
1079 select = select.filter(TaskColumn::ParentId.eq(parent_id));
1080 }
1081 if query.root_only {
1082 select = select.filter(TaskColumn::ParentId.is_null());
1083 }
1084 if let Some(name) = &query.name {
1085 select = select.filter(TaskColumn::Name.eq(name.as_str()));
1086 }
1087 if let Some(queue) = &query.queue_name {
1088 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1089 }
1090
1091 let (col, order) = match query.sort {
1093 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
1094 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
1095 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
1096 TaskSort::Name(ord) => (TaskColumn::Name, ord),
1097 TaskSort::Status(ord) => (TaskColumn::Status, ord),
1098 };
1099 select = select.order_by(col, order);
1100
1101 if let Some(offset) = query.offset {
1103 select = select.offset(offset);
1104 }
1105 if let Some(limit) = query.limit {
1106 select = select.limit(limit);
1107 }
1108
1109 let models = select.all(db).await?;
1110
1111 Ok(models.into_iter().map(TaskSummary::from).collect())
1112 }
1113
1114 pub async fn count(db: &DatabaseConnection, query: TaskQuery) -> Result<u64, DurableError> {
1116 let mut select = Task::find();
1117
1118 if let Some(status) = &query.status {
1119 select = select.filter(TaskColumn::Status.eq(status.to_string()));
1120 }
1121 if let Some(kind) = &query.kind {
1122 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1123 }
1124 if let Some(parent_id) = query.parent_id {
1125 select = select.filter(TaskColumn::ParentId.eq(parent_id));
1126 }
1127 if query.root_only {
1128 select = select.filter(TaskColumn::ParentId.is_null());
1129 }
1130 if let Some(name) = &query.name {
1131 select = select.filter(TaskColumn::Name.eq(name.as_str()));
1132 }
1133 if let Some(queue) = &query.queue_name {
1134 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1135 }
1136
1137 let count = select.count(db).await?;
1138 Ok(count)
1139 }
1140
1141 pub fn db(&self) -> &DatabaseConnection {
1144 &self.db
1145 }
1146
1147 pub fn task_id(&self) -> Uuid {
1148 self.task_id
1149 }
1150
1151 pub fn next_sequence(&self) -> i32 {
1152 self.sequence.fetch_add(1, Ordering::SeqCst)
1153 }
1154
1155 pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1165 let row = self
1166 .db
1167 .query_one(Statement::from_string(
1168 DbBackend::Postgres,
1169 format!(
1170 "SELECT input FROM durable.task WHERE id = '{}'",
1171 self.task_id
1172 ),
1173 ))
1174 .await?
1175 .ok_or_else(|| DurableError::custom(format!("task {} not found", self.task_id)))?;
1176
1177 let input_json: Option<serde_json::Value> = row
1178 .try_get_by_index(0)
1179 .map_err(|e| DurableError::custom(e.to_string()))?;
1180
1181 let value = input_json
1182 .ok_or_else(|| DurableError::custom(format!("task {} has no input", self.task_id)))?;
1183
1184 serde_json::from_value(value)
1185 .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1186 }
1187}
1188
1189#[allow(clippy::too_many_arguments)]
1212async fn find_or_create_task(
1213 db: &impl ConnectionTrait,
1214 parent_id: Option<Uuid>,
1215 sequence: Option<i32>,
1216 name: &str,
1217 kind: &str,
1218 input: Option<serde_json::Value>,
1219 lock: bool,
1220 max_retries: Option<u32>,
1221) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1222 let parent_eq = match parent_id {
1223 Some(p) => format!("= '{p}'"),
1224 None => "IS NULL".to_string(),
1225 };
1226 let parent_sql = match parent_id {
1227 Some(p) => format!("'{p}'"),
1228 None => "NULL".to_string(),
1229 };
1230
1231 if lock {
1232 let new_id = Uuid::new_v4();
1246 let seq_sql = match sequence {
1247 Some(s) => s.to_string(),
1248 None => "NULL".to_string(),
1249 };
1250 let input_sql = match &input {
1251 Some(v) => format!("'{}'", serde_json::to_string(v)?),
1252 None => "NULL".to_string(),
1253 };
1254
1255 let max_retries_sql = match max_retries {
1256 Some(r) => r.to_string(),
1257 None => "3".to_string(), };
1259
1260 let insert_sql = format!(
1262 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1263 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING'{TS}, {input_sql}, {max_retries_sql}) \
1264 ON CONFLICT (parent_id, sequence) DO NOTHING"
1265 );
1266 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1267 .await?;
1268
1269 let lock_sql = format!(
1271 "SELECT id, status::text, output FROM durable.task \
1272 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1273 FOR UPDATE SKIP LOCKED"
1274 );
1275 let row = db
1276 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1277 .await?;
1278
1279 if let Some(row) = row {
1280 let id: Uuid = row
1281 .try_get_by_index(0)
1282 .map_err(|e| DurableError::custom(e.to_string()))?;
1283 let status: String = row
1284 .try_get_by_index(1)
1285 .map_err(|e| DurableError::custom(e.to_string()))?;
1286 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1287
1288 if status == TaskStatus::Completed.to_string() {
1289 return Ok((id, output));
1291 }
1292 if status == TaskStatus::Running.to_string() {
1293 return Err(DurableError::StepLocked(name.to_string()));
1297 }
1298 return Ok((id, None));
1300 }
1301
1302 Err(DurableError::StepLocked(name.to_string()))
1304 } else {
1305 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1309 query = match parent_id {
1310 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1311 None => query.filter(TaskColumn::ParentId.is_null()),
1312 };
1313 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1316 let existing = query
1317 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1318 .one(db)
1319 .await?;
1320
1321 if let Some(model) = existing {
1322 if model.status == TaskStatus::Completed {
1323 return Ok((model.id, model.output));
1324 }
1325 return Ok((model.id, None));
1326 }
1327
1328 let id = Uuid::new_v4();
1330 let new_task = TaskActiveModel {
1331 id: Set(id),
1332 parent_id: Set(parent_id),
1333 sequence: Set(sequence),
1334 name: Set(name.to_string()),
1335 kind: Set(kind.to_string()),
1336 status: Set(TaskStatus::Pending),
1337 input: Set(input),
1338 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1339 ..Default::default()
1340 };
1341 new_task.insert(db).await?;
1342
1343 Ok((id, None))
1344 }
1345}
1346
1347async fn get_output(
1348 db: &impl ConnectionTrait,
1349 task_id: Uuid,
1350) -> Result<Option<serde_json::Value>, DurableError> {
1351 let model = Task::find_by_id(task_id)
1352 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1353 .one(db)
1354 .await?;
1355
1356 Ok(model.and_then(|m| m.output))
1357}
1358
1359async fn get_status(
1360 db: &impl ConnectionTrait,
1361 task_id: Uuid,
1362) -> Result<Option<TaskStatus>, DurableError> {
1363 let model = Task::find_by_id(task_id).one(db).await?;
1364
1365 Ok(model.map(|m| m.status))
1366}
1367
1368async fn get_retry_info(
1370 db: &DatabaseConnection,
1371 task_id: Uuid,
1372) -> Result<(u32, u32), DurableError> {
1373 let model = Task::find_by_id(task_id).one(db).await?;
1374
1375 match model {
1376 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1377 None => Err(DurableError::custom(format!(
1378 "task {task_id} not found when reading retry info"
1379 ))),
1380 }
1381}
1382
1383async fn increment_retry_count(
1385 db: &DatabaseConnection,
1386 task_id: Uuid,
1387) -> Result<u32, DurableError> {
1388 let model = Task::find_by_id(task_id).one(db).await?;
1389
1390 match model {
1391 Some(m) => {
1392 let new_count = m.retry_count + 1;
1393 let mut active: TaskActiveModel = m.into();
1394 active.retry_count = Set(new_count);
1395 active.status = Set(TaskStatus::Pending);
1396 active.error = Set(None);
1397 active.completed_at = Set(None);
1398 active.update(db).await?;
1399 Ok(new_count as u32)
1400 }
1401 None => Err(DurableError::custom(format!(
1402 "task {task_id} not found when incrementing retry count"
1403 ))),
1404 }
1405}
1406
1407async fn set_status(
1410 db: &impl ConnectionTrait,
1411 task_id: Uuid,
1412 status: TaskStatus,
1413) -> Result<(), DurableError> {
1414 let sql = format!(
1415 "UPDATE durable.task \
1416 SET status = '{status}'{TS}, \
1417 started_at = COALESCE(started_at, now()), \
1418 deadline_epoch_ms = CASE \
1419 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1420 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1421 ELSE deadline_epoch_ms \
1422 END \
1423 WHERE id = '{task_id}'"
1424 );
1425 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1426 .await?;
1427 Ok(())
1428}
1429
1430async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1432 let status = get_status(db, task_id).await?;
1433 match status {
1434 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1435 Some(TaskStatus::Cancelled) => Err(DurableError::Cancelled(format!(
1436 "task {task_id} is cancelled"
1437 ))),
1438 _ => Ok(()),
1439 }
1440}
1441
1442async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1444 let model = Task::find_by_id(task_id).one(db).await?;
1445
1446 if let Some(m) = model
1447 && let Some(deadline_ms) = m.deadline_epoch_ms
1448 {
1449 let now_ms = std::time::SystemTime::now()
1450 .duration_since(std::time::UNIX_EPOCH)
1451 .map(|d| d.as_millis() as i64)
1452 .unwrap_or(0);
1453 if now_ms > deadline_ms {
1454 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1455 }
1456 }
1457
1458 Ok(())
1459}
1460
1461async fn complete_task(
1462 db: &impl ConnectionTrait,
1463 task_id: Uuid,
1464 output: serde_json::Value,
1465) -> Result<(), DurableError> {
1466 let model = Task::find_by_id(task_id).one(db).await?;
1467
1468 if let Some(m) = model {
1469 let mut active: TaskActiveModel = m.into();
1470 active.status = Set(TaskStatus::Completed);
1471 active.output = Set(Some(output));
1472 active.completed_at = Set(Some(chrono::Utc::now().into()));
1473 active.update(db).await?;
1474 }
1475 Ok(())
1476}
1477
1478async fn fail_task(
1479 db: &impl ConnectionTrait,
1480 task_id: Uuid,
1481 error: &str,
1482) -> Result<(), DurableError> {
1483 let model = Task::find_by_id(task_id).one(db).await?;
1484
1485 if let Some(m) = model {
1486 let mut active: TaskActiveModel = m.into();
1487 active.status = Set(TaskStatus::Failed);
1488 active.error = Set(Some(error.to_string()));
1489 active.completed_at = Set(Some(chrono::Utc::now().into()));
1490 active.update(db).await?;
1491 }
1492 Ok(())
1493}
1494
1495#[cfg(test)]
1496mod tests {
1497 use super::*;
1498 use std::sync::Arc;
1499 use std::sync::atomic::{AtomicU32, Ordering};
1500
1501 #[tokio::test]
1504 async fn test_retry_db_write_succeeds_first_try() {
1505 let call_count = Arc::new(AtomicU32::new(0));
1506 let cc = call_count.clone();
1507 let result = retry_db_write(|| {
1508 let c = cc.clone();
1509 async move {
1510 c.fetch_add(1, Ordering::SeqCst);
1511 Ok::<(), DurableError>(())
1512 }
1513 })
1514 .await;
1515 assert!(result.is_ok());
1516 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1517 }
1518
1519 #[tokio::test]
1522 async fn test_retry_db_write_succeeds_after_transient_failure() {
1523 let call_count = Arc::new(AtomicU32::new(0));
1524 let cc = call_count.clone();
1525 let result = retry_db_write(|| {
1526 let c = cc.clone();
1527 async move {
1528 let n = c.fetch_add(1, Ordering::SeqCst);
1529 if n < 2 {
1530 Err(DurableError::Db(sea_orm::DbErr::Custom(
1531 "transient".to_string(),
1532 )))
1533 } else {
1534 Ok(())
1535 }
1536 }
1537 })
1538 .await;
1539 assert!(result.is_ok());
1540 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1541 }
1542
1543 #[tokio::test]
1546 async fn test_retry_db_write_exhausts_retries() {
1547 let call_count = Arc::new(AtomicU32::new(0));
1548 let cc = call_count.clone();
1549 let result = retry_db_write(|| {
1550 let c = cc.clone();
1551 async move {
1552 c.fetch_add(1, Ordering::SeqCst);
1553 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1554 "always fails".to_string(),
1555 )))
1556 }
1557 })
1558 .await;
1559 assert!(result.is_err());
1560 assert_eq!(
1562 call_count.load(Ordering::SeqCst),
1563 1 + MAX_CHECKPOINT_RETRIES
1564 );
1565 }
1566
1567 #[tokio::test]
1570 async fn test_retry_db_write_returns_original_error() {
1571 let call_count = Arc::new(AtomicU32::new(0));
1572 let cc = call_count.clone();
1573 let result = retry_db_write(|| {
1574 let c = cc.clone();
1575 async move {
1576 let n = c.fetch_add(1, Ordering::SeqCst);
1577 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1578 "error-{}",
1579 n
1580 ))))
1581 }
1582 })
1583 .await;
1584 let err = result.unwrap_err();
1585 assert!(
1587 err.to_string().contains("error-0"),
1588 "expected first error (error-0), got: {err}"
1589 );
1590 }
1591}