1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20pub(crate) const TS: &str = "::durable.task_status";
28
29const MAX_CHECKPOINT_RETRIES: u32 = 3;
32const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
33
34async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
40where
41 F: FnMut() -> Fut,
42 Fut: std::future::Future<Output = Result<(), DurableError>>,
43{
44 match f().await {
45 Ok(()) => Ok(()),
46 Err(first_err) => {
47 for i in 0..MAX_CHECKPOINT_RETRIES {
48 tokio::time::sleep(Duration::from_millis(
49 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
50 ))
51 .await;
52 if f().await.is_ok() {
53 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
54 return Ok(());
55 }
56 }
57 Err(first_err)
58 }
59 }
60}
61
62pub struct RetryPolicy {
64 pub max_retries: u32,
65 pub initial_backoff: std::time::Duration,
66 pub backoff_multiplier: f64,
67}
68
69impl RetryPolicy {
70 pub fn none() -> Self {
72 Self {
73 max_retries: 0,
74 initial_backoff: std::time::Duration::from_secs(0),
75 backoff_multiplier: 1.0,
76 }
77 }
78
79 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff,
84 backoff_multiplier: 2.0,
85 }
86 }
87
88 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
90 Self {
91 max_retries,
92 initial_backoff: backoff,
93 backoff_multiplier: 1.0,
94 }
95 }
96}
97
98pub enum TaskSort {
100 CreatedAt(Order),
101 StartedAt(Order),
102 CompletedAt(Order),
103 Name(Order),
104 Status(Order),
105}
106
107pub struct TaskQuery {
119 pub status: Option<TaskStatus>,
120 pub kind: Option<String>,
121 pub parent_id: Option<Uuid>,
122 pub root_only: bool,
123 pub name: Option<String>,
124 pub queue_name: Option<String>,
125 pub sort: TaskSort,
126 pub limit: Option<u64>,
127 pub offset: Option<u64>,
128}
129
130impl Default for TaskQuery {
131 fn default() -> Self {
132 Self {
133 status: None,
134 kind: None,
135 parent_id: None,
136 root_only: false,
137 name: None,
138 queue_name: None,
139 sort: TaskSort::CreatedAt(Order::Desc),
140 limit: None,
141 offset: None,
142 }
143 }
144}
145
146impl TaskQuery {
147 pub fn status(mut self, status: TaskStatus) -> Self {
149 self.status = Some(status);
150 self
151 }
152
153 pub fn kind(mut self, kind: &str) -> Self {
155 self.kind = Some(kind.to_string());
156 self
157 }
158
159 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
161 self.parent_id = Some(parent_id);
162 self
163 }
164
165 pub fn root_only(mut self, root_only: bool) -> Self {
167 self.root_only = root_only;
168 self
169 }
170
171 pub fn name(mut self, name: &str) -> Self {
173 self.name = Some(name.to_string());
174 self
175 }
176
177 pub fn queue_name(mut self, queue: &str) -> Self {
179 self.queue_name = Some(queue.to_string());
180 self
181 }
182
183 pub fn sort(mut self, sort: TaskSort) -> Self {
185 self.sort = sort;
186 self
187 }
188
189 pub fn limit(mut self, limit: u64) -> Self {
191 self.limit = Some(limit);
192 self
193 }
194
195 pub fn offset(mut self, offset: u64) -> Self {
197 self.offset = Some(offset);
198 self
199 }
200}
201
202#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
204pub struct TaskSummary {
205 pub id: Uuid,
206 pub parent_id: Option<Uuid>,
207 pub name: String,
208 pub handler: Option<String>,
209 pub status: TaskStatus,
210 pub kind: String,
211 pub input: Option<serde_json::Value>,
212 pub output: Option<serde_json::Value>,
213 pub error: Option<String>,
214 pub queue_name: Option<String>,
215 pub created_at: chrono::DateTime<chrono::FixedOffset>,
216 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
217 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
218}
219
220impl From<durable_db::entity::task::Model> for TaskSummary {
221 fn from(m: durable_db::entity::task::Model) -> Self {
222 Self {
223 id: m.id,
224 parent_id: m.parent_id,
225 name: m.name,
226 handler: m.handler,
227 status: m.status,
228 kind: m.kind,
229 input: m.input,
230 output: m.output,
231 error: m.error,
232 queue_name: m.queue_name,
233 created_at: m.created_at,
234 started_at: m.started_at,
235 completed_at: m.completed_at,
236 }
237 }
238}
239
240pub struct Ctx {
246 db: DatabaseConnection,
247 task_id: Uuid,
248 sequence: AtomicI32,
249 executor_id: Option<String>,
250}
251
252impl Ctx {
253 pub async fn start(
270 db: &DatabaseConnection,
271 name: &str,
272 input: Option<serde_json::Value>,
273 ) -> Result<Self, DurableError> {
274 Self::start_with_handler(db, name, input, None).await
275 }
276
277 pub async fn start_with_handler(
285 db: &DatabaseConnection,
286 name: &str,
287 input: Option<serde_json::Value>,
288 handler: Option<&str>,
289 ) -> Result<Self, DurableError> {
290 let existing_sql = format!(
292 "SELECT id FROM durable.task \
293 WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING'{TS} \
294 LIMIT 1",
295 name
296 );
297 if let Some(row) = db
298 .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
299 .await?
300 {
301 let existing_id: Uuid = row
302 .try_get_by_index(0)
303 .map_err(|e| DurableError::custom(e.to_string()))?;
304 tracing::info!(
305 workflow = name,
306 id = %existing_id,
307 "idempotent start: attaching to existing RUNNING task"
308 );
309 return Self::from_id(db, existing_id).await;
310 }
311
312 let task_id = Uuid::new_v4();
313 let input_json = match &input {
314 Some(v) => serde_json::to_string(v)?,
315 None => "null".to_string(),
316 };
317
318 let executor_id = crate::executor_id();
319
320 let mut extra_cols = String::new();
321 let mut extra_vals = String::new();
322
323 if let Some(eid) = &executor_id {
324 extra_cols.push_str(", executor_id");
325 extra_vals.push_str(&format!(", '{eid}'"));
326 }
327 if let Some(h) = handler {
328 extra_cols.push_str(", handler");
329 extra_vals.push_str(&format!(", '{h}'"));
330 }
331
332 let txn = db.begin().await?;
333 let sql = format!(
334 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
335 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING'{TS}, '{input_json}', now(){extra_vals})"
336 );
337 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
338 .await?;
339 txn.commit().await?;
340
341 Ok(Self {
342 db: db.clone(),
343 task_id,
344 sequence: AtomicI32::new(0),
345 executor_id,
346 })
347 }
348
349 pub async fn from_id(
356 db: &DatabaseConnection,
357 task_id: Uuid,
358 ) -> Result<Self, DurableError> {
359 let model = Task::find_by_id(task_id).one(db).await?;
361 let _model =
362 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
363
364 let executor_id = crate::executor_id();
366 if let Some(eid) = &executor_id {
367 db.execute(Statement::from_string(
368 DbBackend::Postgres,
369 format!(
370 "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
371 ),
372 ))
373 .await?;
374 }
375
376 Ok(Self {
381 db: db.clone(),
382 task_id,
383 sequence: AtomicI32::new(0),
384 executor_id,
385 })
386 }
387
388 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
397 where
398 T: Serialize + DeserializeOwned,
399 F: FnOnce() -> Fut,
400 Fut: std::future::Future<Output = Result<T, DurableError>>,
401 {
402 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
403
404 check_status(&self.db, self.task_id).await?;
406
407 check_deadline(&self.db, self.task_id).await?;
409
410 let txn = self.db.begin().await?;
414 let (step_id, saved_output) = find_or_create_task(
415 &txn,
416 Some(self.task_id),
417 Some(seq),
418 name,
419 "STEP",
420 None,
421 true,
422 Some(0),
423 )
424 .await?;
425
426 if let Some(output) = saved_output {
427 txn.commit().await?;
428 let val: T = serde_json::from_value(output)?;
429 tracing::debug!(step = name, seq, "replaying saved output");
430 return Ok(val);
431 }
432
433 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
434 txn.commit().await?;
435 let result = f().await;
439
440 match result {
442 Ok(val) => {
443 let json = serde_json::to_value(&val)?;
444 retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
445 tracing::debug!(step = name, seq, "step completed");
446 Ok(val)
447 }
448 Err(e) => {
449 let err_msg = e.to_string();
450 retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
451 Err(e)
452 }
453 }
454 }
455
456 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
468 where
469 T: Serialize + DeserializeOwned + Send,
470 F: for<'tx> FnOnce(
471 &'tx DatabaseTransaction,
472 ) -> Pin<
473 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
474 > + Send,
475 {
476 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
477
478 check_status(&self.db, self.task_id).await?;
480
481 let (step_id, saved_output) = find_or_create_task(
484 &self.db,
485 Some(self.task_id),
486 Some(seq),
487 name,
488 "TRANSACTION",
489 None,
490 false,
491 None,
492 )
493 .await?;
494
495 if let Some(output) = saved_output {
497 let val: T = serde_json::from_value(output)?;
498 tracing::debug!(step = name, seq, "replaying saved transaction output");
499 return Ok(val);
500 }
501
502 let tx = self.db.begin().await?;
504
505 set_status(&tx, step_id, TaskStatus::Running).await?;
506
507 match f(&tx).await {
508 Ok(val) => {
509 let json = serde_json::to_value(&val)?;
510 complete_task(&tx, step_id, json).await?;
511 tx.commit().await?;
512 tracing::debug!(step = name, seq, "transaction step committed");
513 Ok(val)
514 }
515 Err(e) => {
516 drop(tx);
519 fail_task(&self.db, step_id, &e.to_string()).await?;
520 Err(e)
521 }
522 }
523 }
524
525 pub async fn child(
533 &self,
534 name: &str,
535 input: Option<serde_json::Value>,
536 ) -> Result<Self, DurableError> {
537 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
538
539 check_status(&self.db, self.task_id).await?;
541
542 let txn = self.db.begin().await?;
543 let (child_id, _saved) = find_or_create_task(
545 &txn,
546 Some(self.task_id),
547 Some(seq),
548 name,
549 "WORKFLOW",
550 input,
551 false,
552 None,
553 )
554 .await?;
555
556 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
559 txn.commit().await?;
560
561 Ok(Self {
562 db: self.db.clone(),
563 task_id: child_id,
564 sequence: AtomicI32::new(0),
565 executor_id: self.executor_id.clone(),
566 })
567 }
568
569 pub async fn is_completed(&self) -> Result<bool, DurableError> {
571 let status = get_status(&self.db, self.task_id).await?;
572 Ok(status == Some(TaskStatus::Completed))
573 }
574
575 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
577 match get_output(&self.db, self.task_id).await? {
578 Some(val) => Ok(Some(serde_json::from_value(val)?)),
579 None => Ok(None),
580 }
581 }
582
583 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
585 let json = serde_json::to_value(output)?;
586 let db = &self.db;
587 let task_id = self.task_id;
588 retry_db_write(|| complete_task(db, task_id, json.clone())).await
589 }
590
591 pub async fn step_with_retry<T, F, Fut>(
605 &self,
606 name: &str,
607 policy: RetryPolicy,
608 f: F,
609 ) -> Result<T, DurableError>
610 where
611 T: Serialize + DeserializeOwned,
612 F: Fn() -> Fut,
613 Fut: std::future::Future<Output = Result<T, DurableError>>,
614 {
615 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
616
617 check_status(&self.db, self.task_id).await?;
619
620 let (step_id, saved_output) = find_or_create_task(
624 &self.db,
625 Some(self.task_id),
626 Some(seq),
627 name,
628 "STEP",
629 None,
630 false,
631 Some(policy.max_retries),
632 )
633 .await?;
634
635 if let Some(output) = saved_output {
637 let val: T = serde_json::from_value(output)?;
638 tracing::debug!(step = name, seq, "replaying saved output");
639 return Ok(val);
640 }
641
642 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
644
645 loop {
647 check_status(&self.db, self.task_id).await?;
649 set_status(&self.db, step_id, TaskStatus::Running).await?;
650 match f().await {
651 Ok(val) => {
652 let json = serde_json::to_value(&val)?;
653 complete_task(&self.db, step_id, json).await?;
654 tracing::debug!(step = name, seq, retry_count, "step completed");
655 return Ok(val);
656 }
657 Err(e) => {
658 if retry_count < max_retries {
659 retry_count = increment_retry_count(&self.db, step_id).await?;
661 tracing::debug!(
662 step = name,
663 seq,
664 retry_count,
665 max_retries,
666 "step failed, retrying"
667 );
668
669 let backoff = if policy.initial_backoff.is_zero() {
671 std::time::Duration::ZERO
672 } else {
673 let factor = policy
674 .backoff_multiplier
675 .powi((retry_count - 1) as i32)
676 .max(1.0);
677 let millis =
678 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
679 std::time::Duration::from_millis(millis)
680 };
681
682 if !backoff.is_zero() {
683 tokio::time::sleep(backoff).await;
684 }
685 } else {
686 fail_task(&self.db, step_id, &e.to_string()).await?;
688 tracing::debug!(
689 step = name,
690 seq,
691 retry_count,
692 "step exhausted retries, marked FAILED"
693 );
694 return Err(e);
695 }
696 }
697 }
698 }
699 }
700
701 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
703 let db = &self.db;
704 let task_id = self.task_id;
705 retry_db_write(|| fail_task(db, task_id, error)).await
706 }
707
708 pub async fn fail_by_id(
710 db: &DatabaseConnection,
711 task_id: Uuid,
712 error: &str,
713 ) -> Result<(), DurableError> {
714 retry_db_write(|| fail_task(db, task_id, error)).await
715 }
716
717 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
725 let sql = format!(
726 "UPDATE durable.task \
727 SET timeout_ms = {timeout_ms}, \
728 deadline_epoch_ms = CASE \
729 WHEN status = 'RUNNING'{TS} AND started_at IS NOT NULL \
730 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
731 ELSE deadline_epoch_ms \
732 END \
733 WHERE id = '{}'",
734 self.task_id
735 );
736 self.db
737 .execute(Statement::from_string(DbBackend::Postgres, sql))
738 .await?;
739 Ok(())
740 }
741
742 pub async fn start_with_timeout(
746 db: &DatabaseConnection,
747 name: &str,
748 input: Option<serde_json::Value>,
749 timeout_ms: i64,
750 ) -> Result<Self, DurableError> {
751 let ctx = Self::start(db, name, input).await?;
752 ctx.set_timeout(timeout_ms).await?;
753 Ok(ctx)
754 }
755
756 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
763 let model = Task::find_by_id(task_id).one(db).await?;
764 let model =
765 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
766
767 match model.status {
768 TaskStatus::Pending | TaskStatus::Running => {}
769 status => {
770 return Err(DurableError::custom(format!(
771 "cannot pause task in {status} status"
772 )));
773 }
774 }
775
776 let sql = format!(
778 "WITH RECURSIVE descendants AS ( \
779 SELECT id FROM durable.task WHERE id = '{task_id}' \
780 UNION ALL \
781 SELECT t.id FROM durable.task t \
782 INNER JOIN descendants d ON t.parent_id = d.id \
783 ) \
784 UPDATE durable.task SET status = 'PAUSED'{TS} \
785 WHERE id IN (SELECT id FROM descendants) \
786 AND status IN ('PENDING'{TS}, 'RUNNING'{TS})"
787 );
788 db.execute(Statement::from_string(DbBackend::Postgres, sql))
789 .await?;
790
791 tracing::info!(%task_id, "workflow paused");
792 Ok(())
793 }
794
795 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
798 let model = Task::find_by_id(task_id).one(db).await?;
799 let model =
800 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
801
802 if model.status != TaskStatus::Paused {
803 return Err(DurableError::custom(format!(
804 "cannot resume task in {} status (must be PAUSED)",
805 model.status
806 )));
807 }
808
809 let mut active: TaskActiveModel = model.into();
811 active.status = Set(TaskStatus::Running);
812 active.update(db).await?;
813
814 let cascade_sql = format!(
816 "WITH RECURSIVE descendants AS ( \
817 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
818 UNION ALL \
819 SELECT t.id FROM durable.task t \
820 INNER JOIN descendants d ON t.parent_id = d.id \
821 ) \
822 UPDATE durable.task SET status = 'PENDING'{TS} \
823 WHERE id IN (SELECT id FROM descendants) \
824 AND status = 'PAUSED'{TS}"
825 );
826 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
827 .await?;
828
829 tracing::info!(%task_id, "workflow resumed");
830 Ok(())
831 }
832
833 pub async fn resume_failed(
846 db: &DatabaseConnection,
847 task_id: Uuid,
848 ) -> Result<Option<String>, DurableError> {
849 let model = Task::find_by_id(task_id).one(db).await?;
851 let model =
852 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
853
854 if model.status != TaskStatus::Failed {
855 return Err(DurableError::custom(format!(
856 "cannot resume task in {} status (must be FAILED)",
857 model.status
858 )));
859 }
860
861 if model.kind != "WORKFLOW" {
862 return Err(DurableError::custom(format!(
863 "cannot resume a {}, only workflows can be resumed",
864 model.kind
865 )));
866 }
867
868 if model.recovery_count >= model.max_recovery_attempts {
870 return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
871 }
872
873 let executor_id = crate::executor_id().unwrap_or_default();
874
875 let reset_workflow_sql = format!(
877 "UPDATE durable.task \
878 SET status = 'RUNNING'{TS}, \
879 error = NULL, \
880 completed_at = NULL, \
881 started_at = now(), \
882 executor_id = '{executor_id}', \
883 recovery_count = recovery_count + 1 \
884 WHERE id = '{task_id}' AND status = 'FAILED'{TS}"
885 );
886 db.execute(Statement::from_string(DbBackend::Postgres, reset_workflow_sql))
887 .await?;
888
889 let reset_steps_sql = format!(
892 "WITH RECURSIVE descendants AS ( \
893 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
894 UNION ALL \
895 SELECT t.id FROM durable.task t \
896 INNER JOIN descendants d ON t.parent_id = d.id \
897 ) \
898 UPDATE durable.task \
899 SET status = 'PENDING'{TS}, \
900 error = NULL, \
901 retry_count = 0, \
902 completed_at = NULL, \
903 started_at = NULL \
904 WHERE id IN (SELECT id FROM descendants) \
905 AND status IN ('FAILED'{TS}, 'RUNNING'{TS})"
906 );
907 db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
908 .await?;
909
910 tracing::info!(
911 %task_id,
912 recovery_count = model.recovery_count + 1,
913 "failed workflow resumed"
914 );
915
916 Ok(model.handler.clone())
917 }
918
919 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
924 let model = Task::find_by_id(task_id).one(db).await?;
925 let model =
926 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
927
928 match model.status {
929 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
930 return Err(DurableError::custom(format!(
931 "cannot cancel task in {} status",
932 model.status
933 )));
934 }
935 _ => {}
936 }
937
938 let sql = format!(
940 "WITH RECURSIVE descendants AS ( \
941 SELECT id FROM durable.task WHERE id = '{task_id}' \
942 UNION ALL \
943 SELECT t.id FROM durable.task t \
944 INNER JOIN descendants d ON t.parent_id = d.id \
945 ) \
946 UPDATE durable.task SET status = 'CANCELLED'{TS}, completed_at = now() \
947 WHERE id IN (SELECT id FROM descendants) \
948 AND status NOT IN ('COMPLETED'{TS}, 'FAILED'{TS}, 'CANCELLED'{TS})"
949 );
950 db.execute(Statement::from_string(DbBackend::Postgres, sql))
951 .await?;
952
953 tracing::info!(%task_id, "workflow cancelled");
954 Ok(())
955 }
956
957 pub async fn list(
965 db: &DatabaseConnection,
966 query: TaskQuery,
967 ) -> Result<Vec<TaskSummary>, DurableError> {
968 let mut select = Task::find();
969
970 if let Some(status) = &query.status {
972 select = select.filter(TaskColumn::Status.eq(status.to_string()));
973 }
974 if let Some(kind) = &query.kind {
975 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
976 }
977 if let Some(parent_id) = query.parent_id {
978 select = select.filter(TaskColumn::ParentId.eq(parent_id));
979 }
980 if query.root_only {
981 select = select.filter(TaskColumn::ParentId.is_null());
982 }
983 if let Some(name) = &query.name {
984 select = select.filter(TaskColumn::Name.eq(name.as_str()));
985 }
986 if let Some(queue) = &query.queue_name {
987 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
988 }
989
990 let (col, order) = match query.sort {
992 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
993 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
994 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
995 TaskSort::Name(ord) => (TaskColumn::Name, ord),
996 TaskSort::Status(ord) => (TaskColumn::Status, ord),
997 };
998 select = select.order_by(col, order);
999
1000 if let Some(offset) = query.offset {
1002 select = select.offset(offset);
1003 }
1004 if let Some(limit) = query.limit {
1005 select = select.limit(limit);
1006 }
1007
1008 let models = select.all(db).await?;
1009
1010 Ok(models.into_iter().map(TaskSummary::from).collect())
1011 }
1012
1013 pub async fn count(
1015 db: &DatabaseConnection,
1016 query: TaskQuery,
1017 ) -> Result<u64, DurableError> {
1018 let mut select = Task::find();
1019
1020 if let Some(status) = &query.status {
1021 select = select.filter(TaskColumn::Status.eq(status.to_string()));
1022 }
1023 if let Some(kind) = &query.kind {
1024 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1025 }
1026 if let Some(parent_id) = query.parent_id {
1027 select = select.filter(TaskColumn::ParentId.eq(parent_id));
1028 }
1029 if query.root_only {
1030 select = select.filter(TaskColumn::ParentId.is_null());
1031 }
1032 if let Some(name) = &query.name {
1033 select = select.filter(TaskColumn::Name.eq(name.as_str()));
1034 }
1035 if let Some(queue) = &query.queue_name {
1036 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1037 }
1038
1039 let count = select.count(db).await?;
1040 Ok(count)
1041 }
1042
1043 pub fn db(&self) -> &DatabaseConnection {
1046 &self.db
1047 }
1048
1049 pub fn task_id(&self) -> Uuid {
1050 self.task_id
1051 }
1052
1053 pub fn next_sequence(&self) -> i32 {
1054 self.sequence.fetch_add(1, Ordering::SeqCst)
1055 }
1056
1057 pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1067 let row = self
1068 .db
1069 .query_one(Statement::from_string(
1070 DbBackend::Postgres,
1071 format!(
1072 "SELECT input FROM durable.task WHERE id = '{}'",
1073 self.task_id
1074 ),
1075 ))
1076 .await?
1077 .ok_or_else(|| {
1078 DurableError::custom(format!("task {} not found", self.task_id))
1079 })?;
1080
1081 let input_json: Option<serde_json::Value> = row
1082 .try_get_by_index(0)
1083 .map_err(|e| DurableError::custom(e.to_string()))?;
1084
1085 let value = input_json.ok_or_else(|| {
1086 DurableError::custom(format!("task {} has no input", self.task_id))
1087 })?;
1088
1089 serde_json::from_value(value)
1090 .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1091 }
1092}
1093
1094#[allow(clippy::too_many_arguments)]
1117async fn find_or_create_task(
1118 db: &impl ConnectionTrait,
1119 parent_id: Option<Uuid>,
1120 sequence: Option<i32>,
1121 name: &str,
1122 kind: &str,
1123 input: Option<serde_json::Value>,
1124 lock: bool,
1125 max_retries: Option<u32>,
1126) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1127 let parent_eq = match parent_id {
1128 Some(p) => format!("= '{p}'"),
1129 None => "IS NULL".to_string(),
1130 };
1131 let parent_sql = match parent_id {
1132 Some(p) => format!("'{p}'"),
1133 None => "NULL".to_string(),
1134 };
1135
1136 if lock {
1137 let new_id = Uuid::new_v4();
1151 let seq_sql = match sequence {
1152 Some(s) => s.to_string(),
1153 None => "NULL".to_string(),
1154 };
1155 let input_sql = match &input {
1156 Some(v) => format!("'{}'", serde_json::to_string(v)?),
1157 None => "NULL".to_string(),
1158 };
1159
1160 let max_retries_sql = match max_retries {
1161 Some(r) => r.to_string(),
1162 None => "3".to_string(), };
1164
1165 let insert_sql = format!(
1167 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1168 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING'{TS}, {input_sql}, {max_retries_sql}) \
1169 ON CONFLICT (parent_id, sequence) DO NOTHING"
1170 );
1171 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1172 .await?;
1173
1174 let lock_sql = format!(
1176 "SELECT id, status::text, output FROM durable.task \
1177 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1178 FOR UPDATE SKIP LOCKED"
1179 );
1180 let row = db
1181 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1182 .await?;
1183
1184 if let Some(row) = row {
1185 let id: Uuid = row
1186 .try_get_by_index(0)
1187 .map_err(|e| DurableError::custom(e.to_string()))?;
1188 let status: String = row
1189 .try_get_by_index(1)
1190 .map_err(|e| DurableError::custom(e.to_string()))?;
1191 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1192
1193 if status == TaskStatus::Completed.to_string() {
1194 return Ok((id, output));
1196 }
1197 if status == TaskStatus::Running.to_string() {
1198 return Err(DurableError::StepLocked(name.to_string()));
1202 }
1203 return Ok((id, None));
1205 }
1206
1207 Err(DurableError::StepLocked(name.to_string()))
1209 } else {
1210 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1214 query = match parent_id {
1215 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1216 None => query.filter(TaskColumn::ParentId.is_null()),
1217 };
1218 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1221 let existing = query
1222 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1223 .one(db)
1224 .await?;
1225
1226 if let Some(model) = existing {
1227 if model.status == TaskStatus::Completed {
1228 return Ok((model.id, model.output));
1229 }
1230 return Ok((model.id, None));
1231 }
1232
1233 let id = Uuid::new_v4();
1235 let new_task = TaskActiveModel {
1236 id: Set(id),
1237 parent_id: Set(parent_id),
1238 sequence: Set(sequence),
1239 name: Set(name.to_string()),
1240 kind: Set(kind.to_string()),
1241 status: Set(TaskStatus::Pending),
1242 input: Set(input),
1243 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1244 ..Default::default()
1245 };
1246 new_task.insert(db).await?;
1247
1248 Ok((id, None))
1249 }
1250}
1251
1252async fn get_output(
1253 db: &impl ConnectionTrait,
1254 task_id: Uuid,
1255) -> Result<Option<serde_json::Value>, DurableError> {
1256 let model = Task::find_by_id(task_id)
1257 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1258 .one(db)
1259 .await?;
1260
1261 Ok(model.and_then(|m| m.output))
1262}
1263
1264async fn get_status(
1265 db: &impl ConnectionTrait,
1266 task_id: Uuid,
1267) -> Result<Option<TaskStatus>, DurableError> {
1268 let model = Task::find_by_id(task_id).one(db).await?;
1269
1270 Ok(model.map(|m| m.status))
1271}
1272
1273async fn get_retry_info(
1275 db: &DatabaseConnection,
1276 task_id: Uuid,
1277) -> Result<(u32, u32), DurableError> {
1278 let model = Task::find_by_id(task_id).one(db).await?;
1279
1280 match model {
1281 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1282 None => Err(DurableError::custom(format!(
1283 "task {task_id} not found when reading retry info"
1284 ))),
1285 }
1286}
1287
1288async fn increment_retry_count(
1290 db: &DatabaseConnection,
1291 task_id: Uuid,
1292) -> Result<u32, DurableError> {
1293 let model = Task::find_by_id(task_id).one(db).await?;
1294
1295 match model {
1296 Some(m) => {
1297 let new_count = m.retry_count + 1;
1298 let mut active: TaskActiveModel = m.into();
1299 active.retry_count = Set(new_count);
1300 active.status = Set(TaskStatus::Pending);
1301 active.error = Set(None);
1302 active.completed_at = Set(None);
1303 active.update(db).await?;
1304 Ok(new_count as u32)
1305 }
1306 None => Err(DurableError::custom(format!(
1307 "task {task_id} not found when incrementing retry count"
1308 ))),
1309 }
1310}
1311
1312async fn set_status(
1315 db: &impl ConnectionTrait,
1316 task_id: Uuid,
1317 status: TaskStatus,
1318) -> Result<(), DurableError> {
1319 let sql = format!(
1320 "UPDATE durable.task \
1321 SET status = '{status}'{TS}, \
1322 started_at = COALESCE(started_at, now()), \
1323 deadline_epoch_ms = CASE \
1324 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1325 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1326 ELSE deadline_epoch_ms \
1327 END \
1328 WHERE id = '{task_id}'"
1329 );
1330 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1331 .await?;
1332 Ok(())
1333}
1334
1335async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1337 let status = get_status(db, task_id).await?;
1338 match status {
1339 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1340 Some(TaskStatus::Cancelled) => {
1341 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1342 }
1343 _ => Ok(()),
1344 }
1345}
1346
1347async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1349 let model = Task::find_by_id(task_id).one(db).await?;
1350
1351 if let Some(m) = model
1352 && let Some(deadline_ms) = m.deadline_epoch_ms
1353 {
1354 let now_ms = std::time::SystemTime::now()
1355 .duration_since(std::time::UNIX_EPOCH)
1356 .map(|d| d.as_millis() as i64)
1357 .unwrap_or(0);
1358 if now_ms > deadline_ms {
1359 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1360 }
1361 }
1362
1363 Ok(())
1364}
1365
1366async fn complete_task(
1367 db: &impl ConnectionTrait,
1368 task_id: Uuid,
1369 output: serde_json::Value,
1370) -> Result<(), DurableError> {
1371 let model = Task::find_by_id(task_id).one(db).await?;
1372
1373 if let Some(m) = model {
1374 let mut active: TaskActiveModel = m.into();
1375 active.status = Set(TaskStatus::Completed);
1376 active.output = Set(Some(output));
1377 active.completed_at = Set(Some(chrono::Utc::now().into()));
1378 active.update(db).await?;
1379 }
1380 Ok(())
1381}
1382
1383async fn fail_task(
1384 db: &impl ConnectionTrait,
1385 task_id: Uuid,
1386 error: &str,
1387) -> Result<(), DurableError> {
1388 let model = Task::find_by_id(task_id).one(db).await?;
1389
1390 if let Some(m) = model {
1391 let mut active: TaskActiveModel = m.into();
1392 active.status = Set(TaskStatus::Failed);
1393 active.error = Set(Some(error.to_string()));
1394 active.completed_at = Set(Some(chrono::Utc::now().into()));
1395 active.update(db).await?;
1396 }
1397 Ok(())
1398}
1399
1400#[cfg(test)]
1401mod tests {
1402 use super::*;
1403 use std::sync::Arc;
1404 use std::sync::atomic::{AtomicU32, Ordering};
1405
1406 #[tokio::test]
1409 async fn test_retry_db_write_succeeds_first_try() {
1410 let call_count = Arc::new(AtomicU32::new(0));
1411 let cc = call_count.clone();
1412 let result = retry_db_write(|| {
1413 let c = cc.clone();
1414 async move {
1415 c.fetch_add(1, Ordering::SeqCst);
1416 Ok::<(), DurableError>(())
1417 }
1418 })
1419 .await;
1420 assert!(result.is_ok());
1421 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1422 }
1423
1424 #[tokio::test]
1427 async fn test_retry_db_write_succeeds_after_transient_failure() {
1428 let call_count = Arc::new(AtomicU32::new(0));
1429 let cc = call_count.clone();
1430 let result = retry_db_write(|| {
1431 let c = cc.clone();
1432 async move {
1433 let n = c.fetch_add(1, Ordering::SeqCst);
1434 if n < 2 {
1435 Err(DurableError::Db(sea_orm::DbErr::Custom(
1436 "transient".to_string(),
1437 )))
1438 } else {
1439 Ok(())
1440 }
1441 }
1442 })
1443 .await;
1444 assert!(result.is_ok());
1445 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1446 }
1447
1448 #[tokio::test]
1451 async fn test_retry_db_write_exhausts_retries() {
1452 let call_count = Arc::new(AtomicU32::new(0));
1453 let cc = call_count.clone();
1454 let result = retry_db_write(|| {
1455 let c = cc.clone();
1456 async move {
1457 c.fetch_add(1, Ordering::SeqCst);
1458 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1459 "always fails".to_string(),
1460 )))
1461 }
1462 })
1463 .await;
1464 assert!(result.is_err());
1465 assert_eq!(
1467 call_count.load(Ordering::SeqCst),
1468 1 + MAX_CHECKPOINT_RETRIES
1469 );
1470 }
1471
1472 #[tokio::test]
1475 async fn test_retry_db_write_returns_original_error() {
1476 let call_count = Arc::new(AtomicU32::new(0));
1477 let cc = call_count.clone();
1478 let result = retry_db_write(|| {
1479 let c = cc.clone();
1480 async move {
1481 let n = c.fetch_add(1, Ordering::SeqCst);
1482 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1483 "error-{}",
1484 n
1485 ))))
1486 }
1487 })
1488 .await;
1489 let err = result.unwrap_err();
1490 assert!(
1492 err.to_string().contains("error-0"),
1493 "expected first error (error-0), got: {err}"
1494 );
1495 }
1496}