1use anyhow::Result;
2use schemars::{schema_for, JsonSchema};
3use serde::Deserialize;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use crate::types::mcp::{McpToolDefinition, McpToolResult};
9
10pub type ToolHandler = Arc<
12 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = Result<McpToolResult>> + Send>>
13 + Send
14 + Sync,
15>;
16
17pub struct McpTool {
19 pub definition: McpToolDefinition,
20 pub handler: ToolHandler,
21}
22
23impl McpTool {
24 pub fn new<F, Fut>(
26 name: &str,
27 description: &str,
28 input_schema: serde_json::Value,
29 handler: F,
30 ) -> Self
31 where
32 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
33 Fut: Future<Output = Result<McpToolResult>> + Send + 'static,
34 {
35 McpTool {
36 definition: McpToolDefinition {
37 name: name.to_string(),
38 description: description.to_string(),
39 input_schema,
40 },
41 handler: Arc::new(move |input| Box::pin(handler(input))),
42 }
43 }
44
45 pub async fn execute(&self, input: serde_json::Value) -> Result<McpToolResult> {
47 (self.handler)(input).await
48 }
49}
50
51#[macro_export]
79macro_rules! tool {
80 ($name:expr, $description:expr, $args_type:ty, $handler:expr) => {{
81 use schemars::schema_for;
82 use $crate::mcp::tool::McpTool;
83 use $crate::types::mcp::McpToolResult;
84 use $crate::types::mcp::ToolContent;
85
86 let schema = schemars::schema_for!($args_type);
87 let schema_json = serde_json::to_value(schema).unwrap();
88
89 McpTool::new(
90 $name,
91 $description,
92 schema_json,
93 move |input: serde_json::Value| {
94 let handler = $handler;
95 async move {
96 let args: $args_type = serde_json::from_value(input)?;
97 handler(args).await
98 }
99 },
100 )
101 }};
102}
103
104pub fn create_tool<F, Fut, Args>(name: &str, description: &str, handler: F) -> McpTool
106where
107 F: Fn(Args) -> Fut + Send + Sync + 'static,
108 Fut: Future<Output = Result<McpToolResult>> + Send + 'static,
109 Args: for<'de> Deserialize<'de> + JsonSchema + 'static,
110{
111 let schema = schema_for!(Args);
112 let schema_json = serde_json::to_value(schema).expect("Failed to serialize schema");
113 let handler = Arc::new(handler);
114
115 McpTool::new(
116 name,
117 description,
118 schema_json,
119 move |input: serde_json::Value| {
120 let handler = handler.clone();
121 async move {
122 let args: Args = serde_json::from_value(input)?;
123 handler(args).await
124 }
125 },
126 )
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::types::mcp::{McpToolResult, ToolContent};
133 use serde::Deserialize;
134
135 #[derive(Debug, Deserialize, JsonSchema)]
136 #[allow(dead_code)]
137 struct TestArgs {
138 value: i32,
139 }
140
141 #[tokio::test]
142 async fn test_mcp_tool_creation() {
143 let tool = McpTool::new(
144 "test_tool",
145 "A test tool",
146 serde_json::json!({
147 "type": "object",
148 "properties": {
149 "value": {"type": "integer"}
150 }
151 }),
152 |input: serde_json::Value| async move {
153 let value = input["value"].as_i64().unwrap_or(0);
154 Ok(McpToolResult {
155 content: vec![ToolContent::Text {
156 text: format!("Got value: {}", value),
157 }],
158 is_error: false,
159 })
160 },
161 );
162
163 assert_eq!(tool.definition.name, "test_tool");
164 assert_eq!(tool.definition.description, "A test tool");
165 }
166
167 #[tokio::test]
168 async fn test_mcp_tool_execution() {
169 let tool = McpTool::new(
170 "test_tool",
171 "A test tool",
172 serde_json::json!({
173 "type": "object",
174 "properties": {
175 "value": {"type": "integer"}
176 }
177 }),
178 |input: serde_json::Value| async move {
179 let value = input["value"].as_i64().unwrap_or(0);
180 Ok(McpToolResult {
181 content: vec![ToolContent::Text {
182 text: format!("Result: {}", value * 2),
183 }],
184 is_error: false,
185 })
186 },
187 );
188
189 let result = tool
190 .execute(serde_json::json!({"value": 21}))
191 .await
192 .unwrap();
193
194 assert!(!result.is_error);
195 assert_eq!(result.content.len(), 1);
196
197 match &result.content[0] {
198 ToolContent::Text { text } => assert_eq!(text, "Result: 42"),
199 _ => panic!("Expected Text content"),
200 }
201 }
202
203 #[test]
204 fn test_tool_definition_structure() {
205 let tool = McpTool::new(
206 "calc",
207 "Calculate",
208 serde_json::json!({"type": "object"}),
209 |_input: serde_json::Value| async move {
210 Ok(McpToolResult {
211 content: vec![],
212 is_error: false,
213 })
214 },
215 );
216
217 assert_eq!(tool.definition.name, "calc");
218 assert_eq!(tool.definition.description, "Calculate");
219 assert!(tool.definition.input_schema.is_object());
220 }
221
222 #[tokio::test]
223 async fn test_tool_with_error() {
224 let tool = McpTool::new(
225 "failing_tool",
226 "Always fails",
227 serde_json::json!({}),
228 |_input: serde_json::Value| async move {
229 Ok(McpToolResult {
230 content: vec![ToolContent::Text {
231 text: "Error occurred".to_string(),
232 }],
233 is_error: true,
234 })
235 },
236 );
237
238 let result = tool.execute(serde_json::json!({})).await.unwrap();
239
240 assert!(result.is_error);
241 assert_eq!(result.content.len(), 1);
242 }
243}