use std::collections::HashMap;
use std::time::Duration;
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use process_wrap::tokio::{KillOnDrop, TokioCommandWrap};
#[cfg(unix)]
use process_wrap::tokio::ProcessGroup;
#[cfg(windows)]
use process_wrap::tokio::JobObject;
use rmcp::{
ServiceExt,
model::{ClientCapabilities, ClientInfo, ProtocolVersion},
transport::{
SseClientTransport, TokioChildProcess,
sse_client::SseClientConfig,
sse_server::{SseServer, SseServerConfig},
streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
},
},
};
use crate::{SseHandler, ToolFilter};
const STDIO_SLOW_THRESHOLD_SECS: u64 = 30;
const HTTP_SLOW_THRESHOLD_SECS: u64 = 10;
#[derive(Debug, Clone)]
pub enum BackendConfig {
Stdio {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
},
SseUrl {
url: String,
headers: Option<HashMap<String, String>>,
},
StreamUrl {
url: String,
headers: Option<HashMap<String, String>>,
},
}
#[derive(Debug, Clone)]
pub struct SseServerBuilderConfig {
pub sse_path: String,
pub post_path: String,
pub mcp_id: Option<String>,
pub tool_filter: Option<ToolFilter>,
pub keep_alive_secs: u64,
pub stateful: bool,
}
impl Default for SseServerBuilderConfig {
fn default() -> Self {
Self {
sse_path: "/sse".into(),
post_path: "/message".into(),
mcp_id: None,
tool_filter: None,
keep_alive_secs: 15,
stateful: true,
}
}
}
fn log_connection_timing(
mcp_id: &str,
backend_type: &str,
total_duration: Duration,
breakdown: &[(&str, Duration)],
warn_threshold_secs: u64,
warn_message: &str,
) {
let breakdown_str: Vec<String> = breakdown
.iter()
.map(|(name, dur)| format!("{}: {:?}", name, dur))
.collect();
info!(
"[SseServerBuilder] {} backend connected successfully - MCP ID: {}, total: {:?} ({})",
backend_type,
mcp_id,
total_duration,
breakdown_str.join(", ")
);
if total_duration.as_secs() >= warn_threshold_secs {
warn!(
"[SseServerBuilder] {} backend connection takes a long time - MCP ID: {}, time: {:?}, {}",
backend_type, mcp_id, total_duration, warn_message
);
}
}
pub struct SseServerBuilder {
backend_config: BackendConfig,
server_config: SseServerBuilderConfig,
}
impl SseServerBuilder {
pub fn new(backend: BackendConfig) -> Self {
Self {
backend_config: backend,
server_config: SseServerBuilderConfig::default(),
}
}
pub fn sse_path(mut self, path: impl Into<String>) -> Self {
self.server_config.sse_path = path.into();
self
}
pub fn post_path(mut self, path: impl Into<String>) -> Self {
self.server_config.post_path = path.into();
self
}
pub fn mcp_id(mut self, id: impl Into<String>) -> Self {
self.server_config.mcp_id = Some(id.into());
self
}
pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
self.server_config.tool_filter = Some(filter);
self
}
pub fn keep_alive(mut self, secs: u64) -> Self {
self.server_config.keep_alive_secs = secs;
self
}
pub fn stateful(mut self, stateful: bool) -> Self {
self.server_config.stateful = stateful;
self
}
pub async fn build(self) -> Result<(axum::Router, CancellationToken, SseHandler)> {
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "sse-proxy".into());
let client_info = ClientInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
..Default::default()
};
let client = match &self.backend_config {
BackendConfig::Stdio { command, args, env } => {
self.connect_stdio(command, args, env, &client_info).await?
}
BackendConfig::SseUrl { url, headers } => {
self.connect_sse_url(url, headers, &client_info).await?
}
BackendConfig::StreamUrl { url, headers } => {
self.connect_stream_url(url, headers, &client_info).await?
}
};
let sse_handler = if let Some(ref tool_filter) = self.server_config.tool_filter {
SseHandler::with_tool_filter(client, mcp_id.clone(), tool_filter.clone())
} else {
SseHandler::with_mcp_id(client, mcp_id.clone())
};
let handler_for_return = sse_handler.clone();
let (router, ct) = self.create_server(sse_handler)?;
info!(
"[SseServerBuilder] Server created - mcp_id: {}, sse_path: {}, post_path: {}",
mcp_id, self.server_config.sse_path, self.server_config.post_path
);
Ok((router, ct, handler_for_return))
}
async fn connect_stdio(
&self,
command: &str,
args: &Option<Vec<String>>,
env: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
use std::time::Instant;
let start_time = Instant::now();
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "unknown".into());
let mut wrapped_cmd = TokioCommandWrap::with_new(command, |cmd| {
if let Some(cmd_args) = args {
cmd.args(cmd_args);
}
if let Some(env_vars) = env {
for (k, v) in env_vars {
cmd.env(k, v);
}
}
});
#[cfg(unix)]
wrapped_cmd.wrap(ProcessGroup::leader());
#[cfg(windows)]
{
use process_wrap::tokio::CreationFlags;
use windows::Win32::System::Threading::{CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW};
wrapped_cmd.wrap(CreationFlags(CREATE_NO_WINDOW | CREATE_NEW_PROCESS_GROUP));
wrapped_cmd.wrap(JobObject);
}
wrapped_cmd.wrap(KillOnDrop);
info!(
"[SseServerBuilder] Starting child process - MCP ID: {}, command: {}, args: {:?}",
mcp_id,
command,
args.as_ref().unwrap_or(&vec![])
);
mcp_common::diagnostic::log_stdio_spawn_context("SseServerBuilder", &mcp_id, env);
let process_start = Instant::now();
let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| {
anyhow::anyhow!(
"{}",
mcp_common::diagnostic::format_spawn_error(&mcp_id, command, args, e)
)
})?;
if let Some(stderr_pipe) = child_stderr {
mcp_common::spawn_stderr_reader(stderr_pipe, mcp_id.clone());
}
let process_duration = process_start.elapsed();
debug!(
"[SseServerBuilder] Child process spawned - MCP ID: {}, spawn time: {:?}",
mcp_id, process_duration
);
let serve_start = Instant::now();
let client = client_info.clone().serve(tokio_process).await?;
let serve_duration = serve_start.elapsed();
let total_duration = start_time.elapsed();
let warn_msg = "建议的优化方案: \
1) 检查网络连接速度 (npm 包下载) \
2) 配置国内 npm 镜像 (如淘宝镜像: npm config set registry https://registry.npmmirror.com) \
3) 预热服务 (启动 mcp-proxy 时预先加载常用服务) \
4) 检查命令参数是否正确";
log_connection_timing(
&mcp_id,
"Stdio",
total_duration,
&[("spawn", process_duration), ("serve", serve_duration)],
STDIO_SLOW_THRESHOLD_SECS,
warn_msg,
);
Ok(client)
}
async fn connect_sse_url(
&self,
url: &str,
headers: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
use std::time::Instant;
let start_time = Instant::now();
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "unknown".into());
info!(
"[SseServerBuilder] Connecting to SSE URL backend - MCP ID: {}, URL: {}",
mcp_id, url
);
let mut req_headers = reqwest::header::HeaderMap::new();
if let Some(config_headers) = headers {
for (key, value) in config_headers {
req_headers.insert(
reqwest::header::HeaderName::try_from(key)
.map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
value.parse().map_err(|e| {
anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
})?,
);
}
}
let http_client = reqwest::Client::builder()
.default_headers(req_headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let sse_config = SseClientConfig {
sse_endpoint: url.to_string().into(),
..Default::default()
};
let transport_start = Instant::now();
let sse_transport = SseClientTransport::start_with_client(http_client, sse_config).await?;
let transport_duration = transport_start.elapsed();
let serve_start = Instant::now();
let client = client_info.clone().serve(sse_transport).await?;
let serve_duration = serve_start.elapsed();
let total_duration = start_time.elapsed();
log_connection_timing(
&mcp_id,
"SSE",
total_duration,
&[("transport", transport_duration), ("serve", serve_duration)],
HTTP_SLOW_THRESHOLD_SECS,
"建议: 检查网络连接和后端服务状态",
);
Ok(client)
}
async fn connect_stream_url(
&self,
url: &str,
headers: &Option<HashMap<String, String>>,
client_info: &ClientInfo,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ClientInfo>> {
use std::time::Instant;
let start_time = Instant::now();
let mcp_id = self
.server_config
.mcp_id
.clone()
.unwrap_or_else(|| "unknown".into());
info!(
"[SseServerBuilder] Connecting to Streamable HTTP URL backend - MCP ID: {}, URL: {}",
mcp_id, url
);
let mut req_headers = reqwest::header::HeaderMap::new();
let mut auth_header: Option<String> = None;
if let Some(config_headers) = headers {
for (key, value) in config_headers {
if key.eq_ignore_ascii_case("Authorization") {
auth_header = Some(value.strip_prefix("Bearer ").unwrap_or(value).to_string());
continue;
}
req_headers.insert(
reqwest::header::HeaderName::try_from(key)
.map_err(|e| anyhow::anyhow!("Invalid header name '{}': {}", key, e))?,
value.parse().map_err(|e| {
anyhow::anyhow!("Invalid header value for '{}': {}", key, e)
})?,
);
}
}
let http_client = reqwest::Client::builder()
.default_headers(req_headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
let config = StreamableHttpClientTransportConfig {
uri: url.to_string().into(),
auth_header,
..Default::default()
};
let serve_start = Instant::now();
let transport = StreamableHttpClientTransport::with_client(http_client, config);
let client = client_info.clone().serve(transport).await?;
let serve_duration = serve_start.elapsed();
let total_duration = start_time.elapsed();
log_connection_timing(
&mcp_id,
"Streamable HTTP",
total_duration,
&[("serve", serve_duration)],
HTTP_SLOW_THRESHOLD_SECS,
"建议: 检查网络连接和后端服务状态",
);
Ok(client)
}
fn create_server(&self, sse_handler: SseHandler) -> Result<(axum::Router, CancellationToken)> {
let config = SseServerConfig {
bind: "0.0.0.0:0".parse()?,
sse_path: self.server_config.sse_path.clone(),
post_path: self.server_config.post_path.clone(),
ct: CancellationToken::new(),
sse_keep_alive: Some(std::time::Duration::from_secs(
self.server_config.keep_alive_secs,
)),
};
let (sse_server, router) = SseServer::new(config);
let ct = if self.server_config.stateful {
sse_server.with_service(move || sse_handler.clone())
} else {
sse_server.with_service_directly(move || sse_handler.clone())
};
Ok((router, ct))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_creation() {
let builder = SseServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: Some(vec!["hello".into()]),
env: None,
})
.mcp_id("test")
.sse_path("/custom/sse")
.post_path("/custom/message");
assert!(builder.server_config.mcp_id.is_some());
assert_eq!(builder.server_config.mcp_id.as_deref(), Some("test"));
assert_eq!(builder.server_config.sse_path, "/custom/sse");
assert_eq!(builder.server_config.post_path, "/custom/message");
}
#[test]
fn test_default_config() {
let config = SseServerBuilderConfig::default();
assert_eq!(config.sse_path, "/sse");
assert_eq!(config.post_path, "/message");
assert_eq!(config.keep_alive_secs, 15);
assert!(
config.stateful,
"default stateful should be true for backward compatibility"
);
}
#[test]
fn test_stateful_flag_default() {
let builder = SseServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: None,
env: None,
});
assert!(
builder.server_config.stateful,
"stateful should default to true"
);
}
#[test]
fn test_stateful_flag_disabled() {
let builder = SseServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: None,
env: None,
})
.stateful(false);
assert!(
!builder.server_config.stateful,
"stateful should be false when set"
);
}
#[test]
fn test_stateful_flag_enabled() {
let builder = SseServerBuilder::new(BackendConfig::Stdio {
command: "echo".into(),
args: None,
env: None,
})
.stateful(true);
assert!(
builder.server_config.stateful,
"stateful should be true when set"
);
}
#[test]
fn test_timing_constants() {
assert_eq!(STDIO_SLOW_THRESHOLD_SECS, 30);
assert_eq!(HTTP_SLOW_THRESHOLD_SECS, 10);
}
#[test]
fn test_log_connection_timing_format() {
use std::time::Duration;
log_connection_timing(
"test-mcp",
"TestBackend",
Duration::from_millis(1500),
&[
("step1", Duration::from_millis(500)),
("step2", Duration::from_millis(1000)),
],
10,
"Test warning message",
);
}
#[test]
fn test_log_connection_timing_no_breakdown() {
use std::time::Duration;
log_connection_timing(
"test-mcp",
"TestBackend",
Duration::from_millis(500),
&[],
10,
"Test warning message",
);
}
}