1use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock};
11
12use crate::core::error::{McpError, McpResult};
13use crate::protocol::{messages::*, methods, types::*, validation::*};
14use crate::transport::traits::Transport;
15
16#[derive(Debug, Clone)]
18pub struct ClientConfig {
19 pub request_timeout_ms: u64,
21 pub max_retries: u32,
23 pub retry_delay_ms: u64,
25 pub validate_requests: bool,
27 pub validate_responses: bool,
29}
30
31impl Default for ClientConfig {
32 fn default() -> Self {
33 Self {
34 request_timeout_ms: 30000,
35 max_retries: 3,
36 retry_delay_ms: 1000,
37 validate_requests: true,
38 validate_responses: true,
39 }
40 }
41}
42
43pub struct McpClient {
45 info: ClientInfo,
47 capabilities: ClientCapabilities,
49 config: ClientConfig,
51 transport: Arc<Mutex<Option<Box<dyn Transport>>>>,
53 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
55 server_info: Arc<RwLock<Option<ServerInfo>>>,
57 request_counter: Arc<Mutex<u64>>,
59 connected: Arc<RwLock<bool>>,
61}
62
63impl McpClient {
64 pub fn new(name: String, version: String) -> Self {
66 Self {
67 info: ClientInfo { name, version },
68 capabilities: ClientCapabilities::default(),
69 config: ClientConfig::default(),
70 transport: Arc::new(Mutex::new(None)),
71 server_capabilities: Arc::new(RwLock::new(None)),
72 server_info: Arc::new(RwLock::new(None)),
73 request_counter: Arc::new(Mutex::new(0)),
74 connected: Arc::new(RwLock::new(false)),
75 }
76 }
77
78 pub fn with_config(name: String, version: String, config: ClientConfig) -> Self {
80 let mut client = Self::new(name, version);
81 client.config = config;
82 client
83 }
84
85 pub fn set_capabilities(&mut self, capabilities: ClientCapabilities) {
87 self.capabilities = capabilities;
88 }
89
90 pub fn info(&self) -> &ClientInfo {
92 &self.info
93 }
94
95 pub fn capabilities(&self) -> &ClientCapabilities {
97 &self.capabilities
98 }
99
100 pub fn config(&self) -> &ClientConfig {
102 &self.config
103 }
104
105 pub async fn server_capabilities(&self) -> Option<ServerCapabilities> {
107 let capabilities = self.server_capabilities.read().await;
108 capabilities.clone()
109 }
110
111 pub async fn server_info(&self) -> Option<ServerInfo> {
113 let info = self.server_info.read().await;
114 info.clone()
115 }
116
117 pub async fn is_connected(&self) -> bool {
119 let connected = self.connected.read().await;
120 *connected
121 }
122
123 pub async fn connect<T>(&mut self, transport: T) -> McpResult<InitializeResult>
129 where
130 T: Transport + 'static,
131 {
132 {
134 let mut transport_guard = self.transport.lock().await;
135 *transport_guard = Some(Box::new(transport));
136 }
137
138 let init_result = self.initialize().await?;
140
141 {
143 let mut connected = self.connected.write().await;
144 *connected = true;
145 }
146
147 Ok(init_result)
148 }
149
150 pub async fn disconnect(&self) -> McpResult<()> {
152 {
154 let mut transport_guard = self.transport.lock().await;
155 if let Some(transport) = transport_guard.as_mut() {
156 transport.close().await?;
157 }
158 *transport_guard = None;
159 }
160
161 {
163 let mut server_capabilities = self.server_capabilities.write().await;
164 *server_capabilities = None;
165 }
166 {
167 let mut server_info = self.server_info.write().await;
168 *server_info = None;
169 }
170
171 {
173 let mut connected = self.connected.write().await;
174 *connected = false;
175 }
176
177 Ok(())
178 }
179
180 async fn initialize(&self) -> McpResult<InitializeResult> {
182 let params = InitializeParams::new(
183 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
184 self.capabilities.clone(),
185 self.info.clone(),
186 );
187
188 let request = JsonRpcRequest::new(
189 Value::from(self.next_request_id().await),
190 methods::INITIALIZE.to_string(),
191 Some(params),
192 )?;
193
194 let response = self.send_request(request).await?;
195
196 let result: InitializeResult = serde_json::from_value(
200 response
201 .result
202 .ok_or_else(|| McpError::Protocol("Missing initialize result".to_string()))?,
203 )?;
204
205 {
207 let mut server_capabilities = self.server_capabilities.write().await;
208 *server_capabilities = Some(result.capabilities.clone());
209 }
210 {
211 let mut server_info = self.server_info.write().await;
212 *server_info = Some(result.server_info.clone());
213 }
214
215 Ok(result)
216 }
217
218 pub async fn list_tools(&self, cursor: Option<String>) -> McpResult<ListToolsResult> {
224 self.ensure_connected().await?;
225
226 let params = ListToolsParams { cursor, meta: None };
227 let request = JsonRpcRequest::new(
228 Value::from(self.next_request_id().await),
229 methods::TOOLS_LIST.to_string(),
230 Some(params),
231 )?;
232
233 let response = self.send_request(request).await?;
234 self.handle_response(response)
235 }
236
237 pub async fn call_tool(
239 &self,
240 name: String,
241 arguments: Option<HashMap<String, Value>>,
242 ) -> McpResult<CallToolResult> {
243 self.ensure_connected().await?;
244
245 let params = if let Some(args) = arguments {
246 CallToolParams::new_with_arguments(name, args)
247 } else {
248 CallToolParams::new(name)
249 };
250
251 if self.config.validate_requests {
252 validate_call_tool_params(¶ms)?;
253 }
254
255 let request = JsonRpcRequest::new(
256 Value::from(self.next_request_id().await),
257 methods::TOOLS_CALL.to_string(),
258 Some(params),
259 )?;
260
261 let response = self.send_request(request).await?;
262 self.handle_response(response)
263 }
264
265 pub async fn list_resources(&self, cursor: Option<String>) -> McpResult<ListResourcesResult> {
271 self.ensure_connected().await?;
272
273 let params = ListResourcesParams { cursor, meta: None };
274 let request = JsonRpcRequest::new(
275 Value::from(self.next_request_id().await),
276 methods::RESOURCES_LIST.to_string(),
277 Some(params),
278 )?;
279
280 let response = self.send_request(request).await?;
281 self.handle_response(response)
282 }
283
284 pub async fn read_resource(&self, uri: String) -> McpResult<ReadResourceResult> {
286 self.ensure_connected().await?;
287
288 let params = ReadResourceParams::new(uri);
289
290 if self.config.validate_requests {
291 validate_read_resource_params(¶ms)?;
292 }
293
294 let request = JsonRpcRequest::new(
295 Value::from(self.next_request_id().await),
296 methods::RESOURCES_READ.to_string(),
297 Some(params),
298 )?;
299
300 let response = self.send_request(request).await?;
301 self.handle_response(response)
302 }
303
304 pub async fn subscribe_resource(&self, uri: String) -> McpResult<SubscribeResourceResult> {
306 self.ensure_connected().await?;
307
308 let params = SubscribeResourceParams { uri, meta: None };
309 let request = JsonRpcRequest::new(
310 Value::from(self.next_request_id().await),
311 methods::RESOURCES_SUBSCRIBE.to_string(),
312 Some(params),
313 )?;
314
315 let response = self.send_request(request).await?;
316 self.handle_response(response)
317 }
318
319 pub async fn unsubscribe_resource(&self, uri: String) -> McpResult<UnsubscribeResourceResult> {
321 self.ensure_connected().await?;
322
323 let params = UnsubscribeResourceParams { uri, meta: None };
324 let request = JsonRpcRequest::new(
325 Value::from(self.next_request_id().await),
326 methods::RESOURCES_UNSUBSCRIBE.to_string(),
327 Some(params),
328 )?;
329
330 let response = self.send_request(request).await?;
331 self.handle_response(response)
332 }
333
334 pub async fn list_prompts(&self, cursor: Option<String>) -> McpResult<ListPromptsResult> {
340 self.ensure_connected().await?;
341
342 let params = ListPromptsParams { cursor, meta: None };
343 let request = JsonRpcRequest::new(
344 Value::from(self.next_request_id().await),
345 methods::PROMPTS_LIST.to_string(),
346 Some(params),
347 )?;
348
349 let response = self.send_request(request).await?;
350 self.handle_response(response)
351 }
352
353 pub async fn get_prompt(
355 &self,
356 name: String,
357 arguments: Option<HashMap<String, String>>,
358 ) -> McpResult<GetPromptResult> {
359 self.ensure_connected().await?;
360
361 let params = if let Some(args) = arguments {
362 GetPromptParams::new_with_arguments(name, args)
363 } else {
364 GetPromptParams::new(name)
365 };
366
367 if self.config.validate_requests {
368 validate_get_prompt_params(¶ms)?;
369 }
370
371 let request = JsonRpcRequest::new(
372 Value::from(self.next_request_id().await),
373 methods::PROMPTS_GET.to_string(),
374 Some(params),
375 )?;
376
377 let response = self.send_request(request).await?;
378 self.handle_response(response)
379 }
380
381 pub async fn create_message(
387 &self,
388 params: CreateMessageParams,
389 ) -> McpResult<CreateMessageResult> {
390 self.ensure_connected().await?;
391
392 {
394 let server_capabilities = self.server_capabilities.read().await;
395 if let Some(capabilities) = server_capabilities.as_ref() {
396 if capabilities.sampling.is_none() {
397 return Err(McpError::Protocol(
398 "Server does not support sampling".to_string(),
399 ));
400 }
401 } else {
402 return Err(McpError::Protocol("Not connected to server".to_string()));
403 }
404 }
405
406 if self.config.validate_requests {
407 validate_create_message_params(¶ms)?;
408 }
409
410 let request = JsonRpcRequest::new(
411 Value::from(self.next_request_id().await),
412 methods::SAMPLING_CREATE_MESSAGE.to_string(),
413 Some(params),
414 )?;
415
416 let response = self.send_request(request).await?;
417 self.handle_response(response)
418 }
419
420 pub async fn ping(&self) -> McpResult<PingResult> {
426 self.ensure_connected().await?;
427
428 let request = JsonRpcRequest::new(
429 Value::from(self.next_request_id().await),
430 methods::PING.to_string(),
431 Some(PingParams { meta: None }),
432 )?;
433
434 let response = self.send_request(request).await?;
435 self.handle_response(response)
436 }
437
438 pub async fn set_logging_level(&self, level: LoggingLevel) -> McpResult<SetLoggingLevelResult> {
440 self.ensure_connected().await?;
441
442 let params = SetLoggingLevelParams { level, meta: None };
443 let request = JsonRpcRequest::new(
444 Value::from(self.next_request_id().await),
445 methods::LOGGING_SET_LEVEL.to_string(),
446 Some(params),
447 )?;
448
449 let response = self.send_request(request).await?;
450 self.handle_response(response)
451 }
452
453 pub async fn receive_notification(&self) -> McpResult<Option<JsonRpcNotification>> {
459 let mut transport_guard = self.transport.lock().await;
460 if let Some(transport) = transport_guard.as_mut() {
461 transport.receive_notification().await
462 } else {
463 Err(McpError::Transport("Not connected".to_string()))
464 }
465 }
466
467 async fn send_request(&self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
473 if self.config.validate_requests {
474 validate_jsonrpc_request(&request)?;
475 validate_mcp_request(&request.method, request.params.as_ref())?;
476 }
477
478 let mut transport_guard = self.transport.lock().await;
479 if let Some(transport) = transport_guard.as_mut() {
480 let response = transport.send_request(request).await?;
481
482 if self.config.validate_responses {
483 validate_jsonrpc_response(&response)?;
484 }
485
486 Ok(response)
487 } else {
488 Err(McpError::Transport("Not connected".to_string()))
489 }
490 }
491
492 fn handle_response<T>(&self, response: JsonRpcResponse) -> McpResult<T>
494 where
495 T: serde::de::DeserializeOwned,
496 {
497 let result = response
500 .result
501 .ok_or_else(|| McpError::Protocol("Missing result in response".to_string()))?;
502
503 serde_json::from_value(result).map_err(|e| McpError::Serialization(e.to_string()))
504 }
505
506 async fn ensure_connected(&self) -> McpResult<()> {
508 if !self.is_connected().await {
509 return Err(McpError::Connection("Not connected to server".to_string()));
510 }
511 Ok(())
512 }
513
514 async fn next_request_id(&self) -> u64 {
516 let mut counter = self.request_counter.lock().await;
517 *counter += 1;
518 *counter
519 }
520}
521
522pub struct McpClientBuilder {
524 name: String,
525 version: String,
526 capabilities: ClientCapabilities,
527 config: ClientConfig,
528}
529
530impl McpClientBuilder {
531 pub fn new(name: String, version: String) -> Self {
533 Self {
534 name,
535 version,
536 capabilities: ClientCapabilities::default(),
537 config: ClientConfig::default(),
538 }
539 }
540
541 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
543 self.capabilities = capabilities;
544 self
545 }
546
547 pub fn config(mut self, config: ClientConfig) -> Self {
549 self.config = config;
550 self
551 }
552
553 pub fn request_timeout(mut self, timeout_ms: u64) -> Self {
555 self.config.request_timeout_ms = timeout_ms;
556 self
557 }
558
559 pub fn max_retries(mut self, retries: u32) -> Self {
561 self.config.max_retries = retries;
562 self
563 }
564
565 pub fn validate_requests(mut self, validate: bool) -> Self {
567 self.config.validate_requests = validate;
568 self
569 }
570
571 pub fn validate_responses(mut self, validate: bool) -> Self {
573 self.config.validate_responses = validate;
574 self
575 }
576
577 pub fn build(self) -> McpClient {
579 let mut client = McpClient::new(self.name, self.version);
580 client.set_capabilities(self.capabilities);
581 client.config = self.config;
582 client
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use async_trait::async_trait;
590
591 struct MockTransport {
593 responses: Vec<JsonRpcResponse>,
594 current: usize,
595 }
596
597 impl MockTransport {
598 fn new(responses: Vec<JsonRpcResponse>) -> Self {
599 Self {
600 responses,
601 current: 0,
602 }
603 }
604 }
605
606 #[async_trait]
607 impl Transport for MockTransport {
608 async fn send_request(&mut self, _request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
609 if self.current < self.responses.len() {
610 let response = self.responses[self.current].clone();
611 self.current += 1;
612 Ok(response)
613 } else {
614 Err(McpError::Transport("No more responses".to_string()))
615 }
616 }
617
618 async fn send_notification(&mut self, _notification: JsonRpcNotification) -> McpResult<()> {
619 Ok(())
620 }
621
622 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
623 Ok(None)
624 }
625
626 async fn close(&mut self) -> McpResult<()> {
627 Ok(())
628 }
629 }
630
631 #[tokio::test]
632 async fn test_client_creation() {
633 let client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
634 assert_eq!(client.info().name, "test-client");
635 assert_eq!(client.info().version, "1.0.0");
636 assert!(!client.is_connected().await);
637 }
638
639 #[tokio::test]
640 async fn test_client_builder() {
641 let client = McpClientBuilder::new("test-client".to_string(), "1.0.0".to_string())
642 .request_timeout(5000)
643 .max_retries(5)
644 .validate_requests(false)
645 .build();
646
647 assert_eq!(client.config().request_timeout_ms, 5000);
648 assert_eq!(client.config().max_retries, 5);
649 assert!(!client.config().validate_requests);
650 }
651
652 #[tokio::test]
653 async fn test_mock_connection() {
654 let init_result = InitializeResult::new(
655 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
656 ServerCapabilities::default(),
657 ServerInfo {
658 name: "test-server".to_string(),
659 version: "1.0.0".to_string(),
660 },
661 );
662
663 let init_response = JsonRpcResponse::success(Value::from(1), init_result.clone()).unwrap();
664
665 let transport = MockTransport::new(vec![init_response]);
666
667 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
668 let result = client.connect(transport).await.unwrap();
669
670 assert_eq!(result.server_info.name, "test-server");
671 assert!(client.is_connected().await);
672 }
673
674 #[tokio::test]
675 async fn test_disconnect() {
676 let init_result = InitializeResult::new(
677 crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
678 ServerCapabilities::default(),
679 ServerInfo {
680 name: "test-server".to_string(),
681 version: "1.0.0".to_string(),
682 },
683 );
684
685 let init_response = JsonRpcResponse::success(Value::from(1), init_result).unwrap();
686
687 let transport = MockTransport::new(vec![init_response]);
688
689 let mut client = McpClient::new("test-client".to_string(), "1.0.0".to_string());
690 client.connect(transport).await.unwrap();
691
692 assert!(client.is_connected().await);
693
694 client.disconnect().await.unwrap();
695 assert!(!client.is_connected().await);
696 assert!(client.server_info().await.is_none());
697 assert!(client.server_capabilities().await.is_none());
698 }
699}