1use super::{McpError, Result, config::ServerConfig, mcp_client::McpClient};
2use crate::transport::create_in_memory_transport;
3use rmcp::{
4 RoleClient, RoleServer, ServiceExt,
5 model::Tool as RmcpTool,
6 serve_client,
7 service::{DynService, RunningService},
8 transport::{
9 StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
10 streamable_http_client::StreamableHttpClientTransportConfig,
11 },
12};
13use serde_json::Value;
14use std::sync::Arc;
15use tokio::{process::Command, task::JoinHandle};
16
17use super::oauth::{OAuthHandler, create_auth_manager_from_store};
18
19#[derive(Debug, Clone)]
20pub struct ServerInstructions {
21 pub server_name: String,
22 pub instructions: String,
23}
24
25#[derive(Debug, Clone)]
26pub struct Tool {
27 pub description: String,
28 pub parameters: Value,
29}
30
31impl From<RmcpTool> for Tool {
32 fn from(tool: RmcpTool) -> Self {
33 Self {
34 description: tool.description.unwrap_or_default().to_string(),
35 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
36 }
37 }
38}
39
40impl From<&RmcpTool> for Tool {
41 fn from(tool: &RmcpTool) -> Self {
42 Self {
43 description: tool.description.clone().unwrap_or_default().to_string(),
44 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
45 }
46 }
47}
48
49pub(super) struct ConnectParams {
51 pub mcp_client: McpClient,
52 pub oauth_handler: Option<Arc<dyn OAuthHandler>>,
53}
54
55pub(super) enum ConnectResult {
57 Connected(McpServerConnection),
59 NeedsOAuth {
61 name: String,
62 config: StreamableHttpClientTransportConfig,
63 error: McpError,
64 },
65 Failed(McpError),
67}
68
69pub(super) struct McpServerConnection {
70 pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
71 pub(super) server_task: Option<JoinHandle<()>>,
72 pub(super) instructions: Option<String>,
73}
74
75impl McpServerConnection {
76 pub(super) async fn connect(config: ServerConfig, params: ConnectParams) -> ConnectResult {
82 match config {
83 ServerConfig::Stdio { command, args, .. } => {
84 let mut cmd = Command::new(&command);
85 cmd.args(&args);
86 let child = match TokioChildProcess::new(cmd) {
87 Ok(child) => child,
88 Err(e) => {
89 return ConnectResult::Failed(McpError::SpawnFailed {
90 command,
91 reason: e.to_string(),
92 });
93 }
94 };
95 match params.mcp_client.serve(child).await {
96 Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
97 Err(e) => ConnectResult::Failed(McpError::from(e)),
98 }
99 }
100
101 ServerConfig::InMemory { name, server } => {
102 match serve_in_memory(server, params.mcp_client, &name).await {
103 Ok((client, handle)) => {
104 ConnectResult::Connected(Self::from_parts(client, Some(handle)))
105 }
106 Err(e) => ConnectResult::Failed(e),
107 }
108 }
109
110 ServerConfig::Http { name, config: cfg } => Self::connect_http(name, cfg, params).await,
111 }
112 }
113
114 pub(super) async fn reconnect_with_auth(
118 name: &str,
119 config: StreamableHttpClientTransportConfig,
120 auth_client: AuthClient<reqwest::Client>,
121 mcp_client: McpClient,
122 ) -> Result<Self> {
123 let transport = StreamableHttpClientTransport::with_client(auth_client, config);
124 let client = serve_client(mcp_client, transport).await.map_err(|e| {
125 McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}"))
126 })?;
127 Ok(Self::from_parts(client, None))
128 }
129
130 pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
132 let response = self
133 .client
134 .list_tools(None)
135 .await
136 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
137 Ok(response.tools)
138 }
139
140 fn from_parts(
143 client: RunningService<RoleClient, McpClient>,
144 server_task: Option<JoinHandle<()>>,
145 ) -> Self {
146 let instructions = client
147 .peer_info()
148 .and_then(|info| info.instructions.clone())
149 .filter(|s| !s.is_empty());
150 Self {
151 client: Arc::new(client),
152 server_task,
153 instructions,
154 }
155 }
156
157 async fn connect_http(
161 name: String,
162 config: StreamableHttpClientTransportConfig,
163 params: ConnectParams,
164 ) -> ConnectResult {
165 let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
166
167 let result = match create_auth_client(&name, &config.uri).await {
168 Some(auth_client) if config.auth_header.is_none() => {
169 tracing::debug!("Using OAuth for server '{name}'");
170 let transport =
171 StreamableHttpClientTransport::with_client(auth_client, config.clone());
172 serve_client(params.mcp_client, transport)
173 .await
174 .map_err(conn_err)
175 }
176 _ => {
177 let transport = StreamableHttpClientTransport::from_config(config.clone());
178 serve_client(params.mcp_client, transport)
179 .await
180 .map_err(conn_err)
181 }
182 };
183
184 match result {
185 Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
186 Err(err) => {
187 tracing::warn!("Failed to connect to MCP server '{name}': {err}");
188 if params.oauth_handler.is_some() {
189 ConnectResult::NeedsOAuth {
190 name,
191 config,
192 error: err,
193 }
194 } else {
195 ConnectResult::Failed(err)
196 }
197 }
198 }
199 }
200}
201
202async fn create_auth_client(
204 server_id: &str,
205 base_url: &str,
206) -> Option<AuthClient<reqwest::Client>> {
207 let auth_manager = create_auth_manager_from_store(server_id, base_url)
208 .await
209 .ok()??;
210 Some(AuthClient::new(reqwest::Client::default(), auth_manager))
211}
212
213async fn serve_in_memory(
217 server: Box<dyn DynService<RoleServer>>,
218 mcp_client: McpClient,
219 label: &str,
220) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
221 let (client_transport, server_transport) = create_in_memory_transport();
222
223 let server_handle = tokio::spawn(async move {
224 match server.serve(server_transport).await {
225 Ok(_service) => {
226 std::future::pending::<()>().await;
227 }
228 Err(e) => {
229 eprintln!("MCP server error: {e}");
230 }
231 }
232 });
233
234 let client = serve_client(mcp_client, client_transport)
235 .await
236 .map_err(|e| {
237 McpError::ConnectionFailed(format!(
238 "Failed to connect to in-memory server '{label}': {e}"
239 ))
240 })?;
241
242 Ok((client, server_handle))
243}