Skip to main content

mcp_streamable_proxy/
server_builder.rs

1//! Streamable HTTP Server Builder
2//!
3//! This module provides a high-level Builder API for creating Streamable HTTP MCP servers.
4//! It encapsulates all rmcp-specific types and provides a simple interface for mcp-proxy.
5
6use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::info;
14
15use rmcp::{
16    ServiceExt,
17    model::{ClientCapabilities, ClientInfo},
18    transport::{
19        TokioChildProcess,
20        streamable_http_client::{
21            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22        },
23        streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24    },
25};
26
27// Unix 进程组支持
28#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31// Windows 静默运行支持
32#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38/// Backend configuration for the MCP server
39///
40/// Defines how the proxy connects to the upstream MCP service.
41#[derive(Debug, Clone)]
42pub enum BackendConfig {
43    /// Connect to a local command via stdio
44    Stdio {
45        /// Command to execute (e.g., "npx", "python", etc.)
46        command: String,
47        /// Arguments for the command
48        args: Option<Vec<String>>,
49        /// Environment variables from MCP JSON config
50        env: Option<HashMap<String, String>>,
51    },
52    /// Connect to a remote URL
53    Url {
54        /// URL of the MCP service
55        url: String,
56        /// Custom HTTP headers (including Authorization)
57        headers: Option<HashMap<String, String>>,
58    },
59}
60
61/// Configuration for the Streamable HTTP server
62#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64    /// Enable stateful mode with session management
65    pub stateful_mode: bool,
66    /// MCP service identifier for logging
67    pub mcp_id: Option<String>,
68    /// Tool filter configuration
69    pub tool_filter: Option<ToolFilter>,
70}
71
72/// Builder for creating Streamable HTTP MCP servers
73///
74/// Provides a fluent API for configuring and building MCP proxy servers.
75///
76/// # Example
77///
78/// ```rust,ignore
79/// use mcp_streamable_proxy::server_builder::{StreamServerBuilder, BackendConfig};
80///
81/// // Create a server with stdio backend
82/// let (router, ct) = StreamServerBuilder::new(BackendConfig::Stdio {
83///     command: "npx".into(),
84///     args: Some(vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]),
85///     env: None,
86/// })
87/// .mcp_id("my-server")
88/// .stateful(false)
89/// .build()
90/// .await?;
91/// ```
92pub struct StreamServerBuilder {
93    backend_config: BackendConfig,
94    server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98    /// Create a new builder with the given backend configuration
99    pub fn new(backend: BackendConfig) -> Self {
100        Self {
101            backend_config: backend,
102            server_config: StreamServerConfig::default(),
103        }
104    }
105
106    /// Set whether to enable stateful mode
107    ///
108    /// Stateful mode enables session management and server-side push.
109    pub fn stateful(mut self, enabled: bool) -> Self {
110        self.server_config.stateful_mode = enabled;
111        self
112    }
113
114    /// Set the MCP service identifier
115    ///
116    /// Used for logging and service identification.
117    pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118        self.server_config.mcp_id = Some(id.into());
119        self
120    }
121
122    /// Set the tool filter configuration
123    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124        self.server_config.tool_filter = Some(filter);
125        self
126    }
127
128    /// Build the server and return an axum Router, CancellationToken, and ProxyHandler
129    ///
130    /// The router can be merged with other axum routers or served directly.
131    /// The CancellationToken can be used to gracefully shut down the service.
132    /// The ProxyHandler can be used for status checks and management.
133    pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134        let mcp_id = self
135            .server_config
136            .mcp_id
137            .clone()
138            .unwrap_or_else(|| "stream-proxy".into());
139
140        // Create client info for connecting to backend
141        let capabilities = ClientCapabilities::builder()
142            .enable_experimental()
143            .enable_roots()
144            .enable_roots_list_changed()
145            .enable_sampling()
146            .build();
147        let client_info = ClientInfo::new(
148            capabilities,
149            rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150        );
151
152        // Connect to backend based on configuration
153        let client = match &self.backend_config {
154            BackendConfig::Stdio { command, args, env } => {
155                self.connect_stdio(command, args, env, &client_info).await?
156            }
157            BackendConfig::Url { url, headers } => {
158                self.connect_url(url, headers, &client_info).await?
159            }
160        };
161
162        // Create proxy handler
163        let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164            ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165        } else {
166            ProxyHandler::with_mcp_id(client, mcp_id.clone())
167        };
168
169        // Clone handler before creating server
170        let handler_for_return = proxy_handler.clone();
171
172        // Create server with configured stateful mode
173        let (router, ct) = self.create_server(proxy_handler).await?;
174
175        info!(
176            "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177            mcp_id, self.server_config.stateful_mode
178        );
179
180        Ok((router, ct, handler_for_return))
181    }
182
183    /// Connect to a stdio backend (child process)
184    async fn connect_stdio(
185        &self,
186        command: &str,
187        args: &Option<Vec<String>>,
188        env: &Option<HashMap<String, String>>,
189        client_info: &ClientInfo,
190    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191        // Windows 上预处理 npx 命令,避免 .cmd 文件导致窗口闪烁
192        #[cfg(windows)]
193        let (command, args) = self.preprocess_npx_command_windows(command, args.clone());
194        #[cfg(not(windows))]
195        let args = args.clone();
196
197        // 使用 process-wrap 创建子进程命令(跨平台进程清理)
198        // process-wrap 会自动处理进程组(Unix)或 Job Object(Windows)
199        // 并且在 Drop 时自动清理子进程树
200        // 子进程默认继承父进程的所有环境变量
201        let mut wrapped_cmd = CommandWrap::with_new(command, |cmd| {
202            if let Some(cmd_args) = &args {
203                cmd.args(cmd_args);
204            }
205            // 设置 MCP JSON 配置中的环境变量(会覆盖继承的同名变量)
206            if let Some(env_vars) = env {
207                for (k, v) in env_vars {
208                    cmd.env(k, v);
209                }
210            }
211        });
212
213        // Unix: 创建进程组,支持 killpg 清理整个进程树
214        #[cfg(unix)]
215        wrapped_cmd.wrap(ProcessGroup::leader());
216
217        // Windows: 使用 CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP 隐藏控制台窗口
218        #[cfg(windows)]
219        {
220            use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
221            info!(
222                "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
223            );
224            wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
225            wrapped_cmd.wrap(JobObject);
226        }
227
228        // 所有平台: Drop 时自动清理进程
229        wrapped_cmd.wrap(KillOnDrop);
230
231        info!(
232            "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
233            command,
234            args.as_ref().unwrap_or(&vec![])
235        );
236
237        let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
238
239        // 诊断日志:子进程关键环境变量
240        mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
241
242        // MCP 服务通过 stdin/stdout 进行 JSON-RPC 通信,必须使用 piped(默认行为)
243        // 使用 builder 模式捕获 stderr,便于诊断子 MCP 服务初始化失败
244        let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
245            .stderr(Stdio::piped())
246            .spawn()
247            .map_err(|e| {
248                anyhow::anyhow!(
249                    "{}",
250                    mcp_common::diagnostic::format_spawn_error(mcp_id, command, &args, e)
251                )
252            })?;
253
254        // 启动 stderr 日志读取任务
255        if let Some(stderr_pipe) = child_stderr {
256            mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
257        }
258
259        let client = client_info.clone().serve(tokio_process).await?;
260
261        info!("[StreamServerBuilder] Child process connected successfully");
262        Ok(client)
263    }
264
265    /// Connect to a URL backend (remote MCP service)
266    async fn connect_url(
267        &self,
268        url: &str,
269        headers: &Option<HashMap<String, String>>,
270        client_info: &ClientInfo,
271    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
272        info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
273
274        // Build HTTP client with custom headers (excluding Authorization)
275        let mut req_headers = reqwest::header::HeaderMap::new();
276        let mut auth_header: Option<String> = None;
277
278        if let Some(config_headers) = headers {
279            for (key, value) in config_headers {
280                // Authorization header is handled separately by rmcp
281                if key.eq_ignore_ascii_case("Authorization") {
282                    auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
283                    continue;
284                }
285
286                req_headers.insert(
287                    reqwest::header::HeaderName::try_from(key)
288                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
289                    value.parse().map_err(|e| {
290                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
291                    })?,
292                );
293            }
294        }
295
296        let http_client = reqwest::Client::builder()
297            .default_headers(req_headers)
298            .build()
299            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
300
301        // Create transport configuration
302        let config = StreamableHttpClientTransportConfig {
303            uri: url.to_string().into(),
304            auth_header,
305            ..Default::default()
306        };
307
308        let transport = StreamableHttpClientTransport::with_client(http_client, config);
309        let client = client_info.clone().serve(transport).await?;
310
311        info!("[StreamServerBuilder] URL backend connected successfully");
312        Ok(client)
313    }
314
315    /// Windows 上预处理 npx 命令
316    ///
317    /// 将 `npx -y package@version` 转换为直接的 `node` 命令,
318    /// 避免使用 .cmd 批处理文件导致窗口闪烁。
319    #[cfg(windows)]
320    fn preprocess_npx_command_windows(
321        &self,
322        command: &str,
323        args: Option<Vec<String>>,
324    ) -> (String, Option<Vec<String>>) {
325        // 检测 npx 命令
326        let is_npx = command == "npx"
327            || command == "npx.cmd"
328            || command.ends_with("/npx")
329            || command.ends_with("\\npx")
330            || command.ends_with("/npx.cmd")
331            || command.ends_with("\\npx.cmd");
332
333        if !is_npx {
334            return (command.to_string(), args);
335        }
336
337        let args = match args {
338            Some(a) => a,
339            None => return (command.to_string(), None),
340        };
341
342        // 提取包名(跳过 -y 标志)
343        let package_spec = args.iter().find(|s| !s.starts_with('-') && s.contains('@'));
344
345        let Some(pkg) = package_spec else {
346            return (command.to_string(), Some(args));
347        };
348
349        // 解析包名(去掉版本号)
350        let package_name = pkg.split('@').next().unwrap_or(pkg);
351
352        // 尝试找到已安装的包
353        if let Some((node_exe, js_entry)) = self.find_npx_package_entry_windows(package_name) {
354            info!(
355                "[StreamServerBuilder] Windows npx 转换: npx {} -> node {}",
356                pkg,
357                js_entry.display()
358            );
359
360            // 构建新参数
361            let mut new_args = vec![js_entry.to_string_lossy().to_string()];
362            for arg in &args {
363                if arg != "-y" && arg != pkg {
364                    new_args.push(arg.clone());
365                }
366            }
367
368            return (node_exe.to_string_lossy().to_string(), Some(new_args));
369        }
370
371        // 未找到已安装的包,保持原样
372        info!(
373            "[StreamServerBuilder] Windows npx 未找到已安装的包: {},保持原命令",
374            pkg
375        );
376        (command.to_string(), Some(args))
377    }
378
379    /// 查找 npx 包的 node 可执行文件和 JS 入口
380    #[cfg(windows)]
381    fn find_npx_package_entry_windows(
382        &self,
383        package_name: &str,
384    ) -> Option<(std::path::PathBuf, std::path::PathBuf)> {
385        // 查找 node.exe
386        let node_exe = self.find_node_exe_windows()?;
387
388        // 在多个可能的位置查找已安装的包
389        let search_paths = self.get_npx_cache_paths_windows();
390
391        for node_modules_dir in search_paths {
392            let package_dir = node_modules_dir.join(package_name);
393            if !package_dir.exists() {
394                continue;
395            }
396
397            // 读取 package.json 查找入口
398            let package_json_path = package_dir.join("package.json");
399            if let Ok(content) = std::fs::read_to_string(&package_json_path) {
400                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
401                    // 查找 bin 字段
402                    let bin_entry = json.get("bin").and_then(|b| {
403                        if let Some(s) = b.as_str() {
404                            Some(s.to_string())
405                        } else if let Some(obj) = b.as_object() {
406                            obj.get(package_name)
407                                .or_else(|| obj.values().next())
408                                .and_then(|v| v.as_str())
409                                .map(str::to_string)
410                        } else {
411                            None
412                        }
413                    });
414
415                    if let Some(bin_entry) = bin_entry {
416                        let js_entry = package_dir.join(bin_entry);
417                        if js_entry.exists() {
418                            info!(
419                                "[StreamServerBuilder] Windows 找到包入口: {} -> {}",
420                                package_name,
421                                js_entry.display()
422                            );
423                            return Some((node_exe.clone(), js_entry));
424                        }
425                    }
426                }
427            }
428        }
429
430        None
431    }
432
433    /// 查找 node.exe 路径
434    #[cfg(windows)]
435    fn find_node_exe_windows(&self) -> Option<std::path::PathBuf> {
436        use std::path::PathBuf;
437
438        // 1. 检查环境变量
439        if let Ok(node_from_env) = std::env::var("NUWAX_NODE_EXE") {
440            let path = PathBuf::from(node_from_env);
441            if path.exists() {
442                return Some(path);
443            }
444        }
445
446        // 2. 检查应用资源目录
447        if let Ok(exe_path) = std::env::current_exe() {
448            if let Some(exe_dir) = exe_path.parent() {
449                let resource_paths = [
450                    exe_dir
451                        .join("resources")
452                        .join("node")
453                        .join("bin")
454                        .join("node.exe"),
455                    exe_dir
456                        .parent()
457                        .unwrap_or(exe_dir)
458                        .join("resources")
459                        .join("node")
460                        .join("bin")
461                        .join("node.exe"),
462                ];
463
464                for path in resource_paths {
465                    if path.exists() {
466                        return Some(path);
467                    }
468                }
469            }
470        }
471
472        // 3. 在 PATH 中查找
473        which::which("node.exe").ok()
474    }
475
476    /// 获取 npx 缓存搜索路径
477    #[cfg(windows)]
478    fn get_npx_cache_paths_windows(&self) -> Vec<std::path::PathBuf> {
479        use std::path::PathBuf;
480
481        let mut paths = Vec::new();
482
483        // npm 全局 node_modules
484        if let Ok(appdata) = std::env::var("APPDATA") {
485            let appdata_path = PathBuf::from(&appdata);
486
487            // npm 全局目录
488            paths.push(appdata_path.join("npm").join("node_modules"));
489
490            // 应用私有目录
491            paths.push(
492                appdata_path
493                    .join("com.nuwax.agent-tauri-client")
494                    .join("node_modules"),
495            );
496
497            // npx 缓存目录(npm 8.16+)
498            paths.push(appdata_path.join("npm-cache").join("_npx"));
499        }
500
501        // 应用资源目录
502        if let Ok(exe_path) = std::env::current_exe() {
503            if let Some(exe_dir) = exe_path.parent() {
504                let resource_paths = [
505                    exe_dir.join("resources").join("node").join("node_modules"),
506                    exe_dir
507                        .parent()
508                        .unwrap_or(exe_dir)
509                        .join("resources")
510                        .join("node")
511                        .join("node_modules"),
512                ];
513
514                for path in resource_paths {
515                    if path.exists() {
516                        paths.push(path);
517                    }
518                }
519            }
520        }
521
522        paths
523    }
524
525    /// Create the Streamable HTTP server
526    async fn create_server(
527        &self,
528        proxy_handler: ProxyHandler,
529    ) -> Result<(axum::Router, CancellationToken)> {
530        let handler = Arc::new(proxy_handler);
531        let ct = CancellationToken::new();
532
533        if self.server_config.stateful_mode {
534            // Stateful mode with custom session manager
535            let session_manager = ProxyAwareSessionManager::new(handler.clone());
536            let handler_for_service = handler.clone();
537
538            let service = StreamableHttpService::new(
539                move || Ok((*handler_for_service).clone()),
540                session_manager.into(),
541                StreamableHttpServerConfig {
542                    stateful_mode: true,
543                    ..Default::default()
544                },
545            );
546
547            let router = axum::Router::new().fallback_service(service);
548            Ok((router, ct))
549        } else {
550            // Stateless mode with local session manager
551            use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
552
553            let handler_for_service = handler.clone();
554
555            let service = StreamableHttpService::new(
556                move || Ok((*handler_for_service).clone()),
557                LocalSessionManager::default().into(),
558                StreamableHttpServerConfig {
559                    stateful_mode: false,
560                    ..Default::default()
561                },
562            );
563
564            let router = axum::Router::new().fallback_service(service);
565            Ok((router, ct))
566        }
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn test_builder_creation() {
576        let builder = StreamServerBuilder::new(BackendConfig::Stdio {
577            command: "echo".into(),
578            args: Some(vec!["hello".into()]),
579            env: None,
580        })
581        .mcp_id("test")
582        .stateful(true);
583
584        assert!(builder.server_config.mcp_id.is_some());
585        assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
586        assert!(builder.server_config.stateful_mode);
587    }
588
589    #[test]
590    fn test_url_backend_config() {
591        let mut headers = HashMap::new();
592        headers.insert("Authorization".into(), "Bearer token123".into());
593        headers.insert("X-Custom".into(), "value".into());
594
595        let builder = StreamServerBuilder::new(BackendConfig::Url {
596            url: "http://localhost:8080/mcp".into(),
597            headers: Some(headers),
598        });
599
600        match &builder.backend_config {
601            BackendConfig::Url { url, headers } => {
602                assert_eq!(url, "http://localhost:8080/mcp");
603                assert!(headers.is_some());
604            }
605            _ => panic!("Expected URL backend"),
606        }
607    }
608}