1use super::{
2 McpClientEvent, McpError, OAuthHandlerFactory, Result,
3 config::{McpServer, McpTransport},
4 mcp_client::McpClient,
5 roots::primary_root_path,
6};
7use crate::{client::OAuthHandlerContext, transport::create_in_memory_transport};
8use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
9use llm::ToolAnnotations;
10use rmcp::{
11 RoleClient, RoleServer, ServiceExt,
12 model::{ClientInfo, Root, Tool as RmcpTool},
13 serve_client,
14 service::{DynService, RunningService},
15 transport::{
16 StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
17 streamable_http_client::StreamableHttpClientTransportConfig,
18 },
19};
20use serde_json::Value;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::process::Stdio;
24use std::sync::Arc;
25use tokio::{
26 io::{AsyncBufReadExt, BufReader},
27 process::{ChildStderr, Command},
28 sync::{RwLock, mpsc},
29 task::JoinHandle,
30};
31
32#[derive(Debug, Clone)]
33pub struct Tool {
34 pub name: String,
35 pub description: String,
36 pub parameters: Value,
37 pub annotations: Option<ToolAnnotations>,
38}
39
40pub(crate) fn convert_tool_annotations(annotations: &rmcp::model::ToolAnnotations) -> ToolAnnotations {
41 ToolAnnotations {
42 title: annotations.title.clone(),
43 read_only_hint: annotations.read_only_hint,
44 destructive_hint: annotations.destructive_hint,
45 idempotent_hint: annotations.idempotent_hint,
46 open_world_hint: annotations.open_world_hint,
47 }
48}
49
50impl From<RmcpTool> for Tool {
51 fn from(tool: RmcpTool) -> Self {
52 Self::from(&tool)
53 }
54}
55
56impl From<&RmcpTool> for Tool {
57 fn from(tool: &RmcpTool) -> Self {
58 Self {
59 name: tool.name.to_string(),
60 description: tool.description.clone().unwrap_or_default().to_string(),
61 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
62 annotations: tool.annotations.as_ref().map(convert_tool_annotations),
63 }
64 }
65}
66
67pub(super) struct ConnectConfig {
68 pub client_info: ClientInfo,
69 pub event_sender: mpsc::Sender<McpClientEvent>,
70 pub roots: Arc<RwLock<Vec<Root>>>,
71 pub oauth_handler_factory: Option<OAuthHandlerFactory>,
72 pub oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
73}
74
75pub struct McpConnectAttempt {
77 pub name: String,
78 pub proxied: bool,
79 pub outcome: McpConnectOutcome,
80}
81
82pub enum McpConnectOutcome {
83 Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
84 NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
85 Failed { error: McpError },
86}
87
88impl McpConnectAttempt {
89 pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
90 Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
91 }
92}
93
94pub struct McpServerConnection {
95 pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
96 pub(super) server_task: Option<JoinHandle<()>>,
97 pub(super) instructions: Option<String>,
98}
99
100impl McpServerConnection {
101 pub(super) async fn reconnect_with_auth(
102 name: &str,
103 config: StreamableHttpClientTransportConfig,
104 auth_client: AuthClient<reqwest::Client>,
105 mcp_client: McpClient,
106 ) -> Result<Self> {
107 let transport = StreamableHttpClientTransport::with_client(auth_client, config);
108 let client = serve_client(mcp_client, transport)
109 .await
110 .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
111 Ok(Self::from_parts(client, None))
112 }
113
114 pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
115 let response = self
116 .client
117 .list_tools(None)
118 .await
119 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
120 Ok(response.tools)
121 }
122
123 fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
124 let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
125 Self { client: Arc::new(client), server_task, instructions }
126 }
127}
128
129pub(super) async fn connect_server(server: McpServer, ctx: &ConnectConfig) -> McpConnectAttempt {
130 let McpServer { name, transport, proxy: proxied } = server;
131 let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory.as_ref());
132 let mcp_client =
133 McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
134
135 let outcome = match transport {
136 McpTransport::Stdio { command, args, env } => {
137 let roots = ctx.roots.read().await;
138 let cwd = primary_root_path(&roots);
139 connect_stdio(&name, command, args, env, mcp_client, cwd).await
140 }
141 McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
142 McpTransport::Http { config } => {
143 connect_http(
144 &name,
145 config,
146 mcp_client,
147 ctx.oauth_handler_factory.as_ref(),
148 ctx.oauth_credential_store.as_ref(),
149 )
150 .await
151 }
152 };
153
154 McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
155}
156
157pub async fn authenticate_http(
158 name: String,
159 config: StreamableHttpClientTransportConfig,
160 ctx: Arc<ConnectConfig>,
161 proxied: bool,
162) -> McpConnectAttempt {
163 let outcome = match async {
164 let factory = ctx
165 .oauth_handler_factory
166 .as_ref()
167 .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
168 let handler = factory(OAuthHandlerContext { server_name: name.clone(), tx: ctx.event_sender.clone() })?;
169
170 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), ctx.oauth_credential_store.clone())
171 .await
172 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
173
174 let mcp_client =
175 McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
176 McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
177 }
178 .await
179 {
180 Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
181 Err(error) => McpConnectOutcome::Failed { error },
182 };
183
184 McpConnectAttempt { name, proxied, outcome }
185}
186
187impl McpConnectOutcome {
188 fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
189 match self {
190 Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
191 other => other,
192 }
193 }
194}
195
196async fn connect_stdio(
197 server_name: &str,
198 command: String,
199 args: Vec<String>,
200 env: HashMap<String, String>,
201 mcp_client: McpClient,
202 cwd: Option<PathBuf>,
203) -> McpConnectOutcome {
204 let mut cmd = Command::new(&command);
205 cmd.args(&args).envs(&env);
206
207 if let Some(ref dir) = cwd {
208 cmd.current_dir(dir);
209 }
210
211 let (proc, stderr) = match TokioChildProcess::builder(cmd).stderr(Stdio::piped()).spawn() {
212 Ok(parts) => parts,
213 Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
214 };
215 let stderr_task = stderr.map(|stderr| spawn_stderr_logger(server_name.to_string(), stderr));
216
217 match serve_client(mcp_client, proc).await {
218 Ok(client) => McpConnectOutcome::Connected {
219 conn: McpServerConnection::from_parts(client, stderr_task),
220 reauth_config: None,
221 },
222 Err(e) => {
223 if let Some(task) = stderr_task {
224 task.abort();
225 }
226 McpConnectOutcome::Failed { error: McpError::from(e) }
227 }
228 }
229}
230
231fn spawn_stderr_logger(server_name: String, stderr: ChildStderr) -> JoinHandle<()> {
232 tokio::spawn(async move {
233 let mut lines = BufReader::new(stderr).lines();
234 loop {
235 match lines.next_line().await {
236 Ok(Some(line)) => tracing::info!(server = %server_name, stderr = %line, "MCP server stderr"),
237 Ok(None) => break,
238 Err(error) => {
239 tracing::warn!(server = %server_name, %error, "failed to read MCP server stderr");
240 break;
241 }
242 }
243 }
244 })
245}
246
247async fn connect_in_memory(
248 name: &str,
249 server: Box<dyn DynService<RoleServer>>,
250 mcp_client: McpClient,
251) -> McpConnectOutcome {
252 match serve_in_memory(server, mcp_client, name).await {
253 Ok((client, handle)) => McpConnectOutcome::Connected {
254 conn: McpServerConnection::from_parts(client, Some(handle)),
255 reauth_config: None,
256 },
257 Err(error) => McpConnectOutcome::Failed { error },
258 }
259}
260
261async fn connect_http(
262 name: &str,
263 config: StreamableHttpClientTransportConfig,
264 mcp_client: McpClient,
265 oauth_handler_factory: Option<&OAuthHandlerFactory>,
266 oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
267) -> McpConnectOutcome {
268 let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
269 let stored_auth_manager = if let Some(store) = oauth_credential_store
270 && config.auth_header.is_none()
271 {
272 match create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await {
273 Ok(manager) => manager,
274 Err(e) => {
275 tracing::warn!(
276 server = %name,
277 error = %e,
278 "Failed to initialize auth manager from stored credentials, proceeding without auth"
279 );
280 None
281 }
282 }
283 } else {
284 None
285 };
286 let result = if let Some(auth_manager) = stored_auth_manager {
287 tracing::debug!("Using OAuth for server '{name}'");
288 let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
289 let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
290 serve_client(mcp_client, transport).await.map_err(conn_err)
291 } else {
292 let transport = StreamableHttpClientTransport::from_config(config.clone());
293 serve_client(mcp_client, transport).await.map_err(conn_err)
294 };
295
296 match result {
297 Ok(client) => {
298 McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
299 }
300 Err(error) => {
301 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
302 if oauth_handler_factory.is_some() && config.auth_header.is_none() {
303 McpConnectOutcome::NeedsOAuth { config, error }
304 } else {
305 McpConnectOutcome::Failed { error }
306 }
307 }
308 }
309}
310
311fn reauth_config_for(
312 transport: &McpTransport,
313 oauth_handler_factory: Option<&OAuthHandlerFactory>,
314) -> Option<StreamableHttpClientTransportConfig> {
315 match transport {
316 McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
317 Some(config.clone())
318 }
319 _ => None,
320 }
321}
322
323async fn serve_in_memory(
324 server: Box<dyn DynService<RoleServer>>,
325 mcp_client: McpClient,
326 label: &str,
327) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
328 let (client_transport, server_transport) = create_in_memory_transport();
329
330 let server_handle = tokio::spawn(async move {
331 match server.serve(server_transport).await {
332 Ok(_service) => {
333 std::future::pending::<()>().await;
334 }
335 Err(e) => {
336 eprintln!("MCP server error: {e}");
337 }
338 }
339 });
340
341 let client = serve_client(mcp_client, client_transport)
342 .await
343 .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
344
345 Ok((client, server_handle))
346}