Skip to main content

mcp_utils/client/
connection.rs

1use super::{
2    McpClientEvent, McpError, OAuthHandlerFactory, Result,
3    config::{McpServer, McpTransport},
4    mcp_client::McpClient,
5};
6use crate::transport::create_in_memory_transport;
7use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
8use rmcp::{
9    RoleClient, RoleServer, ServiceExt,
10    model::{ClientInfo, Root, Tool as RmcpTool},
11    serve_client,
12    service::{DynService, RunningService},
13    transport::{
14        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15        streamable_http_client::StreamableHttpClientTransportConfig,
16    },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::{
22    process::Command,
23    sync::{RwLock, mpsc},
24    task::JoinHandle,
25};
26
27#[derive(Debug, Clone)]
28pub struct ServerInstructions {
29    pub server_name: String,
30    pub instructions: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct Tool {
35    pub description: String,
36    pub parameters: Value,
37}
38
39impl From<RmcpTool> for Tool {
40    fn from(tool: RmcpTool) -> Self {
41        Self {
42            description: tool.description.unwrap_or_default().to_string(),
43            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
44        }
45    }
46}
47
48impl From<&RmcpTool> for Tool {
49    fn from(tool: &RmcpTool) -> Self {
50        Self {
51            description: tool.description.clone().unwrap_or_default().to_string(),
52            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
53        }
54    }
55}
56
57pub(super) struct ConnectContext<'a> {
58    pub client_info: &'a ClientInfo,
59    pub event_sender: &'a mpsc::Sender<McpClientEvent>,
60    pub roots: &'a Arc<RwLock<Vec<Root>>>,
61    pub oauth_handler_factory: Option<&'a OAuthHandlerFactory>,
62    pub oauth_credential_store: Option<&'a Arc<dyn OAuthCredentialStorage>>,
63}
64
65/// The result of attempting to connect (or authenticate) to an MCP server.
66pub struct McpConnectAttempt {
67    pub name: String,
68    pub proxied: bool,
69    pub outcome: McpConnectOutcome,
70}
71
72pub enum McpConnectOutcome {
73    Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
74    NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
75    Failed { error: McpError },
76}
77
78impl McpConnectAttempt {
79    pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
80        Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
81    }
82}
83
84pub struct McpServerConnection {
85    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
86    pub(super) server_task: Option<JoinHandle<()>>,
87    pub(super) instructions: Option<String>,
88}
89
90impl McpServerConnection {
91    pub(super) async fn reconnect_with_auth(
92        name: &str,
93        config: StreamableHttpClientTransportConfig,
94        auth_client: AuthClient<reqwest::Client>,
95        mcp_client: McpClient,
96    ) -> Result<Self> {
97        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
98        let client = serve_client(mcp_client, transport)
99            .await
100            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
101        Ok(Self::from_parts(client, None))
102    }
103
104    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
105        let response = self
106            .client
107            .list_tools(None)
108            .await
109            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
110        Ok(response.tools)
111    }
112
113    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
114        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
115        Self { client: Arc::new(client), server_task, instructions }
116    }
117}
118
119pub(super) async fn connect_server(server: McpServer, ctx: &ConnectContext<'_>) -> McpConnectAttempt {
120    let McpServer { name, transport, proxy: proxied } = server;
121    let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory);
122    let mcp_client =
123        McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(ctx.roots));
124
125    let outcome = match transport {
126        McpTransport::Stdio { command, args, env } => connect_stdio(command, args, env, mcp_client).await,
127        McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
128        McpTransport::Http { config } => {
129            connect_http(&name, config, mcp_client, ctx.oauth_handler_factory, ctx.oauth_credential_store).await
130        }
131    };
132
133    McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
134}
135
136#[allow(clippy::too_many_arguments)]
137pub async fn authenticate_http(
138    name: String,
139    config: StreamableHttpClientTransportConfig,
140    client_info: ClientInfo,
141    event_sender: mpsc::Sender<McpClientEvent>,
142    roots: Arc<RwLock<Vec<Root>>>,
143    oauth_handler_factory: OAuthHandlerFactory,
144    oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
145    proxied: bool,
146) -> McpConnectAttempt {
147    let outcome = match async {
148        let handler = oauth_handler_factory()?;
149        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), oauth_credential_store)
150            .await
151            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
152
153        let mcp_client = McpClient::new(client_info, name.clone(), event_sender, roots);
154        McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
155    }
156    .await
157    {
158        Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
159        Err(error) => McpConnectOutcome::Failed { error },
160    };
161
162    McpConnectAttempt { name, proxied, outcome }
163}
164
165impl McpConnectOutcome {
166    fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
167        match self {
168            Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
169            other => other,
170        }
171    }
172}
173
174async fn connect_stdio(
175    command: String,
176    args: Vec<String>,
177    env: HashMap<String, String>,
178    mcp_client: McpClient,
179) -> McpConnectOutcome {
180    let cmd = {
181        let mut cmd = Command::new(&command);
182        cmd.args(&args);
183        cmd.envs(&env);
184        cmd
185    };
186
187    let child = match TokioChildProcess::new(cmd) {
188        Ok(child) => child,
189        Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
190    };
191
192    match mcp_client.serve(child).await {
193        Ok(client) => {
194            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
195        }
196        Err(e) => McpConnectOutcome::Failed { error: McpError::from(e) },
197    }
198}
199
200async fn connect_in_memory(
201    name: &str,
202    server: Box<dyn DynService<RoleServer>>,
203    mcp_client: McpClient,
204) -> McpConnectOutcome {
205    match serve_in_memory(server, mcp_client, name).await {
206        Ok((client, handle)) => McpConnectOutcome::Connected {
207            conn: McpServerConnection::from_parts(client, Some(handle)),
208            reauth_config: None,
209        },
210        Err(error) => McpConnectOutcome::Failed { error },
211    }
212}
213
214async fn connect_http(
215    name: &str,
216    config: StreamableHttpClientTransportConfig,
217    mcp_client: McpClient,
218    oauth_handler_factory: Option<&OAuthHandlerFactory>,
219    oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
220) -> McpConnectOutcome {
221    let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
222    let stored_auth_manager = if let Some(store) = oauth_credential_store
223        && config.auth_header.is_none()
224    {
225        create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await.ok().flatten()
226    } else {
227        None
228    };
229    let result = if let Some(auth_manager) = stored_auth_manager {
230        tracing::debug!("Using OAuth for server '{name}'");
231        let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
232        let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
233        serve_client(mcp_client, transport).await.map_err(conn_err)
234    } else {
235        let transport = StreamableHttpClientTransport::from_config(config.clone());
236        serve_client(mcp_client, transport).await.map_err(conn_err)
237    };
238
239    match result {
240        Ok(client) => {
241            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
242        }
243        Err(error) => {
244            tracing::warn!("Failed to connect to MCP server '{name}': {error}");
245            if oauth_handler_factory.is_some() && config.auth_header.is_none() {
246                McpConnectOutcome::NeedsOAuth { config, error }
247            } else {
248                McpConnectOutcome::Failed { error }
249            }
250        }
251    }
252}
253
254fn reauth_config_for(
255    transport: &McpTransport,
256    oauth_handler_factory: Option<&OAuthHandlerFactory>,
257) -> Option<StreamableHttpClientTransportConfig> {
258    match transport {
259        McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
260            Some(config.clone())
261        }
262        _ => None,
263    }
264}
265
266async fn serve_in_memory(
267    server: Box<dyn DynService<RoleServer>>,
268    mcp_client: McpClient,
269    label: &str,
270) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
271    let (client_transport, server_transport) = create_in_memory_transport();
272
273    let server_handle = tokio::spawn(async move {
274        match server.serve(server_transport).await {
275            Ok(_service) => {
276                std::future::pending::<()>().await;
277            }
278            Err(e) => {
279                eprintln!("MCP server error: {e}");
280            }
281        }
282    });
283
284    let client = serve_client(mcp_client, client_transport)
285        .await
286        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
287
288    Ok((client, server_handle))
289}