1use mcp_utils::client::mcp_client::McpClient;
2use mcp_utils::client::{McpManager, McpServerStatusEntry};
3use mcp_utils::display_meta::ToolResultMeta;
4
5use futures::future::Either;
6use futures::stream::{self, StreamExt};
7use llm::{ToolCallError, ToolCallRequest, ToolCallResult, ToolDefinition};
8use rmcp::RoleClient;
9use rmcp::model::{CallToolRequestParams, GetPromptResult, ProgressNotificationParam, Prompt};
10use rmcp::service::RunningService;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::mpsc;
14use tokio::sync::oneshot;
15
16#[derive(Debug)]
18pub enum ToolExecutionEvent {
19 Started {
20 tool_id: String,
21 tool_name: String,
22 },
23 Progress {
24 tool_id: String,
25 progress: ProgressNotificationParam,
26 },
27 Complete {
28 tool_id: String,
29 result: Result<ToolCallResult, ToolCallError>,
30 result_meta: Option<ToolResultMeta>,
31 },
32}
33
34type AuthResult = Result<(Vec<McpServerStatusEntry>, Vec<ToolDefinition>), String>;
35
36#[derive(Debug)]
38pub enum McpCommand {
39 ExecuteTool {
40 request: ToolCallRequest,
41 timeout: Duration,
42 tx: mpsc::Sender<ToolExecutionEvent>,
43 },
44 ListPrompts {
45 tx: oneshot::Sender<Result<Vec<Prompt>, String>>,
46 },
47 GetPrompt {
48 name: String,
49 arguments: Option<serde_json::Map<String, serde_json::Value>>,
50 tx: oneshot::Sender<Result<GetPromptResult, String>>,
51 },
52 GetServerStatuses {
53 tx: oneshot::Sender<Vec<McpServerStatusEntry>>,
54 },
55 AuthenticateServer {
56 name: String,
57 tx: oneshot::Sender<AuthResult>,
58 },
59}
60
61pub async fn run_mcp_task(mut mcp: McpManager, mut command_rx: mpsc::Receiver<McpCommand>) {
62 while let Some(command) = command_rx.recv().await {
63 on_command(command, &mut mcp).await;
64 }
65
66 mcp.shutdown().await;
67 tracing::debug!("MCP manager task ended");
68}
69
70async fn on_command(command: McpCommand, mcp: &mut McpManager) {
71 match command {
72 McpCommand::ExecuteTool {
73 request,
74 timeout,
75 tx,
76 } => {
77 let tool_id = request.id.clone();
78 let tool_name = request.name.clone();
79
80 let _ = tx
81 .send(ToolExecutionEvent::Started {
82 tool_id: tool_id.clone(),
83 tool_name: tool_name.clone(),
84 })
85 .await;
86
87 match mcp.get_client_for_tool(&request.name, &request.arguments) {
88 Ok((client, params)) => {
89 tokio::spawn(async move {
90 let outcome = execute_mcp_call(
91 client,
92 &request,
93 params,
94 timeout,
95 tool_id.clone(),
96 tx.clone(),
97 )
98 .await;
99 let (result, result_meta) = match outcome {
100 Ok((r, m)) => (Ok(r), m),
101 Err(e) => (Err(e), None),
102 };
103 let _ = tx
104 .send(ToolExecutionEvent::Complete {
105 tool_id,
106 result,
107 result_meta,
108 })
109 .await;
110 });
111 }
112 Err(e) => {
113 tracing::error!("Failed to get client for tool {}: {e}", request.name);
114 let error =
115 ToolCallError::from_request(&request, format!("Failed to get client: {e}"));
116 let _ = tx
117 .send(ToolExecutionEvent::Complete {
118 tool_id,
119 result: Err(error),
120 result_meta: None,
121 })
122 .await;
123 }
124 }
125 }
126
127 McpCommand::ListPrompts { tx } => {
128 let result = mcp
129 .list_prompts()
130 .await
131 .map_err(|e| format!("Failed to list prompts: {e}"));
132 let _ = tx.send(result);
133 }
134
135 McpCommand::GetPrompt {
136 name: namespaced_name,
137 arguments,
138 tx,
139 } => {
140 let result = mcp
141 .get_prompt(&namespaced_name, arguments)
142 .await
143 .map_err(|e| format!("Failed to get prompt: {e}"));
144 let _ = tx.send(result);
145 }
146
147 McpCommand::GetServerStatuses { tx } => {
148 let _ = tx.send(mcp.server_statuses().to_vec());
149 }
150
151 McpCommand::AuthenticateServer { name, tx } => {
152 let result = match mcp.authenticate_server(&name).await {
153 Ok(()) => Ok((mcp.server_statuses().to_vec(), mcp.tool_definitions())),
154 Err(e) => Err(format!("Authentication failed for '{name}': {e}")),
155 };
156 let _ = tx.send(result);
157 }
158 }
159}
160
161async fn execute_mcp_call(
164 client: Arc<RunningService<RoleClient, McpClient>>,
165 request: &ToolCallRequest,
166 params: CallToolRequestParams,
167 timeout: Duration,
168 tool_call_id: String,
169 event_tx: mpsc::Sender<ToolExecutionEvent>,
170) -> Result<(ToolCallResult, Option<ToolResultMeta>), ToolCallError> {
171 use super::tool_bridge::mcp_result_to_tool_call_result;
172 use rmcp::model::{ClientRequest::CallToolRequest, Request, ServerResult};
173 use rmcp::service::PeerRequestOptions;
174
175 let handle = client
176 .send_cancellable_request(CallToolRequest(Request::new(params)), {
177 let mut opts = PeerRequestOptions::default();
178 opts.timeout = Some(timeout);
179 opts
180 })
181 .await
182 .map_err(|e| {
183 ToolCallError::from_request(request, format!("Failed to send tool request: {e}"))
184 })?;
185
186 let progress_subscriber = client
187 .service()
188 .progress_dispatcher
189 .subscribe(handle.progress_token.clone())
190 .await;
191
192 let progress_stream = progress_subscriber.map(move |progress| {
193 Either::Left(ToolExecutionEvent::Progress {
194 tool_id: tool_call_id.clone(),
195 progress,
196 })
197 });
198
199 let result_stream = stream::once(handle.await_response()).map(Either::Right);
200 let combined_stream = stream::select(progress_stream, result_stream);
201 tokio::pin!(combined_stream);
202
203 let server_result = loop {
204 match combined_stream.next().await {
205 Some(Either::Left(progress_event)) => {
206 let _ = event_tx.send(progress_event).await;
207 }
208 Some(Either::Right(result)) => {
209 break result.map_err(|e| {
210 ToolCallError::from_request(request, format!("Tool execution failed: {e}"))
211 })?;
212 }
213 None => {
214 return Err(ToolCallError::from_request(
215 request,
216 "Stream ended without result",
217 ));
218 }
219 }
220 };
221
222 let ServerResult::CallToolResult(mcp_result) = server_result else {
223 return Err(ToolCallError::from_request(
224 request,
225 "Unexpected response type from MCP server",
226 ));
227 };
228
229 mcp_result_to_tool_call_result(request, mcp_result)
230}