1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{Content, ToolInfo, ToolInputSchema, ToolResult};
12
13#[async_trait]
15pub trait ToolHandler: Send + Sync {
16 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult>;
24}
25
26pub struct Tool {
28 pub info: ToolInfo,
30 pub handler: Box<dyn ToolHandler>,
32 pub enabled: bool,
34}
35
36impl Tool {
37 pub fn new<H>(
45 name: String,
46 description: Option<String>,
47 input_schema: Value,
48 handler: H,
49 ) -> Self
50 where
51 H: ToolHandler + 'static,
52 {
53 Self {
54 info: ToolInfo {
55 name,
56 description,
57 input_schema: ToolInputSchema {
58 schema_type: "object".to_string(),
59 properties: input_schema
60 .get("properties")
61 .and_then(|p| p.as_object())
62 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
63 required: input_schema
64 .get("required")
65 .and_then(|r| r.as_array())
66 .map(|arr| {
67 arr.iter()
68 .filter_map(|v| v.as_str().map(String::from))
69 .collect()
70 }),
71 additional_properties: input_schema
72 .as_object()
73 .unwrap_or(&serde_json::Map::new())
74 .iter()
75 .filter(|(k, _)| !["type", "properties", "required"].contains(&k.as_str()))
76 .map(|(k, v)| (k.clone(), v.clone()))
77 .collect(),
78 },
79 annotations: None,
80 },
81 handler: Box::new(handler),
82 enabled: true,
83 }
84 }
85
86 pub fn enable(&mut self) {
88 self.enabled = true;
89 }
90
91 pub fn disable(&mut self) {
93 self.enabled = false;
94 }
95
96 pub fn is_enabled(&self) -> bool {
98 self.enabled
99 }
100
101 pub async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
109 if !self.enabled {
110 return Err(McpError::validation(format!(
111 "Tool '{}' is disabled",
112 self.info.name
113 )));
114 }
115
116 self.handler.call(arguments).await
117 }
118}
119
120impl std::fmt::Debug for Tool {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("Tool")
123 .field("info", &self.info)
124 .field("enabled", &self.enabled)
125 .finish()
126 }
127}
128
129#[macro_export]
158macro_rules! tool {
159 ($name:expr, $schema:expr, $handler:expr) => {
160 $crate::core::tool::Tool::new($name.to_string(), None, $schema, $handler)
161 };
162 ($name:expr, $description:expr, $schema:expr, $handler:expr) => {
163 $crate::core::tool::Tool::new(
164 $name.to_string(),
165 Some($description.to_string()),
166 $schema,
167 $handler,
168 )
169 };
170}
171
172pub struct EchoTool;
176
177#[async_trait]
178impl ToolHandler for EchoTool {
179 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
180 let message = arguments
181 .get("message")
182 .and_then(|v| v.as_str())
183 .unwrap_or("Hello, World!");
184
185 Ok(ToolResult {
186 content: vec![Content::Text {
187 text: message.to_string(),
188 annotations: None,
189 }],
190 is_error: None,
191 meta: None,
192 })
193 }
194}
195
196pub struct AdditionTool;
198
199#[async_trait]
200impl ToolHandler for AdditionTool {
201 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
202 let a = arguments
203 .get("a")
204 .and_then(|v| v.as_f64())
205 .ok_or_else(|| McpError::validation("Missing or invalid 'a' parameter"))?;
206
207 let b = arguments
208 .get("b")
209 .and_then(|v| v.as_f64())
210 .ok_or_else(|| McpError::validation("Missing or invalid 'b' parameter"))?;
211
212 let result = a + b;
213
214 Ok(ToolResult {
215 content: vec![Content::Text {
216 text: result.to_string(),
217 annotations: None,
218 }],
219 is_error: None,
220 meta: None,
221 })
222 }
223}
224
225pub struct TimestampTool;
227
228#[async_trait]
229impl ToolHandler for TimestampTool {
230 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
231 use std::time::{SystemTime, UNIX_EPOCH};
232
233 let timestamp = SystemTime::now()
234 .duration_since(UNIX_EPOCH)
235 .map_err(|e| McpError::internal(e.to_string()))?
236 .as_secs();
237
238 Ok(ToolResult {
239 content: vec![Content::Text {
240 text: timestamp.to_string(),
241 annotations: None,
242 }],
243 is_error: None,
244 meta: None,
245 })
246 }
247}
248
249pub struct ToolBuilder {
251 name: String,
252 description: Option<String>,
253 input_schema: Option<Value>,
254}
255
256impl ToolBuilder {
257 pub fn new<S: Into<String>>(name: S) -> Self {
259 Self {
260 name: name.into(),
261 description: None,
262 input_schema: None,
263 }
264 }
265
266 pub fn description<S: Into<String>>(mut self, description: S) -> Self {
268 self.description = Some(description.into());
269 self
270 }
271
272 pub fn schema(mut self, schema: Value) -> Self {
274 self.input_schema = Some(schema);
275 self
276 }
277
278 pub fn build<H>(self, handler: H) -> McpResult<Tool>
280 where
281 H: ToolHandler + 'static,
282 {
283 let schema = self.input_schema.unwrap_or_else(|| {
284 serde_json::json!({
285 "type": "object",
286 "properties": {},
287 "additionalProperties": true
288 })
289 });
290
291 Ok(Tool::new(self.name, self.description, schema, handler))
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use serde_json::json;
299
300 #[tokio::test]
301 async fn test_echo_tool() {
302 let tool = EchoTool;
303 let mut args = HashMap::new();
304 args.insert("message".to_string(), json!("test message"));
305
306 let result = tool.call(args).await.unwrap();
307 match &result.content[0] {
308 Content::Text { text, .. } => assert_eq!(text, "test message"),
309 _ => panic!("Expected text content"),
310 }
311 }
312
313 #[tokio::test]
314 async fn test_addition_tool() {
315 let tool = AdditionTool;
316 let mut args = HashMap::new();
317 args.insert("a".to_string(), json!(5.0));
318 args.insert("b".to_string(), json!(3.0));
319
320 let result = tool.call(args).await.unwrap();
321 match &result.content[0] {
322 Content::Text { text, .. } => assert_eq!(text, "8"),
323 _ => panic!("Expected text content"),
324 }
325 }
326
327 #[test]
328 fn test_tool_creation() {
329 let tool = Tool::new(
330 "test_tool".to_string(),
331 Some("Test tool".to_string()),
332 json!({"type": "object"}),
333 EchoTool,
334 );
335
336 assert_eq!(tool.info.name, "test_tool");
337 assert_eq!(tool.info.description, Some("Test tool".to_string()));
338 assert!(tool.is_enabled());
339 }
340
341 #[test]
342 fn test_tool_enable_disable() {
343 let mut tool = Tool::new(
344 "test_tool".to_string(),
345 None,
346 json!({"type": "object"}),
347 EchoTool,
348 );
349
350 assert!(tool.is_enabled());
351
352 tool.disable();
353 assert!(!tool.is_enabled());
354
355 tool.enable();
356 assert!(tool.is_enabled());
357 }
358
359 #[tokio::test]
360 async fn test_disabled_tool() {
361 let mut tool = Tool::new(
362 "test_tool".to_string(),
363 None,
364 json!({"type": "object"}),
365 EchoTool,
366 );
367
368 tool.disable();
369
370 let result = tool.call(HashMap::new()).await;
371 assert!(result.is_err());
372 match result.unwrap_err() {
373 McpError::Validation(msg) => assert!(msg.contains("disabled")),
374 _ => panic!("Expected validation error"),
375 }
376 }
377
378 #[test]
379 fn test_tool_builder() {
380 let tool = ToolBuilder::new("test")
381 .description("A test tool")
382 .schema(json!({"type": "object", "properties": {"x": {"type": "number"}}}))
383 .build(EchoTool)
384 .unwrap();
385
386 assert_eq!(tool.info.name, "test");
387 assert_eq!(tool.info.description, Some("A test tool".to_string()));
388 }
389}