1use anyhow::{anyhow, Result};
8use rmcp::{
9 model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation, Tool},
10 service::RoleClient,
11 service::RunningService,
12 transport::{ConfigureCommandExt, SseClientTransport, TokioChildProcess},
13 ServiceExt,
14};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::path::PathBuf;
18use tokio::process::Command;
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub enum McpServerType {
23 Stdio,
24 Sse,
25 Streamable,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct McpServerConfig {
30 pub name: String,
31 pub server_type: McpServerType,
32 pub command_or_url: String,
33 #[serde(default)]
34 pub env: HashMap<String, String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct McpConfig {
39 pub servers: HashMap<String, McpServerConfig>,
40}
41
42impl McpConfig {
43 pub fn new() -> Self {
44 Self {
45 servers: HashMap::new(),
46 }
47 }
48
49 pub async fn load() -> Result<Self> {
50 let config_dir = crate::config::Config::config_dir()?;
51 let mcp_config_path = config_dir.join("mcp.toml");
52
53 if !mcp_config_path.exists() {
54 return Ok(Self::new());
55 }
56
57 let content = tokio::fs::read_to_string(&mcp_config_path).await?;
58 let config: McpConfig = toml::from_str(&content)?;
59 Ok(config)
60 }
61
62 pub async fn save(&self) -> Result<()> {
63 let config_dir = crate::config::Config::config_dir()?;
64 tokio::fs::create_dir_all(&config_dir).await?;
65
66 let mcp_config_path = config_dir.join("mcp.toml");
67 let content = toml::to_string_pretty(self)?;
68 tokio::fs::write(&mcp_config_path, content).await?;
69 Ok(())
70 }
71
72 #[allow(dead_code)]
73 pub fn add_server(
74 &mut self,
75 name: String,
76 command_or_url: String,
77 server_type: McpServerType,
78 ) -> Result<()> {
79 self.add_server_with_env(name, command_or_url, server_type, HashMap::new())
80 }
81
82 pub fn add_server_with_env(
83 &mut self,
84 name: String,
85 command_or_url: String,
86 server_type: McpServerType,
87 env: HashMap<String, String>,
88 ) -> Result<()> {
89 let server_config = McpServerConfig {
90 name: name.clone(),
91 server_type,
92 command_or_url,
93 env,
94 };
95 self.servers.insert(name, server_config);
96 Ok(())
97 }
98
99 pub fn delete_server(&mut self, name: &str) -> Result<()> {
100 if self.servers.remove(name).is_none() {
101 return Err(anyhow!("MCP server '{}' not found", name));
102 }
103 Ok(())
104 }
105
106 pub fn get_server(&self, name: &str) -> Option<&McpServerConfig> {
107 self.servers.get(name)
108 }
109
110 pub fn list_servers(&self) -> HashMap<String, &McpServerConfig> {
111 self.servers.iter().map(|(k, v)| (k.clone(), v)).collect()
112 }
113}
114
115pub struct SdkMcpManager {
119 pub clients: HashMap<String, RunningService<RoleClient, ClientInfo>>,
120}
121
122use std::sync::Arc;
124use tokio::sync::Mutex;
125
126lazy_static::lazy_static! {
127 static ref GLOBAL_MCP_MANAGER: Arc<Mutex<SdkMcpManager>> = Arc::new(Mutex::new(SdkMcpManager::new()));
128}
129
130#[allow(dead_code)]
132pub async fn get_global_manager() -> Arc<Mutex<SdkMcpManager>> {
133 GLOBAL_MCP_MANAGER.clone()
134}
135
136#[allow(dead_code)]
137pub async fn ensure_server_connected(server_name: &str, config: SdkMcpServerConfig) -> Result<()> {
138 let manager = get_global_manager().await;
139 let mut manager_lock = manager.lock().await;
140
141 if !manager_lock.clients.contains_key(server_name) {
143 crate::debug_log!(
144 "GLOBAL_MANAGER: Connecting to MCP server '{}' (not already connected)",
145 server_name
146 );
147 manager_lock.add_server(config).await?;
148 crate::debug_log!(
149 "GLOBAL_MANAGER: Successfully connected to MCP server '{}'. Total connections: {}",
150 server_name,
151 manager_lock.clients.len()
152 );
153 } else {
154 crate::debug_log!(
155 "GLOBAL_MANAGER: MCP server '{}' already connected. Total connections: {}",
156 server_name,
157 manager_lock.clients.len()
158 );
159 }
160
161 Ok(())
162}
163
164#[allow(dead_code)]
165pub async fn call_global_tool(
166 server_name: &str,
167 tool_name: &str,
168 arguments: serde_json::Value,
169) -> Result<serde_json::Value> {
170 let manager = get_global_manager().await;
171 let manager_lock = manager.lock().await;
172
173 crate::debug_log!(
174 "GLOBAL_MANAGER: Calling tool '{}' on server '{}'. Total connections: {}",
175 tool_name,
176 server_name,
177 manager_lock.clients.len()
178 );
179
180 if !manager_lock.clients.contains_key(server_name) {
181 crate::debug_log!(
182 "GLOBAL_MANAGER: ERROR - Server '{}' not found in global manager!",
183 server_name
184 );
185 return Err(anyhow::anyhow!(
186 "Server '{}' not found in global manager",
187 server_name
188 ));
189 }
190
191 let result = manager_lock
192 .call_tool(server_name, tool_name, arguments)
193 .await;
194
195 crate::debug_log!(
196 "GLOBAL_MANAGER: Tool call completed. Connection still active: {}",
197 manager_lock.clients.contains_key(server_name)
198 );
199
200 result
201}
202
203#[allow(dead_code)]
204pub async fn list_global_tools() -> Result<HashMap<String, Vec<Tool>>> {
205 let manager = get_global_manager().await;
206 let manager_lock = manager.lock().await;
207 manager_lock.list_all_tools().await
208}
209
210#[allow(dead_code)]
211pub async fn close_global_server(server_name: &str) -> Result<()> {
212 let manager = get_global_manager().await;
213 let mut manager_lock = manager.lock().await;
214
215 if let Some(client) = manager_lock.clients.remove(server_name) {
216 let _ = client.cancel().await;
217 crate::debug_log!("Closed connection to MCP server '{}'", server_name);
218 }
219
220 Ok(())
221}
222
223impl SdkMcpManager {
224 pub fn new() -> Self {
225 Self {
226 clients: HashMap::new(),
227 }
228 }
229
230 pub async fn add_server(&mut self, config: SdkMcpServerConfig) -> Result<()> {
231 crate::debug_log!(
232 "SdkMcpManager: Adding server '{}' with transport: {:?}",
233 config.name,
234 config.transport
235 );
236
237 let client_info = ClientInfo {
238 protocol_version: Default::default(),
239 capabilities: ClientCapabilities::default(),
240 client_info: Implementation {
241 name: "lc-mcp-client".to_string(),
242 version: "0.1.0".to_string(),
243 },
244 };
245
246 let client = match config.transport {
247 SdkMcpTransport::Stdio {
248 command,
249 args,
250 env,
251 cwd,
252 } => {
253 crate::debug_log!(
254 "SdkMcpManager: Creating STDIO transport with command: {} args: {:?}",
255 command,
256 args
257 );
258
259 let mut cmd = Command::new(&command);
260 if let Some(args) = args {
261 cmd.args(&args);
262 crate::debug_log!("SdkMcpManager: Added args: {:?}", args);
263 }
264 if let Some(env) = env {
265 let env_count = env.len();
266 for (key, value) in env {
267 crate::debug_log!("SdkMcpManager: Setting env var {}={}", key, value);
268 cmd.env(key, value);
269 }
270 crate::debug_log!("SdkMcpManager: Added {} env vars", env_count);
271 } else {
272 crate::debug_log!("SdkMcpManager: No env vars to add");
273 }
274 if let Some(cwd) = cwd {
275 cmd.current_dir(cwd);
276 crate::debug_log!("SdkMcpManager: Set working directory");
277 }
278
279 cmd.stdin(std::process::Stdio::piped());
281 cmd.stdout(std::process::Stdio::piped());
282 cmd.stderr(std::process::Stdio::piped());
283
284 crate::debug_log!("SdkMcpManager: Creating TokioChildProcess transport");
285 let transport = TokioChildProcess::new(cmd.configure(|_| {}))?;
286 crate::debug_log!("SdkMcpManager: Starting client connection");
287 client_info.serve(transport).await?
288 }
289 SdkMcpTransport::Sse { url } => {
290 crate::debug_log!("SdkMcpManager: Creating SSE transport with URL: {}", url);
291 let transport = SseClientTransport::start(url.as_str()).await?;
292 crate::debug_log!("SdkMcpManager: Starting client connection");
293 client_info.serve(transport).await?
294 }
295 };
296
297 crate::debug_log!(
298 "SdkMcpManager: Successfully connected to server '{}'",
299 config.name
300 );
301 self.clients.insert(config.name, client);
302 Ok(())
303 }
304
305 pub async fn list_all_tools(&self) -> Result<HashMap<String, Vec<Tool>>> {
306 let mut all_tools = HashMap::new();
307
308 crate::debug_log!(
309 "SdkMcpManager: Listing tools from {} connected servers",
310 self.clients.len()
311 );
312
313 for (server_name, client) in &self.clients {
314 crate::debug_log!(
315 "SdkMcpManager: Requesting tools from server '{}'",
316 server_name
317 );
318 match client.list_tools(Default::default()).await {
319 Ok(tools_result) => {
320 crate::debug_log!(
321 "SdkMcpManager: Server '{}' returned {} tools",
322 server_name,
323 tools_result.tools.len()
324 );
325 all_tools.insert(server_name.clone(), tools_result.tools);
326 }
327 Err(e) => {
328 crate::debug_log!(
329 "SdkMcpManager: Failed to list tools from server '{}': {}",
330 server_name,
331 e
332 );
333 eprintln!(
334 "Warning: Failed to list tools from server '{}': {}",
335 server_name, e
336 );
337 }
338 }
339 }
340
341 crate::debug_log!(
342 "SdkMcpManager: Total tools collected from {} servers",
343 all_tools.len()
344 );
345 Ok(all_tools)
346 }
347
348 pub async fn call_tool(
349 &self,
350 server_name: &str,
351 tool_name: &str,
352 arguments: serde_json::Value,
353 ) -> Result<serde_json::Value> {
354 let client = self
355 .clients
356 .get(server_name)
357 .ok_or_else(|| anyhow!("Server '{}' not found", server_name))?;
358
359 let result = client
360 .call_tool(CallToolRequestParam {
361 name: tool_name.to_string().into(),
362 arguments: arguments.as_object().cloned(),
363 })
364 .await?;
365
366 Ok(serde_json::to_value(result)?)
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct SdkMcpServerConfig {
374 pub name: String,
375 pub transport: SdkMcpTransport,
376}
377
378#[derive(Debug, Clone)]
379pub enum SdkMcpTransport {
380 Stdio {
381 command: String,
382 args: Option<Vec<String>>,
383 env: Option<HashMap<String, String>>,
384 cwd: Option<PathBuf>,
385 },
386 Sse {
387 url: String,
388 },
389}
390
391#[allow(dead_code)]
393pub fn create_stdio_server_config(
394 name: String,
395 command_parts: Vec<String>,
396 env: Option<HashMap<String, String>>,
397 cwd: Option<PathBuf>,
398) -> SdkMcpServerConfig {
399 let (command, args) = if command_parts.is_empty() {
400 ("echo".to_string(), None)
401 } else if command_parts.len() == 1 {
402 (command_parts[0].clone(), None)
403 } else {
404 (command_parts[0].clone(), Some(command_parts[1..].to_vec()))
405 };
406
407 SdkMcpServerConfig {
408 name,
409 transport: SdkMcpTransport::Stdio {
410 command,
411 args,
412 env,
413 cwd,
414 },
415 }
416}
417
418#[allow(dead_code)]
419pub fn create_sse_server_config(name: String, url: String) -> SdkMcpServerConfig {
420 SdkMcpServerConfig {
421 name,
422 transport: SdkMcpTransport::Sse { url },
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_mcp_config_creation() {
432 let config = McpConfig::new();
433 assert!(config.servers.is_empty());
434 }
435
436 #[test]
437 fn test_add_server() {
438 let mut config = McpConfig::new();
439 config
440 .add_server(
441 "test-server".to_string(),
442 "echo test".to_string(),
443 McpServerType::Stdio,
444 )
445 .unwrap();
446
447 assert_eq!(config.servers.len(), 1);
448 let server = config.get_server("test-server").unwrap();
449 assert_eq!(server.name, "test-server");
450 assert_eq!(server.command_or_url, "echo test");
451 assert_eq!(server.server_type, McpServerType::Stdio);
452 }
453
454 #[test]
455 fn test_sdk_manager_creation() {
456 let manager = SdkMcpManager::new();
457 assert!(manager.clients.is_empty());
458 }
459
460 #[test]
461 fn test_create_stdio_config() {
462 let config = create_stdio_server_config(
463 "test".to_string(),
464 vec!["echo".to_string(), "hello".to_string()],
465 None,
466 None,
467 );
468 assert_eq!(config.name, "test");
469 match config.transport {
470 SdkMcpTransport::Stdio { command, args, .. } => {
471 assert_eq!(command, "echo");
472 assert_eq!(args, Some(vec!["hello".to_string()]));
473 }
474 _ => panic!("Expected Stdio transport"),
475 }
476 }
477
478 #[test]
479 fn test_create_sse_config() {
480 let config =
481 create_sse_server_config("test".to_string(), "http://localhost:8080/sse".to_string());
482 assert_eq!(config.name, "test");
483 match config.transport {
484 SdkMcpTransport::Sse { url } => {
485 assert_eq!(url, "http://localhost:8080/sse");
486 }
487 _ => panic!("Expected SSE transport"),
488 }
489 }
490}