1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub status: TaskStatus,
200 pub kind: String,
201 pub input: Option<serde_json::Value>,
202 pub output: Option<serde_json::Value>,
203 pub error: Option<String>,
204 pub queue_name: Option<String>,
205 pub created_at: chrono::DateTime<chrono::FixedOffset>,
206 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211 fn from(m: durable_db::entity::task::Model) -> Self {
212 Self {
213 id: m.id,
214 parent_id: m.parent_id,
215 name: m.name,
216 status: m.status,
217 kind: m.kind,
218 input: m.input,
219 output: m.output,
220 error: m.error,
221 queue_name: m.queue_name,
222 created_at: m.created_at,
223 started_at: m.started_at,
224 completed_at: m.completed_at,
225 }
226 }
227}
228
229pub struct Ctx {
235 db: DatabaseConnection,
236 task_id: Uuid,
237 sequence: AtomicI32,
238 executor_id: Option<String>,
239}
240
241impl Ctx {
242 pub async fn start(
255 db: &DatabaseConnection,
256 name: &str,
257 input: Option<serde_json::Value>,
258 ) -> Result<Self, DurableError> {
259 let task_id = Uuid::new_v4();
260 let input_json = match &input {
261 Some(v) => serde_json::to_string(v)?,
262 None => "null".to_string(),
263 };
264
265 let executor_id = crate::executor_id();
267
268 let executor_col = if executor_id.is_some() { ", executor_id" } else { "" };
269 let executor_val = match &executor_id {
270 Some(eid) => format!(", '{eid}'"),
271 None => String::new(),
272 };
273
274 let txn = db.begin().await?;
275 let sql = format!(
276 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{executor_col}) \
277 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){executor_val})"
278 );
279 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
280 .await?;
281 txn.commit().await?;
282
283 Ok(Self {
284 db: db.clone(),
285 task_id,
286 sequence: AtomicI32::new(0),
287 executor_id,
288 })
289 }
290
291 pub async fn from_id(
298 db: &DatabaseConnection,
299 task_id: Uuid,
300 ) -> Result<Self, DurableError> {
301 let model = Task::find_by_id(task_id).one(db).await?;
303 let _model =
304 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
305
306 let executor_id = crate::executor_id();
308 if let Some(eid) = &executor_id {
309 db.execute(Statement::from_string(
310 DbBackend::Postgres,
311 format!(
312 "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
313 ),
314 ))
315 .await?;
316 }
317
318 Ok(Self {
323 db: db.clone(),
324 task_id,
325 sequence: AtomicI32::new(0),
326 executor_id,
327 })
328 }
329
330 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
339 where
340 T: Serialize + DeserializeOwned,
341 F: FnOnce() -> Fut,
342 Fut: std::future::Future<Output = Result<T, DurableError>>,
343 {
344 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
345
346 check_status(&self.db, self.task_id).await?;
348
349 check_deadline(&self.db, self.task_id).await?;
351
352 let txn = self.db.begin().await?;
354
355 let (step_id, saved_output) = find_or_create_task(
360 &txn,
361 Some(self.task_id),
362 Some(seq),
363 name,
364 "STEP",
365 None,
366 true,
367 Some(0),
368 )
369 .await?;
370
371 if let Some(output) = saved_output {
373 txn.commit().await?;
374 let val: T = serde_json::from_value(output)?;
375 tracing::debug!(step = name, seq, "replaying saved output");
376 return Ok(val);
377 }
378
379 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
381 match f().await {
382 Ok(val) => {
383 let json = serde_json::to_value(&val)?;
384 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
385 txn.commit().await?;
386 tracing::debug!(step = name, seq, "step completed");
387 Ok(val)
388 }
389 Err(e) => {
390 let err_msg = e.to_string();
391 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
392 txn.commit().await?;
393 Err(e)
394 }
395 }
396 }
397
398 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
410 where
411 T: Serialize + DeserializeOwned + Send,
412 F: for<'tx> FnOnce(
413 &'tx DatabaseTransaction,
414 ) -> Pin<
415 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
416 > + Send,
417 {
418 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
419
420 check_status(&self.db, self.task_id).await?;
422
423 let (step_id, saved_output) = find_or_create_task(
426 &self.db,
427 Some(self.task_id),
428 Some(seq),
429 name,
430 "TRANSACTION",
431 None,
432 false,
433 None,
434 )
435 .await?;
436
437 if let Some(output) = saved_output {
439 let val: T = serde_json::from_value(output)?;
440 tracing::debug!(step = name, seq, "replaying saved transaction output");
441 return Ok(val);
442 }
443
444 let tx = self.db.begin().await?;
446
447 set_status(&tx, step_id, TaskStatus::Running).await?;
448
449 match f(&tx).await {
450 Ok(val) => {
451 let json = serde_json::to_value(&val)?;
452 complete_task(&tx, step_id, json).await?;
453 tx.commit().await?;
454 tracing::debug!(step = name, seq, "transaction step committed");
455 Ok(val)
456 }
457 Err(e) => {
458 drop(tx);
461 fail_task(&self.db, step_id, &e.to_string()).await?;
462 Err(e)
463 }
464 }
465 }
466
467 pub async fn child(
475 &self,
476 name: &str,
477 input: Option<serde_json::Value>,
478 ) -> Result<Self, DurableError> {
479 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
480
481 check_status(&self.db, self.task_id).await?;
483
484 let txn = self.db.begin().await?;
485 let (child_id, _saved) = find_or_create_task(
487 &txn,
488 Some(self.task_id),
489 Some(seq),
490 name,
491 "WORKFLOW",
492 input,
493 false,
494 None,
495 )
496 .await?;
497
498 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
501 txn.commit().await?;
502
503 Ok(Self {
504 db: self.db.clone(),
505 task_id: child_id,
506 sequence: AtomicI32::new(0),
507 executor_id: self.executor_id.clone(),
508 })
509 }
510
511 pub async fn is_completed(&self) -> Result<bool, DurableError> {
513 let status = get_status(&self.db, self.task_id).await?;
514 Ok(status == Some(TaskStatus::Completed))
515 }
516
517 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
519 match get_output(&self.db, self.task_id).await? {
520 Some(val) => Ok(Some(serde_json::from_value(val)?)),
521 None => Ok(None),
522 }
523 }
524
525 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
527 let json = serde_json::to_value(output)?;
528 let db = &self.db;
529 let task_id = self.task_id;
530 retry_db_write(|| complete_task(db, task_id, json.clone())).await
531 }
532
533 pub async fn step_with_retry<T, F, Fut>(
547 &self,
548 name: &str,
549 policy: RetryPolicy,
550 f: F,
551 ) -> Result<T, DurableError>
552 where
553 T: Serialize + DeserializeOwned,
554 F: Fn() -> Fut,
555 Fut: std::future::Future<Output = Result<T, DurableError>>,
556 {
557 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
558
559 check_status(&self.db, self.task_id).await?;
561
562 let (step_id, saved_output) = find_or_create_task(
566 &self.db,
567 Some(self.task_id),
568 Some(seq),
569 name,
570 "STEP",
571 None,
572 false,
573 Some(policy.max_retries),
574 )
575 .await?;
576
577 if let Some(output) = saved_output {
579 let val: T = serde_json::from_value(output)?;
580 tracing::debug!(step = name, seq, "replaying saved output");
581 return Ok(val);
582 }
583
584 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
586
587 loop {
589 check_status(&self.db, self.task_id).await?;
591 set_status(&self.db, step_id, TaskStatus::Running).await?;
592 match f().await {
593 Ok(val) => {
594 let json = serde_json::to_value(&val)?;
595 complete_task(&self.db, step_id, json).await?;
596 tracing::debug!(step = name, seq, retry_count, "step completed");
597 return Ok(val);
598 }
599 Err(e) => {
600 if retry_count < max_retries {
601 retry_count = increment_retry_count(&self.db, step_id).await?;
603 tracing::debug!(
604 step = name,
605 seq,
606 retry_count,
607 max_retries,
608 "step failed, retrying"
609 );
610
611 let backoff = if policy.initial_backoff.is_zero() {
613 std::time::Duration::ZERO
614 } else {
615 let factor = policy
616 .backoff_multiplier
617 .powi((retry_count - 1) as i32)
618 .max(1.0);
619 let millis =
620 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
621 std::time::Duration::from_millis(millis)
622 };
623
624 if !backoff.is_zero() {
625 tokio::time::sleep(backoff).await;
626 }
627 } else {
628 fail_task(&self.db, step_id, &e.to_string()).await?;
630 tracing::debug!(
631 step = name,
632 seq,
633 retry_count,
634 "step exhausted retries, marked FAILED"
635 );
636 return Err(e);
637 }
638 }
639 }
640 }
641 }
642
643 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
645 let db = &self.db;
646 let task_id = self.task_id;
647 retry_db_write(|| fail_task(db, task_id, error)).await
648 }
649
650 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
658 let sql = format!(
659 "UPDATE durable.task \
660 SET timeout_ms = {timeout_ms}, \
661 deadline_epoch_ms = CASE \
662 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
663 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
664 ELSE deadline_epoch_ms \
665 END \
666 WHERE id = '{}'",
667 self.task_id
668 );
669 self.db
670 .execute(Statement::from_string(DbBackend::Postgres, sql))
671 .await?;
672 Ok(())
673 }
674
675 pub async fn start_with_timeout(
679 db: &DatabaseConnection,
680 name: &str,
681 input: Option<serde_json::Value>,
682 timeout_ms: i64,
683 ) -> Result<Self, DurableError> {
684 let ctx = Self::start(db, name, input).await?;
685 ctx.set_timeout(timeout_ms).await?;
686 Ok(ctx)
687 }
688
689 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
696 let model = Task::find_by_id(task_id).one(db).await?;
697 let model =
698 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
699
700 match model.status {
701 TaskStatus::Pending | TaskStatus::Running => {}
702 status => {
703 return Err(DurableError::custom(format!(
704 "cannot pause task in {status} status"
705 )));
706 }
707 }
708
709 let sql = format!(
711 "WITH RECURSIVE descendants AS ( \
712 SELECT id FROM durable.task WHERE id = '{task_id}' \
713 UNION ALL \
714 SELECT t.id FROM durable.task t \
715 INNER JOIN descendants d ON t.parent_id = d.id \
716 ) \
717 UPDATE durable.task SET status = 'PAUSED' \
718 WHERE id IN (SELECT id FROM descendants) \
719 AND status IN ('PENDING', 'RUNNING')"
720 );
721 db.execute(Statement::from_string(DbBackend::Postgres, sql))
722 .await?;
723
724 tracing::info!(%task_id, "workflow paused");
725 Ok(())
726 }
727
728 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
731 let model = Task::find_by_id(task_id).one(db).await?;
732 let model =
733 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
734
735 if model.status != TaskStatus::Paused {
736 return Err(DurableError::custom(format!(
737 "cannot resume task in {} status (must be PAUSED)",
738 model.status
739 )));
740 }
741
742 let mut active: TaskActiveModel = model.into();
744 active.status = Set(TaskStatus::Running);
745 active.update(db).await?;
746
747 let cascade_sql = format!(
749 "WITH RECURSIVE descendants AS ( \
750 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
751 UNION ALL \
752 SELECT t.id FROM durable.task t \
753 INNER JOIN descendants d ON t.parent_id = d.id \
754 ) \
755 UPDATE durable.task SET status = 'PENDING' \
756 WHERE id IN (SELECT id FROM descendants) \
757 AND status = 'PAUSED'"
758 );
759 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
760 .await?;
761
762 tracing::info!(%task_id, "workflow resumed");
763 Ok(())
764 }
765
766 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
771 let model = Task::find_by_id(task_id).one(db).await?;
772 let model =
773 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
774
775 match model.status {
776 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
777 return Err(DurableError::custom(format!(
778 "cannot cancel task in {} status",
779 model.status
780 )));
781 }
782 _ => {}
783 }
784
785 let sql = format!(
787 "WITH RECURSIVE descendants AS ( \
788 SELECT id FROM durable.task WHERE id = '{task_id}' \
789 UNION ALL \
790 SELECT t.id FROM durable.task t \
791 INNER JOIN descendants d ON t.parent_id = d.id \
792 ) \
793 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
794 WHERE id IN (SELECT id FROM descendants) \
795 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
796 );
797 db.execute(Statement::from_string(DbBackend::Postgres, sql))
798 .await?;
799
800 tracing::info!(%task_id, "workflow cancelled");
801 Ok(())
802 }
803
804 pub async fn list(
812 db: &DatabaseConnection,
813 query: TaskQuery,
814 ) -> Result<Vec<TaskSummary>, DurableError> {
815 let mut select = Task::find();
816
817 if let Some(status) = &query.status {
819 select = select.filter(TaskColumn::Status.eq(status.to_string()));
820 }
821 if let Some(kind) = &query.kind {
822 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
823 }
824 if let Some(parent_id) = query.parent_id {
825 select = select.filter(TaskColumn::ParentId.eq(parent_id));
826 }
827 if query.root_only {
828 select = select.filter(TaskColumn::ParentId.is_null());
829 }
830 if let Some(name) = &query.name {
831 select = select.filter(TaskColumn::Name.eq(name.as_str()));
832 }
833 if let Some(queue) = &query.queue_name {
834 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
835 }
836
837 let (col, order) = match query.sort {
839 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
840 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
841 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
842 TaskSort::Name(ord) => (TaskColumn::Name, ord),
843 TaskSort::Status(ord) => (TaskColumn::Status, ord),
844 };
845 select = select.order_by(col, order);
846
847 if let Some(offset) = query.offset {
849 select = select.offset(offset);
850 }
851 if let Some(limit) = query.limit {
852 select = select.limit(limit);
853 }
854
855 let models = select.all(db).await?;
856
857 Ok(models.into_iter().map(TaskSummary::from).collect())
858 }
859
860 pub async fn count(
862 db: &DatabaseConnection,
863 query: TaskQuery,
864 ) -> Result<u64, DurableError> {
865 let mut select = Task::find();
866
867 if let Some(status) = &query.status {
868 select = select.filter(TaskColumn::Status.eq(status.to_string()));
869 }
870 if let Some(kind) = &query.kind {
871 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
872 }
873 if let Some(parent_id) = query.parent_id {
874 select = select.filter(TaskColumn::ParentId.eq(parent_id));
875 }
876 if query.root_only {
877 select = select.filter(TaskColumn::ParentId.is_null());
878 }
879 if let Some(name) = &query.name {
880 select = select.filter(TaskColumn::Name.eq(name.as_str()));
881 }
882 if let Some(queue) = &query.queue_name {
883 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
884 }
885
886 let count = select.count(db).await?;
887 Ok(count)
888 }
889
890 pub fn db(&self) -> &DatabaseConnection {
893 &self.db
894 }
895
896 pub fn task_id(&self) -> Uuid {
897 self.task_id
898 }
899
900 pub fn next_sequence(&self) -> i32 {
901 self.sequence.fetch_add(1, Ordering::SeqCst)
902 }
903}
904
905#[allow(clippy::too_many_arguments)]
926async fn find_or_create_task(
927 db: &impl ConnectionTrait,
928 parent_id: Option<Uuid>,
929 sequence: Option<i32>,
930 name: &str,
931 kind: &str,
932 input: Option<serde_json::Value>,
933 lock: bool,
934 max_retries: Option<u32>,
935) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
936 let parent_eq = match parent_id {
937 Some(p) => format!("= '{p}'"),
938 None => "IS NULL".to_string(),
939 };
940 let parent_sql = match parent_id {
941 Some(p) => format!("'{p}'"),
942 None => "NULL".to_string(),
943 };
944
945 if lock {
946 let new_id = Uuid::new_v4();
960 let seq_sql = match sequence {
961 Some(s) => s.to_string(),
962 None => "NULL".to_string(),
963 };
964 let input_sql = match &input {
965 Some(v) => format!("'{}'", serde_json::to_string(v)?),
966 None => "NULL".to_string(),
967 };
968
969 let max_retries_sql = match max_retries {
970 Some(r) => r.to_string(),
971 None => "3".to_string(), };
973
974 let insert_sql = format!(
976 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
977 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
978 ON CONFLICT (parent_id, sequence) DO NOTHING"
979 );
980 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
981 .await?;
982
983 let lock_sql = format!(
985 "SELECT id, status::text, output FROM durable.task \
986 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
987 FOR UPDATE SKIP LOCKED"
988 );
989 let row = db
990 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
991 .await?;
992
993 if let Some(row) = row {
994 let id: Uuid = row
995 .try_get_by_index(0)
996 .map_err(|e| DurableError::custom(e.to_string()))?;
997 let status: String = row
998 .try_get_by_index(1)
999 .map_err(|e| DurableError::custom(e.to_string()))?;
1000 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1001
1002 if status == TaskStatus::Completed.to_string() {
1003 return Ok((id, output));
1005 }
1006 return Ok((id, None));
1008 }
1009
1010 Err(DurableError::StepLocked(name.to_string()))
1012 } else {
1013 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1017 query = match parent_id {
1018 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1019 None => query.filter(TaskColumn::ParentId.is_null()),
1020 };
1021 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1024 let existing = query
1025 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1026 .one(db)
1027 .await?;
1028
1029 if let Some(model) = existing {
1030 if model.status == TaskStatus::Completed {
1031 return Ok((model.id, model.output));
1032 }
1033 return Ok((model.id, None));
1034 }
1035
1036 let id = Uuid::new_v4();
1038 let new_task = TaskActiveModel {
1039 id: Set(id),
1040 parent_id: Set(parent_id),
1041 sequence: Set(sequence),
1042 name: Set(name.to_string()),
1043 kind: Set(kind.to_string()),
1044 status: Set(TaskStatus::Pending),
1045 input: Set(input),
1046 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1047 ..Default::default()
1048 };
1049 new_task.insert(db).await?;
1050
1051 Ok((id, None))
1052 }
1053}
1054
1055async fn get_output(
1056 db: &impl ConnectionTrait,
1057 task_id: Uuid,
1058) -> Result<Option<serde_json::Value>, DurableError> {
1059 let model = Task::find_by_id(task_id)
1060 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1061 .one(db)
1062 .await?;
1063
1064 Ok(model.and_then(|m| m.output))
1065}
1066
1067async fn get_status(
1068 db: &impl ConnectionTrait,
1069 task_id: Uuid,
1070) -> Result<Option<TaskStatus>, DurableError> {
1071 let model = Task::find_by_id(task_id).one(db).await?;
1072
1073 Ok(model.map(|m| m.status))
1074}
1075
1076async fn get_retry_info(
1078 db: &DatabaseConnection,
1079 task_id: Uuid,
1080) -> Result<(u32, u32), DurableError> {
1081 let model = Task::find_by_id(task_id).one(db).await?;
1082
1083 match model {
1084 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1085 None => Err(DurableError::custom(format!(
1086 "task {task_id} not found when reading retry info"
1087 ))),
1088 }
1089}
1090
1091async fn increment_retry_count(
1093 db: &DatabaseConnection,
1094 task_id: Uuid,
1095) -> Result<u32, DurableError> {
1096 let model = Task::find_by_id(task_id).one(db).await?;
1097
1098 match model {
1099 Some(m) => {
1100 let new_count = m.retry_count + 1;
1101 let mut active: TaskActiveModel = m.into();
1102 active.retry_count = Set(new_count);
1103 active.status = Set(TaskStatus::Pending);
1104 active.error = Set(None);
1105 active.completed_at = Set(None);
1106 active.update(db).await?;
1107 Ok(new_count as u32)
1108 }
1109 None => Err(DurableError::custom(format!(
1110 "task {task_id} not found when incrementing retry count"
1111 ))),
1112 }
1113}
1114
1115async fn set_status(
1118 db: &impl ConnectionTrait,
1119 task_id: Uuid,
1120 status: TaskStatus,
1121) -> Result<(), DurableError> {
1122 let sql = format!(
1123 "UPDATE durable.task \
1124 SET status = '{status}', \
1125 started_at = COALESCE(started_at, now()), \
1126 deadline_epoch_ms = CASE \
1127 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1128 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1129 ELSE deadline_epoch_ms \
1130 END \
1131 WHERE id = '{task_id}'"
1132 );
1133 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1134 .await?;
1135 Ok(())
1136}
1137
1138async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1140 let status = get_status(db, task_id).await?;
1141 match status {
1142 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1143 Some(TaskStatus::Cancelled) => {
1144 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1145 }
1146 _ => Ok(()),
1147 }
1148}
1149
1150async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1152 let model = Task::find_by_id(task_id).one(db).await?;
1153
1154 if let Some(m) = model
1155 && let Some(deadline_ms) = m.deadline_epoch_ms
1156 {
1157 let now_ms = std::time::SystemTime::now()
1158 .duration_since(std::time::UNIX_EPOCH)
1159 .map(|d| d.as_millis() as i64)
1160 .unwrap_or(0);
1161 if now_ms > deadline_ms {
1162 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1163 }
1164 }
1165
1166 Ok(())
1167}
1168
1169async fn complete_task(
1170 db: &impl ConnectionTrait,
1171 task_id: Uuid,
1172 output: serde_json::Value,
1173) -> Result<(), DurableError> {
1174 let model = Task::find_by_id(task_id).one(db).await?;
1175
1176 if let Some(m) = model {
1177 let mut active: TaskActiveModel = m.into();
1178 active.status = Set(TaskStatus::Completed);
1179 active.output = Set(Some(output));
1180 active.completed_at = Set(Some(chrono::Utc::now().into()));
1181 active.update(db).await?;
1182 }
1183 Ok(())
1184}
1185
1186async fn fail_task(
1187 db: &impl ConnectionTrait,
1188 task_id: Uuid,
1189 error: &str,
1190) -> Result<(), DurableError> {
1191 let model = Task::find_by_id(task_id).one(db).await?;
1192
1193 if let Some(m) = model {
1194 let mut active: TaskActiveModel = m.into();
1195 active.status = Set(TaskStatus::Failed);
1196 active.error = Set(Some(error.to_string()));
1197 active.completed_at = Set(Some(chrono::Utc::now().into()));
1198 active.update(db).await?;
1199 }
1200 Ok(())
1201}
1202
1203#[cfg(test)]
1204mod tests {
1205 use super::*;
1206 use std::sync::Arc;
1207 use std::sync::atomic::{AtomicU32, Ordering};
1208
1209 #[tokio::test]
1212 async fn test_retry_db_write_succeeds_first_try() {
1213 let call_count = Arc::new(AtomicU32::new(0));
1214 let cc = call_count.clone();
1215 let result = retry_db_write(|| {
1216 let c = cc.clone();
1217 async move {
1218 c.fetch_add(1, Ordering::SeqCst);
1219 Ok::<(), DurableError>(())
1220 }
1221 })
1222 .await;
1223 assert!(result.is_ok());
1224 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1225 }
1226
1227 #[tokio::test]
1230 async fn test_retry_db_write_succeeds_after_transient_failure() {
1231 let call_count = Arc::new(AtomicU32::new(0));
1232 let cc = call_count.clone();
1233 let result = retry_db_write(|| {
1234 let c = cc.clone();
1235 async move {
1236 let n = c.fetch_add(1, Ordering::SeqCst);
1237 if n < 2 {
1238 Err(DurableError::Db(sea_orm::DbErr::Custom(
1239 "transient".to_string(),
1240 )))
1241 } else {
1242 Ok(())
1243 }
1244 }
1245 })
1246 .await;
1247 assert!(result.is_ok());
1248 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1249 }
1250
1251 #[tokio::test]
1254 async fn test_retry_db_write_exhausts_retries() {
1255 let call_count = Arc::new(AtomicU32::new(0));
1256 let cc = call_count.clone();
1257 let result = retry_db_write(|| {
1258 let c = cc.clone();
1259 async move {
1260 c.fetch_add(1, Ordering::SeqCst);
1261 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1262 "always fails".to_string(),
1263 )))
1264 }
1265 })
1266 .await;
1267 assert!(result.is_err());
1268 assert_eq!(
1270 call_count.load(Ordering::SeqCst),
1271 1 + MAX_CHECKPOINT_RETRIES
1272 );
1273 }
1274
1275 #[tokio::test]
1278 async fn test_retry_db_write_returns_original_error() {
1279 let call_count = Arc::new(AtomicU32::new(0));
1280 let cc = call_count.clone();
1281 let result = retry_db_write(|| {
1282 let c = cc.clone();
1283 async move {
1284 let n = c.fetch_add(1, Ordering::SeqCst);
1285 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1286 "error-{}",
1287 n
1288 ))))
1289 }
1290 })
1291 .await;
1292 let err = result.unwrap_err();
1293 assert!(
1295 err.to_string().contains("error-0"),
1296 "expected first error (error-0), got: {err}"
1297 );
1298 }
1299}