mcp_protocol_sdk/core/
tool.rs

1//! Tool system for MCP servers
2//!
3//! This module provides the abstraction for implementing and managing tools in MCP servers.
4//! Tools are functions that can be called by clients to perform specific operations.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{Content, ToolInfo, ToolInputSchema, ToolResult};
12
13/// Trait for implementing tool handlers
14#[async_trait]
15pub trait ToolHandler: Send + Sync {
16    /// Execute the tool with the given arguments
17    ///
18    /// # Arguments
19    /// * `arguments` - Tool arguments as key-value pairs
20    ///
21    /// # Returns
22    /// Result containing the tool execution result or an error
23    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult>;
24}
25
26/// A registered tool with its handler
27pub struct Tool {
28    /// Information about the tool
29    pub info: ToolInfo,
30    /// Handler that implements the tool's functionality
31    pub handler: Box<dyn ToolHandler>,
32    /// Whether the tool is currently enabled
33    pub enabled: bool,
34}
35
36impl Tool {
37    /// Create a new tool with the given information and handler
38    ///
39    /// # Arguments
40    /// * `name` - Name of the tool
41    /// * `description` - Optional description of the tool
42    /// * `input_schema` - JSON schema describing the input parameters
43    /// * `handler` - Implementation of the tool's functionality
44    pub fn new<H>(
45        name: String,
46        description: Option<String>,
47        input_schema: Value,
48        handler: H,
49    ) -> Self
50    where
51        H: ToolHandler + 'static,
52    {
53        Self {
54            info: ToolInfo {
55                name,
56                description,
57                input_schema: ToolInputSchema {
58                    schema_type: "object".to_string(),
59                    properties: input_schema
60                        .get("properties")
61                        .and_then(|p| p.as_object())
62                        .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
63                    required: input_schema
64                        .get("required")
65                        .and_then(|r| r.as_array())
66                        .map(|arr| {
67                            arr.iter()
68                                .filter_map(|v| v.as_str().map(String::from))
69                                .collect()
70                        }),
71                    additional_properties: input_schema
72                        .as_object()
73                        .unwrap_or(&serde_json::Map::new())
74                        .iter()
75                        .filter(|(k, _)| !["type", "properties", "required"].contains(&k.as_str()))
76                        .map(|(k, v)| (k.clone(), v.clone()))
77                        .collect(),
78                },
79                annotations: None,
80            },
81            handler: Box::new(handler),
82            enabled: true,
83        }
84    }
85
86    /// Enable the tool
87    pub fn enable(&mut self) {
88        self.enabled = true;
89    }
90
91    /// Disable the tool
92    pub fn disable(&mut self) {
93        self.enabled = false;
94    }
95
96    /// Check if the tool is enabled
97    pub fn is_enabled(&self) -> bool {
98        self.enabled
99    }
100
101    /// Execute the tool if it's enabled
102    ///
103    /// # Arguments
104    /// * `arguments` - Tool arguments as key-value pairs
105    ///
106    /// # Returns
107    /// Result containing the tool execution result or an error
108    pub async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
109        if !self.enabled {
110            return Err(McpError::validation(format!(
111                "Tool '{}' is disabled",
112                self.info.name
113            )));
114        }
115
116        self.handler.call(arguments).await
117    }
118}
119
120impl std::fmt::Debug for Tool {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("Tool")
123            .field("info", &self.info)
124            .field("enabled", &self.enabled)
125            .finish()
126    }
127}
128
129/// Helper macro for creating tools with schema validation
130///
131/// # Examples
132/// ```rust
133/// use mcp_protocol_sdk::{tool, core::tool::ToolHandler};
134/// use serde_json::json;
135///
136/// struct MyHandler;
137/// #[async_trait::async_trait]
138/// impl ToolHandler for MyHandler {
139///     async fn call(&self, _args: std::collections::HashMap<String, serde_json::Value>) -> mcp_protocol_sdk::McpResult<mcp_protocol_sdk::protocol::types::ToolResult> {
140///         // Implementation here
141///         todo!()
142///     }
143/// }
144///
145/// let tool = tool!(
146///     "my_tool",
147///     "A sample tool",
148///     json!({
149///         "type": "object",
150///         "properties": {
151///             "input": { "type": "string" }
152///         }
153///     }),
154///     MyHandler
155/// );
156/// ```
157#[macro_export]
158macro_rules! tool {
159    ($name:expr, $schema:expr, $handler:expr) => {
160        $crate::core::tool::Tool::new($name.to_string(), None, $schema, $handler)
161    };
162    ($name:expr, $description:expr, $schema:expr, $handler:expr) => {
163        $crate::core::tool::Tool::new(
164            $name.to_string(),
165            Some($description.to_string()),
166            $schema,
167            $handler,
168        )
169    };
170}
171
172// Common tool implementations
173
174/// Simple echo tool for testing
175pub struct EchoTool;
176
177#[async_trait]
178impl ToolHandler for EchoTool {
179    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
180        let message = arguments
181            .get("message")
182            .and_then(|v| v.as_str())
183            .unwrap_or("Hello, World!");
184
185        Ok(ToolResult {
186            content: vec![Content::Text {
187                text: message.to_string(),
188                annotations: None,
189            }],
190            is_error: None,
191            meta: None,
192        })
193    }
194}
195
196/// Tool for adding two numbers
197pub struct AdditionTool;
198
199#[async_trait]
200impl ToolHandler for AdditionTool {
201    async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
202        let a = arguments
203            .get("a")
204            .and_then(|v| v.as_f64())
205            .ok_or_else(|| McpError::validation("Missing or invalid 'a' parameter"))?;
206
207        let b = arguments
208            .get("b")
209            .and_then(|v| v.as_f64())
210            .ok_or_else(|| McpError::validation("Missing or invalid 'b' parameter"))?;
211
212        let result = a + b;
213
214        Ok(ToolResult {
215            content: vec![Content::Text {
216                text: result.to_string(),
217                annotations: None,
218            }],
219            is_error: None,
220            meta: None,
221        })
222    }
223}
224
225/// Tool for getting current timestamp
226pub struct TimestampTool;
227
228#[async_trait]
229impl ToolHandler for TimestampTool {
230    async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
231        use std::time::{SystemTime, UNIX_EPOCH};
232
233        let timestamp = SystemTime::now()
234            .duration_since(UNIX_EPOCH)
235            .map_err(|e| McpError::internal(e.to_string()))?
236            .as_secs();
237
238        Ok(ToolResult {
239            content: vec![Content::Text {
240                text: timestamp.to_string(),
241                annotations: None,
242            }],
243            is_error: None,
244            meta: None,
245        })
246    }
247}
248
249/// Builder for creating tools with fluent API
250pub struct ToolBuilder {
251    name: String,
252    description: Option<String>,
253    input_schema: Option<Value>,
254}
255
256impl ToolBuilder {
257    /// Create a new tool builder with the given name
258    pub fn new<S: Into<String>>(name: S) -> Self {
259        Self {
260            name: name.into(),
261            description: None,
262            input_schema: None,
263        }
264    }
265
266    /// Set the tool description
267    pub fn description<S: Into<String>>(mut self, description: S) -> Self {
268        self.description = Some(description.into());
269        self
270    }
271
272    /// Set the input schema
273    pub fn schema(mut self, schema: Value) -> Self {
274        self.input_schema = Some(schema);
275        self
276    }
277
278    /// Build the tool with the given handler
279    pub fn build<H>(self, handler: H) -> McpResult<Tool>
280    where
281        H: ToolHandler + 'static,
282    {
283        let schema = self.input_schema.unwrap_or_else(|| {
284            serde_json::json!({
285                "type": "object",
286                "properties": {},
287                "additionalProperties": true
288            })
289        });
290
291        Ok(Tool::new(self.name, self.description, schema, handler))
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use serde_json::json;
299
300    #[tokio::test]
301    async fn test_echo_tool() {
302        let tool = EchoTool;
303        let mut args = HashMap::new();
304        args.insert("message".to_string(), json!("test message"));
305
306        let result = tool.call(args).await.unwrap();
307        match &result.content[0] {
308            Content::Text { text, .. } => assert_eq!(text, "test message"),
309            _ => panic!("Expected text content"),
310        }
311    }
312
313    #[tokio::test]
314    async fn test_addition_tool() {
315        let tool = AdditionTool;
316        let mut args = HashMap::new();
317        args.insert("a".to_string(), json!(5.0));
318        args.insert("b".to_string(), json!(3.0));
319
320        let result = tool.call(args).await.unwrap();
321        match &result.content[0] {
322            Content::Text { text, .. } => assert_eq!(text, "8"),
323            _ => panic!("Expected text content"),
324        }
325    }
326
327    #[test]
328    fn test_tool_creation() {
329        let tool = Tool::new(
330            "test_tool".to_string(),
331            Some("Test tool".to_string()),
332            json!({"type": "object"}),
333            EchoTool,
334        );
335
336        assert_eq!(tool.info.name, "test_tool");
337        assert_eq!(tool.info.description, Some("Test tool".to_string()));
338        assert!(tool.is_enabled());
339    }
340
341    #[test]
342    fn test_tool_enable_disable() {
343        let mut tool = Tool::new(
344            "test_tool".to_string(),
345            None,
346            json!({"type": "object"}),
347            EchoTool,
348        );
349
350        assert!(tool.is_enabled());
351
352        tool.disable();
353        assert!(!tool.is_enabled());
354
355        tool.enable();
356        assert!(tool.is_enabled());
357    }
358
359    #[tokio::test]
360    async fn test_disabled_tool() {
361        let mut tool = Tool::new(
362            "test_tool".to_string(),
363            None,
364            json!({"type": "object"}),
365            EchoTool,
366        );
367
368        tool.disable();
369
370        let result = tool.call(HashMap::new()).await;
371        assert!(result.is_err());
372        match result.unwrap_err() {
373            McpError::Validation(msg) => assert!(msg.contains("disabled")),
374            _ => panic!("Expected validation error"),
375        }
376    }
377
378    #[test]
379    fn test_tool_builder() {
380        let tool = ToolBuilder::new("test")
381            .description("A test tool")
382            .schema(json!({"type": "object", "properties": {"x": {"type": "number"}}}))
383            .build(EchoTool)
384            .unwrap();
385
386        assert_eq!(tool.info.name, "test");
387        assert_eq!(tool.info.description, Some("A test tool".to_string()));
388    }
389}