Skip to main content

sayiir_runtime/
worker.rs

1//! Pooled worker for distributed, multi-worker workflow execution.
2//!
3//! A pooled worker is part of a worker pool that collaboratively executes workflows.
4//! Each worker polls the backend for available tasks, claims them (to prevent duplicates),
5//! executes them, and updates the snapshot. Multiple workers can process tasks from
6//! the same workflow instance in parallel.
7//!
8//! **Use this when**: You need horizontal scaling with multiple workers processing
9//! tasks concurrently across machines or processes.
10//!
11//! **Use [`CheckpointingRunner`](crate::runner::distributed::CheckpointingRunner) instead when**:
12//! You want a single process to run an entire workflow with crash recovery.
13
14use std::collections::HashMap;
15
16use bytes::Bytes;
17use chrono;
18use futures::FutureExt;
19use sayiir_core::codec::Codec;
20use sayiir_core::codec::sealed;
21use sayiir_core::context::with_context;
22use sayiir_core::error::{BoxError, WorkflowError};
23use sayiir_core::registry::TaskRegistry;
24use sayiir_core::snapshot::{
25    ExecutionPosition, SignalKind, SignalRequest, TaskDeadline, WorkflowSnapshot,
26};
27use sayiir_core::task_claim::AvailableTask;
28use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
29use sayiir_persistence::{PersistentBackend, SignalStore, TaskClaimStore};
30use std::num::NonZeroUsize;
31use std::panic::AssertUnwindSafe;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::mpsc;
36use tokio::time;
37
38/// A list of workflow definitions keyed by their definition hash.
39pub type WorkflowRegistry<C, Input, M> = Vec<(String, Arc<Workflow<C, Input, M>>)>;
40
41/// Workflow definition for binding-friendly worker API.
42///
43/// Contains only the structural information (definition hash + continuation tree)
44/// needed by `PooledWorker` for position tracking, completion detection, retry
45/// policies, and timeouts. Task execution is delegated to an external executor.
46pub struct ExternalWorkflow {
47    /// The continuation tree describing the workflow structure.
48    pub continuation: Arc<WorkflowContinuation>,
49}
50
51/// Workflow index keyed by definition hash for O(1) lookup during task dispatch.
52pub type WorkflowIndex = HashMap<String, ExternalWorkflow>;
53
54/// External task executor function signature.
55///
56/// Receives the task ID and input bytes, returns the output bytes.
57/// Used by language bindings (Python, Node.js) to delegate task execution
58/// to the host language's runtime.
59pub type ExternalTaskExecutor = Arc<
60    dyn Fn(
61            &str,
62            Bytes,
63        ) -> Pin<Box<dyn std::future::Future<Output = Result<Bytes, BoxError>> + Send>>
64        + Send
65        + Sync,
66>;
67
68/// Internal command sent from [`WorkerHandle`] to the actor loop.
69enum WorkerCommand {
70    Shutdown,
71}
72
73struct WorkerHandleInner<B> {
74    backend: Arc<B>,
75    shutdown_tx: mpsc::Sender<WorkerCommand>,
76    join_handle:
77        tokio::sync::Mutex<Option<tokio::task::JoinHandle<Result<(), crate::error::RuntimeError>>>>,
78}
79
80/// A cloneable handle for interacting with a running [`PooledWorker`].
81///
82/// Obtained from [`PooledWorker::spawn`]. The handle is cheap to clone and can
83/// be shared across tasks. Dropping **all** handles triggers a graceful
84/// shutdown of the worker (equivalent to calling [`shutdown`](Self::shutdown)).
85pub struct WorkerHandle<B> {
86    inner: Arc<WorkerHandleInner<B>>,
87}
88
89impl<B> Clone for WorkerHandle<B> {
90    fn clone(&self) -> Self {
91        Self {
92            inner: Arc::clone(&self.inner),
93        }
94    }
95}
96
97impl<B> WorkerHandle<B> {
98    /// Request a graceful shutdown of the worker.
99    ///
100    /// The worker will finish its current task (if any) and then exit.
101    /// This is a non-async, fire-and-forget operation — errors are ignored
102    /// (the actor may have already stopped).
103    pub fn shutdown(&self) {
104        let _ = self.inner.shutdown_tx.try_send(WorkerCommand::Shutdown);
105    }
106
107    /// Wait for the worker task to finish.
108    ///
109    /// The first caller gets the real result; subsequent callers get `Ok(())`.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if the worker task panicked or returned an error.
114    pub async fn join(&self) -> Result<(), crate::error::RuntimeError> {
115        let jh = self.inner.join_handle.lock().await.take();
116        match jh {
117            Some(jh) => Ok(jh.await??),
118            None => Ok(()),
119        }
120    }
121
122    /// Get a reference to the backend.
123    #[must_use]
124    pub fn backend(&self) -> &Arc<B> {
125        &self.inner.backend
126    }
127}
128
129impl<B: SignalStore> WorkerHandle<B> {
130    /// Request cancellation of a workflow.
131    ///
132    /// This stores a cancel signal directly in the backend. The worker will
133    /// pick it up at the next guard check (task boundary).
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if the signal cannot be stored (workflow not found or in terminal state).
138    pub async fn cancel_workflow(
139        &self,
140        instance_id: &str,
141        reason: Option<String>,
142        cancelled_by: Option<String>,
143    ) -> Result<(), crate::error::RuntimeError> {
144        self.inner
145            .backend
146            .store_signal(
147                instance_id,
148                SignalKind::Cancel,
149                SignalRequest::new(reason, cancelled_by),
150            )
151            .await?;
152        Ok(())
153    }
154
155    /// Request pausing of a workflow.
156    ///
157    /// This stores a pause signal directly in the backend. The worker will
158    /// pick it up at the next guard check (task boundary).
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if the signal cannot be stored (workflow not found or in terminal/paused state).
163    pub async fn pause_workflow(
164        &self,
165        instance_id: &str,
166        reason: Option<String>,
167        paused_by: Option<String>,
168    ) -> Result<(), crate::error::RuntimeError> {
169        self.inner
170            .backend
171            .store_signal(
172                instance_id,
173                SignalKind::Pause,
174                SignalRequest::new(reason, paused_by),
175            )
176            .await?;
177        Ok(())
178    }
179}
180
181/// Owns a claimed task and provides explicit release methods.
182///
183/// No `Drop` impl — callers must explicitly call `release()` or `release_quietly()`.
184struct ActiveTaskClaim<'a, B> {
185    backend: &'a B,
186    instance_id: String,
187    task_id: String,
188    worker_id: String,
189}
190
191impl<B: TaskClaimStore> ActiveTaskClaim<'_, B> {
192    /// Release the claim, propagating backend errors.
193    async fn release(self) -> Result<(), crate::error::RuntimeError> {
194        self.backend
195            .release_task_claim(&self.instance_id, &self.task_id, &self.worker_id)
196            .await?;
197        Ok(())
198    }
199
200    /// Release the claim, silently ignoring errors. Use for error/panic paths.
201    async fn release_quietly(self) {
202        let _ = self.release().await;
203    }
204}
205
206/// Outcome of running a task through `execute_with_deadline`.
207enum ExecutionOutcome {
208    /// Task completed successfully.
209    Success(Bytes),
210    /// Task execution returned an error.
211    TaskError(crate::error::RuntimeError),
212    /// Task panicked.
213    Panic(Box<dyn std::any::Any + Send>),
214    /// Heartbeat detected an expired deadline (active cancellation).
215    Timeout(crate::error::RuntimeError),
216}
217
218/// Extract a human-readable message from a panic payload.
219fn extract_panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
220    if let Some(s) = payload.downcast_ref::<&str>() {
221        s.to_string()
222    } else if let Some(s) = payload.downcast_ref::<String>() {
223        s.clone()
224    } else {
225        "Task panicked with unknown payload".to_string()
226    }
227}
228
229/// A pooled worker that claims and executes tasks from a shared backend.
230///
231/// `PooledWorker` is designed for horizontal scaling: multiple workers can run
232/// across different machines/processes, all polling the same backend for tasks.
233/// Task claiming with TTL prevents duplicate execution while allowing automatic
234/// recovery when workers crash.
235///
236/// # When to Use
237///
238/// - **Horizontal scaling**: Multiple workers process tasks concurrently
239/// - **Fault tolerance**: Failed workers' tasks are automatically reclaimed
240/// - **Load balancing**: Tasks distributed across available workers
241///
242/// For single-process execution with checkpointing, use
243/// [`CheckpointingRunner`](crate::runner::distributed::CheckpointingRunner).
244///
245/// # Example
246///
247/// ```rust,no_run
248/// # use sayiir_runtime::prelude::*;
249/// # use std::sync::Arc;
250/// # use std::time::Duration;
251/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
252/// let backend = InMemoryBackend::new();
253/// let registry = TaskRegistry::new();
254/// let worker = PooledWorker::new("worker-1", backend, registry);
255///
256/// let ctx = WorkflowContext::new("my-wf", Arc::new(JsonCodec), Arc::new(()));
257/// let workflow = WorkflowBuilder::new(ctx)
258///     .then("step1", |i: u32| async move { Ok(i + 1) })
259///     .build()?;
260/// let workflows = vec![(workflow.definition_hash().to_string(), Arc::new(workflow))];
261///
262/// // Spawn the worker and get a handle for lifecycle control
263/// let handle = worker.spawn(Duration::from_secs(1), workflows);
264/// // ... later ...
265/// handle.shutdown();
266/// handle.join().await?;
267/// # Ok(())
268/// # }
269/// ```
270pub struct PooledWorker<B> {
271    worker_id: String,
272    backend: Arc<B>,
273    #[allow(unused)]
274    registry: Arc<TaskRegistry>,
275    claim_ttl: Option<Duration>,
276    batch_size: NonZeroUsize,
277}
278
279impl<B> PooledWorker<B>
280where
281    B: PersistentBackend + TaskClaimStore + 'static,
282{
283    /// Create a new worker node.
284    ///
285    /// # Parameters
286    ///
287    /// - `worker_id`: Unique identifier for this worker node
288    /// - `backend`: The persistent backend to use
289    /// - `registry`: Task registry containing all task implementations
290    ///
291    /// # Heartbeat
292    ///
293    /// Heartbeats are derived automatically from `claim_ttl` (TTL / 2).
294    /// With the default 5-minute TTL, heartbeats fire every 2.5 minutes.
295    ///
296    pub fn new(worker_id: impl Into<String>, backend: B, registry: TaskRegistry) -> Self {
297        Self {
298            worker_id: worker_id.into(),
299            backend: Arc::new(backend),
300            registry: Arc::new(registry),
301            claim_ttl: Some(Duration::from_secs(5 * 60)), // Default 5 minutes
302            batch_size: NonZeroUsize::MIN,                // Default: fetch one task at a time (1)
303        }
304    }
305
306    /// Set the TTL for task claims.
307    #[must_use]
308    pub fn with_claim_ttl(mut self, ttl: Option<Duration>) -> Self {
309        self.claim_ttl = ttl;
310        self
311    }
312
313    /// Set the number of tasks to fetch per poll (default: 1).
314    ///
315    /// With `batch_size=1`, the worker fetches one task, executes it, then polls again.
316    /// Other workers can pick up remaining tasks immediately.
317    ///
318    /// Higher values reduce polling overhead but may cause workers to hold task IDs
319    /// they won't process immediately (though other workers can still claim them).
320    #[must_use]
321    pub fn with_batch_size(mut self, size: NonZeroUsize) -> Self {
322        self.batch_size = size;
323        self
324    }
325
326    /// Request cancellation of a workflow.
327    ///
328    /// This requests cancellation of the specified workflow instance.
329    /// Running tasks will complete, but no new tasks will be started.
330    ///
331    /// # Parameters
332    ///
333    /// - `instance_id`: The workflow instance ID to cancel
334    /// - `reason`: Optional reason for the cancellation
335    /// - `cancelled_by`: Optional identifier of who requested the cancellation
336    ///
337    /// # Errors
338    ///
339    /// Returns an error if the workflow cannot be cancelled (not found or in terminal state).
340    pub async fn cancel_workflow(
341        &self,
342        instance_id: &str,
343        reason: Option<String>,
344        cancelled_by: Option<String>,
345    ) -> Result<(), crate::error::RuntimeError> {
346        self.backend
347            .store_signal(
348                instance_id,
349                SignalKind::Cancel,
350                SignalRequest::new(reason, cancelled_by),
351            )
352            .await?;
353
354        Ok(())
355    }
356
357    /// Request pausing of a workflow.
358    ///
359    /// This requests pausing of the specified workflow instance.
360    /// Running tasks will complete, but no new tasks will be started.
361    ///
362    /// # Parameters
363    ///
364    /// - `instance_id`: The workflow instance ID to pause
365    /// - `reason`: Optional reason for the pause
366    /// - `paused_by`: Optional identifier of who requested the pause
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if the workflow cannot be paused (not found or in terminal/paused state).
371    pub async fn pause_workflow(
372        &self,
373        instance_id: &str,
374        reason: Option<String>,
375        paused_by: Option<String>,
376    ) -> Result<(), crate::error::RuntimeError> {
377        self.backend
378            .store_signal(
379                instance_id,
380                SignalKind::Pause,
381                SignalRequest::new(reason, paused_by),
382            )
383            .await?;
384
385        Ok(())
386    }
387
388    /// Get a reference to the backend.
389    #[must_use]
390    pub fn backend(&self) -> &Arc<B> {
391        &self.backend
392    }
393
394    /// Spawn the worker as a background task and return a handle.
395    ///
396    /// Consumes `self`, creates an internal command channel, and spawns the
397    /// actor loop on the Tokio runtime. Returns a cloneable [`WorkerHandle`]
398    /// for lifecycle control — call [`WorkerHandle::shutdown`] to request
399    /// graceful shutdown and [`WorkerHandle::wait`] to await completion.
400    ///
401    /// The worker runs until:
402    /// - [`WorkerHandle::shutdown`] is called, or
403    /// - All clones of the handle are dropped, or
404    /// - A fatal backend error occurs.
405    ///
406    /// # Parameters
407    ///
408    /// - `poll_interval`: How often to poll for new tasks
409    /// - `workflows`: Map of workflow definition hash to workflow
410    #[must_use]
411    pub fn spawn<C, Input, M>(
412        self,
413        poll_interval: Duration,
414        workflows: WorkflowRegistry<C, Input, M>,
415    ) -> WorkerHandle<B>
416    where
417        Input: Send + Sync + 'static,
418        M: Send + Sync + 'static,
419        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
420    {
421        let (tx, rx) = mpsc::channel(1);
422        let backend = Arc::clone(&self.backend);
423        let join_handle =
424            tokio::spawn(async move { self.run_actor_loop(poll_interval, workflows, rx).await });
425        WorkerHandle {
426            inner: Arc::new(WorkerHandleInner {
427                backend,
428                shutdown_tx: tx,
429                join_handle: tokio::sync::Mutex::new(Some(join_handle)),
430            }),
431        }
432    }
433
434    /// Spawn the worker with an external executor and return a handle.
435    ///
436    /// Like [`spawn`](Self::spawn) but instead of executing tasks via typed
437    /// `Workflow` closures, delegates all task execution to the provided
438    /// `executor`. This is used by language bindings (Python, Node.js) where
439    /// task functions live in the host language.
440    ///
441    /// # Parameters
442    ///
443    /// - `poll_interval`: How often to poll for new tasks
444    /// - `workflows`: Workflow definitions (hash + continuation tree)
445    /// - `executor`: Closure that executes a task by ID given input bytes
446    #[must_use]
447    pub fn spawn_with_executor(
448        self,
449        poll_interval: Duration,
450        workflows: WorkflowIndex,
451        executor: ExternalTaskExecutor,
452    ) -> WorkerHandle<B> {
453        let (tx, rx) = mpsc::channel(1);
454        let backend = Arc::clone(&self.backend);
455        let join_handle = tokio::spawn(async move {
456            self.run_external_actor_loop(poll_interval, workflows, executor, rx)
457                .await
458        });
459        WorkerHandle {
460            inner: Arc::new(WorkerHandleInner {
461                backend,
462                shutdown_tx: tx,
463                join_handle: tokio::sync::Mutex::new(Some(join_handle)),
464            }),
465        }
466    }
467
468    /// Actor loop for external executor mode.
469    async fn run_external_actor_loop(
470        &self,
471        poll_interval: Duration,
472        workflows: WorkflowIndex,
473        executor: ExternalTaskExecutor,
474        mut cmd_rx: mpsc::Receiver<WorkerCommand>,
475    ) -> Result<(), crate::error::RuntimeError> {
476        let mut interval = time::interval(poll_interval);
477
478        loop {
479            tokio::select! {
480                biased;
481
482                cmd = cmd_rx.recv() => {
483                    match cmd {
484                        Some(WorkerCommand::Shutdown) | None => {
485                            tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
486                            return Ok(());
487                        }
488                    }
489                }
490
491                _ = interval.tick() => {
492                    tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
493                }
494            }
495
496            let available_tasks = self
497                .backend
498                .find_available_tasks(&self.worker_id, self.batch_size.get())
499                .await?;
500
501            for task in available_tasks {
502                if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
503                    cmd_rx.try_recv()
504                {
505                    tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
506                    return Ok(());
507                }
508
509                if let Some(ext_wf) = workflows.get(&task.workflow_definition_hash) {
510                    match self
511                        .execute_external_task(
512                            &ext_wf.continuation,
513                            &task.workflow_definition_hash,
514                            &executor,
515                            &task,
516                        )
517                        .await
518                    {
519                        Err(ref e) if e.is_timeout() => {
520                            tracing::error!(
521                                worker_id = %self.worker_id,
522                                error = %e,
523                                "Task timed out — worker shutting down"
524                            );
525                            return Ok(());
526                        }
527                        Ok(_) => {
528                            tracing::info!("Worker {} completed a task", self.worker_id);
529                        }
530                        Err(e) => {
531                            tracing::error!(
532                                "Worker {} task execution failed: {}",
533                                self.worker_id,
534                                e
535                            );
536                        }
537                    }
538                }
539            }
540        }
541    }
542
543    /// Execute a single task using an external executor.
544    async fn execute_external_task(
545        &self,
546        continuation: &WorkflowContinuation,
547        definition_hash: &str,
548        executor: &ExternalTaskExecutor,
549        available_task: &AvailableTask,
550    ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
551        let mut snapshot = self
552            .backend
553            .load_snapshot(&available_task.instance_id)
554            .await?;
555        let already_completed = Self::validate_task_preconditions(
556            definition_hash,
557            continuation,
558            available_task,
559            &snapshot,
560        )?;
561        if already_completed {
562            return Ok(WorkflowStatus::InProgress);
563        }
564
565        let Some(claim) = self.claim_task(available_task).await? else {
566            return Ok(WorkflowStatus::InProgress);
567        };
568
569        if let Some(status) = self.check_post_claim_guards(available_task).await? {
570            claim.release_quietly().await;
571            return Ok(status);
572        }
573
574        tracing::debug!(
575            instance_id = %available_task.instance_id,
576            task_id = %available_task.task_id,
577            "Executing task (external)"
578        );
579
580        let execution_result = self
581            .execute_with_deadline_ext(
582                continuation,
583                executor,
584                available_task,
585                &mut snapshot,
586                &claim,
587            )
588            .await;
589
590        self.settle_execution_result_ext(
591            execution_result,
592            continuation,
593            available_task,
594            &mut snapshot,
595            claim,
596        )
597        .await
598    }
599
600    /// Run the external executor with an optional deadline.
601    async fn execute_with_deadline_ext(
602        &self,
603        continuation: &WorkflowContinuation,
604        executor: &ExternalTaskExecutor,
605        available_task: &AvailableTask,
606        snapshot: &mut WorkflowSnapshot,
607        claim: &ActiveTaskClaim<'_, B>,
608    ) -> ExecutionOutcome {
609        let task_id = available_task.task_id.clone();
610        let input = available_task.input.clone();
611
612        let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
613            snapshot.set_task_deadline(task_id.clone(), timeout);
614            let _ = self.backend.save_snapshot(snapshot).await;
615            snapshot.refresh_task_deadline();
616            snapshot.task_deadline.clone()
617        } else {
618            None
619        };
620
621        let execution_future = executor(&task_id, input);
622
623        let heartbeat_result = self
624            .run_with_heartbeat(
625                claim,
626                deadline.as_ref(),
627                AssertUnwindSafe(execution_future).catch_unwind(),
628            )
629            .await;
630
631        snapshot.clear_task_deadline();
632
633        match heartbeat_result {
634            Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
635            Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
636            Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e.into()),
637            Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
638        }
639    }
640
641    /// Settle execution result for external executor mode.
642    async fn settle_execution_result_ext(
643        &self,
644        outcome: ExecutionOutcome,
645        continuation: &WorkflowContinuation,
646        available_task: &AvailableTask,
647        snapshot: &mut WorkflowSnapshot,
648        claim: ActiveTaskClaim<'_, B>,
649    ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
650        match outcome {
651            ExecutionOutcome::Timeout(err) => {
652                if let Ok(Some(status)) = self
653                    .try_schedule_retry(continuation, available_task, snapshot, &err.to_string())
654                    .await
655                {
656                    claim.release_quietly().await;
657                    return Ok(status);
658                }
659
660                tracing::warn!(
661                    instance_id = %available_task.instance_id,
662                    task_id = %available_task.task_id,
663                    error = %err,
664                    "Task timed out via heartbeat — marking workflow failed, shutting down"
665                );
666                snapshot.mark_failed(err.to_string());
667                let _ = self.backend.save_snapshot(snapshot).await;
668                claim.release_quietly().await;
669                Err(err)
670            }
671            ExecutionOutcome::Panic(panic_payload) => {
672                let panic_msg = extract_panic_message(&panic_payload);
673
674                if let Ok(Some(status)) = self
675                    .try_schedule_retry(continuation, available_task, snapshot, &panic_msg)
676                    .await
677                {
678                    claim.release_quietly().await;
679                    return Ok(status);
680                }
681
682                tracing::error!(
683                    instance_id = %available_task.instance_id,
684                    task_id = %available_task.task_id,
685                    panic = %panic_msg,
686                    "Task panicked - releasing claim"
687                );
688                claim.release_quietly().await;
689                Err(WorkflowError::TaskPanicked(panic_msg).into())
690            }
691            ExecutionOutcome::TaskError(e) => {
692                if let Ok(Some(status)) = self
693                    .try_schedule_retry(continuation, available_task, snapshot, &e.to_string())
694                    .await
695                {
696                    claim.release_quietly().await;
697                    return Ok(status);
698                }
699
700                tracing::error!(
701                    instance_id = %available_task.instance_id,
702                    task_id = %available_task.task_id,
703                    error = %e,
704                    "Task execution failed"
705                );
706                claim.release_quietly().await;
707                Err(e)
708            }
709            ExecutionOutcome::Success(output) => {
710                snapshot.clear_retry_state(&available_task.task_id);
711                self.commit_task_result(
712                    continuation,
713                    available_task,
714                    snapshot,
715                    output.clone(),
716                    claim,
717                )
718                .await?;
719                self.determine_post_task_status(continuation, available_task, snapshot, output)
720                    .await
721            }
722        }
723    }
724
725    /// The actor loop: poll for tasks, execute them, respond to shutdown.
726    ///
727    async fn run_actor_loop<C, Input, M>(
728        &self,
729        poll_interval: Duration,
730        workflows: WorkflowRegistry<C, Input, M>,
731        mut cmd_rx: mpsc::Receiver<WorkerCommand>,
732    ) -> Result<(), crate::error::RuntimeError>
733    where
734        Input: Send + 'static,
735        M: Send + Sync + 'static,
736        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
737    {
738        let mut interval = time::interval(poll_interval);
739
740        loop {
741            tokio::select! {
742                biased;
743
744                cmd = cmd_rx.recv() => {
745                    // None (all handles dropped) or Some(Shutdown) → exit
746                    match cmd {
747                        Some(WorkerCommand::Shutdown) | None => {
748                            tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
749                            return Ok(());
750                        }
751                    }
752                }
753
754                _ = interval.tick() => {
755                    tracing::trace!(worker_id = %self.worker_id, "Will poll for available tasks");
756                }
757            }
758
759            let available_tasks = self
760                .backend
761                .find_available_tasks(&self.worker_id, self.batch_size.get())
762                .await?;
763
764            for task in available_tasks {
765                if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
766                    cmd_rx.try_recv()
767                {
768                    tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
769                    return Ok(());
770                }
771
772                if let Some((_, workflow)) = workflows
773                    .iter()
774                    .find(|(hash, _)| *hash == task.workflow_definition_hash)
775                {
776                    match self.execute_task(workflow.as_ref(), task).await {
777                        Err(ref e) if e.is_timeout() => {
778                            tracing::error!(
779                                worker_id = %self.worker_id,
780                                error = %e,
781                                "Task timed out — worker shutting down"
782                            );
783                            return Ok(());
784                        }
785                        Ok(_) => {
786                            tracing::info!("Worker {} completed a task", self.worker_id);
787                        }
788                        Err(e) => {
789                            tracing::error!(
790                                "Worker {} task execution failed: {}",
791                                self.worker_id,
792                                e
793                            );
794                        }
795                    }
796                }
797            }
798        }
799    }
800
801    /// Load cancellation status from a snapshot.
802    ///
803    /// Attempts to load the snapshot and extract cancellation details.
804    /// Returns `WorkflowStatus::Cancelled` with either the extracted details or defaults.
805    async fn load_cancelled_status(&self, instance_id: &str) -> WorkflowStatus {
806        if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
807            && let Some((reason, cancelled_by)) = snapshot.state.cancellation_details()
808        {
809            return WorkflowStatus::Cancelled {
810                reason,
811                cancelled_by,
812            };
813        }
814        WorkflowStatus::Cancelled {
815            reason: None,
816            cancelled_by: None,
817        }
818    }
819
820    /// Load paused status from a snapshot.
821    ///
822    /// Attempts to load the snapshot and extract pause details.
823    /// Returns `WorkflowStatus::Paused` with either the extracted details or defaults.
824    async fn load_paused_status(&self, instance_id: &str) -> WorkflowStatus {
825        if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
826            && let Some((reason, paused_by)) = snapshot.state.pause_details()
827        {
828            return WorkflowStatus::Paused { reason, paused_by };
829        }
830        WorkflowStatus::Paused {
831            reason: None,
832            paused_by: None,
833        }
834    }
835
836    /// Execute a single task from an available task.
837    ///
838    /// This claims the task, executes it, updates the snapshot, and releases the claim.
839    ///
840    /// # Errors
841    ///
842    /// Returns an error if:
843    /// - The task cannot be claimed
844    /// - The workflow definition hash doesn't match
845    /// - Task execution fails
846    /// - Snapshot update fails
847    pub async fn execute_task<C, Input, M>(
848        &self,
849        workflow: &Workflow<C, Input, M>,
850        available_task: AvailableTask,
851    ) -> Result<WorkflowStatus, crate::error::RuntimeError>
852    where
853        Input: Send + 'static,
854        M: Send + Sync + 'static,
855        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
856    {
857        // 1. Load snapshot + pure validation
858        let mut snapshot = self
859            .backend
860            .load_snapshot(&available_task.instance_id)
861            .await?;
862        let already_completed = Self::validate_task_preconditions(
863            workflow.definition_hash(),
864            workflow.continuation(),
865            &available_task,
866            &snapshot,
867        )?;
868        if already_completed {
869            return Ok(WorkflowStatus::InProgress);
870        }
871
872        let Some(claim) = self.claim_task(&available_task).await? else {
873            return Ok(WorkflowStatus::InProgress);
874        };
875
876        // 3. Post-claim guards (cancel/pause)
877        if let Some(status) = self.check_post_claim_guards(&available_task).await? {
878            claim.release_quietly().await;
879            return Ok(status);
880        }
881
882        tracing::debug!(
883            instance_id = %available_task.instance_id,
884            task_id = %available_task.task_id,
885            "Executing task"
886        );
887
888        // 4. Execute with deadline + heartbeat, then settle the result
889        let execution_result = self
890            .execute_with_deadline(workflow, &available_task, &mut snapshot, &claim)
891            .await;
892
893        self.settle_execution_result(
894            execution_result,
895            workflow,
896            &available_task,
897            &mut snapshot,
898            claim,
899        )
900        .await
901    }
902
903    /// Run the task future with an optional deadline, returning the panic-wrapped result.
904    ///
905    /// Sets a deadline on the snapshot (if the task has a timeout), persists it,
906    /// then runs the future inside `run_with_heartbeat`. On heartbeat-level timeout
907    /// the task future is dropped and an `Err` is returned.
908    async fn execute_with_deadline<C, Input, M>(
909        &self,
910        workflow: &Workflow<C, Input, M>,
911        available_task: &AvailableTask,
912        snapshot: &mut WorkflowSnapshot,
913        claim: &ActiveTaskClaim<'_, B>,
914    ) -> ExecutionOutcome
915    where
916        Input: Send + 'static,
917        M: Send + Sync + 'static,
918        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
919    {
920        let continuation = workflow.continuation();
921        let task_id = available_task.task_id.clone();
922        let input = available_task.input.clone();
923
924        // Set deadline if task has a timeout configured
925        let deadline = if let Some(timeout) = continuation.get_task_timeout(&task_id) {
926            snapshot.set_task_deadline(task_id.clone(), timeout);
927            let _ = self.backend.save_snapshot(snapshot).await;
928            // Refresh deadline to now + timeout so it measures actual execution
929            // time, not time spent on the snapshot save above.
930            snapshot.refresh_task_deadline();
931            snapshot.task_deadline.clone()
932        } else {
933            None
934        };
935
936        let context = workflow.context().clone();
937        let execution_future = with_context(context, || async move {
938            Self::execute_task_by_id(continuation, &task_id, input).await
939        });
940
941        let heartbeat_result = self
942            .run_with_heartbeat(
943                claim,
944                deadline.as_ref(),
945                AssertUnwindSafe(execution_future).catch_unwind(),
946            )
947            .await;
948
949        snapshot.clear_task_deadline();
950
951        match heartbeat_result {
952            Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
953            Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
954            Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e),
955            Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
956        }
957    }
958
959    /// Try to schedule a retry for a failed task.
960    ///
961    /// Looks up the retry policy on the continuation. If retries are available,
962    /// records the retry state on the snapshot, clears the deadline, saves the
963    /// snapshot, releases the claim, and returns `Ok(Some(InProgress))`.
964    /// Otherwise returns `Ok(None)` (caller falls through to existing error handling).
965    async fn try_schedule_retry(
966        &self,
967        continuation: &WorkflowContinuation,
968        available_task: &AvailableTask,
969        snapshot: &mut WorkflowSnapshot,
970        error_msg: &str,
971    ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
972        let Some(policy) = continuation.get_task_retry_policy(&available_task.task_id) else {
973            return Ok(None);
974        };
975
976        if snapshot.retries_exhausted(&available_task.task_id) {
977            return Ok(None);
978        }
979
980        let next_retry_at = snapshot.record_retry(
981            &available_task.task_id,
982            policy,
983            error_msg,
984            Some(&self.worker_id),
985        );
986        snapshot.clear_task_deadline();
987        let _ = self.backend.save_snapshot(snapshot).await;
988
989        tracing::info!(
990            instance_id = %available_task.instance_id,
991            task_id = %available_task.task_id,
992            attempt = snapshot.get_retry_state(&available_task.task_id).map_or(0, |rs| rs.attempts),
993            max_retries = policy.max_retries,
994            %next_retry_at,
995            "Scheduling retry"
996        );
997
998        Ok(Some(WorkflowStatus::InProgress))
999    }
1000
1001    /// Settle the outcome of task execution: persist results or errors, release claim.
1002    async fn settle_execution_result<C, Input, M>(
1003        &self,
1004        outcome: ExecutionOutcome,
1005        workflow: &Workflow<C, Input, M>,
1006        available_task: &AvailableTask,
1007        snapshot: &mut WorkflowSnapshot,
1008        claim: ActiveTaskClaim<'_, B>,
1009    ) -> Result<WorkflowStatus, crate::error::RuntimeError>
1010    where
1011        Input: Send + 'static,
1012        M: Send + Sync + 'static,
1013        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
1014    {
1015        match outcome {
1016            ExecutionOutcome::Timeout(err) => {
1017                if let Ok(Some(status)) = self
1018                    .try_schedule_retry(
1019                        workflow.continuation(),
1020                        available_task,
1021                        snapshot,
1022                        &err.to_string(),
1023                    )
1024                    .await
1025                {
1026                    claim.release_quietly().await;
1027                    return Ok(status);
1028                }
1029
1030                tracing::warn!(
1031                    instance_id = %available_task.instance_id,
1032                    task_id = %available_task.task_id,
1033                    error = %err,
1034                    "Task timed out via heartbeat — marking workflow failed, shutting down"
1035                );
1036                snapshot.mark_failed(err.to_string());
1037                let _ = self.backend.save_snapshot(snapshot).await;
1038                claim.release_quietly().await;
1039                Err(err)
1040            }
1041            ExecutionOutcome::Panic(panic_payload) => {
1042                let panic_msg = extract_panic_message(&panic_payload);
1043
1044                if let Ok(Some(status)) = self
1045                    .try_schedule_retry(
1046                        workflow.continuation(),
1047                        available_task,
1048                        snapshot,
1049                        &panic_msg,
1050                    )
1051                    .await
1052                {
1053                    claim.release_quietly().await;
1054                    return Ok(status);
1055                }
1056
1057                tracing::error!(
1058                    instance_id = %available_task.instance_id,
1059                    task_id = %available_task.task_id,
1060                    panic = %panic_msg,
1061                    "Task panicked - releasing claim"
1062                );
1063                claim.release_quietly().await;
1064                Err(WorkflowError::TaskPanicked(panic_msg).into())
1065            }
1066            ExecutionOutcome::TaskError(e) => {
1067                if let Ok(Some(status)) = self
1068                    .try_schedule_retry(
1069                        workflow.continuation(),
1070                        available_task,
1071                        snapshot,
1072                        &e.to_string(),
1073                    )
1074                    .await
1075                {
1076                    claim.release_quietly().await;
1077                    return Ok(status);
1078                }
1079
1080                tracing::error!(
1081                    instance_id = %available_task.instance_id,
1082                    task_id = %available_task.task_id,
1083                    error = %e,
1084                    "Task execution failed"
1085                );
1086                claim.release_quietly().await;
1087                Err(e)
1088            }
1089            ExecutionOutcome::Success(output) => {
1090                snapshot.clear_retry_state(&available_task.task_id);
1091                self.commit_task_result(
1092                    workflow.continuation(),
1093                    available_task,
1094                    snapshot,
1095                    output.clone(),
1096                    claim,
1097                )
1098                .await?;
1099                self.determine_post_task_status(
1100                    workflow.continuation(),
1101                    available_task,
1102                    snapshot,
1103                    output,
1104                )
1105                .await
1106            }
1107        }
1108    }
1109
1110    /// Validate task preconditions without side effects.
1111    ///
1112    /// Checks definition hash match, task existence in continuation,
1113    /// and that the task is not already completed in the snapshot.
1114    /// Returns `Ok(true)` if the task should be skipped (already completed).
1115    fn validate_task_preconditions(
1116        definition_hash: &str,
1117        continuation: &WorkflowContinuation,
1118        available_task: &AvailableTask,
1119        snapshot: &WorkflowSnapshot,
1120    ) -> Result<bool, crate::error::RuntimeError> {
1121        if available_task.workflow_definition_hash != definition_hash {
1122            return Err(WorkflowError::DefinitionMismatch {
1123                expected: definition_hash.to_string(),
1124                found: available_task.workflow_definition_hash.clone(),
1125            }
1126            .into());
1127        }
1128
1129        if !Self::find_task_id_in_continuation(continuation, &available_task.task_id) {
1130            tracing::error!(
1131                instance_id = %available_task.instance_id,
1132                task_id = %available_task.task_id,
1133                "Task does not exist in workflow"
1134            );
1135            return Err(WorkflowError::TaskNotFound(available_task.task_id.clone()).into());
1136        }
1137
1138        if snapshot.get_task_result(&available_task.task_id).is_some() {
1139            tracing::debug!(
1140                instance_id = %available_task.instance_id,
1141                task_id = %available_task.task_id,
1142                "Task already completed, skipping"
1143            );
1144            return Ok(true);
1145        }
1146
1147        Ok(false)
1148    }
1149
1150    /// Acquire a claim on the task, returning an `ActiveTaskClaim`.
1151    ///
1152    /// Returns `None` if already claimed by another worker.
1153    async fn claim_task(
1154        &self,
1155        available_task: &AvailableTask,
1156    ) -> Result<Option<ActiveTaskClaim<'_, B>>, crate::error::RuntimeError> {
1157        let claim = self
1158            .backend
1159            .claim_task(
1160                &available_task.instance_id,
1161                &available_task.task_id,
1162                &self.worker_id,
1163                self.claim_ttl
1164                    .and_then(|d| chrono::Duration::from_std(d).ok()),
1165            )
1166            .await?;
1167
1168        if claim.is_some() {
1169            tracing::debug!(
1170                instance_id = %available_task.instance_id,
1171                task_id = %available_task.task_id,
1172                "Claim successful"
1173            );
1174            Ok(Some(ActiveTaskClaim {
1175                backend: &self.backend,
1176                instance_id: available_task.instance_id.clone(),
1177                task_id: available_task.task_id.clone(),
1178                worker_id: self.worker_id.clone(),
1179            }))
1180        } else {
1181            tracing::debug!(
1182                instance_id = %available_task.instance_id,
1183                task_id = %available_task.task_id,
1184                "Task was already claimed by another worker"
1185            );
1186            Ok(None)
1187        }
1188    }
1189
1190    /// Check cancel/pause guards after claiming.
1191    ///
1192    /// Returns `Some(status)` if the workflow is cancelled or paused
1193    /// (caller should release claim and return status).
1194    /// Returns `None` if execution should proceed.
1195    async fn check_post_claim_guards(
1196        &self,
1197        available_task: &AvailableTask,
1198    ) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
1199        if self
1200            .backend
1201            .check_and_cancel(&available_task.instance_id, Some(&available_task.task_id))
1202            .await?
1203        {
1204            tracing::info!(
1205                instance_id = %available_task.instance_id,
1206                task_id = %available_task.task_id,
1207                "Workflow was cancelled, releasing claim"
1208            );
1209            return Ok(Some(
1210                self.load_cancelled_status(&available_task.instance_id)
1211                    .await,
1212            ));
1213        }
1214
1215        if self
1216            .backend
1217            .check_and_pause(&available_task.instance_id)
1218            .await?
1219        {
1220            tracing::info!(
1221                instance_id = %available_task.instance_id,
1222                task_id = %available_task.task_id,
1223                "Workflow was paused, releasing claim"
1224            );
1225            return Ok(Some(
1226                self.load_paused_status(&available_task.instance_id).await,
1227            ));
1228        }
1229
1230        Ok(None)
1231    }
1232
1233    /// Execute a future while periodically extending the task claim.
1234    ///
1235    /// If a `deadline` is provided, the heartbeat tick also checks whether the
1236    /// deadline has expired. If it has, the task future is dropped (active
1237    /// cancellation) and a `TaskTimedOut` error is returned.
1238    async fn run_with_heartbeat<F, T>(
1239        &self,
1240        claim: &ActiveTaskClaim<'_, B>,
1241        deadline: Option<&TaskDeadline>,
1242        future: F,
1243    ) -> Result<T, crate::error::RuntimeError>
1244    where
1245        F: std::future::Future<Output = T>,
1246    {
1247        let Some(ttl) = self.claim_ttl else {
1248            return Ok(future.await);
1249        };
1250        let Some(chrono_ttl) = chrono::Duration::from_std(ttl).ok() else {
1251            return Ok(future.await);
1252        };
1253
1254        let interval_duration = ttl / 2;
1255        let mut heartbeat_timer = time::interval(interval_duration);
1256        heartbeat_timer.tick().await; // skip first immediate tick
1257
1258        tokio::pin!(future);
1259
1260        loop {
1261            tokio::select! {
1262                result = &mut future => break Ok(result),
1263                _ = heartbeat_timer.tick() => {
1264                    // Check deadline during heartbeat
1265                    if let Some(dl) = deadline
1266                        && chrono::Utc::now() >= dl.deadline
1267                    {
1268                        tracing::warn!(
1269                            instance_id = %claim.instance_id,
1270                            task_id = %dl.task_id,
1271                            "Task deadline expired during heartbeat, cancelling"
1272                        );
1273                        return Err(WorkflowError::TaskTimedOut {
1274                            task_id: dl.task_id.clone(),
1275                            timeout: std::time::Duration::from_millis(dl.timeout_ms),
1276                        }
1277                        .into());
1278                    }
1279
1280                    tracing::trace!(
1281                        instance_id = %claim.instance_id,
1282                        task_id = %claim.task_id,
1283                        "Extending task claim via heartbeat"
1284                    );
1285                    if let Err(e) = self.backend
1286                        .extend_task_claim(
1287                            &claim.instance_id,
1288                            &claim.task_id,
1289                            &claim.worker_id,
1290                            chrono_ttl,
1291                        )
1292                        .await
1293                    {
1294                        tracing::warn!(
1295                            instance_id = %claim.instance_id,
1296                            task_id = %claim.task_id,
1297                            error = %e,
1298                            "Failed to extend task claim"
1299                        );
1300                    }
1301                }
1302            }
1303        }
1304    }
1305
1306    /// Persist task result and release the claim.
1307    async fn commit_task_result(
1308        &self,
1309        continuation: &WorkflowContinuation,
1310        available_task: &AvailableTask,
1311        snapshot: &mut WorkflowSnapshot,
1312        output: Bytes,
1313        claim: ActiveTaskClaim<'_, B>,
1314    ) -> Result<(), crate::error::RuntimeError> {
1315        snapshot.mark_task_completed(available_task.task_id.clone(), output);
1316        tracing::debug!(
1317            instance_id = %available_task.instance_id,
1318            task_id = %available_task.task_id,
1319            "Task completed"
1320        );
1321
1322        Self::update_position_after_task(continuation, &available_task.task_id, snapshot);
1323        self.backend.save_snapshot(snapshot).await?;
1324        claim.release().await?;
1325        Ok(())
1326    }
1327
1328    /// Determine workflow status after a task completes.
1329    ///
1330    /// Checks cancel/pause guards and workflow completion.
1331    async fn determine_post_task_status(
1332        &self,
1333        continuation: &WorkflowContinuation,
1334        available_task: &AvailableTask,
1335        snapshot: &mut WorkflowSnapshot,
1336        output: Bytes,
1337    ) -> Result<WorkflowStatus, crate::error::RuntimeError> {
1338        // Check for cancellation after task completion
1339        if self
1340            .backend
1341            .check_and_cancel(&available_task.instance_id, None)
1342            .await?
1343        {
1344            tracing::info!(
1345                instance_id = %available_task.instance_id,
1346                task_id = %available_task.task_id,
1347                "Workflow was cancelled after task completion"
1348            );
1349            return Ok(self
1350                .load_cancelled_status(&available_task.instance_id)
1351                .await);
1352        }
1353
1354        // Check for pause after task completion
1355        if self
1356            .backend
1357            .check_and_pause(&available_task.instance_id)
1358            .await?
1359        {
1360            tracing::info!(
1361                instance_id = %available_task.instance_id,
1362                task_id = %available_task.task_id,
1363                "Workflow was paused after task completion"
1364            );
1365            return Ok(self.load_paused_status(&available_task.instance_id).await);
1366        }
1367
1368        if Self::is_workflow_complete(continuation, snapshot) {
1369            tracing::info!(
1370                instance_id = %available_task.instance_id,
1371                task_id = %available_task.task_id,
1372                "Workflow complete"
1373            );
1374            snapshot.mark_completed(output);
1375            self.backend.save_snapshot(snapshot).await?;
1376            Ok(WorkflowStatus::Completed)
1377        } else {
1378            tracing::debug!(
1379                instance_id = %available_task.instance_id,
1380                task_id = %available_task.task_id,
1381                "Task completed, workflow continues"
1382            );
1383            Ok(WorkflowStatus::InProgress)
1384        }
1385    }
1386
1387    /// Find a task function in the workflow continuation and return a reference.
1388    ///
1389    /// Note: We can't clone `UntypedCoreTask`, so we need to execute it directly
1390    /// from the continuation structure. This method returns the task ID if found.
1391    fn find_task_id_in_continuation(continuation: &WorkflowContinuation, task_id: &str) -> bool {
1392        match continuation {
1393            WorkflowContinuation::Task { id, next, .. }
1394            | WorkflowContinuation::Delay { id, next, .. }
1395            | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1396                if id == task_id {
1397                    return true;
1398                }
1399                next.as_ref()
1400                    .is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
1401            }
1402            WorkflowContinuation::Fork { branches, join, .. } => {
1403                // Check branches
1404                for branch in branches {
1405                    if Self::find_task_id_in_continuation(branch, task_id) {
1406                        return true;
1407                    }
1408                }
1409                // Check join
1410                if let Some(join_cont) = join {
1411                    Self::find_task_id_in_continuation(join_cont, task_id)
1412                } else {
1413                    false
1414                }
1415            }
1416        }
1417    }
1418
1419    /// Execute a task by ID from the workflow continuation (iterative, no boxing).
1420    #[allow(clippy::manual_async_fn)]
1421    fn execute_task_by_id<'a>(
1422        continuation: &'a WorkflowContinuation,
1423        task_id: &'a str,
1424        input: Bytes,
1425    ) -> impl std::future::Future<Output = Result<Bytes, crate::error::RuntimeError>> + Send + 'a
1426    {
1427        async move {
1428            let mut current = continuation;
1429
1430            loop {
1431                match current {
1432                    WorkflowContinuation::Task { id, func, next, .. } => {
1433                        if id == task_id {
1434                            let func = func
1435                                .as_ref()
1436                                .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
1437                            return Ok(func.run(input).await?);
1438                        } else if let Some(next_cont) = next {
1439                            current = next_cont;
1440                        } else {
1441                            return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1442                        }
1443                    }
1444                    WorkflowContinuation::Delay { next, .. }
1445                    | WorkflowContinuation::AwaitSignal { next, .. } => {
1446                        // Skip over delay/signal nodes when searching for a task
1447                        if let Some(next_cont) = next {
1448                            current = next_cont;
1449                        } else {
1450                            return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1451                        }
1452                    }
1453                    WorkflowContinuation::Fork { branches, join, .. } => {
1454                        // Check branches
1455                        let mut found_in_branch = false;
1456                        for branch in branches {
1457                            if Self::find_task_id_in_continuation(branch, task_id) {
1458                                current = branch;
1459                                found_in_branch = true;
1460                                break;
1461                            }
1462                        }
1463                        if found_in_branch {
1464                            continue;
1465                        }
1466                        // Check join
1467                        if let Some(join_cont) = join {
1468                            current = join_cont;
1469                        } else {
1470                            return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
1471                        }
1472                    }
1473                }
1474            }
1475        }
1476    }
1477
1478    /// Update execution position after a task completes.
1479    fn update_position_after_task(
1480        continuation: &WorkflowContinuation,
1481        completed_task_id: &str,
1482        snapshot: &mut WorkflowSnapshot,
1483    ) {
1484        match continuation {
1485            WorkflowContinuation::Task { id, next, .. }
1486            | WorkflowContinuation::Delay { id, next, .. }
1487            | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1488                if id == completed_task_id {
1489                    if let Some(next_cont) = next {
1490                        snapshot.update_position(ExecutionPosition::AtTask {
1491                            task_id: next_cont.first_task_id().to_string(),
1492                        });
1493                    }
1494                } else if let Some(next_cont) = next {
1495                    Self::update_position_after_task(next_cont, completed_task_id, snapshot);
1496                }
1497            }
1498            WorkflowContinuation::Fork { branches, join, .. } => {
1499                // Check if any branch task completed
1500                for branch in branches {
1501                    Self::update_position_after_task(branch, completed_task_id, snapshot);
1502                }
1503                // Check join
1504                if let Some(join_cont) = join {
1505                    Self::update_position_after_task(join_cont, completed_task_id, snapshot);
1506                }
1507            }
1508        }
1509    }
1510
1511    /// Check if the workflow is complete based on the snapshot.
1512    fn is_workflow_complete(
1513        continuation: &WorkflowContinuation,
1514        snapshot: &WorkflowSnapshot,
1515    ) -> bool {
1516        // Check if all tasks in the continuation are completed
1517        match continuation {
1518            WorkflowContinuation::Task { id, next, .. } => {
1519                if snapshot.get_task_result(id).is_none() {
1520                    return false;
1521                }
1522                if let Some(next_cont) = next {
1523                    Self::is_workflow_complete(next_cont, snapshot)
1524                } else {
1525                    true // Last task completed
1526                }
1527            }
1528            WorkflowContinuation::Delay { id, next, .. }
1529            | WorkflowContinuation::AwaitSignal { id, next, .. } => {
1530                if snapshot.get_task_result(id).is_none() {
1531                    return false;
1532                }
1533                next.as_ref()
1534                    .is_none_or(|n| Self::is_workflow_complete(n, snapshot))
1535            }
1536            WorkflowContinuation::Fork { branches, join, .. } => {
1537                // All branches must be completed (recursively check entire branch chain)
1538                for branch in branches {
1539                    if !Self::is_workflow_complete(branch, snapshot) {
1540                        return false;
1541                    }
1542                }
1543                // Join must be completed if it exists
1544                if let Some(join_cont) = join {
1545                    Self::is_workflow_complete(join_cont, snapshot)
1546                } else {
1547                    true
1548                }
1549            }
1550        }
1551    }
1552}
1553
1554#[cfg(test)]
1555#[allow(clippy::unwrap_used)]
1556mod tests {
1557    use super::*;
1558    use crate::serialization::JsonCodec;
1559    use sayiir_core::registry::TaskRegistry;
1560    use sayiir_core::snapshot::WorkflowSnapshot;
1561    use sayiir_persistence::{InMemoryBackend, SignalStore, SnapshotStore};
1562
1563    type EmptyWorkflows = WorkflowRegistry<JsonCodec, (), ()>;
1564
1565    fn make_worker() -> PooledWorker<InMemoryBackend> {
1566        let backend = InMemoryBackend::new();
1567        let registry = TaskRegistry::new();
1568        PooledWorker::new("test-worker", backend, registry)
1569    }
1570
1571    #[tokio::test]
1572    async fn test_spawn_and_shutdown() {
1573        let worker = make_worker();
1574        let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1575
1576        handle.shutdown();
1577
1578        let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
1579        assert!(result.is_ok(), "Worker should exit cleanly after shutdown");
1580        assert!(result.unwrap().is_ok());
1581    }
1582
1583    #[tokio::test]
1584    async fn test_handle_is_clone_and_send() {
1585        let worker = make_worker();
1586        let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1587
1588        let handle2 = handle.clone();
1589        let remote = tokio::spawn(async move {
1590            handle2.shutdown();
1591        });
1592        remote.await.ok();
1593
1594        let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
1595        assert!(result.is_ok_and(|r| r.is_ok()));
1596    }
1597
1598    #[tokio::test]
1599    async fn test_cancel_via_handle() {
1600        let backend = InMemoryBackend::new();
1601        let registry = TaskRegistry::new();
1602
1603        // Create a workflow snapshot so store_signal can validate it
1604        let snapshot = WorkflowSnapshot::new("wf-1".to_string(), "hash-1".to_string());
1605        backend.save_snapshot(&snapshot).await.ok();
1606
1607        let worker = PooledWorker::new("test-worker", backend, registry);
1608        let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1609
1610        handle
1611            .cancel_workflow(
1612                "wf-1",
1613                Some("test reason".to_string()),
1614                Some("tester".to_string()),
1615            )
1616            .await
1617            .ok();
1618
1619        // Verify the signal was stored
1620        let signal = handle
1621            .backend()
1622            .get_signal("wf-1", SignalKind::Cancel)
1623            .await;
1624        assert!(signal.is_ok_and(|s| s.is_some()));
1625
1626        handle.shutdown();
1627        tokio::time::timeout(Duration::from_secs(5), handle.join())
1628            .await
1629            .ok();
1630    }
1631
1632    #[tokio::test]
1633    async fn test_dropped_handle_shuts_down_worker() {
1634        let worker = make_worker();
1635        let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
1636
1637        // Extract the join handle before dropping so we can still await completion
1638        let join_handle = handle.inner.join_handle.lock().await.take().unwrap();
1639        drop(handle);
1640
1641        let result = tokio::time::timeout(Duration::from_secs(5), join_handle)
1642            .await
1643            .ok()
1644            .and_then(Result::ok);
1645        assert!(
1646            result.is_some(),
1647            "Worker should exit when all handles are dropped"
1648        );
1649        assert!(result.is_some_and(|r| r.is_ok()));
1650    }
1651}