1use std::pin::Pin;
2
3use durable_db::entity::task::{
4 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task, TaskStatus,
5};
6use sea_orm::{
7 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
8 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
9 Statement, TransactionTrait,
10};
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use std::sync::atomic::{AtomicI32, Ordering};
14use std::time::Duration;
15use uuid::Uuid;
16
17use crate::error::DurableError;
18
19const MAX_CHECKPOINT_RETRIES: u32 = 3;
22const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
23
24async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
30where
31 F: FnMut() -> Fut,
32 Fut: std::future::Future<Output = Result<(), DurableError>>,
33{
34 match f().await {
35 Ok(()) => Ok(()),
36 Err(first_err) => {
37 for i in 0..MAX_CHECKPOINT_RETRIES {
38 tokio::time::sleep(Duration::from_millis(
39 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
40 ))
41 .await;
42 if f().await.is_ok() {
43 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
44 return Ok(());
45 }
46 }
47 Err(first_err)
48 }
49 }
50}
51
52pub struct RetryPolicy {
54 pub max_retries: u32,
55 pub initial_backoff: std::time::Duration,
56 pub backoff_multiplier: f64,
57}
58
59impl RetryPolicy {
60 pub fn none() -> Self {
62 Self {
63 max_retries: 0,
64 initial_backoff: std::time::Duration::from_secs(0),
65 backoff_multiplier: 1.0,
66 }
67 }
68
69 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
71 Self {
72 max_retries,
73 initial_backoff,
74 backoff_multiplier: 2.0,
75 }
76 }
77
78 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
80 Self {
81 max_retries,
82 initial_backoff: backoff,
83 backoff_multiplier: 1.0,
84 }
85 }
86}
87
88pub enum TaskSort {
90 CreatedAt(Order),
91 StartedAt(Order),
92 CompletedAt(Order),
93 Name(Order),
94 Status(Order),
95}
96
97pub struct TaskQuery {
109 pub status: Option<TaskStatus>,
110 pub kind: Option<String>,
111 pub parent_id: Option<Uuid>,
112 pub root_only: bool,
113 pub name: Option<String>,
114 pub queue_name: Option<String>,
115 pub sort: TaskSort,
116 pub limit: Option<u64>,
117 pub offset: Option<u64>,
118}
119
120impl Default for TaskQuery {
121 fn default() -> Self {
122 Self {
123 status: None,
124 kind: None,
125 parent_id: None,
126 root_only: false,
127 name: None,
128 queue_name: None,
129 sort: TaskSort::CreatedAt(Order::Desc),
130 limit: None,
131 offset: None,
132 }
133 }
134}
135
136impl TaskQuery {
137 pub fn status(mut self, status: TaskStatus) -> Self {
139 self.status = Some(status);
140 self
141 }
142
143 pub fn kind(mut self, kind: &str) -> Self {
145 self.kind = Some(kind.to_string());
146 self
147 }
148
149 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
151 self.parent_id = Some(parent_id);
152 self
153 }
154
155 pub fn root_only(mut self, root_only: bool) -> Self {
157 self.root_only = root_only;
158 self
159 }
160
161 pub fn name(mut self, name: &str) -> Self {
163 self.name = Some(name.to_string());
164 self
165 }
166
167 pub fn queue_name(mut self, queue: &str) -> Self {
169 self.queue_name = Some(queue.to_string());
170 self
171 }
172
173 pub fn sort(mut self, sort: TaskSort) -> Self {
175 self.sort = sort;
176 self
177 }
178
179 pub fn limit(mut self, limit: u64) -> Self {
181 self.limit = Some(limit);
182 self
183 }
184
185 pub fn offset(mut self, offset: u64) -> Self {
187 self.offset = Some(offset);
188 self
189 }
190}
191
192#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
194pub struct TaskSummary {
195 pub id: Uuid,
196 pub parent_id: Option<Uuid>,
197 pub name: String,
198 pub status: TaskStatus,
199 pub kind: String,
200 pub input: Option<serde_json::Value>,
201 pub output: Option<serde_json::Value>,
202 pub error: Option<String>,
203 pub queue_name: Option<String>,
204 pub created_at: chrono::DateTime<chrono::FixedOffset>,
205 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
206 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207}
208
209impl From<durable_db::entity::task::Model> for TaskSummary {
210 fn from(m: durable_db::entity::task::Model) -> Self {
211 Self {
212 id: m.id,
213 parent_id: m.parent_id,
214 name: m.name,
215 status: m.status,
216 kind: m.kind,
217 input: m.input,
218 output: m.output,
219 error: m.error,
220 queue_name: m.queue_name,
221 created_at: m.created_at,
222 started_at: m.started_at,
223 completed_at: m.completed_at,
224 }
225 }
226}
227
228pub struct Ctx {
234 db: DatabaseConnection,
235 task_id: Uuid,
236 sequence: AtomicI32,
237}
238
239impl Ctx {
240 pub async fn start(
248 db: &DatabaseConnection,
249 name: &str,
250 input: Option<serde_json::Value>,
251 ) -> Result<Self, DurableError> {
252 let txn = db.begin().await?;
253 let (task_id, _saved) =
257 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
258 retry_db_write(|| set_status(&txn, task_id, TaskStatus::Running)).await?;
259 txn.commit().await?;
260 Ok(Self {
261 db: db.clone(),
262 task_id,
263 sequence: AtomicI32::new(0),
264 })
265 }
266
267 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
276 where
277 T: Serialize + DeserializeOwned,
278 F: FnOnce() -> Fut,
279 Fut: std::future::Future<Output = Result<T, DurableError>>,
280 {
281 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
282
283 check_status(&self.db, self.task_id).await?;
285
286 check_deadline(&self.db, self.task_id).await?;
288
289 let txn = self.db.begin().await?;
291
292 let (step_id, saved_output) = find_or_create_task(
297 &txn,
298 Some(self.task_id),
299 Some(seq),
300 name,
301 "STEP",
302 None,
303 true,
304 Some(0),
305 )
306 .await?;
307
308 if let Some(output) = saved_output {
310 txn.commit().await?;
311 let val: T = serde_json::from_value(output)?;
312 tracing::debug!(step = name, seq, "replaying saved output");
313 return Ok(val);
314 }
315
316 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
318 match f().await {
319 Ok(val) => {
320 let json = serde_json::to_value(&val)?;
321 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
322 txn.commit().await?;
323 tracing::debug!(step = name, seq, "step completed");
324 Ok(val)
325 }
326 Err(e) => {
327 let err_msg = e.to_string();
328 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
329 txn.commit().await?;
330 Err(e)
331 }
332 }
333 }
334
335 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
347 where
348 T: Serialize + DeserializeOwned + Send,
349 F: for<'tx> FnOnce(
350 &'tx DatabaseTransaction,
351 ) -> Pin<
352 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
353 > + Send,
354 {
355 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
356
357 check_status(&self.db, self.task_id).await?;
359
360 let (step_id, saved_output) = find_or_create_task(
363 &self.db,
364 Some(self.task_id),
365 Some(seq),
366 name,
367 "TRANSACTION",
368 None,
369 false,
370 None,
371 )
372 .await?;
373
374 if let Some(output) = saved_output {
376 let val: T = serde_json::from_value(output)?;
377 tracing::debug!(step = name, seq, "replaying saved transaction output");
378 return Ok(val);
379 }
380
381 let tx = self.db.begin().await?;
383
384 set_status(&tx, step_id, TaskStatus::Running).await?;
385
386 match f(&tx).await {
387 Ok(val) => {
388 let json = serde_json::to_value(&val)?;
389 complete_task(&tx, step_id, json).await?;
390 tx.commit().await?;
391 tracing::debug!(step = name, seq, "transaction step committed");
392 Ok(val)
393 }
394 Err(e) => {
395 drop(tx);
398 fail_task(&self.db, step_id, &e.to_string()).await?;
399 Err(e)
400 }
401 }
402 }
403
404 pub async fn child(
412 &self,
413 name: &str,
414 input: Option<serde_json::Value>,
415 ) -> Result<Self, DurableError> {
416 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
417
418 check_status(&self.db, self.task_id).await?;
420
421 let txn = self.db.begin().await?;
422 let (child_id, _saved) = find_or_create_task(
424 &txn,
425 Some(self.task_id),
426 Some(seq),
427 name,
428 "WORKFLOW",
429 input,
430 false,
431 None,
432 )
433 .await?;
434
435 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
438 txn.commit().await?;
439
440 Ok(Self {
441 db: self.db.clone(),
442 task_id: child_id,
443 sequence: AtomicI32::new(0),
444 })
445 }
446
447 pub async fn is_completed(&self) -> Result<bool, DurableError> {
449 let status = get_status(&self.db, self.task_id).await?;
450 Ok(status == Some(TaskStatus::Completed))
451 }
452
453 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
455 match get_output(&self.db, self.task_id).await? {
456 Some(val) => Ok(Some(serde_json::from_value(val)?)),
457 None => Ok(None),
458 }
459 }
460
461 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
463 let json = serde_json::to_value(output)?;
464 let db = &self.db;
465 let task_id = self.task_id;
466 retry_db_write(|| complete_task(db, task_id, json.clone())).await
467 }
468
469 pub async fn step_with_retry<T, F, Fut>(
483 &self,
484 name: &str,
485 policy: RetryPolicy,
486 f: F,
487 ) -> Result<T, DurableError>
488 where
489 T: Serialize + DeserializeOwned,
490 F: Fn() -> Fut,
491 Fut: std::future::Future<Output = Result<T, DurableError>>,
492 {
493 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
494
495 check_status(&self.db, self.task_id).await?;
497
498 let (step_id, saved_output) = find_or_create_task(
502 &self.db,
503 Some(self.task_id),
504 Some(seq),
505 name,
506 "STEP",
507 None,
508 false,
509 Some(policy.max_retries),
510 )
511 .await?;
512
513 if let Some(output) = saved_output {
515 let val: T = serde_json::from_value(output)?;
516 tracing::debug!(step = name, seq, "replaying saved output");
517 return Ok(val);
518 }
519
520 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
522
523 loop {
525 check_status(&self.db, self.task_id).await?;
527 set_status(&self.db, step_id, TaskStatus::Running).await?;
528 match f().await {
529 Ok(val) => {
530 let json = serde_json::to_value(&val)?;
531 complete_task(&self.db, step_id, json).await?;
532 tracing::debug!(step = name, seq, retry_count, "step completed");
533 return Ok(val);
534 }
535 Err(e) => {
536 if retry_count < max_retries {
537 retry_count = increment_retry_count(&self.db, step_id).await?;
539 tracing::debug!(
540 step = name,
541 seq,
542 retry_count,
543 max_retries,
544 "step failed, retrying"
545 );
546
547 let backoff = if policy.initial_backoff.is_zero() {
549 std::time::Duration::ZERO
550 } else {
551 let factor = policy
552 .backoff_multiplier
553 .powi((retry_count - 1) as i32)
554 .max(1.0);
555 let millis =
556 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
557 std::time::Duration::from_millis(millis)
558 };
559
560 if !backoff.is_zero() {
561 tokio::time::sleep(backoff).await;
562 }
563 } else {
564 fail_task(&self.db, step_id, &e.to_string()).await?;
566 tracing::debug!(
567 step = name,
568 seq,
569 retry_count,
570 "step exhausted retries, marked FAILED"
571 );
572 return Err(e);
573 }
574 }
575 }
576 }
577 }
578
579 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
581 let db = &self.db;
582 let task_id = self.task_id;
583 retry_db_write(|| fail_task(db, task_id, error)).await
584 }
585
586 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
594 let sql = format!(
595 "UPDATE durable.task \
596 SET timeout_ms = {timeout_ms}, \
597 deadline_epoch_ms = CASE \
598 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
599 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
600 ELSE deadline_epoch_ms \
601 END \
602 WHERE id = '{}'",
603 self.task_id
604 );
605 self.db
606 .execute(Statement::from_string(DbBackend::Postgres, sql))
607 .await?;
608 Ok(())
609 }
610
611 pub async fn start_with_timeout(
615 db: &DatabaseConnection,
616 name: &str,
617 input: Option<serde_json::Value>,
618 timeout_ms: i64,
619 ) -> Result<Self, DurableError> {
620 let ctx = Self::start(db, name, input).await?;
621 ctx.set_timeout(timeout_ms).await?;
622 Ok(ctx)
623 }
624
625 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
632 let model = Task::find_by_id(task_id).one(db).await?;
633 let model =
634 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
635
636 match model.status {
637 TaskStatus::Pending | TaskStatus::Running => {}
638 status => {
639 return Err(DurableError::custom(format!(
640 "cannot pause task in {status} status"
641 )));
642 }
643 }
644
645 let sql = format!(
647 "WITH RECURSIVE descendants AS ( \
648 SELECT id FROM durable.task WHERE id = '{task_id}' \
649 UNION ALL \
650 SELECT t.id FROM durable.task t \
651 INNER JOIN descendants d ON t.parent_id = d.id \
652 ) \
653 UPDATE durable.task SET status = 'PAUSED' \
654 WHERE id IN (SELECT id FROM descendants) \
655 AND status IN ('PENDING', 'RUNNING')"
656 );
657 db.execute(Statement::from_string(DbBackend::Postgres, sql))
658 .await?;
659
660 tracing::info!(%task_id, "workflow paused");
661 Ok(())
662 }
663
664 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
667 let model = Task::find_by_id(task_id).one(db).await?;
668 let model =
669 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
670
671 if model.status != TaskStatus::Paused {
672 return Err(DurableError::custom(format!(
673 "cannot resume task in {} status (must be PAUSED)",
674 model.status
675 )));
676 }
677
678 let sql = format!(
680 "UPDATE durable.task SET status = 'RUNNING' WHERE id = '{task_id}'"
681 );
682 db.execute(Statement::from_string(DbBackend::Postgres, sql))
683 .await?;
684
685 let cascade_sql = format!(
687 "WITH RECURSIVE descendants AS ( \
688 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
689 UNION ALL \
690 SELECT t.id FROM durable.task t \
691 INNER JOIN descendants d ON t.parent_id = d.id \
692 ) \
693 UPDATE durable.task SET status = 'PENDING' \
694 WHERE id IN (SELECT id FROM descendants) \
695 AND status = 'PAUSED'"
696 );
697 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
698 .await?;
699
700 tracing::info!(%task_id, "workflow resumed");
701 Ok(())
702 }
703
704 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
709 let model = Task::find_by_id(task_id).one(db).await?;
710 let model =
711 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
712
713 match model.status {
714 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
715 return Err(DurableError::custom(format!(
716 "cannot cancel task in {} status",
717 model.status
718 )));
719 }
720 _ => {}
721 }
722
723 let sql = format!(
725 "WITH RECURSIVE descendants AS ( \
726 SELECT id FROM durable.task WHERE id = '{task_id}' \
727 UNION ALL \
728 SELECT t.id FROM durable.task t \
729 INNER JOIN descendants d ON t.parent_id = d.id \
730 ) \
731 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
732 WHERE id IN (SELECT id FROM descendants) \
733 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
734 );
735 db.execute(Statement::from_string(DbBackend::Postgres, sql))
736 .await?;
737
738 tracing::info!(%task_id, "workflow cancelled");
739 Ok(())
740 }
741
742 pub async fn list(
750 db: &DatabaseConnection,
751 query: TaskQuery,
752 ) -> Result<Vec<TaskSummary>, DurableError> {
753 let mut select = Task::find();
754
755 if let Some(status) = &query.status {
757 select = select.filter(TaskColumn::Status.eq(status.to_string()));
758 }
759 if let Some(kind) = &query.kind {
760 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
761 }
762 if let Some(parent_id) = query.parent_id {
763 select = select.filter(TaskColumn::ParentId.eq(parent_id));
764 }
765 if query.root_only {
766 select = select.filter(TaskColumn::ParentId.is_null());
767 }
768 if let Some(name) = &query.name {
769 select = select.filter(TaskColumn::Name.eq(name.as_str()));
770 }
771 if let Some(queue) = &query.queue_name {
772 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
773 }
774
775 let (col, order) = match query.sort {
777 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
778 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
779 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
780 TaskSort::Name(ord) => (TaskColumn::Name, ord),
781 TaskSort::Status(ord) => (TaskColumn::Status, ord),
782 };
783 select = select.order_by(col, order);
784
785 if let Some(offset) = query.offset {
787 select = select.offset(offset);
788 }
789 if let Some(limit) = query.limit {
790 select = select.limit(limit);
791 }
792
793 let models = select.all(db).await?;
794
795 Ok(models.into_iter().map(TaskSummary::from).collect())
796 }
797
798 pub async fn count(
800 db: &DatabaseConnection,
801 query: TaskQuery,
802 ) -> Result<u64, DurableError> {
803 let mut select = Task::find();
804
805 if let Some(status) = &query.status {
806 select = select.filter(TaskColumn::Status.eq(status.to_string()));
807 }
808 if let Some(kind) = &query.kind {
809 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
810 }
811 if let Some(parent_id) = query.parent_id {
812 select = select.filter(TaskColumn::ParentId.eq(parent_id));
813 }
814 if query.root_only {
815 select = select.filter(TaskColumn::ParentId.is_null());
816 }
817 if let Some(name) = &query.name {
818 select = select.filter(TaskColumn::Name.eq(name.as_str()));
819 }
820 if let Some(queue) = &query.queue_name {
821 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
822 }
823
824 let count = select.count(db).await?;
825 Ok(count)
826 }
827
828 pub fn db(&self) -> &DatabaseConnection {
831 &self.db
832 }
833
834 pub fn task_id(&self) -> Uuid {
835 self.task_id
836 }
837
838 pub fn next_sequence(&self) -> i32 {
839 self.sequence.fetch_add(1, Ordering::SeqCst)
840 }
841}
842
843#[allow(clippy::too_many_arguments)]
864async fn find_or_create_task(
865 db: &impl ConnectionTrait,
866 parent_id: Option<Uuid>,
867 sequence: Option<i32>,
868 name: &str,
869 kind: &str,
870 input: Option<serde_json::Value>,
871 lock: bool,
872 max_retries: Option<u32>,
873) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
874 let parent_eq = match parent_id {
875 Some(p) => format!("= '{p}'"),
876 None => "IS NULL".to_string(),
877 };
878 let parent_sql = match parent_id {
879 Some(p) => format!("'{p}'"),
880 None => "NULL".to_string(),
881 };
882
883 if lock {
884 let new_id = Uuid::new_v4();
898 let seq_sql = match sequence {
899 Some(s) => s.to_string(),
900 None => "NULL".to_string(),
901 };
902 let input_sql = match &input {
903 Some(v) => format!("'{}'", serde_json::to_string(v)?),
904 None => "NULL".to_string(),
905 };
906
907 let max_retries_sql = match max_retries {
908 Some(r) => r.to_string(),
909 None => "3".to_string(), };
911
912 let insert_sql = format!(
914 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
915 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
916 ON CONFLICT (parent_id, name) DO NOTHING"
917 );
918 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
919 .await?;
920
921 let lock_sql = format!(
923 "SELECT id, status, output FROM durable.task \
924 WHERE parent_id {parent_eq} AND name = '{name}' \
925 FOR UPDATE SKIP LOCKED"
926 );
927 let row = db
928 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
929 .await?;
930
931 if let Some(row) = row {
932 let id: Uuid = row
933 .try_get_by_index(0)
934 .map_err(|e| DurableError::custom(e.to_string()))?;
935 let status: String = row
936 .try_get_by_index(1)
937 .map_err(|e| DurableError::custom(e.to_string()))?;
938 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
939
940 if status == TaskStatus::Completed.to_string() {
941 return Ok((id, output));
943 }
944 return Ok((id, None));
946 }
947
948 Err(DurableError::StepLocked(name.to_string()))
950 } else {
951 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
955 query = match parent_id {
956 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
957 None => query.filter(TaskColumn::ParentId.is_null()),
958 };
959 let existing = query.one(db).await?;
960
961 if let Some(model) = existing {
962 if model.status == TaskStatus::Completed {
963 return Ok((model.id, model.output));
964 }
965 return Ok((model.id, None));
966 }
967
968 let id = Uuid::new_v4();
970 let new_task = TaskActiveModel {
971 id: Set(id),
972 parent_id: Set(parent_id),
973 sequence: Set(sequence),
974 name: Set(name.to_string()),
975 kind: Set(kind.to_string()),
976 status: Set(TaskStatus::Pending),
977 input: Set(input),
978 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
979 ..Default::default()
980 };
981 new_task.insert(db).await?;
982
983 Ok((id, None))
984 }
985}
986
987async fn get_output(
988 db: &impl ConnectionTrait,
989 task_id: Uuid,
990) -> Result<Option<serde_json::Value>, DurableError> {
991 let model = Task::find_by_id(task_id)
992 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
993 .one(db)
994 .await?;
995
996 Ok(model.and_then(|m| m.output))
997}
998
999async fn get_status(
1000 db: &impl ConnectionTrait,
1001 task_id: Uuid,
1002) -> Result<Option<TaskStatus>, DurableError> {
1003 let model = Task::find_by_id(task_id).one(db).await?;
1004
1005 Ok(model.map(|m| m.status))
1006}
1007
1008async fn get_retry_info(
1010 db: &DatabaseConnection,
1011 task_id: Uuid,
1012) -> Result<(u32, u32), DurableError> {
1013 let model = Task::find_by_id(task_id).one(db).await?;
1014
1015 match model {
1016 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1017 None => Err(DurableError::custom(format!(
1018 "task {task_id} not found when reading retry info"
1019 ))),
1020 }
1021}
1022
1023async fn increment_retry_count(
1025 db: &DatabaseConnection,
1026 task_id: Uuid,
1027) -> Result<u32, DurableError> {
1028 let model = Task::find_by_id(task_id).one(db).await?;
1029
1030 match model {
1031 Some(m) => {
1032 let new_count = m.retry_count + 1;
1033 let mut active: TaskActiveModel = m.into();
1034 active.retry_count = Set(new_count);
1035 active.status = Set(TaskStatus::Pending);
1036 active.error = Set(None);
1037 active.completed_at = Set(None);
1038 active.update(db).await?;
1039 Ok(new_count as u32)
1040 }
1041 None => Err(DurableError::custom(format!(
1042 "task {task_id} not found when incrementing retry count"
1043 ))),
1044 }
1045}
1046
1047async fn set_status(
1050 db: &impl ConnectionTrait,
1051 task_id: Uuid,
1052 status: TaskStatus,
1053) -> Result<(), DurableError> {
1054 let sql = format!(
1055 "UPDATE durable.task \
1056 SET status = '{status}', \
1057 started_at = COALESCE(started_at, now()), \
1058 deadline_epoch_ms = CASE \
1059 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1060 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1061 ELSE deadline_epoch_ms \
1062 END \
1063 WHERE id = '{task_id}'"
1064 );
1065 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1066 .await?;
1067 Ok(())
1068}
1069
1070async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1072 let status = get_status(db, task_id).await?;
1073 match status {
1074 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1075 Some(TaskStatus::Cancelled) => {
1076 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1077 }
1078 _ => Ok(()),
1079 }
1080}
1081
1082async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1084 let model = Task::find_by_id(task_id).one(db).await?;
1085
1086 if let Some(m) = model
1087 && let Some(deadline_ms) = m.deadline_epoch_ms
1088 {
1089 let now_ms = std::time::SystemTime::now()
1090 .duration_since(std::time::UNIX_EPOCH)
1091 .map(|d| d.as_millis() as i64)
1092 .unwrap_or(0);
1093 if now_ms > deadline_ms {
1094 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1095 }
1096 }
1097
1098 Ok(())
1099}
1100
1101async fn complete_task(
1102 db: &impl ConnectionTrait,
1103 task_id: Uuid,
1104 output: serde_json::Value,
1105) -> Result<(), DurableError> {
1106 let model = Task::find_by_id(task_id).one(db).await?;
1107
1108 if let Some(m) = model {
1109 let mut active: TaskActiveModel = m.into();
1110 active.status = Set(TaskStatus::Completed);
1111 active.output = Set(Some(output));
1112 active.completed_at = Set(Some(chrono::Utc::now().into()));
1113 active.update(db).await?;
1114 }
1115 Ok(())
1116}
1117
1118async fn fail_task(
1119 db: &impl ConnectionTrait,
1120 task_id: Uuid,
1121 error: &str,
1122) -> Result<(), DurableError> {
1123 let model = Task::find_by_id(task_id).one(db).await?;
1124
1125 if let Some(m) = model {
1126 let mut active: TaskActiveModel = m.into();
1127 active.status = Set(TaskStatus::Failed);
1128 active.error = Set(Some(error.to_string()));
1129 active.completed_at = Set(Some(chrono::Utc::now().into()));
1130 active.update(db).await?;
1131 }
1132 Ok(())
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137 use super::*;
1138 use std::sync::Arc;
1139 use std::sync::atomic::{AtomicU32, Ordering};
1140
1141 #[tokio::test]
1144 async fn test_retry_db_write_succeeds_first_try() {
1145 let call_count = Arc::new(AtomicU32::new(0));
1146 let cc = call_count.clone();
1147 let result = retry_db_write(|| {
1148 let c = cc.clone();
1149 async move {
1150 c.fetch_add(1, Ordering::SeqCst);
1151 Ok::<(), DurableError>(())
1152 }
1153 })
1154 .await;
1155 assert!(result.is_ok());
1156 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1157 }
1158
1159 #[tokio::test]
1162 async fn test_retry_db_write_succeeds_after_transient_failure() {
1163 let call_count = Arc::new(AtomicU32::new(0));
1164 let cc = call_count.clone();
1165 let result = retry_db_write(|| {
1166 let c = cc.clone();
1167 async move {
1168 let n = c.fetch_add(1, Ordering::SeqCst);
1169 if n < 2 {
1170 Err(DurableError::Db(sea_orm::DbErr::Custom(
1171 "transient".to_string(),
1172 )))
1173 } else {
1174 Ok(())
1175 }
1176 }
1177 })
1178 .await;
1179 assert!(result.is_ok());
1180 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1181 }
1182
1183 #[tokio::test]
1186 async fn test_retry_db_write_exhausts_retries() {
1187 let call_count = Arc::new(AtomicU32::new(0));
1188 let cc = call_count.clone();
1189 let result = retry_db_write(|| {
1190 let c = cc.clone();
1191 async move {
1192 c.fetch_add(1, Ordering::SeqCst);
1193 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1194 "always fails".to_string(),
1195 )))
1196 }
1197 })
1198 .await;
1199 assert!(result.is_err());
1200 assert_eq!(
1202 call_count.load(Ordering::SeqCst),
1203 1 + MAX_CHECKPOINT_RETRIES
1204 );
1205 }
1206
1207 #[tokio::test]
1210 async fn test_retry_db_write_returns_original_error() {
1211 let call_count = Arc::new(AtomicU32::new(0));
1212 let cc = call_count.clone();
1213 let result = retry_db_write(|| {
1214 let c = cc.clone();
1215 async move {
1216 let n = c.fetch_add(1, Ordering::SeqCst);
1217 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1218 "error-{}",
1219 n
1220 ))))
1221 }
1222 })
1223 .await;
1224 let err = result.unwrap_err();
1225 assert!(
1227 err.to_string().contains("error-0"),
1228 "expected first error (error-0), got: {err}"
1229 );
1230 }
1231}