#![warn(missing_docs)]
#[macro_use]
extern crate tracing;
extern crate self as temporalio_sdk;
pub mod activities;
pub mod interceptors;
mod workflow_context;
mod workflow_future;
pub mod workflows;
#[macro_export]
#[doc(hidden)]
macro_rules! __temporal_select {
($($tokens:tt)*) => {
::futures_util::select_biased! { $($tokens)* }
};
}
#[macro_export]
#[doc(hidden)]
macro_rules! __temporal_join {
($($tokens:tt)*) => {
::futures_util::join!($($tokens)*)
};
}
use workflow_future::WorkflowFunction;
pub use temporalio_client::Namespace;
pub use workflow_context::{
ActivityExecutionError, ActivityOptions, BaseWorkflowContext, CancellableFuture, ChildWorkflow,
ChildWorkflowOptions, LocalActivityOptions, NexusOperationOptions, ParentWorkflowInfo,
PendingChildWorkflow, RootWorkflowInfo, Signal, SignalData, SignalWorkflowOptions,
StartedChildWorkflow, SyncWorkflowContext, TimerOptions, WorkflowContext, WorkflowContextView,
};
use crate::{
activities::{
ActivityContext, ActivityDefinitions, ActivityError, ActivityImplementer,
ExecutableActivity,
},
interceptors::WorkerInterceptor,
workflow_context::{ChildWfCommon, NexusUnblockData, StartedNexusOperation},
workflows::{WorkflowDefinitions, WorkflowImplementation, WorkflowImplementer},
};
use anyhow::{Context, anyhow, bail};
use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use std::{
any::{Any, TypeId},
cell::RefCell,
collections::{HashMap, HashSet},
fmt::{Debug, Display, Formatter},
future::Future,
panic::AssertUnwindSafe,
sync::Arc,
time::Duration,
};
use temporalio_client::{Client, NamespacedClient};
use temporalio_common::{
ActivityDefinition, WorkflowDefinition,
data_converters::{DataConverter, SerializationContextData},
payload_visitor::{decode_payloads, encode_payloads},
protos::{
TaskToken,
coresdk::{
ActivityTaskCompletion, AsJsonPayloadExt,
activity_result::{ActivityExecutionResult, ActivityResolution},
activity_task::{ActivityTask, activity_task},
child_workflow::ChildWorkflowResult,
common::NamespacedWorkflowExecution,
nexus::NexusOperationResult,
workflow_activation::{
WorkflowActivation,
resolve_child_workflow_execution_start::Status as ChildWorkflowStartStatus,
resolve_nexus_operation_start, workflow_activation_job::Variant,
},
workflow_commands::{
ContinueAsNewWorkflowExecution, WorkflowCommand, workflow_command,
},
workflow_completion::WorkflowActivationCompletion,
},
temporal::api::{
common::v1::Payload,
enums::v1::WorkflowTaskFailedCause,
failure::v1::{Failure, failure},
},
},
worker::{WorkerDeploymentOptions, WorkerTaskTypes, build_id_from_current_exe},
};
use temporalio_sdk_core::{
CoreRuntime, PollError, PollerBehavior, TunerBuilder, Worker as CoreWorker, WorkerConfig,
WorkerTuner, WorkerVersioningStrategy, WorkflowErrorType, init_worker,
};
use tokio::{
sync::{
Notify,
mpsc::{UnboundedSender, unbounded_channel},
oneshot,
},
task::JoinError,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Span, field};
use uuid::Uuid;
#[derive(bon::Builder, Clone)]
#[builder(start_fn = new, on(String, into), state_mod(vis = "pub"))]
#[non_exhaustive]
pub struct WorkerOptions {
#[builder(start_fn)]
pub task_queue: String,
#[builder(field)]
activities: ActivityDefinitions,
#[builder(field)]
workflows: WorkflowDefinitions,
#[builder(default = def_build_id())]
pub deployment_options: WorkerDeploymentOptions,
pub client_identity_override: Option<String>,
#[builder(default = 1000)]
pub max_cached_workflows: usize,
#[builder(default = Arc::new(TunerBuilder::default().build()))]
pub tuner: Arc<dyn WorkerTuner + Send + Sync>,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub workflow_task_poller_behavior: PollerBehavior,
#[builder(default = 0.2)]
pub nonsticky_to_sticky_poll_ratio: f32,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub activity_task_poller_behavior: PollerBehavior,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub nexus_task_poller_behavior: PollerBehavior,
#[builder(default = WorkerTaskTypes::all())]
pub task_types: WorkerTaskTypes,
#[builder(default = Duration::from_secs(10))]
pub sticky_queue_schedule_to_start_timeout: Duration,
#[builder(default = Duration::from_secs(60))]
pub max_heartbeat_throttle_interval: Duration,
#[builder(default = Duration::from_secs(30))]
pub default_heartbeat_throttle_interval: Duration,
pub max_task_queue_activities_per_second: Option<f64>,
pub max_worker_activities_per_second: Option<f64>,
#[builder(default)]
pub workflow_failure_errors: HashSet<WorkflowErrorType>,
#[builder(default)]
pub workflow_types_to_failure_errors: HashMap<String, HashSet<WorkflowErrorType>>,
pub graceful_shutdown_period: Option<Duration>,
}
impl<S: worker_options_builder::State> WorkerOptionsBuilder<S> {
pub fn register_activities<AI: ActivityImplementer>(mut self, instance: AI) -> Self {
self.activities.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(mut self, instance: Arc<AD::Implementer>) -> Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Output: Send + Sync,
{
self.activities.register_activity::<AD>(instance);
self
}
pub fn register_workflow<WI: WorkflowImplementer>(mut self) -> Self {
self.workflows.register_workflow::<WI>();
self
}
pub fn register_workflow_with_factory<W, F>(mut self, factory: F) -> Self
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflows
.register_workflow_run_with_factory::<W, F>(factory);
self
}
}
fn def_build_id() -> WorkerDeploymentOptions {
WorkerDeploymentOptions::from_build_id(build_id_from_current_exe().to_owned())
}
impl WorkerOptions {
pub fn register_activities<AI: ActivityImplementer>(&mut self, instance: AI) -> &mut Self {
self.activities.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(&mut self, instance: Arc<AD::Implementer>) -> &mut Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Output: Send + Sync,
{
self.activities.register_activity::<AD>(instance);
self
}
pub fn activities(&self) -> ActivityDefinitions {
self.activities.clone()
}
pub fn register_workflow<WI: WorkflowImplementer>(&mut self) -> &mut Self {
self.workflows.register_workflow::<WI>();
self
}
pub fn register_workflow_with_factory<W, F>(&mut self, factory: F) -> &mut Self
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflows
.register_workflow_run_with_factory::<W, F>(factory);
self
}
pub fn workflows(&self) -> WorkflowDefinitions {
self.workflows.clone()
}
#[doc(hidden)]
pub fn to_core_options(
&self,
namespace: String,
connection_identity: String,
) -> Result<WorkerConfig, String> {
WorkerConfig::builder()
.namespace(namespace)
.task_queue(self.task_queue.clone())
.maybe_client_identity_override(self.client_identity_override.clone().or_else(|| {
connection_identity.is_empty().then(|| {
format!(
"{}@{}",
std::process::id(),
gethostname::gethostname().to_string_lossy()
)
})
}))
.max_cached_workflows(self.max_cached_workflows)
.tuner(self.tuner.clone())
.workflow_task_poller_behavior(self.workflow_task_poller_behavior)
.activity_task_poller_behavior(self.activity_task_poller_behavior)
.nexus_task_poller_behavior(self.nexus_task_poller_behavior)
.task_types(self.task_types)
.sticky_queue_schedule_to_start_timeout(self.sticky_queue_schedule_to_start_timeout)
.max_heartbeat_throttle_interval(self.max_heartbeat_throttle_interval)
.default_heartbeat_throttle_interval(self.default_heartbeat_throttle_interval)
.maybe_max_task_queue_activities_per_second(self.max_task_queue_activities_per_second)
.maybe_max_worker_activities_per_second(self.max_worker_activities_per_second)
.maybe_graceful_shutdown_period(self.graceful_shutdown_period)
.versioning_strategy(WorkerVersioningStrategy::WorkerDeploymentBased(
self.deployment_options.clone(),
))
.workflow_failure_errors(self.workflow_failure_errors.clone())
.workflow_types_to_failure_errors(self.workflow_types_to_failure_errors.clone())
.build()
}
}
pub struct Worker {
common: CommonWorker,
workflow_half: WorkflowHalf,
activity_half: ActivityHalf,
}
struct CommonWorker {
worker: Arc<CoreWorker>,
task_queue: String,
worker_interceptor: Option<Box<dyn WorkerInterceptor>>,
data_converter: DataConverter,
}
#[derive(Default)]
struct WorkflowHalf {
workflows: RefCell<HashMap<String, WorkflowData>>,
workflow_definitions: WorkflowDefinitions,
workflow_removed_from_map: Notify,
}
struct WorkflowData {
activation_chan: UnboundedSender<WorkflowActivation>,
}
struct WorkflowFutureHandle<F: Future<Output = Result<WorkflowResult<Payload>, JoinError>>> {
join_handle: F,
run_id: String,
}
#[derive(Default)]
struct ActivityHalf {
activities: ActivityDefinitions,
task_tokens_to_cancels: HashMap<TaskToken, CancellationToken>,
}
impl Worker {
pub fn new(
runtime: &CoreRuntime,
client: Client,
mut options: WorkerOptions,
) -> Result<Self, Box<dyn std::error::Error>> {
let acts = std::mem::take(&mut options.activities);
let wfs = std::mem::take(&mut options.workflows);
let wc = options
.to_core_options(client.namespace(), client.identity())
.map_err(|s| anyhow::anyhow!("{s}"))?;
let core = init_worker(runtime, wc, client.connection().clone())?;
let mut me = Self::new_from_core_definitions(
Arc::new(core),
client.data_converter().clone(),
Default::default(),
Default::default(),
);
me.activity_half.activities = acts;
me.workflow_half.workflow_definitions = wfs;
Ok(me)
}
#[doc(hidden)]
pub fn new_from_core(worker: Arc<CoreWorker>, data_converter: DataConverter) -> Self {
Self::new_from_core_definitions(
worker,
data_converter,
Default::default(),
Default::default(),
)
}
#[doc(hidden)]
pub fn new_from_core_definitions(
worker: Arc<CoreWorker>,
data_converter: DataConverter,
activities: ActivityDefinitions,
workflows: WorkflowDefinitions,
) -> Self {
Self {
common: CommonWorker {
task_queue: worker.get_config().task_queue.clone(),
worker,
worker_interceptor: None,
data_converter,
},
workflow_half: WorkflowHalf {
workflow_definitions: workflows,
..Default::default()
},
activity_half: ActivityHalf {
activities,
..Default::default()
},
}
}
pub fn task_queue(&self) -> &str {
&self.common.task_queue
}
pub fn shutdown_handle(&self) -> impl Fn() + use<> {
let w = self.common.worker.clone();
move || w.initiate_shutdown()
}
pub fn register_activities<AI: ActivityImplementer>(&mut self, instance: AI) -> &mut Self {
self.activity_half
.activities
.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(&mut self, instance: Arc<AD::Implementer>) -> &mut Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Output: Send + Sync,
{
self.activity_half
.activities
.register_activity::<AD>(instance);
self
}
pub fn register_workflow<WI: WorkflowImplementer>(&mut self) -> &mut Self {
self.workflow_half
.workflow_definitions
.register_workflow::<WI>();
self
}
pub fn register_workflow_with_factory<W, F>(&mut self, factory: F) -> &mut Self
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflow_half
.workflow_definitions
.register_workflow_run_with_factory::<W, F>(factory);
self
}
pub async fn run(&mut self) -> Result<(), anyhow::Error> {
let shutdown_token = CancellationToken::new();
let (common, wf_half, act_half) = self.split_apart();
let (wf_future_tx, wf_future_rx) = unbounded_channel();
let (completions_tx, completions_rx) = unbounded_channel();
let workflow_local_set = tokio::task::LocalSet::new();
let wf_future_joiner = async {
UnboundedReceiverStream::new(wf_future_rx)
.map(Result::<_, anyhow::Error>::Ok)
.try_for_each_concurrent(
None,
|WorkflowFutureHandle {
join_handle,
run_id,
}| {
let wf_half = &*wf_half;
async move {
let result = join_handle.await?;
if let Err(e) = result
&& !matches!(e, WorkflowTermination::Evicted)
{
return Err(e.into());
}
debug!(run_id=%run_id, "Removing workflow from cache");
wf_half.workflows.borrow_mut().remove(&run_id);
wf_half.workflow_removed_from_map.notify_one();
Ok(())
}
},
)
.await
.context("Workflow futures encountered an error")
};
let wf_completion_processor = async {
UnboundedReceiverStream::new(completions_rx)
.map(Ok)
.try_for_each_concurrent(None, |mut completion| async {
encode_payloads(
&mut completion,
common.data_converter.codec(),
&SerializationContextData::Workflow,
)
.await;
if let Some(ref i) = common.worker_interceptor {
i.on_workflow_activation_completion(&completion).await;
}
common.worker.complete_workflow_activation(completion).await
})
.map_err(anyhow::Error::from)
.await
.context("Workflow completions processor encountered an error")
};
tokio::try_join!(
async {
workflow_local_set.run_until(async {
tokio::try_join!(
async {
loop {
let mut activation =
match common.worker.poll_workflow_activation().await {
Err(PollError::ShutDown) => {
break;
}
o => o?,
};
decode_payloads(
&mut activation,
common.data_converter.codec(),
&SerializationContextData::Workflow,
)
.await;
if let Some(ref i) = common.worker_interceptor {
i.on_workflow_activation(&activation).await?;
}
if let Some(wf_fut) = wf_half
.workflow_activation_handler(
common,
shutdown_token.clone(),
activation,
&completions_tx,
)
.await?
&& wf_future_tx.send(wf_fut).is_err()
{
panic!(
"Receive half of completion processor channel cannot be dropped"
);
}
}
shutdown_token.cancel();
drop(wf_future_tx);
drop(completions_tx);
Result::<_, anyhow::Error>::Ok(())
},
wf_future_joiner,
)
}).await
},
async {
if !act_half.activities.is_empty() {
loop {
let activity = common.worker.poll_activity_task().await;
if matches!(activity, Err(PollError::ShutDown)) {
break;
}
let mut activity = activity?;
decode_payloads(
&mut activity,
common.data_converter.codec(),
&SerializationContextData::Activity,
)
.await;
act_half.activity_task_handler(
common.worker.clone(),
common.task_queue.clone(),
common.data_converter.clone(),
activity,
)?;
}
};
Result::<_, anyhow::Error>::Ok(())
},
wf_completion_processor,
)?;
if let Some(i) = self.common.worker_interceptor.as_ref() {
i.on_shutdown(self);
}
self.common.worker.shutdown().await;
Ok(())
}
pub fn set_worker_interceptor(&mut self, interceptor: impl WorkerInterceptor + 'static) {
self.common.worker_interceptor = Some(Box::new(interceptor));
}
pub fn with_new_core_worker(&mut self, new_core_worker: Arc<CoreWorker>) {
self.common.worker = new_core_worker;
}
pub fn cached_workflows(&self) -> usize {
self.workflow_half.workflows.borrow().len()
}
pub fn worker_instance_key(&self) -> Uuid {
self.common.worker.worker_instance_key()
}
#[doc(hidden)]
pub fn core_worker(&self) -> Arc<CoreWorker> {
self.common.worker.clone()
}
fn split_apart(&mut self) -> (&mut CommonWorker, &mut WorkflowHalf, &mut ActivityHalf) {
(
&mut self.common,
&mut self.workflow_half,
&mut self.activity_half,
)
}
}
impl WorkflowHalf {
#[allow(clippy::type_complexity)]
async fn workflow_activation_handler(
&self,
common: &CommonWorker,
shutdown_token: CancellationToken,
mut activation: WorkflowActivation,
completions_tx: &UnboundedSender<WorkflowActivationCompletion>,
) -> Result<
Option<
WorkflowFutureHandle<
impl Future<Output = Result<WorkflowResult<Payload>, JoinError>> + use<>,
>,
>,
anyhow::Error,
> {
let mut res = None;
let run_id = activation.run_id.clone();
if let Some(sw) = activation.jobs.iter_mut().find_map(|j| match j.variant {
Some(Variant::InitializeWorkflow(ref mut sw)) => Some(sw),
_ => None,
}) {
let workflow_type = sw.workflow_type.clone();
let payload_converter = common.data_converter.payload_converter().clone();
let (wff, activations) = {
if let Some(factory) = self.workflow_definitions.get_workflow(&workflow_type) {
match WorkflowFunction::from_invocation(factory).start_workflow(
common.worker.get_config().namespace.clone(),
common.task_queue.clone(),
run_id.clone(),
std::mem::take(sw),
completions_tx.clone(),
payload_converter,
) {
Ok(result) => result,
Err(e) => {
warn!("Failed to create workflow {workflow_type}: {e}");
completions_tx
.send(WorkflowActivationCompletion::fail(
run_id,
format!("Failed to create workflow: {e}").into(),
Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
))
.expect("Completion channel intact");
return Ok(None);
}
}
} else {
warn!("Workflow type {workflow_type} not found");
completions_tx
.send(WorkflowActivationCompletion::fail(
run_id,
format!("Workflow type {workflow_type} not found").into(),
Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
))
.expect("Completion channel intact");
return Ok(None);
}
};
let wff = tokio::task::unconstrained(wff);
let jh = tokio::task::spawn_local(async move {
tokio::select! {
r = wff.fuse() => r,
_ = shutdown_token.cancelled() => {
Err(WorkflowTermination::Evicted)
}
}
});
res = Some(WorkflowFutureHandle {
join_handle: jh,
run_id: run_id.clone(),
});
loop {
if self.workflows.borrow_mut().contains_key(&run_id) {
self.workflow_removed_from_map.notified().await;
} else {
break;
}
}
self.workflows.borrow_mut().insert(
run_id.clone(),
WorkflowData {
activation_chan: activations,
},
);
}
if let Some(dat) = self.workflows.borrow_mut().get_mut(&run_id) {
dat.activation_chan
.send(activation)
.expect("Workflow should exist if we're sending it an activation");
} else {
if activation.jobs.len() == 1
&& matches!(
activation.jobs.first().map(|j| &j.variant),
Some(Some(Variant::RemoveFromCache(_)))
)
{
completions_tx
.send(WorkflowActivationCompletion::from_cmds(run_id, vec![]))
.expect("Completion channel intact");
return Ok(None);
}
bail!("Got activation {activation:?} for unknown workflow {run_id}");
};
Ok(res)
}
}
impl ActivityHalf {
fn activity_task_handler(
&mut self,
worker: Arc<CoreWorker>,
task_queue: String,
data_converter: DataConverter,
activity: ActivityTask,
) -> Result<(), anyhow::Error> {
match activity.variant {
Some(activity_task::Variant::Start(start)) => {
let act_fn = self.activities.get(&start.activity_type).ok_or_else(|| {
anyhow!(
"No function registered for activity type {}",
start.activity_type
)
})?;
let span = info_span!(
"RunActivity",
"otel.name" = format!("RunActivity:{}", start.activity_type),
"otel.kind" = "server",
"temporalActivityID" = start.activity_id,
"temporalWorkflowID" = field::Empty,
"temporalRunID" = field::Empty,
);
let ct = CancellationToken::new();
let task_token = activity.task_token;
self.task_tokens_to_cancels
.insert(task_token.clone().into(), ct.clone());
let (ctx, args) =
ActivityContext::new(worker.clone(), ct, task_queue, task_token.clone(), start);
let codec_data_converter = data_converter.clone();
tokio::spawn(async move {
let act_fut = async move {
if let Some(info) = &ctx.info().workflow_execution {
Span::current()
.record("temporalWorkflowID", &info.workflow_id)
.record("temporalRunID", &info.run_id);
}
(act_fn)(args, data_converter, ctx).await
}
.instrument(span);
let output = AssertUnwindSafe(act_fut).catch_unwind().await;
let result = match output {
Err(e) => ActivityExecutionResult::fail(Failure::application_failure(
format!("Activity function panicked: {}", panic_formatter(e)),
true,
)),
Ok(Ok(p)) => ActivityExecutionResult::ok(p),
Ok(Err(err)) => match err {
ActivityError::Retryable {
source,
explicit_delay,
} => ActivityExecutionResult::fail({
let mut f = Failure::application_failure_from_error(
anyhow::Error::from_boxed(source),
false,
);
if let Some(d) = explicit_delay
&& let Some(failure::FailureInfo::ApplicationFailureInfo(fi)) =
f.failure_info.as_mut()
{
fi.next_retry_delay = d.try_into().ok();
}
f
}),
ActivityError::Cancelled { details } => {
ActivityExecutionResult::cancel_from_details(details)
}
ActivityError::NonRetryable(nre) => ActivityExecutionResult::fail(
Failure::application_failure_from_error(
anyhow::Error::from_boxed(nre),
true,
),
),
ActivityError::WillCompleteAsync => {
ActivityExecutionResult::will_complete_async()
}
},
};
let mut completion = ActivityTaskCompletion {
task_token,
result: Some(result),
};
encode_payloads(
&mut completion,
codec_data_converter.codec(),
&SerializationContextData::Activity,
)
.await;
worker.complete_activity_task(completion).await?;
Ok::<_, anyhow::Error>(())
});
}
Some(activity_task::Variant::Cancel(_)) => {
if let Some(ct) = self
.task_tokens_to_cancels
.get(activity.task_token.as_slice())
{
ct.cancel();
}
}
None => bail!("Undefined activity task variant"),
}
Ok(())
}
}
#[derive(Debug)]
enum UnblockEvent {
Timer(u32, TimerResult),
Activity(u32, Box<ActivityResolution>),
WorkflowStart(u32, Box<ChildWorkflowStartStatus>),
WorkflowComplete(u32, Box<ChildWorkflowResult>),
SignalExternal(u32, Option<Failure>),
CancelExternal(u32, Option<Failure>),
NexusOperationStart(u32, Box<resolve_nexus_operation_start::Status>),
NexusOperationComplete(u32, Box<NexusOperationResult>),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum TimerResult {
Cancelled,
Fired,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SignalExternalOk;
pub type SignalExternalWfResult = Result<SignalExternalOk, Failure>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CancelExternalOk;
pub type CancelExternalWfResult = Result<CancelExternalOk, Failure>;
trait Unblockable {
type OtherDat;
fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self;
}
impl Unblockable for TimerResult {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::Timer(_, result) => result,
_ => panic!("Invalid unblock event for timer"),
}
}
}
impl Unblockable for ActivityResolution {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::Activity(_, result) => *result,
_ => panic!("Invalid unblock event for activity"),
}
}
}
impl Unblockable for PendingChildWorkflow {
type OtherDat = ChildWfCommon;
fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
match ue {
UnblockEvent::WorkflowStart(_, result) => Self {
status: *result,
common: od,
},
_ => panic!("Invalid unblock event for child workflow start"),
}
}
}
impl Unblockable for ChildWorkflowResult {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::WorkflowComplete(_, result) => *result,
_ => panic!("Invalid unblock event for child workflow complete"),
}
}
}
impl Unblockable for SignalExternalWfResult {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::SignalExternal(_, maybefail) => {
maybefail.map_or(Ok(SignalExternalOk), Err)
}
_ => panic!("Invalid unblock event for signal external workflow result"),
}
}
}
impl Unblockable for CancelExternalWfResult {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::CancelExternal(_, maybefail) => {
maybefail.map_or(Ok(CancelExternalOk), Err)
}
_ => panic!("Invalid unblock event for signal external workflow result"),
}
}
}
type NexusStartResult = Result<StartedNexusOperation, Failure>;
impl Unblockable for NexusStartResult {
type OtherDat = NexusUnblockData;
fn unblock(ue: UnblockEvent, od: Self::OtherDat) -> Self {
match ue {
UnblockEvent::NexusOperationStart(_, result) => match *result {
resolve_nexus_operation_start::Status::OperationToken(op_token) => {
Ok(StartedNexusOperation {
operation_token: Some(op_token),
unblock_dat: od,
})
}
resolve_nexus_operation_start::Status::StartedSync(_) => {
Ok(StartedNexusOperation {
operation_token: None,
unblock_dat: od,
})
}
resolve_nexus_operation_start::Status::Failed(f) => Err(f),
},
_ => panic!("Invalid unblock event for nexus operation"),
}
}
}
impl Unblockable for NexusOperationResult {
type OtherDat = ();
fn unblock(ue: UnblockEvent, _: Self::OtherDat) -> Self {
match ue {
UnblockEvent::NexusOperationComplete(_, result) => *result,
_ => panic!("Invalid unblock event for nexus operation complete"),
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum CancellableID {
Timer(u32),
Activity(u32),
LocalActivity(u32),
ChildWorkflow {
seqnum: u32,
reason: String,
},
SignalExternalWorkflow(u32),
ExternalWorkflow {
seqnum: u32,
execution: NamespacedWorkflowExecution,
reason: String,
},
NexusOp(u32),
}
pub(crate) trait SupportsCancelReason {
fn with_reason(self, reason: String) -> CancellableID;
}
#[derive(Debug, Clone)]
pub(crate) enum CancellableIDWithReason {
ChildWorkflow {
seqnum: u32,
},
ExternalWorkflow {
seqnum: u32,
execution: NamespacedWorkflowExecution,
},
}
impl CancellableIDWithReason {
pub(crate) fn seq_num(&self) -> u32 {
match self {
CancellableIDWithReason::ChildWorkflow { seqnum } => *seqnum,
CancellableIDWithReason::ExternalWorkflow { seqnum, .. } => *seqnum,
}
}
}
impl SupportsCancelReason for CancellableIDWithReason {
fn with_reason(self, reason: String) -> CancellableID {
match self {
CancellableIDWithReason::ChildWorkflow { seqnum } => {
CancellableID::ChildWorkflow { seqnum, reason }
}
CancellableIDWithReason::ExternalWorkflow { seqnum, execution } => {
CancellableID::ExternalWorkflow {
seqnum,
execution,
reason,
}
}
}
}
}
impl From<CancellableIDWithReason> for CancellableID {
fn from(v: CancellableIDWithReason) -> Self {
v.with_reason("".to_string())
}
}
#[derive(derive_more::From)]
#[allow(clippy::large_enum_variant)]
enum RustWfCmd {
#[from(ignore)]
Cancel(CancellableID),
ForceWFTFailure(anyhow::Error),
NewCmd(CommandCreateRequest),
NewNonblockingCmd(workflow_command::Variant),
SubscribeChildWorkflowCompletion(CommandSubscribeChildWorkflowCompletion),
SubscribeNexusOperationCompletion {
seq: u32,
unblocker: oneshot::Sender<UnblockEvent>,
},
}
struct CommandCreateRequest {
cmd: WorkflowCommand,
unblocker: oneshot::Sender<UnblockEvent>,
}
struct CommandSubscribeChildWorkflowCompletion {
seq: u32,
unblocker: oneshot::Sender<UnblockEvent>,
}
pub type WorkflowResult<T> = Result<T, WorkflowTermination>;
#[derive(Debug, thiserror::Error)]
pub enum WorkflowTermination {
#[error("Workflow cancelled")]
Cancelled,
#[error("Workflow evicted from cache")]
Evicted,
#[error("Continue as new")]
ContinueAsNew(Box<ContinueAsNewWorkflowExecution>),
#[error("Workflow failed: {0}")]
Failed(#[source] anyhow::Error),
}
impl WorkflowTermination {
pub fn continue_as_new(can: ContinueAsNewWorkflowExecution) -> Self {
Self::ContinueAsNew(Box::new(can))
}
pub fn failed(err: impl Into<anyhow::Error>) -> Self {
Self::Failed(err.into())
}
}
impl From<anyhow::Error> for WorkflowTermination {
fn from(err: anyhow::Error) -> Self {
Self::Failed(err)
}
}
impl From<ActivityExecutionError> for WorkflowTermination {
fn from(value: ActivityExecutionError) -> Self {
Self::failed(value)
}
}
#[derive(Debug)]
pub enum ActExitValue<T> {
WillCompleteAsync,
Normal(T),
}
impl<T: AsJsonPayloadExt> From<T> for ActExitValue<T> {
fn from(t: T) -> Self {
Self::Normal(t)
}
}
fn panic_formatter(panic: Box<dyn Any>) -> Box<dyn Display> {
_panic_formatter::<&str>(panic)
}
fn _panic_formatter<T: 'static + PrintablePanicType>(panic: Box<dyn Any>) -> Box<dyn Display> {
match panic.downcast::<T>() {
Ok(d) => d,
Err(orig) => {
if TypeId::of::<<T as PrintablePanicType>::NextType>()
== TypeId::of::<EndPrintingAttempts>()
{
return Box::new("Couldn't turn panic into a string");
}
_panic_formatter::<T::NextType>(orig)
}
}
}
trait PrintablePanicType: Display {
type NextType: PrintablePanicType;
}
impl PrintablePanicType for &str {
type NextType = String;
}
impl PrintablePanicType for String {
type NextType = EndPrintingAttempts;
}
struct EndPrintingAttempts {}
impl Display for EndPrintingAttempts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Will never be printed")
}
}
impl PrintablePanicType for EndPrintingAttempts {
type NextType = EndPrintingAttempts;
}
#[cfg(test)]
mod tests {
use super::*;
use temporalio_macros::{activities, workflow, workflow_methods};
struct MyActivities {}
#[activities]
impl MyActivities {
#[activity]
async fn my_activity(_ctx: ActivityContext) -> Result<(), ActivityError> {
Ok(())
}
#[activity]
async fn takes_self(
self: Arc<Self>,
_ctx: ActivityContext,
_: String,
) -> Result<(), ActivityError> {
Ok(())
}
}
#[test]
fn test_activity_registration() {
let act_instance = MyActivities {};
let _ = WorkerOptions::new("task_q").register_activities(act_instance);
}
#[allow(unused, clippy::diverging_sub_expression)]
fn test_activity_via_workflow_context() {
let wf_ctx: WorkflowContext<MyWorkflow> = unimplemented!();
wf_ctx.start_activity(MyActivities::my_activity, (), ActivityOptions::default());
wf_ctx.start_activity(
MyActivities::takes_self,
"Hi".to_owned(),
ActivityOptions::default(),
);
}
#[allow(dead_code, unreachable_code, unused, clippy::diverging_sub_expression)]
async fn test_activity_direct_invocation() {
let ctx: ActivityContext = unimplemented!();
let _result = MyActivities::my_activity.run(ctx).await;
}
#[workflow]
struct MyWorkflow {
counter: u32,
}
#[allow(dead_code)]
#[workflow_methods]
impl MyWorkflow {
#[init]
fn new(_ctx: &WorkflowContextView, _input: String) -> Self {
Self { counter: 0 }
}
#[run]
async fn run(ctx: &mut WorkflowContext<Self>) -> WorkflowResult<String> {
Ok(format!("Counter: {}", ctx.state(|s| s.counter)))
}
#[signal(name = "increment")]
fn increment_counter(&mut self, _ctx: &mut SyncWorkflowContext<Self>, amount: u32) {
self.counter += amount;
}
#[signal]
async fn async_signal(_ctx: &mut WorkflowContext<Self>) {}
#[query]
fn get_counter(&self, _ctx: &WorkflowContextView) -> u32 {
self.counter
}
#[update(name = "double")]
fn double_counter(&mut self, _ctx: &mut SyncWorkflowContext<Self>) -> u32 {
self.counter *= 2;
self.counter
}
#[update]
async fn async_update(_ctx: &mut WorkflowContext<Self>, val: i32) -> i32 {
val * 2
}
}
#[test]
fn test_workflow_registration() {
let _ = WorkerOptions::new("task_q").register_workflow::<MyWorkflow>();
}
fn default_identity() -> String {
format!(
"{}@{}",
std::process::id(),
gethostname::gethostname().to_string_lossy()
)
}
#[rstest::rstest]
#[case::default_when_none_provided(None, "", Some(default_identity()))]
#[case::connection_identity_preserved(None, "conn-identity", None)]
#[case::worker_override_takes_precedence(
Some("worker-identity"),
"conn-identity",
Some("worker-identity".into())
)]
#[case::worker_override_with_empty_connection(
Some("worker-identity"),
"",
Some("worker-identity".into())
)]
#[test]
fn client_identity_resolution(
#[case] worker_override: Option<&str>,
#[case] connection_identity: &str,
#[case] expected: Option<String>,
) {
let opts = WorkerOptions::new("task_q")
.task_types(WorkerTaskTypes::activity_only())
.maybe_client_identity_override(worker_override.map(|s| s.to_owned()))
.build();
let config = opts
.to_core_options("ns".into(), connection_identity.into())
.unwrap();
assert_eq!(config.client_identity_override, expected);
}
}