mcpkit_server/capability/
tools.rs1use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16pub type BoxedToolFn = Box<
18 dyn for<'a> Fn(
19 Value,
20 &'a Context<'a>,
21 )
22 -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
23 + Send
24 + Sync,
25>;
26
27pub struct RegisteredTool {
29 pub tool: Tool,
31 pub handler: BoxedToolFn,
33}
34
35pub struct ToolService {
40 tools: HashMap<String, RegisteredTool>,
41}
42
43impl Default for ToolService {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl ToolService {
50 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 tools: HashMap::new(),
55 }
56 }
57
58 pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
60 where
61 F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
63 {
64 let name = tool.name.clone();
65 let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
66 self.tools.insert(
67 name,
68 RegisteredTool {
69 tool,
70 handler: boxed,
71 },
72 );
73 }
74
75 pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
77 where
78 H: for<'a> Fn(
79 Value,
80 &'a Context<'a>,
81 )
82 -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
83 + Send
84 + Sync
85 + 'static,
86 {
87 let name = tool.name.clone();
88 let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
89 self.tools.insert(
90 name,
91 RegisteredTool {
92 tool,
93 handler: boxed,
94 },
95 );
96 }
97
98 #[must_use]
100 pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
101 self.tools.get(name)
102 }
103
104 #[must_use]
106 pub fn contains(&self, name: &str) -> bool {
107 self.tools.contains_key(name)
108 }
109
110 #[must_use]
112 pub fn list(&self) -> Vec<&Tool> {
113 self.tools.values().map(|r| &r.tool).collect()
114 }
115
116 #[must_use]
118 pub fn len(&self) -> usize {
119 self.tools.len()
120 }
121
122 #[must_use]
124 pub fn is_empty(&self) -> bool {
125 self.tools.is_empty()
126 }
127
128 pub async fn call(
130 &self,
131 name: &str,
132 arguments: Value,
133 ctx: &Context<'_>,
134 ) -> Result<ToolOutput, McpError> {
135 let registered = self.tools.get(name).ok_or_else(|| {
136 McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
137 })?;
138
139 (registered.handler)(arguments, ctx).await
140 }
141}
142
143impl ToolHandler for ToolService {
144 async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
145 Ok(self.list().into_iter().cloned().collect())
146 }
147
148 async fn call_tool(
149 &self,
150 name: &str,
151 arguments: Value,
152 ctx: &Context<'_>,
153 ) -> Result<ToolOutput, McpError> {
154 self.call(name, arguments, ctx).await
155 }
156}
157
158pub struct ToolBuilder {
160 name: String,
161 description: Option<String>,
162 input_schema: Value,
163}
164
165impl ToolBuilder {
166 pub fn new(name: impl Into<String>) -> Self {
168 Self {
169 name: name.into(),
170 description: None,
171 input_schema: serde_json::json!({
172 "type": "object",
173 "properties": {},
174 }),
175 }
176 }
177
178 pub fn description(mut self, desc: impl Into<String>) -> Self {
180 self.description = Some(desc.into());
181 self
182 }
183
184 #[must_use]
186 pub fn input_schema(mut self, schema: Value) -> Self {
187 self.input_schema = schema;
188 self
189 }
190
191 #[must_use]
193 pub fn build(self) -> Tool {
194 Tool {
195 name: self.name,
196 description: self.description,
197 input_schema: self.input_schema,
198 annotations: None,
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::context::{Context, NoOpPeer};
207 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
208 use mcpkit_core::protocol::RequestId;
209 use mcpkit_core::types::tool::CallToolResult;
210
211 fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
212 (
213 RequestId::Number(1),
214 ClientCapabilities::default(),
215 ServerCapabilities::default(),
216 NoOpPeer,
217 )
218 }
219
220 #[test]
221 fn test_tool_builder() {
222 let tool = ToolBuilder::new("test")
223 .description("A test tool")
224 .input_schema(serde_json::json!({
225 "type": "object",
226 "properties": {
227 "query": { "type": "string" }
228 }
229 }))
230 .build();
231
232 assert_eq!(tool.name, "test");
233 assert_eq!(tool.description.as_deref(), Some("A test tool"));
234 }
235
236 #[tokio::test]
237 async fn test_tool_service() {
238 let mut service = ToolService::new();
239
240 let tool = ToolBuilder::new("echo")
241 .description("Echo back input")
242 .build();
243
244 service.register(tool, |args, _ctx| async move {
245 Ok(ToolOutput::text(args.to_string()))
246 });
247
248 assert!(service.contains("echo"));
249 assert_eq!(service.len(), 1);
250
251 let (req_id, client_caps, server_caps, peer) = make_context();
252 let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
253
254 let result = service
255 .call("echo", serde_json::json!({"hello": "world"}), &ctx)
256 .await
257 .unwrap();
258
259 let call_result: CallToolResult = result.into();
261 assert!(!call_result.content.is_empty());
262 }
263}