mcpkit_rs/wasm/
integration.rs1use std::sync::Arc;
7
8use super::{WasmToolExecutor, WasmToolRegistry};
9use crate::{
10 ErrorData,
11 handler::server::ServerHandler,
12 model::{CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams},
13 service::{RequestContext, RoleServer},
14};
15
16#[derive(Clone)]
18pub struct WasmToolHandler {
19 executor: WasmToolExecutor,
21
22 registry: Arc<WasmToolRegistry>,
24
25 #[cfg(feature = "config")]
26 _config: Option<Arc<crate::config::ServerConfig>>,
28}
29
30impl WasmToolHandler {
31 pub fn new(registry: Arc<WasmToolRegistry>) -> Self {
33 let executor = WasmToolExecutor::new(registry.clone());
34 Self {
35 executor,
36 registry,
37 #[cfg(feature = "config")]
38 _config: None,
39 }
40 }
41
42 #[cfg(feature = "config")]
43 pub fn with_config(
45 registry: Arc<WasmToolRegistry>,
46 config: Arc<crate::config::ServerConfig>,
47 ) -> Self {
48 let executor = WasmToolExecutor::with_config(registry.clone(), config.clone());
49 Self {
50 executor,
51 registry,
52 _config: Some(config),
53 }
54 }
55
56 pub fn executor(&self) -> &WasmToolExecutor {
58 &self.executor
59 }
60
61 pub fn registry(&self) -> &Arc<WasmToolRegistry> {
63 &self.registry
64 }
65}
66
67impl ServerHandler for WasmToolHandler {
68 async fn list_tools(
69 &self,
70 _request: Option<PaginatedRequestParams>,
71 _context: RequestContext<RoleServer>,
72 ) -> Result<ListToolsResult, ErrorData> {
73 let tools = self.executor.list_tools();
74 Ok(ListToolsResult {
75 tools,
76 next_cursor: None,
77 meta: None,
78 })
79 }
80
81 async fn call_tool(
82 &self,
83 request: CallToolRequestParams,
84 _context: RequestContext<RoleServer>,
85 ) -> Result<CallToolResult, ErrorData> {
86 let arguments = request.arguments.unwrap_or_default();
87 self.executor.execute(&request.name, arguments).await
88 }
89}
90
91#[derive(Clone)]
93pub struct CompositeToolHandler<H: ServerHandler> {
94 native_handler: H,
96
97 wasm_handler: WasmToolHandler,
99}
100
101impl<H: ServerHandler> CompositeToolHandler<H> {
102 pub fn new(native_handler: H, wasm_registry: Arc<WasmToolRegistry>) -> Self {
104 let wasm_handler = WasmToolHandler::new(wasm_registry);
105 Self {
106 native_handler,
107 wasm_handler,
108 }
109 }
110
111 #[cfg(feature = "config")]
112 pub fn with_config(
114 native_handler: H,
115 wasm_registry: Arc<WasmToolRegistry>,
116 config: Arc<crate::config::ServerConfig>,
117 ) -> Self {
118 let wasm_handler = WasmToolHandler::with_config(wasm_registry, config);
119 Self {
120 native_handler,
121 wasm_handler,
122 }
123 }
124}
125
126impl<H: ServerHandler + Send + Sync> ServerHandler for CompositeToolHandler<H> {
127 async fn list_tools(
128 &self,
129 request: Option<PaginatedRequestParams>,
130 context: RequestContext<RoleServer>,
131 ) -> Result<ListToolsResult, ErrorData> {
132 let mut native_result = self
134 .native_handler
135 .list_tools(request.clone(), context.clone())
136 .await?;
137 let wasm_result = self.wasm_handler.list_tools(request, context).await?;
138
139 native_result.tools.extend(wasm_result.tools);
141 Ok(native_result)
142 }
143
144 async fn call_tool(
145 &self,
146 request: CallToolRequestParams,
147 context: RequestContext<RoleServer>,
148 ) -> Result<CallToolResult, ErrorData> {
149 if self.wasm_handler.executor.has_tool(&request.name) {
151 self.wasm_handler.call_tool(request, context).await
152 } else {
153 self.native_handler.call_tool(request, context).await
155 }
156 }
157}
158
159pub async fn load_wasm_tools_from_directory(
161 tool_dir: impl AsRef<std::path::Path>,
162 credential_provider: Arc<dyn super::CredentialProvider>,
163) -> Result<WasmToolHandler, super::WasmError> {
164 let registry = Arc::new(WasmToolRegistry::load_from_directory(
165 tool_dir,
166 credential_provider,
167 )?);
168 Ok(WasmToolHandler::new(registry))
169}
170
171#[cfg(feature = "config")]
172pub async fn load_wasm_tools_with_config(
174 tool_dir: impl AsRef<std::path::Path>,
175 config_path: impl AsRef<std::path::Path>,
176 credential_provider: Arc<dyn super::CredentialProvider>,
177) -> Result<WasmToolHandler, Box<dyn std::error::Error>> {
178 let config =
180 crate::config::ServerConfig::from_file(config_path.as_ref().to_str().unwrap()).await?;
181 let config = Arc::new(config);
182
183 let registry = Arc::new(WasmToolRegistry::load_from_directory(
185 tool_dir,
186 credential_provider,
187 )?);
188
189 Ok(WasmToolHandler::with_config(registry, config))
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::wasm::credentials::InMemoryCredentialProvider;
197
198 #[tokio::test]
199 async fn test_wasm_handler_creation() {
200 let runtime = Arc::new(super::super::runtime::WasmRuntime::new().unwrap());
201 let provider = Arc::new(InMemoryCredentialProvider::new());
202 let registry = Arc::new(WasmToolRegistry::new(provider, runtime));
203 let handler = WasmToolHandler::new(registry);
204
205 assert_eq!(handler.registry().tool_count(), 0);
208 assert_eq!(handler.executor().list_tools().len(), 0);
209 }
210}