use std::collections::HashMap;
use bytes::Bytes;
use chrono;
use futures::FutureExt;
use sayiir_core::codec::sealed;
use sayiir_core::codec::{Codec, EnvelopeCodec, LoopDecision};
use sayiir_core::context::{TaskExecutionContext, with_task_context};
use sayiir_core::error::{BoxError, CodecError, WorkflowError};
use sayiir_core::registry::TaskRegistry;
use sayiir_core::snapshot::{
ExecutionPosition, SignalKind, SignalRequest, TaskDeadline, TaskHint, WorkflowSnapshot,
};
use sayiir_core::task_claim::AvailableTask;
use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
use sayiir_persistence::{PersistentBackend, TaskClaimStore, TaskWakeupHint};
use std::num::NonZeroUsize;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time;
fn staggered_poll_interval(worker_id: &str, poll_interval: Duration) -> time::Interval {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
worker_id.hash(&mut hasher);
let period_ns = u64::try_from(poll_interval.as_nanos())
.unwrap_or(u64::MAX)
.max(1);
let offset = Duration::from_nanos(hasher.finish() % period_ns);
time::interval_at(time::Instant::now() + offset, poll_interval)
}
pub type WorkflowRegistry<C, Input, M> =
Vec<(sayiir_core::DefinitionHash, Arc<Workflow<C, Input, M>>)>;
pub struct ExternalWorkflow {
pub continuation: Arc<WorkflowContinuation>,
pub task_index: Arc<sayiir_core::TaskIndex>,
pub workflow_id: Arc<str>,
pub metadata_json: Option<Arc<str>>,
}
pub type WorkflowIndex = HashMap<sayiir_core::DefinitionHash, ExternalWorkflow>;
pub(crate) trait WorkflowLookup {
fn contains_definition_hash(&self, hash: &sayiir_core::DefinitionHash) -> bool;
}
impl WorkflowLookup for WorkflowIndex {
fn contains_definition_hash(&self, hash: &sayiir_core::DefinitionHash) -> bool {
self.contains_key(hash)
}
}
impl<C, Input, M> WorkflowLookup for WorkflowRegistry<C, Input, M> {
fn contains_definition_hash(&self, hash: &sayiir_core::DefinitionHash) -> bool {
self.iter().any(|(k, _)| k == hash)
}
}
pub type ExternalTaskExecutor = Arc<
dyn Fn(
&str,
Bytes,
) -> Pin<Box<dyn std::future::Future<Output = Result<Bytes, BoxError>> + Send>>
+ Send
+ Sync,
>;
enum WorkerCommand {
Shutdown,
}
struct WorkerHandleInner<B> {
backend: Arc<B>,
shutdown_tx: mpsc::Sender<WorkerCommand>,
join_handle:
tokio::sync::Mutex<Option<tokio::task::JoinHandle<Result<(), crate::error::RuntimeError>>>>,
}
pub struct WorkerHandle<B> {
inner: Arc<WorkerHandleInner<B>>,
}
impl<B> Clone for WorkerHandle<B> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<B> WorkerHandle<B> {
pub fn shutdown(&self) {
let _ = self.inner.shutdown_tx.try_send(WorkerCommand::Shutdown);
}
pub async fn join(&self) -> Result<(), crate::error::RuntimeError> {
let jh = self.inner.join_handle.lock().await.take();
match jh {
Some(jh) => Ok(jh.await??),
None => Ok(()),
}
}
#[must_use]
pub fn backend(&self) -> &Arc<B> {
&self.inner.backend
}
}
struct ActiveTaskClaim<'a, B> {
backend: &'a B,
instance_id: std::sync::Arc<str>,
task_id: sayiir_core::TaskId,
worker_id: String,
}
impl<B: TaskClaimStore> ActiveTaskClaim<'_, B> {
async fn release(self) -> Result<(), crate::error::RuntimeError> {
self.backend
.release_task_claim(&self.instance_id, &self.task_id, &self.worker_id)
.await?;
Ok(())
}
async fn release_quietly(self) {
let _ = self.release().await;
}
}
enum ExecutionOutcome {
Success(Bytes),
TaskError(crate::error::RuntimeError),
Panic(Box<dyn std::any::Any + Send>),
Timeout(crate::error::RuntimeError),
}
fn extract_panic_message(payload: &Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"Task panicked with unknown payload".to_string()
}
}
pub struct PooledWorker<B> {
worker_id: String,
backend: Arc<B>,
#[allow(unused)]
registry: Arc<TaskRegistry>,
claim_ttl: Option<Duration>,
batch_size: NonZeroUsize,
aging_interval: Duration,
tags: Vec<String>,
}
impl<B> PooledWorker<B>
where
B: PersistentBackend + TaskClaimStore + 'static,
{
pub fn new(worker_id: impl Into<String>, backend: B, registry: TaskRegistry) -> Self {
Self {
worker_id: worker_id.into(),
backend: Arc::new(backend),
registry: Arc::new(registry),
claim_ttl: Some(Duration::from_mins(5)), batch_size: NonZeroUsize::MIN, aging_interval: Duration::from_mins(5), tags: vec![],
}
}
#[must_use]
pub fn with_claim_ttl(mut self, ttl: Option<Duration>) -> Self {
if ttl.is_none() {
tracing::warn!(
"PooledWorker::with_claim_ttl(None) disables claim expiry; \
a crashed worker will pin its workflow until manual release"
);
}
self.claim_ttl = ttl;
self
}
#[must_use]
pub fn with_aging_interval(mut self, interval: Duration) -> Self {
assert!(!interval.is_zero(), "aging interval must be non-zero");
self.aging_interval = interval;
self
}
#[must_use]
pub fn with_batch_size(mut self, size: NonZeroUsize) -> Self {
self.batch_size = size;
self
}
#[must_use]
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub async fn cancel_workflow(
&self,
instance_id: &str,
reason: Option<String>,
cancelled_by: Option<String>,
) -> Result<(), crate::error::RuntimeError> {
self.backend
.store_signal(
instance_id,
SignalKind::Cancel,
SignalRequest::new(reason, cancelled_by),
)
.await?;
Ok(())
}
pub async fn pause_workflow(
&self,
instance_id: &str,
reason: Option<String>,
paused_by: Option<String>,
) -> Result<(), crate::error::RuntimeError> {
self.backend
.store_signal(
instance_id,
SignalKind::Pause,
SignalRequest::new(reason, paused_by),
)
.await?;
Ok(())
}
#[must_use]
pub fn backend(&self) -> &Arc<B> {
&self.backend
}
#[must_use]
pub fn spawn<C, Input, M>(
self,
poll_interval: Duration,
workflows: WorkflowRegistry<C, Input, M>,
) -> WorkerHandle<B>
where
Input: Send + Sync + 'static,
M: Send + Sync + 'static,
C: Codec
+ EnvelopeCodec
+ sealed::DecodeValue<Input>
+ sealed::EncodeValue<Input>
+ 'static,
{
let (tx, rx) = mpsc::channel(1);
let backend = Arc::clone(&self.backend);
let join_handle =
tokio::spawn(async move { self.run_actor_loop(poll_interval, workflows, rx).await });
WorkerHandle {
inner: Arc::new(WorkerHandleInner {
backend,
shutdown_tx: tx,
join_handle: tokio::sync::Mutex::new(Some(join_handle)),
}),
}
}
#[must_use]
pub fn spawn_with_executor(
self,
poll_interval: Duration,
workflows: WorkflowIndex,
executor: ExternalTaskExecutor,
) -> WorkerHandle<B> {
let (tx, rx) = mpsc::channel(1);
let backend = Arc::clone(&self.backend);
let join_handle = tokio::spawn(async move {
self.run_external_actor_loop(poll_interval, workflows, executor, rx)
.await
});
WorkerHandle {
inner: Arc::new(WorkerHandleInner {
backend,
shutdown_tx: tx,
join_handle: tokio::sync::Mutex::new(Some(join_handle)),
}),
}
}
async fn run_external_actor_loop(
&self,
poll_interval: Duration,
workflows: WorkflowIndex,
executor: ExternalTaskExecutor,
mut cmd_rx: mpsc::Receiver<WorkerCommand>,
) -> Result<(), crate::error::RuntimeError> {
let mut interval = staggered_poll_interval(&self.worker_id, poll_interval);
loop {
let hint = tokio::select! {
biased;
cmd = cmd_rx.recv() => {
match cmd {
Some(WorkerCommand::Shutdown) | None => {
tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
return Ok(());
}
}
}
_ = interval.tick() => {
tracing::trace!(worker_id = %self.worker_id, "fallback poll tick");
None
}
hint = self.backend.wait_for_wakeup(poll_interval) => {
let hint = hint?;
tracing::debug!(
worker_id = %self.worker_id,
has_hint = hint.is_some(),
"wakeup notification",
);
hint
}
};
if let Some(h) = hint.as_ref()
&& !self.can_handle_hint(h, &workflows)
{
tracing::trace!(
worker_id = %self.worker_id,
instance_id = %h.instance_id,
"skipping wakeup, hint not handleable here",
);
continue;
}
if let Some(h) = hint.as_ref() {
match self.try_hinted_execute(h, &workflows, &executor).await {
Ok(true) => continue,
Ok(false) => {}
Err(ref e) if e.is_timeout() => {
tracing::error!(worker_id = %self.worker_id, error = %e, "Task timed out — worker shutting down");
return Ok(());
}
Err(e) => return Err(e),
}
}
let available_tasks = self
.backend
.find_available_tasks(
&self.worker_id,
self.batch_size.get(),
chrono::Duration::from_std(self.aging_interval)
.unwrap_or(chrono::Duration::MAX),
&self.tags,
)
.await?;
for task in available_tasks {
if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
cmd_rx.try_recv()
{
tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
return Ok(());
}
if let Some(ext_wf) = workflows.get(&task.workflow_definition_hash) {
match self
.execute_external_task(
ext_wf,
&task.workflow_definition_hash,
&executor,
&task,
)
.await
{
Err(ref e) if e.is_timeout() => {
tracing::error!(
worker_id = %self.worker_id,
error = %e,
"Task timed out — worker shutting down"
);
return Ok(());
}
Ok(_) => {
tracing::info!(worker_id = %self.worker_id, "completed task");
}
Err(e) => {
tracing::error!(
worker_id = %self.worker_id,
error = %e,
"task execution failed"
);
}
}
}
}
}
}
fn can_handle_hint(&self, hint: &TaskWakeupHint, workflows: &impl WorkflowLookup) -> bool {
let hash = sayiir_core::DefinitionHash::from_bytes(hint.definition_hash);
if !workflows.contains_definition_hash(&hash) {
return false;
}
hint.tags.iter().all(|t| self.tags.contains(t))
}
async fn try_hinted_execute(
&self,
hint: &TaskWakeupHint,
workflows: &WorkflowIndex,
executor: &ExternalTaskExecutor,
) -> Result<bool, crate::error::RuntimeError> {
let Some(task) = self.backend.find_hinted_task(hint).await? else {
return Ok(false);
};
let Some(ext_wf) = workflows.get(&task.workflow_definition_hash) else {
return Ok(false);
};
match self
.execute_external_task(ext_wf, &task.workflow_definition_hash, executor, &task)
.await
{
Err(e) if e.is_timeout() => return Err(e),
Ok(_) => tracing::info!(worker_id = %self.worker_id, "completed hinted task"),
Err(e) => {
tracing::error!(worker_id = %self.worker_id, error = %e, "hinted task execution failed");
}
}
Ok(true)
}
#[tracing::instrument(
name = "workflow",
skip_all,
fields(
worker_id = %self.worker_id,
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
definition_hash = %definition_hash,
),
)]
async fn execute_external_task(
&self,
ext_wf: &ExternalWorkflow,
definition_hash: &sayiir_core::DefinitionHash,
executor: &ExternalTaskExecutor,
available_task: &AvailableTask,
) -> Result<WorkflowStatus, crate::error::RuntimeError> {
#[cfg(feature = "otel")]
if let Some(ref tp) = available_task.trace_parent {
use tracing_opentelemetry::OpenTelemetrySpanExt;
let remote_ctx = crate::trace_context::context_from_trace_parent(tp);
let _ = tracing::Span::current().set_parent(remote_ctx);
}
let already_completed = Self::validate_task_preconditions(
definition_hash,
&ext_wf.task_index,
available_task,
&available_task.snapshot,
)?;
if already_completed {
return Ok(WorkflowStatus::InProgress);
}
let Some(claim) = self.claim_task(available_task).await? else {
return Ok(WorkflowStatus::InProgress);
};
if let Some(status) = self.check_post_claim_guards(available_task).await? {
claim.release_quietly().await;
return Ok(status);
}
let mut snapshot = (*available_task.snapshot).clone();
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Executing task (external)"
);
let execution_result = self
.execute_with_deadline_ext(ext_wf, executor, available_task, &mut snapshot, &claim)
.await;
self.settle_execution_result_ext(
execution_result,
&ext_wf.continuation,
&ext_wf.task_index,
available_task,
&mut snapshot,
claim,
)
.await
}
async fn execute_with_deadline_ext(
&self,
ext_wf: &ExternalWorkflow,
executor: &ExternalTaskExecutor,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
claim: &ActiveTaskClaim<'_, B>,
) -> ExecutionOutcome {
let task_id = available_task.task_id;
let input = available_task.input.clone();
let indexed_meta = ext_wf.task_index.get(&task_id);
let task_name: Arc<str> =
indexed_meta.map_or_else(|| Arc::from(task_id.to_hex()), |m| Arc::clone(m.name()));
let deadline =
if let Some(timeout) = indexed_meta.and_then(sayiir_core::TaskNodeMetadata::timeout) {
snapshot.set_task_deadline(task_id, timeout);
snapshot.refresh_task_deadline();
snapshot.task_deadline.clone()
} else {
None
};
let task_ctx = TaskExecutionContext {
workflow_id: Arc::clone(&ext_wf.workflow_id),
instance_id: Arc::clone(&available_task.instance_id),
task_id: Arc::clone(&task_name),
metadata: ext_wf.task_index.build_task_metadata(&task_id),
workflow_metadata_json: ext_wf.metadata_json.clone(),
};
let execution_future = with_task_context(task_ctx, executor(&task_name, input));
let heartbeat_result = self
.run_with_heartbeat(
claim,
deadline.as_ref(),
AssertUnwindSafe(execution_future).catch_unwind(),
)
.await;
snapshot.clear_task_deadline();
match heartbeat_result {
Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e.into()),
Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
}
}
#[tracing::instrument(
name = "settle_result",
skip_all,
fields(worker_id = %self.worker_id, instance_id = %available_task.instance_id, task_id = %available_task.task_id),
)]
async fn settle_execution_result_ext(
&self,
outcome: ExecutionOutcome,
continuation: &WorkflowContinuation,
task_index: &sayiir_core::TaskIndex,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
claim: ActiveTaskClaim<'_, B>,
) -> Result<WorkflowStatus, crate::error::RuntimeError> {
tracing::debug!("settling execution result");
match outcome {
ExecutionOutcome::Timeout(err) => {
if let Ok(Some(status)) = self
.try_schedule_retry(task_index, available_task, snapshot, &err.to_string())
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::warn!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
error = %err,
"Task timed out via heartbeat — marking workflow failed, shutting down"
);
snapshot.mark_failed(err.to_string());
let _ = self.backend.save_snapshot(snapshot).await;
claim.release_quietly().await;
Err(err)
}
ExecutionOutcome::Panic(panic_payload) => {
let panic_msg = extract_panic_message(&panic_payload);
if let Ok(Some(status)) = self
.try_schedule_retry(task_index, available_task, snapshot, &panic_msg)
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::error!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
panic = %panic_msg,
"Task panicked - releasing claim"
);
claim.release_quietly().await;
Err(WorkflowError::TaskPanicked(panic_msg).into())
}
ExecutionOutcome::TaskError(e) => {
if let Ok(Some(status)) = self
.try_schedule_retry(task_index, available_task, snapshot, &e.to_string())
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::error!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
error = %e,
"Task execution failed"
);
claim.release_quietly().await;
Err(e)
}
ExecutionOutcome::Success(output) => {
snapshot.clear_retry_state(&available_task.task_id);
self.commit_task_result(
continuation,
available_task,
snapshot,
output.clone(),
claim,
)
.await?;
self.determine_post_task_status(continuation, available_task, snapshot, output)
.await
}
}
}
async fn run_actor_loop<C, Input, M>(
&self,
poll_interval: Duration,
workflows: WorkflowRegistry<C, Input, M>,
mut cmd_rx: mpsc::Receiver<WorkerCommand>,
) -> Result<(), crate::error::RuntimeError>
where
Input: Send + 'static,
M: Send + Sync + 'static,
C: Codec
+ EnvelopeCodec
+ sealed::DecodeValue<Input>
+ sealed::EncodeValue<Input>
+ 'static,
{
let mut interval = staggered_poll_interval(&self.worker_id, poll_interval);
loop {
let hint = tokio::select! {
biased;
cmd = cmd_rx.recv() => {
match cmd {
Some(WorkerCommand::Shutdown) | None => {
tracing::info!(worker_id = %self.worker_id, "Worker shutting down");
return Ok(());
}
}
}
_ = interval.tick() => {
tracing::trace!(worker_id = %self.worker_id, "fallback poll tick");
None
}
hint = self.backend.wait_for_wakeup(poll_interval) => {
let hint = hint?;
tracing::debug!(
worker_id = %self.worker_id,
has_hint = hint.is_some(),
"wakeup notification",
);
hint
}
};
if let Some(h) = hint.as_ref()
&& !self.can_handle_hint(h, &workflows)
{
tracing::trace!(
worker_id = %self.worker_id,
instance_id = %h.instance_id,
"skipping wakeup, hint not handleable here",
);
continue;
}
let available_tasks = self
.backend
.find_available_tasks(
&self.worker_id,
self.batch_size.get(),
chrono::Duration::from_std(self.aging_interval)
.unwrap_or(chrono::Duration::MAX),
&self.tags,
)
.await?;
for task in available_tasks {
if let Ok(WorkerCommand::Shutdown) | Err(mpsc::error::TryRecvError::Disconnected) =
cmd_rx.try_recv()
{
tracing::info!(worker_id = %self.worker_id, "Worker shutting down mid-batch");
return Ok(());
}
if let Some((_, workflow)) = workflows
.iter()
.find(|(hash, _)| *hash == task.workflow_definition_hash)
{
match self.execute_task(workflow.as_ref(), task).await {
Err(ref e) if e.is_timeout() => {
tracing::error!(
worker_id = %self.worker_id,
error = %e,
"Task timed out — worker shutting down"
);
return Ok(());
}
Ok(_) => {
tracing::info!(worker_id = %self.worker_id, "completed task");
}
Err(e) => {
tracing::error!(
worker_id = %self.worker_id,
error = %e,
"task execution failed"
);
}
}
}
}
}
}
async fn load_cancelled_status(&self, instance_id: &str) -> WorkflowStatus {
if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
&& let Some((reason, cancelled_by)) = snapshot.state.cancellation_details()
{
return WorkflowStatus::Cancelled {
reason,
cancelled_by,
};
}
WorkflowStatus::Cancelled {
reason: None,
cancelled_by: None,
}
}
async fn load_paused_status(&self, instance_id: &str) -> WorkflowStatus {
if let Ok(snapshot) = self.backend.load_snapshot(instance_id).await
&& let Some((reason, paused_by)) = snapshot.state.pause_details()
{
return WorkflowStatus::Paused { reason, paused_by };
}
WorkflowStatus::Paused {
reason: None,
paused_by: None,
}
}
#[tracing::instrument(
name = "workflow",
skip_all,
fields(
worker_id = %self.worker_id,
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
definition_hash = %available_task.workflow_definition_hash,
),
)]
pub async fn execute_task<C, Input, M>(
&self,
workflow: &Workflow<C, Input, M>,
available_task: AvailableTask,
) -> Result<WorkflowStatus, crate::error::RuntimeError>
where
Input: Send + 'static,
M: Send + Sync + 'static,
C: Codec
+ EnvelopeCodec
+ sealed::DecodeValue<Input>
+ sealed::EncodeValue<Input>
+ 'static,
{
#[cfg(feature = "otel")]
if let Some(ref tp) = available_task.trace_parent {
use tracing_opentelemetry::OpenTelemetrySpanExt;
let remote_ctx = crate::trace_context::context_from_trace_parent(tp);
let _ = tracing::Span::current().set_parent(remote_ctx);
}
let already_completed = Self::validate_task_preconditions(
workflow.definition_hash(),
workflow.task_index(),
&available_task,
&available_task.snapshot,
)?;
if already_completed {
return Ok(WorkflowStatus::InProgress);
}
let Some(claim) = self.claim_task(&available_task).await? else {
return Ok(WorkflowStatus::InProgress);
};
if let Some(status) = self.check_post_claim_guards(&available_task).await? {
claim.release_quietly().await;
return Ok(status);
}
let mut snapshot = (*available_task.snapshot).clone();
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Executing task"
);
let execution_result = self
.execute_with_deadline(workflow, &available_task, &mut snapshot, &claim)
.await;
self.settle_execution_result(
execution_result,
workflow,
&available_task,
&mut snapshot,
claim,
)
.await
}
async fn execute_with_deadline<C, Input, M>(
&self,
workflow: &Workflow<C, Input, M>,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
claim: &ActiveTaskClaim<'_, B>,
) -> ExecutionOutcome
where
Input: Send + 'static,
M: Send + Sync + 'static,
C: Codec
+ EnvelopeCodec
+ sealed::DecodeValue<Input>
+ sealed::EncodeValue<Input>
+ 'static,
{
let continuation = workflow.continuation();
let task_index = workflow.task_index();
let task_id = available_task.task_id;
let input = match Self::find_fork_branches_for_join(continuation, &task_id) {
Some(branches) => {
let build = || -> Result<Bytes, crate::error::RuntimeError> {
let mut results = Vec::with_capacity(branches.len());
for branch in branches {
let branch_name = branch.id().to_string();
let terminal_tid = sayiir_core::TaskId::from(branch.terminal_task_id());
let output = snapshot
.get_task_result_bytes(&terminal_tid)
.ok_or_else(|| WorkflowError::TaskNotFound(branch_name.clone()))?;
results.push((branch_name, output));
}
crate::execution::serialize_branch_results(&results, workflow.codec().as_ref())
};
match build() {
Ok(bytes) => bytes,
Err(e) => return ExecutionOutcome::TaskError(e),
}
}
None => available_task.input.clone(),
};
let indexed_meta = task_index.get(&task_id);
let task_name: Arc<str> =
indexed_meta.map_or_else(|| Arc::from(task_id.to_hex()), |m| Arc::clone(m.name()));
let deadline =
if let Some(timeout) = indexed_meta.and_then(sayiir_core::TaskNodeMetadata::timeout) {
snapshot.set_task_deadline(task_id, timeout);
snapshot.refresh_task_deadline();
snapshot.task_deadline.clone()
} else {
None
};
let task_ctx = TaskExecutionContext {
workflow_id: Arc::from(workflow.context().workflow_id()),
instance_id: Arc::clone(&available_task.instance_id),
task_id: Arc::clone(&task_name),
metadata: task_index.build_task_metadata(&task_id),
workflow_metadata_json: workflow.context().metadata_json.clone(),
};
let execution_future = with_task_context(task_ctx, async move {
Self::execute_task_by_id(continuation, &task_name, input).await
});
let heartbeat_result = self
.run_with_heartbeat(
claim,
deadline.as_ref(),
AssertUnwindSafe(execution_future).catch_unwind(),
)
.await;
snapshot.clear_task_deadline();
match heartbeat_result {
Err(timeout_err) => ExecutionOutcome::Timeout(timeout_err),
Ok(Err(panic_payload)) => ExecutionOutcome::Panic(panic_payload),
Ok(Ok(Err(e))) => ExecutionOutcome::TaskError(e),
Ok(Ok(Ok(output))) => ExecutionOutcome::Success(output),
}
}
async fn try_schedule_retry(
&self,
task_index: &sayiir_core::TaskIndex,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
error_msg: &str,
) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
let Some(policy) = task_index.retry_policy(&available_task.task_id) else {
return Ok(None);
};
if snapshot.retries_exhausted(&available_task.task_id) {
return Ok(None);
}
let next_retry_at = snapshot.record_retry(
available_task.task_id,
policy,
error_msg,
Some(&self.worker_id),
);
snapshot.clear_task_deadline();
let _ = self.backend.save_snapshot(snapshot).await;
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
attempt = snapshot.get_retry_state(&available_task.task_id).map_or(0, |rs| rs.attempts),
max_retries = policy.max_retries,
%next_retry_at,
"Scheduling retry"
);
Ok(Some(WorkflowStatus::InProgress))
}
#[tracing::instrument(
name = "settle_result",
skip_all,
fields(worker_id = %self.worker_id, instance_id = %available_task.instance_id, task_id = %available_task.task_id),
)]
async fn settle_execution_result<C, Input, M>(
&self,
outcome: ExecutionOutcome,
workflow: &Workflow<C, Input, M>,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
claim: ActiveTaskClaim<'_, B>,
) -> Result<WorkflowStatus, crate::error::RuntimeError>
where
Input: Send + 'static,
M: Send + Sync + 'static,
C: Codec
+ EnvelopeCodec
+ sealed::DecodeValue<Input>
+ sealed::EncodeValue<Input>
+ 'static,
{
tracing::debug!("settling execution result");
match outcome {
ExecutionOutcome::Timeout(err) => {
if let Ok(Some(status)) = self
.try_schedule_retry(
workflow.task_index(),
available_task,
snapshot,
&err.to_string(),
)
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::warn!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
error = %err,
"Task timed out via heartbeat — marking workflow failed, shutting down"
);
snapshot.mark_failed(err.to_string());
let _ = self.backend.save_snapshot(snapshot).await;
claim.release_quietly().await;
Err(err)
}
ExecutionOutcome::Panic(panic_payload) => {
let panic_msg = extract_panic_message(&panic_payload);
if let Ok(Some(status)) = self
.try_schedule_retry(workflow.task_index(), available_task, snapshot, &panic_msg)
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::error!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
panic = %panic_msg,
"Task panicked - releasing claim"
);
claim.release_quietly().await;
Err(WorkflowError::TaskPanicked(panic_msg).into())
}
ExecutionOutcome::TaskError(e) => {
if let Ok(Some(status)) = self
.try_schedule_retry(
workflow.task_index(),
available_task,
snapshot,
&e.to_string(),
)
.await
{
claim.release_quietly().await;
return Ok(status);
}
tracing::error!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
error = %e,
"Task execution failed"
);
claim.release_quietly().await;
Err(e)
}
ExecutionOutcome::Success(output) => {
snapshot.clear_retry_state(&available_task.task_id);
self.commit_task_result(
workflow.continuation(),
available_task,
snapshot,
output.clone(),
claim,
)
.await?;
Self::resolve_loop_completions(
workflow.continuation(),
snapshot,
self.backend.as_ref(),
)
.await?;
self.determine_post_task_status(
workflow.continuation(),
available_task,
snapshot,
output,
)
.await
}
}
}
fn validate_task_preconditions(
definition_hash: &sayiir_core::DefinitionHash,
task_index: &sayiir_core::TaskIndex,
available_task: &AvailableTask,
snapshot: &WorkflowSnapshot,
) -> Result<bool, crate::error::RuntimeError> {
if available_task.workflow_definition_hash != *definition_hash {
return Err(WorkflowError::DefinitionMismatch {
expected: *definition_hash,
found: available_task.workflow_definition_hash,
}
.into());
}
if !task_index.contains(&available_task.task_id) {
tracing::error!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Task does not exist in workflow"
);
return Err(WorkflowError::TaskNotFound(available_task.task_id.to_hex()).into());
}
if snapshot.get_task_result(&available_task.task_id).is_some() {
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Task already completed, skipping"
);
return Ok(true);
}
Ok(false)
}
async fn claim_task(
&self,
available_task: &AvailableTask,
) -> Result<Option<ActiveTaskClaim<'_, B>>, crate::error::RuntimeError> {
let claim = self
.backend
.claim_task(
&available_task.instance_id,
&available_task.task_id,
&self.worker_id,
self.claim_ttl
.and_then(|d| chrono::Duration::from_std(d).ok()),
)
.await?;
if claim.is_some() {
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Claim successful"
);
Ok(Some(ActiveTaskClaim {
backend: &self.backend,
instance_id: Arc::clone(&available_task.instance_id),
task_id: available_task.task_id,
worker_id: self.worker_id.clone(),
}))
} else {
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Task was already claimed by another worker"
);
Ok(None)
}
}
async fn check_post_claim_guards(
&self,
available_task: &AvailableTask,
) -> Result<Option<WorkflowStatus>, crate::error::RuntimeError> {
if self
.backend
.check_and_cancel(&available_task.instance_id, Some(available_task.task_id))
.await?
{
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Workflow was cancelled, releasing claim"
);
return Ok(Some(
self.load_cancelled_status(&available_task.instance_id)
.await,
));
}
if self
.backend
.check_and_pause(&available_task.instance_id)
.await?
{
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Workflow was paused, releasing claim"
);
return Ok(Some(
self.load_paused_status(&available_task.instance_id).await,
));
}
Ok(None)
}
#[tracing::instrument(
name = "task",
skip_all,
fields(worker_id = %self.worker_id, instance_id = %claim.instance_id, task_id = %claim.task_id),
)]
async fn run_with_heartbeat<F, T>(
&self,
claim: &ActiveTaskClaim<'_, B>,
deadline: Option<&TaskDeadline>,
future: F,
) -> Result<T, crate::error::RuntimeError>
where
F: std::future::Future<Output = T>,
{
tracing::debug!("running task with heartbeat");
let Some(ttl) = self.claim_ttl else {
return Ok(future.await);
};
let Some(chrono_ttl) = chrono::Duration::from_std(ttl).ok() else {
return Ok(future.await);
};
let interval_duration = ttl / 2;
let mut heartbeat_timer = time::interval(interval_duration);
heartbeat_timer.tick().await;
tokio::pin!(future);
loop {
tokio::select! {
result = &mut future => break Ok(result),
_ = heartbeat_timer.tick() => {
if let Some(dl) = deadline
&& chrono::Utc::now() >= dl.deadline
{
tracing::warn!(
instance_id = %claim.instance_id,
task_id = %dl.task_id,
"Task deadline expired during heartbeat, cancelling"
);
return Err(WorkflowError::TaskTimedOut {
task_id: dl.task_id,
timeout: std::time::Duration::from_millis(dl.timeout_ms),
}
.into());
}
tracing::trace!(
instance_id = %claim.instance_id,
task_id = %claim.task_id,
"Extending task claim via heartbeat"
);
if let Err(e) = self.backend
.extend_task_claim(
&claim.instance_id,
&claim.task_id,
&claim.worker_id,
chrono_ttl,
)
.await
{
tracing::warn!(
instance_id = %claim.instance_id,
task_id = %claim.task_id,
error = %e,
"Failed to extend task claim"
);
}
}
}
}
}
async fn commit_task_result(
&self,
continuation: &WorkflowContinuation,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
output: Bytes,
claim: ActiveTaskClaim<'_, B>,
) -> Result<(), crate::error::RuntimeError> {
snapshot.mark_task_completed(available_task.task_id, output);
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Task completed"
);
Self::update_position_after_task(continuation, &available_task.task_id, snapshot)?;
#[cfg(feature = "otel")]
{
snapshot.trace_parent = crate::trace_context::current_trace_parent();
}
self.backend.save_snapshot(snapshot).await?;
self.drain_pending_signal(&available_task.instance_id, snapshot)
.await?;
claim.release().await?;
Ok(())
}
async fn drain_pending_signal(
&self,
instance_id: &Arc<str>,
snapshot: &mut WorkflowSnapshot,
) -> Result<(), crate::error::RuntimeError> {
loop {
let (signal_id, signal_name, next_task_id) = match &snapshot.state {
sayiir_core::snapshot::WorkflowSnapshotState::InProgress {
position:
sayiir_core::snapshot::ExecutionPosition::AtSignal {
signal_id,
signal_name,
next_task_id,
..
},
..
} => (*signal_id, signal_name.clone(), *next_task_id),
_ => return Ok(()),
};
let Some(payload) = self
.backend
.consume_event(instance_id, &signal_name)
.await?
else {
return Ok(());
};
tracing::debug!(
instance_id = %instance_id,
%signal_name,
"draining buffered signal that landed during the AtTask→AtSignal transition"
);
snapshot.mark_task_completed(signal_id, payload.clone());
if let Some(next_id) = next_task_id {
snapshot.update_position(sayiir_core::snapshot::ExecutionPosition::AtTask {
task_id: next_id,
});
} else {
snapshot.mark_completed(payload);
}
self.backend.save_snapshot(snapshot).await?;
}
}
async fn determine_post_task_status(
&self,
continuation: &WorkflowContinuation,
available_task: &AvailableTask,
snapshot: &mut WorkflowSnapshot,
output: Bytes,
) -> Result<WorkflowStatus, crate::error::RuntimeError> {
if self
.backend
.check_and_cancel(&available_task.instance_id, None)
.await?
{
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Workflow was cancelled after task completion"
);
return Ok(self
.load_cancelled_status(&available_task.instance_id)
.await);
}
if self
.backend
.check_and_pause(&available_task.instance_id)
.await?
{
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Workflow was paused after task completion"
);
return Ok(self.load_paused_status(&available_task.instance_id).await);
}
if Self::is_workflow_complete(continuation, snapshot) {
tracing::info!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Workflow complete"
);
snapshot.mark_completed(output);
self.backend.save_snapshot(snapshot).await?;
Ok(WorkflowStatus::Completed)
} else {
tracing::debug!(
instance_id = %available_task.instance_id,
task_id = %available_task.task_id,
"Task completed, workflow continues"
);
Ok(WorkflowStatus::InProgress)
}
}
fn find_fork_branches_for_join<'a>(
continuation: &'a WorkflowContinuation,
task_id: &sayiir_core::TaskId,
) -> Option<&'a [Arc<WorkflowContinuation>]> {
match continuation {
WorkflowContinuation::Task { next, .. }
| WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. } => next
.as_deref()
.and_then(|n| Self::find_fork_branches_for_join(n, task_id)),
WorkflowContinuation::Fork { branches, join, .. } => {
if let Some(join_cont) = join {
let join_first = sayiir_core::TaskId::from(join_cont.first_task_id());
if join_first == *task_id {
return Some(&branches[..]);
}
if let Some(b) = Self::find_fork_branches_for_join(join_cont, task_id) {
return Some(b);
}
}
for branch in branches {
if let Some(b) = Self::find_fork_branches_for_join(branch, task_id) {
return Some(b);
}
}
None
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
for branch_cont in branches.values() {
if let Some(b) = Self::find_fork_branches_for_join(branch_cont, task_id) {
return Some(b);
}
}
if let Some(def) = default
&& let Some(b) = Self::find_fork_branches_for_join(def, task_id)
{
return Some(b);
}
next.as_deref()
.and_then(|n| Self::find_fork_branches_for_join(n, task_id))
}
WorkflowContinuation::Loop { body, next, .. } => {
Self::find_fork_branches_for_join(body, task_id).or_else(|| {
next.as_deref()
.and_then(|n| Self::find_fork_branches_for_join(n, task_id))
})
}
WorkflowContinuation::ChildWorkflow { child, next, .. } => {
Self::find_fork_branches_for_join(child, task_id).or_else(|| {
next.as_deref()
.and_then(|n| Self::find_fork_branches_for_join(n, task_id))
})
}
}
}
fn find_task_id_in_continuation(
continuation: &WorkflowContinuation,
task_id: &sayiir_core::TaskId,
) -> bool {
match continuation {
WorkflowContinuation::Task { id, next, .. }
| WorkflowContinuation::Delay { id, next, .. }
| WorkflowContinuation::AwaitSignal { id, next, .. } => {
if sayiir_core::TaskId::from(id.as_str()) == *task_id {
return true;
}
next.as_ref()
.is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
}
WorkflowContinuation::Fork { branches, join, .. } => {
for branch in branches {
if Self::find_task_id_in_continuation(branch, task_id) {
return true;
}
}
if let Some(join_cont) = join {
Self::find_task_id_in_continuation(join_cont, task_id)
} else {
false
}
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
for branch_cont in branches.values() {
if Self::find_task_id_in_continuation(branch_cont, task_id) {
return true;
}
}
if let Some(def) = default
&& Self::find_task_id_in_continuation(def, task_id)
{
return true;
}
next.as_ref()
.is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
}
WorkflowContinuation::Loop { body, next, .. } => {
if Self::find_task_id_in_continuation(body, task_id) {
return true;
}
next.as_ref()
.is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
}
WorkflowContinuation::ChildWorkflow { child, next, .. } => {
if Self::find_task_id_in_continuation(child, task_id) {
return true;
}
next.as_ref()
.is_some_and(|n| Self::find_task_id_in_continuation(n, task_id))
}
}
}
#[allow(clippy::manual_async_fn)]
fn execute_task_by_id<'a>(
continuation: &'a WorkflowContinuation,
task_id: &'a str,
input: Bytes,
) -> impl std::future::Future<Output = Result<Bytes, crate::error::RuntimeError>> + Send + 'a
{
async move {
let task_id_hash = sayiir_core::TaskId::from(task_id);
let task_id = &task_id_hash;
let mut current = continuation;
loop {
match current {
WorkflowContinuation::Task { id, func, next, .. } => {
if sayiir_core::TaskId::from(id.as_str()) == *task_id {
let func = func
.as_ref()
.ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
return Ok(func.run(input).await?);
} else if let Some(next_cont) = next {
current = next_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. } => {
if let Some(next_cont) = next {
current = next_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
WorkflowContinuation::Fork { branches, join, .. } => {
let mut found_in_branch = false;
for branch in branches {
if Self::find_task_id_in_continuation(branch, task_id) {
current = branch;
found_in_branch = true;
break;
}
}
if found_in_branch {
continue;
}
if let Some(join_cont) = join {
current = join_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
let mut found = false;
for branch_cont in branches.values() {
if Self::find_task_id_in_continuation(branch_cont, task_id) {
current = branch_cont;
found = true;
break;
}
}
if found {
continue;
}
if let Some(def) = default
&& Self::find_task_id_in_continuation(def, task_id)
{
current = def;
continue;
}
if let Some(next_cont) = next {
current = next_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
WorkflowContinuation::Loop { body, next, .. } => {
if Self::find_task_id_in_continuation(body, task_id) {
current = body;
continue;
}
if let Some(next_cont) = next {
current = next_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
WorkflowContinuation::ChildWorkflow { child, next, .. } => {
if Self::find_task_id_in_continuation(child, task_id) {
current = child;
continue;
}
if let Some(next_cont) = next {
current = next_cont;
} else {
return Err(WorkflowError::TaskNotFound(task_id.to_string()).into());
}
}
}
}
}
}
fn set_position_at(
cont: &WorkflowContinuation,
snapshot: &mut WorkflowSnapshot,
) -> Result<(), crate::error::RuntimeError> {
use crate::execution::control_flow::{compute_signal_timeout, compute_wake_at};
match cont {
WorkflowContinuation::Delay { id, duration, next } => {
let wake_at = compute_wake_at(duration)?;
let entered_at = chrono::Utc::now();
let next_hint = next.as_deref().map(WorkflowContinuation::first_task_hint);
let next_task_id = next_hint.as_ref().map(|h| h.id);
snapshot.set_task_hint(next_hint.as_ref().unwrap_or(&TaskHint::default()));
let delay_id = sayiir_core::TaskId::from(id.as_str());
snapshot.update_position(ExecutionPosition::AtDelay {
delay_id,
entered_at,
wake_at,
next_task_id,
});
let passthrough = snapshot.get_last_task_output().unwrap_or_default();
snapshot.mark_task_completed(delay_id, passthrough);
}
WorkflowContinuation::AwaitSignal {
id,
signal_name,
timeout,
next,
} => {
let wake_at = compute_signal_timeout(timeout.as_ref());
let next_hint = next.as_deref().map(WorkflowContinuation::first_task_hint);
let next_task_id = next_hint.as_ref().map(|h| h.id);
snapshot.set_task_hint(next_hint.as_ref().unwrap_or(&TaskHint::default()));
snapshot.update_position(ExecutionPosition::AtSignal {
signal_id: sayiir_core::TaskId::from(id.as_str()),
signal_name: signal_name.clone(),
wake_at,
next_task_id,
});
}
_ => {
let hint = cont.first_task_hint();
snapshot.update_position(ExecutionPosition::AtTask { task_id: hint.id });
snapshot.set_task_hint(&hint);
}
}
Ok(())
}
fn update_position_after_task(
continuation: &WorkflowContinuation,
completed_task_id: &sayiir_core::TaskId,
snapshot: &mut WorkflowSnapshot,
) -> Result<(), crate::error::RuntimeError> {
match continuation {
WorkflowContinuation::Task { id, next, .. }
| WorkflowContinuation::Delay { id, next, .. }
| WorkflowContinuation::AwaitSignal { id, next, .. } => {
if sayiir_core::TaskId::from(id.as_str()) == *completed_task_id {
if let Some(next_cont) = next.as_deref() {
Self::set_position_at(next_cont, snapshot)?;
}
} else if let Some(next_cont) = next {
Self::update_position_after_task(next_cont, completed_task_id, snapshot)?;
}
}
WorkflowContinuation::Fork { branches, join, .. } => {
for branch in branches {
Self::update_position_after_task(branch, completed_task_id, snapshot)?;
}
if let Some(join_cont) = join {
Self::update_position_after_task(join_cont, completed_task_id, snapshot)?;
}
let still_at_completed = snapshot
.current_task_id()
.is_some_and(|c| c == *completed_task_id);
if !still_at_completed {
return Ok(());
}
for branch in branches {
let first_tid = sayiir_core::TaskId::from(branch.first_task_id());
if snapshot.get_task_result(&first_tid).is_none() {
Self::set_position_at(branch, snapshot)?;
return Ok(());
}
}
if let Some(join_cont) = join {
Self::set_position_at(join_cont, snapshot)?;
}
}
WorkflowContinuation::Branch {
branches,
default,
next,
..
} => {
for branch_cont in branches.values() {
Self::update_position_after_task(branch_cont, completed_task_id, snapshot)?;
}
if let Some(def) = default {
Self::update_position_after_task(def, completed_task_id, snapshot)?;
}
if let Some(next_cont) = next {
Self::update_position_after_task(next_cont, completed_task_id, snapshot)?;
}
}
WorkflowContinuation::Loop { body, next, .. } => {
Self::update_position_after_task(body, completed_task_id, snapshot)?;
if let Some(next_cont) = next {
Self::update_position_after_task(next_cont, completed_task_id, snapshot)?;
}
}
WorkflowContinuation::ChildWorkflow { child, next, .. } => {
Self::update_position_after_task(child, completed_task_id, snapshot)?;
if let Some(next_cont) = next {
Self::update_position_after_task(next_cont, completed_task_id, snapshot)?;
}
}
}
Ok(())
}
pub fn builder(backend: B, registry: TaskRegistry) -> PooledWorkerBuilder<B> {
PooledWorkerBuilder {
worker_id: None,
backend,
registry,
claim_ttl: Some(Duration::from_mins(5)),
batch_size: NonZeroUsize::MIN,
aging_interval: Duration::from_mins(5),
tags: vec![],
}
}
async fn resolve_loop_completions(
continuation: &WorkflowContinuation,
snapshot: &mut WorkflowSnapshot,
backend: &B,
) -> Result<(), crate::error::RuntimeError> {
Self::resolve_loops_recursive(continuation, snapshot, backend).await
}
#[allow(clippy::too_many_lines)]
fn resolve_loops_recursive<'a>(
continuation: &'a WorkflowContinuation,
snapshot: &'a mut WorkflowSnapshot,
backend: &'a B,
) -> Pin<
Box<dyn std::future::Future<Output = Result<(), crate::error::RuntimeError>> + Send + 'a>,
> {
Box::pin(async move {
match continuation {
WorkflowContinuation::Loop {
id,
body,
max_iterations,
on_max,
next,
} => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
let terminal_id = body.terminal_task_id();
if let Some(result) =
snapshot.get_task_result(&sayiir_core::TaskId::from(terminal_id))
{
let output = result.output.clone();
match crate::execution::decode_loop_envelope(&output) {
Ok((LoopDecision::Done, inner)) => {
snapshot.clear_loop_iteration(&sayiir_core::TaskId::from(id));
snapshot
.mark_task_completed(sayiir_core::TaskId::from(id), inner);
backend.save_snapshot(snapshot).await?;
}
Ok((LoopDecision::Again, again_value)) => {
let current_iter =
snapshot.loop_iteration(&sayiir_core::TaskId::from(id));
let next_iter = current_iter + 1;
if next_iter >= *max_iterations {
match on_max {
sayiir_core::workflow::MaxIterationsPolicy::Fail => {
return Err(WorkflowError::MaxIterationsExceeded {
loop_id: sayiir_core::TaskId::from(id),
max_iterations: *max_iterations,
}
.into());
}
sayiir_core::workflow::MaxIterationsPolicy::ExitWithLast => {
snapshot.clear_loop_iteration(&sayiir_core::TaskId::from(id));
snapshot.mark_task_completed(
sayiir_core::TaskId::from(id.as_str()),
again_value,
);
backend.save_snapshot(snapshot).await?;
}
}
} else {
let body_ser = body.to_serializable();
for tid in &body_ser.task_ids() {
snapshot.remove_task_result(
&sayiir_core::TaskId::from(*tid),
);
}
snapshot.set_loop_iteration(
sayiir_core::TaskId::from(id),
next_iter,
);
backend.save_snapshot(snapshot).await?;
}
}
Err(e) => {
return Err(CodecError::DecodeFailed {
task_id: sayiir_core::TaskId::from(id),
expected_type: "LoopEnvelope",
source: e,
}
.into());
}
}
}
}
Self::resolve_loops_recursive(body, snapshot, backend).await?;
if let Some(next) = next {
Self::resolve_loops_recursive(next, snapshot, backend).await?;
}
}
WorkflowContinuation::Task { next, .. }
| WorkflowContinuation::Delay { next, .. }
| WorkflowContinuation::AwaitSignal { next, .. }
| WorkflowContinuation::Branch { next, .. } => {
if let Some(next) = next {
Self::resolve_loops_recursive(next, snapshot, backend).await?;
}
}
WorkflowContinuation::Fork { branches, join, .. } => {
for branch in branches {
Self::resolve_loops_recursive(branch, snapshot, backend).await?;
}
if let Some(join) = join {
Self::resolve_loops_recursive(join, snapshot, backend).await?;
}
}
WorkflowContinuation::ChildWorkflow { child, next, .. } => {
Self::resolve_loops_recursive(child, snapshot, backend).await?;
if let Some(next) = next {
Self::resolve_loops_recursive(next, snapshot, backend).await?;
}
}
}
Ok(())
})
}
fn is_workflow_complete(
continuation: &WorkflowContinuation,
snapshot: &WorkflowSnapshot,
) -> bool {
match continuation {
WorkflowContinuation::Task { id, next, .. } => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
return false;
}
if let Some(next_cont) = next {
Self::is_workflow_complete(next_cont, snapshot)
} else {
true }
}
WorkflowContinuation::Delay { id, next, .. }
| WorkflowContinuation::AwaitSignal { id, next, .. } => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
return false;
}
next.as_ref()
.is_none_or(|n| Self::is_workflow_complete(n, snapshot))
}
WorkflowContinuation::Fork { branches, join, .. } => {
for branch in branches {
if !Self::is_workflow_complete(branch, snapshot) {
return false;
}
}
if let Some(join_cont) = join {
Self::is_workflow_complete(join_cont, snapshot)
} else {
true
}
}
WorkflowContinuation::Branch { id, next, .. } => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
return false;
}
next.as_ref()
.is_none_or(|n| Self::is_workflow_complete(n, snapshot))
}
WorkflowContinuation::Loop { id, next, .. } => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
return false;
}
next.as_ref()
.is_none_or(|n| Self::is_workflow_complete(n, snapshot))
}
WorkflowContinuation::ChildWorkflow { id, next, .. } => {
if snapshot
.get_task_result(&sayiir_core::TaskId::from(id))
.is_none()
{
return false;
}
next.as_ref()
.is_none_or(|n| Self::is_workflow_complete(n, snapshot))
}
}
}
}
fn default_worker_id() -> String {
let host = hostname::get()
.ok()
.and_then(|h| h.into_string().ok())
.unwrap_or_else(|| "unknown".to_string());
format!("{host}-{}", std::process::id())
}
pub struct PooledWorkerBuilder<B> {
worker_id: Option<String>,
backend: B,
registry: TaskRegistry,
claim_ttl: Option<Duration>,
batch_size: NonZeroUsize,
aging_interval: Duration,
tags: Vec<String>,
}
impl<B> PooledWorkerBuilder<B>
where
B: PersistentBackend + TaskClaimStore + 'static,
{
#[must_use]
pub fn worker_id(mut self, id: impl Into<String>) -> Self {
self.worker_id = Some(id.into());
self
}
#[must_use]
pub fn claim_ttl(mut self, ttl: Option<Duration>) -> Self {
self.claim_ttl = ttl;
self
}
#[must_use]
pub fn batch_size(mut self, size: NonZeroUsize) -> Self {
self.batch_size = size;
self
}
#[must_use]
pub fn aging_interval(mut self, interval: Duration) -> Self {
assert!(!interval.is_zero(), "aging interval must be non-zero");
self.aging_interval = interval;
self
}
#[must_use]
pub fn tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
#[must_use]
pub fn build(self) -> PooledWorker<B> {
let worker_id = self.worker_id.unwrap_or_else(default_worker_id);
PooledWorker {
worker_id,
backend: Arc::new(self.backend),
registry: Arc::new(self.registry),
claim_ttl: self.claim_ttl,
batch_size: self.batch_size,
aging_interval: self.aging_interval,
tags: self.tags,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::serialization::JsonCodec;
use sayiir_core::registry::TaskRegistry;
use sayiir_core::snapshot::WorkflowSnapshot;
use sayiir_persistence::{InMemoryBackend, SignalStore, SnapshotStore};
type EmptyWorkflows = WorkflowRegistry<JsonCodec, (), ()>;
fn make_worker() -> PooledWorker<InMemoryBackend> {
let backend = InMemoryBackend::new();
let registry = TaskRegistry::new();
PooledWorker::new("test-worker", backend, registry)
}
#[tokio::test]
async fn test_spawn_and_shutdown() {
let worker = make_worker();
let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
handle.shutdown();
let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
assert!(result.is_ok(), "Worker should exit cleanly after shutdown");
assert!(result.unwrap().is_ok());
}
#[tokio::test]
async fn test_handle_is_clone_and_send() {
let worker = make_worker();
let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
let handle2 = handle.clone();
let remote = tokio::spawn(async move {
handle2.shutdown();
});
remote.await.ok();
let result = tokio::time::timeout(Duration::from_secs(5), handle.join()).await;
assert!(result.is_ok_and(|r| r.is_ok()));
}
#[tokio::test]
async fn test_cancel_via_client() {
let backend = InMemoryBackend::new();
let registry = TaskRegistry::new();
let mut snapshot = WorkflowSnapshot::new("wf-1", "hash-1".into());
backend.save_snapshot(&mut snapshot).await.ok();
let worker = PooledWorker::new("test-worker", backend, registry);
let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
let client = crate::WorkflowClient::from_shared(std::sync::Arc::clone(handle.backend()));
client
.cancel(
"wf-1",
Some("test reason".to_string()),
Some("tester".to_string()),
)
.await
.ok();
let signal = handle
.backend()
.get_signal("wf-1", SignalKind::Cancel)
.await;
assert!(signal.is_ok_and(|s| s.is_some()));
handle.shutdown();
tokio::time::timeout(Duration::from_secs(5), handle.join())
.await
.ok();
}
#[test]
fn test_builder_auto_generates_worker_id() {
let backend = InMemoryBackend::new();
let registry = TaskRegistry::new();
let worker = PooledWorker::builder(backend, registry).build();
let pid = std::process::id().to_string();
assert!(
worker.worker_id.contains(&pid),
"Auto-generated ID '{}' should contain PID '{}'",
worker.worker_id,
pid
);
}
#[test]
fn test_builder_explicit_worker_id() {
let backend = InMemoryBackend::new();
let registry = TaskRegistry::new();
let worker = PooledWorker::builder(backend, registry)
.worker_id("my-worker")
.build();
assert_eq!(worker.worker_id, "my-worker");
}
#[test]
fn test_builder_custom_settings() {
let backend = InMemoryBackend::new();
let registry = TaskRegistry::new();
let worker = PooledWorker::builder(backend, registry)
.worker_id("w1")
.claim_ttl(Some(Duration::from_mins(2)))
.batch_size(NonZeroUsize::new(8).unwrap())
.build();
assert_eq!(worker.worker_id, "w1");
assert_eq!(worker.claim_ttl, Some(Duration::from_mins(2)));
assert_eq!(worker.batch_size.get(), 8);
}
#[tokio::test]
async fn test_dropped_handle_shuts_down_worker() {
let worker = make_worker();
let handle = worker.spawn(Duration::from_millis(50), EmptyWorkflows::new());
let join_handle = handle.inner.join_handle.lock().await.take().unwrap();
drop(handle);
let result = tokio::time::timeout(Duration::from_secs(5), join_handle)
.await
.ok()
.and_then(Result::ok);
assert!(
result.is_some(),
"Worker should exit when all handles are dropped"
);
assert!(result.is_some_and(|r| r.is_ok()));
}
}