mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8    ServiceExt,
9    model::{ClientCapabilities, ClientInfo},
10    transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::{ProxyHandler, ToolFilter};
13
14/// MCP-Proxy CLI 主命令结构
15#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20    #[command(subcommand)]
21    pub command: Option<Commands>,
22    
23    /// 直接URL模式(向后兼容)
24    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25    pub url: Option<String>,
26    
27    /// 全局详细输出
28    #[arg(short, long, global = true)]
29    pub verbose: bool,
30    
31    /// 全局静默模式
32    #[arg(short, long, global = true)]
33    pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38    /// 协议转换模式 - 将 URL 转换为 stdio
39    Convert(ConvertArgs),
40    
41    /// 检查服务状态
42    Check(CheckArgs),
43    
44    /// 协议检测
45    Detect(DetectArgs),
46}
47
48/// 协议转换参数
49#[derive(Parser, Debug, Clone)]
50pub struct ConvertArgs {
51    /// MCP 服务的 URL 地址
52    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
53    pub url: String,
54
55    /// 认证 header (如: "Bearer token")
56    #[arg(short, long, help = "认证 header")]
57    pub auth: Option<String>,
58
59    /// 自定义 HTTP headers
60    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
61    pub header: Vec<(String, String)>,
62
63    /// 连接超时时间(秒)
64    #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
65    pub timeout: u64,
66
67    /// 重试次数
68    #[arg(long, default_value = "3", help = "重试次数")]
69    pub retries: u32,
70
71    /// 工具白名单(逗号分隔),只允许指定的工具
72    #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
73    pub allow_tools: Option<Vec<String>>,
74
75    /// 工具黑名单(逗号分隔),排除指定的工具
76    #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
77    pub deny_tools: Option<Vec<String>>,
78}
79
80/// 检查参数
81#[derive(Parser, Debug)]
82pub struct CheckArgs {
83    /// 要检查的 MCP 服务 URL
84    #[arg(value_name = "URL")]
85    pub url: String,
86    
87    /// 认证 header
88    #[arg(short, long)]
89    pub auth: Option<String>,
90    
91    /// 超时时间
92    #[arg(long, default_value = "10")]
93    pub timeout: u64,
94}
95
96/// 协议检测参数
97#[derive(Parser, Debug)]
98pub struct DetectArgs {
99    /// 要检测的 MCP 服务 URL
100    #[arg(value_name = "URL")]
101    pub url: String,
102    
103    /// 认证 header
104    #[arg(short, long)]
105    pub auth: Option<String>,
106}
107
108/// 解析 KEY=VALUE 格式的辅助函数
109fn parse_key_val(s: &str) -> Result<(String, String)> {
110    let pos = s.find('=')
111        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
112    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
113}
114
115/// 运行 CLI 主逻辑
116pub async fn run_cli(cli: Cli) -> Result<()> {
117    match cli.command {
118        Some(Commands::Convert(args)) => {
119            run_convert_command(args, cli.verbose, cli.quiet).await
120        }
121        Some(Commands::Check(args)) => {
122            run_check_command(args, cli.verbose, cli.quiet).await
123        }
124        Some(Commands::Detect(args)) => {
125            run_detect_command(args, cli.verbose, cli.quiet).await
126        }
127        None => {
128            // 直接 URL 模式(向后兼容)
129            if let Some(url) = cli.url {
130                let args = ConvertArgs {
131                    url,
132                    auth: None,
133                    header: vec![],
134                    timeout: 30,
135                    retries: 3,
136                    allow_tools: None,
137                    deny_tools: None,
138                };
139                run_convert_command(args, cli.verbose, cli.quiet).await
140            } else {
141                bail!("请提供 URL 或使用子命令")
142            }
143        }
144    }
145}
146
147/// 运行转换命令 - 核心功能
148async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
149    // 检查 --allow-tools 和 --deny-tools 互斥
150    if args.allow_tools.is_some() && args.deny_tools.is_some() {
151        bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
152    }
153
154    // 创建工具过滤器
155    let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
156        ToolFilter::allow(allow_tools)
157    } else if let Some(deny_tools) = args.deny_tools.clone() {
158        ToolFilter::deny(deny_tools)
159    } else {
160        ToolFilter::default()
161    };
162
163    if !quiet {
164        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
165        if verbose {
166            eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
167        }
168        // 显示过滤器配置
169        if let Some(ref allow_tools) = args.allow_tools {
170            eprintln!("🔧 工具白名单: {:?}", allow_tools);
171        }
172        if let Some(ref deny_tools) = args.deny_tools {
173            eprintln!("🔧 工具黑名单: {:?}", deny_tools);
174        }
175    }
176
177    // 使用统一的协议检测模块
178    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
179
180    if !quiet {
181        let proto_name = match protocol {
182            super::protocol::McpProtocol::Sse => "SSE",
183            super::protocol::McpProtocol::Stream => "Streamable HTTP",
184            super::protocol::McpProtocol::Stdio => "Stdio",
185        };
186        eprintln!("🔍 检测到 {} 协议", proto_name);
187        eprintln!("🔗 建立连接...");
188    }
189
190    // 构建带认证与自定义头的 HTTP 客户端
191    let http_client = create_http_client(&args)?;
192
193    // 配置客户端能力
194    let client_info = ClientInfo {
195        protocol_version: Default::default(),
196        capabilities: ClientCapabilities::builder()
197            .enable_experimental()
198            .enable_roots()
199            .enable_roots_list_changed()
200            .enable_sampling()
201            .build(),
202        ..Default::default()
203    };
204
205    // 为不同协议创建传输并启动 rmcp 客户端
206    let running = match protocol {
207        super::protocol::McpProtocol::Sse => {
208            let cfg = SseClientConfig {
209                sse_endpoint: args.url.clone().into(),
210                ..Default::default()
211            };
212            let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
213            client_info.serve(transport).await?
214        }
215        super::protocol::McpProtocol::Stream => {
216            let cfg = StreamableHttpClientTransportConfig {
217                uri: args.url.clone().into(),
218                ..Default::default()
219            };
220            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
221            client_info.serve(transport).await?
222        }
223        super::protocol::McpProtocol::Stdio => {
224            bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
225        }
226    };
227
228    if !quiet {
229        eprintln!("✅ 连接成功,开始代理转换...");
230        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
231    }
232
233    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
234    let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
235    let stdio_transport = stdio();
236    let server = proxy_handler.serve(stdio_transport).await?;
237    server.waiting().await?;
238
239    Ok(())
240}
241
242
243
244
245
246
247/// 创建 HTTP 客户端
248fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
249    let mut headers = reqwest::header::HeaderMap::new();
250    
251    // 添加认证 header
252    if let Some(auth) = &args.auth {
253        headers.insert("Authorization", auth.parse()?);
254    }
255    
256    // 添加自定义 headers
257    for (key, value) in &args.header {
258        headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
259    }
260    
261    let client = reqwest::Client::builder()
262        .default_headers(headers)
263        .timeout(tokio::time::Duration::from_secs(args.timeout))
264        .build()?;
265    
266    Ok(client)
267}
268
269
270
271/// 运行检查命令
272async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
273    if !quiet {
274        eprintln!("🔍 检查服务: {}", args.url);
275    }
276
277    match super::protocol::detect_mcp_protocol(&args.url).await {
278        Ok(protocol) => {
279            if !quiet {
280                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
281            }
282            Ok(())
283        }
284        Err(e) => {
285            if !quiet {
286                eprintln!("❌ 服务检查失败: {}", e);
287            }
288            Err(e)
289        }
290    }
291}
292
293/// 运行协议检测命令
294async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
295    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
296
297    if quiet {
298        println!("{}", protocol);
299    } else {
300        eprintln!("{}", protocol);
301    }
302
303    Ok(())
304}