1use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use extism::{Manifest, PluginBuilder, UserData, Wasm};
13
14use astrid_core::plugin_abi::ToolDefinition;
15
16use crate::context::PluginContext;
17use crate::error::{PluginError, PluginResult};
18use crate::manifest::{PluginEntryPoint, PluginManifest};
19use crate::plugin::{Plugin, PluginId, PluginState};
20use crate::security::PluginSecurityGate;
21use crate::tool::PluginTool;
22use crate::wasm::host_functions::register_host_functions;
23use crate::wasm::host_state::HostState;
24use crate::wasm::tool::WasmPluginTool;
25
26#[derive(Clone)]
30pub struct WasmPluginConfig {
31 pub security: Option<Arc<dyn PluginSecurityGate>>,
33 pub max_memory_bytes: u64,
35 pub max_execution_time: Duration,
37 pub require_hash: bool,
39}
40
41impl std::fmt::Debug for WasmPluginConfig {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("WasmPluginConfig")
44 .field("has_security", &self.security.is_some())
45 .field("max_memory_bytes", &self.max_memory_bytes)
46 .field("max_execution_time", &self.max_execution_time)
47 .field("require_hash", &self.require_hash)
48 .finish()
49 }
50}
51
52pub struct WasmPlugin {
54 id: PluginId,
55 manifest: PluginManifest,
56 state: PluginState,
57 config: WasmPluginConfig,
58 extism_plugin: Option<Arc<Mutex<extism::Plugin>>>,
60 tools: Vec<Box<dyn PluginTool>>,
62}
63
64impl WasmPlugin {
65 pub(crate) fn new(manifest: PluginManifest, config: WasmPluginConfig) -> Self {
67 let id = manifest.id.clone();
68 Self {
69 id,
70 manifest,
71 state: PluginState::Unloaded,
72 config,
73 extism_plugin: None,
74 tools: Vec::new(),
75 }
76 }
77}
78
79#[async_trait]
80impl Plugin for WasmPlugin {
81 fn id(&self) -> &PluginId {
82 &self.id
83 }
84
85 fn manifest(&self) -> &PluginManifest {
86 &self.manifest
87 }
88
89 fn state(&self) -> PluginState {
90 self.state.clone()
91 }
92
93 async fn load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
94 self.state = PluginState::Loading;
95
96 match self.do_load(ctx) {
97 Ok(()) => {
98 self.state = PluginState::Ready;
99 Ok(())
100 },
101 Err(e) => {
102 let msg = e.to_string();
103 self.state = PluginState::Failed(msg);
104 Err(e)
105 },
106 }
107 }
108
109 async fn unload(&mut self) -> PluginResult<()> {
110 self.state = PluginState::Unloading;
111 self.tools.clear();
112 self.extism_plugin = None;
113 self.state = PluginState::Unloaded;
114 Ok(())
115 }
116
117 fn tools(&self) -> &[Box<dyn PluginTool>] {
118 &self.tools
119 }
120}
121
122impl WasmPlugin {
123 fn do_load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
125 let (wasm_path, expected_hash) = match &self.manifest.entry_point {
127 PluginEntryPoint::Wasm { path, hash } => (path.clone(), hash.clone()),
128 other @ PluginEntryPoint::Mcp { .. } => {
129 return Err(PluginError::LoadFailed {
130 plugin_id: self.id.clone(),
131 message: format!("expected Wasm entry point, got: {other:?}"),
132 });
133 },
134 };
135
136 let resolved_path = if wasm_path.is_absolute() {
138 wasm_path
139 } else {
140 ctx.workspace_root.join(&wasm_path)
141 };
142
143 let wasm_bytes = std::fs::read(&resolved_path).map_err(|e| PluginError::LoadFailed {
145 plugin_id: self.id.clone(),
146 message: format!("failed to read WASM file {}: {e}", resolved_path.display()),
147 })?;
148
149 verify_hash(
151 &wasm_bytes,
152 expected_hash.as_deref(),
153 &self.id,
154 self.config.require_hash,
155 )?;
156
157 let host_state = HostState {
159 plugin_id: self.id.clone(),
160 workspace_root: ctx.workspace_root.clone(),
161 kv: ctx.kv.clone(),
162 config: ctx.config.clone(),
163 security: self.config.security.clone(),
164 runtime_handle: tokio::runtime::Handle::current(),
165 };
166 let user_data = UserData::new(host_state);
167
168 let extism_wasm = Wasm::data(wasm_bytes);
170 let mut extism_manifest = Manifest::new([extism_wasm]);
171 extism_manifest = extism_manifest.with_timeout(self.config.max_execution_time);
172 let pages = self.config.max_memory_bytes / (64 * 1024);
174 let max_pages = u32::try_from(pages).unwrap_or(u32::MAX);
175 extism_manifest = extism_manifest.with_memory_max(max_pages);
176
177 let builder = PluginBuilder::new(extism_manifest).with_wasi(true);
179 let builder = register_host_functions(builder, user_data);
180 let mut plugin = builder
181 .build()
182 .map_err(|e| PluginError::WasmError(format!("failed to build Extism plugin: {e}")))?;
183
184 let tools = discover_tools(&mut plugin)?;
186 let plugin_arc = Arc::new(Mutex::new(plugin));
187
188 let wasm_tools: Vec<Box<dyn PluginTool>> = tools
189 .into_iter()
190 .map(|td| {
191 let schema: serde_json::Value =
192 serde_json::from_str(&td.input_schema).unwrap_or(serde_json::json!({}));
193 Box::new(WasmPluginTool::new(
194 td.name,
195 td.description,
196 schema,
197 Arc::clone(&plugin_arc),
198 )) as Box<dyn PluginTool>
199 })
200 .collect();
201
202 self.extism_plugin = Some(plugin_arc);
203 self.tools = wasm_tools;
204
205 Ok(())
206 }
207}
208
209fn verify_hash(
214 wasm_bytes: &[u8],
215 expected: Option<&str>,
216 plugin_id: &PluginId,
217 require_hash: bool,
218) -> PluginResult<()> {
219 match expected {
220 Some(expected_hex) => {
221 let actual_hex = blake3::hash(wasm_bytes).to_hex().to_string();
222 if actual_hex != expected_hex {
223 return Err(PluginError::HashMismatch {
224 expected: expected_hex.to_string(),
225 actual: actual_hex,
226 });
227 }
228 tracing::debug!(plugin = %plugin_id, "WASM module hash verified");
229 },
230 None if require_hash => {
231 return Err(PluginError::LoadFailed {
232 plugin_id: plugin_id.clone(),
233 message: "WASM module hash required but not specified in manifest".into(),
234 });
235 },
236 None => {
237 tracing::warn!(
238 plugin = %plugin_id,
239 "WASM module hash not specified — module integrity not verified"
240 );
241 },
242 }
243 Ok(())
244}
245
246fn discover_tools(plugin: &mut extism::Plugin) -> PluginResult<Vec<ToolDefinition>> {
248 let result = plugin
250 .call::<&str, String>("describe-tools", "")
251 .map_err(|e| PluginError::WasmError(format!("describe-tools call failed: {e}")))?;
252
253 let definitions: Vec<ToolDefinition> = serde_json::from_str(&result).map_err(|e| {
254 PluginError::WasmError(format!("failed to parse describe-tools output: {e}"))
255 })?;
256
257 Ok(definitions)
258}
259
260impl std::fmt::Debug for WasmPlugin {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("WasmPlugin")
263 .field("id", &self.id)
264 .field("state", &self.state)
265 .field("tool_count", &self.tools.len())
266 .finish_non_exhaustive()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::collections::HashMap;
274
275 #[test]
276 fn hash_verification_match() {
277 let data = b"hello world";
278 let expected = blake3::hash(data).to_hex().to_string();
279 let id = PluginId::from_static("test");
280 assert!(verify_hash(data, Some(&expected), &id, false).is_ok());
281 }
282
283 #[test]
284 fn hash_verification_mismatch() {
285 let data = b"hello world";
286 let id = PluginId::from_static("test");
287 let result = verify_hash(
288 data,
289 Some("0000000000000000000000000000000000000000000000000000000000000000"),
290 &id,
291 false,
292 );
293 assert!(result.is_err());
294 match result.unwrap_err() {
295 PluginError::HashMismatch { expected, actual } => {
296 assert_eq!(
297 expected,
298 "0000000000000000000000000000000000000000000000000000000000000000"
299 );
300 assert!(!actual.is_empty());
301 },
302 other => panic!("expected HashMismatch, got: {other:?}"),
303 }
304 }
305
306 #[test]
307 fn hash_verification_none_is_ok() {
308 let data = b"hello world";
309 let id = PluginId::from_static("test");
310 assert!(verify_hash(data, None, &id, false).is_ok());
311 }
312
313 #[test]
314 fn hash_verification_none_rejected_when_required() {
315 let data = b"hello world";
316 let id = PluginId::from_static("test");
317 let result = verify_hash(data, None, &id, true);
318 assert!(result.is_err());
319 assert!(matches!(
320 result.unwrap_err(),
321 PluginError::LoadFailed { .. }
322 ));
323 }
324
325 #[test]
326 fn wasm_plugin_starts_unloaded() {
327 let manifest = PluginManifest {
328 id: PluginId::from_static("test"),
329 name: "Test".into(),
330 version: "0.1.0".into(),
331 description: None,
332 author: None,
333 entry_point: PluginEntryPoint::Wasm {
334 path: "plugin.wasm".into(),
335 hash: None,
336 },
337 capabilities: vec![],
338 config: HashMap::new(),
339 };
340 let config = WasmPluginConfig {
341 security: None,
342 max_memory_bytes: 64 * 1024 * 1024,
343 max_execution_time: Duration::from_secs(30),
344 require_hash: false,
345 };
346 let plugin = WasmPlugin::new(manifest, config);
347 assert_eq!(plugin.state(), PluginState::Unloaded);
348 assert!(plugin.tools().is_empty());
349 }
350}