Skip to main content

aether_core/mcp/
run_mcp_task.rs

1use mcp_utils::client::mcp_client::McpClient;
2use mcp_utils::client::{McpManager, McpServerStatusEntry};
3use mcp_utils::display_meta::ToolResultMeta;
4
5use futures::future::Either;
6use futures::stream::{self, StreamExt};
7use llm::{ToolCallError, ToolCallRequest, ToolCallResult, ToolDefinition};
8use rmcp::RoleClient;
9use rmcp::model::{
10    CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
11    Prompt,
12};
13use rmcp::service::RunningService;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio::sync::oneshot;
18
19/// Events emitted during tool execution lifecycle
20#[derive(Debug)]
21pub enum ToolExecutionEvent {
22    Started { tool_id: String, tool_name: String },
23    Progress { tool_id: String, progress: ProgressNotificationParam },
24    Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
25}
26
27type AuthResult = Result<(Vec<McpServerStatusEntry>, Vec<ToolDefinition>), String>;
28
29/// Commands that can be sent to the MCP manager task
30#[derive(Debug)]
31pub enum McpCommand {
32    ExecuteTool {
33        request: ToolCallRequest,
34        timeout: Duration,
35        tx: mpsc::Sender<ToolExecutionEvent>,
36    },
37    ListPrompts {
38        tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
39    },
40    GetPrompt {
41        name: String,
42        arguments: Option<serde_json::Map<String, serde_json::Value>>,
43        tx: oneshot::Sender<Result<GetPromptResult, String>>,
44    },
45    GetServerStatuses {
46        tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
47    },
48    AuthenticateServer {
49        name: String,
50        tx: oneshot::Sender<AuthResult>,
51    },
52}
53
54pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
55    while let Some(command) = command_rx.recv().await {
56        on_command(command, &mut mcp).await;
57    }
58
59    mcp.shutdown().await;
60    tracing::debug!("MCP manager task ended");
61}
62
63async fn on_command(command: McpCommand, mcp: &mut McpManager) {
64    match command {
65        McpCommand::ExecuteTool { request, timeout, tx } => {
66            let tool_id = request.id.clone();
67            let tool_name = request.name.clone();
68
69            let _ =
70                tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
71
72            match mcp.get_client_for_tool(&request.name, &request.arguments) {
73                Ok((client, params)) => {
74                    tokio::spawn(async move {
75                        let outcome =
76                            execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
77                        let (result, result_meta) = match outcome {
78                            Ok((r, m)) => (Ok(r), m),
79                            Err(e) => (Err(e), None),
80                        };
81                        let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
82                    });
83                }
84                Err(e) => {
85                    tracing::error!("Failed to get client for tool {}: {e}", request.name);
86                    let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
87                    let _ =
88                        tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
89                }
90            }
91        }
92
93        McpCommand::ListPrompts { tx } => {
94            let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
95            let _ = tx.send(result);
96        }
97
98        McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
99            let result =
100                mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
101            let _ = tx.send(result);
102        }
103
104        McpCommand::GetServerStatuses { tx } => {
105            let _ = tx.send(mcp.server_statuses().to_vec());
106        }
107
108        McpCommand::AuthenticateServer { name, tx } => {
109            let result = match mcp.authenticate_server(&name).await {
110                Ok(()) => Ok((mcp.server_statuses().to_vec(), mcp.tool_definitions())),
111                Err(e) => Err(format!("Authentication failed for '{name}': {e}")),
112            };
113            let _ = tx.send(result);
114        }
115    }
116}
117
118/// Shared logic for sending an MCP tool call, streaming progress events,
119/// and collecting the result.
120async fn execute_mcp_call(
121    client: Arc<RunningService<RoleClient, McpClient>>,
122    request: &ToolCallRequest,
123    params: CallToolRequestParams,
124    timeout: Duration,
125    tool_call_id: String,
126    event_tx: mpsc::Sender<ToolExecutionEvent>,
127) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
128    use super::tool_bridge::mcp_result_to_tool_call_result;
129    use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
130    use rmcp::service::PeerRequestOptions;
131
132    let handle = client
133        .send_cancellable_request(CallToolRequest(Request::new(params)), {
134            let mut opts = PeerRequestOptions::default();
135            opts.timeout = Some(timeout);
136            opts
137        })
138        .await
139        .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
140
141    let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
142
143    let progress_stream = progress_subscriber
144        .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
145
146    let result_stream = stream::once(handle.await_response()).map(Either::Right);
147    let combined_stream = stream::select(progress_stream, result_stream);
148    tokio::pin!(combined_stream);
149
150    let server_result = loop {
151        match combined_stream.next().await {
152            Some(Either::Left(progress_event)) => {
153                let _ = event_tx.send(progress_event).await;
154            }
155            Some(Either::Right(result)) => {
156                break match result {
157                    Ok(server_result) => server_result,
158                    Err(e) => {
159                        if let rmcp::service::ServiceError::McpError(ref error_data) = e
160                            && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
161                        {
162                            return Err(handle_url_elicitation_required(&client, request, error_data).await);
163                        }
164                        return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
165                    }
166                };
167            }
168            None => {
169                return Err(ToolCallError::from_request(request, "Stream ended without result"));
170            }
171        }
172    };
173
174    let ServerResult::CallToolResult(mcp_result) = server_result else {
175        return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
176    };
177
178    mcp_result_to_tool_call_result(request, mcp_result)
179}
180
181#[derive(serde::Deserialize)]
182struct UrlElicitationRequiredData {
183    elicitations: Vec<CreateElicitationRequestParams>,
184}
185
186#[derive(Debug)]
187enum UrlElicitationRequiredParseError {
188    MissingData,
189    InvalidData(serde_json::Error),
190    NoUrlRequests,
191}
192
193impl std::fmt::Display for UrlElicitationRequiredParseError {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        match self {
196            Self::MissingData => write!(f, "missing error data"),
197            Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
198            Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
199        }
200    }
201}
202
203fn parse_required_url_elicitations(
204    error_data: &rmcp::model::ErrorData,
205) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
206    let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
207    let parsed: UrlElicitationRequiredData =
208        serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
209
210    let url_elicitations = parsed
211        .elicitations
212        .into_iter()
213        .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
214        .collect::<Vec<_>>();
215
216    if url_elicitations.is_empty() {
217        return Err(UrlElicitationRequiredParseError::NoUrlRequests);
218    }
219
220    Ok(url_elicitations)
221}
222
223/// Handle a `URL_ELICITATION_REQUIRED` (-32042) error by dispatching each
224/// URL elicitation through the same consent channel used by normal
225/// `create_elicitation` requests.
226async fn handle_url_elicitation_required(
227    client: &Arc<RunningService<RoleClient, McpClient>>,
228    request: &ToolCallRequest,
229    error_data: &rmcp::model::ErrorData,
230) -> ToolCallError {
231    let server_name = client.service().server_name().to_string();
232    let url_elicitations = match parse_required_url_elicitations(error_data) {
233        Ok(url_elicitations) => url_elicitations,
234        Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
235            return ToolCallError::from_request(
236                request,
237                format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
238            );
239        }
240        Err(parse_error) => {
241            return ToolCallError::from_request(
242                request,
243                format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
244            );
245        }
246    };
247
248    tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
249
250    for elicitation in url_elicitations {
251        let result = client.service().dispatch_elicitation(elicitation).await;
252        match result.action {
253            rmcp::model::ElicitationAction::Decline => {
254                return ToolCallError::from_request(
255                    request,
256                    format!("Required browser interaction for server '{server_name}' was declined"),
257                );
258            }
259            rmcp::model::ElicitationAction::Cancel => {
260                return ToolCallError::from_request(
261                    request,
262                    format!("Required browser interaction for server '{server_name}' was cancelled"),
263                );
264            }
265            rmcp::model::ElicitationAction::Accept => {
266                tracing::info!("User accepted URL elicitation for server '{server_name}'");
267            }
268        }
269    }
270
271    ToolCallError::from_request(
272        request,
273        format!(
274            "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
275        ),
276    )
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn url_elicitation_required_data_parses_url_entries() {
285        let data = serde_json::json!({
286            "elicitations": [
287                {
288                    "mode": "url",
289                    "message": "Auth",
290                    "url": "https://example.com/auth?elicitationId=el-1",
291                    "elicitationId": "el-1"
292                }
293            ]
294        });
295
296        let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
297        assert_eq!(parsed.elicitations.len(), 1);
298        assert!(matches!(
299            &parsed.elicitations[0],
300            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
301        ));
302    }
303
304    #[test]
305    fn parse_required_url_elicitations_filters_to_url_only() {
306        let error_data = rmcp::model::ErrorData {
307            code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
308            message: "URL elicitation required".into(),
309            data: Some(serde_json::json!({
310                "elicitations": [
311                    {
312                        "mode": "url",
313                        "message": "Auth",
314                        "url": "https://example.com/auth",
315                        "elicitationId": "el-1"
316                    },
317                    {
318                        "mode": "form",
319                        "message": "Pick a color",
320                        "requestedSchema": { "type": "object", "properties": {} }
321                    }
322                ]
323            })),
324        };
325
326        let result = parse_required_url_elicitations(&error_data).unwrap();
327        assert_eq!(result.len(), 1);
328        assert!(matches!(
329            &result[0],
330            CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
331        ));
332    }
333}