agent_stream_kit/
mcp.rs

1#![cfg(feature = "mcp")]
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::{Arc, OnceLock};
6
7use agent_stream_kit::{AgentContext, AgentError, AgentValue, async_trait};
8use rmcp::{
9    model::{CallToolRequestParam, CallToolResult},
10    service::ServiceExt,
11    transport::{ConfigureCommandExt, TokioChildProcess},
12};
13use serde::Deserialize;
14use tokio::process::Command;
15use tokio::sync::Mutex as AsyncMutex;
16
17use crate::tool::{Tool, ToolInfo, register_tool};
18
19/// MCP Tool with connection pool support
20struct MCPTool {
21    server_name: String,
22    server_config: MCPServerConfig,
23    tool: rmcp::model::Tool,
24    info: ToolInfo,
25}
26
27impl MCPTool {
28    fn new(
29        name: String,
30        server_name: String,
31        server_config: MCPServerConfig,
32        tool: rmcp::model::Tool,
33    ) -> Self {
34        let info = ToolInfo {
35            name,
36            description: tool.description.clone().unwrap_or_default().into_owned(),
37            parameters: serde_json::to_value(&tool.input_schema).ok(),
38        };
39        Self {
40            server_name,
41            server_config,
42            tool,
43            info,
44        }
45    }
46
47    async fn tool_call(
48        &self,
49        _ctx: AgentContext,
50        value: AgentValue,
51    ) -> Result<AgentValue, AgentError> {
52        // Get or create connection from pool
53        let conn = {
54            let mut pool = connection_pool().lock().await;
55            pool.get_or_create(&self.server_name, &self.server_config)
56                .await?
57        };
58
59        let arguments = value.as_object().map(|obj| {
60            obj.iter()
61                .map(|(k, v)| {
62                    (
63                        k.clone(),
64                        serde_json::to_value(v).unwrap_or(serde_json::Value::Null),
65                    )
66                })
67                .collect::<serde_json::Map<String, serde_json::Value>>()
68        });
69
70        let tool_result = {
71            let connection = conn.lock().await;
72            let service = connection.service.as_ref().ok_or_else(|| {
73                AgentError::Other(format!(
74                    "MCP service for '{}' is not available",
75                    self.server_name
76                ))
77            })?;
78            service
79                .call_tool(CallToolRequestParam {
80                    name: self.tool.name.clone().into(),
81                    arguments,
82                    task: None,
83                })
84                .await
85                .map_err(|e| {
86                    AgentError::Other(format!("Failed to call tool '{}': {e}", self.tool.name))
87                })?
88        };
89
90        Ok(call_tool_result_to_agent_value(tool_result)?)
91    }
92}
93
94#[async_trait]
95impl Tool for MCPTool {
96    fn info(&self) -> &ToolInfo {
97        &self.info
98    }
99
100    async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
101        self.tool_call(ctx, args).await
102    }
103}
104
105/// Structure representing the Claude Desktop MCP configuration format
106#[derive(Debug, Deserialize)]
107pub struct MCPConfig {
108    #[serde(rename = "mcpServers")]
109    pub mcp_servers: HashMap<String, MCPServerConfig>,
110}
111
112#[derive(Debug, Clone, Deserialize)]
113pub struct MCPServerConfig {
114    pub command: String,
115    pub args: Vec<String>,
116    #[serde(default)]
117    pub env: Option<HashMap<String, String>>,
118}
119
120type MCPService = rmcp::service::RunningService<rmcp::service::RoleClient, ()>;
121
122/// Connection pool entry for an MCP server
123struct MCPConnection {
124    service: Option<MCPService>,
125}
126
127/// Connection pool for MCP servers
128struct MCPConnectionPool {
129    connections: HashMap<String, Arc<AsyncMutex<MCPConnection>>>,
130}
131
132impl MCPConnectionPool {
133    fn new() -> Self {
134        Self {
135            connections: HashMap::new(),
136        }
137    }
138
139    async fn get_or_create(
140        &mut self,
141        server_name: &str,
142        config: &MCPServerConfig,
143    ) -> Result<Arc<AsyncMutex<MCPConnection>>, AgentError> {
144        // Check if connection already exists
145        if let Some(conn) = self.connections.get(server_name) {
146            log::debug!("Reusing existing MCP connection for '{}'", server_name);
147            return Ok(conn.clone());
148        }
149
150        log::info!(
151            "Starting MCP server '{}' (command: {})",
152            server_name,
153            config.command
154        );
155
156        // Start new MCP service
157        let service = ()
158            .serve(
159                TokioChildProcess::new(Command::new(&config.command).configure(|cmd| {
160                    for arg in &config.args {
161                        cmd.arg(arg);
162                    }
163                    if let Some(env) = &config.env {
164                        for (key, value) in env {
165                            cmd.env(key, value);
166                        }
167                    }
168                }))
169                .map_err(|e| {
170                    log::error!("Failed to start MCP process for '{}': {}", server_name, e);
171                    AgentError::Other(format!(
172                        "Failed to start MCP process for '{}': {e}",
173                        server_name
174                    ))
175                })?,
176            )
177            .await
178            .map_err(|e| {
179                log::error!("Failed to start MCP service for '{}': {}", server_name, e);
180                AgentError::Other(format!(
181                    "Failed to start MCP service for '{}': {e}",
182                    server_name
183                ))
184            })?;
185
186        log::info!("Successfully started MCP server '{}'", server_name);
187
188        let connection = MCPConnection {
189            service: Some(service),
190        };
191
192        let conn_arc = Arc::new(AsyncMutex::new(connection));
193        self.connections
194            .insert(server_name.to_string(), conn_arc.clone());
195        Ok(conn_arc)
196    }
197
198    async fn shutdown_all(&mut self) -> Result<(), AgentError> {
199        let count = self.connections.len();
200        log::debug!("Shutting down {} MCP server connection(s)", count);
201
202        for (name, conn) in self.connections.drain() {
203            log::debug!("Shutting down MCP server '{}'", name);
204            let mut connection = conn.lock().await;
205            if let Some(service) = connection.service.take() {
206                service.cancel().await.map_err(|e| {
207                    log::error!("Failed to cancel MCP service '{}': {}", name, e);
208                    AgentError::Other(format!("Failed to cancel MCP service: {e}"))
209                })?;
210                log::debug!("Successfully shut down MCP server '{}'", name);
211            }
212        }
213        Ok(())
214    }
215}
216
217// Global connection pool
218static CONNECTION_POOL: OnceLock<AsyncMutex<MCPConnectionPool>> = OnceLock::new();
219
220fn connection_pool() -> &'static AsyncMutex<MCPConnectionPool> {
221    CONNECTION_POOL.get_or_init(|| AsyncMutex::new(MCPConnectionPool::new()))
222}
223
224/// Shuts down all MCP server connections in the pool
225pub async fn shutdown_all_mcp_connections() -> Result<(), AgentError> {
226    log::info!("Shutting down all MCP server connections");
227    connection_pool().lock().await.shutdown_all().await?;
228    log::info!("All MCP server connections shut down successfully");
229    Ok(())
230}
231
232/// Registers tools from a single MCP server
233///
234/// # Arguments
235/// * `server_name` - Name of the MCP server
236/// * `server_config` - Configuration for the MCP server
237///
238/// # Returns
239/// A vector of registered tool names in the format "server_name::tool_name"
240async fn register_tools_from_server(
241    server_name: String,
242    server_config: MCPServerConfig,
243) -> Result<Vec<String>, AgentError> {
244    log::debug!("Registering tools from MCP server '{}'", server_name);
245
246    // Get or create connection from pool
247    let conn = {
248        let mut pool = connection_pool().lock().await;
249        pool.get_or_create(&server_name, &server_config).await?
250    };
251
252    // List all available tools from this server
253    log::debug!("Listing tools from MCP server '{}'", server_name);
254    let tools_list = {
255        let connection = conn.lock().await;
256        let service = connection.service.as_ref().ok_or_else(|| {
257            log::error!("MCP service for '{}' is not available", server_name);
258            AgentError::Other(format!(
259                "MCP service for '{}' is not available",
260                server_name
261            ))
262        })?;
263        service.list_tools(Default::default()).await.map_err(|e| {
264            log::error!("Failed to list MCP tools for '{}': {}", server_name, e);
265            AgentError::Other(format!(
266                "Failed to list MCP tools for '{}': {e}",
267                server_name
268            ))
269        })?
270    };
271
272    let mut registered_tool_names = Vec::new();
273
274    // Register all tools from this server using connection pool
275    for tool_info in tools_list.tools {
276        let mcp_tool_name = format!("{}::{}", server_name, tool_info.name);
277        registered_tool_names.push(mcp_tool_name.clone());
278
279        register_tool(MCPTool::new(
280            mcp_tool_name.clone(),
281            server_name.clone(),
282            server_config.clone(),
283            tool_info,
284        ));
285        log::debug!("Registered MCP tool '{}'", mcp_tool_name);
286    }
287
288    log::info!(
289        "Registered {} tools from MCP server '{}'",
290        registered_tool_names.len(),
291        server_name
292    );
293
294    Ok(registered_tool_names)
295}
296
297/// Loads MCP configuration from a JSON file and registers all tools
298///
299/// # Arguments
300/// * `json_path` - Path to the mcp.json file
301///
302/// # Returns
303/// A vector of registered tool names in the format "server_name::tool_name"
304///
305/// # Example
306/// ```no_run
307/// use agent_stream_kit::mcp::register_tools_from_mcp_json;
308///
309/// #[tokio::main]
310/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
311///     let tool_names = register_tools_from_mcp_json("mcp.json").await?;
312///     println!("Registered {} tools", tool_names.len());
313///     Ok(())
314/// }
315/// ```
316pub async fn register_tools_from_mcp_json<P: AsRef<Path>>(
317    json_path: P,
318) -> Result<Vec<String>, AgentError> {
319    let path = json_path.as_ref();
320    log::info!("Loading MCP configuration from: {}", path.display());
321
322    // Read the JSON file
323    let json_content = std::fs::read_to_string(path).map_err(|e| {
324        log::error!("Failed to read MCP config file '{}': {}", path.display(), e);
325        AgentError::Other(format!("Failed to read MCP config file: {e}"))
326    })?;
327
328    // Parse the JSON
329    let config: MCPConfig = serde_json::from_str(&json_content).map_err(|e| {
330        log::error!("Failed to parse MCP config JSON: {}", e);
331        AgentError::Other(format!("Failed to parse MCP config JSON: {e}"))
332    })?;
333
334    log::info!("Found {} MCP servers in config", config.mcp_servers.len());
335
336    let mut registered_tool_names = Vec::new();
337
338    // Iterate through each MCP server
339    for (server_name, server_config) in config.mcp_servers {
340        let tools = register_tools_from_server(server_name, server_config).await?;
341        registered_tool_names.extend(tools);
342    }
343
344    log::info!(
345        "Successfully registered {} MCP tools total",
346        registered_tool_names.len()
347    );
348
349    Ok(registered_tool_names)
350}
351
352fn call_tool_result_to_agent_value(result: CallToolResult) -> Result<AgentValue, AgentError> {
353    let mut contents = Vec::new();
354    for c in result.content.iter() {
355        match &c.raw {
356            rmcp::model::RawContent::Text(text) => {
357                contents.push(AgentValue::string(text.text.clone()));
358            }
359            _ => {
360                // Handle other content types as needed
361            }
362        }
363    }
364    let data = AgentValue::array(contents.into());
365    if result.is_error == Some(true) {
366        return Err(AgentError::Other(
367            serde_json::to_string(&data).map_err(|e| AgentError::InvalidValue(e.to_string()))?,
368        ));
369    }
370    Ok(data)
371}