#[doc(inline)]
pub use crate::__temporal_select as select;
#[doc(inline)]
pub use crate::__temporal_join as join;
pub use futures_util::future::join_all;
use crate::{
BaseWorkflowContext, SyncWorkflowContext, WorkflowContext, WorkflowContextView,
WorkflowTermination,
};
use futures_util::future::{Fuse, FutureExt, LocalBoxFuture};
use std::{
cell::RefCell,
collections::HashMap,
fmt::Debug,
pin::Pin,
rc::Rc,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use temporalio_common::{
QueryDefinition, SignalDefinition, UpdateDefinition, WorkflowDefinition,
data_converters::{
GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext,
SerializationContextData, TemporalDeserializable, TemporalSerializable,
},
protos::temporal::api::{
common::v1::{Payload, Payloads},
failure::v1::Failure,
},
};
#[derive(Debug, thiserror::Error)]
pub enum WorkflowError {
#[error("Payload conversion error: {0}")]
PayloadConversion(#[from] PayloadConversionError),
#[error("Workflow execution error: {0}")]
Execution(#[from] anyhow::Error),
}
impl From<WorkflowError> for Failure {
fn from(err: WorkflowError) -> Self {
Failure {
message: err.to_string(),
..Default::default()
}
}
}
#[doc(hidden)]
pub trait WorkflowImplementation: Sized + 'static {
type Run: WorkflowDefinition;
const HAS_INIT: bool;
const INIT_TAKES_INPUT: bool;
fn name() -> &'static str;
fn init(
ctx: WorkflowContextView,
input: Option<<Self::Run as WorkflowDefinition>::Input>,
) -> Self;
fn run(
ctx: WorkflowContext<Self>,
input: Option<<Self::Run as WorkflowDefinition>::Input>,
) -> LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>;
fn dispatch_update(
_ctx: WorkflowContext<Self>,
_name: &str,
_payloads: Payloads,
_converter: &PayloadConverter,
) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
None
}
fn validate_update(
&self,
_ctx: WorkflowContextView,
_name: &str,
_payloads: &Payloads,
_converter: &PayloadConverter,
) -> Option<Result<(), WorkflowError>> {
None
}
fn dispatch_signal(
_ctx: WorkflowContext<Self>,
_name: &str,
_payloads: Payloads,
_converter: &PayloadConverter,
) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
None
}
fn dispatch_query(
&self,
_ctx: WorkflowContextView,
_name: &str,
_payloads: &Payloads,
_converter: &PayloadConverter,
) -> Option<Result<Payload, WorkflowError>> {
None
}
}
#[doc(hidden)]
pub trait ExecutableSyncSignal<S: SignalDefinition>: WorkflowImplementation {
fn handle(&mut self, ctx: &mut SyncWorkflowContext<Self>, input: S::Input);
fn dispatch(
ctx: WorkflowContext<Self>,
payloads: Payloads,
converter: &PayloadConverter,
) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
match deserialize_input::<S::Input>(payloads.payloads, converter) {
Ok(input) => {
let mut sync_ctx = ctx.sync_context();
ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
std::future::ready(Ok(())).boxed_local()
}
Err(e) => std::future::ready(Err(e)).boxed_local(),
}
}
}
#[doc(hidden)]
pub trait ExecutableAsyncSignal<S: SignalDefinition>: WorkflowImplementation {
fn handle(ctx: WorkflowContext<Self>, input: S::Input) -> LocalBoxFuture<'static, ()>;
fn dispatch(
ctx: WorkflowContext<Self>,
payloads: Payloads,
converter: &PayloadConverter,
) -> LocalBoxFuture<'static, Result<(), WorkflowError>> {
match deserialize_input::<S::Input>(payloads.payloads, converter) {
Ok(input) => Self::handle(ctx, input).map(|()| Ok(())).boxed_local(),
Err(e) => std::future::ready(Err(e)).boxed_local(),
}
}
}
#[doc(hidden)]
pub trait ExecutableQuery<Q: QueryDefinition>: WorkflowImplementation {
fn handle(
&self,
ctx: &WorkflowContextView,
input: Q::Input,
) -> Result<Q::Output, Box<dyn std::error::Error + Send + Sync>>;
fn dispatch(
&self,
ctx: &WorkflowContextView,
payloads: &Payloads,
converter: &PayloadConverter,
) -> Result<Payload, WorkflowError> {
let input = deserialize_input::<Q::Input>(payloads.payloads.clone(), converter)?;
let output = self.handle(ctx, input).map_err(wrap_handler_error)?;
serialize_output(&output, converter)
}
}
#[doc(hidden)]
pub trait ExecutableSyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
fn handle(
&mut self,
ctx: &mut SyncWorkflowContext<Self>,
input: U::Input,
) -> Result<U::Output, Box<dyn std::error::Error + Send + Sync>>;
fn validate(
&self,
_ctx: &WorkflowContextView,
_input: &U::Input,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
fn dispatch(
ctx: WorkflowContext<Self>,
payloads: Payloads,
converter: &PayloadConverter,
) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
Ok(v) => v,
Err(e) => return std::future::ready(Err(e)).boxed_local(),
};
let converter = converter.clone();
let mut sync_ctx = ctx.sync_context();
let result = ctx.state_mut(|wf| Self::handle(wf, &mut sync_ctx, input));
match result {
Ok(output) => match serialize_output(&output, &converter) {
Ok(payload) => std::future::ready(Ok(payload)).boxed_local(),
Err(e) => std::future::ready(Err(e)).boxed_local(),
},
Err(e) => std::future::ready(Err(wrap_handler_error(e))).boxed_local(),
}
}
fn dispatch_validate(
&self,
ctx: &WorkflowContextView,
payloads: &Payloads,
converter: &PayloadConverter,
) -> Result<(), WorkflowError> {
let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
self.validate(ctx, &input).map_err(wrap_handler_error)
}
}
#[doc(hidden)]
pub trait ExecutableAsyncUpdate<U: UpdateDefinition>: WorkflowImplementation {
fn handle(
ctx: WorkflowContext<Self>,
input: U::Input,
) -> LocalBoxFuture<'static, Result<U::Output, Box<dyn std::error::Error + Send + Sync>>>;
fn validate(
&self,
_ctx: &WorkflowContextView,
_input: &U::Input,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
fn dispatch(
ctx: WorkflowContext<Self>,
payloads: Payloads,
converter: &PayloadConverter,
) -> LocalBoxFuture<'static, Result<Payload, WorkflowError>> {
let input = match deserialize_input::<U::Input>(payloads.payloads, converter) {
Ok(v) => v,
Err(e) => return std::future::ready(Err(e)).boxed_local(),
};
let converter = converter.clone();
async move {
let output = Self::handle(ctx, input).await.map_err(wrap_handler_error)?;
serialize_output(&output, &converter)
}
.boxed_local()
}
fn dispatch_validate(
&self,
ctx: &WorkflowContextView,
payloads: &Payloads,
converter: &PayloadConverter,
) -> Result<(), WorkflowError> {
let input = deserialize_input::<U::Input>(payloads.payloads.clone(), converter)?;
self.validate(ctx, &input).map_err(wrap_handler_error)
}
}
pub(crate) struct DispatchData<'a> {
pub(crate) payloads: Payloads,
pub(crate) headers: HashMap<String, Payload>,
pub(crate) converter: &'a PayloadConverter,
}
#[doc(hidden)]
pub trait WorkflowImplementer: WorkflowImplementation {
fn register_all(defs: &mut WorkflowDefinitions);
}
pub(crate) trait DynWorkflowExecution {
fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>>;
fn validate_update(&self, name: &str, data: &DispatchData)
-> Option<Result<(), WorkflowError>>;
fn start_update(
&mut self,
name: &str,
data: DispatchData,
) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>>;
fn dispatch_signal(
&mut self,
name: &str,
data: DispatchData,
) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>>;
fn dispatch_query(
&self,
name: &str,
data: DispatchData,
) -> Option<Result<Payload, WorkflowError>>;
}
pub(crate) struct WorkflowExecution<W: WorkflowImplementation> {
ctx: WorkflowContext<W>,
run_future: Fuse<LocalBoxFuture<'static, Result<Payload, WorkflowTermination>>>,
}
impl<W: WorkflowImplementation> WorkflowExecution<W>
where
<W::Run as WorkflowDefinition>::Input: Send,
{
pub(crate) fn new(
base_ctx: BaseWorkflowContext,
init_input: Option<<W::Run as WorkflowDefinition>::Input>,
run_input: Option<<W::Run as WorkflowDefinition>::Input>,
) -> Self {
let view = base_ctx.view();
let workflow = W::init(view, init_input);
Self::new_with_workflow(workflow, base_ctx, run_input)
}
pub(crate) fn new_with_workflow(
workflow: W,
base_ctx: BaseWorkflowContext,
run_input: Option<<W::Run as WorkflowDefinition>::Input>,
) -> Self {
let workflow = Rc::new(RefCell::new(workflow));
let ctx = WorkflowContext::from_base(base_ctx, workflow);
let run_future = W::run(ctx.clone(), run_input).fuse();
Self { ctx, run_future }
}
}
impl<W: WorkflowImplementation> DynWorkflowExecution for WorkflowExecution<W> {
fn poll_run(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<Payload, WorkflowTermination>> {
Pin::new(&mut self.run_future).poll(cx)
}
fn validate_update(
&self,
name: &str,
data: &DispatchData,
) -> Option<Result<(), WorkflowError>> {
let view = self.ctx.view();
self.ctx
.state(|wf| wf.validate_update(view, name, &data.payloads, data.converter))
}
fn start_update(
&mut self,
name: &str,
data: DispatchData,
) -> Option<LocalBoxFuture<'static, Result<Payload, WorkflowError>>> {
let ctx = self.ctx.with_headers(data.headers);
W::dispatch_update(ctx, name, data.payloads, data.converter)
}
fn dispatch_signal(
&mut self,
name: &str,
data: DispatchData,
) -> Option<LocalBoxFuture<'static, Result<(), WorkflowError>>> {
let ctx = self.ctx.with_headers(data.headers);
W::dispatch_signal(ctx, name, data.payloads, data.converter)
}
fn dispatch_query(
&self,
name: &str,
data: DispatchData,
) -> Option<Result<Payload, WorkflowError>> {
let view = self.ctx.view();
self.ctx
.state(|wf| wf.dispatch_query(view, name, &data.payloads, data.converter))
}
}
pub(crate) type WorkflowExecutionFactory = Arc<
dyn Fn(
Vec<Payload>,
PayloadConverter,
BaseWorkflowContext,
) -> Result<Box<dyn DynWorkflowExecution>, PayloadConversionError>
+ Send
+ Sync,
>;
#[derive(Default, Clone)]
pub struct WorkflowDefinitions {
workflows: HashMap<&'static str, WorkflowExecutionFactory>,
}
impl WorkflowDefinitions {
pub fn new() -> Self {
Self::default()
}
pub fn register_workflow<W: WorkflowImplementer>(&mut self) -> &mut Self {
W::register_all(self);
self
}
#[doc(hidden)]
pub fn register_workflow_run<W: WorkflowImplementation>(&mut self) -> &mut Self
where
<W::Run as WorkflowDefinition>::Input: Send,
{
let workflow_name = W::name();
let factory: WorkflowExecutionFactory =
Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
let ser_ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};
let input = converter.from_payloads(&ser_ctx, payloads)?;
let (init_input, run_input) = if W::INIT_TAKES_INPUT {
(Some(input), None)
} else {
(None, Some(input))
};
Ok(
Box::new(WorkflowExecution::<W>::new(base_ctx, init_input, run_input))
as Box<dyn DynWorkflowExecution>,
)
});
self.workflows.insert(workflow_name, factory);
self
}
pub fn register_workflow_run_with_factory<W, F>(&mut self, user_factory: F) -> &mut Self
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
assert!(
!W::HAS_INIT,
"Workflows registered with a factory must not define an #[init] method. \
The factory replaces init for instance creation."
);
let workflow_name = W::name();
let user_factory = Arc::new(user_factory);
let factory: WorkflowExecutionFactory =
Arc::new(move |payloads, converter: PayloadConverter, base_ctx| {
let ser_ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};
let input: <W::Run as WorkflowDefinition>::Input =
converter.from_payloads(&ser_ctx, payloads)?;
let workflow = user_factory();
Ok(Box::new(WorkflowExecution::<W>::new_with_workflow(
workflow,
base_ctx,
Some(input),
)) as Box<dyn DynWorkflowExecution>)
});
self.workflows.insert(workflow_name, factory);
self
}
pub fn is_empty(&self) -> bool {
self.workflows.is_empty()
}
pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
self.workflows.get(workflow_type).cloned()
}
pub fn workflow_types(&self) -> impl Iterator<Item = &'static str> + '_ {
self.workflows.keys().copied()
}
}
impl Debug for WorkflowDefinitions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkflowDefinitions")
.field("workflows", &self.workflows.keys().collect::<Vec<_>>())
.finish()
}
}
pub fn deserialize_input<I: TemporalDeserializable + 'static>(
payloads: Vec<Payload>,
converter: &PayloadConverter,
) -> Result<I, WorkflowError> {
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter,
};
converter.from_payloads(&ctx, payloads).map_err(Into::into)
}
pub fn serialize_output<O: TemporalSerializable + 'static>(
output: &O,
converter: &PayloadConverter,
) -> Result<Payload, WorkflowError> {
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter,
};
converter.to_payload(&ctx, output).map_err(Into::into)
}
pub fn wrap_handler_error(e: Box<dyn std::error::Error + Send + Sync>) -> WorkflowError {
WorkflowError::Execution(anyhow::anyhow!(e))
}
pub fn serialize_result<T: TemporalSerializable + 'static>(
result: T,
converter: &PayloadConverter,
) -> Result<Payload, WorkflowError> {
serialize_output(&result, converter)
}