Skip to main content

greentic_mcp/
executor.rs

1use std::fs;
2
3use tokio::task::JoinError;
4use tokio::time::{sleep, timeout};
5use tracing::instrument;
6use wasmtime::component::{Component, Linker, ResourceTable};
7use wasmtime::{Engine, Store, Trap};
8use wasmtime_wasi::p2;
9use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
10
11use crate::retry;
12use crate::types::{McpError, ToolInput, ToolOutput, ToolRef};
13
14/// Executes WASIX/WASI tools compiled to WebAssembly.
15#[derive(Clone)]
16pub struct WasixExecutor {
17    engine: Engine,
18}
19
20impl WasixExecutor {
21    /// Construct a new executor using a synchronous engine.
22    pub fn new() -> Result<Self, McpError> {
23        let mut config = wasmtime::Config::new();
24        config.wasm_component_model(true);
25        config.async_support(false);
26        config.epoch_interruption(true);
27        let engine = Engine::new(&config)
28            .map_err(|err| McpError::Internal(format!("failed to create engine: {err}")))?;
29        Ok(Self { engine })
30    }
31
32    /// Access the underlying Wasmtime engine.
33    pub fn engine(&self) -> &Engine {
34        &self.engine
35    }
36
37    /// Invoke the specified tool with the provided input payload.
38    #[instrument(skip(self, tool, input), fields(tool = %tool.name))]
39    pub async fn invoke(&self, tool: &ToolRef, input: &ToolInput) -> Result<ToolOutput, McpError> {
40        let input_bytes = serde_json::to_vec(&input.payload)
41            .map_err(|err| McpError::InvalidInput(err.to_string()))?;
42        let attempts = tool.max_retries().saturating_add(1);
43        let timeout_duration = tool.timeout();
44        let base_backoff = tool.retry_backoff();
45
46        for attempt in 0..attempts {
47            let exec = self.exec_once(tool.clone(), input_bytes.clone());
48            let result = if let Some(duration) = timeout_duration {
49                match timeout(duration, exec).await {
50                    Ok(res) => res,
51                    Err(_) => return Err(McpError::timeout(&tool.name, duration)),
52                }
53            } else {
54                exec.await
55            };
56
57            match result {
58                Ok(bytes) => {
59                    let payload = serde_json::from_slice(&bytes).map_err(|err| {
60                        McpError::ExecutionFailed(format!("invalid tool output JSON: {err}"))
61                    })?;
62                    let structured_content = match &payload {
63                        serde_json::Value::Object(map) => map.get("structuredContent").cloned(),
64                        _ => None,
65                    };
66                    return Ok(ToolOutput {
67                        payload,
68                        structured_content,
69                    });
70                }
71                Err(InvocationFailure::Transient(msg)) => {
72                    if attempt + 1 >= attempts {
73                        return Err(McpError::Transient(tool.name.clone(), msg));
74                    }
75                    let backoff = retry::backoff(base_backoff, attempt);
76                    tracing::debug!(attempt, ?backoff, "transient failure, retrying");
77                    sleep(backoff).await;
78                }
79                Err(InvocationFailure::Fatal(err)) => return Err(err),
80            }
81        }
82
83        Err(McpError::Internal("unreachable retry loop".into()))
84    }
85
86    async fn exec_once(&self, tool: ToolRef, input: Vec<u8>) -> Result<Vec<u8>, InvocationFailure> {
87        let engine = self.engine.clone();
88        tokio::task::spawn_blocking(move || invoke_blocking(engine, tool, input))
89            .await
90            .map_err(|err| join_error(err, "spawn_blocking failed"))?
91    }
92}
93
94impl Default for WasixExecutor {
95    fn default() -> Self {
96        Self::new().expect("engine construction should succeed")
97    }
98}
99
100fn join_error(err: JoinError, context: &str) -> InvocationFailure {
101    InvocationFailure::Fatal(McpError::Internal(format!("{context}: {err}")))
102}
103
104enum InvocationFailure {
105    Transient(String),
106    Fatal(McpError),
107}
108
109impl InvocationFailure {
110    fn transient(msg: impl Into<String>) -> Self {
111        Self::Transient(msg.into())
112    }
113
114    fn fatal(err: impl Into<McpError>) -> Self {
115        Self::Fatal(err.into())
116    }
117}
118
119fn invoke_blocking(
120    engine: Engine,
121    tool: ToolRef,
122    input: Vec<u8>,
123) -> Result<Vec<u8>, InvocationFailure> {
124    let component_bytes = fs::read(tool.component_path()).map_err(|err| {
125        InvocationFailure::fatal(McpError::ExecutionFailed(format!(
126            "failed to read `{}`: {err}",
127            tool.component
128        )))
129    })?;
130    let component = Component::from_binary(&engine, &component_bytes).map_err(|err| {
131        InvocationFailure::fatal(McpError::ExecutionFailed(format!(
132            "failed to compile `{}`: {err}",
133            tool.component
134        )))
135    })?;
136
137    let mut linker = Linker::new(&engine);
138    p2::add_to_linker_sync(&mut linker).map_err(|err| {
139        InvocationFailure::fatal(McpError::Internal(format!(
140            "failed to link WASI imports: {err}"
141        )))
142    })?;
143
144    let pre = linker.instantiate_pre(&component).map_err(|err| {
145        InvocationFailure::fatal(McpError::ExecutionFailed(format!(
146            "failed to prepare `{}`: {err}",
147            tool.component
148        )))
149    })?;
150
151    let mut store = Store::new(&engine, WasiState::new());
152    let instance = pre
153        .instantiate(&mut store)
154        .map_err(|err| classify(err, &tool))?;
155
156    let func = instance
157        .get_typed_func::<(String,), (String,)>(&mut store, &tool.entry)
158        .map_err(|err| {
159            InvocationFailure::fatal(McpError::ExecutionFailed(format!(
160                "missing entry `{}`: {err}",
161                tool.entry
162            )))
163        })?;
164
165    let input_str = String::from_utf8(input).map_err(|err| {
166        InvocationFailure::fatal(McpError::InvalidInput(format!(
167            "input is not valid UTF-8: {err}"
168        )))
169    })?;
170
171    let (output,) = func
172        .call(&mut store, (input_str,))
173        .map_err(|err| classify(err, &tool))?;
174
175    Ok(output.into_bytes())
176}
177
178fn classify(err: wasmtime::Error, tool: &ToolRef) -> InvocationFailure {
179    if err.downcast_ref::<Trap>().is_some() {
180        InvocationFailure::transient(err.to_string())
181    } else {
182        InvocationFailure::fatal(McpError::ExecutionFailed(format!(
183            "tool `{}` failed: {err}",
184            tool.name
185        )))
186    }
187}
188
189struct WasiState {
190    ctx: WasiCtx,
191    table: ResourceTable,
192}
193
194impl WasiState {
195    fn new() -> Self {
196        let mut builder = WasiCtxBuilder::new();
197        builder.inherit_stdio();
198        builder.inherit_env();
199        builder.allow_blocking_current_thread(true);
200        Self {
201            ctx: builder.build(),
202            table: ResourceTable::new(),
203        }
204    }
205}
206
207impl WasiView for WasiState {
208    fn ctx(&mut self) -> WasiCtxView<'_> {
209        WasiCtxView {
210            ctx: &mut self.ctx,
211            table: &mut self.table,
212        }
213    }
214}