mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 实现
2// 使用库提供的高层 API,分支处理 SSE 和 Stream 协议
3
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use clap::Parser;
9use anyhow::{Result, bail};
10use serde::Deserialize;
11use tokio::process::Command;
12use tracing::error;
13
14// 使用各自库的高层 API(从 proxy/mod.rs 导入)
15use crate::proxy::{
16    ProxyHandler, ToolFilter,
17    McpClientConfig, SseClientConnection, StreamClientConnection,
18};
19
20// SSE 模式需要的类型(rmcp 0.10)- 用于 command 模式和 SSE stdio 服务
21use mcp_sse_proxy::{
22    ServiceExt as SseServiceExt,
23    TokioChildProcess,
24    stdio as sse_stdio,
25    ClientInfo as SseClientInfo,
26    ClientCapabilities as SseClientCapabilities,
27    Implementation as SseImplementation,
28};
29
30// Stream 模式需要的类型(rmcp 0.12)
31use mcp_streamable_proxy::{
32    ServiceExt as StreamServiceExt,
33    stdio as stream_stdio,
34    ProxyHandler as StreamProxyHandler,
35};
36
37/// MCP-Proxy CLI 主命令结构
38#[derive(Parser, Debug)]
39#[command(name = "mcp-proxy")]
40#[command(version = env!("CARGO_PKG_VERSION"))]
41#[command(about = "MCP 协议转换代理工具", long_about = None)]
42pub struct Cli {
43    #[command(subcommand)]
44    pub command: Option<Commands>,
45    
46    /// 直接URL模式(向后兼容)
47    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
48    pub url: Option<String>,
49    
50    /// 全局详细输出
51    #[arg(short, long, global = true)]
52    pub verbose: bool,
53    
54    /// 全局静默模式
55    #[arg(short, long, global = true)]
56    pub quiet: bool,
57}
58
59#[derive(clap::Subcommand, Debug)]
60pub enum Commands {
61    /// 协议转换模式 - 将 URL 转换为 stdio
62    Convert(ConvertArgs),
63
64    /// 检查服务状态
65    Check(CheckArgs),
66
67    /// 协议检测
68    Detect(DetectArgs),
69
70    /// 代理模式 - 将 stdio MCP 服务代理为 HTTP/SSE 服务
71    Proxy(super::proxy_server::ProxyArgs),
72}
73
74/// 协议转换参数
75#[derive(Parser, Debug, Clone)]
76pub struct ConvertArgs {
77    /// MCP 服务的 URL 地址(可选,与 --config/--config-file 二选一)
78    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
79    pub url: Option<String>,
80
81    /// MCP 服务配置 JSON
82    #[arg(long, conflicts_with = "config_file", help = "MCP 服务配置 JSON")]
83    pub config: Option<String>,
84
85    /// MCP 服务配置文件路径
86    #[arg(long, conflicts_with = "config", help = "MCP 服务配置文件路径")]
87    pub config_file: Option<std::path::PathBuf>,
88
89    /// MCP 服务名称(多服务配置时必需)
90    #[arg(short, long, help = "MCP 服务名称(多服务配置时必需)")]
91    pub name: Option<String>,
92
93    /// 指定远程服务协议类型(不指定则自动检测)
94    #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
95    pub protocol: Option<super::proxy_server::ProxyProtocol>,
96
97    /// 认证 header (如: "Bearer token")
98    #[arg(short, long, help = "认证 header")]
99    pub auth: Option<String>,
100
101    /// 自定义 HTTP headers
102    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
103    pub header: Vec<(String, String)>,
104
105    /// 重试次数
106    #[arg(long, default_value = "0", help = "重试次数,0 表示无限重试")]
107    pub retries: u32,
108
109    /// 工具白名单(逗号分隔),只允许指定的工具
110    #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
111    pub allow_tools: Option<Vec<String>>,
112
113    /// 工具黑名单(逗号分隔),排除指定的工具
114    #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
115    pub deny_tools: Option<Vec<String>>,
116
117    /// 客户端 ping 间隔(秒),0 表示禁用
118    #[arg(long, default_value = "30", help = "客户端 ping 间隔(秒),0 表示禁用")]
119    pub ping_interval: u64,
120
121    /// 客户端 ping 超时(秒)
122    #[arg(long, default_value = "10", help = "客户端 ping 超时(秒),超时则认为连接断开")]
123    pub ping_timeout: u64,
124}
125
126/// 检查参数
127#[derive(Parser, Debug)]
128pub struct CheckArgs {
129    /// 要检查的 MCP 服务 URL
130    #[arg(value_name = "URL")]
131    pub url: String,
132    
133    /// 认证 header
134    #[arg(short, long)]
135    pub auth: Option<String>,
136    
137    /// 超时时间
138    #[arg(long, default_value = "10")]
139    pub timeout: u64,
140}
141
142/// 协议检测参数
143#[derive(Parser, Debug)]
144pub struct DetectArgs {
145    /// 要检测的 MCP 服务 URL
146    #[arg(value_name = "URL")]
147    pub url: String,
148    
149    /// 认证 header
150    #[arg(short, long)]
151    pub auth: Option<String>,
152}
153
154/// 解析 KEY=VALUE 格式的辅助函数
155fn parse_key_val(s: &str) -> Result<(String, String)> {
156    let pos = s.find('=')
157        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
158    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
159}
160
161// ============== MCP 配置解析相关 ==============
162
163/// MCP 配置格式
164#[derive(Deserialize, Debug)]
165struct McpConfig {
166    #[serde(rename = "mcpServers")]
167    mcp_servers: HashMap<String, McpServerInnerConfig>,
168}
169
170/// MCP 服务配置(支持 Command 和 Url 两种类型)
171#[derive(Deserialize, Debug, Clone)]
172#[serde(untagged)]
173enum McpServerInnerConfig {
174    Command(StdioConfig),
175    Url(UrlConfig),
176}
177
178/// stdio 配置(本地命令)
179#[derive(Deserialize, Debug, Clone)]
180struct StdioConfig {
181    command: String,
182    args: Option<Vec<String>>,
183    env: Option<HashMap<String, String>>,
184}
185
186/// URL 配置(远程服务)
187#[derive(Deserialize, Debug, Clone)]
188struct UrlConfig {
189    #[serde(skip_serializing_if = "Option::is_none")]
190    url: Option<String>,
191    #[serde(
192        skip_serializing_if = "Option::is_none",
193        default,
194        rename = "baseUrl",
195        alias = "baseurl",
196        alias = "base_url"
197    )]
198    base_url: Option<String>,
199    #[serde(default, rename = "type", alias = "Type")]
200    r#type: Option<String>,
201    pub headers: Option<HashMap<String, String>>,
202    #[serde(default, alias = "authToken", alias = "auth_token")]
203    pub auth_token: Option<String>,
204    pub timeout: Option<u64>,
205}
206
207impl UrlConfig {
208    fn get_url(&self) -> Option<&str> {
209        self.url.as_deref().or(self.base_url.as_deref())
210    }
211}
212
213/// 解析后的配置源
214enum McpConfigSource {
215    /// 直接 URL 模式(命令行参数)
216    DirectUrl {
217        url: String,
218    },
219    /// 远程服务配置(JSON 配置)
220    RemoteService {
221        name: String,
222        url: String,
223        protocol: Option<super::protocol::McpProtocol>,
224        headers: HashMap<String, String>,
225        timeout: Option<u64>,
226    },
227    /// 本地命令配置(JSON 配置)
228    LocalCommand {
229        name: String,
230        command: String,
231        args: Vec<String>,
232        env: HashMap<String, String>,
233    },
234}
235
236/// 解析 convert 命令的配置
237fn parse_convert_config(args: &ConvertArgs) -> Result<McpConfigSource> {
238    // 优先级:url > config > config_file
239    if let Some(ref url) = args.url {
240        return Ok(McpConfigSource::DirectUrl { url: url.clone() });
241    }
242
243    // 读取 JSON 配置
244    let json_str = if let Some(ref config) = args.config {
245        config.clone()
246    } else if let Some(ref path) = args.config_file {
247        std::fs::read_to_string(path)
248            .map_err(|e| anyhow::anyhow!("读取配置文件失败: {}", e))?
249    } else {
250        bail!("必须提供 URL、--config 或 --config-file 参数之一");
251    };
252
253    // 解析 JSON 配置
254    let mcp_config: McpConfig = serde_json::from_str(&json_str)
255        .map_err(|e| anyhow::anyhow!(
256            "配置解析失败: {}。配置必须是标准 MCP 格式,包含 mcpServers 字段",
257            e
258        ))?;
259
260    let servers = mcp_config.mcp_servers;
261
262    if servers.is_empty() {
263        bail!("配置中没有找到任何 MCP 服务");
264    }
265
266    // 选择服务
267    let (name, inner_config) = if servers.len() == 1 {
268        servers.into_iter().next().unwrap()
269    } else if let Some(ref name) = args.name {
270        let config = servers.get(name)
271            .cloned()
272            .ok_or_else(|| anyhow::anyhow!(
273                "服务 '{}' 不存在。可用服务: {:?}",
274                name,
275                servers.keys().collect::<Vec<_>>()
276            ))?;
277        (name.clone(), config)
278    } else {
279        bail!(
280            "配置包含多个服务 {:?},请使用 --name 指定要使用的服务",
281            servers.keys().collect::<Vec<_>>()
282        );
283    };
284
285    // 根据配置类型返回
286    match inner_config {
287        McpServerInnerConfig::Command(stdio) => {
288            Ok(McpConfigSource::LocalCommand {
289                name,
290                command: stdio.command,
291                args: stdio.args.unwrap_or_default(),
292                env: stdio.env.unwrap_or_default(),
293            })
294        }
295        McpServerInnerConfig::Url(url_config) => {
296            let url = url_config.get_url()
297                .ok_or_else(|| anyhow::anyhow!("URL 配置缺少 url 或 baseUrl 字段"))?
298                .to_string();
299
300            // 解析协议类型
301            let protocol = url_config.r#type.as_ref().and_then(|t| {
302                match t.as_str() {
303                    "sse" => Some(super::protocol::McpProtocol::Sse),
304                    "http" | "stream" => Some(super::protocol::McpProtocol::Stream),
305                    _ => None,
306                }
307            });
308
309            // 合并 headers:JSON 配置中的 auth_token -> Authorization
310            let mut headers = url_config.headers.clone().unwrap_or_default();
311            if let Some(auth_token) = &url_config.auth_token {
312                headers.insert("Authorization".to_string(), auth_token.clone());
313            }
314
315            Ok(McpConfigSource::RemoteService {
316                name,
317                url,
318                protocol,
319                headers,
320                timeout: url_config.timeout,
321            })
322        }
323    }
324}
325
326/// 合并 headers:JSON 配置 + 命令行参数(命令行优先)
327fn merge_headers(
328    config_headers: HashMap<String, String>,
329    cli_headers: &[(String, String)],
330    cli_auth: Option<&String>,
331) -> HashMap<String, String> {
332    let mut merged = config_headers;
333
334    // 命令行 -H 参数覆盖配置
335    for (key, value) in cli_headers {
336        merged.insert(key.clone(), value.clone());
337    }
338
339    // 命令行 --auth 参数优先级最高
340    if let Some(auth_value) = cli_auth {
341        merged.insert("Authorization".to_string(), auth_value.clone());
342    }
343
344    merged
345}
346
347/// 运行 CLI 主逻辑
348pub async fn run_cli(cli: Cli) -> Result<()> {
349    match cli.command {
350        Some(Commands::Convert(args)) => {
351            run_convert_command(args, cli.verbose, cli.quiet).await
352        }
353        Some(Commands::Check(args)) => {
354            run_check_command(args, cli.verbose, cli.quiet).await
355        }
356        Some(Commands::Detect(args)) => {
357            run_detect_command(args, cli.verbose, cli.quiet).await
358        }
359        Some(Commands::Proxy(args)) => {
360            super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
361        }
362        None => {
363            // 直接 URL 模式(向后兼容)
364            if let Some(url) = cli.url {
365                let args = ConvertArgs {
366                    url: Some(url),
367                    config: None,
368                    config_file: None,
369                    name: None,
370                    protocol: None,
371                    auth: None,
372                    header: vec![],
373                    retries: 0,    // 无限重试
374                    allow_tools: None,
375                    deny_tools: None,
376                    ping_interval: 30,  // 默认 30 秒 ping 一次
377                    ping_timeout: 10,   // 默认 10 秒超时
378                };
379                run_convert_command(args, cli.verbose, cli.quiet).await
380            } else {
381                bail!("请提供 URL 或使用子命令")
382            }
383        }
384    }
385}
386
387/// 运行转换命令 - 核心功能
388async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
389    // 检查 --allow-tools 和 --deny-tools 互斥
390    if args.allow_tools.is_some() && args.deny_tools.is_some() {
391        bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
392    }
393
394    // 创建工具过滤器
395    let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
396        ToolFilter::allow(allow_tools)
397    } else if let Some(deny_tools) = args.deny_tools.clone() {
398        ToolFilter::deny(deny_tools)
399    } else {
400        ToolFilter::default()
401    };
402
403    // 解析配置
404    let config_source = parse_convert_config(&args)?;
405
406    // 根据配置源执行不同逻辑
407    match config_source {
408        McpConfigSource::DirectUrl { url } => {
409            // 直接 URL 模式(带自动重连)
410            run_url_mode_with_retry(&args, &url, HashMap::new(), None, tool_filter, verbose, quiet).await
411        }
412        McpConfigSource::RemoteService { name, url, protocol, headers, timeout } => {
413            // 远程服务配置模式
414            if !quiet {
415                eprintln!("🚀 MCP-Stdio-Proxy: {} ({}) → stdio", name, url);
416            }
417            // 合并 headers:配置 + 命令行
418            let merged_headers = merge_headers(headers, &args.header, args.auth.as_ref());
419            run_url_mode_with_retry(&args, &url, merged_headers, protocol.or(timeout.map(|_| super::protocol::McpProtocol::Stream)), tool_filter, verbose, quiet).await
420        }
421        McpConfigSource::LocalCommand { name, command, args: cmd_args, env } => {
422            // 本地命令模式(使用 SSE 库的 rmcp 0.10)
423            run_command_mode(&name, &command, cmd_args, env, tool_filter, verbose, quiet).await
424        }
425    }
426}
427
428/// URL 模式执行(带自动重连)
429/// 使用分支逻辑:根据协议类型调用不同的处理函数
430async fn run_url_mode_with_retry(
431    args: &ConvertArgs,
432    url: &str,
433    merged_headers: HashMap<String, String>,
434    config_protocol: Option<super::protocol::McpProtocol>,
435    tool_filter: ToolFilter,
436    verbose: bool,
437    quiet: bool,
438) -> Result<()> {
439    if !quiet && merged_headers.is_empty() {
440        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", url);
441    }
442
443    // 显示过滤器配置
444    if !quiet {
445        if let Some(ref allow_tools) = args.allow_tools {
446            eprintln!("🔧 工具白名单: {:?}", allow_tools);
447        }
448        if let Some(ref deny_tools) = args.deny_tools {
449            eprintln!("🔧 工具黑名单: {:?}", deny_tools);
450        }
451    }
452
453    // 确定协议类型:命令行参数 > 配置文件 > 自动检测
454    let protocol = if let Some(ref proto) = args.protocol {
455        let detected = match proto {
456            super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
457            super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
458        };
459        if !quiet {
460            eprintln!("🔧 使用指定协议: {}", protocol_name(&detected));
461        }
462        detected
463    } else if let Some(proto) = config_protocol {
464        if !quiet {
465            eprintln!("🔧 使用配置协议: {}", protocol_name(&proto));
466        }
467        proto
468    } else {
469        if !quiet {
470            eprintln!("🔍 正在检测协议...");
471        }
472        let detected = super::protocol::detect_mcp_protocol(url).await?;
473        if !quiet {
474            eprintln!("🔍 检测到 {} 协议", protocol_name(&detected));
475        }
476        detected
477    };
478
479    // 构建 McpClientConfig
480    let config = build_mcp_config(url, &merged_headers, args.auth.as_ref());
481
482    // 根据协议类型分支处理
483    match protocol {
484        super::protocol::McpProtocol::Sse => {
485            run_sse_mode(config, args.clone(), tool_filter, verbose, quiet).await
486        }
487        super::protocol::McpProtocol::Stream => {
488            run_stream_mode(config, args.clone(), tool_filter, verbose, quiet).await
489        }
490        super::protocol::McpProtocol::Stdio => {
491            bail!("Stdio 协议不支持通过 URL 转换,请使用 --config 配置本地命令")
492        }
493    }
494}
495
496/// 构建 McpClientConfig
497fn build_mcp_config(
498    url: &str,
499    headers: &HashMap<String, String>,
500    auth: Option<&String>,
501) -> McpClientConfig {
502    let mut config = McpClientConfig::new(url);
503    for (k, v) in headers {
504        config = config.with_header(k, v);
505    }
506    if let Some(auth_value) = auth {
507        config = config.with_header("Authorization", auth_value);
508    }
509    config
510}
511
512/// SSE 模式处理(使用 mcp-sse-proxy,rmcp 0.10)
513async fn run_sse_mode(
514    config: McpClientConfig,
515    args: ConvertArgs,
516    tool_filter: ToolFilter,
517    verbose: bool,
518    quiet: bool,
519) -> Result<()> {
520    if !quiet {
521        eprintln!("🔗 正在连接到后端服务 (SSE)...");
522    }
523
524    // 1. 使用高层 API 连接
525    let connect_timeout = Duration::from_secs(30);
526    let conn = tokio::time::timeout(connect_timeout, SseClientConnection::connect(config.clone()))
527        .await
528        .map_err(|_| anyhow::anyhow!("连接后端超时 ({}秒)", connect_timeout.as_secs()))?
529        .map_err(|e| anyhow::anyhow!("连接后端失败: {}", e))?;
530
531    if !quiet {
532        eprintln!("✅ 后端连接成功");
533        // 打印工具列表
534        print_sse_tools(&conn, quiet).await;
535        if args.ping_interval > 0 {
536            eprintln!("💓 心跳检测: 每 {}s ping 一次(超时 {}s)", args.ping_interval, args.ping_timeout);
537        }
538    }
539
540    // 2. 创建 handler(消耗 conn)
541    let handler = Arc::new(conn.into_handler("cli".to_string(), tool_filter.clone()));
542
543    // 3. 启动 stdio server
544    let server = (*handler).clone().serve(sse_stdio()).await?;
545
546    if !quiet {
547        eprintln!("💡 stdio server 已启动,开始代理转换...");
548    }
549
550    // 4. 启动 watchdog 任务
551    let handler_for_watchdog = handler.clone();
552    let mut watchdog_handle = tokio::spawn(run_sse_watchdog(
553        handler_for_watchdog,
554        args,
555        config,
556        tool_filter,
557        verbose,
558        quiet,
559    ));
560
561    // 5. 等待 stdio server 退出
562    tokio::select! {
563        result = server.waiting() => {
564            watchdog_handle.abort();
565            result?;
566        }
567        watchdog_result = &mut watchdog_handle => {
568            if let Err(e) = watchdog_result {
569                if !e.is_cancelled() {
570                    error!("SSE Watchdog task failed: {:?}", e);
571                }
572            }
573        }
574    }
575
576    Ok(())
577}
578
579/// Stream 模式处理(使用 mcp-streamable-proxy,rmcp 0.12)
580async fn run_stream_mode(
581    config: McpClientConfig,
582    args: ConvertArgs,
583    tool_filter: ToolFilter,
584    verbose: bool,
585    quiet: bool,
586) -> Result<()> {
587    if !quiet {
588        eprintln!("🔗 正在连接到后端服务 (Stream)...");
589    }
590
591    // 1. 使用高层 API 连接
592    let connect_timeout = Duration::from_secs(30);
593    let conn = tokio::time::timeout(connect_timeout, StreamClientConnection::connect(config.clone()))
594        .await
595        .map_err(|_| anyhow::anyhow!("连接后端超时 ({}秒)", connect_timeout.as_secs()))?
596        .map_err(|e| anyhow::anyhow!("连接后端失败: {}", e))?;
597
598    if !quiet {
599        eprintln!("✅ 后端连接成功");
600        // 打印工具列表
601        print_stream_tools(&conn, quiet).await;
602        if args.ping_interval > 0 {
603            eprintln!("💓 心跳检测: 每 {}s ping 一次(超时 {}s)", args.ping_interval, args.ping_timeout);
604        }
605    }
606
607    // 2. 创建 handler(消耗 conn)
608    let handler = Arc::new(conn.into_handler("cli".to_string(), tool_filter.clone()));
609
610    // 3. 启动 stdio server(使用 stream_stdio,即 rmcp 0.12 的 stdio)
611    let server = (*handler).clone().serve(stream_stdio()).await?;
612
613    if !quiet {
614        eprintln!("💡 stdio server 已启动,开始代理转换...");
615    }
616
617    // 4. 启动 watchdog 任务
618    let handler_for_watchdog = handler.clone();
619    let mut watchdog_handle = tokio::spawn(run_stream_watchdog(
620        handler_for_watchdog,
621        args,
622        config,
623        tool_filter,
624        verbose,
625        quiet,
626    ));
627
628    // 5. 等待 stdio server 退出
629    tokio::select! {
630        result = server.waiting() => {
631            watchdog_handle.abort();
632            result?;
633        }
634        watchdog_result = &mut watchdog_handle => {
635            if let Err(e) = watchdog_result {
636                if !e.is_cancelled() {
637                    error!("Stream Watchdog task failed: {:?}", e);
638                }
639            }
640        }
641    }
642
643    Ok(())
644}
645
646/// 打印 SSE 连接的工具列表
647async fn print_sse_tools(conn: &SseClientConnection, quiet: bool) {
648    if quiet {
649        return;
650    }
651    match conn.list_tools().await {
652        Ok(tools) => {
653            if tools.is_empty() {
654                eprintln!("⚠️  工具列表为空 (tools/list 返回 0 个工具)");
655            } else {
656                eprintln!("🔧 可用工具 ({} 个):", tools.len());
657                for tool in &tools {
658                    let desc = tool.description.as_deref().unwrap_or("无描述");
659                    let desc_short = truncate_str(desc, 50);
660                    eprintln!("   - {} : {}", tool.name, desc_short);
661                }
662            }
663        }
664        Err(e) => {
665            eprintln!("⚠️  获取工具列表失败: {}", e);
666        }
667    }
668}
669
670/// 打印 Stream 连接的工具列表
671async fn print_stream_tools(conn: &StreamClientConnection, quiet: bool) {
672    if quiet {
673        return;
674    }
675    match conn.list_tools().await {
676        Ok(tools) => {
677            if tools.is_empty() {
678                eprintln!("⚠️  工具列表为空 (tools/list 返回 0 个工具)");
679            } else {
680                eprintln!("🔧 可用工具 ({} 个):", tools.len());
681                for tool in &tools {
682                    let desc = tool.description.as_deref().unwrap_or("无描述");
683                    let desc_short = truncate_str(desc, 50);
684                    eprintln!("   - {} : {}", tool.name, desc_short);
685                }
686            }
687        }
688        Err(e) => {
689            eprintln!("⚠️  获取工具列表失败: {}", e);
690        }
691    }
692}
693
694/// 截断字符串(UTF-8 安全)
695fn truncate_str(s: &str, max_len: usize) -> String {
696    if s.chars().count() > max_len {
697        format!("{}...", s.chars().take(max_len - 3).collect::<String>())
698    } else {
699        s.to_string()
700    }
701}
702
703/// SSE 模式的 watchdog:负责监控连接健康、断开时重连
704async fn run_sse_watchdog(
705    handler: Arc<ProxyHandler>,
706    args: ConvertArgs,
707    config: McpClientConfig,
708    _tool_filter: ToolFilter,
709    verbose: bool,
710    quiet: bool,
711) {
712    let max_retries = args.retries;
713    let mut attempt = 0u32;
714    let mut backoff_secs = 1u64;
715    const MAX_BACKOFF_SECS: u64 = 30;
716
717    // 首先监控现有连接的健康状态
718    let disconnect_reason = monitor_sse_connection(
719        &handler,
720        args.ping_interval,
721        args.ping_timeout,
722        quiet,
723    ).await;
724
725    // 连接断开,标记后端不可用
726    handler.swap_backend(None);
727
728    if !quiet {
729        eprintln!("⚠️  连接断开: {}", disconnect_reason);
730    }
731
732    // 进入重连循环
733    loop {
734        attempt += 1;
735
736        if !quiet {
737            eprintln!("🔗 正在重新连接 (第{}次尝试)...", attempt);
738        }
739
740        // 尝试建立连接
741        let connect_result = SseClientConnection::connect(config.clone()).await;
742
743        match connect_result {
744            Ok(conn) => {
745                // 连接成功,获取 RunningService 并热替换后端
746                let running = conn.into_running_service();
747                handler.swap_backend(Some(running));
748                backoff_secs = 1;
749
750                if !quiet {
751                    eprintln!("✅ 重连成功,恢复代理服务");
752                }
753
754                // 监控连接健康
755                let disconnect_reason = monitor_sse_connection(
756                    &handler,
757                    args.ping_interval,
758                    args.ping_timeout,
759                    quiet,
760                ).await;
761
762                // 连接断开,标记后端不可用
763                handler.swap_backend(None);
764
765                if !quiet {
766                    eprintln!("⚠️  连接断开: {}", disconnect_reason);
767                }
768            }
769            Err(e) => {
770                let error_type = classify_error(&e);
771
772                if max_retries > 0 && attempt >= max_retries {
773                    if !quiet {
774                        eprintln!("❌ 连接失败,已达最大重试次数 ({})", max_retries);
775                        eprintln!("   错误类型: {}", error_type);
776                        eprintln!("   错误详情: {}", e);
777                    }
778                    break;
779                }
780
781                if !quiet {
782                    if max_retries == 0 {
783                        eprintln!("⚠️  连接失败 [{}]: {},{}秒后重连 (第{}次)...",
784                            error_type, summarize_error(&e), backoff_secs, attempt);
785                    } else {
786                        eprintln!("⚠️  连接失败 [{}]: {},{}秒后重连 ({}/{})...",
787                            error_type, summarize_error(&e), backoff_secs, attempt, max_retries);
788                    }
789                }
790
791                if verbose && !quiet {
792                    eprintln!("   完整错误: {}", e);
793                }
794            }
795        }
796
797        tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
798        backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF_SECS);
799    }
800}
801
802/// Stream 模式的 watchdog:负责监控连接健康、断开时重连
803async fn run_stream_watchdog(
804    handler: Arc<StreamProxyHandler>,
805    args: ConvertArgs,
806    config: McpClientConfig,
807    _tool_filter: ToolFilter,
808    verbose: bool,
809    quiet: bool,
810) {
811    let max_retries = args.retries;
812    let mut attempt = 0u32;
813    let mut backoff_secs = 1u64;
814    const MAX_BACKOFF_SECS: u64 = 30;
815
816    // 首先监控现有连接的健康状态
817    let disconnect_reason = monitor_stream_connection(
818        &handler,
819        args.ping_interval,
820        args.ping_timeout,
821        quiet,
822    ).await;
823
824    // 连接断开,标记后端不可用
825    handler.swap_backend(None);
826
827    if !quiet {
828        eprintln!("⚠️  连接断开: {}", disconnect_reason);
829    }
830
831    // 进入重连循环
832    loop {
833        attempt += 1;
834
835        if !quiet {
836            eprintln!("🔗 正在重新连接 (第{}次尝试)...", attempt);
837        }
838
839        // 尝试建立连接
840        let connect_result = StreamClientConnection::connect(config.clone()).await;
841
842        match connect_result {
843            Ok(conn) => {
844                // 连接成功,获取 RunningService 并热替换后端
845                let running = conn.into_running_service();
846                handler.swap_backend(Some(running));
847                backoff_secs = 1;
848
849                if !quiet {
850                    eprintln!("✅ 重连成功,恢复代理服务");
851                }
852
853                // 监控连接健康
854                let disconnect_reason = monitor_stream_connection(
855                    &handler,
856                    args.ping_interval,
857                    args.ping_timeout,
858                    quiet,
859                ).await;
860
861                // 连接断开,标记后端不可用
862                handler.swap_backend(None);
863
864                if !quiet {
865                    eprintln!("⚠️  连接断开: {}", disconnect_reason);
866                }
867            }
868            Err(e) => {
869                let error_type = classify_error(&e);
870
871                if max_retries > 0 && attempt >= max_retries {
872                    if !quiet {
873                        eprintln!("❌ 连接失败,已达最大重试次数 ({})", max_retries);
874                        eprintln!("   错误类型: {}", error_type);
875                        eprintln!("   错误详情: {}", e);
876                    }
877                    break;
878                }
879
880                if !quiet {
881                    if max_retries == 0 {
882                        eprintln!("⚠️  连接失败 [{}]: {},{}秒后重连 (第{}次)...",
883                            error_type, summarize_error(&e), backoff_secs, attempt);
884                    } else {
885                        eprintln!("⚠️  连接失败 [{}]: {},{}秒后重连 ({}/{})...",
886                            error_type, summarize_error(&e), backoff_secs, attempt, max_retries);
887                    }
888                }
889
890                if verbose && !quiet {
891                    eprintln!("   完整错误: {}", e);
892                }
893            }
894        }
895
896        tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
897        backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF_SECS);
898    }
899}
900
901/// SSE 模式:监控连接健康,返回断开原因
902async fn monitor_sse_connection(
903    handler: &ProxyHandler,
904    ping_interval: u64,
905    ping_timeout: u64,
906    quiet: bool,
907) -> String {
908    if ping_interval == 0 {
909        loop {
910            tokio::time::sleep(Duration::from_secs(1)).await;
911            if !handler.is_backend_available() {
912                return "后端连接已关闭".to_string();
913            }
914        }
915    }
916
917    let mut interval = tokio::time::interval(Duration::from_secs(ping_interval));
918    interval.tick().await;
919
920    loop {
921        interval.tick().await;
922
923        if !handler.is_backend_available() {
924            return "后端连接已关闭".to_string();
925        }
926
927        let check_result = tokio::time::timeout(
928            Duration::from_secs(ping_timeout),
929            handler.is_terminated_async()
930        ).await;
931
932        match check_result {
933            Ok(true) => return "Ping 检测失败(服务错误)".to_string(),
934            Ok(false) => {}
935            Err(_) => {
936                if !quiet {
937                    eprintln!("❌ Ping 检测超时({}s)", ping_timeout);
938                }
939                return format!("Ping 检测超时({}s)", ping_timeout);
940            }
941        }
942    }
943}
944
945/// Stream 模式:监控连接健康,返回断开原因
946async fn monitor_stream_connection(
947    handler: &StreamProxyHandler,
948    ping_interval: u64,
949    ping_timeout: u64,
950    quiet: bool,
951) -> String {
952    if ping_interval == 0 {
953        loop {
954            tokio::time::sleep(Duration::from_secs(1)).await;
955            if !handler.is_backend_available() {
956                return "后端连接已关闭".to_string();
957            }
958        }
959    }
960
961    let mut interval = tokio::time::interval(Duration::from_secs(ping_interval));
962    interval.tick().await;
963
964    loop {
965        interval.tick().await;
966
967        if !handler.is_backend_available() {
968            return "后端连接已关闭".to_string();
969        }
970
971        let check_result = tokio::time::timeout(
972            Duration::from_secs(ping_timeout),
973            handler.is_terminated_async()
974        ).await;
975
976        match check_result {
977            Ok(true) => return "Ping 检测失败(服务错误)".to_string(),
978            Ok(false) => {}
979            Err(_) => {
980                if !quiet {
981                    eprintln!("❌ Ping 检测超时({}s)", ping_timeout);
982                }
983                return format!("Ping 检测超时({}s)", ping_timeout);
984            }
985        }
986    }
987}
988
989/// 错误分类
990fn classify_error(e: &anyhow::Error) -> &'static str {
991    let err_str = e.to_string().to_lowercase();
992
993    if err_str.contains("timeout") || err_str.contains("timed out") {
994        "超时"
995    } else if err_str.contains("connection refused") {
996        "连接被拒绝"
997    } else if err_str.contains("connection reset") {
998        "连接被重置"
999    } else if err_str.contains("dns") || err_str.contains("resolve") {
1000        "DNS解析失败"
1001    } else if err_str.contains("certificate") || err_str.contains("ssl") || err_str.contains("tls") {
1002        "SSL/TLS错误"
1003    } else if err_str.contains("session") {
1004        "会话错误"
1005    } else if err_str.contains("sending request") || err_str.contains("network") {
1006        "网络错误"
1007    } else if err_str.contains("eof") || err_str.contains("closed") || err_str.contains("shutdown") {
1008        "连接关闭"
1009    } else {
1010        "未知错误"
1011    }
1012}
1013
1014/// 简化错误信息(用于单行日志)
1015fn summarize_error(e: &anyhow::Error) -> String {
1016    let full = e.to_string();
1017    // 截取第一行或前80个字符
1018    let first_line = full.lines().next().unwrap_or(&full);
1019    // 使用 chars() 安全处理 UTF-8 字符,避免在多字节字符中间截断
1020    if first_line.chars().count() > 80 {
1021        format!("{}...", first_line.chars().take(77).collect::<String>())
1022    } else {
1023        first_line.to_string()
1024    }
1025}
1026
1027/// 命令模式执行(本地子进程)
1028/// 使用 SSE 库(rmcp 0.10)的类型
1029async fn run_command_mode(
1030    name: &str,
1031    command: &str,
1032    cmd_args: Vec<String>,
1033    env: HashMap<String, String>,
1034    tool_filter: ToolFilter,
1035    verbose: bool,
1036    quiet: bool,
1037) -> Result<()> {
1038    if !quiet {
1039        eprintln!("🚀 MCP-Stdio-Proxy: {} (command) → stdio", name);
1040        eprintln!("   命令: {} {:?}", command, cmd_args);
1041        if verbose && !env.is_empty() {
1042            eprintln!("   环境变量: {:?}", env);
1043        }
1044    }
1045
1046    // 显示过滤器配置
1047    if !quiet {
1048        if tool_filter.is_enabled() {
1049            eprintln!("🔧 工具过滤已启用");
1050        }
1051    }
1052
1053    // 创建子进程命令
1054    let mut cmd = Command::new(command);
1055    cmd.args(&cmd_args);
1056    for (k, v) in &env {
1057        cmd.env(k, v);
1058    }
1059
1060    // 启动子进程
1061    let tokio_process = TokioChildProcess::new(cmd)?;
1062
1063    if !quiet {
1064        eprintln!("🔗 启动子进程...");
1065    }
1066
1067    // 创建 ClientInfo(使用 SSE 库的类型,rmcp 0.10)
1068    let client_info = create_sse_client_info();
1069
1070    // 连接到子进程
1071    let running = client_info.serve(tokio_process).await?;
1072
1073    if !quiet {
1074        eprintln!("✅ 子进程已启动,开始代理转换...");
1075
1076        // 打印工具列表
1077        match running.list_tools(None).await {
1078            Ok(tools_result) => {
1079                let tools = &tools_result.tools;
1080                if tools.is_empty() {
1081                    eprintln!("⚠️  工具列表为空 (tools/list 返回 0 个工具)");
1082                } else {
1083                    eprintln!("🔧 可用工具 ({} 个):", tools.len());
1084                    for tool in tools {
1085                        let desc = tool.description.as_deref().unwrap_or("无描述");
1086                        let desc_short = truncate_str(desc, 50);
1087                        eprintln!("   - {} : {}", tool.name, desc_short);
1088                    }
1089                }
1090            }
1091            Err(e) => {
1092                eprintln!("⚠️  获取工具列表失败: {}", e);
1093            }
1094        }
1095
1096        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
1097    }
1098
1099    // 使用 ProxyHandler + stdio 将本地 MCP 服务透明暴露为 stdio
1100    let proxy_handler = ProxyHandler::with_tool_filter(running, name.to_string(), tool_filter);
1101    let server = proxy_handler.serve(sse_stdio()).await?;
1102    server.waiting().await?;
1103
1104    Ok(())
1105}
1106
1107/// 创建 SSE 库的 ClientInfo(rmcp 0.10)
1108fn create_sse_client_info() -> SseClientInfo {
1109    SseClientInfo {
1110        protocol_version: Default::default(),
1111        capabilities: SseClientCapabilities::builder()
1112            .enable_experimental()
1113            .enable_roots()
1114            .enable_roots_list_changed()
1115            .enable_sampling()
1116            .build(),
1117        client_info: SseImplementation {
1118            name: "mcp-proxy-cli".to_string(),
1119            version: env!("CARGO_PKG_VERSION").to_string(),
1120            title: None,
1121            website_url: None,
1122            icons: None,
1123        },
1124    }
1125}
1126
1127/// 获取协议名称
1128fn protocol_name(protocol: &super::protocol::McpProtocol) -> &'static str {
1129    match protocol {
1130        super::protocol::McpProtocol::Sse => "SSE",
1131        super::protocol::McpProtocol::Stream => "Streamable HTTP",
1132        super::protocol::McpProtocol::Stdio => "Stdio",
1133    }
1134}
1135
1136/// 运行检查命令
1137async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
1138    if !quiet {
1139        eprintln!("🔍 检查服务: {}", args.url);
1140    }
1141
1142    match super::protocol::detect_mcp_protocol(&args.url).await {
1143        Ok(protocol) => {
1144            if !quiet {
1145                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
1146            }
1147            Ok(())
1148        }
1149        Err(e) => {
1150            if !quiet {
1151                eprintln!("❌ 服务检查失败: {}", e);
1152            }
1153            Err(e)
1154        }
1155    }
1156}
1157
1158/// 运行协议检测命令
1159async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
1160    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
1161
1162    if quiet {
1163        println!("{}", protocol);
1164    } else {
1165        eprintln!("{}", protocol);
1166    }
1167
1168    Ok(())
1169}