astrid_plugins/wasm/
tool.rs1use std::sync::{Arc, Mutex};
4
5use async_trait::async_trait;
6use serde_json::Value;
7
8use astrid_core::plugin_abi::{ToolInput, ToolOutput};
9
10use crate::context::PluginToolContext;
11use crate::error::{PluginError, PluginResult};
12use crate::tool::PluginTool;
13
14pub struct WasmPluginTool {
19 name: String,
21 description: String,
23 input_schema: Value,
25 plugin: Arc<Mutex<extism::Plugin>>,
27}
28
29impl WasmPluginTool {
30 pub(crate) fn new(
32 name: String,
33 description: String,
34 input_schema: Value,
35 plugin: Arc<Mutex<extism::Plugin>>,
36 ) -> Self {
37 Self {
38 name,
39 description,
40 input_schema,
41 plugin,
42 }
43 }
44}
45
46#[async_trait]
47impl PluginTool for WasmPluginTool {
48 fn name(&self) -> &str {
49 &self.name
50 }
51
52 fn description(&self) -> &str {
53 &self.description
54 }
55
56 fn input_schema(&self) -> Value {
57 self.input_schema.clone()
58 }
59
60 async fn execute(&self, args: Value, _ctx: &PluginToolContext) -> PluginResult<String> {
61 let tool_input = ToolInput {
62 name: self.name.clone(),
63 arguments: serde_json::to_string(&args).map_err(|e| {
64 PluginError::ExecutionFailed(format!("failed to serialize args: {e}"))
65 })?,
66 };
67
68 let input_json = serde_json::to_string(&tool_input).map_err(|e| {
69 PluginError::ExecutionFailed(format!("failed to serialize ToolInput: {e}"))
70 })?;
71
72 let result = tokio::task::block_in_place(|| {
74 let mut plugin = self
75 .plugin
76 .lock()
77 .map_err(|e| PluginError::WasmError(format!("plugin lock poisoned: {e}")))?;
78 plugin
79 .call::<&str, String>("execute-tool", &input_json)
80 .map_err(|e| PluginError::WasmError(format!("execute-tool call failed: {e}")))
81 })?;
82
83 let output: ToolOutput = serde_json::from_str(&result).map_err(|e| {
85 PluginError::ExecutionFailed(format!("failed to parse ToolOutput: {e}"))
86 })?;
87
88 if output.is_error {
89 Err(PluginError::ExecutionFailed(output.content))
90 } else {
91 Ok(output.content)
92 }
93 }
94}
95
96impl std::fmt::Debug for WasmPluginTool {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("WasmPluginTool")
99 .field("name", &self.name)
100 .field("description", &self.description)
101 .finish_non_exhaustive()
102 }
103}
104
105