Skip to main content

mcp_utils/client/
connection.rs

1use super::{
2    McpClientEvent, McpError, OAuthHandlerFactory, Result,
3    config::{McpServer, McpTransport},
4    mcp_client::McpClient,
5};
6use crate::transport::create_in_memory_transport;
7use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
8use rmcp::{
9    RoleClient, RoleServer, ServiceExt,
10    model::{ClientInfo, Root, Tool as RmcpTool},
11    serve_client,
12    service::{DynService, RunningService},
13    transport::{
14        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15        streamable_http_client::StreamableHttpClientTransportConfig,
16    },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::process::Stdio;
21use std::sync::Arc;
22use tokio::{
23    io::{AsyncBufReadExt, BufReader},
24    process::{ChildStderr, Command},
25    sync::{RwLock, mpsc},
26    task::JoinHandle,
27};
28
29#[derive(Debug, Clone)]
30pub struct ServerInstructions {
31    pub server_name: String,
32    pub instructions: String,
33}
34
35#[derive(Debug, Clone)]
36pub struct Tool {
37    pub description: String,
38    pub parameters: Value,
39}
40
41impl From<RmcpTool> for Tool {
42    fn from(tool: RmcpTool) -> Self {
43        Self {
44            description: tool.description.unwrap_or_default().to_string(),
45            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
46        }
47    }
48}
49
50impl From<&RmcpTool> for Tool {
51    fn from(tool: &RmcpTool) -> Self {
52        Self {
53            description: tool.description.clone().unwrap_or_default().to_string(),
54            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
55        }
56    }
57}
58
59pub(super) struct ConnectContext<'a> {
60    pub client_info: &'a ClientInfo,
61    pub event_sender: &'a mpsc::Sender<McpClientEvent>,
62    pub roots: &'a Arc<RwLock<Vec<Root>>>,
63    pub oauth_handler_factory: Option<&'a OAuthHandlerFactory>,
64    pub oauth_credential_store: Option<&'a Arc<dyn OAuthCredentialStorage>>,
65}
66
67/// The result of attempting to connect (or authenticate) to an MCP server.
68pub struct McpConnectAttempt {
69    pub name: String,
70    pub proxied: bool,
71    pub outcome: McpConnectOutcome,
72}
73
74pub enum McpConnectOutcome {
75    Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
76    NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
77    Failed { error: McpError },
78}
79
80impl McpConnectAttempt {
81    pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
82        Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
83    }
84}
85
86pub struct McpServerConnection {
87    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
88    pub(super) server_task: Option<JoinHandle<()>>,
89    pub(super) instructions: Option<String>,
90}
91
92impl McpServerConnection {
93    pub(super) async fn reconnect_with_auth(
94        name: &str,
95        config: StreamableHttpClientTransportConfig,
96        auth_client: AuthClient<reqwest::Client>,
97        mcp_client: McpClient,
98    ) -> Result<Self> {
99        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
100        let client = serve_client(mcp_client, transport)
101            .await
102            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
103        Ok(Self::from_parts(client, None))
104    }
105
106    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
107        let response = self
108            .client
109            .list_tools(None)
110            .await
111            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
112        Ok(response.tools)
113    }
114
115    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
116        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
117        Self { client: Arc::new(client), server_task, instructions }
118    }
119}
120
121pub(super) async fn connect_server(server: McpServer, ctx: &ConnectContext<'_>) -> McpConnectAttempt {
122    let McpServer { name, transport, proxy: proxied } = server;
123    let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory);
124    let mcp_client =
125        McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(ctx.roots));
126
127    let outcome = match transport {
128        McpTransport::Stdio { command, args, env } => connect_stdio(&name, command, args, env, mcp_client).await,
129        McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
130        McpTransport::Http { config } => {
131            connect_http(&name, config, mcp_client, ctx.oauth_handler_factory, ctx.oauth_credential_store).await
132        }
133    };
134
135    McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
136}
137
138#[allow(clippy::too_many_arguments)]
139pub async fn authenticate_http(
140    name: String,
141    config: StreamableHttpClientTransportConfig,
142    client_info: ClientInfo,
143    event_sender: mpsc::Sender<McpClientEvent>,
144    roots: Arc<RwLock<Vec<Root>>>,
145    oauth_handler_factory: OAuthHandlerFactory,
146    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
147    proxied: bool,
148) -> McpConnectAttempt {
149    let outcome = match async {
150        let handler = oauth_handler_factory()?;
151        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), oauth_credential_store)
152            .await
153            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
154
155        let mcp_client = McpClient::new(client_info, name.clone(), event_sender, roots);
156        McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
157    }
158    .await
159    {
160        Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
161        Err(error) => McpConnectOutcome::Failed { error },
162    };
163
164    McpConnectAttempt { name, proxied, outcome }
165}
166
167impl McpConnectOutcome {
168    fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
169        match self {
170            Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
171            other => other,
172        }
173    }
174}
175
176async fn connect_stdio(
177    server_name: &str,
178    command: String,
179    args: Vec<String>,
180    env: HashMap<String, String>,
181    mcp_client: McpClient,
182) -> McpConnectOutcome {
183    let mut cmd = Command::new(&command);
184    cmd.args(&args).envs(&env);
185
186    let (proc, stderr) = match TokioChildProcess::builder(cmd).stderr(Stdio::piped()).spawn() {
187        Ok(parts) => parts,
188        Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
189    };
190    let stderr_task = stderr.map(|stderr| spawn_stderr_logger(server_name.to_string(), stderr));
191
192    match serve_client(mcp_client, proc).await {
193        Ok(client) => McpConnectOutcome::Connected {
194            conn: McpServerConnection::from_parts(client, stderr_task),
195            reauth_config: None,
196        },
197        Err(e) => {
198            if let Some(task) = stderr_task {
199                task.abort();
200            }
201            McpConnectOutcome::Failed { error: McpError::from(e) }
202        }
203    }
204}
205
206fn spawn_stderr_logger(server_name: String, stderr: ChildStderr) -> JoinHandle<()> {
207    tokio::spawn(async move {
208        let mut lines = BufReader::new(stderr).lines();
209        loop {
210            match lines.next_line().await {
211                Ok(Some(line)) => tracing::info!(server = %server_name, stderr = %line, "MCP server stderr"),
212                Ok(None) => break,
213                Err(error) => {
214                    tracing::warn!(server = %server_name, %error, "failed to read MCP server stderr");
215                    break;
216                }
217            }
218        }
219    })
220}
221
222async fn connect_in_memory(
223    name: &str,
224    server: Box<dyn DynService<RoleServer>>,
225    mcp_client: McpClient,
226) -> McpConnectOutcome {
227    match serve_in_memory(server, mcp_client, name).await {
228        Ok((client, handle)) => McpConnectOutcome::Connected {
229            conn: McpServerConnection::from_parts(client, Some(handle)),
230            reauth_config: None,
231        },
232        Err(error) => McpConnectOutcome::Failed { error },
233    }
234}
235
236async fn connect_http(
237    name: &str,
238    config: StreamableHttpClientTransportConfig,
239    mcp_client: McpClient,
240    oauth_handler_factory: Option<&OAuthHandlerFactory>,
241    oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
242) -> McpConnectOutcome {
243    let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
244    let stored_auth_manager = if let Some(store) = oauth_credential_store
245        && config.auth_header.is_none()
246    {
247        match create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await {
248            Ok(manager) => manager,
249            Err(e) => {
250                tracing::warn!(
251                    server = %name,
252                    error = %e,
253                    "Failed to initialize auth manager from stored credentials, proceeding without auth"
254                );
255                None
256            }
257        }
258    } else {
259        None
260    };
261    let result = if let Some(auth_manager) = stored_auth_manager {
262        tracing::debug!("Using OAuth for server '{name}'");
263        let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
264        let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
265        serve_client(mcp_client, transport).await.map_err(conn_err)
266    } else {
267        let transport = StreamableHttpClientTransport::from_config(config.clone());
268        serve_client(mcp_client, transport).await.map_err(conn_err)
269    };
270
271    match result {
272        Ok(client) => {
273            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
274        }
275        Err(error) => {
276            tracing::warn!("Failed to connect to MCP server '{name}': {error}");
277            if oauth_handler_factory.is_some() && config.auth_header.is_none() {
278                McpConnectOutcome::NeedsOAuth { config, error }
279            } else {
280                McpConnectOutcome::Failed { error }
281            }
282        }
283    }
284}
285
286fn reauth_config_for(
287    transport: &McpTransport,
288    oauth_handler_factory: Option<&OAuthHandlerFactory>,
289) -> Option<StreamableHttpClientTransportConfig> {
290    match transport {
291        McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
292            Some(config.clone())
293        }
294        _ => None,
295    }
296}
297
298async fn serve_in_memory(
299    server: Box<dyn DynService<RoleServer>>,
300    mcp_client: McpClient,
301    label: &str,
302) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
303    let (client_transport, server_transport) = create_in_memory_transport();
304
305    let server_handle = tokio::spawn(async move {
306        match server.serve(server_transport).await {
307            Ok(_service) => {
308                std::future::pending::<()>().await;
309            }
310            Err(e) => {
311                eprintln!("MCP server error: {e}");
312            }
313        }
314    });
315
316    let client = serve_client(mcp_client, client_transport)
317        .await
318        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
319
320    Ok((client, server_handle))
321}