1use super::transport::{McpMessage, ProcessTransport, Transport};
10use super::types::*;
11use anyhow::Result;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicI64, Ordering};
16use std::time::Duration;
17use tokio::sync::{RwLock, oneshot};
18use tokio::time::timeout;
19use tracing::{debug, error, info, warn};
20
21pub struct McpClient {
23 transport: Arc<dyn Transport>,
24 pending_requests: RwLock<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>,
25 request_id: AtomicI64,
26 server_info: RwLock<Option<ServerInfo>>,
27 server_capabilities: RwLock<Option<ServerCapabilities>>,
28 available_tools: RwLock<Vec<McpTool>>,
29 registry: Arc<McpRegistry>,
31 server_name: RwLock<Option<String>>,
33}
34
35impl McpClient {
36 pub async fn connect_subprocess(command: &str, args: &[&str]) -> Result<Arc<Self>> {
38 let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
39 let client = Arc::new(Self::new(transport));
40
41 let client_clone = Arc::clone(&client);
43 tokio::spawn(async move {
44 client_clone.receive_loop().await;
45 });
46
47 client.initialize().await?;
49
50 Ok(client)
51 }
52
53 pub fn new(transport: Arc<dyn Transport>) -> Self {
55 Self {
56 transport,
57 pending_requests: RwLock::new(HashMap::new()),
58 request_id: AtomicI64::new(1),
59 server_info: RwLock::new(None),
60 server_capabilities: RwLock::new(None),
61 available_tools: RwLock::new(Vec::new()),
62 registry: Arc::new(McpRegistry::new()),
63 server_name: RwLock::new(None),
64 }
65 }
66
67 pub fn with_registry(
69 transport: Arc<dyn Transport>,
70 registry: Arc<McpRegistry>,
71 name: Option<String>,
72 ) -> Self {
73 Self {
74 transport,
75 pending_requests: RwLock::new(HashMap::new()),
76 request_id: AtomicI64::new(1),
77 server_info: RwLock::new(None),
78 server_capabilities: RwLock::new(None),
79 available_tools: RwLock::new(Vec::new()),
80 registry,
81 server_name: RwLock::new(name),
82 }
83 }
84
85 pub async fn initialize(&self) -> Result<InitializeResult> {
87 let params = InitializeParams {
88 protocol_version: PROTOCOL_VERSION.to_string(),
89 capabilities: ClientCapabilities {
90 roots: Some(RootsCapability { list_changed: true }),
91 sampling: Some(SamplingCapability {}),
92 experimental: None,
93 },
94 client_info: ClientInfo {
95 name: "codetether".to_string(),
96 version: env!("CARGO_PKG_VERSION").to_string(),
97 },
98 };
99
100 let response = self
101 .request("initialize", Some(serde_json::to_value(¶ms)?))
102 .await?;
103 let result: InitializeResult = serde_json::from_value(response)?;
104
105 *self.server_info.write().await = Some(result.server_info.clone());
107 *self.server_capabilities.write().await = Some(result.capabilities.clone());
108
109 if let Some(name) = self.server_name.read().await.clone() {
111 debug!("Client initialized with server name: {}", name);
115 }
116
117 self.notify("notifications/initialized", None).await?;
119
120 info!(
121 "Connected to MCP server: {} v{}",
122 result.server_info.name, result.server_info.version
123 );
124
125 if result.capabilities.tools.is_some() {
127 self.refresh_tools().await?;
128 }
129
130 Ok(result)
131 }
132
133 pub fn registry(&self) -> Arc<McpRegistry> {
135 Arc::clone(&self.registry)
136 }
137
138 pub async fn server_name(&self) -> Option<String> {
140 self.server_name.read().await.clone()
141 }
142
143 pub async fn set_server_name(&self, name: String) {
145 *self.server_name.write().await = Some(name);
146 }
147
148 pub async fn has_capability(&self, capability: &str) -> bool {
150 let caps = self.server_capabilities.read().await;
151 match capability {
152 "tools" => caps.as_ref().map(|c| c.tools.is_some()).unwrap_or(false),
153 "resources" => caps
154 .as_ref()
155 .map(|c| c.resources.is_some())
156 .unwrap_or(false),
157 "prompts" => caps.as_ref().map(|c| c.prompts.is_some()).unwrap_or(false),
158 "logging" => caps.as_ref().map(|c| c.logging.is_some()).unwrap_or(false),
159 _ => false,
160 }
161 }
162
163 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
165 self.server_capabilities.read().await.clone()
166 }
167
168 pub async fn discover_tools_from_registry(&self) -> Vec<(String, McpTool)> {
170 self.registry.all_tools().await
171 }
172
173 pub async fn find_tool_in_registry(&self, tool_name: &str) -> Option<(String, McpTool)> {
175 self.registry.find_tool(tool_name).await
176 }
177
178 pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
180 let response = self.request("tools/list", None).await?;
181 let result: ListToolsResult = serde_json::from_value(response)?;
182
183 *self.available_tools.write().await = result.tools.clone();
184
185 info!("Loaded {} tools from MCP server", result.tools.len());
186
187 Ok(result.tools)
188 }
189
190 pub async fn tools(&self) -> Vec<McpTool> {
192 self.available_tools.read().await.clone()
193 }
194
195 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult> {
197 let params = CallToolParams {
198 name: name.to_string(),
199 arguments,
200 };
201
202 let response = self
203 .request("tools/call", Some(serde_json::to_value(¶ms)?))
204 .await?;
205 let result: CallToolResult = serde_json::from_value(response)?;
206
207 Ok(result)
208 }
209
210 pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
212 let response = self.request("resources/list", None).await?;
213 let result: ListResourcesResult = serde_json::from_value(response)?;
214 Ok(result.resources)
215 }
216
217 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
219 let params = ReadResourceParams {
220 uri: uri.to_string(),
221 };
222 let response = self
223 .request("resources/read", Some(serde_json::to_value(¶ms)?))
224 .await?;
225 let result: ReadResourceResult = serde_json::from_value(response)?;
226 Ok(result)
227 }
228
229 pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
231 let response = self.request("prompts/list", None).await?;
232 let result: ListPromptsResult = serde_json::from_value(response)?;
233 Ok(result.prompts)
234 }
235
236 pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
238 let params = GetPromptParams {
239 name: name.to_string(),
240 arguments,
241 };
242 let response = self
243 .request("prompts/get", Some(serde_json::to_value(¶ms)?))
244 .await?;
245 let result: GetPromptResult = serde_json::from_value(response)?;
246 Ok(result)
247 }
248
249 async fn request(&self, method: &str, params: Option<Value>) -> Result<Value> {
251 let id = RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst));
252 let request = JsonRpcRequest::new(id.clone(), method, params);
253
254 let (tx, rx) = oneshot::channel();
256 self.pending_requests.write().await.insert(id.clone(), tx);
257
258 self.transport.send_request(request).await?;
260
261 let response = timeout(Duration::from_secs(30), rx)
263 .await
264 .map_err(|_| anyhow::anyhow!("Request timed out"))??;
265
266 if let Some(error) = response.error {
267 return Err(anyhow::anyhow!(
268 "MCP error {}: {}",
269 error.code,
270 error.message
271 ));
272 }
273
274 response
275 .result
276 .ok_or_else(|| anyhow::anyhow!("Empty response"))
277 }
278
279 async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
281 let notification = JsonRpcNotification::new(method, params);
282 self.transport.send_notification(notification).await
283 }
284
285 async fn receive_loop(&self) {
287 loop {
288 match self.transport.receive().await {
289 Ok(Some(message)) => {
290 self.handle_message(message).await;
291 }
292 Ok(None) => {
293 info!("MCP connection closed");
294 break;
295 }
296 Err(e) => {
297 error!("Error receiving MCP message: {}", e);
298 break;
299 }
300 }
301 }
302 }
303
304 async fn handle_message(&self, message: McpMessage) {
306 match message {
307 McpMessage::Response(response) => {
308 if let Some(tx) = self.pending_requests.write().await.remove(&response.id) {
310 let _ = tx.send(response);
311 } else {
312 warn!("Received response for unknown request: {:?}", response.id);
313 }
314 }
315 McpMessage::Request(request) => {
316 debug!("Received request from server: {}", request.method);
318
319 let response = match request.method.as_str() {
320 "sampling/createMessage" => {
321 JsonRpcResponse::error(
324 request.id,
325 JsonRpcError::method_not_found("Sampling not yet implemented"),
326 )
327 }
328 _ => JsonRpcResponse::error(
329 request.id,
330 JsonRpcError::method_not_found(&request.method),
331 ),
332 };
333
334 if let Err(e) = self.transport.send_response(response).await {
335 error!("Failed to send response: {}", e);
336 }
337 }
338 McpMessage::Notification(notification) => {
339 debug!("Received notification: {}", notification.method);
340
341 match notification.method.as_str() {
342 "notifications/tools/list_changed" => {
343 info!("Tools list changed, refreshing...");
344 if let Err(e) = self.refresh_tools().await {
345 error!("Failed to refresh tools: {}", e);
346 }
347 }
348 "notifications/resources/list_changed" => {
349 info!("Resources list changed");
350 }
351 _ => {
352 debug!("Unknown notification: {}", notification.method);
353 }
354 }
355 }
356 }
357 }
358
359 pub async fn close(&self) -> Result<()> {
361 self.transport.close().await
362 }
363}
364
365pub struct McpRegistry {
371 clients: RwLock<HashMap<String, Arc<McpClient>>>,
372 server_capabilities: RwLock<HashMap<String, ServerCapabilities>>,
374 tool_index: RwLock<HashMap<String, String>>, }
377
378impl McpRegistry {
379 pub fn new() -> Self {
381 Self {
382 clients: RwLock::new(HashMap::new()),
383 server_capabilities: RwLock::new(HashMap::new()),
384 tool_index: RwLock::new(HashMap::new()),
385 }
386 }
387
388 pub async fn connect(
390 &self,
391 name: &str,
392 command: &str,
393 args: &[&str],
394 ) -> Result<Arc<McpClient>> {
395 let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
396 let client = Arc::new(McpClient::with_registry(
397 transport,
398 Arc::new(McpRegistry::new()), Some(name.to_string()),
400 ));
401
402 let client_clone = Arc::clone(&client);
404 tokio::spawn(async move {
405 client_clone.receive_loop().await;
406 });
407
408 let init_result = client.initialize().await?;
410
411 self.register(name, Arc::clone(&client), init_result.capabilities)
413 .await;
414
415 Ok(client)
416 }
417
418 pub async fn register(
420 &self,
421 name: &str,
422 client: Arc<McpClient>,
423 capabilities: ServerCapabilities,
424 ) {
425 self.clients.write().await.insert(name.to_string(), client);
427
428 self.server_capabilities
430 .write()
431 .await
432 .insert(name.to_string(), capabilities);
433
434 info!("Registered MCP server '{}' with registry", name);
435 }
436
437 pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
439 self.clients.read().await.get(name).cloned()
440 }
441
442 pub async fn list(&self) -> Vec<String> {
444 self.clients.read().await.keys().cloned().collect()
445 }
446
447 pub async fn get_capabilities(&self, name: &str) -> Option<ServerCapabilities> {
449 self.server_capabilities.read().await.get(name).cloned()
450 }
451
452 pub async fn has_capability(&self, name: &str, capability: &str) -> bool {
454 let caps = self.server_capabilities.read().await;
455 caps.get(name)
456 .map(|c| match capability {
457 "tools" => c.tools.is_some(),
458 "resources" => c.resources.is_some(),
459 "prompts" => c.prompts.is_some(),
460 "logging" => c.logging.is_some(),
461 _ => false,
462 })
463 .unwrap_or(false)
464 }
465
466 pub async fn list_by_capability(&self, capability: &str) -> Vec<String> {
468 let mut result = Vec::new();
469 let caps = self.server_capabilities.read().await;
470
471 for (name, caps) in caps.iter() {
472 let has_cap = match capability {
473 "tools" => caps.tools.is_some(),
474 "resources" => caps.resources.is_some(),
475 "prompts" => caps.prompts.is_some(),
476 "logging" => caps.logging.is_some(),
477 _ => false,
478 };
479 if has_cap {
480 result.push(name.clone());
481 }
482 }
483
484 result
485 }
486
487 pub async fn disconnect(&self, name: &str) -> Result<()> {
489 if let Some(client) = self.clients.write().await.remove(name) {
490 self.server_capabilities.write().await.remove(name);
492 let mut tool_index = self.tool_index.write().await;
494 tool_index.retain(|_, server| server != name);
495 client.close().await?;
497 }
498 Ok(())
499 }
500
501 pub async fn all_tools(&self) -> Vec<(String, McpTool)> {
503 let mut all_tools = Vec::new();
504
505 for (name, client) in self.clients.read().await.iter() {
506 for tool in client.tools().await {
507 all_tools.push((name.clone(), tool));
508 }
509 }
510
511 all_tools
512 }
513
514 pub async fn find_tool(&self, tool_name: &str) -> Option<(String, McpTool)> {
516 if let Some(server_name) = self.tool_index.read().await.get(tool_name) {
518 if let Some(client) = self.get(server_name).await {
519 if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
520 return Some((server_name.clone(), tool.clone()));
521 }
522 }
523 }
524
525 for (name, client) in self.clients.read().await.iter() {
527 if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
528 self.tool_index
530 .write()
531 .await
532 .insert(tool_name.to_string(), name.clone());
533 return Some((name.clone(), tool.clone()));
534 }
535 }
536
537 None
538 }
539
540 pub async fn refresh_tool_index(&self) {
542 let mut tool_index = self.tool_index.write().await;
543 tool_index.clear();
544
545 for (name, client) in self.clients.read().await.iter() {
546 for tool in client.tools().await {
547 tool_index.insert(tool.name.clone(), name.clone());
548 }
549 }
550
551 info!("Refreshed tool index with {} tools", tool_index.len());
552 }
553
554 pub async fn call_tool(
556 &self,
557 server: &str,
558 tool: &str,
559 arguments: Value,
560 ) -> Result<CallToolResult> {
561 let client = self
562 .get(server)
563 .await
564 .ok_or_else(|| anyhow::anyhow!("Server not found: {}", server))?;
565 client.call_tool(tool, arguments).await
566 }
567
568 pub async fn call_tool_auto(
570 &self,
571 tool_name: &str,
572 arguments: Value,
573 ) -> Result<CallToolResult> {
574 let (server, _) = self
575 .find_tool(tool_name)
576 .await
577 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", tool_name))?;
578 self.call_tool(&server, tool_name, arguments).await
579 }
580}
581
582impl Default for McpRegistry {
583 fn default() -> Self {
584 Self::new()
585 }
586}