1use std::collections::HashMap;
15
16use bytes::Bytes;
17use chrono;
18use futures::FutureExt;
19use sayiir_core::codec::sealed;
20use sayiir_core::codec::{Codec, EnvelopeCodec, LoopDecision};
21use sayiir_core::context::{TaskExecutionContext, with_task_context};
22use sayiir_core::error::{BoxError, CodecError, WorkflowError};
23use sayiir_core::registry::TaskRegistry;
24use sayiir_core::snapshot::{
25 ExecutionPosition, SignalKind, SignalRequest, TaskDeadline, WorkflowSnapshot,
26};
27use sayiir_core::task_claim::AvailableTask;
28use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
29use sayiir_persistence::{PersistentBackend, TaskClaimStore};
30use std::num::NonZeroUsize;
31use std::panic::AssertUnwindSafe;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::mpsc;
36use tokio::time;
37
38pub type WorkflowRegistry<C, Input, M> = Vec<(String, Arc<Workflow<C, Input, M>>)>;
40
41pub struct ExternalWorkflow {
47 pub continuation: Arc<WorkflowContinuation>,
49 pub workflow_id: Arc<str>,
51 pub metadata_json: Option<Arc<str>>,
53}
54
55pub type WorkflowIndex = HashMap<String, ExternalWorkflow>;
57
58pub type ExternalTaskExecutor = Arc<
64 dyn Fn(
65 &str,
66 Bytes,
67 ) -> Pin<Box<dyn std::future::Future<Output = Result<Bytes, BoxError>> + Send>>
68 + Send
69 + Sync,
70>;
71
72enum WorkerCommand {
74 Shutdown,
75}
76
77struct WorkerHandleInner<B> {
78 backend: Arc<B>,
79 shutdown_tx: mpsc::Sender<WorkerCommand>,
80 join_handle:
81 tokio::sync::Mutex<Option<tokio::task::JoinHandle<Result<(), crate::error::RuntimeError>>>>,
82}
83
84pub struct WorkerHandle<B> {
90 inner: Arc<WorkerHandleInner<B>>,
91}
92
93impl<B> Clone for WorkerHandle<B> {
94 fn clone(&self) -> Self {
95 Self {
96 inner: Arc::clone(&self.inner),
97 }
98 }
99}
100
101impl<B> WorkerHandle<B> {
102 pub fn shutdown(&self) {
108 let _ = self.inner.shutdown_tx.try_send(WorkerCommand::Shutdown);
109 }
110
111 pub async fn join(&self) -> Result<(), crate::error::RuntimeError> {
119 let jh = self.inner.join_handle.lock().await.take();
120 match jh {
121 Some(jh) => Ok(jh.await??),
122 None => Ok(()),
123 }
124 }
125
126 #[must_use]
128 pub fn backend(&self) -> &Arc<B> {
129 &self.inner.backend
130 }
131}
132
133struct ActiveTaskClaim<'a, B> {
137 backend: &'a B,
138 instance_id: String,
139 task_id: String,
140 worker_id: String,
141}
142
143impl<B: TaskClaimStore> ActiveTaskClaim<'_, B> {
144 async fn release(self) -> Result<(), crate::error::RuntimeError> {
146 self.backend
147 .release_task_claim(&self.instance_id, &self.task_id, &self.worker_id)
148 .await?;
149 Ok(())
150 }
151
152 async fn release_quietly(self) {
154 let _ = self.release().await;
155 }
156}
157
158enum ExecutionOutcome {
160 Success(Bytes),
162 TaskError(crate::error::RuntimeError),
164 Panic(Box<dyn std::any::Any + Send>),
166 Timeout(crate::error::RuntimeError),
168}
169
170fn extract_panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
172 if let Some(s) = payload.downcast_ref::<&str>() {
173 s.to_string()
174 } else if let Some(s) = payload.downcast_ref::<String>() {
175 s.clone()
176 } else {
177 "Task panicked with unknown payload".to_string()
178 }
179}
180
181pub struct PooledWorker<B> {
223 worker_id: String,
224 backend: Arc<B>,
225 #[allow(unused)]
226 registry: Arc<TaskRegistry>,
227 claim_ttl: Option<Duration>,
228 batch_size: NonZeroUsize,
229 aging_interval: Duration,
230 tags: Vec<String>,
231}
232
233impl<B> PooledWorker<B>
234where
235 B: PersistentBackend + TaskClaimStore + 'static,
236{
237 pub fn new(worker_id: impl Into<String>, backend: B, registry: TaskRegistry) -> Self {
251 Self {
252 worker_id: worker_id.into(),
253 backend: Arc::new(backend),
254 registry: Arc::new(registry),
255 claim_ttl: Some(Duration::from_secs(5 * 60)), batch_size: NonZeroUsize::MIN, aging_interval: Duration::from_secs(300), tags: vec![],
259 }
260 }
261
262 #[must_use]
264 pub fn with_claim_ttl(mut self, ttl: Option<Duration>) -> Self {
265 self.claim_ttl = ttl;
266 self
267 }
268
269 #[must_use]
279 pub fn with_aging_interval(mut self, interval: Duration) -> Self {
280 assert!(!interval.is_zero(), "aging interval must be non-zero");
281 self.aging_interval = interval;
282 self
283 }
284
285 #[must_use]
293 pub fn with_batch_size(mut self, size: NonZeroUsize) -> Self {
294 self.batch_size = size;
295 self
296 }
297
298 #[must_use]
304 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
305 self.tags = tags;
306 self
307 }
308
309 pub async fn cancel_workflow(
324 &self,
325 instance_id: &str,
326 reason: Option<String>,
327 cancelled_by: Option<String>,
328 ) -> Result<(), crate::error::RuntimeError> {
329 self.backend
330 .store_signal(
331 instance_id,
332 SignalKind::Cancel,
333 SignalRequest::new(reason, cancelled_by),
334 )
335 .await?;
336
337 Ok(())
338 }
339
340 pub async fn pause_workflow(
355 &self,
356 instance_id: &str,
357 reason: Option<String>,
358 paused_by: Option<String>,
359 ) -> Result<(), crate::error::RuntimeError> {
360 self.backend
361 .store_signal(
362 instance_id,
363 SignalKind::Pause,
364 SignalRequest::new(reason, paused_by),
365 )
366 .await?;
367
368 Ok(())
369 }
370
371 #[must_use]
373 pub fn backend(&self) -> &Arc<B> {
374 &self.backend
375 }
376
377 #[must_use]
394 pub fn spawn<C, Input, M>(
395 self,
396 poll_interval: Duration,
397 workflows: WorkflowRegistry<C, Input, M>,
398 ) -> WorkerHandle<B>
399 where
400 Input: Send + Sync + 'static,
401 M: Send + Sync + 'static,
402 C: Codec
403 + EnvelopeCodec
404 + sealed::DecodeValue<Input>
405 + sealed::EncodeValue<Input>
406 + 'static,
407 {
408 let (tx, rx) = mpsc::channel(1);
409 let backend = Arc::clone(&self.backend);
410 let join_handle =
411 tokio::spawn(async move { self.run_actor_loop(poll_interval, workflows, rx).await });
412 WorkerHandle {
413 inner: Arc::new(WorkerHandleInner {
414 backend,
415 shutdown_tx: tx,
416 join_handle: tokio::sync::Mutex::new(Some(join_handle)),
417 }),
418 }
419 }
420
421 #[must_use]
434 pub fn spawn_with_executor(
435 self,
436 poll_interval: Duration,
437 workflows: WorkflowIndex,
438 executor: ExternalTaskExecutor,
439 ) -> WorkerHandle<B> {
440 let (tx, rx) = mpsc::channel(1);
441 let backend = Arc::clone(&self.backend);
442 let join_handle = tokio::spawn(async move {
443 self.run_external_actor_loop(poll_interval, workflows, executor, rx)
444 .await
445 });
446 WorkerHandle {
447 inner: Arc::new(WorkerHandleInner {
448 backend,
449 shutdown_tx: tx,
450 join_handle: tokio::sync::Mutex::new(Some(join_handle)),
451 }),
452 }
453 }
454
455 async fn run_external_actor_loop(
457 &self,
458 poll_interval: Duration,
459 workflows: WorkflowIndex,
460 executor: ExternalTaskExecutor,
461 mut cmd_rx: mpsc::Receiver<WorkerCommand>,
462 ) -> Result<(), crate::error::RuntimeError> {
463 let mut interval = time::interval(poll_interval);
464
465 loop {
466 tokio::select! {
467 biased;
468
469 cmd = cmd_rx.recv() => {
470 match cmd {
471 Some(WorkerCommand::Shutdown) | None => {
472 tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
473 return Ok(());
474 }
475 }
476 }
477
478 _ = interval.tick() => {
479 tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
480 }
481 }
482
483 let available_tasks = self
484 .backend
485 .find_available_tasks(
486 &self.worker_id,
487 self.batch_size.get(),
488 chrono::Duration::from_std(self.aging_interval)
489 .unwrap_or(chrono::Duration::MAX),
490 &self.tags,
491 )
492 .await?;
493
494 for task in available_tasks {
495 if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
496 cmd_rx.try_recv()
497 {
498 tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
499 return Ok(());
500 }
501
502 if let Some(ext_wf) = workflows.get(&task.workflow_definition_hash) {
503 match self
504 .execute_external_task(
505 ext_wf,
506 &task.workflow_definition_hash,
507 &executor,
508 &task,
509 )
510 .await
511 {
512 Err(ref e) if e.is_timeout() => {
513 tracing::error!(
514 worker_id = %self.worker_id,
515 error = %e,
516 "Task timed out — worker shutting down"
517 );
518 return Ok(());
519 }
520 Ok(_) => {
521 tracing::info!(worker_id = %self.worker_id, "completed task");
522 }
523 Err(e) => {
524 tracing::error!(
525 worker_id = %self.worker_id,
526 error = %e,
527 "task execution failed"
528 );
529 }
530 }
531 }
532 }
533 }
534 }
535
536 #[tracing::instrument(
538 name = "workflow",
539 skip_all,
540 fields(
541 worker_id = %self.worker_id,
542 instance_id = %available_task.instance_id,
543 task_id = %available_task.task_id,
544 definition_hash = %definition_hash,
545 ),
546 )]
547 async fn execute_external_task(
548 &self,
549 ext_wf: &ExternalWorkflow,
550 definition_hash: &str,
551 executor: &ExternalTaskExecutor,
552 available_task: &AvailableTask,
553 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
554 #[cfg(feature = "otel")]
556 if let Some(ref tp) = available_task.trace_parent {
557 use tracing_opentelemetry::OpenTelemetrySpanExt;
558 let remote_ctx = crate::trace_context::context_from_trace_parent(tp);
559 let _ = tracing::Span::current().set_parent(remote_ctx);
560 }
561
562 let continuation = &ext_wf.continuation;
563 let mut snapshot = self
564 .backend
565 .load_snapshot(&available_task.instance_id)
566 .await?;
567 let already_completed = Self::validate_task_preconditions(
568 definition_hash,
569 continuation,
570 available_task,
571 &snapshot,
572 )?;
573 if already_completed {
574 return Ok(WorkflowStatus::InProgress);
575 }
576
577 let Some(claim) = self.claim_task(available_task).await? else {
578 return Ok(WorkflowStatus::InProgress);
579 };
580
581 if let Some(status) = self.check_post_claim_guards(available_task).await? {
582 claim.release_quietly().await;
583 return Ok(status);
584 }
585
586 tracing::debug!(
587 instance_id = %available_task.instance_id,
588 task_id = %available_task.task_id,
589 "Executing task (external)"
590 );
591
592 let execution_result = self
593 .execute_with_deadline_ext(ext_wf, executor, available_task, &mut snapshot, &claim)
594 .await;
595
596 self.settle_execution_result_ext(
597 execution_result,
598 &ext_wf.continuation,
599 available_task,
600 &mut snapshot,
601 claim,
602 )
603 .await
604 }
605
606 async fn execute_with_deadline_ext(
608 &self,
609 ext_wf: &ExternalWorkflow,
610 executor: &ExternalTaskExecutor,
611 available_task: &AvailableTask,
612 snapshot: &mut WorkflowSnapshot,
613 claim: &ActiveTaskClaim<'_, B>,
614 ) -> ExecutionOutcome {
615 let continuation = &ext_wf.continuation;
616 let task_id = available_task.task_id.clone();
617 let input = available_task.input.clone();
618
619 let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
620 snapshot.set_task_deadline(task_id.clone(), timeout);
621 let _ = self.backend.save_snapshot(snapshot).await;
622 snapshot.refresh_task_deadline();
623 snapshot.task_deadline.clone()
624 } else {
625 None
626 };
627
628 let task_ctx = TaskExecutionContext {
629 workflow_id: Arc::clone(&ext_wf.workflow_id),
630 instance_id: Arc::from(available_task.instance_id.as_str()),
631 task_id: Arc::from(available_task.task_id.as_str()),
632 metadata: continuation.build_task_metadata(&available_task.task_id),
633 workflow_metadata_json: ext_wf.metadata_json.clone(),
634 };
635
636 let execution_future = with_task_context(task_ctx, executor(&task_id, input));
637
638 let heartbeat_result = self
639 .run_with_heartbeat(
640 claim,
641 deadline.as_ref(),
642 AssertUnwindSafe(execution_future).catch_unwind(),
643 )
644 .await;
645
646 snapshot.clear_task_deadline();
647
648 match heartbeat_result {
649 Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
650 Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
651 Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e.into()),
652 Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
653 }
654 }
655
656 #[tracing::instrument(
658 name = "settle_result",
659 skip_all,
660 fields(worker_id = %self.worker_id, instance_id = %available_task.instance_id, task_id = %available_task.task_id),
661 )]
662 async fn settle_execution_result_ext(
663 &self,
664 outcome: ExecutionOutcome,
665 continuation: &WorkflowContinuation,
666 available_task: &AvailableTask,
667 snapshot: &mut WorkflowSnapshot,
668 claim: ActiveTaskClaim<'_, B>,
669 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
670 tracing::debug!("settling execution result");
671 match outcome {
672 ExecutionOutcome::Timeout(err) => {
673 if let Ok(Some(status)) = self
674 .try_schedule_retry(continuation, available_task, snapshot, &err.to_string())
675 .await
676 {
677 claim.release_quietly().await;
678 return Ok(status);
679 }
680
681 tracing::warn!(
682 instance_id = %available_task.instance_id,
683 task_id = %available_task.task_id,
684 error = %err,
685 "Task timed out via heartbeat — marking workflow failed, shutting down"
686 );
687 snapshot.mark_failed(err.to_string());
688 let _ = self.backend.save_snapshot(snapshot).await;
689 claim.release_quietly().await;
690 Err(err)
691 }
692 ExecutionOutcome::Panic(panic_payload) => {
693 let panic_msg = extract_panic_message(&panic_payload);
694
695 if let Ok(Some(status)) = self
696 .try_schedule_retry(continuation, available_task, snapshot, &panic_msg)
697 .await
698 {
699 claim.release_quietly().await;
700 return Ok(status);
701 }
702
703 tracing::error!(
704 instance_id = %available_task.instance_id,
705 task_id = %available_task.task_id,
706 panic = %panic_msg,
707 "Task panicked - releasing claim"
708 );
709 claim.release_quietly().await;
710 Err(WorkflowError::TaskPanicked(panic_msg).into())
711 }
712 ExecutionOutcome::TaskError(e) => {
713 if let Ok(Some(status)) = self
714 .try_schedule_retry(continuation, available_task, snapshot, &e.to_string())
715 .await
716 {
717 claim.release_quietly().await;
718 return Ok(status);
719 }
720
721 tracing::error!(
722 instance_id = %available_task.instance_id,
723 task_id = %available_task.task_id,
724 error = %e,
725 "Task execution failed"
726 );
727 claim.release_quietly().await;
728 Err(e)
729 }
730 ExecutionOutcome::Success(output) => {
731 snapshot.clear_retry_state(&available_task.task_id);
732 self.commit_task_result(
733 continuation,
734 available_task,
735 snapshot,
736 output.clone(),
737 claim,
738 )
739 .await?;
740 self.determine_post_task_status(continuation, available_task, snapshot, output)
741 .await
742 }
743 }
744 }
745
746 async fn run_actor_loop<C, Input, M>(
749 &self,
750 poll_interval: Duration,
751 workflows: WorkflowRegistry<C, Input, M>,
752 mut cmd_rx: mpsc::Receiver<WorkerCommand>,
753 ) -> Result<(), crate::error::RuntimeError>
754 where
755 Input: Send + 'static,
756 M: Send + Sync + 'static,
757 C: Codec
758 + EnvelopeCodec
759 + sealed::DecodeValue<Input>
760 + sealed::EncodeValue<Input>
761 + 'static,
762 {
763 let mut interval = time::interval(poll_interval);
764
765 loop {
766 tokio::select! {
767 biased;
768
769 cmd = cmd_rx.recv() => {
770 match cmd {
772 Some(WorkerCommand::Shutdown) | None => {
773 tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
774 return Ok(());
775 }
776 }
777 }
778
779 _ = interval.tick() => {
780 tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
781 }
782 }
783
784 let available_tasks = self
785 .backend
786 .find_available_tasks(
787 &self.worker_id,
788 self.batch_size.get(),
789 chrono::Duration::from_std(self.aging_interval)
790 .unwrap_or(chrono::Duration::MAX),
791 &self.tags,
792 )
793 .await?;
794
795 for task in available_tasks {
796 if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
797 cmd_rx.try_recv()
798 {
799 tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
800 return Ok(());
801 }
802
803 if let Some((_, workflow)) = workflows
804 .iter()
805 .find(|(hash, _)| *hash == task.workflow_definition_hash)
806 {
807 match self.execute_task(workflow.as_ref(), task).await {
808 Err(ref e) if e.is_timeout() => {
809 tracing::error!(
810 worker_id = %self.worker_id,
811 error = %e,
812 "Task timed out — worker shutting down"
813 );
814 return Ok(());
815 }
816 Ok(_) => {
817 tracing::info!(worker_id = %self.worker_id, "completed task");
818 }
819 Err(e) => {
820 tracing::error!(
821 worker_id = %self.worker_id,
822 error = %e,
823 "task execution failed"
824 );
825 }
826 }
827 }
828 }
829 }
830 }
831
832 async fn load_cancelled_status(&self, instance_id: &str) -> WorkflowStatus {
837 if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
838 && let Some((reason, cancelled_by)) = snapshot.state.cancellation_details()
839 {
840 return WorkflowStatus::Cancelled {
841 reason,
842 cancelled_by,
843 };
844 }
845 WorkflowStatus::Cancelled {
846 reason: None,
847 cancelled_by: None,
848 }
849 }
850
851 async fn load_paused_status(&self, instance_id: &str) -> WorkflowStatus {
856 if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
857 && let Some((reason, paused_by)) = snapshot.state.pause_details()
858 {
859 return WorkflowStatus::Paused { reason, paused_by };
860 }
861 WorkflowStatus::Paused {
862 reason: None,
863 paused_by: None,
864 }
865 }
866
867 #[tracing::instrument(
879 name = "workflow",
880 skip_all,
881 fields(
882 worker_id = %self.worker_id,
883 instance_id = %available_task.instance_id,
884 task_id = %available_task.task_id,
885 definition_hash = %available_task.workflow_definition_hash,
886 ),
887 )]
888 pub async fn execute_task<C, Input, M>(
889 &self,
890 workflow: &Workflow<C, Input, M>,
891 available_task: AvailableTask,
892 ) -> Result<WorkflowStatus, crate::error::RuntimeError>
893 where
894 Input: Send + 'static,
895 M: Send + Sync + 'static,
896 C: Codec
897 + EnvelopeCodec
898 + sealed::DecodeValue<Input>
899 + sealed::EncodeValue<Input>
900 + 'static,
901 {
902 #[cfg(feature = "otel")]
904 if let Some(ref tp) = available_task.trace_parent {
905 use tracing_opentelemetry::OpenTelemetrySpanExt;
906 let remote_ctx = crate::trace_context::context_from_trace_parent(tp);
907 let _ = tracing::Span::current().set_parent(remote_ctx);
908 }
909
910 let mut snapshot = self
912 .backend
913 .load_snapshot(&available_task.instance_id)
914 .await?;
915 let already_completed = Self::validate_task_preconditions(
916 workflow.definition_hash(),
917 workflow.continuation(),
918 &available_task,
919 &snapshot,
920 )?;
921 if already_completed {
922 return Ok(WorkflowStatus::InProgress);
923 }
924
925 let Some(claim) = self.claim_task(&available_task).await? else {
926 return Ok(WorkflowStatus::InProgress);
927 };
928
929 if let Some(status) = self.check_post_claim_guards(&available_task).await? {
931 claim.release_quietly().await;
932 return Ok(status);
933 }
934
935 tracing::debug!(
936 instance_id = %available_task.instance_id,
937 task_id = %available_task.task_id,
938 "Executing task"
939 );
940
941 let execution_result = self
943 .execute_with_deadline(workflow, &available_task, &mut snapshot, &claim)
944 .await;
945
946 self.settle_execution_result(
947 execution_result,
948 workflow,
949 &available_task,
950 &mut snapshot,
951 claim,
952 )
953 .await
954 }
955
956 async fn execute_with_deadline<C, Input, M>(
962 &self,
963 workflow: &Workflow<C, Input, M>,
964 available_task: &AvailableTask,
965 snapshot: &mut WorkflowSnapshot,
966 claim: &ActiveTaskClaim<'_, B>,
967 ) -> ExecutionOutcome
968 where
969 Input: Send + 'static,
970 M: Send + Sync + 'static,
971 C: Codec
972 + EnvelopeCodec
973 + sealed::DecodeValue<Input>
974 + sealed::EncodeValue<Input>
975 + 'static,
976 {
977 let continuation = workflow.continuation();
978 let task_id = available_task.task_id.clone();
979 let input = available_task.input.clone();
980
981 let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
983 snapshot.set_task_deadline(task_id.clone(), timeout);
984 let _ = self.backend.save_snapshot(snapshot).await;
985 snapshot.refresh_task_deadline();
988 snapshot.task_deadline.clone()
989 } else {
990 None
991 };
992
993 let task_ctx = TaskExecutionContext {
994 workflow_id: Arc::clone(&workflow.context().workflow_id),
995 instance_id: Arc::from(available_task.instance_id.as_str()),
996 task_id: Arc::from(task_id.as_str()),
997 metadata: continuation.build_task_metadata(&task_id),
998 workflow_metadata_json: workflow.context().metadata_json.clone(),
999 };
1000
1001 let execution_future = with_task_context(task_ctx, async move {
1002 Self::execute_task_by_id(continuation, &task_id, input).await
1003 });
1004
1005 let heartbeat_result = self
1006 .run_with_heartbeat(
1007 claim,
1008 deadline.as_ref(),
1009 AssertUnwindSafe(execution_future).catch_unwind(),
1010 )
1011 .await;
1012
1013 snapshot.clear_task_deadline();
1014
1015 match heartbeat_result {
1016 Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
1017 Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
1018 Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e),
1019 Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
1020 }
1021 }
1022
1023 async fn try_schedule_retry(
1030 &self,
1031 continuation: &WorkflowContinuation,
1032 available_task: &AvailableTask,
1033 snapshot: &mut WorkflowSnapshot,
1034 error_msg: &str,
1035 ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
1036 let Some(policy) = continuation.get_task_retry_policy(&available_task.task_id) else {
1037 return Ok(None);
1038 };
1039
1040 if snapshot.retries_exhausted(&available_task.task_id) {
1041 return Ok(None);
1042 }
1043
1044 let next_retry_at = snapshot.record_retry(
1045 &available_task.task_id,
1046 policy,
1047 error_msg,
1048 Some(&self.worker_id),
1049 );
1050 snapshot.clear_task_deadline();
1051 let _ = self.backend.save_snapshot(snapshot).await;
1052
1053 tracing::info!(
1054 instance_id = %available_task.instance_id,
1055 task_id = %available_task.task_id,
1056 attempt = snapshot.get_retry_state(&available_task.task_id).map_or(0, |rs| rs.attempts),
1057 max_retries = policy.max_retries,
1058 %next_retry_at,
1059 "Scheduling retry"
1060 );
1061
1062 Ok(Some(WorkflowStatus::InProgress))
1063 }
1064
1065 #[tracing::instrument(
1067 name = "settle_result",
1068 skip_all,
1069 fields(worker_id = %self.worker_id, instance_id = %available_task.instance_id, task_id = %available_task.task_id),
1070 )]
1071 async fn settle_execution_result<C, Input, M>(
1072 &self,
1073 outcome: ExecutionOutcome,
1074 workflow: &Workflow<C, Input, M>,
1075 available_task: &AvailableTask,
1076 snapshot: &mut WorkflowSnapshot,
1077 claim: ActiveTaskClaim<'_, B>,
1078 ) -> Result<WorkflowStatus, crate::error::RuntimeError>
1079 where
1080 Input: Send + 'static,
1081 M: Send + Sync + 'static,
1082 C: Codec
1083 + EnvelopeCodec
1084 + sealed::DecodeValue<Input>
1085 + sealed::EncodeValue<Input>
1086 + 'static,
1087 {
1088 tracing::debug!("settling execution result");
1089 match outcome {
1090 ExecutionOutcome::Timeout(err) => {
1091 if let Ok(Some(status)) = self
1092 .try_schedule_retry(
1093 workflow.continuation(),
1094 available_task,
1095 snapshot,
1096 &err.to_string(),
1097 )
1098 .await
1099 {
1100 claim.release_quietly().await;
1101 return Ok(status);
1102 }
1103
1104 tracing::warn!(
1105 instance_id = %available_task.instance_id,
1106 task_id = %available_task.task_id,
1107 error = %err,
1108 "Task timed out via heartbeat — marking workflow failed, shutting down"
1109 );
1110 snapshot.mark_failed(err.to_string());
1111 let _ = self.backend.save_snapshot(snapshot).await;
1112 claim.release_quietly().await;
1113 Err(err)
1114 }
1115 ExecutionOutcome::Panic(panic_payload) => {
1116 let panic_msg = extract_panic_message(&panic_payload);
1117
1118 if let Ok(Some(status)) = self
1119 .try_schedule_retry(
1120 workflow.continuation(),
1121 available_task,
1122 snapshot,
1123 &panic_msg,
1124 )
1125 .await
1126 {
1127 claim.release_quietly().await;
1128 return Ok(status);
1129 }
1130
1131 tracing::error!(
1132 instance_id = %available_task.instance_id,
1133 task_id = %available_task.task_id,
1134 panic = %panic_msg,
1135 "Task panicked - releasing claim"
1136 );
1137 claim.release_quietly().await;
1138 Err(WorkflowError::TaskPanicked(panic_msg).into())
1139 }
1140 ExecutionOutcome::TaskError(e) => {
1141 if let Ok(Some(status)) = self
1142 .try_schedule_retry(
1143 workflow.continuation(),
1144 available_task,
1145 snapshot,
1146 &e.to_string(),
1147 )
1148 .await
1149 {
1150 claim.release_quietly().await;
1151 return Ok(status);
1152 }
1153
1154 tracing::error!(
1155 instance_id = %available_task.instance_id,
1156 task_id = %available_task.task_id,
1157 error = %e,
1158 "Task execution failed"
1159 );
1160 claim.release_quietly().await;
1161 Err(e)
1162 }
1163 ExecutionOutcome::Success(output) => {
1164 snapshot.clear_retry_state(&available_task.task_id);
1165 self.commit_task_result(
1166 workflow.continuation(),
1167 available_task,
1168 snapshot,
1169 output.clone(),
1170 claim,
1171 )
1172 .await?;
1173 Self::resolve_loop_completions(
1176 workflow.continuation(),
1177 snapshot,
1178 self.backend.as_ref(),
1179 )
1180 .await?;
1181 self.determine_post_task_status(
1182 workflow.continuation(),
1183 available_task,
1184 snapshot,
1185 output,
1186 )
1187 .await
1188 }
1189 }
1190 }
1191
1192 fn validate_task_preconditions(
1198 definition_hash: &str,
1199 continuation: &WorkflowContinuation,
1200 available_task: &AvailableTask,
1201 snapshot: &WorkflowSnapshot,
1202 ) -> Result<bool, crate::error::RuntimeError> {
1203 if available_task.workflow_definition_hash != definition_hash {
1204 return Err(WorkflowError::DefinitionMismatch {
1205 expected: definition_hash.to_string(),
1206 found: available_task.workflow_definition_hash.clone(),
1207 }
1208 .into());
1209 }
1210
1211 if !Self::find_task_id_in_continuation(continuation, &available_task.task_id) {
1212 tracing::error!(
1213 instance_id = %available_task.instance_id,
1214 task_id = %available_task.task_id,
1215 "Task does not exist in workflow"
1216 );
1217 return Err(WorkflowError::TaskNotFound(available_task.task_id.clone()).into());
1218 }
1219
1220 if snapshot.get_task_result(&available_task.task_id).is_some() {
1221 tracing::debug!(
1222 instance_id = %available_task.instance_id,
1223 task_id = %available_task.task_id,
1224 "Task already completed, skipping"
1225 );
1226 return Ok(true);
1227 }
1228
1229 Ok(false)
1230 }
1231
1232 async fn claim_task(
1236 &self,
1237 available_task: &AvailableTask,
1238 ) -> Result<Option<ActiveTaskClaim<'_, B>>, crate::error::RuntimeError> {
1239 let claim = self
1240 .backend
1241 .claim_task(
1242 &available_task.instance_id,
1243 &available_task.task_id,
1244 &self.worker_id,
1245 self.claim_ttl
1246 .and_then(|d| chrono::Duration::from_std(d).ok()),
1247 )
1248 .await?;
1249
1250 if claim.is_some() {
1251 tracing::debug!(
1252 instance_id = %available_task.instance_id,
1253 task_id = %available_task.task_id,
1254 "Claim successful"
1255 );
1256 Ok(Some(ActiveTaskClaim {
1257 backend: &self.backend,
1258 instance_id: available_task.instance_id.clone(),
1259 task_id: available_task.task_id.clone(),
1260 worker_id: self.worker_id.clone(),
1261 }))
1262 } else {
1263 tracing::debug!(
1264 instance_id = %available_task.instance_id,
1265 task_id = %available_task.task_id,
1266 "Task was already claimed by another worker"
1267 );
1268 Ok(None)
1269 }
1270 }
1271
1272 async fn check_post_claim_guards(
1278 &self,
1279 available_task: &AvailableTask,
1280 ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
1281 if self
1282 .backend
1283 .check_and_cancel(&available_task.instance_id, Some(&available_task.task_id))
1284 .await?
1285 {
1286 tracing::info!(
1287 instance_id = %available_task.instance_id,
1288 task_id = %available_task.task_id,
1289 "Workflow was cancelled, releasing claim"
1290 );
1291 return Ok(Some(
1292 self.load_cancelled_status(&available_task.instance_id)
1293 .await,
1294 ));
1295 }
1296
1297 if self
1298 .backend
1299 .check_and_pause(&available_task.instance_id)
1300 .await?
1301 {
1302 tracing::info!(
1303 instance_id = %available_task.instance_id,
1304 task_id = %available_task.task_id,
1305 "Workflow was paused, releasing claim"
1306 );
1307 return Ok(Some(
1308 self.load_paused_status(&available_task.instance_id).await,
1309 ));
1310 }
1311
1312 Ok(None)
1313 }
1314
1315 #[tracing::instrument(
1321 name = "task",
1322 skip_all,
1323 fields(worker_id = %self.worker_id, instance_id = %claim.instance_id, task_id = %claim.task_id),
1324 )]
1325 async fn run_with_heartbeat<F, T>(
1326 &self,
1327 claim: &ActiveTaskClaim<'_, B>,
1328 deadline: Option<&TaskDeadline>,
1329 future: F,
1330 ) -> Result<T, crate::error::RuntimeError>
1331 where
1332 F: std::future::Future<Output = T>,
1333 {
1334 tracing::debug!("running task with heartbeat");
1335 let Some(ttl) = self.claim_ttl else {
1336 return Ok(future.await);
1337 };
1338 let Some(chrono_ttl) = chrono::Duration::from_std(ttl).ok() else {
1339 return Ok(future.await);
1340 };
1341
1342 let interval_duration = ttl / 2;
1343 let mut heartbeat_timer = time::interval(interval_duration);
1344 heartbeat_timer.tick().await; tokio::pin!(future);
1347
1348 loop {
1349 tokio::select! {
1350 result = &mut future => break Ok(result),
1351 _ = heartbeat_timer.tick() => {
1352 if let Some(dl) = deadline
1354 && chrono::Utc::now() >= dl.deadline
1355 {
1356 tracing::warn!(
1357 instance_id = %claim.instance_id,
1358 task_id = %dl.task_id,
1359 "Task deadline expired during heartbeat, cancelling"
1360 );
1361 return Err(WorkflowError::TaskTimedOut {
1362 task_id: dl.task_id.clone(),
1363 timeout: std::time::Duration::from_millis(dl.timeout_ms),
1364 }
1365 .into());
1366 }
1367
1368 tracing::trace!(
1369 instance_id = %claim.instance_id,
1370 task_id = %claim.task_id,
1371 "Extending task claim via heartbeat"
1372 );
1373 if let Err(e) = self.backend
1374 .extend_task_claim(
1375 &claim.instance_id,
1376 &claim.task_id,
1377 &claim.worker_id,
1378 chrono_ttl,
1379 )
1380 .await
1381 {
1382 tracing::warn!(
1383 instance_id = %claim.instance_id,
1384 task_id = %claim.task_id,
1385 error = %e,
1386 "Failed to extend task claim"
1387 );
1388 }
1389 }
1390 }
1391 }
1392 }
1393
1394 async fn commit_task_result(
1396 &self,
1397 continuation: &WorkflowContinuation,
1398 available_task: &AvailableTask,
1399 snapshot: &mut WorkflowSnapshot,
1400 output: Bytes,
1401 claim: ActiveTaskClaim<'_, B>,
1402 ) -> Result<(), crate::error::RuntimeError> {
1403 snapshot.mark_task_completed(available_task.task_id.clone(), output);
1404 tracing::debug!(
1405 instance_id = %available_task.instance_id,
1406 task_id = %available_task.task_id,
1407 "Task completed"
1408 );
1409
1410 Self::update_position_after_task(continuation, &available_task.task_id, snapshot);
1411 #[cfg(feature = "otel")]
1412 {
1413 snapshot.trace_parent = crate::trace_context::current_trace_parent();
1414 }
1415 self.backend.save_snapshot(snapshot).await?;
1416 claim.release().await?;
1417 Ok(())
1418 }
1419
1420 async fn determine_post_task_status(
1424 &self,
1425 continuation: &WorkflowContinuation,
1426 available_task: &AvailableTask,
1427 snapshot: &mut WorkflowSnapshot,
1428 output: Bytes,
1429 ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
1430 if self
1432 .backend
1433 .check_and_cancel(&available_task.instance_id, None)
1434 .await?
1435 {
1436 tracing::info!(
1437 instance_id = %available_task.instance_id,
1438 task_id = %available_task.task_id,
1439 "Workflow was cancelled after task completion"
1440 );
1441 return Ok(self
1442 .load_cancelled_status(&available_task.instance_id)
1443 .await);
1444 }
1445
1446 if self
1448 .backend
1449 .check_and_pause(&available_task.instance_id)
1450 .await?
1451 {
1452 tracing::info!(
1453 instance_id = %available_task.instance_id,
1454 task_id = %available_task.task_id,
1455 "Workflow was paused after task completion"
1456 );
1457 return Ok(self.load_paused_status(&available_task.instance_id).await);
1458 }
1459
1460 if Self::is_workflow_complete(continuation, snapshot) {
1461 tracing::info!(
1462 instance_id = %available_task.instance_id,
1463 task_id = %available_task.task_id,
1464 "Workflow complete"
1465 );
1466 snapshot.mark_completed(output);
1467 self.backend.save_snapshot(snapshot).await?;
1468 Ok(WorkflowStatus::Completed)
1469 } else {
1470 tracing::debug!(
1471 instance_id = %available_task.instance_id,
1472 task_id = %available_task.task_id,
1473 "Task completed, workflow continues"
1474 );
1475 Ok(WorkflowStatus::InProgress)
1476 }
1477 }
1478
1479 fn find_task_id_in_continuation(continuation: &WorkflowContinuation, task_id: &str) -> bool {
1484 match continuation {
1485 WorkflowContinuation::Task { id, next, .. }
1486 | WorkflowContinuation::Delay { id, next, .. }
1487 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1488 if id == task_id {
1489 return true;
1490 }
1491 next.as_ref()
1492 .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1493 }
1494 WorkflowContinuation::Fork { branches, join, .. } => {
1495 for branch in branches {
1497 if Self::find_task_id_in_continuation(branch, task_id) {
1498 return true;
1499 }
1500 }
1501 if let Some(join_cont) = join {
1503 Self::find_task_id_in_continuation(join_cont, task_id)
1504 } else {
1505 false
1506 }
1507 }
1508 WorkflowContinuation::Branch {
1509 branches,
1510 default,
1511 next,
1512 ..
1513 } => {
1514 for branch_cont in branches.values() {
1515 if Self::find_task_id_in_continuation(branch_cont, task_id) {
1516 return true;
1517 }
1518 }
1519 if let Some(def) = default
1520 && Self::find_task_id_in_continuation(def, task_id)
1521 {
1522 return true;
1523 }
1524 next.as_ref()
1525 .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1526 }
1527 WorkflowContinuation::Loop { body, next, .. } => {
1528 if Self::find_task_id_in_continuation(body, task_id) {
1529 return true;
1530 }
1531 next.as_ref()
1532 .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1533 }
1534 WorkflowContinuation::ChildWorkflow { child, next, .. } => {
1535 if Self::find_task_id_in_continuation(child, task_id) {
1536 return true;
1537 }
1538 next.as_ref()
1539 .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1540 }
1541 }
1542 }
1543
1544 #[allow(clippy::manual_async_fn)]
1546 fn execute_task_by_id<'a>(
1547 continuation: &'a WorkflowContinuation,
1548 task_id: &'a str,
1549 input: Bytes,
1550 ) -> impl std::future::Future<Output = Result<Bytes, crate::error::RuntimeError>> + Send + 'a
1551 {
1552 async move {
1553 let mut current = continuation;
1554
1555 loop {
1556 match current {
1557 WorkflowContinuation::Task { id, func, next, .. } => {
1558 if id == task_id {
1559 let func = func
1560 .as_ref()
1561 .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
1562 return Ok(func.run(input).await?);
1563 } else if let Some(next_cont) = next {
1564 current = next_cont;
1565 } else {
1566 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1567 }
1568 }
1569 WorkflowContinuation::Delay { next, .. }
1570 | WorkflowContinuation::AwaitSignal { next, .. } => {
1571 if let Some(next_cont) = next {
1573 current = next_cont;
1574 } else {
1575 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1576 }
1577 }
1578 WorkflowContinuation::Fork { branches, join, .. } => {
1579 let mut found_in_branch = false;
1581 for branch in branches {
1582 if Self::find_task_id_in_continuation(branch, task_id) {
1583 current = branch;
1584 found_in_branch = true;
1585 break;
1586 }
1587 }
1588 if found_in_branch {
1589 continue;
1590 }
1591 if let Some(join_cont) = join {
1593 current = join_cont;
1594 } else {
1595 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1596 }
1597 }
1598 WorkflowContinuation::Branch {
1599 branches,
1600 default,
1601 next,
1602 ..
1603 } => {
1604 let mut found = false;
1606 for branch_cont in branches.values() {
1607 if Self::find_task_id_in_continuation(branch_cont, task_id) {
1608 current = branch_cont;
1609 found = true;
1610 break;
1611 }
1612 }
1613 if found {
1614 continue;
1615 }
1616 if let Some(def) = default
1617 && Self::find_task_id_in_continuation(def, task_id)
1618 {
1619 current = def;
1620 continue;
1621 }
1622 if let Some(next_cont) = next {
1623 current = next_cont;
1624 } else {
1625 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1626 }
1627 }
1628 WorkflowContinuation::Loop { body, next, .. } => {
1629 if Self::find_task_id_in_continuation(body, task_id) {
1630 current = body;
1631 continue;
1632 }
1633 if let Some(next_cont) = next {
1634 current = next_cont;
1635 } else {
1636 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1637 }
1638 }
1639 WorkflowContinuation::ChildWorkflow { child, next, .. } => {
1640 if Self::find_task_id_in_continuation(child, task_id) {
1641 current = child;
1642 continue;
1643 }
1644 if let Some(next_cont) = next {
1645 current = next_cont;
1646 } else {
1647 return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1648 }
1649 }
1650 }
1651 }
1652 }
1653 }
1654
1655 fn update_position_after_task(
1657 continuation: &WorkflowContinuation,
1658 completed_task_id: &str,
1659 snapshot: &mut WorkflowSnapshot,
1660 ) {
1661 match continuation {
1662 WorkflowContinuation::Task { id, next, .. }
1663 | WorkflowContinuation::Delay { id, next, .. }
1664 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1665 if id == completed_task_id {
1666 if let Some(next_cont) = next {
1667 let hint = next_cont.first_task_hint();
1668 snapshot.update_position(ExecutionPosition::AtTask {
1669 task_id: hint.id.clone(),
1670 });
1671 snapshot.set_task_hint(&hint);
1672 }
1673 } else if let Some(next_cont) = next {
1674 Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1675 }
1676 }
1677 WorkflowContinuation::Fork { branches, join, .. } => {
1678 for branch in branches {
1680 Self::update_position_after_task(branch, completed_task_id, snapshot);
1681 }
1682 if let Some(join_cont) = join {
1684 Self::update_position_after_task(join_cont, completed_task_id, snapshot);
1685 }
1686 }
1687 WorkflowContinuation::Branch {
1688 branches,
1689 default,
1690 next,
1691 ..
1692 } => {
1693 for branch_cont in branches.values() {
1694 Self::update_position_after_task(branch_cont, completed_task_id, snapshot);
1695 }
1696 if let Some(def) = default {
1697 Self::update_position_after_task(def, completed_task_id, snapshot);
1698 }
1699 if let Some(next_cont) = next {
1700 Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1701 }
1702 }
1703 WorkflowContinuation::Loop { body, next, .. } => {
1704 Self::update_position_after_task(body, completed_task_id, snapshot);
1705 if let Some(next_cont) = next {
1706 Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1707 }
1708 }
1709 WorkflowContinuation::ChildWorkflow { child, next, .. } => {
1710 Self::update_position_after_task(child, completed_task_id, snapshot);
1711 if let Some(next_cont) = next {
1712 Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1713 }
1714 }
1715 }
1716 }
1717
1718 pub fn builder(backend: B, registry: TaskRegistry) -> PooledWorkerBuilder<B> {
1739 PooledWorkerBuilder {
1740 worker_id: None,
1741 backend,
1742 registry,
1743 claim_ttl: Some(Duration::from_secs(5 * 60)),
1744 batch_size: NonZeroUsize::MIN,
1745 aging_interval: Duration::from_secs(300),
1746 tags: vec![],
1747 }
1748 }
1749
1750 async fn resolve_loop_completions(
1756 continuation: &WorkflowContinuation,
1757 snapshot: &mut WorkflowSnapshot,
1758 backend: &B,
1759 ) -> Result<(), crate::error::RuntimeError> {
1760 Self::resolve_loops_recursive(continuation, snapshot, backend).await
1761 }
1762
1763 fn resolve_loops_recursive<'a>(
1764 continuation: &'a WorkflowContinuation,
1765 snapshot: &'a mut WorkflowSnapshot,
1766 backend: &'a B,
1767 ) -> Pin<
1768 Box<dyn std::future::Future<Output = Result<(), crate::error::RuntimeError>> + Send + 'a>,
1769 > {
1770 Box::pin(async move {
1771 match continuation {
1772 WorkflowContinuation::Loop {
1773 id,
1774 body,
1775 max_iterations,
1776 on_max,
1777 next,
1778 } => {
1779 if snapshot.get_task_result(id).is_none() {
1781 let terminal_id = body.terminal_task_id();
1782 if let Some(result) = snapshot.get_task_result(terminal_id) {
1783 let output = result.output.clone();
1784 match crate::execution::decode_loop_envelope(&output) {
1785 Ok((LoopDecision::Done, inner)) => {
1786 snapshot.clear_loop_iteration(id);
1787 snapshot.mark_task_completed(id.clone(), inner);
1788 backend.save_snapshot(snapshot).await?;
1789 }
1790 Ok((LoopDecision::Again, again_value)) => {
1791 let current_iter = snapshot.loop_iteration(id);
1792 let next_iter = current_iter + 1;
1793 if next_iter >= *max_iterations {
1794 match on_max {
1795 sayiir_core::workflow::MaxIterationsPolicy::Fail => {
1796 return Err(WorkflowError::MaxIterationsExceeded {
1797 loop_id: id.clone(),
1798 max_iterations: *max_iterations,
1799 }
1800 .into());
1801 }
1802 sayiir_core::workflow::MaxIterationsPolicy::ExitWithLast => {
1803 snapshot.clear_loop_iteration(id);
1804 snapshot.mark_task_completed(
1805 id.clone(),
1806 again_value,
1807 );
1808 backend.save_snapshot(snapshot).await?;
1809 }
1810 }
1811 } else {
1812 let body_ser = body.to_serializable();
1815 for tid in &body_ser.task_ids() {
1816 snapshot.remove_task_result(tid);
1817 }
1818 snapshot.set_loop_iteration(id, next_iter);
1819 backend.save_snapshot(snapshot).await?;
1820 }
1821 }
1822 Err(e) => {
1823 return Err(CodecError::DecodeFailed {
1824 task_id: id.clone(),
1825 expected_type: "LoopEnvelope",
1826 source: e,
1827 }
1828 .into());
1829 }
1830 }
1831 }
1832 }
1833 Self::resolve_loops_recursive(body, snapshot, backend).await?;
1835 if let Some(next) = next {
1836 Self::resolve_loops_recursive(next, snapshot, backend).await?;
1837 }
1838 }
1839 WorkflowContinuation::Task { next, .. }
1840 | WorkflowContinuation::Delay { next, .. }
1841 | WorkflowContinuation::AwaitSignal { next, .. }
1842 | WorkflowContinuation::Branch { next, .. } => {
1843 if let Some(next) = next {
1844 Self::resolve_loops_recursive(next, snapshot, backend).await?;
1845 }
1846 }
1847 WorkflowContinuation::Fork { branches, join, .. } => {
1848 for branch in branches {
1849 Self::resolve_loops_recursive(branch, snapshot, backend).await?;
1850 }
1851 if let Some(join) = join {
1852 Self::resolve_loops_recursive(join, snapshot, backend).await?;
1853 }
1854 }
1855 WorkflowContinuation::ChildWorkflow { child, next, .. } => {
1856 Self::resolve_loops_recursive(child, snapshot, backend).await?;
1857 if let Some(next) = next {
1858 Self::resolve_loops_recursive(next, snapshot, backend).await?;
1859 }
1860 }
1861 }
1862 Ok(())
1863 })
1864 }
1865
1866 fn is_workflow_complete(
1868 continuation: &WorkflowContinuation,
1869 snapshot: &WorkflowSnapshot,
1870 ) -> bool {
1871 match continuation {
1873 WorkflowContinuation::Task { id, next, .. } => {
1874 if snapshot.get_task_result(id).is_none() {
1875 return false;
1876 }
1877 if let Some(next_cont) = next {
1878 Self::is_workflow_complete(next_cont, snapshot)
1879 } else {
1880 true }
1882 }
1883 WorkflowContinuation::Delay { id, next, .. }
1884 | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1885 if snapshot.get_task_result(id).is_none() {
1886 return false;
1887 }
1888 next.as_ref()
1889 .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1890 }
1891 WorkflowContinuation::Fork { branches, join, .. } => {
1892 for branch in branches {
1894 if !Self::is_workflow_complete(branch, snapshot) {
1895 return false;
1896 }
1897 }
1898 if let Some(join_cont) = join {
1900 Self::is_workflow_complete(join_cont, snapshot)
1901 } else {
1902 true
1903 }
1904 }
1905 WorkflowContinuation::Branch { id, next, .. } => {
1906 if snapshot.get_task_result(id).is_none() {
1908 return false;
1909 }
1910 next.as_ref()
1911 .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1912 }
1913 WorkflowContinuation::Loop { id, next, .. } => {
1914 if snapshot.get_task_result(id).is_none() {
1916 return false;
1917 }
1918 next.as_ref()
1919 .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1920 }
1921 WorkflowContinuation::ChildWorkflow { id, next, .. } => {
1922 if snapshot.get_task_result(id).is_none() {
1924 return false;
1925 }
1926 next.as_ref()
1927 .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1928 }
1929 }
1930 }
1931}
1932
1933fn default_worker_id() -> String {
1935 let host = hostname::get()
1936 .ok()
1937 .and_then(|h| h.into_string().ok())
1938 .unwrap_or_else(|| "unknown".to_string());
1939 format!("{host}-{}", std::process::id())
1940}
1941
1942pub struct PooledWorkerBuilder<B> {
1949 worker_id: Option<String>,
1950 backend: B,
1951 registry: TaskRegistry,
1952 claim_ttl: Option<Duration>,
1953 batch_size: NonZeroUsize,
1954 aging_interval: Duration,
1955 tags: Vec<String>,
1956}
1957
1958impl<B> PooledWorkerBuilder<B>
1959where
1960 B: PersistentBackend + TaskClaimStore + 'static,
1961{
1962 #[must_use]
1966 pub fn worker_id(mut self, id: impl Into<String>) -> Self {
1967 self.worker_id = Some(id.into());
1968 self
1969 }
1970
1971 #[must_use]
1973 pub fn claim_ttl(mut self, ttl: Option<Duration>) -> Self {
1974 self.claim_ttl = ttl;
1975 self
1976 }
1977
1978 #[must_use]
1980 pub fn batch_size(mut self, size: NonZeroUsize) -> Self {
1981 self.batch_size = size;
1982 self
1983 }
1984
1985 #[must_use]
1991 pub fn aging_interval(mut self, interval: Duration) -> Self {
1992 assert!(!interval.is_zero(), "aging interval must be non-zero");
1993 self.aging_interval = interval;
1994 self
1995 }
1996
1997 #[must_use]
2003 pub fn tags(mut self, tags: Vec<String>) -> Self {
2004 self.tags = tags;
2005 self
2006 }
2007
2008 #[must_use]
2012 pub fn build(self) -> PooledWorker<B> {
2013 let worker_id = self.worker_id.unwrap_or_else(default_worker_id);
2014 PooledWorker {
2015 worker_id,
2016 backend: Arc::new(self.backend),
2017 registry: Arc::new(self.registry),
2018 claim_ttl: self.claim_ttl,
2019 batch_size: self.batch_size,
2020 aging_interval: self.aging_interval,
2021 tags: self.tags,
2022 }
2023 }
2024}
2025
2026#[cfg(test)]
2027#[allow(clippy::unwrap_used)]
2028mod tests {
2029 use super::*;
2030 use crate::serialization::JsonCodec;
2031 use sayiir_core::registry::TaskRegistry;
2032 use sayiir_core::snapshot::WorkflowSnapshot;
2033 use sayiir_persistence::{InMemoryBackend, SignalStore, SnapshotStore};
2034
2035 type EmptyWorkflows = WorkflowRegistry<JsonCodec, (), ()>;
2036
2037 fn make_worker() -> PooledWorker<InMemoryBackend> {
2038 let backend = InMemoryBackend::new();
2039 let registry = TaskRegistry::new();
2040 PooledWorker::new("test-worker", backend, registry)
2041 }
2042
2043 #[tokio::test]
2044 async fn test_spawn_and_shutdown() {
2045 let worker = make_worker();
2046 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
2047
2048 handle.shutdown();
2049
2050 let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
2051 assert!(result.is_ok(), "Worker should exit cleanly after shutdown");
2052 assert!(result.unwrap().is_ok());
2053 }
2054
2055 #[tokio::test]
2056 async fn test_handle_is_clone_and_send() {
2057 let worker = make_worker();
2058 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
2059
2060 let handle2 = handle.clone();
2061 let remote = tokio::spawn(async move {
2062 handle2.shutdown();
2063 });
2064 remote.await.ok();
2065
2066 let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
2067 assert!(result.is_ok_and(|r| r.is_ok()));
2068 }
2069
2070 #[tokio::test]
2071 async fn test_cancel_via_client() {
2072 let backend = InMemoryBackend::new();
2073 let registry = TaskRegistry::new();
2074
2075 let snapshot = WorkflowSnapshot::new("wf-1".to_string(), "hash-1".to_string());
2077 backend.save_snapshot(&snapshot).await.ok();
2078
2079 let worker = PooledWorker::new("test-worker", backend, registry);
2080 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
2081
2082 let client = crate::WorkflowClient::from_shared(std::sync::Arc::clone(handle.backend()));
2084 client
2085 .cancel(
2086 "wf-1",
2087 Some("test reason".to_string()),
2088 Some("tester".to_string()),
2089 )
2090 .await
2091 .ok();
2092
2093 let signal = handle
2095 .backend()
2096 .get_signal("wf-1", SignalKind::Cancel)
2097 .await;
2098 assert!(signal.is_ok_and(|s| s.is_some()));
2099
2100 handle.shutdown();
2101 tokio::time::timeout(Duration::from_secs(5), handle.join())
2102 .await
2103 .ok();
2104 }
2105
2106 #[test]
2107 fn test_builder_auto_generates_worker_id() {
2108 let backend = InMemoryBackend::new();
2109 let registry = TaskRegistry::new();
2110 let worker = PooledWorker::builder(backend, registry).build();
2111
2112 let pid = std::process::id().to_string();
2114 assert!(
2115 worker.worker_id.contains(&pid),
2116 "Auto-generated ID '{}' should contain PID '{}'",
2117 worker.worker_id,
2118 pid
2119 );
2120 }
2121
2122 #[test]
2123 fn test_builder_explicit_worker_id() {
2124 let backend = InMemoryBackend::new();
2125 let registry = TaskRegistry::new();
2126 let worker = PooledWorker::builder(backend, registry)
2127 .worker_id("my-worker")
2128 .build();
2129
2130 assert_eq!(worker.worker_id, "my-worker");
2131 }
2132
2133 #[test]
2134 fn test_builder_custom_settings() {
2135 let backend = InMemoryBackend::new();
2136 let registry = TaskRegistry::new();
2137 let worker = PooledWorker::builder(backend, registry)
2138 .worker_id("w1")
2139 .claim_ttl(Some(Duration::from_secs(120)))
2140 .batch_size(NonZeroUsize::new(8).unwrap())
2141 .build();
2142
2143 assert_eq!(worker.worker_id, "w1");
2144 assert_eq!(worker.claim_ttl, Some(Duration::from_secs(120)));
2145 assert_eq!(worker.batch_size.get(), 8);
2146 }
2147
2148 #[tokio::test]
2149 async fn test_dropped_handle_shuts_down_worker() {
2150 let worker = make_worker();
2151 let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
2152
2153 let join_handle = handle.inner.join_handle.lock().await.take().unwrap();
2155 drop(handle);
2156
2157 let result = tokio::time::timeout(Duration::from_secs(5), join_handle)
2158 .await
2159 .ok()
2160 .and_then(Result::ok);
2161 assert!(
2162 result.is_some(),
2163 "Worker should exit when all handles are dropped"
2164 );
2165 assert!(result.is_some_and(|r| r.is_ok()));
2166 }
2167}