Skip to main content

mcp_utils/client/
connection.rs

1use super::{
2    McpClientEvent, McpError, OAuthHandlerFactory, Result,
3    config::{McpServer, McpTransport},
4    mcp_client::McpClient,
5    roots::primary_root_path,
6};
7use crate::{client::OAuthHandlerContext, transport::create_in_memory_transport};
8use aether_auth::{OAuthCredentialStorage, create_auth_manager_from_store, perform_oauth_flow};
9use llm::ToolAnnotations;
10use rmcp::{
11    RoleClient, RoleServer, ServiceExt,
12    model::{ClientInfo, Root, Tool as RmcpTool},
13    serve_client,
14    service::{DynService, RunningService},
15    transport::{
16        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
17        streamable_http_client::StreamableHttpClientTransportConfig,
18    },
19};
20use serde_json::Value;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::process::Stdio;
24use std::sync::Arc;
25use tokio::{
26    io::{AsyncBufReadExt, BufReader},
27    process::{ChildStderr, Command},
28    sync::{RwLock, mpsc},
29    task::JoinHandle,
30};
31
32#[derive(Debug, Clone)]
33pub struct Tool {
34    pub name: String,
35    pub description: String,
36    pub parameters: Value,
37    pub annotations: Option<ToolAnnotations>,
38}
39
40pub(crate) fn convert_tool_annotations(annotations: &rmcp::model::ToolAnnotations) -> ToolAnnotations {
41    ToolAnnotations {
42        title: annotations.title.clone(),
43        read_only_hint: annotations.read_only_hint,
44        destructive_hint: annotations.destructive_hint,
45        idempotent_hint: annotations.idempotent_hint,
46        open_world_hint: annotations.open_world_hint,
47    }
48}
49
50impl From<RmcpTool> for Tool {
51    fn from(tool: RmcpTool) -> Self {
52        Self::from(&tool)
53    }
54}
55
56impl From<&RmcpTool> for Tool {
57    fn from(tool: &RmcpTool) -> Self {
58        Self {
59            name: tool.name.to_string(),
60            description: tool.description.clone().unwrap_or_default().to_string(),
61            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
62            annotations: tool.annotations.as_ref().map(convert_tool_annotations),
63        }
64    }
65}
66
67pub(super) struct ConnectConfig {
68    pub client_info: ClientInfo,
69    pub event_sender: mpsc::Sender<McpClientEvent>,
70    pub roots: Arc<RwLock<Vec<Root>>>,
71    pub oauth_handler_factory: Option<OAuthHandlerFactory>,
72    pub oauth_credential_store: Option<Arc<dyn OAuthCredentialStorage>>,
73}
74
75/// The result of attempting to connect (or authenticate) to an MCP server.
76pub struct McpConnectAttempt {
77    pub name: String,
78    pub proxied: bool,
79    pub outcome: McpConnectOutcome,
80}
81
82pub enum McpConnectOutcome {
83    Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
84    NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
85    Failed { error: McpError },
86}
87
88impl McpConnectAttempt {
89    pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
90        Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
91    }
92}
93
94pub struct McpServerConnection {
95    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
96    pub(super) server_task: Option<JoinHandle<()>>,
97    pub(super) instructions: Option<String>,
98}
99
100impl McpServerConnection {
101    pub(super) async fn reconnect_with_auth(
102        name: &str,
103        config: StreamableHttpClientTransportConfig,
104        auth_client: AuthClient<reqwest::Client>,
105        mcp_client: McpClient,
106    ) -> Result<Self> {
107        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
108        let client = serve_client(mcp_client, transport)
109            .await
110            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
111        Ok(Self::from_parts(client, None))
112    }
113
114    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
115        let response = self
116            .client
117            .list_tools(None)
118            .await
119            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
120        Ok(response.tools)
121    }
122
123    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
124        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
125        Self { client: Arc::new(client), server_task, instructions }
126    }
127}
128
129pub(super) async fn connect_server(server: McpServer, ctx: &ConnectConfig) -> McpConnectAttempt {
130    let McpServer { name, transport, proxy: proxied } = server;
131    let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory.as_ref());
132    let mcp_client =
133        McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
134
135    let outcome = match transport {
136        McpTransport::Stdio { command, args, env } => {
137            let roots = ctx.roots.read().await;
138            let cwd = primary_root_path(&roots);
139            connect_stdio(&name, command, args, env, mcp_client, cwd).await
140        }
141        McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
142        McpTransport::Http { config } => {
143            connect_http(
144                &name,
145                config,
146                mcp_client,
147                ctx.oauth_handler_factory.as_ref(),
148                ctx.oauth_credential_store.as_ref(),
149            )
150            .await
151        }
152    };
153
154    McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
155}
156
157pub async fn authenticate_http(
158    name: String,
159    config: StreamableHttpClientTransportConfig,
160    ctx: Arc<ConnectConfig>,
161    proxied: bool,
162) -> McpConnectAttempt {
163    let outcome = match async {
164        let factory = ctx
165            .oauth_handler_factory
166            .as_ref()
167            .ok_or_else(|| McpError::ConnectionFailed(format!("No OAuth handler factory available for '{name}'")))?;
168        let handler = factory(OAuthHandlerContext { server_name: name.clone(), tx: ctx.event_sender.clone() })?;
169
170        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref(), ctx.oauth_credential_store.clone())
171            .await
172            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
173
174        let mcp_client =
175            McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(&ctx.roots));
176        McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
177    }
178    .await
179    {
180        Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
181        Err(error) => McpConnectOutcome::Failed { error },
182    };
183
184    McpConnectAttempt { name, proxied, outcome }
185}
186
187impl McpConnectOutcome {
188    fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
189        match self {
190            Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
191            other => other,
192        }
193    }
194}
195
196async fn connect_stdio(
197    server_name: &str,
198    command: String,
199    args: Vec<String>,
200    env: HashMap<String, String>,
201    mcp_client: McpClient,
202    cwd: Option<PathBuf>,
203) -> McpConnectOutcome {
204    let mut cmd = Command::new(&command);
205    cmd.args(&args).envs(&env);
206
207    if let Some(ref dir) = cwd {
208        cmd.current_dir(dir);
209    }
210
211    let (proc, stderr) = match TokioChildProcess::builder(cmd).stderr(Stdio::piped()).spawn() {
212        Ok(parts) => parts,
213        Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
214    };
215    let stderr_task = stderr.map(|stderr| spawn_stderr_logger(server_name.to_string(), stderr));
216
217    match serve_client(mcp_client, proc).await {
218        Ok(client) => McpConnectOutcome::Connected {
219            conn: McpServerConnection::from_parts(client, stderr_task),
220            reauth_config: None,
221        },
222        Err(e) => {
223            if let Some(task) = stderr_task {
224                task.abort();
225            }
226            McpConnectOutcome::Failed { error: McpError::from(e) }
227        }
228    }
229}
230
231fn spawn_stderr_logger(server_name: String, stderr: ChildStderr) -> JoinHandle<()> {
232    tokio::spawn(async move {
233        let mut lines = BufReader::new(stderr).lines();
234        loop {
235            match lines.next_line().await {
236                Ok(Some(line)) => tracing::info!(server = %server_name, stderr = %line, "MCP server stderr"),
237                Ok(None) => break,
238                Err(error) => {
239                    tracing::warn!(server = %server_name, %error, "failed to read MCP server stderr");
240                    break;
241                }
242            }
243        }
244    })
245}
246
247async fn connect_in_memory(
248    name: &str,
249    server: Box<dyn DynService<RoleServer>>,
250    mcp_client: McpClient,
251) -> McpConnectOutcome {
252    match serve_in_memory(server, mcp_client, name).await {
253        Ok((client, handle)) => McpConnectOutcome::Connected {
254            conn: McpServerConnection::from_parts(client, Some(handle)),
255            reauth_config: None,
256        },
257        Err(error) => McpConnectOutcome::Failed { error },
258    }
259}
260
261async fn connect_http(
262    name: &str,
263    config: StreamableHttpClientTransportConfig,
264    mcp_client: McpClient,
265    oauth_handler_factory: Option<&OAuthHandlerFactory>,
266    oauth_credential_store: Option<&Arc<dyn OAuthCredentialStorage>>,
267) -> McpConnectOutcome {
268    let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
269    let stored_auth_manager = if let Some(store) = oauth_credential_store
270        && config.auth_header.is_none()
271    {
272        match create_auth_manager_from_store(name, &config.uri, Arc::clone(store)).await {
273            Ok(manager) => manager,
274            Err(e) => {
275                tracing::warn!(
276                    server = %name,
277                    error = %e,
278                    "Failed to initialize auth manager from stored credentials, proceeding without auth"
279                );
280                None
281            }
282        }
283    } else {
284        None
285    };
286    let result = if let Some(auth_manager) = stored_auth_manager {
287        tracing::debug!("Using OAuth for server '{name}'");
288        let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
289        let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
290        serve_client(mcp_client, transport).await.map_err(conn_err)
291    } else {
292        let transport = StreamableHttpClientTransport::from_config(config.clone());
293        serve_client(mcp_client, transport).await.map_err(conn_err)
294    };
295
296    match result {
297        Ok(client) => {
298            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
299        }
300        Err(error) => {
301            tracing::warn!("Failed to connect to MCP server '{name}': {error}");
302            if oauth_handler_factory.is_some() && config.auth_header.is_none() {
303                McpConnectOutcome::NeedsOAuth { config, error }
304            } else {
305                McpConnectOutcome::Failed { error }
306            }
307        }
308    }
309}
310
311fn reauth_config_for(
312    transport: &McpTransport,
313    oauth_handler_factory: Option<&OAuthHandlerFactory>,
314) -> Option<StreamableHttpClientTransportConfig> {
315    match transport {
316        McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
317            Some(config.clone())
318        }
319        _ => None,
320    }
321}
322
323async fn serve_in_memory(
324    server: Box<dyn DynService<RoleServer>>,
325    mcp_client: McpClient,
326    label: &str,
327) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
328    let (client_transport, server_transport) = create_in_memory_transport();
329
330    let server_handle = tokio::spawn(async move {
331        match server.serve(server_transport).await {
332            Ok(_service) => {
333                std::future::pending::<()>().await;
334            }
335            Err(e) => {
336                eprintln!("MCP server error: {e}");
337            }
338        }
339    });
340
341    let client = serve_client(mcp_client, client_transport)
342        .await
343        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
344
345    Ok((client, server_handle))
346}