1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub status: TaskStatus,
200 pub kind: String,
201 pub input: Option<serde_json::Value>,
202 pub output: Option<serde_json::Value>,
203 pub error: Option<String>,
204 pub queue_name: Option<String>,
205 pub created_at: chrono::DateTime<chrono::FixedOffset>,
206 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211 fn from(m: durable_db::entity::task::Model) -> Self {
212 Self {
213 id: m.id,
214 parent_id: m.parent_id,
215 name: m.name,
216 status: m.status,
217 kind: m.kind,
218 input: m.input,
219 output: m.output,
220 error: m.error,
221 queue_name: m.queue_name,
222 created_at: m.created_at,
223 started_at: m.started_at,
224 completed_at: m.completed_at,
225 }
226 }
227}
228
229pub struct Ctx {
235 db: DatabaseConnection,
236 task_id: Uuid,
237 sequence: AtomicI32,
238}
239
240impl Ctx {
241 pub async fn start(
251 db: &DatabaseConnection,
252 name: &str,
253 input: Option<serde_json::Value>,
254 ) -> Result<Self, DurableError> {
255 let task_id = Uuid::new_v4();
256 let input_json = match &input {
257 Some(v) => serde_json::to_string(v)?,
258 None => "null".to_string(),
259 };
260
261 let txn = db.begin().await?;
262 let sql = format!(
263 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at) \
264 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now())"
265 );
266 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
267 .await?;
268 txn.commit().await?;
269
270 Ok(Self {
271 db: db.clone(),
272 task_id,
273 sequence: AtomicI32::new(0),
274 })
275 }
276
277 pub async fn from_id(
284 db: &DatabaseConnection,
285 task_id: Uuid,
286 ) -> Result<Self, DurableError> {
287 let model = Task::find_by_id(task_id).one(db).await?;
289 let _model =
290 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
291
292 Ok(Self {
297 db: db.clone(),
298 task_id,
299 sequence: AtomicI32::new(0),
300 })
301 }
302
303 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
312 where
313 T: Serialize + DeserializeOwned,
314 F: FnOnce() -> Fut,
315 Fut: std::future::Future<Output = Result<T, DurableError>>,
316 {
317 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
318
319 check_status(&self.db, self.task_id).await?;
321
322 check_deadline(&self.db, self.task_id).await?;
324
325 let txn = self.db.begin().await?;
327
328 let (step_id, saved_output) = find_or_create_task(
333 &txn,
334 Some(self.task_id),
335 Some(seq),
336 name,
337 "STEP",
338 None,
339 true,
340 Some(0),
341 )
342 .await?;
343
344 if let Some(output) = saved_output {
346 txn.commit().await?;
347 let val: T = serde_json::from_value(output)?;
348 tracing::debug!(step = name, seq, "replaying saved output");
349 return Ok(val);
350 }
351
352 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
354 match f().await {
355 Ok(val) => {
356 let json = serde_json::to_value(&val)?;
357 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
358 txn.commit().await?;
359 tracing::debug!(step = name, seq, "step completed");
360 Ok(val)
361 }
362 Err(e) => {
363 let err_msg = e.to_string();
364 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
365 txn.commit().await?;
366 Err(e)
367 }
368 }
369 }
370
371 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
383 where
384 T: Serialize + DeserializeOwned + Send,
385 F: for<'tx> FnOnce(
386 &'tx DatabaseTransaction,
387 ) -> Pin<
388 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
389 > + Send,
390 {
391 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
392
393 check_status(&self.db, self.task_id).await?;
395
396 let (step_id, saved_output) = find_or_create_task(
399 &self.db,
400 Some(self.task_id),
401 Some(seq),
402 name,
403 "TRANSACTION",
404 None,
405 false,
406 None,
407 )
408 .await?;
409
410 if let Some(output) = saved_output {
412 let val: T = serde_json::from_value(output)?;
413 tracing::debug!(step = name, seq, "replaying saved transaction output");
414 return Ok(val);
415 }
416
417 let tx = self.db.begin().await?;
419
420 set_status(&tx, step_id, TaskStatus::Running).await?;
421
422 match f(&tx).await {
423 Ok(val) => {
424 let json = serde_json::to_value(&val)?;
425 complete_task(&tx, step_id, json).await?;
426 tx.commit().await?;
427 tracing::debug!(step = name, seq, "transaction step committed");
428 Ok(val)
429 }
430 Err(e) => {
431 drop(tx);
434 fail_task(&self.db, step_id, &e.to_string()).await?;
435 Err(e)
436 }
437 }
438 }
439
440 pub async fn child(
448 &self,
449 name: &str,
450 input: Option<serde_json::Value>,
451 ) -> Result<Self, DurableError> {
452 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
453
454 check_status(&self.db, self.task_id).await?;
456
457 let txn = self.db.begin().await?;
458 let (child_id, _saved) = find_or_create_task(
460 &txn,
461 Some(self.task_id),
462 Some(seq),
463 name,
464 "WORKFLOW",
465 input,
466 false,
467 None,
468 )
469 .await?;
470
471 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
474 txn.commit().await?;
475
476 Ok(Self {
477 db: self.db.clone(),
478 task_id: child_id,
479 sequence: AtomicI32::new(0),
480 })
481 }
482
483 pub async fn is_completed(&self) -> Result<bool, DurableError> {
485 let status = get_status(&self.db, self.task_id).await?;
486 Ok(status == Some(TaskStatus::Completed))
487 }
488
489 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
491 match get_output(&self.db, self.task_id).await? {
492 Some(val) => Ok(Some(serde_json::from_value(val)?)),
493 None => Ok(None),
494 }
495 }
496
497 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
499 let json = serde_json::to_value(output)?;
500 let db = &self.db;
501 let task_id = self.task_id;
502 retry_db_write(|| complete_task(db, task_id, json.clone())).await
503 }
504
505 pub async fn step_with_retry<T, F, Fut>(
519 &self,
520 name: &str,
521 policy: RetryPolicy,
522 f: F,
523 ) -> Result<T, DurableError>
524 where
525 T: Serialize + DeserializeOwned,
526 F: Fn() -> Fut,
527 Fut: std::future::Future<Output = Result<T, DurableError>>,
528 {
529 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
530
531 check_status(&self.db, self.task_id).await?;
533
534 let (step_id, saved_output) = find_or_create_task(
538 &self.db,
539 Some(self.task_id),
540 Some(seq),
541 name,
542 "STEP",
543 None,
544 false,
545 Some(policy.max_retries),
546 )
547 .await?;
548
549 if let Some(output) = saved_output {
551 let val: T = serde_json::from_value(output)?;
552 tracing::debug!(step = name, seq, "replaying saved output");
553 return Ok(val);
554 }
555
556 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
558
559 loop {
561 check_status(&self.db, self.task_id).await?;
563 set_status(&self.db, step_id, TaskStatus::Running).await?;
564 match f().await {
565 Ok(val) => {
566 let json = serde_json::to_value(&val)?;
567 complete_task(&self.db, step_id, json).await?;
568 tracing::debug!(step = name, seq, retry_count, "step completed");
569 return Ok(val);
570 }
571 Err(e) => {
572 if retry_count < max_retries {
573 retry_count = increment_retry_count(&self.db, step_id).await?;
575 tracing::debug!(
576 step = name,
577 seq,
578 retry_count,
579 max_retries,
580 "step failed, retrying"
581 );
582
583 let backoff = if policy.initial_backoff.is_zero() {
585 std::time::Duration::ZERO
586 } else {
587 let factor = policy
588 .backoff_multiplier
589 .powi((retry_count - 1) as i32)
590 .max(1.0);
591 let millis =
592 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
593 std::time::Duration::from_millis(millis)
594 };
595
596 if !backoff.is_zero() {
597 tokio::time::sleep(backoff).await;
598 }
599 } else {
600 fail_task(&self.db, step_id, &e.to_string()).await?;
602 tracing::debug!(
603 step = name,
604 seq,
605 retry_count,
606 "step exhausted retries, marked FAILED"
607 );
608 return Err(e);
609 }
610 }
611 }
612 }
613 }
614
615 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
617 let db = &self.db;
618 let task_id = self.task_id;
619 retry_db_write(|| fail_task(db, task_id, error)).await
620 }
621
622 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
630 let sql = format!(
631 "UPDATE durable.task \
632 SET timeout_ms = {timeout_ms}, \
633 deadline_epoch_ms = CASE \
634 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
635 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
636 ELSE deadline_epoch_ms \
637 END \
638 WHERE id = '{}'",
639 self.task_id
640 );
641 self.db
642 .execute(Statement::from_string(DbBackend::Postgres, sql))
643 .await?;
644 Ok(())
645 }
646
647 pub async fn start_with_timeout(
651 db: &DatabaseConnection,
652 name: &str,
653 input: Option<serde_json::Value>,
654 timeout_ms: i64,
655 ) -> Result<Self, DurableError> {
656 let ctx = Self::start(db, name, input).await?;
657 ctx.set_timeout(timeout_ms).await?;
658 Ok(ctx)
659 }
660
661 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
668 let model = Task::find_by_id(task_id).one(db).await?;
669 let model =
670 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
671
672 match model.status {
673 TaskStatus::Pending | TaskStatus::Running => {}
674 status => {
675 return Err(DurableError::custom(format!(
676 "cannot pause task in {status} status"
677 )));
678 }
679 }
680
681 let sql = format!(
683 "WITH RECURSIVE descendants AS ( \
684 SELECT id FROM durable.task WHERE id = '{task_id}' \
685 UNION ALL \
686 SELECT t.id FROM durable.task t \
687 INNER JOIN descendants d ON t.parent_id = d.id \
688 ) \
689 UPDATE durable.task SET status = 'PAUSED' \
690 WHERE id IN (SELECT id FROM descendants) \
691 AND status IN ('PENDING', 'RUNNING')"
692 );
693 db.execute(Statement::from_string(DbBackend::Postgres, sql))
694 .await?;
695
696 tracing::info!(%task_id, "workflow paused");
697 Ok(())
698 }
699
700 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
703 let model = Task::find_by_id(task_id).one(db).await?;
704 let model =
705 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
706
707 if model.status != TaskStatus::Paused {
708 return Err(DurableError::custom(format!(
709 "cannot resume task in {} status (must be PAUSED)",
710 model.status
711 )));
712 }
713
714 let mut active: TaskActiveModel = model.into();
716 active.status = Set(TaskStatus::Running);
717 active.update(db).await?;
718
719 let cascade_sql = format!(
721 "WITH RECURSIVE descendants AS ( \
722 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
723 UNION ALL \
724 SELECT t.id FROM durable.task t \
725 INNER JOIN descendants d ON t.parent_id = d.id \
726 ) \
727 UPDATE durable.task SET status = 'PENDING' \
728 WHERE id IN (SELECT id FROM descendants) \
729 AND status = 'PAUSED'"
730 );
731 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
732 .await?;
733
734 tracing::info!(%task_id, "workflow resumed");
735 Ok(())
736 }
737
738 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
743 let model = Task::find_by_id(task_id).one(db).await?;
744 let model =
745 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
746
747 match model.status {
748 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
749 return Err(DurableError::custom(format!(
750 "cannot cancel task in {} status",
751 model.status
752 )));
753 }
754 _ => {}
755 }
756
757 let sql = format!(
759 "WITH RECURSIVE descendants AS ( \
760 SELECT id FROM durable.task WHERE id = '{task_id}' \
761 UNION ALL \
762 SELECT t.id FROM durable.task t \
763 INNER JOIN descendants d ON t.parent_id = d.id \
764 ) \
765 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
766 WHERE id IN (SELECT id FROM descendants) \
767 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
768 );
769 db.execute(Statement::from_string(DbBackend::Postgres, sql))
770 .await?;
771
772 tracing::info!(%task_id, "workflow cancelled");
773 Ok(())
774 }
775
776 pub async fn list(
784 db: &DatabaseConnection,
785 query: TaskQuery,
786 ) -> Result<Vec<TaskSummary>, DurableError> {
787 let mut select = Task::find();
788
789 if let Some(status) = &query.status {
791 select = select.filter(TaskColumn::Status.eq(status.to_string()));
792 }
793 if let Some(kind) = &query.kind {
794 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
795 }
796 if let Some(parent_id) = query.parent_id {
797 select = select.filter(TaskColumn::ParentId.eq(parent_id));
798 }
799 if query.root_only {
800 select = select.filter(TaskColumn::ParentId.is_null());
801 }
802 if let Some(name) = &query.name {
803 select = select.filter(TaskColumn::Name.eq(name.as_str()));
804 }
805 if let Some(queue) = &query.queue_name {
806 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
807 }
808
809 let (col, order) = match query.sort {
811 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
812 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
813 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
814 TaskSort::Name(ord) => (TaskColumn::Name, ord),
815 TaskSort::Status(ord) => (TaskColumn::Status, ord),
816 };
817 select = select.order_by(col, order);
818
819 if let Some(offset) = query.offset {
821 select = select.offset(offset);
822 }
823 if let Some(limit) = query.limit {
824 select = select.limit(limit);
825 }
826
827 let models = select.all(db).await?;
828
829 Ok(models.into_iter().map(TaskSummary::from).collect())
830 }
831
832 pub async fn count(
834 db: &DatabaseConnection,
835 query: TaskQuery,
836 ) -> Result<u64, DurableError> {
837 let mut select = Task::find();
838
839 if let Some(status) = &query.status {
840 select = select.filter(TaskColumn::Status.eq(status.to_string()));
841 }
842 if let Some(kind) = &query.kind {
843 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
844 }
845 if let Some(parent_id) = query.parent_id {
846 select = select.filter(TaskColumn::ParentId.eq(parent_id));
847 }
848 if query.root_only {
849 select = select.filter(TaskColumn::ParentId.is_null());
850 }
851 if let Some(name) = &query.name {
852 select = select.filter(TaskColumn::Name.eq(name.as_str()));
853 }
854 if let Some(queue) = &query.queue_name {
855 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
856 }
857
858 let count = select.count(db).await?;
859 Ok(count)
860 }
861
862 pub fn db(&self) -> &DatabaseConnection {
865 &self.db
866 }
867
868 pub fn task_id(&self) -> Uuid {
869 self.task_id
870 }
871
872 pub fn next_sequence(&self) -> i32 {
873 self.sequence.fetch_add(1, Ordering::SeqCst)
874 }
875}
876
877#[allow(clippy::too_many_arguments)]
898async fn find_or_create_task(
899 db: &impl ConnectionTrait,
900 parent_id: Option<Uuid>,
901 sequence: Option<i32>,
902 name: &str,
903 kind: &str,
904 input: Option<serde_json::Value>,
905 lock: bool,
906 max_retries: Option<u32>,
907) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
908 let parent_eq = match parent_id {
909 Some(p) => format!("= '{p}'"),
910 None => "IS NULL".to_string(),
911 };
912 let parent_sql = match parent_id {
913 Some(p) => format!("'{p}'"),
914 None => "NULL".to_string(),
915 };
916
917 if lock {
918 let new_id = Uuid::new_v4();
932 let seq_sql = match sequence {
933 Some(s) => s.to_string(),
934 None => "NULL".to_string(),
935 };
936 let input_sql = match &input {
937 Some(v) => format!("'{}'", serde_json::to_string(v)?),
938 None => "NULL".to_string(),
939 };
940
941 let max_retries_sql = match max_retries {
942 Some(r) => r.to_string(),
943 None => "3".to_string(), };
945
946 let insert_sql = format!(
948 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
949 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
950 ON CONFLICT (parent_id, sequence) DO NOTHING"
951 );
952 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
953 .await?;
954
955 let lock_sql = format!(
957 "SELECT id, status::text, output FROM durable.task \
958 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
959 FOR UPDATE SKIP LOCKED"
960 );
961 let row = db
962 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
963 .await?;
964
965 if let Some(row) = row {
966 let id: Uuid = row
967 .try_get_by_index(0)
968 .map_err(|e| DurableError::custom(e.to_string()))?;
969 let status: String = row
970 .try_get_by_index(1)
971 .map_err(|e| DurableError::custom(e.to_string()))?;
972 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
973
974 if status == TaskStatus::Completed.to_string() {
975 return Ok((id, output));
977 }
978 return Ok((id, None));
980 }
981
982 Err(DurableError::StepLocked(name.to_string()))
984 } else {
985 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
989 query = match parent_id {
990 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
991 None => query.filter(TaskColumn::ParentId.is_null()),
992 };
993 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
996 let existing = query
997 .filter(TaskColumn::Status.is_not_in(status_exclusions))
998 .one(db)
999 .await?;
1000
1001 if let Some(model) = existing {
1002 if model.status == TaskStatus::Completed {
1003 return Ok((model.id, model.output));
1004 }
1005 return Ok((model.id, None));
1006 }
1007
1008 let id = Uuid::new_v4();
1010 let new_task = TaskActiveModel {
1011 id: Set(id),
1012 parent_id: Set(parent_id),
1013 sequence: Set(sequence),
1014 name: Set(name.to_string()),
1015 kind: Set(kind.to_string()),
1016 status: Set(TaskStatus::Pending),
1017 input: Set(input),
1018 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1019 ..Default::default()
1020 };
1021 new_task.insert(db).await?;
1022
1023 Ok((id, None))
1024 }
1025}
1026
1027async fn get_output(
1028 db: &impl ConnectionTrait,
1029 task_id: Uuid,
1030) -> Result<Option<serde_json::Value>, DurableError> {
1031 let model = Task::find_by_id(task_id)
1032 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1033 .one(db)
1034 .await?;
1035
1036 Ok(model.and_then(|m| m.output))
1037}
1038
1039async fn get_status(
1040 db: &impl ConnectionTrait,
1041 task_id: Uuid,
1042) -> Result<Option<TaskStatus>, DurableError> {
1043 let model = Task::find_by_id(task_id).one(db).await?;
1044
1045 Ok(model.map(|m| m.status))
1046}
1047
1048async fn get_retry_info(
1050 db: &DatabaseConnection,
1051 task_id: Uuid,
1052) -> Result<(u32, u32), DurableError> {
1053 let model = Task::find_by_id(task_id).one(db).await?;
1054
1055 match model {
1056 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1057 None => Err(DurableError::custom(format!(
1058 "task {task_id} not found when reading retry info"
1059 ))),
1060 }
1061}
1062
1063async fn increment_retry_count(
1065 db: &DatabaseConnection,
1066 task_id: Uuid,
1067) -> Result<u32, DurableError> {
1068 let model = Task::find_by_id(task_id).one(db).await?;
1069
1070 match model {
1071 Some(m) => {
1072 let new_count = m.retry_count + 1;
1073 let mut active: TaskActiveModel = m.into();
1074 active.retry_count = Set(new_count);
1075 active.status = Set(TaskStatus::Pending);
1076 active.error = Set(None);
1077 active.completed_at = Set(None);
1078 active.update(db).await?;
1079 Ok(new_count as u32)
1080 }
1081 None => Err(DurableError::custom(format!(
1082 "task {task_id} not found when incrementing retry count"
1083 ))),
1084 }
1085}
1086
1087async fn set_status(
1090 db: &impl ConnectionTrait,
1091 task_id: Uuid,
1092 status: TaskStatus,
1093) -> Result<(), DurableError> {
1094 let sql = format!(
1095 "UPDATE durable.task \
1096 SET status = '{status}', \
1097 started_at = COALESCE(started_at, now()), \
1098 deadline_epoch_ms = CASE \
1099 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1100 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1101 ELSE deadline_epoch_ms \
1102 END \
1103 WHERE id = '{task_id}'"
1104 );
1105 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1106 .await?;
1107 Ok(())
1108}
1109
1110async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1112 let status = get_status(db, task_id).await?;
1113 match status {
1114 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1115 Some(TaskStatus::Cancelled) => {
1116 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1117 }
1118 _ => Ok(()),
1119 }
1120}
1121
1122async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1124 let model = Task::find_by_id(task_id).one(db).await?;
1125
1126 if let Some(m) = model
1127 && let Some(deadline_ms) = m.deadline_epoch_ms
1128 {
1129 let now_ms = std::time::SystemTime::now()
1130 .duration_since(std::time::UNIX_EPOCH)
1131 .map(|d| d.as_millis() as i64)
1132 .unwrap_or(0);
1133 if now_ms > deadline_ms {
1134 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1135 }
1136 }
1137
1138 Ok(())
1139}
1140
1141async fn complete_task(
1142 db: &impl ConnectionTrait,
1143 task_id: Uuid,
1144 output: serde_json::Value,
1145) -> Result<(), DurableError> {
1146 let model = Task::find_by_id(task_id).one(db).await?;
1147
1148 if let Some(m) = model {
1149 let mut active: TaskActiveModel = m.into();
1150 active.status = Set(TaskStatus::Completed);
1151 active.output = Set(Some(output));
1152 active.completed_at = Set(Some(chrono::Utc::now().into()));
1153 active.update(db).await?;
1154 }
1155 Ok(())
1156}
1157
1158async fn fail_task(
1159 db: &impl ConnectionTrait,
1160 task_id: Uuid,
1161 error: &str,
1162) -> Result<(), DurableError> {
1163 let model = Task::find_by_id(task_id).one(db).await?;
1164
1165 if let Some(m) = model {
1166 let mut active: TaskActiveModel = m.into();
1167 active.status = Set(TaskStatus::Failed);
1168 active.error = Set(Some(error.to_string()));
1169 active.completed_at = Set(Some(chrono::Utc::now().into()));
1170 active.update(db).await?;
1171 }
1172 Ok(())
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177 use super::*;
1178 use std::sync::Arc;
1179 use std::sync::atomic::{AtomicU32, Ordering};
1180
1181 #[tokio::test]
1184 async fn test_retry_db_write_succeeds_first_try() {
1185 let call_count = Arc::new(AtomicU32::new(0));
1186 let cc = call_count.clone();
1187 let result = retry_db_write(|| {
1188 let c = cc.clone();
1189 async move {
1190 c.fetch_add(1, Ordering::SeqCst);
1191 Ok::<(), DurableError>(())
1192 }
1193 })
1194 .await;
1195 assert!(result.is_ok());
1196 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1197 }
1198
1199 #[tokio::test]
1202 async fn test_retry_db_write_succeeds_after_transient_failure() {
1203 let call_count = Arc::new(AtomicU32::new(0));
1204 let cc = call_count.clone();
1205 let result = retry_db_write(|| {
1206 let c = cc.clone();
1207 async move {
1208 let n = c.fetch_add(1, Ordering::SeqCst);
1209 if n < 2 {
1210 Err(DurableError::Db(sea_orm::DbErr::Custom(
1211 "transient".to_string(),
1212 )))
1213 } else {
1214 Ok(())
1215 }
1216 }
1217 })
1218 .await;
1219 assert!(result.is_ok());
1220 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1221 }
1222
1223 #[tokio::test]
1226 async fn test_retry_db_write_exhausts_retries() {
1227 let call_count = Arc::new(AtomicU32::new(0));
1228 let cc = call_count.clone();
1229 let result = retry_db_write(|| {
1230 let c = cc.clone();
1231 async move {
1232 c.fetch_add(1, Ordering::SeqCst);
1233 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1234 "always fails".to_string(),
1235 )))
1236 }
1237 })
1238 .await;
1239 assert!(result.is_err());
1240 assert_eq!(
1242 call_count.load(Ordering::SeqCst),
1243 1 + MAX_CHECKPOINT_RETRIES
1244 );
1245 }
1246
1247 #[tokio::test]
1250 async fn test_retry_db_write_returns_original_error() {
1251 let call_count = Arc::new(AtomicU32::new(0));
1252 let cc = call_count.clone();
1253 let result = retry_db_write(|| {
1254 let c = cc.clone();
1255 async move {
1256 let n = c.fetch_add(1, Ordering::SeqCst);
1257 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1258 "error-{}",
1259 n
1260 ))))
1261 }
1262 })
1263 .await;
1264 let err = result.unwrap_err();
1265 assert!(
1267 err.to_string().contains("error-0"),
1268 "expected first error (error-0), got: {err}"
1269 );
1270 }
1271}