1use crate::{
2 AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3 model::{
4 CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5 McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6 },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12 ServiceExt,
13 model::{ClientCapabilities, ClientInfo},
14 transport::streamable_http_server::{
15 StreamableHttpService, session::local::LocalSessionManager,
16 },
17 transport::{
18 SseClientTransport, SseServer, TokioChildProcess, sse_server::SseServerConfig,
19 streamable_http_client::StreamableHttpClientTransport,
20 },
21};
22use tokio::process::Command;
23
24pub async fn mcp_start_task(
26 mcp_config: McpConfig,
27) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
28 let mcp_id = mcp_config.mcp_id.clone();
29 let client_protocol = mcp_config.client_protocol.clone();
30
31 let mcp_router_path: McpRouterPath =
33 McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
34
35 let mcp_json_config = mcp_config
36 .mcp_json_config
37 .clone()
38 .expect("mcp_json_config is required");
39
40 let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
41
42 integrate_sse_server_with_axum(
44 mcp_server_config.clone(),
45 mcp_router_path.clone(),
46 mcp_config.mcp_type,
47 )
48 .await
49}
50
51pub async fn integrate_sse_server_with_axum(
53 mcp_config: McpServerConfig,
54 mcp_router_path: McpRouterPath,
55 mcp_type: McpType,
56) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
57 let base_path = mcp_router_path.base_path.clone();
58 let mcp_id = mcp_router_path.mcp_id.clone();
59
60 let backend_protocol = match &mcp_config {
62 McpServerConfig::Command(_) => McpProtocol::Stdio,
64 McpServerConfig::Url(url_config) => {
66 if let Some(type_str) = &url_config.r#type {
68 match type_str.parse::<McpProtocol>() {
70 Ok(protocol) => {
71 debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
72 protocol
73 }
74 Err(_) => {
75 debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
77 let detected_protocol =
78 crate::server::detect_mcp_protocol(url_config.get_url())
79 .await
80 .map_err(|e| {
81 anyhow::anyhow!(
82 "协议类型 '{}' 不可识别,且自动检测失败: {}",
83 type_str,
84 e
85 )
86 })?;
87 debug!(
88 "自动检测到协议类型: {:?}(原始配置: '{}')",
89 detected_protocol, type_str
90 );
91 detected_protocol
92 }
93 }
94 } else {
95 debug!("未指定 type 字段,自动检测协议");
97 let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
98 .await
99 .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
100 detected_protocol
101 }
102 }
103 };
104
105 debug!(
106 "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
107 mcp_id, mcp_router_path.mcp_protocol, backend_protocol
108 );
109
110 let client_info = ClientInfo {
112 protocol_version: Default::default(),
113 capabilities: ClientCapabilities::builder()
114 .enable_experimental()
115 .enable_roots()
116 .enable_roots_list_changed()
117 .enable_sampling()
118 .build(),
119 ..Default::default()
120 };
121
122 let client = match &mcp_config {
124 McpServerConfig::Command(cmd_config) => {
125 let mut command = Command::new(&cmd_config.command);
127
128 if let Some(args) = &cmd_config.args {
130 command.args(args);
131 }
132
133 if let Some(env_vars) = &cmd_config.env {
135 for (key, value) in env_vars {
136 command.env(key, value);
137 }
138 }
139
140 log_command_details(cmd_config, &mcp_router_path);
142
143 info!(
144 "子进程已启动,MCP ID: {}, 类型: {:?}",
145 mcp_router_path.mcp_id,
146 mcp_type.clone()
147 );
148
149 let tokio_process = TokioChildProcess::new(command)?;
151 client_info.serve(tokio_process).await?
152 }
153 McpServerConfig::Url(url_config) => {
154 info!(
156 "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
157 url_config.get_url(),
158 backend_protocol,
159 mcp_router_path.mcp_protocol
160 );
161
162 match backend_protocol {
163 McpProtocol::Stdio => {
164 return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
166 }
167 McpProtocol::Sse => {
168 info!("使用SSE协议连接到: {}", url_config.get_url());
170
171 let mut headers = reqwest::header::HeaderMap::new();
173
174 if let Some(config_headers) = &url_config.headers {
176 for (key, value) in config_headers {
177 headers.insert(
180 reqwest::header::HeaderName::try_from(key).map_err(|e| {
181 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
182 })?,
183 value.parse().map_err(|e| {
184 anyhow::anyhow!(
185 "Invalid header value for {}: {}, error: {}",
186 key,
187 value,
188 e
189 )
190 })?,
191 );
192 }
193 info!(
194 "添加了 {} 个自定义 headers(包含 Authorization)",
195 headers.len()
196 );
197 } else {
198 info!("没有配置自定义 headers");
199 }
200
201 let client = reqwest::Client::builder()
202 .default_headers(headers)
203 .build()
204 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
205
206 let sse_config = rmcp::transport::sse_client::SseClientConfig {
208 sse_endpoint: url_config.get_url().to_string().into(),
209 ..Default::default()
210 };
211
212 let sse_transport =
213 SseClientTransport::start_with_client(client, sse_config).await?;
214 client_info.serve(sse_transport).await?
215 }
216 McpProtocol::Stream => {
217 info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
219
220 let mut headers = reqwest::header::HeaderMap::new();
222
223 if let Some(config_headers) = &url_config.headers {
225 for (key, value) in config_headers {
226 if key.eq_ignore_ascii_case("Authorization") {
228 continue;
229 }
230 headers.insert(
231 reqwest::header::HeaderName::try_from(key).map_err(|e| {
232 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
233 })?,
234 value.parse().map_err(|e| {
235 anyhow::anyhow!(
236 "Invalid header value for {}: {}, error: {}",
237 key,
238 value,
239 e
240 )
241 })?,
242 );
243 }
244 info!("添加了 {} 个自定义 headers", headers.len());
245 } else {
246 info!("没有配置自定义 headers");
247 }
248
249 let client = reqwest::Client::builder()
250 .default_headers(headers)
251 .build()
252 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
253
254 let auth_header = url_config.headers.as_ref().and_then(|h| {
256 h.iter()
258 .find_map(|(k, v)| {
259 if k.eq_ignore_ascii_case("Authorization") {
260 Some(v)
261 } else {
262 None
263 }
264 })
265 .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
266 });
267
268 let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
270 uri: url_config.get_url().to_string().into(),
271 auth_header,
272 ..Default::default()
273 };
274
275 let transport = StreamableHttpClientTransport::with_client(client, config);
276
277 info!(
278 "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
279 mcp_router_path.mcp_id,
280 mcp_type.clone()
281 );
282
283 let client = client_info.serve(transport).await?;
285
286 info!(
287 "Streamable HTTP客户端连接成功,MCP ID: {}",
288 mcp_router_path.mcp_id
289 );
290
291 client
292 }
293 }
294 }
295 };
296
297 let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
299
300 let proxy_manager = get_proxy_manager();
302
303 let proxy_handler_clone = proxy_handler.clone();
305
306 let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
310 McpProtocol::Sse => {
312 let addr: String = "0.0.0.0:0".to_string();
315 let sse_path = match &mcp_router_path.mcp_protocol_path {
316 McpProtocolPath::SsePath(sse_path) => sse_path,
317 _ => unreachable!(),
318 };
319 let config = SseServerConfig {
320 bind: addr.parse()?,
321 sse_path: sse_path.sse_path.clone(),
322 post_path: sse_path.message_path.clone(),
323 ct: tokio_util::sync::CancellationToken::new(),
324 sse_keep_alive: None,
325 };
326
327 debug!(
328 "创建SSE服务器,配置: bind={}, sse_path={}, post_path={}",
329 config.bind, config.sse_path, config.post_path
330 );
331
332 let (sse_server, router) = SseServer::new(config);
333 let ct = sse_server.with_service(move || proxy_handler_clone.clone());
334 (router, ct)
335 }
336
337 McpProtocol::Stream => {
339 let service = StreamableHttpService::new(
342 move || Ok(proxy_handler_clone.clone()),
343 LocalSessionManager::default().into(),
344 Default::default(),
345 );
346 let router = axum::Router::new().fallback_service(service);
347 let ct = tokio_util::sync::CancellationToken::new();
348 (router, ct)
349 }
350
351 McpProtocol::Stdio => {
353 return Err(anyhow::anyhow!(
354 "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
355 ));
356 }
357 };
358
359 let ct_clone = ct.clone();
361 let mcp_id_clone = mcp_id.clone();
362
363 let mcp_service_status = McpServiceStatus::new(
365 mcp_id_clone.clone(),
366 mcp_type.clone(),
367 mcp_router_path.clone(),
368 ct_clone.clone(),
369 CheckMcpStatusResponseStatus::Ready,
370 );
371 proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
373
374 let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
377 let modified_router = router.fallback(base_path_fallback_handler);
379 info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
380 modified_router
381 } else {
382 router
383 };
384
385 info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
387 info!(
388 "SSE路径配置: sse_path={}, post_path={}",
389 match &mcp_router_path.mcp_protocol_path {
390 McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
391 _ => "N/A",
392 },
393 match &mcp_router_path.mcp_protocol_path {
394 McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
395 _ => "N/A",
396 }
397 );
398 DynamicRouterService::register_route(&base_path, router.clone());
399 info!("路由注册完成: base_path={}", base_path);
400
401 Ok((router, ct))
403}
404
405#[axum::debug_handler]
407async fn base_path_fallback_handler(
408 method: axum::http::Method,
409 uri: axum::http::Uri,
410 headers: axum::http::HeaderMap,
411) -> impl axum::response::IntoResponse {
412 let path = uri.path();
413 info!("基础路径处理器: {} {}", method, path);
414
415 if path.contains("/sse/proxy/") {
417 match method {
419 axum::http::Method::GET => {
420 let mcp_id = path.split("/sse/proxy/").nth(1);
422
423 if let Some(mcp_id) = mcp_id {
424 let proxy_manager = get_proxy_manager();
426 if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
427 (
429 axum::http::StatusCode::NOT_FOUND,
430 [("Content-Type", "text/plain".to_string())],
431 format!("MCP service '{}' not found", mcp_id).to_string(),
432 )
433 } else {
434 let accept_header = headers.get("accept");
436 if let Some(accept) = accept_header {
437 let accept_str = accept.to_str().unwrap_or("");
438 if accept_str.contains("text/event-stream") {
439 let redirect_uri = format!("{}/sse", path);
441 info!("SSE重定向到: {}", redirect_uri);
442 (
443 axum::http::StatusCode::FOUND,
444 [("Location", redirect_uri.to_string())],
445 "Redirecting to SSE endpoint".to_string(),
446 )
447 } else {
448 (
450 axum::http::StatusCode::BAD_REQUEST,
451 [("Content-Type", "text/plain".to_string())],
452 "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
453 )
454 }
455 } else {
456 (
458 axum::http::StatusCode::BAD_REQUEST,
459 [("Content-Type", "text/plain".to_string())],
460 "SSE error: Missing Accept header, expected 'text/event-stream'"
461 .to_string(),
462 )
463 }
464 }
465 } else {
466 (
468 axum::http::StatusCode::BAD_REQUEST,
469 [("Content-Type", "text/plain".to_string())],
470 "SSE error: Invalid SSE path".to_string(),
471 )
472 }
473 }
474 axum::http::Method::POST => {
475 let redirect_uri = format!("{}/message", path);
477 info!("SSE重定向到: {}", redirect_uri);
478 (
479 axum::http::StatusCode::FOUND,
480 [("Location", redirect_uri.to_string())],
481 "Redirecting to message endpoint".to_string(),
482 )
483 }
484 _ => {
485 (
487 axum::http::StatusCode::METHOD_NOT_ALLOWED,
488 [("Allow", "GET, POST".to_string())],
489 "Only GET and POST methods are allowed".to_string(),
490 )
491 }
492 }
493 } else if path.contains("/stream/proxy/") {
494 match method {
496 axum::http::Method::GET => {
497 (
499 axum::http::StatusCode::OK,
500 [("Content-Type", "application/json".to_string())],
501 r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
502 )
503 }
504 axum::http::Method::POST => {
505 (
507 axum::http::StatusCode::OK,
508 [("Content-Type", "application/json".to_string())],
509 r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
510 )
511 }
512 _ => {
513 (
515 axum::http::StatusCode::METHOD_NOT_ALLOWED,
516 [("Allow", "GET, POST".to_string())],
517 "Only GET and POST methods are allowed".to_string(),
518 )
519 }
520 }
521 } else {
522 (
524 axum::http::StatusCode::BAD_REQUEST,
525 [("Content-Type", "text/plain".to_string())],
526 "Unknown protocol or path".to_string(),
527 )
528 }
529}
530
531fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
533 let args_str = mcp_config
535 .args
536 .as_ref()
537 .map_or(String::new(), |args| args.join(" "));
538 let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
539 debug!("{cmd_str}");
540
541 if let Some(env_vars) = &mcp_config.env {
543 let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
544 if !env_vars.is_empty() {
545 debug!("环境变量: {}", env_vars.join(", "));
546 }
547 }
548
549 debug!(
551 "完整命令,mcpId={}, command={:?}",
552 mcp_router_path.mcp_id, mcp_config.command
553 );
554
555 let args_str = mcp_config
557 .args
558 .as_ref()
559 .map_or(String::new(), |args| args.join(" "));
560 let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
561 env.iter()
562 .map(|(k, v)| format!("{k}={v}"))
563 .collect::<Vec<String>>()
564 .join(" ")
565 });
566
567 let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
568 info!(
569 "完整命令字符串,mcpId={},command={:?}",
570 mcp_router_path.mcp_id, full_command
571 );
572}