1use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9 RoleClient, ServiceExt,
10 model::{ClientCapabilities, ClientInfo, Implementation},
11 service::RunningService,
12 transport::{SseClientTransport, sse_client::SseClientConfig},
13};
14use std::time::Instant;
15use tracing::{info, debug};
16
17use crate::sse_handler::SseHandler;
18use mcp_common::ToolFilter;
19
20pub struct SseClientConnection {
43 inner: RunningService<RoleClient, ClientInfo>,
44}
45
46impl SseClientConnection {
47 pub async fn connect(config: McpClientConfig) -> Result<Self> {
56 let start = Instant::now();
57 info!("🔗 开始建立 SSE 连接: {}", config.url);
58
59 debug!("构建 HTTP 客户端配置...");
60 let http_client = build_http_client(&config)?;
61
62 let sse_config = SseClientConfig {
63 sse_endpoint: config.url.clone().into(),
64 ..Default::default()
65 };
66
67 debug!("启动 SSE 传输层...");
68 let transport: SseClientTransport<reqwest::Client> =
69 SseClientTransport::start_with_client(http_client, sse_config)
70 .await
71 .context("Failed to start SSE transport")?;
72
73 let transport_elapsed = start.elapsed();
74 debug!("SSE 传输层启动完成,耗时: {:?}", transport_elapsed);
75
76 debug!("初始化 MCP 客户端握手...");
77 let client_info = create_default_client_info();
78 let running = client_info
79 .serve(transport)
80 .await
81 .context("Failed to initialize MCP client")?;
82
83 let total_elapsed = start.elapsed();
84 info!("✅ SSE 连接建立成功,总耗时: {:?} (传输层: {:?}, 握手: {:?})",
85 total_elapsed, transport_elapsed, total_elapsed - transport_elapsed);
86
87 Ok(Self { inner: running })
88 }
89
90 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
92 let result = self.inner.list_tools(None).await?;
93 Ok(result
94 .tools
95 .into_iter()
96 .map(|t| ToolInfo {
97 name: t.name.to_string(),
98 description: t.description.map(|d| d.to_string()),
99 })
100 .collect())
101 }
102
103 pub fn is_closed(&self) -> bool {
105 use std::ops::Deref;
106 self.inner.deref().is_transport_closed()
107 }
108
109 pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
111 self.inner.peer_info()
112 }
113
114 pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
123 SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
124 }
125
126 pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
130 self.inner
131 }
132}
133
134#[derive(Clone, Debug)]
136pub struct ToolInfo {
137 pub name: String,
139 pub description: Option<String>,
141}
142
143fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
145 let mut headers = reqwest::header::HeaderMap::new();
146 for (key, value) in &config.headers {
147 let header_name = key
148 .parse::<reqwest::header::HeaderName>()
149 .with_context(|| format!("Invalid header name: {}", key))?;
150 let header_value = value
151 .parse()
152 .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
153 headers.insert(header_name, header_value);
154 }
155
156 let mut builder = reqwest::Client::builder().default_headers(headers);
157
158 if let Some(timeout) = config.connect_timeout {
159 builder = builder.connect_timeout(timeout);
160 }
161
162 if let Some(timeout) = config.read_timeout {
163 builder = builder.timeout(timeout);
164 }
165
166 builder.build().context("Failed to build HTTP client")
167}
168
169fn create_default_client_info() -> ClientInfo {
171 ClientInfo {
172 protocol_version: Default::default(),
173 capabilities: ClientCapabilities::builder()
174 .enable_experimental()
175 .enable_roots()
176 .enable_roots_list_changed()
177 .enable_sampling()
178 .build(),
179 client_info: Implementation {
180 name: "mcp-sse-proxy-client".to_string(),
181 version: env!("CARGO_PKG_VERSION").to_string(),
182 title: None,
183 website_url: None,
184 icons: None,
185 },
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_tool_info() {
195 let info = ToolInfo {
196 name: "test_tool".to_string(),
197 description: Some("A test tool".to_string()),
198 };
199 assert_eq!(info.name, "test_tool");
200 assert_eq!(info.description, Some("A test tool".to_string()));
201 }
202}