1use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9 RoleClient, ServiceExt,
10 model::{ClientCapabilities, ClientInfo, Implementation},
11 service::RunningService,
12 transport::{SseClientTransport, sse_client::SseClientConfig},
13};
14
15use crate::sse_handler::SseHandler;
16use mcp_common::ToolFilter;
17
18pub struct SseClientConnection {
41 inner: RunningService<RoleClient, ClientInfo>,
42}
43
44impl SseClientConnection {
45 pub async fn connect(config: McpClientConfig) -> Result<Self> {
54 let http_client = build_http_client(&config)?;
55
56 let sse_config = SseClientConfig {
57 sse_endpoint: config.url.clone().into(),
58 ..Default::default()
59 };
60
61 let transport: SseClientTransport<reqwest::Client> =
62 SseClientTransport::start_with_client(http_client, sse_config)
63 .await
64 .context("Failed to start SSE transport")?;
65
66 let client_info = create_default_client_info();
67 let running = client_info
68 .serve(transport)
69 .await
70 .context("Failed to initialize MCP client")?;
71
72 Ok(Self { inner: running })
73 }
74
75 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
77 let result = self.inner.list_tools(None).await?;
78 Ok(result
79 .tools
80 .into_iter()
81 .map(|t| ToolInfo {
82 name: t.name.to_string(),
83 description: t.description.map(|d| d.to_string()),
84 })
85 .collect())
86 }
87
88 pub fn is_closed(&self) -> bool {
90 use std::ops::Deref;
91 self.inner.deref().is_transport_closed()
92 }
93
94 pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
96 self.inner.peer_info()
97 }
98
99 pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
108 SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
109 }
110
111 pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
115 self.inner
116 }
117}
118
119#[derive(Clone, Debug)]
121pub struct ToolInfo {
122 pub name: String,
124 pub description: Option<String>,
126}
127
128fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
130 let mut headers = reqwest::header::HeaderMap::new();
131 for (key, value) in &config.headers {
132 let header_name = key
133 .parse::<reqwest::header::HeaderName>()
134 .with_context(|| format!("Invalid header name: {}", key))?;
135 let header_value = value
136 .parse()
137 .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
138 headers.insert(header_name, header_value);
139 }
140
141 let mut builder = reqwest::Client::builder().default_headers(headers);
142
143 if let Some(timeout) = config.connect_timeout {
144 builder = builder.connect_timeout(timeout);
145 }
146
147 if let Some(timeout) = config.read_timeout {
148 builder = builder.timeout(timeout);
149 }
150
151 builder.build().context("Failed to build HTTP client")
152}
153
154fn create_default_client_info() -> ClientInfo {
156 ClientInfo {
157 protocol_version: Default::default(),
158 capabilities: ClientCapabilities::builder()
159 .enable_experimental()
160 .enable_roots()
161 .enable_roots_list_changed()
162 .enable_sampling()
163 .build(),
164 client_info: Implementation {
165 name: "mcp-sse-proxy-client".to_string(),
166 version: env!("CARGO_PKG_VERSION").to_string(),
167 title: None,
168 website_url: None,
169 icons: None,
170 },
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_tool_info() {
180 let info = ToolInfo {
181 name: "test_tool".to_string(),
182 description: Some("A test tool".to_string()),
183 };
184 assert_eq!(info.name, "test_tool");
185 assert_eq!(info.description, Some("A test tool".to_string()));
186 }
187}