1use rmcp::{
3 ClientHandler, RoleClient,
4 handler::client::progress::ProgressDispatcher,
5 model::{
6 ClientInfo, CreateElicitationRequestParams, CreateElicitationResult, ElicitationAction,
7 ElicitationResponseNotificationParam, ErrorData, ListRootsResult, ProgressNotificationParam,
8 },
9 service::{NotificationContext, RequestContext},
10};
11use std::result::Result;
12use std::sync::Arc;
13use tokio::sync::{RwLock, mpsc, oneshot};
14
15use crate::client::{ElicitationRequest, McpClientEvent};
16use rmcp::model::Root;
17
18pub struct McpClient {
19 client_info: ClientInfo,
20 server_name: String,
21 pub progress_dispatcher: ProgressDispatcher,
22 event_sender: mpsc::Sender<McpClientEvent>,
23 roots: Arc<RwLock<Vec<Root>>>,
25}
26
27impl McpClient {
28 pub fn new(
29 client_info: ClientInfo,
30 server_name: String,
31 event_sender: mpsc::Sender<McpClientEvent>,
32 roots: Arc<RwLock<Vec<Root>>>,
33 ) -> Self {
34 Self { client_info, server_name, progress_dispatcher: ProgressDispatcher::new(), event_sender, roots }
35 }
36
37 pub fn server_name(&self) -> &str {
38 &self.server_name
39 }
40
41 pub async fn dispatch_elicitation(&self, request: CreateElicitationRequestParams) -> CreateElicitationResult {
46 let (response_tx, response_rx) = oneshot::channel();
47 let elicitation_request =
48 ElicitationRequest { server_name: self.server_name.clone(), request, response_sender: response_tx };
49
50 if self.event_sender.send(McpClientEvent::Elicitation(elicitation_request)).await.is_err() {
51 return cancel_result();
52 }
53 response_rx.await.unwrap_or_else(|_| cancel_result())
54 }
55
56 pub async fn forward_url_elicitation_complete(&self, elicitation_id: String) {
61 let event = McpClientEvent::UrlElicitationComplete(super::UrlElicitationCompleteParams {
62 server_name: self.server_name.clone(),
63 elicitation_id,
64 });
65 if self.event_sender.send(event).await.is_err() {
66 tracing::warn!("Failed to forward URL elicitation completion: receiver dropped");
67 }
68 }
69}
70
71pub fn cancel_result() -> CreateElicitationResult {
72 CreateElicitationResult { action: ElicitationAction::Cancel, content: None, meta: Option::default() }
73}
74
75impl ClientHandler for McpClient {
76 fn get_info(&self) -> ClientInfo {
77 self.client_info.clone()
78 }
79
80 async fn on_progress(&self, params: ProgressNotificationParam, _context: NotificationContext<RoleClient>) -> () {
81 self.progress_dispatcher.handle_notification(params).await;
82 }
83
84 async fn create_elicitation(
85 &self,
86 request: CreateElicitationRequestParams,
87 _context: RequestContext<RoleClient>,
88 ) -> Result<CreateElicitationResult, ErrorData> {
89 Ok(self.dispatch_elicitation(request).await)
90 }
91
92 async fn on_url_elicitation_notification_complete(
93 &self,
94 params: ElicitationResponseNotificationParam,
95 _context: NotificationContext<RoleClient>,
96 ) {
97 self.forward_url_elicitation_complete(params.elicitation_id).await;
98 }
99
100 async fn list_roots(&self, _context: RequestContext<RoleClient>) -> Result<ListRootsResult, ErrorData> {
101 let roots = self.roots.read().await;
102
103 Ok(ListRootsResult::new(roots.clone()))
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use rmcp::model::{
111 ClientCapabilities, ElicitationSchema, FormElicitationCapability, Implementation, UrlElicitationCapability,
112 };
113 use std::collections::BTreeMap;
114
115 fn test_client_info() -> ClientInfo {
116 let mut capabilities = ClientCapabilities::builder().enable_elicitation().enable_roots().build();
117 if let Some(elicitation) = capabilities.elicitation.as_mut() {
118 elicitation.form = Some(FormElicitationCapability::default());
119 elicitation.url = Some(UrlElicitationCapability::default());
120 }
121 ClientInfo::new(capabilities, Implementation::new("test", "0.1.0"))
122 }
123
124 fn make_client(event_sender: mpsc::Sender<McpClientEvent>) -> McpClient {
125 McpClient::new(test_client_info(), "test-server".to_string(), event_sender, Arc::new(RwLock::new(Vec::new())))
126 }
127
128 fn unwrap_elicitation(event: McpClientEvent) -> ElicitationRequest {
129 match event {
130 McpClientEvent::Elicitation(req) => req,
131 other @ McpClientEvent::UrlElicitationComplete(_) => panic!("expected Elicitation, got {other:?}"),
132 }
133 }
134
135 #[tokio::test]
136 async fn dispatch_elicitation_dropped_sender_returns_cancel() {
137 let (event_tx, _) = mpsc::channel(1);
138 let client = make_client(event_tx);
139
140 let request = CreateElicitationRequestParams::FormElicitationParams {
141 meta: None,
142 message: "test".to_string(),
143 requested_schema: ElicitationSchema::new(BTreeMap::new()),
144 };
145
146 let result = client.dispatch_elicitation(request).await;
147 assert_eq!(result.action, ElicitationAction::Cancel, "dropped sender should return Cancel, not Decline");
148 assert!(result.content.is_none());
149 }
150
151 #[tokio::test]
152 async fn dispatch_elicitation_dropped_receiver_returns_cancel() {
153 let (event_tx, mut event_rx) = mpsc::channel(1);
154 let client = make_client(event_tx);
155
156 let request = CreateElicitationRequestParams::FormElicitationParams {
157 meta: None,
158 message: "test".to_string(),
159 requested_schema: ElicitationSchema::new(BTreeMap::new()),
160 };
161
162 let handle = tokio::spawn(async move {
163 let event = event_rx.recv().await.unwrap();
164 let elicitation = unwrap_elicitation(event);
165 drop(elicitation.response_sender);
166 });
167
168 let result = client.dispatch_elicitation(request).await;
169 handle.await.unwrap();
170
171 assert_eq!(result.action, ElicitationAction::Cancel, "dropped receiver should return Cancel, not Decline");
172 assert!(result.content.is_none());
173 }
174
175 #[tokio::test]
176 async fn dispatch_elicitation_forwards_request_with_server_name() {
177 let (event_tx, mut event_rx) = mpsc::channel(1);
178 let client = make_client(event_tx);
179
180 let request = CreateElicitationRequestParams::UrlElicitationParams {
181 meta: None,
182 message: "Auth".to_string(),
183 url: "https://example.com/auth".to_string(),
184 elicitation_id: "el-123".to_string(),
185 };
186
187 let handle = tokio::spawn(async move {
188 let event = event_rx.recv().await.unwrap();
189 let elicitation = unwrap_elicitation(event);
190 assert_eq!(elicitation.server_name, "test-server");
191 let _ = elicitation.response_sender.send(CreateElicitationResult {
192 action: ElicitationAction::Accept,
193 content: None,
194 meta: Option::default(),
195 });
196 });
197
198 let result = client.dispatch_elicitation(request).await;
199 handle.await.unwrap();
200 assert_eq!(result.action, ElicitationAction::Accept);
201 }
202
203 #[tokio::test]
204 async fn forward_url_elicitation_complete_uses_server_name_and_id() {
205 let (event_tx, mut event_rx) = mpsc::channel(1);
206 let client = make_client(event_tx);
207
208 client.forward_url_elicitation_complete("el-456".to_string()).await;
209
210 let event = event_rx.recv().await.unwrap();
211 match event {
212 McpClientEvent::UrlElicitationComplete(params) => {
213 assert_eq!(params.server_name, "test-server");
214 assert_eq!(params.elicitation_id, "el-456");
215 }
216 other @ McpClientEvent::Elicitation(_) => panic!("expected UrlElicitationComplete, got {other:?}"),
217 }
218 }
219
220 #[tokio::test]
221 async fn forward_url_elicitation_complete_swallows_dropped_receiver() {
222 let (event_tx, event_rx) = mpsc::channel(1);
223 drop(event_rx);
224 let client = make_client(event_tx);
225
226 client.forward_url_elicitation_complete("el-gone".to_string()).await;
228 }
229
230 #[test]
231 fn capabilities_include_form_and_url() {
232 let info = test_client_info();
233 let caps = &info.capabilities;
234 let elicitation = caps.elicitation.as_ref().expect("elicitation capability should be set");
235 assert!(elicitation.form.is_some(), "form capability should be advertised");
236 assert!(elicitation.url.is_some(), "url capability should be advertised");
237 }
238}