mcp_sse_proxy/
client.rs

1//! SSE Client Connection Module
2//!
3//! Provides a high-level API for connecting to MCP servers via SSE protocol.
4//! This module encapsulates the rmcp transport details and exposes a simple interface.
5
6use anyhow::{Context, Result};
7use mcp_common::McpClientConfig;
8use rmcp::{
9    RoleClient, ServiceExt,
10    model::{ClientCapabilities, ClientInfo, Implementation},
11    service::RunningService,
12    transport::{SseClientTransport, sse_client::SseClientConfig},
13};
14
15use crate::sse_handler::SseHandler;
16use mcp_common::ToolFilter;
17
18/// Opaque wrapper for SSE client connection
19///
20/// This type encapsulates an active connection to an MCP server via SSE protocol.
21/// It hides the internal `RunningService` type and provides only the methods
22/// needed by consuming code.
23///
24/// Note: This type is not Clone because the underlying RunningService
25/// is designed for single-owner use. Use `into_handler()` or `into_running_service()`
26/// to consume the connection.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use mcp_sse_proxy::{SseClientConnection, McpClientConfig};
32///
33/// let config = McpClientConfig::new("http://localhost:8080/sse")
34///     .with_header("Authorization", "Bearer token");
35///
36/// let conn = SseClientConnection::connect(config).await?;
37/// let tools = conn.list_tools().await?;
38/// println!("Available tools: {:?}", tools);
39/// ```
40pub struct SseClientConnection {
41    inner: RunningService<RoleClient, ClientInfo>,
42}
43
44impl SseClientConnection {
45    /// Connect to an SSE MCP server
46    ///
47    /// # Arguments
48    /// * `config` - Client configuration including URL and headers
49    ///
50    /// # Returns
51    /// * `Ok(SseClientConnection)` - Successfully connected client
52    /// * `Err` - Connection failed
53    pub async fn connect(config: McpClientConfig) -> Result<Self> {
54        let http_client = build_http_client(&config)?;
55
56        let sse_config = SseClientConfig {
57            sse_endpoint: config.url.clone().into(),
58            ..Default::default()
59        };
60
61        let transport: SseClientTransport<reqwest::Client> =
62            SseClientTransport::start_with_client(http_client, sse_config)
63                .await
64                .context("Failed to start SSE transport")?;
65
66        let client_info = create_default_client_info();
67        let running = client_info
68            .serve(transport)
69            .await
70            .context("Failed to initialize MCP client")?;
71
72        Ok(Self { inner: running })
73    }
74
75    /// List available tools from the MCP server
76    pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
77        let result = self.inner.list_tools(None).await?;
78        Ok(result
79            .tools
80            .into_iter()
81            .map(|t| ToolInfo {
82                name: t.name.to_string(),
83                description: t.description.map(|d| d.to_string()),
84            })
85            .collect())
86    }
87
88    /// Check if the connection is closed
89    pub fn is_closed(&self) -> bool {
90        use std::ops::Deref;
91        self.inner.deref().is_transport_closed()
92    }
93
94    /// Get the peer info from the server
95    pub fn peer_info(&self) -> Option<&rmcp::model::ServerInfo> {
96        self.inner.peer_info()
97    }
98
99    /// Convert this connection into an SseHandler for serving
100    ///
101    /// This consumes the connection and creates an SseHandler that can
102    /// proxy requests to the backend MCP server.
103    ///
104    /// # Arguments
105    /// * `mcp_id` - Identifier for logging purposes
106    /// * `tool_filter` - Tool filtering configuration
107    pub fn into_handler(self, mcp_id: String, tool_filter: ToolFilter) -> SseHandler {
108        SseHandler::with_tool_filter(self.inner, mcp_id, tool_filter)
109    }
110
111    /// Extract the internal RunningService for use with swap_backend
112    ///
113    /// This is used internally to support backend hot-swapping.
114    pub fn into_running_service(self) -> RunningService<RoleClient, ClientInfo> {
115        self.inner
116    }
117}
118
119/// Simplified tool information
120#[derive(Clone, Debug)]
121pub struct ToolInfo {
122    /// Tool name
123    pub name: String,
124    /// Tool description (optional)
125    pub description: Option<String>,
126}
127
128/// Build an HTTP client with the given configuration
129fn build_http_client(config: &McpClientConfig) -> Result<reqwest::Client> {
130    let mut headers = reqwest::header::HeaderMap::new();
131    for (key, value) in &config.headers {
132        let header_name = key
133            .parse::<reqwest::header::HeaderName>()
134            .with_context(|| format!("Invalid header name: {}", key))?;
135        let header_value = value
136            .parse()
137            .with_context(|| format!("Invalid header value for {}: {}", key, value))?;
138        headers.insert(header_name, header_value);
139    }
140
141    let mut builder = reqwest::Client::builder().default_headers(headers);
142
143    if let Some(timeout) = config.connect_timeout {
144        builder = builder.connect_timeout(timeout);
145    }
146
147    if let Some(timeout) = config.read_timeout {
148        builder = builder.timeout(timeout);
149    }
150
151    builder.build().context("Failed to build HTTP client")
152}
153
154/// Create default client info for MCP handshake
155fn create_default_client_info() -> ClientInfo {
156    ClientInfo {
157        protocol_version: Default::default(),
158        capabilities: ClientCapabilities::builder()
159            .enable_experimental()
160            .enable_roots()
161            .enable_roots_list_changed()
162            .enable_sampling()
163            .build(),
164        client_info: Implementation {
165            name: "mcp-sse-proxy-client".to_string(),
166            version: env!("CARGO_PKG_VERSION").to_string(),
167            title: None,
168            website_url: None,
169            icons: None,
170        },
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_tool_info() {
180        let info = ToolInfo {
181            name: "test_tool".to_string(),
182            description: Some("A test tool".to_string()),
183        };
184        assert_eq!(info.name, "test_tool");
185        assert_eq!(info.description, Some("A test tool".to_string()));
186    }
187}