Skip to main content

aether_core/mcp/
run_mcp_task.rs

1use mcp_utils::client::mcp_client::McpClient;
2use mcp_utils::client::{McpManager, McpServerStatusEntry};
3use mcp_utils::display_meta::ToolResultMeta;
4
5use futures::future::Either;
6use futures::stream::{self, StreamExt};
7use llm::{ToolCallError, ToolCallRequest, ToolCallResult, ToolDefinition};
8use rmcp::RoleClient;
9use rmcp::model::{CallToolRequestParams, GetPromptResult, ProgressNotificationParam, Prompt};
10use rmcp::service::RunningService;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::mpsc;
14use tokio::sync::oneshot;
15
16/// Events emitted during tool execution lifecycle
17#[derive(Debug)]
18pub enum ToolExecutionEvent {
19    Started {
20        tool_id: String,
21        tool_name: String,
22    },
23    Progress {
24        tool_id: String,
25        progress: ProgressNotificationParam,
26    },
27    Complete {
28        tool_id: String,
29        result: Result<ToolCallResult, ToolCallError>,
30        result_meta: Option<ToolResultMeta>,
31    },
32}
33
34type AuthResult = Result<(Vec<McpServerStatusEntry>, Vec<ToolDefinition>), String>;
35
36/// Commands that can be sent to the MCP manager task
37#[derive(Debug)]
38pub enum McpCommand {
39    ExecuteTool {
40        request: ToolCallRequest,
41        timeout: Duration,
42        tx: mpsc::Sender<ToolExecutionEvent>,
43    },
44    ListPrompts {
45        tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
46    },
47    GetPrompt {
48        name: String,
49        arguments: Option<serde_json::Map<String, serde_json::Value>>,
50        tx: oneshot::Sender<Result<GetPromptResult, String>>,
51    },
52    GetServerStatuses {
53        tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
54    },
55    AuthenticateServer {
56        name: String,
57        tx: oneshot::Sender<AuthResult>,
58    },
59}
60
61pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
62    while let Some(command) = command_rx.recv().await {
63        on_command(command, &mut mcp).await;
64    }
65
66    mcp.shutdown().await;
67    tracing::debug!("MCP manager task ended");
68}
69
70async fn on_command(command: McpCommand, mcp: &mut McpManager) {
71    match command {
72        McpCommand::ExecuteTool {
73            request,
74            timeout,
75            tx,
76        } => {
77            let tool_id = request.id.clone();
78            let tool_name = request.name.clone();
79
80            let _ = tx
81                .send(ToolExecutionEvent::Started {
82                    tool_id: tool_id.clone(),
83                    tool_name: tool_name.clone(),
84                })
85                .await;
86
87            match mcp.get_client_for_tool(&request.name, &request.arguments) {
88                Ok((client, params)) => {
89                    tokio::spawn(async move {
90                        let outcome = execute_mcp_call(
91                            client,
92                            &request,
93                            params,
94                            timeout,
95                            tool_id.clone(),
96                            tx.clone(),
97                        )
98                        .await;
99                        let (result, result_meta) = match outcome {
100                            Ok((r, m)) => (Ok(r), m),
101                            Err(e) => (Err(e), None),
102                        };
103                        let _ = tx
104                            .send(ToolExecutionEvent::Complete {
105                                tool_id,
106                                result,
107                                result_meta,
108                            })
109                            .await;
110                    });
111                }
112                Err(e) => {
113                    tracing::error!("Failed to get client for tool {}: {e}", request.name);
114                    let error =
115                        ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
116                    let _ = tx
117                        .send(ToolExecutionEvent::Complete {
118                            tool_id,
119                            result: Err(error),
120                            result_meta: None,
121                        })
122                        .await;
123                }
124            }
125        }
126
127        McpCommand::ListPrompts { tx } => {
128            let result = mcp
129                .list_prompts()
130                .await
131                .map_err(|e| format!("Failed to list prompts: {e}"));
132            let _ = tx.send(result);
133        }
134
135        McpCommand::GetPrompt {
136            name: namespaced_name,
137            arguments,
138            tx,
139        } => {
140            let result = mcp
141                .get_prompt(&namespaced_name, arguments)
142                .await
143                .map_err(|e| format!("Failed to get prompt: {e}"));
144            let _ = tx.send(result);
145        }
146
147        McpCommand::GetServerStatuses { tx } => {
148            let _ = tx.send(mcp.server_statuses().to_vec());
149        }
150
151        McpCommand::AuthenticateServer { name, tx } => {
152            let result = match mcp.authenticate_server(&name).await {
153                Ok(()) => Ok((mcp.server_statuses().to_vec(), mcp.tool_definitions())),
154                Err(e) => Err(format!("Authentication failed for '{name}': {e}")),
155            };
156            let _ = tx.send(result);
157        }
158    }
159}
160
161/// Shared logic for sending an MCP tool call, streaming progress events,
162/// and collecting the result.
163async fn execute_mcp_call(
164    client: Arc<RunningService<RoleClient, McpClient>>,
165    request: &ToolCallRequest,
166    params: CallToolRequestParams,
167    timeout: Duration,
168    tool_call_id: String,
169    event_tx: mpsc::Sender<ToolExecutionEvent>,
170) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
171    use super::tool_bridge::mcp_result_to_tool_call_result;
172    use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
173    use rmcp::service::PeerRequestOptions;
174
175    let handle = client
176        .send_cancellable_request(CallToolRequest(Request::new(params)), {
177            let mut opts = PeerRequestOptions::default();
178            opts.timeout = Some(timeout);
179            opts
180        })
181        .await
182        .map_err(|e| {
183            ToolCallError::from_request(request, format!("Failed to send tool request: {e}"))
184        })?;
185
186    let progress_subscriber = client
187        .service()
188        .progress_dispatcher
189        .subscribe(handle.progress_token.clone())
190        .await;
191
192    let progress_stream = progress_subscriber.map(move |progress| {
193        Either::Left(ToolExecutionEvent::Progress {
194            tool_id: tool_call_id.clone(),
195            progress,
196        })
197    });
198
199    let result_stream = stream::once(handle.await_response()).map(Either::Right);
200    let combined_stream = stream::select(progress_stream, result_stream);
201    tokio::pin!(combined_stream);
202
203    let server_result = loop {
204        match combined_stream.next().await {
205            Some(Either::Left(progress_event)) => {
206                let _ = event_tx.send(progress_event).await;
207            }
208            Some(Either::Right(result)) => {
209                break result.map_err(|e| {
210                    ToolCallError::from_request(request, format!("Tool execution failed: {e}"))
211                })?;
212            }
213            None => {
214                return Err(ToolCallError::from_request(
215                    request,
216                    "Stream ended without result",
217                ));
218            }
219        }
220    };
221
222    let ServerResult::CallToolResult(mcp_result) = server_result else {
223        return Err(ToolCallError::from_request(
224            request,
225            "Unexpected response type from MCP server",
226        ));
227    };
228
229    mcp_result_to_tool_call_result(request, mcp_result)
230}