Skip to main content

mcp_sse_proxy/
server_builder.rs

1//! SSE Server Builder
2//!
3//! This module provides a high-level Builder API for creating SSE MCP servers.
4//! It encapsulates all rmcp-specific types and provides a simple interface for mcp-proxy.
5
6use std::collections::HashMap;
7use std::time::Duration;
8
9use anyhow::Result;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, info, warn};
12
13// 进程组管理(跨平台子进程清理)
14use process_wrap::tokio::{KillOnDrop, TokioCommandWrap};
15
16#[cfg(unix)]
17use process_wrap::tokio::ProcessGroup;
18
19#[cfg(windows)]
20use process_wrap::tokio::JobObject;
21
22use rmcp::{
23    ServiceExt,
24    model::{ClientCapabilities, ClientInfo, ProtocolVersion},
25    transport::{
26        SseClientTransport, TokioChildProcess,
27        sse_client::SseClientConfig,
28        sse_server::{SseServer, SseServerConfig},
29        streamable_http_client::{
30            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
31        },
32    },
33};
34
35use crate::{SseHandler, ToolFilter};
36
37/// Performance warning threshold for stdio (child process) backend connections
38const STDIO_SLOW_THRESHOLD_SECS: u64 = 30;
39
40/// Performance warning threshold for HTTP-based backend connections (SSE/Stream)
41const HTTP_SLOW_THRESHOLD_SECS: u64 = 10;
42
43/// Backend configuration for the MCP server
44///
45/// Defines how the proxy connects to the upstream MCP service.
46#[derive(Debug, Clone)]
47pub enum BackendConfig {
48    /// Connect to a local command via stdio
49    Stdio {
50        /// Command to execute (e.g., "npx", "python", etc.)
51        command: String,
52        /// Arguments for the command
53        args: Option<Vec<String>>,
54        /// Environment variables
55        env: Option<HashMap<String, String>>,
56    },
57    /// Connect to a remote URL using SSE protocol
58    SseUrl {
59        /// URL of the MCP SSE service
60        url: String,
61        /// Custom HTTP headers (including Authorization)
62        headers: Option<HashMap<String, String>>,
63    },
64    /// Connect to a remote URL using Streamable HTTP protocol
65    /// (for protocol conversion: Stream backend -> SSE frontend)
66    StreamUrl {
67        /// URL of the MCP Streamable HTTP service
68        url: String,
69        /// Custom HTTP headers (including Authorization)
70        headers: Option<HashMap<String, String>>,
71    },
72}
73
74/// Configuration for the SSE server
75#[derive(Debug, Clone)]
76pub struct SseServerBuilderConfig {
77    /// SSE endpoint path (default: "/sse")
78    pub sse_path: String,
79    /// Message endpoint path (default: "/message")
80    pub post_path: String,
81    /// MCP service identifier for logging
82    pub mcp_id: Option<String>,
83    /// Tool filter configuration
84    pub tool_filter: Option<ToolFilter>,
85    /// Keep-alive interval in seconds (default: 15)
86    pub keep_alive_secs: u64,
87    /// Enable stateful mode with full MCP initialization (default: true)
88    /// When false, uses `with_service_directly` which skips initialization for faster responses
89    pub stateful: bool,
90}
91
92impl Default for SseServerBuilderConfig {
93    fn default() -> Self {
94        Self {
95            sse_path: "/sse".into(),
96            post_path: "/message".into(),
97            mcp_id: None,
98            tool_filter: None,
99            keep_alive_secs: 15,
100            stateful: true,
101        }
102    }
103}
104
105/// Log connection timing with optional performance warning
106///
107/// # Arguments
108///
109/// * `mcp_id` - MCP service identifier
110/// * `backend_type` - Type of backend (e.g., "stdio", "SSE", "Streamable HTTP")
111/// * `total_duration` - Total connection time
112/// * `breakdown` - Optional breakdown of timing components
113/// * `warn_threshold_secs` - Threshold for performance warning
114/// * `warn_message` - Message to show if threshold exceeded
115fn log_connection_timing(
116    mcp_id: &str,
117    backend_type: &str,
118    total_duration: Duration,
119    breakdown: &[(&str, Duration)],
120    warn_threshold_secs: u64,
121    warn_message: &str,
122) {
123    let breakdown_str: Vec<String> = breakdown
124        .iter()
125        .map(|(name, dur)| format!("{}: {:?}", name, dur))
126        .collect();
127
128    info!(
129        "[SseServerBuilder] {} backend connected successfully - MCP ID: {}, total: {:?} ({})",
130        backend_type,
131        mcp_id,
132        total_duration,
133        breakdown_str.join(", ")
134    );
135
136    if total_duration.as_secs() >= warn_threshold_secs {
137        warn!(
138            "[SseServerBuilder] {} 后端连接耗时较长 - MCP ID: {}, 耗时: {:?}, {}",
139            backend_type, mcp_id, total_duration, warn_message
140        );
141    }
142}
143
144/// Builder for creating SSE MCP servers
145///
146/// Provides a fluent API for configuring and building MCP proxy servers.
147///
148/// # Example
149///
150/// ```rust,ignore
151/// use mcp_sse_proxy::server_builder::{SseServerBuilder, BackendConfig};
152///
153/// // Create a server with stdio backend
154/// let (router, ct) = SseServerBuilder::new(BackendConfig::Stdio {
155///     command: "npx".into(),
156///     args: Some(vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]),
157///     env: None,
158/// })
159/// .mcp_id("my-server")
160/// .sse_path("/custom/sse")
161/// .post_path("/custom/message")
162/// .stateful(false)  // Disable stateful mode for OneShot services (faster responses)
163/// .build()
164/// .await?;
165/// ```
166pub struct SseServerBuilder {
167    backend_config: BackendConfig,
168    server_config: SseServerBuilderConfig,
169}
170
171impl SseServerBuilder {
172    /// Create a new builder with the given backend configuration
173    pub fn new(backend: BackendConfig) -> Self {
174        Self {
175            backend_config: backend,
176            server_config: SseServerBuilderConfig::default(),
177        }
178    }
179
180    /// Set the SSE endpoint path
181    pub fn sse_path(mut self, path: impl Into<String>) -> Self {
182        self.server_config.sse_path = path.into();
183        self
184    }
185
186    /// Set the message endpoint path
187    pub fn post_path(mut self, path: impl Into<String>) -> Self {
188        self.server_config.post_path = path.into();
189        self
190    }
191
192    /// Set the MCP service identifier
193    ///
194    /// Used for logging and service identification.
195    pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
196        self.server_config.mcp_id = Some(id.into());
197        self
198    }
199
200    /// Set the tool filter configuration
201    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
202        self.server_config.tool_filter = Some(filter);
203        self
204    }
205
206    /// Set the keep-alive interval in seconds
207    pub fn keep_alive(mut self, secs: u64) -> Self {
208        self.server_config.keep_alive_secs = secs;
209        self
210    }
211
212    /// Set stateful mode (default: true)
213    ///
214    /// When false, uses `with_service_directly` which skips MCP initialization
215    /// for faster responses. This is recommended for OneShot services.
216    pub fn stateful(mut self, stateful: bool) -> Self {
217        self.server_config.stateful = stateful;
218        self
219    }
220
221    /// Build the server and return an axum Router, CancellationToken, and SseHandler
222    ///
223    /// The router can be merged with other axum routers or served directly.
224    /// The CancellationToken can be used to gracefully shut down the service.
225    /// The SseHandler can be used for status checks and management.
226    pub async fn build(self) -> Result<(axum::Router, CancellationToken, SseHandler)> {
227        let mcp_id = self
228            .server_config
229            .mcp_id
230            .clone()
231            .unwrap_or_else(|| "sse-proxy".into());
232
233        // Create client info for connecting to backend
234        let client_info = ClientInfo {
235            protocol_version: ProtocolVersion::V_2024_11_05,
236            capabilities: ClientCapabilities::builder()
237                .enable_experimental()
238                .enable_roots()
239                .enable_roots_list_changed()
240                .enable_sampling()
241                .build(),
242            ..Default::default()
243        };
244
245        // Connect to backend based on configuration
246        let client = match &self.backend_config {
247            BackendConfig::Stdio { command, args, env } => {
248                self.connect_stdio(command, args, env, &client_info).await?
249            }
250            BackendConfig::SseUrl { url, headers } => {
251                self.connect_sse_url(url, headers, &client_info).await?
252            }
253            BackendConfig::StreamUrl { url, headers } => {
254                self.connect_stream_url(url, headers, &client_info).await?
255            }
256        };
257
258        // Create SSE handler
259        let sse_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
260            SseHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
261        } else {
262            SseHandler::with_mcp_id(client, mcp_id.clone())
263        };
264
265        // Clone handler before creating server (create_server uses sse_handler.clone() internally)
266        let handler_for_return = sse_handler.clone();
267
268        // Create SSE server
269        let (router, ct) = self.create_server(sse_handler)?;
270
271        info!(
272            "[SseServerBuilder] Server created - mcp_id: {}, sse_path: {}, post_path: {}",
273            mcp_id, self.server_config.sse_path, self.server_config.post_path
274        );
275
276        Ok((router, ct, handler_for_return))
277    }
278
279    /// Connect to a stdio backend (child process)
280    async fn connect_stdio(
281        &self,
282        command: &str,
283        args: &Option<Vec<String>>,
284        env: &Option<HashMap<String, String>>,
285        client_info: &ClientInfo,
286    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
287        use std::time::Instant;
288
289        let start_time = Instant::now();
290        let mcp_id = self
291            .server_config
292            .mcp_id
293            .clone()
294            .unwrap_or_else(|| "unknown".into());
295
296        // 使用 process-wrap 创建子进程命令(跨平台进程清理)
297        // process-wrap 会自动处理进程组(Unix)或 Job Object(Windows)
298        // 并且在 Drop 时自动清理子进程树
299        // 子进程默认继承父进程的所有环境变量
300        let mut wrapped_cmd = TokioCommandWrap::with_new(command, |cmd| {
301            if let Some(cmd_args) = args {
302                cmd.args(cmd_args);
303            }
304            // 设置 MCP JSON 配置中的环境变量(会覆盖继承的同名变量)
305            if let Some(env_vars) = env {
306                for (k, v) in env_vars {
307                    cmd.env(k, v);
308                }
309            }
310        });
311
312        // Unix: 创建进程组,支持 killpg 清理整个进程树
313        #[cfg(unix)]
314        wrapped_cmd.wrap(ProcessGroup::leader());
315        // Windows: 使用 CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP 隐藏控制台窗口
316        #[cfg(windows)]
317        {
318            use process_wrap::tokio::CreationFlags;
319            use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
320            wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
321            wrapped_cmd.wrap(JobObject);
322        }
323
324        // 所有平台: Drop 时自动清理进程
325        wrapped_cmd.wrap(KillOnDrop);
326
327        info!(
328            "[SseServerBuilder] Starting child process - MCP ID: {}, command: {}, args: {:?}",
329            mcp_id,
330            command,
331            args.as_ref().unwrap_or(&vec![])
332        );
333
334        // 诊断日志:子进程关键环境变量
335        mcp_common::diagnostic::log_stdio_spawn_context("SseServerBuilder", &mcp_id, env);
336
337        let process_start = Instant::now();
338        // MCP 服务通过 stdin/stdout 进行 JSON-RPC 通信,必须使用 piped(默认行为)
339        // 使用 builder 模式捕获 stderr,便于诊断子 MCP 服务初始化失败
340        let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
341            .stderr(std::process::Stdio::piped())
342            .spawn()
343            .map_err(|e| {
344                anyhow::anyhow!(
345                    "{}",
346                    mcp_common::diagnostic::format_spawn_error(&mcp_id, command, args, e)
347                )
348            })?;
349
350        // 启动 stderr 日志读取任务
351        if let Some(stderr_pipe) = child_stderr {
352            mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.clone());
353        }
354
355        let process_duration = process_start.elapsed();
356
357        debug!(
358            "[SseServerBuilder] Child process spawned - MCP ID: {}, spawn time: {:?}",
359            mcp_id, process_duration
360        );
361
362        let serve_start = Instant::now();
363        let client = client_info.clone().serve(tokio_process).await?;
364        let serve_duration = serve_start.elapsed();
365        let total_duration = start_time.elapsed();
366
367        let warn_msg = "建议的优化方案: \
368            1) 检查网络连接速度 (npm 包下载) \
369            2) 配置国内 npm 镜像 (如淘宝镜像: npm config set registry https://registry.npmmirror.com) \
370            3) 预热服务 (启动 mcp-proxy 时预先加载常用服务) \
371            4) 检查命令参数是否正确";
372
373        log_connection_timing(
374            &mcp_id,
375            "Stdio",
376            total_duration,
377            &[("spawn", process_duration), ("serve", serve_duration)],
378            STDIO_SLOW_THRESHOLD_SECS,
379            warn_msg,
380        );
381
382        Ok(client)
383    }
384
385    /// Connect to an SSE URL backend
386    async fn connect_sse_url(
387        &self,
388        url: &str,
389        headers: &Option<HashMap<String, String>>,
390        client_info: &ClientInfo,
391    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
392        use std::time::Instant;
393
394        let start_time = Instant::now();
395        let mcp_id = self
396            .server_config
397            .mcp_id
398            .clone()
399            .unwrap_or_else(|| "unknown".into());
400
401        info!(
402            "[SseServerBuilder] Connecting to SSE URL backend - MCP ID: {}, URL: {}",
403            mcp_id, url
404        );
405
406        // Build HTTP client with custom headers
407        let mut req_headers = reqwest::header::HeaderMap::new();
408
409        if let Some(config_headers) = headers {
410            for (key, value) in config_headers {
411                req_headers.insert(
412                    reqwest::header::HeaderName::try_from(key)
413                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
414                    value.parse().map_err(|e| {
415                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
416                    })?,
417                );
418            }
419        }
420
421        let http_client = reqwest::Client::builder()
422            .default_headers(req_headers)
423            .build()
424            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
425
426        // Create SSE client configuration
427        let sse_config = SseClientConfig {
428            sse_endpoint: url.to_string().into(),
429            ..Default::default()
430        };
431
432        let transport_start = Instant::now();
433        let sse_transport = SseClientTransport::start_with_client(http_client, sse_config).await?;
434        let transport_duration = transport_start.elapsed();
435
436        let serve_start = Instant::now();
437        let client = client_info.clone().serve(sse_transport).await?;
438        let serve_duration = serve_start.elapsed();
439        let total_duration = start_time.elapsed();
440
441        log_connection_timing(
442            &mcp_id,
443            "SSE",
444            total_duration,
445            &[("transport", transport_duration), ("serve", serve_duration)],
446            HTTP_SLOW_THRESHOLD_SECS,
447            "建议: 检查网络连接和后端服务状态",
448        );
449
450        Ok(client)
451    }
452
453    /// Connect to a Streamable HTTP URL backend
454    async fn connect_stream_url(
455        &self,
456        url: &str,
457        headers: &Option<HashMap<String, String>>,
458        client_info: &ClientInfo,
459    ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
460        use std::time::Instant;
461
462        let start_time = Instant::now();
463        let mcp_id = self
464            .server_config
465            .mcp_id
466            .clone()
467            .unwrap_or_else(|| "unknown".into());
468
469        info!(
470            "[SseServerBuilder] Connecting to Streamable HTTP URL backend - MCP ID: {}, URL: {}",
471            mcp_id, url
472        );
473
474        // Build HTTP client with custom headers (excluding Authorization)
475        let mut req_headers = reqwest::header::HeaderMap::new();
476        let mut auth_header: Option<String> = None;
477
478        if let Some(config_headers) = headers {
479            for (key, value) in config_headers {
480                // Authorization header is handled separately by rmcp
481                if key.eq_ignore_ascii_case("Authorization") {
482                    auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
483                    continue;
484                }
485
486                req_headers.insert(
487                    reqwest::header::HeaderName::try_from(key)
488                        .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
489                    value.parse().map_err(|e| {
490                        anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
491                    })?,
492                );
493            }
494        }
495
496        let http_client = reqwest::Client::builder()
497            .default_headers(req_headers)
498            .build()
499            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
500
501        // Create transport configuration
502        let config = StreamableHttpClientTransportConfig {
503            uri: url.to_string().into(),
504            auth_header,
505            ..Default::default()
506        };
507
508        let serve_start = Instant::now();
509        let transport = StreamableHttpClientTransport::with_client(http_client, config);
510        let client = client_info.clone().serve(transport).await?;
511        let serve_duration = serve_start.elapsed();
512        let total_duration = start_time.elapsed();
513
514        log_connection_timing(
515            &mcp_id,
516            "Streamable HTTP",
517            total_duration,
518            &[("serve", serve_duration)],
519            HTTP_SLOW_THRESHOLD_SECS,
520            "建议: 检查网络连接和后端服务状态",
521        );
522
523        Ok(client)
524    }
525
526    /// Create the SSE server
527    fn create_server(&self, sse_handler: SseHandler) -> Result<(axum::Router, CancellationToken)> {
528        // SSE server uses bind address 0.0.0.0:0 since we're returning a router
529        // The actual binding will be done by the caller
530        let config = SseServerConfig {
531            bind: "0.0.0.0:0".parse()?,
532            sse_path: self.server_config.sse_path.clone(),
533            post_path: self.server_config.post_path.clone(),
534            ct: CancellationToken::new(),
535            sse_keep_alive: Some(std::time::Duration::from_secs(
536                self.server_config.keep_alive_secs,
537            )),
538        };
539
540        let (sse_server, router) = SseServer::new(config);
541
542        // Use with_service_directly for non-stateful mode (OneShot services)
543        // This skips MCP initialization for faster responses
544        let ct = if self.server_config.stateful {
545            sse_server.with_service(move || sse_handler.clone())
546        } else {
547            sse_server.with_service_directly(move || sse_handler.clone())
548        };
549
550        Ok((router, ct))
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_builder_creation() {
560        let builder = SseServerBuilder::new(BackendConfig::Stdio {
561            command: "echo".into(),
562            args: Some(vec!["hello".into()]),
563            env: None,
564        })
565        .mcp_id("test")
566        .sse_path("/custom/sse")
567        .post_path("/custom/message");
568
569        assert!(builder.server_config.mcp_id.is_some());
570        assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
571        assert_eq!(builder.server_config.sse_path, "/custom/sse");
572        assert_eq!(builder.server_config.post_path, "/custom/message");
573    }
574
575    #[test]
576    fn test_default_config() {
577        let config = SseServerBuilderConfig::default();
578        assert_eq!(config.sse_path, "/sse");
579        assert_eq!(config.post_path, "/message");
580        assert_eq!(config.keep_alive_secs, 15);
581        assert!(
582            config.stateful,
583            "default stateful should be true for backward compatibility"
584        );
585    }
586
587    #[test]
588    fn test_stateful_flag_default() {
589        let builder = SseServerBuilder::new(BackendConfig::Stdio {
590            command: "echo".into(),
591            args: None,
592            env: None,
593        });
594        assert!(
595            builder.server_config.stateful,
596            "stateful should default to true"
597        );
598    }
599
600    #[test]
601    fn test_stateful_flag_disabled() {
602        let builder = SseServerBuilder::new(BackendConfig::Stdio {
603            command: "echo".into(),
604            args: None,
605            env: None,
606        })
607        .stateful(false);
608        assert!(
609            !builder.server_config.stateful,
610            "stateful should be false when set"
611        );
612    }
613
614    #[test]
615    fn test_stateful_flag_enabled() {
616        let builder = SseServerBuilder::new(BackendConfig::Stdio {
617            command: "echo".into(),
618            args: None,
619            env: None,
620        })
621        .stateful(true);
622        assert!(
623            builder.server_config.stateful,
624            "stateful should be true when set"
625        );
626    }
627
628    #[test]
629    fn test_timing_constants() {
630        assert_eq!(STDIO_SLOW_THRESHOLD_SECS, 30);
631        assert_eq!(HTTP_SLOW_THRESHOLD_SECS, 10);
632    }
633
634    #[test]
635    fn test_log_connection_timing_format() {
636        use std::time::Duration;
637        // Test that the function doesn't panic and formats correctly
638        log_connection_timing(
639            "test-mcp",
640            "TestBackend",
641            Duration::from_millis(1500),
642            &[
643                ("step1", Duration::from_millis(500)),
644                ("step2", Duration::from_millis(1000)),
645            ],
646            10,
647            "Test warning message",
648        );
649        // If we get here, the function works correctly
650    }
651
652    #[test]
653    fn test_log_connection_timing_no_breakdown() {
654        use std::time::Duration;
655        // Test with empty breakdown
656        log_connection_timing(
657            "test-mcp",
658            "TestBackend",
659            Duration::from_millis(500),
660            &[],
661            10,
662            "Test warning message",
663        );
664    }
665}