mcp_protocol_sdk/core/
tool.rs

1//! Tool system for MCP servers
2//!
3//! This module provides the abstraction for implementing and managing tools in MCP servers.
4//! Tools are functions that can be called by clients to perform specific operations.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{Content, ToolInfo, ToolResult};
12
13/// Trait for implementing tool handlers
14#[async_trait]
15pub trait ToolHandler: Send + Sync {
16    /// Execute the tool with the given arguments
17    ///
18    /// # Arguments
19    /// * `arguments` - Tool arguments as key-value pairs
20    ///
21    /// # Returns
22    /// Result containing the tool execution result or an error
23    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult>;
24}
25
26/// A registered tool with its handler
27pub struct Tool {
28    /// Information about the tool
29    pub info: ToolInfo,
30    /// Handler that implements the tool's functionality
31    pub handler: Box<dyn ToolHandler>,
32    /// Whether the tool is currently enabled
33    pub enabled: bool,
34}
35
36impl Tool {
37    /// Create a new tool with the given information and handler
38    ///
39    /// # Arguments
40    /// * `name` - Name of the tool
41    /// * `description` - Optional description of the tool
42    /// * `input_schema` - JSON schema describing the input parameters
43    /// * `handler` - Implementation of the tool's functionality
44    pub fn new<H>(
45        name: String,
46        description: Option<String>,
47        input_schema: Value,
48        handler: H,
49    ) -> Self
50    where
51        H: ToolHandler + 'static,
52    {
53        Self {
54            info: ToolInfo {
55                name,
56                description,
57                input_schema,
58            },
59            handler: Box::new(handler),
60            enabled: true,
61        }
62    }
63
64    /// Enable the tool
65    pub fn enable(&mut self) {
66        self.enabled = true;
67    }
68
69    /// Disable the tool
70    pub fn disable(&mut self) {
71        self.enabled = false;
72    }
73
74    /// Check if the tool is enabled
75    pub fn is_enabled(&self) -> bool {
76        self.enabled
77    }
78
79    /// Execute the tool if it's enabled
80    ///
81    /// # Arguments
82    /// * `arguments` - Tool arguments as key-value pairs
83    ///
84    /// # Returns
85    /// Result containing the tool execution result or an error
86    pub async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
87        if !self.enabled {
88            return Err(McpError::validation(format!(
89                "Tool '{}' is disabled",
90                self.info.name
91            )));
92        }
93
94        self.handler.call(arguments).await
95    }
96}
97
98impl std::fmt::Debug for Tool {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("Tool")
101            .field("info", &self.info)
102            .field("enabled", &self.enabled)
103            .finish()
104    }
105}
106
107/// Helper macro for creating tools with schema validation
108///
109/// # Examples
110/// ```rust
111/// use mcp_protocol_sdk::{tool, core::tool::ToolHandler};
112/// use serde_json::json;
113///
114/// struct MyHandler;
115/// #[async_trait::async_trait]
116/// impl ToolHandler for MyHandler {
117///     async fn call(&self, _args: std::collections::HashMap<String, serde_json::Value>) -> mcp_protocol_sdk::McpResult<mcp_protocol_sdk::protocol::types::ToolResult> {
118///         // Implementation here
119///         todo!()
120///     }
121/// }
122///
123/// let tool = tool!(
124///     "my_tool",
125///     "A sample tool",
126///     json!({
127///         "type": "object",
128///         "properties": {
129///             "input": { "type": "string" }
130///         }
131///     }),
132///     MyHandler
133/// );
134/// ```
135#[macro_export]
136macro_rules! tool {
137    ($name:expr, $schema:expr, $handler:expr) => {
138        $crate::core::tool::Tool::new($name.to_string(), None, $schema, $handler)
139    };
140    ($name:expr, $description:expr, $schema:expr, $handler:expr) => {
141        $crate::core::tool::Tool::new(
142            $name.to_string(),
143            Some($description.to_string()),
144            $schema,
145            $handler,
146        )
147    };
148}
149
150// Common tool implementations
151
152/// Simple echo tool for testing
153pub struct EchoTool;
154
155#[async_trait]
156impl ToolHandler for EchoTool {
157    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
158        let message = arguments
159            .get("message")
160            .and_then(|v| v.as_str())
161            .unwrap_or("Hello, World!");
162
163        Ok(ToolResult {
164            content: vec![Content::Text {
165                text: message.to_string(),
166            }],
167            is_error: None,
168        })
169    }
170}
171
172/// Tool for adding two numbers
173pub struct AdditionTool;
174
175#[async_trait]
176impl ToolHandler for AdditionTool {
177    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
178        let a = arguments
179            .get("a")
180            .and_then(|v| v.as_f64())
181            .ok_or_else(|| McpError::validation("Missing or invalid 'a' parameter"))?;
182
183        let b = arguments
184            .get("b")
185            .and_then(|v| v.as_f64())
186            .ok_or_else(|| McpError::validation("Missing or invalid 'b' parameter"))?;
187
188        let result = a + b;
189
190        Ok(ToolResult {
191            content: vec![Content::Text {
192                text: result.to_string(),
193            }],
194            is_error: None,
195        })
196    }
197}
198
199/// Tool for getting current timestamp
200pub struct TimestampTool;
201
202#[async_trait]
203impl ToolHandler for TimestampTool {
204    async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
205        use std::time::{SystemTime, UNIX_EPOCH};
206
207        let timestamp = SystemTime::now()
208            .duration_since(UNIX_EPOCH)
209            .map_err(|e| McpError::internal(e.to_string()))?
210            .as_secs();
211
212        Ok(ToolResult {
213            content: vec![Content::Text {
214                text: timestamp.to_string(),
215            }],
216            is_error: None,
217        })
218    }
219}
220
221/// Builder for creating tools with fluent API
222pub struct ToolBuilder {
223    name: String,
224    description: Option<String>,
225    input_schema: Option<Value>,
226}
227
228impl ToolBuilder {
229    /// Create a new tool builder with the given name
230    pub fn new<S: Into<String>>(name: S) -> Self {
231        Self {
232            name: name.into(),
233            description: None,
234            input_schema: None,
235        }
236    }
237
238    /// Set the tool description
239    pub fn description<S: Into<String>>(mut self, description: S) -> Self {
240        self.description = Some(description.into());
241        self
242    }
243
244    /// Set the input schema
245    pub fn schema(mut self, schema: Value) -> Self {
246        self.input_schema = Some(schema);
247        self
248    }
249
250    /// Build the tool with the given handler
251    pub fn build<H>(self, handler: H) -> McpResult<Tool>
252    where
253        H: ToolHandler + 'static,
254    {
255        let schema = self.input_schema.unwrap_or_else(|| {
256            serde_json::json!({
257                "type": "object",
258                "properties": {},
259                "additionalProperties": true
260            })
261        });
262
263        Ok(Tool::new(self.name, self.description, schema, handler))
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use serde_json::json;
271
272    #[tokio::test]
273    async fn test_echo_tool() {
274        let tool = EchoTool;
275        let mut args = HashMap::new();
276        args.insert("message".to_string(), json!("test message"));
277
278        let result = tool.call(args).await.unwrap();
279        match &result.content[0] {
280            Content::Text { text } => assert_eq!(text, "test message"),
281            _ => panic!("Expected text content"),
282        }
283    }
284
285    #[tokio::test]
286    async fn test_addition_tool() {
287        let tool = AdditionTool;
288        let mut args = HashMap::new();
289        args.insert("a".to_string(), json!(5.0));
290        args.insert("b".to_string(), json!(3.0));
291
292        let result = tool.call(args).await.unwrap();
293        match &result.content[0] {
294            Content::Text { text } => assert_eq!(text, "8"),
295            _ => panic!("Expected text content"),
296        }
297    }
298
299    #[test]
300    fn test_tool_creation() {
301        let tool = Tool::new(
302            "test_tool".to_string(),
303            Some("Test tool".to_string()),
304            json!({"type": "object"}),
305            EchoTool,
306        );
307
308        assert_eq!(tool.info.name, "test_tool");
309        assert_eq!(tool.info.description, Some("Test tool".to_string()));
310        assert!(tool.is_enabled());
311    }
312
313    #[test]
314    fn test_tool_enable_disable() {
315        let mut tool = Tool::new(
316            "test_tool".to_string(),
317            None,
318            json!({"type": "object"}),
319            EchoTool,
320        );
321
322        assert!(tool.is_enabled());
323
324        tool.disable();
325        assert!(!tool.is_enabled());
326
327        tool.enable();
328        assert!(tool.is_enabled());
329    }
330
331    #[tokio::test]
332    async fn test_disabled_tool() {
333        let mut tool = Tool::new(
334            "test_tool".to_string(),
335            None,
336            json!({"type": "object"}),
337            EchoTool,
338        );
339
340        tool.disable();
341
342        let result = tool.call(HashMap::new()).await;
343        assert!(result.is_err());
344        match result.unwrap_err() {
345            McpError::Validation(msg) => assert!(msg.contains("disabled")),
346            _ => panic!("Expected validation error"),
347        }
348    }
349
350    #[test]
351    fn test_tool_builder() {
352        let tool = ToolBuilder::new("test")
353            .description("A test tool")
354            .schema(json!({"type": "object", "properties": {"x": {"type": "number"}}}))
355            .build(EchoTool)
356            .unwrap();
357
358        assert_eq!(tool.info.name, "test");
359        assert_eq!(tool.info.description, Some("A test tool".to_string()));
360    }
361}