1use std::pin::Pin;
2
3use durable_db::entity::sea_orm_active_enums::TaskStatus;
4use durable_db::entity::task::{
5 ActiveModel as TaskActiveModel, Column as TaskColumn, Entity as Task,
6};
7use sea_orm::{
8 ActiveModelTrait, ColumnTrait, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
9 DbBackend, EntityTrait, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set,
10 Statement, TransactionTrait,
11};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::sync::atomic::{AtomicI32, Ordering};
15use std::time::Duration;
16use uuid::Uuid;
17
18use crate::error::DurableError;
19
20const MAX_CHECKPOINT_RETRIES: u32 = 3;
23const CHECKPOINT_RETRY_BASE_MS: u64 = 100;
24
25async fn retry_db_write<F, Fut>(mut f: F) -> Result<(), DurableError>
31where
32 F: FnMut() -> Fut,
33 Fut: std::future::Future<Output = Result<(), DurableError>>,
34{
35 match f().await {
36 Ok(()) => Ok(()),
37 Err(first_err) => {
38 for i in 0..MAX_CHECKPOINT_RETRIES {
39 tokio::time::sleep(Duration::from_millis(
40 CHECKPOINT_RETRY_BASE_MS * 2u64.pow(i),
41 ))
42 .await;
43 if f().await.is_ok() {
44 tracing::warn!(retry = i + 1, "checkpoint write succeeded on retry");
45 return Ok(());
46 }
47 }
48 Err(first_err)
49 }
50 }
51}
52
53pub struct RetryPolicy {
55 pub max_retries: u32,
56 pub initial_backoff: std::time::Duration,
57 pub backoff_multiplier: f64,
58}
59
60impl RetryPolicy {
61 pub fn none() -> Self {
63 Self {
64 max_retries: 0,
65 initial_backoff: std::time::Duration::from_secs(0),
66 backoff_multiplier: 1.0,
67 }
68 }
69
70 pub fn exponential(max_retries: u32, initial_backoff: std::time::Duration) -> Self {
72 Self {
73 max_retries,
74 initial_backoff,
75 backoff_multiplier: 2.0,
76 }
77 }
78
79 pub fn fixed(max_retries: u32, backoff: std::time::Duration) -> Self {
81 Self {
82 max_retries,
83 initial_backoff: backoff,
84 backoff_multiplier: 1.0,
85 }
86 }
87}
88
89pub enum TaskSort {
91 CreatedAt(Order),
92 StartedAt(Order),
93 CompletedAt(Order),
94 Name(Order),
95 Status(Order),
96}
97
98pub struct TaskQuery {
110 pub status: Option<TaskStatus>,
111 pub kind: Option<String>,
112 pub parent_id: Option<Uuid>,
113 pub root_only: bool,
114 pub name: Option<String>,
115 pub queue_name: Option<String>,
116 pub sort: TaskSort,
117 pub limit: Option<u64>,
118 pub offset: Option<u64>,
119}
120
121impl Default for TaskQuery {
122 fn default() -> Self {
123 Self {
124 status: None,
125 kind: None,
126 parent_id: None,
127 root_only: false,
128 name: None,
129 queue_name: None,
130 sort: TaskSort::CreatedAt(Order::Desc),
131 limit: None,
132 offset: None,
133 }
134 }
135}
136
137impl TaskQuery {
138 pub fn status(mut self, status: TaskStatus) -> Self {
140 self.status = Some(status);
141 self
142 }
143
144 pub fn kind(mut self, kind: &str) -> Self {
146 self.kind = Some(kind.to_string());
147 self
148 }
149
150 pub fn parent_id(mut self, parent_id: Uuid) -> Self {
152 self.parent_id = Some(parent_id);
153 self
154 }
155
156 pub fn root_only(mut self, root_only: bool) -> Self {
158 self.root_only = root_only;
159 self
160 }
161
162 pub fn name(mut self, name: &str) -> Self {
164 self.name = Some(name.to_string());
165 self
166 }
167
168 pub fn queue_name(mut self, queue: &str) -> Self {
170 self.queue_name = Some(queue.to_string());
171 self
172 }
173
174 pub fn sort(mut self, sort: TaskSort) -> Self {
176 self.sort = sort;
177 self
178 }
179
180 pub fn limit(mut self, limit: u64) -> Self {
182 self.limit = Some(limit);
183 self
184 }
185
186 pub fn offset(mut self, offset: u64) -> Self {
188 self.offset = Some(offset);
189 self
190 }
191}
192
193#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
195pub struct TaskSummary {
196 pub id: Uuid,
197 pub parent_id: Option<Uuid>,
198 pub name: String,
199 pub status: TaskStatus,
200 pub kind: String,
201 pub input: Option<serde_json::Value>,
202 pub output: Option<serde_json::Value>,
203 pub error: Option<String>,
204 pub queue_name: Option<String>,
205 pub created_at: chrono::DateTime<chrono::FixedOffset>,
206 pub started_at: Option<chrono::DateTime<chrono::FixedOffset>>,
207 pub completed_at: Option<chrono::DateTime<chrono::FixedOffset>>,
208}
209
210impl From<durable_db::entity::task::Model> for TaskSummary {
211 fn from(m: durable_db::entity::task::Model) -> Self {
212 Self {
213 id: m.id,
214 parent_id: m.parent_id,
215 name: m.name,
216 status: m.status,
217 kind: m.kind,
218 input: m.input,
219 output: m.output,
220 error: m.error,
221 queue_name: m.queue_name,
222 created_at: m.created_at,
223 started_at: m.started_at,
224 completed_at: m.completed_at,
225 }
226 }
227}
228
229pub struct Ctx {
235 db: DatabaseConnection,
236 task_id: Uuid,
237 sequence: AtomicI32,
238}
239
240impl Ctx {
241 pub async fn start(
249 db: &DatabaseConnection,
250 name: &str,
251 input: Option<serde_json::Value>,
252 ) -> Result<Self, DurableError> {
253 let txn = db.begin().await?;
254 let (task_id, _saved) =
258 find_or_create_task(&txn, None, None, name, "WORKFLOW", input, false, None).await?;
259 retry_db_write(|| set_status(&txn, task_id, TaskStatus::Running)).await?;
260 txn.commit().await?;
261 Ok(Self {
262 db: db.clone(),
263 task_id,
264 sequence: AtomicI32::new(0),
265 })
266 }
267
268 pub async fn step<T, F, Fut>(&self, name: &str, f: F) -> Result<T, DurableError>
277 where
278 T: Serialize + DeserializeOwned,
279 F: FnOnce() -> Fut,
280 Fut: std::future::Future<Output = Result<T, DurableError>>,
281 {
282 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
283
284 check_status(&self.db, self.task_id).await?;
286
287 check_deadline(&self.db, self.task_id).await?;
289
290 let txn = self.db.begin().await?;
292
293 let (step_id, saved_output) = find_or_create_task(
298 &txn,
299 Some(self.task_id),
300 Some(seq),
301 name,
302 "STEP",
303 None,
304 true,
305 Some(0),
306 )
307 .await?;
308
309 if let Some(output) = saved_output {
311 txn.commit().await?;
312 let val: T = serde_json::from_value(output)?;
313 tracing::debug!(step = name, seq, "replaying saved output");
314 return Ok(val);
315 }
316
317 retry_db_write(|| set_status(&txn, step_id, TaskStatus::Running)).await?;
319 match f().await {
320 Ok(val) => {
321 let json = serde_json::to_value(&val)?;
322 retry_db_write(|| complete_task(&txn, step_id, json.clone())).await?;
323 txn.commit().await?;
324 tracing::debug!(step = name, seq, "step completed");
325 Ok(val)
326 }
327 Err(e) => {
328 let err_msg = e.to_string();
329 retry_db_write(|| fail_task(&txn, step_id, &err_msg)).await?;
330 txn.commit().await?;
331 Err(e)
332 }
333 }
334 }
335
336 pub async fn transaction<T, F>(&self, name: &str, f: F) -> Result<T, DurableError>
348 where
349 T: Serialize + DeserializeOwned + Send,
350 F: for<'tx> FnOnce(
351 &'tx DatabaseTransaction,
352 ) -> Pin<
353 Box<dyn std::future::Future<Output = Result<T, DurableError>> + Send + 'tx>,
354 > + Send,
355 {
356 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
357
358 check_status(&self.db, self.task_id).await?;
360
361 let (step_id, saved_output) = find_or_create_task(
364 &self.db,
365 Some(self.task_id),
366 Some(seq),
367 name,
368 "TRANSACTION",
369 None,
370 false,
371 None,
372 )
373 .await?;
374
375 if let Some(output) = saved_output {
377 let val: T = serde_json::from_value(output)?;
378 tracing::debug!(step = name, seq, "replaying saved transaction output");
379 return Ok(val);
380 }
381
382 let tx = self.db.begin().await?;
384
385 set_status(&tx, step_id, TaskStatus::Running).await?;
386
387 match f(&tx).await {
388 Ok(val) => {
389 let json = serde_json::to_value(&val)?;
390 complete_task(&tx, step_id, json).await?;
391 tx.commit().await?;
392 tracing::debug!(step = name, seq, "transaction step committed");
393 Ok(val)
394 }
395 Err(e) => {
396 drop(tx);
399 fail_task(&self.db, step_id, &e.to_string()).await?;
400 Err(e)
401 }
402 }
403 }
404
405 pub async fn child(
413 &self,
414 name: &str,
415 input: Option<serde_json::Value>,
416 ) -> Result<Self, DurableError> {
417 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
418
419 check_status(&self.db, self.task_id).await?;
421
422 let txn = self.db.begin().await?;
423 let (child_id, _saved) = find_or_create_task(
425 &txn,
426 Some(self.task_id),
427 Some(seq),
428 name,
429 "WORKFLOW",
430 input,
431 false,
432 None,
433 )
434 .await?;
435
436 retry_db_write(|| set_status(&txn, child_id, TaskStatus::Running)).await?;
439 txn.commit().await?;
440
441 Ok(Self {
442 db: self.db.clone(),
443 task_id: child_id,
444 sequence: AtomicI32::new(0),
445 })
446 }
447
448 pub async fn is_completed(&self) -> Result<bool, DurableError> {
450 let status = get_status(&self.db, self.task_id).await?;
451 Ok(status == Some(TaskStatus::Completed))
452 }
453
454 pub async fn get_output<T: DeserializeOwned>(&self) -> Result<Option<T>, DurableError> {
456 match get_output(&self.db, self.task_id).await? {
457 Some(val) => Ok(Some(serde_json::from_value(val)?)),
458 None => Ok(None),
459 }
460 }
461
462 pub async fn complete<T: Serialize>(&self, output: &T) -> Result<(), DurableError> {
464 let json = serde_json::to_value(output)?;
465 let db = &self.db;
466 let task_id = self.task_id;
467 retry_db_write(|| complete_task(db, task_id, json.clone())).await
468 }
469
470 pub async fn step_with_retry<T, F, Fut>(
484 &self,
485 name: &str,
486 policy: RetryPolicy,
487 f: F,
488 ) -> Result<T, DurableError>
489 where
490 T: Serialize + DeserializeOwned,
491 F: Fn() -> Fut,
492 Fut: std::future::Future<Output = Result<T, DurableError>>,
493 {
494 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
495
496 check_status(&self.db, self.task_id).await?;
498
499 let (step_id, saved_output) = find_or_create_task(
503 &self.db,
504 Some(self.task_id),
505 Some(seq),
506 name,
507 "STEP",
508 None,
509 false,
510 Some(policy.max_retries),
511 )
512 .await?;
513
514 if let Some(output) = saved_output {
516 let val: T = serde_json::from_value(output)?;
517 tracing::debug!(step = name, seq, "replaying saved output");
518 return Ok(val);
519 }
520
521 let (mut retry_count, max_retries) = get_retry_info(&self.db, step_id).await?;
523
524 loop {
526 check_status(&self.db, self.task_id).await?;
528 set_status(&self.db, step_id, TaskStatus::Running).await?;
529 match f().await {
530 Ok(val) => {
531 let json = serde_json::to_value(&val)?;
532 complete_task(&self.db, step_id, json).await?;
533 tracing::debug!(step = name, seq, retry_count, "step completed");
534 return Ok(val);
535 }
536 Err(e) => {
537 if retry_count < max_retries {
538 retry_count = increment_retry_count(&self.db, step_id).await?;
540 tracing::debug!(
541 step = name,
542 seq,
543 retry_count,
544 max_retries,
545 "step failed, retrying"
546 );
547
548 let backoff = if policy.initial_backoff.is_zero() {
550 std::time::Duration::ZERO
551 } else {
552 let factor = policy
553 .backoff_multiplier
554 .powi((retry_count - 1) as i32)
555 .max(1.0);
556 let millis =
557 (policy.initial_backoff.as_millis() as f64 * factor) as u64;
558 std::time::Duration::from_millis(millis)
559 };
560
561 if !backoff.is_zero() {
562 tokio::time::sleep(backoff).await;
563 }
564 } else {
565 fail_task(&self.db, step_id, &e.to_string()).await?;
567 tracing::debug!(
568 step = name,
569 seq,
570 retry_count,
571 "step exhausted retries, marked FAILED"
572 );
573 return Err(e);
574 }
575 }
576 }
577 }
578 }
579
580 pub async fn fail(&self, error: &str) -> Result<(), DurableError> {
582 let db = &self.db;
583 let task_id = self.task_id;
584 retry_db_write(|| fail_task(db, task_id, error)).await
585 }
586
587 pub async fn set_timeout(&self, timeout_ms: i64) -> Result<(), DurableError> {
595 let sql = format!(
596 "UPDATE durable.task \
597 SET timeout_ms = {timeout_ms}, \
598 deadline_epoch_ms = CASE \
599 WHEN status = 'RUNNING' AND started_at IS NOT NULL \
600 THEN EXTRACT(EPOCH FROM started_at) * 1000 + {timeout_ms} \
601 ELSE deadline_epoch_ms \
602 END \
603 WHERE id = '{}'",
604 self.task_id
605 );
606 self.db
607 .execute(Statement::from_string(DbBackend::Postgres, sql))
608 .await?;
609 Ok(())
610 }
611
612 pub async fn start_with_timeout(
616 db: &DatabaseConnection,
617 name: &str,
618 input: Option<serde_json::Value>,
619 timeout_ms: i64,
620 ) -> Result<Self, DurableError> {
621 let ctx = Self::start(db, name, input).await?;
622 ctx.set_timeout(timeout_ms).await?;
623 Ok(ctx)
624 }
625
626 pub async fn pause(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
633 let model = Task::find_by_id(task_id).one(db).await?;
634 let model =
635 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
636
637 match model.status {
638 TaskStatus::Pending | TaskStatus::Running => {}
639 status => {
640 return Err(DurableError::custom(format!(
641 "cannot pause task in {status} status"
642 )));
643 }
644 }
645
646 let sql = format!(
648 "WITH RECURSIVE descendants AS ( \
649 SELECT id FROM durable.task WHERE id = '{task_id}' \
650 UNION ALL \
651 SELECT t.id FROM durable.task t \
652 INNER JOIN descendants d ON t.parent_id = d.id \
653 ) \
654 UPDATE durable.task SET status = 'PAUSED' \
655 WHERE id IN (SELECT id FROM descendants) \
656 AND status IN ('PENDING', 'RUNNING')"
657 );
658 db.execute(Statement::from_string(DbBackend::Postgres, sql))
659 .await?;
660
661 tracing::info!(%task_id, "workflow paused");
662 Ok(())
663 }
664
665 pub async fn resume(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
668 let model = Task::find_by_id(task_id).one(db).await?;
669 let model =
670 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
671
672 if model.status != TaskStatus::Paused {
673 return Err(DurableError::custom(format!(
674 "cannot resume task in {} status (must be PAUSED)",
675 model.status
676 )));
677 }
678
679 let sql = format!(
681 "UPDATE durable.task SET status = 'RUNNING' WHERE id = '{task_id}'"
682 );
683 db.execute(Statement::from_string(DbBackend::Postgres, sql))
684 .await?;
685
686 let cascade_sql = format!(
688 "WITH RECURSIVE descendants AS ( \
689 SELECT id FROM durable.task WHERE parent_id = '{task_id}' \
690 UNION ALL \
691 SELECT t.id FROM durable.task t \
692 INNER JOIN descendants d ON t.parent_id = d.id \
693 ) \
694 UPDATE durable.task SET status = 'PENDING' \
695 WHERE id IN (SELECT id FROM descendants) \
696 AND status = 'PAUSED'"
697 );
698 db.execute(Statement::from_string(DbBackend::Postgres, cascade_sql))
699 .await?;
700
701 tracing::info!(%task_id, "workflow resumed");
702 Ok(())
703 }
704
705 pub async fn cancel(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
710 let model = Task::find_by_id(task_id).one(db).await?;
711 let model =
712 model.ok_or_else(|| DurableError::custom(format!("task {task_id} not found")))?;
713
714 match model.status {
715 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
716 return Err(DurableError::custom(format!(
717 "cannot cancel task in {} status",
718 model.status
719 )));
720 }
721 _ => {}
722 }
723
724 let sql = format!(
726 "WITH RECURSIVE descendants AS ( \
727 SELECT id FROM durable.task WHERE id = '{task_id}' \
728 UNION ALL \
729 SELECT t.id FROM durable.task t \
730 INNER JOIN descendants d ON t.parent_id = d.id \
731 ) \
732 UPDATE durable.task SET status = 'CANCELLED', completed_at = now() \
733 WHERE id IN (SELECT id FROM descendants) \
734 AND status NOT IN ('COMPLETED', 'FAILED', 'CANCELLED')"
735 );
736 db.execute(Statement::from_string(DbBackend::Postgres, sql))
737 .await?;
738
739 tracing::info!(%task_id, "workflow cancelled");
740 Ok(())
741 }
742
743 pub async fn list(
751 db: &DatabaseConnection,
752 query: TaskQuery,
753 ) -> Result<Vec<TaskSummary>, DurableError> {
754 let mut select = Task::find();
755
756 if let Some(status) = &query.status {
758 select = select.filter(TaskColumn::Status.eq(status.to_string()));
759 }
760 if let Some(kind) = &query.kind {
761 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
762 }
763 if let Some(parent_id) = query.parent_id {
764 select = select.filter(TaskColumn::ParentId.eq(parent_id));
765 }
766 if query.root_only {
767 select = select.filter(TaskColumn::ParentId.is_null());
768 }
769 if let Some(name) = &query.name {
770 select = select.filter(TaskColumn::Name.eq(name.as_str()));
771 }
772 if let Some(queue) = &query.queue_name {
773 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
774 }
775
776 let (col, order) = match query.sort {
778 TaskSort::CreatedAt(ord) => (TaskColumn::CreatedAt, ord),
779 TaskSort::StartedAt(ord) => (TaskColumn::StartedAt, ord),
780 TaskSort::CompletedAt(ord) => (TaskColumn::CompletedAt, ord),
781 TaskSort::Name(ord) => (TaskColumn::Name, ord),
782 TaskSort::Status(ord) => (TaskColumn::Status, ord),
783 };
784 select = select.order_by(col, order);
785
786 if let Some(offset) = query.offset {
788 select = select.offset(offset);
789 }
790 if let Some(limit) = query.limit {
791 select = select.limit(limit);
792 }
793
794 let models = select.all(db).await?;
795
796 Ok(models.into_iter().map(TaskSummary::from).collect())
797 }
798
799 pub async fn count(
801 db: &DatabaseConnection,
802 query: TaskQuery,
803 ) -> Result<u64, DurableError> {
804 let mut select = Task::find();
805
806 if let Some(status) = &query.status {
807 select = select.filter(TaskColumn::Status.eq(status.to_string()));
808 }
809 if let Some(kind) = &query.kind {
810 select = select.filter(TaskColumn::Kind.eq(kind.as_str()));
811 }
812 if let Some(parent_id) = query.parent_id {
813 select = select.filter(TaskColumn::ParentId.eq(parent_id));
814 }
815 if query.root_only {
816 select = select.filter(TaskColumn::ParentId.is_null());
817 }
818 if let Some(name) = &query.name {
819 select = select.filter(TaskColumn::Name.eq(name.as_str()));
820 }
821 if let Some(queue) = &query.queue_name {
822 select = select.filter(TaskColumn::QueueName.eq(queue.as_str()));
823 }
824
825 let count = select.count(db).await?;
826 Ok(count)
827 }
828
829 pub fn db(&self) -> &DatabaseConnection {
832 &self.db
833 }
834
835 pub fn task_id(&self) -> Uuid {
836 self.task_id
837 }
838
839 pub fn next_sequence(&self) -> i32 {
840 self.sequence.fetch_add(1, Ordering::SeqCst)
841 }
842}
843
844#[allow(clippy::too_many_arguments)]
865async fn find_or_create_task(
866 db: &impl ConnectionTrait,
867 parent_id: Option<Uuid>,
868 sequence: Option<i32>,
869 name: &str,
870 kind: &str,
871 input: Option<serde_json::Value>,
872 lock: bool,
873 max_retries: Option<u32>,
874) -> Result<(Uuid, Option<serde_json::Value>), DurableError> {
875 let parent_eq = match parent_id {
876 Some(p) => format!("= '{p}'"),
877 None => "IS NULL".to_string(),
878 };
879 let parent_sql = match parent_id {
880 Some(p) => format!("'{p}'"),
881 None => "NULL".to_string(),
882 };
883
884 if lock {
885 let new_id = Uuid::new_v4();
899 let seq_sql = match sequence {
900 Some(s) => s.to_string(),
901 None => "NULL".to_string(),
902 };
903 let input_sql = match &input {
904 Some(v) => format!("'{}'", serde_json::to_string(v)?),
905 None => "NULL".to_string(),
906 };
907
908 let max_retries_sql = match max_retries {
909 Some(r) => r.to_string(),
910 None => "3".to_string(), };
912
913 let insert_sql = format!(
915 "INSERT INTO durable.task (id, parent_id, sequence, name, kind, status, input, max_retries) \
916 VALUES ('{new_id}', {parent_sql}, {seq_sql}, '{name}', '{kind}', 'PENDING', {input_sql}, {max_retries_sql}) \
917 ON CONFLICT (parent_id, name) DO NOTHING"
918 );
919 db.execute(Statement::from_string(DbBackend::Postgres, insert_sql))
920 .await?;
921
922 let lock_sql = format!(
924 "SELECT id, status::text, output FROM durable.task \
925 WHERE parent_id {parent_eq} AND name = '{name}' \
926 FOR UPDATE SKIP LOCKED"
927 );
928 let row = db
929 .query_one(Statement::from_string(DbBackend::Postgres, lock_sql))
930 .await?;
931
932 if let Some(row) = row {
933 let id: Uuid = row
934 .try_get_by_index(0)
935 .map_err(|e| DurableError::custom(e.to_string()))?;
936 let status: String = row
937 .try_get_by_index(1)
938 .map_err(|e| DurableError::custom(e.to_string()))?;
939 let output: Option<serde_json::Value> = row.try_get_by_index(2).ok();
940
941 if status == TaskStatus::Completed.to_string() {
942 return Ok((id, output));
944 }
945 return Ok((id, None));
947 }
948
949 Err(DurableError::StepLocked(name.to_string()))
951 } else {
952 let mut query = Task::find().filter(TaskColumn::Name.eq(name));
956 query = match parent_id {
957 Some(p) => query.filter(TaskColumn::ParentId.eq(p)),
958 None => query.filter(TaskColumn::ParentId.is_null()),
959 };
960 let existing = query.one(db).await?;
961
962 if let Some(model) = existing {
963 if model.status == TaskStatus::Completed {
964 return Ok((model.id, model.output));
965 }
966 return Ok((model.id, None));
967 }
968
969 let id = Uuid::new_v4();
971 let new_task = TaskActiveModel {
972 id: Set(id),
973 parent_id: Set(parent_id),
974 sequence: Set(sequence),
975 name: Set(name.to_string()),
976 kind: Set(kind.to_string()),
977 status: Set(TaskStatus::Pending),
978 input: Set(input),
979 max_retries: Set(max_retries.map(|r| r as i32).unwrap_or(3)),
980 ..Default::default()
981 };
982 new_task.insert(db).await?;
983
984 Ok((id, None))
985 }
986}
987
988async fn get_output(
989 db: &impl ConnectionTrait,
990 task_id: Uuid,
991) -> Result<Option<serde_json::Value>, DurableError> {
992 let model = Task::find_by_id(task_id)
993 .filter(TaskColumn::Status.eq(TaskStatus::Completed.to_string()))
994 .one(db)
995 .await?;
996
997 Ok(model.and_then(|m| m.output))
998}
999
1000async fn get_status(
1001 db: &impl ConnectionTrait,
1002 task_id: Uuid,
1003) -> Result<Option<TaskStatus>, DurableError> {
1004 let model = Task::find_by_id(task_id).one(db).await?;
1005
1006 Ok(model.map(|m| m.status))
1007}
1008
1009async fn get_retry_info(
1011 db: &DatabaseConnection,
1012 task_id: Uuid,
1013) -> Result<(u32, u32), DurableError> {
1014 let model = Task::find_by_id(task_id).one(db).await?;
1015
1016 match model {
1017 Some(m) => Ok((m.retry_count as u32, m.max_retries as u32)),
1018 None => Err(DurableError::custom(format!(
1019 "task {task_id} not found when reading retry info"
1020 ))),
1021 }
1022}
1023
1024async fn increment_retry_count(
1026 db: &DatabaseConnection,
1027 task_id: Uuid,
1028) -> Result<u32, DurableError> {
1029 let model = Task::find_by_id(task_id).one(db).await?;
1030
1031 match model {
1032 Some(m) => {
1033 let new_count = m.retry_count + 1;
1034 let mut active: TaskActiveModel = m.into();
1035 active.retry_count = Set(new_count);
1036 active.status = Set(TaskStatus::Pending);
1037 active.error = Set(None);
1038 active.completed_at = Set(None);
1039 active.update(db).await?;
1040 Ok(new_count as u32)
1041 }
1042 None => Err(DurableError::custom(format!(
1043 "task {task_id} not found when incrementing retry count"
1044 ))),
1045 }
1046}
1047
1048async fn set_status(
1051 db: &impl ConnectionTrait,
1052 task_id: Uuid,
1053 status: TaskStatus,
1054) -> Result<(), DurableError> {
1055 let sql = format!(
1056 "UPDATE durable.task \
1057 SET status = '{status}', \
1058 started_at = COALESCE(started_at, now()), \
1059 deadline_epoch_ms = CASE \
1060 WHEN '{status}' = 'RUNNING' AND timeout_ms IS NOT NULL \
1061 THEN EXTRACT(EPOCH FROM now()) * 1000 + timeout_ms \
1062 ELSE deadline_epoch_ms \
1063 END \
1064 WHERE id = '{task_id}'"
1065 );
1066 db.execute(Statement::from_string(DbBackend::Postgres, sql))
1067 .await?;
1068 Ok(())
1069}
1070
1071async fn check_status(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1073 let status = get_status(db, task_id).await?;
1074 match status {
1075 Some(TaskStatus::Paused) => Err(DurableError::Paused(format!("task {task_id} is paused"))),
1076 Some(TaskStatus::Cancelled) => {
1077 Err(DurableError::Cancelled(format!("task {task_id} is cancelled")))
1078 }
1079 _ => Ok(()),
1080 }
1081}
1082
1083async fn check_deadline(db: &DatabaseConnection, task_id: Uuid) -> Result<(), DurableError> {
1085 let model = Task::find_by_id(task_id).one(db).await?;
1086
1087 if let Some(m) = model
1088 && let Some(deadline_ms) = m.deadline_epoch_ms
1089 {
1090 let now_ms = std::time::SystemTime::now()
1091 .duration_since(std::time::UNIX_EPOCH)
1092 .map(|d| d.as_millis() as i64)
1093 .unwrap_or(0);
1094 if now_ms > deadline_ms {
1095 return Err(DurableError::Timeout("task deadline exceeded".to_string()));
1096 }
1097 }
1098
1099 Ok(())
1100}
1101
1102async fn complete_task(
1103 db: &impl ConnectionTrait,
1104 task_id: Uuid,
1105 output: serde_json::Value,
1106) -> Result<(), DurableError> {
1107 let model = Task::find_by_id(task_id).one(db).await?;
1108
1109 if let Some(m) = model {
1110 let mut active: TaskActiveModel = m.into();
1111 active.status = Set(TaskStatus::Completed);
1112 active.output = Set(Some(output));
1113 active.completed_at = Set(Some(chrono::Utc::now().into()));
1114 active.update(db).await?;
1115 }
1116 Ok(())
1117}
1118
1119async fn fail_task(
1120 db: &impl ConnectionTrait,
1121 task_id: Uuid,
1122 error: &str,
1123) -> Result<(), DurableError> {
1124 let model = Task::find_by_id(task_id).one(db).await?;
1125
1126 if let Some(m) = model {
1127 let mut active: TaskActiveModel = m.into();
1128 active.status = Set(TaskStatus::Failed);
1129 active.error = Set(Some(error.to_string()));
1130 active.completed_at = Set(Some(chrono::Utc::now().into()));
1131 active.update(db).await?;
1132 }
1133 Ok(())
1134}
1135
1136#[cfg(test)]
1137mod tests {
1138 use super::*;
1139 use std::sync::Arc;
1140 use std::sync::atomic::{AtomicU32, Ordering};
1141
1142 #[tokio::test]
1145 async fn test_retry_db_write_succeeds_first_try() {
1146 let call_count = Arc::new(AtomicU32::new(0));
1147 let cc = call_count.clone();
1148 let result = retry_db_write(|| {
1149 let c = cc.clone();
1150 async move {
1151 c.fetch_add(1, Ordering::SeqCst);
1152 Ok::<(), DurableError>(())
1153 }
1154 })
1155 .await;
1156 assert!(result.is_ok());
1157 assert_eq!(call_count.load(Ordering::SeqCst), 1);
1158 }
1159
1160 #[tokio::test]
1163 async fn test_retry_db_write_succeeds_after_transient_failure() {
1164 let call_count = Arc::new(AtomicU32::new(0));
1165 let cc = call_count.clone();
1166 let result = retry_db_write(|| {
1167 let c = cc.clone();
1168 async move {
1169 let n = c.fetch_add(1, Ordering::SeqCst);
1170 if n < 2 {
1171 Err(DurableError::Db(sea_orm::DbErr::Custom(
1172 "transient".to_string(),
1173 )))
1174 } else {
1175 Ok(())
1176 }
1177 }
1178 })
1179 .await;
1180 assert!(result.is_ok());
1181 assert_eq!(call_count.load(Ordering::SeqCst), 3);
1182 }
1183
1184 #[tokio::test]
1187 async fn test_retry_db_write_exhausts_retries() {
1188 let call_count = Arc::new(AtomicU32::new(0));
1189 let cc = call_count.clone();
1190 let result = retry_db_write(|| {
1191 let c = cc.clone();
1192 async move {
1193 c.fetch_add(1, Ordering::SeqCst);
1194 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(
1195 "always fails".to_string(),
1196 )))
1197 }
1198 })
1199 .await;
1200 assert!(result.is_err());
1201 assert_eq!(
1203 call_count.load(Ordering::SeqCst),
1204 1 + MAX_CHECKPOINT_RETRIES
1205 );
1206 }
1207
1208 #[tokio::test]
1211 async fn test_retry_db_write_returns_original_error() {
1212 let call_count = Arc::new(AtomicU32::new(0));
1213 let cc = call_count.clone();
1214 let result = retry_db_write(|| {
1215 let c = cc.clone();
1216 async move {
1217 let n = c.fetch_add(1, Ordering::SeqCst);
1218 Err::<(), DurableError>(DurableError::Db(sea_orm::DbErr::Custom(format!(
1219 "error-{}",
1220 n
1221 ))))
1222 }
1223 })
1224 .await;
1225 let err = result.unwrap_err();
1226 assert!(
1228 err.to_string().contains("error-0"),
1229 "expected first error (error-0), got: {err}"
1230 );
1231 }
1232}