use anyhow::{Context, Result};
use mcp_common::McpClientConfig;
use rmcp::{
RoleClient, ServiceExt,
model::{ClientCapabilities, ClientInfo, Implementation},
service::RunningService,
transport::{SseClientTransport, sse_client::SseClientConfig},
};
use std::time::Instant;
use tracing::{debug, info};
use crate::sse_handler::SseHandler;
use mcp_common::ToolFilter;
pub struct SseClientConnection {
inner: RunningService<RoleClient, ClientInfo>,
}
impl SseClientConnection {
pub async fn connect(config: McpClientConfig) -> Result<Self> {
let start = Instant::now();
info!("🔗 开始建立 SSE 连接: {}", config.url);
debug!("构建 HTTP 客户端配置...");
let http_client = build_http_client(&config)?;
let sse_config = SseClientConfig {
sse_endpoint: config.url.clone().into(),
..Default::default()
};
debug!("启动 SSE 传输层...");
let transport: SseClientTransport<reqwest::Client> =
SseClientTransport::start_with_client(http_client, sse_config)
.await
.context("Failed to start SSE transport")?;
let transport_elapsed = start.elapsed();
debug!("SSE 传输层启动完成,耗时: {:?}", transport_elapsed);
debug!("初始化 MCP 客户端握手...");
let client_info = create_default_client_info();
let running = client_info
.serve(transport)
.await
.context("Failed to initialize MCP client")?;
let total_elapsed = start.elapsed();
info!(
"✅ SSE 连接建立成功,总耗时: {:?} (传输层: {:?}, 握手: {:?})",
total_elapsed,
transport_elapsed,
total_elapsed - transport_elapsed
);
Ok(Self { inner: running })
}
pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
let result = self.inner.list_tools(None).await?;
Ok(result
.tools
.into_iter()
.map(|t| ToolInfo {
name: t.name.to_string(),
description: t.description.map(|d| d.to_string()),
})
.collect())
}
pub fn is_closed(&self) -> bool {
use std::ops::Deref;
self.inner.deref().is_transport_closed()
}
pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
self.inner.peer_info()
}
pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
}
pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
self.inner
}
}
#[derive(Clone, Debug)]
pub struct ToolInfo {
pub name: String,
pub description: Option<String>,
}
fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in &config.headers {
let header_name = key
.parse::<reqwest::header::HeaderName>()
.with_context(|| format!("Invalid header name: {}", key))?;
let header_value = value
.parse()
.with_context(|| format!("Invalid header value for {}: {}", key, value))?;
headers.insert(header_name, header_value);
}
let mut builder = reqwest::Client::builder().default_headers(headers);
if let Some(timeout) = config.connect_timeout {
builder = builder.connect_timeout(timeout);
}
if let Some(timeout) = config.read_timeout {
builder = builder.timeout(timeout);
}
builder.build().context("Failed to build HTTP client")
}
fn create_default_client_info() -> ClientInfo {
ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
client_info: Implementation {
name: "mcp-sse-proxy-client".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
title: None,
website_url: None,
icons: None,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_info() {
let info = ToolInfo {
name: "test_tool".to_string(),
description: Some("A test tool".to_string()),
};
assert_eq!(info.name, "test_tool");
assert_eq!(info.description, Some("A test tool".to_string()));
}
}