1#![cfg(feature = "mcp")]
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::{Arc, OnceLock};
6
7use agent_stream_kit::{AgentContext, AgentError, AgentValue, async_trait};
8use rmcp::{
9 model::{CallToolRequestParam, CallToolResult},
10 service::ServiceExt,
11 transport::{ConfigureCommandExt, TokioChildProcess},
12};
13use serde::Deserialize;
14use tokio::process::Command;
15use tokio::sync::Mutex as AsyncMutex;
16
17use crate::tool::{Tool, ToolInfo, register_tool};
18
19struct MCPTool {
21 server_name: String,
22 server_config: MCPServerConfig,
23 tool: rmcp::model::Tool,
24 info: ToolInfo,
25}
26
27impl MCPTool {
28 fn new(
29 name: String,
30 server_name: String,
31 server_config: MCPServerConfig,
32 tool: rmcp::model::Tool,
33 ) -> Self {
34 let info = ToolInfo {
35 name,
36 description: tool.description.clone().unwrap_or_default().into_owned(),
37 parameters: serde_json::to_value(&tool.input_schema).ok(),
38 };
39 Self {
40 server_name,
41 server_config,
42 tool,
43 info,
44 }
45 }
46
47 async fn tool_call(
48 &self,
49 _ctx: AgentContext,
50 value: AgentValue,
51 ) -> Result<AgentValue, AgentError> {
52 let conn = {
54 let mut pool = connection_pool().lock().await;
55 pool.get_or_create(&self.server_name, &self.server_config)
56 .await?
57 };
58
59 let arguments = value.as_object().map(|obj| {
60 obj.iter()
61 .map(|(k, v)| {
62 (
63 k.clone(),
64 serde_json::to_value(v).unwrap_or(serde_json::Value::Null),
65 )
66 })
67 .collect::<serde_json::Map<String, serde_json::Value>>()
68 });
69
70 let tool_result = {
71 let connection = conn.lock().await;
72 let service = connection.service.as_ref().ok_or_else(|| {
73 AgentError::Other(format!(
74 "MCP service for '{}' is not available",
75 self.server_name
76 ))
77 })?;
78 service
79 .call_tool(CallToolRequestParam {
80 name: self.tool.name.clone().into(),
81 arguments,
82 task: None,
83 })
84 .await
85 .map_err(|e| {
86 AgentError::Other(format!("Failed to call tool '{}': {e}", self.tool.name))
87 })?
88 };
89
90 Ok(call_tool_result_to_agent_value(tool_result)?)
91 }
92}
93
94#[async_trait]
95impl Tool for MCPTool {
96 fn info(&self) -> &ToolInfo {
97 &self.info
98 }
99
100 async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
101 self.tool_call(ctx, args).await
102 }
103}
104
105#[derive(Debug, Deserialize)]
107pub struct MCPConfig {
108 #[serde(rename = "mcpServers")]
109 pub mcp_servers: HashMap<String, MCPServerConfig>,
110}
111
112#[derive(Debug, Clone, Deserialize)]
113pub struct MCPServerConfig {
114 pub command: String,
115 pub args: Vec<String>,
116 #[serde(default)]
117 pub env: Option<HashMap<String, String>>,
118}
119
120type MCPService = rmcp::service::RunningService<rmcp::service::RoleClient, ()>;
121
122struct MCPConnection {
124 service: Option<MCPService>,
125}
126
127struct MCPConnectionPool {
129 connections: HashMap<String, Arc<AsyncMutex<MCPConnection>>>,
130}
131
132impl MCPConnectionPool {
133 fn new() -> Self {
134 Self {
135 connections: HashMap::new(),
136 }
137 }
138
139 async fn get_or_create(
140 &mut self,
141 server_name: &str,
142 config: &MCPServerConfig,
143 ) -> Result<Arc<AsyncMutex<MCPConnection>>, AgentError> {
144 if let Some(conn) = self.connections.get(server_name) {
146 log::debug!("Reusing existing MCP connection for '{}'", server_name);
147 return Ok(conn.clone());
148 }
149
150 log::info!(
151 "Starting MCP server '{}' (command: {})",
152 server_name,
153 config.command
154 );
155
156 let service = ()
158 .serve(
159 TokioChildProcess::new(Command::new(&config.command).configure(|cmd| {
160 for arg in &config.args {
161 cmd.arg(arg);
162 }
163 if let Some(env) = &config.env {
164 for (key, value) in env {
165 cmd.env(key, value);
166 }
167 }
168 }))
169 .map_err(|e| {
170 log::error!("Failed to start MCP process for '{}': {}", server_name, e);
171 AgentError::Other(format!(
172 "Failed to start MCP process for '{}': {e}",
173 server_name
174 ))
175 })?,
176 )
177 .await
178 .map_err(|e| {
179 log::error!("Failed to start MCP service for '{}': {}", server_name, e);
180 AgentError::Other(format!(
181 "Failed to start MCP service for '{}': {e}",
182 server_name
183 ))
184 })?;
185
186 log::info!("Successfully started MCP server '{}'", server_name);
187
188 let connection = MCPConnection {
189 service: Some(service),
190 };
191
192 let conn_arc = Arc::new(AsyncMutex::new(connection));
193 self.connections
194 .insert(server_name.to_string(), conn_arc.clone());
195 Ok(conn_arc)
196 }
197
198 async fn shutdown_all(&mut self) -> Result<(), AgentError> {
199 let count = self.connections.len();
200 log::debug!("Shutting down {} MCP server connection(s)", count);
201
202 for (name, conn) in self.connections.drain() {
203 log::debug!("Shutting down MCP server '{}'", name);
204 let mut connection = conn.lock().await;
205 if let Some(service) = connection.service.take() {
206 service.cancel().await.map_err(|e| {
207 log::error!("Failed to cancel MCP service '{}': {}", name, e);
208 AgentError::Other(format!("Failed to cancel MCP service: {e}"))
209 })?;
210 log::debug!("Successfully shut down MCP server '{}'", name);
211 }
212 }
213 Ok(())
214 }
215}
216
217static CONNECTION_POOL: OnceLock<AsyncMutex<MCPConnectionPool>> = OnceLock::new();
219
220fn connection_pool() -> &'static AsyncMutex<MCPConnectionPool> {
221 CONNECTION_POOL.get_or_init(|| AsyncMutex::new(MCPConnectionPool::new()))
222}
223
224pub async fn shutdown_all_mcp_connections() -> Result<(), AgentError> {
226 log::info!("Shutting down all MCP server connections");
227 connection_pool().lock().await.shutdown_all().await?;
228 log::info!("All MCP server connections shut down successfully");
229 Ok(())
230}
231
232async fn register_tools_from_server(
241 server_name: String,
242 server_config: MCPServerConfig,
243) -> Result<Vec<String>, AgentError> {
244 log::debug!("Registering tools from MCP server '{}'", server_name);
245
246 let conn = {
248 let mut pool = connection_pool().lock().await;
249 pool.get_or_create(&server_name, &server_config).await?
250 };
251
252 log::debug!("Listing tools from MCP server '{}'", server_name);
254 let tools_list = {
255 let connection = conn.lock().await;
256 let service = connection.service.as_ref().ok_or_else(|| {
257 log::error!("MCP service for '{}' is not available", server_name);
258 AgentError::Other(format!(
259 "MCP service for '{}' is not available",
260 server_name
261 ))
262 })?;
263 service.list_tools(Default::default()).await.map_err(|e| {
264 log::error!("Failed to list MCP tools for '{}': {}", server_name, e);
265 AgentError::Other(format!(
266 "Failed to list MCP tools for '{}': {e}",
267 server_name
268 ))
269 })?
270 };
271
272 let mut registered_tool_names = Vec::new();
273
274 for tool_info in tools_list.tools {
276 let mcp_tool_name = format!("{}::{}", server_name, tool_info.name);
277 registered_tool_names.push(mcp_tool_name.clone());
278
279 register_tool(MCPTool::new(
280 mcp_tool_name.clone(),
281 server_name.clone(),
282 server_config.clone(),
283 tool_info,
284 ));
285 log::debug!("Registered MCP tool '{}'", mcp_tool_name);
286 }
287
288 log::info!(
289 "Registered {} tools from MCP server '{}'",
290 registered_tool_names.len(),
291 server_name
292 );
293
294 Ok(registered_tool_names)
295}
296
297pub async fn register_tools_from_mcp_json<P: AsRef<Path>>(
317 json_path: P,
318) -> Result<Vec<String>, AgentError> {
319 let path = json_path.as_ref();
320 log::info!("Loading MCP configuration from: {}", path.display());
321
322 let json_content = std::fs::read_to_string(path).map_err(|e| {
324 log::error!("Failed to read MCP config file '{}': {}", path.display(), e);
325 AgentError::Other(format!("Failed to read MCP config file: {e}"))
326 })?;
327
328 let config: MCPConfig = serde_json::from_str(&json_content).map_err(|e| {
330 log::error!("Failed to parse MCP config JSON: {}", e);
331 AgentError::Other(format!("Failed to parse MCP config JSON: {e}"))
332 })?;
333
334 log::info!("Found {} MCP servers in config", config.mcp_servers.len());
335
336 let mut registered_tool_names = Vec::new();
337
338 for (server_name, server_config) in config.mcp_servers {
340 let tools = register_tools_from_server(server_name, server_config).await?;
341 registered_tool_names.extend(tools);
342 }
343
344 log::info!(
345 "Successfully registered {} MCP tools total",
346 registered_tool_names.len()
347 );
348
349 Ok(registered_tool_names)
350}
351
352fn call_tool_result_to_agent_value(result: CallToolResult) -> Result<AgentValue, AgentError> {
353 let mut contents = Vec::new();
354 for c in result.content.iter() {
355 match &c.raw {
356 rmcp::model::RawContent::Text(text) => {
357 contents.push(AgentValue::string(text.text.clone()));
358 }
359 _ => {
360 }
362 }
363 }
364 let data = AgentValue::array(contents.into());
365 if result.is_error == Some(true) {
366 return Err(AgentError::Other(
367 serde_json::to_string(&data).map_err(|e| AgentError::InvalidValue(e.to_string()))?,
368 ));
369 }
370 Ok(data)
371}