1use mcp_utils::client::{McpClient, McpConnectAttempt, McpError, McpManager, McpServerStatusEntry};
2use mcp_utils::display_meta::ToolResultMeta;
3
4use futures::future::Either;
5use futures::stream::{self, StreamExt};
6use llm::{ToolCallError, ToolCallRequest, ToolCallResult};
7use rmcp::RoleClient;
8use rmcp::model::{
9 CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
10 Prompt,
11};
12use rmcp::service::RunningService;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::select;
16use tokio::sync::mpsc;
17use tokio::sync::oneshot;
18use tokio::task::JoinSet;
19
20#[derive(Debug)]
22pub enum ToolExecutionEvent {
23 Started { tool_id: String, tool_name: String },
24 Progress { tool_id: String, progress: ProgressNotificationParam },
25 Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
26}
27
28const MCP_AUTH_TIMEOUT: Duration = Duration::from_mins(3);
29
30#[derive(Debug)]
32pub enum McpCommand {
33 ExecuteTool {
34 request: ToolCallRequest,
35 timeout: Duration,
36 tx: mpsc::Sender<ToolExecutionEvent>,
37 },
38 ListPrompts {
39 tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
40 },
41 GetPrompt {
42 name: String,
43 arguments: Option<serde_json::Map<String, serde_json::Value>>,
44 tx: oneshot::Sender<Result<GetPromptResult, String>>,
45 },
46 GetServerStatuses {
47 tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
48 },
49 AuthenticateServer {
50 name: String,
51 },
52}
53
54pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
55 let mut auth_tasks = JoinSet::<McpConnectAttempt>::new();
56 loop {
57 select! {
58 command = command_rx.recv() => {
59 let Some(command) = command else { break; };
60 on_command(command, &mut mcp, &mut auth_tasks).await;
61 }
62
63 Some(joined) = auth_tasks.join_next(), if !auth_tasks.is_empty() => {
64 match joined {
65 Ok(attempt) => mcp.apply_connection_attempt(attempt).await,
66 Err(e) => tracing::error!("MCP auth task did not complete normally: {e:?}"),
67 }
68 }
69 }
70 }
71
72 auth_tasks.abort_all();
73 while auth_tasks.join_next().await.is_some() {}
74 mcp.shutdown().await;
75 tracing::debug!("MCP manager task ended");
76}
77
78async fn on_command(command: McpCommand, mcp: &mut McpManager, auth_tasks: &mut JoinSet<McpConnectAttempt>) {
79 match command {
80 McpCommand::ExecuteTool { request, timeout, tx } => {
81 let tool_id = request.id.clone();
82 let tool_name = request.name.clone();
83
84 let _ =
85 tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
86
87 match mcp.get_client_for_tool(&request.name, &request.arguments) {
88 Ok((client, params)) => {
89 tokio::spawn(async move {
90 let outcome =
91 execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
92 let (result, result_meta) = match outcome {
93 Ok((r, m)) => (Ok(r), m),
94 Err(e) => (Err(e), None),
95 };
96 let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
97 });
98 }
99 Err(e) => {
100 tracing::error!("Failed to get client for tool {}: {e}", request.name);
101 let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
102 let _ =
103 tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
104 }
105 }
106 }
107
108 McpCommand::ListPrompts { tx } => {
109 let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
110 let _ = tx.send(result);
111 }
112
113 McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
114 let result =
115 mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
116 let _ = tx.send(result);
117 }
118
119 McpCommand::GetServerStatuses { tx } => {
120 let _ = tx.send(mcp.server_statuses().to_vec());
121 }
122
123 McpCommand::AuthenticateServer { name } => match mcp.authenticate_server_task(&name).await {
124 Ok(task) => {
125 let server_name = name.clone();
126 auth_tasks.spawn(async move {
127 match tokio::time::timeout(MCP_AUTH_TIMEOUT, task).await {
128 Ok(attempt) => attempt,
129 Err(_) => McpConnectAttempt::failed(
130 server_name,
131 McpError::ConnectionFailed("authentication timed out after 3 minutes".to_string()),
132 false,
133 ),
134 }
135 });
136 }
137 Err(e) => tracing::warn!("Authentication failed for '{name}': {e}"),
138 },
139 }
140}
141
142async fn execute_mcp_call(
145 client: Arc<RunningService<RoleClient, McpClient>>,
146 request: &ToolCallRequest,
147 params: CallToolRequestParams,
148 timeout: Duration,
149 tool_call_id: String,
150 event_tx: mpsc::Sender<ToolExecutionEvent>,
151) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
152 use super::tool_bridge::mcp_result_to_tool_call_result;
153 use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
154 use rmcp::service::PeerRequestOptions;
155
156 let handle = client
157 .send_cancellable_request(CallToolRequest(Request::new(params)), {
158 let mut opts = PeerRequestOptions::default();
159 opts.timeout = Some(timeout);
160 opts
161 })
162 .await
163 .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
164
165 let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
166
167 let progress_stream = progress_subscriber
168 .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
169
170 let result_stream = stream::once(handle.await_response()).map(Either::Right);
171 let combined_stream = stream::select(progress_stream, result_stream);
172 tokio::pin!(combined_stream);
173
174 let server_result = loop {
175 match combined_stream.next().await {
176 Some(Either::Left(progress_event)) => {
177 let _ = event_tx.send(progress_event).await;
178 }
179 Some(Either::Right(result)) => {
180 break match result {
181 Ok(server_result) => server_result,
182 Err(e) => {
183 if let rmcp::service::ServiceError::McpError(ref error_data) = e
184 && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
185 {
186 return Err(handle_url_elicitation_required(&client, request, error_data).await);
187 }
188 return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
189 }
190 };
191 }
192 None => {
193 return Err(ToolCallError::from_request(request, "Stream ended without result"));
194 }
195 }
196 };
197
198 let ServerResult::CallToolResult(mcp_result) = server_result else {
199 return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
200 };
201
202 mcp_result_to_tool_call_result(request, mcp_result)
203}
204
205#[derive(serde::Deserialize)]
206struct UrlElicitationRequiredData {
207 elicitations: Vec<CreateElicitationRequestParams>,
208}
209
210#[derive(Debug)]
211enum UrlElicitationRequiredParseError {
212 MissingData,
213 InvalidData(serde_json::Error),
214 NoUrlRequests,
215}
216
217impl std::fmt::Display for UrlElicitationRequiredParseError {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 Self::MissingData => write!(f, "missing error data"),
221 Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
222 Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
223 }
224 }
225}
226
227fn parse_required_url_elicitations(
228 error_data: &rmcp::model::ErrorData,
229) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
230 let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
231 let parsed: UrlElicitationRequiredData =
232 serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
233
234 let url_elicitations = parsed
235 .elicitations
236 .into_iter()
237 .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
238 .collect::<Vec<_>>();
239
240 if url_elicitations.is_empty() {
241 return Err(UrlElicitationRequiredParseError::NoUrlRequests);
242 }
243
244 Ok(url_elicitations)
245}
246
247async fn handle_url_elicitation_required(
251 client: &Arc<RunningService<RoleClient, McpClient>>,
252 request: &ToolCallRequest,
253 error_data: &rmcp::model::ErrorData,
254) -> ToolCallError {
255 let server_name = client.service().server_name().to_string();
256 let url_elicitations = match parse_required_url_elicitations(error_data) {
257 Ok(url_elicitations) => url_elicitations,
258 Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
259 return ToolCallError::from_request(
260 request,
261 format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
262 );
263 }
264 Err(parse_error) => {
265 return ToolCallError::from_request(
266 request,
267 format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
268 );
269 }
270 };
271
272 tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
273
274 for elicitation in url_elicitations {
275 let result = client.service().dispatch_elicitation(elicitation).await;
276 match result.action {
277 rmcp::model::ElicitationAction::Decline => {
278 return ToolCallError::from_request(
279 request,
280 format!("Required browser interaction for server '{server_name}' was declined"),
281 );
282 }
283 rmcp::model::ElicitationAction::Cancel => {
284 return ToolCallError::from_request(
285 request,
286 format!("Required browser interaction for server '{server_name}' was cancelled"),
287 );
288 }
289 rmcp::model::ElicitationAction::Accept => {
290 tracing::info!("User accepted URL elicitation for server '{server_name}'");
291 }
292 }
293 }
294
295 ToolCallError::from_request(
296 request,
297 format!(
298 "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
299 ),
300 )
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn url_elicitation_required_data_parses_url_entries() {
309 let data = serde_json::json!({
310 "elicitations": [
311 {
312 "mode": "url",
313 "message": "Auth",
314 "url": "https://example.com/auth?elicitationId=el-1",
315 "elicitationId": "el-1"
316 }
317 ]
318 });
319
320 let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
321 assert_eq!(parsed.elicitations.len(), 1);
322 assert!(matches!(
323 &parsed.elicitations[0],
324 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
325 ));
326 }
327
328 #[test]
329 fn parse_required_url_elicitations_filters_to_url_only() {
330 let error_data = rmcp::model::ErrorData {
331 code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
332 message: "URL elicitation required".into(),
333 data: Some(serde_json::json!({
334 "elicitations": [
335 {
336 "mode": "url",
337 "message": "Auth",
338 "url": "https://example.com/auth",
339 "elicitationId": "el-1"
340 },
341 {
342 "mode": "form",
343 "message": "Pick a color",
344 "requestedSchema": { "type": "object", "properties": {} }
345 }
346 ]
347 })),
348 };
349
350 let result = parse_required_url_elicitations(&error_data).unwrap();
351 assert_eq!(result.len(), 1);
352 assert!(matches!(
353 &result[0],
354 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
355 ));
356 }
357}