1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub status: TaskStatus,
200 pub kind: String,
201 pub input: Option<serde_json::Value>,
202 pub output: Option<serde_json::Value>,
203 pub error: Option<String>,
204 pub queue_name: Option<String>,
205 pub created_at: chrono::DateTime<chrono::FixedOffset>,
206 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211 fn from(m: durable_db::entity::task::Model) -> Self {
212 Self {
213 id: m.id,
214 parent_id: m.parent_id,
215 name: m.name,
216 status: m.status,
217 kind: m.kind,
218 input: m.input,
219 output: m.output,
220 error: m.error,
221 queue_name: m.queue_name,
222 created_at: m.created_at,
223 started_at: m.started_at,
224 completed_at: m.completed_at,
225 }
226 }
227}
228
229pub struct Ctx {
235 db: DatabaseConnection,
236 task_id: Uuid,
237 sequence: AtomicI32,
238}
239
240impl Ctx {
241 pub async fn start(
249 db: &DatabaseConnection,
250 name: &str,
251 input: Option<serde_json::Value>,
252 ) -> Result<Self, DurableError> {
253 let txn = db.begin().await?;
254 let (task_id, _saved) =
258 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
259 retry_db_write(|| set_status(&txn, task_id, TaskStatus::Running)).await?;
260 txn.commit().await?;
261 Ok(Self {
262 db: db.clone(),
263 task_id,
264 sequence: AtomicI32::new(0),
265 })
266 }
267
268 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
277 where
278 T: Serialize + DeserializeOwned,
279 F: FnOnce() -> Fut,
280 Fut: std::future::Future<Output = Result<T, DurableError>>,
281 {
282 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
283
284 check_status(&self.db, self.task_id).await?;
286
287 check_deadline(&self.db, self.task_id).await?;
289
290 let txn = self.db.begin().await?;
292
293 let (step_id, saved_output) = find_or_create_task(
298 &txn,
299 Some(self.task_id),
300 Some(seq),
301 name,
302 "STEP",
303 None,
304 true,
305 Some(0),
306 )
307 .await?;
308
309 if let Some(output) = saved_output {
311 txn.commit().await?;
312 let val: T = serde_json::from_value(output)?;
313 tracing::debug!(step = name, seq, "replaying saved output");
314 return Ok(val);
315 }
316
317 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
319 match f().await {
320 Ok(val) => {
321 let json = serde_json::to_value(&val)?;
322 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
323 txn.commit().await?;
324 tracing::debug!(step = name, seq, "step completed");
325 Ok(val)
326 }
327 Err(e) => {
328 let err_msg = e.to_string();
329 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
330 txn.commit().await?;
331 Err(e)
332 }
333 }
334 }
335
336 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
348 where
349 T: Serialize + DeserializeOwned + Send,
350 F: for<'tx> FnOnce(
351 &'tx DatabaseTransaction,
352 ) -> Pin<
353 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
354 > + Send,
355 {
356 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
357
358 check_status(&self.db, self.task_id).await?;
360
361 let (step_id, saved_output) = find_or_create_task(
364 &self.db,
365 Some(self.task_id),
366 Some(seq),
367 name,
368 "TRANSACTION",
369 None,
370 false,
371 None,
372 )
373 .await?;
374
375 if let Some(output) = saved_output {
377 let val: T = serde_json::from_value(output)?;
378 tracing::debug!(step = name, seq, "replaying saved transaction output");
379 return Ok(val);
380 }
381
382 let tx = self.db.begin().await?;
384
385 set_status(&tx, step_id, TaskStatus::Running).await?;
386
387 match f(&tx).await {
388 Ok(val) => {
389 let json = serde_json::to_value(&val)?;
390 complete_task(&tx, step_id, json).await?;
391 tx.commit().await?;
392 tracing::debug!(step = name, seq, "transaction step committed");
393 Ok(val)
394 }
395 Err(e) => {
396 drop(tx);
399 fail_task(&self.db, step_id, &e.to_string()).await?;
400 Err(e)
401 }
402 }
403 }
404
405 pub async fn child(
413 &self,
414 name: &str,
415 input: Option<serde_json::Value>,
416 ) -> Result<Self, DurableError> {
417 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
418
419 check_status(&self.db, self.task_id).await?;
421
422 let txn = self.db.begin().await?;
423 let (child_id, _saved) = find_or_create_task(
425 &txn,
426 Some(self.task_id),
427 Some(seq),
428 name,
429 "WORKFLOW",
430 input,
431 false,
432 None,
433 )
434 .await?;
435
436 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
439 txn.commit().await?;
440
441 Ok(Self {
442 db: self.db.clone(),
443 task_id: child_id,
444 sequence: AtomicI32::new(0),
445 })
446 }
447
448 pub async fn is_completed(&self) -> Result<bool, DurableError> {
450 let status = get_status(&self.db, self.task_id).await?;
451 Ok(status == Some(TaskStatus::Completed))
452 }
453
454 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
456 match get_output(&self.db, self.task_id).await? {
457 Some(val) => Ok(Some(serde_json::from_value(val)?)),
458 None => Ok(None),
459 }
460 }
461
462 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
464 let json = serde_json::to_value(output)?;
465 let db = &self.db;
466 let task_id = self.task_id;
467 retry_db_write(|| complete_task(db, task_id, json.clone())).await
468 }
469
470 pub async fn step_with_retry<T, F, Fut>(
484 &self,
485 name: &str,
486 policy: RetryPolicy,
487 f: F,
488 ) -> Result<T, DurableError>
489 where
490 T: Serialize + DeserializeOwned,
491 F: Fn() -> Fut,
492 Fut: std::future::Future<Output = Result<T, DurableError>>,
493 {
494 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
495
496 check_status(&self.db, self.task_id).await?;
498
499 let (step_id, saved_output) = find_or_create_task(
503 &self.db,
504 Some(self.task_id),
505 Some(seq),
506 name,
507 "STEP",
508 None,
509 false,
510 Some(policy.max_retries),
511 )
512 .await?;
513
514 if let Some(output) = saved_output {
516 let val: T = serde_json::from_value(output)?;
517 tracing::debug!(step = name, seq, "replaying saved output");
518 return Ok(val);
519 }
520
521 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
523
524 loop {
526 check_status(&self.db, self.task_id).await?;
528 set_status(&self.db, step_id, TaskStatus::Running).await?;
529 match f().await {
530 Ok(val) => {
531 let json = serde_json::to_value(&val)?;
532 complete_task(&self.db, step_id, json).await?;
533 tracing::debug!(step = name, seq, retry_count, "step completed");
534 return Ok(val);
535 }
536 Err(e) => {
537 if retry_count < max_retries {
538 retry_count = increment_retry_count(&self.db, step_id).await?;
540 tracing::debug!(
541 step = name,
542 seq,
543 retry_count,
544 max_retries,
545 "step failed, retrying"
546 );
547
548 let backoff = if policy.initial_backoff.is_zero() {
550 std::time::Duration::ZERO
551 } else {
552 let factor = policy
553 .backoff_multiplier
554 .powi((retry_count - 1) as i32)
555 .max(1.0);
556 let millis =
557 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
558 std::time::Duration::from_millis(millis)
559 };
560
561 if !backoff.is_zero() {
562 tokio::time::sleep(backoff).await;
563 }
564 } else {
565 fail_task(&self.db, step_id, &e.to_string()).await?;
567 tracing::debug!(
568 step = name,
569 seq,
570 retry_count,
571 "step exhausted retries, marked FAILED"
572 );
573 return Err(e);
574 }
575 }
576 }
577 }
578 }
579
580 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
582 let db = &self.db;
583 let task_id = self.task_id;
584 retry_db_write(|| fail_task(db, task_id, error)).await
585 }
586
587 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
595 let sql = format!(
596 "UPDATE durable.task \
597 SET timeout_ms = {timeout_ms}, \
598 deadline_epoch_ms = CASE \
599 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
600 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
601 ELSE deadline_epoch_ms \
602 END \
603 WHERE id = '{}'",
604 self.task_id
605 );
606 self.db
607 .execute(Statement::from_string(DbBackend::Postgres, sql))
608 .await?;
609 Ok(())
610 }
611
612 pub async fn start_with_timeout(
616 db: &DatabaseConnection,
617 name: &str,
618 input: Option<serde_json::Value>,
619 timeout_ms: i64,
620 ) -> Result<Self, DurableError> {
621 let ctx = Self::start(db, name, input).await?;
622 ctx.set_timeout(timeout_ms).await?;
623 Ok(ctx)
624 }
625
626 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
633 let model = Task::find_by_id(task_id).one(db).await?;
634 let model =
635 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
636
637 match model.status {
638 TaskStatus::Pending | TaskStatus::Running => {}
639 status => {
640 return Err(DurableError::custom(format!(
641 "cannot pause task in {status} status"
642 )));
643 }
644 }
645
646 let sql = format!(
648 "WITH RECURSIVE descendants AS ( \
649 SELECT id FROM durable.task WHERE id = '{task_id}' \
650 UNION ALL \
651 SELECT t.id FROM durable.task t \
652 INNER JOIN descendants d ON t.parent_id = d.id \
653 ) \
654 UPDATE durable.task SET status = 'PAUSED' \
655 WHERE id IN (SELECT id FROM descendants) \
656 AND status IN ('PENDING', 'RUNNING')"
657 );
658 db.execute(Statement::from_string(DbBackend::Postgres, sql))
659 .await?;
660
661 tracing::info!(%task_id, "workflow paused");
662 Ok(())
663 }
664
665 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
668 let model = Task::find_by_id(task_id).one(db).await?;
669 let model =
670 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
671
672 if model.status != TaskStatus::Paused {
673 return Err(DurableError::custom(format!(
674 "cannot resume task in {} status (must be PAUSED)",
675 model.status
676 )));
677 }
678
679 let mut active: TaskActiveModel = model.into();
681 active.status = Set(TaskStatus::Running);
682 active.update(db).await?;
683
684 let cascade_sql = format!(
686 "WITH RECURSIVE descendants AS ( \
687 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
688 UNION ALL \
689 SELECT t.id FROM durable.task t \
690 INNER JOIN descendants d ON t.parent_id = d.id \
691 ) \
692 UPDATE durable.task SET status = 'PENDING' \
693 WHERE id IN (SELECT id FROM descendants) \
694 AND status = 'PAUSED'"
695 );
696 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
697 .await?;
698
699 tracing::info!(%task_id, "workflow resumed");
700 Ok(())
701 }
702
703 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
708 let model = Task::find_by_id(task_id).one(db).await?;
709 let model =
710 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
711
712 match model.status {
713 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
714 return Err(DurableError::custom(format!(
715 "cannot cancel task in {} status",
716 model.status
717 )));
718 }
719 _ => {}
720 }
721
722 let sql = format!(
724 "WITH RECURSIVE descendants AS ( \
725 SELECT id FROM durable.task WHERE id = '{task_id}' \
726 UNION ALL \
727 SELECT t.id FROM durable.task t \
728 INNER JOIN descendants d ON t.parent_id = d.id \
729 ) \
730 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
731 WHERE id IN (SELECT id FROM descendants) \
732 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
733 );
734 db.execute(Statement::from_string(DbBackend::Postgres, sql))
735 .await?;
736
737 tracing::info!(%task_id, "workflow cancelled");
738 Ok(())
739 }
740
741 pub async fn list(
749 db: &DatabaseConnection,
750 query: TaskQuery,
751 ) -> Result<Vec<TaskSummary>, DurableError> {
752 let mut select = Task::find();
753
754 if let Some(status) = &query.status {
756 select = select.filter(TaskColumn::Status.eq(status.to_string()));
757 }
758 if let Some(kind) = &query.kind {
759 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
760 }
761 if let Some(parent_id) = query.parent_id {
762 select = select.filter(TaskColumn::ParentId.eq(parent_id));
763 }
764 if query.root_only {
765 select = select.filter(TaskColumn::ParentId.is_null());
766 }
767 if let Some(name) = &query.name {
768 select = select.filter(TaskColumn::Name.eq(name.as_str()));
769 }
770 if let Some(queue) = &query.queue_name {
771 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
772 }
773
774 let (col, order) = match query.sort {
776 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
777 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
778 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
779 TaskSort::Name(ord) => (TaskColumn::Name, ord),
780 TaskSort::Status(ord) => (TaskColumn::Status, ord),
781 };
782 select = select.order_by(col, order);
783
784 if let Some(offset) = query.offset {
786 select = select.offset(offset);
787 }
788 if let Some(limit) = query.limit {
789 select = select.limit(limit);
790 }
791
792 let models = select.all(db).await?;
793
794 Ok(models.into_iter().map(TaskSummary::from).collect())
795 }
796
797 pub async fn count(
799 db: &DatabaseConnection,
800 query: TaskQuery,
801 ) -> Result<u64, DurableError> {
802 let mut select = Task::find();
803
804 if let Some(status) = &query.status {
805 select = select.filter(TaskColumn::Status.eq(status.to_string()));
806 }
807 if let Some(kind) = &query.kind {
808 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
809 }
810 if let Some(parent_id) = query.parent_id {
811 select = select.filter(TaskColumn::ParentId.eq(parent_id));
812 }
813 if query.root_only {
814 select = select.filter(TaskColumn::ParentId.is_null());
815 }
816 if let Some(name) = &query.name {
817 select = select.filter(TaskColumn::Name.eq(name.as_str()));
818 }
819 if let Some(queue) = &query.queue_name {
820 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
821 }
822
823 let count = select.count(db).await?;
824 Ok(count)
825 }
826
827 pub fn db(&self) -> &DatabaseConnection {
830 &self.db
831 }
832
833 pub fn task_id(&self) -> Uuid {
834 self.task_id
835 }
836
837 pub fn next_sequence(&self) -> i32 {
838 self.sequence.fetch_add(1, Ordering::SeqCst)
839 }
840}
841
842#[allow(clippy::too_many_arguments)]
863async fn find_or_create_task(
864 db: &impl ConnectionTrait,
865 parent_id: Option<Uuid>,
866 sequence: Option<i32>,
867 name: &str,
868 kind: &str,
869 input: Option<serde_json::Value>,
870 lock: bool,
871 max_retries: Option<u32>,
872) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
873 let parent_eq = match parent_id {
874 Some(p) => format!("= '{p}'"),
875 None => "IS NULL".to_string(),
876 };
877 let parent_sql = match parent_id {
878 Some(p) => format!("'{p}'"),
879 None => "NULL".to_string(),
880 };
881
882 if lock {
883 let new_id = Uuid::new_v4();
897 let seq_sql = match sequence {
898 Some(s) => s.to_string(),
899 None => "NULL".to_string(),
900 };
901 let input_sql = match &input {
902 Some(v) => format!("'{}'", serde_json::to_string(v)?),
903 None => "NULL".to_string(),
904 };
905
906 let max_retries_sql = match max_retries {
907 Some(r) => r.to_string(),
908 None => "3".to_string(), };
910
911 let insert_sql = format!(
913 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
914 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
915 ON CONFLICT (parent_id, sequence) DO NOTHING"
916 );
917 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
918 .await?;
919
920 let lock_sql = format!(
922 "SELECT id, status::text, output FROM durable.task \
923 WHERE parent_id {parent_eq} AND name = '{name}' \
924 FOR UPDATE SKIP LOCKED"
925 );
926 let row = db
927 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
928 .await?;
929
930 if let Some(row) = row {
931 let id: Uuid = row
932 .try_get_by_index(0)
933 .map_err(|e| DurableError::custom(e.to_string()))?;
934 let status: String = row
935 .try_get_by_index(1)
936 .map_err(|e| DurableError::custom(e.to_string()))?;
937 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
938
939 if status == TaskStatus::Completed.to_string() {
940 return Ok((id, output));
942 }
943 return Ok((id, None));
945 }
946
947 Err(DurableError::StepLocked(name.to_string()))
949 } else {
950 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
954 query = match parent_id {
955 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
956 None => query.filter(TaskColumn::ParentId.is_null()),
957 };
958 let existing = query
961 .filter(TaskColumn::Status.is_not_in([
962 TaskStatus::Cancelled,
963 TaskStatus::Failed,
964 ]))
965 .one(db)
966 .await?;
967
968 if let Some(model) = existing {
969 if model.status == TaskStatus::Completed {
970 return Ok((model.id, model.output));
971 }
972 return Ok((model.id, None));
973 }
974
975 let id = Uuid::new_v4();
977 let new_task = TaskActiveModel {
978 id: Set(id),
979 parent_id: Set(parent_id),
980 sequence: Set(sequence),
981 name: Set(name.to_string()),
982 kind: Set(kind.to_string()),
983 status: Set(TaskStatus::Pending),
984 input: Set(input),
985 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
986 ..Default::default()
987 };
988 new_task.insert(db).await?;
989
990 Ok((id, None))
991 }
992}
993
994async fn get_output(
995 db: &impl ConnectionTrait,
996 task_id: Uuid,
997) -> Result<Option<serde_json::Value>, DurableError> {
998 let model = Task::find_by_id(task_id)
999 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
1000 .one(db)
1001 .await?;
1002
1003 Ok(model.and_then(|m| m.output))
1004}
1005
1006async fn get_status(
1007 db: &impl ConnectionTrait,
1008 task_id: Uuid,
1009) -> Result<Option<TaskStatus>, DurableError> {
1010 let model = Task::find_by_id(task_id).one(db).await?;
1011
1012 Ok(model.map(|m| m.status))
1013}
1014
1015async fn get_retry_info(
1017 db: &DatabaseConnection,
1018 task_id: Uuid,
1019) -> Result<(u32, u32), DurableError> {
1020 let model = Task::find_by_id(task_id).one(db).await?;
1021
1022 match model {
1023 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1024 None => Err(DurableError::custom(format!(
1025 "task {task_id} not found when reading retry info"
1026 ))),
1027 }
1028}
1029
1030async fn increment_retry_count(
1032 db: &DatabaseConnection,
1033 task_id: Uuid,
1034) -> Result<u32, DurableError> {
1035 let model = Task::find_by_id(task_id).one(db).await?;
1036
1037 match model {
1038 Some(m) => {
1039 let new_count = m.retry_count + 1;
1040 let mut active: TaskActiveModel = m.into();
1041 active.retry_count = Set(new_count);
1042 active.status = Set(TaskStatus::Pending);
1043 active.error = Set(None);
1044 active.completed_at = Set(None);
1045 active.update(db).await?;
1046 Ok(new_count as u32)
1047 }
1048 None => Err(DurableError::custom(format!(
1049 "task {task_id} not found when incrementing retry count"
1050 ))),
1051 }
1052}
1053
1054async fn set_status(
1057 db: &impl ConnectionTrait,
1058 task_id: Uuid,
1059 status: TaskStatus,
1060) -> Result<(), DurableError> {
1061 let sql = format!(
1062 "UPDATE durable.task \
1063 SET status = '{status}', \
1064 started_at = COALESCE(started_at, now()), \
1065 deadline_epoch_ms = CASE \
1066 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1067 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1068 ELSE deadline_epoch_ms \
1069 END \
1070 WHERE id = '{task_id}'"
1071 );
1072 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1073 .await?;
1074 Ok(())
1075}
1076
1077async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1079 let status = get_status(db, task_id).await?;
1080 match status {
1081 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1082 Some(TaskStatus::Cancelled) => {
1083 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1084 }
1085 _ => Ok(()),
1086 }
1087}
1088
1089async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1091 let model = Task::find_by_id(task_id).one(db).await?;
1092
1093 if let Some(m) = model
1094 && let Some(deadline_ms) = m.deadline_epoch_ms
1095 {
1096 let now_ms = std::time::SystemTime::now()
1097 .duration_since(std::time::UNIX_EPOCH)
1098 .map(|d| d.as_millis() as i64)
1099 .unwrap_or(0);
1100 if now_ms > deadline_ms {
1101 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1102 }
1103 }
1104
1105 Ok(())
1106}
1107
1108async fn complete_task(
1109 db: &impl ConnectionTrait,
1110 task_id: Uuid,
1111 output: serde_json::Value,
1112) -> Result<(), DurableError> {
1113 let model = Task::find_by_id(task_id).one(db).await?;
1114
1115 if let Some(m) = model {
1116 let mut active: TaskActiveModel = m.into();
1117 active.status = Set(TaskStatus::Completed);
1118 active.output = Set(Some(output));
1119 active.completed_at = Set(Some(chrono::Utc::now().into()));
1120 active.update(db).await?;
1121 }
1122 Ok(())
1123}
1124
1125async fn fail_task(
1126 db: &impl ConnectionTrait,
1127 task_id: Uuid,
1128 error: &str,
1129) -> Result<(), DurableError> {
1130 let model = Task::find_by_id(task_id).one(db).await?;
1131
1132 if let Some(m) = model {
1133 let mut active: TaskActiveModel = m.into();
1134 active.status = Set(TaskStatus::Failed);
1135 active.error = Set(Some(error.to_string()));
1136 active.completed_at = Set(Some(chrono::Utc::now().into()));
1137 active.update(db).await?;
1138 }
1139 Ok(())
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144 use super::*;
1145 use std::sync::Arc;
1146 use std::sync::atomic::{AtomicU32, Ordering};
1147
1148 #[tokio::test]
1151 async fn test_retry_db_write_succeeds_first_try() {
1152 let call_count = Arc::new(AtomicU32::new(0));
1153 let cc = call_count.clone();
1154 let result = retry_db_write(|| {
1155 let c = cc.clone();
1156 async move {
1157 c.fetch_add(1, Ordering::SeqCst);
1158 Ok::<(), DurableError>(())
1159 }
1160 })
1161 .await;
1162 assert!(result.is_ok());
1163 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1164 }
1165
1166 #[tokio::test]
1169 async fn test_retry_db_write_succeeds_after_transient_failure() {
1170 let call_count = Arc::new(AtomicU32::new(0));
1171 let cc = call_count.clone();
1172 let result = retry_db_write(|| {
1173 let c = cc.clone();
1174 async move {
1175 let n = c.fetch_add(1, Ordering::SeqCst);
1176 if n < 2 {
1177 Err(DurableError::Db(sea_orm::DbErr::Custom(
1178 "transient".to_string(),
1179 )))
1180 } else {
1181 Ok(())
1182 }
1183 }
1184 })
1185 .await;
1186 assert!(result.is_ok());
1187 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1188 }
1189
1190 #[tokio::test]
1193 async fn test_retry_db_write_exhausts_retries() {
1194 let call_count = Arc::new(AtomicU32::new(0));
1195 let cc = call_count.clone();
1196 let result = retry_db_write(|| {
1197 let c = cc.clone();
1198 async move {
1199 c.fetch_add(1, Ordering::SeqCst);
1200 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1201 "always fails".to_string(),
1202 )))
1203 }
1204 })
1205 .await;
1206 assert!(result.is_err());
1207 assert_eq!(
1209 call_count.load(Ordering::SeqCst),
1210 1 + MAX_CHECKPOINT_RETRIES
1211 );
1212 }
1213
1214 #[tokio::test]
1217 async fn test_retry_db_write_returns_original_error() {
1218 let call_count = Arc::new(AtomicU32::new(0));
1219 let cc = call_count.clone();
1220 let result = retry_db_write(|| {
1221 let c = cc.clone();
1222 async move {
1223 let n = c.fetch_add(1, Ordering::SeqCst);
1224 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1225 "error-{}",
1226 n
1227 ))))
1228 }
1229 })
1230 .await;
1231 let err = result.unwrap_err();
1232 assert!(
1234 err.to_string().contains("error-0"),
1235 "expected first error (error-0), got: {err}"
1236 );
1237 }
1238}