mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8    ServiceExt,
9    model::{ClientCapabilities, ClientInfo},
10    transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::{ProxyHandler, ToolFilter};
13
14/// MCP-Proxy CLI 主命令结构
15#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20    #[command(subcommand)]
21    pub command: Option<Commands>,
22    
23    /// 直接URL模式(向后兼容)
24    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25    pub url: Option<String>,
26    
27    /// 全局详细输出
28    #[arg(short, long, global = true)]
29    pub verbose: bool,
30    
31    /// 全局静默模式
32    #[arg(short, long, global = true)]
33    pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38    /// 协议转换模式 - 将 URL 转换为 stdio
39    Convert(ConvertArgs),
40
41    /// 检查服务状态
42    Check(CheckArgs),
43
44    /// 协议检测
45    Detect(DetectArgs),
46
47    /// 代理模式 - 将 stdio MCP 服务代理为 HTTP/SSE 服务
48    Proxy(super::proxy_server::ProxyArgs),
49}
50
51/// 协议转换参数
52#[derive(Parser, Debug, Clone)]
53pub struct ConvertArgs {
54    /// MCP 服务的 URL 地址
55    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
56    pub url: String,
57
58    /// 指定远程服务协议类型(不指定则自动检测)
59    #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
60    pub protocol: Option<super::proxy_server::ProxyProtocol>,
61
62    /// 认证 header (如: "Bearer token")
63    #[arg(short, long, help = "认证 header")]
64    pub auth: Option<String>,
65
66    /// 自定义 HTTP headers
67    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
68    pub header: Vec<(String, String)>,
69
70    /// 连接超时时间(秒)
71    #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
72    pub timeout: u64,
73
74    /// 重试次数
75    #[arg(long, default_value = "3", help = "重试次数")]
76    pub retries: u32,
77
78    /// 工具白名单(逗号分隔),只允许指定的工具
79    #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
80    pub allow_tools: Option<Vec<String>>,
81
82    /// 工具黑名单(逗号分隔),排除指定的工具
83    #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
84    pub deny_tools: Option<Vec<String>>,
85}
86
87/// 检查参数
88#[derive(Parser, Debug)]
89pub struct CheckArgs {
90    /// 要检查的 MCP 服务 URL
91    #[arg(value_name = "URL")]
92    pub url: String,
93    
94    /// 认证 header
95    #[arg(short, long)]
96    pub auth: Option<String>,
97    
98    /// 超时时间
99    #[arg(long, default_value = "10")]
100    pub timeout: u64,
101}
102
103/// 协议检测参数
104#[derive(Parser, Debug)]
105pub struct DetectArgs {
106    /// 要检测的 MCP 服务 URL
107    #[arg(value_name = "URL")]
108    pub url: String,
109    
110    /// 认证 header
111    #[arg(short, long)]
112    pub auth: Option<String>,
113}
114
115/// 解析 KEY=VALUE 格式的辅助函数
116fn parse_key_val(s: &str) -> Result<(String, String)> {
117    let pos = s.find('=')
118        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
119    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
120}
121
122/// 运行 CLI 主逻辑
123pub async fn run_cli(cli: Cli) -> Result<()> {
124    match cli.command {
125        Some(Commands::Convert(args)) => {
126            run_convert_command(args, cli.verbose, cli.quiet).await
127        }
128        Some(Commands::Check(args)) => {
129            run_check_command(args, cli.verbose, cli.quiet).await
130        }
131        Some(Commands::Detect(args)) => {
132            run_detect_command(args, cli.verbose, cli.quiet).await
133        }
134        Some(Commands::Proxy(args)) => {
135            super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
136        }
137        None => {
138            // 直接 URL 模式(向后兼容)
139            if let Some(url) = cli.url {
140                let args = ConvertArgs {
141                    url,
142                    protocol: None,
143                    auth: None,
144                    header: vec![],
145                    timeout: 30,
146                    retries: 3,
147                    allow_tools: None,
148                    deny_tools: None,
149                };
150                run_convert_command(args, cli.verbose, cli.quiet).await
151            } else {
152                bail!("请提供 URL 或使用子命令")
153            }
154        }
155    }
156}
157
158/// 运行转换命令 - 核心功能
159async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
160    // 检查 --allow-tools 和 --deny-tools 互斥
161    if args.allow_tools.is_some() && args.deny_tools.is_some() {
162        bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
163    }
164
165    // 创建工具过滤器
166    let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
167        ToolFilter::allow(allow_tools)
168    } else if let Some(deny_tools) = args.deny_tools.clone() {
169        ToolFilter::deny(deny_tools)
170    } else {
171        ToolFilter::default()
172    };
173
174    if !quiet {
175        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
176        if verbose {
177            eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
178        }
179        // 显示过滤器配置
180        if let Some(ref allow_tools) = args.allow_tools {
181            eprintln!("🔧 工具白名单: {:?}", allow_tools);
182        }
183        if let Some(ref deny_tools) = args.deny_tools {
184            eprintln!("🔧 工具黑名单: {:?}", deny_tools);
185        }
186    }
187
188    // 确定协议类型:手动指定或自动检测
189    let protocol = if let Some(ref proto) = args.protocol {
190        // 手动指定协议
191        let detected = match proto {
192            super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
193            super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
194        };
195        if !quiet {
196            let proto_name = match detected {
197                super::protocol::McpProtocol::Sse => "SSE",
198                super::protocol::McpProtocol::Stream => "Streamable HTTP",
199                super::protocol::McpProtocol::Stdio => "Stdio",
200            };
201            eprintln!("🔧 使用指定协议: {}", proto_name);
202        }
203        detected
204    } else {
205        // 自动检测协议
206        let detected = super::protocol::detect_mcp_protocol(&args.url).await?;
207        if !quiet {
208            let proto_name = match detected {
209                super::protocol::McpProtocol::Sse => "SSE",
210                super::protocol::McpProtocol::Stream => "Streamable HTTP",
211                super::protocol::McpProtocol::Stdio => "Stdio",
212            };
213            eprintln!("🔍 检测到 {} 协议", proto_name);
214        }
215        detected
216    };
217
218    if !quiet {
219        eprintln!("🔗 建立连接...");
220    }
221
222    // 构建带认证与自定义头的 HTTP 客户端
223    let http_client = create_http_client(&args)?;
224
225    // 配置客户端能力
226    let client_info = ClientInfo {
227        protocol_version: Default::default(),
228        capabilities: ClientCapabilities::builder()
229            .enable_experimental()
230            .enable_roots()
231            .enable_roots_list_changed()
232            .enable_sampling()
233            .build(),
234        ..Default::default()
235    };
236
237    // 为不同协议创建传输并启动 rmcp 客户端
238    let running = match protocol {
239        super::protocol::McpProtocol::Sse => {
240            let cfg = SseClientConfig {
241                sse_endpoint: args.url.clone().into(),
242                ..Default::default()
243            };
244            let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
245            client_info.serve(transport).await?
246        }
247        super::protocol::McpProtocol::Stream => {
248            let cfg = StreamableHttpClientTransportConfig {
249                uri: args.url.clone().into(),
250                ..Default::default()
251            };
252            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
253            client_info.serve(transport).await?
254        }
255        super::protocol::McpProtocol::Stdio => {
256            bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
257        }
258    };
259
260    if !quiet {
261        eprintln!("✅ 连接成功,开始代理转换...");
262        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
263    }
264
265    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
266    let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
267    let stdio_transport = stdio();
268    let server = proxy_handler.serve(stdio_transport).await?;
269    server.waiting().await?;
270
271    Ok(())
272}
273
274
275
276
277
278
279/// 创建 HTTP 客户端
280fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
281    let mut headers = reqwest::header::HeaderMap::new();
282    
283    // 添加认证 header
284    if let Some(auth) = &args.auth {
285        headers.insert("Authorization", auth.parse()?);
286    }
287    
288    // 添加自定义 headers
289    for (key, value) in &args.header {
290        headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
291    }
292    
293    let client = reqwest::Client::builder()
294        .default_headers(headers)
295        .timeout(tokio::time::Duration::from_secs(args.timeout))
296        .build()?;
297    
298    Ok(client)
299}
300
301
302
303/// 运行检查命令
304async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
305    if !quiet {
306        eprintln!("🔍 检查服务: {}", args.url);
307    }
308
309    match super::protocol::detect_mcp_protocol(&args.url).await {
310        Ok(protocol) => {
311            if !quiet {
312                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
313            }
314            Ok(())
315        }
316        Err(e) => {
317            if !quiet {
318                eprintln!("❌ 服务检查失败: {}", e);
319            }
320            Err(e)
321        }
322    }
323}
324
325/// 运行协议检测命令
326async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
327    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
328
329    if quiet {
330        println!("{}", protocol);
331    } else {
332        eprintln!("{}", protocol);
333    }
334
335    Ok(())
336}