Skip to main content

mcp_utils/client/
connection.rs

1use super::{
2    McpClientEvent, McpError, OAuthHandlerFactory, Result,
3    config::{McpServer, McpTransport},
4    mcp_client::McpClient,
5    oauth::{create_auth_manager_from_store, perform_oauth_flow},
6};
7use crate::transport::create_in_memory_transport;
8use rmcp::{
9    RoleClient, RoleServer, ServiceExt,
10    model::{ClientInfo, Root, Tool as RmcpTool},
11    serve_client,
12    service::{DynService, RunningService},
13    transport::{
14        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15        streamable_http_client::StreamableHttpClientTransportConfig,
16    },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::{
22    process::Command,
23    sync::{RwLock, mpsc},
24    task::JoinHandle,
25};
26
27#[derive(Debug, Clone)]
28pub struct ServerInstructions {
29    pub server_name: String,
30    pub instructions: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct Tool {
35    pub description: String,
36    pub parameters: Value,
37}
38
39impl From<RmcpTool> for Tool {
40    fn from(tool: RmcpTool) -> Self {
41        Self {
42            description: tool.description.unwrap_or_default().to_string(),
43            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
44        }
45    }
46}
47
48impl From<&RmcpTool> for Tool {
49    fn from(tool: &RmcpTool) -> Self {
50        Self {
51            description: tool.description.clone().unwrap_or_default().to_string(),
52            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
53        }
54    }
55}
56
57pub(super) struct ConnectContext<'a> {
58    pub client_info: &'a ClientInfo,
59    pub event_sender: &'a mpsc::Sender<McpClientEvent>,
60    pub roots: &'a Arc<RwLock<Vec<Root>>>,
61    pub oauth_handler_factory: Option<&'a OAuthHandlerFactory>,
62}
63
64/// The result of attempting to connect (or authenticate) to an MCP server.
65pub struct McpConnectAttempt {
66    pub name: String,
67    pub proxied: bool,
68    pub outcome: McpConnectOutcome,
69}
70
71pub enum McpConnectOutcome {
72    Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
73    NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
74    Failed { error: McpError },
75}
76
77impl McpConnectAttempt {
78    pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
79        Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
80    }
81}
82
83pub struct McpServerConnection {
84    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
85    pub(super) server_task: Option<JoinHandle<()>>,
86    pub(super) instructions: Option<String>,
87}
88
89impl McpServerConnection {
90    pub(super) async fn reconnect_with_auth(
91        name: &str,
92        config: StreamableHttpClientTransportConfig,
93        auth_client: AuthClient<reqwest::Client>,
94        mcp_client: McpClient,
95    ) -> Result<Self> {
96        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
97        let client = serve_client(mcp_client, transport)
98            .await
99            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
100        Ok(Self::from_parts(client, None))
101    }
102
103    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
104        let response = self
105            .client
106            .list_tools(None)
107            .await
108            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
109        Ok(response.tools)
110    }
111
112    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
113        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
114        Self { client: Arc::new(client), server_task, instructions }
115    }
116}
117
118pub(super) async fn connect_server(server: McpServer, ctx: &ConnectContext<'_>) -> McpConnectAttempt {
119    let McpServer { name, transport, proxy: proxied } = server;
120    let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory);
121    let mcp_client =
122        McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(ctx.roots));
123
124    let outcome = match transport {
125        McpTransport::Stdio { command, args, env } => connect_stdio(command, args, env, mcp_client).await,
126        McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
127        McpTransport::Http { config } => connect_http(&name, config, mcp_client, ctx.oauth_handler_factory).await,
128    };
129
130    McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
131}
132
133pub async fn authenticate_http(
134    name: String,
135    config: StreamableHttpClientTransportConfig,
136    client_info: ClientInfo,
137    event_sender: mpsc::Sender<McpClientEvent>,
138    roots: Arc<RwLock<Vec<Root>>>,
139    oauth_handler_factory: OAuthHandlerFactory,
140    proxied: bool,
141) -> McpConnectAttempt {
142    let outcome = match async {
143        let handler = oauth_handler_factory()?;
144        let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
145            .await
146            .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
147
148        let mcp_client = McpClient::new(client_info, name.clone(), event_sender, roots);
149        McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
150    }
151    .await
152    {
153        Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
154        Err(error) => McpConnectOutcome::Failed { error },
155    };
156
157    McpConnectAttempt { name, proxied, outcome }
158}
159
160impl McpConnectOutcome {
161    fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
162        match self {
163            Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
164            other => other,
165        }
166    }
167}
168
169async fn connect_stdio(
170    command: String,
171    args: Vec<String>,
172    env: HashMap<String, String>,
173    mcp_client: McpClient,
174) -> McpConnectOutcome {
175    let cmd = {
176        let mut cmd = Command::new(&command);
177        cmd.args(&args);
178        cmd.envs(&env);
179        cmd
180    };
181
182    let child = match TokioChildProcess::new(cmd) {
183        Ok(child) => child,
184        Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
185    };
186
187    match mcp_client.serve(child).await {
188        Ok(client) => {
189            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
190        }
191        Err(e) => McpConnectOutcome::Failed { error: McpError::from(e) },
192    }
193}
194
195async fn connect_in_memory(
196    name: &str,
197    server: Box<dyn DynService<RoleServer>>,
198    mcp_client: McpClient,
199) -> McpConnectOutcome {
200    match serve_in_memory(server, mcp_client, name).await {
201        Ok((client, handle)) => McpConnectOutcome::Connected {
202            conn: McpServerConnection::from_parts(client, Some(handle)),
203            reauth_config: None,
204        },
205        Err(error) => McpConnectOutcome::Failed { error },
206    }
207}
208
209async fn connect_http(
210    name: &str,
211    config: StreamableHttpClientTransportConfig,
212    mcp_client: McpClient,
213    oauth_handler_factory: Option<&OAuthHandlerFactory>,
214) -> McpConnectOutcome {
215    let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
216    let result = if config.auth_header.is_none()
217        && let Ok(Some(auth_manager)) = create_auth_manager_from_store(name, &config.uri).await
218    {
219        tracing::debug!("Using OAuth for server '{name}'");
220        let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
221        let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
222        serve_client(mcp_client, transport).await.map_err(conn_err)
223    } else {
224        let transport = StreamableHttpClientTransport::from_config(config.clone());
225        serve_client(mcp_client, transport).await.map_err(conn_err)
226    };
227
228    match result {
229        Ok(client) => {
230            McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
231        }
232        Err(error) => {
233            tracing::warn!("Failed to connect to MCP server '{name}': {error}");
234            if oauth_handler_factory.is_some() && config.auth_header.is_none() {
235                McpConnectOutcome::NeedsOAuth { config, error }
236            } else {
237                McpConnectOutcome::Failed { error }
238            }
239        }
240    }
241}
242
243fn reauth_config_for(
244    transport: &McpTransport,
245    oauth_handler_factory: Option<&OAuthHandlerFactory>,
246) -> Option<StreamableHttpClientTransportConfig> {
247    match transport {
248        McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
249            Some(config.clone())
250        }
251        _ => None,
252    }
253}
254
255async fn serve_in_memory(
256    server: Box<dyn DynService<RoleServer>>,
257    mcp_client: McpClient,
258    label: &str,
259) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
260    let (client_transport, server_transport) = create_in_memory_transport();
261
262    let server_handle = tokio::spawn(async move {
263        match server.serve(server_transport).await {
264            Ok(_service) => {
265                std::future::pending::<()>().await;
266            }
267            Err(e) => {
268                eprintln!("MCP server error: {e}");
269            }
270        }
271    });
272
273    let client = serve_client(mcp_client, client_transport)
274        .await
275        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
276
277    Ok((client, server_handle))
278}