Skip to main content

aether_core/mcp/
run_mcp_task.rs

1use mcp_utils::client::{
2    McpClient, McpConnectAttempt, McpConnectionAttemptManager, McpError, McpManager, McpServer, McpServerStatusEntry,
3};
4use mcp_utils::display_meta::ToolResultMeta;
5
6use futures::future::Either;
7use futures::stream::{self, StreamExt};
8use llm::{ToolCallError, ToolCallRequest, ToolCallResult};
9use rmcp::RoleClient;
10use rmcp::model::{
11    CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
12    Prompt,
13};
14use rmcp::service::RunningService;
15use std::collections::HashSet;
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::select;
19use tokio::sync::mpsc;
20use tokio::sync::oneshot;
21
22/// Events emitted during tool execution lifecycle
23#[derive(Debug)]
24pub enum ToolExecutionEvent {
25    Started { tool_id: String, tool_name: String },
26    Progress { tool_id: String, progress: ProgressNotificationParam },
27    Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
28}
29
30const MCP_AUTH_TIMEOUT: Duration = Duration::from_mins(3);
31
32/// Commands that can be sent to the MCP manager task
33#[derive(Debug)]
34pub enum McpCommand {
35    ExecuteTool {
36        request: ToolCallRequest,
37        timeout: Duration,
38        tx: mpsc::Sender<ToolExecutionEvent>,
39    },
40    ListPrompts {
41        tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
42    },
43    GetPrompt {
44        name: String,
45        arguments: Option<serde_json::Map<String, serde_json::Value>>,
46        tx: oneshot::Sender<Result<GetPromptResult, String>>,
47    },
48    GetServerStatuses {
49        tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
50    },
51    AuthenticateServer {
52        name: String,
53    },
54}
55
56pub async fn run_mcp_task(
57    mut mcp: McpManager,
58    mut command_rx: mpsc::Receiver<McpCommand>,
59    pending_servers: Vec<McpServer>,
60) {
61    let mut mcp_connection_attempts = McpConnectionAttemptManager::default();
62    let mut pending_connections: HashSet<String> = pending_servers.iter().map(|server| server.name.clone()).collect();
63    for server in pending_servers {
64        let name = server.name.clone();
65        let task = mcp.connect_pending_task(server);
66        mcp_connection_attempts.spawn(name, task);
67    }
68    if pending_connections.is_empty() {
69        mcp.emit_connection_ready().await;
70    }
71
72    loop {
73        select! {
74            command = command_rx.recv() => {
75                let Some(command) = command else { break; };
76                on_command(command, &mut mcp, &mut mcp_connection_attempts).await;
77            }
78
79            Some(joined) = mcp_connection_attempts.join_next(), if !mcp_connection_attempts.is_empty() => {
80                match joined {
81                    Ok(attempt) => {
82                        let was_bootstrap = pending_connections.remove(&attempt.name);
83                        mcp.apply_connection_attempt(attempt).await;
84                        if was_bootstrap && pending_connections.is_empty() {
85                            mcp.emit_connection_ready().await;
86                        }
87                    }
88                    Err(e) => tracing::error!("MCP auth task did not complete normally: {e:?}"),
89                }
90            }
91        }
92    }
93
94    mcp_connection_attempts.shutdown().await;
95    mcp.shutdown().await;
96    tracing::debug!("MCP manager task ended");
97}
98
99async fn on_command(command: McpCommand, mcp: &mut McpManager, auth_tasks: &mut McpConnectionAttemptManager) {
100    match command {
101        McpCommand::ExecuteTool { request, timeout, tx } => {
102            let tool_id = request.id.clone();
103            let tool_name = request.name.clone();
104
105            let _ =
106                tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
107
108            match mcp.get_client_for_tool(&request.name, &request.arguments) {
109                Ok((client, params)) => {
110                    tokio::spawn(async move {
111                        let outcome =
112                            execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
113                        let (result, result_meta) = match outcome {
114                            Ok((r, m)) => (Ok(r), m),
115                            Err(e) => (Err(e), None),
116                        };
117                        let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
118                    });
119                }
120                Err(e) => {
121                    tracing::error!("Failed to get client for tool {}: {e}", request.name);
122                    let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
123                    let _ =
124                        tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
125                }
126            }
127        }
128
129        McpCommand::ListPrompts { tx } => {
130            let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
131            let _ = tx.send(result);
132        }
133
134        McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
135            let result =
136                mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
137            let _ = tx.send(result);
138        }
139
140        McpCommand::GetServerStatuses { tx } => {
141            let _ = tx.send(mcp.server_statuses());
142        }
143
144        McpCommand::AuthenticateServer { name } => match mcp.authenticate_server_task(&name).await {
145            Ok(task) => {
146                let server_name = name.clone();
147                auth_tasks.spawn(name, async move {
148                    match tokio::time::timeout(MCP_AUTH_TIMEOUT, task).await {
149                        Ok(attempt) => attempt,
150                        Err(_) => McpConnectAttempt::failed(
151                            server_name,
152                            McpError::ConnectionFailed("authentication timed out after 3 minutes".to_string()),
153                            false,
154                        ),
155                    }
156                });
157            }
158            Err(e) => tracing::warn!("Authentication failed for '{name}': {e}"),
159        },
160    }
161}
162
163/// Shared logic for sending an MCP tool call, streaming progress events,
164/// and collecting the result.
165async fn execute_mcp_call(
166    client: Arc<RunningService<RoleClient, McpClient>>,
167    request: &ToolCallRequest,
168    params: CallToolRequestParams,
169    timeout: Duration,
170    tool_call_id: String,
171    event_tx: mpsc::Sender<ToolExecutionEvent>,
172) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
173    use super::tool_bridge::mcp_result_to_tool_call_result;
174    use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
175    use rmcp::service::PeerRequestOptions;
176
177    let handle = client
178        .send_cancellable_request(CallToolRequest(Request::new(params)), {
179            let mut opts = PeerRequestOptions::default();
180            opts.timeout = Some(timeout);
181            opts
182        })
183        .await
184        .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
185
186    let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
187
188    let progress_stream = progress_subscriber
189        .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
190
191    let result_stream = stream::once(handle.await_response()).map(Either::Right);
192    let combined_stream = stream::select(progress_stream, result_stream);
193    tokio::pin!(combined_stream);
194
195    let server_result = loop {
196        match combined_stream.next().await {
197            Some(Either::Left(progress_event)) => {
198                let _ = event_tx.send(progress_event).await;
199            }
200            Some(Either::Right(result)) => {
201                break match result {
202                    Ok(server_result) => server_result,
203                    Err(e) => {
204                        if let rmcp::service::ServiceError::McpError(ref error_data) = e
205                            && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
206                        {
207                            return Err(handle_url_elicitation_required(&client, request, error_data).await);
208                        }
209                        return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
210                    }
211                };
212            }
213            None => {
214                return Err(ToolCallError::from_request(request, "Stream ended without result"));
215            }
216        }
217    };
218
219    let ServerResult::CallToolResult(mcp_result) = server_result else {
220        return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
221    };
222
223    mcp_result_to_tool_call_result(request, mcp_result)
224}
225
226#[derive(serde::Deserialize)]
227struct UrlElicitationRequiredData {
228    elicitations: Vec<CreateElicitationRequestParams>,
229}
230
231#[derive(Debug)]
232enum UrlElicitationRequiredParseError {
233    MissingData,
234    InvalidData(serde_json::Error),
235    NoUrlRequests,
236}
237
238impl std::fmt::Display for UrlElicitationRequiredParseError {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        match self {
241            Self::MissingData => write!(f, "missing error data"),
242            Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
243            Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
244        }
245    }
246}
247
248fn parse_required_url_elicitations(
249    error_data: &rmcp::model::ErrorData,
250) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
251    let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
252    let parsed: UrlElicitationRequiredData =
253        serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
254
255    let url_elicitations = parsed
256        .elicitations
257        .into_iter()
258        .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
259        .collect::<Vec<_>>();
260
261    if url_elicitations.is_empty() {
262        return Err(UrlElicitationRequiredParseError::NoUrlRequests);
263    }
264
265    Ok(url_elicitations)
266}
267
268/// Handle a `URL_ELICITATION_REQUIRED` (-32042) error by dispatching each
269/// URL elicitation through the same consent channel used by normal
270/// `create_elicitation` requests.
271async fn handle_url_elicitation_required(
272    client: &Arc<RunningService<RoleClient, McpClient>>,
273    request: &ToolCallRequest,
274    error_data: &rmcp::model::ErrorData,
275) -> ToolCallError {
276    let server_name = client.service().server_name().to_string();
277    let url_elicitations = match parse_required_url_elicitations(error_data) {
278        Ok(url_elicitations) => url_elicitations,
279        Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
280            return ToolCallError::from_request(
281                request,
282                format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
283            );
284        }
285        Err(parse_error) => {
286            return ToolCallError::from_request(
287                request,
288                format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
289            );
290        }
291    };
292
293    tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
294
295    for elicitation in url_elicitations {
296        let result = client.service().dispatch_elicitation(elicitation).await;
297        match result.action {
298            rmcp::model::ElicitationAction::Decline => {
299                return ToolCallError::from_request(
300                    request,
301                    format!("Required browser interaction for server '{server_name}' was declined"),
302                );
303            }
304            rmcp::model::ElicitationAction::Cancel => {
305                return ToolCallError::from_request(
306                    request,
307                    format!("Required browser interaction for server '{server_name}' was cancelled"),
308                );
309            }
310            rmcp::model::ElicitationAction::Accept => {
311                tracing::info!("User accepted URL elicitation for server '{server_name}'");
312            }
313        }
314    }
315
316    ToolCallError::from_request(
317        request,
318        format!(
319            "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
320        ),
321    )
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn url_elicitation_required_data_parses_url_entries() {
330        let data = serde_json::json!({
331            "elicitations": [
332                {
333                    "mode": "url",
334                    "message": "Auth",
335                    "url": "https://example.com/auth?elicitationId=el-1",
336                    "elicitationId": "el-1"
337                }
338            ]
339        });
340
341        let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
342        assert_eq!(parsed.elicitations.len(), 1);
343        assert!(matches!(
344            &parsed.elicitations[0],
345            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
346        ));
347    }
348
349    #[test]
350    fn parse_required_url_elicitations_filters_to_url_only() {
351        let error_data = rmcp::model::ErrorData {
352            code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
353            message: "URL elicitation required".into(),
354            data: Some(serde_json::json!({
355                "elicitations": [
356                    {
357                        "mode": "url",
358                        "message": "Auth",
359                        "url": "https://example.com/auth",
360                        "elicitationId": "el-1"
361                    },
362                    {
363                        "mode": "form",
364                        "message": "Pick a color",
365                        "requestedSchema": { "type": "object", "properties": {} }
366                    }
367                ]
368            })),
369        };
370
371        let result = parse_required_url_elicitations(&error_data).unwrap();
372        assert_eq!(result.len(), 1);
373        assert!(matches!(
374            &result[0],
375            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
376        ));
377    }
378}