1#![allow(
7 clippy::cast_sign_loss,
8 clippy::cast_possible_truncation,
9 clippy::cast_possible_wrap
10)]
11
12use std::path::Path;
13
14use async_trait::async_trait;
15use wasmtime::{Config, Engine, Instance, Linker, Module, Store, TypedFunc};
16
17use crate::api::{Plugin, PluginError, PluginHook};
18
19#[derive(Debug, Clone)]
21pub struct WasmPluginMetadata {
22 pub name: String,
24 pub version: String,
26 pub hooks: Vec<String>,
28}
29
30pub struct PluginState {
32 pub metadata: WasmPluginMetadata,
34 result_buffer: Vec<u8>,
36}
37
38impl PluginState {
39 fn new(metadata: WasmPluginMetadata) -> Self {
40 Self {
41 metadata,
42 result_buffer: Vec::with_capacity(4096),
43 }
44 }
45}
46
47pub struct WasmPlugin {
49 engine: Engine,
50 module: Module,
51 instance: Instance,
52 store: Store<PluginState>,
53 metadata: WasmPluginMetadata,
54}
55
56impl WasmPlugin {
57 pub fn load(path: &Path) -> Result<Self, PluginError> {
63 let mut config = Config::new();
64 config.wasm_backtrace_details(wasmtime::WasmBacktraceDetails::Enable);
65
66 let engine =
67 Engine::new(&config).map_err(|e| PluginError::LoadFailed(format!("Engine: {e}")))?;
68
69 let module = Module::from_file(&engine, path)
70 .map_err(|e| PluginError::LoadFailed(format!("Module load: {e}")))?;
71
72 let metadata = WasmPluginMetadata {
74 name: path
75 .file_stem()
76 .and_then(|s| s.to_str())
77 .unwrap_or("unknown")
78 .to_string(),
79 version: "0.0.0".to_string(),
80 hooks: Vec::new(),
81 };
82
83 let mut store = Store::new(&engine, PluginState::new(metadata.clone()));
84
85 let mut linker = Linker::new(&engine);
87 Self::define_host_functions(&mut linker)?;
88
89 let instance = linker
91 .instantiate(&mut store, &module)
92 .map_err(|e| PluginError::LoadFailed(format!("Instantiate: {e}")))?;
93
94 let mut plugin = Self {
95 engine,
96 module,
97 instance,
98 store,
99 metadata,
100 };
101
102 plugin.init()?;
104
105 Ok(plugin)
106 }
107
108 pub fn load_bytes(name: &str, bytes: &[u8]) -> Result<Self, PluginError> {
114 let mut config = Config::new();
115 config.wasm_backtrace_details(wasmtime::WasmBacktraceDetails::Enable);
116
117 let engine =
118 Engine::new(&config).map_err(|e| PluginError::LoadFailed(format!("Engine: {e}")))?;
119
120 let module = Module::new(&engine, bytes)
121 .map_err(|e| PluginError::LoadFailed(format!("Module load: {e}")))?;
122
123 let metadata = WasmPluginMetadata {
124 name: name.to_string(),
125 version: "0.0.0".to_string(),
126 hooks: Vec::new(),
127 };
128
129 let mut store = Store::new(&engine, PluginState::new(metadata.clone()));
130
131 let mut linker = Linker::new(&engine);
132 Self::define_host_functions(&mut linker)?;
133
134 let instance = linker
135 .instantiate(&mut store, &module)
136 .map_err(|e| PluginError::LoadFailed(format!("Instantiate: {e}")))?;
137
138 let mut plugin = Self {
139 engine,
140 module,
141 instance,
142 store,
143 metadata,
144 };
145
146 plugin.init()?;
147
148 Ok(plugin)
149 }
150
151 fn define_host_functions(linker: &mut Linker<PluginState>) -> Result<(), PluginError> {
153 linker
155 .func_wrap(
156 "env",
157 "plugin_log",
158 |mut caller: wasmtime::Caller<'_, PluginState>,
159 level: i32,
160 msg_ptr: i32,
161 msg_len: i32| {
162 let Some(wasmtime::Extern::Memory(memory)) = caller.get_export("memory") else {
163 return;
164 };
165
166 let data = memory.data(&caller);
167 let start = msg_ptr as usize;
168 let end = start + msg_len as usize;
169 if end > data.len() {
170 return;
171 }
172
173 let msg = String::from_utf8_lossy(&data[start..end]);
174 match level {
175 0 => tracing::trace!(plugin = "wasm", "{msg}"),
176 1 => tracing::debug!(plugin = "wasm", "{msg}"),
177 2 => tracing::info!(plugin = "wasm", "{msg}"),
178 3 => tracing::warn!(plugin = "wasm", "{msg}"),
179 _ => tracing::error!(plugin = "wasm", "{msg}"),
180 }
181 },
182 )
183 .map_err(|e| PluginError::LoadFailed(format!("Link plugin_log: {e}")))?;
184
185 linker
187 .func_wrap(
188 "env",
189 "plugin_get_config",
190 |_caller: wasmtime::Caller<'_, PluginState>, _key_ptr: i32, _key_len: i32| -> i32 {
191 0
194 },
195 )
196 .map_err(|e| PluginError::LoadFailed(format!("Link plugin_get_config: {e}")))?;
197
198 linker
200 .func_wrap(
201 "env",
202 "plugin_set_result",
203 |mut caller: wasmtime::Caller<'_, PluginState>, ptr: i32, len: i32| {
204 let Some(wasmtime::Extern::Memory(memory)) = caller.get_export("memory") else {
205 return;
206 };
207
208 let data = memory.data(&caller);
209 let start = ptr as usize;
210 let end = start + len as usize;
211 if end > data.len() {
212 return;
213 }
214
215 let result_data = data[start..end].to_vec();
216 caller.data_mut().result_buffer = result_data;
217 },
218 )
219 .map_err(|e| PluginError::LoadFailed(format!("Link plugin_set_result: {e}")))?;
220
221 Ok(())
222 }
223
224 fn init(&mut self) -> Result<(), PluginError> {
226 let init_fn: Option<TypedFunc<(), i32>> = self
228 .instance
229 .get_typed_func::<(), i32>(&mut self.store, "plugin_init")
230 .ok();
231
232 if let Some(init) = init_fn {
233 let result = init
234 .call(&mut self.store, ())
235 .map_err(|e| PluginError::ExecutionError(format!("Init failed: {e}")))?;
236
237 if result != 0 {
238 return Err(PluginError::ExecutionError(format!(
239 "Plugin init returned error code: {result}"
240 )));
241 }
242 }
243
244 if let Ok(get_name) = self
246 .instance
247 .get_typed_func::<(), i32>(&mut self.store, "plugin_get_name")
248 {
249 let _ = get_name.call(&mut self.store, ());
250 if !self.store.data().result_buffer.is_empty() {
251 if let Ok(name) = String::from_utf8(self.store.data().result_buffer.clone()) {
252 self.metadata.name = name;
253 }
254 self.store.data_mut().result_buffer.clear();
255 }
256 }
257
258 if let Ok(get_version) = self
259 .instance
260 .get_typed_func::<(), i32>(&mut self.store, "plugin_get_version")
261 {
262 let _ = get_version.call(&mut self.store, ());
263 if !self.store.data().result_buffer.is_empty() {
264 if let Ok(version) = String::from_utf8(self.store.data().result_buffer.clone()) {
265 self.metadata.version = version;
266 }
267 self.store.data_mut().result_buffer.clear();
268 }
269 }
270
271 tracing::info!(
272 name = %self.metadata.name,
273 version = %self.metadata.version,
274 "WASM plugin loaded"
275 );
276
277 Ok(())
278 }
279
280 pub fn call_export(&mut self, method: &str, params: &[u8]) -> Result<Vec<u8>, PluginError> {
286 let memory = self
288 .instance
289 .get_memory(&mut self.store, "memory")
290 .ok_or_else(|| PluginError::ExecutionError("No memory export".to_string()))?;
291
292 let alloc_fn: TypedFunc<i32, i32> = self
294 .instance
295 .get_typed_func(&mut self.store, "plugin_alloc")
296 .map_err(|e| PluginError::ExecutionError(format!("No alloc function: {e}")))?;
297
298 let params_ptr = alloc_fn
300 .call(&mut self.store, params.len() as i32)
301 .map_err(|e| PluginError::ExecutionError(format!("Alloc failed: {e}")))?;
302
303 memory
305 .write(&mut self.store, params_ptr as usize, params)
306 .map_err(|e| PluginError::ExecutionError(format!("Memory write failed: {e}")))?;
307
308 let export_fn: TypedFunc<(i32, i32), i32> = self
310 .instance
311 .get_typed_func(&mut self.store, method)
312 .map_err(|e| PluginError::ExecutionError(format!("Export not found: {e}")))?;
313
314 self.store.data_mut().result_buffer.clear();
316
317 let result = export_fn
319 .call(&mut self.store, (params_ptr, params.len() as i32))
320 .map_err(|e| PluginError::ExecutionError(format!("Call failed: {e}")))?;
321
322 if result != 0 {
323 return Err(PluginError::ExecutionError(format!(
324 "Export returned error code: {result}"
325 )));
326 }
327
328 Ok(self.store.data().result_buffer.clone())
330 }
331
332 #[must_use]
334 pub const fn metadata(&self) -> &WasmPluginMetadata {
335 &self.metadata
336 }
337}
338
339#[async_trait]
340impl Plugin for WasmPlugin {
341 fn id(&self) -> &str {
342 &self.metadata.name
343 }
344
345 fn name(&self) -> &str {
346 &self.metadata.name
347 }
348
349 fn version(&self) -> &str {
350 &self.metadata.version
351 }
352
353 fn hooks(&self) -> &[PluginHook] {
354 &[
356 PluginHook::BeforeMessage,
357 PluginHook::AfterMessage,
358 PluginHook::BeforeToolCall,
359 PluginHook::AfterToolCall,
360 PluginHook::SessionStart,
361 PluginHook::SessionEnd,
362 PluginHook::AgentResponse,
363 PluginHook::Error,
364 ]
365 }
366
367 async fn execute_hook(
368 &self,
369 hook: PluginHook,
370 data: serde_json::Value,
371 ) -> Result<serde_json::Value, PluginError> {
372 let _ = hook;
376 Ok(data)
377 }
378
379 async fn activate(&self) -> Result<(), PluginError> {
380 Ok(())
381 }
382
383 async fn deactivate(&self) -> Result<(), PluginError> {
384 Ok(())
385 }
386}
387
388pub struct WasmPluginManager {
390 plugins: Vec<WasmPlugin>,
391}
392
393impl WasmPluginManager {
394 #[must_use]
396 pub const fn new() -> Self {
397 Self {
398 plugins: Vec::new(),
399 }
400 }
401
402 pub fn load(&mut self, path: &Path) -> Result<(), PluginError> {
408 let plugin = WasmPlugin::load(path)?;
409 self.plugins.push(plugin);
410 Ok(())
411 }
412
413 pub fn load_dir(&mut self, dir: &Path) -> Result<usize, PluginError> {
419 let entries = std::fs::read_dir(dir)
420 .map_err(|e| PluginError::LoadFailed(format!("Read dir: {e}")))?;
421
422 let mut loaded = 0;
423 for entry in entries.flatten() {
424 let path = entry.path();
425 if path.extension().is_some_and(|ext| ext == "wasm") {
426 match WasmPlugin::load(&path) {
427 Ok(plugin) => {
428 tracing::info!(path = %path.display(), "Loaded WASM plugin");
429 self.plugins.push(plugin);
430 loaded += 1;
431 }
432 Err(e) => {
433 tracing::warn!(path = %path.display(), error = %e, "Failed to load WASM plugin");
434 }
435 }
436 }
437 }
438
439 Ok(loaded)
440 }
441
442 #[must_use]
444 pub fn plugins(&self) -> &[WasmPlugin] {
445 &self.plugins
446 }
447
448 pub fn plugins_mut(&mut self) -> &mut [WasmPlugin] {
450 &mut self.plugins
451 }
452}
453
454impl Default for WasmPluginManager {
455 fn default() -> Self {
456 Self::new()
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_manager_creation() {
466 let manager = WasmPluginManager::new();
467 assert!(manager.plugins().is_empty());
468 }
469
470 }