mcp_streamable_proxy/
client.rs1use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9 RoleClient, ServiceExt,
10 model::{ClientCapabilities, ClientInfo, Implementation},
11 service::RunningService,
12 transport::streamable_http_client::{
13 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
14 },
15};
16
17use crate::proxy_handler::ProxyHandler;
18use mcp_common::ToolFilter;
19
20pub struct StreamClientConnection {
43 inner: RunningService<RoleClient, ClientInfo>,
44}
45
46impl StreamClientConnection {
47 pub async fn connect(config: McpClientConfig) -> Result<Self> {
56 let http_client = build_http_client(&config)?;
57
58 let transport_config = StreamableHttpClientTransportConfig {
59 uri: config.url.clone().into(),
60 ..Default::default()
61 };
62
63 let transport = StreamableHttpClientTransport::with_client(http_client, transport_config);
64
65 let client_info = create_default_client_info();
66 let running = client_info
67 .serve(transport)
68 .await
69 .context("Failed to initialize MCP client")?;
70
71 Ok(Self { inner: running })
72 }
73
74 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
76 let result = self.inner.list_tools(None).await?;
77 Ok(result
78 .tools
79 .into_iter()
80 .map(|t| ToolInfo {
81 name: t.name.to_string(),
82 description: t.description.map(|d| d.to_string()),
83 })
84 .collect())
85 }
86
87 pub fn is_closed(&self) -> bool {
89 use std::ops::Deref;
90 self.inner.deref().is_transport_closed()
91 }
92
93 pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
95 self.inner.peer_info()
96 }
97
98 pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> ProxyHandler {
107 ProxyHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
108 }
109
110 pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
114 self.inner
115 }
116}
117
118#[derive(Clone, Debug)]
120pub struct ToolInfo {
121 pub name: String,
123 pub description: Option<String>,
125}
126
127fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
129 let mut headers = reqwest::header::HeaderMap::new();
130 for (key, value) in &config.headers {
131 let header_name = key
132 .parse::<reqwest::header::HeaderName>()
133 .with_context(|| format!("Invalid header name: {}", key))?;
134 let header_value = value
135 .parse()
136 .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
137 headers.insert(header_name, header_value);
138 }
139
140 let mut builder = reqwest::Client::builder().default_headers(headers);
141
142 if let Some(timeout) = config.connect_timeout {
143 builder = builder.connect_timeout(timeout);
144 }
145
146 if let Some(timeout) = config.read_timeout {
147 builder = builder.timeout(timeout);
148 }
149
150 builder.build().context("Failed to build HTTP client")
151}
152
153fn create_default_client_info() -> ClientInfo {
155 ClientInfo {
156 protocol_version: Default::default(),
157 capabilities: ClientCapabilities::builder()
158 .enable_experimental()
159 .enable_roots()
160 .enable_roots_list_changed()
161 .enable_sampling()
162 .build(),
163 client_info: Implementation {
164 name: "mcp-streamable-proxy-client".to_string(),
165 version: env!("CARGO_PKG_VERSION").to_string(),
166 title: None,
167 website_url: None,
168 icons: None,
169 },
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_tool_info() {
179 let info = ToolInfo {
180 name: "test_tool".to_string(),
181 description: Some("A test tool".to_string()),
182 };
183 assert_eq!(info.name, "test_tool");
184 assert_eq!(info.description, Some("A test tool".to_string()));
185 }
186}