1use crate::{
2 AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3 model::{
4 CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5 McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6 },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12 ServiceExt,
13 model::{ClientCapabilities, ClientInfo},
14 transport::streamable_http_server::{
15 StreamableHttpService, session::local::LocalSessionManager,
16 },
17 transport::{
18 SseClientTransport, SseServer, TokioChildProcess, StreamableHttpServerConfig,
19 sse_server::SseServerConfig,
20 streamable_http_client::StreamableHttpClientTransport,
21 },
22};
23use tokio::process::Command;
24
25pub async fn mcp_start_task(
27 mcp_config: McpConfig,
28) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
29 let mcp_id = mcp_config.mcp_id.clone();
30 let client_protocol = mcp_config.client_protocol.clone();
31
32 let mcp_router_path: McpRouterPath =
34 McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
35
36 let mcp_json_config = mcp_config
37 .mcp_json_config
38 .clone()
39 .expect("mcp_json_config is required");
40
41 let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
42
43 integrate_sse_server_with_axum(
45 mcp_server_config.clone(),
46 mcp_router_path.clone(),
47 mcp_config.mcp_type,
48 )
49 .await
50}
51
52pub async fn integrate_sse_server_with_axum(
54 mcp_config: McpServerConfig,
55 mcp_router_path: McpRouterPath,
56 mcp_type: McpType,
57) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
58 let base_path = mcp_router_path.base_path.clone();
59 let mcp_id = mcp_router_path.mcp_id.clone();
60
61 let backend_protocol = match &mcp_config {
63 McpServerConfig::Command(_) => McpProtocol::Stdio,
65 McpServerConfig::Url(url_config) => {
67 if let Some(type_str) = &url_config.r#type {
69 match type_str.parse::<McpProtocol>() {
71 Ok(protocol) => {
72 debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
73 protocol
74 }
75 Err(_) => {
76 debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
78 let detected_protocol =
79 crate::server::detect_mcp_protocol(url_config.get_url())
80 .await
81 .map_err(|e| {
82 anyhow::anyhow!(
83 "协议类型 '{}' 不可识别,且自动检测失败: {}",
84 type_str,
85 e
86 )
87 })?;
88 debug!(
89 "自动检测到协议类型: {:?}(原始配置: '{}')",
90 detected_protocol, type_str
91 );
92 detected_protocol
93 }
94 }
95 } else {
96 debug!("未指定 type 字段,自动检测协议");
98 let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
99 .await
100 .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
101 detected_protocol
102 }
103 }
104 };
105
106 debug!(
107 "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
108 mcp_id, mcp_router_path.mcp_protocol, backend_protocol
109 );
110
111 let client_info = ClientInfo {
113 protocol_version: Default::default(),
114 capabilities: ClientCapabilities::builder()
115 .enable_experimental()
116 .enable_roots()
117 .enable_roots_list_changed()
118 .enable_sampling()
119 .build(),
120 ..Default::default()
121 };
122
123 let client = match &mcp_config {
125 McpServerConfig::Command(cmd_config) => {
126 let mut command = Command::new(&cmd_config.command);
128
129 if let Some(args) = &cmd_config.args {
131 command.args(args);
132 }
133
134 if let Some(env_vars) = &cmd_config.env {
136 for (key, value) in env_vars {
137 command.env(key, value);
138 }
139 }
140
141 log_command_details(cmd_config, &mcp_router_path);
143
144 info!(
145 "子进程已启动,MCP ID: {}, 类型: {:?}",
146 mcp_router_path.mcp_id,
147 mcp_type.clone()
148 );
149
150 let tokio_process = TokioChildProcess::new(command)?;
152 client_info.serve(tokio_process).await?
153 }
154 McpServerConfig::Url(url_config) => {
155 info!(
157 "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
158 url_config.get_url(),
159 backend_protocol,
160 mcp_router_path.mcp_protocol
161 );
162
163 match backend_protocol {
164 McpProtocol::Stdio => {
165 return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
167 }
168 McpProtocol::Sse => {
169 info!("使用SSE协议连接到: {}", url_config.get_url());
171
172 let mut headers = reqwest::header::HeaderMap::new();
174
175 if let Some(config_headers) = &url_config.headers {
177 for (key, value) in config_headers {
178 headers.insert(
181 reqwest::header::HeaderName::try_from(key).map_err(|e| {
182 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
183 })?,
184 value.parse().map_err(|e| {
185 anyhow::anyhow!(
186 "Invalid header value for {}: {}, error: {}",
187 key,
188 value,
189 e
190 )
191 })?,
192 );
193 }
194 info!(
195 "添加了 {} 个自定义 headers(包含 Authorization)",
196 headers.len()
197 );
198 } else {
199 info!("没有配置自定义 headers");
200 }
201
202 let client = reqwest::Client::builder()
203 .default_headers(headers)
204 .build()
205 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
206
207 let sse_config = rmcp::transport::sse_client::SseClientConfig {
209 sse_endpoint: url_config.get_url().to_string().into(),
210 ..Default::default()
211 };
212
213 let sse_transport =
214 SseClientTransport::start_with_client(client, sse_config).await?;
215 client_info.serve(sse_transport).await?
216 }
217 McpProtocol::Stream => {
218 info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
220
221 let mut headers = reqwest::header::HeaderMap::new();
223
224 if let Some(config_headers) = &url_config.headers {
226 for (key, value) in config_headers {
227 if key.eq_ignore_ascii_case("Authorization") {
229 continue;
230 }
231 headers.insert(
232 reqwest::header::HeaderName::try_from(key).map_err(|e| {
233 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
234 })?,
235 value.parse().map_err(|e| {
236 anyhow::anyhow!(
237 "Invalid header value for {}: {}, error: {}",
238 key,
239 value,
240 e
241 )
242 })?,
243 );
244 }
245 info!("添加了 {} 个自定义 headers", headers.len());
246 } else {
247 info!("没有配置自定义 headers");
248 }
249
250 let client = reqwest::Client::builder()
251 .default_headers(headers)
252 .build()
253 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
254
255 let auth_header = url_config.headers.as_ref().and_then(|h| {
257 h.iter()
259 .find_map(|(k, v)| {
260 if k.eq_ignore_ascii_case("Authorization") {
261 Some(v)
262 } else {
263 None
264 }
265 })
266 .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
267 });
268
269 let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
271 uri: url_config.get_url().to_string().into(),
272 auth_header,
273 ..Default::default()
274 };
275
276 let transport = StreamableHttpClientTransport::with_client(client, config);
277
278 info!(
279 "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
280 mcp_router_path.mcp_id,
281 mcp_type.clone()
282 );
283
284 let client = client_info.serve(transport).await?;
286
287 info!(
288 "Streamable HTTP客户端连接成功,MCP ID: {}",
289 mcp_router_path.mcp_id
290 );
291
292 client
293 }
294 }
295 }
296 };
297
298 let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
300
301 let proxy_manager = get_proxy_manager();
303
304 let proxy_handler_clone = proxy_handler.clone();
306
307 let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
311 McpProtocol::Sse => {
313 let addr: String = "0.0.0.0:0".to_string();
316 let sse_path = match &mcp_router_path.mcp_protocol_path {
317 McpProtocolPath::SsePath(sse_path) => sse_path,
318 _ => unreachable!(),
319 };
320 let config = SseServerConfig {
321 bind: addr.parse()?,
322 sse_path: sse_path.sse_path.clone(),
323 post_path: sse_path.message_path.clone(),
324 ct: tokio_util::sync::CancellationToken::new(),
325 sse_keep_alive: Some(std::time::Duration::from_secs(15)), };
327
328 debug!(
329 "创建SSE服务器,配置: bind={}, sse_path={}, post_path={}",
330 config.bind, config.sse_path, config.post_path
331 );
332
333 let (sse_server, router) = SseServer::new(config);
334 let ct = sse_server.with_service(move || proxy_handler_clone.clone());
335 (router, ct)
336 }
337
338 McpProtocol::Stream => {
340 let service = StreamableHttpService::new(
343 move || Ok(proxy_handler_clone.clone()),
344 LocalSessionManager::default().into(),
345 StreamableHttpServerConfig {
346 stateful_mode: false,
347 ..Default::default()
348 },
349 );
350 let router = axum::Router::new().fallback_service(service);
351 let ct = tokio_util::sync::CancellationToken::new();
352 (router, ct)
353 }
354
355 McpProtocol::Stdio => {
357 return Err(anyhow::anyhow!(
358 "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
359 ));
360 }
361 };
362
363 let ct_clone = ct.clone();
365 let mcp_id_clone = mcp_id.clone();
366
367 let mcp_service_status = McpServiceStatus::new(
369 mcp_id_clone.clone(),
370 mcp_type.clone(),
371 mcp_router_path.clone(),
372 ct_clone.clone(),
373 CheckMcpStatusResponseStatus::Ready,
374 );
375 proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
377
378 let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
381 let modified_router = router.fallback(base_path_fallback_handler);
383 info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
384 modified_router
385 } else {
386 router
387 };
388
389 info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
391 info!(
392 "SSE路径配置: sse_path={}, post_path={}",
393 match &mcp_router_path.mcp_protocol_path {
394 McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
395 _ => "N/A",
396 },
397 match &mcp_router_path.mcp_protocol_path {
398 McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
399 _ => "N/A",
400 }
401 );
402 DynamicRouterService::register_route(&base_path, router.clone());
403 info!("路由注册完成: base_path={}", base_path);
404
405 Ok((router, ct))
407}
408
409#[axum::debug_handler]
411async fn base_path_fallback_handler(
412 method: axum::http::Method,
413 uri: axum::http::Uri,
414 headers: axum::http::HeaderMap,
415) -> impl axum::response::IntoResponse {
416 let path = uri.path();
417 info!("基础路径处理器: {} {}", method, path);
418
419 if path.contains("/sse/proxy/") {
421 match method {
423 axum::http::Method::GET => {
424 let mcp_id = path.split("/sse/proxy/").nth(1);
426
427 if let Some(mcp_id) = mcp_id {
428 let proxy_manager = get_proxy_manager();
430 if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
431 (
433 axum::http::StatusCode::NOT_FOUND,
434 [("Content-Type", "text/plain".to_string())],
435 format!("MCP service '{}' not found", mcp_id).to_string(),
436 )
437 } else {
438 let accept_header = headers.get("accept");
440 if let Some(accept) = accept_header {
441 let accept_str = accept.to_str().unwrap_or("");
442 if accept_str.contains("text/event-stream") {
443 let redirect_uri = format!("{}/sse", path);
445 info!("SSE重定向到: {}", redirect_uri);
446 (
447 axum::http::StatusCode::FOUND,
448 [("Location", redirect_uri.to_string())],
449 "Redirecting to SSE endpoint".to_string(),
450 )
451 } else {
452 (
454 axum::http::StatusCode::BAD_REQUEST,
455 [("Content-Type", "text/plain".to_string())],
456 "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
457 )
458 }
459 } else {
460 (
462 axum::http::StatusCode::BAD_REQUEST,
463 [("Content-Type", "text/plain".to_string())],
464 "SSE error: Missing Accept header, expected 'text/event-stream'"
465 .to_string(),
466 )
467 }
468 }
469 } else {
470 (
472 axum::http::StatusCode::BAD_REQUEST,
473 [("Content-Type", "text/plain".to_string())],
474 "SSE error: Invalid SSE path".to_string(),
475 )
476 }
477 }
478 axum::http::Method::POST => {
479 let redirect_uri = format!("{}/message", path);
481 info!("SSE重定向到: {}", redirect_uri);
482 (
483 axum::http::StatusCode::FOUND,
484 [("Location", redirect_uri.to_string())],
485 "Redirecting to message endpoint".to_string(),
486 )
487 }
488 _ => {
489 (
491 axum::http::StatusCode::METHOD_NOT_ALLOWED,
492 [("Allow", "GET, POST".to_string())],
493 "Only GET and POST methods are allowed".to_string(),
494 )
495 }
496 }
497 } else if path.contains("/stream/proxy/") {
498 match method {
500 axum::http::Method::GET => {
501 (
503 axum::http::StatusCode::OK,
504 [("Content-Type", "application/json".to_string())],
505 r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
506 )
507 }
508 axum::http::Method::POST => {
509 (
511 axum::http::StatusCode::OK,
512 [("Content-Type", "application/json".to_string())],
513 r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
514 )
515 }
516 _ => {
517 (
519 axum::http::StatusCode::METHOD_NOT_ALLOWED,
520 [("Allow", "GET, POST".to_string())],
521 "Only GET and POST methods are allowed".to_string(),
522 )
523 }
524 }
525 } else {
526 (
528 axum::http::StatusCode::BAD_REQUEST,
529 [("Content-Type", "text/plain".to_string())],
530 "Unknown protocol or path".to_string(),
531 )
532 }
533}
534
535fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
537 let args_str = mcp_config
539 .args
540 .as_ref()
541 .map_or(String::new(), |args| args.join(" "));
542 let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
543 debug!("{cmd_str}");
544
545 if let Some(env_vars) = &mcp_config.env {
547 let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
548 if !env_vars.is_empty() {
549 debug!("环境变量: {}", env_vars.join(", "));
550 }
551 }
552
553 debug!(
555 "完整命令,mcpId={}, command={:?}",
556 mcp_router_path.mcp_id, mcp_config.command
557 );
558
559 let args_str = mcp_config
561 .args
562 .as_ref()
563 .map_or(String::new(), |args| args.join(" "));
564 let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
565 env.iter()
566 .map(|(k, v)| format!("{k}={v}"))
567 .collect::<Vec<String>>()
568 .join(" ")
569 });
570
571 let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
572 info!(
573 "完整命令字符串,mcpId={},command={:?}",
574 mcp_router_path.mcp_id, full_command
575 );
576}