1use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8 ServiceExt,
9 model::{ClientCapabilities, ClientInfo},
10 transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::{ProxyHandler, ToolFilter};
13
14#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20 #[command(subcommand)]
21 pub command: Option<Commands>,
22
23 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25 pub url: Option<String>,
26
27 #[arg(short, long, global = true)]
29 pub verbose: bool,
30
31 #[arg(short, long, global = true)]
33 pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38 Convert(ConvertArgs),
40
41 Check(CheckArgs),
43
44 Detect(DetectArgs),
46
47 Proxy(super::proxy_server::ProxyArgs),
49}
50
51#[derive(Parser, Debug, Clone)]
53pub struct ConvertArgs {
54 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
56 pub url: String,
57
58 #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
60 pub protocol: Option<super::proxy_server::ProxyProtocol>,
61
62 #[arg(short, long, help = "认证 header")]
64 pub auth: Option<String>,
65
66 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
68 pub header: Vec<(String, String)>,
69
70 #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
72 pub timeout: u64,
73
74 #[arg(long, default_value = "3", help = "重试次数")]
76 pub retries: u32,
77
78 #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
80 pub allow_tools: Option<Vec<String>>,
81
82 #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
84 pub deny_tools: Option<Vec<String>>,
85}
86
87#[derive(Parser, Debug)]
89pub struct CheckArgs {
90 #[arg(value_name = "URL")]
92 pub url: String,
93
94 #[arg(short, long)]
96 pub auth: Option<String>,
97
98 #[arg(long, default_value = "10")]
100 pub timeout: u64,
101}
102
103#[derive(Parser, Debug)]
105pub struct DetectArgs {
106 #[arg(value_name = "URL")]
108 pub url: String,
109
110 #[arg(short, long)]
112 pub auth: Option<String>,
113}
114
115fn parse_key_val(s: &str) -> Result<(String, String)> {
117 let pos = s.find('=')
118 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
119 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
120}
121
122pub async fn run_cli(cli: Cli) -> Result<()> {
124 match cli.command {
125 Some(Commands::Convert(args)) => {
126 run_convert_command(args, cli.verbose, cli.quiet).await
127 }
128 Some(Commands::Check(args)) => {
129 run_check_command(args, cli.verbose, cli.quiet).await
130 }
131 Some(Commands::Detect(args)) => {
132 run_detect_command(args, cli.verbose, cli.quiet).await
133 }
134 Some(Commands::Proxy(args)) => {
135 super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
136 }
137 None => {
138 if let Some(url) = cli.url {
140 let args = ConvertArgs {
141 url,
142 protocol: None,
143 auth: None,
144 header: vec![],
145 timeout: 30,
146 retries: 3,
147 allow_tools: None,
148 deny_tools: None,
149 };
150 run_convert_command(args, cli.verbose, cli.quiet).await
151 } else {
152 bail!("请提供 URL 或使用子命令")
153 }
154 }
155 }
156}
157
158async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
160 if args.allow_tools.is_some() && args.deny_tools.is_some() {
162 bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
163 }
164
165 let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
167 ToolFilter::allow(allow_tools)
168 } else if let Some(deny_tools) = args.deny_tools.clone() {
169 ToolFilter::deny(deny_tools)
170 } else {
171 ToolFilter::default()
172 };
173
174 if !quiet {
175 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
176 if verbose {
177 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
178 }
179 if let Some(ref allow_tools) = args.allow_tools {
181 eprintln!("🔧 工具白名单: {:?}", allow_tools);
182 }
183 if let Some(ref deny_tools) = args.deny_tools {
184 eprintln!("🔧 工具黑名单: {:?}", deny_tools);
185 }
186 }
187
188 let protocol = if let Some(ref proto) = args.protocol {
190 let detected = match proto {
192 super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
193 super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
194 };
195 if !quiet {
196 let proto_name = match detected {
197 super::protocol::McpProtocol::Sse => "SSE",
198 super::protocol::McpProtocol::Stream => "Streamable HTTP",
199 super::protocol::McpProtocol::Stdio => "Stdio",
200 };
201 eprintln!("🔧 使用指定协议: {}", proto_name);
202 }
203 detected
204 } else {
205 let detected = super::protocol::detect_mcp_protocol(&args.url).await?;
207 if !quiet {
208 let proto_name = match detected {
209 super::protocol::McpProtocol::Sse => "SSE",
210 super::protocol::McpProtocol::Stream => "Streamable HTTP",
211 super::protocol::McpProtocol::Stdio => "Stdio",
212 };
213 eprintln!("🔍 检测到 {} 协议", proto_name);
214 }
215 detected
216 };
217
218 if !quiet {
219 eprintln!("🔗 建立连接...");
220 }
221
222 let http_client = create_http_client(&args)?;
224
225 let client_info = ClientInfo {
227 protocol_version: Default::default(),
228 capabilities: ClientCapabilities::builder()
229 .enable_experimental()
230 .enable_roots()
231 .enable_roots_list_changed()
232 .enable_sampling()
233 .build(),
234 ..Default::default()
235 };
236
237 let running = match protocol {
239 super::protocol::McpProtocol::Sse => {
240 let cfg = SseClientConfig {
241 sse_endpoint: args.url.clone().into(),
242 ..Default::default()
243 };
244 let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
245 client_info.serve(transport).await?
246 }
247 super::protocol::McpProtocol::Stream => {
248 let cfg = StreamableHttpClientTransportConfig {
249 uri: args.url.clone().into(),
250 ..Default::default()
251 };
252 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
253 client_info.serve(transport).await?
254 }
255 super::protocol::McpProtocol::Stdio => {
256 bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
257 }
258 };
259
260 if !quiet {
261 eprintln!("✅ 连接成功,开始代理转换...");
262 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
263 }
264
265 let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
267 let stdio_transport = stdio();
268 let server = proxy_handler.serve(stdio_transport).await?;
269 server.waiting().await?;
270
271 Ok(())
272}
273
274
275
276
277
278
279fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
281 let mut headers = reqwest::header::HeaderMap::new();
282
283 if let Some(auth) = &args.auth {
285 headers.insert("Authorization", auth.parse()?);
286 }
287
288 for (key, value) in &args.header {
290 headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
291 }
292
293 let client = reqwest::Client::builder()
294 .default_headers(headers)
295 .timeout(tokio::time::Duration::from_secs(args.timeout))
296 .build()?;
297
298 Ok(client)
299}
300
301
302
303async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
305 if !quiet {
306 eprintln!("🔍 检查服务: {}", args.url);
307 }
308
309 match super::protocol::detect_mcp_protocol(&args.url).await {
310 Ok(protocol) => {
311 if !quiet {
312 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
313 }
314 Ok(())
315 }
316 Err(e) => {
317 if !quiet {
318 eprintln!("❌ 服务检查失败: {}", e);
319 }
320 Err(e)
321 }
322 }
323}
324
325async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
327 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
328
329 if quiet {
330 println!("{}", protocol);
331 } else {
332 eprintln!("{}", protocol);
333 }
334
335 Ok(())
336}