mcp_sdk/
tools.rs

1use crate::types::{CallToolRequest, CallToolResponse, ToolDefinition, ToolResponseContent};
2use anyhow::Result;
3use std::{collections::HashMap, sync::Arc};
4
5pub trait Tool: Send + Sync + 'static {
6    fn name(&self) -> String;
7    fn description(&self) -> String;
8    fn input_schema(&self) -> serde_json::Value;
9    fn call(&self, input: Option<serde_json::Value>) -> Result<CallToolResponse>;
10    fn as_definition(&self) -> ToolDefinition {
11        ToolDefinition {
12            name: self.name(),
13            description: Some(self.description()),
14            input_schema: self.input_schema(),
15        }
16    }
17}
18
19#[derive(Default)]
20pub struct Tools {
21    tools: HashMap<String, Arc<dyn Tool>>,
22}
23
24impl Tools {
25    pub fn add_tool(&mut self, tool: impl Tool) {
26        self.tools.insert(tool.name(), Arc::new(tool));
27    }
28
29    pub fn list_tools(&self) -> Vec<ToolDefinition> {
30        self.tools
31            .values()
32            .map(|tool| tool.as_definition())
33            .collect()
34    }
35
36    pub fn call_tool(&self, request: CallToolRequest) -> CallToolResponse {
37        let tool = self.tools.get(&request.name);
38        if tool.is_none() {
39            return CallToolResponse {
40                content: vec![ToolResponseContent::Text {
41                    text: format!("Tool {} not found", request.name),
42                }],
43                is_error: Some(true),
44                meta: None,
45            };
46        }
47        let arguments = request.arguments;
48        let result = tool.unwrap().call(arguments);
49        if result.is_err() {
50            return CallToolResponse {
51                content: vec![ToolResponseContent::Text {
52                    text: format!(
53                        "Error calling tool {}: {}",
54                        &request.name,
55                        result.err().unwrap()
56                    ),
57                }],
58                is_error: Some(true),
59                meta: None,
60            };
61        }
62        result.unwrap()
63    }
64}