1use std::fs;
2
3use tokio::task::JoinError;
4use tokio::time::{sleep, timeout};
5use tracing::instrument;
6use wasmtime::component::{Component, Linker, ResourceTable};
7use wasmtime::{Engine, Store, Trap};
8use wasmtime_wasi::p2;
9use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
10
11use crate::retry;
12use crate::types::{McpError, ToolInput, ToolOutput, ToolRef};
13
14#[derive(Clone)]
16pub struct WasixExecutor {
17 engine: Engine,
18}
19
20impl WasixExecutor {
21 pub fn new() -> Result<Self, McpError> {
23 let mut config = wasmtime::Config::new();
24 config.wasm_component_model(true);
25 config.async_support(false);
26 config.epoch_interruption(true);
27 let engine = Engine::new(&config)
28 .map_err(|err| McpError::Internal(format!("failed to create engine: {err}")))?;
29 Ok(Self { engine })
30 }
31
32 pub fn engine(&self) -> &Engine {
34 &self.engine
35 }
36
37 #[instrument(skip(self, tool, input), fields(tool = %tool.name))]
39 pub async fn invoke(&self, tool: &ToolRef, input: &ToolInput) -> Result<ToolOutput, McpError> {
40 let input_bytes = serde_json::to_vec(&input.payload)
41 .map_err(|err| McpError::InvalidInput(err.to_string()))?;
42 let attempts = tool.max_retries().saturating_add(1);
43 let timeout_duration = tool.timeout();
44 let base_backoff = tool.retry_backoff();
45
46 for attempt in 0..attempts {
47 let exec = self.exec_once(tool.clone(), input_bytes.clone());
48 let result = if let Some(duration) = timeout_duration {
49 match timeout(duration, exec).await {
50 Ok(res) => res,
51 Err(_) => return Err(McpError::timeout(&tool.name, duration)),
52 }
53 } else {
54 exec.await
55 };
56
57 match result {
58 Ok(bytes) => {
59 let payload = serde_json::from_slice(&bytes).map_err(|err| {
60 McpError::ExecutionFailed(format!("invalid tool output JSON: {err}"))
61 })?;
62 let structured_content = match &payload {
63 serde_json::Value::Object(map) => map.get("structuredContent").cloned(),
64 _ => None,
65 };
66 return Ok(ToolOutput {
67 payload,
68 structured_content,
69 });
70 }
71 Err(InvocationFailure::Transient(msg)) => {
72 if attempt + 1 >= attempts {
73 return Err(McpError::Transient(tool.name.clone(), msg));
74 }
75 let backoff = retry::backoff(base_backoff, attempt);
76 tracing::debug!(attempt, ?backoff, "transient failure, retrying");
77 sleep(backoff).await;
78 }
79 Err(InvocationFailure::Fatal(err)) => return Err(err),
80 }
81 }
82
83 Err(McpError::Internal("unreachable retry loop".into()))
84 }
85
86 async fn exec_once(&self, tool: ToolRef, input: Vec<u8>) -> Result<Vec<u8>, InvocationFailure> {
87 let engine = self.engine.clone();
88 tokio::task::spawn_blocking(move || invoke_blocking(engine, tool, input))
89 .await
90 .map_err(|err| join_error(err, "spawn_blocking failed"))?
91 }
92}
93
94impl Default for WasixExecutor {
95 fn default() -> Self {
96 Self::new().expect("engine construction should succeed")
97 }
98}
99
100fn join_error(err: JoinError, context: &str) -> InvocationFailure {
101 InvocationFailure::Fatal(McpError::Internal(format!("{context}: {err}")))
102}
103
104enum InvocationFailure {
105 Transient(String),
106 Fatal(McpError),
107}
108
109impl InvocationFailure {
110 fn transient(msg: impl Into<String>) -> Self {
111 Self::Transient(msg.into())
112 }
113
114 fn fatal(err: impl Into<McpError>) -> Self {
115 Self::Fatal(err.into())
116 }
117}
118
119fn invoke_blocking(
120 engine: Engine,
121 tool: ToolRef,
122 input: Vec<u8>,
123) -> Result<Vec<u8>, InvocationFailure> {
124 let component_bytes = fs::read(tool.component_path()).map_err(|err| {
125 InvocationFailure::fatal(McpError::ExecutionFailed(format!(
126 "failed to read `{}`: {err}",
127 tool.component
128 )))
129 })?;
130 let component = Component::from_binary(&engine, &component_bytes).map_err(|err| {
131 InvocationFailure::fatal(McpError::ExecutionFailed(format!(
132 "failed to compile `{}`: {err}",
133 tool.component
134 )))
135 })?;
136
137 let mut linker = Linker::new(&engine);
138 p2::add_to_linker_sync(&mut linker).map_err(|err| {
139 InvocationFailure::fatal(McpError::Internal(format!(
140 "failed to link WASI imports: {err}"
141 )))
142 })?;
143
144 let pre = linker.instantiate_pre(&component).map_err(|err| {
145 InvocationFailure::fatal(McpError::ExecutionFailed(format!(
146 "failed to prepare `{}`: {err}",
147 tool.component
148 )))
149 })?;
150
151 let mut store = Store::new(&engine, WasiState::new());
152 let instance = pre
153 .instantiate(&mut store)
154 .map_err(|err| classify(err, &tool))?;
155
156 let func = instance
157 .get_typed_func::<(String,), (String,)>(&mut store, &tool.entry)
158 .map_err(|err| {
159 InvocationFailure::fatal(McpError::ExecutionFailed(format!(
160 "missing entry `{}`: {err}",
161 tool.entry
162 )))
163 })?;
164
165 let input_str = String::from_utf8(input).map_err(|err| {
166 InvocationFailure::fatal(McpError::InvalidInput(format!(
167 "input is not valid UTF-8: {err}"
168 )))
169 })?;
170
171 let (output,) = func
172 .call(&mut store, (input_str,))
173 .map_err(|err| classify(err, &tool))?;
174
175 Ok(output.into_bytes())
176}
177
178fn classify(err: wasmtime::Error, tool: &ToolRef) -> InvocationFailure {
179 if err.downcast_ref::<Trap>().is_some() {
180 InvocationFailure::transient(err.to_string())
181 } else {
182 InvocationFailure::fatal(McpError::ExecutionFailed(format!(
183 "tool `{}` failed: {err}",
184 tool.name
185 )))
186 }
187}
188
189struct WasiState {
190 ctx: WasiCtx,
191 table: ResourceTable,
192}
193
194impl WasiState {
195 fn new() -> Self {
196 let mut builder = WasiCtxBuilder::new();
197 builder.inherit_stdio();
198 builder.inherit_env();
199 builder.allow_blocking_current_thread(true);
200 Self {
201 ctx: builder.build(),
202 table: ResourceTable::new(),
203 }
204 }
205}
206
207impl WasiView for WasiState {
208 fn ctx(&mut self) -> WasiCtxView<'_> {
209 WasiCtxView {
210 ctx: &mut self.ctx,
211 table: &mut self.table,
212 }
213 }
214}