mcp_streamable_proxy/
client.rs1use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9 RoleClient, ServiceExt,
10 model::{ClientCapabilities, ClientInfo, Implementation},
11 service::RunningService,
12 transport::{
13 common::client_side_sse::SseRetryPolicy,
14 streamable_http_client::{
15 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
16 },
17 },
18};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::proxy_handler::ProxyHandler;
23use mcp_common::ToolFilter;
24
25#[derive(Debug, Clone)]
33pub struct CappedExponentialBackoff {
34 pub max_times: Option<usize>,
36 pub base_duration: Duration,
38 pub max_interval: Duration,
40}
41
42impl CappedExponentialBackoff {
43 pub fn new(max_times: Option<usize>, base_duration: Duration, max_interval: Duration) -> Self {
50 Self {
51 max_times,
52 base_duration,
53 max_interval,
54 }
55 }
56}
57
58impl Default for CappedExponentialBackoff {
59 fn default() -> Self {
60 Self {
61 max_times: None,
62 base_duration: Duration::from_secs(1),
63 max_interval: Duration::from_secs(60),
64 }
65 }
66}
67
68impl SseRetryPolicy for CappedExponentialBackoff {
69 fn retry(&self, current_times: usize) -> Option<Duration> {
70 if let Some(max_times) = self.max_times
72 && current_times >= max_times
73 {
74 return None;
75 }
76
77 let exponential_delay = self.base_duration * (2u32.pow(current_times as u32));
79
80 Some(exponential_delay.min(self.max_interval))
82 }
83}
84
85pub struct StreamClientConnection {
108 inner: RunningService<RoleClient, ClientInfo>,
109}
110
111impl StreamClientConnection {
112 pub async fn connect(config: McpClientConfig) -> Result<Self> {
121 let http_client = build_http_client(&config)?;
122
123 let retry_policy = CappedExponentialBackoff::new(
125 None, Duration::from_secs(1), Duration::from_secs(60), );
129
130 let mut transport_config = StreamableHttpClientTransportConfig::with_uri(config.url.clone());
131 transport_config.retry_config = Arc::new(retry_policy);
132
133 let transport = StreamableHttpClientTransport::with_client(http_client, transport_config);
134
135 let client_info = create_default_client_info();
136 let running = client_info
137 .serve(transport)
138 .await
139 .context("Failed to initialize MCP client")?;
140
141 Ok(Self { inner: running })
142 }
143
144 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
146 let result = self.inner.list_tools(None).await?;
147 Ok(result
148 .tools
149 .into_iter()
150 .map(|t| ToolInfo {
151 name: t.name.to_string(),
152 description: t.description.map(|d| d.to_string()),
153 })
154 .collect())
155 }
156
157 pub fn is_closed(&self) -> bool {
159 use std::ops::Deref;
160 self.inner.deref().is_transport_closed()
161 }
162
163 pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
165 self.inner.peer_info()
166 }
167
168 pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> ProxyHandler {
177 ProxyHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
178 }
179
180 pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
184 self.inner
185 }
186}
187
188#[derive(Clone, Debug)]
190pub struct ToolInfo {
191 pub name: String,
193 pub description: Option<String>,
195}
196
197fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
199 let mut headers = reqwest::header::HeaderMap::new();
200 for (key, value) in &config.headers {
201 let header_name = key
202 .parse::<reqwest::header::HeaderName>()
203 .with_context(|| format!("Invalid header name: {}", key))?;
204 let header_value = value
205 .parse()
206 .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
207 headers.insert(header_name, header_value);
208 }
209
210 let mut builder = reqwest::Client::builder().default_headers(headers);
211
212 if let Some(timeout) = config.connect_timeout {
213 builder = builder.connect_timeout(timeout);
214 }
215
216 if let Some(timeout) = config.read_timeout {
217 builder = builder.timeout(timeout);
218 }
219
220 builder.build().context("Failed to build HTTP client")
221}
222
223fn create_default_client_info() -> ClientInfo {
225 let capabilities = ClientCapabilities::builder()
226 .enable_experimental()
227 .enable_roots()
228 .enable_roots_list_changed()
229 .enable_sampling()
230 .build();
231 ClientInfo::new(
232 capabilities,
233 Implementation::new("mcp-streamable-proxy-client", env!("CARGO_PKG_VERSION")),
234 )
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_tool_info() {
243 let info = ToolInfo {
244 name: "test_tool".to_string(),
245 description: Some("A test tool".to_string()),
246 };
247 assert_eq!(info.name, "test_tool");
248 assert_eq!(info.description, Some("A test tool".to_string()));
249 }
250
251 #[test]
252 fn test_capped_exponential_backoff() {
253 let policy = CappedExponentialBackoff::new(
255 None, Duration::from_secs(1), Duration::from_secs(60), );
259
260 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
262 assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
264 assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
266 assert_eq!(policy.retry(3), Some(Duration::from_secs(8)));
268 assert_eq!(policy.retry(4), Some(Duration::from_secs(16)));
270 assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
272 assert_eq!(policy.retry(6), Some(Duration::from_secs(60)));
274 assert_eq!(policy.retry(7), Some(Duration::from_secs(60)));
276 }
277
278 #[test]
279 fn test_capped_exponential_backoff_with_max_times() {
280 let policy = CappedExponentialBackoff::new(
282 Some(3), Duration::from_secs(1), Duration::from_secs(60), );
286
287 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
289 assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
290 assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
291
292 assert_eq!(policy.retry(3), None);
294 }
295
296 #[test]
297 fn test_capped_exponential_backoff_default() {
298 let policy = CappedExponentialBackoff::default();
300
301 assert_eq!(policy.max_times, None);
303 assert_eq!(policy.base_duration, Duration::from_secs(1));
304 assert_eq!(policy.max_interval, Duration::from_secs(60));
305
306 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
308 assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
309 assert_eq!(policy.retry(10), Some(Duration::from_secs(60)));
310 }
311}