mcp_sse_proxy/
client.rs

1//! SSE Client Connection Module
2//!
3//! Provides a high-level API for connecting to MCP servers via SSE protocol.
4//! This module encapsulates the rmcp transport details and exposes a simple interface.
5
6use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9    RoleClient, ServiceExt,
10    model::{ClientCapabilities, ClientInfo, Implementation},
11    service::RunningService,
12    transport::{SseClientTransport, sse_client::SseClientConfig},
13};
14use std::time::Instant;
15use tracing::{info, debug};
16
17use crate::sse_handler::SseHandler;
18use mcp_common::ToolFilter;
19
20/// Opaque wrapper for SSE client connection
21///
22/// This type encapsulates an active connection to an MCP server via SSE protocol.
23/// It hides the internal `RunningService` type and provides only the methods
24/// needed by consuming code.
25///
26/// Note: This type is not Clone because the underlying RunningService
27/// is designed for single-owner use. Use `into_handler()` or `into_running_service()`
28/// to consume the connection.
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use mcp_sse_proxy::{SseClientConnection, McpClientConfig};
34///
35/// let config = McpClientConfig::new("http://localhost:8080/sse")
36///     .with_header("Authorization", "Bearer token");
37///
38/// let conn = SseClientConnection::connect(config).await?;
39/// let tools = conn.list_tools().await?;
40/// println!("Available tools: {:?}", tools);
41/// ```
42pub struct SseClientConnection {
43    inner: RunningService<RoleClient, ClientInfo>,
44}
45
46impl SseClientConnection {
47    /// Connect to an SSE MCP server
48    ///
49    /// # Arguments
50    /// * `config` - Client configuration including URL and headers
51    ///
52    /// # Returns
53    /// * `Ok(SseClientConnection)` - Successfully connected client
54    /// * `Err` - Connection failed
55    pub async fn connect(config: McpClientConfig) -> Result<Self> {
56        let start = Instant::now();
57        info!("🔗 开始建立 SSE 连接: {}", config.url);
58        
59        debug!("构建 HTTP 客户端配置...");
60        let http_client = build_http_client(&config)?;
61
62        let sse_config = SseClientConfig {
63            sse_endpoint: config.url.clone().into(),
64            ..Default::default()
65        };
66
67        debug!("启动 SSE 传输层...");
68        let transport: SseClientTransport<reqwest::Client> =
69            SseClientTransport::start_with_client(http_client, sse_config)
70                .await
71                .context("Failed to start SSE transport")?;
72        
73        let transport_elapsed = start.elapsed();
74        debug!("SSE 传输层启动完成,耗时: {:?}", transport_elapsed);
75
76        debug!("初始化 MCP 客户端握手...");
77        let client_info = create_default_client_info();
78        let running = client_info
79            .serve(transport)
80            .await
81            .context("Failed to initialize MCP client")?;
82
83        let total_elapsed = start.elapsed();
84        info!("✅ SSE 连接建立成功,总耗时: {:?} (传输层: {:?}, 握手: {:?})", 
85              total_elapsed, transport_elapsed, total_elapsed - transport_elapsed);
86
87        Ok(Self { inner: running })
88    }
89
90    /// List available tools from the MCP server
91    pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
92        let result = self.inner.list_tools(None).await?;
93        Ok(result
94            .tools
95            .into_iter()
96            .map(|t| ToolInfo {
97                name: t.name.to_string(),
98                description: t.description.map(|d| d.to_string()),
99            })
100            .collect())
101    }
102
103    /// Check if the connection is closed
104    pub fn is_closed(&self) -> bool {
105        use std::ops::Deref;
106        self.inner.deref().is_transport_closed()
107    }
108
109    /// Get the peer info from the server
110    pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
111        self.inner.peer_info()
112    }
113
114    /// Convert this connection into an SseHandler for serving
115    ///
116    /// This consumes the connection and creates an SseHandler that can
117    /// proxy requests to the backend MCP server.
118    ///
119    /// # Arguments
120    /// * `mcp_id` - Identifier for logging purposes
121    /// * `tool_filter` - Tool filtering configuration
122    pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
123        SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
124    }
125
126    /// Extract the internal RunningService for use with swap_backend
127    ///
128    /// This is used internally to support backend hot-swapping.
129    pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
130        self.inner
131    }
132}
133
134/// Simplified tool information
135#[derive(Clone, Debug)]
136pub struct ToolInfo {
137    /// Tool name
138    pub name: String,
139    /// Tool description (optional)
140    pub description: Option<String>,
141}
142
143/// Build an HTTP client with the given configuration
144fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
145    let mut headers = reqwest::header::HeaderMap::new();
146    for (key, value) in &config.headers {
147        let header_name = key
148            .parse::<reqwest::header::HeaderName>()
149            .with_context(|| format!("Invalid header name: {}", key))?;
150        let header_value = value
151            .parse()
152            .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
153        headers.insert(header_name, header_value);
154    }
155
156    let mut builder = reqwest::Client::builder().default_headers(headers);
157
158    if let Some(timeout) = config.connect_timeout {
159        builder = builder.connect_timeout(timeout);
160    }
161
162    if let Some(timeout) = config.read_timeout {
163        builder = builder.timeout(timeout);
164    }
165
166    builder.build().context("Failed to build HTTP client")
167}
168
169/// Create default client info for MCP handshake
170fn create_default_client_info() -> ClientInfo {
171    ClientInfo {
172        protocol_version: Default::default(),
173        capabilities: ClientCapabilities::builder()
174            .enable_experimental()
175            .enable_roots()
176            .enable_roots_list_changed()
177            .enable_sampling()
178            .build(),
179        client_info: Implementation {
180            name: "mcp-sse-proxy-client".to_string(),
181            version: env!("CARGO_PKG_VERSION").to_string(),
182            title: None,
183            website_url: None,
184            icons: None,
185        },
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_tool_info() {
195        let info = ToolInfo {
196            name: "test_tool".to_string(),
197            description: Some("A test tool".to_string()),
198        };
199        assert_eq!(info.name, "test_tool");
200        assert_eq!(info.description, Some("A test tool".to_string()));
201    }
202}