1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub handler: Option<String>,
200 pub status: TaskStatus,
201 pub kind: String,
202 pub input: Option<serde_json::Value>,
203 pub output: Option<serde_json::Value>,
204 pub error: Option<String>,
205 pub queue_name: Option<String>,
206 pub created_at: chrono::DateTime<chrono::FixedOffset>,
207 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
209}
210
211impl From<durable_db::entity::task::Model> for TaskSummary {
212 fn from(m: durable_db::entity::task::Model) -> Self {
213 Self {
214 id: m.id,
215 parent_id: m.parent_id,
216 name: m.name,
217 handler: m.handler,
218 status: m.status,
219 kind: m.kind,
220 input: m.input,
221 output: m.output,
222 error: m.error,
223 queue_name: m.queue_name,
224 created_at: m.created_at,
225 started_at: m.started_at,
226 completed_at: m.completed_at,
227 }
228 }
229}
230
231pub struct Ctx {
237 db: DatabaseConnection,
238 task_id: Uuid,
239 sequence: AtomicI32,
240 executor_id: Option<String>,
241}
242
243impl Ctx {
244 pub async fn start(
261 db: &DatabaseConnection,
262 name: &str,
263 input: Option<serde_json::Value>,
264 ) -> Result<Self, DurableError> {
265 Self::start_with_handler(db, name, input, None).await
266 }
267
268 pub async fn start_with_handler(
276 db: &DatabaseConnection,
277 name: &str,
278 input: Option<serde_json::Value>,
279 handler: Option<&str>,
280 ) -> Result<Self, DurableError> {
281 let existing_sql = format!(
283 "SELECT id FROM durable.task \
284 WHERE name = '{}' AND parent_id IS NULL AND status = 'RUNNING' \
285 LIMIT 1",
286 name
287 );
288 if let Some(row) = db
289 .query_one(Statement::from_string(DbBackend::Postgres, existing_sql))
290 .await?
291 {
292 let existing_id: Uuid = row
293 .try_get_by_index(0)
294 .map_err(|e| DurableError::custom(e.to_string()))?;
295 tracing::info!(
296 workflow = name,
297 id = %existing_id,
298 "idempotent start: attaching to existing RUNNING task"
299 );
300 return Self::from_id(db, existing_id).await;
301 }
302
303 let task_id = Uuid::new_v4();
304 let input_json = match &input {
305 Some(v) => serde_json::to_string(v)?,
306 None => "null".to_string(),
307 };
308
309 let executor_id = crate::executor_id();
310
311 let mut extra_cols = String::new();
312 let mut extra_vals = String::new();
313
314 if let Some(eid) = &executor_id {
315 extra_cols.push_str(", executor_id");
316 extra_vals.push_str(&format!(", '{eid}'"));
317 }
318 if let Some(h) = handler {
319 extra_cols.push_str(", handler");
320 extra_vals.push_str(&format!(", '{h}'"));
321 }
322
323 let txn = db.begin().await?;
324 let sql = format!(
325 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, started_at{extra_cols}) \
326 VALUES ('{task_id}', NULL, NULL, '{name}', 'WORKFLOW', 'RUNNING', '{input_json}', now(){extra_vals})"
327 );
328 txn.execute(Statement::from_string(DbBackend::Postgres, sql))
329 .await?;
330 txn.commit().await?;
331
332 Ok(Self {
333 db: db.clone(),
334 task_id,
335 sequence: AtomicI32::new(0),
336 executor_id,
337 })
338 }
339
340 pub async fn from_id(
347 db: &DatabaseConnection,
348 task_id: Uuid,
349 ) -> Result<Self, DurableError> {
350 let model = Task::find_by_id(task_id).one(db).await?;
352 let _model =
353 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
354
355 let executor_id = crate::executor_id();
357 if let Some(eid) = &executor_id {
358 db.execute(Statement::from_string(
359 DbBackend::Postgres,
360 format!(
361 "UPDATE durable.task SET executor_id = '{eid}' WHERE id = '{task_id}'"
362 ),
363 ))
364 .await?;
365 }
366
367 Ok(Self {
372 db: db.clone(),
373 task_id,
374 sequence: AtomicI32::new(0),
375 executor_id,
376 })
377 }
378
379 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
388 where
389 T: Serialize + DeserializeOwned,
390 F: FnOnce() -> Fut,
391 Fut: std::future::Future<Output = Result<T, DurableError>>,
392 {
393 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
394
395 check_status(&self.db, self.task_id).await?;
397
398 check_deadline(&self.db, self.task_id).await?;
400
401 let txn = self.db.begin().await?;
403
404 let (step_id, saved_output) = find_or_create_task(
409 &txn,
410 Some(self.task_id),
411 Some(seq),
412 name,
413 "STEP",
414 None,
415 true,
416 Some(0),
417 )
418 .await?;
419
420 if let Some(output) = saved_output {
422 txn.commit().await?;
423 let val: T = serde_json::from_value(output)?;
424 tracing::debug!(step = name, seq, "replaying saved output");
425 return Ok(val);
426 }
427
428 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
430 match f().await {
431 Ok(val) => {
432 let json = serde_json::to_value(&val)?;
433 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
434 txn.commit().await?;
435 tracing::debug!(step = name, seq, "step completed");
436 Ok(val)
437 }
438 Err(e) => {
439 let err_msg = e.to_string();
440 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
441 txn.commit().await?;
442 Err(e)
443 }
444 }
445 }
446
447 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
459 where
460 T: Serialize + DeserializeOwned + Send,
461 F: for<'tx> FnOnce(
462 &'tx DatabaseTransaction,
463 ) -> Pin<
464 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
465 > + Send,
466 {
467 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
468
469 check_status(&self.db, self.task_id).await?;
471
472 let (step_id, saved_output) = find_or_create_task(
475 &self.db,
476 Some(self.task_id),
477 Some(seq),
478 name,
479 "TRANSACTION",
480 None,
481 false,
482 None,
483 )
484 .await?;
485
486 if let Some(output) = saved_output {
488 let val: T = serde_json::from_value(output)?;
489 tracing::debug!(step = name, seq, "replaying saved transaction output");
490 return Ok(val);
491 }
492
493 let tx = self.db.begin().await?;
495
496 set_status(&tx, step_id, TaskStatus::Running).await?;
497
498 match f(&tx).await {
499 Ok(val) => {
500 let json = serde_json::to_value(&val)?;
501 complete_task(&tx, step_id, json).await?;
502 tx.commit().await?;
503 tracing::debug!(step = name, seq, "transaction step committed");
504 Ok(val)
505 }
506 Err(e) => {
507 drop(tx);
510 fail_task(&self.db, step_id, &e.to_string()).await?;
511 Err(e)
512 }
513 }
514 }
515
516 pub async fn child(
524 &self,
525 name: &str,
526 input: Option<serde_json::Value>,
527 ) -> Result<Self, DurableError> {
528 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
529
530 check_status(&self.db, self.task_id).await?;
532
533 let txn = self.db.begin().await?;
534 let (child_id, _saved) = find_or_create_task(
536 &txn,
537 Some(self.task_id),
538 Some(seq),
539 name,
540 "WORKFLOW",
541 input,
542 false,
543 None,
544 )
545 .await?;
546
547 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
550 txn.commit().await?;
551
552 Ok(Self {
553 db: self.db.clone(),
554 task_id: child_id,
555 sequence: AtomicI32::new(0),
556 executor_id: self.executor_id.clone(),
557 })
558 }
559
560 pub async fn is_completed(&self) -> Result<bool, DurableError> {
562 let status = get_status(&self.db, self.task_id).await?;
563 Ok(status == Some(TaskStatus::Completed))
564 }
565
566 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
568 match get_output(&self.db, self.task_id).await? {
569 Some(val) => Ok(Some(serde_json::from_value(val)?)),
570 None => Ok(None),
571 }
572 }
573
574 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
576 let json = serde_json::to_value(output)?;
577 let db = &self.db;
578 let task_id = self.task_id;
579 retry_db_write(|| complete_task(db, task_id, json.clone())).await
580 }
581
582 pub async fn step_with_retry<T, F, Fut>(
596 &self,
597 name: &str,
598 policy: RetryPolicy,
599 f: F,
600 ) -> Result<T, DurableError>
601 where
602 T: Serialize + DeserializeOwned,
603 F: Fn() -> Fut,
604 Fut: std::future::Future<Output = Result<T, DurableError>>,
605 {
606 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
607
608 check_status(&self.db, self.task_id).await?;
610
611 let (step_id, saved_output) = find_or_create_task(
615 &self.db,
616 Some(self.task_id),
617 Some(seq),
618 name,
619 "STEP",
620 None,
621 false,
622 Some(policy.max_retries),
623 )
624 .await?;
625
626 if let Some(output) = saved_output {
628 let val: T = serde_json::from_value(output)?;
629 tracing::debug!(step = name, seq, "replaying saved output");
630 return Ok(val);
631 }
632
633 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
635
636 loop {
638 check_status(&self.db, self.task_id).await?;
640 set_status(&self.db, step_id, TaskStatus::Running).await?;
641 match f().await {
642 Ok(val) => {
643 let json = serde_json::to_value(&val)?;
644 complete_task(&self.db, step_id, json).await?;
645 tracing::debug!(step = name, seq, retry_count, "step completed");
646 return Ok(val);
647 }
648 Err(e) => {
649 if retry_count < max_retries {
650 retry_count = increment_retry_count(&self.db, step_id).await?;
652 tracing::debug!(
653 step = name,
654 seq,
655 retry_count,
656 max_retries,
657 "step failed, retrying"
658 );
659
660 let backoff = if policy.initial_backoff.is_zero() {
662 std::time::Duration::ZERO
663 } else {
664 let factor = policy
665 .backoff_multiplier
666 .powi((retry_count - 1) as i32)
667 .max(1.0);
668 let millis =
669 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
670 std::time::Duration::from_millis(millis)
671 };
672
673 if !backoff.is_zero() {
674 tokio::time::sleep(backoff).await;
675 }
676 } else {
677 fail_task(&self.db, step_id, &e.to_string()).await?;
679 tracing::debug!(
680 step = name,
681 seq,
682 retry_count,
683 "step exhausted retries, marked FAILED"
684 );
685 return Err(e);
686 }
687 }
688 }
689 }
690 }
691
692 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
694 let db = &self.db;
695 let task_id = self.task_id;
696 retry_db_write(|| fail_task(db, task_id, error)).await
697 }
698
699 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
707 let sql = format!(
708 "UPDATE durable.task \
709 SET timeout_ms = {timeout_ms}, \
710 deadline_epoch_ms = CASE \
711 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
712 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
713 ELSE deadline_epoch_ms \
714 END \
715 WHERE id = '{}'",
716 self.task_id
717 );
718 self.db
719 .execute(Statement::from_string(DbBackend::Postgres, sql))
720 .await?;
721 Ok(())
722 }
723
724 pub async fn start_with_timeout(
728 db: &DatabaseConnection,
729 name: &str,
730 input: Option<serde_json::Value>,
731 timeout_ms: i64,
732 ) -> Result<Self, DurableError> {
733 let ctx = Self::start(db, name, input).await?;
734 ctx.set_timeout(timeout_ms).await?;
735 Ok(ctx)
736 }
737
738 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
745 let model = Task::find_by_id(task_id).one(db).await?;
746 let model =
747 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
748
749 match model.status {
750 TaskStatus::Pending | TaskStatus::Running => {}
751 status => {
752 return Err(DurableError::custom(format!(
753 "cannot pause task in {status} status"
754 )));
755 }
756 }
757
758 let sql = format!(
760 "WITH RECURSIVE descendants AS ( \
761 SELECT id FROM durable.task WHERE id = '{task_id}' \
762 UNION ALL \
763 SELECT t.id FROM durable.task t \
764 INNER JOIN descendants d ON t.parent_id = d.id \
765 ) \
766 UPDATE durable.task SET status = 'PAUSED' \
767 WHERE id IN (SELECT id FROM descendants) \
768 AND status IN ('PENDING', 'RUNNING')"
769 );
770 db.execute(Statement::from_string(DbBackend::Postgres, sql))
771 .await?;
772
773 tracing::info!(%task_id, "workflow paused");
774 Ok(())
775 }
776
777 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
780 let model = Task::find_by_id(task_id).one(db).await?;
781 let model =
782 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
783
784 if model.status != TaskStatus::Paused {
785 return Err(DurableError::custom(format!(
786 "cannot resume task in {} status (must be PAUSED)",
787 model.status
788 )));
789 }
790
791 let mut active: TaskActiveModel = model.into();
793 active.status = Set(TaskStatus::Running);
794 active.update(db).await?;
795
796 let cascade_sql = format!(
798 "WITH RECURSIVE descendants AS ( \
799 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
800 UNION ALL \
801 SELECT t.id FROM durable.task t \
802 INNER JOIN descendants d ON t.parent_id = d.id \
803 ) \
804 UPDATE durable.task SET status = 'PENDING' \
805 WHERE id IN (SELECT id FROM descendants) \
806 AND status = 'PAUSED'"
807 );
808 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
809 .await?;
810
811 tracing::info!(%task_id, "workflow resumed");
812 Ok(())
813 }
814
815 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
820 let model = Task::find_by_id(task_id).one(db).await?;
821 let model =
822 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
823
824 match model.status {
825 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
826 return Err(DurableError::custom(format!(
827 "cannot cancel task in {} status",
828 model.status
829 )));
830 }
831 _ => {}
832 }
833
834 let sql = format!(
836 "WITH RECURSIVE descendants AS ( \
837 SELECT id FROM durable.task WHERE id = '{task_id}' \
838 UNION ALL \
839 SELECT t.id FROM durable.task t \
840 INNER JOIN descendants d ON t.parent_id = d.id \
841 ) \
842 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
843 WHERE id IN (SELECT id FROM descendants) \
844 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
845 );
846 db.execute(Statement::from_string(DbBackend::Postgres, sql))
847 .await?;
848
849 tracing::info!(%task_id, "workflow cancelled");
850 Ok(())
851 }
852
853 pub async fn list(
861 db: &DatabaseConnection,
862 query: TaskQuery,
863 ) -> Result<Vec<TaskSummary>, DurableError> {
864 let mut select = Task::find();
865
866 if let Some(status) = &query.status {
868 select = select.filter(TaskColumn::Status.eq(status.to_string()));
869 }
870 if let Some(kind) = &query.kind {
871 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
872 }
873 if let Some(parent_id) = query.parent_id {
874 select = select.filter(TaskColumn::ParentId.eq(parent_id));
875 }
876 if query.root_only {
877 select = select.filter(TaskColumn::ParentId.is_null());
878 }
879 if let Some(name) = &query.name {
880 select = select.filter(TaskColumn::Name.eq(name.as_str()));
881 }
882 if let Some(queue) = &query.queue_name {
883 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
884 }
885
886 let (col, order) = match query.sort {
888 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
889 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
890 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
891 TaskSort::Name(ord) => (TaskColumn::Name, ord),
892 TaskSort::Status(ord) => (TaskColumn::Status, ord),
893 };
894 select = select.order_by(col, order);
895
896 if let Some(offset) = query.offset {
898 select = select.offset(offset);
899 }
900 if let Some(limit) = query.limit {
901 select = select.limit(limit);
902 }
903
904 let models = select.all(db).await?;
905
906 Ok(models.into_iter().map(TaskSummary::from).collect())
907 }
908
909 pub async fn count(
911 db: &DatabaseConnection,
912 query: TaskQuery,
913 ) -> Result<u64, DurableError> {
914 let mut select = Task::find();
915
916 if let Some(status) = &query.status {
917 select = select.filter(TaskColumn::Status.eq(status.to_string()));
918 }
919 if let Some(kind) = &query.kind {
920 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
921 }
922 if let Some(parent_id) = query.parent_id {
923 select = select.filter(TaskColumn::ParentId.eq(parent_id));
924 }
925 if query.root_only {
926 select = select.filter(TaskColumn::ParentId.is_null());
927 }
928 if let Some(name) = &query.name {
929 select = select.filter(TaskColumn::Name.eq(name.as_str()));
930 }
931 if let Some(queue) = &query.queue_name {
932 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
933 }
934
935 let count = select.count(db).await?;
936 Ok(count)
937 }
938
939 pub fn db(&self) -> &DatabaseConnection {
942 &self.db
943 }
944
945 pub fn task_id(&self) -> Uuid {
946 self.task_id
947 }
948
949 pub fn next_sequence(&self) -> i32 {
950 self.sequence.fetch_add(1, Ordering::SeqCst)
951 }
952
953 pub async fn input<T: DeserializeOwned>(&self) -> Result<T, DurableError> {
963 let row = self
964 .db
965 .query_one(Statement::from_string(
966 DbBackend::Postgres,
967 format!(
968 "SELECT input FROM durable.task WHERE id = '{}'",
969 self.task_id
970 ),
971 ))
972 .await?
973 .ok_or_else(|| {
974 DurableError::custom(format!("task {} not found", self.task_id))
975 })?;
976
977 let input_json: Option<serde_json::Value> = row
978 .try_get_by_index(0)
979 .map_err(|e| DurableError::custom(e.to_string()))?;
980
981 let value = input_json.ok_or_else(|| {
982 DurableError::custom(format!("task {} has no input", self.task_id))
983 })?;
984
985 serde_json::from_value(value)
986 .map_err(|e| DurableError::custom(format!("failed to deserialize input: {e}")))
987 }
988}
989
990#[allow(clippy::too_many_arguments)]
1011async fn find_or_create_task(
1012 db: &impl ConnectionTrait,
1013 parent_id: Option<Uuid>,
1014 sequence: Option<i32>,
1015 name: &str,
1016 kind: &str,
1017 input: Option<serde_json::Value>,
1018 lock: bool,
1019 max_retries: Option<u32>,
1020) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
1021 let parent_eq = match parent_id {
1022 Some(p) => format!("= '{p}'"),
1023 None => "IS NULL".to_string(),
1024 };
1025 let parent_sql = match parent_id {
1026 Some(p) => format!("'{p}'"),
1027 None => "NULL".to_string(),
1028 };
1029
1030 if lock {
1031 let new_id = Uuid::new_v4();
1045 let seq_sql = match sequence {
1046 Some(s) => s.to_string(),
1047 None => "NULL".to_string(),
1048 };
1049 let input_sql = match &input {
1050 Some(v) => format!("'{}'", serde_json::to_string(v)?),
1051 None => "NULL".to_string(),
1052 };
1053
1054 let max_retries_sql = match max_retries {
1055 Some(r) => r.to_string(),
1056 None => "3".to_string(), };
1058
1059 let insert_sql = format!(
1061 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
1062 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
1063 ON CONFLICT (parent_id, sequence) DO NOTHING"
1064 );
1065 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
1066 .await?;
1067
1068 let lock_sql = format!(
1070 "SELECT id, status::text, output FROM durable.task \
1071 WHERE parent_id {parent_eq} AND sequence = {seq_sql} \
1072 FOR UPDATE SKIP LOCKED"
1073 );
1074 let row = db
1075 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
1076 .await?;
1077
1078 if let Some(row) = row {
1079 let id: Uuid = row
1080 .try_get_by_index(0)
1081 .map_err(|e| DurableError::custom(e.to_string()))?;
1082 let status: String = row
1083 .try_get_by_index(1)
1084 .map_err(|e| DurableError::custom(e.to_string()))?;
1085 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
1086
1087 if status == TaskStatus::Completed.to_string() {
1088 return Ok((id, output));
1090 }
1091 return Ok((id, None));
1093 }
1094
1095 Err(DurableError::StepLocked(name.to_string()))
1097 } else {
1098 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
1102 query = match parent_id {
1103 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
1104 None => query.filter(TaskColumn::ParentId.is_null()),
1105 };
1106 let status_exclusions = vec![TaskStatus::Cancelled, TaskStatus::Failed];
1109 let existing = query
1110 .filter(TaskColumn::Status.is_not_in(status_exclusions))
1111 .one(db)
1112 .await?;
1113
1114 if let Some(model) = existing {
1115 if model.status == TaskStatus::Completed {
1116 return Ok((model.id, model.output));
1117 }
1118 return Ok((model.id, None));
1119 }
1120
1121 let id = Uuid::new_v4();
1123 let new_task = TaskActiveModel {
1124 id: Set(id),
1125 parent_id: Set(parent_id),
1126 sequence: Set(sequence),
1127 name: Set(name.to_string()),
1128 kind: Set(kind.to_string()),
1129 status: Set(TaskStatus::Pending),
1130 input: Set(input),
1131 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
1132 ..Default::default()
1133 };
1134 new_task.insert(db).await?;
1135
1136 Ok((id, None))
1137 }
1138}
1139
1140async fn get_output(
1141 db: &impl ConnectionTrait,
1142 task_id: Uuid,
1143) -> Result<Option<serde_json::Value>, DurableError> {
1144 let model = Task::find_by_id(task_id)
1145 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1146 .one(db)
1147 .await?;
1148
1149 Ok(model.and_then(|m| m.output))
1150}
1151
1152async fn get_status(
1153 db: &impl ConnectionTrait,
1154 task_id: Uuid,
1155) -> Result<Option<TaskStatus>, DurableError> {
1156 let model = Task::find_by_id(task_id).one(db).await?;
1157
1158 Ok(model.map(|m| m.status))
1159}
1160
1161async fn get_retry_info(
1163 db: &DatabaseConnection,
1164 task_id: Uuid,
1165) -> Result<(u32, u32), DurableError> {
1166 let model = Task::find_by_id(task_id).one(db).await?;
1167
1168 match model {
1169 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1170 None => Err(DurableError::custom(format!(
1171 "task {task_id} not found when reading retry info"
1172 ))),
1173 }
1174}
1175
1176async fn increment_retry_count(
1178 db: &DatabaseConnection,
1179 task_id: Uuid,
1180) -> Result<u32, DurableError> {
1181 let model = Task::find_by_id(task_id).one(db).await?;
1182
1183 match model {
1184 Some(m) => {
1185 let new_count = m.retry_count + 1;
1186 let mut active: TaskActiveModel = m.into();
1187 active.retry_count = Set(new_count);
1188 active.status = Set(TaskStatus::Pending);
1189 active.error = Set(None);
1190 active.completed_at = Set(None);
1191 active.update(db).await?;
1192 Ok(new_count as u32)
1193 }
1194 None => Err(DurableError::custom(format!(
1195 "task {task_id} not found when incrementing retry count"
1196 ))),
1197 }
1198}
1199
1200async fn set_status(
1203 db: &impl ConnectionTrait,
1204 task_id: Uuid,
1205 status: TaskStatus,
1206) -> Result<(), DurableError> {
1207 let sql = format!(
1208 "UPDATE durable.task \
1209 SET status = '{status}', \
1210 started_at = COALESCE(started_at, now()), \
1211 deadline_epoch_ms = CASE \
1212 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1213 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1214 ELSE deadline_epoch_ms \
1215 END \
1216 WHERE id = '{task_id}'"
1217 );
1218 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1219 .await?;
1220 Ok(())
1221}
1222
1223async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1225 let status = get_status(db, task_id).await?;
1226 match status {
1227 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1228 Some(TaskStatus::Cancelled) => {
1229 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1230 }
1231 _ => Ok(()),
1232 }
1233}
1234
1235async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1237 let model = Task::find_by_id(task_id).one(db).await?;
1238
1239 if let Some(m) = model
1240 && let Some(deadline_ms) = m.deadline_epoch_ms
1241 {
1242 let now_ms = std::time::SystemTime::now()
1243 .duration_since(std::time::UNIX_EPOCH)
1244 .map(|d| d.as_millis() as i64)
1245 .unwrap_or(0);
1246 if now_ms > deadline_ms {
1247 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1248 }
1249 }
1250
1251 Ok(())
1252}
1253
1254async fn complete_task(
1255 db: &impl ConnectionTrait,
1256 task_id: Uuid,
1257 output: serde_json::Value,
1258) -> Result<(), DurableError> {
1259 let model = Task::find_by_id(task_id).one(db).await?;
1260
1261 if let Some(m) = model {
1262 let mut active: TaskActiveModel = m.into();
1263 active.status = Set(TaskStatus::Completed);
1264 active.output = Set(Some(output));
1265 active.completed_at = Set(Some(chrono::Utc::now().into()));
1266 active.update(db).await?;
1267 }
1268 Ok(())
1269}
1270
1271async fn fail_task(
1272 db: &impl ConnectionTrait,
1273 task_id: Uuid,
1274 error: &str,
1275) -> Result<(), DurableError> {
1276 let model = Task::find_by_id(task_id).one(db).await?;
1277
1278 if let Some(m) = model {
1279 let mut active: TaskActiveModel = m.into();
1280 active.status = Set(TaskStatus::Failed);
1281 active.error = Set(Some(error.to_string()));
1282 active.completed_at = Set(Some(chrono::Utc::now().into()));
1283 active.update(db).await?;
1284 }
1285 Ok(())
1286}
1287
1288#[cfg(test)]
1289mod tests {
1290 use super::*;
1291 use std::sync::Arc;
1292 use std::sync::atomic::{AtomicU32, Ordering};
1293
1294 #[tokio::test]
1297 async fn test_retry_db_write_succeeds_first_try() {
1298 let call_count = Arc::new(AtomicU32::new(0));
1299 let cc = call_count.clone();
1300 let result = retry_db_write(|| {
1301 let c = cc.clone();
1302 async move {
1303 c.fetch_add(1, Ordering::SeqCst);
1304 Ok::<(), DurableError>(())
1305 }
1306 })
1307 .await;
1308 assert!(result.is_ok());
1309 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1310 }
1311
1312 #[tokio::test]
1315 async fn test_retry_db_write_succeeds_after_transient_failure() {
1316 let call_count = Arc::new(AtomicU32::new(0));
1317 let cc = call_count.clone();
1318 let result = retry_db_write(|| {
1319 let c = cc.clone();
1320 async move {
1321 let n = c.fetch_add(1, Ordering::SeqCst);
1322 if n < 2 {
1323 Err(DurableError::Db(sea_orm::DbErr::Custom(
1324 "transient".to_string(),
1325 )))
1326 } else {
1327 Ok(())
1328 }
1329 }
1330 })
1331 .await;
1332 assert!(result.is_ok());
1333 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1334 }
1335
1336 #[tokio::test]
1339 async fn test_retry_db_write_exhausts_retries() {
1340 let call_count = Arc::new(AtomicU32::new(0));
1341 let cc = call_count.clone();
1342 let result = retry_db_write(|| {
1343 let c = cc.clone();
1344 async move {
1345 c.fetch_add(1, Ordering::SeqCst);
1346 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1347 "always fails".to_string(),
1348 )))
1349 }
1350 })
1351 .await;
1352 assert!(result.is_err());
1353 assert_eq!(
1355 call_count.load(Ordering::SeqCst),
1356 1 + MAX_CHECKPOINT_RETRIES
1357 );
1358 }
1359
1360 #[tokio::test]
1363 async fn test_retry_db_write_returns_original_error() {
1364 let call_count = Arc::new(AtomicU32::new(0));
1365 let cc = call_count.clone();
1366 let result = retry_db_write(|| {
1367 let c = cc.clone();
1368 async move {
1369 let n = c.fetch_add(1, Ordering::SeqCst);
1370 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1371 "error-{}",
1372 n
1373 ))))
1374 }
1375 })
1376 .await;
1377 let err = result.unwrap_err();
1378 assert!(
1380 err.to_string().contains("error-0"),
1381 "expected first error (error-0), got: {err}"
1382 );
1383 }
1384}