use std::{fs, io, path::PathBuf, rc::Rc, sync::Arc};
use anyhow::{Context, bail};
use prost::Message;
use temporalio_common::protos::{
coresdk::workflow_commands::WorkflowCommand, temporal::api::failure::v1::Failure,
};
use temporalio_workflow::runtime::{
guest::WorkflowInstance,
host::WorkflowHost,
types::{
ActivationJobResult, ActivationResult, MainRoutineCompletion, QueryResponse,
RoutineCompletion, RoutinePollResult, StartedRoutine, TaskFailure, TerminalOutcome,
UpdateRoutineCompletion, UpdateRoutineKind, WorkflowActivation,
WorkflowDefinitionDescriptor, WorkflowFailure,
},
};
use wasmtime::{
Config, Engine, Store,
component::{Component, HasSelf, Linker, ResourceAny},
};
use crate::workflow_registry::{
WorkflowDefinitions, WorkflowExecutionFactory, WorkflowExecutionInput,
};
wasmtime::component::bindgen!({
path: "../workflow/wit",
world: "workflow-module",
});
use self::{
exports::temporal::workflow_runtime::workflow_guest as wit_guest,
temporal::workflow_runtime::{types as wit_types, workflow_host as wit_host},
};
#[derive(Clone, Debug)]
pub struct WasmWorkflowComponent {
component_id: String,
source: WasmWorkflowComponentSource,
}
#[derive(Clone, Debug)]
enum WasmWorkflowComponentSource {
File(PathBuf),
Bytes(Arc<[u8]>),
}
impl WasmWorkflowComponent {
pub fn from_file(
component_id: impl Into<String>,
path: impl Into<PathBuf>,
) -> io::Result<Self> {
let path = path.into();
if !path.try_exists()? {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!(
"WASM workflow component file does not exist: {}",
path.display()
),
));
}
Ok(Self {
component_id: component_id.into(),
source: WasmWorkflowComponentSource::File(path),
})
}
pub fn from_bytes(
component_id: impl Into<String>,
bytes: impl Into<Arc<[u8]>>,
) -> io::Result<Self> {
let bytes = bytes.into();
if bytes.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"WASM workflow component bytes must not be empty",
));
}
Ok(Self {
component_id: component_id.into(),
source: WasmWorkflowComponentSource::Bytes(bytes),
})
}
}
impl WorkflowDefinitions {
pub(crate) fn register_wasm_workflows(
&mut self,
components: Vec<WasmWorkflowComponent>,
) -> Result<(), anyhow::Error> {
for component in components {
let module = Arc::new(CompiledWasmWorkflowModule::new(component)?);
for definition in &module.definitions {
let module = module.clone();
let factory: WorkflowExecutionFactory =
Arc::new(move |input| module.instantiate(input));
self.insert_workflow(definition.clone(), factory)?;
}
}
Ok(())
}
}
struct CompiledWasmWorkflowModule {
engine: Engine,
component: Component,
definitions: Vec<WorkflowDefinitionDescriptor>,
}
impl CompiledWasmWorkflowModule {
fn new(component: WasmWorkflowComponent) -> Result<Self, anyhow::Error> {
let mut config = Config::new();
config.wasm_component_model(true);
let engine = Engine::new(&config)?;
let bytes: Arc<[u8]> = match &component.source {
WasmWorkflowComponentSource::File(path) => fs::read(path)
.with_context(|| format!("Failed reading WASM component {}", path.display()))?
.into(),
WasmWorkflowComponentSource::Bytes(bytes) => bytes.clone(),
};
let component = Component::new(&engine, bytes.as_ref()).map_err(|err| {
anyhow::Error::msg(format!(
"Failed to compile WASM component {}: {err}",
component.component_id
))
})?;
let definitions = Self::read_definitions(&engine, &component)?;
if definitions.is_empty() {
bail!("WASM component exports no workflows");
}
Ok(Self {
engine,
component,
definitions,
})
}
fn read_definitions(
engine: &Engine,
component: &Component,
) -> Result<Vec<WorkflowDefinitionDescriptor>, anyhow::Error> {
let mut linker = Linker::new(engine);
WorkflowModule::add_to_linker::<_, HasSelf<_>>(&mut linker, |data| data)?;
let mut store = Store::new(
engine,
WasmWorkflowHostState::new(Rc::new(NoopWorkflowHost)),
);
let module = WorkflowModule::instantiate(&mut store, component, &linker)?;
module
.temporal_workflow_runtime_workflow_guest()
.call_list_workflows(&mut store)
.map(|defs| {
defs.into_iter()
.map(|def| WorkflowDefinitionDescriptor {
workflow_type: def.workflow_type,
has_init: def.has_init,
init_takes_input: def.init_takes_input,
signals: def.signals,
queries: def.queries,
updates: def
.updates
.into_iter()
.map(|u| {
temporalio_workflow::runtime::types::UpdateDefinitionDescriptor {
name: u.name,
has_validator: u.has_validator,
}
})
.collect(),
})
.collect()
})
.map_err(|err| {
anyhow::Error::msg(format!(
"Failed to list workflows exported by WASM component: {err}"
))
})
}
fn instantiate(
&self,
input: WorkflowExecutionInput,
) -> Result<Box<dyn WorkflowInstance>, anyhow::Error> {
let mut linker = Linker::new(&self.engine);
WorkflowModule::add_to_linker::<_, HasSelf<_>>(&mut linker, |data| data)?;
let mut store = Store::new(&self.engine, WasmWorkflowHostState::new(input.host.clone()));
let module = WorkflowModule::instantiate(&mut store, &self.component, &linker)?;
let guest = module.temporal_workflow_runtime_workflow_guest();
let workflow_init = wit_types::WorkflowInit {
namespace: input.namespace.clone(),
task_queue: input.task_queue.clone(),
run_id: input.run_id.clone(),
initialize_workflow: input.init_workflow_job.encode_to_vec(),
};
let workflow_instance = guest
.call_instantiate_workflow(&mut store, &workflow_init)
.map_err(|err| {
anyhow::Error::msg(format!("Failed to instantiate WASM workflow: {err}"))
})?
.map_err(convert_failure)
.map_err(|failure| {
anyhow::Error::msg(format!(
"WASM workflow initialization failed: {}",
failure.message
))
})?;
Ok(Box::new(WasmWorkflowInstance {
store,
guest: guest.clone(),
workflow_instance,
}))
}
}
struct WasmWorkflowInstance {
store: Store<WasmWorkflowHostState>,
guest: wit_guest::Guest,
workflow_instance: ResourceAny,
}
impl WorkflowInstance for WasmWorkflowInstance {
fn activate(
&mut self,
activation: WorkflowActivation,
) -> Result<ActivationResult, WorkflowFailure> {
let result = self.guest.workflow_instance().call_activate(
&mut self.store,
self.workflow_instance,
&activation.encode_to_vec(),
);
trap_to_failure(result, |result| ActivationResult {
job_results: result
.job_results
.into_iter()
.map(|job_result| match job_result {
wit_types::ActivationJobResult::None => ActivationJobResult::None,
wit_types::ActivationJobResult::StartedRoutine(routine) => {
ActivationJobResult::StartedRoutine(StartedRoutine {
routine_id: routine.routine_id,
kind: match routine.kind {
wit_types::RoutineKind::Main => {
temporalio_workflow::runtime::types::RoutineKind::Main
}
wit_types::RoutineKind::Signal(name) => {
temporalio_workflow::runtime::types::RoutineKind::Signal(name)
}
wit_types::RoutineKind::Update(update) => {
temporalio_workflow::runtime::types::RoutineKind::Update(
UpdateRoutineKind {
name: update.name,
update_id: update.update_id,
protocol_instance_id: update.protocol_instance_id,
},
)
}
},
})
}
wit_types::ActivationJobResult::QueryResponse(response) => {
ActivationJobResult::QueryResponse(Box::new(QueryResponse {
result: response
.response
.map(decode_proto)
.map_err(|failure| *convert_failure(failure)),
}))
}
wit_types::ActivationJobResult::UpdateRejected(failure) => {
ActivationJobResult::UpdateRejected(convert_failure(failure))
}
})
.collect(),
})
}
fn poll_routine(
&mut self,
routine_id: u64,
_waker: &std::task::Waker,
) -> Result<RoutinePollResult, WorkflowFailure> {
let result = self.guest.workflow_instance().call_poll_routine(
&mut self.store,
self.workflow_instance,
routine_id,
);
trap_to_failure(result, |result| RoutinePollResult {
completion: result.completion.map(|completion| match completion {
wit_types::RoutineCompletion::Main(completion) => {
RoutineCompletion::Main(match completion {
wit_types::MainRoutineCompletion::Blocked => MainRoutineCompletion::Blocked,
wit_types::MainRoutineCompletion::TaskFailed(task_failure) => {
MainRoutineCompletion::TaskFailed(TaskFailure {
failure: convert_failure(task_failure.failure),
force_cause: task_failure.force_cause,
})
}
wit_types::MainRoutineCompletion::Terminal(outcome) => {
MainRoutineCompletion::Terminal(Box::new(match outcome {
wit_types::TerminalOutcome::Completed(payload) => {
TerminalOutcome::Completed(decode_proto(payload))
}
wit_types::TerminalOutcome::Failed(failure) => {
TerminalOutcome::Failed(convert_failure(failure))
}
wit_types::TerminalOutcome::Cancelled => TerminalOutcome::Cancelled,
wit_types::TerminalOutcome::ContinueAsNew(req) => {
TerminalOutcome::ContinueAsNew(Box::new(decode_proto(req)))
}
}))
}
})
}
wit_types::RoutineCompletion::Signal(result) => {
RoutineCompletion::Signal(match result {
wit_types::SignalRoutineCompletion::Succeeded => Ok(()),
wit_types::SignalRoutineCompletion::Failed(failure) => {
Err(convert_failure(failure))
}
})
}
wit_types::RoutineCompletion::Update(completion) => {
RoutineCompletion::Update(match completion {
wit_types::UpdateRoutineCompletion::Completed(success) => {
UpdateRoutineCompletion::Completed {
protocol_instance_id: success.protocol_instance_id,
result: decode_proto(success.value),
}
}
wit_types::UpdateRoutineCompletion::Rejected(rejection) => {
UpdateRoutineCompletion::Rejected {
protocol_instance_id: rejection.protocol_instance_id,
failure: convert_failure(rejection.failure),
}
}
})
}
}),
made_progress: result.made_progress,
})
}
}
struct WasmWorkflowHostState {
host: Rc<dyn WorkflowHost>,
}
impl WasmWorkflowHostState {
fn new(host: Rc<dyn WorkflowHost>) -> Self {
Self { host }
}
}
impl wit_types::Host for WasmWorkflowHostState {}
impl wit_host::Host for WasmWorkflowHostState {
fn set_current_details(&mut self, details: String) {
self.host.set_current_details(details);
}
fn push_command(&mut self, command: wit_types::WorkflowCommand) {
self.host.push_command(decode_proto(command));
}
}
struct NoopWorkflowHost;
impl WorkflowHost for NoopWorkflowHost {
fn set_current_details(&self, _details: String) {}
fn push_command(&self, _command: WorkflowCommand) {}
}
fn trap_to_failure<T, U>(
result: Result<Result<T, wit_types::Failure>, wasmtime::Error>,
convert: impl FnOnce(T) -> U,
) -> Result<U, WorkflowFailure> {
result
.map_err(|trap| {
Box::new(Failure {
message: format!("WASM workflow trapped: {trap}"),
..Default::default()
})
})?
.map(convert)
.map_err(convert_failure)
}
fn convert_failure(failure: wit_types::Failure) -> WorkflowFailure {
Box::new(decode_proto(failure))
}
fn decode_proto<M: Message + prost::Name + Default>(bytes: Vec<u8>) -> M {
M::decode(bytes.as_slice()).unwrap_or_else(|err| {
let n = M::NAME;
panic!("failed to decode {n} from WASM boundary bytes: {err}")
})
}