mcpkit_server/capability/
tools.rs

1//! Tool capability implementation.
2//!
3//! This module provides utilities for managing and executing tools
4//! in an MCP server.
5
6use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16/// A boxed async function for tool execution.
17pub type BoxedToolFn = Box<
18    dyn for<'a> Fn(
19            Value,
20            &'a Context<'a>,
21        )
22            -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
23        + Send
24        + Sync,
25>;
26
27/// A registered tool with metadata and handler.
28pub struct RegisteredTool {
29    /// Tool metadata.
30    pub tool: Tool,
31    /// Handler function.
32    pub handler: BoxedToolFn,
33}
34
35/// Service for managing tools.
36///
37/// This provides a registry for tools and handles dispatching
38/// tool calls to the appropriate handlers.
39pub struct ToolService {
40    tools: HashMap<String, RegisteredTool>,
41}
42
43impl Default for ToolService {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ToolService {
50    /// Create a new empty tool service.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            tools: HashMap::new(),
55        }
56    }
57
58    /// Register a tool with a handler function.
59    pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
60    where
61        F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
62        Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
63    {
64        let name = tool.name.clone();
65        let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
66        self.tools.insert(
67            name,
68            RegisteredTool {
69                tool,
70                handler: boxed,
71            },
72        );
73    }
74
75    /// Register a tool with an Arc'd handler (for shared state).
76    pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
77    where
78        H: for<'a> Fn(
79                Value,
80                &'a Context<'a>,
81            )
82                -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
83            + Send
84            + Sync
85            + 'static,
86    {
87        let name = tool.name.clone();
88        let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
89        self.tools.insert(
90            name,
91            RegisteredTool {
92                tool,
93                handler: boxed,
94            },
95        );
96    }
97
98    /// Get a tool by name.
99    #[must_use]
100    pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
101        self.tools.get(name)
102    }
103
104    /// Check if a tool exists.
105    #[must_use]
106    pub fn contains(&self, name: &str) -> bool {
107        self.tools.contains_key(name)
108    }
109
110    /// Get all registered tools.
111    #[must_use]
112    pub fn list(&self) -> Vec<&Tool> {
113        self.tools.values().map(|r| &r.tool).collect()
114    }
115
116    /// Get the number of registered tools.
117    #[must_use]
118    pub fn len(&self) -> usize {
119        self.tools.len()
120    }
121
122    /// Check if the service has no tools.
123    #[must_use]
124    pub fn is_empty(&self) -> bool {
125        self.tools.is_empty()
126    }
127
128    /// Call a tool by name.
129    pub async fn call(
130        &self,
131        name: &str,
132        arguments: Value,
133        ctx: &Context<'_>,
134    ) -> Result<ToolOutput, McpError> {
135        let registered = self.tools.get(name).ok_or_else(|| {
136            McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
137        })?;
138
139        (registered.handler)(arguments, ctx).await
140    }
141}
142
143impl ToolHandler for ToolService {
144    async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
145        Ok(self.list().into_iter().cloned().collect())
146    }
147
148    async fn call_tool(
149        &self,
150        name: &str,
151        arguments: Value,
152        ctx: &Context<'_>,
153    ) -> Result<ToolOutput, McpError> {
154        self.call(name, arguments, ctx).await
155    }
156}
157
158/// Builder for creating tools with a fluent API.
159pub struct ToolBuilder {
160    name: String,
161    description: Option<String>,
162    input_schema: Value,
163    destructive: Option<bool>,
164    idempotent: Option<bool>,
165    read_only: Option<bool>,
166}
167
168impl ToolBuilder {
169    /// Create a new tool builder.
170    pub fn new(name: impl Into<String>) -> Self {
171        Self {
172            name: name.into(),
173            description: None,
174            input_schema: serde_json::json!({
175                "type": "object",
176                "properties": {},
177            }),
178            destructive: None,
179            idempotent: None,
180            read_only: None,
181        }
182    }
183
184    /// Set the tool description.
185    pub fn description(mut self, desc: impl Into<String>) -> Self {
186        self.description = Some(desc.into());
187        self
188    }
189
190    /// Set the input schema.
191    #[must_use]
192    pub fn input_schema(mut self, schema: Value) -> Self {
193        self.input_schema = schema;
194        self
195    }
196
197    /// Mark this tool as destructive.
198    ///
199    /// Destructive tools modify data or state in ways that cannot be easily undone.
200    /// When set to true, clients should warn users before executing.
201    #[must_use]
202    pub fn destructive(mut self, value: bool) -> Self {
203        self.destructive = Some(value);
204        self
205    }
206
207    /// Mark this tool as idempotent.
208    ///
209    /// Idempotent tools produce the same result when called multiple times
210    /// with the same arguments.
211    #[must_use]
212    pub fn idempotent(mut self, value: bool) -> Self {
213        self.idempotent = Some(value);
214        self
215    }
216
217    /// Mark this tool as read-only.
218    ///
219    /// Read-only tools do not modify any data or state.
220    #[must_use]
221    pub fn read_only(mut self, value: bool) -> Self {
222        self.read_only = Some(value);
223        self
224    }
225
226    /// Build the tool.
227    #[must_use]
228    pub fn build(self) -> Tool {
229        let has_annotations =
230            self.destructive.is_some() || self.idempotent.is_some() || self.read_only.is_some();
231
232        let annotations = if has_annotations {
233            Some(mcpkit_core::types::tool::ToolAnnotations {
234                title: None,
235                read_only_hint: self.read_only.or(Some(false)),
236                destructive_hint: self.destructive.or(Some(false)),
237                idempotent_hint: self.idempotent.or(Some(false)),
238                open_world_hint: None,
239            })
240        } else {
241            None
242        };
243
244        Tool {
245            name: self.name,
246            description: self.description,
247            input_schema: self.input_schema,
248            annotations,
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::context::{Context, NoOpPeer};
257    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
258    use mcpkit_core::protocol::RequestId;
259    use mcpkit_core::protocol_version::ProtocolVersion;
260    use mcpkit_core::types::tool::CallToolResult;
261
262    fn make_context() -> (
263        RequestId,
264        ClientCapabilities,
265        ServerCapabilities,
266        ProtocolVersion,
267        NoOpPeer,
268    ) {
269        (
270            RequestId::Number(1),
271            ClientCapabilities::default(),
272            ServerCapabilities::default(),
273            ProtocolVersion::LATEST,
274            NoOpPeer,
275        )
276    }
277
278    #[test]
279    fn test_tool_builder() {
280        let tool = ToolBuilder::new("test")
281            .description("A test tool")
282            .input_schema(serde_json::json!({
283                "type": "object",
284                "properties": {
285                    "query": { "type": "string" }
286                }
287            }))
288            .build();
289
290        assert_eq!(tool.name, "test");
291        assert_eq!(tool.description.as_deref(), Some("A test tool"));
292    }
293
294    #[tokio::test]
295    async fn test_tool_service() -> Result<(), Box<dyn std::error::Error>> {
296        let mut service = ToolService::new();
297
298        let tool = ToolBuilder::new("echo")
299            .description("Echo back input")
300            .build();
301
302        service.register(tool, |args, _ctx| async move {
303            Ok(ToolOutput::text(args.to_string()))
304        });
305
306        assert!(service.contains("echo"));
307        assert_eq!(service.len(), 1);
308
309        let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
310        let ctx = Context::new(
311            &req_id,
312            None,
313            &client_caps,
314            &server_caps,
315            protocol_version,
316            &peer,
317        );
318
319        let result = service
320            .call("echo", serde_json::json!({"hello": "world"}), &ctx)
321            .await?;
322
323        // Convert to CallToolResult to check content
324        let call_result: CallToolResult = result.into();
325        assert!(!call_result.content.is_empty());
326
327        Ok(())
328    }
329}