Skip to main content

mcp_utils/client/
connection.rs

1use super::{McpError, Result, config::ServerConfig, mcp_client::McpClient};
2use crate::transport::create_in_memory_transport;
3use rmcp::{
4    RoleClient, RoleServer, ServiceExt,
5    model::Tool as RmcpTool,
6    serve_client,
7    service::{DynService, RunningService},
8    transport::{
9        StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
10        streamable_http_client::StreamableHttpClientTransportConfig,
11    },
12};
13use serde_json::Value;
14use std::sync::Arc;
15use tokio::{process::Command, task::JoinHandle};
16
17use super::oauth::{OAuthHandler, create_auth_manager_from_store};
18
19#[derive(Debug, Clone)]
20pub struct ServerInstructions {
21    pub server_name: String,
22    pub instructions: String,
23}
24
25#[derive(Debug, Clone)]
26pub struct Tool {
27    pub description: String,
28    pub parameters: Value,
29}
30
31impl From<RmcpTool> for Tool {
32    fn from(tool: RmcpTool) -> Self {
33        Self {
34            description: tool.description.unwrap_or_default().to_string(),
35            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
36        }
37    }
38}
39
40impl From<&RmcpTool> for Tool {
41    fn from(tool: &RmcpTool) -> Self {
42        Self {
43            description: tool.description.clone().unwrap_or_default().to_string(),
44            parameters: serde_json::Value::Object((*tool.input_schema).clone()),
45        }
46    }
47}
48
49/// Everything the connection needs from the manager.
50pub(super) struct ConnectParams {
51    pub mcp_client: McpClient,
52    pub oauth_handler: Option<Arc<dyn OAuthHandler>>,
53}
54
55/// Result of attempting to connect to an MCP server.
56pub(super) enum ConnectResult {
57    /// Connection established successfully.
58    Connected(McpServerConnection),
59    /// HTTP server failed; may need OAuth. Carries the config for retry.
60    NeedsOAuth {
61        name: String,
62        config: StreamableHttpClientTransportConfig,
63        error: McpError,
64    },
65    /// Hard failure (non-HTTP, or no OAuth handler available).
66    Failed(McpError),
67}
68
69pub(super) struct McpServerConnection {
70    pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
71    pub(super) server_task: Option<JoinHandle<()>>,
72    pub(super) instructions: Option<String>,
73}
74
75impl McpServerConnection {
76    /// Connect to an MCP server described by `config`.
77    ///
78    /// This is the single entry point for establishing a connection — handling
79    /// transport creation, OAuth credential lookup, `serve_client()`, and
80    /// returning a ready-to-use connection.
81    pub(super) async fn connect(config: ServerConfig, params: ConnectParams) -> ConnectResult {
82        match config {
83            ServerConfig::Stdio { command, args, .. } => {
84                let mut cmd = Command::new(&command);
85                cmd.args(&args);
86                let child = match TokioChildProcess::new(cmd) {
87                    Ok(child) => child,
88                    Err(e) => {
89                        return ConnectResult::Failed(McpError::SpawnFailed {
90                            command,
91                            reason: e.to_string(),
92                        });
93                    }
94                };
95                match params.mcp_client.serve(child).await {
96                    Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
97                    Err(e) => ConnectResult::Failed(McpError::from(e)),
98                }
99            }
100
101            ServerConfig::InMemory { name, server } => {
102                match serve_in_memory(server, params.mcp_client, &name).await {
103                    Ok((client, handle)) => {
104                        ConnectResult::Connected(Self::from_parts(client, Some(handle)))
105                    }
106                    Err(e) => ConnectResult::Failed(e),
107                }
108            }
109
110            ServerConfig::Http { name, config: cfg } => Self::connect_http(name, cfg, params).await,
111        }
112    }
113
114    /// Reconnect to an HTTP server using an already-obtained `AuthClient`.
115    ///
116    /// Used after a successful OAuth flow to establish the authenticated connection.
117    pub(super) async fn reconnect_with_auth(
118        name: &str,
119        config: StreamableHttpClientTransportConfig,
120        auth_client: AuthClient<reqwest::Client>,
121        mcp_client: McpClient,
122    ) -> Result<Self> {
123        let transport = StreamableHttpClientTransport::with_client(auth_client, config);
124        let client = serve_client(mcp_client, transport).await.map_err(|e| {
125            McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}"))
126        })?;
127        Ok(Self::from_parts(client, None))
128    }
129
130    /// List tools from the connected server.
131    pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
132        let response = self
133            .client
134            .list_tools(None)
135            .await
136            .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
137        Ok(response.tools)
138    }
139
140    /// Build a connection from already-connected parts, extracting any
141    /// server-provided instructions from peer info.
142    fn from_parts(
143        client: RunningService<RoleClient, McpClient>,
144        server_task: Option<JoinHandle<()>>,
145    ) -> Self {
146        let instructions = client
147            .peer_info()
148            .and_then(|info| info.instructions.clone())
149            .filter(|s| !s.is_empty());
150        Self {
151            client: Arc::new(client),
152            server_task,
153            instructions,
154        }
155    }
156
157    /// Connect to an HTTP MCP server. Tries stored OAuth credentials first,
158    /// falls back to plain connection, and returns `NeedsOAuth` on failure
159    /// if an OAuth handler is available.
160    async fn connect_http(
161        name: String,
162        config: StreamableHttpClientTransportConfig,
163        params: ConnectParams,
164    ) -> ConnectResult {
165        let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
166
167        let result = match create_auth_client(&name, &config.uri).await {
168            Some(auth_client) if config.auth_header.is_none() => {
169                tracing::debug!("Using OAuth for server '{name}'");
170                let transport =
171                    StreamableHttpClientTransport::with_client(auth_client, config.clone());
172                serve_client(params.mcp_client, transport)
173                    .await
174                    .map_err(conn_err)
175            }
176            _ => {
177                let transport = StreamableHttpClientTransport::from_config(config.clone());
178                serve_client(params.mcp_client, transport)
179                    .await
180                    .map_err(conn_err)
181            }
182        };
183
184        match result {
185            Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
186            Err(err) => {
187                tracing::warn!("Failed to connect to MCP server '{name}': {err}");
188                if params.oauth_handler.is_some() {
189                    ConnectResult::NeedsOAuth {
190                        name,
191                        config,
192                        error: err,
193                    }
194                } else {
195                    ConnectResult::Failed(err)
196                }
197            }
198        }
199    }
200}
201
202/// Try to build an `AuthClient` from stored OAuth credentials.
203async fn create_auth_client(
204    server_id: &str,
205    base_url: &str,
206) -> Option<AuthClient<reqwest::Client>> {
207    let auth_manager = create_auth_manager_from_store(server_id, base_url)
208        .await
209        .ok()??;
210    Some(AuthClient::new(reqwest::Client::default(), auth_manager))
211}
212
213/// Spawn an in-memory MCP server on a background task and connect a client to it.
214///
215/// Returns the running client service and the server's join handle.
216async fn serve_in_memory(
217    server: Box<dyn DynService<RoleServer>>,
218    mcp_client: McpClient,
219    label: &str,
220) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
221    let (client_transport, server_transport) = create_in_memory_transport();
222
223    let server_handle = tokio::spawn(async move {
224        match server.serve(server_transport).await {
225            Ok(_service) => {
226                std::future::pending::<()>().await;
227            }
228            Err(e) => {
229                eprintln!("MCP server error: {e}");
230            }
231        }
232    });
233
234    let client = serve_client(mcp_client, client_transport)
235        .await
236        .map_err(|e| {
237            McpError::ConnectionFailed(format!(
238                "Failed to connect to in-memory server '{label}': {e}"
239            ))
240        })?;
241
242    Ok((client, server_handle))
243}