1use mcp_utils::client::mcp_client::McpClient;
2use mcp_utils::client::{McpManager, McpServerStatusEntry};
3use mcp_utils::display_meta::ToolResultMeta;
4
5use futures::future::Either;
6use futures::stream::{self, StreamExt};
7use llm::{ToolCallError, ToolCallRequest, ToolCallResult, ToolDefinition};
8use rmcp::RoleClient;
9use rmcp::model::{
10 CallToolRequestParams, CreateElicitationRequestParams, ErrorCode, GetPromptResult, ProgressNotificationParam,
11 Prompt,
12};
13use rmcp::service::RunningService;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio::sync::oneshot;
18
19#[derive(Debug)]
21pub enum ToolExecutionEvent {
22 Started { tool_id: String, tool_name: String },
23 Progress { tool_id: String, progress: ProgressNotificationParam },
24 Complete { tool_id: String, result: Result<ToolCallResult, ToolCallError>, result_meta: Option<ToolResultMeta> },
25}
26
27type AuthResult = Result<(Vec<McpServerStatusEntry>, Vec<ToolDefinition>), String>;
28
29#[derive(Debug)]
31pub enum McpCommand {
32 ExecuteTool {
33 request: ToolCallRequest,
34 timeout: Duration,
35 tx: mpsc::Sender<ToolExecutionEvent>,
36 },
37 ListPrompts {
38 tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
39 },
40 GetPrompt {
41 name: String,
42 arguments: Option<serde_json::Map<String, serde_json::Value>>,
43 tx: oneshot::Sender<Result<GetPromptResult, String>>,
44 },
45 GetServerStatuses {
46 tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
47 },
48 AuthenticateServer {
49 name: String,
50 tx: oneshot::Sender<AuthResult>,
51 },
52}
53
54pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
55 while let Some(command) = command_rx.recv().await {
56 on_command(command, &mut mcp).await;
57 }
58
59 mcp.shutdown().await;
60 tracing::debug!("MCP manager task ended");
61}
62
63async fn on_command(command: McpCommand, mcp: &mut McpManager) {
64 match command {
65 McpCommand::ExecuteTool { request, timeout, tx } => {
66 let tool_id = request.id.clone();
67 let tool_name = request.name.clone();
68
69 let _ =
70 tx.send(ToolExecutionEvent::Started { tool_id: tool_id.clone(), tool_name: tool_name.clone() }).await;
71
72 match mcp.get_client_for_tool(&request.name, &request.arguments) {
73 Ok((client, params)) => {
74 tokio::spawn(async move {
75 let outcome =
76 execute_mcp_call(client, &request, params, timeout, tool_id.clone(), tx.clone()).await;
77 let (result, result_meta) = match outcome {
78 Ok((r, m)) => (Ok(r), m),
79 Err(e) => (Err(e), None),
80 };
81 let _ = tx.send(ToolExecutionEvent::Complete { tool_id, result, result_meta }).await;
82 });
83 }
84 Err(e) => {
85 tracing::error!("Failed to get client for tool {}: {e}", request.name);
86 let error = ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
87 let _ =
88 tx.send(ToolExecutionEvent::Complete { tool_id, result: Err(error), result_meta: None }).await;
89 }
90 }
91 }
92
93 McpCommand::ListPrompts { tx } => {
94 let result = mcp.list_prompts().await.map_err(|e| format!("Failed to list prompts: {e}"));
95 let _ = tx.send(result);
96 }
97
98 McpCommand::GetPrompt { name: namespaced_name, arguments, tx } => {
99 let result =
100 mcp.get_prompt(&namespaced_name, arguments).await.map_err(|e| format!("Failed to get prompt: {e}"));
101 let _ = tx.send(result);
102 }
103
104 McpCommand::GetServerStatuses { tx } => {
105 let _ = tx.send(mcp.server_statuses().to_vec());
106 }
107
108 McpCommand::AuthenticateServer { name, tx } => {
109 let result = match mcp.authenticate_server(&name).await {
110 Ok(()) => Ok((mcp.server_statuses().to_vec(), mcp.tool_definitions())),
111 Err(e) => Err(format!("Authentication failed for '{name}': {e}")),
112 };
113 let _ = tx.send(result);
114 }
115 }
116}
117
118async fn execute_mcp_call(
121 client: Arc<RunningService<RoleClient, McpClient>>,
122 request: &ToolCallRequest,
123 params: CallToolRequestParams,
124 timeout: Duration,
125 tool_call_id: String,
126 event_tx: mpsc::Sender<ToolExecutionEvent>,
127) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
128 use super::tool_bridge::mcp_result_to_tool_call_result;
129 use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
130 use rmcp::service::PeerRequestOptions;
131
132 let handle = client
133 .send_cancellable_request(CallToolRequest(Request::new(params)), {
134 let mut opts = PeerRequestOptions::default();
135 opts.timeout = Some(timeout);
136 opts
137 })
138 .await
139 .map_err(|e| ToolCallError::from_request(request, format!("Failed to send tool request: {e}")))?;
140
141 let progress_subscriber = client.service().progress_dispatcher.subscribe(handle.progress_token.clone()).await;
142
143 let progress_stream = progress_subscriber
144 .map(move |progress| Either::Left(ToolExecutionEvent::Progress { tool_id: tool_call_id.clone(), progress }));
145
146 let result_stream = stream::once(handle.await_response()).map(Either::Right);
147 let combined_stream = stream::select(progress_stream, result_stream);
148 tokio::pin!(combined_stream);
149
150 let server_result = loop {
151 match combined_stream.next().await {
152 Some(Either::Left(progress_event)) => {
153 let _ = event_tx.send(progress_event).await;
154 }
155 Some(Either::Right(result)) => {
156 break match result {
157 Ok(server_result) => server_result,
158 Err(e) => {
159 if let rmcp::service::ServiceError::McpError(ref error_data) = e
160 && error_data.code == ErrorCode::URL_ELICITATION_REQUIRED
161 {
162 return Err(handle_url_elicitation_required(&client, request, error_data).await);
163 }
164 return Err(ToolCallError::from_request(request, format!("Tool execution failed: {e}")));
165 }
166 };
167 }
168 None => {
169 return Err(ToolCallError::from_request(request, "Stream ended without result"));
170 }
171 }
172 };
173
174 let ServerResult::CallToolResult(mcp_result) = server_result else {
175 return Err(ToolCallError::from_request(request, "Unexpected response type from MCP server"));
176 };
177
178 mcp_result_to_tool_call_result(request, mcp_result)
179}
180
181#[derive(serde::Deserialize)]
182struct UrlElicitationRequiredData {
183 elicitations: Vec<CreateElicitationRequestParams>,
184}
185
186#[derive(Debug)]
187enum UrlElicitationRequiredParseError {
188 MissingData,
189 InvalidData(serde_json::Error),
190 NoUrlRequests,
191}
192
193impl std::fmt::Display for UrlElicitationRequiredParseError {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 match self {
196 Self::MissingData => write!(f, "missing error data"),
197 Self::InvalidData(error) => write!(f, "malformed error data: {error}"),
198 Self::NoUrlRequests => write!(f, "provided no URL elicitation requests"),
199 }
200 }
201}
202
203fn parse_required_url_elicitations(
204 error_data: &rmcp::model::ErrorData,
205) -> Result<Vec<CreateElicitationRequestParams>, UrlElicitationRequiredParseError> {
206 let data = error_data.data.as_ref().ok_or(UrlElicitationRequiredParseError::MissingData)?;
207 let parsed: UrlElicitationRequiredData =
208 serde_json::from_value(data.clone()).map_err(UrlElicitationRequiredParseError::InvalidData)?;
209
210 let url_elicitations = parsed
211 .elicitations
212 .into_iter()
213 .filter(|elicitation| matches!(elicitation, CreateElicitationRequestParams::UrlElicitationParams { .. }))
214 .collect::<Vec<_>>();
215
216 if url_elicitations.is_empty() {
217 return Err(UrlElicitationRequiredParseError::NoUrlRequests);
218 }
219
220 Ok(url_elicitations)
221}
222
223async fn handle_url_elicitation_required(
227 client: &Arc<RunningService<RoleClient, McpClient>>,
228 request: &ToolCallRequest,
229 error_data: &rmcp::model::ErrorData,
230) -> ToolCallError {
231 let server_name = client.service().server_name().to_string();
232 let url_elicitations = match parse_required_url_elicitations(error_data) {
233 Ok(url_elicitations) => url_elicitations,
234 Err(UrlElicitationRequiredParseError::NoUrlRequests) => {
235 return ToolCallError::from_request(
236 request,
237 format!("Server '{server_name}' requires URL elicitation but provided no URL elicitation requests"),
238 );
239 }
240 Err(parse_error) => {
241 return ToolCallError::from_request(
242 request,
243 format!("Server '{server_name}' sent an invalid URL elicitation response: {parse_error}"),
244 );
245 }
246 };
247
248 tracing::info!("Server '{server_name}' requires {} URL elicitation(s)", url_elicitations.len());
249
250 for elicitation in url_elicitations {
251 let result = client.service().dispatch_elicitation(elicitation).await;
252 match result.action {
253 rmcp::model::ElicitationAction::Decline => {
254 return ToolCallError::from_request(
255 request,
256 format!("Required browser interaction for server '{server_name}' was declined"),
257 );
258 }
259 rmcp::model::ElicitationAction::Cancel => {
260 return ToolCallError::from_request(
261 request,
262 format!("Required browser interaction for server '{server_name}' was cancelled"),
263 );
264 }
265 rmcp::model::ElicitationAction::Accept => {
266 tracing::info!("User accepted URL elicitation for server '{server_name}'");
267 }
268 }
269 }
270
271 ToolCallError::from_request(
272 request,
273 format!(
274 "Server '{server_name}' requires a browser flow. The URL has been opened for your approval. Retry the previous request after completing the browser flow."
275 ),
276 )
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn url_elicitation_required_data_parses_url_entries() {
285 let data = serde_json::json!({
286 "elicitations": [
287 {
288 "mode": "url",
289 "message": "Auth",
290 "url": "https://example.com/auth?elicitationId=el-1",
291 "elicitationId": "el-1"
292 }
293 ]
294 });
295
296 let parsed: UrlElicitationRequiredData = serde_json::from_value(data).unwrap();
297 assert_eq!(parsed.elicitations.len(), 1);
298 assert!(matches!(
299 &parsed.elicitations[0],
300 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
301 ));
302 }
303
304 #[test]
305 fn parse_required_url_elicitations_filters_to_url_only() {
306 let error_data = rmcp::model::ErrorData {
307 code: rmcp::model::ErrorCode::URL_ELICITATION_REQUIRED,
308 message: "URL elicitation required".into(),
309 data: Some(serde_json::json!({
310 "elicitations": [
311 {
312 "mode": "url",
313 "message": "Auth",
314 "url": "https://example.com/auth",
315 "elicitationId": "el-1"
316 },
317 {
318 "mode": "form",
319 "message": "Pick a color",
320 "requestedSchema": { "type": "object", "properties": {} }
321 }
322 ]
323 })),
324 };
325
326 let result = parse_required_url_elicitations(&error_data).unwrap();
327 assert_eq!(result.len(), 1);
328 assert!(matches!(
329 &result[0],
330 CreateElicitationRequestParams::UrlElicitationParams { elicitation_id, .. } if elicitation_id == "el-1"
331 ));
332 }
333}