mcpkit_server/capability/
tools.rs

1//! Tool capability implementation.
2//!
3//! This module provides utilities for managing and executing tools
4//! in an MCP server.
5
6use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16/// A boxed async function for tool execution.
17pub type BoxedToolFn = Box<
18    dyn for<'a> Fn(
19            Value,
20            &'a Context<'a>,
21        ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
22        + Send
23        + Sync,
24>;
25
26/// A registered tool with metadata and handler.
27pub struct RegisteredTool {
28    /// Tool metadata.
29    pub tool: Tool,
30    /// Handler function.
31    pub handler: BoxedToolFn,
32}
33
34/// Service for managing tools.
35///
36/// This provides a registry for tools and handles dispatching
37/// tool calls to the appropriate handlers.
38pub struct ToolService {
39    tools: HashMap<String, RegisteredTool>,
40}
41
42impl Default for ToolService {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ToolService {
49    /// Create a new empty tool service.
50    pub fn new() -> Self {
51        Self {
52            tools: HashMap::new(),
53        }
54    }
55
56    /// Register a tool with a handler function.
57    pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
58    where
59        F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
60        Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
61    {
62        let name = tool.name.clone();
63        let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
64        self.tools.insert(
65            name,
66            RegisteredTool {
67                tool,
68                handler: boxed,
69            },
70        );
71    }
72
73    /// Register a tool with an Arc'd handler (for shared state).
74    pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
75    where
76        H: for<'a> Fn(Value, &'a Context<'a>) -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
77            + Send
78            + Sync
79            + 'static,
80    {
81        let name = tool.name.clone();
82        let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
83        self.tools.insert(
84            name,
85            RegisteredTool {
86                tool,
87                handler: boxed,
88            },
89        );
90    }
91
92    /// Get a tool by name.
93    pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
94        self.tools.get(name)
95    }
96
97    /// Check if a tool exists.
98    pub fn contains(&self, name: &str) -> bool {
99        self.tools.contains_key(name)
100    }
101
102    /// Get all registered tools.
103    pub fn list(&self) -> Vec<&Tool> {
104        self.tools.values().map(|r| &r.tool).collect()
105    }
106
107    /// Get the number of registered tools.
108    pub fn len(&self) -> usize {
109        self.tools.len()
110    }
111
112    /// Check if the service has no tools.
113    pub fn is_empty(&self) -> bool {
114        self.tools.is_empty()
115    }
116
117    /// Call a tool by name.
118    pub async fn call(
119        &self,
120        name: &str,
121        arguments: Value,
122        ctx: &Context<'_>,
123    ) -> Result<ToolOutput, McpError> {
124        let registered = self.tools.get(name).ok_or_else(|| {
125            McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
126        })?;
127
128        (registered.handler)(arguments, ctx).await
129    }
130}
131
132impl ToolHandler for ToolService {
133    async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
134        Ok(self.list().into_iter().cloned().collect())
135    }
136
137    async fn call_tool(
138        &self,
139        name: &str,
140        arguments: Value,
141        ctx: &Context<'_>,
142    ) -> Result<ToolOutput, McpError> {
143        self.call(name, arguments, ctx).await
144    }
145}
146
147/// Builder for creating tools with a fluent API.
148pub struct ToolBuilder {
149    name: String,
150    description: Option<String>,
151    input_schema: Value,
152}
153
154impl ToolBuilder {
155    /// Create a new tool builder.
156    pub fn new(name: impl Into<String>) -> Self {
157        Self {
158            name: name.into(),
159            description: None,
160            input_schema: serde_json::json!({
161                "type": "object",
162                "properties": {},
163            }),
164        }
165    }
166
167    /// Set the tool description.
168    pub fn description(mut self, desc: impl Into<String>) -> Self {
169        self.description = Some(desc.into());
170        self
171    }
172
173    /// Set the input schema.
174    pub fn input_schema(mut self, schema: Value) -> Self {
175        self.input_schema = schema;
176        self
177    }
178
179    /// Build the tool.
180    pub fn build(self) -> Tool {
181        Tool {
182            name: self.name,
183            description: self.description,
184            input_schema: self.input_schema,
185            annotations: None,
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::context::{NoOpPeer, Context};
194    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
195    use mcpkit_core::protocol::RequestId;
196    use mcpkit_core::types::tool::CallToolResult;
197
198    fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
199        (
200            RequestId::Number(1),
201            ClientCapabilities::default(),
202            ServerCapabilities::default(),
203            NoOpPeer,
204        )
205    }
206
207    #[test]
208    fn test_tool_builder() {
209        let tool = ToolBuilder::new("test")
210            .description("A test tool")
211            .input_schema(serde_json::json!({
212                "type": "object",
213                "properties": {
214                    "query": { "type": "string" }
215                }
216            }))
217            .build();
218
219        assert_eq!(tool.name, "test");
220        assert_eq!(tool.description.as_deref(), Some("A test tool"));
221    }
222
223    #[tokio::test]
224    async fn test_tool_service() {
225        let mut service = ToolService::new();
226
227        let tool = ToolBuilder::new("echo")
228            .description("Echo back input")
229            .build();
230
231        service.register(tool, |args, _ctx| async move {
232            Ok(ToolOutput::text(args.to_string()))
233        });
234
235        assert!(service.contains("echo"));
236        assert_eq!(service.len(), 1);
237
238        let (req_id, client_caps, server_caps, peer) = make_context();
239        let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
240
241        let result = service
242            .call("echo", serde_json::json!({"hello": "world"}), &ctx)
243            .await
244            .unwrap();
245
246        // Convert to CallToolResult to check content
247        let call_result: CallToolResult = result.into();
248        assert!(!call_result.content.is_empty());
249    }
250}