Skip to main content

mcp_streamable_proxy/
server_builder.rs

1//! Streamable HTTP Server Builder
2//!
3//! This module provides a high-level Builder API for creating Streamable HTTP MCP servers.
4//! It encapsulates all rmcp-specific types and provides a simple interface for mcp-proxy.
5
6use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9
10use anyhow::Result;
11use process_wrap::tokio::{CommandWrap, KillOnDrop};
12use tokio_util::sync::CancellationToken;
13use tracing::{info, warn};
14
15use rmcp::{
16    ServiceExt,
17    model::{ClientCapabilities, ClientInfo},
18    transport::{
19        TokioChildProcess,
20        streamable_http_client::{
21            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
22        },
23        streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
24    },
25};
26
27// Unix 进程组支持
28#[cfg(unix)]
29use process_wrap::tokio::ProcessGroup;
30
31// Windows 静默运行支持
32#[cfg(windows)]
33use process_wrap::tokio::{CreationFlags, JobObject};
34
35use crate::{ProxyAwareSessionManager, ProxyHandler, ToolFilter};
36pub use mcp_common::ToolFilter as CommonToolFilter;
37
38/// Backend configuration for the MCP server
39///
40/// Defines how the proxy connects to the upstream MCP service.
41#[derive(Debug, Clone)]
42pub enum BackendConfig {
43    /// Connect to a local command via stdio
44    Stdio {
45        /// Command to execute (e.g., "npx", "python", etc.)
46        command: String,
47        /// Arguments for the command
48        args: Option<Vec<String>>,
49        /// Environment variables
50        env: Option<HashMap<String, String>>,
51    },
52    /// Connect to a remote URL
53    Url {
54        /// URL of the MCP service
55        url: String,
56        /// Custom HTTP headers (including Authorization)
57        headers: Option<HashMap<String, String>>,
58    },
59}
60
61/// Configuration for the Streamable HTTP server
62#[derive(Debug, Clone, Default)]
63pub struct StreamServerConfig {
64    /// Enable stateful mode with session management
65    pub stateful_mode: bool,
66    /// MCP service identifier for logging
67    pub mcp_id: Option<String>,
68    /// Tool filter configuration
69    pub tool_filter: Option<ToolFilter>,
70}
71
72/// Builder for creating Streamable HTTP MCP servers
73///
74/// Provides a fluent API for configuring and building MCP proxy servers.
75///
76/// # Example
77///
78/// ```rust,ignore
79/// use mcp_streamable_proxy::server_builder::{StreamServerBuilder, BackendConfig};
80///
81/// // Create a server with stdio backend
82/// let (router, ct) = StreamServerBuilder::new(BackendConfig::Stdio {
83///     command: "npx".into(),
84///     args: Some(vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]),
85///     env: None,
86/// })
87/// .mcp_id("my-server")
88/// .stateful(false)
89/// .build()
90/// .await?;
91/// ```
92pub struct StreamServerBuilder {
93    backend_config: BackendConfig,
94    server_config: StreamServerConfig,
95}
96
97impl StreamServerBuilder {
98    /// Create a new builder with the given backend configuration
99    pub fn new(backend: BackendConfig) -> Self {
100        Self {
101            backend_config: backend,
102            server_config: StreamServerConfig::default(),
103        }
104    }
105
106    /// Set whether to enable stateful mode
107    ///
108    /// Stateful mode enables session management and server-side push.
109    pub fn stateful(mut self, enabled: bool) -> Self {
110        self.server_config.stateful_mode = enabled;
111        self
112    }
113
114    /// Set the MCP service identifier
115    ///
116    /// Used for logging and service identification.
117    pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
118        self.server_config.mcp_id = Some(id.into());
119        self
120    }
121
122    /// Set the tool filter configuration
123    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
124        self.server_config.tool_filter = Some(filter);
125        self
126    }
127
128    /// Build the server and return an axum Router, CancellationToken, and ProxyHandler
129    ///
130    /// The router can be merged with other axum routers or served directly.
131    /// The CancellationToken can be used to gracefully shut down the service.
132    /// The ProxyHandler can be used for status checks and management.
133    pub async fn build(self) -> Result<(axum::Router, CancellationToken, ProxyHandler)> {
134        let mcp_id = self
135            .server_config
136            .mcp_id
137            .clone()
138            .unwrap_or_else(|| "stream-proxy".into());
139
140        // Create client info for connecting to backend
141        let capabilities = ClientCapabilities::builder()
142            .enable_experimental()
143            .enable_roots()
144            .enable_roots_list_changed()
145            .enable_sampling()
146            .build();
147        let client_info = ClientInfo::new(
148            capabilities,
149            rmcp::model::Implementation::new("mcp-streamable-proxy", env!("CARGO_PKG_VERSION")),
150        );
151
152        // Connect to backend based on configuration
153        let client = match &self.backend_config {
154            BackendConfig::Stdio { command, args, env } => {
155                self.connect_stdio(command, args, env, &client_info).await?
156            }
157            BackendConfig::Url { url, headers } => {
158                self.connect_url(url, headers, &client_info).await?
159            }
160        };
161
162        // Create proxy handler
163        let proxy_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
164            ProxyHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
165        } else {
166            ProxyHandler::with_mcp_id(client, mcp_id.clone())
167        };
168
169        // Clone handler before creating server
170        let handler_for_return = proxy_handler.clone();
171
172        // Create server with configured stateful mode
173        let (router, ct) = self.create_server(proxy_handler).await?;
174
175        info!(
176            "[StreamServerBuilder] Server created - mcp_id: {}, stateful: {}",
177            mcp_id, self.server_config.stateful_mode
178        );
179
180        Ok((router, ct, handler_for_return))
181    }
182
183    /// Connect to a stdio backend (child process)
184    async fn connect_stdio(
185        &self,
186        command: &str,
187        args: &Option<Vec<String>>,
188        env: &Option<HashMap<String, String>>,
189        client_info: &ClientInfo,
190    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
191        // Windows 上预处理 npx 命令,避免 .cmd 文件导致窗口闪烁
192        #[cfg(windows)]
193        let (command, args) = self.preprocess_npx_command_windows(command, args.clone());
194        #[cfg(not(windows))]
195        let args = args.clone();
196
197        // 使用 process-wrap 创建子进程命令(跨平台进程清理)
198        // process-wrap 会自动处理进程组(Unix)或 Job Object(Windows)
199        // 并且在 Drop 时自动清理子进程树
200        let mut wrapped_cmd = CommandWrap::with_new(&command, |cmd| {
201            let (final_path, filtered_env) = mcp_common::prepare_stdio_env(env);
202            if let Some(path) = final_path {
203                cmd.env("PATH", path);
204            } else {
205                warn!("[StreamServerBuilder] PATH not available from parent process or config");
206            }
207
208            if let Some(cmd_args) = &args {
209                cmd.args(cmd_args);
210            }
211
212            if let Some(vars) = filtered_env {
213                for (k, v) in vars {
214                    cmd.env(k, v);
215                }
216            }
217        });
218
219        // Unix: 创建进程组,支持 killpg 清理整个进程树
220        #[cfg(unix)]
221        wrapped_cmd.wrap(ProcessGroup::leader());
222
223        // Windows: 使用 CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP 隐藏控制台窗口
224        #[cfg(windows)]
225        {
226            use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
227            info!(
228                "[StreamServerBuilder] Setting CreationFlags: CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP"
229            );
230            wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
231            wrapped_cmd.wrap(JobObject);
232        }
233
234        // 所有平台: Drop 时自动清理进程
235        wrapped_cmd.wrap(KillOnDrop);
236
237        info!(
238            "[StreamServerBuilder] Starting child process - command: {}, args: {:?}",
239            command,
240            args.as_ref().unwrap_or(&vec![])
241        );
242
243        let mcp_id = self.server_config.mcp_id.as_deref().unwrap_or("unknown");
244
245        // 诊断日志:子进程关键环境变量
246        mcp_common::diagnostic::log_stdio_spawn_context("StreamServerBuilder", mcp_id, env);
247
248        // MCP 服务通过 stdin/stdout 进行 JSON-RPC 通信,必须使用 piped(默认行为)
249        // 使用 builder 模式捕获 stderr,便于诊断子 MCP 服务初始化失败
250        let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
251            .stderr(Stdio::piped())
252            .spawn()
253            .map_err(|e| {
254                anyhow::anyhow!(
255                    "{}",
256                    mcp_common::diagnostic::format_spawn_error(mcp_id, &command, &args, e)
257                )
258            })?;
259
260        // 启动 stderr 日志读取任务
261        if let Some(stderr_pipe) = child_stderr {
262            mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.to_string());
263        }
264
265        let client = client_info.clone().serve(tokio_process).await?;
266
267        info!("[StreamServerBuilder] Child process connected successfully");
268        Ok(client)
269    }
270
271    /// Connect to a URL backend (remote MCP service)
272    async fn connect_url(
273        &self,
274        url: &str,
275        headers: &Option<HashMap<String, String>>,
276        client_info: &ClientInfo,
277    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
278        info!("[StreamServerBuilder] Connecting to URL backend: {}", url);
279
280        // Build HTTP client with custom headers (excluding Authorization)
281        let mut req_headers = reqwest::header::HeaderMap::new();
282        let mut auth_header: Option<String> = None;
283
284        if let Some(config_headers) = headers {
285            for (key, value) in config_headers {
286                // Authorization header is handled separately by rmcp
287                if key.eq_ignore_ascii_case("Authorization") {
288                    auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
289                    continue;
290                }
291
292                req_headers.insert(
293                    reqwest::header::HeaderName::try_from(key)
294                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
295                    value.parse().map_err(|e| {
296                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
297                    })?,
298                );
299            }
300        }
301
302        let http_client = reqwest::Client::builder()
303            .default_headers(req_headers)
304            .build()
305            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
306
307        // Create transport configuration
308        let config = StreamableHttpClientTransportConfig {
309            uri: url.to_string().into(),
310            auth_header,
311            ..Default::default()
312        };
313
314        let transport = StreamableHttpClientTransport::with_client(http_client, config);
315        let client = client_info.clone().serve(transport).await?;
316
317        info!("[StreamServerBuilder] URL backend connected successfully");
318        Ok(client)
319    }
320
321    /// Windows 上预处理 npx 命令
322    ///
323    /// 将 `npx -y package@version` 转换为直接的 `node` 命令,
324    /// 避免使用 .cmd 批处理文件导致窗口闪烁。
325    #[cfg(windows)]
326    fn preprocess_npx_command_windows(
327        &self,
328        command: &str,
329        args: Option<Vec<String>>,
330    ) -> (String, Option<Vec<String>>) {
331        // 检测 npx 命令
332        let is_npx = command == "npx"
333            || command == "npx.cmd"
334            || command.ends_with("/npx")
335            || command.ends_with("\\npx")
336            || command.ends_with("/npx.cmd")
337            || command.ends_with("\\npx.cmd");
338
339        if !is_npx {
340            return (command.to_string(), args);
341        }
342
343        let args = match args {
344            Some(a) => a,
345            None => return (command.to_string(), None),
346        };
347
348        // 提取包名(跳过 -y 标志)
349        let package_spec = args.iter().find(|s| !s.starts_with('-') && s.contains('@'));
350
351        let Some(pkg) = package_spec else {
352            return (command.to_string(), Some(args));
353        };
354
355        // 解析包名(去掉版本号)
356        let package_name = pkg.split('@').next().unwrap_or(pkg);
357
358        // 尝试找到已安装的包
359        if let Some((node_exe, js_entry)) = self.find_npx_package_entry_windows(package_name) {
360            info!(
361                "[StreamServerBuilder] Windows npx 转换: npx {} -> node {}",
362                pkg,
363                js_entry.display()
364            );
365
366            // 构建新参数
367            let mut new_args = vec![js_entry.to_string_lossy().to_string()];
368            for arg in &args {
369                if arg != "-y" && arg != pkg {
370                    new_args.push(arg.clone());
371                }
372            }
373
374            return (node_exe.to_string_lossy().to_string(), Some(new_args));
375        }
376
377        // 未找到已安装的包,保持原样
378        info!(
379            "[StreamServerBuilder] Windows npx 未找到已安装的包: {},保持原命令",
380            pkg
381        );
382        (command.to_string(), Some(args))
383    }
384
385    /// 查找 npx 包的 node 可执行文件和 JS 入口
386    #[cfg(windows)]
387    fn find_npx_package_entry_windows(
388        &self,
389        package_name: &str,
390    ) -> Option<(std::path::PathBuf, std::path::PathBuf)> {
391        // 查找 node.exe
392        let node_exe = self.find_node_exe_windows()?;
393
394        // 在多个可能的位置查找已安装的包
395        let search_paths = self.get_npx_cache_paths_windows();
396
397        for node_modules_dir in search_paths {
398            let package_dir = node_modules_dir.join(package_name);
399            if !package_dir.exists() {
400                continue;
401            }
402
403            // 读取 package.json 查找入口
404            let package_json_path = package_dir.join("package.json");
405            if let Ok(content) = std::fs::read_to_string(&package_json_path) {
406                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
407                    // 查找 bin 字段
408                    let bin_entry = json.get("bin").and_then(|b| {
409                        if let Some(s) = b.as_str() {
410                            Some(s.to_string())
411                        } else if let Some(obj) = b.as_object() {
412                            obj.get(package_name)
413                                .or_else(|| obj.values().next())
414                                .and_then(|v| v.as_str())
415                                .map(str::to_string)
416                        } else {
417                            None
418                        }
419                    });
420
421                    if let Some(bin_entry) = bin_entry {
422                        let js_entry = package_dir.join(bin_entry);
423                        if js_entry.exists() {
424                            info!(
425                                "[StreamServerBuilder] Windows 找到包入口: {} -> {}",
426                                package_name,
427                                js_entry.display()
428                            );
429                            return Some((node_exe.clone(), js_entry));
430                        }
431                    }
432                }
433            }
434        }
435
436        None
437    }
438
439    /// 查找 node.exe 路径
440    #[cfg(windows)]
441    fn find_node_exe_windows(&self) -> Option<std::path::PathBuf> {
442        use std::path::PathBuf;
443
444        // 1. 检查环境变量
445        if let Ok(node_from_env) = std::env::var("NUWAX_NODE_EXE") {
446            let path = PathBuf::from(node_from_env);
447            if path.exists() {
448                return Some(path);
449            }
450        }
451
452        // 2. 检查应用资源目录
453        if let Ok(exe_path) = std::env::current_exe() {
454            if let Some(exe_dir) = exe_path.parent() {
455                let resource_paths = [
456                    exe_dir
457                        .join("resources")
458                        .join("node")
459                        .join("bin")
460                        .join("node.exe"),
461                    exe_dir
462                        .parent()
463                        .unwrap_or(exe_dir)
464                        .join("resources")
465                        .join("node")
466                        .join("bin")
467                        .join("node.exe"),
468                ];
469
470                for path in resource_paths {
471                    if path.exists() {
472                        return Some(path);
473                    }
474                }
475            }
476        }
477
478        // 3. 在 PATH 中查找
479        which::which("node.exe").ok()
480    }
481
482    /// 获取 npx 缓存搜索路径
483    #[cfg(windows)]
484    fn get_npx_cache_paths_windows(&self) -> Vec<std::path::PathBuf> {
485        use std::path::PathBuf;
486
487        let mut paths = Vec::new();
488
489        // npm 全局 node_modules
490        if let Ok(appdata) = std::env::var("APPDATA") {
491            let appdata_path = PathBuf::from(&appdata);
492
493            // npm 全局目录
494            paths.push(appdata_path.join("npm").join("node_modules"));
495
496            // 应用私有目录
497            paths.push(
498                appdata_path
499                    .join("com.nuwax.agent-tauri-client")
500                    .join("node_modules"),
501            );
502
503            // npx 缓存目录(npm 8.16+)
504            paths.push(appdata_path.join("npm-cache").join("_npx"));
505        }
506
507        // 应用资源目录
508        if let Ok(exe_path) = std::env::current_exe() {
509            if let Some(exe_dir) = exe_path.parent() {
510                let resource_paths = [
511                    exe_dir.join("resources").join("node").join("node_modules"),
512                    exe_dir
513                        .parent()
514                        .unwrap_or(exe_dir)
515                        .join("resources")
516                        .join("node")
517                        .join("node_modules"),
518                ];
519
520                for path in resource_paths {
521                    if path.exists() {
522                        paths.push(path);
523                    }
524                }
525            }
526        }
527
528        paths
529    }
530
531    /// Create the Streamable HTTP server
532    async fn create_server(
533        &self,
534        proxy_handler: ProxyHandler,
535    ) -> Result<(axum::Router, CancellationToken)> {
536        let handler = Arc::new(proxy_handler);
537        let ct = CancellationToken::new();
538
539        if self.server_config.stateful_mode {
540            // Stateful mode with custom session manager
541            let session_manager = ProxyAwareSessionManager::new(handler.clone());
542            let handler_for_service = handler.clone();
543
544            let service = StreamableHttpService::new(
545                move || Ok((*handler_for_service).clone()),
546                session_manager.into(),
547                StreamableHttpServerConfig {
548                    stateful_mode: true,
549                    ..Default::default()
550                },
551            );
552
553            let router = axum::Router::new().fallback_service(service);
554            Ok((router, ct))
555        } else {
556            // Stateless mode with local session manager
557            use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
558
559            let handler_for_service = handler.clone();
560
561            let service = StreamableHttpService::new(
562                move || Ok((*handler_for_service).clone()),
563                LocalSessionManager::default().into(),
564                StreamableHttpServerConfig {
565                    stateful_mode: false,
566                    ..Default::default()
567                },
568            );
569
570            let router = axum::Router::new().fallback_service(service);
571            Ok((router, ct))
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_builder_creation() {
582        let builder = StreamServerBuilder::new(BackendConfig::Stdio {
583            command: "echo".into(),
584            args: Some(vec!["hello".into()]),
585            env: None,
586        })
587        .mcp_id("test")
588        .stateful(true);
589
590        assert!(builder.server_config.mcp_id.is_some());
591        assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
592        assert!(builder.server_config.stateful_mode);
593    }
594
595    #[test]
596    fn test_url_backend_config() {
597        let mut headers = HashMap::new();
598        headers.insert("Authorization".into(), "Bearer token123".into());
599        headers.insert("X-Custom".into(), "value".into());
600
601        let builder = StreamServerBuilder::new(BackendConfig::Url {
602            url: "http://localhost:8080/mcp".into(),
603            headers: Some(headers),
604        });
605
606        match &builder.backend_config {
607            BackendConfig::Url { url, headers } => {
608                assert_eq!(url, "http://localhost:8080/mcp");
609                assert!(headers.is_some());
610            }
611            _ => panic!("Expected URL backend"),
612        }
613    }
614}