mcp_stdio_proxy/server/task/
mcp_start_task.rs

1use crate::{
2    AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3    model::{
4        CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5        McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6    },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12    ServiceExt,
13    model::{ClientCapabilities, ClientInfo},
14    transport::streamable_http_server::{
15        StreamableHttpService, session::local::LocalSessionManager,
16    },
17    transport::{
18        SseClientTransport, SseServer, TokioChildProcess, StreamableHttpServerConfig,
19        sse_server::SseServerConfig,
20        streamable_http_client::StreamableHttpClientTransport,
21    },
22};
23use tokio::process::Command;
24
25///根据 mcp_id 和 mcp_json_config 启动mcp服务
26pub async fn mcp_start_task(
27    mcp_config: McpConfig,
28) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
29    let mcp_id = mcp_config.mcp_id.clone();
30    let client_protocol = mcp_config.client_protocol.clone();
31
32    // 使用客户端协议创建路由路径(决定暴露的API接口)
33    let mcp_router_path: McpRouterPath =
34        McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
35
36    let mcp_json_config = mcp_config
37        .mcp_json_config
38        .clone()
39        .expect("mcp_json_config is required");
40
41    let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
42
43    // 使用新的集成方法,后端协议在函数内部解析
44    integrate_sse_server_with_axum(
45        mcp_server_config.clone(),
46        mcp_router_path.clone(),
47        mcp_config.mcp_type,
48    )
49    .await
50}
51
52// 创建一个新函数,将 SseServer 与 axum 路由集成
53pub async fn integrate_sse_server_with_axum(
54    mcp_config: McpServerConfig,
55    mcp_router_path: McpRouterPath,
56    mcp_type: McpType,
57) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
58    let base_path = mcp_router_path.base_path.clone();
59    let mcp_id = mcp_router_path.mcp_id.clone();
60
61    // 根据MCP服务器配置解析后端协议
62    let backend_protocol = match &mcp_config {
63        // 命令行配置:使用 stdio 协议
64        McpServerConfig::Command(_) => McpProtocol::Stdio,
65        // URL配置:解析 type 字段或自动检测
66        McpServerConfig::Url(url_config) => {
67            // 首先检查 type 字段
68            if let Some(type_str) = &url_config.r#type {
69                // 尝试解析 type 字段
70                match type_str.parse::<McpProtocol>() {
71                    Ok(protocol) => {
72                        debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
73                        protocol
74                    }
75                    Err(_) => {
76                        // 如果解析失败,自动检测协议
77                        debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
78                        let detected_protocol =
79                            crate::server::detect_mcp_protocol(url_config.get_url())
80                                .await
81                                .map_err(|e| {
82                                    anyhow::anyhow!(
83                                        "协议类型 '{}' 不可识别,且自动检测失败: {}",
84                                        type_str,
85                                        e
86                                    )
87                                })?;
88                        debug!(
89                            "自动检测到协议类型: {:?}(原始配置: '{}')",
90                            detected_protocol, type_str
91                        );
92                        detected_protocol
93                    }
94                }
95            } else {
96                // 没有指定 type 字段,自动检测协议
97                debug!("未指定 type 字段,自动检测协议");
98                let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
99                    .await
100                    .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
101                detected_protocol
102            }
103        }
104    };
105
106    debug!(
107        "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
108        mcp_id, mcp_router_path.mcp_protocol, backend_protocol
109    );
110
111    // 创建客户端信息
112    let client_info = ClientInfo {
113        protocol_version: Default::default(),
114        capabilities: ClientCapabilities::builder()
115            .enable_experimental()
116            .enable_roots()
117            .enable_roots_list_changed()
118            .enable_sampling()
119            .build(),
120        ..Default::default()
121    };
122
123    // 根据配置类型创建不同的客户端服务
124    let client = match &mcp_config {
125        McpServerConfig::Command(cmd_config) => {
126            // 创建子进程命令
127            let mut command = Command::new(&cmd_config.command);
128
129            // 正确处理Option<Vec<String>>
130            if let Some(args) = &cmd_config.args {
131                command.args(args);
132            }
133
134            // 正确处理Option<HashMap<String, String>>
135            if let Some(env_vars) = &cmd_config.env {
136                for (key, value) in env_vars {
137                    command.env(key, value);
138                }
139            }
140
141            // 记录命令执行信息,方便调试
142            log_command_details(cmd_config, &mcp_router_path);
143
144            info!(
145                "子进程已启动,MCP ID: {}, 类型: {:?}",
146                mcp_router_path.mcp_id,
147                mcp_type.clone()
148            );
149
150            // 创建子进程传输并创建客户端服务
151            let tokio_process = TokioChildProcess::new(command)?;
152            client_info.serve(tokio_process).await?
153        }
154        McpServerConfig::Url(url_config) => {
155            // 根据后端协议类型创建不同的客户端传输
156            info!(
157                "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
158                url_config.get_url(),
159                backend_protocol,
160                mcp_router_path.mcp_protocol
161            );
162
163            match backend_protocol {
164                McpProtocol::Stdio => {
165                    // URL 配置不应该出现 Stdio 协议
166                    return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
167                }
168                McpProtocol::Sse => {
169                    // SSE 协议 - 创建 SSE 客户端传输
170                    info!("使用SSE协议连接到: {}", url_config.get_url());
171
172                    // 创建带有自定义 headers 的 reqwest client
173                    let mut headers = reqwest::header::HeaderMap::new();
174
175                    // 添加配置中的自定义 headers
176                    if let Some(config_headers) = &url_config.headers {
177                        for (key, value) in config_headers {
178                            // SSE 协议:直接添加所有 headers(不跳过 Authorization)
179                            // 原因:SSE 协议没有官方的 auth_header 字段配置
180                            headers.insert(
181                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
182                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
183                                })?,
184                                value.parse().map_err(|e| {
185                                    anyhow::anyhow!(
186                                        "Invalid header value for {}: {}, error: {}",
187                                        key,
188                                        value,
189                                        e
190                                    )
191                                })?,
192                            );
193                        }
194                        info!(
195                            "添加了 {} 个自定义 headers(包含 Authorization)",
196                            headers.len()
197                        );
198                    } else {
199                        info!("没有配置自定义 headers");
200                    }
201
202                    let client = reqwest::Client::builder()
203                        .default_headers(headers)
204                        .build()
205                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
206
207                    // 创建 SSE 客户端配置
208                    let sse_config = rmcp::transport::sse_client::SseClientConfig {
209                        sse_endpoint: url_config.get_url().to_string().into(),
210                        ..Default::default()
211                    };
212
213                    let sse_transport =
214                        SseClientTransport::start_with_client(client, sse_config).await?;
215                    client_info.serve(sse_transport).await?
216                }
217                McpProtocol::Stream => {
218                    // Streamable 协议 - 创建 Streamable HTTP 客户端传输
219                    info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
220
221                    // 创建自定义 client 和配置(支持 Authorization header)
222                    let mut headers = reqwest::header::HeaderMap::new();
223
224                    // 添加配置中的自定义 headers(排除 Authorization)
225                    if let Some(config_headers) = &url_config.headers {
226                        for (key, value) in config_headers {
227                            // 跳过 Authorization header,它会通过 auth_header 配置字段传递
228                            if key.eq_ignore_ascii_case("Authorization") {
229                                continue;
230                            }
231                            headers.insert(
232                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
233                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
234                                })?,
235                                value.parse().map_err(|e| {
236                                    anyhow::anyhow!(
237                                        "Invalid header value for {}: {}, error: {}",
238                                        key,
239                                        value,
240                                        e
241                                    )
242                                })?,
243                            );
244                        }
245                        info!("添加了 {} 个自定义 headers", headers.len());
246                    } else {
247                        info!("没有配置自定义 headers");
248                    }
249
250                    let client = reqwest::Client::builder()
251                        .default_headers(headers)
252                        .build()
253                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
254
255                    // 提取 Authorization header 用于配置(不区分大小写)
256                    let auth_header = url_config.headers.as_ref().and_then(|h| {
257                        // HTTP header 名称不区分大小写,查找 Authorization
258                        h.iter()
259                            .find_map(|(k, v)| {
260                                if k.eq_ignore_ascii_case("Authorization") {
261                                    Some(v)
262                                } else {
263                                    None
264                                }
265                            })
266                            .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
267                    });
268
269                    // 创建传输配置
270                    let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
271                        uri: url_config.get_url().to_string().into(),
272                        auth_header,
273                        ..Default::default()
274                    };
275
276                    let transport = StreamableHttpClientTransport::with_client(client, config);
277
278                    info!(
279                        "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
280                        mcp_router_path.mcp_id,
281                        mcp_type.clone()
282                    );
283
284                    // serve 会建立连接并完成初始化握手
285                    let client = client_info.serve(transport).await?;
286
287                    info!(
288                        "Streamable HTTP客户端连接成功,MCP ID: {}",
289                        mcp_router_path.mcp_id
290                    );
291
292                    client
293                }
294            }
295        }
296    };
297
298    // 创建代理处理器
299    let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
300
301    // 获取全局 ProxyHandlerManager
302    let proxy_manager = get_proxy_manager();
303
304    // 注册代理处理器(ProxyHandler 内部已使用 Arc,clone 非常轻量)
305    let proxy_handler_clone = proxy_handler.clone();
306
307    // 根据客户端协议和后端协议创建服务器(支持协议转换)
308    // 支持三种模式:
309    // 根据客户端协议(主导)创建路由:决定对外暴露的 API 接口类型
310    let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
311        // ================ 客户端使用 SSE 协议 ================
312        McpProtocol::Sse => {
313            // 对外提供 SSE 接口
314            // 协议转换由 proxy_handler_clone 自动处理
315            let addr: String = "0.0.0.0:0".to_string();
316            let sse_path = match &mcp_router_path.mcp_protocol_path {
317                McpProtocolPath::SsePath(sse_path) => sse_path,
318                _ => unreachable!(),
319            };
320            let config = SseServerConfig {
321                bind: addr.parse()?,
322                sse_path: sse_path.sse_path.clone(),
323                post_path: sse_path.message_path.clone(),
324                ct: tokio_util::sync::CancellationToken::new(),
325                sse_keep_alive: Some(std::time::Duration::from_secs(15)), // 15秒心跳,防止连接被中间代理关闭
326            };
327
328            debug!(
329                "创建SSE服务器,配置: bind={}, sse_path={}, post_path={}",
330                config.bind, config.sse_path, config.post_path
331            );
332
333            let (sse_server, router) = SseServer::new(config);
334            let ct = sse_server.with_service(move || proxy_handler_clone.clone());
335            (router, ct)
336        }
337
338        // ================ 客户端使用 Streamable HTTP 协议 ================
339        McpProtocol::Stream => {
340            // 对外提供 Streamable HTTP 接口
341            // 内部协议转换由 proxy_handler_clone 自动处理
342            let service = StreamableHttpService::new(
343                move || Ok(proxy_handler_clone.clone()),
344                LocalSessionManager::default().into(),
345                StreamableHttpServerConfig {
346                    stateful_mode: false,
347                    ..Default::default()
348                },
349            );
350            let router = axum::Router::new().fallback_service(service);
351            let ct = tokio_util::sync::CancellationToken::new();
352            (router, ct)
353        }
354
355        // 不应该出现的情况
356        McpProtocol::Stdio => {
357            return Err(anyhow::anyhow!(
358                "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
359            ));
360        }
361    };
362
363    // 克隆一份取消令牌和 mcp_id 用于监控子进程
364    let ct_clone = ct.clone();
365    let mcp_id_clone = mcp_id.clone();
366
367    // 存储 MCP 服务状态
368    let mcp_service_status = McpServiceStatus::new(
369        mcp_id_clone.clone(),
370        mcp_type.clone(),
371        mcp_router_path.clone(),
372        ct_clone.clone(),
373        CheckMcpStatusResponseStatus::Ready,
374    );
375    // 添加 MCP 服务状态到全局管理器,以及 proxy_handler 的透明代理
376    proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
377
378    // 为SSE和Stream协议添加基础路径处理
379    // 支持直接访问基础路径,自动重定向到正确的子路径
380    let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
381        // 使用fallback处理器来匹配基础路径
382        let modified_router = router.fallback(base_path_fallback_handler);
383        info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
384        modified_router
385    } else {
386        router
387    };
388
389    // 注册路由到全局路由表
390    info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
391    info!(
392        "SSE路径配置: sse_path={}, post_path={}",
393        match &mcp_router_path.mcp_protocol_path {
394            McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
395            _ => "N/A",
396        },
397        match &mcp_router_path.mcp_protocol_path {
398            McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
399            _ => "N/A",
400        }
401    );
402    DynamicRouterService::register_route(&base_path, router.clone());
403    info!("路由注册完成: base_path={}", base_path);
404
405    // 返回路由和取消令牌
406    Ok((router, ct))
407}
408
409// 基础路径处理器 - 支持直接访问基础路径,自动重定向到正确的子路径
410#[axum::debug_handler]
411async fn base_path_fallback_handler(
412    method: axum::http::Method,
413    uri: axum::http::Uri,
414    headers: axum::http::HeaderMap,
415) -> impl axum::response::IntoResponse {
416    let path = uri.path();
417    info!("基础路径处理器: {} {}", method, path);
418
419    // 判断是SSE还是Stream协议
420    if path.contains("/sse/proxy/") {
421        // SSE协议处理
422        match method {
423            axum::http::Method::GET => {
424                // 从路径中提取 MCP ID
425                let mcp_id = path.split("/sse/proxy/").nth(1);
426
427                if let Some(mcp_id) = mcp_id {
428                    // 检查MCP服务是否存在
429                    let proxy_manager = get_proxy_manager();
430                    if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
431                        // MCP服务不存在
432                        (
433                            axum::http::StatusCode::NOT_FOUND,
434                            [("Content-Type", "text/plain".to_string())],
435                            format!("MCP service '{}' not found", mcp_id).to_string(),
436                        )
437                    } else {
438                        // MCP服务存在,检查Accept头
439                        let accept_header = headers.get("accept");
440                        if let Some(accept) = accept_header {
441                            let accept_str = accept.to_str().unwrap_or("");
442                            if accept_str.contains("text/event-stream") {
443                                // 正确的Accept头,重定向到 /sse
444                                let redirect_uri = format!("{}/sse", path);
445                                info!("SSE重定向到: {}", redirect_uri);
446                                (
447                                    axum::http::StatusCode::FOUND,
448                                    [("Location", redirect_uri.to_string())],
449                                    "Redirecting to SSE endpoint".to_string(),
450                                )
451                            } else {
452                                // Accept头不正确
453                                (
454                                    axum::http::StatusCode::BAD_REQUEST,
455                                    [("Content-Type", "text/plain".to_string())],
456                                    "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
457                                )
458                            }
459                        } else {
460                            // 没有Accept头
461                            (
462                                axum::http::StatusCode::BAD_REQUEST,
463                                [("Content-Type", "text/plain".to_string())],
464                                "SSE error: Missing Accept header, expected 'text/event-stream'"
465                                    .to_string(),
466                            )
467                        }
468                    }
469                } else {
470                    // 无法从路径中提取MCP ID
471                    (
472                        axum::http::StatusCode::BAD_REQUEST,
473                        [("Content-Type", "text/plain".to_string())],
474                        "SSE error: Invalid SSE path".to_string(),
475                    )
476                }
477            }
478            axum::http::Method::POST => {
479                // POST请求重定向到 /message
480                let redirect_uri = format!("{}/message", path);
481                info!("SSE重定向到: {}", redirect_uri);
482                (
483                    axum::http::StatusCode::FOUND,
484                    [("Location", redirect_uri.to_string())],
485                    "Redirecting to message endpoint".to_string(),
486                )
487            }
488            _ => {
489                // 其他方法返回405 Method Not Allowed
490                (
491                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
492                    [("Allow", "GET, POST".to_string())],
493                    "Only GET and POST methods are allowed".to_string(),
494                )
495            }
496        }
497    } else if path.contains("/stream/proxy/") {
498        // Stream协议处理 - 直接返回成功,不重定向
499        match method {
500            axum::http::Method::GET => {
501                // GET请求返回服务器信息
502                (
503                    axum::http::StatusCode::OK,
504                    [("Content-Type", "application/json".to_string())],
505                    r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
506                )
507            }
508            axum::http::Method::POST => {
509                // POST请求返回成功,让StreamableHttpService处理
510                (
511                    axum::http::StatusCode::OK,
512                    [("Content-Type", "application/json".to_string())],
513                    r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
514                )
515            }
516            _ => {
517                // 其他方法返回405 Method Not Allowed
518                (
519                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
520                    [("Allow", "GET, POST".to_string())],
521                    "Only GET and POST methods are allowed".to_string(),
522                )
523            }
524        }
525    } else {
526        // 未知协议
527        (
528            axum::http::StatusCode::BAD_REQUEST,
529            [("Content-Type", "text/plain".to_string())],
530            "Unknown protocol or path".to_string(),
531        )
532    }
533}
534
535// 提取记录命令详情的函数
536fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
537    // 打印命令行参数
538    let args_str = mcp_config
539        .args
540        .as_ref()
541        .map_or(String::new(), |args| args.join(" "));
542    let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
543    debug!("{cmd_str}");
544
545    // 打印环境变量
546    if let Some(env_vars) = &mcp_config.env {
547        let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
548        if !env_vars.is_empty() {
549            debug!("环境变量: {}", env_vars.join(", "));
550        }
551    }
552
553    // 打印完整命令
554    debug!(
555        "完整命令,mcpId={}, command={:?}",
556        mcp_router_path.mcp_id, mcp_config.command
557    );
558
559    // 构建完整的命令字符串,用于直接复制运行
560    let args_str = mcp_config
561        .args
562        .as_ref()
563        .map_or(String::new(), |args| args.join(" "));
564    let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
565        env.iter()
566            .map(|(k, v)| format!("{k}={v}"))
567            .collect::<Vec<String>>()
568            .join(" ")
569    });
570
571    let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
572    info!(
573        "完整命令字符串,mcpId={},command={:?}",
574        mcp_router_path.mcp_id, full_command
575    );
576}