use std::time::Duration;
use anyhow::{Context, Result};
use rmcp::ServiceExt;
use rmcp::model::CallToolRequestParams;
use rmcp::service::RunningService;
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::child_process::TokioChildProcess;
use serde_json::Value;
use tokio::process::Command;
use super::config::{McpServerConfig, McpTransport};
use super::tool_bridge::McpToolAnnotations;
use crate::providers::ToolDefinition;
use crate::tools::web_fetch::is_safe_url;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpClientStatus {
Disconnected,
Connecting,
Connected,
Failed,
}
#[derive(Debug, Clone)]
pub struct DiscoveredTool {
pub definition: ToolDefinition,
pub annotations: McpToolAnnotations,
pub original_name: String,
}
pub struct McpClient {
name: String,
config: McpServerConfig,
service: Option<RunningService<rmcp::service::RoleClient, ()>>,
tools: Vec<DiscoveredTool>,
status: McpClientStatus,
last_error: Option<String>,
}
impl McpClient {
pub fn new(name: String, config: McpServerConfig) -> Self {
Self {
name,
config,
service: None,
tools: Vec::new(),
status: McpClientStatus::Disconnected,
last_error: None,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn status(&self) -> McpClientStatus {
self.status
}
pub fn last_error(&self) -> Option<&str> {
self.last_error.as_deref()
}
pub fn tools(&self) -> &[DiscoveredTool] {
&self.tools
}
pub async fn connect(&mut self) -> Result<()> {
self.status = McpClientStatus::Connecting;
self.last_error = None;
let timeout = Duration::from_secs(self.config.startup_timeout_sec);
match tokio::time::timeout(timeout, self.connect_inner()).await {
Ok(Ok(())) => {
let transport_label = match &self.config.transport {
McpTransport::Stdio { .. } => "stdio",
McpTransport::Http { .. } => "http",
};
self.status = McpClientStatus::Connected;
tracing::info!(
server = %self.name,
transport = transport_label,
tools = self.tools.len(),
"MCP server connected"
);
Ok(())
}
Ok(Err(e)) => {
self.status = McpClientStatus::Failed;
self.last_error = Some(e.to_string());
tracing::warn!(
server = %self.name,
error = %e,
"MCP server connection failed"
);
Err(e)
}
Err(_) => {
self.status = McpClientStatus::Failed;
let msg = format!(
"MCP server '{}' startup timed out after {}s",
self.name, self.config.startup_timeout_sec
);
self.last_error = Some(msg.clone());
tracing::warn!(server = %self.name, "{msg}");
Err(anyhow::anyhow!(msg))
}
}
}
async fn connect_inner(&mut self) -> Result<()> {
let transport = self.config.transport.clone();
match transport {
McpTransport::Stdio {
ref command,
ref args,
ref env,
ref cwd,
} => self.connect_stdio(command, args, env, cwd.as_deref()).await,
McpTransport::Http {
ref url,
ref bearer_token,
ref headers,
} => {
self.connect_http(url, bearer_token.as_deref(), headers)
.await
}
}
}
async fn connect_stdio(
&mut self,
command: &str,
args: &[String],
env: &std::collections::HashMap<String, String>,
cwd: Option<&str>,
) -> Result<()> {
let mut cmd = Command::new(command);
cmd.args(args);
for (key, val) in env {
cmd.env(key, val);
}
if let Some(cwd) = cwd {
cmd.current_dir(cwd);
}
let transport =
TokioChildProcess::new(cmd).context("failed to spawn MCP server process")?;
let service = ().serve(transport).await.context("MCP handshake failed")?;
self.service = Some(service);
self.discover_tools().await
}
async fn connect_http(
&mut self,
url: &str,
bearer_token: Option<&str>,
headers: &std::collections::HashMap<String, String>,
) -> Result<()> {
use http::{HeaderName, HeaderValue};
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
if !is_safe_url(url) {
anyhow::bail!(
"MCP HTTP URL '{url}' is not allowed: private, loopback, or link-local \
addresses are blocked to prevent SSRF attacks"
);
}
if bearer_token.is_some() && url.starts_with("http://") {
tracing::warn!(
server = %self.name,
url = %url,
"MCP bearer token is being sent over plaintext HTTP — use HTTPS in production"
);
}
let mut config = StreamableHttpClientTransportConfig::with_uri(url);
if let Some(token) = bearer_token {
config.auth_header = Some(token.to_string());
}
if !headers.is_empty() {
let mut header_map = std::collections::HashMap::new();
for (k, v) in headers {
let name = HeaderName::try_from(k.as_str())
.with_context(|| format!("invalid HTTP header name: {k}"))?;
let value = HeaderValue::try_from(v.as_str())
.with_context(|| format!("invalid HTTP header value for {k}"))?;
header_map.insert(name, value);
}
config.custom_headers = header_map;
}
config.reinit_on_expired_session = true;
let transport = StreamableHttpClientTransport::from_config(config);
let service = ().serve(transport).await.context("MCP HTTP handshake failed")?;
self.service = Some(service);
self.discover_tools().await
}
async fn discover_tools(&mut self) -> Result<()> {
let service = self.service.as_ref().context("not connected")?;
let result = service
.list_tools(Default::default())
.await
.context("failed to list MCP tools")?;
self.tools.clear();
for tool in result.tools {
let tool_name: &str = &tool.name;
if !self.config.is_tool_allowed(tool_name) {
tracing::debug!(
server = %self.name,
tool = %tool_name,
"MCP tool filtered out by config"
);
continue;
}
let (definition, annotations) =
super::tool_bridge::mcp_tool_to_definition(&self.name, &tool);
self.tools.push(DiscoveredTool {
definition,
annotations,
original_name: tool_name.to_string(),
});
}
Ok(())
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: Value,
) -> Result<rmcp::model::CallToolResult> {
let service = self.service.as_ref().context("MCP server not connected")?;
let timeout = Duration::from_secs(self.config.tool_timeout_sec);
let mut params = CallToolRequestParams::new(tool_name.to_string());
if let Value::Object(map) = arguments {
params.arguments = Some(map);
}
let result = tokio::time::timeout(timeout, service.call_tool(params))
.await
.map_err(|_| {
anyhow::anyhow!(
"MCP tool call '{}' on server '{}' timed out after {}s",
tool_name,
self.name,
self.config.tool_timeout_sec
)
})?
.context("MCP tool call failed")?;
Ok(result)
}
pub async fn disconnect(&mut self) {
if let Some(service) = self.service.take() {
drop(service);
}
self.tools.clear();
self.status = McpClientStatus::Disconnected;
self.last_error = None;
tracing::info!(server = %self.name, "MCP server disconnected");
}
}
impl McpClient {
#[cfg(feature = "test-support")]
pub fn set_status_for_test(&mut self, status: McpClientStatus) {
self.status = status;
}
#[cfg(feature = "test-support")]
pub fn set_last_error_for_test(&mut self, err: Option<String>) {
self.last_error = err;
}
}
impl Drop for McpClient {
fn drop(&mut self) {
if self.service.is_some() {
tracing::debug!(server = %self.name, "McpClient dropped while still connected");
}
}
}