mcp_core/tools/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, sync::Arc};
use test_tool::{PingTool, TestTool};
use tokio::sync::RwLock;

pub mod calculator;
pub mod file_system;
pub mod test_tool;

use crate::{client::Client, error::McpError};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolType {
    Calculator,
    TestTool,
    PingTool,
    FileSystem,
}

impl ToolType {
    pub fn to_tool_provider(&self) -> Arc<dyn ToolProvider> {
        match self {
            ToolType::Calculator => Arc::new(calculator::CalculatorTool::new()),
            ToolType::TestTool => Arc::new(TestTool::new()),
            ToolType::PingTool => Arc::new(PingTool::new()),
            ToolType::FileSystem => Arc::new(file_system::FileSystemTools::new()),
        }
    }
}

// Tool Types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
    pub name: String,
    pub description: String,
    pub input_schema: ToolInputSchema,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolInputSchema {
    #[serde(rename = "type")]
    pub schema_type: String,
    pub properties: HashMap<String, Value>,
    pub required: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "camelCase")]
pub enum ToolContent {
    #[serde(rename = "text")]
    Text { text: String },
    #[serde(rename = "image")]
    Image { data: String, mime_type: String },
    #[serde(rename = "resource")]
    Resource { resource: ResourceContent },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceContent {
    pub uri: String,
    pub mime_type: Option<String>,
    pub text: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResult {
    pub content: Vec<ToolContent>,
    pub is_error: bool,
}

// Request/Response types
#[derive(Debug, Deserialize, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsRequest {
    pub cursor: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResponse {
    pub tools: Vec<Tool>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub next_cursor: Option<String>,
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolRequest {
    pub name: String,
    pub arguments: Value,
}

// Tool Provider trait
#[async_trait]
pub trait ToolProvider: Send + Sync {
    /// Get tool definition
    async fn get_tool(&self) -> Tool;

    /// Execute tool
    async fn execute(&self, arguments: Value) -> Result<ToolResult, McpError>;
}

// Tool Manager
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolCapabilities {
    pub list_changed: bool,
}

pub struct ToolManager {
    pub tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProvider>>>>,
    pub capabilities: ToolCapabilities,
}

impl ToolManager {
    pub fn new(capabilities: ToolCapabilities) -> Self {
        Self {
            tools: Arc::new(RwLock::new(HashMap::new())),
            capabilities,
        }
    }

    pub async fn register_tool(&self, provider: Arc<dyn ToolProvider>) {
        let tool = provider.get_tool().await;
        let mut tools = self.tools.write().await;
        tools.insert(tool.name, provider);
    }

    pub async fn list_tools(&self, _cursor: Option<String>) -> Result<ListToolsResponse, McpError> {
        let tools = self.tools.read().await;
        let mut tool_list = Vec::new();

        for provider in tools.values() {
            tool_list.push(provider.get_tool().await);
        }

        Ok(ListToolsResponse {
            tools: tool_list,
            next_cursor: None, // Implement pagination if needed
        })
    }

    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolResult, McpError> {
        let tools = self.tools.read().await;
        let provider = tools
            .get(name)
            .ok_or_else(|| McpError::InvalidRequest(format!("Unknown tool: {}", name)))?;

        provider.execute(arguments).await
    }
}