Skip to main content

mcp_utils/client/
connection.rs

1use super::{McpError, Result, config::ServerConfig, mcp_client::McpClient};
2use crate::transport::create_in_memory_transport;
3use rmcp::{
4    RoleClient, RoleServer, ServiceExt,
5    model::Tool as RmcpTool,
6    serve_client,
7    service::{DynService, RunningService},
8    transport::{
9        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
10        streamable_http_client::StreamableHttpClientTransportConfig,
11    },
12};
13use serde_json::Value;
14use std::sync::Arc;
15use tokio::{process::Command, task::JoinHandle};
16
17use super::oauth::{OAuthHandler, create_auth_manager_from_store};
18
19#[derive(Debug, Clone)]
20pub struct ServerInstructions {
21    pub server_name: String,
22    pub instructions: String,
23}
24
25#[derive(Debug, Clone)]
26pub struct Tool {
27    pub description: String,
28    pub parameters: Value,
29}
30
31impl From<RmcpTool> for Tool {
32    fn from(tool: RmcpTool) -> Self {
33        Self {
34            description: tool.description.unwrap_or_default().to_string(),
35            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
36        }
37    }
38}
39
40impl From<&RmcpTool> for Tool {
41    fn from(tool: &RmcpTool) -> Self {
42        Self {
43            description: tool.description.clone().unwrap_or_default().to_string(),
44            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
45        }
46    }
47}
48
49/// Everything the connection needs from the manager.
50pub(super) struct ConnectParams {
51    pub mcp_client: McpClient,
52    pub oauth_handler: Option<Arc<dyn OAuthHandler>>,
53}
54
55/// Result of attempting to connect to an MCP server.
56pub(super) enum ConnectResult {
57    /// Connection established successfully.
58    Connected(McpServerConnection),
59    /// HTTP server failed; may need OAuth. Carries the config for retry.
60    NeedsOAuth { name: String, config: StreamableHttpClientTransportConfig, error: McpError },
61    /// Hard failure (non-HTTP, or no OAuth handler available).
62    Failed(McpError),
63}
64
65pub(super) struct McpServerConnection {
66    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
67    pub(super) server_task: Option<JoinHandle<()>>,
68    pub(super) instructions: Option<String>,
69}
70
71impl McpServerConnection {
72    /// Connect to an MCP server described by `config`.
73    ///
74    /// This is the single entry point for establishing a connection — handling
75    /// transport creation, OAuth credential lookup, `serve_client()`, and
76    /// returning a ready-to-use connection.
77    pub(super) async fn connect(config: ServerConfig, params: ConnectParams) -> ConnectResult {
78        match config {
79            ServerConfig::Stdio { command, args, .. } => {
80                let mut cmd = Command::new(&command);
81                cmd.args(&args);
82                let child = match TokioChildProcess::new(cmd) {
83                    Ok(child) => child,
84                    Err(e) => {
85                        return ConnectResult::Failed(McpError::SpawnFailed { command, reason: e.to_string() });
86                    }
87                };
88                match params.mcp_client.serve(child).await {
89                    Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
90                    Err(e) => ConnectResult::Failed(McpError::from(e)),
91                }
92            }
93
94            ServerConfig::InMemory { name, server } => match serve_in_memory(server, params.mcp_client, &name).await {
95                Ok((client, handle)) => ConnectResult::Connected(Self::from_parts(client, Some(handle))),
96                Err(e) => ConnectResult::Failed(e),
97            },
98
99            ServerConfig::Http { name, config: cfg } => Self::connect_http(name, cfg, params).await,
100        }
101    }
102
103    /// Reconnect to an HTTP server using an already-obtained `AuthClient`.
104    ///
105    /// Used after a successful OAuth flow to establish the authenticated connection.
106    pub(super) async fn reconnect_with_auth(
107        name: &str,
108        config: StreamableHttpClientTransportConfig,
109        auth_client: AuthClient<reqwest::Client>,
110        mcp_client: McpClient,
111    ) -> Result<Self> {
112        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
113        let client = serve_client(mcp_client, transport)
114            .await
115            .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
116        Ok(Self::from_parts(client, None))
117    }
118
119    /// List tools from the connected server.
120    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
121        let response = self
122            .client
123            .list_tools(None)
124            .await
125            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
126        Ok(response.tools)
127    }
128
129    /// Build a connection from already-connected parts, extracting any
130    /// server-provided instructions from peer info.
131    fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
132        let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
133        Self { client: Arc::new(client), server_task, instructions }
134    }
135
136    /// Connect to an HTTP MCP server. Tries stored OAuth credentials first,
137    /// falls back to plain connection, and returns `NeedsOAuth` on failure
138    /// if an OAuth handler is available.
139    async fn connect_http(
140        name: String,
141        config: StreamableHttpClientTransportConfig,
142        params: ConnectParams,
143    ) -> ConnectResult {
144        let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
145
146        let result = match create_auth_client(&name, &config.uri).await {
147            Some(auth_client) if config.auth_header.is_none() => {
148                tracing::debug!("Using OAuth for server '{name}'");
149                let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
150                serve_client(params.mcp_client, transport).await.map_err(conn_err)
151            }
152            _ => {
153                let transport = StreamableHttpClientTransport::from_config(config.clone());
154                serve_client(params.mcp_client, transport).await.map_err(conn_err)
155            }
156        };
157
158        match result {
159            Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
160            Err(err) => {
161                tracing::warn!("Failed to connect to MCP server '{name}': {err}");
162                if params.oauth_handler.is_some() {
163                    ConnectResult::NeedsOAuth { name, config, error: err }
164                } else {
165                    ConnectResult::Failed(err)
166                }
167            }
168        }
169    }
170}
171
172/// Try to build an `AuthClient` from stored OAuth credentials.
173async fn create_auth_client(server_id: &str, base_url: &str) -> Option<AuthClient<reqwest::Client>> {
174    let auth_manager = create_auth_manager_from_store(server_id, base_url).await.ok()??;
175    Some(AuthClient::new(reqwest::Client::default(), auth_manager))
176}
177
178/// Spawn an in-memory MCP server on a background task and connect a client to it.
179///
180/// Returns the running client service and the server's join handle.
181async fn serve_in_memory(
182    server: Box<dyn DynService<RoleServer>>,
183    mcp_client: McpClient,
184    label: &str,
185) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
186    let (client_transport, server_transport) = create_in_memory_transport();
187
188    let server_handle = tokio::spawn(async move {
189        match server.serve(server_transport).await {
190            Ok(_service) => {
191                std::future::pending::<()>().await;
192            }
193            Err(e) => {
194                eprintln!("MCP server error: {e}");
195            }
196        }
197    });
198
199    let client = serve_client(mcp_client, client_transport)
200        .await
201        .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
202
203    Ok((client, server_handle))
204}