Skip to main content

mcp_streamable_proxy/
server_builder.rs

1//! Streamable HTTP Server Builder
2//!
3//! This module provides a high-level Builder API for creating Streamable HTTP MCP servers.
4//! It encapsulates all rmcp-specific types and provides a simple interface for mcp-proxy.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use tokio::process::Command;
11use tokio_util::sync::CancellationToken;
12use tracing::info;
13
14use rmcp::{
15    ServiceExt,
16    model::{ClientCapabilities, ClientInfo},
17    transport::{
18        TokioChildProcess,
19        streamable_http_client::{
20            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21        },
22        streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
23    },
24};
25
26use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
27pub use mcp_common::ToolFilter as CommonToolFilter;
28
29/// Backend configuration for the MCP server
30///
31/// Defines how the proxy connects to the upstream MCP service.
32#[derive(Debug, Clone)]
33pub enum BackendConfig {
34    /// Connect to a local command via stdio
35    Stdio {
36        /// Command to execute (e.g., "npx", "python", etc.)
37        command: String,
38        /// Arguments for the command
39        args: Option<Vec<String>>,
40        /// Environment variables
41        env: Option<HashMap<String, String>>,
42    },
43    /// Connect to a remote URL
44    Url {
45        /// URL of the MCP service
46        url: String,
47        /// Custom HTTP headers (including Authorization)
48        headers: Option<HashMap<String, String>>,
49    },
50}
51
52/// Configuration for the Streamable HTTP server
53#[derive(Debug, Clone, Default)]
54pub struct StreamServerConfig {
55    /// Enable stateful mode with session management
56    pub stateful_mode: bool,
57    /// MCP service identifier for logging
58    pub mcp_id: Option<String>,
59    /// Tool filter configuration
60    pub tool_filter: Option<ToolFilter>,
61}
62
63/// Builder for creating Streamable HTTP MCP servers
64///
65/// Provides a fluent API for configuring and building MCP proxy servers.
66///
67/// # Example
68///
69/// ```rust,ignore
70/// use mcp_streamable_proxy::server_builder::{StreamServerBuilder, BackendConfig};
71///
72/// // Create a server with stdio backend
73/// let (router, ct) = StreamServerBuilder::new(BackendConfig::Stdio {
74///     command: "npx".into(),
75///     args: Some(vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]),
76///     env: None,
77/// })
78/// .mcp_id("my-server")
79/// .stateful(false)
80/// .build()
81/// .await?;
82/// ```
83pub struct StreamServerBuilder {
84    backend_config: BackendConfig,
85    server_config: StreamServerConfig,
86}
87
88impl StreamServerBuilder {
89    /// Create a new builder with the given backend configuration
90    pub fn new(backend: BackendConfig) -> Self {
91        Self {
92            backend_config: backend,
93            server_config: StreamServerConfig::default(),
94        }
95    }
96
97    /// Set whether to enable stateful mode
98    ///
99    /// Stateful mode enables session management and server-side push.
100    pub fn stateful(mut self, enabled: bool) -> Self {
101        self.server_config.stateful_mode = enabled;
102        self
103    }
104
105    /// Set the MCP service identifier
106    ///
107    /// Used for logging and service identification.
108    pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
109        self.server_config.mcp_id = Some(id.into());
110        self
111    }
112
113    /// Set the tool filter configuration
114    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
115        self.server_config.tool_filter = Some(filter);
116        self
117    }
118
119    /// Build the server and return an axum Router, CancellationToken, and ProxyHandler
120    ///
121    /// The router can be merged with other axum routers or served directly.
122    /// The CancellationToken can be used to gracefully shut down the service.
123    /// The ProxyHandler can be used for status checks and management.
124    pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
125        let mcp_id = self
126            .server_config
127            .mcp_id
128            .clone()
129            .unwrap_or_else(|| "stream-proxy".into());
130
131        // Create client info for connecting to backend
132        let client_info = ClientInfo {
133            protocol_version: Default::default(),
134            capabilities: ClientCapabilities::builder()
135                .enable_experimental()
136                .enable_roots()
137                .enable_roots_list_changed()
138                .enable_sampling()
139                .build(),
140            ..Default::default()
141        };
142
143        // Connect to backend based on configuration
144        let client = match &self.backend_config {
145            BackendConfig::Stdio { command, args, env } => {
146                self.connect_stdio(command, args, env, &client_info).await?
147            }
148            BackendConfig::Url { url, headers } => {
149                self.connect_url(url, headers, &client_info).await?
150            }
151        };
152
153        // Create proxy handler
154        let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
155            ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
156        } else {
157            ProxyHandler::with_mcp_id(client, mcp_id.clone())
158        };
159
160        // Clone handler before creating server
161        let handler_for_return = proxy_handler.clone();
162
163        // Create server with configured stateful mode
164        let (router, ct) = self.create_server(proxy_handler).await?;
165
166        info!(
167            "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
168            mcp_id, self.server_config.stateful_mode
169        );
170
171        Ok((router, ct, handler_for_return))
172    }
173
174    /// Connect to a stdio backend (child process)
175    async fn connect_stdio(
176        &self,
177        command: &str,
178        args: &Option<Vec<String>>,
179        env: &Option<HashMap<String, String>>,
180        client_info: &ClientInfo,
181    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
182        let mut cmd = Command::new(command);
183
184        if let Some(cmd_args) = args {
185            cmd.args(cmd_args);
186        }
187
188        if let Some(env_vars) = env {
189            for (k, v) in env_vars {
190                cmd.env(k, v);
191            }
192        }
193
194        info!(
195            "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
196            command,
197            args.as_ref().unwrap_or(&vec![])
198        );
199
200        let tokio_process = TokioChildProcess::new(cmd)?;
201        let client = client_info.clone().serve(tokio_process).await?;
202
203        info!("[StreamServerBuilder] Child process connected successfully");
204        Ok(client)
205    }
206
207    /// Connect to a URL backend (remote MCP service)
208    async fn connect_url(
209        &self,
210        url: &str,
211        headers: &Option<HashMap<String, String>>,
212        client_info: &ClientInfo,
213    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
214        info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
215
216        // Build HTTP client with custom headers (excluding Authorization)
217        let mut req_headers = reqwest::header::HeaderMap::new();
218        let mut auth_header: Option<String> = None;
219
220        if let Some(config_headers) = headers {
221            for (key, value) in config_headers {
222                // Authorization header is handled separately by rmcp
223                if key.eq_ignore_ascii_case("Authorization") {
224                    auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
225                    continue;
226                }
227
228                req_headers.insert(
229                    reqwest::header::HeaderName::try_from(key)
230                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
231                    value.parse().map_err(|e| {
232                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
233                    })?,
234                );
235            }
236        }
237
238        let http_client = reqwest::Client::builder()
239            .default_headers(req_headers)
240            .build()
241            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
242
243        // Create transport configuration
244        let config = StreamableHttpClientTransportConfig {
245            uri: url.to_string().into(),
246            auth_header,
247            ..Default::default()
248        };
249
250        let transport = StreamableHttpClientTransport::with_client(http_client, config);
251        let client = client_info.clone().serve(transport).await?;
252
253        info!("[StreamServerBuilder] URL backend connected successfully");
254        Ok(client)
255    }
256
257    /// Create the Streamable HTTP server
258    async fn create_server(
259        &self,
260        proxy_handler: ProxyHandler,
261    ) -> Result<(axum::Router, CancellationToken)> {
262        let handler = Arc::new(proxy_handler);
263        let ct = CancellationToken::new();
264
265        if self.server_config.stateful_mode {
266            // Stateful mode with custom session manager
267            let session_manager = ProxyAwareSessionManager::new(handler.clone());
268            let handler_for_service = handler.clone();
269
270            let service = StreamableHttpService::new(
271                move || Ok((*handler_for_service).clone()),
272                session_manager.into(),
273                StreamableHttpServerConfig {
274                    stateful_mode: true,
275                    ..Default::default()
276                },
277            );
278
279            let router = axum::Router::new().fallback_service(service);
280            Ok((router, ct))
281        } else {
282            // Stateless mode with local session manager
283            use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
284
285            let handler_for_service = handler.clone();
286
287            let service = StreamableHttpService::new(
288                move || Ok((*handler_for_service).clone()),
289                LocalSessionManager::default().into(),
290                StreamableHttpServerConfig {
291                    stateful_mode: false,
292                    ..Default::default()
293                },
294            );
295
296            let router = axum::Router::new().fallback_service(service);
297            Ok((router, ct))
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_builder_creation() {
308        let builder = StreamServerBuilder::new(BackendConfig::Stdio {
309            command: "echo".into(),
310            args: Some(vec!["hello".into()]),
311            env: None,
312        })
313        .mcp_id("test")
314        .stateful(true);
315
316        assert!(builder.server_config.mcp_id.is_some());
317        assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
318        assert!(builder.server_config.stateful_mode);
319    }
320
321    #[test]
322    fn test_url_backend_config() {
323        let mut headers = HashMap::new();
324        headers.insert("Authorization".into(), "Bearer token123".into());
325        headers.insert("X-Custom".into(), "value".into());
326
327        let builder = StreamServerBuilder::new(BackendConfig::Url {
328            url: "http://localhost:8080/mcp".into(),
329            headers: Some(headers),
330        });
331
332        match &builder.backend_config {
333            BackendConfig::Url { url, headers } => {
334                assert_eq!(url, "http://localhost:8080/mcp");
335                assert!(headers.is_some());
336            }
337            _ => panic!("Expected URL backend"),
338        }
339    }
340}