Skip to main content

mcp_utils/client/
connection.rs

1use super::{
2    McpClientEvent, McpError, OAuthHandlerFactory, Result,
3    config::{McpServer, McpTransport},
4    mcp_client::McpClient,
5};
6use crate::{client::OAuthHandlerContext, transport::create_in_memory_transport};
7use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
8use rmcp::{
9    RoleClient, RoleServer, ServiceExt,
10    model::{ClientInfo, Root, Tool as RmcpTool},
11    serve_client,
12    service::{DynService, RunningService},
13    transport::{
14        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15        streamable_http_client::StreamableHttpClientTransportConfig,
16    },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::process::Stdio;
21use std::sync::Arc;
22use tokio::{
23    io::{AsyncBufReadExt, BufReader},
24    process::{ChildStderr, Command},
25    sync::{RwLock, mpsc},
26    task::JoinHandle,
27};
28
29#[derive(Debug, Clone)]
30pub struct Tool {
31    pub name: String,
32    pub description: String,
33    pub parameters: Value,
34}
35
36impl From<RmcpTool> for Tool {
37    fn from(tool: RmcpTool) -> Self {
38        Self::from(&tool)
39    }
40}
41
42impl From<&RmcpTool> for Tool {
43    fn from(tool: &RmcpTool) -> Self {
44        Self {
45            name: tool.name.to_string(),
46            description: tool.description.clone().unwrap_or_default().to_string(),
47            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
48        }
49    }
50}
51
52pub(super) struct ConnectConfig {
53    pub client_info: ClientInfo,
54    pub event_sender: mpsc::Sender<McpClientEvent>,
55    pub roots: Arc<RwLock<Vec<Root>>>,
56    pub oauth_handler_factory: Option<OAuthHandlerFactory>,
57    pub oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
58}
59
60/// The result of attempting to connect (or authenticate) to an MCP server.
61pub struct McpConnectAttempt {
62    pub name: String,
63    pub proxied: bool,
64    pub outcome: McpConnectOutcome,
65}
66
67pub enum McpConnectOutcome {
68    Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
69    NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
70    Failed { error: McpError },
71}
72
73impl McpConnectAttempt {
74    pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
75        Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
76    }
77}
78
79pub struct McpServerConnection {
80    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
81    pub(super) server_task: Option<JoinHandle<()>>,
82    pub(super) instructions: Option<String>,
83}
84
85impl McpServerConnection {
86    pub(super) async fn reconnect_with_auth(
87        name: &str,
88        config: StreamableHttpClientTransportConfig,
89        auth_client: AuthClient<reqwest::Client>,
90        mcp_client: McpClient,
91    ) -> Result<Self> {
92        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
93        let client = serve_client(mcp_client, transport)
94            .await
95            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
96        Ok(Self::from_parts(client, None))
97    }
98
99    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
100        let response = self
101            .client
102            .list_tools(None)
103            .await
104            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
105        Ok(response.tools)
106    }
107
108    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
109        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
110        Self { client: Arc::new(client), server_task, instructions }
111    }
112}
113
114pub(super) async fn connect_server(server: McpServer, ctx: &ConnectConfig) -> McpConnectAttempt {
115    let McpServer { name, transport, proxy: proxied } = server;
116    let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory.as_ref());
117    let mcp_client =
118        McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
119
120    let outcome = match transport {
121        McpTransport::Stdio { command, args, env } => connect_stdio(&name, command, args, env, mcp_client).await,
122        McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
123        McpTransport::Http { config } => {
124            connect_http(
125                &name,
126                config,
127                mcp_client,
128                ctx.oauth_handler_factory.as_ref(),
129                ctx.oauth_credential_store.as_ref(),
130            )
131            .await
132        }
133    };
134
135    McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
136}
137
138pub async fn authenticate_http(
139    name: String,
140    config: StreamableHttpClientTransportConfig,
141    ctx: Arc<ConnectConfig>,
142    proxied: bool,
143) -> McpConnectAttempt {
144    let outcome = match async {
145        let factory = ctx
146            .oauth_handler_factory
147            .as_ref()
148            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
149        let handler = factory(OAuthHandlerContext { server_name: name.clone(), tx: ctx.event_sender.clone() })?;
150
151        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), ctx.oauth_credential_store.clone())
152            .await
153            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
154
155        let mcp_client =
156            McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
157        McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
158    }
159    .await
160    {
161        Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
162        Err(error) => McpConnectOutcome::Failed { error },
163    };
164
165    McpConnectAttempt { name, proxied, outcome }
166}
167
168impl McpConnectOutcome {
169    fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
170        match self {
171            Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
172            other => other,
173        }
174    }
175}
176
177async fn connect_stdio(
178    server_name: &str,
179    command: String,
180    args: Vec<String>,
181    env: HashMap<String, String>,
182    mcp_client: McpClient,
183) -> McpConnectOutcome {
184    let mut cmd = Command::new(&command);
185    cmd.args(&args).envs(&env);
186
187    let (proc, stderr) = match TokioChildProcess::builder(cmd).stderr(Stdio::piped()).spawn() {
188        Ok(parts) => parts,
189        Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
190    };
191    let stderr_task = stderr.map(|stderr| spawn_stderr_logger(server_name.to_string(), stderr));
192
193    match serve_client(mcp_client, proc).await {
194        Ok(client) => McpConnectOutcome::Connected {
195            conn: McpServerConnection::from_parts(client, stderr_task),
196            reauth_config: None,
197        },
198        Err(e) => {
199            if let Some(task) = stderr_task {
200                task.abort();
201            }
202            McpConnectOutcome::Failed { error: McpError::from(e) }
203        }
204    }
205}
206
207fn spawn_stderr_logger(server_name: String, stderr: ChildStderr) -> JoinHandle<()> {
208    tokio::spawn(async move {
209        let mut lines = BufReader::new(stderr).lines();
210        loop {
211            match lines.next_line().await {
212                Ok(Some(line)) => tracing::info!(server = %server_name, stderr = %line, "MCP server stderr"),
213                Ok(None) => break,
214                Err(error) => {
215                    tracing::warn!(server = %server_name, %error, "failed to read MCP server stderr");
216                    break;
217                }
218            }
219        }
220    })
221}
222
223async fn connect_in_memory(
224    name: &str,
225    server: Box<dyn DynService<RoleServer>>,
226    mcp_client: McpClient,
227) -> McpConnectOutcome {
228    match serve_in_memory(server, mcp_client, name).await {
229        Ok((client, handle)) => McpConnectOutcome::Connected {
230            conn: McpServerConnection::from_parts(client, Some(handle)),
231            reauth_config: None,
232        },
233        Err(error) => McpConnectOutcome::Failed { error },
234    }
235}
236
237async fn connect_http(
238    name: &str,
239    config: StreamableHttpClientTransportConfig,
240    mcp_client: McpClient,
241    oauth_handler_factory: Option<&OAuthHandlerFactory>,
242    oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
243) -> McpConnectOutcome {
244    let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
245    let stored_auth_manager = if let Some(store) = oauth_credential_store
246        && config.auth_header.is_none()
247    {
248        match create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await {
249            Ok(manager) => manager,
250            Err(e) => {
251                tracing::warn!(
252                    server = %name,
253                    error = %e,
254                    "Failed to initialize auth manager from stored credentials, proceeding without auth"
255                );
256                None
257            }
258        }
259    } else {
260        None
261    };
262    let result = if let Some(auth_manager) = stored_auth_manager {
263        tracing::debug!("Using OAuth for server '{name}'");
264        let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
265        let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
266        serve_client(mcp_client, transport).await.map_err(conn_err)
267    } else {
268        let transport = StreamableHttpClientTransport::from_config(config.clone());
269        serve_client(mcp_client, transport).await.map_err(conn_err)
270    };
271
272    match result {
273        Ok(client) => {
274            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
275        }
276        Err(error) => {
277            tracing::warn!("Failed to connect to MCP server '{name}': {error}");
278            if oauth_handler_factory.is_some() && config.auth_header.is_none() {
279                McpConnectOutcome::NeedsOAuth { config, error }
280            } else {
281                McpConnectOutcome::Failed { error }
282            }
283        }
284    }
285}
286
287fn reauth_config_for(
288    transport: &McpTransport,
289    oauth_handler_factory: Option<&OAuthHandlerFactory>,
290) -> Option<StreamableHttpClientTransportConfig> {
291    match transport {
292        McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
293            Some(config.clone())
294        }
295        _ => None,
296    }
297}
298
299async fn serve_in_memory(
300    server: Box<dyn DynService<RoleServer>>,
301    mcp_client: McpClient,
302    label: &str,
303) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
304    let (client_transport, server_transport) = create_in_memory_transport();
305
306    let server_handle = tokio::spawn(async move {
307        match server.serve(server_transport).await {
308            Ok(_service) => {
309                std::future::pending::<()>().await;
310            }
311            Err(e) => {
312                eprintln!("MCP server error: {e}");
313            }
314        }
315    });
316
317    let client = serve_client(mcp_client, client_transport)
318        .await
319        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
320
321    Ok((client, server_handle))
322}