#![warn(missing_docs)]
#[macro_use]
extern crate tracing;
extern crate self as temporalio_sdk;
pub mod activities;
pub mod error;
pub mod interceptors;
mod workflow_executor;
mod workflow_future;
mod workflow_registry;
#[cfg(feature = "wasm-workflows")]
mod workflow_wasm;
pub mod workflows;
pub use crate::error::{
ActivityExecutionError, ApplicationFailure, ChildWorkflowExecutionError,
ChildWorkflowStartError, OutgoingActivityError, OutgoingError, OutgoingWorkflowError,
WorkflowRegistrationError, WorkflowSignalError,
};
pub use temporalio_client::Namespace;
pub use temporalio_workflow::{
ActivityCloseTimeouts, ActivityOptions, BaseWorkflowContext, CancellableFuture,
ChildWorkflowOptions, ContinueAsNewOptions, ContinueAsNewVersioningBehavior,
ExternalWorkflowHandle, LocalActivityOptions, NexusOperationOptions, ParentWorkflowInfo,
RootWorkflowInfo, Signal, SignalData, StartChildWorkflowExecutionFailedCause,
StartedChildWorkflow, SyncWorkflowContext, TimerOptions, TimerResult, WorkflowContext,
WorkflowContextView, WorkflowResult, WorkflowTermination,
};
#[cfg(feature = "wasm-workflows")]
pub use workflow_wasm::WasmWorkflowComponent;
use crate::{
activities::{
ActivityContext, ActivityDefinitions, ActivityImplementer, ExecutableActivity,
activity_error_to_core_result,
},
interceptors::{ActivityInboundInterceptor, WorkerInterceptor},
workflow_executor::{TaskHandle, WorkflowExecutor},
workflow_future::start_workflow,
workflow_registry::WorkflowDefinitions,
};
use anyhow::{Context, anyhow, bail};
use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use std::{
any::{Any, TypeId},
cell::RefCell,
collections::{HashMap, HashSet},
fmt::{Debug, Display, Formatter},
future::Future,
sync::Arc,
time::Duration,
};
use temporalio_client::{Client, ClientOptions, NamespacedClient};
use temporalio_common::{
ActivityDefinition, WorkflowDefinition,
data_converters::{DataConverter, SerializationContext, SerializationContextData},
payload_visitor::{decode_payloads, encode_payloads},
protos::{
TaskToken,
coresdk::{
ActivityTaskCompletion, AsJsonPayloadExt,
activity_result::ActivityExecutionResult,
activity_task::{ActivityTask, activity_task},
workflow_activation::{WorkflowActivation, workflow_activation_job::Variant},
workflow_completion::WorkflowActivationCompletion,
},
temporal::api::{common::v1::Payload, enums::v1::WorkflowTaskFailedCause},
},
worker::{WorkerDeploymentOptions, WorkerTaskTypes, build_id_from_current_exe},
};
use temporalio_sdk_core::{
CoreRuntime, PollError, PollerBehavior, TunerBuilder, Worker as CoreWorker, WorkerConfig,
WorkerTuner, WorkerVersioningStrategy, WorkflowErrorType, init_worker,
};
use temporalio_workflow::runtime::entry::WorkflowImplementation;
use tokio::sync::{
Notify,
mpsc::{UnboundedSender, unbounded_channel},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Span, field};
use uuid::Uuid;
#[derive(bon::Builder, Clone)]
#[builder(start_fn = new, on(String, into), state_mod(vis = "pub"))]
#[non_exhaustive]
pub struct WorkerOptions {
#[builder(start_fn)]
pub task_queue: String,
#[builder(field)]
activities: ActivityDefinitions,
#[builder(field)]
workflows: WorkflowDefinitions,
#[cfg(feature = "wasm-workflows")]
#[builder(field)]
wasm_workflow_components: Vec<WasmWorkflowComponent>,
#[builder(default = def_build_id())]
pub deployment_options: WorkerDeploymentOptions,
pub client_identity_override: Option<String>,
#[builder(default = 1000)]
pub max_cached_workflows: usize,
#[builder(default = Arc::new(TunerBuilder::default().build()))]
pub tuner: Arc<dyn WorkerTuner + Send + Sync>,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub workflow_task_poller_behavior: PollerBehavior,
#[builder(default = 0.2)]
pub nonsticky_to_sticky_poll_ratio: f32,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub activity_task_poller_behavior: PollerBehavior,
#[builder(default = PollerBehavior::SimpleMaximum(5))]
pub nexus_task_poller_behavior: PollerBehavior,
#[builder(default = WorkerTaskTypes::all())]
pub task_types: WorkerTaskTypes,
#[builder(default = Duration::from_secs(10))]
pub sticky_queue_schedule_to_start_timeout: Duration,
#[builder(default = Duration::from_secs(60))]
pub max_heartbeat_throttle_interval: Duration,
#[builder(default = Duration::from_secs(30))]
pub default_heartbeat_throttle_interval: Duration,
pub max_task_queue_activities_per_second: Option<f64>,
pub max_worker_activities_per_second: Option<f64>,
#[builder(default)]
pub workflow_failure_errors: HashSet<WorkflowErrorType>,
#[builder(default)]
pub workflow_types_to_failure_errors: HashMap<String, HashSet<WorkflowErrorType>>,
pub graceful_shutdown_period: Option<Duration>,
#[builder(default = true)]
pub detect_nondeterministic_futures: bool,
}
impl<S: worker_options_builder::State> WorkerOptionsBuilder<S> {
pub fn register_activities<AI: ActivityImplementer>(mut self, instance: AI) -> Self {
self.activities.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(mut self, instance: Arc<AD::Implementer>) -> Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Input: Send + Sync,
AD::Output: Send + Sync,
{
self.activities.register_activity::<AD>(instance);
self
}
pub fn register_workflow<W>(mut self) -> Result<Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
{
self.workflows.register_workflow::<W>()?;
Ok(self)
}
pub fn register_workflow_with_factory<W, F>(
mut self,
factory: F,
) -> Result<Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflows
.register_workflow_run_with_factory::<W, F>(factory)?;
Ok(self)
}
#[cfg(feature = "wasm-workflows")]
pub fn register_wasm_workflow(mut self, component: WasmWorkflowComponent) -> Self {
self.wasm_workflow_components.push(component);
self
}
}
fn def_build_id() -> WorkerDeploymentOptions {
WorkerDeploymentOptions::from_build_id(build_id_from_current_exe().to_owned())
}
impl WorkerOptions {
pub fn register_activities<AI: ActivityImplementer>(&mut self, instance: AI) -> &mut Self {
self.activities.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(&mut self, instance: Arc<AD::Implementer>) -> &mut Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Input: Send + Sync,
AD::Output: Send + Sync,
{
self.activities.register_activity::<AD>(instance);
self
}
pub fn activities(&self) -> ActivityDefinitions {
self.activities.clone()
}
pub fn register_workflow<W>(&mut self) -> Result<&mut Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
{
self.workflows.register_workflow::<W>()?;
Ok(self)
}
pub fn register_workflow_with_factory<W, F>(
&mut self,
factory: F,
) -> Result<&mut Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflows
.register_workflow_run_with_factory::<W, F>(factory)?;
Ok(self)
}
#[cfg(feature = "wasm-workflows")]
pub fn register_wasm_workflow(&mut self, component: WasmWorkflowComponent) -> &mut Self {
self.wasm_workflow_components.push(component);
self
}
pub fn workflows(&self) -> WorkflowDefinitions {
self.workflows.clone()
}
#[doc(hidden)]
pub fn to_core_options(
&self,
namespace: String,
connection_identity: String,
) -> Result<WorkerConfig, String> {
WorkerConfig::builder()
.namespace(namespace)
.task_queue(self.task_queue.clone())
.maybe_client_identity_override(self.client_identity_override.clone().or_else(|| {
connection_identity.is_empty().then(|| {
format!(
"{}@{}",
std::process::id(),
gethostname::gethostname().to_string_lossy()
)
})
}))
.max_cached_workflows(self.max_cached_workflows)
.tuner(self.tuner.clone())
.workflow_task_poller_behavior(self.workflow_task_poller_behavior)
.activity_task_poller_behavior(self.activity_task_poller_behavior)
.nexus_task_poller_behavior(self.nexus_task_poller_behavior)
.task_types(self.task_types)
.sticky_queue_schedule_to_start_timeout(self.sticky_queue_schedule_to_start_timeout)
.max_heartbeat_throttle_interval(self.max_heartbeat_throttle_interval)
.default_heartbeat_throttle_interval(self.default_heartbeat_throttle_interval)
.maybe_max_task_queue_activities_per_second(self.max_task_queue_activities_per_second)
.maybe_max_worker_activities_per_second(self.max_worker_activities_per_second)
.maybe_graceful_shutdown_period(self.graceful_shutdown_period)
.versioning_strategy(WorkerVersioningStrategy::WorkerDeploymentBased(
self.deployment_options.clone(),
))
.workflow_failure_errors(self.workflow_failure_errors.clone())
.workflow_types_to_failure_errors(self.workflow_types_to_failure_errors.clone())
.build()
}
}
pub struct Worker {
common: CommonWorker,
workflow_half: WorkflowHalf,
activity_half: ActivityHalf,
}
struct CommonWorker {
worker: Arc<CoreWorker>,
task_queue: String,
worker_interceptor: Option<Box<dyn WorkerInterceptor>>,
activity_inbound_interceptors: Vec<Arc<dyn ActivityInboundInterceptor>>,
client_options: ClientOptions,
data_converter: DataConverter,
}
struct WorkflowHalf {
workflows: RefCell<HashMap<String, WorkflowData>>,
workflow_definitions: WorkflowDefinitions,
workflow_removed_from_map: Notify,
detect_nondeterministic_futures: bool,
}
struct WorkflowData {
activation_chan: UnboundedSender<WorkflowActivation>,
}
struct WorkflowFutureHandle<F: Future> {
join_handle: F,
run_id: String,
}
#[derive(Default)]
struct ActivityHalf {
activities: ActivityDefinitions,
task_tokens_to_cancels: HashMap<TaskToken, CancellationToken>,
}
#[derive(Debug, thiserror::Error)]
enum ActivityTaskHandlerError {
#[error("{source}")]
UnregisteredActivity {
source: ActivityNotRegisteredError,
task_token: Vec<u8>,
},
#[error(transparent)]
Fatal(#[from] anyhow::Error),
}
#[derive(Debug, thiserror::Error)]
enum ActivityNotRegisteredError {
#[error(
"Activity {activity_type} is not registered on this worker, available activities: {}",
.available_activities.join(", ")
)]
HasAvailable {
activity_type: String,
available_activities: Vec<&'static str>,
},
#[error("Activity {activity_type} is not registered on this worker, no available activities.")]
NoAvailable { activity_type: String },
}
impl ActivityNotRegisteredError {
fn new(activity_type: String, available_activities: Vec<&'static str>) -> Self {
if available_activities.is_empty() {
Self::NoAvailable { activity_type }
} else {
Self::HasAvailable {
activity_type,
available_activities,
}
}
}
}
impl Worker {
pub fn new(
runtime: &CoreRuntime,
client: Client,
options: WorkerOptions,
) -> Result<Self, Box<dyn std::error::Error>> {
let wc = options
.to_core_options(client.namespace(), client.identity())
.map_err(|s| anyhow::anyhow!("{s}"))?;
let core = init_worker(runtime, wc, client.connection().clone())?;
Self::new_from_core_options(Arc::new(core), client.options().clone(), options)
}
#[doc(hidden)]
pub fn new_from_core(worker: Arc<CoreWorker>, data_converter: DataConverter) -> Self {
let client_options = ClientOptions::new(worker.get_config().namespace.clone())
.data_converter(data_converter)
.build();
Self::new_from_core_definitions(
worker,
client_options,
Default::default(),
Default::default(),
)
}
#[doc(hidden)]
pub fn new_from_core_options(
worker: Arc<CoreWorker>,
client_options: ClientOptions,
mut options: WorkerOptions,
) -> Result<Self, Box<dyn std::error::Error>> {
let acts = std::mem::take(&mut options.activities);
let wfs = std::mem::take(&mut options.workflows);
#[cfg(feature = "wasm-workflows")]
let wasm_components = std::mem::take(&mut options.wasm_workflow_components);
let mut me = Self::new_from_core_definitions(worker, client_options, acts, wfs);
me.set_detect_nondeterministic_futures(options.detect_nondeterministic_futures);
#[cfg(feature = "wasm-workflows")]
me.workflow_half
.workflow_definitions
.register_wasm_workflows(wasm_components)?;
Ok(me)
}
fn new_from_core_definitions(
worker: Arc<CoreWorker>,
client_options: ClientOptions,
activities: ActivityDefinitions,
workflows: WorkflowDefinitions,
) -> Self {
let data_converter = client_options.data_converter.clone();
Self {
common: CommonWorker {
task_queue: worker.get_config().task_queue.clone(),
worker,
worker_interceptor: None,
activity_inbound_interceptors: Vec::new(),
client_options,
data_converter,
},
workflow_half: WorkflowHalf {
workflows: Default::default(),
workflow_definitions: workflows,
workflow_removed_from_map: Default::default(),
detect_nondeterministic_futures: false,
},
activity_half: ActivityHalf {
activities,
..Default::default()
},
}
}
pub fn task_queue(&self) -> &str {
&self.common.task_queue
}
#[doc(hidden)]
pub fn set_detect_nondeterministic_futures(&mut self, enabled: bool) {
self.workflow_half.detect_nondeterministic_futures = enabled;
}
pub fn shutdown_handle(&self) -> impl Fn() + use<> {
let w = self.common.worker.clone();
move || w.initiate_shutdown()
}
pub fn register_activities<AI: ActivityImplementer>(&mut self, instance: AI) -> &mut Self {
self.activity_half
.activities
.register_activities::<AI>(instance);
self
}
pub fn register_activity<AD>(&mut self, instance: Arc<AD::Implementer>) -> &mut Self
where
AD: ActivityDefinition + ExecutableActivity,
AD::Input: Send + Sync,
AD::Output: Send + Sync,
{
self.activity_half
.activities
.register_activity::<AD>(instance);
self
}
pub fn register_workflow<W>(&mut self) -> Result<&mut Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
{
self.workflow_half
.workflow_definitions
.register_workflow::<W>()?;
Ok(self)
}
pub fn register_workflow_with_factory<W, F>(
&mut self,
factory: F,
) -> Result<&mut Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
self.workflow_half
.workflow_definitions
.register_workflow_run_with_factory::<W, F>(factory)?;
Ok(self)
}
pub async fn run(&mut self) -> Result<(), anyhow::Error> {
let shutdown_token = CancellationToken::new();
let (common, wf_half, act_half) = self.split_apart();
let (wf_future_tx, wf_future_rx) =
unbounded_channel::<WorkflowFutureHandle<TaskHandle<WorkflowResult<Payload>>>>();
let (completions_tx, completions_rx) = unbounded_channel();
let workflow_local_set = tokio::task::LocalSet::new();
let executor = WorkflowExecutor::new();
let wf_future_joiner = async {
UnboundedReceiverStream::new(wf_future_rx)
.map(Result::<_, anyhow::Error>::Ok)
.try_for_each_concurrent(
None,
|WorkflowFutureHandle {
join_handle,
run_id,
}| {
let wf_half = &*wf_half;
async move {
let result = join_handle.await.map_err(anyhow::Error::new)?;
if let Err(e) = result
&& !matches!(e, WorkflowTermination::Evicted)
{
return Err(anyhow::Error::new(e));
}
debug!(run_id=%run_id, "Removing workflow from cache");
wf_half.workflows.borrow_mut().remove(&run_id);
wf_half.workflow_removed_from_map.notify_one();
Ok(())
}
},
)
.await
.context("Workflow futures encountered an error")
};
let wf_completion_processor = async {
UnboundedReceiverStream::new(completions_rx)
.map(Ok)
.try_for_each_concurrent(None, |mut completion| async {
encode_payloads(
&mut completion,
common.data_converter.codec(),
&SerializationContextData::Workflow,
)
.await;
if let Some(ref i) = common.worker_interceptor {
i.on_workflow_activation_completion(&completion).await;
}
common.worker.complete_workflow_activation(completion).await
})
.map_err(anyhow::Error::from)
.await
.context("Workflow completions processor encountered an error")
};
tokio::try_join!(
async {
workflow_local_set.run_until(async {
tokio::try_join!(
async {
loop {
let mut activation =
match common.worker.poll_workflow_activation().await {
Err(PollError::ShutDown) => {
break;
}
o => o?,
};
decode_payloads(
&mut activation,
common.data_converter.codec(),
&SerializationContextData::Workflow,
)
.await;
if let Some(ref i) = common.worker_interceptor {
i.on_workflow_activation(&activation).await?;
}
if let Some(wf_fut) = wf_half
.workflow_activation_handler(
common,
shutdown_token.clone(),
activation,
&completions_tx,
&executor,
)
.await?
&& wf_future_tx.send(wf_fut).is_err()
{
panic!(
"Receive half of completion processor channel cannot be dropped"
);
}
executor.process_tasks();
}
shutdown_token.cancel();
drop(wf_future_tx);
drop(completions_tx);
executor.shutdown().await;
Result::<_, anyhow::Error>::Ok(())
},
wf_future_joiner,
)
}).await
},
async {
if !act_half.activities.is_empty() {
loop {
let activity = common.worker.poll_activity_task().await;
if matches!(activity, Err(PollError::ShutDown)) {
break;
}
let mut activity = activity?;
decode_payloads(
&mut activity,
common.data_converter.codec(),
&SerializationContextData::Activity,
)
.await;
match act_half.activity_task_handler(
common.worker.clone(),
common.client_options.clone(),
common.task_queue.clone(),
common.data_converter.clone(),
common.activity_inbound_interceptors.clone(),
activity,
) {
Ok(()) => {}
Err(ActivityTaskHandlerError::UnregisteredActivity {
source,
task_token,
}) => {
let failure = common.data_converter.to_failure(
&SerializationContextData::Activity,
OutgoingError::Activity(OutgoingActivityError::Application(
ApplicationFailure::builder(source)
.type_name("NotFoundError".to_owned())
.build()
.into(),
)),
);
let mut completion = ActivityTaskCompletion {
task_token,
result: Some(ActivityExecutionResult::fail(failure)),
};
encode_payloads(
&mut completion,
common.data_converter.codec(),
&SerializationContextData::Activity,
)
.await;
common.worker.complete_activity_task(completion).await?;
}
Err(ActivityTaskHandlerError::Fatal(err)) => return Err(err),
};
}
};
Result::<_, anyhow::Error>::Ok(())
},
wf_completion_processor,
)?;
if let Some(i) = self.common.worker_interceptor.as_ref() {
i.on_shutdown(self);
}
self.common.worker.shutdown().await;
Ok(())
}
pub fn set_worker_interceptor(&mut self, interceptor: impl WorkerInterceptor + 'static) {
self.common.worker_interceptor = Some(Box::new(interceptor));
}
pub fn add_activity_inbound_interceptor(
&mut self,
interceptor: impl ActivityInboundInterceptor,
) {
self.common
.activity_inbound_interceptors
.push(Arc::new(interceptor));
}
pub fn with_new_core_worker(&mut self, new_core_worker: Arc<CoreWorker>) {
self.common.worker = new_core_worker;
}
pub fn cached_workflows(&self) -> usize {
self.workflow_half.workflows.borrow().len()
}
pub fn worker_instance_key(&self) -> Uuid {
self.common.worker.worker_instance_key()
}
#[doc(hidden)]
pub fn core_worker(&self) -> Arc<CoreWorker> {
self.common.worker.clone()
}
fn split_apart(&mut self) -> (&mut CommonWorker, &mut WorkflowHalf, &mut ActivityHalf) {
(
&mut self.common,
&mut self.workflow_half,
&mut self.activity_half,
)
}
}
impl WorkflowHalf {
#[allow(clippy::type_complexity)]
async fn workflow_activation_handler(
&self,
common: &CommonWorker,
shutdown_token: CancellationToken,
mut activation: WorkflowActivation,
completions_tx: &UnboundedSender<WorkflowActivationCompletion>,
executor: &WorkflowExecutor,
) -> Result<Option<WorkflowFutureHandle<TaskHandle<WorkflowResult<Payload>>>>, anyhow::Error>
{
let mut res = None;
let run_id = activation.run_id.clone();
if let Some(sw) = activation.jobs.iter_mut().find_map(|j| match j.variant {
Some(Variant::InitializeWorkflow(ref mut sw)) => Some(sw),
_ => None,
}) {
let workflow_type = sw.workflow_type.clone();
let (wff, activations) = {
if let Some(factory) = self.workflow_definitions.get_workflow(&workflow_type) {
match start_workflow(
factory,
common.worker.get_config().namespace.clone(),
common.task_queue.clone(),
run_id.clone(),
std::mem::take(sw),
completions_tx.clone(),
common.data_converter.clone(),
self.detect_nondeterministic_futures,
) {
Ok(result) => result,
Err(e) => {
warn!("Failed to create workflow {workflow_type}: {e}");
completions_tx
.send(WorkflowActivationCompletion::fail(
run_id,
format!("Failed to create workflow: {e}").into(),
Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
))
.expect("Completion channel intact");
return Ok(None);
}
}
} else {
warn!("Workflow type {workflow_type} not found");
completions_tx
.send(WorkflowActivationCompletion::fail(
run_id,
format!("Workflow type {workflow_type} not found").into(),
Some(WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure),
))
.expect("Completion channel intact");
return Ok(None);
}
};
let jh = executor.spawn(async move {
tokio::select! {
r = wff.fuse() => r,
_ = shutdown_token.cancelled() => {
Err(WorkflowTermination::Evicted)
}
}
});
res = Some(WorkflowFutureHandle {
join_handle: jh,
run_id: run_id.clone(),
});
loop {
if self.workflows.borrow_mut().contains_key(&run_id) {
self.workflow_removed_from_map.notified().await;
} else {
break;
}
}
self.workflows.borrow_mut().insert(
run_id.clone(),
WorkflowData {
activation_chan: activations,
},
);
}
if let Some(dat) = self.workflows.borrow_mut().get_mut(&run_id) {
dat.activation_chan
.send(activation)
.expect("Workflow should exist if we're sending it an activation");
} else {
if activation.jobs.len() == 1
&& matches!(
activation.jobs.first().map(|j| &j.variant),
Some(Some(Variant::RemoveFromCache(_)))
)
{
completions_tx
.send(WorkflowActivationCompletion::from_cmds(run_id, vec![]))
.expect("Completion channel intact");
return Ok(None);
}
bail!("Got activation {activation:?} for unknown workflow {run_id}");
};
Ok(res)
}
}
impl ActivityHalf {
fn activity_task_handler(
&mut self,
worker: Arc<CoreWorker>,
client_options: ClientOptions,
task_queue: String,
data_converter: DataConverter,
activity_inbound_interceptors: Vec<Arc<dyn ActivityInboundInterceptor>>,
activity: ActivityTask,
) -> Result<(), ActivityTaskHandlerError> {
match activity.variant {
Some(activity_task::Variant::Start(start)) => {
let Some(act_fn) = self.activities.get(&start.activity_type) else {
let activity_type = start.activity_type.clone();
let source =
ActivityNotRegisteredError::new(activity_type, self.activities.names());
return Err(ActivityTaskHandlerError::UnregisteredActivity {
source,
task_token: activity.task_token,
});
};
let span = info_span!(
"RunActivity",
"otel.name" = format!("RunActivity:{}", start.activity_type),
"otel.kind" = "server",
"temporalActivityID" = start.activity_id,
"temporalWorkflowID" = field::Empty,
"temporalRunID" = field::Empty,
);
let ct = CancellationToken::new();
let task_token = activity.task_token;
self.task_tokens_to_cancels
.insert(task_token.clone().into(), ct.clone());
let (ctx, args) = ActivityContext::new(
worker.clone(),
client_options,
ct,
task_queue,
task_token.clone(),
start,
);
let codec_data_converter = data_converter.clone();
tokio::spawn(async move {
let act_fut = async move {
if let Some(info) = &ctx.info().workflow_execution {
Span::current()
.record("temporalWorkflowID", &info.workflow_id)
.record("temporalRunID", &info.run_id);
}
(act_fn)(args, data_converter, ctx, activity_inbound_interceptors).await
}
.instrument(span);
let result = act_fut.await;
let result = match result {
Ok(output) => {
let pc = codec_data_converter.payload_converter();
let ctx = SerializationContext {
data: &SerializationContextData::Activity,
converter: pc,
};
match output.serialize_payload(&ctx) {
Ok(payload) => ActivityExecutionResult::ok(payload),
Err(err) => {
activity_error_to_core_result(&codec_data_converter, err.into())
}
}
}
Err(err) => activity_error_to_core_result(&codec_data_converter, err),
};
let mut completion = ActivityTaskCompletion {
task_token,
result: Some(result),
};
encode_payloads(
&mut completion,
codec_data_converter.codec(),
&SerializationContextData::Activity,
)
.await;
worker.complete_activity_task(completion).await?;
Ok::<_, anyhow::Error>(())
});
}
Some(activity_task::Variant::Cancel(_)) => {
if let Some(ct) = self
.task_tokens_to_cancels
.get(activity.task_token.as_slice())
{
ct.cancel();
}
}
None => {
return Err(anyhow!("Undefined activity task variant").into());
}
}
Ok(())
}
}
#[derive(Debug)]
pub enum ActExitValue<T> {
WillCompleteAsync,
Normal(T),
}
impl<T: AsJsonPayloadExt> From<T> for ActExitValue<T> {
fn from(t: T) -> Self {
Self::Normal(t)
}
}
fn panic_formatter(panic: Box<dyn Any>) -> Box<dyn Display> {
_panic_formatter::<&str>(panic)
}
fn _panic_formatter<T: 'static + PrintablePanicType>(panic: Box<dyn Any>) -> Box<dyn Display> {
match panic.downcast::<T>() {
Ok(d) => d,
Err(orig) => {
if TypeId::of::<<T as PrintablePanicType>::NextType>()
== TypeId::of::<EndPrintingAttempts>()
{
return Box::new("Couldn't turn panic into a string");
}
_panic_formatter::<T::NextType>(orig)
}
}
}
trait PrintablePanicType: Display {
type NextType: PrintablePanicType;
}
impl PrintablePanicType for &str {
type NextType = String;
}
impl PrintablePanicType for String {
type NextType = EndPrintingAttempts;
}
struct EndPrintingAttempts {}
impl Display for EndPrintingAttempts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Will never be printed")
}
}
impl PrintablePanicType for EndPrintingAttempts {
type NextType = EndPrintingAttempts;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::activities::ActivityError;
use temporalio_macros::{activities, activity_definitions, workflow, workflow_methods};
struct MyActivities {}
struct SharedActivities;
#[activity_definitions]
impl SharedActivities {
#[activity(name = "shared-greet")]
fn greet(name: String) -> Result<String, ActivityError> {
unimplemented!()
}
}
#[activities]
impl MyActivities {
#[activity]
async fn my_activity(_ctx: ActivityContext) -> Result<(), ActivityError> {
Ok(())
}
#[activity(definition = shared_activities::Greet)]
async fn greet(_ctx: ActivityContext, name: String) -> Result<String, ActivityError> {
Ok(name)
}
#[activity]
async fn takes_self(
self: Arc<Self>,
_ctx: ActivityContext,
_: String,
) -> Result<(), ActivityError> {
Ok(())
}
}
#[test]
fn test_activity_registration() {
let act_instance = MyActivities {};
let _ = WorkerOptions::new("task_q").register_activities(act_instance);
}
#[allow(unused, clippy::diverging_sub_expression)]
fn test_activity_via_workflow_context() {
let wf_ctx: WorkflowContext<MyWorkflow> = unimplemented!();
wf_ctx.start_activity(
MyActivities::my_activity,
(),
ActivityOptions::start_to_close_timeout(Duration::from_secs(5)),
);
wf_ctx.start_activity(
SharedActivities::greet,
"Hi".to_owned(),
ActivityOptions::start_to_close_timeout(Duration::from_secs(5)),
);
wf_ctx.start_activity(
MyActivities::greet,
"Hi".to_owned(),
ActivityOptions::start_to_close_timeout(Duration::from_secs(5)),
);
wf_ctx.start_activity(
MyActivities::takes_self,
"Hi".to_owned(),
ActivityOptions::start_to_close_timeout(Duration::from_secs(5)),
);
}
#[allow(dead_code, unreachable_code, unused, clippy::diverging_sub_expression)]
async fn test_activity_direct_invocation() {
let ctx: ActivityContext = unimplemented!();
let _result = MyActivities::my_activity.run(ctx).await;
}
#[workflow]
struct MyWorkflow {
counter: u32,
}
#[allow(dead_code)]
#[workflow_methods]
impl MyWorkflow {
#[init]
fn new(_ctx: &WorkflowContextView, _input: String) -> Self {
Self { counter: 0 }
}
#[run]
async fn run(ctx: &mut WorkflowContext<Self>) -> WorkflowResult<String> {
Ok(format!("Counter: {}", ctx.state(|s| s.counter)))
}
#[signal(name = "increment")]
fn increment_counter(&mut self, _ctx: &mut SyncWorkflowContext<Self>, amount: u32) {
self.counter += amount;
}
#[signal]
async fn async_signal(_ctx: &mut WorkflowContext<Self>) {}
#[query]
fn get_counter(&self, _ctx: &WorkflowContextView) -> u32 {
self.counter
}
#[update(name = "double")]
fn double_counter(&mut self, _ctx: &mut SyncWorkflowContext<Self>) -> u32 {
self.counter *= 2;
self.counter
}
#[update]
async fn async_update(_ctx: &mut WorkflowContext<Self>, val: i32) -> i32 {
val * 2
}
}
#[test]
fn test_workflow_registration() {
let _ = WorkerOptions::new("task_q")
.register_workflow::<MyWorkflow>()
.unwrap();
}
#[test]
fn duplicate_workflow_registration_errors() {
let result = WorkerOptions::new("task_q")
.register_workflow::<MyWorkflow>()
.unwrap()
.register_workflow::<MyWorkflow>();
let err = match result {
Ok(_) => panic!("duplicate workflow registration should error"),
Err(err) => err,
};
assert_eq!(
err,
WorkflowRegistrationError::DuplicateWorkflowType {
workflow_type: "MyWorkflow".to_string()
}
);
}
#[test]
fn factory_registration_with_init_errors() {
let result = WorkerOptions::new("task_q")
.register_workflow_with_factory(|| MyWorkflow { counter: 0 });
let err = match result {
Ok(_) => panic!("factory registration with #[init] should error"),
Err(err) => err,
};
assert_eq!(
err,
WorkflowRegistrationError::FactoryRegistrationWithInit {
workflow_type: "MyWorkflow".to_string()
}
);
}
fn default_identity() -> String {
format!(
"{}@{}",
std::process::id(),
gethostname::gethostname().to_string_lossy()
)
}
#[rstest::rstest]
#[case::default_when_none_provided(None, "", Some(default_identity()))]
#[case::connection_identity_preserved(None, "conn-identity", None)]
#[case::worker_override_takes_precedence(
Some("worker-identity"),
"conn-identity",
Some("worker-identity".into())
)]
#[case::worker_override_with_empty_connection(
Some("worker-identity"),
"",
Some("worker-identity".into())
)]
#[test]
fn client_identity_resolution(
#[case] worker_override: Option<&str>,
#[case] connection_identity: &str,
#[case] expected: Option<String>,
) {
let opts = WorkerOptions::new("task_q")
.task_types(WorkerTaskTypes::activity_only())
.maybe_client_identity_override(worker_override.map(|s| s.to_owned()))
.build();
let config = opts
.to_core_options("ns".into(), connection_identity.into())
.unwrap();
assert_eq!(config.client_identity_override, expected);
}
}