Skip to main content

mcp_streamable_proxy/
client.rs

1//! Streamable HTTP Client Connection Module
2//!
3//! Provides a high-level API for connecting to MCP servers via Streamable HTTP protocol.
4//! This module encapsulates the rmcp 0.12 transport details and exposes a simple interface.
5
6use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9    RoleClient, ServiceExt,
10    model::{ClientCapabilities, ClientInfo, Implementation},
11    service::RunningService,
12    transport::{
13        common::client_side_sse::SseRetryPolicy,
14        streamable_http_client::{
15            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
16        },
17    },
18};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::proxy_handler::ProxyHandler;
23use mcp_common::ToolFilter;
24
25/// 自定义的指数退避重试策略,支持最大间隔限制
26///
27/// 重试间隔按照指数增长,但不会超过 max_interval
28/// - 第 1 次重试:base_duration × 2^0
29/// - 第 2 次重试:base_duration × 2^1
30/// - ...
31/// - 第 n 次重试:min(base_duration × 2^(n-1), max_interval)
32#[derive(Debug, Clone)]
33pub struct CappedExponentialBackoff {
34    /// 最大重试次数,None 表示无限制
35    pub max_times: Option<usize>,
36    /// 基础延迟时间(第一次重试前的等待时间)
37    pub base_duration: Duration,
38    /// 最大延迟间隔(重试间隔不会超过这个值)
39    pub max_interval: Duration,
40}
41
42impl CappedExponentialBackoff {
43    /// 创建一个新的带上限的指数退避策略
44    ///
45    /// # Arguments
46    /// * `max_times` - 最大重试次数,None 表示无限制
47    /// * `base_duration` - 基础延迟时间
48    /// * `max_interval` - 最大延迟间隔
49    pub fn new(max_times: Option<usize>, base_duration: Duration, max_interval: Duration) -> Self {
50        Self {
51            max_times,
52            base_duration,
53            max_interval,
54        }
55    }
56}
57
58impl Default for CappedExponentialBackoff {
59    fn default() -> Self {
60        Self {
61            max_times: None,
62            base_duration: Duration::from_secs(1),
63            max_interval: Duration::from_secs(60),
64        }
65    }
66}
67
68impl SseRetryPolicy for CappedExponentialBackoff {
69    fn retry(&self, current_times: usize) -> Option<Duration> {
70        // 检查是否超过最大重试次数
71        if let Some(max_times) = self.max_times
72            && current_times >= max_times
73        {
74            return None;
75        }
76
77        // 计算指数退避时间
78        let exponential_delay = self.base_duration * (2u32.pow(current_times as u32));
79
80        // 限制最大间隔
81        Some(exponential_delay.min(self.max_interval))
82    }
83}
84
85/// Opaque wrapper for Streamable HTTP client connection
86///
87/// This type encapsulates an active connection to an MCP server via Streamable HTTP protocol.
88/// It hides the internal `RunningService` type and provides only the methods
89/// needed by consuming code.
90///
91/// Note: This type is not Clone because the underlying RunningService
92/// is designed for single-owner use. Use `into_handler()` or `into_running_service()`
93/// to consume the connection.
94///
95/// # Example
96///
97/// ```rust,ignore
98/// use mcp_streamable_proxy::{StreamClientConnection, McpClientConfig};
99///
100/// let config = McpClientConfig::new("http://localhost:8080/mcp")
101///     .with_header("Authorization", "Bearer token");
102///
103/// let conn = StreamClientConnection::connect(config).await?;
104/// let tools = conn.list_tools().await?;
105/// println!("Available tools: {:?}", tools);
106/// ```
107pub struct StreamClientConnection {
108    inner: RunningService<RoleClient, ClientInfo>,
109}
110
111impl StreamClientConnection {
112    /// Connect to a Streamable HTTP MCP server
113    ///
114    /// # Arguments
115    /// * `config` - Client configuration including URL and headers
116    ///
117    /// # Returns
118    /// * `Ok(StreamClientConnection)` - Successfully connected client
119    /// * `Err` - Connection failed
120    pub async fn connect(config: McpClientConfig) -> Result<Self> {
121        let http_client = build_http_client(&config)?;
122
123        // 配置指数退避重试策略,最大间隔 1 分钟,不限制重试次数
124        let retry_policy = CappedExponentialBackoff::new(
125            None,                    // 不限制重试次数
126            Duration::from_secs(1),  // 基础延迟 1 秒
127            Duration::from_secs(60), // 最大间隔 60 秒
128        );
129
130        let transport_config = StreamableHttpClientTransportConfig {
131            uri: config.url.clone().into(),
132            retry_config: Arc::new(retry_policy),
133            ..Default::default()
134        };
135
136        let transport = StreamableHttpClientTransport::with_client(http_client, transport_config);
137
138        let client_info = create_default_client_info();
139        let running = client_info
140            .serve(transport)
141            .await
142            .context("Failed to initialize MCP client")?;
143
144        Ok(Self { inner: running })
145    }
146
147    /// List available tools from the MCP server
148    pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
149        let result = self.inner.list_tools(None).await?;
150        Ok(result
151            .tools
152            .into_iter()
153            .map(|t| ToolInfo {
154                name: t.name.to_string(),
155                description: t.description.map(|d| d.to_string()),
156            })
157            .collect())
158    }
159
160    /// Check if the connection is closed
161    pub fn is_closed(&self) -> bool {
162        use std::ops::Deref;
163        self.inner.deref().is_transport_closed()
164    }
165
166    /// Get the peer info from the server
167    pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
168        self.inner.peer_info()
169    }
170
171    /// Convert this connection into a ProxyHandler for serving
172    ///
173    /// This consumes the connection and creates a ProxyHandler that can
174    /// proxy requests to the backend MCP server.
175    ///
176    /// # Arguments
177    /// * `mcp_id` - Identifier for logging purposes
178    /// * `tool_filter` - Tool filtering configuration
179    pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> ProxyHandler {
180        ProxyHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
181    }
182
183    /// Extract the internal RunningService for use with swap_backend
184    ///
185    /// This is used internally to support backend hot-swapping.
186    pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
187        self.inner
188    }
189}
190
191/// Simplified tool information
192#[derive(Clone, Debug)]
193pub struct ToolInfo {
194    /// Tool name
195    pub name: String,
196    /// Tool description (optional)
197    pub description: Option<String>,
198}
199
200/// Build an HTTP client with the given configuration
201fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
202    let mut headers = reqwest::header::HeaderMap::new();
203    for (key, value) in &config.headers {
204        let header_name = key
205            .parse::<reqwest::header::HeaderName>()
206            .with_context(|| format!("Invalid header name: {}", key))?;
207        let header_value = value
208            .parse()
209            .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
210        headers.insert(header_name, header_value);
211    }
212
213    let mut builder = reqwest::Client::builder().default_headers(headers);
214
215    if let Some(timeout) = config.connect_timeout {
216        builder = builder.connect_timeout(timeout);
217    }
218
219    if let Some(timeout) = config.read_timeout {
220        builder = builder.timeout(timeout);
221    }
222
223    builder.build().context("Failed to build HTTP client")
224}
225
226/// Create default client info for MCP handshake
227fn create_default_client_info() -> ClientInfo {
228    let capabilities = ClientCapabilities::builder()
229        .enable_experimental()
230        .enable_roots()
231        .enable_roots_list_changed()
232        .enable_sampling()
233        .build();
234    ClientInfo::new(
235        capabilities,
236        Implementation::new("mcp-streamable-proxy-client", env!("CARGO_PKG_VERSION")),
237    )
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_tool_info() {
246        let info = ToolInfo {
247            name: "test_tool".to_string(),
248            description: Some("A test tool".to_string()),
249        };
250        assert_eq!(info.name, "test_tool");
251        assert_eq!(info.description, Some("A test tool".to_string()));
252    }
253
254    #[test]
255    fn test_capped_exponential_backoff() {
256        // 测试带上限的指数退避策略
257        let policy = CappedExponentialBackoff::new(
258            None,                    // 不限制重试次数
259            Duration::from_secs(1),  // 基础延迟 1 秒
260            Duration::from_secs(60), // 最大间隔 60 秒
261        );
262
263        // 验证第 1 次重试:1 秒
264        assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
265        // 验证第 2 次重试:2 秒
266        assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
267        // 验证第 3 次重试:4 秒
268        assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
269        // 验证第 4 次重试:8 秒
270        assert_eq!(policy.retry(3), Some(Duration::from_secs(8)));
271        // 验证第 5 次重试:16 秒
272        assert_eq!(policy.retry(4), Some(Duration::from_secs(16)));
273        // 验证第 6 次重试:32 秒
274        assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
275        // 验证第 7 次重试:64 秒 -> 会被限制为 60 秒
276        assert_eq!(policy.retry(6), Some(Duration::from_secs(60)));
277        // 验证第 8 次重试:128 秒 -> 会被限制为 60 秒
278        assert_eq!(policy.retry(7), Some(Duration::from_secs(60)));
279    }
280
281    #[test]
282    fn test_capped_exponential_backoff_with_max_times() {
283        // 测试带最大重试次数的限制
284        let policy = CappedExponentialBackoff::new(
285            Some(3),                 // 最多重试 3 次
286            Duration::from_secs(1),  // 基础延迟 1 秒
287            Duration::from_secs(60), // 最大间隔 60 秒
288        );
289
290        // 验证前 3 次重试都有延迟时间
291        assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
292        assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
293        assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
294
295        // 验证第 4 次重试(重试次数已达到上限)
296        assert_eq!(policy.retry(3), None);
297    }
298
299    #[test]
300    fn test_capped_exponential_backoff_default() {
301        // 测试默认配置
302        let policy = CappedExponentialBackoff::default();
303
304        // 验证默认配置
305        assert_eq!(policy.max_times, None);
306        assert_eq!(policy.base_duration, Duration::from_secs(1));
307        assert_eq!(policy.max_interval, Duration::from_secs(60));
308
309        // 验证重试行为
310        assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
311        assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
312        assert_eq!(policy.retry(10), Some(Duration::from_secs(60)));
313    }
314}