Skip to main content

mcp_utils/client/
mcp_client.rs

1// Don't use custom Result type here as we need to return rmcp::ErrorData
2use rmcp::{
3    ClientHandler, RoleClient,
4    handler::client::progress::ProgressDispatcher,
5    model::{
6        ClientInfo, CreateElicitationRequestParams, CreateElicitationResult, ElicitationAction,
7        ElicitationResponseNotificationParam, ErrorData, ListRootsResult, ProgressNotificationParam,
8    },
9    service::{NotificationContext, RequestContext},
10};
11use std::result::Result;
12use std::sync::Arc;
13use tokio::sync::{RwLock, mpsc, oneshot};
14
15use crate::client::{ElicitationRequest, McpClientEvent};
16use rmcp::model::Root;
17
18pub struct McpClient {
19    client_info: ClientInfo,
20    server_name: String,
21    pub progress_dispatcher: ProgressDispatcher,
22    event_sender: mpsc::Sender<McpClientEvent>,
23    /// Roots advertised to MCP servers
24    roots: Arc<RwLock<Vec<Root>>>,
25}
26
27impl McpClient {
28    pub fn new(
29        client_info: ClientInfo,
30        server_name: String,
31        event_sender: mpsc::Sender<McpClientEvent>,
32        roots: Arc<RwLock<Vec<Root>>>,
33    ) -> Self {
34        Self { client_info, server_name, progress_dispatcher: ProgressDispatcher::new(), event_sender, roots }
35    }
36
37    pub fn server_name(&self) -> &str {
38        &self.server_name
39    }
40
41    /// Dispatch an elicitation request through the shared event channel.
42    ///
43    /// Used by both the `create_elicitation` handler and the `-32042`
44    /// `URL_ELICITATION_REQUIRED` error path to ensure the same user-facing flow.
45    pub async fn dispatch_elicitation(&self, request: CreateElicitationRequestParams) -> CreateElicitationResult {
46        let (response_tx, response_rx) = oneshot::channel();
47        let elicitation_request =
48            ElicitationRequest { server_name: self.server_name.clone(), request, response_sender: response_tx };
49
50        if self.event_sender.send(McpClientEvent::Elicitation(elicitation_request)).await.is_err() {
51            return cancel_result();
52        }
53        response_rx.await.unwrap_or_else(|_| cancel_result())
54    }
55
56    /// Forward a URL elicitation completion through the shared event channel.
57    ///
58    /// Split out from `on_url_elicitation_notification_complete` so it can be
59    /// tested without constructing a `NotificationContext`.
60    pub async fn forward_url_elicitation_complete(&self, elicitation_id: String) {
61        let event = McpClientEvent::UrlElicitationComplete(super::UrlElicitationCompleteParams {
62            server_name: self.server_name.clone(),
63            elicitation_id,
64        });
65        if self.event_sender.send(event).await.is_err() {
66            tracing::warn!("Failed to forward URL elicitation completion: receiver dropped");
67        }
68    }
69}
70
71pub fn cancel_result() -> CreateElicitationResult {
72    CreateElicitationResult { action: ElicitationAction::Cancel, content: None, meta: Option::default() }
73}
74
75impl ClientHandler for McpClient {
76    fn get_info(&self) -> ClientInfo {
77        self.client_info.clone()
78    }
79
80    async fn on_progress(&self, params: ProgressNotificationParam, _context: NotificationContext<RoleClient>) -> () {
81        self.progress_dispatcher.handle_notification(params).await;
82    }
83
84    async fn create_elicitation(
85        &self,
86        request: CreateElicitationRequestParams,
87        _context: RequestContext<RoleClient>,
88    ) -> Result<CreateElicitationResult, ErrorData> {
89        Ok(self.dispatch_elicitation(request).await)
90    }
91
92    async fn on_url_elicitation_notification_complete(
93        &self,
94        params: ElicitationResponseNotificationParam,
95        _context: NotificationContext<RoleClient>,
96    ) {
97        self.forward_url_elicitation_complete(params.elicitation_id).await;
98    }
99
100    async fn list_roots(&self, _context: RequestContext<RoleClient>) -> Result<ListRootsResult, ErrorData> {
101        let roots = self.roots.read().await;
102
103        Ok(ListRootsResult::new(roots.clone()))
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use rmcp::model::{
111        ClientCapabilities, ElicitationSchema, FormElicitationCapability, Implementation, UrlElicitationCapability,
112    };
113    use std::collections::BTreeMap;
114
115    fn test_client_info() -> ClientInfo {
116        let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
117        if let Some(elicitation) = capabilities.elicitation.as_mut() {
118            elicitation.form = Some(FormElicitationCapability::default());
119            elicitation.url = Some(UrlElicitationCapability::default());
120        }
121        ClientInfo::new(capabilities, Implementation::new("test", "0.1.0"))
122    }
123
124    fn make_client(event_sender: mpsc::Sender<McpClientEvent>) -> McpClient {
125        McpClient::new(test_client_info(), "test-server".to_string(), event_sender, Arc::new(RwLock::new(Vec::new())))
126    }
127
128    fn unwrap_elicitation(event: McpClientEvent) -> ElicitationRequest {
129        match event {
130            McpClientEvent::Elicitation(req) => req,
131            other @ McpClientEvent::UrlElicitationComplete(_) => panic!("expected Elicitation, got {other:?}"),
132        }
133    }
134
135    #[tokio::test]
136    async fn dispatch_elicitation_dropped_sender_returns_cancel() {
137        let (event_tx, _) = mpsc::channel(1);
138        let client = make_client(event_tx);
139
140        let request = CreateElicitationRequestParams::FormElicitationParams {
141            meta: None,
142            message: "test".to_string(),
143            requested_schema: ElicitationSchema::new(BTreeMap::new()),
144        };
145
146        let result = client.dispatch_elicitation(request).await;
147        assert_eq!(result.action, ElicitationAction::Cancel, "dropped sender should return Cancel, not Decline");
148        assert!(result.content.is_none());
149    }
150
151    #[tokio::test]
152    async fn dispatch_elicitation_dropped_receiver_returns_cancel() {
153        let (event_tx, mut event_rx) = mpsc::channel(1);
154        let client = make_client(event_tx);
155
156        let request = CreateElicitationRequestParams::FormElicitationParams {
157            meta: None,
158            message: "test".to_string(),
159            requested_schema: ElicitationSchema::new(BTreeMap::new()),
160        };
161
162        let handle = tokio::spawn(async move {
163            let event = event_rx.recv().await.unwrap();
164            let elicitation = unwrap_elicitation(event);
165            drop(elicitation.response_sender);
166        });
167
168        let result = client.dispatch_elicitation(request).await;
169        handle.await.unwrap();
170
171        assert_eq!(result.action, ElicitationAction::Cancel, "dropped receiver should return Cancel, not Decline");
172        assert!(result.content.is_none());
173    }
174
175    #[tokio::test]
176    async fn dispatch_elicitation_forwards_request_with_server_name() {
177        let (event_tx, mut event_rx) = mpsc::channel(1);
178        let client = make_client(event_tx);
179
180        let request = CreateElicitationRequestParams::UrlElicitationParams {
181            meta: None,
182            message: "Auth".to_string(),
183            url: "https://example.com/auth".to_string(),
184            elicitation_id: "el-123".to_string(),
185        };
186
187        let handle = tokio::spawn(async move {
188            let event = event_rx.recv().await.unwrap();
189            let elicitation = unwrap_elicitation(event);
190            assert_eq!(elicitation.server_name, "test-server");
191            let _ = elicitation.response_sender.send(CreateElicitationResult {
192                action: ElicitationAction::Accept,
193                content: None,
194                meta: Option::default(),
195            });
196        });
197
198        let result = client.dispatch_elicitation(request).await;
199        handle.await.unwrap();
200        assert_eq!(result.action, ElicitationAction::Accept);
201    }
202
203    #[tokio::test]
204    async fn forward_url_elicitation_complete_uses_server_name_and_id() {
205        let (event_tx, mut event_rx) = mpsc::channel(1);
206        let client = make_client(event_tx);
207
208        client.forward_url_elicitation_complete("el-456".to_string()).await;
209
210        let event = event_rx.recv().await.unwrap();
211        match event {
212            McpClientEvent::UrlElicitationComplete(params) => {
213                assert_eq!(params.server_name, "test-server");
214                assert_eq!(params.elicitation_id, "el-456");
215            }
216            other @ McpClientEvent::Elicitation(_) => panic!("expected UrlElicitationComplete, got {other:?}"),
217        }
218    }
219
220    #[tokio::test]
221    async fn forward_url_elicitation_complete_swallows_dropped_receiver() {
222        let (event_tx, event_rx) = mpsc::channel(1);
223        drop(event_rx);
224        let client = make_client(event_tx);
225
226        // Should not panic even though the receiver is dropped.
227        client.forward_url_elicitation_complete("el-gone".to_string()).await;
228    }
229
230    #[test]
231    fn capabilities_include_form_and_url() {
232        let info = test_client_info();
233        let caps = &info.capabilities;
234        let elicitation = caps.elicitation.as_ref().expect("elicitation capability should be set");
235        assert!(elicitation.form.is_some(), "form capability should be advertised");
236        assert!(elicitation.url.is_some(), "url capability should be advertised");
237    }
238}