mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use std::collections::HashMap;
5
6use clap::Parser;
7use anyhow::{Result, bail};
8use serde::Deserialize;
9use tokio::process::Command;
10
11use rmcp::{
12    ServiceExt,
13    model::{ClientCapabilities, ClientInfo},
14    transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
15};
16use crate::proxy::{ProxyHandler, ToolFilter};
17
18/// MCP-Proxy CLI 主命令结构
19#[derive(Parser, Debug)]
20#[command(name = "mcp-proxy")]
21#[command(version = env!("CARGO_PKG_VERSION"))]
22#[command(about = "MCP 协议转换代理工具", long_about = None)]
23pub struct Cli {
24    #[command(subcommand)]
25    pub command: Option<Commands>,
26    
27    /// 直接URL模式(向后兼容)
28    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
29    pub url: Option<String>,
30    
31    /// 全局详细输出
32    #[arg(short, long, global = true)]
33    pub verbose: bool,
34    
35    /// 全局静默模式
36    #[arg(short, long, global = true)]
37    pub quiet: bool,
38}
39
40#[derive(clap::Subcommand, Debug)]
41pub enum Commands {
42    /// 协议转换模式 - 将 URL 转换为 stdio
43    Convert(ConvertArgs),
44
45    /// 检查服务状态
46    Check(CheckArgs),
47
48    /// 协议检测
49    Detect(DetectArgs),
50
51    /// 代理模式 - 将 stdio MCP 服务代理为 HTTP/SSE 服务
52    Proxy(super::proxy_server::ProxyArgs),
53}
54
55/// 协议转换参数
56#[derive(Parser, Debug, Clone)]
57pub struct ConvertArgs {
58    /// MCP 服务的 URL 地址(可选,与 --config/--config-file 二选一)
59    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
60    pub url: Option<String>,
61
62    /// MCP 服务配置 JSON
63    #[arg(long, conflicts_with = "config_file", help = "MCP 服务配置 JSON")]
64    pub config: Option<String>,
65
66    /// MCP 服务配置文件路径
67    #[arg(long, conflicts_with = "config", help = "MCP 服务配置文件路径")]
68    pub config_file: Option<std::path::PathBuf>,
69
70    /// MCP 服务名称(多服务配置时必需)
71    #[arg(short, long, help = "MCP 服务名称(多服务配置时必需)")]
72    pub name: Option<String>,
73
74    /// 指定远程服务协议类型(不指定则自动检测)
75    #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
76    pub protocol: Option<super::proxy_server::ProxyProtocol>,
77
78    /// 认证 header (如: "Bearer token")
79    #[arg(short, long, help = "认证 header")]
80    pub auth: Option<String>,
81
82    /// 自定义 HTTP headers
83    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
84    pub header: Vec<(String, String)>,
85
86    /// 连接超时时间(秒)
87    #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
88    pub timeout: u64,
89
90    /// 重试次数
91    #[arg(long, default_value = "3", help = "重试次数")]
92    pub retries: u32,
93
94    /// 工具白名单(逗号分隔),只允许指定的工具
95    #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
96    pub allow_tools: Option<Vec<String>>,
97
98    /// 工具黑名单(逗号分隔),排除指定的工具
99    #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
100    pub deny_tools: Option<Vec<String>>,
101}
102
103/// 检查参数
104#[derive(Parser, Debug)]
105pub struct CheckArgs {
106    /// 要检查的 MCP 服务 URL
107    #[arg(value_name = "URL")]
108    pub url: String,
109    
110    /// 认证 header
111    #[arg(short, long)]
112    pub auth: Option<String>,
113    
114    /// 超时时间
115    #[arg(long, default_value = "10")]
116    pub timeout: u64,
117}
118
119/// 协议检测参数
120#[derive(Parser, Debug)]
121pub struct DetectArgs {
122    /// 要检测的 MCP 服务 URL
123    #[arg(value_name = "URL")]
124    pub url: String,
125    
126    /// 认证 header
127    #[arg(short, long)]
128    pub auth: Option<String>,
129}
130
131/// 解析 KEY=VALUE 格式的辅助函数
132fn parse_key_val(s: &str) -> Result<(String, String)> {
133    let pos = s.find('=')
134        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
135    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
136}
137
138// ============== MCP 配置解析相关 ==============
139
140/// MCP 配置格式
141#[derive(Deserialize, Debug)]
142struct McpConfig {
143    #[serde(rename = "mcpServers")]
144    mcp_servers: HashMap<String, McpServerInnerConfig>,
145}
146
147/// MCP 服务配置(支持 Command 和 Url 两种类型)
148#[derive(Deserialize, Debug, Clone)]
149#[serde(untagged)]
150enum McpServerInnerConfig {
151    Command(StdioConfig),
152    Url(UrlConfig),
153}
154
155/// stdio 配置(本地命令)
156#[derive(Deserialize, Debug, Clone)]
157struct StdioConfig {
158    command: String,
159    args: Option<Vec<String>>,
160    env: Option<HashMap<String, String>>,
161}
162
163/// URL 配置(远程服务)
164#[derive(Deserialize, Debug, Clone)]
165struct UrlConfig {
166    #[serde(skip_serializing_if = "Option::is_none")]
167    url: Option<String>,
168    #[serde(
169        skip_serializing_if = "Option::is_none",
170        default,
171        rename = "baseUrl",
172        alias = "baseurl",
173        alias = "base_url"
174    )]
175    base_url: Option<String>,
176    #[serde(default, rename = "type", alias = "Type")]
177    r#type: Option<String>,
178    pub headers: Option<HashMap<String, String>>,
179    #[serde(default, alias = "authToken", alias = "auth_token")]
180    pub auth_token: Option<String>,
181    pub timeout: Option<u64>,
182}
183
184impl UrlConfig {
185    fn get_url(&self) -> Option<&str> {
186        self.url.as_deref().or(self.base_url.as_deref())
187    }
188}
189
190/// 解析后的配置源
191enum McpConfigSource {
192    /// 直接 URL 模式(命令行参数)
193    DirectUrl {
194        url: String,
195    },
196    /// 远程服务配置(JSON 配置)
197    RemoteService {
198        name: String,
199        url: String,
200        protocol: Option<super::protocol::McpProtocol>,
201        headers: HashMap<String, String>,
202        timeout: Option<u64>,
203    },
204    /// 本地命令配置(JSON 配置)
205    LocalCommand {
206        name: String,
207        command: String,
208        args: Vec<String>,
209        env: HashMap<String, String>,
210    },
211}
212
213/// 解析 convert 命令的配置
214fn parse_convert_config(args: &ConvertArgs) -> Result<McpConfigSource> {
215    // 优先级:url > config > config_file
216    if let Some(ref url) = args.url {
217        return Ok(McpConfigSource::DirectUrl { url: url.clone() });
218    }
219
220    // 读取 JSON 配置
221    let json_str = if let Some(ref config) = args.config {
222        config.clone()
223    } else if let Some(ref path) = args.config_file {
224        std::fs::read_to_string(path)
225            .map_err(|e| anyhow::anyhow!("读取配置文件失败: {}", e))?
226    } else {
227        bail!("必须提供 URL、--config 或 --config-file 参数之一");
228    };
229
230    // 解析 JSON 配置
231    let mcp_config: McpConfig = serde_json::from_str(&json_str)
232        .map_err(|e| anyhow::anyhow!(
233            "配置解析失败: {}。配置必须是标准 MCP 格式,包含 mcpServers 字段",
234            e
235        ))?;
236
237    let servers = mcp_config.mcp_servers;
238
239    if servers.is_empty() {
240        bail!("配置中没有找到任何 MCP 服务");
241    }
242
243    // 选择服务
244    let (name, inner_config) = if servers.len() == 1 {
245        servers.into_iter().next().unwrap()
246    } else if let Some(ref name) = args.name {
247        let config = servers.get(name)
248            .cloned()
249            .ok_or_else(|| anyhow::anyhow!(
250                "服务 '{}' 不存在。可用服务: {:?}",
251                name,
252                servers.keys().collect::<Vec<_>>()
253            ))?;
254        (name.clone(), config)
255    } else {
256        bail!(
257            "配置包含多个服务 {:?},请使用 --name 指定要使用的服务",
258            servers.keys().collect::<Vec<_>>()
259        );
260    };
261
262    // 根据配置类型返回
263    match inner_config {
264        McpServerInnerConfig::Command(stdio) => {
265            Ok(McpConfigSource::LocalCommand {
266                name,
267                command: stdio.command,
268                args: stdio.args.unwrap_or_default(),
269                env: stdio.env.unwrap_or_default(),
270            })
271        }
272        McpServerInnerConfig::Url(url_config) => {
273            let url = url_config.get_url()
274                .ok_or_else(|| anyhow::anyhow!("URL 配置缺少 url 或 baseUrl 字段"))?
275                .to_string();
276
277            // 解析协议类型
278            let protocol = url_config.r#type.as_ref().and_then(|t| {
279                match t.as_str() {
280                    "sse" => Some(super::protocol::McpProtocol::Sse),
281                    "http" | "stream" => Some(super::protocol::McpProtocol::Stream),
282                    _ => None,
283                }
284            });
285
286            // 合并 headers:JSON 配置中的 auth_token -> Authorization
287            let mut headers = url_config.headers.clone().unwrap_or_default();
288            if let Some(auth_token) = &url_config.auth_token {
289                headers.insert("Authorization".to_string(), auth_token.clone());
290            }
291
292            Ok(McpConfigSource::RemoteService {
293                name,
294                url,
295                protocol,
296                headers,
297                timeout: url_config.timeout,
298            })
299        }
300    }
301}
302
303/// 合并 headers:JSON 配置 + 命令行参数(命令行优先)
304fn merge_headers(
305    config_headers: HashMap<String, String>,
306    cli_headers: &[(String, String)],
307    cli_auth: Option<&String>,
308) -> HashMap<String, String> {
309    let mut merged = config_headers;
310
311    // 命令行 -H 参数覆盖配置
312    for (key, value) in cli_headers {
313        merged.insert(key.clone(), value.clone());
314    }
315
316    // 命令行 --auth 参数优先级最高
317    if let Some(auth_value) = cli_auth {
318        merged.insert("Authorization".to_string(), auth_value.clone());
319    }
320
321    merged
322}
323
324/// 运行 CLI 主逻辑
325pub async fn run_cli(cli: Cli) -> Result<()> {
326    match cli.command {
327        Some(Commands::Convert(args)) => {
328            run_convert_command(args, cli.verbose, cli.quiet).await
329        }
330        Some(Commands::Check(args)) => {
331            run_check_command(args, cli.verbose, cli.quiet).await
332        }
333        Some(Commands::Detect(args)) => {
334            run_detect_command(args, cli.verbose, cli.quiet).await
335        }
336        Some(Commands::Proxy(args)) => {
337            super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
338        }
339        None => {
340            // 直接 URL 模式(向后兼容)
341            if let Some(url) = cli.url {
342                let args = ConvertArgs {
343                    url: Some(url),
344                    config: None,
345                    config_file: None,
346                    name: None,
347                    protocol: None,
348                    auth: None,
349                    header: vec![],
350                    timeout: 30,
351                    retries: 3,
352                    allow_tools: None,
353                    deny_tools: None,
354                };
355                run_convert_command(args, cli.verbose, cli.quiet).await
356            } else {
357                bail!("请提供 URL 或使用子命令")
358            }
359        }
360    }
361}
362
363/// 运行转换命令 - 核心功能
364async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
365    // 检查 --allow-tools 和 --deny-tools 互斥
366    if args.allow_tools.is_some() && args.deny_tools.is_some() {
367        bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
368    }
369
370    // 创建工具过滤器
371    let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
372        ToolFilter::allow(allow_tools)
373    } else if let Some(deny_tools) = args.deny_tools.clone() {
374        ToolFilter::deny(deny_tools)
375    } else {
376        ToolFilter::default()
377    };
378
379    // 解析配置
380    let config_source = parse_convert_config(&args)?;
381
382    // 配置客户端能力
383    let client_info = ClientInfo {
384        protocol_version: Default::default(),
385        capabilities: ClientCapabilities::builder()
386            .enable_experimental()
387            .enable_roots()
388            .enable_roots_list_changed()
389            .enable_sampling()
390            .build(),
391        ..Default::default()
392    };
393
394    // 根据配置源执行不同逻辑
395    match config_source {
396        McpConfigSource::DirectUrl { url } => {
397            // 直接 URL 模式(原有逻辑)
398            run_url_mode(&args, &url, HashMap::new(), None, client_info, tool_filter, verbose, quiet).await
399        }
400        McpConfigSource::RemoteService { name, url, protocol, headers, timeout } => {
401            // 远程服务配置模式
402            if !quiet {
403                eprintln!("🚀 MCP-Stdio-Proxy: {} ({}) → stdio", name, url);
404            }
405            // 合并 headers:配置 + 命令行
406            let merged_headers = merge_headers(headers, &args.header, args.auth.as_ref());
407            run_url_mode(&args, &url, merged_headers, protocol.or(timeout.map(|_| super::protocol::McpProtocol::Stream)), client_info, tool_filter, verbose, quiet).await
408        }
409        McpConfigSource::LocalCommand { name, command, args: cmd_args, env } => {
410            // 本地命令模式
411            run_command_mode(&name, &command, cmd_args, env, client_info, tool_filter, verbose, quiet).await
412        }
413    }
414}
415
416/// URL 模式执行(远程 HTTP/SSE 服务)
417async fn run_url_mode(
418    args: &ConvertArgs,
419    url: &str,
420    merged_headers: HashMap<String, String>,
421    config_protocol: Option<super::protocol::McpProtocol>,
422    client_info: ClientInfo,
423    tool_filter: ToolFilter,
424    verbose: bool,
425    quiet: bool,
426) -> Result<()> {
427    if !quiet && merged_headers.is_empty() {
428        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", url);
429    }
430
431    if verbose && !quiet {
432        eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
433    }
434
435    // 显示过滤器配置
436    if !quiet {
437        if let Some(ref allow_tools) = args.allow_tools {
438            eprintln!("🔧 工具白名单: {:?}", allow_tools);
439        }
440        if let Some(ref deny_tools) = args.deny_tools {
441            eprintln!("🔧 工具黑名单: {:?}", deny_tools);
442        }
443    }
444
445    // 确定协议类型:命令行参数 > 配置文件 > 自动检测
446    let protocol = if let Some(ref proto) = args.protocol {
447        // 命令行指定协议
448        let detected = match proto {
449            super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
450            super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
451        };
452        if !quiet {
453            eprintln!("🔧 使用指定协议: {}", protocol_name(&detected));
454        }
455        detected
456    } else if let Some(proto) = config_protocol {
457        // 配置文件指定协议
458        if !quiet {
459            eprintln!("🔧 使用配置协议: {}", protocol_name(&proto));
460        }
461        proto
462    } else {
463        // 自动检测协议
464        let detected = super::protocol::detect_mcp_protocol(url).await?;
465        if !quiet {
466            eprintln!("🔍 检测到 {} 协议", protocol_name(&detected));
467        }
468        detected
469    };
470
471    if !quiet {
472        eprintln!("🔗 建立连接...");
473    }
474
475    // 构建带认证与自定义头的 HTTP 客户端
476    let http_client = create_http_client_with_headers(&merged_headers, &args.header, args.auth.as_ref(), args.timeout)?;
477
478    // 为不同协议创建传输并启动 rmcp 客户端
479    let running = match protocol {
480        super::protocol::McpProtocol::Sse => {
481            let cfg = SseClientConfig {
482                sse_endpoint: url.to_string().into(),
483                ..Default::default()
484            };
485            let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
486            client_info.serve(transport).await?
487        }
488        super::protocol::McpProtocol::Stream => {
489            let cfg = StreamableHttpClientTransportConfig {
490                uri: url.to_string().into(),
491                ..Default::default()
492            };
493            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
494            client_info.serve(transport).await?
495        }
496        super::protocol::McpProtocol::Stdio => {
497            bail!("Stdio 协议不支持通过 URL 转换,请使用 --config 配置本地命令")
498        }
499    };
500
501    if !quiet {
502        eprintln!("✅ 连接成功,开始代理转换...");
503        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
504    }
505
506    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
507    let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
508    let stdio_transport = stdio();
509    let server = proxy_handler.serve(stdio_transport).await?;
510    server.waiting().await?;
511
512    Ok(())
513}
514
515/// 命令模式执行(本地子进程)
516async fn run_command_mode(
517    name: &str,
518    command: &str,
519    cmd_args: Vec<String>,
520    env: HashMap<String, String>,
521    client_info: ClientInfo,
522    tool_filter: ToolFilter,
523    verbose: bool,
524    quiet: bool,
525) -> Result<()> {
526    if !quiet {
527        eprintln!("🚀 MCP-Stdio-Proxy: {} (command) → stdio", name);
528        eprintln!("   命令: {} {:?}", command, cmd_args);
529        if verbose && !env.is_empty() {
530            eprintln!("   环境变量: {:?}", env);
531        }
532    }
533
534    // 显示过滤器配置
535    if !quiet {
536        if tool_filter.is_enabled() {
537            eprintln!("🔧 工具过滤已启用");
538        }
539    }
540
541    // 创建子进程命令
542    let mut cmd = Command::new(command);
543    cmd.args(&cmd_args);
544    for (k, v) in &env {
545        cmd.env(k, v);
546    }
547
548    // 启动子进程
549    let tokio_process = TokioChildProcess::new(cmd)?;
550
551    if !quiet {
552        eprintln!("🔗 启动子进程...");
553    }
554
555    // 连接到子进程
556    let running = client_info.serve(tokio_process).await?;
557
558    if !quiet {
559        eprintln!("✅ 子进程已启动,开始代理转换...");
560        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
561    }
562
563    // 使用 ProxyHandler + stdio 将本地 MCP 服务透明暴露为 stdio
564    let proxy_handler = ProxyHandler::with_tool_filter(running, name.to_string(), tool_filter);
565    let stdio_transport = stdio();
566    let server = proxy_handler.serve(stdio_transport).await?;
567    server.waiting().await?;
568
569    Ok(())
570}
571
572/// 获取协议名称
573fn protocol_name(protocol: &super::protocol::McpProtocol) -> &'static str {
574    match protocol {
575        super::protocol::McpProtocol::Sse => "SSE",
576        super::protocol::McpProtocol::Stream => "Streamable HTTP",
577        super::protocol::McpProtocol::Stdio => "Stdio",
578    }
579}
580
581/// 创建 HTTP 客户端(使用合并后的 headers)
582fn create_http_client_with_headers(
583    config_headers: &HashMap<String, String>,
584    cli_headers: &[(String, String)],
585    cli_auth: Option<&String>,
586    timeout: u64,
587) -> Result<reqwest::Client> {
588    let mut headers = reqwest::header::HeaderMap::new();
589
590    // 1. 先添加配置中的 headers
591    for (key, value) in config_headers {
592        headers.insert(
593            key.parse::<reqwest::header::HeaderName>()?,
594            value.parse()?,
595        );
596    }
597
598    // 2. 命令行 -H 参数覆盖
599    for (key, value) in cli_headers {
600        headers.insert(
601            key.parse::<reqwest::header::HeaderName>()?,
602            value.parse()?,
603        );
604    }
605
606    // 3. 命令行 --auth 参数优先级最高
607    if let Some(auth) = cli_auth {
608        headers.insert("Authorization", auth.parse()?);
609    }
610
611    let client = reqwest::Client::builder()
612        .default_headers(headers)
613        .timeout(tokio::time::Duration::from_secs(timeout))
614        .build()?;
615
616    Ok(client)
617}
618
619/// 运行检查命令
620async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
621    if !quiet {
622        eprintln!("🔍 检查服务: {}", args.url);
623    }
624
625    match super::protocol::detect_mcp_protocol(&args.url).await {
626        Ok(protocol) => {
627            if !quiet {
628                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
629            }
630            Ok(())
631        }
632        Err(e) => {
633            if !quiet {
634                eprintln!("❌ 服务检查失败: {}", e);
635            }
636            Err(e)
637        }
638    }
639}
640
641/// 运行协议检测命令
642async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
643    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
644
645    if quiet {
646        println!("{}", protocol);
647    } else {
648        eprintln!("{}", protocol);
649    }
650
651    Ok(())
652}