1use futures::channel::oneshot;
14use mcpkit_core::capability::{
15 is_version_supported, ClientCapabilities, ClientInfo, InitializeRequest, InitializeResult,
16 ServerCapabilities, ServerInfo, PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS,
17};
18use mcpkit_core::error::{HandshakeDetails, JsonRpcError, McpError, TransportContext, TransportDetails, TransportErrorKind};
19use mcpkit_core::protocol::{Message, Notification, Request, RequestId, Response};
20use mcpkit_core::types::{
21 CallToolRequest, CallToolResult, CreateMessageRequest, ElicitRequest,
22 GetPromptRequest, GetPromptResult, ListPromptsResult, ListResourcesResult,
23 ListResourceTemplatesResult, ListToolsResult, Prompt, ReadResourceRequest,
24 ReadResourceResult, Resource, ResourceContents, ResourceTemplate, Tool,
25};
26use mcpkit_transport::Transport;
27use std::collections::HashMap;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::sync::Arc;
30use tracing::{debug, error, info, trace, warn};
31
32use async_lock::RwLock;
34
35#[cfg(feature = "tokio-runtime")]
37use tokio::sync::mpsc;
38
39use crate::handler::ClientHandler;
40
41pub struct Client<T: Transport, H: ClientHandler = crate::handler::NoOpHandler> {
72 transport: Arc<T>,
74 server_info: ServerInfo,
76 server_caps: ServerCapabilities,
78 client_info: ClientInfo,
80 client_caps: ClientCapabilities,
82 next_id: AtomicU64,
84 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
86 instructions: Option<String>,
88 handler: Arc<H>,
90 outgoing_tx: mpsc::Sender<Message>,
92 running: Arc<AtomicBool>,
94 _background_handle: Option<tokio::task::JoinHandle<()>>,
96}
97
98impl<T: Transport + 'static> Client<T, crate::handler::NoOpHandler> {
99 pub(crate) fn new(
101 transport: T,
102 init_result: InitializeResult,
103 client_info: ClientInfo,
104 client_caps: ClientCapabilities,
105 ) -> Self {
106 Self::with_handler(transport, init_result, client_info, client_caps, crate::handler::NoOpHandler)
107 }
108}
109
110impl<T: Transport + 'static, H: ClientHandler + 'static> Client<T, H> {
111 pub(crate) fn with_handler(
113 transport: T,
114 init_result: InitializeResult,
115 client_info: ClientInfo,
116 client_caps: ClientCapabilities,
117 handler: H,
118 ) -> Self {
119 let transport = Arc::new(transport);
120 let pending = Arc::new(RwLock::new(HashMap::new()));
121 let handler = Arc::new(handler);
122 let running = Arc::new(AtomicBool::new(true));
123
124 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Message>(256);
126
127 let background_handle = Self::spawn_message_router(
129 Arc::clone(&transport),
130 Arc::clone(&pending),
131 Arc::clone(&handler),
132 Arc::clone(&running),
133 outgoing_rx,
134 );
135
136 let handler_clone = Arc::clone(&handler);
138 tokio::spawn(async move {
139 handler_clone.on_connected().await;
140 });
141
142 Self {
143 transport,
144 server_info: init_result.server_info,
145 server_caps: init_result.capabilities,
146 client_info,
147 client_caps,
148 next_id: AtomicU64::new(1),
149 pending,
150 instructions: init_result.instructions,
151 handler,
152 outgoing_tx,
153 running,
154 _background_handle: Some(background_handle),
155 }
156 }
157
158 fn spawn_message_router(
166 transport: Arc<T>,
167 pending: Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
168 handler: Arc<H>,
169 running: Arc<AtomicBool>,
170 mut outgoing_rx: mpsc::Receiver<Message>,
171 ) -> tokio::task::JoinHandle<()> {
172 tokio::spawn(async move {
173 debug!("Starting client message router");
174
175 loop {
176 if !running.load(Ordering::SeqCst) {
177 debug!("Message router stopping (client closed)");
178 break;
179 }
180
181 tokio::select! {
182 Some(msg) = outgoing_rx.recv() => {
184 if let Err(e) = transport.send(msg).await {
185 error!(?e, "Failed to send message");
186 }
187 }
188
189 result = transport.recv() => {
191 match result {
192 Ok(Some(message)) => {
193 Self::handle_incoming_message(
194 message,
195 &pending,
196 &handler,
197 &transport,
198 ).await;
199 }
200 Ok(None) => {
201 info!("Connection closed by server");
202 running.store(false, Ordering::SeqCst);
203 handler.on_disconnected().await;
204 break;
205 }
206 Err(e) => {
207 error!(?e, "Transport error in message router");
208 running.store(false, Ordering::SeqCst);
209 handler.on_disconnected().await;
210 break;
211 }
212 }
213 }
214 }
215 }
216
217 debug!("Message router stopped");
218 })
219 }
220
221 async fn handle_incoming_message(
223 message: Message,
224 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
225 handler: &Arc<H>,
226 transport: &Arc<T>,
227 ) {
228 match message {
229 Message::Response(response) => {
230 Self::route_response(response, pending).await;
231 }
232 Message::Request(request) => {
233 Self::handle_server_request(request, handler, transport).await;
234 }
235 Message::Notification(notification) => {
236 Self::handle_notification(notification, handler).await;
237 }
238 }
239 }
240
241 async fn route_response(
243 response: Response,
244 pending: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<Response>>>>,
245 ) {
246 let sender = {
247 let mut pending_guard = pending.write().await;
248 pending_guard.remove(&response.id)
249 };
250
251 if let Some(sender) = sender {
252 trace!(?response.id, "Routing response to pending request");
253 if sender.send(response).is_err() {
254 warn!("Pending request receiver dropped");
255 }
256 } else {
257 warn!(?response.id, "Received response for unknown request");
258 }
259 }
260
261 async fn handle_server_request(
263 request: Request,
264 handler: &Arc<H>,
265 transport: &Arc<T>,
266 ) {
267 trace!(method = %request.method, "Handling server request");
268
269 let response = match request.method.as_ref() {
270 "sampling/createMessage" => {
271 Self::handle_sampling_request(&request, handler).await
272 }
273 "elicitation/elicit" => {
274 Self::handle_elicitation_request(&request, handler).await
275 }
276 "roots/list" => {
277 Self::handle_roots_request(&request, handler).await
278 }
279 "ping" => {
280 Response::success(request.id.clone(), serde_json::json!({}))
282 }
283 _ => {
284 warn!(method = %request.method, "Unknown server request method");
285 Response::error(
286 request.id.clone(),
287 JsonRpcError::method_not_found(&format!("Unknown method: {}", request.method)),
288 )
289 }
290 };
291
292 if let Err(e) = transport.send(Message::Response(response)).await {
294 error!(?e, "Failed to send response to server request");
295 }
296 }
297
298 async fn handle_sampling_request(request: &Request, handler: &Arc<H>) -> Response {
300 let params = match &request.params {
301 Some(p) => match serde_json::from_value::<CreateMessageRequest>(p.clone()) {
302 Ok(req) => req,
303 Err(e) => {
304 return Response::error(
305 request.id.clone(),
306 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
307 );
308 }
309 },
310 None => {
311 return Response::error(
312 request.id.clone(),
313 JsonRpcError::invalid_params("Missing params for sampling/createMessage"),
314 );
315 }
316 };
317
318 match handler.create_message(params).await {
319 Ok(result) => {
320 match serde_json::to_value(result) {
321 Ok(value) => Response::success(request.id.clone(), value),
322 Err(e) => Response::error(
323 request.id.clone(),
324 JsonRpcError::internal_error(format!("Serialization error: {e}")),
325 ),
326 }
327 }
328 Err(e) => Response::error(
329 request.id.clone(),
330 JsonRpcError::internal_error(e.to_string()),
331 ),
332 }
333 }
334
335 async fn handle_elicitation_request(request: &Request, handler: &Arc<H>) -> Response {
337 let params = match &request.params {
338 Some(p) => match serde_json::from_value::<ElicitRequest>(p.clone()) {
339 Ok(req) => req,
340 Err(e) => {
341 return Response::error(
342 request.id.clone(),
343 JsonRpcError::invalid_params(format!("Invalid params: {e}")),
344 );
345 }
346 },
347 None => {
348 return Response::error(
349 request.id.clone(),
350 JsonRpcError::invalid_params("Missing params for elicitation/elicit"),
351 );
352 }
353 };
354
355 match handler.elicit(params).await {
356 Ok(result) => {
357 match serde_json::to_value(result) {
358 Ok(value) => Response::success(request.id.clone(), value),
359 Err(e) => Response::error(
360 request.id.clone(),
361 JsonRpcError::internal_error(format!("Serialization error: {e}")),
362 ),
363 }
364 }
365 Err(e) => Response::error(
366 request.id.clone(),
367 JsonRpcError::internal_error(e.to_string()),
368 ),
369 }
370 }
371
372 async fn handle_roots_request(request: &Request, handler: &Arc<H>) -> Response {
374 match handler.list_roots().await {
375 Ok(roots) => {
376 let roots_json: Vec<serde_json::Value> = roots
377 .into_iter()
378 .map(|r| {
379 serde_json::json!({
380 "uri": r.uri,
381 "name": r.name
382 })
383 })
384 .collect();
385 Response::success(request.id.clone(), serde_json::json!({ "roots": roots_json }))
386 }
387 Err(e) => Response::error(
388 request.id.clone(),
389 JsonRpcError::internal_error(&e.to_string()),
390 ),
391 }
392 }
393
394 async fn handle_notification(notification: Notification, _handler: &Arc<H>) {
396 trace!(method = %notification.method, "Received server notification");
397
398 match notification.method.as_ref() {
399 "notifications/cancelled" => {
400 if let Some(params) = notification.params {
402 if let Some(request_id) = params.get("requestId") {
403 debug!(?request_id, "Server cancelled request");
404 }
405 }
406 }
407 "notifications/progress" => {
408 trace!("Received progress notification");
410 }
411 "notifications/resources/updated" => {
412 trace!("Resources updated notification");
413 }
414 "notifications/tools/list_changed" => {
415 trace!("Tools list changed notification");
416 }
417 "notifications/prompts/list_changed" => {
418 trace!("Prompts list changed notification");
419 }
420 _ => {
421 trace!(method = %notification.method, "Unhandled notification");
422 }
423 }
424 }
425
426 pub fn server_info(&self) -> &ServerInfo {
428 &self.server_info
429 }
430
431 pub fn server_capabilities(&self) -> &ServerCapabilities {
433 &self.server_caps
434 }
435
436 pub fn client_info(&self) -> &ClientInfo {
438 &self.client_info
439 }
440
441 pub fn client_capabilities(&self) -> &ClientCapabilities {
443 &self.client_caps
444 }
445
446 pub fn instructions(&self) -> Option<&str> {
448 self.instructions.as_deref()
449 }
450
451 pub fn has_tools(&self) -> bool {
453 self.server_caps.has_tools()
454 }
455
456 pub fn has_resources(&self) -> bool {
458 self.server_caps.has_resources()
459 }
460
461 pub fn has_prompts(&self) -> bool {
463 self.server_caps.has_prompts()
464 }
465
466 pub fn has_tasks(&self) -> bool {
468 self.server_caps.has_tasks()
469 }
470
471 pub fn is_connected(&self) -> bool {
473 self.running.load(Ordering::SeqCst)
474 }
475
476 pub async fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
486 self.ensure_capability("tools", self.has_tools())?;
487
488 let result: ListToolsResult = self.request("tools/list", None).await?;
489 Ok(result.tools)
490 }
491
492 pub async fn list_tools_paginated(
498 &self,
499 cursor: Option<&str>,
500 ) -> Result<ListToolsResult, McpError> {
501 self.ensure_capability("tools", self.has_tools())?;
502
503 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
504 self.request("tools/list", params).await
505 }
506
507 pub async fn call_tool(
518 &self,
519 name: impl Into<String>,
520 arguments: serde_json::Value,
521 ) -> Result<CallToolResult, McpError> {
522 self.ensure_capability("tools", self.has_tools())?;
523
524 let request = CallToolRequest {
525 name: name.into(),
526 arguments: Some(arguments),
527 };
528 self.request("tools/call", Some(serde_json::to_value(request)?))
529 .await
530 }
531
532 pub async fn list_resources(&self) -> Result<Vec<Resource>, McpError> {
542 self.ensure_capability("resources", self.has_resources())?;
543
544 let result: ListResourcesResult = self.request("resources/list", None).await?;
545 Ok(result.resources)
546 }
547
548 pub async fn list_resources_paginated(
554 &self,
555 cursor: Option<&str>,
556 ) -> Result<ListResourcesResult, McpError> {
557 self.ensure_capability("resources", self.has_resources())?;
558
559 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
560 self.request("resources/list", params).await
561 }
562
563 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, McpError> {
569 self.ensure_capability("resources", self.has_resources())?;
570
571 let result: ListResourceTemplatesResult =
572 self.request("resources/templates/list", None).await?;
573 Ok(result.resource_templates)
574 }
575
576 pub async fn read_resource(&self, uri: impl Into<String>) -> Result<Vec<ResourceContents>, McpError> {
582 self.ensure_capability("resources", self.has_resources())?;
583
584 let request = ReadResourceRequest { uri: uri.into() };
585 let result: ReadResourceResult =
586 self.request("resources/read", Some(serde_json::to_value(request)?))
587 .await?;
588 Ok(result.contents)
589 }
590
591 pub async fn list_prompts(&self) -> Result<Vec<Prompt>, McpError> {
601 self.ensure_capability("prompts", self.has_prompts())?;
602
603 let result: ListPromptsResult = self.request("prompts/list", None).await?;
604 Ok(result.prompts)
605 }
606
607 pub async fn list_prompts_paginated(
613 &self,
614 cursor: Option<&str>,
615 ) -> Result<ListPromptsResult, McpError> {
616 self.ensure_capability("prompts", self.has_prompts())?;
617
618 let params = cursor.map(|c| serde_json::json!({ "cursor": c }));
619 self.request("prompts/list", params).await
620 }
621
622 pub async fn get_prompt(
628 &self,
629 name: impl Into<String>,
630 arguments: Option<serde_json::Map<String, serde_json::Value>>,
631 ) -> Result<GetPromptResult, McpError> {
632 self.ensure_capability("prompts", self.has_prompts())?;
633
634 let request = GetPromptRequest {
635 name: name.into(),
636 arguments,
637 };
638 self.request("prompts/get", Some(serde_json::to_value(request)?))
639 .await
640 }
641
642 pub async fn ping(&self) -> Result<(), McpError> {
652 let _: serde_json::Value = self.request("ping", None).await?;
653 Ok(())
654 }
655
656 pub async fn close(self) -> Result<(), McpError> {
662 debug!("Closing client connection");
663
664 self.running.store(false, Ordering::SeqCst);
666
667 self.handler.on_disconnected().await;
669
670 self.transport.close().await.map_err(|e| {
672 McpError::Transport(Box::new(TransportDetails {
673 kind: TransportErrorKind::ConnectionClosed,
674 message: e.to_string(),
675 context: TransportContext::default(),
676 source: None,
677 }))
678 })
679 }
680
681 fn next_request_id(&self) -> RequestId {
687 RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst))
688 }
689
690 async fn request<R: serde::de::DeserializeOwned>(
692 &self,
693 method: &str,
694 params: Option<serde_json::Value>,
695 ) -> Result<R, McpError> {
696 if !self.is_connected() {
697 return Err(McpError::Transport(Box::new(TransportDetails {
698 kind: TransportErrorKind::ConnectionClosed,
699 message: "Client is not connected".to_string(),
700 context: TransportContext::default(),
701 source: None,
702 })));
703 }
704
705 let id = self.next_request_id();
706 let request = if let Some(params) = params {
707 Request::with_params(method.to_string(), id.clone(), params)
708 } else {
709 Request::new(method.to_string(), id.clone())
710 };
711
712 trace!(?id, method, "Sending request");
713
714 let (tx, rx) = oneshot::channel();
716 {
717 let mut pending = self.pending.write().await;
718 pending.insert(id.clone(), tx);
719 }
720
721 self.outgoing_tx
723 .send(Message::Request(request))
724 .await
725 .map_err(|_| McpError::Transport(Box::new(TransportDetails {
726 kind: TransportErrorKind::WriteFailed,
727 message: "Failed to send request (channel closed)".to_string(),
728 context: TransportContext::default(),
729 source: None,
730 })))?;
731
732 let response = rx.await.map_err(|_| McpError::Transport(Box::new(TransportDetails {
734 kind: TransportErrorKind::ConnectionClosed,
735 message: "Response channel closed (server may have disconnected)".to_string(),
736 context: TransportContext::default(),
737 source: None,
738 })))?;
739
740 if let Some(error) = response.error {
742 return Err(McpError::Internal {
743 message: error.message,
744 source: None,
745 });
746 }
747
748 let result = response.result.ok_or_else(|| McpError::Internal {
749 message: "Response contained neither result nor error".to_string(),
750 source: None,
751 })?;
752
753 serde_json::from_value(result).map_err(McpError::from)
754 }
755
756 fn ensure_capability(&self, name: &str, supported: bool) -> Result<(), McpError> {
758 if supported {
759 Ok(())
760 } else {
761 Err(McpError::CapabilityNotSupported {
762 capability: name.to_string(),
763 available: self.available_capabilities().into_boxed_slice(),
764 })
765 }
766 }
767
768 fn available_capabilities(&self) -> Vec<String> {
770 let mut caps = Vec::new();
771 if self.has_tools() {
772 caps.push("tools".to_string());
773 }
774 if self.has_resources() {
775 caps.push("resources".to_string());
776 }
777 if self.has_prompts() {
778 caps.push("prompts".to_string());
779 }
780 if self.has_tasks() {
781 caps.push("tasks".to_string());
782 }
783 caps
784 }
785}
786
787pub(crate) async fn initialize<T: Transport>(
804 transport: &T,
805 client_info: &ClientInfo,
806 capabilities: &ClientCapabilities,
807) -> Result<InitializeResult, McpError> {
808 debug!(
809 protocol_version = %PROTOCOL_VERSION,
810 supported_versions = ?SUPPORTED_PROTOCOL_VERSIONS,
811 "Initializing MCP connection"
812 );
813
814 let request = InitializeRequest::new(client_info.clone(), capabilities.clone());
816 let init_request = Request::with_params(
817 "initialize".to_string(),
818 RequestId::Number(0),
819 serde_json::to_value(&request)?,
820 );
821
822 transport
824 .send(Message::Request(init_request))
825 .await
826 .map_err(|e| McpError::Transport(Box::new(TransportDetails {
827 kind: TransportErrorKind::WriteFailed,
828 message: format!("Failed to send initialize: {e}"),
829 context: TransportContext::default(),
830 source: None,
831 })))?;
832
833 let response = loop {
835 match transport.recv().await {
836 Ok(Some(Message::Response(r))) if r.id == RequestId::Number(0) => break r,
837 Ok(Some(_)) => continue,
838 Ok(None) => {
839 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
840 message: "Connection closed during initialization".to_string(),
841 client_version: Some(PROTOCOL_VERSION.to_string()),
842 server_version: None,
843 source: None,
844 })));
845 }
846 Err(e) => {
847 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
848 message: format!("Transport error during initialization: {e}"),
849 client_version: Some(PROTOCOL_VERSION.to_string()),
850 server_version: None,
851 source: None,
852 })));
853 }
854 }
855 };
856
857 if let Some(error) = response.error {
859 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
860 message: error.message,
861 client_version: Some(PROTOCOL_VERSION.to_string()),
862 server_version: None,
863 source: None,
864 })));
865 }
866
867 let result: InitializeResult = response
868 .result
869 .map(serde_json::from_value)
870 .transpose()?
871 .ok_or_else(|| McpError::HandshakeFailed(Box::new(HandshakeDetails {
872 message: "Empty initialize result".to_string(),
873 client_version: Some(PROTOCOL_VERSION.to_string()),
874 server_version: None,
875 source: None,
876 })))?;
877
878 let server_version = &result.protocol_version;
880 if !is_version_supported(server_version) {
881 warn!(
882 server_version = %server_version,
883 supported = ?SUPPORTED_PROTOCOL_VERSIONS,
884 "Server returned unsupported protocol version"
885 );
886 return Err(McpError::HandshakeFailed(Box::new(HandshakeDetails {
887 message: format!(
888 "Unsupported protocol version: server returned '{}', but client only supports {:?}",
889 server_version, SUPPORTED_PROTOCOL_VERSIONS
890 ),
891 client_version: Some(PROTOCOL_VERSION.to_string()),
892 server_version: Some(server_version.clone()),
893 source: None,
894 })));
895 }
896
897 debug!(
898 server = %result.server_info.name,
899 server_version = %result.server_info.version,
900 protocol_version = %result.protocol_version,
901 "Received initialize result with compatible protocol version"
902 );
903
904 let notification = Notification::new("notifications/initialized");
906 transport
907 .send(Message::Notification(notification))
908 .await
909 .map_err(|e| McpError::Transport(Box::new(TransportDetails {
910 kind: TransportErrorKind::WriteFailed,
911 message: format!("Failed to send initialized: {e}"),
912 context: TransportContext::default(),
913 source: None,
914 })))?;
915
916 debug!("MCP initialization complete");
917 Ok(result)
918}
919
920#[cfg(test)]
921mod tests {
922 use super::*;
923
924 #[test]
925 fn test_request_id_generation() {
926 let next_id = AtomicU64::new(1);
927 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 1);
928 assert_eq!(next_id.fetch_add(1, Ordering::SeqCst), 2);
929 }
930}