Skip to main content

mcp_streamable_proxy/
server_builder.rs

1//! Streamable HTTP Server Builder
2//!
3//! This module provides a high-level Builder API for creating Streamable HTTP MCP servers.
4//! It encapsulates all rmcp-specific types and provides a simple interface for mcp-proxy.
5
6use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14
15use rmcp::{
16    ServiceExt,
17    model::{ClientCapabilities, ClientInfo},
18    transport::{
19        TokioChildProcess,
20        streamable_http_client::{
21            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22        },
23        streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24    },
25};
26
27// Unix 进程组支持
28#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31// Windows 静默运行支持
32#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38/// Backend configuration for the MCP server
39///
40/// Defines how the proxy connects to the upstream MCP service.
41#[derive(Debug, Clone)]
42pub enum BackendConfig {
43    /// Connect to a local command via stdio
44    Stdio {
45        /// Command to execute (e.g., "npx", "python", etc.)
46        command: String,
47        /// Arguments for the command
48        args: Option<Vec<String>>,
49        /// Environment variables from MCP JSON config
50        env: Option<HashMap<String, String>>,
51    },
52    /// Connect to a remote URL
53    Url {
54        /// URL of the MCP service
55        url: String,
56        /// Custom HTTP headers (including Authorization)
57        headers: Option<HashMap<String, String>>,
58    },
59}
60
61/// Configuration for the Streamable HTTP server
62#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64    /// Enable stateful mode with session management
65    pub stateful_mode: bool,
66    /// MCP service identifier for logging
67    pub mcp_id: Option<String>,
68    /// Tool filter configuration
69    pub tool_filter: Option<ToolFilter>,
70}
71
72/// Builder for creating Streamable HTTP MCP servers
73///
74/// Provides a fluent API for configuring and building MCP proxy servers.
75///
76/// # Example
77///
78/// ```rust,ignore
79/// use mcp_streamable_proxy::server_builder::{StreamServerBuilder, BackendConfig};
80///
81/// // Create a server with stdio backend
82/// let (router, ct) = StreamServerBuilder::new(BackendConfig::Stdio {
83///     command: "npx".into(),
84///     args: Some(vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]),
85///     env: None,
86/// })
87/// .mcp_id("my-server")
88/// .stateful(false)
89/// .build()
90/// .await?;
91/// ```
92pub struct StreamServerBuilder {
93    backend_config: BackendConfig,
94    server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98    /// Create a new builder with the given backend configuration
99    pub fn new(backend: BackendConfig) -> Self {
100        Self {
101            backend_config: backend,
102            server_config: StreamServerConfig::default(),
103        }
104    }
105
106    /// Set whether to enable stateful mode
107    ///
108    /// Stateful mode enables session management and server-side push.
109    pub fn stateful(mut self, enabled: bool) -> Self {
110        self.server_config.stateful_mode = enabled;
111        self
112    }
113
114    /// Set the MCP service identifier
115    ///
116    /// Used for logging and service identification.
117    pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118        self.server_config.mcp_id = Some(id.into());
119        self
120    }
121
122    /// Set the tool filter configuration
123    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124        self.server_config.tool_filter = Some(filter);
125        self
126    }
127
128    /// Build the server and return an axum Router, CancellationToken, and ProxyHandler
129    ///
130    /// The router can be merged with other axum routers or served directly.
131    /// The CancellationToken can be used to gracefully shut down the service.
132    /// The ProxyHandler can be used for status checks and management.
133    pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134        let mcp_id = self
135            .server_config
136            .mcp_id
137            .clone()
138            .unwrap_or_else(|| "stream-proxy".into());
139
140        // Create client info for connecting to backend
141        let capabilities = ClientCapabilities::builder()
142            .enable_experimental()
143            .enable_roots()
144            .enable_roots_list_changed()
145            .enable_sampling()
146            .build();
147        let client_info = ClientInfo::new(
148            capabilities,
149            rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150        );
151
152        // Connect to backend based on configuration
153        let client = match &self.backend_config {
154            BackendConfig::Stdio { command, args, env } => {
155                self.connect_stdio(command, args, env, &client_info).await?
156            }
157            BackendConfig::Url { url, headers } => {
158                self.connect_url(url, headers, &client_info).await?
159            }
160        };
161
162        // Create proxy handler
163        let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164            ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165        } else {
166            ProxyHandler::with_mcp_id(client, mcp_id.clone())
167        };
168
169        // Clone handler before creating server
170        let handler_for_return = proxy_handler.clone();
171
172        // Create server with configured stateful mode
173        let (router, ct) = self.create_server(proxy_handler).await?;
174
175        info!(
176            "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177            mcp_id, self.server_config.stateful_mode
178        );
179
180        Ok((router, ct, handler_for_return))
181    }
182
183    /// Connect to a stdio backend (child process)
184    async fn connect_stdio(
185        &self,
186        command: &str,
187        args: &Option<Vec<String>>,
188        env: &Option<HashMap<String, String>>,
189        client_info: &ClientInfo,
190    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191        let args = args.clone();
192
193        // 使用 process-wrap 创建子进程命令(跨平台进程清理)
194        // process-wrap 会自动处理进程组(Unix)或 Job Object(Windows)
195        // 并且在 Drop 时自动清理子进程树
196        // 子进程默认继承父进程的所有环境变量
197        let mut wrapped_cmd = CommandWrap::with_new(command, |cmd| {
198            if let Some(cmd_args) = &args {
199                cmd.args(cmd_args);
200            }
201            // 设置 MCP JSON 配置中的环境变量(会覆盖继承的同名变量)
202            if let Some(env_vars) = env {
203                for (k, v) in env_vars {
204                    cmd.env(k, v);
205                }
206            }
207        });
208
209        // Unix: 创建进程组,支持 killpg 清理整个进程树
210        #[cfg(unix)]
211        wrapped_cmd.wrap(ProcessGroup::leader());
212
213        // Windows: 使用 CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP 隐藏控制台窗口
214        #[cfg(windows)]
215        {
216            use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
217            info!(
218                "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
219            );
220            wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
221            wrapped_cmd.wrap(JobObject);
222        }
223
224        // 所有平台: Drop 时自动清理进程
225        wrapped_cmd.wrap(KillOnDrop);
226
227        info!(
228            "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
229            command,
230            args.as_ref().unwrap_or(&vec![])
231        );
232
233        let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
234
235        // 诊断日志:子进程关键环境变量
236        mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
237
238        // MCP 服务通过 stdin/stdout 进行 JSON-RPC 通信,必须使用 piped(默认行为)
239        // 使用 builder 模式捕获 stderr,便于诊断子 MCP 服务初始化失败
240        let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
241            .stderr(Stdio::piped())
242            .spawn()
243            .map_err(|e| {
244                anyhow::anyhow!(
245                    "{}",
246                    mcp_common::diagnostic::format_spawn_error(mcp_id, command, &args, e)
247                )
248            })?;
249
250        // 启动 stderr 日志读取任务
251        if let Some(stderr_pipe) = child_stderr {
252            mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
253        }
254
255        let client = client_info.clone().serve(tokio_process).await?;
256
257        info!("[StreamServerBuilder] Child process connected successfully");
258        Ok(client)
259    }
260
261    /// Connect to a URL backend (remote MCP service)
262    async fn connect_url(
263        &self,
264        url: &str,
265        headers: &Option<HashMap<String, String>>,
266        client_info: &ClientInfo,
267    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
268        info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
269
270        // Build HTTP client with custom headers (excluding Authorization)
271        let mut req_headers = reqwest::header::HeaderMap::new();
272        let mut auth_header: Option<String> = None;
273
274        if let Some(config_headers) = headers {
275            for (key, value) in config_headers {
276                // Authorization header is handled separately by rmcp
277                if key.eq_ignore_ascii_case("Authorization") {
278                    auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
279                    continue;
280                }
281
282                req_headers.insert(
283                    reqwest::header::HeaderName::try_from(key)
284                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
285                    value.parse().map_err(|e| {
286                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
287                    })?,
288                );
289            }
290        }
291
292        let http_client = reqwest::Client::builder()
293            .default_headers(req_headers)
294            .build()
295            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
296
297        // Create transport configuration
298        let config = StreamableHttpClientTransportConfig {
299            uri: url.to_string().into(),
300            auth_header,
301            ..Default::default()
302        };
303
304        let transport = StreamableHttpClientTransport::with_client(http_client, config);
305        let client = client_info.clone().serve(transport).await?;
306
307        info!("[StreamServerBuilder] URL backend connected successfully");
308        Ok(client)
309    }
310
311    /// Create the Streamable HTTP server
312    async fn create_server(
313        &self,
314        proxy_handler: ProxyHandler,
315    ) -> Result<(axum::Router, CancellationToken)> {
316        let handler = Arc::new(proxy_handler);
317        let ct = CancellationToken::new();
318
319        if self.server_config.stateful_mode {
320            // Stateful mode with custom session manager
321            let session_manager = ProxyAwareSessionManager::new(handler.clone());
322            let handler_for_service = handler.clone();
323
324            let service = StreamableHttpService::new(
325                move || Ok((*handler_for_service).clone()),
326                session_manager.into(),
327                StreamableHttpServerConfig {
328                    stateful_mode: true,
329                    ..Default::default()
330                },
331            );
332
333            let router = axum::Router::new().fallback_service(service);
334            Ok((router, ct))
335        } else {
336            // Stateless mode with local session manager
337            use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
338
339            let handler_for_service = handler.clone();
340
341            let service = StreamableHttpService::new(
342                move || Ok((*handler_for_service).clone()),
343                LocalSessionManager::default().into(),
344                StreamableHttpServerConfig {
345                    stateful_mode: false,
346                    ..Default::default()
347                },
348            );
349
350            let router = axum::Router::new().fallback_service(service);
351            Ok((router, ct))
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_builder_creation() {
362        let builder = StreamServerBuilder::new(BackendConfig::Stdio {
363            command: "echo".into(),
364            args: Some(vec!["hello".into()]),
365            env: None,
366        })
367        .mcp_id("test")
368        .stateful(true);
369
370        assert!(builder.server_config.mcp_id.is_some());
371        assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
372        assert!(builder.server_config.stateful_mode);
373    }
374
375    #[test]
376    fn test_url_backend_config() {
377        let mut headers = HashMap::new();
378        headers.insert("Authorization".into(), "Bearer token123".into());
379        headers.insert("X-Custom".into(), "value".into());
380
381        let builder = StreamServerBuilder::new(BackendConfig::Url {
382            url: "http://localhost:8080/mcp".into(),
383            headers: Some(headers),
384        });
385
386        match &builder.backend_config {
387            BackendConfig::Url { url, headers } => {
388                assert_eq!(url, "http://localhost:8080/mcp");
389                assert!(headers.is_some());
390            }
391            _ => panic!("Expected URL backend"),
392        }
393    }
394}