1use super::{McpError, Result, config::ServerConfig, mcp_client::McpClient};
2use crate::transport::create_in_memory_transport;
3use rmcp::{
4 RoleClient, RoleServer, ServiceExt,
5 model::Tool as RmcpTool,
6 serve_client,
7 service::{DynService, RunningService},
8 transport::{
9 StreamableHttpClientTransport, TokioChildProcess, auth::AuthClient,
10 streamable_http_client::StreamableHttpClientTransportConfig,
11 },
12};
13use serde_json::Value;
14use std::sync::Arc;
15use tokio::{process::Command, task::JoinHandle};
16
17use super::oauth::{OAuthHandler, create_auth_manager_from_store};
18
19#[derive(Debug, Clone)]
20pub struct ServerInstructions {
21 pub server_name: String,
22 pub instructions: String,
23}
24
25#[derive(Debug, Clone)]
26pub struct Tool {
27 pub description: String,
28 pub parameters: Value,
29}
30
31impl From<RmcpTool> for Tool {
32 fn from(tool: RmcpTool) -> Self {
33 Self {
34 description: tool.description.unwrap_or_default().to_string(),
35 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
36 }
37 }
38}
39
40impl From<&RmcpTool> for Tool {
41 fn from(tool: &RmcpTool) -> Self {
42 Self {
43 description: tool.description.clone().unwrap_or_default().to_string(),
44 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
45 }
46 }
47}
48
49pub(super) struct ConnectParams {
51 pub mcp_client: McpClient,
52 pub oauth_handler: Option<Arc<dyn OAuthHandler>>,
53}
54
55pub(super) enum ConnectResult {
57 Connected(McpServerConnection),
59 NeedsOAuth { name: String, config: StreamableHttpClientTransportConfig, error: McpError },
61 Failed(McpError),
63}
64
65pub(super) struct McpServerConnection {
66 pub(super) client: Arc<RunningService<RoleClient, McpClient>>,
67 pub(super) server_task: Option<JoinHandle<()>>,
68 pub(super) instructions: Option<String>,
69}
70
71impl McpServerConnection {
72 pub(super) async fn connect(config: ServerConfig, params: ConnectParams) -> ConnectResult {
78 match config {
79 ServerConfig::Stdio { command, args, .. } => {
80 let mut cmd = Command::new(&command);
81 cmd.args(&args);
82 let child = match TokioChildProcess::new(cmd) {
83 Ok(child) => child,
84 Err(e) => {
85 return ConnectResult::Failed(McpError::SpawnFailed { command, reason: e.to_string() });
86 }
87 };
88 match params.mcp_client.serve(child).await {
89 Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
90 Err(e) => ConnectResult::Failed(McpError::from(e)),
91 }
92 }
93
94 ServerConfig::InMemory { name, server } => match serve_in_memory(server, params.mcp_client, &name).await {
95 Ok((client, handle)) => ConnectResult::Connected(Self::from_parts(client, Some(handle))),
96 Err(e) => ConnectResult::Failed(e),
97 },
98
99 ServerConfig::Http { name, config: cfg } => Self::connect_http(name, cfg, params).await,
100 }
101 }
102
103 pub(super) async fn reconnect_with_auth(
107 name: &str,
108 config: StreamableHttpClientTransportConfig,
109 auth_client: AuthClient<reqwest::Client>,
110 mcp_client: McpClient,
111 ) -> Result<Self> {
112 let transport = StreamableHttpClientTransport::with_client(auth_client, config);
113 let client = serve_client(mcp_client, transport)
114 .await
115 .map_err(|e| McpError::ConnectionFailed(format!("reconnect failed for '{name}': {e}")))?;
116 Ok(Self::from_parts(client, None))
117 }
118
119 pub(super) async fn list_tools(&self) -> Result<Vec<RmcpTool>> {
121 let response = self
122 .client
123 .list_tools(None)
124 .await
125 .map_err(|e| McpError::ToolDiscoveryFailed(format!("Failed to list tools: {e}")))?;
126 Ok(response.tools)
127 }
128
129 fn from_parts(client: RunningService<RoleClient, McpClient>, server_task: Option<JoinHandle<()>>) -> Self {
132 let instructions = client.peer_info().and_then(|info| info.instructions.clone()).filter(|s| !s.is_empty());
133 Self { client: Arc::new(client), server_task, instructions }
134 }
135
136 async fn connect_http(
140 name: String,
141 config: StreamableHttpClientTransportConfig,
142 params: ConnectParams,
143 ) -> ConnectResult {
144 let conn_err = |e| McpError::ConnectionFailed(format!("HTTP MCP server {name}: {e}"));
145
146 let result = match create_auth_client(&name, &config.uri).await {
147 Some(auth_client) if config.auth_header.is_none() => {
148 tracing::debug!("Using OAuth for server '{name}'");
149 let transport = StreamableHttpClientTransport::with_client(auth_client, config.clone());
150 serve_client(params.mcp_client, transport).await.map_err(conn_err)
151 }
152 _ => {
153 let transport = StreamableHttpClientTransport::from_config(config.clone());
154 serve_client(params.mcp_client, transport).await.map_err(conn_err)
155 }
156 };
157
158 match result {
159 Ok(client) => ConnectResult::Connected(Self::from_parts(client, None)),
160 Err(err) => {
161 tracing::warn!("Failed to connect to MCP server '{name}': {err}");
162 if params.oauth_handler.is_some() {
163 ConnectResult::NeedsOAuth { name, config, error: err }
164 } else {
165 ConnectResult::Failed(err)
166 }
167 }
168 }
169 }
170}
171
172async fn create_auth_client(server_id: &str, base_url: &str) -> Option<AuthClient<reqwest::Client>> {
174 let auth_manager = create_auth_manager_from_store(server_id, base_url).await.ok()??;
175 Some(AuthClient::new(reqwest::Client::default(), auth_manager))
176}
177
178async fn serve_in_memory(
182 server: Box<dyn DynService<RoleServer>>,
183 mcp_client: McpClient,
184 label: &str,
185) -> Result<(RunningService<RoleClient, McpClient>, JoinHandle<()>)> {
186 let (client_transport, server_transport) = create_in_memory_transport();
187
188 let server_handle = tokio::spawn(async move {
189 match server.serve(server_transport).await {
190 Ok(_service) => {
191 std::future::pending::<()>().await;
192 }
193 Err(e) => {
194 eprintln!("MCP server error: {e}");
195 }
196 }
197 });
198
199 let client = serve_client(mcp_client, client_transport)
200 .await
201 .map_err(|e| McpError::ConnectionFailed(format!("Failed to connect to in-memory server '{label}': {e}")))?;
202
203 Ok((client, server_handle))
204}