1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, types::*, validation::*};
14use crate::transport::traits::Transport;
15
16#[derive(Debug, Clone)]
18pub struct ClientConfig {
19 pub request_timeout_ms: u64,
21 pub max_retries: u32,
23 pub retry_delay_ms: u64,
25 pub validate_requests: bool,
27 pub validate_responses: bool,
29}
30
31impl Default for ClientConfig {
32 fn default() -> Self {
33 Self {
34 request_timeout_ms: 30000,
35 max_retries: 3,
36 retry_delay_ms: 1000,
37 validate_requests: true,
38 validate_responses: true,
39 }
40 }
41}
42
43pub struct McpClient {
45 info: ClientInfo,
47 capabilities: ClientCapabilities,
49 config: ClientConfig,
51 transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
53 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
55 server_info: Arc<RwLock<Option<ServerInfo>>>,
57 request_counter: Arc<Mutex<u64>>,
59 connected: Arc<RwLock<bool>>,
61}
62
63impl McpClient {
64 pub fn new(name: String, version: String) -> Self {
66 Self {
67 info: ClientInfo { name, version },
68 capabilities: ClientCapabilities::default(),
69 config: ClientConfig::default(),
70 transport: Arc::new(Mutex::new(None)),
71 server_capabilities: Arc::new(RwLock::new(None)),
72 server_info: Arc::new(RwLock::new(None)),
73 request_counter: Arc::new(Mutex::new(0)),
74 connected: Arc::new(RwLock::new(false)),
75 }
76 }
77
78 pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
80 let mut client = Self::new(name, version);
81 client.config = config;
82 client
83 }
84
85 pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
87 self.capabilities = capabilities;
88 }
89
90 pub fn info(&self) -> &ClientInfo {
92 &self.info
93 }
94
95 pub fn capabilities(&self) -> &ClientCapabilities {
97 &self.capabilities
98 }
99
100 pub fn config(&self) -> &ClientConfig {
102 &self.config
103 }
104
105 pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
107 let capabilities = self.server_capabilities.read().await;
108 capabilities.clone()
109 }
110
111 pub async fn server_info(&self) -> Option<ServerInfo> {
113 let info = self.server_info.read().await;
114 info.clone()
115 }
116
117 pub async fn is_connected(&self) -> bool {
119 let connected = self.connected.read().await;
120 *connected
121 }
122
123 pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
129 where
130 T: Transport + 'static,
131 {
132 {
134 let mut transport_guard = self.transport.lock().await;
135 *transport_guard = Some(Box::new(transport));
136 }
137
138 let init_result = self.initialize().await?;
140
141 {
143 let mut connected = self.connected.write().await;
144 *connected = true;
145 }
146
147 Ok(init_result)
148 }
149
150 pub async fn disconnect(&self) -> McpResult<()> {
152 {
154 let mut transport_guard = self.transport.lock().await;
155 if let Some(transport) = transport_guard.as_mut() {
156 transport.close().await?;
157 }
158 *transport_guard = None;
159 }
160
161 {
163 let mut server_capabilities = self.server_capabilities.write().await;
164 *server_capabilities = None;
165 }
166 {
167 let mut server_info = self.server_info.write().await;
168 *server_info = None;
169 }
170
171 {
173 let mut connected = self.connected.write().await;
174 *connected = false;
175 }
176
177 Ok(())
178 }
179
180 async fn initialize(&self) -> McpResult<InitializeResult> {
182 let params = InitializeParams::new(
183 self.info.clone(),
184 self.capabilities.clone(),
185 MCP_PROTOCOL_VERSION.to_string(),
186 );
187
188 let request = JsonRpcRequest::new(
189 Value::from(self.next_request_id().await),
190 methods::INITIALIZE.to_string(),
191 Some(params),
192 )?;
193
194 let response = self.send_request(request).await?;
195
196 if let Some(error) = response.error {
197 return Err(McpError::Protocol(format!(
198 "Initialize failed: {}",
199 error.message
200 )));
201 }
202
203 let result: InitializeResult = serde_json::from_value(
204 response
205 .result
206 .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
207 )?;
208
209 {
211 let mut server_capabilities = self.server_capabilities.write().await;
212 *server_capabilities = Some(result.capabilities.clone());
213 }
214 {
215 let mut server_info = self.server_info.write().await;
216 *server_info = Some(result.server_info.clone());
217 }
218
219 Ok(result)
220 }
221
222 pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
228 self.ensure_connected().await?;
229
230 let params = ListToolsParams { cursor };
231 let request = JsonRpcRequest::new(
232 Value::from(self.next_request_id().await),
233 methods::TOOLS_LIST.to_string(),
234 Some(params),
235 )?;
236
237 let response = self.send_request(request).await?;
238 self.handle_response(response)
239 }
240
241 pub async fn call_tool(
243 &self,
244 name: String,
245 arguments: Option<HashMap<String, Value>>,
246 ) -> McpResult<CallToolResult> {
247 self.ensure_connected().await?;
248
249 let params = CallToolParams::new(name, arguments);
250
251 if self.config.validate_requests {
252 validate_call_tool_params(¶ms)?;
253 }
254
255 let request = JsonRpcRequest::new(
256 Value::from(self.next_request_id().await),
257 methods::TOOLS_CALL.to_string(),
258 Some(params),
259 )?;
260
261 let response = self.send_request(request).await?;
262 self.handle_response(response)
263 }
264
265 pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
271 self.ensure_connected().await?;
272
273 let params = ListResourcesParams { cursor };
274 let request = JsonRpcRequest::new(
275 Value::from(self.next_request_id().await),
276 methods::RESOURCES_LIST.to_string(),
277 Some(params),
278 )?;
279
280 let response = self.send_request(request).await?;
281 self.handle_response(response)
282 }
283
284 pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
286 self.ensure_connected().await?;
287
288 let params = ReadResourceParams::new(uri);
289
290 if self.config.validate_requests {
291 validate_read_resource_params(¶ms)?;
292 }
293
294 let request = JsonRpcRequest::new(
295 Value::from(self.next_request_id().await),
296 methods::RESOURCES_READ.to_string(),
297 Some(params),
298 )?;
299
300 let response = self.send_request(request).await?;
301 self.handle_response(response)
302 }
303
304 pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
306 self.ensure_connected().await?;
307
308 let params = SubscribeResourceParams { uri };
309 let request = JsonRpcRequest::new(
310 Value::from(self.next_request_id().await),
311 methods::RESOURCES_SUBSCRIBE.to_string(),
312 Some(params),
313 )?;
314
315 let response = self.send_request(request).await?;
316 self.handle_response(response)
317 }
318
319 pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
321 self.ensure_connected().await?;
322
323 let params = UnsubscribeResourceParams { uri };
324 let request = JsonRpcRequest::new(
325 Value::from(self.next_request_id().await),
326 methods::RESOURCES_UNSUBSCRIBE.to_string(),
327 Some(params),
328 )?;
329
330 let response = self.send_request(request).await?;
331 self.handle_response(response)
332 }
333
334 pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
340 self.ensure_connected().await?;
341
342 let params = ListPromptsParams { cursor };
343 let request = JsonRpcRequest::new(
344 Value::from(self.next_request_id().await),
345 methods::PROMPTS_LIST.to_string(),
346 Some(params),
347 )?;
348
349 let response = self.send_request(request).await?;
350 self.handle_response(response)
351 }
352
353 pub async fn get_prompt(
355 &self,
356 name: String,
357 arguments: Option<HashMap<String, Value>>,
358 ) -> McpResult<GetPromptResult> {
359 self.ensure_connected().await?;
360
361 let params = GetPromptParams::new(name, arguments);
362
363 if self.config.validate_requests {
364 validate_get_prompt_params(¶ms)?;
365 }
366
367 let request = JsonRpcRequest::new(
368 Value::from(self.next_request_id().await),
369 methods::PROMPTS_GET.to_string(),
370 Some(params),
371 )?;
372
373 let response = self.send_request(request).await?;
374 self.handle_response(response)
375 }
376
377 pub async fn create_message(
383 &self,
384 params: CreateMessageParams,
385 ) -> McpResult<CreateMessageResult> {
386 self.ensure_connected().await?;
387
388 {
390 let server_capabilities = self.server_capabilities.read().await;
391 if let Some(capabilities) = server_capabilities.as_ref() {
392 if capabilities.sampling.is_none() {
393 return Err(McpError::Protocol(
394 "Server does not support sampling".to_string(),
395 ));
396 }
397 } else {
398 return Err(McpError::Protocol("Not connected to server".to_string()));
399 }
400 }
401
402 if self.config.validate_requests {
403 validate_create_message_params(¶ms)?;
404 }
405
406 let request = JsonRpcRequest::new(
407 Value::from(self.next_request_id().await),
408 methods::SAMPLING_CREATE_MESSAGE.to_string(),
409 Some(params),
410 )?;
411
412 let response = self.send_request(request).await?;
413 self.handle_response(response)
414 }
415
416 pub async fn ping(&self) -> McpResult<PingResult> {
422 self.ensure_connected().await?;
423
424 let request = JsonRpcRequest::new(
425 Value::from(self.next_request_id().await),
426 methods::PING.to_string(),
427 Some(PingParams {}),
428 )?;
429
430 let response = self.send_request(request).await?;
431 self.handle_response(response)
432 }
433
434 pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
436 self.ensure_connected().await?;
437
438 let params = SetLoggingLevelParams { level };
439 let request = JsonRpcRequest::new(
440 Value::from(self.next_request_id().await),
441 methods::LOGGING_SET_LEVEL.to_string(),
442 Some(params),
443 )?;
444
445 let response = self.send_request(request).await?;
446 self.handle_response(response)
447 }
448
449 pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
455 let mut transport_guard = self.transport.lock().await;
456 if let Some(transport) = transport_guard.as_mut() {
457 transport.receive_notification().await
458 } else {
459 Err(McpError::Transport("Not connected".to_string()))
460 }
461 }
462
463 async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
469 if self.config.validate_requests {
470 validate_jsonrpc_request(&request)?;
471 validate_mcp_request(&request.method, request.params.as_ref())?;
472 }
473
474 let mut transport_guard = self.transport.lock().await;
475 if let Some(transport) = transport_guard.as_mut() {
476 let response = transport.send_request(request).await?;
477
478 if self.config.validate_responses {
479 validate_jsonrpc_response(&response)?;
480 }
481
482 Ok(response)
483 } else {
484 Err(McpError::Transport("Not connected".to_string()))
485 }
486 }
487
488 fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
490 where
491 T: serde::de::DeserializeOwned,
492 {
493 if let Some(error) = response.error {
494 return Err(McpError::Protocol(format!(
495 "Server error: {}",
496 error.message
497 )));
498 }
499
500 let result = response
501 .result
502 .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
503
504 serde_json::from_value(result).map_err(McpError::Serialization)
505 }
506
507 async fn ensure_connected(&self) -> McpResult<()> {
509 if !self.is_connected().await {
510 return Err(McpError::Connection("Not connected to server".to_string()));
511 }
512 Ok(())
513 }
514
515 async fn next_request_id(&self) -> u64 {
517 let mut counter = self.request_counter.lock().await;
518 *counter += 1;
519 *counter
520 }
521}
522
523pub struct McpClientBuilder {
525 name: String,
526 version: String,
527 capabilities: ClientCapabilities,
528 config: ClientConfig,
529}
530
531impl McpClientBuilder {
532 pub fn new(name: String, version: String) -> Self {
534 Self {
535 name,
536 version,
537 capabilities: ClientCapabilities::default(),
538 config: ClientConfig::default(),
539 }
540 }
541
542 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
544 self.capabilities = capabilities;
545 self
546 }
547
548 pub fn config(mut self, config: ClientConfig) -> Self {
550 self.config = config;
551 self
552 }
553
554 pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
556 self.config.request_timeout_ms = timeout_ms;
557 self
558 }
559
560 pub fn max_retries(mut self, retries: u32) -> Self {
562 self.config.max_retries = retries;
563 self
564 }
565
566 pub fn validate_requests(mut self, validate: bool) -> Self {
568 self.config.validate_requests = validate;
569 self
570 }
571
572 pub fn validate_responses(mut self, validate: bool) -> Self {
574 self.config.validate_responses = validate;
575 self
576 }
577
578 pub fn build(self) -> McpClient {
580 let mut client = McpClient::new(self.name, self.version);
581 client.set_capabilities(self.capabilities);
582 client.config = self.config;
583 client
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use async_trait::async_trait;
591
592 struct MockTransport {
594 responses: Vec<JsonRpcResponse>,
595 current: usize,
596 }
597
598 impl MockTransport {
599 fn new(responses: Vec<JsonRpcResponse>) -> Self {
600 Self {
601 responses,
602 current: 0,
603 }
604 }
605 }
606
607 #[async_trait]
608 impl Transport for MockTransport {
609 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
610 if self.current < self.responses.len() {
611 let response = self.responses[self.current].clone();
612 self.current += 1;
613 Ok(response)
614 } else {
615 Err(McpError::Transport("No more responses".to_string()))
616 }
617 }
618
619 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
620 Ok(())
621 }
622
623 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
624 Ok(None)
625 }
626
627 async fn close(&mut self) -> McpResult<()> {
628 Ok(())
629 }
630 }
631
632 #[tokio::test]
633 async fn test_client_creation() {
634 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
635 assert_eq!(client.info().name, "test-client");
636 assert_eq!(client.info().version, "1.0.0");
637 assert!(!client.is_connected().await);
638 }
639
640 #[tokio::test]
641 async fn test_client_builder() {
642 let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
643 .request_timeout(5000)
644 .max_retries(5)
645 .validate_requests(false)
646 .build();
647
648 assert_eq!(client.config().request_timeout_ms, 5000);
649 assert_eq!(client.config().max_retries, 5);
650 assert!(!client.config().validate_requests);
651 }
652
653 #[tokio::test]
654 async fn test_mock_connection() {
655 let init_result = InitializeResult::new(
656 ServerInfo {
657 name: "test-server".to_string(),
658 version: "1.0.0".to_string(),
659 },
660 ServerCapabilities::default(),
661 MCP_PROTOCOL_VERSION.to_string(),
662 );
663
664 let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
665
666 let transport = MockTransport::new(vec![init_response]);
667
668 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
669 let result = client.connect(transport).await.unwrap();
670
671 assert_eq!(result.server_info.name, "test-server");
672 assert!(client.is_connected().await);
673 }
674
675 #[tokio::test]
676 async fn test_disconnect() {
677 let init_result = InitializeResult::new(
678 ServerInfo {
679 name: "test-server".to_string(),
680 version: "1.0.0".to_string(),
681 },
682 ServerCapabilities::default(),
683 MCP_PROTOCOL_VERSION.to_string(),
684 );
685
686 let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
687
688 let transport = MockTransport::new(vec![init_response]);
689
690 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
691 client.connect(transport).await.unwrap();
692
693 assert!(client.is_connected().await);
694
695 client.disconnect().await.unwrap();
696 assert!(!client.is_connected().await);
697 assert!(client.server_info().await.is_none());
698 assert!(client.server_capabilities().await.is_none());
699 }
700}