Skip to main content

qwencode_rs/mcp/
tool.rs

1use anyhow::Result;
2use schemars::{schema_for, JsonSchema};
3use serde::Deserialize;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use crate::types::mcp::{McpToolDefinition, McpToolResult};
9
10/// Tool handler function type
11pub type ToolHandler = Arc<
12    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<McpToolResult>> + Send>>
13        + Send
14        + Sync,
15>;
16
17/// MCP Tool with name, description, schema, and handler
18pub struct McpTool {
19    pub definition: McpToolDefinition,
20    pub handler: ToolHandler,
21}
22
23impl McpTool {
24    /// Create a new MCP tool
25    pub fn new<F, Fut>(
26        name: &str,
27        description: &str,
28        input_schema: serde_json::Value,
29        handler: F,
30    ) -> Self
31    where
32        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
33        Fut: Future<Output = Result<McpToolResult>> + Send + 'static,
34    {
35        McpTool {
36            definition: McpToolDefinition {
37                name: name.to_string(),
38                description: description.to_string(),
39                input_schema,
40            },
41            handler: Arc::new(move |input| Box::pin(handler(input))),
42        }
43    }
44
45    /// Execute the tool with given input
46    pub async fn execute(&self, input: serde_json::Value) -> Result<McpToolResult> {
47        (self.handler)(input).await
48    }
49}
50
51/// Macro to create a tool with automatic schema generation
52///
53/// # Example
54/// ```ignore
55/// use qwencode_rs::tool;
56/// use serde::Deserialize;
57///
58/// #[derive(Deserialize, JsonSchema)]
59/// struct AddArgs {
60///     a: i32,
61///     b: i32,
62/// }
63///
64/// let tool = tool!(
65///     "add",
66///     "Add two numbers",
67///     AddArgs,
68///     |args: AddArgs| async move {
69///         Ok(McpToolResult {
70///             content: vec![ToolContent::Text {
71///                 text: format!("{}", args.a + args.b),
72///             }],
73///             is_error: false,
74///         })
75///     }
76/// );
77/// ```
78#[macro_export]
79macro_rules! tool {
80    ($name:expr, $description:expr, $args_type:ty, $handler:expr) => {{
81        use schemars::schema_for;
82        use $crate::mcp::tool::McpTool;
83        use $crate::types::mcp::McpToolResult;
84        use $crate::types::mcp::ToolContent;
85
86        let schema = schemars::schema_for!($args_type);
87        let schema_json = serde_json::to_value(schema).unwrap();
88
89        McpTool::new(
90            $name,
91            $description,
92            schema_json,
93            move |input: serde_json::Value| {
94                let handler = $handler;
95                async move {
96                    let args: $args_type = serde_json::from_value(input)?;
97                    handler(args).await
98                }
99            },
100        )
101    }};
102}
103
104/// Helper function to create a tool without the macro (for dynamic tool creation)
105pub fn create_tool<F, Fut, Args>(name: &str, description: &str, handler: F) -> McpTool
106where
107    F: Fn(Args) -> Fut + Send + Sync + 'static,
108    Fut: Future<Output = Result<McpToolResult>> + Send + 'static,
109    Args: for<'de> Deserialize<'de> + JsonSchema + 'static,
110{
111    let schema = schema_for!(Args);
112    let schema_json = serde_json::to_value(schema).expect("Failed to serialize schema");
113    let handler = Arc::new(handler);
114
115    McpTool::new(
116        name,
117        description,
118        schema_json,
119        move |input: serde_json::Value| {
120            let handler = handler.clone();
121            async move {
122                let args: Args = serde_json::from_value(input)?;
123                handler(args).await
124            }
125        },
126    )
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::types::mcp::{McpToolResult, ToolContent};
133    use serde::Deserialize;
134
135    #[derive(Debug, Deserialize, JsonSchema)]
136    #[allow(dead_code)]
137    struct TestArgs {
138        value: i32,
139    }
140
141    #[tokio::test]
142    async fn test_mcp_tool_creation() {
143        let tool = McpTool::new(
144            "test_tool",
145            "A test tool",
146            serde_json::json!({
147                "type": "object",
148                "properties": {
149                    "value": {"type": "integer"}
150                }
151            }),
152            |input: serde_json::Value| async move {
153                let value = input["value"].as_i64().unwrap_or(0);
154                Ok(McpToolResult {
155                    content: vec![ToolContent::Text {
156                        text: format!("Got value: {}", value),
157                    }],
158                    is_error: false,
159                })
160            },
161        );
162
163        assert_eq!(tool.definition.name, "test_tool");
164        assert_eq!(tool.definition.description, "A test tool");
165    }
166
167    #[tokio::test]
168    async fn test_mcp_tool_execution() {
169        let tool = McpTool::new(
170            "test_tool",
171            "A test tool",
172            serde_json::json!({
173                "type": "object",
174                "properties": {
175                    "value": {"type": "integer"}
176                }
177            }),
178            |input: serde_json::Value| async move {
179                let value = input["value"].as_i64().unwrap_or(0);
180                Ok(McpToolResult {
181                    content: vec![ToolContent::Text {
182                        text: format!("Result: {}", value * 2),
183                    }],
184                    is_error: false,
185                })
186            },
187        );
188
189        let result = tool
190            .execute(serde_json::json!({"value": 21}))
191            .await
192            .unwrap();
193
194        assert!(!result.is_error);
195        assert_eq!(result.content.len(), 1);
196
197        match &result.content[0] {
198            ToolContent::Text { text } => assert_eq!(text, "Result: 42"),
199            _ => panic!("Expected Text content"),
200        }
201    }
202
203    #[test]
204    fn test_tool_definition_structure() {
205        let tool = McpTool::new(
206            "calc",
207            "Calculate",
208            serde_json::json!({"type": "object"}),
209            |_input: serde_json::Value| async move {
210                Ok(McpToolResult {
211                    content: vec![],
212                    is_error: false,
213                })
214            },
215        );
216
217        assert_eq!(tool.definition.name, "calc");
218        assert_eq!(tool.definition.description, "Calculate");
219        assert!(tool.definition.input_schema.is_object());
220    }
221
222    #[tokio::test]
223    async fn test_tool_with_error() {
224        let tool = McpTool::new(
225            "failing_tool",
226            "Always fails",
227            serde_json::json!({}),
228            |_input: serde_json::Value| async move {
229                Ok(McpToolResult {
230                    content: vec![ToolContent::Text {
231                        text: "Error occurred".to_string(),
232                    }],
233                    is_error: true,
234                })
235            },
236        );
237
238        let result = tool.execute(serde_json::json!({})).await.unwrap();
239
240        assert!(result.is_error);
241        assert_eq!(result.content.len(), 1);
242    }
243}