use anyhow::{Result, bail};
use mcp_common::{
McpServiceConfig, check_windows_command, convert_path_to_windows_format,
preprocess_npx_command_windows, wrap_process_v8,
};
use rmcp::{
ServiceExt,
model::{ClientCapabilities, ClientInfo, ProtocolVersion},
transport::{
TokioChildProcess,
sse_server::{SseServer, SseServerConfig},
},
};
use std::process::Stdio;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use process_wrap::tokio::{KillOnDrop, TokioCommandWrap};
use crate::SseHandler;
pub async fn run_sse_server_from_config(
config: McpServiceConfig,
std_listener: &std::net::TcpListener,
quiet: bool,
) -> Result<()> {
let inherited_path = std::env::var("PATH").unwrap_or_default();
let user_env_path = config.env.as_ref().and_then(|e| e.get("PATH").cloned());
let effective_path = user_env_path.as_deref().unwrap_or(&inherited_path);
check_windows_command(&config.command);
let (processed_command, processed_args) =
preprocess_npx_command_windows(&config.command, config.args.as_deref());
info!(
"[子进程环境][{}] 命令: {} {:?}",
config.name,
processed_command,
processed_args.as_ref().unwrap_or(&vec![])
);
debug!(
"[子进程环境][{}] 继承 PATH: {}",
config.name, inherited_path
);
if let Some(ref user_path) = user_env_path {
info!("[子进程环境][{}] 用户覆盖 PATH: {}", config.name, user_path);
}
info!(
"[子进程环境][{}] 生效 PATH: {}",
config.name, effective_path
);
if let Some(ref env_vars) = config.env {
let non_path_keys: Vec<&String> = env_vars.keys().filter(|k| *k != "PATH").collect();
if !non_path_keys.is_empty() {
info!(
"[子进程环境][{}] 用户自定义环境变量: {:?}",
config.name, non_path_keys
);
}
}
let mut wrapped_cmd = TokioCommandWrap::with_new(&processed_command, |command| {
if let Some(ref cmd_args) = processed_args {
command.args(cmd_args);
}
#[cfg(target_os = "windows")]
if let Ok(path_value) = std::env::var("PATH") {
let converted_path = convert_path_to_windows_format(&path_value);
if converted_path != path_value {
command.env("PATH", converted_path);
}
}
if let Some(ref env_vars) = config.env {
for (k, v) in env_vars {
command.env(k, v);
}
}
});
wrap_process_v8!(wrapped_cmd);
wrapped_cmd.wrap(KillOnDrop);
let (tokio_process, child_stderr) = TokioChildProcess::builder(wrapped_cmd)
.stderr(Stdio::piped())
.spawn()?;
if let Some(stderr_pipe) = child_stderr {
mcp_common::spawn_stderr_reader(stderr_pipe, config.name.clone());
}
let client_info = ClientInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
..Default::default()
};
let client = client_info.serve(tokio_process).await?;
info!(
"[子进程启动] SSE - 服务名: {}, 命令: {} {:?}",
config.name,
config.command,
config.args.as_ref().unwrap_or(&vec![])
);
if !quiet {
eprintln!("✅ 子进程已启动");
match client.list_tools(None).await {
Ok(tools_result) => {
let tools = &tools_result.tools;
if tools.is_empty() {
warn!("[工具列表] 工具列表为空 - 服务名: {}", config.name);
eprintln!("⚠️ 工具列表为空");
} else {
info!(
"[工具列表] 服务名: {}, 工具数量: {}",
config.name,
tools.len()
);
eprintln!("🔧 可用工具 ({} 个):", tools.len());
for tool in tools.iter().take(10) {
let desc = tool.description.as_deref().unwrap_or("无描述");
let desc_short = if desc.len() > 50 {
format!("{}...", &desc[..50])
} else {
desc.to_string()
};
eprintln!(" - {} : {}", tool.name, desc_short);
}
if tools.len() > 10 {
eprintln!(" ... 和 {} 个其他工具", tools.len() - 10);
}
}
}
Err(e) => {
error!(
"[工具列表] 获取工具列表失败 - 服务名: {}, 错误: {}",
config.name, e
);
eprintln!("⚠️ 获取工具列表失败: {}", e);
}
}
} else {
match client.list_tools(None).await {
Ok(tools_result) => {
info!(
"[工具列表] 服务名: {}, 工具数量: {}",
config.name,
tools_result.tools.len()
);
}
Err(e) => {
error!(
"[工具列表] 获取工具列表失败 - 服务名: {}, 错误: {}",
config.name, e
);
}
}
}
let sse_handler = if let Some(tool_filter) = config.tool_filter {
SseHandler::with_tool_filter(client, config.name.clone(), tool_filter)
} else {
SseHandler::with_mcp_id(client, config.name.clone())
};
let listener = tokio::net::TcpListener::from_std(std_listener.try_clone()?)?;
run_sse_server(sse_handler, listener, quiet).await
}
pub async fn run_sse_server(
sse_handler: SseHandler,
listener: tokio::net::TcpListener,
quiet: bool,
) -> Result<()> {
let bind_addr = listener.local_addr()?;
let bind_addr_str = bind_addr.to_string();
let sse_path = "/sse".to_string();
let message_path = "/message".to_string();
let mcp_id = sse_handler.mcp_id().to_string();
info!(
"[HTTP服务启动] SSE 服务启动 - 地址: {}, MCP ID: {}, SSE端点: {}, 消息端点: {}",
bind_addr_str, mcp_id, sse_path, message_path
);
if !quiet {
eprintln!("📡 SSE 服务启动: http://{}", bind_addr_str);
eprintln!(" SSE 端点: http://{}{}", bind_addr_str, sse_path);
eprintln!(" 消息端点: http://{}{}", bind_addr_str, message_path);
eprintln!(
"💡 MCP 客户端可直接使用: http://{} (自动重定向)",
bind_addr_str
);
eprintln!("🔄 后端热替换: 启用");
eprintln!("💡 按 Ctrl+C 停止服务");
}
let config = SseServerConfig {
bind: bind_addr,
sse_path: sse_path.clone(),
post_path: message_path.clone(),
ct: CancellationToken::new(),
sse_keep_alive: Some(std::time::Duration::from_secs(15)),
};
let (sse_server, sse_router) = SseServer::new(config);
let ct = sse_server.with_service(move || sse_handler.clone());
let sse_path_for_fallback = sse_path.clone();
let message_path_for_fallback = message_path.clone();
let fallback_handler = move |method: axum::http::Method, headers: axum::http::HeaderMap| {
let sse_path = sse_path_for_fallback.clone();
let message_path = message_path_for_fallback.clone();
async move {
match method {
axum::http::Method::GET => {
let accept = headers
.get("accept")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if accept.contains("text/event-stream") {
(
axum::http::StatusCode::TEMPORARY_REDIRECT,
[("Location", sse_path)],
"Redirecting to SSE endpoint".to_string(),
)
} else {
(
axum::http::StatusCode::OK,
[("Content-Type", "application/json".to_string())],
serde_json::json!({
"status": "running",
"protocol": "SSE",
"endpoints": {
"sse": sse_path,
"message": message_path
},
"usage": "Connect your MCP client to this URL or the SSE endpoint directly"
}).to_string(),
)
}
}
axum::http::Method::POST => {
(
axum::http::StatusCode::TEMPORARY_REDIRECT,
[("Location", message_path)],
"Redirecting to message endpoint".to_string(),
)
}
_ => (
axum::http::StatusCode::METHOD_NOT_ALLOWED,
[("Allow", "GET, POST".to_string())],
"Method not allowed".to_string(),
),
}
}
};
let router = sse_router.fallback(fallback_handler);
tokio::select! {
result = axum::serve(listener, router) => {
if let Err(e) = result {
error!(
"[HTTP服务错误] SSE 服务器错误 - MCP ID: {}, 错误: {}",
mcp_id, e
);
bail!("服务器错误: {}", e);
}
}
_ = tokio::signal::ctrl_c() => {
info!(
"[HTTP服务关闭] 收到退出信号,正在关闭 SSE 服务 - MCP ID: {}",
mcp_id
);
if !quiet {
eprintln!("\n🛑 收到退出信号,正在关闭...");
}
ct.cancel();
}
}
Ok(())
}