mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use anyhow::{Result, bail};
5use clap::Parser;
6
7use crate::proxy::ProxyHandler;
8use rmcp::{
9    ServiceExt,
10    model::{ClientCapabilities, ClientInfo},
11    transport::{
12        StreamableHttpClientTransport, stdio,
13        streamable_http_client::StreamableHttpClientTransportConfig,
14    },
15};
16
17/// MCP-Proxy CLI 主命令结构
18#[derive(Parser, Debug)]
19#[command(name = "mcp-proxy")]
20#[command(version = env!("CARGO_PKG_VERSION"))]
21#[command(about = "MCP 协议转换代理工具", long_about = None)]
22pub struct Cli {
23    #[command(subcommand)]
24    pub command: Option<Commands>,
25
26    /// 直接URL模式(向后兼容)
27    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
28    pub url: Option<String>,
29
30    /// 全局详细输出
31    #[arg(short, long, global = true)]
32    pub verbose: bool,
33
34    /// 全局静默模式
35    #[arg(short, long, global = true)]
36    pub quiet: bool,
37}
38
39#[derive(clap::Subcommand, Debug)]
40pub enum Commands {
41    /// 协议转换模式 - 将 URL 转换为 stdio
42    Convert(ConvertArgs),
43
44    /// 检查服务状态
45    Check(CheckArgs),
46
47    /// 协议检测
48    Detect(DetectArgs),
49}
50
51/// 协议转换参数
52#[derive(Parser, Debug, Clone)]
53pub struct ConvertArgs {
54    /// MCP 服务的 URL 地址
55    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
56    pub url: String,
57
58    /// 认证 header (如: "Bearer token")
59    #[arg(short, long, help = "认证 header")]
60    pub auth: Option<String>,
61
62    /// 自定义 HTTP headers
63    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
64    pub header: Vec<(String, String)>,
65
66    /// 连接超时时间(秒)
67    #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
68    pub timeout: u64,
69
70    /// 重试次数
71    #[arg(long, default_value = "3", help = "重试次数")]
72    pub retries: u32,
73}
74
75/// 检查参数
76#[derive(Parser, Debug)]
77pub struct CheckArgs {
78    /// 要检查的 MCP 服务 URL
79    #[arg(value_name = "URL")]
80    pub url: String,
81
82    /// 认证 header
83    #[arg(short, long)]
84    pub auth: Option<String>,
85
86    /// 超时时间
87    #[arg(long, default_value = "10")]
88    pub timeout: u64,
89}
90
91/// 协议检测参数
92#[derive(Parser, Debug)]
93pub struct DetectArgs {
94    /// 要检测的 MCP 服务 URL
95    #[arg(value_name = "URL")]
96    pub url: String,
97
98    /// 认证 header
99    #[arg(short, long)]
100    pub auth: Option<String>,
101}
102
103/// 解析 KEY=VALUE 格式的辅助函数
104fn parse_key_val(s: &str) -> Result<(String, String)> {
105    let pos = s
106        .find('=')
107        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
108    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
109}
110
111/// 运行 CLI 主逻辑
112pub async fn run_cli(cli: Cli) -> Result<()> {
113    match cli.command {
114        Some(Commands::Convert(args)) => run_convert_command(args, cli.verbose, cli.quiet).await,
115        Some(Commands::Check(args)) => run_check_command(args, cli.verbose, cli.quiet).await,
116        Some(Commands::Detect(args)) => run_detect_command(args, cli.verbose, cli.quiet).await,
117        None => {
118            // 直接 URL 模式(向后兼容)
119            if let Some(url) = cli.url {
120                let args = ConvertArgs {
121                    url,
122                    auth: None,
123                    header: vec![],
124                    timeout: 30,
125                    retries: 3,
126                };
127                run_convert_command(args, cli.verbose, cli.quiet).await
128            } else {
129                bail!("请提供 URL 或使用子命令")
130            }
131        }
132    }
133}
134
135/// 运行转换命令 - 核心功能
136async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
137    if !quiet {
138        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
139        if verbose {
140            eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
141        }
142    }
143
144    // 使用统一的协议检测模块
145    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
146
147    if !quiet {
148        let proto_name = match protocol {
149            super::protocol::McpProtocol::Sse => "SSE",
150            super::protocol::McpProtocol::Stream => "Streamable HTTP",
151            super::protocol::McpProtocol::Stdio => "Stdio",
152        };
153        eprintln!("🔍 检测到 {} 协议", proto_name);
154        eprintln!("🔗 建立连接...");
155    }
156
157    // 构建带认证与自定义头的 HTTP 客户端
158    let http_client = create_http_client(&args)?;
159
160    // 配置客户端能力
161    let client_info = ClientInfo {
162        protocol_version: Default::default(),
163        capabilities: ClientCapabilities::builder()
164            .enable_experimental()
165            .enable_roots()
166            .enable_roots_list_changed()
167            .enable_sampling()
168            .build(),
169        ..Default::default()
170    };
171
172    // 为不同协议创建传输并启动 rmcp 客户端
173    // 注意: rmcp 0.12 移除了 SseClientTransport,SSE 和 Stream 协议都使用 StreamableHttpClientTransport
174    let running = match protocol {
175        super::protocol::McpProtocol::Sse | super::protocol::McpProtocol::Stream => {
176            let cfg = StreamableHttpClientTransportConfig {
177                uri: args.url.clone().into(),
178                ..Default::default()
179            };
180            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
181            client_info.serve(transport).await?
182        }
183        super::protocol::McpProtocol::Stdio => {
184            bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
185        }
186    };
187
188    if !quiet {
189        eprintln!("✅ 连接成功,开始代理转换...");
190        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
191    }
192
193    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
194    let proxy_handler = ProxyHandler::new(running);
195    let stdio_transport = stdio();
196    let server = proxy_handler.serve(stdio_transport).await?;
197    server.waiting().await?;
198
199    Ok(())
200}
201
202/// 创建 HTTP 客户端
203fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
204    let mut headers = reqwest::header::HeaderMap::new();
205
206    // 添加认证 header
207    if let Some(auth) = &args.auth {
208        headers.insert("Authorization", auth.parse()?);
209    }
210
211    // 添加自定义 headers
212    for (key, value) in &args.header {
213        headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
214    }
215
216    let client = reqwest::Client::builder()
217        .default_headers(headers)
218        .timeout(tokio::time::Duration::from_secs(args.timeout))
219        .build()?;
220
221    Ok(client)
222}
223
224/// 运行检查命令
225async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
226    if !quiet {
227        eprintln!("🔍 检查服务: {}", args.url);
228    }
229
230    match super::protocol::detect_mcp_protocol(&args.url).await {
231        Ok(protocol) => {
232            if !quiet {
233                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
234            }
235            Ok(())
236        }
237        Err(e) => {
238            if !quiet {
239                eprintln!("❌ 服务检查失败: {}", e);
240            }
241            Err(e)
242        }
243    }
244}
245
246/// 运行协议检测命令
247async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
248    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
249
250    if quiet {
251        println!("{}", protocol);
252    } else {
253        eprintln!("{}", protocol);
254    }
255
256    Ok(())
257}