use std::{collections::HashMap, fmt::Debug, rc::Rc, sync::Arc};
use anyhow::Context;
use temporalio_common::{
WorkflowDefinition,
data_converters::{
DataConverter, GenericPayloadConverter, PayloadConverter, SerializationContext,
SerializationContextData,
},
protos::{
coresdk::workflow_activation::InitializeWorkflow, temporal::api::common::v1::Payload,
},
};
use temporalio_workflow::{
BaseWorkflowContext,
runtime::{
entry::WorkflowImplementation,
guest::WorkflowInstance,
host::WorkflowHost,
instance::{GuestWorkflowInstance, instantiate_workflow},
types::WorkflowDefinitionDescriptor,
},
};
pub(crate) struct WorkflowExecutionInput {
pub namespace: String,
pub task_queue: String,
pub run_id: String,
pub init_workflow_job: InitializeWorkflow,
pub data_converter: DataConverter,
pub host: Rc<dyn WorkflowHost>,
}
pub(crate) type WorkflowExecutionFactory = Arc<
dyn Fn(WorkflowExecutionInput) -> Result<Box<dyn WorkflowInstance>, anyhow::Error>
+ Send
+ Sync,
>;
#[derive(Clone)]
struct RegisteredWorkflow {
definition: WorkflowDefinitionDescriptor,
factory: WorkflowExecutionFactory,
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum WorkflowRegistrationError {
#[error("Workflow type {workflow_type} is already registered")]
DuplicateWorkflowType {
workflow_type: String,
},
#[error(
"Workflow type {workflow_type} must not define an #[init] method when registered with a factory"
)]
FactoryRegistrationWithInit {
workflow_type: String,
},
}
#[derive(Default, Clone)]
pub struct WorkflowDefinitions {
workflows: HashMap<String, RegisteredWorkflow>,
}
impl WorkflowDefinitions {
pub fn new() -> Self {
Self::default()
}
pub fn register_workflow<W: WorkflowImplementation>(
&mut self,
) -> Result<&mut Self, WorkflowRegistrationError>
where
<W::Run as WorkflowDefinition>::Input: Send,
{
let factory = Arc::new(move |input| {
let (payloads, payload_converter, base_ctx) = workflow_input_parts(input);
instantiate_workflow::<W>(payloads, payload_converter, base_ctx)
.context("Failed to instantiate native workflow")
});
self.insert_workflow(W::definition(), factory)?;
Ok(self)
}
pub fn register_workflow_run_with_factory<W, F>(
&mut self,
user_factory: F,
) -> Result<&mut Self, WorkflowRegistrationError>
where
W: WorkflowImplementation,
<W::Run as WorkflowDefinition>::Input: Send,
F: Fn() -> W + Send + Sync + 'static,
{
if W::HAS_INIT {
return Err(WorkflowRegistrationError::FactoryRegistrationWithInit {
workflow_type: W::definition().workflow_type,
});
}
let factory = Arc::new(move |input| {
let (payloads, payload_converter, base_ctx) = workflow_input_parts(input);
let ser_ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &payload_converter,
};
let input: <W::Run as WorkflowDefinition>::Input =
payload_converter.from_payloads(&ser_ctx, payloads)?;
let workflow = user_factory();
Ok(Box::new(GuestWorkflowInstance::<W>::new_with_workflow(
workflow,
base_ctx,
Some(input),
)) as Box<dyn WorkflowInstance>)
});
self.insert_workflow(W::definition(), factory)?;
Ok(self)
}
pub fn is_empty(&self) -> bool {
self.workflows.is_empty()
}
pub(crate) fn insert_workflow(
&mut self,
definition: WorkflowDefinitionDescriptor,
factory: WorkflowExecutionFactory,
) -> Result<(), WorkflowRegistrationError> {
let workflow_type = definition.workflow_type.clone();
if self.workflows.contains_key(&workflow_type) {
return Err(WorkflowRegistrationError::DuplicateWorkflowType { workflow_type });
}
self.workflows.insert(
workflow_type,
RegisteredWorkflow {
definition,
factory,
},
);
Ok(())
}
pub(crate) fn get_workflow(&self, workflow_type: &str) -> Option<WorkflowExecutionFactory> {
self.workflows
.get(workflow_type)
.map(|wf| wf.factory.clone())
}
pub fn workflow_definitions(&self) -> impl Iterator<Item = &WorkflowDefinitionDescriptor> + '_ {
self.workflows.values().map(|wf| &wf.definition)
}
}
fn workflow_input_parts(
input: WorkflowExecutionInput,
) -> (Vec<Payload>, PayloadConverter, BaseWorkflowContext) {
let WorkflowExecutionInput {
namespace,
task_queue,
run_id,
init_workflow_job,
data_converter,
host,
} = input;
let payloads = init_workflow_job.arguments.clone();
let payload_converter = data_converter.payload_converter().clone();
let base_ctx = BaseWorkflowContext::new(
namespace,
task_queue,
run_id,
init_workflow_job,
data_converter,
host,
);
(payloads, payload_converter, base_ctx)
}
impl Debug for WorkflowDefinitions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkflowDefinitions")
.field("workflows", &self.workflows.keys().collect::<Vec<_>>())
.finish()
}
}