1mod channel;
34mod handler;
35#[cfg(feature = "http-client")]
36mod http;
37#[cfg(feature = "oauth-client")]
38mod oauth;
39#[cfg(feature = "oauth-client")]
40mod oauth_authcode;
41mod stdio;
42mod transport;
43
44pub use channel::ChannelTransport;
45pub use handler::{ClientHandler, NotificationHandler, ServerNotification};
46#[cfg(feature = "http-client")]
47pub use http::{HttpClientConfig, HttpClientTransport};
48#[cfg(feature = "oauth-client")]
49pub use oauth::{
50 OAuthClientCredentials, OAuthClientCredentialsBuilder, OAuthClientError, TokenProvider,
51};
52#[cfg(feature = "oauth-client")]
53pub use oauth_authcode::{OAuthAuthCodeConfig, OAuthAuthorizationCode};
54pub use stdio::StdioClientTransport;
55pub use transport::ClientTransport;
56
57use std::collections::HashMap;
58use std::sync::Arc;
59use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
60
61use tokio::sync::{Mutex, RwLock, mpsc, oneshot};
62use tokio::task::JoinHandle;
63
64use crate::error::{Error, Result};
65use crate::protocol::{
66 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
67 CompletionArgument, CompletionReference, ElicitationCapability, GetPromptParams,
68 GetPromptResult, Implementation, InitializeParams, InitializeResult, JsonRpcNotification,
69 JsonRpcRequest, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
70 ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListRootsResult,
71 ListToolsParams, ListToolsResult, PromptDefinition, ReadResourceParams, ReadResourceResult,
72 RequestId, ResourceDefinition, ResourceTemplateDefinition, Root, RootsCapability,
73 SamplingCapability, ToolDefinition, notifications,
74};
75use tower_mcp_types::JsonRpcError;
76
77enum LoopCommand {
79 Request {
81 method: String,
82 params: serde_json::Value,
83 response_tx: oneshot::Sender<Result<serde_json::Value>>,
84 },
85 Notify {
87 method: String,
88 params: serde_json::Value,
89 },
90 ResetSession { done_tx: oneshot::Sender<()> },
92 Shutdown,
94}
95
96pub struct McpClient {
126 command_tx: mpsc::Sender<LoopCommand>,
128 task: Option<JoinHandle<()>>,
130 initialized: AtomicBool,
132 server_info: RwLock<Option<InitializeResult>>,
134 capabilities: ClientCapabilities,
136 roots: Arc<RwLock<Vec<Root>>>,
138 connected: Arc<AtomicBool>,
140 supports_session_recovery: bool,
142 init_params: RwLock<Option<(String, String)>>,
144 recovery_lock: Mutex<()>,
146}
147
148pub struct McpClientBuilder {
168 capabilities: ClientCapabilities,
169 roots: Vec<Root>,
170}
171
172impl McpClientBuilder {
173 pub fn new() -> Self {
175 Self {
176 capabilities: ClientCapabilities::default(),
177 roots: Vec::new(),
178 }
179 }
180
181 pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
186 self.roots = roots;
187 self.capabilities.roots = Some(RootsCapability { list_changed: true });
188 self
189 }
190
191 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
193 self.capabilities = capabilities;
194 self
195 }
196
197 pub fn with_sampling(mut self) -> Self {
204 self.capabilities.sampling = Some(SamplingCapability::default());
205 self
206 }
207
208 pub fn with_elicitation(mut self) -> Self {
215 self.capabilities.elicitation = Some(ElicitationCapability::default());
216 self
217 }
218
219 pub async fn connect<T, H>(self, transport: T, handler: H) -> Result<McpClient>
224 where
225 T: ClientTransport,
226 H: ClientHandler,
227 {
228 McpClient::connect_inner(transport, handler, self.capabilities, self.roots).await
229 }
230
231 pub async fn connect_simple<T: ClientTransport>(self, transport: T) -> Result<McpClient> {
235 self.connect(transport, ()).await
236 }
237}
238
239impl Default for McpClientBuilder {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245impl McpClient {
246 pub async fn connect<T: ClientTransport>(transport: T) -> Result<Self> {
250 McpClientBuilder::new().connect_simple(transport).await
251 }
252
253 pub async fn connect_with_handler<T, H>(transport: T, handler: H) -> Result<Self>
255 where
256 T: ClientTransport,
257 H: ClientHandler,
258 {
259 McpClientBuilder::new().connect(transport, handler).await
260 }
261
262 pub fn builder() -> McpClientBuilder {
264 McpClientBuilder::new()
265 }
266
267 async fn connect_inner<T, H>(
269 transport: T,
270 handler: H,
271 capabilities: ClientCapabilities,
272 roots: Vec<Root>,
273 ) -> Result<Self>
274 where
275 T: ClientTransport,
276 H: ClientHandler,
277 {
278 let supports_session_recovery = transport.supports_session_recovery();
279 let (command_tx, command_rx) = mpsc::channel::<LoopCommand>(64);
280 let connected = Arc::new(AtomicBool::new(true));
281 let roots = Arc::new(RwLock::new(roots));
282
283 let loop_connected = connected.clone();
284 let loop_roots = roots.clone();
285
286 let task = tokio::spawn(async move {
287 message_loop(transport, handler, command_rx, loop_connected, loop_roots).await;
288 });
289
290 Ok(Self {
291 command_tx,
292 task: Some(task),
293 initialized: AtomicBool::new(false),
294 server_info: RwLock::new(None),
295 capabilities,
296 roots,
297 connected,
298 supports_session_recovery,
299 init_params: RwLock::new(None),
300 recovery_lock: Mutex::new(()),
301 })
302 }
303
304 pub fn is_initialized(&self) -> bool {
306 self.initialized.load(Ordering::Acquire)
307 }
308
309 pub fn is_connected(&self) -> bool {
311 self.connected.load(Ordering::Acquire)
312 }
313
314 pub async fn server_info(&self) -> Option<InitializeResult> {
316 self.server_info.read().await.clone()
317 }
318
319 pub fn server_info_blocking(&self) -> Option<InitializeResult> {
325 self.server_info.try_read().ok()?.clone()
326 }
327
328 pub async fn initialize(
333 &self,
334 client_name: &str,
335 client_version: &str,
336 ) -> Result<InitializeResult> {
337 let params = InitializeParams {
338 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
339 capabilities: self.capabilities.clone(),
340 client_info: Implementation {
341 name: client_name.to_string(),
342 version: client_version.to_string(),
343 ..Default::default()
344 },
345 meta: None,
346 };
347
348 let result: InitializeResult = self.send_request("initialize", ¶ms).await?;
349 *self.server_info.write().await = Some(result.clone());
350
351 *self.init_params.write().await =
353 Some((client_name.to_string(), client_version.to_string()));
354
355 self.send_notification("notifications/initialized", &serde_json::json!({}))
357 .await?;
358 self.initialized.store(true, Ordering::Release);
359
360 Ok(result)
361 }
362
363 pub async fn list_tools(&self) -> Result<ListToolsResult> {
365 self.ensure_initialized()?;
366 self.send_request(
367 "tools/list",
368 &ListToolsParams {
369 cursor: None,
370 meta: None,
371 },
372 )
373 .await
374 }
375
376 pub async fn call_tool(
378 &self,
379 name: &str,
380 arguments: serde_json::Value,
381 ) -> Result<CallToolResult> {
382 self.ensure_initialized()?;
383 let params = CallToolParams {
384 name: name.to_string(),
385 arguments,
386 meta: None,
387 task: None,
388 };
389 self.send_request("tools/call", ¶ms).await
390 }
391
392 pub async fn list_resources(&self) -> Result<ListResourcesResult> {
394 self.ensure_initialized()?;
395 self.send_request(
396 "resources/list",
397 &ListResourcesParams {
398 cursor: None,
399 meta: None,
400 },
401 )
402 .await
403 }
404
405 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
407 self.ensure_initialized()?;
408 let params = ReadResourceParams {
409 uri: uri.to_string(),
410 meta: None,
411 };
412 self.send_request("resources/read", ¶ms).await
413 }
414
415 pub async fn list_prompts(&self) -> Result<ListPromptsResult> {
417 self.ensure_initialized()?;
418 self.send_request(
419 "prompts/list",
420 &ListPromptsParams {
421 cursor: None,
422 meta: None,
423 },
424 )
425 .await
426 }
427
428 pub async fn list_tools_with_cursor(&self, cursor: Option<String>) -> Result<ListToolsResult> {
430 self.ensure_initialized()?;
431 self.send_request("tools/list", &ListToolsParams { cursor, meta: None })
432 .await
433 }
434
435 pub async fn list_resources_with_cursor(
437 &self,
438 cursor: Option<String>,
439 ) -> Result<ListResourcesResult> {
440 self.ensure_initialized()?;
441 self.send_request(
442 "resources/list",
443 &ListResourcesParams { cursor, meta: None },
444 )
445 .await
446 }
447
448 pub async fn list_resource_templates(&self) -> Result<ListResourceTemplatesResult> {
450 self.ensure_initialized()?;
451 self.send_request(
452 "resources/templates/list",
453 &ListResourceTemplatesParams {
454 cursor: None,
455 meta: None,
456 },
457 )
458 .await
459 }
460
461 pub async fn list_resource_templates_with_cursor(
463 &self,
464 cursor: Option<String>,
465 ) -> Result<ListResourceTemplatesResult> {
466 self.ensure_initialized()?;
467 self.send_request(
468 "resources/templates/list",
469 &ListResourceTemplatesParams { cursor, meta: None },
470 )
471 .await
472 }
473
474 pub async fn list_prompts_with_cursor(
476 &self,
477 cursor: Option<String>,
478 ) -> Result<ListPromptsResult> {
479 self.ensure_initialized()?;
480 self.send_request("prompts/list", &ListPromptsParams { cursor, meta: None })
481 .await
482 }
483
484 pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>> {
486 let mut all = Vec::new();
487 let mut cursor = None;
488 loop {
489 let result = self.list_tools_with_cursor(cursor).await?;
490 all.extend(result.tools);
491 match result.next_cursor {
492 Some(c) => cursor = Some(c),
493 None => break,
494 }
495 }
496 Ok(all)
497 }
498
499 pub async fn list_all_resources(&self) -> Result<Vec<ResourceDefinition>> {
501 let mut all = Vec::new();
502 let mut cursor = None;
503 loop {
504 let result = self.list_resources_with_cursor(cursor).await?;
505 all.extend(result.resources);
506 match result.next_cursor {
507 Some(c) => cursor = Some(c),
508 None => break,
509 }
510 }
511 Ok(all)
512 }
513
514 pub async fn list_all_resource_templates(&self) -> Result<Vec<ResourceTemplateDefinition>> {
516 let mut all = Vec::new();
517 let mut cursor = None;
518 loop {
519 let result = self.list_resource_templates_with_cursor(cursor).await?;
520 all.extend(result.resource_templates);
521 match result.next_cursor {
522 Some(c) => cursor = Some(c),
523 None => break,
524 }
525 }
526 Ok(all)
527 }
528
529 pub async fn list_all_prompts(&self) -> Result<Vec<PromptDefinition>> {
531 let mut all = Vec::new();
532 let mut cursor = None;
533 loop {
534 let result = self.list_prompts_with_cursor(cursor).await?;
535 all.extend(result.prompts);
536 match result.next_cursor {
537 Some(c) => cursor = Some(c),
538 None => break,
539 }
540 }
541 Ok(all)
542 }
543
544 pub async fn call_tool_text(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
552 let result = self.call_tool(name, arguments).await?;
553 if result.is_error {
554 return Err(Error::Internal(result.all_text()));
555 }
556 Ok(result.all_text())
557 }
558
559 pub async fn get_prompt(
561 &self,
562 name: &str,
563 arguments: Option<std::collections::HashMap<String, String>>,
564 ) -> Result<GetPromptResult> {
565 self.ensure_initialized()?;
566 let params = GetPromptParams {
567 name: name.to_string(),
568 arguments: arguments.unwrap_or_default(),
569 meta: None,
570 };
571 self.send_request("prompts/get", ¶ms).await
572 }
573
574 pub async fn ping(&self) -> Result<()> {
576 let _: serde_json::Value = self.send_request("ping", &serde_json::json!({})).await?;
577 Ok(())
578 }
579
580 pub async fn complete(
582 &self,
583 reference: CompletionReference,
584 argument_name: &str,
585 argument_value: &str,
586 ) -> Result<CompleteResult> {
587 self.ensure_initialized()?;
588 let params = CompleteParams {
589 reference,
590 argument: CompletionArgument::new(argument_name, argument_value),
591 context: None,
592 meta: None,
593 };
594 self.send_request("completion/complete", ¶ms).await
595 }
596
597 pub async fn complete_prompt_arg(
599 &self,
600 prompt_name: &str,
601 argument_name: &str,
602 argument_value: &str,
603 ) -> Result<CompleteResult> {
604 self.complete(
605 CompletionReference::prompt(prompt_name),
606 argument_name,
607 argument_value,
608 )
609 .await
610 }
611
612 pub async fn complete_resource_uri(
614 &self,
615 resource_uri: &str,
616 argument_name: &str,
617 argument_value: &str,
618 ) -> Result<CompleteResult> {
619 self.complete(
620 CompletionReference::resource(resource_uri),
621 argument_name,
622 argument_value,
623 )
624 .await
625 }
626
627 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
629 &self,
630 method: &str,
631 params: &P,
632 ) -> Result<R> {
633 self.send_request(method, params).await
634 }
635
636 pub async fn notify<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
638 self.send_notification(method, params).await
639 }
640
641 pub async fn roots(&self) -> Vec<Root> {
643 self.roots.read().await.clone()
644 }
645
646 pub async fn set_roots(&self, roots: Vec<Root>) -> Result<()> {
648 *self.roots.write().await = roots;
649 if self.is_initialized() {
650 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
651 .await?;
652 }
653 Ok(())
654 }
655
656 pub async fn add_root(&self, root: Root) -> Result<()> {
658 self.roots.write().await.push(root);
659 if self.is_initialized() {
660 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
661 .await?;
662 }
663 Ok(())
664 }
665
666 pub async fn remove_root(&self, uri: &str) -> Result<bool> {
668 let mut roots = self.roots.write().await;
669 let initial_len = roots.len();
670 roots.retain(|r| r.uri != uri);
671 let removed = roots.len() < initial_len;
672 drop(roots);
673
674 if removed && self.is_initialized() {
675 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
676 .await?;
677 }
678 Ok(removed)
679 }
680
681 pub async fn list_roots(&self) -> ListRootsResult {
683 ListRootsResult {
684 roots: self.roots.read().await.clone(),
685 meta: None,
686 }
687 }
688
689 pub async fn shutdown(mut self) -> Result<()> {
691 let _ = self.command_tx.send(LoopCommand::Shutdown).await;
692 if let Some(task) = self.task.take() {
693 let _ = task.await;
694 }
695 Ok(())
696 }
697
698 async fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
701 &self,
702 method: &str,
703 params: &P,
704 ) -> Result<R> {
705 match self.send_request_once(method, params).await {
706 Err(Error::SessionExpired)
707 if self.supports_session_recovery && method != "initialize" =>
708 {
709 tracing::info!(method = %method, "Session expired, attempting recovery");
710 self.recover_session().await?;
711 self.send_request_once(method, params).await
712 }
713 other => other,
714 }
715 }
716
717 async fn send_request_once<P: serde::Serialize, R: serde::de::DeserializeOwned>(
718 &self,
719 method: &str,
720 params: &P,
721 ) -> Result<R> {
722 self.ensure_connected()?;
723 let params_value = serde_json::to_value(params)
724 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
725
726 let (response_tx, response_rx) = oneshot::channel();
727 self.command_tx
728 .send(LoopCommand::Request {
729 method: method.to_string(),
730 params: params_value,
731 response_tx,
732 })
733 .await
734 .map_err(|_| Error::Transport("Connection closed".to_string()))?;
735
736 let result = response_rx
737 .await
738 .map_err(|_| Error::Transport("Connection closed".to_string()))??;
739
740 serde_json::from_value(result)
741 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
742 }
743
744 async fn recover_session(&self) -> Result<()> {
746 let _guard = self.recovery_lock.lock().await;
748
749 let init_params = self.init_params.read().await.clone();
752 let (client_name, client_version) = match init_params {
753 Some(params) => params,
754 None => {
755 return Err(Error::Transport(
756 "Cannot recover: never initialized".to_string(),
757 ));
758 }
759 };
760
761 let (done_tx, done_rx) = oneshot::channel();
763 self.command_tx
764 .send(LoopCommand::ResetSession { done_tx })
765 .await
766 .map_err(|_| Error::Transport("Connection closed".to_string()))?;
767 done_rx
768 .await
769 .map_err(|_| Error::Transport("Connection closed during recovery".to_string()))?;
770
771 self.initialized.store(false, Ordering::Release);
773 *self.server_info.write().await = None;
774
775 tracing::info!("Re-initializing session after expiry");
777 let params = InitializeParams {
778 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
779 capabilities: self.capabilities.clone(),
780 client_info: Implementation {
781 name: client_name,
782 version: client_version,
783 ..Default::default()
784 },
785 meta: None,
786 };
787
788 let result: InitializeResult = self.send_request_once("initialize", ¶ms).await?;
789 *self.server_info.write().await = Some(result);
790
791 self.send_notification("notifications/initialized", &serde_json::json!({}))
792 .await?;
793 self.initialized.store(true, Ordering::Release);
794
795 Ok(())
796 }
797
798 async fn send_notification<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
799 self.ensure_connected()?;
800 let params_value = serde_json::to_value(params)
801 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
802
803 self.command_tx
804 .send(LoopCommand::Notify {
805 method: method.to_string(),
806 params: params_value,
807 })
808 .await
809 .map_err(|_| Error::Transport("Connection closed".to_string()))?;
810
811 Ok(())
812 }
813
814 fn ensure_connected(&self) -> Result<()> {
815 if !self.connected.load(Ordering::Acquire) {
816 return Err(Error::Transport("Connection closed".to_string()));
817 }
818 Ok(())
819 }
820
821 fn ensure_initialized(&self) -> Result<()> {
822 if !self.initialized.load(Ordering::Acquire) {
823 return Err(Error::Transport("Client not initialized".to_string()));
824 }
825 Ok(())
826 }
827}
828
829impl Drop for McpClient {
830 fn drop(&mut self) {
831 if let Some(task) = self.task.take() {
832 task.abort();
833 }
834 }
835}
836
837struct PendingRequest {
843 response_tx: oneshot::Sender<Result<serde_json::Value>>,
844}
845
846async fn message_loop<T: ClientTransport, H: ClientHandler>(
848 mut transport: T,
849 handler: H,
850 mut command_rx: mpsc::Receiver<LoopCommand>,
851 connected: Arc<AtomicBool>,
852 roots: Arc<RwLock<Vec<Root>>>,
853) {
854 let handler = Arc::new(handler);
855 let mut pending_requests: HashMap<RequestId, PendingRequest> = HashMap::new();
856 let next_id = AtomicI64::new(1);
857
858 loop {
859 tokio::select! {
860 command = command_rx.recv() => {
862 match command {
863 Some(LoopCommand::Request { method, params, response_tx }) => {
864 let id = RequestId::Number(next_id.fetch_add(1, Ordering::Relaxed));
865
866 let request = JsonRpcRequest::new(id.clone(), &method)
867 .with_params(params);
868 let json = match serde_json::to_string(&request) {
869 Ok(j) => j,
870 Err(e) => {
871 let _ = response_tx.send(Err(Error::Transport(
872 format!("Serialization failed: {}", e)
873 )));
874 continue;
875 }
876 };
877
878 tracing::debug!(method = %method, id = ?id, "Sending request");
879 pending_requests.insert(id, PendingRequest { response_tx });
880
881 if let Err(e) = transport.send(&json).await {
882 tracing::error!(error = %e, "Transport send error");
883 fail_all_pending(&mut pending_requests, &format!("Transport error: {}", e));
884 break;
885 }
886 }
887 Some(LoopCommand::Notify { method, params }) => {
888 let notification = JsonRpcNotification::new(&method)
889 .with_params(params);
890 if let Ok(json) = serde_json::to_string(¬ification) {
891 tracing::debug!(method = %method, "Sending notification");
892 let _ = transport.send(&json).await;
893 }
894 }
895 Some(LoopCommand::ResetSession { done_tx }) => {
896 tracing::info!("Resetting transport session for re-initialization");
897 transport.reset_session().await;
898 for (_, pending) in pending_requests.drain() {
900 let _ = pending.response_tx.send(Err(Error::SessionExpired));
901 }
902 let _ = done_tx.send(());
903 }
904 Some(LoopCommand::Shutdown) | None => {
905 tracing::debug!("Message loop shutting down");
906 break;
907 }
908 }
909 }
910
911 result = transport.recv() => {
913 match result {
914 Ok(Some(line)) => {
915 handle_incoming(
916 &line,
917 &mut pending_requests,
918 &handler,
919 &roots,
920 &mut transport,
921 ).await;
922 }
923 Ok(None) => {
924 tracing::info!("Transport closed (EOF)");
925 break;
926 }
927 Err(e) => {
928 tracing::error!(error = %e, "Transport receive error");
929 break;
930 }
931 }
932 }
933 }
934 }
935
936 connected.store(false, Ordering::Release);
938 fail_all_pending(&mut pending_requests, "Connection closed");
939 let _ = transport.close().await;
940}
941
942async fn handle_incoming<T: ClientTransport, H: ClientHandler>(
944 line: &str,
945 pending_requests: &mut HashMap<RequestId, PendingRequest>,
946 handler: &Arc<H>,
947 roots: &Arc<RwLock<Vec<Root>>>,
948 transport: &mut T,
949) {
950 let parsed: serde_json::Value = match serde_json::from_str(line) {
951 Ok(v) => v,
952 Err(e) => {
953 tracing::warn!(error = %e, "Failed to parse incoming message");
954 return;
955 }
956 };
957
958 if parsed.get("method").is_none()
960 && (parsed.get("result").is_some() || parsed.get("error").is_some())
961 {
962 if let Some(error) = parsed.get("error") {
965 let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(0) as i32;
966 let id_missing_or_null = parsed.get("id").is_none_or(|id| id.is_null());
967 if code == -32005 && id_missing_or_null {
968 tracing::warn!(
969 "Session expired (-32005 with null id), failing all pending requests"
970 );
971 for (_, pending) in pending_requests.drain() {
972 let _ = pending.response_tx.send(Err(Error::SessionExpired));
973 }
974 return;
975 }
976 }
977
978 handle_response(&parsed, pending_requests);
979 return;
980 }
981
982 if parsed.get("id").is_some() && parsed.get("method").is_some() {
984 let id = parse_request_id(&parsed);
985 let method = parsed["method"].as_str().unwrap_or("");
986 let params = parsed.get("params").cloned();
987
988 let result = dispatch_server_request(handler, roots, method, params).await;
989
990 let response = match result {
992 Ok(value) => {
993 if let Some(id) = id {
994 serde_json::json!({
995 "jsonrpc": "2.0",
996 "id": id,
997 "result": value
998 })
999 } else {
1000 return;
1001 }
1002 }
1003 Err(error) => {
1004 serde_json::json!({
1005 "jsonrpc": "2.0",
1006 "id": id,
1007 "error": {
1008 "code": error.code,
1009 "message": error.message
1010 }
1011 })
1012 }
1013 };
1014
1015 if let Ok(json) = serde_json::to_string(&response) {
1016 let _ = transport.send(&json).await;
1017 }
1018 return;
1019 }
1020
1021 if parsed.get("method").is_some() && parsed.get("id").is_none() {
1023 let method = parsed["method"].as_str().unwrap_or("");
1024 let params = parsed.get("params").cloned();
1025 let notification = parse_server_notification(method, params);
1026 handler.on_notification(notification).await;
1027 }
1028}
1029
1030fn handle_response(
1032 parsed: &serde_json::Value,
1033 pending_requests: &mut HashMap<RequestId, PendingRequest>,
1034) {
1035 let id = match parse_request_id(parsed) {
1036 Some(id) => id,
1037 None => {
1038 tracing::warn!("Response without id");
1039 return;
1040 }
1041 };
1042
1043 let pending = match pending_requests.remove(&id) {
1044 Some(p) => p,
1045 None => {
1046 tracing::warn!(id = ?id, "Response for unknown request");
1047 return;
1048 }
1049 };
1050
1051 tracing::debug!(id = ?id, "Received response");
1052
1053 if let Some(error) = parsed.get("error") {
1054 let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
1055 let message = error
1056 .get("message")
1057 .and_then(|m| m.as_str())
1058 .unwrap_or("Unknown error")
1059 .to_string();
1060 let data = error.get("data").cloned();
1061
1062 if code == -32005 {
1064 let _ = pending.response_tx.send(Err(Error::SessionExpired));
1065 return;
1066 }
1067
1068 let json_rpc_error = JsonRpcError {
1069 code,
1070 message,
1071 data,
1072 };
1073 let _ = pending
1074 .response_tx
1075 .send(Err(Error::JsonRpc(json_rpc_error)));
1076 } else if let Some(result) = parsed.get("result") {
1077 let _ = pending.response_tx.send(Ok(result.clone()));
1078 } else {
1079 let _ = pending
1080 .response_tx
1081 .send(Err(Error::Transport("Invalid response".to_string())));
1082 }
1083}
1084
1085async fn dispatch_server_request<H: ClientHandler>(
1087 handler: &Arc<H>,
1088 roots: &Arc<RwLock<Vec<Root>>>,
1089 method: &str,
1090 params: Option<serde_json::Value>,
1091) -> std::result::Result<serde_json::Value, JsonRpcError> {
1092 match method {
1093 "sampling/createMessage" => {
1094 let p = serde_json::from_value(params.unwrap_or_default())
1095 .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
1096 let result = handler.handle_create_message(p).await?;
1097 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1098 }
1099 "elicitation/create" => {
1100 let p = serde_json::from_value(params.unwrap_or_default())
1101 .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
1102 let result = handler.handle_elicit(p).await?;
1103 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1104 }
1105 "roots/list" => {
1106 let roots_list = roots.read().await;
1108 if !roots_list.is_empty() {
1109 let result = ListRootsResult {
1110 roots: roots_list.clone(),
1111 meta: None,
1112 };
1113 return serde_json::to_value(result)
1114 .map_err(|e| JsonRpcError::internal_error(e.to_string()));
1115 }
1116 drop(roots_list);
1117
1118 let result = handler.handle_list_roots().await?;
1119 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
1120 }
1121 "ping" => Ok(serde_json::json!({})),
1122 _ => Err(JsonRpcError::method_not_found(method)),
1123 }
1124}
1125
1126fn parse_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
1128 parsed.get("id").and_then(|id| {
1129 if let Some(n) = id.as_i64() {
1130 Some(RequestId::Number(n))
1131 } else {
1132 id.as_str().map(|s| RequestId::String(s.to_string()))
1133 }
1134 })
1135}
1136
1137fn parse_server_notification(
1139 method: &str,
1140 params: Option<serde_json::Value>,
1141) -> ServerNotification {
1142 match method {
1143 notifications::PROGRESS => {
1144 if let Some(params) = params
1145 && let Ok(p) = serde_json::from_value(params)
1146 {
1147 return ServerNotification::Progress(p);
1148 }
1149 ServerNotification::Unknown {
1150 method: method.to_string(),
1151 params: None,
1152 }
1153 }
1154 notifications::MESSAGE => {
1155 if let Some(params) = params
1156 && let Ok(p) = serde_json::from_value(params)
1157 {
1158 return ServerNotification::LogMessage(p);
1159 }
1160 ServerNotification::Unknown {
1161 method: method.to_string(),
1162 params: None,
1163 }
1164 }
1165 notifications::RESOURCE_UPDATED => {
1166 if let Some(params) = ¶ms
1167 && let Some(uri) = params.get("uri").and_then(|u| u.as_str())
1168 {
1169 return ServerNotification::ResourceUpdated {
1170 uri: uri.to_string(),
1171 };
1172 }
1173 ServerNotification::Unknown {
1174 method: method.to_string(),
1175 params,
1176 }
1177 }
1178 notifications::RESOURCES_LIST_CHANGED => ServerNotification::ResourcesListChanged,
1179 notifications::TOOLS_LIST_CHANGED => ServerNotification::ToolsListChanged,
1180 notifications::PROMPTS_LIST_CHANGED => ServerNotification::PromptsListChanged,
1181 _ => ServerNotification::Unknown {
1182 method: method.to_string(),
1183 params,
1184 },
1185 }
1186}
1187
1188fn fail_all_pending(pending: &mut HashMap<RequestId, PendingRequest>, reason: &str) {
1190 for (_, req) in pending.drain() {
1191 let _ = req
1192 .response_tx
1193 .send(Err(Error::Transport(reason.to_string())));
1194 }
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199 use super::*;
1200 use async_trait::async_trait;
1201 use std::sync::Mutex;
1202
1203 struct MockTransport {
1211 responses: Arc<Mutex<Vec<serde_json::Value>>>,
1213 response_idx: Arc<std::sync::atomic::AtomicUsize>,
1215 incoming_tx: mpsc::Sender<String>,
1217 incoming_rx: mpsc::Receiver<String>,
1219 outgoing: Arc<Mutex<Vec<String>>>,
1221 connected: Arc<AtomicBool>,
1222 }
1223
1224 #[allow(dead_code)]
1225 impl MockTransport {
1226 fn new() -> Self {
1227 let (tx, rx) = mpsc::channel(32);
1228 Self {
1229 responses: Arc::new(Mutex::new(Vec::new())),
1230 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1231 incoming_tx: tx,
1232 incoming_rx: rx,
1233 outgoing: Arc::new(Mutex::new(Vec::new())),
1234 connected: Arc::new(AtomicBool::new(true)),
1235 }
1236 }
1237
1238 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
1244 let (tx, rx) = mpsc::channel(32);
1245 Self {
1246 responses: Arc::new(Mutex::new(responses)),
1247 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1248 incoming_tx: tx,
1249 incoming_rx: rx,
1250 outgoing: Arc::new(Mutex::new(Vec::new())),
1251 connected: Arc::new(AtomicBool::new(true)),
1252 }
1253 }
1254 }
1255
1256 #[async_trait]
1257 impl ClientTransport for MockTransport {
1258 async fn send(&mut self, message: &str) -> Result<()> {
1259 self.outgoing.lock().unwrap().push(message.to_string());
1260
1261 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
1263 if let Some(id) = parsed.get("id") {
1265 let idx = self
1266 .response_idx
1267 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1268 let responses = self.responses.lock().unwrap();
1269 if let Some(result) = responses.get(idx) {
1270 let response = serde_json::json!({
1271 "jsonrpc": "2.0",
1272 "id": id,
1273 "result": result
1274 });
1275 let _ = self.incoming_tx.try_send(response.to_string());
1276 }
1277 }
1278 }
1279
1280 Ok(())
1281 }
1282
1283 async fn recv(&mut self) -> Result<Option<String>> {
1284 match self.incoming_rx.recv().await {
1287 Some(msg) => Ok(Some(msg)),
1288 None => Ok(None),
1289 }
1290 }
1291
1292 fn is_connected(&self) -> bool {
1293 self.connected.load(Ordering::Relaxed)
1294 }
1295
1296 async fn close(&mut self) -> Result<()> {
1297 self.connected.store(false, Ordering::Relaxed);
1298 Ok(())
1299 }
1300 }
1301
1302 fn mock_initialize_response() -> serde_json::Value {
1303 serde_json::json!({
1304 "protocolVersion": "2025-11-25",
1305 "serverInfo": {
1306 "name": "test-server",
1307 "version": "1.0.0"
1308 },
1309 "capabilities": {
1310 "tools": {}
1311 }
1312 })
1313 }
1314
1315 #[tokio::test]
1316 async fn test_client_not_initialized() {
1317 let client = McpClient::connect(MockTransport::with_responses(vec![]))
1318 .await
1319 .unwrap();
1320
1321 let result = client.list_tools().await;
1322 assert!(result.is_err());
1323 assert!(result.unwrap_err().to_string().contains("not initialized"));
1324 }
1325
1326 #[tokio::test]
1327 async fn test_client_initialize() {
1328 let client = McpClient::connect(MockTransport::with_responses(vec![
1329 mock_initialize_response(),
1330 ]))
1331 .await
1332 .unwrap();
1333
1334 assert!(!client.is_initialized());
1335
1336 let result = client.initialize("test-client", "1.0.0").await;
1337 assert!(result.is_ok());
1338 assert!(client.is_initialized());
1339
1340 let server_info = client.server_info().await.unwrap();
1341 assert_eq!(server_info.server_info.name, "test-server");
1342 }
1343
1344 #[tokio::test]
1345 async fn test_list_tools() {
1346 let client = McpClient::connect(MockTransport::with_responses(vec![
1347 mock_initialize_response(),
1348 serde_json::json!({
1349 "tools": [
1350 {
1351 "name": "test_tool",
1352 "description": "A test tool",
1353 "inputSchema": {
1354 "type": "object",
1355 "properties": {}
1356 }
1357 }
1358 ]
1359 }),
1360 ]))
1361 .await
1362 .unwrap();
1363
1364 client.initialize("test-client", "1.0.0").await.unwrap();
1365 let tools = client.list_tools().await.unwrap();
1366
1367 assert_eq!(tools.tools.len(), 1);
1368 assert_eq!(tools.tools[0].name, "test_tool");
1369 }
1370
1371 #[tokio::test]
1372 async fn test_call_tool() {
1373 let client = McpClient::connect(MockTransport::with_responses(vec![
1374 mock_initialize_response(),
1375 serde_json::json!({
1376 "content": [
1377 {
1378 "type": "text",
1379 "text": "Tool result"
1380 }
1381 ]
1382 }),
1383 ]))
1384 .await
1385 .unwrap();
1386
1387 client.initialize("test-client", "1.0.0").await.unwrap();
1388 let result = client
1389 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
1390 .await
1391 .unwrap();
1392
1393 assert!(!result.content.is_empty());
1394 }
1395
1396 #[tokio::test]
1397 async fn test_list_resources() {
1398 let client = McpClient::connect(MockTransport::with_responses(vec![
1399 mock_initialize_response(),
1400 serde_json::json!({
1401 "resources": [
1402 {
1403 "uri": "file://test.txt",
1404 "name": "Test File"
1405 }
1406 ]
1407 }),
1408 ]))
1409 .await
1410 .unwrap();
1411
1412 client.initialize("test-client", "1.0.0").await.unwrap();
1413 let resources = client.list_resources().await.unwrap();
1414
1415 assert_eq!(resources.resources.len(), 1);
1416 assert_eq!(resources.resources[0].uri, "file://test.txt");
1417 }
1418
1419 #[tokio::test]
1420 async fn test_read_resource() {
1421 let client = McpClient::connect(MockTransport::with_responses(vec![
1422 mock_initialize_response(),
1423 serde_json::json!({
1424 "contents": [
1425 {
1426 "uri": "file://test.txt",
1427 "text": "File contents"
1428 }
1429 ]
1430 }),
1431 ]))
1432 .await
1433 .unwrap();
1434
1435 client.initialize("test-client", "1.0.0").await.unwrap();
1436 let result = client.read_resource("file://test.txt").await.unwrap();
1437
1438 assert_eq!(result.contents.len(), 1);
1439 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
1440 }
1441
1442 #[tokio::test]
1443 async fn test_list_prompts() {
1444 let client = McpClient::connect(MockTransport::with_responses(vec![
1445 mock_initialize_response(),
1446 serde_json::json!({
1447 "prompts": [
1448 {
1449 "name": "test_prompt",
1450 "description": "A test prompt"
1451 }
1452 ]
1453 }),
1454 ]))
1455 .await
1456 .unwrap();
1457
1458 client.initialize("test-client", "1.0.0").await.unwrap();
1459 let prompts = client.list_prompts().await.unwrap();
1460
1461 assert_eq!(prompts.prompts.len(), 1);
1462 assert_eq!(prompts.prompts[0].name, "test_prompt");
1463 }
1464
1465 #[tokio::test]
1466 async fn test_get_prompt() {
1467 let client = McpClient::connect(MockTransport::with_responses(vec![
1468 mock_initialize_response(),
1469 serde_json::json!({
1470 "messages": [
1471 {
1472 "role": "user",
1473 "content": {
1474 "type": "text",
1475 "text": "Prompt message"
1476 }
1477 }
1478 ]
1479 }),
1480 ]))
1481 .await
1482 .unwrap();
1483
1484 client.initialize("test-client", "1.0.0").await.unwrap();
1485 let result = client.get_prompt("test_prompt", None).await.unwrap();
1486
1487 assert_eq!(result.messages.len(), 1);
1488 }
1489
1490 #[tokio::test]
1491 async fn test_ping() {
1492 let client = McpClient::connect(MockTransport::with_responses(vec![
1493 mock_initialize_response(),
1494 serde_json::json!({}),
1495 ]))
1496 .await
1497 .unwrap();
1498
1499 client.initialize("test-client", "1.0.0").await.unwrap();
1500 let result = client.ping().await;
1501 assert!(result.is_ok());
1502 }
1503
1504 #[tokio::test]
1505 async fn test_with_roots() {
1506 let roots = vec![Root::new("file:///test")];
1507 let client = McpClient::builder()
1508 .with_roots(roots)
1509 .connect_simple(MockTransport::with_responses(vec![]))
1510 .await
1511 .unwrap();
1512
1513 let current_roots = client.roots().await;
1514 assert_eq!(current_roots.len(), 1);
1515 }
1516
1517 #[tokio::test]
1518 async fn test_roots_management() {
1519 let client = McpClient::connect(MockTransport::with_responses(vec![
1520 mock_initialize_response(),
1521 ]))
1522 .await
1523 .unwrap();
1524
1525 assert!(client.roots().await.is_empty());
1527
1528 client.add_root(Root::new("file:///project")).await.unwrap();
1530 assert_eq!(client.roots().await.len(), 1);
1531
1532 client.initialize("test-client", "1.0.0").await.unwrap();
1534
1535 let removed = client.remove_root("file:///project").await.unwrap();
1537 assert!(removed);
1538 assert!(client.roots().await.is_empty());
1539
1540 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
1542 assert!(!not_removed);
1543 }
1544
1545 #[tokio::test]
1546 async fn test_list_roots() {
1547 let roots = vec![
1548 Root::new("file:///project1"),
1549 Root::with_name("file:///project2", "Project 2"),
1550 ];
1551 let client = McpClient::builder()
1552 .with_roots(roots)
1553 .connect_simple(MockTransport::with_responses(vec![]))
1554 .await
1555 .unwrap();
1556
1557 let result = client.list_roots().await;
1558 assert_eq!(result.roots.len(), 2);
1559 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
1560 }
1561
1562 #[test]
1563 fn test_builder_with_sampling() {
1564 let builder = McpClientBuilder::new().with_sampling();
1565 assert!(builder.capabilities.sampling.is_some());
1566 }
1567
1568 #[test]
1569 fn test_builder_with_elicitation() {
1570 let builder = McpClientBuilder::new().with_elicitation();
1571 assert!(builder.capabilities.elicitation.is_some());
1572 }
1573
1574 #[test]
1575 fn test_builder_chaining() {
1576 let builder = McpClientBuilder::new()
1577 .with_sampling()
1578 .with_elicitation()
1579 .with_roots(vec![Root::new("file:///project")]);
1580 assert!(builder.capabilities.sampling.is_some());
1581 assert!(builder.capabilities.elicitation.is_some());
1582 assert!(builder.capabilities.roots.is_some());
1583 }
1584
1585 #[tokio::test]
1586 async fn test_bidirectional_sampling_round_trip() {
1587 use crate::protocol::{
1588 ContentRole, CreateMessageParams, CreateMessageResult, SamplingContent,
1589 SamplingContentOrArray,
1590 };
1591
1592 struct RecordingHandler {
1594 called: Arc<AtomicBool>,
1595 }
1596
1597 #[async_trait]
1598 impl ClientHandler for RecordingHandler {
1599 async fn handle_create_message(
1600 &self,
1601 _params: CreateMessageParams,
1602 ) -> std::result::Result<CreateMessageResult, tower_mcp_types::JsonRpcError>
1603 {
1604 self.called.store(true, Ordering::SeqCst);
1605 Ok(CreateMessageResult {
1606 content: SamplingContentOrArray::Single(SamplingContent::Text {
1607 text: "test response".to_string(),
1608 annotations: None,
1609 meta: None,
1610 }),
1611 model: "test-model".to_string(),
1612 role: ContentRole::Assistant,
1613 stop_reason: Some("end_turn".to_string()),
1614 meta: None,
1615 })
1616 }
1617 }
1618
1619 let called = Arc::new(AtomicBool::new(false));
1620 let handler = RecordingHandler {
1621 called: called.clone(),
1622 };
1623
1624 let (inject_tx, rx) = mpsc::channel::<String>(32);
1627 let responses = vec![mock_initialize_response()];
1628 let inject_tx_clone = inject_tx.clone();
1629
1630 let transport = MockTransport {
1631 responses: Arc::new(Mutex::new(responses)),
1632 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1633 incoming_tx: inject_tx,
1634 incoming_rx: rx,
1635 outgoing: Arc::new(Mutex::new(Vec::new())),
1636 connected: Arc::new(AtomicBool::new(true)),
1637 };
1638
1639 let client = McpClient::builder()
1640 .with_sampling()
1641 .connect(transport, handler)
1642 .await
1643 .unwrap();
1644
1645 client.initialize("test-client", "1.0.0").await.unwrap();
1647
1648 let sampling_request = serde_json::json!({
1650 "jsonrpc": "2.0",
1651 "id": 100,
1652 "method": "sampling/createMessage",
1653 "params": {
1654 "messages": [
1655 {
1656 "role": "user",
1657 "content": {
1658 "type": "text",
1659 "text": "Hello"
1660 }
1661 }
1662 ],
1663 "maxTokens": 100
1664 }
1665 });
1666 inject_tx_clone
1667 .send(sampling_request.to_string())
1668 .await
1669 .unwrap();
1670
1671 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1673
1674 assert!(
1676 called.load(Ordering::SeqCst),
1677 "handle_create_message should have been called"
1678 );
1679 }
1680
1681 #[tokio::test]
1682 async fn test_list_resource_templates() {
1683 let client = McpClient::connect(MockTransport::with_responses(vec![
1684 mock_initialize_response(),
1685 serde_json::json!({
1686 "resourceTemplates": [
1687 {
1688 "uriTemplate": "file:///{path}",
1689 "name": "File Template",
1690 "description": "A file template"
1691 }
1692 ]
1693 }),
1694 ]))
1695 .await
1696 .unwrap();
1697
1698 client.initialize("test-client", "1.0.0").await.unwrap();
1699 let result = client.list_resource_templates().await.unwrap();
1700
1701 assert_eq!(result.resource_templates.len(), 1);
1702 assert_eq!(result.resource_templates[0].name, "File Template");
1703 }
1704
1705 #[tokio::test]
1706 async fn test_list_all_tools_single_page() {
1707 let client = McpClient::connect(MockTransport::with_responses(vec![
1708 mock_initialize_response(),
1709 serde_json::json!({
1710 "tools": [
1711 {
1712 "name": "tool_a",
1713 "description": "Tool A",
1714 "inputSchema": { "type": "object", "properties": {} }
1715 },
1716 {
1717 "name": "tool_b",
1718 "description": "Tool B",
1719 "inputSchema": { "type": "object", "properties": {} }
1720 }
1721 ]
1722 }),
1723 ]))
1724 .await
1725 .unwrap();
1726
1727 client.initialize("test-client", "1.0.0").await.unwrap();
1728 let tools = client.list_all_tools().await.unwrap();
1729
1730 assert_eq!(tools.len(), 2);
1731 assert_eq!(tools[0].name, "tool_a");
1732 assert_eq!(tools[1].name, "tool_b");
1733 }
1734
1735 #[tokio::test]
1736 async fn test_list_all_tools_paginated() {
1737 let client = McpClient::connect(MockTransport::with_responses(vec![
1738 mock_initialize_response(),
1739 serde_json::json!({
1741 "tools": [
1742 {
1743 "name": "tool_a",
1744 "description": "Tool A",
1745 "inputSchema": { "type": "object", "properties": {} }
1746 }
1747 ],
1748 "nextCursor": "page2"
1749 }),
1750 serde_json::json!({
1752 "tools": [
1753 {
1754 "name": "tool_b",
1755 "description": "Tool B",
1756 "inputSchema": { "type": "object", "properties": {} }
1757 }
1758 ]
1759 }),
1760 ]))
1761 .await
1762 .unwrap();
1763
1764 client.initialize("test-client", "1.0.0").await.unwrap();
1765 let tools = client.list_all_tools().await.unwrap();
1766
1767 assert_eq!(tools.len(), 2);
1768 assert_eq!(tools[0].name, "tool_a");
1769 assert_eq!(tools[1].name, "tool_b");
1770 }
1771
1772 #[tokio::test]
1773 async fn test_call_tool_text_success() {
1774 let client = McpClient::connect(MockTransport::with_responses(vec![
1775 mock_initialize_response(),
1776 serde_json::json!({
1777 "content": [
1778 { "type": "text", "text": "Hello " },
1779 { "type": "text", "text": "World" }
1780 ]
1781 }),
1782 ]))
1783 .await
1784 .unwrap();
1785
1786 client.initialize("test-client", "1.0.0").await.unwrap();
1787 let text = client
1788 .call_tool_text("test_tool", serde_json::json!({}))
1789 .await
1790 .unwrap();
1791
1792 assert_eq!(text, "Hello World");
1793 }
1794
1795 #[tokio::test]
1796 async fn test_call_tool_text_error() {
1797 let client = McpClient::connect(MockTransport::with_responses(vec![
1798 mock_initialize_response(),
1799 serde_json::json!({
1800 "content": [
1801 { "type": "text", "text": "something went wrong" }
1802 ],
1803 "isError": true
1804 }),
1805 ]))
1806 .await
1807 .unwrap();
1808
1809 client.initialize("test-client", "1.0.0").await.unwrap();
1810 let result = client
1811 .call_tool_text("test_tool", serde_json::json!({}))
1812 .await;
1813
1814 assert!(result.is_err());
1815 let err = result.unwrap_err();
1816 assert!(
1817 err.to_string().contains("something went wrong"),
1818 "Error message should contain tool error text, got: {}",
1819 err
1820 );
1821 }
1822
1823 #[tokio::test]
1824 async fn test_server_notification_parsing() {
1825 let notification = parse_server_notification("notifications/tools/list_changed", None);
1826 assert!(matches!(notification, ServerNotification::ToolsListChanged));
1827
1828 let notification = parse_server_notification("notifications/resources/list_changed", None);
1829 assert!(matches!(
1830 notification,
1831 ServerNotification::ResourcesListChanged
1832 ));
1833
1834 let notification = parse_server_notification(
1835 "notifications/resources/updated",
1836 Some(serde_json::json!({"uri": "file:///test"})),
1837 );
1838 match notification {
1839 ServerNotification::ResourceUpdated { uri } => {
1840 assert_eq!(uri, "file:///test");
1841 }
1842 _ => panic!("Expected ResourceUpdated"),
1843 }
1844
1845 let notification =
1846 parse_server_notification("custom/notification", Some(serde_json::json!({"data": 42})));
1847 match notification {
1848 ServerNotification::Unknown { method, params } => {
1849 assert_eq!(method, "custom/notification");
1850 assert!(params.is_some());
1851 }
1852 _ => panic!("Expected Unknown"),
1853 }
1854 }
1855}