1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub handler: Option<String>,
200 pub status: TaskStatus,
201 pub kind: String,
202 pub input: Option<serde_json::Value>,
203 pub output: Option<serde_json::Value>,
204 pub error: Option<String>,
205 pub queue_name: Option<String>,
206 pub created_at: chrono::DateTime<chrono::FixedOffset>,
207 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
209}
210
211impl From<durable_db::entity::task::Model> for TaskSummary {
212 fn from(m: durable_db::entity::task::Model) -> Self {
213 Self {
214 id: m.id,
215 parent_id: m.parent_id,
216 name: m.name,
217 handler: m.handler,
218 status: m.status,
219 kind: m.kind,
220 input: m.input,
221 output: m.output,
222 error: m.error,
223 queue_name: m.queue_name,
224 created_at: m.created_at,
225 started_at: m.started_at,
226 completed_at: m.completed_at,
227 }
228 }
229}
230
231pub struct Ctx {
237 db: DatabaseConnection,
238 task_id: Uuid,
239 sequence: AtomicI32,
240 executor_id: Option<String>,
241}
242
243impl Ctx {
244 pub async fn start(
261 db: &DatabaseConnection,
262 name: &str,
263 input: Option<serde_json::Value>,
264 ) -> Result<Self, DurableError> {
265 Self::start_with_handler(db, name, input, None).await
266 }
267
268 pub async fn start_with_handler(
276 db: &DatabaseConnection,
277 name: &str,
278 input: Option<serde_json::Value>,
279 handler: Option<&str>,
280 ) -> Result<Self, DurableError> {
281 let existing_sql = format!(
283 "SELECT id FROM durable.task \
284 WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING' \
285 LIMIT 1",
286 name
287 );
288 if let Some(row) = db
289 .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
290 .await?
291 {
292 let existing_id: Uuid = row
293 .try_get_by_index(0)
294 .map_err(|e| DurableError::custom(e.to_string()))?;
295 tracing::info!(
296 workflow = name,
297 id = %existing_id,
298 "idempotent start: attaching to existing RUNNING task"
299 );
300 return Self::from_id(db, existing_id).await;
301 }
302
303 let task_id = Uuid::new_v4();
304 let input_json = match &input {
305 Some(v) => serde_json::to_string(v)?,
306 None => "null".to_string(),
307 };
308
309 let executor_id = crate::executor_id();
310
311 let mut extra_cols = String::new();
312 let mut extra_vals = String::new();
313
314 if let Some(eid) = &executor_id {
315 extra_cols.push_str(", executor_id");
316 extra_vals.push_str(&format!(", '{eid}'"));
317 }
318 if let Some(h) = handler {
319 extra_cols.push_str(", handler");
320 extra_vals.push_str(&format!(", '{h}'"));
321 }
322
323 let txn = db.begin().await?;
324 let sql = format!(
325 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
326 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){extra_vals})"
327 );
328 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
329 .await?;
330 txn.commit().await?;
331
332 Ok(Self {
333 db: db.clone(),
334 task_id,
335 sequence: AtomicI32::new(0),
336 executor_id,
337 })
338 }
339
340 pub async fn from_id(
347 db: &DatabaseConnection,
348 task_id: Uuid,
349 ) -> Result<Self, DurableError> {
350 let model = Task::find_by_id(task_id).one(db).await?;
352 let _model =
353 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
354
355 let executor_id = crate::executor_id();
357 if let Some(eid) = &executor_id {
358 db.execute(Statement::from_string(
359 DbBackend::Postgres,
360 format!(
361 "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
362 ),
363 ))
364 .await?;
365 }
366
367 Ok(Self {
372 db: db.clone(),
373 task_id,
374 sequence: AtomicI32::new(0),
375 executor_id,
376 })
377 }
378
379 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
388 where
389 T: Serialize + DeserializeOwned,
390 F: FnOnce() -> Fut,
391 Fut: std::future::Future<Output = Result<T, DurableError>>,
392 {
393 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
394
395 check_status(&self.db, self.task_id).await?;
397
398 check_deadline(&self.db, self.task_id).await?;
400
401 let txn = self.db.begin().await?;
405 let (step_id, saved_output) = find_or_create_task(
406 &txn,
407 Some(self.task_id),
408 Some(seq),
409 name,
410 "STEP",
411 None,
412 true,
413 Some(0),
414 )
415 .await?;
416
417 if let Some(output) = saved_output {
418 txn.commit().await?;
419 let val: T = serde_json::from_value(output)?;
420 tracing::debug!(step = name, seq, "replaying saved output");
421 return Ok(val);
422 }
423
424 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
425 txn.commit().await?;
426 let result = f().await;
430
431 match result {
433 Ok(val) => {
434 let json = serde_json::to_value(&val)?;
435 retry_db_write(|| complete_task(&self.db, step_id, json.clone())).await?;
436 tracing::debug!(step = name, seq, "step completed");
437 Ok(val)
438 }
439 Err(e) => {
440 let err_msg = e.to_string();
441 retry_db_write(|| fail_task(&self.db, step_id, &err_msg)).await?;
442 Err(e)
443 }
444 }
445 }
446
447 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
459 where
460 T: Serialize + DeserializeOwned + Send,
461 F: for<'tx> FnOnce(
462 &'tx DatabaseTransaction,
463 ) -> Pin<
464 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
465 > + Send,
466 {
467 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
468
469 check_status(&self.db, self.task_id).await?;
471
472 let (step_id, saved_output) = find_or_create_task(
475 &self.db,
476 Some(self.task_id),
477 Some(seq),
478 name,
479 "TRANSACTION",
480 None,
481 false,
482 None,
483 )
484 .await?;
485
486 if let Some(output) = saved_output {
488 let val: T = serde_json::from_value(output)?;
489 tracing::debug!(step = name, seq, "replaying saved transaction output");
490 return Ok(val);
491 }
492
493 let tx = self.db.begin().await?;
495
496 set_status(&tx, step_id, TaskStatus::Running).await?;
497
498 match f(&tx).await {
499 Ok(val) => {
500 let json = serde_json::to_value(&val)?;
501 complete_task(&tx, step_id, json).await?;
502 tx.commit().await?;
503 tracing::debug!(step = name, seq, "transaction step committed");
504 Ok(val)
505 }
506 Err(e) => {
507 drop(tx);
510 fail_task(&self.db, step_id, &e.to_string()).await?;
511 Err(e)
512 }
513 }
514 }
515
516 pub async fn child(
524 &self,
525 name: &str,
526 input: Option<serde_json::Value>,
527 ) -> Result<Self, DurableError> {
528 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
529
530 check_status(&self.db, self.task_id).await?;
532
533 let txn = self.db.begin().await?;
534 let (child_id, _saved) = find_or_create_task(
536 &txn,
537 Some(self.task_id),
538 Some(seq),
539 name,
540 "WORKFLOW",
541 input,
542 false,
543 None,
544 )
545 .await?;
546
547 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
550 txn.commit().await?;
551
552 Ok(Self {
553 db: self.db.clone(),
554 task_id: child_id,
555 sequence: AtomicI32::new(0),
556 executor_id: self.executor_id.clone(),
557 })
558 }
559
560 pub async fn is_completed(&self) -> Result<bool, DurableError> {
562 let status = get_status(&self.db, self.task_id).await?;
563 Ok(status == Some(TaskStatus::Completed))
564 }
565
566 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
568 match get_output(&self.db, self.task_id).await? {
569 Some(val) => Ok(Some(serde_json::from_value(val)?)),
570 None => Ok(None),
571 }
572 }
573
574 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
576 let json = serde_json::to_value(output)?;
577 let db = &self.db;
578 let task_id = self.task_id;
579 retry_db_write(|| complete_task(db, task_id, json.clone())).await
580 }
581
582 pub async fn step_with_retry<T, F, Fut>(
596 &self,
597 name: &str,
598 policy: RetryPolicy,
599 f: F,
600 ) -> Result<T, DurableError>
601 where
602 T: Serialize + DeserializeOwned,
603 F: Fn() -> Fut,
604 Fut: std::future::Future<Output = Result<T, DurableError>>,
605 {
606 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
607
608 check_status(&self.db, self.task_id).await?;
610
611 let (step_id, saved_output) = find_or_create_task(
615 &self.db,
616 Some(self.task_id),
617 Some(seq),
618 name,
619 "STEP",
620 None,
621 false,
622 Some(policy.max_retries),
623 )
624 .await?;
625
626 if let Some(output) = saved_output {
628 let val: T = serde_json::from_value(output)?;
629 tracing::debug!(step = name, seq, "replaying saved output");
630 return Ok(val);
631 }
632
633 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
635
636 loop {
638 check_status(&self.db, self.task_id).await?;
640 set_status(&self.db, step_id, TaskStatus::Running).await?;
641 match f().await {
642 Ok(val) => {
643 let json = serde_json::to_value(&val)?;
644 complete_task(&self.db, step_id, json).await?;
645 tracing::debug!(step = name, seq, retry_count, "step completed");
646 return Ok(val);
647 }
648 Err(e) => {
649 if retry_count < max_retries {
650 retry_count = increment_retry_count(&self.db, step_id).await?;
652 tracing::debug!(
653 step = name,
654 seq,
655 retry_count,
656 max_retries,
657 "step failed, retrying"
658 );
659
660 let backoff = if policy.initial_backoff.is_zero() {
662 std::time::Duration::ZERO
663 } else {
664 let factor = policy
665 .backoff_multiplier
666 .powi((retry_count - 1) as i32)
667 .max(1.0);
668 let millis =
669 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
670 std::time::Duration::from_millis(millis)
671 };
672
673 if !backoff.is_zero() {
674 tokio::time::sleep(backoff).await;
675 }
676 } else {
677 fail_task(&self.db, step_id, &e.to_string()).await?;
679 tracing::debug!(
680 step = name,
681 seq,
682 retry_count,
683 "step exhausted retries, marked FAILED"
684 );
685 return Err(e);
686 }
687 }
688 }
689 }
690 }
691
692 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
694 let db = &self.db;
695 let task_id = self.task_id;
696 retry_db_write(|| fail_task(db, task_id, error)).await
697 }
698
699 pub async fn fail_by_id(
701 db: &DatabaseConnection,
702 task_id: Uuid,
703 error: &str,
704 ) -> Result<(), DurableError> {
705 retry_db_write(|| fail_task(db, task_id, error)).await
706 }
707
708 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
716 let sql = format!(
717 "UPDATE durable.task \
718 SET timeout_ms = {timeout_ms}, \
719 deadline_epoch_ms = CASE \
720 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
721 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
722 ELSE deadline_epoch_ms \
723 END \
724 WHERE id = '{}'",
725 self.task_id
726 );
727 self.db
728 .execute(Statement::from_string(DbBackend::Postgres, sql))
729 .await?;
730 Ok(())
731 }
732
733 pub async fn start_with_timeout(
737 db: &DatabaseConnection,
738 name: &str,
739 input: Option<serde_json::Value>,
740 timeout_ms: i64,
741 ) -> Result<Self, DurableError> {
742 let ctx = Self::start(db, name, input).await?;
743 ctx.set_timeout(timeout_ms).await?;
744 Ok(ctx)
745 }
746
747 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
754 let model = Task::find_by_id(task_id).one(db).await?;
755 let model =
756 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
757
758 match model.status {
759 TaskStatus::Pending | TaskStatus::Running => {}
760 status => {
761 return Err(DurableError::custom(format!(
762 "cannot pause task in {status} status"
763 )));
764 }
765 }
766
767 let sql = format!(
769 "WITH RECURSIVE descendants AS ( \
770 SELECT id FROM durable.task WHERE id = '{task_id}' \
771 UNION ALL \
772 SELECT t.id FROM durable.task t \
773 INNER JOIN descendants d ON t.parent_id = d.id \
774 ) \
775 UPDATE durable.task SET status = 'PAUSED' \
776 WHERE id IN (SELECT id FROM descendants) \
777 AND status IN ('PENDING', 'RUNNING')"
778 );
779 db.execute(Statement::from_string(DbBackend::Postgres, sql))
780 .await?;
781
782 tracing::info!(%task_id, "workflow paused");
783 Ok(())
784 }
785
786 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
789 let model = Task::find_by_id(task_id).one(db).await?;
790 let model =
791 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
792
793 if model.status != TaskStatus::Paused {
794 return Err(DurableError::custom(format!(
795 "cannot resume task in {} status (must be PAUSED)",
796 model.status
797 )));
798 }
799
800 let mut active: TaskActiveModel = model.into();
802 active.status = Set(TaskStatus::Running);
803 active.update(db).await?;
804
805 let cascade_sql = format!(
807 "WITH RECURSIVE descendants AS ( \
808 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
809 UNION ALL \
810 SELECT t.id FROM durable.task t \
811 INNER JOIN descendants d ON t.parent_id = d.id \
812 ) \
813 UPDATE durable.task SET status = 'PENDING' \
814 WHERE id IN (SELECT id FROM descendants) \
815 AND status = 'PAUSED'"
816 );
817 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
818 .await?;
819
820 tracing::info!(%task_id, "workflow resumed");
821 Ok(())
822 }
823
824 pub async fn resume_failed(
837 db: &DatabaseConnection,
838 task_id: Uuid,
839 ) -> Result<Option<String>, DurableError> {
840 let model = Task::find_by_id(task_id).one(db).await?;
842 let model =
843 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
844
845 if model.status != TaskStatus::Failed {
846 return Err(DurableError::custom(format!(
847 "cannot resume task in {} status (must be FAILED)",
848 model.status
849 )));
850 }
851
852 if model.kind != "WORKFLOW" {
853 return Err(DurableError::custom(format!(
854 "cannot resume a {}, only workflows can be resumed",
855 model.kind
856 )));
857 }
858
859 if model.recovery_count >= model.max_recovery_attempts {
861 return Err(DurableError::MaxRecoveryExceeded(task_id.to_string()));
862 }
863
864 let executor_id = crate::executor_id().unwrap_or_default();
865
866 let reset_workflow_sql = format!(
868 "UPDATE durable.task \
869 SET status = 'RUNNING', \
870 error = NULL, \
871 completed_at = NULL, \
872 started_at = now(), \
873 executor_id = '{executor_id}', \
874 recovery_count = recovery_count + 1 \
875 WHERE id = '{task_id}' AND status = 'FAILED'"
876 );
877 db.execute(Statement::from_string(DbBackend::Postgres, reset_workflow_sql))
878 .await?;
879
880 let reset_steps_sql = format!(
883 "WITH RECURSIVE descendants AS ( \
884 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
885 UNION ALL \
886 SELECT t.id FROM durable.task t \
887 INNER JOIN descendants d ON t.parent_id = d.id \
888 ) \
889 UPDATE durable.task \
890 SET status = 'PENDING', \
891 error = NULL, \
892 retry_count = 0, \
893 completed_at = NULL, \
894 started_at = NULL \
895 WHERE id IN (SELECT id FROM descendants) \
896 AND status IN ('FAILED', 'RUNNING')"
897 );
898 db.execute(Statement::from_string(DbBackend::Postgres, reset_steps_sql))
899 .await?;
900
901 tracing::info!(
902 %task_id,
903 recovery_count = model.recovery_count + 1,
904 "failed workflow resumed"
905 );
906
907 Ok(model.handler.clone())
908 }
909
910 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
915 let model = Task::find_by_id(task_id).one(db).await?;
916 let model =
917 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
918
919 match model.status {
920 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
921 return Err(DurableError::custom(format!(
922 "cannot cancel task in {} status",
923 model.status
924 )));
925 }
926 _ => {}
927 }
928
929 let sql = format!(
931 "WITH RECURSIVE descendants AS ( \
932 SELECT id FROM durable.task WHERE id = '{task_id}' \
933 UNION ALL \
934 SELECT t.id FROM durable.task t \
935 INNER JOIN descendants d ON t.parent_id = d.id \
936 ) \
937 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
938 WHERE id IN (SELECT id FROM descendants) \
939 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
940 );
941 db.execute(Statement::from_string(DbBackend::Postgres, sql))
942 .await?;
943
944 tracing::info!(%task_id, "workflow cancelled");
945 Ok(())
946 }
947
948 pub async fn list(
956 db: &DatabaseConnection,
957 query: TaskQuery,
958 ) -> Result<Vec<TaskSummary>, DurableError> {
959 let mut select = Task::find();
960
961 if let Some(status) = &query.status {
963 select = select.filter(TaskColumn::Status.eq(status.to_string()));
964 }
965 if let Some(kind) = &query.kind {
966 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
967 }
968 if let Some(parent_id) = query.parent_id {
969 select = select.filter(TaskColumn::ParentId.eq(parent_id));
970 }
971 if query.root_only {
972 select = select.filter(TaskColumn::ParentId.is_null());
973 }
974 if let Some(name) = &query.name {
975 select = select.filter(TaskColumn::Name.eq(name.as_str()));
976 }
977 if let Some(queue) = &query.queue_name {
978 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
979 }
980
981 let (col, order) = match query.sort {
983 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
984 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
985 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
986 TaskSort::Name(ord) => (TaskColumn::Name, ord),
987 TaskSort::Status(ord) => (TaskColumn::Status, ord),
988 };
989 select = select.order_by(col, order);
990
991 if let Some(offset) = query.offset {
993 select = select.offset(offset);
994 }
995 if let Some(limit) = query.limit {
996 select = select.limit(limit);
997 }
998
999 let models = select.all(db).await?;
1000
1001 Ok(models.into_iter().map(TaskSummary::from).collect())
1002 }
1003
1004 pub async fn count(
1006 db: &DatabaseConnection,
1007 query: TaskQuery,
1008 ) -> Result<u64, DurableError> {
1009 let mut select = Task::find();
1010
1011 if let Some(status) = &query.status {
1012 select = select.filter(TaskColumn::Status.eq(status.to_string()));
1013 }
1014 if let Some(kind) = &query.kind {
1015 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
1016 }
1017 if let Some(parent_id) = query.parent_id {
1018 select = select.filter(TaskColumn::ParentId.eq(parent_id));
1019 }
1020 if query.root_only {
1021 select = select.filter(TaskColumn::ParentId.is_null());
1022 }
1023 if let Some(name) = &query.name {
1024 select = select.filter(TaskColumn::Name.eq(name.as_str()));
1025 }
1026 if let Some(queue) = &query.queue_name {
1027 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
1028 }
1029
1030 let count = select.count(db).await?;
1031 Ok(count)
1032 }
1033
1034 pub fn db(&self) -> &DatabaseConnection {
1037 &self.db
1038 }
1039
1040 pub fn task_id(&self) -> Uuid {
1041 self.task_id
1042 }
1043
1044 pub fn next_sequence(&self) -> i32 {
1045 self.sequence.fetch_add(1, Ordering::SeqCst)
1046 }
1047
1048 pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
1058 let row = self
1059 .db
1060 .query_one(Statement::from_string(
1061 DbBackend::Postgres,
1062 format!(
1063 "SELECT input FROM durable.task WHERE id = '{}'",
1064 self.task_id
1065 ),
1066 ))
1067 .await?
1068 .ok_or_else(|| {
1069 DurableError::custom(format!("task {} not found", self.task_id))
1070 })?;
1071
1072 let input_json: Option<serde_json::Value> = row
1073 .try_get_by_index(0)
1074 .map_err(|e| DurableError::custom(e.to_string()))?;
1075
1076 let value = input_json.ok_or_else(|| {
1077 DurableError::custom(format!("task {} has no input", self.task_id))
1078 })?;
1079
1080 serde_json::from_value(value)
1081 .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
1082 }
1083}
1084
1085#[allow(clippy::too_many_arguments)]
1108async fn find_or_create_task(
1109 db: &impl ConnectionTrait,
1110 parent_id: Option<Uuid>,
1111 sequence: Option<i32>,
1112 name: &str,
1113 kind: &str,
1114 input: Option<serde_json::Value>,
1115 lock: bool,
1116 max_retries: Option<u32>,
1117) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1118 let parent_eq = match parent_id {
1119 Some(p) => format!("= '{p}'"),
1120 None => "IS NULL".to_string(),
1121 };
1122 let parent_sql = match parent_id {
1123 Some(p) => format!("'{p}'"),
1124 None => "NULL".to_string(),
1125 };
1126
1127 if lock {
1128 let new_id = Uuid::new_v4();
1142 let seq_sql = match sequence {
1143 Some(s) => s.to_string(),
1144 None => "NULL".to_string(),
1145 };
1146 let input_sql = match &input {
1147 Some(v) => format!("'{}'", serde_json::to_string(v)?),
1148 None => "NULL".to_string(),
1149 };
1150
1151 let max_retries_sql = match max_retries {
1152 Some(r) => r.to_string(),
1153 None => "3".to_string(), };
1155
1156 let insert_sql = format!(
1158 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1159 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
1160 ON CONFLICT (parent_id, sequence) DO NOTHING"
1161 );
1162 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1163 .await?;
1164
1165 let lock_sql = format!(
1167 "SELECT id, status::text, output FROM durable.task \
1168 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1169 FOR UPDATE SKIP LOCKED"
1170 );
1171 let row = db
1172 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1173 .await?;
1174
1175 if let Some(row) = row {
1176 let id: Uuid = row
1177 .try_get_by_index(0)
1178 .map_err(|e| DurableError::custom(e.to_string()))?;
1179 let status: String = row
1180 .try_get_by_index(1)
1181 .map_err(|e| DurableError::custom(e.to_string()))?;
1182 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1183
1184 if status == TaskStatus::Completed.to_string() {
1185 return Ok((id, output));
1187 }
1188 if status == TaskStatus::Running.to_string() {
1189 return Err(DurableError::StepLocked(name.to_string()));
1193 }
1194 return Ok((id, None));
1196 }
1197
1198 Err(DurableError::StepLocked(name.to_string()))
1200 } else {
1201 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1205 query = match parent_id {
1206 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1207 None => query.filter(TaskColumn::ParentId.is_null()),
1208 };
1209 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1212 let existing = query
1213 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1214 .one(db)
1215 .await?;
1216
1217 if let Some(model) = existing {
1218 if model.status == TaskStatus::Completed {
1219 return Ok((model.id, model.output));
1220 }
1221 return Ok((model.id, None));
1222 }
1223
1224 let id = Uuid::new_v4();
1226 let new_task = TaskActiveModel {
1227 id: Set(id),
1228 parent_id: Set(parent_id),
1229 sequence: Set(sequence),
1230 name: Set(name.to_string()),
1231 kind: Set(kind.to_string()),
1232 status: Set(TaskStatus::Pending),
1233 input: Set(input),
1234 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1235 ..Default::default()
1236 };
1237 new_task.insert(db).await?;
1238
1239 Ok((id, None))
1240 }
1241}
1242
1243async fn get_output(
1244 db: &impl ConnectionTrait,
1245 task_id: Uuid,
1246) -> Result<Option<serde_json::Value>, DurableError> {
1247 let model = Task::find_by_id(task_id)
1248 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1249 .one(db)
1250 .await?;
1251
1252 Ok(model.and_then(|m| m.output))
1253}
1254
1255async fn get_status(
1256 db: &impl ConnectionTrait,
1257 task_id: Uuid,
1258) -> Result<Option<TaskStatus>, DurableError> {
1259 let model = Task::find_by_id(task_id).one(db).await?;
1260
1261 Ok(model.map(|m| m.status))
1262}
1263
1264async fn get_retry_info(
1266 db: &DatabaseConnection,
1267 task_id: Uuid,
1268) -> Result<(u32, u32), DurableError> {
1269 let model = Task::find_by_id(task_id).one(db).await?;
1270
1271 match model {
1272 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1273 None => Err(DurableError::custom(format!(
1274 "task {task_id} not found when reading retry info"
1275 ))),
1276 }
1277}
1278
1279async fn increment_retry_count(
1281 db: &DatabaseConnection,
1282 task_id: Uuid,
1283) -> Result<u32, DurableError> {
1284 let model = Task::find_by_id(task_id).one(db).await?;
1285
1286 match model {
1287 Some(m) => {
1288 let new_count = m.retry_count + 1;
1289 let mut active: TaskActiveModel = m.into();
1290 active.retry_count = Set(new_count);
1291 active.status = Set(TaskStatus::Pending);
1292 active.error = Set(None);
1293 active.completed_at = Set(None);
1294 active.update(db).await?;
1295 Ok(new_count as u32)
1296 }
1297 None => Err(DurableError::custom(format!(
1298 "task {task_id} not found when incrementing retry count"
1299 ))),
1300 }
1301}
1302
1303async fn set_status(
1306 db: &impl ConnectionTrait,
1307 task_id: Uuid,
1308 status: TaskStatus,
1309) -> Result<(), DurableError> {
1310 let sql = format!(
1311 "UPDATE durable.task \
1312 SET status = '{status}', \
1313 started_at = COALESCE(started_at, now()), \
1314 deadline_epoch_ms = CASE \
1315 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1316 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1317 ELSE deadline_epoch_ms \
1318 END \
1319 WHERE id = '{task_id}'"
1320 );
1321 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1322 .await?;
1323 Ok(())
1324}
1325
1326async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1328 let status = get_status(db, task_id).await?;
1329 match status {
1330 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1331 Some(TaskStatus::Cancelled) => {
1332 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1333 }
1334 _ => Ok(()),
1335 }
1336}
1337
1338async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1340 let model = Task::find_by_id(task_id).one(db).await?;
1341
1342 if let Some(m) = model
1343 && let Some(deadline_ms) = m.deadline_epoch_ms
1344 {
1345 let now_ms = std::time::SystemTime::now()
1346 .duration_since(std::time::UNIX_EPOCH)
1347 .map(|d| d.as_millis() as i64)
1348 .unwrap_or(0);
1349 if now_ms > deadline_ms {
1350 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1351 }
1352 }
1353
1354 Ok(())
1355}
1356
1357async fn complete_task(
1358 db: &impl ConnectionTrait,
1359 task_id: Uuid,
1360 output: serde_json::Value,
1361) -> Result<(), DurableError> {
1362 let model = Task::find_by_id(task_id).one(db).await?;
1363
1364 if let Some(m) = model {
1365 let mut active: TaskActiveModel = m.into();
1366 active.status = Set(TaskStatus::Completed);
1367 active.output = Set(Some(output));
1368 active.completed_at = Set(Some(chrono::Utc::now().into()));
1369 active.update(db).await?;
1370 }
1371 Ok(())
1372}
1373
1374async fn fail_task(
1375 db: &impl ConnectionTrait,
1376 task_id: Uuid,
1377 error: &str,
1378) -> Result<(), DurableError> {
1379 let model = Task::find_by_id(task_id).one(db).await?;
1380
1381 if let Some(m) = model {
1382 let mut active: TaskActiveModel = m.into();
1383 active.status = Set(TaskStatus::Failed);
1384 active.error = Set(Some(error.to_string()));
1385 active.completed_at = Set(Some(chrono::Utc::now().into()));
1386 active.update(db).await?;
1387 }
1388 Ok(())
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393 use super::*;
1394 use std::sync::Arc;
1395 use std::sync::atomic::{AtomicU32, Ordering};
1396
1397 #[tokio::test]
1400 async fn test_retry_db_write_succeeds_first_try() {
1401 let call_count = Arc::new(AtomicU32::new(0));
1402 let cc = call_count.clone();
1403 let result = retry_db_write(|| {
1404 let c = cc.clone();
1405 async move {
1406 c.fetch_add(1, Ordering::SeqCst);
1407 Ok::<(), DurableError>(())
1408 }
1409 })
1410 .await;
1411 assert!(result.is_ok());
1412 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1413 }
1414
1415 #[tokio::test]
1418 async fn test_retry_db_write_succeeds_after_transient_failure() {
1419 let call_count = Arc::new(AtomicU32::new(0));
1420 let cc = call_count.clone();
1421 let result = retry_db_write(|| {
1422 let c = cc.clone();
1423 async move {
1424 let n = c.fetch_add(1, Ordering::SeqCst);
1425 if n < 2 {
1426 Err(DurableError::Db(sea_orm::DbErr::Custom(
1427 "transient".to_string(),
1428 )))
1429 } else {
1430 Ok(())
1431 }
1432 }
1433 })
1434 .await;
1435 assert!(result.is_ok());
1436 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1437 }
1438
1439 #[tokio::test]
1442 async fn test_retry_db_write_exhausts_retries() {
1443 let call_count = Arc::new(AtomicU32::new(0));
1444 let cc = call_count.clone();
1445 let result = retry_db_write(|| {
1446 let c = cc.clone();
1447 async move {
1448 c.fetch_add(1, Ordering::SeqCst);
1449 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1450 "always fails".to_string(),
1451 )))
1452 }
1453 })
1454 .await;
1455 assert!(result.is_err());
1456 assert_eq!(
1458 call_count.load(Ordering::SeqCst),
1459 1 + MAX_CHECKPOINT_RETRIES
1460 );
1461 }
1462
1463 #[tokio::test]
1466 async fn test_retry_db_write_returns_original_error() {
1467 let call_count = Arc::new(AtomicU32::new(0));
1468 let cc = call_count.clone();
1469 let result = retry_db_write(|| {
1470 let c = cc.clone();
1471 async move {
1472 let n = c.fetch_add(1, Ordering::SeqCst);
1473 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1474 "error-{}",
1475 n
1476 ))))
1477 }
1478 })
1479 .await;
1480 let err = result.unwrap_err();
1481 assert!(
1483 err.to_string().contains("error-0"),
1484 "expected first error (error-0), got: {err}"
1485 );
1486 }
1487}