mcp_stdio_proxy/server/task/
mcp_start_task.rs

1use crate::{
2    AppError, DynamicRouterService, ProxyHandler, get_proxy_manager,
3    model::{
4        CheckMcpStatusResponseStatus, McpConfig, McpProtocol, McpProtocolPath, McpRouterPath,
5        McpServerCommandConfig, McpServerConfig, McpServiceStatus, McpType,
6    },
7};
8
9use anyhow::Result;
10use log::{debug, info};
11use rmcp::{
12    ServiceExt,
13    model::{ClientCapabilities, ClientInfo},
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17    transport::{
18        TokioChildProcess,
19        streamable_http_client::{
20            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
21        },
22    },
23};
24use tokio::process::Command;
25
26///根据 mcp_id 和 mcp_json_config 启动mcp服务
27pub async fn mcp_start_task(
28    mcp_config: McpConfig,
29) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
30    let mcp_id = mcp_config.mcp_id.clone();
31    let client_protocol = mcp_config.client_protocol.clone();
32
33    // 使用客户端协议创建路由路径(决定暴露的API接口)
34    let mcp_router_path: McpRouterPath =
35        McpRouterPath::new(mcp_id, client_protocol).map_err(|e| AppError::McpServerError(e))?;
36
37    let mcp_json_config = mcp_config
38        .mcp_json_config
39        .clone()
40        .expect("mcp_json_config is required");
41
42    let mcp_server_config = McpServerConfig::try_from(mcp_json_config)?;
43
44    // 使用新的集成方法,后端协议在函数内部解析
45    integrate_sse_server_with_axum(
46        mcp_server_config.clone(),
47        mcp_router_path.clone(),
48        mcp_config.mcp_type,
49    )
50    .await
51}
52
53// 创建一个新函数,将 SseServer 与 axum 路由集成
54pub async fn integrate_sse_server_with_axum(
55    mcp_config: McpServerConfig,
56    mcp_router_path: McpRouterPath,
57    mcp_type: McpType,
58) -> Result<(axum::Router, tokio_util::sync::CancellationToken)> {
59    let base_path = mcp_router_path.base_path.clone();
60    let mcp_id = mcp_router_path.mcp_id.clone();
61
62    // 根据MCP服务器配置解析后端协议
63    let backend_protocol = match &mcp_config {
64        // 命令行配置:使用 stdio 协议
65        McpServerConfig::Command(_) => McpProtocol::Stdio,
66        // URL配置:解析 type 字段或自动检测
67        McpServerConfig::Url(url_config) => {
68            // 首先检查 type 字段
69            if let Some(type_str) = &url_config.r#type {
70                // 尝试解析 type 字段
71                match type_str.parse::<McpProtocol>() {
72                    Ok(protocol) => {
73                        debug!("使用配置中指定的协议类型: {} -> {:?}", type_str, protocol);
74                        protocol
75                    }
76                    Err(_) => {
77                        // 如果解析失败,自动检测协议
78                        debug!("协议类型 '{}' 无法识别,开始自动检测协议", type_str);
79                        let detected_protocol =
80                            crate::server::detect_mcp_protocol(url_config.get_url())
81                                .await
82                                .map_err(|e| {
83                                    anyhow::anyhow!(
84                                        "协议类型 '{}' 不可识别,且自动检测失败: {}",
85                                        type_str,
86                                        e
87                                    )
88                                })?;
89                        debug!(
90                            "自动检测到协议类型: {:?}(原始配置: '{}')",
91                            detected_protocol, type_str
92                        );
93                        detected_protocol
94                    }
95                }
96            } else {
97                // 没有指定 type 字段,自动检测协议
98                debug!("未指定 type 字段,自动检测协议");
99                let detected_protocol = crate::server::detect_mcp_protocol(url_config.get_url())
100                    .await
101                    .map_err(|e| anyhow::anyhow!("自动检测协议失败: {}", e))?;
102                detected_protocol
103            }
104        }
105    };
106
107    debug!(
108        "MCP ID: {}, 客户端协议: {:?}, 后端协议: {:?}",
109        mcp_id, mcp_router_path.mcp_protocol, backend_protocol
110    );
111
112    // 创建客户端信息
113    let client_info = ClientInfo {
114        protocol_version: Default::default(),
115        capabilities: ClientCapabilities::builder()
116            .enable_experimental()
117            .enable_roots()
118            .enable_roots_list_changed()
119            .enable_sampling()
120            .build(),
121        ..Default::default()
122    };
123
124    // 根据配置类型创建不同的客户端服务
125    let client = match &mcp_config {
126        McpServerConfig::Command(cmd_config) => {
127            // 创建子进程命令
128            let mut command = Command::new(&cmd_config.command);
129
130            // 正确处理Option<Vec<String>>
131            if let Some(args) = &cmd_config.args {
132                command.args(args);
133            }
134
135            // 正确处理Option<HashMap<String, String>>
136            if let Some(env_vars) = &cmd_config.env {
137                for (key, value) in env_vars {
138                    command.env(key, value);
139                }
140            }
141
142            // 记录命令执行信息,方便调试
143            log_command_details(cmd_config, &mcp_router_path);
144
145            info!(
146                "子进程已启动,MCP ID: {}, 类型: {:?}",
147                mcp_router_path.mcp_id,
148                mcp_type.clone()
149            );
150
151            // 创建子进程传输并创建客户端服务
152            let tokio_process = TokioChildProcess::new(command)?;
153            client_info.serve(tokio_process).await?
154        }
155        McpServerConfig::Url(url_config) => {
156            // 根据后端协议类型创建不同的客户端传输
157            info!(
158                "连接到远程MCP服务: {}, 后端协议: {:?}, 客户端协议: {:?}",
159                url_config.get_url(),
160                backend_protocol,
161                mcp_router_path.mcp_protocol
162            );
163
164            match backend_protocol {
165                McpProtocol::Stdio => {
166                    // URL 配置不应该出现 Stdio 协议
167                    return Err(anyhow::anyhow!("URL 配置的 MCP 服务不能使用 Stdio 协议"));
168                }
169                McpProtocol::Sse => {
170                    // SSE 协议 - rmcp 0.12 移除了 SseClientTransport,使用 StreamableHttpClientTransport
171                    info!(
172                        "使用Streamable HTTP协议连接到(SSE兼容模式): {}",
173                        url_config.get_url()
174                    );
175
176                    // 创建带有自定义 headers 的 reqwest client
177                    let mut headers = reqwest::header::HeaderMap::new();
178
179                    // 添加配置中的自定义 headers
180                    if let Some(config_headers) = &url_config.headers {
181                        for (key, value) in config_headers {
182                            headers.insert(
183                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
184                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
185                                })?,
186                                value.parse().map_err(|e| {
187                                    anyhow::anyhow!(
188                                        "Invalid header value for {}: {}, error: {}",
189                                        key,
190                                        value,
191                                        e
192                                    )
193                                })?,
194                            );
195                        }
196                        info!("添加了 {} 个自定义 headers", headers.len());
197                    } else {
198                        info!("没有配置自定义 headers");
199                    }
200
201                    let client = reqwest::Client::builder()
202                        .default_headers(headers)
203                        .build()
204                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
205
206                    // 创建 Streamable HTTP 客户端配置
207                    let config = StreamableHttpClientTransportConfig {
208                        uri: url_config.get_url().to_string().into(),
209                        ..Default::default()
210                    };
211
212                    let transport = StreamableHttpClientTransport::with_client(client, config);
213                    client_info.serve(transport).await?
214                }
215                McpProtocol::Stream => {
216                    // Streamable 协议 - 创建 Streamable HTTP 客户端传输
217                    info!("使用Streamable HTTP协议连接到: {}", url_config.get_url());
218
219                    // 创建自定义 client 和配置(支持 Authorization header)
220                    let mut headers = reqwest::header::HeaderMap::new();
221
222                    // 添加配置中的自定义 headers(排除 Authorization)
223                    if let Some(config_headers) = &url_config.headers {
224                        for (key, value) in config_headers {
225                            // 跳过 Authorization header,它会通过 auth_header 配置字段传递
226                            if key.eq_ignore_ascii_case("Authorization") {
227                                continue;
228                            }
229                            headers.insert(
230                                reqwest::header::HeaderName::try_from(key).map_err(|e| {
231                                    anyhow::anyhow!("Invalid header name: {}, error: {}", key, e)
232                                })?,
233                                value.parse().map_err(|e| {
234                                    anyhow::anyhow!(
235                                        "Invalid header value for {}: {}, error: {}",
236                                        key,
237                                        value,
238                                        e
239                                    )
240                                })?,
241                            );
242                        }
243                        info!("添加了 {} 个自定义 headers", headers.len());
244                    } else {
245                        info!("没有配置自定义 headers");
246                    }
247
248                    let client = reqwest::Client::builder()
249                        .default_headers(headers)
250                        .build()
251                        .map_err(|e| anyhow::anyhow!("创建 reqwest client 失败: {}", e))?;
252
253                    // 提取 Authorization header 用于配置(不区分大小写)
254                    let auth_header = url_config.headers.as_ref().and_then(|h| {
255                        // HTTP header 名称不区分大小写,查找 Authorization
256                        h.iter()
257                            .find_map(|(k, v)| {
258                                if k.eq_ignore_ascii_case("Authorization") {
259                                    Some(v)
260                                } else {
261                                    None
262                                }
263                            })
264                            .map(|s| s.strip_prefix("Bearer ").unwrap_or(s).to_string())
265                    });
266
267                    // 创建传输配置
268                    let config = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
269                        uri: url_config.get_url().to_string().into(),
270                        auth_header,
271                        ..Default::default()
272                    };
273
274                    let transport = StreamableHttpClientTransport::with_client(client, config);
275
276                    info!(
277                        "Streamable HTTP传输已创建,开始建立连接,MCP ID: {}, 类型: {:?}",
278                        mcp_router_path.mcp_id,
279                        mcp_type.clone()
280                    );
281
282                    // serve 会建立连接并完成初始化握手
283                    let client = client_info.serve(transport).await?;
284
285                    info!(
286                        "Streamable HTTP客户端连接成功,MCP ID: {}",
287                        mcp_router_path.mcp_id
288                    );
289
290                    client
291                }
292            }
293        }
294    };
295
296    // 创建代理处理器
297    let proxy_handler = ProxyHandler::with_mcp_id(client, mcp_id.clone());
298
299    // 获取全局 ProxyHandlerManager
300    let proxy_manager = get_proxy_manager();
301
302    // 注册代理处理器(ProxyHandler 内部已使用 Arc,clone 非常轻量)
303    let proxy_handler_clone = proxy_handler.clone();
304
305    // 根据客户端协议和后端协议创建服务器(支持协议转换)
306    // 支持三种模式:
307    // 根据客户端协议(主导)创建路由:决定对外暴露的 API 接口类型
308    let (router, ct) = match mcp_router_path.mcp_protocol.clone() {
309        // ================ 客户端使用 SSE 协议 ================
310        // rmcp 0.12 移除了 SseServer,使用 StreamableHttpService 替代
311        McpProtocol::Sse => {
312            // 对外提供 Streamable HTTP 接口(SSE兼容模式)
313            // 协议转换由 proxy_handler_clone 自动处理
314            debug!(
315                "创建Streamable HTTP服务器(SSE兼容模式), mcp_id={}",
316                mcp_router_path.mcp_id
317            );
318
319            let ct = tokio_util::sync::CancellationToken::new();
320            let service = StreamableHttpService::new(
321                move || Ok(proxy_handler_clone.clone()),
322                LocalSessionManager::default().into(),
323                StreamableHttpServerConfig {
324                    cancellation_token: ct.clone(),
325                    ..Default::default()
326                },
327            );
328            let router = axum::Router::new().fallback_service(service);
329            (router, ct)
330        }
331
332        // ================ 客户端使用 Streamable HTTP 协议 ================
333        McpProtocol::Stream => {
334            // 对外提供 Streamable HTTP 接口
335            // 内部协议转换由 proxy_handler_clone 自动处理
336            let service = StreamableHttpService::new(
337                move || Ok(proxy_handler_clone.clone()),
338                LocalSessionManager::default().into(),
339                Default::default(),
340            );
341            let router = axum::Router::new().fallback_service(service);
342            let ct = tokio_util::sync::CancellationToken::new();
343            (router, ct)
344        }
345
346        // 不应该出现的情况
347        McpProtocol::Stdio => {
348            return Err(anyhow::anyhow!(
349                "客户端协议不能是 Stdio。McpRouterPath::new 不支持创建 Stdio 协议的路由路径"
350            ));
351        }
352    };
353
354    // 克隆一份取消令牌和 mcp_id 用于监控子进程
355    let ct_clone = ct.clone();
356    let mcp_id_clone = mcp_id.clone();
357
358    // 存储 MCP 服务状态
359    let mcp_service_status = McpServiceStatus::new(
360        mcp_id_clone.clone(),
361        mcp_type.clone(),
362        mcp_router_path.clone(),
363        ct_clone.clone(),
364        CheckMcpStatusResponseStatus::Ready,
365    );
366    // 添加 MCP 服务状态到全局管理器,以及 proxy_handler 的透明代理
367    proxy_manager.add_mcp_service_status_and_proxy(mcp_service_status, Some(proxy_handler));
368
369    // 为SSE和Stream协议添加基础路径处理
370    // 支持直接访问基础路径,自动重定向到正确的子路径
371    let router = if matches!(mcp_router_path.mcp_protocol, McpProtocol::Sse) {
372        // 使用fallback处理器来匹配基础路径
373        let modified_router = router.fallback(base_path_fallback_handler);
374        info!("SSE基础路径处理器已添加, 基础路径: {}", base_path);
375        modified_router
376    } else {
377        router
378    };
379
380    // 注册路由到全局路由表
381    info!("注册路由: base_path={}, mcp_id={}", base_path, mcp_id);
382    info!(
383        "SSE路径配置: sse_path={}, post_path={}",
384        match &mcp_router_path.mcp_protocol_path {
385            McpProtocolPath::SsePath(sse_path) => &sse_path.sse_path,
386            _ => "N/A",
387        },
388        match &mcp_router_path.mcp_protocol_path {
389            McpProtocolPath::SsePath(sse_path) => &sse_path.message_path,
390            _ => "N/A",
391        }
392    );
393    DynamicRouterService::register_route(&base_path, router.clone());
394    info!("路由注册完成: base_path={}", base_path);
395
396    // 返回路由和取消令牌
397    Ok((router, ct))
398}
399
400// 基础路径处理器 - 支持直接访问基础路径,自动重定向到正确的子路径
401#[axum::debug_handler]
402async fn base_path_fallback_handler(
403    method: axum::http::Method,
404    uri: axum::http::Uri,
405    headers: axum::http::HeaderMap,
406) -> impl axum::response::IntoResponse {
407    let path = uri.path();
408    info!("基础路径处理器: {} {}", method, path);
409
410    // 判断是SSE还是Stream协议
411    if path.contains("/sse/proxy/") {
412        // SSE协议处理
413        match method {
414            axum::http::Method::GET => {
415                // 从路径中提取 MCP ID
416                let mcp_id = path.split("/sse/proxy/").nth(1);
417
418                if let Some(mcp_id) = mcp_id {
419                    // 检查MCP服务是否存在
420                    let proxy_manager = get_proxy_manager();
421                    if proxy_manager.get_mcp_service_status(mcp_id).is_none() {
422                        // MCP服务不存在
423                        (
424                            axum::http::StatusCode::NOT_FOUND,
425                            [("Content-Type", "text/plain".to_string())],
426                            format!("MCP service '{}' not found", mcp_id).to_string(),
427                        )
428                    } else {
429                        // MCP服务存在,检查Accept头
430                        let accept_header = headers.get("accept");
431                        if let Some(accept) = accept_header {
432                            let accept_str = accept.to_str().unwrap_or("");
433                            if accept_str.contains("text/event-stream") {
434                                // 正确的Accept头,重定向到 /sse
435                                let redirect_uri = format!("{}/sse", path);
436                                info!("SSE重定向到: {}", redirect_uri);
437                                (
438                                    axum::http::StatusCode::FOUND,
439                                    [("Location", redirect_uri.to_string())],
440                                    "Redirecting to SSE endpoint".to_string(),
441                                )
442                            } else {
443                                // Accept头不正确
444                                (
445                                    axum::http::StatusCode::BAD_REQUEST,
446                                    [("Content-Type", "text/plain".to_string())],
447                                    "SSE error: Invalid Accept header, expected 'text/event-stream'".to_string(),
448                                )
449                            }
450                        } else {
451                            // 没有Accept头
452                            (
453                                axum::http::StatusCode::BAD_REQUEST,
454                                [("Content-Type", "text/plain".to_string())],
455                                "SSE error: Missing Accept header, expected 'text/event-stream'"
456                                    .to_string(),
457                            )
458                        }
459                    }
460                } else {
461                    // 无法从路径中提取MCP ID
462                    (
463                        axum::http::StatusCode::BAD_REQUEST,
464                        [("Content-Type", "text/plain".to_string())],
465                        "SSE error: Invalid SSE path".to_string(),
466                    )
467                }
468            }
469            axum::http::Method::POST => {
470                // POST请求重定向到 /message
471                let redirect_uri = format!("{}/message", path);
472                info!("SSE重定向到: {}", redirect_uri);
473                (
474                    axum::http::StatusCode::FOUND,
475                    [("Location", redirect_uri.to_string())],
476                    "Redirecting to message endpoint".to_string(),
477                )
478            }
479            _ => {
480                // 其他方法返回405 Method Not Allowed
481                (
482                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
483                    [("Allow", "GET, POST".to_string())],
484                    "Only GET and POST methods are allowed".to_string(),
485                )
486            }
487        }
488    } else if path.contains("/stream/proxy/") {
489        // Stream协议处理 - 直接返回成功,不重定向
490        match method {
491            axum::http::Method::GET => {
492                // GET请求返回服务器信息
493                (
494                    axum::http::StatusCode::OK,
495                    [("Content-Type", "application/json".to_string())],
496                    r#"{"jsonrpc":"2.0","result":{"info":"Streamable MCP Server","version":"1.0"}}"#.to_string(),
497                )
498            }
499            axum::http::Method::POST => {
500                // POST请求返回成功,让StreamableHttpService处理
501                (
502                    axum::http::StatusCode::OK,
503                    [("Content-Type", "application/json".to_string())],
504                    r#"{"jsonrpc":"2.0","result":{"message":"Stream request received","protocol":"streamable-http"}}"#.to_string(),
505                )
506            }
507            _ => {
508                // 其他方法返回405 Method Not Allowed
509                (
510                    axum::http::StatusCode::METHOD_NOT_ALLOWED,
511                    [("Allow", "GET, POST".to_string())],
512                    "Only GET and POST methods are allowed".to_string(),
513                )
514            }
515        }
516    } else {
517        // 未知协议
518        (
519            axum::http::StatusCode::BAD_REQUEST,
520            [("Content-Type", "text/plain".to_string())],
521            "Unknown protocol or path".to_string(),
522        )
523    }
524}
525
526// 提取记录命令详情的函数
527fn log_command_details(mcp_config: &McpServerCommandConfig, mcp_router_path: &McpRouterPath) {
528    // 打印命令行参数
529    let args_str = mcp_config
530        .args
531        .as_ref()
532        .map_or(String::new(), |args| args.join(" "));
533    let cmd_str = format!("执行命令: {} {}", mcp_config.command, args_str);
534    debug!("{cmd_str}");
535
536    // 打印环境变量
537    if let Some(env_vars) = &mcp_config.env {
538        let env_vars: Vec<String> = env_vars.iter().map(|(k, v)| format!("{k}={v}")).collect();
539        if !env_vars.is_empty() {
540            debug!("环境变量: {}", env_vars.join(", "));
541        }
542    }
543
544    // 打印完整命令
545    debug!(
546        "完整命令,mcpId={}, command={:?}",
547        mcp_router_path.mcp_id, mcp_config.command
548    );
549
550    // 构建完整的命令字符串,用于直接复制运行
551    let args_str = mcp_config
552        .args
553        .as_ref()
554        .map_or(String::new(), |args| args.join(" "));
555    let env_str = mcp_config.env.as_ref().map_or(String::new(), |env| {
556        env.iter()
557            .map(|(k, v)| format!("{k}={v}"))
558            .collect::<Vec<String>>()
559            .join(" ")
560    });
561
562    let full_command = format!("{} {} {}", mcp_config.command, args_str, env_str);
563    info!(
564        "完整命令字符串,mcpId={},command={:?}",
565        mcp_router_path.mcp_id, full_command
566    );
567}