mcp_stdio_proxy/server/task/
mcp_start_task.rs

1use crate::{
2    AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3    model::{
4        CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5        McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6    },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12    ServiceExt,
13    model::{ClientCapabilities, ClientInfo},
14    transport::streamable_http_server::{
15        StreamableHttpService, session::local::LocalSessionManager,
16    },
17    transport::{
18        SseClientTransport, SseServer, TokioChildProcess, sse_server::SseServerConfig,
19        streamable_http_client::StreamableHttpClientTransport,
20    },
21};
22use tokio::process::Command;
23
24///根据 mcp_id 和 mcp_json_config 启动mcp服务
25pub async fn mcp_start_task(
26    mcp_config: McpConfig,
27) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
28    let mcp_id = mcp_config.mcp_id.clone();
29    let client_protocol = mcp_config.client_protocol.clone();
30
31    // 使用客户端协议创建路由路径(决定暴露的API接口)
32    let mcp_router_path: McpRouterPath =
33        McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
34
35    let mcp_json_config = mcp_config
36        .mcp_json_config
37        .clone()
38        .expect("mcp_json_config is required");
39
40    let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
41
42    // 使用新的集成方法,后端协议在函数内部解析
43    integrate_sse_server_with_axum(
44        mcp_server_config.clone(),
45        mcp_router_path.clone(),
46        mcp_config.mcp_type,
47    )
48    .await
49}
50
51// 创建一个新函数,将 SseServer 与 axum 路由集成
52pub async fn integrate_sse_server_with_axum(
53    mcp_config: McpServerConfig,
54    mcp_router_path: McpRouterPath,
55    mcp_type: McpType,
56) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
57    let base_path = mcp_router_path.base_path.clone();
58    let mcp_id = mcp_router_path.mcp_id.clone();
59
60    // 根据MCP服务器配置解析后端协议
61    let backend_protocol = match &mcp_config {
62        // 命令行配置:使用 stdio 协议
63        McpServerConfig::Command(_) => McpProtocol::Stdio,
64        // URL配置:解析 type 字段或自动检测
65        McpServerConfig::Url(url_config) => {
66            // 首先检查 type 字段
67            if let Some(type_str) = &url_config.r#type {
68                // 尝试解析 type 字段
69                match type_str.parse::<McpProtocol>() {
70                    Ok(protocol) => {
71                        debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
72                        protocol
73                    }
74                    Err(_) => {
75                        // 如果解析失败,自动检测协议
76                        debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
77                        let detected_protocol =
78                            crate::server::detect_mcp_protocol(url_config.get_url())
79                                .await
80                                .map_err(|e| {
81                                    anyhow::anyhow!(
82                                        "协议类型 '{}' 不可识别,且自动检测失败: {}",
83                                        type_str,
84                                        e
85                                    )
86                                })?;
87                        debug!(
88                            "自动检测到协议类型: {:?}(原始配置: '{}')",
89                            detected_protocol, type_str
90                        );
91                        detected_protocol
92                    }
93                }
94            } else {
95                // 没有指定 type 字段,自动检测协议
96                debug!("未指定 type 字段,自动检测协议");
97                let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
98                    .await
99                    .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
100                detected_protocol
101            }
102        }
103    };
104
105    debug!(
106        "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
107        mcp_id, mcp_router_path.mcp_protocol, backend_protocol
108    );
109
110    // 创建客户端信息
111    let client_info = ClientInfo {
112        protocol_version: Default::default(),
113        capabilities: ClientCapabilities::builder()
114            .enable_experimental()
115            .enable_roots()
116            .enable_roots_list_changed()
117            .enable_sampling()
118            .build(),
119        ..Default::default()
120    };
121
122    // 根据配置类型创建不同的客户端服务
123    let client = match &mcp_config {
124        McpServerConfig::Command(cmd_config) => {
125            // 创建子进程命令
126            let mut command = Command::new(&cmd_config.command);
127
128            // 正确处理Option<Vec<String>>
129            if let Some(args) = &cmd_config.args {
130                command.args(args);
131            }
132
133            // 正确处理Option<HashMap<String, String>>
134            if let Some(env_vars) = &cmd_config.env {
135                for (key, value) in env_vars {
136                    command.env(key, value);
137                }
138            }
139
140            // 记录命令执行信息,方便调试
141            log_command_details(cmd_config, &mcp_router_path);
142
143            info!(
144                "子进程已启动,MCP ID: {}, 类型: {:?}",
145                mcp_router_path.mcp_id,
146                mcp_type.clone()
147            );
148
149            // 创建子进程传输并创建客户端服务
150            let tokio_process = TokioChildProcess::new(command)?;
151            client_info.serve(tokio_process).await?
152        }
153        McpServerConfig::Url(url_config) => {
154            // 根据后端协议类型创建不同的客户端传输
155            info!(
156                "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
157                url_config.get_url(),
158                backend_protocol,
159                mcp_router_path.mcp_protocol
160            );
161
162            match backend_protocol {
163                McpProtocol::Stdio => {
164                    // URL 配置不应该出现 Stdio 协议
165                    return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
166                }
167                McpProtocol::Sse => {
168                    // SSE 协议 - 创建 SSE 客户端传输
169                    info!("使用SSE协议连接到: {}", url_config.get_url());
170
171                    // 创建带有自定义 headers 的 reqwest client
172                    let mut headers = reqwest::header::HeaderMap::new();
173
174                    // 添加配置中的自定义 headers
175                    if let Some(config_headers) = &url_config.headers {
176                        for (key, value) in config_headers {
177                            // SSE 协议:直接添加所有 headers(不跳过 Authorization)
178                            // 原因:SSE 协议没有官方的 auth_header 字段配置
179                            headers.insert(
180                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
181                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
182                                })?,
183                                value.parse().map_err(|e| {
184                                    anyhow::anyhow!(
185                                        "Invalid header value for {}: {}, error: {}",
186                                        key,
187                                        value,
188                                        e
189                                    )
190                                })?,
191                            );
192                        }
193                        info!(
194                            "添加了 {} 个自定义 headers(包含 Authorization)",
195                            headers.len()
196                        );
197                    } else {
198                        info!("没有配置自定义 headers");
199                    }
200
201                    let client = reqwest::Client::builder()
202                        .default_headers(headers)
203                        .build()
204                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
205
206                    // 创建 SSE 客户端配置
207                    let sse_config = rmcp::transport::sse_client::SseClientConfig {
208                        sse_endpoint: url_config.get_url().to_string().into(),
209                        ..Default::default()
210                    };
211
212                    let sse_transport =
213                        SseClientTransport::start_with_client(client, sse_config).await?;
214                    client_info.serve(sse_transport).await?
215                }
216                McpProtocol::Stream => {
217                    // Streamable 协议 - 创建 Streamable HTTP 客户端传输
218                    info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
219
220                    // 创建自定义 client 和配置(支持 Authorization header)
221                    let mut headers = reqwest::header::HeaderMap::new();
222
223                    // 添加配置中的自定义 headers(排除 Authorization)
224                    if let Some(config_headers) = &url_config.headers {
225                        for (key, value) in config_headers {
226                            // 跳过 Authorization header,它会通过 auth_header 配置字段传递
227                            if key.eq_ignore_ascii_case("Authorization") {
228                                continue;
229                            }
230                            headers.insert(
231                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
232                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
233                                })?,
234                                value.parse().map_err(|e| {
235                                    anyhow::anyhow!(
236                                        "Invalid header value for {}: {}, error: {}",
237                                        key,
238                                        value,
239                                        e
240                                    )
241                                })?,
242                            );
243                        }
244                        info!("添加了 {} 个自定义 headers", headers.len());
245                    } else {
246                        info!("没有配置自定义 headers");
247                    }
248
249                    let client = reqwest::Client::builder()
250                        .default_headers(headers)
251                        .build()
252                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
253
254                    // 提取 Authorization header 用于配置(不区分大小写)
255                    let auth_header = url_config.headers.as_ref().and_then(|h| {
256                        // HTTP header 名称不区分大小写,查找 Authorization
257                        h.iter()
258                            .find_map(|(k, v)| {
259                                if k.eq_ignore_ascii_case("Authorization") {
260                                    Some(v)
261                                } else {
262                                    None
263                                }
264                            })
265                            .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
266                    });
267
268                    // 创建传输配置
269                    let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
270                        uri: url_config.get_url().to_string().into(),
271                        auth_header,
272                        ..Default::default()
273                    };
274
275                    let transport = StreamableHttpClientTransport::with_client(client, config);
276
277                    info!(
278                        "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
279                        mcp_router_path.mcp_id,
280                        mcp_type.clone()
281                    );
282
283                    // serve 会建立连接并完成初始化握手
284                    let client = client_info.serve(transport).await?;
285
286                    info!(
287                        "Streamable HTTP客户端连接成功,MCP ID: {}",
288                        mcp_router_path.mcp_id
289                    );
290
291                    client
292                }
293            }
294        }
295    };
296
297    // 创建代理处理器
298    let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
299
300    // 获取全局 ProxyHandlerManager
301    let proxy_manager = get_proxy_manager();
302
303    // 注册代理处理器(ProxyHandler 内部已使用 Arc,clone 非常轻量)
304    let proxy_handler_clone = proxy_handler.clone();
305
306    // 根据客户端协议和后端协议创建服务器(支持协议转换)
307    // 支持三种模式:
308    // 根据客户端协议(主导)创建路由:决定对外暴露的 API 接口类型
309    let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
310        // ================ 客户端使用 SSE 协议 ================
311        McpProtocol::Sse => {
312            // 对外提供 SSE 接口
313            // 协议转换由 proxy_handler_clone 自动处理
314            let addr: String = "0.0.0.0:0".to_string();
315            let sse_path = match &mcp_router_path.mcp_protocol_path {
316                McpProtocolPath::SsePath(sse_path) => sse_path,
317                _ => unreachable!(),
318            };
319            let config = SseServerConfig {
320                bind: addr.parse()?,
321                sse_path: sse_path.sse_path.clone(),
322                post_path: sse_path.message_path.clone(),
323                ct: tokio_util::sync::CancellationToken::new(),
324                sse_keep_alive: None,
325            };
326
327            debug!(
328                "创建SSE服务器,配置: bind={}, sse_path={}, post_path={}",
329                config.bind, config.sse_path, config.post_path
330            );
331
332            let (sse_server, router) = SseServer::new(config);
333            let ct = sse_server.with_service(move || proxy_handler_clone.clone());
334            (router, ct)
335        }
336
337        // ================ 客户端使用 Streamable HTTP 协议 ================
338        McpProtocol::Stream => {
339            // 对外提供 Streamable HTTP 接口
340            // 内部协议转换由 proxy_handler_clone 自动处理
341            let service = StreamableHttpService::new(
342                move || Ok(proxy_handler_clone.clone()),
343                LocalSessionManager::default().into(),
344                Default::default(),
345            );
346            let router = axum::Router::new().fallback_service(service);
347            let ct = tokio_util::sync::CancellationToken::new();
348            (router, ct)
349        }
350
351        // 不应该出现的情况
352        McpProtocol::Stdio => {
353            return Err(anyhow::anyhow!(
354                "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
355            ));
356        }
357    };
358
359    // 克隆一份取消令牌和 mcp_id 用于监控子进程
360    let ct_clone = ct.clone();
361    let mcp_id_clone = mcp_id.clone();
362
363    // 存储 MCP 服务状态
364    let mcp_service_status = McpServiceStatus::new(
365        mcp_id_clone.clone(),
366        mcp_type.clone(),
367        mcp_router_path.clone(),
368        ct_clone.clone(),
369        CheckMcpStatusResponseStatus::Ready,
370    );
371    // 添加 MCP 服务状态到全局管理器,以及 proxy_handler 的透明代理
372    proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
373
374    // 为SSE和Stream协议添加基础路径处理
375    // 支持直接访问基础路径,自动重定向到正确的子路径
376    let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
377        // 使用fallback处理器来匹配基础路径
378        let modified_router = router.fallback(base_path_fallback_handler);
379        info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
380        modified_router
381    } else {
382        router
383    };
384
385    // 注册路由到全局路由表
386    info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
387    info!(
388        "SSE路径配置: sse_path={}, post_path={}",
389        match &mcp_router_path.mcp_protocol_path {
390            McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
391            _ => "N/A",
392        },
393        match &mcp_router_path.mcp_protocol_path {
394            McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
395            _ => "N/A",
396        }
397    );
398    DynamicRouterService::register_route(&base_path, router.clone());
399    info!("路由注册完成: base_path={}", base_path);
400
401    // 返回路由和取消令牌
402    Ok((router, ct))
403}
404
405// 基础路径处理器 - 支持直接访问基础路径,自动重定向到正确的子路径
406#[axum::debug_handler]
407async fn base_path_fallback_handler(
408    method: axum::http::Method,
409    uri: axum::http::Uri,
410    headers: axum::http::HeaderMap,
411) -> impl axum::response::IntoResponse {
412    let path = uri.path();
413    info!("基础路径处理器: {} {}", method, path);
414
415    // 判断是SSE还是Stream协议
416    if path.contains("/sse/proxy/") {
417        // SSE协议处理
418        match method {
419            axum::http::Method::GET => {
420                // 从路径中提取 MCP ID
421                let mcp_id = path.split("/sse/proxy/").nth(1);
422
423                if let Some(mcp_id) = mcp_id {
424                    // 检查MCP服务是否存在
425                    let proxy_manager = get_proxy_manager();
426                    if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
427                        // MCP服务不存在
428                        (
429                            axum::http::StatusCode::NOT_FOUND,
430                            [("Content-Type", "text/plain".to_string())],
431                            format!("MCP service '{}' not found", mcp_id).to_string(),
432                        )
433                    } else {
434                        // MCP服务存在,检查Accept头
435                        let accept_header = headers.get("accept");
436                        if let Some(accept) = accept_header {
437                            let accept_str = accept.to_str().unwrap_or("");
438                            if accept_str.contains("text/event-stream") {
439                                // 正确的Accept头,重定向到 /sse
440                                let redirect_uri = format!("{}/sse", path);
441                                info!("SSE重定向到: {}", redirect_uri);
442                                (
443                                    axum::http::StatusCode::FOUND,
444                                    [("Location", redirect_uri.to_string())],
445                                    "Redirecting to SSE endpoint".to_string(),
446                                )
447                            } else {
448                                // Accept头不正确
449                                (
450                                    axum::http::StatusCode::BAD_REQUEST,
451                                    [("Content-Type", "text/plain".to_string())],
452                                    "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
453                                )
454                            }
455                        } else {
456                            // 没有Accept头
457                            (
458                                axum::http::StatusCode::BAD_REQUEST,
459                                [("Content-Type", "text/plain".to_string())],
460                                "SSE error: Missing Accept header, expected 'text/event-stream'"
461                                    .to_string(),
462                            )
463                        }
464                    }
465                } else {
466                    // 无法从路径中提取MCP ID
467                    (
468                        axum::http::StatusCode::BAD_REQUEST,
469                        [("Content-Type", "text/plain".to_string())],
470                        "SSE error: Invalid SSE path".to_string(),
471                    )
472                }
473            }
474            axum::http::Method::POST => {
475                // POST请求重定向到 /message
476                let redirect_uri = format!("{}/message", path);
477                info!("SSE重定向到: {}", redirect_uri);
478                (
479                    axum::http::StatusCode::FOUND,
480                    [("Location", redirect_uri.to_string())],
481                    "Redirecting to message endpoint".to_string(),
482                )
483            }
484            _ => {
485                // 其他方法返回405 Method Not Allowed
486                (
487                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
488                    [("Allow", "GET, POST".to_string())],
489                    "Only GET and POST methods are allowed".to_string(),
490                )
491            }
492        }
493    } else if path.contains("/stream/proxy/") {
494        // Stream协议处理 - 直接返回成功,不重定向
495        match method {
496            axum::http::Method::GET => {
497                // GET请求返回服务器信息
498                (
499                    axum::http::StatusCode::OK,
500                    [("Content-Type", "application/json".to_string())],
501                    r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
502                )
503            }
504            axum::http::Method::POST => {
505                // POST请求返回成功,让StreamableHttpService处理
506                (
507                    axum::http::StatusCode::OK,
508                    [("Content-Type", "application/json".to_string())],
509                    r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
510                )
511            }
512            _ => {
513                // 其他方法返回405 Method Not Allowed
514                (
515                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
516                    [("Allow", "GET, POST".to_string())],
517                    "Only GET and POST methods are allowed".to_string(),
518                )
519            }
520        }
521    } else {
522        // 未知协议
523        (
524            axum::http::StatusCode::BAD_REQUEST,
525            [("Content-Type", "text/plain".to_string())],
526            "Unknown protocol or path".to_string(),
527        )
528    }
529}
530
531// 提取记录命令详情的函数
532fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
533    // 打印命令行参数
534    let args_str = mcp_config
535        .args
536        .as_ref()
537        .map_or(String::new(), |args| args.join(" "));
538    let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
539    debug!("{cmd_str}");
540
541    // 打印环境变量
542    if let Some(env_vars) = &mcp_config.env {
543        let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
544        if !env_vars.is_empty() {
545            debug!("环境变量: {}", env_vars.join(", "));
546        }
547    }
548
549    // 打印完整命令
550    debug!(
551        "完整命令,mcpId={}, command={:?}",
552        mcp_router_path.mcp_id, mcp_config.command
553    );
554
555    // 构建完整的命令字符串,用于直接复制运行
556    let args_str = mcp_config
557        .args
558        .as_ref()
559        .map_or(String::new(), |args| args.join(" "));
560    let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
561        env.iter()
562            .map(|(k, v)| format!("{k}={v}"))
563            .collect::<Vec<String>>()
564            .join(" ")
565    });
566
567    let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
568    info!(
569        "完整命令字符串,mcpId={},command={:?}",
570        mcp_router_path.mcp_id, full_command
571    );
572}