mcp_stdio_proxy/client/
cli.rs

1// MCP-Proxy CLI 简化实现 - 修复版本
2// 直接使用 rmcp 库的功能,无需复杂的 trait 抽象
3
4use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8    ServiceExt,
9    model::{ClientCapabilities, ClientInfo},
10    transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::ProxyHandler;
13
14/// MCP-Proxy CLI 主命令结构
15#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20    #[command(subcommand)]
21    pub command: Option<Commands>,
22    
23    /// 直接URL模式(向后兼容)
24    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25    pub url: Option<String>,
26    
27    /// 全局详细输出
28    #[arg(short, long, global = true)]
29    pub verbose: bool,
30    
31    /// 全局静默模式
32    #[arg(short, long, global = true)]
33    pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38    /// 协议转换模式 - 将 URL 转换为 stdio
39    Convert(ConvertArgs),
40    
41    /// 检查服务状态
42    Check(CheckArgs),
43    
44    /// 协议检测
45    Detect(DetectArgs),
46}
47
48/// 协议转换参数
49#[derive(Parser, Debug, Clone)]
50pub struct ConvertArgs {
51    /// MCP 服务的 URL 地址
52    #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
53    pub url: String,
54    
55    /// 认证 header (如: "Bearer token")
56    #[arg(short, long, help = "认证 header")]
57    pub auth: Option<String>,
58    
59    /// 自定义 HTTP headers
60    #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
61    pub header: Vec<(String, String)>,
62    
63    /// 连接超时时间(秒)
64    #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
65    pub timeout: u64,
66    
67    /// 重试次数
68    #[arg(long, default_value = "3", help = "重试次数")]
69    pub retries: u32,
70}
71
72/// 检查参数
73#[derive(Parser, Debug)]
74pub struct CheckArgs {
75    /// 要检查的 MCP 服务 URL
76    #[arg(value_name = "URL")]
77    pub url: String,
78    
79    /// 认证 header
80    #[arg(short, long)]
81    pub auth: Option<String>,
82    
83    /// 超时时间
84    #[arg(long, default_value = "10")]
85    pub timeout: u64,
86}
87
88/// 协议检测参数
89#[derive(Parser, Debug)]
90pub struct DetectArgs {
91    /// 要检测的 MCP 服务 URL
92    #[arg(value_name = "URL")]
93    pub url: String,
94    
95    /// 认证 header
96    #[arg(short, long)]
97    pub auth: Option<String>,
98}
99
100/// 解析 KEY=VALUE 格式的辅助函数
101fn parse_key_val(s: &str) -> Result<(String, String)> {
102    let pos = s.find('=')
103        .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
104    Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
105}
106
107/// 运行 CLI 主逻辑
108pub async fn run_cli(cli: Cli) -> Result<()> {
109    match cli.command {
110        Some(Commands::Convert(args)) => {
111            run_convert_command(args, cli.verbose, cli.quiet).await
112        }
113        Some(Commands::Check(args)) => {
114            run_check_command(args, cli.verbose, cli.quiet).await
115        }
116        Some(Commands::Detect(args)) => {
117            run_detect_command(args, cli.verbose, cli.quiet).await
118        }
119        None => {
120            // 直接 URL 模式(向后兼容)
121            if let Some(url) = cli.url {
122                let args = ConvertArgs {
123                    url,
124                    auth: None,
125                    header: vec![],
126                    timeout: 30,
127                    retries: 3,
128                };
129                run_convert_command(args, cli.verbose, cli.quiet).await
130            } else {
131                bail!("请提供 URL 或使用子命令")
132            }
133        }
134    }
135}
136
137/// 运行转换命令 - 核心功能
138async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
139    if !quiet {
140        eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
141        if verbose {
142            eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
143        }
144    }
145
146    // 使用统一的协议检测模块
147    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
148
149    if !quiet {
150        let proto_name = match protocol {
151            super::protocol::McpProtocol::Sse => "SSE",
152            super::protocol::McpProtocol::Stream => "Streamable HTTP",
153            super::protocol::McpProtocol::Stdio => "Stdio",
154        };
155        eprintln!("🔍 检测到 {} 协议", proto_name);
156        eprintln!("🔗 建立连接...");
157    }
158
159    // 构建带认证与自定义头的 HTTP 客户端
160    let http_client = create_http_client(&args)?;
161
162    // 配置客户端能力
163    let client_info = ClientInfo {
164        protocol_version: Default::default(),
165        capabilities: ClientCapabilities::builder()
166            .enable_experimental()
167            .enable_roots()
168            .enable_roots_list_changed()
169            .enable_sampling()
170            .build(),
171        ..Default::default()
172    };
173
174    // 为不同协议创建传输并启动 rmcp 客户端
175    let running = match protocol {
176        super::protocol::McpProtocol::Sse => {
177            let cfg = SseClientConfig {
178                sse_endpoint: args.url.clone().into(),
179                ..Default::default()
180            };
181            let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
182            client_info.serve(transport).await?
183        }
184        super::protocol::McpProtocol::Stream => {
185            let cfg = StreamableHttpClientTransportConfig {
186                uri: args.url.clone().into(),
187                ..Default::default()
188            };
189            let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
190            client_info.serve(transport).await?
191        }
192        super::protocol::McpProtocol::Stdio => {
193            bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
194        }
195    };
196
197    if !quiet {
198        eprintln!("✅ 连接成功,开始代理转换...");
199        eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
200    }
201
202    // 使用 ProxyHandler + stdio 将远程 MCP 服务透明暴露为本地 stdio
203    let proxy_handler = ProxyHandler::new(running);
204    let stdio_transport = stdio();
205    let server = proxy_handler.serve(stdio_transport).await?;
206    server.waiting().await?;
207
208    Ok(())
209}
210
211
212
213
214
215
216/// 创建 HTTP 客户端
217fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
218    let mut headers = reqwest::header::HeaderMap::new();
219    
220    // 添加认证 header
221    if let Some(auth) = &args.auth {
222        headers.insert("Authorization", auth.parse()?);
223    }
224    
225    // 添加自定义 headers
226    for (key, value) in &args.header {
227        headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
228    }
229    
230    let client = reqwest::Client::builder()
231        .default_headers(headers)
232        .timeout(tokio::time::Duration::from_secs(args.timeout))
233        .build()?;
234    
235    Ok(client)
236}
237
238
239
240/// 运行检查命令
241async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
242    if !quiet {
243        eprintln!("🔍 检查服务: {}", args.url);
244    }
245
246    match super::protocol::detect_mcp_protocol(&args.url).await {
247        Ok(protocol) => {
248            if !quiet {
249                eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
250            }
251            Ok(())
252        }
253        Err(e) => {
254            if !quiet {
255                eprintln!("❌ 服务检查失败: {}", e);
256            }
257            Err(e)
258        }
259    }
260}
261
262/// 运行协议检测命令
263async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
264    let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
265
266    if quiet {
267        println!("{}", protocol);
268    } else {
269        eprintln!("{}", protocol);
270    }
271
272    Ok(())
273}