1use std::sync::Arc;
24
25use wasmtime::{Engine, Instance, Linker, Memory, Module, Store, TypedFunc};
26
27use crate::tools::{ToolDefinition, ToolImpl, ToolRegistry, ToolResult, ToolResultValue};
28
29#[derive(Debug, thiserror::Error)]
35pub enum PluginError {
36 #[error("Failed to compile WASM module: {0}")]
38 Compile(String),
39
40 #[error("Failed to instantiate WASM module: {0}")]
42 Instantiate(String),
43
44 #[error("Missing required export '{export}' in plugin")]
46 MissingExport { export: String },
47
48 #[error("Export '{export}' has wrong type: {details}")]
50 WrongExportType { export: String, details: String },
51
52 #[error("Plugin tool '{tool}' execution failed: {details}")]
54 ToolExecution { tool: String, details: String },
55
56 #[error("Plugin version '{version}' is not compatible with host ABI v1")]
58 VersionMismatch { version: String },
59
60 #[error(transparent)]
62 Other(#[from] anyhow::Error),
63}
64
65#[derive(Debug, Clone)]
71pub struct PluginTool {
72 pub name: String,
74 pub description: String,
76 pub index: u32,
78}
79
80pub struct WasmPlugin {
86 pub name: String,
88 pub version: String,
90 pub description: String,
92 pub tools: Vec<PluginTool>,
94 store: Store<()>,
96 memory: Memory,
98 tool_execute: TypedFunc<(i32, i32), i32>,
100}
101
102impl std::fmt::Debug for WasmPlugin {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("WasmPlugin")
105 .field("name", &self.name)
106 .field("version", &self.version)
107 .field("description", &self.description)
108 .field("tools", &self.tools)
109 .finish_non_exhaustive()
110 }
111}
112
113impl WasmPlugin {
114 pub fn load(path: &std::path::Path) -> Result<Self, PluginError> {
116 let wasm_bytes = std::fs::read(path).map_err(|e| {
117 PluginError::Other(anyhow::anyhow!("Failed to read {}: {}", path.display(), e))
118 })?;
119 Self::load_from_bytes(&wasm_bytes)
120 }
121
122 pub fn load_from_bytes(wasm_bytes: &[u8]) -> Result<Self, PluginError> {
124 let engine = Engine::default();
125 let module =
126 Module::new(&engine, wasm_bytes).map_err(|e| PluginError::Compile(e.to_string()))?;
127
128 let mut store = Store::new(&engine, ());
129 let linker = Linker::new(&engine);
130 let instance = linker
131 .instantiate(&mut store, &module)
132 .map_err(|e| PluginError::Instantiate(e.to_string()))?;
133
134 let name = Self::read_export_string(&instance, &mut store, "plugin_name")?;
136 let version = Self::read_export_string(&instance, &mut store, "plugin_version")?;
137 let description = Self::read_export_string(&instance, &mut store, "plugin_description")?;
138
139 if version.is_empty() {
141 return Err(PluginError::VersionMismatch {
142 version: version.clone(),
143 });
144 }
145
146 let tools_count: TypedFunc<(), i32> = instance
148 .get_export(&mut store, "plugin_tools_count")
149 .and_then(|e| e.into_func())
150 .ok_or_else(|| PluginError::MissingExport {
151 export: "plugin_tools_count".into(),
152 })?
153 .typed(&store)
154 .map_err(|e| PluginError::WrongExportType {
155 export: "plugin_tools_count".into(),
156 details: e.to_string(),
157 })?;
158 let count: u32 =
159 tools_count
160 .call(&mut store, ())
161 .map_err(|e| PluginError::ToolExecution {
162 tool: "plugin_tools_count".into(),
163 details: e.to_string(),
164 })? as u32;
165
166 let tool_name_fn: TypedFunc<i32, i32> = instance
168 .get_export(&mut store, "plugin_tool_name")
169 .and_then(|e| e.into_func())
170 .ok_or_else(|| PluginError::MissingExport {
171 export: "plugin_tool_name".into(),
172 })?
173 .typed(&store)
174 .map_err(|e| PluginError::WrongExportType {
175 export: "plugin_tool_name".into(),
176 details: e.to_string(),
177 })?;
178
179 let tool_desc_fn: TypedFunc<i32, i32> = instance
180 .get_export(&mut store, "plugin_tool_description")
181 .and_then(|e| e.into_func())
182 .ok_or_else(|| PluginError::MissingExport {
183 export: "plugin_tool_description".into(),
184 })?
185 .typed(&store)
186 .map_err(|e| PluginError::WrongExportType {
187 export: "plugin_tool_description".into(),
188 details: e.to_string(),
189 })?;
190
191 let tool_execute: TypedFunc<(i32, i32), i32> = instance
193 .get_export(&mut store, "plugin_tool_execute")
194 .and_then(|e| e.into_func())
195 .ok_or_else(|| PluginError::MissingExport {
196 export: "plugin_tool_execute".into(),
197 })?
198 .typed(&store)
199 .map_err(|e| PluginError::WrongExportType {
200 export: "plugin_tool_execute".into(),
201 details: e.to_string(),
202 })?;
203
204 let memory: Memory = instance
206 .get_export(&mut store, "memory")
207 .and_then(|e| e.into_memory())
208 .ok_or_else(|| PluginError::MissingExport {
209 export: "memory".into(),
210 })?;
211
212 let mut tools = Vec::with_capacity(count as usize);
214 for i in 0..count {
215 let idx = i as i32;
216 let name_ptr =
217 tool_name_fn
218 .call(&mut store, idx)
219 .map_err(|e| PluginError::ToolExecution {
220 tool: "plugin_tool_name".into(),
221 details: e.to_string(),
222 })?;
223 let desc_ptr =
224 tool_desc_fn
225 .call(&mut store, idx)
226 .map_err(|e| PluginError::ToolExecution {
227 tool: "plugin_tool_description".into(),
228 details: e.to_string(),
229 })?;
230
231 let t_name = Self::read_string_from_memory(&memory, &store, name_ptr as u32);
232 let t_desc = Self::read_string_from_memory(&memory, &store, desc_ptr as u32);
233
234 tools.push(PluginTool {
235 name: t_name,
236 description: t_desc,
237 index: i,
238 });
239 }
240
241 Ok(Self {
242 name,
243 version,
244 description,
245 tools,
246 store,
247 memory,
248 tool_execute,
249 })
250 }
251
252 pub fn execute_tool(&mut self, tool_index: u32, input: &str) -> Result<String, PluginError> {
255 let input_offset = self.reserve_memory(input.len())?;
257 self.memory
258 .write(&mut self.store, input_offset, input.as_bytes())
259 .map_err(|e| PluginError::ToolExecution {
260 tool: self.tools[tool_index as usize].name.clone(),
261 details: format!("Failed to write input to plugin memory: {e}"),
262 })?;
263
264 let result_ptr = self
265 .tool_execute
266 .call(&mut self.store, (tool_index as i32, input_offset as i32))
267 .map_err(|e| PluginError::ToolExecution {
268 tool: self.tools[tool_index as usize].name.clone(),
269 details: e.to_string(),
270 })?;
271
272 let output = Self::read_string_from_memory(&self.memory, &self.store, result_ptr as u32);
273 Ok(output)
274 }
275
276 pub fn register_into_registry(&mut self, registry: &mut ToolRegistry) {
278 let plugin_name = self.name.clone();
279 for i in 0..self.tools.len() as u32 {
280 let tool = self.tools[i as usize].clone();
281 let wrapper = WasmPluginToolWrapper {
282 plugin_name: plugin_name.clone(),
283 tool_index: i,
284 tool_name: tool.name.clone(),
285 tool_description: tool.description.clone(),
286 };
287 registry.register(Arc::new(wrapper));
288 }
289 }
290
291 fn read_string_from_memory(memory: &Memory, store: &Store<()>, offset: u32) -> String {
295 let data = memory.data(store);
296 let mut end = offset as usize;
297 while end < data.len() && data[end] != 0 {
298 end += 1;
299 }
300 String::from_utf8_lossy(&data[offset as usize..end]).to_string()
301 }
302
303 fn read_export_string(
306 instance: &Instance,
307 store: &mut Store<()>,
308 export_name: &str,
309 ) -> Result<String, PluginError> {
310 let func: TypedFunc<(), i32> = instance
311 .get_export(&mut *store, export_name)
312 .and_then(|e| e.into_func())
313 .ok_or_else(|| PluginError::MissingExport {
314 export: export_name.into(),
315 })?
316 .typed(&*store)
317 .map_err(|e| PluginError::WrongExportType {
318 export: export_name.into(),
319 details: e.to_string(),
320 })?;
321
322 let ptr = func
323 .call(&mut *store, ())
324 .map_err(|e| PluginError::ToolExecution {
325 tool: export_name.into(),
326 details: e.to_string(),
327 })?;
328
329 let memory: Memory = instance
331 .get_export(&mut *store, "memory")
332 .and_then(|e| e.into_memory())
333 .ok_or_else(|| PluginError::MissingExport {
334 export: "memory".into(),
335 })?;
336
337 Ok(Self::read_string_from_memory(&memory, &*store, ptr as u32))
338 }
339
340 fn reserve_memory(&mut self, size: usize) -> Result<usize, PluginError> {
343 let current_size = self.memory.data_size(&self.store);
344 let needed = current_size + size;
345 let pages_needed = needed.div_ceil(65536);
346 let current_pages = self.memory.size(&self.store) as usize;
347 if pages_needed > current_pages {
348 let delta = (pages_needed - current_pages) as u64;
349 self.memory
350 .grow(&mut self.store, delta)
351 .map_err(PluginError::Other)?;
352 }
353 Ok(current_size)
354 }
355}
356
357struct WasmPluginToolWrapper {
362 plugin_name: String,
363 tool_index: u32,
364 tool_name: String,
365 #[allow(dead_code)]
366 tool_description: String,
367}
368
369impl std::fmt::Debug for WasmPluginToolWrapper {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 f.debug_struct("WasmPluginToolWrapper")
372 .field("plugin_name", &self.plugin_name)
373 .field("tool_index", &self.tool_index)
374 .field("tool_name", &self.tool_name)
375 .finish()
376 }
377}
378
379#[async_trait::async_trait]
380impl ToolImpl for WasmPluginToolWrapper {
381 fn definition(&self) -> &ToolDefinition {
382 use std::sync::OnceLock;
385 static DEF: OnceLock<ToolDefinition> = OnceLock::new();
386 DEF.get_or_init(|| ToolDefinition {
387 name: self.tool_name.clone(),
388 description: self.tool_description.clone(),
389 parameters: crate::tools::JsonSchema::object(
390 std::collections::HashMap::new(),
391 Vec::new(),
392 ),
393 requires_approval: false,
394 category: crate::tools::ToolCategory::General,
395 })
396 }
397
398 async fn execute(&self, _args: serde_json::Value) -> ToolResultValue<ToolResult> {
399 Err(crate::tools::ToolError::ExecutionFailed(
403 self.tool_name.clone(),
404 "WASM plugin tool requires re-instantiation — not yet implemented".into(),
405 ))
406 }
407
408 fn name(&self) -> &str {
409 &self.tool_name
410 }
411}
412
413pub struct WasmPluginManager {
419 plugins: Vec<WasmPlugin>,
420}
421
422impl WasmPluginManager {
423 pub fn new() -> Self {
425 Self {
426 plugins: Vec::new(),
427 }
428 }
429
430 pub fn load(&mut self, path: &std::path::Path) -> Result<&WasmPlugin, PluginError> {
432 let plugin = WasmPlugin::load(path)?;
433 self.plugins.push(plugin);
434 Ok(self.plugins.last().unwrap())
435 }
436
437 pub fn load_from_bytes(&mut self, bytes: &[u8]) -> Result<&WasmPlugin, PluginError> {
439 let plugin = WasmPlugin::load_from_bytes(bytes)?;
440 self.plugins.push(plugin);
441 Ok(self.plugins.last().unwrap())
442 }
443
444 pub fn get(&self, name: &str) -> Option<&WasmPlugin> {
446 self.plugins.iter().find(|p| p.name == name)
447 }
448
449 pub fn get_mut(&mut self, name: &str) -> Option<&mut WasmPlugin> {
451 self.plugins.iter_mut().find(|p| p.name == name)
452 }
453
454 pub fn unload(&mut self, name: &str) -> bool {
456 let idx = self.plugins.iter().position(|p| p.name == name);
457 if let Some(i) = idx {
458 self.plugins.remove(i);
459 true
460 } else {
461 false
462 }
463 }
464
465 pub fn iter(&self) -> impl Iterator<Item = &WasmPlugin> {
467 self.plugins.iter()
468 }
469
470 pub fn register_all(&mut self, registry: &mut ToolRegistry) {
472 for plugin in self.plugins.iter_mut() {
473 plugin.register_into_registry(registry);
474 }
475 }
476
477 pub fn len(&self) -> usize {
479 self.plugins.len()
480 }
481
482 pub fn is_empty(&self) -> bool {
484 self.plugins.is_empty()
485 }
486}
487
488impl Default for WasmPluginManager {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494impl std::fmt::Debug for WasmPluginManager {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 f.debug_struct("WasmPluginManager")
497 .field("plugin_count", &self.plugins.len())
498 .finish_non_exhaustive()
499 }
500}
501
502#[cfg(test)]
507mod tests {
508 use super::*;
509
510 const HELLO_PLUGIN_WAT: &str = r#"
512 (module
513 (memory (export "memory") 1 256)
514 (data (i32.const 0) "Hello from WASM!\00")
515 (data (i32.const 32) "0.1.0\00")
516 (data (i32.const 64) "A friendly WASM plugin\00")
517 (data (i32.const 128) "hello\00")
518 (data (i32.const 160) "Returns a friendly greeting\00")
519 (func (export "plugin_name") (result i32) i32.const 0)
520 (func (export "plugin_version") (result i32) i32.const 32)
521 (func (export "plugin_description") (result i32) i32.const 64)
522 (func (export "plugin_tools_count") (result i32) i32.const 1)
523 (func (export "plugin_tool_name") (param i32) (result i32)
524 i32.const 128)
525 (func (export "plugin_tool_description") (param i32) (result i32)
526 i32.const 160)
527 (func (export "plugin_tool_execute") (param i32 i32) (result i32)
528 i32.const 0)
529 )
530 "#;
531
532 const ECHO_PLUGIN_WAT: &str = r#"
534 (module
535 (memory (export "memory") 1 256)
536 (data (i32.const 0) "Echo Plugin\00")
537 (data (i32.const 32) "0.1.0\00")
538 (data (i32.const 64) "A plugin that echoes input\00")
539 (data (i32.const 128) "echo\00")
540 (data (i32.const 160) "Echoes back the input string\00")
541 (data (i32.const 192) "hello\00")
542 (data (i32.const 224) "Returns a friendly greeting\00")
543 (func (export "plugin_name") (result i32) i32.const 0)
544 (func (export "plugin_version") (result i32) i32.const 32)
545 (func (export "plugin_description") (result i32) i32.const 64)
546 (func (export "plugin_tools_count") (result i32) i32.const 2)
547 (func (export "plugin_tool_name") (param i32) (result i32)
548 local.get 0
549 if (result i32)
550 i32.const 128
551 else
552 i32.const 192
553 end)
554 (func (export "plugin_tool_description") (param i32) (result i32)
555 local.get 0
556 if (result i32)
557 i32.const 160
558 else
559 i32.const 224
560 end)
561 (func (export "plugin_tool_execute") (param i32 i32) (result i32)
562 local.get 1)
563 )
564 "#;
565
566 fn compile_wat(wat: &str) -> Vec<u8> {
567 wat::parse_str(wat).expect("Failed to parse WAT")
568 }
569
570 #[test]
571 fn test_plugin_load() {
572 let wasm = compile_wat(HELLO_PLUGIN_WAT);
573 let plugin = WasmPlugin::load_from_bytes(&wasm).expect("Failed to load plugin");
574 assert_eq!(plugin.name, "Hello from WASM!");
575 assert_eq!(plugin.version, "0.1.0");
576 assert_eq!(plugin.description, "A friendly WASM plugin");
577 }
578
579 #[test]
580 fn test_plugin_tools() {
581 let wasm = compile_wat(HELLO_PLUGIN_WAT);
582 let plugin = WasmPlugin::load_from_bytes(&wasm).expect("Failed to load plugin");
583 assert_eq!(plugin.tools.len(), 1);
584 assert_eq!(plugin.tools[0].name, "hello");
585 assert_eq!(plugin.tools[0].description, "Returns a friendly greeting");
586 }
587
588 #[test]
589 fn test_plugin_execute_hello() {
590 let wasm = compile_wat(HELLO_PLUGIN_WAT);
591 let mut plugin = WasmPlugin::load_from_bytes(&wasm).expect("Failed to load plugin");
592 let result = plugin.execute_tool(0, "").expect("Failed to execute tool");
593 assert_eq!(result, "Hello from WASM!");
594 }
595
596 #[test]
597 fn test_plugin_execute_echo() {
598 let wasm = compile_wat(ECHO_PLUGIN_WAT);
599 let mut plugin = WasmPlugin::load_from_bytes(&wasm).expect("Failed to load plugin");
600 let result = plugin
601 .execute_tool(1, "Hello world")
602 .expect("Failed to execute tool");
603 assert_eq!(result, "Hello world");
604 }
605
606 #[test]
607 fn test_plugin_manager() {
608 let wasm = compile_wat(HELLO_PLUGIN_WAT);
609 let mut manager = WasmPluginManager::new();
610 manager.load_from_bytes(&wasm).expect("Failed to load");
611 assert_eq!(manager.len(), 1);
612 let plugin = manager.get("Hello from WASM!").expect("Plugin not found");
613 assert_eq!(plugin.tools.len(), 1);
614 }
615
616 #[test]
617 fn test_plugin_manager_execute() {
618 let wasm = compile_wat(ECHO_PLUGIN_WAT);
619 let mut manager = WasmPluginManager::new();
620 manager.load_from_bytes(&wasm).expect("Failed to load");
621 let plugin = manager.get_mut("Echo Plugin").expect("Plugin not found");
622 let result = plugin
623 .execute_tool(0, "test echo")
624 .expect("Failed to execute");
625 assert_eq!(result, "test echo");
626 }
627
628 #[test]
629 fn test_plugin_unload() {
630 let wasm = compile_wat(HELLO_PLUGIN_WAT);
631 let mut manager = WasmPluginManager::new();
632 manager.load_from_bytes(&wasm).expect("Failed to load");
633 assert_eq!(manager.len(), 1);
634 assert!(manager.unload("Hello from WASM!"));
635 assert_eq!(manager.len(), 0);
636 }
637
638 #[test]
639 fn test_plugin_unload_nonexistent() {
640 let mut manager = WasmPluginManager::new();
641 assert!(!manager.unload("nonexistent"));
642 }
643
644 #[test]
645 fn test_plugin_version_mismatch() {
646 let wat = r#"
647 (module
648 (memory (export "memory") 1)
649 (data (i32.const 0) "Empty\00")
650 (data (i32.const 32) "\00")
651 (data (i32.const 64) "No version\00")
652 (func (export "plugin_name") (result i32) i32.const 0)
653 (func (export "plugin_version") (result i32) i32.const 32)
654 (func (export "plugin_description") (result i32) i32.const 64)
655 (func (export "plugin_tools_count") (result i32) i32.const 0)
656 )
657 "#;
658 let wasm = compile_wat(wat);
659 let result = WasmPlugin::load_from_bytes(&wasm);
660 assert!(result.is_err());
661 let err = result.unwrap_err();
662 assert!(
663 matches!(&err, PluginError::VersionMismatch { .. }),
664 "Expected VersionMismatch, got {err}"
665 );
666 }
667
668 #[test]
669 fn test_plugin_missing_export() {
670 let wat = r#"
671 (module
672 (memory (export "memory") 1)
673 )
674 "#;
675 let wasm = compile_wat(wat);
676 let result = WasmPlugin::load_from_bytes(&wasm);
677 assert!(result.is_err());
678 let err = result.unwrap_err();
679 assert!(
680 matches!(&err, PluginError::MissingExport { .. }),
681 "Expected MissingExport, got {err}"
682 );
683 }
684
685 #[test]
686 fn test_plugin_default_manager() {
687 let manager = WasmPluginManager::default();
688 assert!(manager.is_empty());
689 assert_eq!(manager.len(), 0);
690 }
691}