Skip to main content

astrid_plugins/wasm/
plugin.rs

1//! WASM plugin implementation backed by Extism.
2//!
3//! [`WasmPlugin`] implements the [`Plugin`](crate::Plugin) trait, managing the
4//! lifecycle of an Extism WASM module. It loads `.wasm` files, verifies their
5//! blake3 hash (if provided), registers host functions, and discovers tools
6//! via the `describe-tools` guest export.
7
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10
11use async_trait::async_trait;
12use extism::{Manifest, PluginBuilder, UserData, Wasm};
13
14use astrid_core::plugin_abi::ToolDefinition;
15
16use crate::context::PluginContext;
17use crate::error::{PluginError, PluginResult};
18use crate::manifest::{PluginEntryPoint, PluginManifest};
19use crate::plugin::{Plugin, PluginId, PluginState};
20use crate::security::PluginSecurityGate;
21use crate::tool::PluginTool;
22use crate::wasm::host_functions::register_host_functions;
23use crate::wasm::host_state::HostState;
24use crate::wasm::tool::WasmPluginTool;
25
26/// Configuration from [`WasmPluginLoader`](super::loader::WasmPluginLoader).
27///
28/// Debug is implemented manually because `dyn PluginSecurityGate` is not `Debug`.
29#[derive(Clone)]
30pub struct WasmPluginConfig {
31    /// Optional security gate for host function authorization.
32    pub security: Option<Arc<dyn PluginSecurityGate>>,
33    /// Maximum WASM linear memory in bytes.
34    pub max_memory_bytes: u64,
35    /// Maximum execution time per call.
36    pub max_execution_time: Duration,
37    /// If true, reject WASM modules that don't specify a hash in their manifest.
38    pub require_hash: bool,
39}
40
41impl std::fmt::Debug for WasmPluginConfig {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("WasmPluginConfig")
44            .field("has_security", &self.security.is_some())
45            .field("max_memory_bytes", &self.max_memory_bytes)
46            .field("max_execution_time", &self.max_execution_time)
47            .field("require_hash", &self.require_hash)
48            .finish()
49    }
50}
51
52/// A plugin backed by an Extism WASM module.
53pub struct WasmPlugin {
54    id: PluginId,
55    manifest: PluginManifest,
56    state: PluginState,
57    config: WasmPluginConfig,
58    /// The Extism plugin instance (created during load).
59    extism_plugin: Option<Arc<Mutex<extism::Plugin>>>,
60    /// Tools discovered from the guest's `describe-tools` export.
61    tools: Vec<Box<dyn PluginTool>>,
62}
63
64impl WasmPlugin {
65    /// Create a new `WasmPlugin` in the `Unloaded` state.
66    pub(crate) fn new(manifest: PluginManifest, config: WasmPluginConfig) -> Self {
67        let id = manifest.id.clone();
68        Self {
69            id,
70            manifest,
71            state: PluginState::Unloaded,
72            config,
73            extism_plugin: None,
74            tools: Vec::new(),
75        }
76    }
77}
78
79#[async_trait]
80impl Plugin for WasmPlugin {
81    fn id(&self) -> &PluginId {
82        &self.id
83    }
84
85    fn manifest(&self) -> &PluginManifest {
86        &self.manifest
87    }
88
89    fn state(&self) -> PluginState {
90        self.state.clone()
91    }
92
93    async fn load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
94        self.state = PluginState::Loading;
95
96        match self.do_load(ctx) {
97            Ok(()) => {
98                self.state = PluginState::Ready;
99                Ok(())
100            },
101            Err(e) => {
102                let msg = e.to_string();
103                self.state = PluginState::Failed(msg);
104                Err(e)
105            },
106        }
107    }
108
109    async fn unload(&mut self) -> PluginResult<()> {
110        self.state = PluginState::Unloading;
111        self.tools.clear();
112        self.extism_plugin = None;
113        self.state = PluginState::Unloaded;
114        Ok(())
115    }
116
117    fn tools(&self) -> &[Box<dyn PluginTool>] {
118        &self.tools
119    }
120}
121
122impl WasmPlugin {
123    /// Internal load logic. Separated so we can catch errors and set `Failed` state.
124    fn do_load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
125        // 1. Resolve WASM file path
126        let (wasm_path, expected_hash) = match &self.manifest.entry_point {
127            PluginEntryPoint::Wasm { path, hash } => (path.clone(), hash.clone()),
128            other @ PluginEntryPoint::Mcp { .. } => {
129                return Err(PluginError::LoadFailed {
130                    plugin_id: self.id.clone(),
131                    message: format!("expected Wasm entry point, got: {other:?}"),
132                });
133            },
134        };
135
136        // If path is relative, resolve relative to workspace root
137        let resolved_path = if wasm_path.is_absolute() {
138            wasm_path
139        } else {
140            ctx.workspace_root.join(&wasm_path)
141        };
142
143        // 2. Read WASM bytes
144        let wasm_bytes = std::fs::read(&resolved_path).map_err(|e| PluginError::LoadFailed {
145            plugin_id: self.id.clone(),
146            message: format!("failed to read WASM file {}: {e}", resolved_path.display()),
147        })?;
148
149        // 3. Hash verification
150        verify_hash(
151            &wasm_bytes,
152            expected_hash.as_deref(),
153            &self.id,
154            self.config.require_hash,
155        )?;
156
157        // 4. Build HostState
158        let host_state = HostState {
159            plugin_id: self.id.clone(),
160            workspace_root: ctx.workspace_root.clone(),
161            kv: ctx.kv.clone(),
162            config: ctx.config.clone(),
163            security: self.config.security.clone(),
164            runtime_handle: tokio::runtime::Handle::current(),
165        };
166        let user_data = UserData::new(host_state);
167
168        // 5. Build Extism Manifest
169        let extism_wasm = Wasm::data(wasm_bytes);
170        let mut extism_manifest = Manifest::new([extism_wasm]);
171        extism_manifest = extism_manifest.with_timeout(self.config.max_execution_time);
172        // WASM pages are 64KB each; cap at u32::MAX pages if the byte limit is very large
173        let pages = self.config.max_memory_bytes / (64 * 1024);
174        let max_pages = u32::try_from(pages).unwrap_or(u32::MAX);
175        extism_manifest = extism_manifest.with_memory_max(max_pages);
176
177        // 6. Build Extism Plugin
178        let builder = PluginBuilder::new(extism_manifest).with_wasi(true);
179        let builder = register_host_functions(builder, user_data);
180        let mut plugin = builder
181            .build()
182            .map_err(|e| PluginError::WasmError(format!("failed to build Extism plugin: {e}")))?;
183
184        // 7. Discover tools via `describe-tools` export
185        let tools = discover_tools(&mut plugin)?;
186        let plugin_arc = Arc::new(Mutex::new(plugin));
187
188        let wasm_tools: Vec<Box<dyn PluginTool>> = tools
189            .into_iter()
190            .map(|td| {
191                let schema: serde_json::Value =
192                    serde_json::from_str(&td.input_schema).unwrap_or(serde_json::json!({}));
193                Box::new(WasmPluginTool::new(
194                    td.name,
195                    td.description,
196                    schema,
197                    Arc::clone(&plugin_arc),
198                )) as Box<dyn PluginTool>
199            })
200            .collect();
201
202        self.extism_plugin = Some(plugin_arc);
203        self.tools = wasm_tools;
204
205        Ok(())
206    }
207}
208
209/// Verify WASM module hash if an expected hash is provided.
210///
211/// If `require_hash` is true and no hash is specified in the manifest,
212/// loading is rejected. This enforces hash verification in production.
213fn verify_hash(
214    wasm_bytes: &[u8],
215    expected: Option<&str>,
216    plugin_id: &PluginId,
217    require_hash: bool,
218) -> PluginResult<()> {
219    match expected {
220        Some(expected_hex) => {
221            let actual_hex = blake3::hash(wasm_bytes).to_hex().to_string();
222            if actual_hex != expected_hex {
223                return Err(PluginError::HashMismatch {
224                    expected: expected_hex.to_string(),
225                    actual: actual_hex,
226                });
227            }
228            tracing::debug!(plugin = %plugin_id, "WASM module hash verified");
229        },
230        None if require_hash => {
231            return Err(PluginError::LoadFailed {
232                plugin_id: plugin_id.clone(),
233                message: "WASM module hash required but not specified in manifest".into(),
234            });
235        },
236        None => {
237            tracing::warn!(
238                plugin = %plugin_id,
239                "WASM module hash not specified — module integrity not verified"
240            );
241        },
242    }
243    Ok(())
244}
245
246/// Call the guest's `describe-tools` export and parse the result.
247fn discover_tools(plugin: &mut extism::Plugin) -> PluginResult<Vec<ToolDefinition>> {
248    // describe-tools takes no input (empty string) and returns JSON array
249    let result = plugin
250        .call::<&str, String>("describe-tools", "")
251        .map_err(|e| PluginError::WasmError(format!("describe-tools call failed: {e}")))?;
252
253    let definitions: Vec<ToolDefinition> = serde_json::from_str(&result).map_err(|e| {
254        PluginError::WasmError(format!("failed to parse describe-tools output: {e}"))
255    })?;
256
257    Ok(definitions)
258}
259
260impl std::fmt::Debug for WasmPlugin {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.debug_struct("WasmPlugin")
263            .field("id", &self.id)
264            .field("state", &self.state)
265            .field("tool_count", &self.tools.len())
266            .finish_non_exhaustive()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use std::collections::HashMap;
274
275    #[test]
276    fn hash_verification_match() {
277        let data = b"hello world";
278        let expected = blake3::hash(data).to_hex().to_string();
279        let id = PluginId::from_static("test");
280        assert!(verify_hash(data, Some(&expected), &id, false).is_ok());
281    }
282
283    #[test]
284    fn hash_verification_mismatch() {
285        let data = b"hello world";
286        let id = PluginId::from_static("test");
287        let result = verify_hash(
288            data,
289            Some("0000000000000000000000000000000000000000000000000000000000000000"),
290            &id,
291            false,
292        );
293        assert!(result.is_err());
294        match result.unwrap_err() {
295            PluginError::HashMismatch { expected, actual } => {
296                assert_eq!(
297                    expected,
298                    "0000000000000000000000000000000000000000000000000000000000000000"
299                );
300                assert!(!actual.is_empty());
301            },
302            other => panic!("expected HashMismatch, got: {other:?}"),
303        }
304    }
305
306    #[test]
307    fn hash_verification_none_is_ok() {
308        let data = b"hello world";
309        let id = PluginId::from_static("test");
310        assert!(verify_hash(data, None, &id, false).is_ok());
311    }
312
313    #[test]
314    fn hash_verification_none_rejected_when_required() {
315        let data = b"hello world";
316        let id = PluginId::from_static("test");
317        let result = verify_hash(data, None, &id, true);
318        assert!(result.is_err());
319        assert!(matches!(
320            result.unwrap_err(),
321            PluginError::LoadFailed { .. }
322        ));
323    }
324
325    #[test]
326    fn wasm_plugin_starts_unloaded() {
327        let manifest = PluginManifest {
328            id: PluginId::from_static("test"),
329            name: "Test".into(),
330            version: "0.1.0".into(),
331            description: None,
332            author: None,
333            entry_point: PluginEntryPoint::Wasm {
334                path: "plugin.wasm".into(),
335                hash: None,
336            },
337            capabilities: vec![],
338            config: HashMap::new(),
339        };
340        let config = WasmPluginConfig {
341            security: None,
342            max_memory_bytes: 64 * 1024 * 1024,
343            max_execution_time: Duration::from_secs(30),
344            require_hash: false,
345        };
346        let plugin = WasmPlugin::new(manifest, config);
347        assert_eq!(plugin.state(), PluginState::Unloaded);
348        assert!(plugin.tools().is_empty());
349    }
350}