mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use std::collections::HashMap;
5
6use clap::Parser;
7use anyhow::{Result, bail};
8use serde::Deserialize;
9use tokio::process::Command;
10
11use rmcp::{
12    ServiceExt,
13    model::{ClientCapabilities, ClientInfo},
14    transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
15};
16use crate::proxy::{ProxyHandler, ToolFilter};
17
18/// MCP-Proxy CLI 主命令结构
19#[derive(Parser, Debug)]
20#[command(name = "mcp-proxy")]
21#[command(version = env!("CARGO_PKG_VERSION"))]
22#[command(about = "MCP 协议转换代理工具", long_about = None)]
23pub struct Cli {
24    #[command(subcommand)]
25    pub command: Option<Commands>,
26    
27    /// 直接URL模式(向后兼容)
28    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
29    pub url: Option<String>,
30    
31    /// 全局详细输出
32    #[arg(short, long, global = true)]
33    pub verbose: bool,
34    
35    /// 全局静默模式
36    #[arg(short, long, global = true)]
37    pub quiet: bool,
38}
39
40#[derive(clap::Subcommand, Debug)]
41pub enum Commands {
42    /// 协议转换模式 - 将 URL 转换为 stdio
43    Convert(ConvertArgs),
44
45    /// 检查服务状态
46    Check(CheckArgs),
47
48    /// 协议检测
49    Detect(DetectArgs),
50
51    /// 代理模式 - 将 stdio MCP 服务代理为 HTTP/SSE 服务
52    Proxy(super::proxy_server::ProxyArgs),
53}
54
55/// 协议转换参数
56#[derive(Parser, Debug, Clone)]
57pub struct ConvertArgs {
58    /// MCP 服务的 URL 地址(可选,与 --config/--config-file 二选一)
59    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
60    pub url: Option<String>,
61
62    /// MCP 服务配置 JSON
63    #[arg(long, conflicts_with = "config_file", help = "MCP 服务配置 JSON")]
64    pub config: Option<String>,
65
66    /// MCP 服务配置文件路径
67    #[arg(long, conflicts_with = "config", help = "MCP 服务配置文件路径")]
68    pub config_file: Option<std::path::PathBuf>,
69
70    /// MCP 服务名称(多服务配置时必需)
71    #[arg(short, long, help = "MCP 服务名称(多服务配置时必需)")]
72    pub name: Option<String>,
73
74    /// 指定远程服务协议类型(不指定则自动检测)
75    #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
76    pub protocol: Option<super::proxy_server::ProxyProtocol>,
77
78    /// 认证 header (如: "Bearer token")
79    #[arg(short, long, help = "认证 header")]
80    pub auth: Option<String>,
81
82    /// 自定义 HTTP headers
83    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
84    pub header: Vec<(String, String)>,
85
86    /// 连接超时时间(秒)
87    #[arg(long, default_value = "300", help = "连接超时时间(秒),默认5分钟")]
88    pub timeout: u64,
89
90    /// 重试次数
91    #[arg(long, default_value = "3", help = "重试次数")]
92    pub retries: u32,
93
94    /// 工具白名单(逗号分隔),只允许指定的工具
95    #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
96    pub allow_tools: Option<Vec<String>>,
97
98    /// 工具黑名单(逗号分隔),排除指定的工具
99    #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
100    pub deny_tools: Option<Vec<String>>,
101}
102
103/// 检查参数
104#[derive(Parser, Debug)]
105pub struct CheckArgs {
106    /// 要检查的 MCP 服务 URL
107    #[arg(value_name = "URL")]
108    pub url: String,
109    
110    /// 认证 header
111    #[arg(short, long)]
112    pub auth: Option<String>,
113    
114    /// 超时时间
115    #[arg(long, default_value = "10")]
116    pub timeout: u64,
117}
118
119/// 协议检测参数
120#[derive(Parser, Debug)]
121pub struct DetectArgs {
122    /// 要检测的 MCP 服务 URL
123    #[arg(value_name = "URL")]
124    pub url: String,
125    
126    /// 认证 header
127    #[arg(short, long)]
128    pub auth: Option<String>,
129}
130
131/// 解析 KEY=VALUE 格式的辅助函数
132fn parse_key_val(s: &str) -> Result<(String, String)> {
133    let pos = s.find('=')
134        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
135    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
136}
137
138// ============== MCP 配置解析相关 ==============
139
140/// MCP 配置格式
141#[derive(Deserialize, Debug)]
142struct McpConfig {
143    #[serde(rename = "mcpServers")]
144    mcp_servers: HashMap<String, McpServerInnerConfig>,
145}
146
147/// MCP 服务配置(支持 Command 和 Url 两种类型)
148#[derive(Deserialize, Debug, Clone)]
149#[serde(untagged)]
150enum McpServerInnerConfig {
151    Command(StdioConfig),
152    Url(UrlConfig),
153}
154
155/// stdio 配置(本地命令)
156#[derive(Deserialize, Debug, Clone)]
157struct StdioConfig {
158    command: String,
159    args: Option<Vec<String>>,
160    env: Option<HashMap<String, String>>,
161}
162
163/// URL 配置(远程服务)
164#[derive(Deserialize, Debug, Clone)]
165struct UrlConfig {
166    #[serde(skip_serializing_if = "Option::is_none")]
167    url: Option<String>,
168    #[serde(
169        skip_serializing_if = "Option::is_none",
170        default,
171        rename = "baseUrl",
172        alias = "baseurl",
173        alias = "base_url"
174    )]
175    base_url: Option<String>,
176    #[serde(default, rename = "type", alias = "Type")]
177    r#type: Option<String>,
178    pub headers: Option<HashMap<String, String>>,
179    #[serde(default, alias = "authToken", alias = "auth_token")]
180    pub auth_token: Option<String>,
181    pub timeout: Option<u64>,
182}
183
184impl UrlConfig {
185    fn get_url(&self) -> Option<&str> {
186        self.url.as_deref().or(self.base_url.as_deref())
187    }
188}
189
190/// 解析后的配置源
191enum McpConfigSource {
192    /// 直接 URL 模式(命令行参数)
193    DirectUrl {
194        url: String,
195    },
196    /// 远程服务配置(JSON 配置)
197    RemoteService {
198        name: String,
199        url: String,
200        protocol: Option<super::protocol::McpProtocol>,
201        headers: HashMap<String, String>,
202        timeout: Option<u64>,
203    },
204    /// 本地命令配置(JSON 配置)
205    LocalCommand {
206        name: String,
207        command: String,
208        args: Vec<String>,
209        env: HashMap<String, String>,
210    },
211}
212
213/// 解析 convert 命令的配置
214fn parse_convert_config(args: &ConvertArgs) -> Result<McpConfigSource> {
215    // 优先级:url > config > config_file
216    if let Some(ref url) = args.url {
217        return Ok(McpConfigSource::DirectUrl { url: url.clone() });
218    }
219
220    // 读取 JSON 配置
221    let json_str = if let Some(ref config) = args.config {
222        config.clone()
223    } else if let Some(ref path) = args.config_file {
224        std::fs::read_to_string(path)
225            .map_err(|e| anyhow::anyhow!("读取配置文件失败: {}", e))?
226    } else {
227        bail!("必须提供 URL、--config 或 --config-file 参数之一");
228    };
229
230    // 解析 JSON 配置
231    let mcp_config: McpConfig = serde_json::from_str(&json_str)
232        .map_err(|e| anyhow::anyhow!(
233            "配置解析失败: {}。配置必须是标准 MCP 格式,包含 mcpServers 字段",
234            e
235        ))?;
236
237    let servers = mcp_config.mcp_servers;
238
239    if servers.is_empty() {
240        bail!("配置中没有找到任何 MCP 服务");
241    }
242
243    // 选择服务
244    let (name, inner_config) = if servers.len() == 1 {
245        servers.into_iter().next().unwrap()
246    } else if let Some(ref name) = args.name {
247        let config = servers.get(name)
248            .cloned()
249            .ok_or_else(|| anyhow::anyhow!(
250                "服务 '{}' 不存在。可用服务: {:?}",
251                name,
252                servers.keys().collect::<Vec<_>>()
253            ))?;
254        (name.clone(), config)
255    } else {
256        bail!(
257            "配置包含多个服务 {:?},请使用 --name 指定要使用的服务",
258            servers.keys().collect::<Vec<_>>()
259        );
260    };
261
262    // 根据配置类型返回
263    match inner_config {
264        McpServerInnerConfig::Command(stdio) => {
265            Ok(McpConfigSource::LocalCommand {
266                name,
267                command: stdio.command,
268                args: stdio.args.unwrap_or_default(),
269                env: stdio.env.unwrap_or_default(),
270            })
271        }
272        McpServerInnerConfig::Url(url_config) => {
273            let url = url_config.get_url()
274                .ok_or_else(|| anyhow::anyhow!("URL 配置缺少 url 或 baseUrl 字段"))?
275                .to_string();
276
277            // 解析协议类型
278            let protocol = url_config.r#type.as_ref().and_then(|t| {
279                match t.as_str() {
280                    "sse" => Some(super::protocol::McpProtocol::Sse),
281                    "http" | "stream" => Some(super::protocol::McpProtocol::Stream),
282                    _ => None,
283                }
284            });
285
286            // 合并 headers:JSON 配置中的 auth_token -> Authorization
287            let mut headers = url_config.headers.clone().unwrap_or_default();
288            if let Some(auth_token) = &url_config.auth_token {
289                headers.insert("Authorization".to_string(), auth_token.clone());
290            }
291
292            Ok(McpConfigSource::RemoteService {
293                name,
294                url,
295                protocol,
296                headers,
297                timeout: url_config.timeout,
298            })
299        }
300    }
301}
302
303/// 合并 headers:JSON 配置 + 命令行参数(命令行优先)
304fn merge_headers(
305    config_headers: HashMap<String, String>,
306    cli_headers: &[(String, String)],
307    cli_auth: Option<&String>,
308) -> HashMap<String, String> {
309    let mut merged = config_headers;
310
311    // 命令行 -H 参数覆盖配置
312    for (key, value) in cli_headers {
313        merged.insert(key.clone(), value.clone());
314    }
315
316    // 命令行 --auth 参数优先级最高
317    if let Some(auth_value) = cli_auth {
318        merged.insert("Authorization".to_string(), auth_value.clone());
319    }
320
321    merged
322}
323
324/// 运行 CLI 主逻辑
325pub async fn run_cli(cli: Cli) -> Result<()> {
326    match cli.command {
327        Some(Commands::Convert(args)) => {
328            run_convert_command(args, cli.verbose, cli.quiet).await
329        }
330        Some(Commands::Check(args)) => {
331            run_check_command(args, cli.verbose, cli.quiet).await
332        }
333        Some(Commands::Detect(args)) => {
334            run_detect_command(args, cli.verbose, cli.quiet).await
335        }
336        Some(Commands::Proxy(args)) => {
337            super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
338        }
339        None => {
340            // 直接 URL 模式(向后兼容)
341            if let Some(url) = cli.url {
342                let args = ConvertArgs {
343                    url: Some(url),
344                    config: None,
345                    config_file: None,
346                    name: None,
347                    protocol: None,
348                    auth: None,
349                    header: vec![],
350                    timeout: 300,  // 5分钟,匹配 ProxyHandler 的工具调用超时
351                    retries: 3,
352                    allow_tools: None,
353                    deny_tools: None,
354                };
355                run_convert_command(args, cli.verbose, cli.quiet).await
356            } else {
357                bail!("请提供 URL 或使用子命令")
358            }
359        }
360    }
361}
362
363/// 运行转换命令 - 核心功能
364async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
365    // 检查 --allow-tools 和 --deny-tools 互斥
366    if args.allow_tools.is_some() && args.deny_tools.is_some() {
367        bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
368    }
369
370    // 创建工具过滤器
371    let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
372        ToolFilter::allow(allow_tools)
373    } else if let Some(deny_tools) = args.deny_tools.clone() {
374        ToolFilter::deny(deny_tools)
375    } else {
376        ToolFilter::default()
377    };
378
379    // 解析配置
380    let config_source = parse_convert_config(&args)?;
381
382    // 配置客户端能力
383    let client_info = ClientInfo {
384        protocol_version: Default::default(),
385        capabilities: ClientCapabilities::builder()
386            .enable_experimental()
387            .enable_roots()
388            .enable_roots_list_changed()
389            .enable_sampling()
390            .build(),
391        ..Default::default()
392    };
393
394    // 根据配置源执行不同逻辑
395    match config_source {
396        McpConfigSource::DirectUrl { url } => {
397            // 直接 URL 模式(原有逻辑)
398            run_url_mode(&args, &url, HashMap::new(), None, client_info, tool_filter, verbose, quiet).await
399        }
400        McpConfigSource::RemoteService { name, url, protocol, headers, timeout } => {
401            // 远程服务配置模式
402            if !quiet {
403                eprintln!("🚀 MCP-Stdio-Proxy: {} ({}) → stdio", name, url);
404            }
405            // 合并 headers:配置 + 命令行
406            let merged_headers = merge_headers(headers, &args.header, args.auth.as_ref());
407            run_url_mode(&args, &url, merged_headers, protocol.or(timeout.map(|_| super::protocol::McpProtocol::Stream)), client_info, tool_filter, verbose, quiet).await
408        }
409        McpConfigSource::LocalCommand { name, command, args: cmd_args, env } => {
410            // 本地命令模式
411            run_command_mode(&name, &command, cmd_args, env, client_info, tool_filter, verbose, quiet).await
412        }
413    }
414}
415
416/// URL 模式执行(远程 HTTP/SSE 服务)
417async fn run_url_mode(
418    args: &ConvertArgs,
419    url: &str,
420    merged_headers: HashMap<String, String>,
421    config_protocol: Option<super::protocol::McpProtocol>,
422    client_info: ClientInfo,
423    tool_filter: ToolFilter,
424    verbose: bool,
425    quiet: bool,
426) -> Result<()> {
427    if !quiet && merged_headers.is_empty() {
428        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", url);
429    }
430
431    if verbose && !quiet {
432        eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
433    }
434
435    // 显示过滤器配置
436    if !quiet {
437        if let Some(ref allow_tools) = args.allow_tools {
438            eprintln!("🔧 工具白名单: {:?}", allow_tools);
439        }
440        if let Some(ref deny_tools) = args.deny_tools {
441            eprintln!("🔧 工具黑名单: {:?}", deny_tools);
442        }
443    }
444
445    // 确定协议类型:命令行参数 > 配置文件 > 自动检测
446    let protocol = if let Some(ref proto) = args.protocol {
447        // 命令行指定协议
448        let detected = match proto {
449            super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
450            super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
451        };
452        if !quiet {
453            eprintln!("🔧 使用指定协议: {}", protocol_name(&detected));
454        }
455        detected
456    } else if let Some(proto) = config_protocol {
457        // 配置文件指定协议
458        if !quiet {
459            eprintln!("🔧 使用配置协议: {}", protocol_name(&proto));
460        }
461        proto
462    } else {
463        // 自动检测协议
464        let detected = super::protocol::detect_mcp_protocol(url).await?;
465        if !quiet {
466            eprintln!("🔍 检测到 {} 协议", protocol_name(&detected));
467        }
468        detected
469    };
470
471    if !quiet {
472        eprintln!("🔗 建立连接...");
473    }
474
475    // 构建带认证与自定义头的 HTTP 客户端
476    let http_client = create_http_client_with_headers(&merged_headers, &args.header, args.auth.as_ref(), args.timeout)?;
477
478    // 为不同协议创建传输并启动 rmcp 客户端
479    let running = match protocol {
480        super::protocol::McpProtocol::Sse => {
481            let cfg = SseClientConfig {
482                sse_endpoint: url.to_string().into(),
483                ..Default::default()
484            };
485            let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
486            client_info.serve(transport).await?
487        }
488        super::protocol::McpProtocol::Stream => {
489            let cfg = StreamableHttpClientTransportConfig {
490                uri: url.to_string().into(),
491                ..Default::default()
492            };
493            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
494            client_info.serve(transport).await?
495        }
496        super::protocol::McpProtocol::Stdio => {
497            bail!("Stdio 协议不支持通过 URL 转换,请使用 --config 配置本地命令")
498        }
499    };
500
501    if !quiet {
502        eprintln!("✅ 连接成功,开始代理转换...");
503
504        // 打印工具列表
505        match running.list_tools(None).await {
506            Ok(tools_result) => {
507                let tools = &tools_result.tools;
508                if tools.is_empty() {
509                    eprintln!("⚠️  工具列表为空 (tools/list 返回 0 个工具)");
510                } else {
511                    eprintln!("🔧 可用工具 ({} 个):", tools.len());
512                    for tool in tools {
513                        let desc = tool.description.as_deref().unwrap_or("无描述");
514                        let desc_short = if desc.chars().count() > 50 {
515                            format!("{}...", desc.chars().take(50).collect::<String>())
516                        } else {
517                            desc.to_string()
518                        };
519                        eprintln!("   - {} : {}", tool.name, desc_short);
520                    }
521                }
522            }
523            Err(e) => {
524                eprintln!("⚠️  获取工具列表失败: {}", e);
525            }
526        }
527
528        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
529    }
530
531    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
532    let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
533    let stdio_transport = stdio();
534    let server = proxy_handler.serve(stdio_transport).await?;
535    server.waiting().await?;
536
537    Ok(())
538}
539
540/// 命令模式执行(本地子进程)
541async fn run_command_mode(
542    name: &str,
543    command: &str,
544    cmd_args: Vec<String>,
545    env: HashMap<String, String>,
546    client_info: ClientInfo,
547    tool_filter: ToolFilter,
548    verbose: bool,
549    quiet: bool,
550) -> Result<()> {
551    if !quiet {
552        eprintln!("🚀 MCP-Stdio-Proxy: {} (command) → stdio", name);
553        eprintln!("   命令: {} {:?}", command, cmd_args);
554        if verbose && !env.is_empty() {
555            eprintln!("   环境变量: {:?}", env);
556        }
557    }
558
559    // 显示过滤器配置
560    if !quiet {
561        if tool_filter.is_enabled() {
562            eprintln!("🔧 工具过滤已启用");
563        }
564    }
565
566    // 创建子进程命令
567    let mut cmd = Command::new(command);
568    cmd.args(&cmd_args);
569    for (k, v) in &env {
570        cmd.env(k, v);
571    }
572
573    // 启动子进程
574    let tokio_process = TokioChildProcess::new(cmd)?;
575
576    if !quiet {
577        eprintln!("🔗 启动子进程...");
578    }
579
580    // 连接到子进程
581    let running = client_info.serve(tokio_process).await?;
582
583    if !quiet {
584        eprintln!("✅ 子进程已启动,开始代理转换...");
585
586        // 打印工具列表
587        match running.list_tools(None).await {
588            Ok(tools_result) => {
589                let tools = &tools_result.tools;
590                if tools.is_empty() {
591                    eprintln!("⚠️  工具列表为空 (tools/list 返回 0 个工具)");
592                } else {
593                    eprintln!("🔧 可用工具 ({} 个):", tools.len());
594                    for tool in tools {
595                        let desc = tool.description.as_deref().unwrap_or("无描述");
596                        let desc_short = if desc.chars().count() > 50 {
597                            format!("{}...", desc.chars().take(50).collect::<String>())
598                        } else {
599                            desc.to_string()
600                        };
601                        eprintln!("   - {} : {}", tool.name, desc_short);
602                    }
603                }
604            }
605            Err(e) => {
606                eprintln!("⚠️  获取工具列表失败: {}", e);
607            }
608        }
609
610        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
611    }
612
613    // 使用 ProxyHandler + stdio 将本地 MCP 服务透明暴露为 stdio
614    let proxy_handler = ProxyHandler::with_tool_filter(running, name.to_string(), tool_filter);
615    let stdio_transport = stdio();
616    let server = proxy_handler.serve(stdio_transport).await?;
617    server.waiting().await?;
618
619    Ok(())
620}
621
622/// 获取协议名称
623fn protocol_name(protocol: &super::protocol::McpProtocol) -> &'static str {
624    match protocol {
625        super::protocol::McpProtocol::Sse => "SSE",
626        super::protocol::McpProtocol::Stream => "Streamable HTTP",
627        super::protocol::McpProtocol::Stdio => "Stdio",
628    }
629}
630
631/// 创建 HTTP 客户端(使用合并后的 headers)
632fn create_http_client_with_headers(
633    config_headers: &HashMap<String, String>,
634    cli_headers: &[(String, String)],
635    cli_auth: Option<&String>,
636    timeout: u64,
637) -> Result<reqwest::Client> {
638    let mut headers = reqwest::header::HeaderMap::new();
639
640    // 1. 先添加配置中的 headers
641    for (key, value) in config_headers {
642        headers.insert(
643            key.parse::<reqwest::header::HeaderName>()?,
644            value.parse()?,
645        );
646    }
647
648    // 2. 命令行 -H 参数覆盖
649    for (key, value) in cli_headers {
650        headers.insert(
651            key.parse::<reqwest::header::HeaderName>()?,
652            value.parse()?,
653        );
654    }
655
656    // 3. 命令行 --auth 参数优先级最高
657    if let Some(auth) = cli_auth {
658        headers.insert("Authorization", auth.parse()?);
659    }
660
661    let client = reqwest::Client::builder()
662        .default_headers(headers)
663        .timeout(tokio::time::Duration::from_secs(timeout))
664        .build()?;
665
666    Ok(client)
667}
668
669/// 运行检查命令
670async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
671    if !quiet {
672        eprintln!("🔍 检查服务: {}", args.url);
673    }
674
675    match super::protocol::detect_mcp_protocol(&args.url).await {
676        Ok(protocol) => {
677            if !quiet {
678                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
679            }
680            Ok(())
681        }
682        Err(e) => {
683            if !quiet {
684                eprintln!("❌ 服务检查失败: {}", e);
685            }
686            Err(e)
687        }
688    }
689}
690
691/// 运行协议检测命令
692async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
693    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
694
695    if quiet {
696        println!("{}", protocol);
697    } else {
698        eprintln!("{}", protocol);
699    }
700
701    Ok(())
702}