1use super::{
2 McpClientEvent, McpError, OAuthHandlerFactory, Result,
3 config::{McpServer, McpTransport},
4 mcp_client::McpClient,
5 oauth::{create_auth_manager_from_store, perform_oauth_flow},
6};
7use crate::transport::create_in_memory_transport;
8use rmcp::{
9 RoleClient, RoleServer, ServiceExt,
10 model::{ClientInfo, Root, Tool as RmcpTool},
11 serve_client,
12 service::{DynService, RunningService},
13 transport::{
14 StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
15 streamable_http_client::StreamableHttpClientTransportConfig,
16 },
17};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::{
22 process::Command,
23 sync::{RwLock, mpsc},
24 task::JoinHandle,
25};
26
27#[derive(Debug, Clone)]
28pub struct ServerInstructions {
29 pub server_name: String,
30 pub instructions: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct Tool {
35 pub description: String,
36 pub parameters: Value,
37}
38
39impl From<RmcpTool> for Tool {
40 fn from(tool: RmcpTool) -> Self {
41 Self {
42 description: tool.description.unwrap_or_default().to_string(),
43 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
44 }
45 }
46}
47
48impl From<&RmcpTool> for Tool {
49 fn from(tool: &RmcpTool) -> Self {
50 Self {
51 description: tool.description.clone().unwrap_or_default().to_string(),
52 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
53 }
54 }
55}
56
57pub(super) struct ConnectContext<'a> {
58 pub client_info: &'a ClientInfo,
59 pub event_sender: &'a mpsc::Sender<McpClientEvent>,
60 pub roots: &'a Arc<RwLock<Vec<Root>>>,
61 pub oauth_handler_factory: Option<&'a OAuthHandlerFactory>,
62}
63
64pub struct McpConnectAttempt {
66 pub name: String,
67 pub proxied: bool,
68 pub outcome: McpConnectOutcome,
69}
70
71pub enum McpConnectOutcome {
72 Connected { conn: McpServerConnection, reauth_config: Option<StreamableHttpClientTransportConfig> },
73 NeedsOAuth { config: StreamableHttpClientTransportConfig, error: McpError },
74 Failed { error: McpError },
75}
76
77impl McpConnectAttempt {
78 pub fn failed(name: impl Into<String>, error: McpError, proxied: bool) -> Self {
79 Self { name: name.into(), proxied, outcome: McpConnectOutcome::Failed { error } }
80 }
81}
82
83pub struct McpServerConnection {
84 pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
85 pub(super) server_task: Option<JoinHandle<()>>,
86 pub(super) instructions: Option<String>,
87}
88
89impl McpServerConnection {
90 pub(super) async fn reconnect_with_auth(
91 name: &str,
92 config: StreamableHttpClientTransportConfig,
93 auth_client: AuthClient<reqwest::Client>,
94 mcp_client: McpClient,
95 ) -> Result<Self> {
96 let transport = StreamableHttpClientTransport::with_client(auth_client, config);
97 let client = serve_client(mcp_client, transport)
98 .await
99 .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
100 Ok(Self::from_parts(client, None))
101 }
102
103 pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
104 let response = self
105 .client
106 .list_tools(None)
107 .await
108 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
109 Ok(response.tools)
110 }
111
112 fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
113 let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
114 Self { client: Arc::new(client), server_task, instructions }
115 }
116}
117
118pub(super) async fn connect_server(server: McpServer, ctx: &ConnectContext<'_>) -> McpConnectAttempt {
119 let McpServer { name, transport, proxy: proxied } = server;
120 let reauth_config = reauth_config_for(&transport, ctx.oauth_handler_factory);
121 let mcp_client =
122 McpClient::new(ctx.client_info.clone(), name.clone(), ctx.event_sender.clone(), Arc::clone(ctx.roots));
123
124 let outcome = match transport {
125 McpTransport::Stdio { command, args, env } => connect_stdio(command, args, env, mcp_client).await,
126 McpTransport::InMemory { server } => connect_in_memory(&name, server, mcp_client).await,
127 McpTransport::Http { config } => connect_http(&name, config, mcp_client, ctx.oauth_handler_factory).await,
128 };
129
130 McpConnectAttempt { name, proxied, outcome: outcome.with_reauth(reauth_config) }
131}
132
133pub async fn authenticate_http(
134 name: String,
135 config: StreamableHttpClientTransportConfig,
136 client_info: ClientInfo,
137 event_sender: mpsc::Sender<McpClientEvent>,
138 roots: Arc<RwLock<Vec<Root>>>,
139 oauth_handler_factory: OAuthHandlerFactory,
140 proxied: bool,
141) -> McpConnectAttempt {
142 let outcome = match async {
143 let handler = oauth_handler_factory()?;
144 let auth_client = perform_oauth_flow(&name, &config.uri, handler.as_ref())
145 .await
146 .map_err(|e| McpError::ConnectionFailed(format!("OAuth failed for '{name}': {e}")))?;
147
148 let mcp_client = McpClient::new(client_info, name.clone(), event_sender, roots);
149 McpServerConnection::reconnect_with_auth(&name, config.clone(), auth_client, mcp_client).await
150 }
151 .await
152 {
153 Ok(conn) => McpConnectOutcome::Connected { conn, reauth_config: Some(config) },
154 Err(error) => McpConnectOutcome::Failed { error },
155 };
156
157 McpConnectAttempt { name, proxied, outcome }
158}
159
160impl McpConnectOutcome {
161 fn with_reauth(self, reauth_config: Option<StreamableHttpClientTransportConfig>) -> Self {
162 match self {
163 Self::Connected { conn, .. } => Self::Connected { conn, reauth_config },
164 other => other,
165 }
166 }
167}
168
169async fn connect_stdio(
170 command: String,
171 args: Vec<String>,
172 env: HashMap<String, String>,
173 mcp_client: McpClient,
174) -> McpConnectOutcome {
175 let cmd = {
176 let mut cmd = Command::new(&command);
177 cmd.args(&args);
178 cmd.envs(&env);
179 cmd
180 };
181
182 let child = match TokioChildProcess::new(cmd) {
183 Ok(child) => child,
184 Err(e) => return McpConnectOutcome::Failed { error: McpError::SpawnFailed { command, reason: e.to_string() } },
185 };
186
187 match mcp_client.serve(child).await {
188 Ok(client) => {
189 McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
190 }
191 Err(e) => McpConnectOutcome::Failed { error: McpError::from(e) },
192 }
193}
194
195async fn connect_in_memory(
196 name: &str,
197 server: Box<dyn DynService<RoleServer>>,
198 mcp_client: McpClient,
199) -> McpConnectOutcome {
200 match serve_in_memory(server, mcp_client, name).await {
201 Ok((client, handle)) => McpConnectOutcome::Connected {
202 conn: McpServerConnection::from_parts(client, Some(handle)),
203 reauth_config: None,
204 },
205 Err(error) => McpConnectOutcome::Failed { error },
206 }
207}
208
209async fn connect_http(
210 name: &str,
211 config: StreamableHttpClientTransportConfig,
212 mcp_client: McpClient,
213 oauth_handler_factory: Option<&OAuthHandlerFactory>,
214) -> McpConnectOutcome {
215 let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
216 let result = if config.auth_header.is_none()
217 && let Ok(Some(auth_manager)) = create_auth_manager_from_store(name, &config.uri).await
218 {
219 tracing::debug!("Using OAuth for server '{name}'");
220 let auth_client = AuthClient::new(reqwest::Client::default(), auth_manager);
221 let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
222 serve_client(mcp_client, transport).await.map_err(conn_err)
223 } else {
224 let transport = StreamableHttpClientTransport::from_config(config.clone());
225 serve_client(mcp_client, transport).await.map_err(conn_err)
226 };
227
228 match result {
229 Ok(client) => {
230 McpConnectOutcome::Connected { conn: McpServerConnection::from_parts(client, None), reauth_config: None }
231 }
232 Err(error) => {
233 tracing::warn!("Failed to connect to MCP server '{name}': {error}");
234 if oauth_handler_factory.is_some() && config.auth_header.is_none() {
235 McpConnectOutcome::NeedsOAuth { config, error }
236 } else {
237 McpConnectOutcome::Failed { error }
238 }
239 }
240 }
241}
242
243fn reauth_config_for(
244 transport: &McpTransport,
245 oauth_handler_factory: Option<&OAuthHandlerFactory>,
246) -> Option<StreamableHttpClientTransportConfig> {
247 match transport {
248 McpTransport::Http { config } if oauth_handler_factory.is_some() && config.auth_header.is_none() => {
249 Some(config.clone())
250 }
251 _ => None,
252 }
253}
254
255async fn serve_in_memory(
256 server: Box<dyn DynService<RoleServer>>,
257 mcp_client: McpClient,
258 label: &str,
259) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
260 let (client_transport, server_transport) = create_in_memory_transport();
261
262 let server_handle = tokio::spawn(async move {
263 match server.serve(server_transport).await {
264 Ok(_service) => {
265 std::future::pending::<()>().await;
266 }
267 Err(e) => {
268 eprintln!("MCP server error: {e}");
269 }
270 }
271 });
272
273 let client = serve_client(mcp_client, client_transport)
274 .await
275 .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
276
277 Ok((client, server_handle))
278}