1use mcp_utils::client::{
2 McpClient, McpConnectAttempt, McpConnectionAttemptManager, McpError, McpManager, McpServer, McpServerStatusEntry,
3};
4use mcp_utils::display_meta::ToolResultMeta;
5
6use futures::future::Either;
7use futures::stream::{self, StreamExt};
8use llm::{ToolCallError, ToolCallRequest, ToolCallResult};
9use rmcp::RoleClient;
10use rmcp::model::{
11 CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
12 Prompt,
13};
14use rmcp::service::RunningService;
15use std::collections::HashSet;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::select;
19use tokio::sync::mpsc;
20use tokio::sync::oneshot;
21
22#[derive(Debug)]
24pub enum ToolExecutionEvent {
25 Started { tool_id: String, tool_name: String },
26 Progress { tool_id: String, progress: ProgressNotificationParam },
27 Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
28}
29
30const MCP_AUTH_TIMEOUT: Duration = Duration::from_mins(3);
31
32#[derive(Debug)]
34pub enum McpCommand {
35 ExecuteTool {
36 request: ToolCallRequest,
37 timeout: Duration,
38 tx: mpsc::Sender<ToolExecutionEvent>,
39 },
40 ListPrompts {
41 tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
42 },
43 GetPrompt {
44 name: String,
45 arguments: Option<serde_json::Map<String, serde_json::Value>>,
46 tx: oneshot::Sender<Result<GetPromptResult, String>>,
47 },
48 GetServerStatuses {
49 tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
50 },
51 AuthenticateServer {
52 name: String,
53 },
54}
55
56pub async fn run_mcp_task(
57 mut mcp: McpManager,
58 mut command_rx: mpsc::Receiver<McpCommand>,
59 pending_servers: Vec<McpServer>,
60) {
61 let mut mcp_connection_attempts = McpConnectionAttemptManager::default();
62 let mut pending_connections: HashSet<String> = pending_servers.iter().map(|server| server.name.clone()).collect();
63 for server in pending_servers {
64 let name = server.name.clone();
65 let task = mcp.connect_pending_task(server);
66 mcp_connection_attempts.spawn(name, task);
67 }
68 if pending_connections.is_empty() {
69 mcp.emit_connection_ready().await;
70 }
71
72 loop {
73 select! {
74 command = command_rx.recv() => {
75 let Some(command) = command else { break; };
76 on_command(command, &mut mcp, &mut mcp_connection_attempts).await;
77 }
78
79 Some(joined) = mcp_connection_attempts.join_next(), if !mcp_connection_attempts.is_empty() => {
80 match joined {
81 Ok(attempt) => {
82 let was_bootstrap = pending_connections.remove(&attempt.name);
83 mcp.apply_connection_attempt(attempt).await;
84 if was_bootstrap && pending_connections.is_empty() {
85 mcp.emit_connection_ready().await;
86 }
87 }
88 Err(e) => tracing::error!("MCP auth task did not complete normally: {e:?}"),
89 }
90 }
91 }
92 }
93
94 mcp_connection_attempts.shutdown().await;
95 mcp.shutdown().await;
96 tracing::debug!("MCP manager task ended");
97}
98
99async fn on_command(command: McpCommand, mcp: &mut McpManager, auth_tasks: &mut McpConnectionAttemptManager) {
100 match command {
101 McpCommand::ExecuteTool { request, timeout, tx } => {
102 let tool_id = request.id.clone();
103 let tool_name = request.name.clone();
104
105 let _ =
106 tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
107
108 match mcp.get_client_for_tool(&request.name, &request.arguments) {
109 Ok((client, params)) => {
110 tokio::spawn(async move {
111 let outcome =
112 execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
113 let (result, result_meta) = match outcome {
114 Ok((r, m)) => (Ok(r), m),
115 Err(e) => (Err(e), None),
116 };
117 let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
118 });
119 }
120 Err(e) => {
121 tracing::error!("Failed to get client for tool {}: {e}", request.name);
122 let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
123 let _ =
124 tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
125 }
126 }
127 }
128
129 McpCommand::ListPrompts { tx } => {
130 let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
131 let _ = tx.send(result);
132 }
133
134 McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
135 let result =
136 mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
137 let _ = tx.send(result);
138 }
139
140 McpCommand::GetServerStatuses { tx } => {
141 let _ = tx.send(mcp.server_statuses());
142 }
143
144 McpCommand::AuthenticateServer { name } => match mcp.authenticate_server_task(&name).await {
145 Ok(task) => {
146 let server_name = name.clone();
147 auth_tasks.spawn(name, async move {
148 match tokio::time::timeout(MCP_AUTH_TIMEOUT, task).await {
149 Ok(attempt) => attempt,
150 Err(_) => McpConnectAttempt::failed(
151 server_name,
152 McpError::ConnectionFailed("authentication timed out after 3 minutes".to_string()),
153 false,
154 ),
155 }
156 });
157 }
158 Err(e) => tracing::warn!("Authentication failed for '{name}': {e}"),
159 },
160 }
161}
162
163async fn execute_mcp_call(
166 client: Arc<RunningService<RoleClient, McpClient>>,
167 request: &ToolCallRequest,
168 params: CallToolRequestParams,
169 timeout: Duration,
170 tool_call_id: String,
171 event_tx: mpsc::Sender<ToolExecutionEvent>,
172) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
173 use super::tool_bridge::mcp_result_to_tool_call_result;
174 use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
175 use rmcp::service::PeerRequestOptions;
176
177 let handle = client
178 .send_cancellable_request(CallToolRequest(Request::new(params)), {
179 let mut opts = PeerRequestOptions::default();
180 opts.timeout = Some(timeout);
181 opts
182 })
183 .await
184 .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
185
186 let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
187
188 let progress_stream = progress_subscriber
189 .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
190
191 let result_stream = stream::once(handle.await_response()).map(Either::Right);
192 let combined_stream = stream::select(progress_stream, result_stream);
193 tokio::pin!(combined_stream);
194
195 let server_result = loop {
196 match combined_stream.next().await {
197 Some(Either::Left(progress_event)) => {
198 let _ = event_tx.send(progress_event).await;
199 }
200 Some(Either::Right(result)) => {
201 break match result {
202 Ok(server_result) => server_result,
203 Err(e) => {
204 if let rmcp::service::ServiceError::McpError(ref error_data) = e
205 && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
206 {
207 return Err(handle_url_elicitation_required(&client, request, error_data).await);
208 }
209 return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
210 }
211 };
212 }
213 None => {
214 return Err(ToolCallError::from_request(request, "Stream ended without result"));
215 }
216 }
217 };
218
219 let ServerResult::CallToolResult(mcp_result) = server_result else {
220 return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
221 };
222
223 mcp_result_to_tool_call_result(request, mcp_result)
224}
225
226#[derive(serde::Deserialize)]
227struct UrlElicitationRequiredData {
228 elicitations: Vec<CreateElicitationRequestParams>,
229}
230
231#[derive(Debug)]
232enum UrlElicitationRequiredParseError {
233 MissingData,
234 InvalidData(serde_json::Error),
235 NoUrlRequests,
236}
237
238impl std::fmt::Display for UrlElicitationRequiredParseError {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 match self {
241 Self::MissingData => write!(f, "missing error data"),
242 Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
243 Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
244 }
245 }
246}
247
248fn parse_required_url_elicitations(
249 error_data: &rmcp::model::ErrorData,
250) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
251 let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
252 let parsed: UrlElicitationRequiredData =
253 serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
254
255 let url_elicitations = parsed
256 .elicitations
257 .into_iter()
258 .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
259 .collect::<Vec<_>>();
260
261 if url_elicitations.is_empty() {
262 return Err(UrlElicitationRequiredParseError::NoUrlRequests);
263 }
264
265 Ok(url_elicitations)
266}
267
268async fn handle_url_elicitation_required(
272 client: &Arc<RunningService<RoleClient, McpClient>>,
273 request: &ToolCallRequest,
274 error_data: &rmcp::model::ErrorData,
275) -> ToolCallError {
276 let server_name = client.service().server_name().to_string();
277 let url_elicitations = match parse_required_url_elicitations(error_data) {
278 Ok(url_elicitations) => url_elicitations,
279 Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
280 return ToolCallError::from_request(
281 request,
282 format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
283 );
284 }
285 Err(parse_error) => {
286 return ToolCallError::from_request(
287 request,
288 format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
289 );
290 }
291 };
292
293 tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
294
295 for elicitation in url_elicitations {
296 let result = client.service().dispatch_elicitation(elicitation).await;
297 match result.action {
298 rmcp::model::ElicitationAction::Decline => {
299 return ToolCallError::from_request(
300 request,
301 format!("Required browser interaction for server '{server_name}' was declined"),
302 );
303 }
304 rmcp::model::ElicitationAction::Cancel => {
305 return ToolCallError::from_request(
306 request,
307 format!("Required browser interaction for server '{server_name}' was cancelled"),
308 );
309 }
310 rmcp::model::ElicitationAction::Accept => {
311 tracing::info!("User accepted URL elicitation for server '{server_name}'");
312 }
313 }
314 }
315
316 ToolCallError::from_request(
317 request,
318 format!(
319 "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
320 ),
321 )
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn url_elicitation_required_data_parses_url_entries() {
330 let data = serde_json::json!({
331 "elicitations": [
332 {
333 "mode": "url",
334 "message": "Auth",
335 "url": "https://example.com/auth?elicitationId=el-1",
336 "elicitationId": "el-1"
337 }
338 ]
339 });
340
341 let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
342 assert_eq!(parsed.elicitations.len(), 1);
343 assert!(matches!(
344 &parsed.elicitations[0],
345 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
346 ));
347 }
348
349 #[test]
350 fn parse_required_url_elicitations_filters_to_url_only() {
351 let error_data = rmcp::model::ErrorData {
352 code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
353 message: "URL elicitation required".into(),
354 data: Some(serde_json::json!({
355 "elicitations": [
356 {
357 "mode": "url",
358 "message": "Auth",
359 "url": "https://example.com/auth",
360 "elicitationId": "el-1"
361 },
362 {
363 "mode": "form",
364 "message": "Pick a color",
365 "requestedSchema": { "type": "object", "properties": {} }
366 }
367 ]
368 })),
369 };
370
371 let result = parse_required_url_elicitations(&error_data).unwrap();
372 assert_eq!(result.len(), 1);
373 assert!(matches!(
374 &result[0],
375 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
376 ));
377 }
378}