1use crate::{
2 AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3 model::{
4 CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5 McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6 },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12 ServiceExt,
13 model::{ClientCapabilities, ClientInfo},
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17 transport::{
18 TokioChildProcess,
19 streamable_http_client::{
20 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21 },
22 },
23};
24use tokio::process::Command;
25
26pub async fn mcp_start_task(
28 mcp_config: McpConfig,
29) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
30 let mcp_id = mcp_config.mcp_id.clone();
31 let client_protocol = mcp_config.client_protocol.clone();
32
33 let mcp_router_path: McpRouterPath =
35 McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
36
37 let mcp_json_config = mcp_config
38 .mcp_json_config
39 .clone()
40 .expect("mcp_json_config is required");
41
42 let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
43
44 integrate_sse_server_with_axum(
46 mcp_server_config.clone(),
47 mcp_router_path.clone(),
48 mcp_config.mcp_type,
49 )
50 .await
51}
52
53pub async fn integrate_sse_server_with_axum(
55 mcp_config: McpServerConfig,
56 mcp_router_path: McpRouterPath,
57 mcp_type: McpType,
58) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
59 let base_path = mcp_router_path.base_path.clone();
60 let mcp_id = mcp_router_path.mcp_id.clone();
61
62 let backend_protocol = match &mcp_config {
64 McpServerConfig::Command(_) => McpProtocol::Stdio,
66 McpServerConfig::Url(url_config) => {
68 if let Some(type_str) = &url_config.r#type {
70 match type_str.parse::<McpProtocol>() {
72 Ok(protocol) => {
73 debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
74 protocol
75 }
76 Err(_) => {
77 debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
79 let detected_protocol =
80 crate::server::detect_mcp_protocol(url_config.get_url())
81 .await
82 .map_err(|e| {
83 anyhow::anyhow!(
84 "协议类型 '{}' 不可识别,且自动检测失败: {}",
85 type_str,
86 e
87 )
88 })?;
89 debug!(
90 "自动检测到协议类型: {:?}(原始配置: '{}')",
91 detected_protocol, type_str
92 );
93 detected_protocol
94 }
95 }
96 } else {
97 debug!("未指定 type 字段,自动检测协议");
99 let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
100 .await
101 .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
102 detected_protocol
103 }
104 }
105 };
106
107 debug!(
108 "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
109 mcp_id, mcp_router_path.mcp_protocol, backend_protocol
110 );
111
112 let client_info = ClientInfo {
114 protocol_version: Default::default(),
115 capabilities: ClientCapabilities::builder()
116 .enable_experimental()
117 .enable_roots()
118 .enable_roots_list_changed()
119 .enable_sampling()
120 .build(),
121 ..Default::default()
122 };
123
124 let client = match &mcp_config {
126 McpServerConfig::Command(cmd_config) => {
127 let mut command = Command::new(&cmd_config.command);
129
130 if let Some(args) = &cmd_config.args {
132 command.args(args);
133 }
134
135 if let Some(env_vars) = &cmd_config.env {
137 for (key, value) in env_vars {
138 command.env(key, value);
139 }
140 }
141
142 log_command_details(cmd_config, &mcp_router_path);
144
145 info!(
146 "子进程已启动,MCP ID: {}, 类型: {:?}",
147 mcp_router_path.mcp_id,
148 mcp_type.clone()
149 );
150
151 let tokio_process = TokioChildProcess::new(command)?;
153 client_info.serve(tokio_process).await?
154 }
155 McpServerConfig::Url(url_config) => {
156 info!(
158 "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
159 url_config.get_url(),
160 backend_protocol,
161 mcp_router_path.mcp_protocol
162 );
163
164 match backend_protocol {
165 McpProtocol::Stdio => {
166 return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
168 }
169 McpProtocol::Sse => {
170 info!(
172 "使用Streamable HTTP协议连接到(SSE兼容模式): {}",
173 url_config.get_url()
174 );
175
176 let mut headers = reqwest::header::HeaderMap::new();
178
179 if let Some(config_headers) = &url_config.headers {
181 for (key, value) in config_headers {
182 headers.insert(
183 reqwest::header::HeaderName::try_from(key).map_err(|e| {
184 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
185 })?,
186 value.parse().map_err(|e| {
187 anyhow::anyhow!(
188 "Invalid header value for {}: {}, error: {}",
189 key,
190 value,
191 e
192 )
193 })?,
194 );
195 }
196 info!("添加了 {} 个自定义 headers", headers.len());
197 } else {
198 info!("没有配置自定义 headers");
199 }
200
201 let client = reqwest::Client::builder()
202 .default_headers(headers)
203 .build()
204 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
205
206 let config = StreamableHttpClientTransportConfig {
208 uri: url_config.get_url().to_string().into(),
209 ..Default::default()
210 };
211
212 let transport = StreamableHttpClientTransport::with_client(client, config);
213 client_info.serve(transport).await?
214 }
215 McpProtocol::Stream => {
216 info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
218
219 let mut headers = reqwest::header::HeaderMap::new();
221
222 if let Some(config_headers) = &url_config.headers {
224 for (key, value) in config_headers {
225 if key.eq_ignore_ascii_case("Authorization") {
227 continue;
228 }
229 headers.insert(
230 reqwest::header::HeaderName::try_from(key).map_err(|e| {
231 anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
232 })?,
233 value.parse().map_err(|e| {
234 anyhow::anyhow!(
235 "Invalid header value for {}: {}, error: {}",
236 key,
237 value,
238 e
239 )
240 })?,
241 );
242 }
243 info!("添加了 {} 个自定义 headers", headers.len());
244 } else {
245 info!("没有配置自定义 headers");
246 }
247
248 let client = reqwest::Client::builder()
249 .default_headers(headers)
250 .build()
251 .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
252
253 let auth_header = url_config.headers.as_ref().and_then(|h| {
255 h.iter()
257 .find_map(|(k, v)| {
258 if k.eq_ignore_ascii_case("Authorization") {
259 Some(v)
260 } else {
261 None
262 }
263 })
264 .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
265 });
266
267 let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
269 uri: url_config.get_url().to_string().into(),
270 auth_header,
271 ..Default::default()
272 };
273
274 let transport = StreamableHttpClientTransport::with_client(client, config);
275
276 info!(
277 "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
278 mcp_router_path.mcp_id,
279 mcp_type.clone()
280 );
281
282 let client = client_info.serve(transport).await?;
284
285 info!(
286 "Streamable HTTP客户端连接成功,MCP ID: {}",
287 mcp_router_path.mcp_id
288 );
289
290 client
291 }
292 }
293 }
294 };
295
296 let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
298
299 let proxy_manager = get_proxy_manager();
301
302 let proxy_handler_clone = proxy_handler.clone();
304
305 let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
309 McpProtocol::Sse => {
312 debug!(
315 "创建Streamable HTTP服务器(SSE兼容模式), mcp_id={}",
316 mcp_router_path.mcp_id
317 );
318
319 let ct = tokio_util::sync::CancellationToken::new();
320 let service = StreamableHttpService::new(
321 move || Ok(proxy_handler_clone.clone()),
322 LocalSessionManager::default().into(),
323 StreamableHttpServerConfig {
324 cancellation_token: ct.clone(),
325 ..Default::default()
326 },
327 );
328 let router = axum::Router::new().fallback_service(service);
329 (router, ct)
330 }
331
332 McpProtocol::Stream => {
334 let service = StreamableHttpService::new(
337 move || Ok(proxy_handler_clone.clone()),
338 LocalSessionManager::default().into(),
339 Default::default(),
340 );
341 let router = axum::Router::new().fallback_service(service);
342 let ct = tokio_util::sync::CancellationToken::new();
343 (router, ct)
344 }
345
346 McpProtocol::Stdio => {
348 return Err(anyhow::anyhow!(
349 "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
350 ));
351 }
352 };
353
354 let ct_clone = ct.clone();
356 let mcp_id_clone = mcp_id.clone();
357
358 let mcp_service_status = McpServiceStatus::new(
360 mcp_id_clone.clone(),
361 mcp_type.clone(),
362 mcp_router_path.clone(),
363 ct_clone.clone(),
364 CheckMcpStatusResponseStatus::Ready,
365 );
366 proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
368
369 let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
372 let modified_router = router.fallback(base_path_fallback_handler);
374 info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
375 modified_router
376 } else {
377 router
378 };
379
380 info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
382 info!(
383 "SSE路径配置: sse_path={}, post_path={}",
384 match &mcp_router_path.mcp_protocol_path {
385 McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
386 _ => "N/A",
387 },
388 match &mcp_router_path.mcp_protocol_path {
389 McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
390 _ => "N/A",
391 }
392 );
393 DynamicRouterService::register_route(&base_path, router.clone());
394 info!("路由注册完成: base_path={}", base_path);
395
396 Ok((router, ct))
398}
399
400#[axum::debug_handler]
402async fn base_path_fallback_handler(
403 method: axum::http::Method,
404 uri: axum::http::Uri,
405 headers: axum::http::HeaderMap,
406) -> impl axum::response::IntoResponse {
407 let path = uri.path();
408 info!("基础路径处理器: {} {}", method, path);
409
410 if path.contains("/sse/proxy/") {
412 match method {
414 axum::http::Method::GET => {
415 let mcp_id = path.split("/sse/proxy/").nth(1);
417
418 if let Some(mcp_id) = mcp_id {
419 let proxy_manager = get_proxy_manager();
421 if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
422 (
424 axum::http::StatusCode::NOT_FOUND,
425 [("Content-Type", "text/plain".to_string())],
426 format!("MCP service '{}' not found", mcp_id).to_string(),
427 )
428 } else {
429 let accept_header = headers.get("accept");
431 if let Some(accept) = accept_header {
432 let accept_str = accept.to_str().unwrap_or("");
433 if accept_str.contains("text/event-stream") {
434 let redirect_uri = format!("{}/sse", path);
436 info!("SSE重定向到: {}", redirect_uri);
437 (
438 axum::http::StatusCode::FOUND,
439 [("Location", redirect_uri.to_string())],
440 "Redirecting to SSE endpoint".to_string(),
441 )
442 } else {
443 (
445 axum::http::StatusCode::BAD_REQUEST,
446 [("Content-Type", "text/plain".to_string())],
447 "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
448 )
449 }
450 } else {
451 (
453 axum::http::StatusCode::BAD_REQUEST,
454 [("Content-Type", "text/plain".to_string())],
455 "SSE error: Missing Accept header, expected 'text/event-stream'"
456 .to_string(),
457 )
458 }
459 }
460 } else {
461 (
463 axum::http::StatusCode::BAD_REQUEST,
464 [("Content-Type", "text/plain".to_string())],
465 "SSE error: Invalid SSE path".to_string(),
466 )
467 }
468 }
469 axum::http::Method::POST => {
470 let redirect_uri = format!("{}/message", path);
472 info!("SSE重定向到: {}", redirect_uri);
473 (
474 axum::http::StatusCode::FOUND,
475 [("Location", redirect_uri.to_string())],
476 "Redirecting to message endpoint".to_string(),
477 )
478 }
479 _ => {
480 (
482 axum::http::StatusCode::METHOD_NOT_ALLOWED,
483 [("Allow", "GET, POST".to_string())],
484 "Only GET and POST methods are allowed".to_string(),
485 )
486 }
487 }
488 } else if path.contains("/stream/proxy/") {
489 match method {
491 axum::http::Method::GET => {
492 (
494 axum::http::StatusCode::OK,
495 [("Content-Type", "application/json".to_string())],
496 r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
497 )
498 }
499 axum::http::Method::POST => {
500 (
502 axum::http::StatusCode::OK,
503 [("Content-Type", "application/json".to_string())],
504 r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
505 )
506 }
507 _ => {
508 (
510 axum::http::StatusCode::METHOD_NOT_ALLOWED,
511 [("Allow", "GET, POST".to_string())],
512 "Only GET and POST methods are allowed".to_string(),
513 )
514 }
515 }
516 } else {
517 (
519 axum::http::StatusCode::BAD_REQUEST,
520 [("Content-Type", "text/plain".to_string())],
521 "Unknown protocol or path".to_string(),
522 )
523 }
524}
525
526fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
528 let args_str = mcp_config
530 .args
531 .as_ref()
532 .map_or(String::new(), |args| args.join(" "));
533 let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
534 debug!("{cmd_str}");
535
536 if let Some(env_vars) = &mcp_config.env {
538 let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
539 if !env_vars.is_empty() {
540 debug!("环境变量: {}", env_vars.join(", "));
541 }
542 }
543
544 debug!(
546 "完整命令,mcpId={}, command={:?}",
547 mcp_router_path.mcp_id, mcp_config.command
548 );
549
550 let args_str = mcp_config
552 .args
553 .as_ref()
554 .map_or(String::new(), |args| args.join(" "));
555 let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
556 env.iter()
557 .map(|(k, v)| format!("{k}={v}"))
558 .collect::<Vec<String>>()
559 .join(" ")
560 });
561
562 let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
563 info!(
564 "完整命令字符串,mcpId={},command={:?}",
565 mcp_router_path.mcp_id, full_command
566 );
567}