mcp_streamable_proxy/
client.rs1use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9 RoleClient, ServiceExt,
10 model::{ClientCapabilities, ClientInfo, Implementation},
11 service::RunningService,
12 transport::{
13 common::client_side_sse::SseRetryPolicy,
14 streamable_http_client::{
15 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
16 },
17 },
18};
19use std::sync::Arc;
20use std::time::Duration;
21
22use crate::proxy_handler::ProxyHandler;
23use mcp_common::ToolFilter;
24
25#[derive(Debug, Clone)]
33pub struct CappedExponentialBackoff {
34 pub max_times: Option<usize>,
36 pub base_duration: Duration,
38 pub max_interval: Duration,
40}
41
42impl CappedExponentialBackoff {
43 pub fn new(max_times: Option<usize>, base_duration: Duration, max_interval: Duration) -> Self {
50 Self {
51 max_times,
52 base_duration,
53 max_interval,
54 }
55 }
56}
57
58impl Default for CappedExponentialBackoff {
59 fn default() -> Self {
60 Self {
61 max_times: None,
62 base_duration: Duration::from_secs(1),
63 max_interval: Duration::from_secs(60),
64 }
65 }
66}
67
68impl SseRetryPolicy for CappedExponentialBackoff {
69 fn retry(&self, current_times: usize) -> Option<Duration> {
70 if let Some(max_times) = self.max_times
72 && current_times >= max_times
73 {
74 return None;
75 }
76
77 let exponential_delay = self.base_duration * (2u32.pow(current_times as u32));
79
80 Some(exponential_delay.min(self.max_interval))
82 }
83}
84
85pub struct StreamClientConnection {
108 inner: RunningService<RoleClient, ClientInfo>,
109}
110
111impl StreamClientConnection {
112 pub async fn connect(config: McpClientConfig) -> Result<Self> {
121 let http_client = build_http_client(&config)?;
122
123 let retry_policy = CappedExponentialBackoff::new(
125 None, Duration::from_secs(1), Duration::from_secs(60), );
129
130 let transport_config = StreamableHttpClientTransportConfig {
131 uri: config.url.clone().into(),
132 retry_config: Arc::new(retry_policy),
133 ..Default::default()
134 };
135
136 let transport = StreamableHttpClientTransport::with_client(http_client, transport_config);
137
138 let client_info = create_default_client_info();
139 let running = client_info
140 .serve(transport)
141 .await
142 .context("Failed to initialize MCP client")?;
143
144 Ok(Self { inner: running })
145 }
146
147 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
149 let result = self.inner.list_tools(None).await?;
150 Ok(result
151 .tools
152 .into_iter()
153 .map(|t| ToolInfo {
154 name: t.name.to_string(),
155 description: t.description.map(|d| d.to_string()),
156 })
157 .collect())
158 }
159
160 pub fn is_closed(&self) -> bool {
162 use std::ops::Deref;
163 self.inner.deref().is_transport_closed()
164 }
165
166 pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
168 self.inner.peer_info()
169 }
170
171 pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> ProxyHandler {
180 ProxyHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
181 }
182
183 pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
187 self.inner
188 }
189}
190
191#[derive(Clone, Debug)]
193pub struct ToolInfo {
194 pub name: String,
196 pub description: Option<String>,
198}
199
200fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
202 let mut headers = reqwest::header::HeaderMap::new();
203 for (key, value) in &config.headers {
204 let header_name = key
205 .parse::<reqwest::header::HeaderName>()
206 .with_context(|| format!("Invalid header name: {}", key))?;
207 let header_value = value
208 .parse()
209 .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
210 headers.insert(header_name, header_value);
211 }
212
213 let mut builder = reqwest::Client::builder().default_headers(headers);
214
215 if let Some(timeout) = config.connect_timeout {
216 builder = builder.connect_timeout(timeout);
217 }
218
219 if let Some(timeout) = config.read_timeout {
220 builder = builder.timeout(timeout);
221 }
222
223 builder.build().context("Failed to build HTTP client")
224}
225
226fn create_default_client_info() -> ClientInfo {
228 let capabilities = ClientCapabilities::builder()
229 .enable_experimental()
230 .enable_roots()
231 .enable_roots_list_changed()
232 .enable_sampling()
233 .build();
234 ClientInfo::new(
235 capabilities,
236 Implementation::new("mcp-streamable-proxy-client", env!("CARGO_PKG_VERSION")),
237 )
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_tool_info() {
246 let info = ToolInfo {
247 name: "test_tool".to_string(),
248 description: Some("A test tool".to_string()),
249 };
250 assert_eq!(info.name, "test_tool");
251 assert_eq!(info.description, Some("A test tool".to_string()));
252 }
253
254 #[test]
255 fn test_capped_exponential_backoff() {
256 let policy = CappedExponentialBackoff::new(
258 None, Duration::from_secs(1), Duration::from_secs(60), );
262
263 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
265 assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
267 assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
269 assert_eq!(policy.retry(3), Some(Duration::from_secs(8)));
271 assert_eq!(policy.retry(4), Some(Duration::from_secs(16)));
273 assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
275 assert_eq!(policy.retry(6), Some(Duration::from_secs(60)));
277 assert_eq!(policy.retry(7), Some(Duration::from_secs(60)));
279 }
280
281 #[test]
282 fn test_capped_exponential_backoff_with_max_times() {
283 let policy = CappedExponentialBackoff::new(
285 Some(3), Duration::from_secs(1), Duration::from_secs(60), );
289
290 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
292 assert_eq!(policy.retry(1), Some(Duration::from_secs(2)));
293 assert_eq!(policy.retry(2), Some(Duration::from_secs(4)));
294
295 assert_eq!(policy.retry(3), None);
297 }
298
299 #[test]
300 fn test_capped_exponential_backoff_default() {
301 let policy = CappedExponentialBackoff::default();
303
304 assert_eq!(policy.max_times, None);
306 assert_eq!(policy.base_duration, Duration::from_secs(1));
307 assert_eq!(policy.max_interval, Duration::from_secs(60));
308
309 assert_eq!(policy.retry(0), Some(Duration::from_secs(1)));
311 assert_eq!(policy.retry(5), Some(Duration::from_secs(32)));
312 assert_eq!(policy.retry(10), Some(Duration::from_secs(60)));
313 }
314}