mcpkit_server/capability/
tools.rs

1//! Tool capability implementation.
2//!
3//! This module provides utilities for managing and executing tools
4//! in an MCP server.
5
6use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16/// A boxed async function for tool execution.
17pub type BoxedToolFn = Box<
18    dyn for<'a> Fn(
19            Value,
20            &'a Context<'a>,
21        )
22            -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
23        + Send
24        + Sync,
25>;
26
27/// A registered tool with metadata and handler.
28pub struct RegisteredTool {
29    /// Tool metadata.
30    pub tool: Tool,
31    /// Handler function.
32    pub handler: BoxedToolFn,
33}
34
35/// Service for managing tools.
36///
37/// This provides a registry for tools and handles dispatching
38/// tool calls to the appropriate handlers.
39pub struct ToolService {
40    tools: HashMap<String, RegisteredTool>,
41}
42
43impl Default for ToolService {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ToolService {
50    /// Create a new empty tool service.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            tools: HashMap::new(),
55        }
56    }
57
58    /// Register a tool with a handler function.
59    pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
60    where
61        F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
62        Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
63    {
64        let name = tool.name.clone();
65        let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
66        self.tools.insert(
67            name,
68            RegisteredTool {
69                tool,
70                handler: boxed,
71            },
72        );
73    }
74
75    /// Register a tool with an Arc'd handler (for shared state).
76    pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
77    where
78        H: for<'a> Fn(
79                Value,
80                &'a Context<'a>,
81            )
82                -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
83            + Send
84            + Sync
85            + 'static,
86    {
87        let name = tool.name.clone();
88        let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
89        self.tools.insert(
90            name,
91            RegisteredTool {
92                tool,
93                handler: boxed,
94            },
95        );
96    }
97
98    /// Get a tool by name.
99    #[must_use]
100    pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
101        self.tools.get(name)
102    }
103
104    /// Check if a tool exists.
105    #[must_use]
106    pub fn contains(&self, name: &str) -> bool {
107        self.tools.contains_key(name)
108    }
109
110    /// Get all registered tools.
111    #[must_use]
112    pub fn list(&self) -> Vec<&Tool> {
113        self.tools.values().map(|r| &r.tool).collect()
114    }
115
116    /// Get the number of registered tools.
117    #[must_use]
118    pub fn len(&self) -> usize {
119        self.tools.len()
120    }
121
122    /// Check if the service has no tools.
123    #[must_use]
124    pub fn is_empty(&self) -> bool {
125        self.tools.is_empty()
126    }
127
128    /// Call a tool by name.
129    pub async fn call(
130        &self,
131        name: &str,
132        arguments: Value,
133        ctx: &Context<'_>,
134    ) -> Result<ToolOutput, McpError> {
135        let registered = self.tools.get(name).ok_or_else(|| {
136            McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
137        })?;
138
139        (registered.handler)(arguments, ctx).await
140    }
141}
142
143impl ToolHandler for ToolService {
144    async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
145        Ok(self.list().into_iter().cloned().collect())
146    }
147
148    async fn call_tool(
149        &self,
150        name: &str,
151        arguments: Value,
152        ctx: &Context<'_>,
153    ) -> Result<ToolOutput, McpError> {
154        self.call(name, arguments, ctx).await
155    }
156}
157
158/// Builder for creating tools with a fluent API.
159pub struct ToolBuilder {
160    name: String,
161    description: Option<String>,
162    input_schema: Value,
163}
164
165impl ToolBuilder {
166    /// Create a new tool builder.
167    pub fn new(name: impl Into<String>) -> Self {
168        Self {
169            name: name.into(),
170            description: None,
171            input_schema: serde_json::json!({
172                "type": "object",
173                "properties": {},
174            }),
175        }
176    }
177
178    /// Set the tool description.
179    pub fn description(mut self, desc: impl Into<String>) -> Self {
180        self.description = Some(desc.into());
181        self
182    }
183
184    /// Set the input schema.
185    #[must_use]
186    pub fn input_schema(mut self, schema: Value) -> Self {
187        self.input_schema = schema;
188        self
189    }
190
191    /// Build the tool.
192    #[must_use]
193    pub fn build(self) -> Tool {
194        Tool {
195            name: self.name,
196            description: self.description,
197            input_schema: self.input_schema,
198            annotations: None,
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::context::{Context, NoOpPeer};
207    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
208    use mcpkit_core::protocol::RequestId;
209    use mcpkit_core::types::tool::CallToolResult;
210
211    fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
212        (
213            RequestId::Number(1),
214            ClientCapabilities::default(),
215            ServerCapabilities::default(),
216            NoOpPeer,
217        )
218    }
219
220    #[test]
221    fn test_tool_builder() {
222        let tool = ToolBuilder::new("test")
223            .description("A test tool")
224            .input_schema(serde_json::json!({
225                "type": "object",
226                "properties": {
227                    "query": { "type": "string" }
228                }
229            }))
230            .build();
231
232        assert_eq!(tool.name, "test");
233        assert_eq!(tool.description.as_deref(), Some("A test tool"));
234    }
235
236    #[tokio::test]
237    async fn test_tool_service() {
238        let mut service = ToolService::new();
239
240        let tool = ToolBuilder::new("echo")
241            .description("Echo back input")
242            .build();
243
244        service.register(tool, |args, _ctx| async move {
245            Ok(ToolOutput::text(args.to_string()))
246        });
247
248        assert!(service.contains("echo"));
249        assert_eq!(service.len(), 1);
250
251        let (req_id, client_caps, server_caps, peer) = make_context();
252        let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
253
254        let result = service
255            .call("echo", serde_json::json!({"hello": "world"}), &ctx)
256            .await
257            .unwrap();
258
259        // Convert to CallToolResult to check content
260        let call_result: CallToolResult = result.into();
261        assert!(!call_result.content.is_empty());
262    }
263}