1#![cfg(feature = "mcp")]
38
39use std::collections::HashMap;
40use std::path::Path;
41use std::sync::{Arc, OnceLock};
42
43use modular_agent_core::{AgentContext, AgentError, AgentValue, async_trait};
44use rmcp::{
45 model::{CallToolRequestParam, CallToolResult},
46 service::ServiceExt,
47 transport::{ConfigureCommandExt, TokioChildProcess},
48};
49use serde::Deserialize;
50use tokio::process::Command;
51use tokio::sync::Mutex as AsyncMutex;
52
53use crate::tool::{Tool, ToolInfo, register_tool};
54
55struct MCPTool {
59 server_name: String,
61 server_config: MCPServerConfig,
63 tool: rmcp::model::Tool,
65 info: ToolInfo,
67}
68
69impl MCPTool {
70 fn new(
79 name: String,
80 server_name: String,
81 server_config: MCPServerConfig,
82 tool: rmcp::model::Tool,
83 ) -> Self {
84 let info = ToolInfo {
85 name,
86 description: tool.description.clone().unwrap_or_default().into_owned(),
87 parameters: serde_json::to_value(&tool.input_schema).ok(),
88 };
89 Self {
90 server_name,
91 server_config,
92 tool,
93 info,
94 }
95 }
96
97 async fn tool_call(
101 &self,
102 _ctx: AgentContext,
103 value: AgentValue,
104 ) -> Result<AgentValue, AgentError> {
105 let conn = {
107 let mut pool = connection_pool().lock().await;
108 pool.get_or_create(&self.server_name, &self.server_config)
109 .await?
110 };
111
112 let arguments = value.as_object().map(|obj| {
113 obj.iter()
114 .map(|(k, v)| {
115 (
116 k.clone(),
117 serde_json::to_value(v).unwrap_or(serde_json::Value::Null),
118 )
119 })
120 .collect::<serde_json::Map<String, serde_json::Value>>()
121 });
122
123 let tool_result = {
124 let connection = conn.lock().await;
125 let service = connection.service.as_ref().ok_or_else(|| {
126 AgentError::Other(format!(
127 "MCP service for '{}' is not available",
128 self.server_name
129 ))
130 })?;
131 service
132 .call_tool(CallToolRequestParam {
133 name: self.tool.name.clone().into(),
134 arguments,
135 task: None,
136 })
137 .await
138 .map_err(|e| {
139 AgentError::Other(format!("Failed to call tool '{}': {e}", self.tool.name))
140 })?
141 };
142
143 Ok(call_tool_result_to_agent_value(tool_result)?)
144 }
145}
146
147#[async_trait]
148impl Tool for MCPTool {
149 fn info(&self) -> &ToolInfo {
150 &self.info
151 }
152
153 async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
154 self.tool_call(ctx, args).await
155 }
156}
157
158#[derive(Debug, Deserialize)]
175pub struct MCPConfig {
176 #[serde(rename = "mcpServers")]
178 pub mcp_servers: HashMap<String, MCPServerConfig>,
179}
180
181#[derive(Debug, Clone, Deserialize)]
185pub struct MCPServerConfig {
186 pub command: String,
188
189 pub args: Vec<String>,
191
192 #[serde(default)]
194 pub env: Option<HashMap<String, String>>,
195}
196
197type MCPService = rmcp::service::RunningService<rmcp::service::RoleClient, ()>;
199
200struct MCPConnection {
202 service: Option<MCPService>,
204}
205
206struct MCPConnectionPool {
211 connections: HashMap<String, Arc<AsyncMutex<MCPConnection>>>,
213}
214
215impl MCPConnectionPool {
216 fn new() -> Self {
218 Self {
219 connections: HashMap::new(),
220 }
221 }
222
223 async fn get_or_create(
228 &mut self,
229 server_name: &str,
230 config: &MCPServerConfig,
231 ) -> Result<Arc<AsyncMutex<MCPConnection>>, AgentError> {
232 if let Some(conn) = self.connections.get(server_name) {
234 log::debug!("Reusing existing MCP connection for '{}'", server_name);
235 return Ok(conn.clone());
236 }
237
238 log::info!(
239 "Starting MCP server '{}' (command: {})",
240 server_name,
241 config.command
242 );
243
244 let service = ()
246 .serve(
247 TokioChildProcess::new(Command::new(&config.command).configure(|cmd| {
248 for arg in &config.args {
249 cmd.arg(arg);
250 }
251 if let Some(env) = &config.env {
252 for (key, value) in env {
253 cmd.env(key, value);
254 }
255 }
256 }))
257 .map_err(|e| {
258 log::error!("Failed to start MCP process for '{}': {}", server_name, e);
259 AgentError::Other(format!(
260 "Failed to start MCP process for '{}': {e}",
261 server_name
262 ))
263 })?,
264 )
265 .await
266 .map_err(|e| {
267 log::error!("Failed to start MCP service for '{}': {}", server_name, e);
268 AgentError::Other(format!(
269 "Failed to start MCP service for '{}': {e}",
270 server_name
271 ))
272 })?;
273
274 log::info!("Successfully started MCP server '{}'", server_name);
275
276 let connection = MCPConnection {
277 service: Some(service),
278 };
279
280 let conn_arc = Arc::new(AsyncMutex::new(connection));
281 self.connections
282 .insert(server_name.to_string(), conn_arc.clone());
283 Ok(conn_arc)
284 }
285
286 async fn shutdown_all(&mut self) -> Result<(), AgentError> {
290 let count = self.connections.len();
291 log::debug!("Shutting down {} MCP server connection(s)", count);
292
293 for (name, conn) in self.connections.drain() {
294 log::debug!("Shutting down MCP server '{}'", name);
295 let mut connection = conn.lock().await;
296 if let Some(service) = connection.service.take() {
297 service.cancel().await.map_err(|e| {
298 log::error!("Failed to cancel MCP service '{}': {}", name, e);
299 AgentError::Other(format!("Failed to cancel MCP service: {e}"))
300 })?;
301 log::debug!("Successfully shut down MCP server '{}'", name);
302 }
303 }
304 Ok(())
305 }
306}
307
308static CONNECTION_POOL: OnceLock<AsyncMutex<MCPConnectionPool>> = OnceLock::new();
310
311fn connection_pool() -> &'static AsyncMutex<MCPConnectionPool> {
313 CONNECTION_POOL.get_or_init(|| AsyncMutex::new(MCPConnectionPool::new()))
314}
315
316pub async fn shutdown_all_mcp_connections() -> Result<(), AgentError> {
335 log::info!("Shutting down all MCP server connections");
336 connection_pool().lock().await.shutdown_all().await?;
337 log::info!("All MCP server connections shut down successfully");
338 Ok(())
339}
340
341async fn register_tools_from_server(
355 server_name: String,
356 server_config: MCPServerConfig,
357) -> Result<Vec<String>, AgentError> {
358 log::debug!("Registering tools from MCP server '{}'", server_name);
359
360 let conn = {
362 let mut pool = connection_pool().lock().await;
363 pool.get_or_create(&server_name, &server_config).await?
364 };
365
366 log::debug!("Listing tools from MCP server '{}'", server_name);
368 let tools_list = {
369 let connection = conn.lock().await;
370 let service = connection.service.as_ref().ok_or_else(|| {
371 log::error!("MCP service for '{}' is not available", server_name);
372 AgentError::Other(format!(
373 "MCP service for '{}' is not available",
374 server_name
375 ))
376 })?;
377 service.list_tools(Default::default()).await.map_err(|e| {
378 log::error!("Failed to list MCP tools for '{}': {}", server_name, e);
379 AgentError::Other(format!(
380 "Failed to list MCP tools for '{}': {e}",
381 server_name
382 ))
383 })?
384 };
385
386 let mut registered_tool_names = Vec::new();
387
388 for tool_info in tools_list.tools {
390 let mcp_tool_name = format!("{}::{}", server_name, tool_info.name);
391 registered_tool_names.push(mcp_tool_name.clone());
392
393 register_tool(MCPTool::new(
394 mcp_tool_name.clone(),
395 server_name.clone(),
396 server_config.clone(),
397 tool_info,
398 ));
399 log::debug!("Registered MCP tool '{}'", mcp_tool_name);
400 }
401
402 log::info!(
403 "Registered {} tools from MCP server '{}'",
404 registered_tool_names.len(),
405 server_name
406 );
407
408 Ok(registered_tool_names)
409}
410
411pub async fn register_tools_from_mcp_json<P: AsRef<Path>>(
431 json_path: P,
432) -> Result<Vec<String>, AgentError> {
433 let path = json_path.as_ref();
434 log::info!("Loading MCP configuration from: {}", path.display());
435
436 let json_content = std::fs::read_to_string(path).map_err(|e| {
438 log::error!("Failed to read MCP config file '{}': {}", path.display(), e);
439 AgentError::Other(format!("Failed to read MCP config file: {e}"))
440 })?;
441
442 let config: MCPConfig = serde_json::from_str(&json_content).map_err(|e| {
444 log::error!("Failed to parse MCP config JSON: {}", e);
445 AgentError::Other(format!("Failed to parse MCP config JSON: {e}"))
446 })?;
447
448 log::info!("Found {} MCP servers in config", config.mcp_servers.len());
449
450 let mut registered_tool_names = Vec::new();
451
452 for (server_name, server_config) in config.mcp_servers {
454 let tools = register_tools_from_server(server_name, server_config).await?;
455 registered_tool_names.extend(tools);
456 }
457
458 log::info!(
459 "Successfully registered {} MCP tools total",
460 registered_tool_names.len()
461 );
462
463 Ok(registered_tool_names)
464}
465
466fn call_tool_result_to_agent_value(result: CallToolResult) -> Result<AgentValue, AgentError> {
471 let mut contents = Vec::new();
472 for c in result.content.iter() {
473 match &c.raw {
474 rmcp::model::RawContent::Text(text) => {
475 contents.push(AgentValue::string(text.text.clone()));
476 }
477 _ => {
478 }
480 }
481 }
482 let data = AgentValue::array(contents.into());
483 if result.is_error == Some(true) {
484 return Err(AgentError::Other(
485 serde_json::to_string(&data).map_err(|e| AgentError::InvalidValue(e.to_string()))?,
486 ));
487 }
488 Ok(data)
489}