1use std::time::Duration;
8
9use anyhow::{Context, Result};
10use rmcp::ServiceExt;
11use rmcp::model::CallToolRequestParams;
12use rmcp::service::RunningService;
13use rmcp::transport::StreamableHttpClientTransport;
14use rmcp::transport::child_process::TokioChildProcess;
15use serde_json::Value;
16use tokio::process::Command;
17
18use super::config::{McpServerConfig, McpTransport};
19use super::tool_bridge::McpToolAnnotations;
20use crate::providers::ToolDefinition;
21use crate::tools::web_fetch::is_safe_url;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum McpClientStatus {
26 Disconnected,
28 Connecting,
30 Connected,
32 Failed,
34}
35
36#[derive(Debug, Clone)]
38pub struct DiscoveredTool {
39 pub definition: ToolDefinition,
41 pub annotations: McpToolAnnotations,
43 pub original_name: String,
45}
46
47pub struct McpClient {
49 name: String,
51 config: McpServerConfig,
53 service: Option<RunningService<rmcp::service::RoleClient, ()>>,
55 tools: Vec<DiscoveredTool>,
57 status: McpClientStatus,
59 last_error: Option<String>,
61 instructions: Option<String>,
65}
66
67impl McpClient {
68 pub fn new(name: String, config: McpServerConfig) -> Self {
70 Self {
71 name,
72 config,
73 service: None,
74 tools: Vec::new(),
75 status: McpClientStatus::Disconnected,
76 last_error: None,
77 instructions: None,
78 }
79 }
80
81 pub fn name(&self) -> &str {
83 &self.name
84 }
85
86 pub fn status(&self) -> McpClientStatus {
88 self.status
89 }
90
91 pub fn last_error(&self) -> Option<&str> {
93 self.last_error.as_deref()
94 }
95
96 pub fn tools(&self) -> &[DiscoveredTool] {
98 &self.tools
99 }
100
101 pub fn instructions(&self) -> Option<&str> {
104 self.instructions.as_deref()
105 }
106
107 pub async fn connect(&mut self) -> Result<()> {
111 self.status = McpClientStatus::Connecting;
112 self.last_error = None;
113
114 let timeout = Duration::from_secs(self.config.startup_timeout_sec);
115
116 match tokio::time::timeout(timeout, self.connect_inner()).await {
117 Ok(Ok(())) => {
118 let transport_label = match &self.config.transport {
119 McpTransport::Stdio { .. } => "stdio",
120 McpTransport::Http { .. } => "http",
121 };
122 self.status = McpClientStatus::Connected;
123 tracing::info!(
124 server = %self.name,
125 transport = transport_label,
126 tools = self.tools.len(),
127 "MCP server connected"
128 );
129 Ok(())
130 }
131 Ok(Err(e)) => {
132 self.status = McpClientStatus::Failed;
133 self.last_error = Some(e.to_string());
134 tracing::warn!(
135 server = %self.name,
136 error = %e,
137 "MCP server connection failed"
138 );
139 Err(e)
140 }
141 Err(_) => {
142 self.status = McpClientStatus::Failed;
143 let msg = format!(
144 "MCP server '{}' startup timed out after {}s",
145 self.name, self.config.startup_timeout_sec
146 );
147 self.last_error = Some(msg.clone());
148 tracing::warn!(server = %self.name, "{msg}");
149 Err(anyhow::anyhow!(msg))
150 }
151 }
152 }
153
154 async fn connect_inner(&mut self) -> Result<()> {
156 let transport = self.config.transport.clone();
158 match transport {
159 McpTransport::Stdio {
160 ref command,
161 ref args,
162 ref env,
163 ref cwd,
164 } => self.connect_stdio(command, args, env, cwd.as_deref()).await,
165 McpTransport::Http {
166 ref url,
167 ref bearer_token,
168 ref headers,
169 } => {
170 self.connect_http(url, bearer_token.as_deref(), headers)
171 .await
172 }
173 }
174 }
175
176 async fn connect_stdio(
178 &mut self,
179 command: &str,
180 args: &[String],
181 env: &std::collections::HashMap<String, String>,
182 cwd: Option<&str>,
183 ) -> Result<()> {
184 let mut cmd = Command::new(command);
185 cmd.args(args);
186 for (key, val) in env {
187 cmd.env(key, val);
188 }
189 if let Some(cwd) = cwd {
190 cmd.current_dir(cwd);
191 }
192
193 let transport =
194 TokioChildProcess::new(cmd).context("failed to spawn MCP server process")?;
195 let service = ().serve(transport).await.context("MCP handshake failed")?;
196 self.finish_handshake(service).await
197 }
198
199 async fn connect_http(
201 &mut self,
202 url: &str,
203 bearer_token: Option<&str>,
204 headers: &std::collections::HashMap<String, String>,
205 ) -> Result<()> {
206 use http::{HeaderName, HeaderValue};
207 use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
208
209 if !is_safe_url(url) {
211 anyhow::bail!(
212 "MCP HTTP URL '{url}' is not allowed: private, loopback, or link-local \
213 addresses are blocked to prevent SSRF attacks"
214 );
215 }
216
217 if bearer_token.is_some() && url.starts_with("http://") {
219 tracing::warn!(
220 server = %self.name,
221 url = %url,
222 "MCP bearer token is being sent over plaintext HTTP — use HTTPS in production"
223 );
224 }
225
226 let mut config = StreamableHttpClientTransportConfig::with_uri(url);
227
228 if let Some(token) = bearer_token {
232 config.auth_header = Some(token.to_string());
233 }
234
235 if !headers.is_empty() {
237 let mut header_map = std::collections::HashMap::new();
238 for (k, v) in headers {
239 let name = HeaderName::try_from(k.as_str())
240 .with_context(|| format!("invalid HTTP header name: {k}"))?;
241 let value = HeaderValue::try_from(v.as_str())
242 .with_context(|| format!("invalid HTTP header value for {k}"))?;
243 header_map.insert(name, value);
244 }
245 config.custom_headers = header_map;
246 }
247
248 config.reinit_on_expired_session = true;
250
251 let transport = StreamableHttpClientTransport::from_config(config);
252 let service = ().serve(transport).await.context("MCP HTTP handshake failed")?;
253 self.finish_handshake(service).await
254 }
255
256 async fn finish_handshake(
259 &mut self,
260 service: RunningService<rmcp::service::RoleClient, ()>,
261 ) -> Result<()> {
262 self.instructions = service
265 .peer()
266 .peer_info()
267 .and_then(|info| info.instructions.clone())
268 .filter(|s| !s.trim().is_empty());
269 self.service = Some(service);
270 self.discover_tools().await
271 }
272
273 async fn discover_tools(&mut self) -> Result<()> {
275 let service = self.service.as_ref().context("not connected")?;
276
277 let result = service
278 .list_tools(Default::default())
279 .await
280 .context("failed to list MCP tools")?;
281
282 self.tools.clear();
283
284 for tool in result.tools {
285 let tool_name: &str = &tool.name;
286
287 if !self.config.is_tool_allowed(tool_name) {
289 tracing::debug!(
290 server = %self.name,
291 tool = %tool_name,
292 "MCP tool filtered out by config"
293 );
294 continue;
295 }
296
297 let (definition, annotations) =
298 super::tool_bridge::mcp_tool_to_definition(&self.name, &tool);
299
300 self.tools.push(DiscoveredTool {
301 definition,
302 annotations,
303 original_name: tool_name.to_string(),
304 });
305 }
306
307 Ok(())
308 }
309
310 pub async fn call_tool(
315 &self,
316 tool_name: &str,
317 arguments: Value,
318 ) -> Result<rmcp::model::CallToolResult> {
319 let service = self.service.as_ref().context("MCP server not connected")?;
320
321 let timeout = Duration::from_secs(self.config.tool_timeout_sec);
322
323 let mut params = CallToolRequestParams::new(tool_name.to_string());
324 if let Value::Object(map) = arguments {
325 params.arguments = Some(map);
326 }
327
328 let result = tokio::time::timeout(timeout, service.call_tool(params))
329 .await
330 .map_err(|_| {
331 anyhow::anyhow!(
332 "MCP tool call '{}' on server '{}' timed out after {}s",
333 tool_name,
334 self.name,
335 self.config.tool_timeout_sec
336 )
337 })?
338 .context("MCP tool call failed")?;
339
340 Ok(result)
341 }
342
343 pub async fn disconnect(&mut self) {
345 if let Some(service) = self.service.take() {
346 drop(service);
348 }
349 self.tools.clear();
350 self.status = McpClientStatus::Disconnected;
351 self.last_error = None;
352 tracing::info!(server = %self.name, "MCP server disconnected");
353 }
354}
355
356impl McpClient {
357 #[cfg(feature = "test-support")]
359 pub fn set_status_for_test(&mut self, status: McpClientStatus) {
360 self.status = status;
361 }
362
363 #[cfg(feature = "test-support")]
365 pub fn set_instructions_for_test(&mut self, instructions: Option<String>) {
366 self.instructions = instructions;
367 }
368
369 #[cfg(feature = "test-support")]
371 pub fn set_last_error_for_test(&mut self, err: Option<String>) {
372 self.last_error = err;
373 }
374
375 #[cfg(feature = "test-support")]
381 pub fn set_tools_for_test(&mut self, tools: Vec<DiscoveredTool>) {
382 self.tools = tools;
383 }
384}
385
386impl Drop for McpClient {
387 fn drop(&mut self) {
388 if self.service.is_some() {
390 tracing::debug!(server = %self.name, "McpClient dropped while still connected");
391 }
392 }
393}