1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{Mutex, RwLock};
12
13use crate::core::error::{McpError, McpResult};
14use crate::protocol::{messages::*, types::*, validation::*};
15use crate::transport::traits::Transport;
16
17#[derive(Debug, Clone)]
19pub struct ClientConfig {
20 pub request_timeout_ms: u64,
22 pub max_retries: u32,
24 pub retry_delay_ms: u64,
26 pub validate_requests: bool,
28 pub validate_responses: bool,
30}
31
32impl Default for ClientConfig {
33 fn default() -> Self {
34 Self {
35 request_timeout_ms: 30000,
36 max_retries: 3,
37 retry_delay_ms: 1000,
38 validate_requests: true,
39 validate_responses: true,
40 }
41 }
42}
43
44pub struct McpClient {
46 info: ClientInfo,
48 capabilities: ClientCapabilities,
50 config: ClientConfig,
52 transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
54 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
56 server_info: Arc<RwLock<Option<ServerInfo>>>,
58 request_counter: Arc<Mutex<u64>>,
60 connected: Arc<RwLock<bool>>,
62}
63
64impl McpClient {
65 pub fn new(name: String, version: String) -> Self {
67 Self {
68 info: ClientInfo { name, version },
69 capabilities: ClientCapabilities::default(),
70 config: ClientConfig::default(),
71 transport: Arc::new(Mutex::new(None)),
72 server_capabilities: Arc::new(RwLock::new(None)),
73 server_info: Arc::new(RwLock::new(None)),
74 request_counter: Arc::new(Mutex::new(0)),
75 connected: Arc::new(RwLock::new(false)),
76 }
77 }
78
79 pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
81 let mut client = Self::new(name, version);
82 client.config = config;
83 client
84 }
85
86 pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
88 self.capabilities = capabilities;
89 }
90
91 pub fn info(&self) -> &ClientInfo {
93 &self.info
94 }
95
96 pub fn capabilities(&self) -> &ClientCapabilities {
98 &self.capabilities
99 }
100
101 pub fn config(&self) -> &ClientConfig {
103 &self.config
104 }
105
106 pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
108 let capabilities = self.server_capabilities.read().await;
109 capabilities.clone()
110 }
111
112 pub async fn server_info(&self) -> Option<ServerInfo> {
114 let info = self.server_info.read().await;
115 info.clone()
116 }
117
118 pub async fn is_connected(&self) -> bool {
120 let connected = self.connected.read().await;
121 *connected
122 }
123
124 pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
130 where
131 T: Transport + 'static,
132 {
133 {
135 let mut transport_guard = self.transport.lock().await;
136 *transport_guard = Some(Box::new(transport));
137 }
138
139 let init_result = self.initialize().await?;
141
142 {
144 let mut connected = self.connected.write().await;
145 *connected = true;
146 }
147
148 Ok(init_result)
149 }
150
151 pub async fn disconnect(&self) -> McpResult<()> {
153 {
155 let mut transport_guard = self.transport.lock().await;
156 if let Some(transport) = transport_guard.as_mut() {
157 transport.close().await?;
158 }
159 *transport_guard = None;
160 }
161
162 {
164 let mut server_capabilities = self.server_capabilities.write().await;
165 *server_capabilities = None;
166 }
167 {
168 let mut server_info = self.server_info.write().await;
169 *server_info = None;
170 }
171
172 {
174 let mut connected = self.connected.write().await;
175 *connected = false;
176 }
177
178 Ok(())
179 }
180
181 async fn initialize(&self) -> McpResult<InitializeResult> {
183 let params = InitializeParams::new(
184 self.info.clone(),
185 self.capabilities.clone(),
186 MCP_PROTOCOL_VERSION.to_string(),
187 );
188
189 let request = JsonRpcRequest::new(
190 Value::from(self.next_request_id().await),
191 methods::INITIALIZE.to_string(),
192 Some(params),
193 )?;
194
195 let response = self.send_request(request).await?;
196
197 if let Some(error) = response.error {
198 return Err(McpError::Protocol(format!(
199 "Initialize failed: {}",
200 error.message
201 )));
202 }
203
204 let result: InitializeResult = serde_json::from_value(
205 response
206 .result
207 .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
208 )?;
209
210 {
212 let mut server_capabilities = self.server_capabilities.write().await;
213 *server_capabilities = Some(result.capabilities.clone());
214 }
215 {
216 let mut server_info = self.server_info.write().await;
217 *server_info = Some(result.server_info.clone());
218 }
219
220 Ok(result)
221 }
222
223 pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
229 self.ensure_connected().await?;
230
231 let params = ListToolsParams { cursor };
232 let request = JsonRpcRequest::new(
233 Value::from(self.next_request_id().await),
234 methods::TOOLS_LIST.to_string(),
235 Some(params),
236 )?;
237
238 let response = self.send_request(request).await?;
239 self.handle_response(response)
240 }
241
242 pub async fn call_tool(
244 &self,
245 name: String,
246 arguments: Option<HashMap<String, Value>>,
247 ) -> McpResult<CallToolResult> {
248 self.ensure_connected().await?;
249
250 let params = CallToolParams::new(name, arguments);
251
252 if self.config.validate_requests {
253 validate_call_tool_params(¶ms)?;
254 }
255
256 let request = JsonRpcRequest::new(
257 Value::from(self.next_request_id().await),
258 methods::TOOLS_CALL.to_string(),
259 Some(params),
260 )?;
261
262 let response = self.send_request(request).await?;
263 self.handle_response(response)
264 }
265
266 pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
272 self.ensure_connected().await?;
273
274 let params = ListResourcesParams { cursor };
275 let request = JsonRpcRequest::new(
276 Value::from(self.next_request_id().await),
277 methods::RESOURCES_LIST.to_string(),
278 Some(params),
279 )?;
280
281 let response = self.send_request(request).await?;
282 self.handle_response(response)
283 }
284
285 pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
287 self.ensure_connected().await?;
288
289 let params = ReadResourceParams::new(uri);
290
291 if self.config.validate_requests {
292 validate_read_resource_params(¶ms)?;
293 }
294
295 let request = JsonRpcRequest::new(
296 Value::from(self.next_request_id().await),
297 methods::RESOURCES_READ.to_string(),
298 Some(params),
299 )?;
300
301 let response = self.send_request(request).await?;
302 self.handle_response(response)
303 }
304
305 pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
307 self.ensure_connected().await?;
308
309 let params = SubscribeResourceParams { uri };
310 let request = JsonRpcRequest::new(
311 Value::from(self.next_request_id().await),
312 methods::RESOURCES_SUBSCRIBE.to_string(),
313 Some(params),
314 )?;
315
316 let response = self.send_request(request).await?;
317 self.handle_response(response)
318 }
319
320 pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
322 self.ensure_connected().await?;
323
324 let params = UnsubscribeResourceParams { uri };
325 let request = JsonRpcRequest::new(
326 Value::from(self.next_request_id().await),
327 methods::RESOURCES_UNSUBSCRIBE.to_string(),
328 Some(params),
329 )?;
330
331 let response = self.send_request(request).await?;
332 self.handle_response(response)
333 }
334
335 pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
341 self.ensure_connected().await?;
342
343 let params = ListPromptsParams { cursor };
344 let request = JsonRpcRequest::new(
345 Value::from(self.next_request_id().await),
346 methods::PROMPTS_LIST.to_string(),
347 Some(params),
348 )?;
349
350 let response = self.send_request(request).await?;
351 self.handle_response(response)
352 }
353
354 pub async fn get_prompt(
356 &self,
357 name: String,
358 arguments: Option<HashMap<String, Value>>,
359 ) -> McpResult<GetPromptResult> {
360 self.ensure_connected().await?;
361
362 let params = GetPromptParams::new(name, arguments);
363
364 if self.config.validate_requests {
365 validate_get_prompt_params(¶ms)?;
366 }
367
368 let request = JsonRpcRequest::new(
369 Value::from(self.next_request_id().await),
370 methods::PROMPTS_GET.to_string(),
371 Some(params),
372 )?;
373
374 let response = self.send_request(request).await?;
375 self.handle_response(response)
376 }
377
378 pub async fn create_message(
384 &self,
385 params: CreateMessageParams,
386 ) -> McpResult<CreateMessageResult> {
387 self.ensure_connected().await?;
388
389 {
391 let server_capabilities = self.server_capabilities.read().await;
392 if let Some(capabilities) = server_capabilities.as_ref() {
393 if capabilities.sampling.is_none() {
394 return Err(McpError::Protocol(
395 "Server does not support sampling".to_string(),
396 ));
397 }
398 } else {
399 return Err(McpError::Protocol("Not connected to server".to_string()));
400 }
401 }
402
403 if self.config.validate_requests {
404 validate_create_message_params(¶ms)?;
405 }
406
407 let request = JsonRpcRequest::new(
408 Value::from(self.next_request_id().await),
409 methods::SAMPLING_CREATE_MESSAGE.to_string(),
410 Some(params),
411 )?;
412
413 let response = self.send_request(request).await?;
414 self.handle_response(response)
415 }
416
417 pub async fn ping(&self) -> McpResult<PingResult> {
423 self.ensure_connected().await?;
424
425 let request = JsonRpcRequest::new(
426 Value::from(self.next_request_id().await),
427 methods::PING.to_string(),
428 Some(PingParams {}),
429 )?;
430
431 let response = self.send_request(request).await?;
432 self.handle_response(response)
433 }
434
435 pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
437 self.ensure_connected().await?;
438
439 let params = SetLoggingLevelParams { level };
440 let request = JsonRpcRequest::new(
441 Value::from(self.next_request_id().await),
442 methods::LOGGING_SET_LEVEL.to_string(),
443 Some(params),
444 )?;
445
446 let response = self.send_request(request).await?;
447 self.handle_response(response)
448 }
449
450 pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
456 let mut transport_guard = self.transport.lock().await;
457 if let Some(transport) = transport_guard.as_mut() {
458 transport.receive_notification().await
459 } else {
460 Err(McpError::Transport("Not connected".to_string()))
461 }
462 }
463
464 async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
470 if self.config.validate_requests {
471 validate_jsonrpc_request(&request)?;
472 validate_mcp_request(&request.method, request.params.as_ref())?;
473 }
474
475 let mut transport_guard = self.transport.lock().await;
476 if let Some(transport) = transport_guard.as_mut() {
477 let response = transport.send_request(request).await?;
478
479 if self.config.validate_responses {
480 validate_jsonrpc_response(&response)?;
481 }
482
483 Ok(response)
484 } else {
485 Err(McpError::Transport("Not connected".to_string()))
486 }
487 }
488
489 fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
491 where
492 T: serde::de::DeserializeOwned,
493 {
494 if let Some(error) = response.error {
495 return Err(McpError::Protocol(format!(
496 "Server error: {}",
497 error.message
498 )));
499 }
500
501 let result = response
502 .result
503 .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
504
505 serde_json::from_value(result).map_err(|e| McpError::Serialization(e))
506 }
507
508 async fn ensure_connected(&self) -> McpResult<()> {
510 if !self.is_connected().await {
511 return Err(McpError::Connection("Not connected to server".to_string()));
512 }
513 Ok(())
514 }
515
516 async fn next_request_id(&self) -> u64 {
518 let mut counter = self.request_counter.lock().await;
519 *counter += 1;
520 *counter
521 }
522}
523
524pub struct McpClientBuilder {
526 name: String,
527 version: String,
528 capabilities: ClientCapabilities,
529 config: ClientConfig,
530}
531
532impl McpClientBuilder {
533 pub fn new(name: String, version: String) -> Self {
535 Self {
536 name,
537 version,
538 capabilities: ClientCapabilities::default(),
539 config: ClientConfig::default(),
540 }
541 }
542
543 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
545 self.capabilities = capabilities;
546 self
547 }
548
549 pub fn config(mut self, config: ClientConfig) -> Self {
551 self.config = config;
552 self
553 }
554
555 pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
557 self.config.request_timeout_ms = timeout_ms;
558 self
559 }
560
561 pub fn max_retries(mut self, retries: u32) -> Self {
563 self.config.max_retries = retries;
564 self
565 }
566
567 pub fn validate_requests(mut self, validate: bool) -> Self {
569 self.config.validate_requests = validate;
570 self
571 }
572
573 pub fn validate_responses(mut self, validate: bool) -> Self {
575 self.config.validate_responses = validate;
576 self
577 }
578
579 pub fn build(self) -> McpClient {
581 let mut client = McpClient::new(self.name, self.version);
582 client.set_capabilities(self.capabilities);
583 client.config = self.config;
584 client
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use async_trait::async_trait;
592
593 struct MockTransport {
595 responses: Vec<JsonRpcResponse>,
596 current: usize,
597 }
598
599 impl MockTransport {
600 fn new(responses: Vec<JsonRpcResponse>) -> Self {
601 Self {
602 responses,
603 current: 0,
604 }
605 }
606 }
607
608 #[async_trait]
609 impl Transport for MockTransport {
610 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
611 if self.current < self.responses.len() {
612 let response = self.responses[self.current].clone();
613 self.current += 1;
614 Ok(response)
615 } else {
616 Err(McpError::Transport("No more responses".to_string()))
617 }
618 }
619
620 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
621 Ok(())
622 }
623
624 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
625 Ok(None)
626 }
627
628 async fn close(&mut self) -> McpResult<()> {
629 Ok(())
630 }
631 }
632
633 #[tokio::test]
634 async fn test_client_creation() {
635 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
636 assert_eq!(client.info().name, "test-client");
637 assert_eq!(client.info().version, "1.0.0");
638 assert!(!client.is_connected().await);
639 }
640
641 #[tokio::test]
642 async fn test_client_builder() {
643 let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
644 .request_timeout(5000)
645 .max_retries(5)
646 .validate_requests(false)
647 .build();
648
649 assert_eq!(client.config().request_timeout_ms, 5000);
650 assert_eq!(client.config().max_retries, 5);
651 assert!(!client.config().validate_requests);
652 }
653
654 #[tokio::test]
655 async fn test_mock_connection() {
656 let init_result = InitializeResult::new(
657 ServerInfo {
658 name: "test-server".to_string(),
659 version: "1.0.0".to_string(),
660 },
661 ServerCapabilities::default(),
662 MCP_PROTOCOL_VERSION.to_string(),
663 );
664
665 let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
666
667 let transport = MockTransport::new(vec![init_response]);
668
669 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
670 let result = client.connect(transport).await.unwrap();
671
672 assert_eq!(result.server_info.name, "test-server");
673 assert!(client.is_connected().await);
674 }
675
676 #[tokio::test]
677 async fn test_disconnect() {
678 let init_result = InitializeResult::new(
679 ServerInfo {
680 name: "test-server".to_string(),
681 version: "1.0.0".to_string(),
682 },
683 ServerCapabilities::default(),
684 MCP_PROTOCOL_VERSION.to_string(),
685 );
686
687 let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
688
689 let transport = MockTransport::new(vec![init_response]);
690
691 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
692 client.connect(transport).await.unwrap();
693
694 assert!(client.is_connected().await);
695
696 client.disconnect().await.unwrap();
697 assert!(!client.is_connected().await);
698 assert!(client.server_info().await.is_none());
699 assert!(client.server_capabilities().await.is_none());
700 }
701}