1use std::collections::HashMap;
7use std::time::Duration;
8
9use anyhow::Result;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, info, warn};
12
13use process_wrap::tokio::{KillOnDrop, TokioCommandWrap};
15
16#[cfg(unix)]
17use process_wrap::tokio::ProcessGroup;
18
19#[cfg(windows)]
20use process_wrap::tokio::JobObject;
21
22use rmcp::{
23 ServiceExt,
24 model::{ClientCapabilities, ClientInfo, ProtocolVersion},
25 transport::{
26 SseClientTransport, TokioChildProcess,
27 sse_client::SseClientConfig,
28 sse_server::{SseServer, SseServerConfig},
29 streamable_http_client::{
30 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
31 },
32 },
33};
34
35use crate::{SseHandler, ToolFilter};
36
37const STDIO_SLOW_THRESHOLD_SECS: u64 = 30;
39
40const HTTP_SLOW_THRESHOLD_SECS: u64 = 10;
42
43#[derive(Debug, Clone)]
47pub enum BackendConfig {
48 Stdio {
50 command: String,
52 args: Option<Vec<String>>,
54 env: Option<HashMap<String, String>>,
56 },
57 SseUrl {
59 url: String,
61 headers: Option<HashMap<String, String>>,
63 },
64 StreamUrl {
67 url: String,
69 headers: Option<HashMap<String, String>>,
71 },
72}
73
74#[derive(Debug, Clone)]
76pub struct SseServerBuilderConfig {
77 pub sse_path: String,
79 pub post_path: String,
81 pub mcp_id: Option<String>,
83 pub tool_filter: Option<ToolFilter>,
85 pub keep_alive_secs: u64,
87 pub stateful: bool,
90}
91
92impl Default for SseServerBuilderConfig {
93 fn default() -> Self {
94 Self {
95 sse_path: "/sse".into(),
96 post_path: "/message".into(),
97 mcp_id: None,
98 tool_filter: None,
99 keep_alive_secs: 15,
100 stateful: true,
101 }
102 }
103}
104
105fn log_connection_timing(
116 mcp_id: &str,
117 backend_type: &str,
118 total_duration: Duration,
119 breakdown: &[(&str, Duration)],
120 warn_threshold_secs: u64,
121 warn_message: &str,
122) {
123 let breakdown_str: Vec<String> = breakdown
124 .iter()
125 .map(|(name, dur)| format!("{}: {:?}", name, dur))
126 .collect();
127
128 info!(
129 "[SseServerBuilder] {} backend connected successfully - MCP ID: {}, total: {:?} ({})",
130 backend_type,
131 mcp_id,
132 total_duration,
133 breakdown_str.join(", ")
134 );
135
136 if total_duration.as_secs() >= warn_threshold_secs {
137 warn!(
138 "[SseServerBuilder] {} 后端连接耗时较长 - MCP ID: {}, 耗时: {:?}, {}",
139 backend_type, mcp_id, total_duration, warn_message
140 );
141 }
142}
143
144pub struct SseServerBuilder {
167 backend_config: BackendConfig,
168 server_config: SseServerBuilderConfig,
169}
170
171impl SseServerBuilder {
172 pub fn new(backend: BackendConfig) -> Self {
174 Self {
175 backend_config: backend,
176 server_config: SseServerBuilderConfig::default(),
177 }
178 }
179
180 pub fn sse_path(mut self, path: impl Into<String>) -> Self {
182 self.server_config.sse_path = path.into();
183 self
184 }
185
186 pub fn post_path(mut self, path: impl Into<String>) -> Self {
188 self.server_config.post_path = path.into();
189 self
190 }
191
192 pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
196 self.server_config.mcp_id = Some(id.into());
197 self
198 }
199
200 pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
202 self.server_config.tool_filter = Some(filter);
203 self
204 }
205
206 pub fn keep_alive(mut self, secs: u64) -> Self {
208 self.server_config.keep_alive_secs = secs;
209 self
210 }
211
212 pub fn stateful(mut self, stateful: bool) -> Self {
217 self.server_config.stateful = stateful;
218 self
219 }
220
221 pub async fn build(self) -> Result<(axum::Router, CancellationToken, SseHandler)> {
227 let mcp_id = self
228 .server_config
229 .mcp_id
230 .clone()
231 .unwrap_or_else(|| "sse-proxy".into());
232
233 let client_info = ClientInfo {
235 protocol_version: ProtocolVersion::V_2024_11_05,
236 capabilities: ClientCapabilities::builder()
237 .enable_experimental()
238 .enable_roots()
239 .enable_roots_list_changed()
240 .enable_sampling()
241 .build(),
242 ..Default::default()
243 };
244
245 let client = match &self.backend_config {
247 BackendConfig::Stdio { command, args, env } => {
248 self.connect_stdio(command, args, env, &client_info).await?
249 }
250 BackendConfig::SseUrl { url, headers } => {
251 self.connect_sse_url(url, headers, &client_info).await?
252 }
253 BackendConfig::StreamUrl { url, headers } => {
254 self.connect_stream_url(url, headers, &client_info).await?
255 }
256 };
257
258 let sse_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
260 SseHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
261 } else {
262 SseHandler::with_mcp_id(client, mcp_id.clone())
263 };
264
265 let handler_for_return = sse_handler.clone();
267
268 let (router, ct) = self.create_server(sse_handler)?;
270
271 info!(
272 "[SseServerBuilder] Server created - mcp_id: {}, sse_path: {}, post_path: {}",
273 mcp_id, self.server_config.sse_path, self.server_config.post_path
274 );
275
276 Ok((router, ct, handler_for_return))
277 }
278
279 async fn connect_stdio(
281 &self,
282 command: &str,
283 args: &Option<Vec<String>>,
284 env: &Option<HashMap<String, String>>,
285 client_info: &ClientInfo,
286 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
287 use std::time::Instant;
288
289 let start_time = Instant::now();
290 let mcp_id = self
291 .server_config
292 .mcp_id
293 .clone()
294 .unwrap_or_else(|| "unknown".into());
295
296 let mut wrapped_cmd = TokioCommandWrap::with_new(command, |cmd| {
301 if let Some(cmd_args) = args {
302 cmd.args(cmd_args);
303 }
304 if let Some(env_vars) = env {
306 for (k, v) in env_vars {
307 cmd.env(k, v);
308 }
309 }
310 });
311
312 #[cfg(unix)]
314 wrapped_cmd.wrap(ProcessGroup::leader());
315 #[cfg(windows)]
317 {
318 use process_wrap::tokio::CreationFlags;
319 use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
320 wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
321 wrapped_cmd.wrap(JobObject);
322 }
323
324 wrapped_cmd.wrap(KillOnDrop);
326
327 info!(
328 "[SseServerBuilder] Starting child process - MCP ID: {}, command: {}, args: {:?}",
329 mcp_id,
330 command,
331 args.as_ref().unwrap_or(&vec![])
332 );
333
334 mcp_common::diagnostic::log_stdio_spawn_context("SseServerBuilder", &mcp_id, env);
336
337 let process_start = Instant::now();
338 let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
341 .stderr(std::process::Stdio::piped())
342 .spawn()
343 .map_err(|e| {
344 anyhow::anyhow!(
345 "{}",
346 mcp_common::diagnostic::format_spawn_error(&mcp_id, command, args, e)
347 )
348 })?;
349
350 if let Some(stderr_pipe) = child_stderr {
352 mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.clone());
353 }
354
355 let process_duration = process_start.elapsed();
356
357 debug!(
358 "[SseServerBuilder] Child process spawned - MCP ID: {}, spawn time: {:?}",
359 mcp_id, process_duration
360 );
361
362 let serve_start = Instant::now();
363 let client = client_info.clone().serve(tokio_process).await?;
364 let serve_duration = serve_start.elapsed();
365 let total_duration = start_time.elapsed();
366
367 let warn_msg = "建议的优化方案: \
368 1) 检查网络连接速度 (npm 包下载) \
369 2) 配置国内 npm 镜像 (如淘宝镜像: npm config set registry https://registry.npmmirror.com) \
370 3) 预热服务 (启动 mcp-proxy 时预先加载常用服务) \
371 4) 检查命令参数是否正确";
372
373 log_connection_timing(
374 &mcp_id,
375 "Stdio",
376 total_duration,
377 &[("spawn", process_duration), ("serve", serve_duration)],
378 STDIO_SLOW_THRESHOLD_SECS,
379 warn_msg,
380 );
381
382 Ok(client)
383 }
384
385 async fn connect_sse_url(
387 &self,
388 url: &str,
389 headers: &Option<HashMap<String, String>>,
390 client_info: &ClientInfo,
391 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
392 use std::time::Instant;
393
394 let start_time = Instant::now();
395 let mcp_id = self
396 .server_config
397 .mcp_id
398 .clone()
399 .unwrap_or_else(|| "unknown".into());
400
401 info!(
402 "[SseServerBuilder] Connecting to SSE URL backend - MCP ID: {}, URL: {}",
403 mcp_id, url
404 );
405
406 let mut req_headers = reqwest::header::HeaderMap::new();
408
409 if let Some(config_headers) = headers {
410 for (key, value) in config_headers {
411 req_headers.insert(
412 reqwest::header::HeaderName::try_from(key)
413 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
414 value.parse().map_err(|e| {
415 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
416 })?,
417 );
418 }
419 }
420
421 let http_client = reqwest::Client::builder()
422 .default_headers(req_headers)
423 .build()
424 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
425
426 let sse_config = SseClientConfig {
428 sse_endpoint: url.to_string().into(),
429 ..Default::default()
430 };
431
432 let transport_start = Instant::now();
433 let sse_transport = SseClientTransport::start_with_client(http_client, sse_config).await?;
434 let transport_duration = transport_start.elapsed();
435
436 let serve_start = Instant::now();
437 let client = client_info.clone().serve(sse_transport).await?;
438 let serve_duration = serve_start.elapsed();
439 let total_duration = start_time.elapsed();
440
441 log_connection_timing(
442 &mcp_id,
443 "SSE",
444 total_duration,
445 &[("transport", transport_duration), ("serve", serve_duration)],
446 HTTP_SLOW_THRESHOLD_SECS,
447 "建议: 检查网络连接和后端服务状态",
448 );
449
450 Ok(client)
451 }
452
453 async fn connect_stream_url(
455 &self,
456 url: &str,
457 headers: &Option<HashMap<String, String>>,
458 client_info: &ClientInfo,
459 ) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
460 use std::time::Instant;
461
462 let start_time = Instant::now();
463 let mcp_id = self
464 .server_config
465 .mcp_id
466 .clone()
467 .unwrap_or_else(|| "unknown".into());
468
469 info!(
470 "[SseServerBuilder] Connecting to Streamable HTTP URL backend - MCP ID: {}, URL: {}",
471 mcp_id, url
472 );
473
474 let mut req_headers = reqwest::header::HeaderMap::new();
476 let mut auth_header: Option<String> = None;
477
478 if let Some(config_headers) = headers {
479 for (key, value) in config_headers {
480 if key.eq_ignore_ascii_case("Authorization") {
482 auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
483 continue;
484 }
485
486 req_headers.insert(
487 reqwest::header::HeaderName::try_from(key)
488 .map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
489 value.parse().map_err(|e| {
490 anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
491 })?,
492 );
493 }
494 }
495
496 let http_client = reqwest::Client::builder()
497 .default_headers(req_headers)
498 .build()
499 .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
500
501 let config = StreamableHttpClientTransportConfig {
503 uri: url.to_string().into(),
504 auth_header,
505 ..Default::default()
506 };
507
508 let serve_start = Instant::now();
509 let transport = StreamableHttpClientTransport::with_client(http_client, config);
510 let client = client_info.clone().serve(transport).await?;
511 let serve_duration = serve_start.elapsed();
512 let total_duration = start_time.elapsed();
513
514 log_connection_timing(
515 &mcp_id,
516 "Streamable HTTP",
517 total_duration,
518 &[("serve", serve_duration)],
519 HTTP_SLOW_THRESHOLD_SECS,
520 "建议: 检查网络连接和后端服务状态",
521 );
522
523 Ok(client)
524 }
525
526 fn create_server(&self, sse_handler: SseHandler) -> Result<(axum::Router, CancellationToken)> {
528 let config = SseServerConfig {
531 bind: "0.0.0.0:0".parse()?,
532 sse_path: self.server_config.sse_path.clone(),
533 post_path: self.server_config.post_path.clone(),
534 ct: CancellationToken::new(),
535 sse_keep_alive: Some(std::time::Duration::from_secs(
536 self.server_config.keep_alive_secs,
537 )),
538 };
539
540 let (sse_server, router) = SseServer::new(config);
541
542 let ct = if self.server_config.stateful {
545 sse_server.with_service(move || sse_handler.clone())
546 } else {
547 sse_server.with_service_directly(move || sse_handler.clone())
548 };
549
550 Ok((router, ct))
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_builder_creation() {
560 let builder = SseServerBuilder::new(BackendConfig::Stdio {
561 command: "echo".into(),
562 args: Some(vec!["hello".into()]),
563 env: None,
564 })
565 .mcp_id("test")
566 .sse_path("/custom/sse")
567 .post_path("/custom/message");
568
569 assert!(builder.server_config.mcp_id.is_some());
570 assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
571 assert_eq!(builder.server_config.sse_path, "/custom/sse");
572 assert_eq!(builder.server_config.post_path, "/custom/message");
573 }
574
575 #[test]
576 fn test_default_config() {
577 let config = SseServerBuilderConfig::default();
578 assert_eq!(config.sse_path, "/sse");
579 assert_eq!(config.post_path, "/message");
580 assert_eq!(config.keep_alive_secs, 15);
581 assert!(
582 config.stateful,
583 "default stateful should be true for backward compatibility"
584 );
585 }
586
587 #[test]
588 fn test_stateful_flag_default() {
589 let builder = SseServerBuilder::new(BackendConfig::Stdio {
590 command: "echo".into(),
591 args: None,
592 env: None,
593 });
594 assert!(
595 builder.server_config.stateful,
596 "stateful should default to true"
597 );
598 }
599
600 #[test]
601 fn test_stateful_flag_disabled() {
602 let builder = SseServerBuilder::new(BackendConfig::Stdio {
603 command: "echo".into(),
604 args: None,
605 env: None,
606 })
607 .stateful(false);
608 assert!(
609 !builder.server_config.stateful,
610 "stateful should be false when set"
611 );
612 }
613
614 #[test]
615 fn test_stateful_flag_enabled() {
616 let builder = SseServerBuilder::new(BackendConfig::Stdio {
617 command: "echo".into(),
618 args: None,
619 env: None,
620 })
621 .stateful(true);
622 assert!(
623 builder.server_config.stateful,
624 "stateful should be true when set"
625 );
626 }
627
628 #[test]
629 fn test_timing_constants() {
630 assert_eq!(STDIO_SLOW_THRESHOLD_SECS, 30);
631 assert_eq!(HTTP_SLOW_THRESHOLD_SECS, 10);
632 }
633
634 #[test]
635 fn test_log_connection_timing_format() {
636 use std::time::Duration;
637 log_connection_timing(
639 "test-mcp",
640 "TestBackend",
641 Duration::from_millis(1500),
642 &[
643 ("step1", Duration::from_millis(500)),
644 ("step2", Duration::from_millis(1000)),
645 ],
646 10,
647 "Test warning message",
648 );
649 }
651
652 #[test]
653 fn test_log_connection_timing_no_breakdown() {
654 use std::time::Duration;
655 log_connection_timing(
657 "test-mcp",
658 "TestBackend",
659 Duration::from_millis(500),
660 &[],
661 10,
662 "Test warning message",
663 );
664 }
665}