1use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14
15use async_trait::async_trait;
16use tokio::sync::{mpsc, oneshot, RwLock};
17use tokio::time::{sleep, Instant};
18
19use crate::error::{McpError, McpResult, ProtocolError};
20use crate::messages::{
21 Capabilities, Implementation, InitializeRequest, InitializeResponse, InitializedNotification,
22 JsonRpcId, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
23 ProgressNotification, PromptListChangedNotification, ProtocolVersion,
24 ResourceListChangedNotification, ResourceUpdatedNotification, ToolListChangedNotification,
25};
26use crate::transport::{factory::TransportFactory, Transport, TransportConfig};
27
28use tracing::{debug, info, warn};
29
30#[derive(Debug, Clone)]
32pub struct ClientConfig {
33 pub request_timeout: Duration,
35
36 pub init_timeout: Duration,
38
39 pub max_retries: u32,
41
42 pub retry_base_delay: Duration,
44
45 pub auto_handle_notifications: bool,
47
48 pub message_buffer_size: usize,
50}
51
52impl Default for ClientConfig {
53 fn default() -> Self {
54 Self {
55 request_timeout: Duration::from_secs(30),
56 init_timeout: Duration::from_secs(10),
57 max_retries: 3,
58 retry_base_delay: Duration::from_secs(1),
59 auto_handle_notifications: true,
60 message_buffer_size: 1000,
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum ClientState {
68 Disconnected,
70 Connecting,
72 Initializing,
74 Ready,
76 Error(String),
78}
79
80#[derive(Debug, Clone)]
82pub struct ServerInfo {
83 pub implementation: Implementation,
85 pub protocol_version: ProtocolVersion,
87 pub capabilities: Capabilities,
89 pub connected_at: Instant,
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct ClientStats {
96 pub requests_sent: u64,
98 pub responses_received: u64,
100 pub notifications_sent: u64,
102 pub notifications_received: u64,
104 pub errors: u64,
106 pub retries: u64,
108 pub connection_attempts: u64,
110 pub last_activity: Option<Instant>,
112}
113
114#[async_trait]
116pub trait NotificationHandler: Send + Sync {
117 async fn handle_progress(&self, notification: ProgressNotification) -> McpResult<()> {
119 debug!("Received progress notification: {:?}", notification);
120 Ok(())
121 }
122
123 async fn handle_resource_updated(
125 &self,
126 notification: ResourceUpdatedNotification,
127 ) -> McpResult<()> {
128 debug!("Resource updated: {:?}", notification);
129 Ok(())
130 }
131
132 async fn handle_resource_list_changed(
134 &self,
135 notification: ResourceListChangedNotification,
136 ) -> McpResult<()> {
137 debug!("Resource list changed: {:?}", notification);
138 Ok(())
139 }
140
141 async fn handle_tool_list_changed(
143 &self,
144 notification: ToolListChangedNotification,
145 ) -> McpResult<()> {
146 debug!("Tool list changed: {:?}", notification);
147 Ok(())
148 }
149
150 async fn handle_prompt_list_changed(
152 &self,
153 notification: PromptListChangedNotification,
154 ) -> McpResult<()> {
155 debug!("Prompt list changed: {:?}", notification);
156 Ok(())
157 }
158}
159
160#[derive(Debug, Default)]
162pub struct DefaultNotificationHandler;
163
164#[async_trait]
165impl NotificationHandler for DefaultNotificationHandler {}
166
167pub struct McpClient {
176 transport: Box<dyn Transport>,
177 config: ClientConfig,
178 state: RwLock<ClientState>,
179 server_info: RwLock<Option<ServerInfo>>,
180 stats: Arc<RwLock<ClientStats>>,
181 request_counter: AtomicU64,
182 pending_requests: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
183 notification_handler: Arc<dyn NotificationHandler>,
184 _message_sender: Option<mpsc::UnboundedSender<JsonRpcMessage>>,
185}
186
187impl McpClient {
188 pub async fn new(
216 transport_config: TransportConfig,
217 client_config: ClientConfig,
218 notification_handler: Box<dyn NotificationHandler>,
219 ) -> McpResult<Self> {
220 let transport = TransportFactory::create(transport_config).await?;
221
222 Ok(Self {
223 transport,
224 config: client_config,
225 state: RwLock::new(ClientState::Disconnected),
226 server_info: RwLock::new(None),
227 stats: Arc::new(RwLock::new(ClientStats::default())),
228 request_counter: AtomicU64::new(1),
229 pending_requests: Arc::new(RwLock::new(HashMap::new())),
230 notification_handler: notification_handler.into(),
231 _message_sender: None,
232 })
233 }
234
235 pub async fn with_defaults(transport_config: TransportConfig) -> McpResult<Self> {
245 Self::new(
246 transport_config,
247 ClientConfig::default(),
248 Box::new(DefaultNotificationHandler),
249 )
250 .await
251 }
252
253 pub async fn state(&self) -> ClientState {
255 self.state.read().await.clone()
256 }
257
258 pub async fn server_info(&self) -> Option<ServerInfo> {
260 self.server_info.read().await.clone()
261 }
262
263 pub async fn stats(&self) -> ClientStats {
265 self.stats.read().await.clone()
266 }
267
268 pub async fn is_ready(&self) -> bool {
270 matches!(self.state().await, ClientState::Ready)
271 }
272
273 pub fn transport_info(&self) -> crate::transport::TransportInfo {
275 self.transport.get_info()
276 }
277
278 pub async fn connect(&mut self, client_info: Implementation) -> McpResult<ServerInfo> {
313 info!("Connecting MCP client to server");
314
315 *self.state.write().await = ClientState::Connecting;
317
318 self.transport.connect().await.map_err(|e| {
320 let error = format!("Transport connection failed: {}", e);
321 self.set_error_state(error.clone());
322 McpError::Protocol(ProtocolError::InitializationFailed { reason: error })
323 })?;
324
325 self.start_message_processing().await?;
327
328 let server_info = self.perform_initialization(client_info).await?;
330
331 *self.state.write().await = ClientState::Ready;
333 *self.server_info.write().await = Some(server_info.clone());
334
335 info!(
336 "MCP client connected successfully to {}",
337 server_info.implementation.name
338 );
339 Ok(server_info)
340 }
341
342 pub async fn disconnect(&mut self) -> McpResult<()> {
344 info!("Disconnecting MCP client");
345
346 *self.state.write().await = ClientState::Disconnected;
348
349 *self.server_info.write().await = None;
351
352 self.pending_requests.write().await.clear();
354
355 self.transport.disconnect().await?;
357
358 info!("MCP client disconnected");
359 Ok(())
360 }
361
362 pub async fn send_notification<T>(&mut self, method: &str, params: T) -> McpResult<()>
364 where
365 T: serde::Serialize,
366 {
367 if !self.is_ready().await {
368 return Err(McpError::Protocol(ProtocolError::NotInitialized {
369 reason: "Client not ready for notifications".to_string(),
370 }));
371 }
372
373 let notification = JsonRpcNotification {
374 jsonrpc: "2.0".to_string(),
375 method: method.to_string(),
376 params: Some(serde_json::to_value(params)?),
377 };
378
379 self.transport.send_notification(notification).await?;
380 self.stats.write().await.notifications_sent += 1;
381 Ok(())
382 }
383
384 pub async fn send_request<T>(&mut self, method: &str, params: T) -> McpResult<JsonRpcResponse>
386 where
387 T: serde::Serialize,
388 {
389 if !self.is_ready().await {
390 return Err(McpError::Protocol(ProtocolError::NotInitialized {
391 reason: "Client not ready for requests".to_string(),
392 }));
393 }
394
395 self.send_request_with_timeout(method, params, None).await
396 }
397
398 fn set_error_state(&self, error: String) {
401 if let Ok(mut state) = self.state.try_write() {
402 *state = ClientState::Error(error);
403 }
404 }
405
406 fn generate_request_id(&self) -> String {
407 let counter = self.request_counter.fetch_add(1, Ordering::SeqCst);
408 format!("req_{}", counter)
409 }
410
411 async fn start_message_processing(&mut self) -> McpResult<()> {
412 tracing::info!("Starting message processing task");
413 let (sender, mut receiver) = mpsc::unbounded_channel();
414 self._message_sender = Some(sender);
415
416 let pending_requests = Arc::clone(&self.pending_requests);
418 let stats = Arc::clone(&self.stats);
419 let notification_handler = Arc::clone(&self.notification_handler);
420
421 tokio::spawn(async move {
423 tracing::debug!("Message processing task started, waiting for messages");
424 while let Some(message) = receiver.recv().await {
425 tracing::debug!("Received message in processing task: {:?}", message);
426 match message {
427 JsonRpcMessage::Response(response) => {
428 tracing::debug!("Processing response with ID: {}", response.id);
429 if let Some(sender) = pending_requests
431 .write()
432 .await
433 .remove(&response.id.to_string())
434 {
435 tracing::debug!(
436 "Found pending request for ID {}, sending response",
437 response.id
438 );
439 let _ = sender.send(response);
440 stats.write().await.responses_received += 1;
441 } else {
442 tracing::warn!(
443 "Received response for unknown request ID: {}",
444 response.id
445 );
446 }
447 }
448 JsonRpcMessage::Notification(notification) => {
449 tracing::debug!("Processing notification: {}", notification.method);
450 Self::handle_notification(&*notification_handler, notification).await;
452 stats.write().await.notifications_received += 1;
453 }
454 JsonRpcMessage::Request(_) => {
455 tracing::warn!("Received unexpected server-to-client request");
457 }
458 }
459 }
460 });
461
462 Ok(())
463 }
464
465 async fn handle_notification(
466 handler: &dyn NotificationHandler,
467 notification: JsonRpcNotification,
468 ) {
469 match notification.method.as_str() {
470 "notifications/progress" => {
471 if let Some(params) = notification.params {
472 if let Ok(progress) = serde_json::from_value::<ProgressNotification>(params) {
473 let _ = handler.handle_progress(progress).await;
474 }
475 }
476 }
477 "notifications/resources/updated" => {
478 if let Some(params) = notification.params {
479 if let Ok(resource_updated) =
480 serde_json::from_value::<ResourceUpdatedNotification>(params)
481 {
482 let _ = handler.handle_resource_updated(resource_updated).await;
483 }
484 }
485 }
486 "notifications/resources/list_changed" => {
487 if let Some(params) = notification.params {
488 if let Ok(list_changed) =
489 serde_json::from_value::<ResourceListChangedNotification>(params)
490 {
491 let _ = handler.handle_resource_list_changed(list_changed).await;
492 }
493 }
494 }
495 "notifications/tools/list_changed" => {
496 if let Some(params) = notification.params {
497 if let Ok(list_changed) =
498 serde_json::from_value::<ToolListChangedNotification>(params)
499 {
500 let _ = handler.handle_tool_list_changed(list_changed).await;
501 }
502 }
503 }
504 "notifications/prompts/list_changed" => {
505 if let Some(params) = notification.params {
506 if let Ok(list_changed) =
507 serde_json::from_value::<PromptListChangedNotification>(params)
508 {
509 let _ = handler.handle_prompt_list_changed(list_changed).await;
510 }
511 }
512 }
513 _ => {
514 warn!("Unknown notification method: {}", notification.method);
515 }
516 }
517 }
518
519 async fn perform_initialization(
520 &mut self,
521 client_info: Implementation,
522 ) -> McpResult<ServerInfo> {
523 *self.state.write().await = ClientState::Initializing;
524 tracing::info!("Starting MCP protocol initialization");
525
526 let capabilities = Capabilities {
528 standard: crate::messages::StandardCapabilities {
529 tools: Some(crate::messages::ToolCapabilities {
530 list_changed: Some(true),
531 }),
532 resources: Some(crate::messages::ResourceCapabilities {
533 subscribe: Some(true),
534 list_changed: Some(true),
535 }),
536 prompts: Some(crate::messages::PromptCapabilities {
537 list_changed: Some(true),
538 }),
539 ..Default::default()
540 },
541 ..Default::default()
542 };
543
544 let request = InitializeRequest {
545 protocol_version: ProtocolVersion::default(),
546 capabilities,
547 client_info,
548 };
549
550 tracing::debug!("Sending initialize request: {:?}", request);
551
552 let response = self
554 .send_initialization_request("initialize", request, Some(self.config.init_timeout))
555 .await?;
556
557 tracing::debug!("Received initialize response: {:?}", response);
559 let init_response: InitializeResponse = match response.result {
560 Some(result) => {
561 tracing::debug!("Parsing initialize response result: {:?}", result);
562 serde_json::from_value(result)?
563 }
564 None => {
565 tracing::error!("Initialize response missing result field");
566 return Err(McpError::Protocol(ProtocolError::InitializationFailed {
567 reason: "Missing result in initialize response".to_string(),
568 }));
569 }
570 };
571
572 tracing::info!(
573 "Successfully parsed initialize response from server: {}",
574 init_response.server_info.name
575 );
576
577 let initialized = InitializedNotification {
579 metadata: HashMap::new(), };
581 tracing::debug!("Sending initialized notification");
582 self.send_initialized_notification("initialized", initialized)
583 .await?;
584
585 let server_info = ServerInfo {
587 implementation: init_response.server_info,
588 protocol_version: init_response.protocol_version,
589 capabilities: init_response.capabilities,
590 connected_at: Instant::now(),
591 };
592
593 Ok(server_info)
594 }
595
596 async fn send_initialization_request<T>(
598 &mut self,
599 method: &str,
600 params: T,
601 timeout_duration: Option<Duration>,
602 ) -> McpResult<JsonRpcResponse>
603 where
604 T: serde::Serialize,
605 {
606 tracing::debug!("Sending initialization request: {}", method);
607 let request_id = self.generate_request_id();
608 let request = JsonRpcRequest {
609 jsonrpc: "2.0".to_string(),
610 id: JsonRpcId::String(request_id.clone()),
611 method: method.to_string(),
612 params: Some(serde_json::to_value(params)?),
613 };
614
615 let timeout_val = timeout_duration.unwrap_or(self.config.request_timeout);
616
617 self.send_request_with_retries(request, timeout_val).await
619 }
620
621 async fn send_initialized_notification<T>(&mut self, method: &str, params: T) -> McpResult<()>
623 where
624 T: serde::Serialize,
625 {
626 tracing::debug!("Sending initialization notification: {}", method);
627
628 let notification = JsonRpcNotification {
629 jsonrpc: "2.0".to_string(),
630 method: method.to_string(),
631 params: Some(serde_json::to_value(params)?),
632 };
633
634 self.transport.send_notification(notification).await?;
635 self.stats.write().await.notifications_sent += 1;
636 tracing::debug!("Initialization notification sent successfully");
637 Ok(())
638 }
639
640 async fn send_request_with_timeout<T>(
641 &mut self,
642 method: &str,
643 params: T,
644 timeout_duration: Option<Duration>,
645 ) -> McpResult<JsonRpcResponse>
646 where
647 T: serde::Serialize,
648 {
649 let request_id = self.generate_request_id();
650 let request = JsonRpcRequest {
651 jsonrpc: "2.0".to_string(),
652 id: JsonRpcId::String(request_id.clone()),
653 method: method.to_string(),
654 params: Some(serde_json::to_value(params)?),
655 };
656
657 let timeout_val = timeout_duration.unwrap_or(self.config.request_timeout);
658
659 self.send_request_with_retries(request, timeout_val).await
661 }
662
663 async fn send_request_with_retries(
664 &mut self,
665 request: JsonRpcRequest,
666 timeout_duration: Duration,
667 ) -> McpResult<JsonRpcResponse> {
668 let mut last_error = None;
669
670 for attempt in 0..=self.config.max_retries {
671 match self
672 .send_single_request(request.clone(), timeout_duration)
673 .await
674 {
675 Ok(response) => {
676 if attempt > 0 {
677 self.stats.write().await.retries += attempt as u64;
678 }
679 return Ok(response);
680 }
681 Err(e) => {
682 last_error = Some(e);
683
684 if attempt < self.config.max_retries {
685 let delay = self.config.retry_base_delay * 2_u32.pow(attempt);
686 debug!(
687 "Request failed, retrying in {:?} (attempt {} of {})",
688 delay,
689 attempt + 1,
690 self.config.max_retries + 1
691 );
692 sleep(delay).await;
693 }
694 }
695 }
696 }
697
698 self.stats.write().await.errors += 1;
699 Err(last_error.unwrap())
700 }
701
702 async fn send_single_request(
703 &mut self,
704 request: JsonRpcRequest,
705 timeout_duration: Duration,
706 ) -> McpResult<JsonRpcResponse> {
707 let request_id = request.id.to_string();
708 tracing::debug!("Sending single request with ID: {}", request_id);
709
710 let response = self
712 .transport
713 .send_request(request, Some(timeout_duration))
714 .await?;
715 self.stats.write().await.requests_sent += 1;
716
717 tracing::debug!("Received response for request ID: {}", response.id);
718 Ok(response)
719 }
720}
721
722pub struct McpClientBuilder {
724 transport_config: Option<TransportConfig>,
725 client_config: ClientConfig,
726 notification_handler: Option<Box<dyn NotificationHandler>>,
727}
728
729impl McpClientBuilder {
730 pub fn new() -> Self {
732 Self {
733 transport_config: None,
734 client_config: ClientConfig::default(),
735 notification_handler: None,
736 }
737 }
738
739 pub fn transport(mut self, config: TransportConfig) -> Self {
741 self.transport_config = Some(config);
742 self
743 }
744
745 pub fn config(mut self, config: ClientConfig) -> Self {
747 self.client_config = config;
748 self
749 }
750
751 pub fn notification_handler(mut self, handler: Box<dyn NotificationHandler>) -> Self {
753 self.notification_handler = Some(handler);
754 self
755 }
756
757 pub fn request_timeout(mut self, timeout: Duration) -> Self {
759 self.client_config.request_timeout = timeout;
760 self
761 }
762
763 pub fn init_timeout(mut self, timeout: Duration) -> Self {
765 self.client_config.init_timeout = timeout;
766 self
767 }
768
769 pub fn max_retries(mut self, retries: u32) -> Self {
771 self.client_config.max_retries = retries;
772 self
773 }
774
775 pub async fn build(self) -> McpResult<McpClient> {
777 let transport_config = self.transport_config.ok_or_else(|| {
778 McpError::Protocol(ProtocolError::InvalidConfig {
779 reason: "Transport configuration is required".to_string(),
780 })
781 })?;
782
783 let notification_handler = self
784 .notification_handler
785 .unwrap_or_else(|| Box::new(DefaultNotificationHandler));
786
787 McpClient::new(transport_config, self.client_config, notification_handler).await
788 }
789}
790
791impl Default for McpClientBuilder {
792 fn default() -> Self {
793 Self::new()
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800 use crate::transport::TransportConfig;
801
802 #[tokio::test]
803 async fn test_client_creation() {
804 let config = TransportConfig::stdio("echo", &[] as &[String]);
805 let client_config = ClientConfig::default();
806 let handler = Box::new(DefaultNotificationHandler);
807 let client = McpClient::new(config, client_config, handler)
808 .await
809 .unwrap();
810
811 assert_eq!(client.state().await, ClientState::Disconnected);
812 }
813
814 #[tokio::test]
815 async fn test_client_with_defaults() {
816 let config = TransportConfig::stdio("echo", &[] as &[String]);
817 let client = McpClient::with_defaults(config).await.unwrap();
818
819 assert_eq!(client.state().await, ClientState::Disconnected);
820 assert!(!client.is_ready().await);
821 }
822
823 #[test]
824 fn test_client_config_defaults() {
825 let config = ClientConfig::default();
826 assert_eq!(config.request_timeout, Duration::from_secs(30));
827 assert_eq!(config.init_timeout, Duration::from_secs(10));
828 assert_eq!(config.max_retries, 3);
829 }
830}