use anyhow::{Context, Result};
use mcp_common::McpClientConfig;
use rmcp::{
RoleClient, ServiceExt,
model::{ClientCapabilities, ClientInfo, Implementation, ProtocolVersion},
service::RunningService,
transport::{
SseClientTransport, common::client_side_sse::SseRetryPolicy, sse_client::SseClientConfig,
},
};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tracing::{debug, info};
use crate::sse_handler::SseHandler;
use mcp_common::ToolFilter;
#[derive(Debug, Clone)]
pub struct CappedExponentialBackoff {
pub max_times: Option<usize>,
pub base_duration: Duration,
pub max_interval: Duration,
}
impl CappedExponentialBackoff {
pub fn new(max_times: Option<usize>, base_duration: Duration, max_interval: Duration) -> Self {
Self {
max_times,
base_duration,
max_interval,
}
}
}
impl Default for CappedExponentialBackoff {
fn default() -> Self {
Self {
max_times: None,
base_duration: Duration::from_secs(1),
max_interval: Duration::from_secs(60),
}
}
}
impl SseRetryPolicy for CappedExponentialBackoff {
fn retry(&self, current_times: usize) -> Option<Duration> {
if let Some(max_times) = self.max_times
&& current_times >= max_times
{
return None;
}
let exponential_delay = self.base_duration * (2u32.pow(current_times as u32));
Some(exponential_delay.min(self.max_interval))
}
}
pub struct SseClientConnection {
inner: RunningService<RoleClient, ClientInfo>,
}
impl SseClientConnection {
pub async fn connect(config: McpClientConfig) -> Result<Self> {
let start = Instant::now();
info!("🔗 开始建立 SSE 连接: {}", config.url);
debug!("构建 HTTP 客户端配置...");
let http_client = build_http_client(&config)?;
let retry_policy = CappedExponentialBackoff::new(
None, Duration::from_secs(1), Duration::from_secs(60), );
let sse_config = SseClientConfig {
sse_endpoint: config.url.clone().into(),
retry_policy: Arc::new(retry_policy),
..Default::default()
};
debug!("启动 SSE 传输层...");
let transport: SseClientTransport<reqwest::Client> =
SseClientTransport::start_with_client(http_client, sse_config)
.await
.context("Failed to start SSE transport")?;
let transport_elapsed = start.elapsed();
debug!("SSE 传输层启动完成,耗时: {:?}", transport_elapsed);
debug!("初始化 MCP 客户端握手...");
let client_info = create_default_client_info();
let running = client_info
.serve(transport)
.await
.context("Failed to initialize MCP client")?;
let total_elapsed = start.elapsed();
{
use std::ops::Deref;
let transport_closed = running.deref().is_transport_closed();
let peer_info = running.peer_info();
info!(
"✅ SSE 连接建立成功 - 总耗时: {:?}, transport_closed: {}, peer_info: {:?}",
total_elapsed, transport_closed, peer_info
);
if let Some(info) = peer_info {
info!(
" 服务器信息: name={}, version={}, capabilities={:?}",
info.server_info.name, info.server_info.version, info.capabilities
);
}
}
Ok(Self { inner: running })
}
pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
let result = self.inner.list_tools(None).await?;
Ok(result
.tools
.into_iter()
.map(|t| ToolInfo {
name: t.name.to_string(),
description: t.description.map(|d| d.to_string()),
})
.collect())
}
pub fn is_closed(&self) -> bool {
use std::ops::Deref;
self.inner.deref().is_transport_closed()
}
pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
self.inner.peer_info()
}
pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
}
pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
self.inner
}
}
#[derive(Clone, Debug)]
pub struct ToolInfo {
pub name: String,
pub description: Option<String>,
}
fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in &config.headers {
let header_name = key
.parse::<reqwest::header::HeaderName>()
.with_context(|| format!("Invalid header name: {}", key))?;
let header_value = value
.parse()
.with_context(|| format!("Invalid header value for {}: {}", key, value))?;
headers.insert(header_name, header_value);
}
let mut builder = reqwest::Client::builder().default_headers(headers);
if let Some(timeout) = config.connect_timeout {
builder = builder.connect_timeout(timeout);
}
if let Some(timeout) = config.read_timeout {
builder = builder.timeout(timeout);
}
builder.build().context("Failed to build HTTP client")
}
fn create_default_client_info() -> ClientInfo {
ClientInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
client_info: Implementation {
name: "mcp-sse-proxy-client".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
title: None,
website_url: None,
icons: None,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_info() {
let info = ToolInfo {
name: "test_tool".to_string(),
description: Some("A test tool".to_string()),
};
assert_eq!(info.name, "test_tool");
assert_eq!(info.description, Some("A test tool".to_string()));
}
#[test]
fn test_capped_exponential_backoff() {
let policy = CappedExponentialBackoff::new(
None, Duration::from_secs(1), Duration::from_secs(60), );
assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
assert_eq!(policy.retry(3), Some(Duration::from_secs(8)));
assert_eq!(policy.retry(4), Some(Duration::from_secs(16)));
assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
assert_eq!(policy.retry(6), Some(Duration::from_secs(60)));
assert_eq!(policy.retry(7), Some(Duration::from_secs(60)));
}
#[test]
fn test_capped_exponential_backoff_with_max_times() {
let policy = CappedExponentialBackoff::new(
Some(3), Duration::from_secs(1), Duration::from_secs(60), );
assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
assert_eq!(policy.retry(3), None);
}
#[test]
fn test_capped_exponential_backoff_default() {
let policy = CappedExponentialBackoff::default();
assert_eq!(policy.max_times, None);
assert_eq!(policy.base_duration, Duration::from_secs(1));
assert_eq!(policy.max_interval, Duration::from_secs(60));
assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
assert_eq!(policy.retry(10), Some(Duration::from_secs(60)));
}
}