Skip to main content

agent_base/tool/
mcp.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde_json::{json, Value};
6
7use crate::types::{AgentError, AgentResult};
8use super::{Tool, ToolContext, ToolControlFlow, ToolOutput};
9
10pub struct McpToolInfo {
11    pub name: String,
12    pub description: String,
13    pub input_schema: Value,
14}
15
16pub struct McpClient {
17    server_url: String,
18    client: Client,
19}
20
21impl McpClient {
22    pub fn new(server_url: String) -> Self {
23        Self {
24            server_url,
25            client: Client::new(),
26        }
27    }
28
29    async fn send_request(&self, method: &str, params: Value) -> AgentResult<Value> {
30        let request = json!({
31            "jsonrpc": "2.0",
32            "id": 1,
33            "method": method,
34            "params": params,
35        });
36
37        let response = self
38            .client
39            .post(&self.server_url)
40            .header("Content-Type", "application/json")
41            .json(&request)
42            .send()
43            .await
44            .map_err(|e| AgentError::internal(format!("MCP request failed: {e}")))?;
45
46        let res: Value = response.json().await.map_err(|e| {
47            AgentError::json(format!("MCP response parse: {e}"))
48        })?;
49
50        if let Some(error) = res.get("error") {
51            return Err(AgentError::internal(format!("MCP error: {error}")));
52        }
53
54        Ok(res.get("result").cloned().unwrap_or(Value::Null))
55    }
56
57    pub async fn list_tools(&self) -> AgentResult<Vec<McpToolInfo>> {
58        let result = self.send_request("tools/list", json!({})).await?;
59        let tools = result
60            .get("tools")
61            .and_then(Value::as_array)
62            .ok_or_else(|| AgentError::internal("MCP: invalid tools/list response"))?;
63
64        let mut infos = Vec::new();
65        for tool in tools {
66            let name = tool
67                .get("name")
68                .and_then(Value::as_str)
69                .unwrap_or("unknown")
70                .to_string();
71            let description = tool
72                .get("description")
73                .and_then(Value::as_str)
74                .unwrap_or("")
75                .to_string();
76            let input_schema = tool
77                .get("inputSchema")
78                .cloned()
79                .unwrap_or_else(|| json!({"type": "object"}));
80            infos.push(McpToolInfo {
81                name,
82                description,
83                input_schema,
84            });
85        }
86        Ok(infos)
87    }
88
89    pub async fn call_tool(&self, tool_name: &str, arguments: &Value) -> AgentResult<Value> {
90        self.send_request(
91            "tools/call",
92            json!({
93                "name": tool_name,
94                "arguments": arguments,
95            }),
96        )
97        .await
98    }
99}
100
101struct McpToolAdapter {
102    name: &'static str,
103    description: String,
104    input_schema: Value,
105    mcp_client: Arc<McpClient>,
106}
107
108impl McpToolAdapter {
109    fn new(info: McpToolInfo, mcp_client: Arc<McpClient>) -> Self {
110        let static_name: &'static str = Box::leak(info.name.into_boxed_str());
111        Self {
112            name: static_name,
113            description: info.description,
114            input_schema: info.input_schema,
115            mcp_client,
116        }
117    }
118}
119
120#[async_trait]
121impl Tool for McpToolAdapter {
122    fn name(&self) -> &'static str {
123        self.name
124    }
125
126    fn definition(&self) -> Value {
127        json!({
128            "type": "function",
129            "function": {
130                "name": self.name,
131                "description": self.description,
132                "parameters": self.input_schema,
133            }
134        })
135    }
136
137    async fn call(&self, args: &Value, _ctx: &ToolContext) -> AgentResult<ToolOutput> {
138        let result = self.mcp_client.call_tool(self.name, args).await?;
139        let content = result
140            .get("content")
141            .and_then(|c| c.as_array())
142            .map(|arr| {
143                arr.iter()
144                    .filter_map(|item| item.get("text").and_then(Value::as_str))
145                    .collect::<Vec<_>>()
146                    .join("\n")
147            })
148            .filter(|s| !s.is_empty())
149            .unwrap_or_else(|| result.to_string());
150
151        Ok(ToolOutput {
152            summary: content,
153            raw: Some(result),
154            control_flow: ToolControlFlow::Break,
155            truncated: false,
156        })
157    }
158}
159
160pub struct McpToolRegistry {
161    mcp_client: Arc<McpClient>,
162}
163
164impl McpToolRegistry {
165    pub fn new(server_url: String) -> Self {
166        Self {
167            mcp_client: Arc::new(McpClient::new(server_url)),
168        }
169    }
170
171    pub async fn discover_tools(&self) -> AgentResult<Vec<Arc<dyn Tool>>> {
172        let infos = self.mcp_client.list_tools().await?;
173        let tools: Vec<Arc<dyn Tool>> = infos
174            .into_iter()
175            .map(|info| {
176                Arc::new(McpToolAdapter::new(info, self.mcp_client.clone())) as Arc<dyn Tool>
177            })
178            .collect();
179        Ok(tools)
180    }
181}