Skip to main content

aether_core/mcp/
run_mcp_task.rs

1use mcp_utils::client::{McpClient, McpConnectAttempt, McpError, McpManager, McpServerStatusEntry};
2use mcp_utils::display_meta::ToolResultMeta;
3
4use futures::future::Either;
5use futures::stream::{self, StreamExt};
6use llm::{ToolCallError, ToolCallRequest, ToolCallResult};
7use rmcp::RoleClient;
8use rmcp::model::{
9    CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
10    Prompt,
11};
12use rmcp::service::RunningService;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::select;
16use tokio::sync::mpsc;
17use tokio::sync::oneshot;
18use tokio::task::JoinSet;
19
20/// Events emitted during tool execution lifecycle
21#[derive(Debug)]
22pub enum ToolExecutionEvent {
23    Started { tool_id: String, tool_name: String },
24    Progress { tool_id: String, progress: ProgressNotificationParam },
25    Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
26}
27
28const MCP_AUTH_TIMEOUT: Duration = Duration::from_mins(3);
29
30/// Commands that can be sent to the MCP manager task
31#[derive(Debug)]
32pub enum McpCommand {
33    ExecuteTool {
34        request: ToolCallRequest,
35        timeout: Duration,
36        tx: mpsc::Sender<ToolExecutionEvent>,
37    },
38    ListPrompts {
39        tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
40    },
41    GetPrompt {
42        name: String,
43        arguments: Option<serde_json::Map<String, serde_json::Value>>,
44        tx: oneshot::Sender<Result<GetPromptResult, String>>,
45    },
46    GetServerStatuses {
47        tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
48    },
49    AuthenticateServer {
50        name: String,
51    },
52}
53
54pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
55    let mut auth_tasks = JoinSet::<McpConnectAttempt>::new();
56    loop {
57        select! {
58            command = command_rx.recv() => {
59                let Some(command) = command else { break; };
60                on_command(command, &mut mcp, &mut auth_tasks).await;
61            }
62
63            Some(joined) = auth_tasks.join_next(), if !auth_tasks.is_empty() => {
64                match joined {
65                    Ok(attempt) => mcp.apply_connection_attempt(attempt).await,
66                    Err(e) => tracing::error!("MCP auth task did not complete normally: {e:?}"),
67                }
68            }
69        }
70    }
71
72    auth_tasks.abort_all();
73    while auth_tasks.join_next().await.is_some() {}
74    mcp.shutdown().await;
75    tracing::debug!("MCP manager task ended");
76}
77
78async fn on_command(command: McpCommand, mcp: &mut McpManager, auth_tasks: &mut JoinSet<McpConnectAttempt>) {
79    match command {
80        McpCommand::ExecuteTool { request, timeout, tx } => {
81            let tool_id = request.id.clone();
82            let tool_name = request.name.clone();
83
84            let _ =
85                tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
86
87            match mcp.get_client_for_tool(&request.name, &request.arguments) {
88                Ok((client, params)) => {
89                    tokio::spawn(async move {
90                        let outcome =
91                            execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
92                        let (result, result_meta) = match outcome {
93                            Ok((r, m)) => (Ok(r), m),
94                            Err(e) => (Err(e), None),
95                        };
96                        let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
97                    });
98                }
99                Err(e) => {
100                    tracing::error!("Failed to get client for tool {}: {e}", request.name);
101                    let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
102                    let _ =
103                        tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
104                }
105            }
106        }
107
108        McpCommand::ListPrompts { tx } => {
109            let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
110            let _ = tx.send(result);
111        }
112
113        McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
114            let result =
115                mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
116            let _ = tx.send(result);
117        }
118
119        McpCommand::GetServerStatuses { tx } => {
120            let _ = tx.send(mcp.server_statuses().to_vec());
121        }
122
123        McpCommand::AuthenticateServer { name } => match mcp.authenticate_server_task(&name).await {
124            Ok(task) => {
125                let server_name = name.clone();
126                auth_tasks.spawn(async move {
127                    match tokio::time::timeout(MCP_AUTH_TIMEOUT, task).await {
128                        Ok(attempt) => attempt,
129                        Err(_) => McpConnectAttempt::failed(
130                            server_name,
131                            McpError::ConnectionFailed("authentication timed out after 3 minutes".to_string()),
132                            false,
133                        ),
134                    }
135                });
136            }
137            Err(e) => tracing::warn!("Authentication failed for '{name}': {e}"),
138        },
139    }
140}
141
142/// Shared logic for sending an MCP tool call, streaming progress events,
143/// and collecting the result.
144async fn execute_mcp_call(
145    client: Arc<RunningService<RoleClient, McpClient>>,
146    request: &ToolCallRequest,
147    params: CallToolRequestParams,
148    timeout: Duration,
149    tool_call_id: String,
150    event_tx: mpsc::Sender<ToolExecutionEvent>,
151) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
152    use super::tool_bridge::mcp_result_to_tool_call_result;
153    use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
154    use rmcp::service::PeerRequestOptions;
155
156    let handle = client
157        .send_cancellable_request(CallToolRequest(Request::new(params)), {
158            let mut opts = PeerRequestOptions::default();
159            opts.timeout = Some(timeout);
160            opts
161        })
162        .await
163        .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
164
165    let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
166
167    let progress_stream = progress_subscriber
168        .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
169
170    let result_stream = stream::once(handle.await_response()).map(Either::Right);
171    let combined_stream = stream::select(progress_stream, result_stream);
172    tokio::pin!(combined_stream);
173
174    let server_result = loop {
175        match combined_stream.next().await {
176            Some(Either::Left(progress_event)) => {
177                let _ = event_tx.send(progress_event).await;
178            }
179            Some(Either::Right(result)) => {
180                break match result {
181                    Ok(server_result) => server_result,
182                    Err(e) => {
183                        if let rmcp::service::ServiceError::McpError(ref error_data) = e
184                            && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
185                        {
186                            return Err(handle_url_elicitation_required(&client, request, error_data).await);
187                        }
188                        return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
189                    }
190                };
191            }
192            None => {
193                return Err(ToolCallError::from_request(request, "Stream ended without result"));
194            }
195        }
196    };
197
198    let ServerResult::CallToolResult(mcp_result) = server_result else {
199        return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
200    };
201
202    mcp_result_to_tool_call_result(request, mcp_result)
203}
204
205#[derive(serde::Deserialize)]
206struct UrlElicitationRequiredData {
207    elicitations: Vec<CreateElicitationRequestParams>,
208}
209
210#[derive(Debug)]
211enum UrlElicitationRequiredParseError {
212    MissingData,
213    InvalidData(serde_json::Error),
214    NoUrlRequests,
215}
216
217impl std::fmt::Display for UrlElicitationRequiredParseError {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        match self {
220            Self::MissingData => write!(f, "missing error data"),
221            Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
222            Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
223        }
224    }
225}
226
227fn parse_required_url_elicitations(
228    error_data: &rmcp::model::ErrorData,
229) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
230    let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
231    let parsed: UrlElicitationRequiredData =
232        serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
233
234    let url_elicitations = parsed
235        .elicitations
236        .into_iter()
237        .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
238        .collect::<Vec<_>>();
239
240    if url_elicitations.is_empty() {
241        return Err(UrlElicitationRequiredParseError::NoUrlRequests);
242    }
243
244    Ok(url_elicitations)
245}
246
247/// Handle a `URL_ELICITATION_REQUIRED` (-32042) error by dispatching each
248/// URL elicitation through the same consent channel used by normal
249/// `create_elicitation` requests.
250async fn handle_url_elicitation_required(
251    client: &Arc<RunningService<RoleClient, McpClient>>,
252    request: &ToolCallRequest,
253    error_data: &rmcp::model::ErrorData,
254) -> ToolCallError {
255    let server_name = client.service().server_name().to_string();
256    let url_elicitations = match parse_required_url_elicitations(error_data) {
257        Ok(url_elicitations) => url_elicitations,
258        Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
259            return ToolCallError::from_request(
260                request,
261                format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
262            );
263        }
264        Err(parse_error) => {
265            return ToolCallError::from_request(
266                request,
267                format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
268            );
269        }
270    };
271
272    tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
273
274    for elicitation in url_elicitations {
275        let result = client.service().dispatch_elicitation(elicitation).await;
276        match result.action {
277            rmcp::model::ElicitationAction::Decline => {
278                return ToolCallError::from_request(
279                    request,
280                    format!("Required browser interaction for server '{server_name}' was declined"),
281                );
282            }
283            rmcp::model::ElicitationAction::Cancel => {
284                return ToolCallError::from_request(
285                    request,
286                    format!("Required browser interaction for server '{server_name}' was cancelled"),
287                );
288            }
289            rmcp::model::ElicitationAction::Accept => {
290                tracing::info!("User accepted URL elicitation for server '{server_name}'");
291            }
292        }
293    }
294
295    ToolCallError::from_request(
296        request,
297        format!(
298            "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
299        ),
300    )
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn url_elicitation_required_data_parses_url_entries() {
309        let data = serde_json::json!({
310            "elicitations": [
311                {
312                    "mode": "url",
313                    "message": "Auth",
314                    "url": "https://example.com/auth?elicitationId=el-1",
315                    "elicitationId": "el-1"
316                }
317            ]
318        });
319
320        let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
321        assert_eq!(parsed.elicitations.len(), 1);
322        assert!(matches!(
323            &parsed.elicitations[0],
324            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
325        ));
326    }
327
328    #[test]
329    fn parse_required_url_elicitations_filters_to_url_only() {
330        let error_data = rmcp::model::ErrorData {
331            code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
332            message: "URL elicitation required".into(),
333            data: Some(serde_json::json!({
334                "elicitations": [
335                    {
336                        "mode": "url",
337                        "message": "Auth",
338                        "url": "https://example.com/auth",
339                        "elicitationId": "el-1"
340                    },
341                    {
342                        "mode": "form",
343                        "message": "Pick a color",
344                        "requestedSchema": { "type": "object", "properties": {} }
345                    }
346                ]
347            })),
348        };
349
350        let result = parse_required_url_elicitations(&error_data).unwrap();
351        assert_eq!(result.len(), 1);
352        assert!(matches!(
353            &result[0],
354            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
355        ));
356    }
357}