1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{Content, ToolInfo, ToolResult};
12
13#[async_trait]
15pub trait ToolHandler: Send + Sync {
16 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult>;
24}
25
26pub struct Tool {
28 pub info: ToolInfo,
30 pub handler: Box<dyn ToolHandler>,
32 pub enabled: bool,
34}
35
36impl Tool {
37 pub fn new<H>(
45 name: String,
46 description: Option<String>,
47 input_schema: Value,
48 handler: H,
49 ) -> Self
50 where
51 H: ToolHandler + 'static,
52 {
53 Self {
54 info: ToolInfo {
55 name,
56 description,
57 input_schema,
58 },
59 handler: Box::new(handler),
60 enabled: true,
61 }
62 }
63
64 pub fn enable(&mut self) {
66 self.enabled = true;
67 }
68
69 pub fn disable(&mut self) {
71 self.enabled = false;
72 }
73
74 pub fn is_enabled(&self) -> bool {
76 self.enabled
77 }
78
79 pub async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
87 if !self.enabled {
88 return Err(McpError::validation(format!(
89 "Tool '{}' is disabled",
90 self.info.name
91 )));
92 }
93
94 self.handler.call(arguments).await
95 }
96}
97
98impl std::fmt::Debug for Tool {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("Tool")
101 .field("info", &self.info)
102 .field("enabled", &self.enabled)
103 .finish()
104 }
105}
106
107#[macro_export]
136macro_rules! tool {
137 ($name:expr, $schema:expr, $handler:expr) => {
138 $crate::core::tool::Tool::new($name.to_string(), None, $schema, $handler)
139 };
140 ($name:expr, $description:expr, $schema:expr, $handler:expr) => {
141 $crate::core::tool::Tool::new(
142 $name.to_string(),
143 Some($description.to_string()),
144 $schema,
145 $handler,
146 )
147 };
148}
149
150pub struct EchoTool;
154
155#[async_trait]
156impl ToolHandler for EchoTool {
157 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
158 let message = arguments
159 .get("message")
160 .and_then(|v| v.as_str())
161 .unwrap_or("Hello, World!");
162
163 Ok(ToolResult {
164 content: vec![Content::Text {
165 text: message.to_string(),
166 }],
167 is_error: None,
168 })
169 }
170}
171
172pub struct AdditionTool;
174
175#[async_trait]
176impl ToolHandler for AdditionTool {
177 async fn call(&self, arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
178 let a = arguments
179 .get("a")
180 .and_then(|v| v.as_f64())
181 .ok_or_else(|| McpError::validation("Missing or invalid 'a' parameter"))?;
182
183 let b = arguments
184 .get("b")
185 .and_then(|v| v.as_f64())
186 .ok_or_else(|| McpError::validation("Missing or invalid 'b' parameter"))?;
187
188 let result = a + b;
189
190 Ok(ToolResult {
191 content: vec![Content::Text {
192 text: result.to_string(),
193 }],
194 is_error: None,
195 })
196 }
197}
198
199pub struct TimestampTool;
201
202#[async_trait]
203impl ToolHandler for TimestampTool {
204 async fn call(&self, _arguments: HashMap<String, Value>) -> McpResult<ToolResult> {
205 use std::time::{SystemTime, UNIX_EPOCH};
206
207 let timestamp = SystemTime::now()
208 .duration_since(UNIX_EPOCH)
209 .map_err(|e| McpError::internal(e.to_string()))?
210 .as_secs();
211
212 Ok(ToolResult {
213 content: vec![Content::Text {
214 text: timestamp.to_string(),
215 }],
216 is_error: None,
217 })
218 }
219}
220
221pub struct ToolBuilder {
223 name: String,
224 description: Option<String>,
225 input_schema: Option<Value>,
226}
227
228impl ToolBuilder {
229 pub fn new<S: Into<String>>(name: S) -> Self {
231 Self {
232 name: name.into(),
233 description: None,
234 input_schema: None,
235 }
236 }
237
238 pub fn description<S: Into<String>>(mut self, description: S) -> Self {
240 self.description = Some(description.into());
241 self
242 }
243
244 pub fn schema(mut self, schema: Value) -> Self {
246 self.input_schema = Some(schema);
247 self
248 }
249
250 pub fn build<H>(self, handler: H) -> McpResult<Tool>
252 where
253 H: ToolHandler + 'static,
254 {
255 let schema = self.input_schema.unwrap_or_else(|| {
256 serde_json::json!({
257 "type": "object",
258 "properties": {},
259 "additionalProperties": true
260 })
261 });
262
263 Ok(Tool::new(self.name, self.description, schema, handler))
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use serde_json::json;
271
272 #[tokio::test]
273 async fn test_echo_tool() {
274 let tool = EchoTool;
275 let mut args = HashMap::new();
276 args.insert("message".to_string(), json!("test message"));
277
278 let result = tool.call(args).await.unwrap();
279 match &result.content[0] {
280 Content::Text { text } => assert_eq!(text, "test message"),
281 _ => panic!("Expected text content"),
282 }
283 }
284
285 #[tokio::test]
286 async fn test_addition_tool() {
287 let tool = AdditionTool;
288 let mut args = HashMap::new();
289 args.insert("a".to_string(), json!(5.0));
290 args.insert("b".to_string(), json!(3.0));
291
292 let result = tool.call(args).await.unwrap();
293 match &result.content[0] {
294 Content::Text { text } => assert_eq!(text, "8"),
295 _ => panic!("Expected text content"),
296 }
297 }
298
299 #[test]
300 fn test_tool_creation() {
301 let tool = Tool::new(
302 "test_tool".to_string(),
303 Some("Test tool".to_string()),
304 json!({"type": "object"}),
305 EchoTool,
306 );
307
308 assert_eq!(tool.info.name, "test_tool");
309 assert_eq!(tool.info.description, Some("Test tool".to_string()));
310 assert!(tool.is_enabled());
311 }
312
313 #[test]
314 fn test_tool_enable_disable() {
315 let mut tool = Tool::new(
316 "test_tool".to_string(),
317 None,
318 json!({"type": "object"}),
319 EchoTool,
320 );
321
322 assert!(tool.is_enabled());
323
324 tool.disable();
325 assert!(!tool.is_enabled());
326
327 tool.enable();
328 assert!(tool.is_enabled());
329 }
330
331 #[tokio::test]
332 async fn test_disabled_tool() {
333 let mut tool = Tool::new(
334 "test_tool".to_string(),
335 None,
336 json!({"type": "object"}),
337 EchoTool,
338 );
339
340 tool.disable();
341
342 let result = tool.call(HashMap::new()).await;
343 assert!(result.is_err());
344 match result.unwrap_err() {
345 McpError::Validation(msg) => assert!(msg.contains("disabled")),
346 _ => panic!("Expected validation error"),
347 }
348 }
349
350 #[test]
351 fn test_tool_builder() {
352 let tool = ToolBuilder::new("test")
353 .description("A test tool")
354 .schema(json!({"type": "object", "properties": {"x": {"type": "number"}}}))
355 .build(EchoTool)
356 .unwrap();
357
358 assert_eq!(tool.info.name, "test");
359 assert_eq!(tool.info.description, Some("A test tool".to_string()));
360 }
361}