1use super::{
2 McpClientEvent, McpError, OAuthHandlerFactory, Result,
3 config::{McpServer, McpTransport},
4 mcp_client::McpClient,
5};
6use crate::{client::OAuthHandlerContext, transport::create_in_memory_transport};
7use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
8use rmcp::{
9 RoleClient, RoleServer, ServiceExt,
10 model::{ClientInfo, Root, Tool as RmcpTool},
11 serve_client,
12 service::{DynService, RunningService},
13 transport::{
14 StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15 streamable_http_client::StreamableHttpClientTransportConfig,
16 },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::process::Stdio;
21use std::sync::Arc;
22use tokio::{
23 io::{AsyncBufReadExt, BufReader},
24 process::{ChildStderr, Command},
25 sync::{RwLock, mpsc},
26 task::JoinHandle,
27};
28
29#[derive(Debug, Clone)]
30pub struct ServerInstructions {
31 pub server_name: String,
32 pub instructions: String,
33}
34
35#[derive(Debug, Clone)]
36pub struct Tool {
37 pub description: String,
38 pub parameters: Value,
39}
40
41impl From<RmcpTool> for Tool {
42 fn from(tool: RmcpTool) -> Self {
43 Self {
44 description: tool.description.unwrap_or_default().to_string(),
45 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
46 }
47 }
48}
49
50impl From<&RmcpTool> for Tool {
51 fn from(tool: &RmcpTool) -> Self {
52 Self {
53 description: tool.description.clone().unwrap_or_default().to_string(),
54 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
55 }
56 }
57}
58
59pub(super) struct ConnectContext<'a> {
60 pub client_info: &'a ClientInfo,
61 pub event_sender: &'a mpsc::Sender<McpClientEvent>,
62 pub roots: &'a Arc<RwLock<Vec<Root>>>,
63 pub oauth_handler_factory: Option<&'a OAuthHandlerFactory>,
64 pub oauth_credential_store: Option<&'a Arc<dyn OAuthCredentialStorage>>,
65}
66
67pub struct McpConnectAttempt {
69 pub name: String,
70 pub proxied: bool,
71 pub outcome: McpConnectOutcome,
72}
73
74pub enum McpConnectOutcome {
75 Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
76 NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
77 Failed { error: McpError },
78}
79
80impl McpConnectAttempt {
81 pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
82 Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
83 }
84}
85
86pub struct McpServerConnection {
87 pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
88 pub(super) server_task: Option<JoinHandle<()>>,
89 pub(super) instructions: Option<String>,
90}
91
92impl McpServerConnection {
93 pub(super) async fn reconnect_with_auth(
94 name: &str,
95 config: StreamableHttpClientTransportConfig,
96 auth_client: AuthClient<reqwest::Client>,
97 mcp_client: McpClient,
98 ) -> Result<Self> {
99 let transport = StreamableHttpClientTransport::with_client(auth_client, config);
100 let client = serve_client(mcp_client, transport)
101 .await
102 .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
103 Ok(Self::from_parts(client, None))
104 }
105
106 pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
107 let response = self
108 .client
109 .list_tools(None)
110 .await
111 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
112 Ok(response.tools)
113 }
114
115 fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
116 let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
117 Self { client: Arc::new(client), server_task, instructions }
118 }
119}
120
121pub(super) async fn connect_server(server: McpServer, ctx: &ConnectContext<'_>) -> McpConnectAttempt {
122 let McpServer { name, transport, proxy: proxied } = server;
123 let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory);
124 let mcp_client =
125 McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(ctx.roots));
126
127 let outcome = match transport {
128 McpTransport::Stdio { command, args, env } => connect_stdio(&name, command, args, env, mcp_client).await,
129 McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
130 McpTransport::Http { config } => {
131 connect_http(&name, config, mcp_client, ctx.oauth_handler_factory, ctx.oauth_credential_store).await
132 }
133 };
134
135 McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
136}
137
138#[allow(clippy::too_many_arguments)]
139pub async fn authenticate_http(
140 name: String,
141 config: StreamableHttpClientTransportConfig,
142 client_info: ClientInfo,
143 event_sender: mpsc::Sender<McpClientEvent>,
144 roots: Arc<RwLock<Vec<Root>>>,
145 oauth_handler_factory: OAuthHandlerFactory,
146 oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
147 proxied: bool,
148) -> McpConnectAttempt {
149 let outcome = match async {
150 let handler =
151 oauth_handler_factory(OAuthHandlerContext { server_name: name.clone(), tx: event_sender.clone() })?;
152
153 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), oauth_credential_store)
154 .await
155 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
156
157 let mcp_client = McpClient::new(client_info, name.clone(), event_sender, roots);
158 McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
159 }
160 .await
161 {
162 Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
163 Err(error) => McpConnectOutcome::Failed { error },
164 };
165
166 McpConnectAttempt { name, proxied, outcome }
167}
168
169impl McpConnectOutcome {
170 fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
171 match self {
172 Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
173 other => other,
174 }
175 }
176}
177
178async fn connect_stdio(
179 server_name: &str,
180 command: String,
181 args: Vec<String>,
182 env: HashMap<String, String>,
183 mcp_client: McpClient,
184) -> McpConnectOutcome {
185 let mut cmd = Command::new(&command);
186 cmd.args(&args).envs(&env);
187
188 let (proc, stderr) = match TokioChildProcess::builder(cmd).stderr(Stdio::piped()).spawn() {
189 Ok(parts) => parts,
190 Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
191 };
192 let stderr_task = stderr.map(|stderr| spawn_stderr_logger(server_name.to_string(), stderr));
193
194 match serve_client(mcp_client, proc).await {
195 Ok(client) => McpConnectOutcome::Connected {
196 conn: McpServerConnection::from_parts(client, stderr_task),
197 reauth_config: None,
198 },
199 Err(e) => {
200 if let Some(task) = stderr_task {
201 task.abort();
202 }
203 McpConnectOutcome::Failed { error: McpError::from(e) }
204 }
205 }
206}
207
208fn spawn_stderr_logger(server_name: String, stderr: ChildStderr) -> JoinHandle<()> {
209 tokio::spawn(async move {
210 let mut lines = BufReader::new(stderr).lines();
211 loop {
212 match lines.next_line().await {
213 Ok(Some(line)) => tracing::info!(server = %server_name, stderr = %line, "MCP server stderr"),
214 Ok(None) => break,
215 Err(error) => {
216 tracing::warn!(server = %server_name, %error, "failed to read MCP server stderr");
217 break;
218 }
219 }
220 }
221 })
222}
223
224async fn connect_in_memory(
225 name: &str,
226 server: Box<dyn DynService<RoleServer>>,
227 mcp_client: McpClient,
228) -> McpConnectOutcome {
229 match serve_in_memory(server, mcp_client, name).await {
230 Ok((client, handle)) => McpConnectOutcome::Connected {
231 conn: McpServerConnection::from_parts(client, Some(handle)),
232 reauth_config: None,
233 },
234 Err(error) => McpConnectOutcome::Failed { error },
235 }
236}
237
238async fn connect_http(
239 name: &str,
240 config: StreamableHttpClientTransportConfig,
241 mcp_client: McpClient,
242 oauth_handler_factory: Option<&OAuthHandlerFactory>,
243 oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
244) -> McpConnectOutcome {
245 let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
246 let stored_auth_manager = if let Some(store) = oauth_credential_store
247 && config.auth_header.is_none()
248 {
249 match create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await {
250 Ok(manager) => manager,
251 Err(e) => {
252 tracing::warn!(
253 server = %name,
254 error = %e,
255 "Failed to initialize auth manager from stored credentials, proceeding without auth"
256 );
257 None
258 }
259 }
260 } else {
261 None
262 };
263 let result = if let Some(auth_manager) = stored_auth_manager {
264 tracing::debug!("Using OAuth for server '{name}'");
265 let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
266 let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
267 serve_client(mcp_client, transport).await.map_err(conn_err)
268 } else {
269 let transport = StreamableHttpClientTransport::from_config(config.clone());
270 serve_client(mcp_client, transport).await.map_err(conn_err)
271 };
272
273 match result {
274 Ok(client) => {
275 McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
276 }
277 Err(error) => {
278 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
279 if oauth_handler_factory.is_some() && config.auth_header.is_none() {
280 McpConnectOutcome::NeedsOAuth { config, error }
281 } else {
282 McpConnectOutcome::Failed { error }
283 }
284 }
285 }
286}
287
288fn reauth_config_for(
289 transport: &McpTransport,
290 oauth_handler_factory: Option<&OAuthHandlerFactory>,
291) -> Option<StreamableHttpClientTransportConfig> {
292 match transport {
293 McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
294 Some(config.clone())
295 }
296 _ => None,
297 }
298}
299
300async fn serve_in_memory(
301 server: Box<dyn DynService<RoleServer>>,
302 mcp_client: McpClient,
303 label: &str,
304) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
305 let (client_transport, server_transport) = create_in_memory_transport();
306
307 let server_handle = tokio::spawn(async move {
308 match server.serve(server_transport).await {
309 Ok(_service) => {
310 std::future::pending::<()>().await;
311 }
312 Err(e) => {
313 eprintln!("MCP server error: {e}");
314 }
315 }
316 });
317
318 let client = serve_client(mcp_client, client_transport)
319 .await
320 .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
321
322 Ok((client, server_handle))
323}