mcpkit_server/capability/
tools.rs1use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16pub type BoxedToolFn = Box<
18 dyn for<'a> Fn(
19 Value,
20 &'a Context<'a>,
21 )
22 -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
23 + Send
24 + Sync,
25>;
26
27pub struct RegisteredTool {
29 pub tool: Tool,
31 pub handler: BoxedToolFn,
33}
34
35pub struct ToolService {
40 tools: HashMap<String, RegisteredTool>,
41}
42
43impl Default for ToolService {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl ToolService {
50 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 tools: HashMap::new(),
55 }
56 }
57
58 pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
60 where
61 F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
63 {
64 let name = tool.name.clone();
65 let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
66 self.tools.insert(
67 name,
68 RegisteredTool {
69 tool,
70 handler: boxed,
71 },
72 );
73 }
74
75 pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
77 where
78 H: for<'a> Fn(
79 Value,
80 &'a Context<'a>,
81 )
82 -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
83 + Send
84 + Sync
85 + 'static,
86 {
87 let name = tool.name.clone();
88 let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
89 self.tools.insert(
90 name,
91 RegisteredTool {
92 tool,
93 handler: boxed,
94 },
95 );
96 }
97
98 #[must_use]
100 pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
101 self.tools.get(name)
102 }
103
104 #[must_use]
106 pub fn contains(&self, name: &str) -> bool {
107 self.tools.contains_key(name)
108 }
109
110 #[must_use]
112 pub fn list(&self) -> Vec<&Tool> {
113 self.tools.values().map(|r| &r.tool).collect()
114 }
115
116 #[must_use]
118 pub fn len(&self) -> usize {
119 self.tools.len()
120 }
121
122 #[must_use]
124 pub fn is_empty(&self) -> bool {
125 self.tools.is_empty()
126 }
127
128 pub async fn call(
130 &self,
131 name: &str,
132 arguments: Value,
133 ctx: &Context<'_>,
134 ) -> Result<ToolOutput, McpError> {
135 let registered = self.tools.get(name).ok_or_else(|| {
136 McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
137 })?;
138
139 (registered.handler)(arguments, ctx).await
140 }
141}
142
143impl ToolHandler for ToolService {
144 async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
145 Ok(self.list().into_iter().cloned().collect())
146 }
147
148 async fn call_tool(
149 &self,
150 name: &str,
151 arguments: Value,
152 ctx: &Context<'_>,
153 ) -> Result<ToolOutput, McpError> {
154 self.call(name, arguments, ctx).await
155 }
156}
157
158pub struct ToolBuilder {
160 name: String,
161 description: Option<String>,
162 input_schema: Value,
163 destructive: Option<bool>,
164 idempotent: Option<bool>,
165 read_only: Option<bool>,
166}
167
168impl ToolBuilder {
169 pub fn new(name: impl Into<String>) -> Self {
171 Self {
172 name: name.into(),
173 description: None,
174 input_schema: serde_json::json!({
175 "type": "object",
176 "properties": {},
177 }),
178 destructive: None,
179 idempotent: None,
180 read_only: None,
181 }
182 }
183
184 pub fn description(mut self, desc: impl Into<String>) -> Self {
186 self.description = Some(desc.into());
187 self
188 }
189
190 #[must_use]
192 pub fn input_schema(mut self, schema: Value) -> Self {
193 self.input_schema = schema;
194 self
195 }
196
197 #[must_use]
202 pub fn destructive(mut self, value: bool) -> Self {
203 self.destructive = Some(value);
204 self
205 }
206
207 #[must_use]
212 pub fn idempotent(mut self, value: bool) -> Self {
213 self.idempotent = Some(value);
214 self
215 }
216
217 #[must_use]
221 pub fn read_only(mut self, value: bool) -> Self {
222 self.read_only = Some(value);
223 self
224 }
225
226 #[must_use]
228 pub fn build(self) -> Tool {
229 let has_annotations =
230 self.destructive.is_some() || self.idempotent.is_some() || self.read_only.is_some();
231
232 let annotations = if has_annotations {
233 Some(mcpkit_core::types::tool::ToolAnnotations {
234 title: None,
235 read_only_hint: self.read_only.or(Some(false)),
236 destructive_hint: self.destructive.or(Some(false)),
237 idempotent_hint: self.idempotent.or(Some(false)),
238 open_world_hint: None,
239 })
240 } else {
241 None
242 };
243
244 Tool {
245 name: self.name,
246 description: self.description,
247 input_schema: self.input_schema,
248 annotations,
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::context::{Context, NoOpPeer};
257 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
258 use mcpkit_core::protocol::RequestId;
259 use mcpkit_core::protocol_version::ProtocolVersion;
260 use mcpkit_core::types::tool::CallToolResult;
261
262 fn make_context() -> (
263 RequestId,
264 ClientCapabilities,
265 ServerCapabilities,
266 ProtocolVersion,
267 NoOpPeer,
268 ) {
269 (
270 RequestId::Number(1),
271 ClientCapabilities::default(),
272 ServerCapabilities::default(),
273 ProtocolVersion::LATEST,
274 NoOpPeer,
275 )
276 }
277
278 #[test]
279 fn test_tool_builder() {
280 let tool = ToolBuilder::new("test")
281 .description("A test tool")
282 .input_schema(serde_json::json!({
283 "type": "object",
284 "properties": {
285 "query": { "type": "string" }
286 }
287 }))
288 .build();
289
290 assert_eq!(tool.name, "test");
291 assert_eq!(tool.description.as_deref(), Some("A test tool"));
292 }
293
294 #[tokio::test]
295 async fn test_tool_service() -> Result<(), Box<dyn std::error::Error>> {
296 let mut service = ToolService::new();
297
298 let tool = ToolBuilder::new("echo")
299 .description("Echo back input")
300 .build();
301
302 service.register(tool, |args, _ctx| async move {
303 Ok(ToolOutput::text(args.to_string()))
304 });
305
306 assert!(service.contains("echo"));
307 assert_eq!(service.len(), 1);
308
309 let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
310 let ctx = Context::new(
311 &req_id,
312 None,
313 &client_caps,
314 &server_caps,
315 protocol_version,
316 &peer,
317 );
318
319 let result = service
320 .call("echo", serde_json::json!({"hello": "world"}), &ctx)
321 .await?;
322
323 let call_result: CallToolResult = result.into();
325 assert!(!call_result.content.is_empty());
326
327 Ok(())
328 }
329}