1use super::transport::{McpMessage, ProcessTransport, Transport};
10use super::types::*;
11use anyhow::Result;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicI64, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{oneshot, RwLock};
18use tokio::time::timeout;
19use tracing::{debug, error, info, warn};
20
21pub struct McpClient {
23 transport: Arc<dyn Transport>,
24 pending_requests: RwLock<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>,
25 request_id: AtomicI64,
26 server_info: RwLock<Option<ServerInfo>>,
27 server_capabilities: RwLock<Option<ServerCapabilities>>,
28 available_tools: RwLock<Vec<McpTool>>,
29 registry: Arc<McpRegistry>,
31 server_name: RwLock<Option<String>>,
33}
34
35impl McpClient {
36 pub async fn connect_subprocess(command: &str, args: &[&str]) -> Result<Arc<Self>> {
38 let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
39 let client = Arc::new(Self::new(transport));
40
41 let client_clone = Arc::clone(&client);
43 tokio::spawn(async move {
44 client_clone.receive_loop().await;
45 });
46
47 client.initialize().await?;
49
50 Ok(client)
51 }
52
53 pub fn new(transport: Arc<dyn Transport>) -> Self {
55 Self {
56 transport,
57 pending_requests: RwLock::new(HashMap::new()),
58 request_id: AtomicI64::new(1),
59 server_info: RwLock::new(None),
60 server_capabilities: RwLock::new(None),
61 available_tools: RwLock::new(Vec::new()),
62 registry: Arc::new(McpRegistry::new()),
63 server_name: RwLock::new(None),
64 }
65 }
66
67 pub fn with_registry(transport: Arc<dyn Transport>, registry: Arc<McpRegistry>, name: Option<String>) -> Self {
69 Self {
70 transport,
71 pending_requests: RwLock::new(HashMap::new()),
72 request_id: AtomicI64::new(1),
73 server_info: RwLock::new(None),
74 server_capabilities: RwLock::new(None),
75 available_tools: RwLock::new(Vec::new()),
76 registry,
77 server_name: RwLock::new(name),
78 }
79 }
80
81 pub async fn initialize(&self) -> Result<InitializeResult> {
83 let params = InitializeParams {
84 protocol_version: PROTOCOL_VERSION.to_string(),
85 capabilities: ClientCapabilities {
86 roots: Some(RootsCapability { list_changed: true }),
87 sampling: Some(SamplingCapability {}),
88 experimental: None,
89 },
90 client_info: ClientInfo {
91 name: "codetether".to_string(),
92 version: env!("CARGO_PKG_VERSION").to_string(),
93 },
94 };
95
96 let response = self.request("initialize", Some(serde_json::to_value(¶ms)?)).await?;
97 let result: InitializeResult = serde_json::from_value(response)?;
98
99 *self.server_info.write().await = Some(result.server_info.clone());
101 *self.server_capabilities.write().await = Some(result.capabilities.clone());
102
103 if let Some(name) = self.server_name.read().await.clone() {
105 debug!("Client initialized with server name: {}", name);
109 }
110
111 self.notify("notifications/initialized", None).await?;
113
114 info!(
115 "Connected to MCP server: {} v{}",
116 result.server_info.name, result.server_info.version
117 );
118
119 if result.capabilities.tools.is_some() {
121 self.refresh_tools().await?;
122 }
123
124 Ok(result)
125 }
126
127 pub fn registry(&self) -> Arc<McpRegistry> {
129 Arc::clone(&self.registry)
130 }
131
132 pub async fn server_name(&self) -> Option<String> {
134 self.server_name.read().await.clone()
135 }
136
137 pub async fn set_server_name(&self, name: String) {
139 *self.server_name.write().await = Some(name);
140 }
141
142 pub async fn has_capability(&self, capability: &str) -> bool {
144 let caps = self.server_capabilities.read().await;
145 match capability {
146 "tools" => caps.as_ref().map(|c| c.tools.is_some()).unwrap_or(false),
147 "resources" => caps.as_ref().map(|c| c.resources.is_some()).unwrap_or(false),
148 "prompts" => caps.as_ref().map(|c| c.prompts.is_some()).unwrap_or(false),
149 "logging" => caps.as_ref().map(|c| c.logging.is_some()).unwrap_or(false),
150 _ => false,
151 }
152 }
153
154 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
156 self.server_capabilities.read().await.clone()
157 }
158
159 pub async fn discover_tools_from_registry(&self) -> Vec<(String, McpTool)> {
161 self.registry.all_tools().await
162 }
163
164 pub async fn find_tool_in_registry(&self, tool_name: &str) -> Option<(String, McpTool)> {
166 self.registry.find_tool(tool_name).await
167 }
168
169 pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
171 let response = self.request("tools/list", None).await?;
172 let result: ListToolsResult = serde_json::from_value(response)?;
173
174 *self.available_tools.write().await = result.tools.clone();
175
176 info!("Loaded {} tools from MCP server", result.tools.len());
177
178 Ok(result.tools)
179 }
180
181 pub async fn tools(&self) -> Vec<McpTool> {
183 self.available_tools.read().await.clone()
184 }
185
186 pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult> {
188 let params = CallToolParams {
189 name: name.to_string(),
190 arguments,
191 };
192
193 let response = self.request("tools/call", Some(serde_json::to_value(¶ms)?)).await?;
194 let result: CallToolResult = serde_json::from_value(response)?;
195
196 Ok(result)
197 }
198
199 pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
201 let response = self.request("resources/list", None).await?;
202 let result: ListResourcesResult = serde_json::from_value(response)?;
203 Ok(result.resources)
204 }
205
206 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
208 let params = ReadResourceParams { uri: uri.to_string() };
209 let response = self.request("resources/read", Some(serde_json::to_value(¶ms)?)).await?;
210 let result: ReadResourceResult = serde_json::from_value(response)?;
211 Ok(result)
212 }
213
214 pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
216 let response = self.request("prompts/list", None).await?;
217 let result: ListPromptsResult = serde_json::from_value(response)?;
218 Ok(result.prompts)
219 }
220
221 pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
223 let params = GetPromptParams {
224 name: name.to_string(),
225 arguments,
226 };
227 let response = self.request("prompts/get", Some(serde_json::to_value(¶ms)?)).await?;
228 let result: GetPromptResult = serde_json::from_value(response)?;
229 Ok(result)
230 }
231
232 async fn request(&self, method: &str, params: Option<Value>) -> Result<Value> {
234 let id = RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst));
235 let request = JsonRpcRequest::new(id.clone(), method, params);
236
237 let (tx, rx) = oneshot::channel();
239 self.pending_requests.write().await.insert(id.clone(), tx);
240
241 self.transport.send_request(request).await?;
243
244 let response = timeout(Duration::from_secs(30), rx)
246 .await
247 .map_err(|_| anyhow::anyhow!("Request timed out"))??;
248
249 if let Some(error) = response.error {
250 return Err(anyhow::anyhow!("MCP error {}: {}", error.code, error.message));
251 }
252
253 response.result.ok_or_else(|| anyhow::anyhow!("Empty response"))
254 }
255
256 async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
258 let notification = JsonRpcNotification::new(method, params);
259 self.transport.send_notification(notification).await
260 }
261
262 async fn receive_loop(&self) {
264 loop {
265 match self.transport.receive().await {
266 Ok(Some(message)) => {
267 self.handle_message(message).await;
268 }
269 Ok(None) => {
270 info!("MCP connection closed");
271 break;
272 }
273 Err(e) => {
274 error!("Error receiving MCP message: {}", e);
275 break;
276 }
277 }
278 }
279 }
280
281 async fn handle_message(&self, message: McpMessage) {
283 match message {
284 McpMessage::Response(response) => {
285 if let Some(tx) = self.pending_requests.write().await.remove(&response.id) {
287 let _ = tx.send(response);
288 } else {
289 warn!("Received response for unknown request: {:?}", response.id);
290 }
291 }
292 McpMessage::Request(request) => {
293 debug!("Received request from server: {}", request.method);
295
296 let response = match request.method.as_str() {
297 "sampling/createMessage" => {
298 JsonRpcResponse::error(
301 request.id,
302 JsonRpcError::method_not_found("Sampling not yet implemented"),
303 )
304 }
305 _ => {
306 JsonRpcResponse::error(
307 request.id,
308 JsonRpcError::method_not_found(&request.method),
309 )
310 }
311 };
312
313 if let Err(e) = self.transport.send_response(response).await {
314 error!("Failed to send response: {}", e);
315 }
316 }
317 McpMessage::Notification(notification) => {
318 debug!("Received notification: {}", notification.method);
319
320 match notification.method.as_str() {
321 "notifications/tools/list_changed" => {
322 info!("Tools list changed, refreshing...");
323 if let Err(e) = self.refresh_tools().await {
324 error!("Failed to refresh tools: {}", e);
325 }
326 }
327 "notifications/resources/list_changed" => {
328 info!("Resources list changed");
329 }
330 _ => {
331 debug!("Unknown notification: {}", notification.method);
332 }
333 }
334 }
335 }
336 }
337
338 pub async fn close(&self) -> Result<()> {
340 self.transport.close().await
341 }
342}
343
344pub struct McpRegistry {
350 clients: RwLock<HashMap<String, Arc<McpClient>>>,
351 server_capabilities: RwLock<HashMap<String, ServerCapabilities>>,
353 tool_index: RwLock<HashMap<String, String>>, }
356
357impl McpRegistry {
358 pub fn new() -> Self {
360 Self {
361 clients: RwLock::new(HashMap::new()),
362 server_capabilities: RwLock::new(HashMap::new()),
363 tool_index: RwLock::new(HashMap::new()),
364 }
365 }
366
367 pub async fn connect(&self, name: &str, command: &str, args: &[&str]) -> Result<Arc<McpClient>> {
369 let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
370 let client = Arc::new(McpClient::with_registry(
371 transport,
372 Arc::new(McpRegistry::new()), Some(name.to_string())
374 ));
375
376 let client_clone = Arc::clone(&client);
378 tokio::spawn(async move {
379 client_clone.receive_loop().await;
380 });
381
382 let init_result = client.initialize().await?;
384
385 self.register(name, Arc::clone(&client), init_result.capabilities).await;
387
388 Ok(client)
389 }
390
391 pub async fn register(&self, name: &str, client: Arc<McpClient>, capabilities: ServerCapabilities) {
393 self.clients.write().await.insert(name.to_string(), client);
395
396 self.server_capabilities.write().await.insert(name.to_string(), capabilities);
398
399 info!("Registered MCP server '{}' with registry", name);
400 }
401
402 pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
404 self.clients.read().await.get(name).cloned()
405 }
406
407 pub async fn list(&self) -> Vec<String> {
409 self.clients.read().await.keys().cloned().collect()
410 }
411
412 pub async fn get_capabilities(&self, name: &str) -> Option<ServerCapabilities> {
414 self.server_capabilities.read().await.get(name).cloned()
415 }
416
417 pub async fn has_capability(&self, name: &str, capability: &str) -> bool {
419 let caps = self.server_capabilities.read().await;
420 caps.get(name).map(|c| {
421 match capability {
422 "tools" => c.tools.is_some(),
423 "resources" => c.resources.is_some(),
424 "prompts" => c.prompts.is_some(),
425 "logging" => c.logging.is_some(),
426 _ => false,
427 }
428 }).unwrap_or(false)
429 }
430
431 pub async fn list_by_capability(&self, capability: &str) -> Vec<String> {
433 let mut result = Vec::new();
434 let caps = self.server_capabilities.read().await;
435
436 for (name, caps) in caps.iter() {
437 let has_cap = match capability {
438 "tools" => caps.tools.is_some(),
439 "resources" => caps.resources.is_some(),
440 "prompts" => caps.prompts.is_some(),
441 "logging" => caps.logging.is_some(),
442 _ => false,
443 };
444 if has_cap {
445 result.push(name.clone());
446 }
447 }
448
449 result
450 }
451
452 pub async fn disconnect(&self, name: &str) -> Result<()> {
454 if let Some(client) = self.clients.write().await.remove(name) {
455 self.server_capabilities.write().await.remove(name);
457 let mut tool_index = self.tool_index.write().await;
459 tool_index.retain(|_, server| server != name);
460 client.close().await?;
462 }
463 Ok(())
464 }
465
466 pub async fn all_tools(&self) -> Vec<(String, McpTool)> {
468 let mut all_tools = Vec::new();
469
470 for (name, client) in self.clients.read().await.iter() {
471 for tool in client.tools().await {
472 all_tools.push((name.clone(), tool));
473 }
474 }
475
476 all_tools
477 }
478
479 pub async fn find_tool(&self, tool_name: &str) -> Option<(String, McpTool)> {
481 if let Some(server_name) = self.tool_index.read().await.get(tool_name) {
483 if let Some(client) = self.get(server_name).await {
484 if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
485 return Some((server_name.clone(), tool.clone()));
486 }
487 }
488 }
489
490 for (name, client) in self.clients.read().await.iter() {
492 if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
493 self.tool_index.write().await.insert(tool_name.to_string(), name.clone());
495 return Some((name.clone(), tool.clone()));
496 }
497 }
498
499 None
500 }
501
502 pub async fn refresh_tool_index(&self) {
504 let mut tool_index = self.tool_index.write().await;
505 tool_index.clear();
506
507 for (name, client) in self.clients.read().await.iter() {
508 for tool in client.tools().await {
509 tool_index.insert(tool.name.clone(), name.clone());
510 }
511 }
512
513 info!("Refreshed tool index with {} tools", tool_index.len());
514 }
515
516 pub async fn call_tool(&self, server: &str, tool: &str, arguments: Value) -> Result<CallToolResult> {
518 let client = self.get(server).await
519 .ok_or_else(|| anyhow::anyhow!("Server not found: {}", server))?;
520 client.call_tool(tool, arguments).await
521 }
522
523 pub async fn call_tool_auto(&self, tool_name: &str, arguments: Value) -> Result<CallToolResult> {
525 let (server, _) = self.find_tool(tool_name).await
526 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", tool_name))?;
527 self.call_tool(&server, tool_name, arguments).await
528 }
529}
530
531impl Default for McpRegistry {
532 fn default() -> Self {
533 Self::new()
534 }
535}