1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub status: TaskStatus,
200 pub kind: String,
201 pub input: Option<serde_json::Value>,
202 pub output: Option<serde_json::Value>,
203 pub error: Option<String>,
204 pub queue_name: Option<String>,
205 pub created_at: chrono::DateTime<chrono::FixedOffset>,
206 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211 fn from(m: durable_db::entity::task::Model) -> Self {
212 Self {
213 id: m.id,
214 parent_id: m.parent_id,
215 name: m.name,
216 status: m.status,
217 kind: m.kind,
218 input: m.input,
219 output: m.output,
220 error: m.error,
221 queue_name: m.queue_name,
222 created_at: m.created_at,
223 started_at: m.started_at,
224 completed_at: m.completed_at,
225 }
226 }
227}
228
229pub struct Ctx {
235 db: DatabaseConnection,
236 task_id: Uuid,
237 sequence: AtomicI32,
238 executor_id: Option<String>,
239}
240
241impl Ctx {
242 pub async fn start(
252 db: &DatabaseConnection,
253 name: &str,
254 input: Option<serde_json::Value>,
255 ) -> Result<Self, DurableError> {
256 Self::start_with_executor(db, name, input, None).await
257 }
258
259 pub async fn start_with_executor(
265 db: &DatabaseConnection,
266 name: &str,
267 input: Option<serde_json::Value>,
268 executor_id: Option<String>,
269 ) -> Result<Self, DurableError> {
270 let task_id = Uuid::new_v4();
271 let input_json = match &input {
272 Some(v) => serde_json::to_string(v)?,
273 None => "null".to_string(),
274 };
275
276 let executor_col = if executor_id.is_some() { ", executor_id" } else { "" };
277 let executor_val = match &executor_id {
278 Some(eid) => format!(", '{eid}'"),
279 None => String::new(),
280 };
281
282 let txn = db.begin().await?;
283 let sql = format!(
284 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{executor_col}) \
285 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){executor_val})"
286 );
287 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
288 .await?;
289 txn.commit().await?;
290
291 Ok(Self {
292 db: db.clone(),
293 task_id,
294 sequence: AtomicI32::new(0),
295 executor_id,
296 })
297 }
298
299 pub async fn from_id(
306 db: &DatabaseConnection,
307 task_id: Uuid,
308 ) -> Result<Self, DurableError> {
309 let model = Task::find_by_id(task_id).one(db).await?;
311 let model =
312 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
313
314 Ok(Self {
319 db: db.clone(),
320 task_id,
321 sequence: AtomicI32::new(0),
322 executor_id: model.executor_id,
323 })
324 }
325
326 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
335 where
336 T: Serialize + DeserializeOwned,
337 F: FnOnce() -> Fut,
338 Fut: std::future::Future<Output = Result<T, DurableError>>,
339 {
340 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
341
342 check_status(&self.db, self.task_id).await?;
344
345 check_deadline(&self.db, self.task_id).await?;
347
348 let txn = self.db.begin().await?;
350
351 let (step_id, saved_output) = find_or_create_task(
356 &txn,
357 Some(self.task_id),
358 Some(seq),
359 name,
360 "STEP",
361 None,
362 true,
363 Some(0),
364 )
365 .await?;
366
367 if let Some(output) = saved_output {
369 txn.commit().await?;
370 let val: T = serde_json::from_value(output)?;
371 tracing::debug!(step = name, seq, "replaying saved output");
372 return Ok(val);
373 }
374
375 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
377 match f().await {
378 Ok(val) => {
379 let json = serde_json::to_value(&val)?;
380 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
381 txn.commit().await?;
382 tracing::debug!(step = name, seq, "step completed");
383 Ok(val)
384 }
385 Err(e) => {
386 let err_msg = e.to_string();
387 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
388 txn.commit().await?;
389 Err(e)
390 }
391 }
392 }
393
394 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
406 where
407 T: Serialize + DeserializeOwned + Send,
408 F: for<'tx> FnOnce(
409 &'tx DatabaseTransaction,
410 ) -> Pin<
411 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
412 > + Send,
413 {
414 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
415
416 check_status(&self.db, self.task_id).await?;
418
419 let (step_id, saved_output) = find_or_create_task(
422 &self.db,
423 Some(self.task_id),
424 Some(seq),
425 name,
426 "TRANSACTION",
427 None,
428 false,
429 None,
430 )
431 .await?;
432
433 if let Some(output) = saved_output {
435 let val: T = serde_json::from_value(output)?;
436 tracing::debug!(step = name, seq, "replaying saved transaction output");
437 return Ok(val);
438 }
439
440 let tx = self.db.begin().await?;
442
443 set_status(&tx, step_id, TaskStatus::Running).await?;
444
445 match f(&tx).await {
446 Ok(val) => {
447 let json = serde_json::to_value(&val)?;
448 complete_task(&tx, step_id, json).await?;
449 tx.commit().await?;
450 tracing::debug!(step = name, seq, "transaction step committed");
451 Ok(val)
452 }
453 Err(e) => {
454 drop(tx);
457 fail_task(&self.db, step_id, &e.to_string()).await?;
458 Err(e)
459 }
460 }
461 }
462
463 pub async fn child(
471 &self,
472 name: &str,
473 input: Option<serde_json::Value>,
474 ) -> Result<Self, DurableError> {
475 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
476
477 check_status(&self.db, self.task_id).await?;
479
480 let txn = self.db.begin().await?;
481 let (child_id, _saved) = find_or_create_task(
483 &txn,
484 Some(self.task_id),
485 Some(seq),
486 name,
487 "WORKFLOW",
488 input,
489 false,
490 None,
491 )
492 .await?;
493
494 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
497 txn.commit().await?;
498
499 Ok(Self {
500 db: self.db.clone(),
501 task_id: child_id,
502 sequence: AtomicI32::new(0),
503 executor_id: self.executor_id.clone(),
504 })
505 }
506
507 pub async fn is_completed(&self) -> Result<bool, DurableError> {
509 let status = get_status(&self.db, self.task_id).await?;
510 Ok(status == Some(TaskStatus::Completed))
511 }
512
513 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
515 match get_output(&self.db, self.task_id).await? {
516 Some(val) => Ok(Some(serde_json::from_value(val)?)),
517 None => Ok(None),
518 }
519 }
520
521 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
523 let json = serde_json::to_value(output)?;
524 let db = &self.db;
525 let task_id = self.task_id;
526 retry_db_write(|| complete_task(db, task_id, json.clone())).await
527 }
528
529 pub async fn step_with_retry<T, F, Fut>(
543 &self,
544 name: &str,
545 policy: RetryPolicy,
546 f: F,
547 ) -> Result<T, DurableError>
548 where
549 T: Serialize + DeserializeOwned,
550 F: Fn() -> Fut,
551 Fut: std::future::Future<Output = Result<T, DurableError>>,
552 {
553 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
554
555 check_status(&self.db, self.task_id).await?;
557
558 let (step_id, saved_output) = find_or_create_task(
562 &self.db,
563 Some(self.task_id),
564 Some(seq),
565 name,
566 "STEP",
567 None,
568 false,
569 Some(policy.max_retries),
570 )
571 .await?;
572
573 if let Some(output) = saved_output {
575 let val: T = serde_json::from_value(output)?;
576 tracing::debug!(step = name, seq, "replaying saved output");
577 return Ok(val);
578 }
579
580 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
582
583 loop {
585 check_status(&self.db, self.task_id).await?;
587 set_status(&self.db, step_id, TaskStatus::Running).await?;
588 match f().await {
589 Ok(val) => {
590 let json = serde_json::to_value(&val)?;
591 complete_task(&self.db, step_id, json).await?;
592 tracing::debug!(step = name, seq, retry_count, "step completed");
593 return Ok(val);
594 }
595 Err(e) => {
596 if retry_count < max_retries {
597 retry_count = increment_retry_count(&self.db, step_id).await?;
599 tracing::debug!(
600 step = name,
601 seq,
602 retry_count,
603 max_retries,
604 "step failed, retrying"
605 );
606
607 let backoff = if policy.initial_backoff.is_zero() {
609 std::time::Duration::ZERO
610 } else {
611 let factor = policy
612 .backoff_multiplier
613 .powi((retry_count - 1) as i32)
614 .max(1.0);
615 let millis =
616 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
617 std::time::Duration::from_millis(millis)
618 };
619
620 if !backoff.is_zero() {
621 tokio::time::sleep(backoff).await;
622 }
623 } else {
624 fail_task(&self.db, step_id, &e.to_string()).await?;
626 tracing::debug!(
627 step = name,
628 seq,
629 retry_count,
630 "step exhausted retries, marked FAILED"
631 );
632 return Err(e);
633 }
634 }
635 }
636 }
637 }
638
639 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
641 let db = &self.db;
642 let task_id = self.task_id;
643 retry_db_write(|| fail_task(db, task_id, error)).await
644 }
645
646 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
654 let sql = format!(
655 "UPDATE durable.task \
656 SET timeout_ms = {timeout_ms}, \
657 deadline_epoch_ms = CASE \
658 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
659 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
660 ELSE deadline_epoch_ms \
661 END \
662 WHERE id = '{}'",
663 self.task_id
664 );
665 self.db
666 .execute(Statement::from_string(DbBackend::Postgres, sql))
667 .await?;
668 Ok(())
669 }
670
671 pub async fn start_with_timeout(
675 db: &DatabaseConnection,
676 name: &str,
677 input: Option<serde_json::Value>,
678 timeout_ms: i64,
679 ) -> Result<Self, DurableError> {
680 let ctx = Self::start(db, name, input).await?;
681 ctx.set_timeout(timeout_ms).await?;
682 Ok(ctx)
683 }
684
685 pub async fn start_with_timeout_and_executor(
687 db: &DatabaseConnection,
688 name: &str,
689 input: Option<serde_json::Value>,
690 timeout_ms: i64,
691 executor_id: Option<String>,
692 ) -> Result<Self, DurableError> {
693 let ctx = Self::start_with_executor(db, name, input, executor_id).await?;
694 ctx.set_timeout(timeout_ms).await?;
695 Ok(ctx)
696 }
697
698 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
705 let model = Task::find_by_id(task_id).one(db).await?;
706 let model =
707 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
708
709 match model.status {
710 TaskStatus::Pending | TaskStatus::Running => {}
711 status => {
712 return Err(DurableError::custom(format!(
713 "cannot pause task in {status} status"
714 )));
715 }
716 }
717
718 let sql = format!(
720 "WITH RECURSIVE descendants AS ( \
721 SELECT id FROM durable.task WHERE id = '{task_id}' \
722 UNION ALL \
723 SELECT t.id FROM durable.task t \
724 INNER JOIN descendants d ON t.parent_id = d.id \
725 ) \
726 UPDATE durable.task SET status = 'PAUSED' \
727 WHERE id IN (SELECT id FROM descendants) \
728 AND status IN ('PENDING', 'RUNNING')"
729 );
730 db.execute(Statement::from_string(DbBackend::Postgres, sql))
731 .await?;
732
733 tracing::info!(%task_id, "workflow paused");
734 Ok(())
735 }
736
737 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
740 let model = Task::find_by_id(task_id).one(db).await?;
741 let model =
742 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
743
744 if model.status != TaskStatus::Paused {
745 return Err(DurableError::custom(format!(
746 "cannot resume task in {} status (must be PAUSED)",
747 model.status
748 )));
749 }
750
751 let mut active: TaskActiveModel = model.into();
753 active.status = Set(TaskStatus::Running);
754 active.update(db).await?;
755
756 let cascade_sql = format!(
758 "WITH RECURSIVE descendants AS ( \
759 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
760 UNION ALL \
761 SELECT t.id FROM durable.task t \
762 INNER JOIN descendants d ON t.parent_id = d.id \
763 ) \
764 UPDATE durable.task SET status = 'PENDING' \
765 WHERE id IN (SELECT id FROM descendants) \
766 AND status = 'PAUSED'"
767 );
768 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
769 .await?;
770
771 tracing::info!(%task_id, "workflow resumed");
772 Ok(())
773 }
774
775 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
780 let model = Task::find_by_id(task_id).one(db).await?;
781 let model =
782 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
783
784 match model.status {
785 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
786 return Err(DurableError::custom(format!(
787 "cannot cancel task in {} status",
788 model.status
789 )));
790 }
791 _ => {}
792 }
793
794 let sql = format!(
796 "WITH RECURSIVE descendants AS ( \
797 SELECT id FROM durable.task WHERE id = '{task_id}' \
798 UNION ALL \
799 SELECT t.id FROM durable.task t \
800 INNER JOIN descendants d ON t.parent_id = d.id \
801 ) \
802 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
803 WHERE id IN (SELECT id FROM descendants) \
804 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
805 );
806 db.execute(Statement::from_string(DbBackend::Postgres, sql))
807 .await?;
808
809 tracing::info!(%task_id, "workflow cancelled");
810 Ok(())
811 }
812
813 pub async fn list(
821 db: &DatabaseConnection,
822 query: TaskQuery,
823 ) -> Result<Vec<TaskSummary>, DurableError> {
824 let mut select = Task::find();
825
826 if let Some(status) = &query.status {
828 select = select.filter(TaskColumn::Status.eq(status.to_string()));
829 }
830 if let Some(kind) = &query.kind {
831 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
832 }
833 if let Some(parent_id) = query.parent_id {
834 select = select.filter(TaskColumn::ParentId.eq(parent_id));
835 }
836 if query.root_only {
837 select = select.filter(TaskColumn::ParentId.is_null());
838 }
839 if let Some(name) = &query.name {
840 select = select.filter(TaskColumn::Name.eq(name.as_str()));
841 }
842 if let Some(queue) = &query.queue_name {
843 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
844 }
845
846 let (col, order) = match query.sort {
848 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
849 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
850 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
851 TaskSort::Name(ord) => (TaskColumn::Name, ord),
852 TaskSort::Status(ord) => (TaskColumn::Status, ord),
853 };
854 select = select.order_by(col, order);
855
856 if let Some(offset) = query.offset {
858 select = select.offset(offset);
859 }
860 if let Some(limit) = query.limit {
861 select = select.limit(limit);
862 }
863
864 let models = select.all(db).await?;
865
866 Ok(models.into_iter().map(TaskSummary::from).collect())
867 }
868
869 pub async fn count(
871 db: &DatabaseConnection,
872 query: TaskQuery,
873 ) -> Result<u64, DurableError> {
874 let mut select = Task::find();
875
876 if let Some(status) = &query.status {
877 select = select.filter(TaskColumn::Status.eq(status.to_string()));
878 }
879 if let Some(kind) = &query.kind {
880 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
881 }
882 if let Some(parent_id) = query.parent_id {
883 select = select.filter(TaskColumn::ParentId.eq(parent_id));
884 }
885 if query.root_only {
886 select = select.filter(TaskColumn::ParentId.is_null());
887 }
888 if let Some(name) = &query.name {
889 select = select.filter(TaskColumn::Name.eq(name.as_str()));
890 }
891 if let Some(queue) = &query.queue_name {
892 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
893 }
894
895 let count = select.count(db).await?;
896 Ok(count)
897 }
898
899 pub fn db(&self) -> &DatabaseConnection {
902 &self.db
903 }
904
905 pub fn task_id(&self) -> Uuid {
906 self.task_id
907 }
908
909 pub fn next_sequence(&self) -> i32 {
910 self.sequence.fetch_add(1, Ordering::SeqCst)
911 }
912}
913
914#[allow(clippy::too_many_arguments)]
935async fn find_or_create_task(
936 db: &impl ConnectionTrait,
937 parent_id: Option<Uuid>,
938 sequence: Option<i32>,
939 name: &str,
940 kind: &str,
941 input: Option<serde_json::Value>,
942 lock: bool,
943 max_retries: Option<u32>,
944) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
945 let parent_eq = match parent_id {
946 Some(p) => format!("= '{p}'"),
947 None => "IS NULL".to_string(),
948 };
949 let parent_sql = match parent_id {
950 Some(p) => format!("'{p}'"),
951 None => "NULL".to_string(),
952 };
953
954 if lock {
955 let new_id = Uuid::new_v4();
969 let seq_sql = match sequence {
970 Some(s) => s.to_string(),
971 None => "NULL".to_string(),
972 };
973 let input_sql = match &input {
974 Some(v) => format!("'{}'", serde_json::to_string(v)?),
975 None => "NULL".to_string(),
976 };
977
978 let max_retries_sql = match max_retries {
979 Some(r) => r.to_string(),
980 None => "3".to_string(), };
982
983 let insert_sql = format!(
985 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
986 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
987 ON CONFLICT (parent_id, sequence) DO NOTHING"
988 );
989 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
990 .await?;
991
992 let lock_sql = format!(
994 "SELECT id, status::text, output FROM durable.task \
995 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
996 FOR UPDATE SKIP LOCKED"
997 );
998 let row = db
999 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1000 .await?;
1001
1002 if let Some(row) = row {
1003 let id: Uuid = row
1004 .try_get_by_index(0)
1005 .map_err(|e| DurableError::custom(e.to_string()))?;
1006 let status: String = row
1007 .try_get_by_index(1)
1008 .map_err(|e| DurableError::custom(e.to_string()))?;
1009 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1010
1011 if status == TaskStatus::Completed.to_string() {
1012 return Ok((id, output));
1014 }
1015 return Ok((id, None));
1017 }
1018
1019 Err(DurableError::StepLocked(name.to_string()))
1021 } else {
1022 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1026 query = match parent_id {
1027 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1028 None => query.filter(TaskColumn::ParentId.is_null()),
1029 };
1030 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1033 let existing = query
1034 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1035 .one(db)
1036 .await?;
1037
1038 if let Some(model) = existing {
1039 if model.status == TaskStatus::Completed {
1040 return Ok((model.id, model.output));
1041 }
1042 return Ok((model.id, None));
1043 }
1044
1045 let id = Uuid::new_v4();
1047 let new_task = TaskActiveModel {
1048 id: Set(id),
1049 parent_id: Set(parent_id),
1050 sequence: Set(sequence),
1051 name: Set(name.to_string()),
1052 kind: Set(kind.to_string()),
1053 status: Set(TaskStatus::Pending),
1054 input: Set(input),
1055 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1056 ..Default::default()
1057 };
1058 new_task.insert(db).await?;
1059
1060 Ok((id, None))
1061 }
1062}
1063
1064async fn get_output(
1065 db: &impl ConnectionTrait,
1066 task_id: Uuid,
1067) -> Result<Option<serde_json::Value>, DurableError> {
1068 let model = Task::find_by_id(task_id)
1069 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1070 .one(db)
1071 .await?;
1072
1073 Ok(model.and_then(|m| m.output))
1074}
1075
1076async fn get_status(
1077 db: &impl ConnectionTrait,
1078 task_id: Uuid,
1079) -> Result<Option<TaskStatus>, DurableError> {
1080 let model = Task::find_by_id(task_id).one(db).await?;
1081
1082 Ok(model.map(|m| m.status))
1083}
1084
1085async fn get_retry_info(
1087 db: &DatabaseConnection,
1088 task_id: Uuid,
1089) -> Result<(u32, u32), DurableError> {
1090 let model = Task::find_by_id(task_id).one(db).await?;
1091
1092 match model {
1093 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1094 None => Err(DurableError::custom(format!(
1095 "task {task_id} not found when reading retry info"
1096 ))),
1097 }
1098}
1099
1100async fn increment_retry_count(
1102 db: &DatabaseConnection,
1103 task_id: Uuid,
1104) -> Result<u32, DurableError> {
1105 let model = Task::find_by_id(task_id).one(db).await?;
1106
1107 match model {
1108 Some(m) => {
1109 let new_count = m.retry_count + 1;
1110 let mut active: TaskActiveModel = m.into();
1111 active.retry_count = Set(new_count);
1112 active.status = Set(TaskStatus::Pending);
1113 active.error = Set(None);
1114 active.completed_at = Set(None);
1115 active.update(db).await?;
1116 Ok(new_count as u32)
1117 }
1118 None => Err(DurableError::custom(format!(
1119 "task {task_id} not found when incrementing retry count"
1120 ))),
1121 }
1122}
1123
1124async fn set_status(
1127 db: &impl ConnectionTrait,
1128 task_id: Uuid,
1129 status: TaskStatus,
1130) -> Result<(), DurableError> {
1131 let sql = format!(
1132 "UPDATE durable.task \
1133 SET status = '{status}', \
1134 started_at = COALESCE(started_at, now()), \
1135 deadline_epoch_ms = CASE \
1136 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1137 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1138 ELSE deadline_epoch_ms \
1139 END \
1140 WHERE id = '{task_id}'"
1141 );
1142 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1143 .await?;
1144 Ok(())
1145}
1146
1147async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1149 let status = get_status(db, task_id).await?;
1150 match status {
1151 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1152 Some(TaskStatus::Cancelled) => {
1153 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1154 }
1155 _ => Ok(()),
1156 }
1157}
1158
1159async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1161 let model = Task::find_by_id(task_id).one(db).await?;
1162
1163 if let Some(m) = model
1164 && let Some(deadline_ms) = m.deadline_epoch_ms
1165 {
1166 let now_ms = std::time::SystemTime::now()
1167 .duration_since(std::time::UNIX_EPOCH)
1168 .map(|d| d.as_millis() as i64)
1169 .unwrap_or(0);
1170 if now_ms > deadline_ms {
1171 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1172 }
1173 }
1174
1175 Ok(())
1176}
1177
1178async fn complete_task(
1179 db: &impl ConnectionTrait,
1180 task_id: Uuid,
1181 output: serde_json::Value,
1182) -> Result<(), DurableError> {
1183 let model = Task::find_by_id(task_id).one(db).await?;
1184
1185 if let Some(m) = model {
1186 let mut active: TaskActiveModel = m.into();
1187 active.status = Set(TaskStatus::Completed);
1188 active.output = Set(Some(output));
1189 active.completed_at = Set(Some(chrono::Utc::now().into()));
1190 active.update(db).await?;
1191 }
1192 Ok(())
1193}
1194
1195async fn fail_task(
1196 db: &impl ConnectionTrait,
1197 task_id: Uuid,
1198 error: &str,
1199) -> Result<(), DurableError> {
1200 let model = Task::find_by_id(task_id).one(db).await?;
1201
1202 if let Some(m) = model {
1203 let mut active: TaskActiveModel = m.into();
1204 active.status = Set(TaskStatus::Failed);
1205 active.error = Set(Some(error.to_string()));
1206 active.completed_at = Set(Some(chrono::Utc::now().into()));
1207 active.update(db).await?;
1208 }
1209 Ok(())
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214 use super::*;
1215 use std::sync::Arc;
1216 use std::sync::atomic::{AtomicU32, Ordering};
1217
1218 #[tokio::test]
1221 async fn test_retry_db_write_succeeds_first_try() {
1222 let call_count = Arc::new(AtomicU32::new(0));
1223 let cc = call_count.clone();
1224 let result = retry_db_write(|| {
1225 let c = cc.clone();
1226 async move {
1227 c.fetch_add(1, Ordering::SeqCst);
1228 Ok::<(), DurableError>(())
1229 }
1230 })
1231 .await;
1232 assert!(result.is_ok());
1233 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1234 }
1235
1236 #[tokio::test]
1239 async fn test_retry_db_write_succeeds_after_transient_failure() {
1240 let call_count = Arc::new(AtomicU32::new(0));
1241 let cc = call_count.clone();
1242 let result = retry_db_write(|| {
1243 let c = cc.clone();
1244 async move {
1245 let n = c.fetch_add(1, Ordering::SeqCst);
1246 if n < 2 {
1247 Err(DurableError::Db(sea_orm::DbErr::Custom(
1248 "transient".to_string(),
1249 )))
1250 } else {
1251 Ok(())
1252 }
1253 }
1254 })
1255 .await;
1256 assert!(result.is_ok());
1257 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1258 }
1259
1260 #[tokio::test]
1263 async fn test_retry_db_write_exhausts_retries() {
1264 let call_count = Arc::new(AtomicU32::new(0));
1265 let cc = call_count.clone();
1266 let result = retry_db_write(|| {
1267 let c = cc.clone();
1268 async move {
1269 c.fetch_add(1, Ordering::SeqCst);
1270 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1271 "always fails".to_string(),
1272 )))
1273 }
1274 })
1275 .await;
1276 assert!(result.is_err());
1277 assert_eq!(
1279 call_count.load(Ordering::SeqCst),
1280 1 + MAX_CHECKPOINT_RETRIES
1281 );
1282 }
1283
1284 #[tokio::test]
1287 async fn test_retry_db_write_returns_original_error() {
1288 let call_count = Arc::new(AtomicU32::new(0));
1289 let cc = call_count.clone();
1290 let result = retry_db_write(|| {
1291 let c = cc.clone();
1292 async move {
1293 let n = c.fetch_add(1, Ordering::SeqCst);
1294 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1295 "error-{}",
1296 n
1297 ))))
1298 }
1299 })
1300 .await;
1301 let err = result.unwrap_err();
1302 assert!(
1304 err.to_string().contains("error-0"),
1305 "expected first error (error-0), got: {err}"
1306 );
1307 }
1308}