1use std::collections::HashMap;
5
6use clap::Parser;
7use anyhow::{Result, bail};
8use serde::Deserialize;
9use tokio::process::Command;
10
11use rmcp::{
12 ServiceExt,
13 model::{ClientCapabilities, ClientInfo},
14 transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
15};
16use crate::proxy::{ProxyHandler, ToolFilter};
17
18#[derive(Parser, Debug)]
20#[command(name = "mcp-proxy")]
21#[command(version = env!("CARGO_PKG_VERSION"))]
22#[command(about = "MCP 协议转换代理工具", long_about = None)]
23pub struct Cli {
24 #[command(subcommand)]
25 pub command: Option<Commands>,
26
27 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
29 pub url: Option<String>,
30
31 #[arg(short, long, global = true)]
33 pub verbose: bool,
34
35 #[arg(short, long, global = true)]
37 pub quiet: bool,
38}
39
40#[derive(clap::Subcommand, Debug)]
41pub enum Commands {
42 Convert(ConvertArgs),
44
45 Check(CheckArgs),
47
48 Detect(DetectArgs),
50
51 Proxy(super::proxy_server::ProxyArgs),
53}
54
55#[derive(Parser, Debug, Clone)]
57pub struct ConvertArgs {
58 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
60 pub url: Option<String>,
61
62 #[arg(long, conflicts_with = "config_file", help = "MCP 服务配置 JSON")]
64 pub config: Option<String>,
65
66 #[arg(long, conflicts_with = "config", help = "MCP 服务配置文件路径")]
68 pub config_file: Option<std::path::PathBuf>,
69
70 #[arg(short, long, help = "MCP 服务名称(多服务配置时必需)")]
72 pub name: Option<String>,
73
74 #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
76 pub protocol: Option<super::proxy_server::ProxyProtocol>,
77
78 #[arg(short, long, help = "认证 header")]
80 pub auth: Option<String>,
81
82 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
84 pub header: Vec<(String, String)>,
85
86 #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
88 pub timeout: u64,
89
90 #[arg(long, default_value = "3", help = "重试次数")]
92 pub retries: u32,
93
94 #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
96 pub allow_tools: Option<Vec<String>>,
97
98 #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
100 pub deny_tools: Option<Vec<String>>,
101}
102
103#[derive(Parser, Debug)]
105pub struct CheckArgs {
106 #[arg(value_name = "URL")]
108 pub url: String,
109
110 #[arg(short, long)]
112 pub auth: Option<String>,
113
114 #[arg(long, default_value = "10")]
116 pub timeout: u64,
117}
118
119#[derive(Parser, Debug)]
121pub struct DetectArgs {
122 #[arg(value_name = "URL")]
124 pub url: String,
125
126 #[arg(short, long)]
128 pub auth: Option<String>,
129}
130
131fn parse_key_val(s: &str) -> Result<(String, String)> {
133 let pos = s.find('=')
134 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
135 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
136}
137
138#[derive(Deserialize, Debug)]
142struct McpConfig {
143 #[serde(rename = "mcpServers")]
144 mcp_servers: HashMap<String, McpServerInnerConfig>,
145}
146
147#[derive(Deserialize, Debug, Clone)]
149#[serde(untagged)]
150enum McpServerInnerConfig {
151 Command(StdioConfig),
152 Url(UrlConfig),
153}
154
155#[derive(Deserialize, Debug, Clone)]
157struct StdioConfig {
158 command: String,
159 args: Option<Vec<String>>,
160 env: Option<HashMap<String, String>>,
161}
162
163#[derive(Deserialize, Debug, Clone)]
165struct UrlConfig {
166 #[serde(skip_serializing_if = "Option::is_none")]
167 url: Option<String>,
168 #[serde(
169 skip_serializing_if = "Option::is_none",
170 default,
171 rename = "baseUrl",
172 alias = "baseurl",
173 alias = "base_url"
174 )]
175 base_url: Option<String>,
176 #[serde(default, rename = "type", alias = "Type")]
177 r#type: Option<String>,
178 pub headers: Option<HashMap<String, String>>,
179 #[serde(default, alias = "authToken", alias = "auth_token")]
180 pub auth_token: Option<String>,
181 pub timeout: Option<u64>,
182}
183
184impl UrlConfig {
185 fn get_url(&self) -> Option<&str> {
186 self.url.as_deref().or(self.base_url.as_deref())
187 }
188}
189
190enum McpConfigSource {
192 DirectUrl {
194 url: String,
195 },
196 RemoteService {
198 name: String,
199 url: String,
200 protocol: Option<super::protocol::McpProtocol>,
201 headers: HashMap<String, String>,
202 timeout: Option<u64>,
203 },
204 LocalCommand {
206 name: String,
207 command: String,
208 args: Vec<String>,
209 env: HashMap<String, String>,
210 },
211}
212
213fn parse_convert_config(args: &ConvertArgs) -> Result<McpConfigSource> {
215 if let Some(ref url) = args.url {
217 return Ok(McpConfigSource::DirectUrl { url: url.clone() });
218 }
219
220 let json_str = if let Some(ref config) = args.config {
222 config.clone()
223 } else if let Some(ref path) = args.config_file {
224 std::fs::read_to_string(path)
225 .map_err(|e| anyhow::anyhow!("读取配置文件失败: {}", e))?
226 } else {
227 bail!("必须提供 URL、--config 或 --config-file 参数之一");
228 };
229
230 let mcp_config: McpConfig = serde_json::from_str(&json_str)
232 .map_err(|e| anyhow::anyhow!(
233 "配置解析失败: {}。配置必须是标准 MCP 格式,包含 mcpServers 字段",
234 e
235 ))?;
236
237 let servers = mcp_config.mcp_servers;
238
239 if servers.is_empty() {
240 bail!("配置中没有找到任何 MCP 服务");
241 }
242
243 let (name, inner_config) = if servers.len() == 1 {
245 servers.into_iter().next().unwrap()
246 } else if let Some(ref name) = args.name {
247 let config = servers.get(name)
248 .cloned()
249 .ok_or_else(|| anyhow::anyhow!(
250 "服务 '{}' 不存在。可用服务: {:?}",
251 name,
252 servers.keys().collect::<Vec<_>>()
253 ))?;
254 (name.clone(), config)
255 } else {
256 bail!(
257 "配置包含多个服务 {:?},请使用 --name 指定要使用的服务",
258 servers.keys().collect::<Vec<_>>()
259 );
260 };
261
262 match inner_config {
264 McpServerInnerConfig::Command(stdio) => {
265 Ok(McpConfigSource::LocalCommand {
266 name,
267 command: stdio.command,
268 args: stdio.args.unwrap_or_default(),
269 env: stdio.env.unwrap_or_default(),
270 })
271 }
272 McpServerInnerConfig::Url(url_config) => {
273 let url = url_config.get_url()
274 .ok_or_else(|| anyhow::anyhow!("URL 配置缺少 url 或 baseUrl 字段"))?
275 .to_string();
276
277 let protocol = url_config.r#type.as_ref().and_then(|t| {
279 match t.as_str() {
280 "sse" => Some(super::protocol::McpProtocol::Sse),
281 "http" | "stream" => Some(super::protocol::McpProtocol::Stream),
282 _ => None,
283 }
284 });
285
286 let mut headers = url_config.headers.clone().unwrap_or_default();
288 if let Some(auth_token) = &url_config.auth_token {
289 headers.insert("Authorization".to_string(), auth_token.clone());
290 }
291
292 Ok(McpConfigSource::RemoteService {
293 name,
294 url,
295 protocol,
296 headers,
297 timeout: url_config.timeout,
298 })
299 }
300 }
301}
302
303fn merge_headers(
305 config_headers: HashMap<String, String>,
306 cli_headers: &[(String, String)],
307 cli_auth: Option<&String>,
308) -> HashMap<String, String> {
309 let mut merged = config_headers;
310
311 for (key, value) in cli_headers {
313 merged.insert(key.clone(), value.clone());
314 }
315
316 if let Some(auth_value) = cli_auth {
318 merged.insert("Authorization".to_string(), auth_value.clone());
319 }
320
321 merged
322}
323
324pub async fn run_cli(cli: Cli) -> Result<()> {
326 match cli.command {
327 Some(Commands::Convert(args)) => {
328 run_convert_command(args, cli.verbose, cli.quiet).await
329 }
330 Some(Commands::Check(args)) => {
331 run_check_command(args, cli.verbose, cli.quiet).await
332 }
333 Some(Commands::Detect(args)) => {
334 run_detect_command(args, cli.verbose, cli.quiet).await
335 }
336 Some(Commands::Proxy(args)) => {
337 super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
338 }
339 None => {
340 if let Some(url) = cli.url {
342 let args = ConvertArgs {
343 url: Some(url),
344 config: None,
345 config_file: None,
346 name: None,
347 protocol: None,
348 auth: None,
349 header: vec![],
350 timeout: 30,
351 retries: 3,
352 allow_tools: None,
353 deny_tools: None,
354 };
355 run_convert_command(args, cli.verbose, cli.quiet).await
356 } else {
357 bail!("请提供 URL 或使用子命令")
358 }
359 }
360 }
361}
362
363async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
365 if args.allow_tools.is_some() && args.deny_tools.is_some() {
367 bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
368 }
369
370 let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
372 ToolFilter::allow(allow_tools)
373 } else if let Some(deny_tools) = args.deny_tools.clone() {
374 ToolFilter::deny(deny_tools)
375 } else {
376 ToolFilter::default()
377 };
378
379 let config_source = parse_convert_config(&args)?;
381
382 let client_info = ClientInfo {
384 protocol_version: Default::default(),
385 capabilities: ClientCapabilities::builder()
386 .enable_experimental()
387 .enable_roots()
388 .enable_roots_list_changed()
389 .enable_sampling()
390 .build(),
391 ..Default::default()
392 };
393
394 match config_source {
396 McpConfigSource::DirectUrl { url } => {
397 run_url_mode(&args, &url, HashMap::new(), None, client_info, tool_filter, verbose, quiet).await
399 }
400 McpConfigSource::RemoteService { name, url, protocol, headers, timeout } => {
401 if !quiet {
403 eprintln!("🚀 MCP-Stdio-Proxy: {} ({}) → stdio", name, url);
404 }
405 let merged_headers = merge_headers(headers, &args.header, args.auth.as_ref());
407 run_url_mode(&args, &url, merged_headers, protocol.or(timeout.map(|_| super::protocol::McpProtocol::Stream)), client_info, tool_filter, verbose, quiet).await
408 }
409 McpConfigSource::LocalCommand { name, command, args: cmd_args, env } => {
410 run_command_mode(&name, &command, cmd_args, env, client_info, tool_filter, verbose, quiet).await
412 }
413 }
414}
415
416async fn run_url_mode(
418 args: &ConvertArgs,
419 url: &str,
420 merged_headers: HashMap<String, String>,
421 config_protocol: Option<super::protocol::McpProtocol>,
422 client_info: ClientInfo,
423 tool_filter: ToolFilter,
424 verbose: bool,
425 quiet: bool,
426) -> Result<()> {
427 if !quiet && merged_headers.is_empty() {
428 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", url);
429 }
430
431 if verbose && !quiet {
432 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
433 }
434
435 if !quiet {
437 if let Some(ref allow_tools) = args.allow_tools {
438 eprintln!("🔧 工具白名单: {:?}", allow_tools);
439 }
440 if let Some(ref deny_tools) = args.deny_tools {
441 eprintln!("🔧 工具黑名单: {:?}", deny_tools);
442 }
443 }
444
445 let protocol = if let Some(ref proto) = args.protocol {
447 let detected = match proto {
449 super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
450 super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
451 };
452 if !quiet {
453 eprintln!("🔧 使用指定协议: {}", protocol_name(&detected));
454 }
455 detected
456 } else if let Some(proto) = config_protocol {
457 if !quiet {
459 eprintln!("🔧 使用配置协议: {}", protocol_name(&proto));
460 }
461 proto
462 } else {
463 let detected = super::protocol::detect_mcp_protocol(url).await?;
465 if !quiet {
466 eprintln!("🔍 检测到 {} 协议", protocol_name(&detected));
467 }
468 detected
469 };
470
471 if !quiet {
472 eprintln!("🔗 建立连接...");
473 }
474
475 let http_client = create_http_client_with_headers(&merged_headers, &args.header, args.auth.as_ref(), args.timeout)?;
477
478 let running = match protocol {
480 super::protocol::McpProtocol::Sse => {
481 let cfg = SseClientConfig {
482 sse_endpoint: url.to_string().into(),
483 ..Default::default()
484 };
485 let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
486 client_info.serve(transport).await?
487 }
488 super::protocol::McpProtocol::Stream => {
489 let cfg = StreamableHttpClientTransportConfig {
490 uri: url.to_string().into(),
491 ..Default::default()
492 };
493 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
494 client_info.serve(transport).await?
495 }
496 super::protocol::McpProtocol::Stdio => {
497 bail!("Stdio 协议不支持通过 URL 转换,请使用 --config 配置本地命令")
498 }
499 };
500
501 if !quiet {
502 eprintln!("✅ 连接成功,开始代理转换...");
503 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
504 }
505
506 let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
508 let stdio_transport = stdio();
509 let server = proxy_handler.serve(stdio_transport).await?;
510 server.waiting().await?;
511
512 Ok(())
513}
514
515async fn run_command_mode(
517 name: &str,
518 command: &str,
519 cmd_args: Vec<String>,
520 env: HashMap<String, String>,
521 client_info: ClientInfo,
522 tool_filter: ToolFilter,
523 verbose: bool,
524 quiet: bool,
525) -> Result<()> {
526 if !quiet {
527 eprintln!("🚀 MCP-Stdio-Proxy: {} (command) → stdio", name);
528 eprintln!(" 命令: {} {:?}", command, cmd_args);
529 if verbose && !env.is_empty() {
530 eprintln!(" 环境变量: {:?}", env);
531 }
532 }
533
534 if !quiet {
536 if tool_filter.is_enabled() {
537 eprintln!("🔧 工具过滤已启用");
538 }
539 }
540
541 let mut cmd = Command::new(command);
543 cmd.args(&cmd_args);
544 for (k, v) in &env {
545 cmd.env(k, v);
546 }
547
548 let tokio_process = TokioChildProcess::new(cmd)?;
550
551 if !quiet {
552 eprintln!("🔗 启动子进程...");
553 }
554
555 let running = client_info.serve(tokio_process).await?;
557
558 if !quiet {
559 eprintln!("✅ 子进程已启动,开始代理转换...");
560 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
561 }
562
563 let proxy_handler = ProxyHandler::with_tool_filter(running, name.to_string(), tool_filter);
565 let stdio_transport = stdio();
566 let server = proxy_handler.serve(stdio_transport).await?;
567 server.waiting().await?;
568
569 Ok(())
570}
571
572fn protocol_name(protocol: &super::protocol::McpProtocol) -> &'static str {
574 match protocol {
575 super::protocol::McpProtocol::Sse => "SSE",
576 super::protocol::McpProtocol::Stream => "Streamable HTTP",
577 super::protocol::McpProtocol::Stdio => "Stdio",
578 }
579}
580
581fn create_http_client_with_headers(
583 config_headers: &HashMap<String, String>,
584 cli_headers: &[(String, String)],
585 cli_auth: Option<&String>,
586 timeout: u64,
587) -> Result<reqwest::Client> {
588 let mut headers = reqwest::header::HeaderMap::new();
589
590 for (key, value) in config_headers {
592 headers.insert(
593 key.parse::<reqwest::header::HeaderName>()?,
594 value.parse()?,
595 );
596 }
597
598 for (key, value) in cli_headers {
600 headers.insert(
601 key.parse::<reqwest::header::HeaderName>()?,
602 value.parse()?,
603 );
604 }
605
606 if let Some(auth) = cli_auth {
608 headers.insert("Authorization", auth.parse()?);
609 }
610
611 let client = reqwest::Client::builder()
612 .default_headers(headers)
613 .timeout(tokio::time::Duration::from_secs(timeout))
614 .build()?;
615
616 Ok(client)
617}
618
619async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
621 if !quiet {
622 eprintln!("🔍 检查服务: {}", args.url);
623 }
624
625 match super::protocol::detect_mcp_protocol(&args.url).await {
626 Ok(protocol) => {
627 if !quiet {
628 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
629 }
630 Ok(())
631 }
632 Err(e) => {
633 if !quiet {
634 eprintln!("❌ 服务检查失败: {}", e);
635 }
636 Err(e)
637 }
638 }
639}
640
641async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
643 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
644
645 if quiet {
646 println!("{}", protocol);
647 } else {
648 eprintln!("{}", protocol);
649 }
650
651 Ok(())
652}