1use std::collections::HashMap;
5use std::time::Duration;
6
7use clap::Parser;
8use anyhow::{Result, bail};
9use serde::Deserialize;
10use tokio::process::Command;
11
12use rmcp::{
13 ServiceExt,
14 model::{ClientCapabilities, ClientInfo},
15 transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
16};
17use crate::proxy::{ProxyHandler, ToolFilter};
18
19#[derive(Parser, Debug)]
21#[command(name = "mcp-proxy")]
22#[command(version = env!("CARGO_PKG_VERSION"))]
23#[command(about = "MCP 协议转换代理工具", long_about = None)]
24pub struct Cli {
25 #[command(subcommand)]
26 pub command: Option<Commands>,
27
28 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
30 pub url: Option<String>,
31
32 #[arg(short, long, global = true)]
34 pub verbose: bool,
35
36 #[arg(short, long, global = true)]
38 pub quiet: bool,
39}
40
41#[derive(clap::Subcommand, Debug)]
42pub enum Commands {
43 Convert(ConvertArgs),
45
46 Check(CheckArgs),
48
49 Detect(DetectArgs),
51
52 Proxy(super::proxy_server::ProxyArgs),
54}
55
56#[derive(Parser, Debug, Clone)]
58pub struct ConvertArgs {
59 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
61 pub url: Option<String>,
62
63 #[arg(long, conflicts_with = "config_file", help = "MCP 服务配置 JSON")]
65 pub config: Option<String>,
66
67 #[arg(long, conflicts_with = "config", help = "MCP 服务配置文件路径")]
69 pub config_file: Option<std::path::PathBuf>,
70
71 #[arg(short, long, help = "MCP 服务名称(多服务配置时必需)")]
73 pub name: Option<String>,
74
75 #[arg(long, value_enum, help = "指定远程服务协议类型(不指定则自动检测)")]
77 pub protocol: Option<super::proxy_server::ProxyProtocol>,
78
79 #[arg(short, long, help = "认证 header")]
81 pub auth: Option<String>,
82
83 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
85 pub header: Vec<(String, String)>,
86
87 #[arg(long, default_value = "300", help = "连接超时时间(秒),默认5分钟")]
89 pub timeout: u64,
90
91 #[arg(long, default_value = "0", help = "重试次数,0 表示无限重试")]
93 pub retries: u32,
94
95 #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
97 pub allow_tools: Option<Vec<String>>,
98
99 #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
101 pub deny_tools: Option<Vec<String>>,
102}
103
104#[derive(Parser, Debug)]
106pub struct CheckArgs {
107 #[arg(value_name = "URL")]
109 pub url: String,
110
111 #[arg(short, long)]
113 pub auth: Option<String>,
114
115 #[arg(long, default_value = "10")]
117 pub timeout: u64,
118}
119
120#[derive(Parser, Debug)]
122pub struct DetectArgs {
123 #[arg(value_name = "URL")]
125 pub url: String,
126
127 #[arg(short, long)]
129 pub auth: Option<String>,
130}
131
132fn parse_key_val(s: &str) -> Result<(String, String)> {
134 let pos = s.find('=')
135 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
136 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
137}
138
139#[derive(Deserialize, Debug)]
143struct McpConfig {
144 #[serde(rename = "mcpServers")]
145 mcp_servers: HashMap<String, McpServerInnerConfig>,
146}
147
148#[derive(Deserialize, Debug, Clone)]
150#[serde(untagged)]
151enum McpServerInnerConfig {
152 Command(StdioConfig),
153 Url(UrlConfig),
154}
155
156#[derive(Deserialize, Debug, Clone)]
158struct StdioConfig {
159 command: String,
160 args: Option<Vec<String>>,
161 env: Option<HashMap<String, String>>,
162}
163
164#[derive(Deserialize, Debug, Clone)]
166struct UrlConfig {
167 #[serde(skip_serializing_if = "Option::is_none")]
168 url: Option<String>,
169 #[serde(
170 skip_serializing_if = "Option::is_none",
171 default,
172 rename = "baseUrl",
173 alias = "baseurl",
174 alias = "base_url"
175 )]
176 base_url: Option<String>,
177 #[serde(default, rename = "type", alias = "Type")]
178 r#type: Option<String>,
179 pub headers: Option<HashMap<String, String>>,
180 #[serde(default, alias = "authToken", alias = "auth_token")]
181 pub auth_token: Option<String>,
182 pub timeout: Option<u64>,
183}
184
185impl UrlConfig {
186 fn get_url(&self) -> Option<&str> {
187 self.url.as_deref().or(self.base_url.as_deref())
188 }
189}
190
191enum McpConfigSource {
193 DirectUrl {
195 url: String,
196 },
197 RemoteService {
199 name: String,
200 url: String,
201 protocol: Option<super::protocol::McpProtocol>,
202 headers: HashMap<String, String>,
203 timeout: Option<u64>,
204 },
205 LocalCommand {
207 name: String,
208 command: String,
209 args: Vec<String>,
210 env: HashMap<String, String>,
211 },
212}
213
214fn parse_convert_config(args: &ConvertArgs) -> Result<McpConfigSource> {
216 if let Some(ref url) = args.url {
218 return Ok(McpConfigSource::DirectUrl { url: url.clone() });
219 }
220
221 let json_str = if let Some(ref config) = args.config {
223 config.clone()
224 } else if let Some(ref path) = args.config_file {
225 std::fs::read_to_string(path)
226 .map_err(|e| anyhow::anyhow!("读取配置文件失败: {}", e))?
227 } else {
228 bail!("必须提供 URL、--config 或 --config-file 参数之一");
229 };
230
231 let mcp_config: McpConfig = serde_json::from_str(&json_str)
233 .map_err(|e| anyhow::anyhow!(
234 "配置解析失败: {}。配置必须是标准 MCP 格式,包含 mcpServers 字段",
235 e
236 ))?;
237
238 let servers = mcp_config.mcp_servers;
239
240 if servers.is_empty() {
241 bail!("配置中没有找到任何 MCP 服务");
242 }
243
244 let (name, inner_config) = if servers.len() == 1 {
246 servers.into_iter().next().unwrap()
247 } else if let Some(ref name) = args.name {
248 let config = servers.get(name)
249 .cloned()
250 .ok_or_else(|| anyhow::anyhow!(
251 "服务 '{}' 不存在。可用服务: {:?}",
252 name,
253 servers.keys().collect::<Vec<_>>()
254 ))?;
255 (name.clone(), config)
256 } else {
257 bail!(
258 "配置包含多个服务 {:?},请使用 --name 指定要使用的服务",
259 servers.keys().collect::<Vec<_>>()
260 );
261 };
262
263 match inner_config {
265 McpServerInnerConfig::Command(stdio) => {
266 Ok(McpConfigSource::LocalCommand {
267 name,
268 command: stdio.command,
269 args: stdio.args.unwrap_or_default(),
270 env: stdio.env.unwrap_or_default(),
271 })
272 }
273 McpServerInnerConfig::Url(url_config) => {
274 let url = url_config.get_url()
275 .ok_or_else(|| anyhow::anyhow!("URL 配置缺少 url 或 baseUrl 字段"))?
276 .to_string();
277
278 let protocol = url_config.r#type.as_ref().and_then(|t| {
280 match t.as_str() {
281 "sse" => Some(super::protocol::McpProtocol::Sse),
282 "http" | "stream" => Some(super::protocol::McpProtocol::Stream),
283 _ => None,
284 }
285 });
286
287 let mut headers = url_config.headers.clone().unwrap_or_default();
289 if let Some(auth_token) = &url_config.auth_token {
290 headers.insert("Authorization".to_string(), auth_token.clone());
291 }
292
293 Ok(McpConfigSource::RemoteService {
294 name,
295 url,
296 protocol,
297 headers,
298 timeout: url_config.timeout,
299 })
300 }
301 }
302}
303
304fn merge_headers(
306 config_headers: HashMap<String, String>,
307 cli_headers: &[(String, String)],
308 cli_auth: Option<&String>,
309) -> HashMap<String, String> {
310 let mut merged = config_headers;
311
312 for (key, value) in cli_headers {
314 merged.insert(key.clone(), value.clone());
315 }
316
317 if let Some(auth_value) = cli_auth {
319 merged.insert("Authorization".to_string(), auth_value.clone());
320 }
321
322 merged
323}
324
325pub async fn run_cli(cli: Cli) -> Result<()> {
327 match cli.command {
328 Some(Commands::Convert(args)) => {
329 run_convert_command(args, cli.verbose, cli.quiet).await
330 }
331 Some(Commands::Check(args)) => {
332 run_check_command(args, cli.verbose, cli.quiet).await
333 }
334 Some(Commands::Detect(args)) => {
335 run_detect_command(args, cli.verbose, cli.quiet).await
336 }
337 Some(Commands::Proxy(args)) => {
338 super::proxy_server::run_proxy_command(args, cli.verbose, cli.quiet).await
339 }
340 None => {
341 if let Some(url) = cli.url {
343 let args = ConvertArgs {
344 url: Some(url),
345 config: None,
346 config_file: None,
347 name: None,
348 protocol: None,
349 auth: None,
350 header: vec![],
351 timeout: 300, retries: 0, allow_tools: None,
354 deny_tools: None,
355 };
356 run_convert_command(args, cli.verbose, cli.quiet).await
357 } else {
358 bail!("请提供 URL 或使用子命令")
359 }
360 }
361 }
362}
363
364async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
366 if args.allow_tools.is_some() && args.deny_tools.is_some() {
368 bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
369 }
370
371 let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
373 ToolFilter::allow(allow_tools)
374 } else if let Some(deny_tools) = args.deny_tools.clone() {
375 ToolFilter::deny(deny_tools)
376 } else {
377 ToolFilter::default()
378 };
379
380 let config_source = parse_convert_config(&args)?;
382
383 let client_info = ClientInfo {
385 protocol_version: Default::default(),
386 capabilities: ClientCapabilities::builder()
387 .enable_experimental()
388 .enable_roots()
389 .enable_roots_list_changed()
390 .enable_sampling()
391 .build(),
392 ..Default::default()
393 };
394
395 match config_source {
397 McpConfigSource::DirectUrl { url } => {
398 run_url_mode_with_retry(&args, &url, HashMap::new(), None, client_info, tool_filter, verbose, quiet).await
400 }
401 McpConfigSource::RemoteService { name, url, protocol, headers, timeout } => {
402 if !quiet {
404 eprintln!("🚀 MCP-Stdio-Proxy: {} ({}) → stdio", name, url);
405 }
406 let merged_headers = merge_headers(headers, &args.header, args.auth.as_ref());
408 run_url_mode_with_retry(&args, &url, merged_headers, protocol.or(timeout.map(|_| super::protocol::McpProtocol::Stream)), client_info, tool_filter, verbose, quiet).await
409 }
410 McpConfigSource::LocalCommand { name, command, args: cmd_args, env } => {
411 run_command_mode(&name, &command, cmd_args, env, client_info, tool_filter, verbose, quiet).await
413 }
414 }
415}
416
417async fn run_url_mode_with_retry(
420 args: &ConvertArgs,
421 url: &str,
422 merged_headers: HashMap<String, String>,
423 config_protocol: Option<super::protocol::McpProtocol>,
424 client_info: ClientInfo,
425 tool_filter: ToolFilter,
426 verbose: bool,
427 quiet: bool,
428) -> Result<()> {
429 let max_retries = args.retries;
430 let mut attempt = 0u32;
431 let mut backoff_secs = 1u64;
432 const MAX_BACKOFF_SECS: u64 = 30;
433
434 loop {
435 attempt += 1;
436 let is_retry = attempt > 1;
437
438 if is_retry && !quiet {
440 eprintln!("🔗 正在建立连接 (第{}次尝试)...", attempt);
441 }
442
443 let client_info_clone = ClientInfo {
445 protocol_version: client_info.protocol_version.clone(),
446 capabilities: client_info.capabilities.clone(),
447 client_info: client_info.client_info.clone(),
448 };
449
450 let result = run_url_mode(
451 args,
452 url,
453 merged_headers.clone(),
454 config_protocol.clone(),
455 client_info_clone,
456 tool_filter.clone(),
457 verbose,
458 quiet,
459 is_retry, ).await;
461
462 match result {
463 Ok(_) => {
464 break Ok(());
466 }
467 Err(e) => {
468 let error_type = classify_error(&e);
470
471 if max_retries > 0 && attempt >= max_retries {
473 if !quiet {
474 eprintln!("❌ 连接失败,已达最大重试次数 ({})", max_retries);
475 eprintln!(" 错误类型: {}", error_type);
476 eprintln!(" 错误详情: {}", e);
477 }
478 break Err(e);
479 }
480
481 if !quiet {
482 if max_retries == 0 {
483 eprintln!("⚠️ 连接断开 [{}]: {},{}秒后重连 (第{}次)...",
484 error_type, summarize_error(&e), backoff_secs, attempt);
485 } else {
486 eprintln!("⚠️ 连接断开 [{}]: {},{}秒后重连 ({}/{})...",
487 error_type, summarize_error(&e), backoff_secs, attempt, max_retries);
488 }
489 }
490
491 if verbose && !quiet {
493 eprintln!(" 完整错误: {}", e);
494 eprintln!(" 当前退避: {}s,下次退避: {}s",
495 backoff_secs, (backoff_secs * 2).min(MAX_BACKOFF_SECS));
496 }
497
498 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
499
500 backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF_SECS);
502 }
503 }
504 }
505}
506
507fn classify_error(e: &anyhow::Error) -> &'static str {
509 let err_str = e.to_string().to_lowercase();
510
511 if err_str.contains("timeout") || err_str.contains("timed out") {
512 "超时"
513 } else if err_str.contains("connection refused") {
514 "连接被拒绝"
515 } else if err_str.contains("connection reset") {
516 "连接被重置"
517 } else if err_str.contains("dns") || err_str.contains("resolve") {
518 "DNS解析失败"
519 } else if err_str.contains("certificate") || err_str.contains("ssl") || err_str.contains("tls") {
520 "SSL/TLS错误"
521 } else if err_str.contains("session") {
522 "会话错误"
523 } else if err_str.contains("sending request") || err_str.contains("network") {
524 "网络错误"
525 } else if err_str.contains("eof") || err_str.contains("closed") || err_str.contains("shutdown") {
526 "连接关闭"
527 } else {
528 "未知错误"
529 }
530}
531
532fn summarize_error(e: &anyhow::Error) -> String {
534 let full = e.to_string();
535 let first_line = full.lines().next().unwrap_or(&full);
537 if first_line.chars().count() > 80 {
539 format!("{}...", first_line.chars().take(77).collect::<String>())
540 } else {
541 first_line.to_string()
542 }
543}
544
545async fn run_url_mode(
547 args: &ConvertArgs,
548 url: &str,
549 merged_headers: HashMap<String, String>,
550 config_protocol: Option<super::protocol::McpProtocol>,
551 client_info: ClientInfo,
552 tool_filter: ToolFilter,
553 verbose: bool,
554 quiet: bool,
555 is_retry: bool, ) -> Result<()> {
557 if !quiet && merged_headers.is_empty() && !is_retry {
558 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", url);
559 }
560
561 if verbose && !quiet && !is_retry {
562 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
563 }
564
565 if !quiet && !is_retry {
567 if let Some(ref allow_tools) = args.allow_tools {
568 eprintln!("🔧 工具白名单: {:?}", allow_tools);
569 }
570 if let Some(ref deny_tools) = args.deny_tools {
571 eprintln!("🔧 工具黑名单: {:?}", deny_tools);
572 }
573 }
574
575 let protocol = if let Some(ref proto) = args.protocol {
577 let detected = match proto {
579 super::proxy_server::ProxyProtocol::Sse => super::protocol::McpProtocol::Sse,
580 super::proxy_server::ProxyProtocol::Stream => super::protocol::McpProtocol::Stream,
581 };
582 if !quiet && !is_retry {
583 eprintln!("🔧 使用指定协议: {}", protocol_name(&detected));
584 }
585 detected
586 } else if let Some(proto) = config_protocol {
587 if !quiet && !is_retry {
589 eprintln!("🔧 使用配置协议: {}", protocol_name(&proto));
590 }
591 proto
592 } else {
593 let detected = super::protocol::detect_mcp_protocol(url).await?;
595 if !quiet && !is_retry {
596 eprintln!("🔍 检测到 {} 协议", protocol_name(&detected));
597 }
598 detected
599 };
600
601 if !quiet && !is_retry {
602 eprintln!("🔗 建立连接...");
603 }
604
605 let http_client = create_http_client_with_headers(&merged_headers, &args.header, args.auth.as_ref(), args.timeout)?;
607
608 let running = match protocol {
610 super::protocol::McpProtocol::Sse => {
611 let cfg = SseClientConfig {
612 sse_endpoint: url.to_string().into(),
613 ..Default::default()
614 };
615 let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
616 client_info.serve(transport).await?
617 }
618 super::protocol::McpProtocol::Stream => {
619 let cfg = StreamableHttpClientTransportConfig {
620 uri: url.to_string().into(),
621 ..Default::default()
622 };
623 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
624 client_info.serve(transport).await?
625 }
626 super::protocol::McpProtocol::Stdio => {
627 bail!("Stdio 协议不支持通过 URL 转换,请使用 --config 配置本地命令")
628 }
629 };
630
631 if !quiet {
632 if is_retry {
633 eprintln!("✅ 重连成功,恢复代理服务");
634 } else {
635 eprintln!("✅ 连接成功,开始代理转换...");
636 }
637
638 match running.list_tools(None).await {
640 Ok(tools_result) => {
641 let tools = &tools_result.tools;
642 if tools.is_empty() {
643 eprintln!("⚠️ 工具列表为空 (tools/list 返回 0 个工具)");
644 } else {
645 eprintln!("🔧 可用工具 ({} 个):", tools.len());
646 for tool in tools {
647 let desc = tool.description.as_deref().unwrap_or("无描述");
648 let desc_short = if desc.chars().count() > 50 {
649 format!("{}...", desc.chars().take(50).collect::<String>())
650 } else {
651 desc.to_string()
652 };
653 eprintln!(" - {} : {}", tool.name, desc_short);
654 }
655 }
656 }
657 Err(e) => {
658 eprintln!("⚠️ 获取工具列表失败: {}", e);
659 }
660 }
661
662 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
663 }
664
665 let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
667 let stdio_transport = stdio();
668 let server = proxy_handler.serve(stdio_transport).await?;
669 server.waiting().await?;
670
671 Ok(())
672}
673
674async fn run_command_mode(
676 name: &str,
677 command: &str,
678 cmd_args: Vec<String>,
679 env: HashMap<String, String>,
680 client_info: ClientInfo,
681 tool_filter: ToolFilter,
682 verbose: bool,
683 quiet: bool,
684) -> Result<()> {
685 if !quiet {
686 eprintln!("🚀 MCP-Stdio-Proxy: {} (command) → stdio", name);
687 eprintln!(" 命令: {} {:?}", command, cmd_args);
688 if verbose && !env.is_empty() {
689 eprintln!(" 环境变量: {:?}", env);
690 }
691 }
692
693 if !quiet {
695 if tool_filter.is_enabled() {
696 eprintln!("🔧 工具过滤已启用");
697 }
698 }
699
700 let mut cmd = Command::new(command);
702 cmd.args(&cmd_args);
703 for (k, v) in &env {
704 cmd.env(k, v);
705 }
706
707 let tokio_process = TokioChildProcess::new(cmd)?;
709
710 if !quiet {
711 eprintln!("🔗 启动子进程...");
712 }
713
714 let running = client_info.serve(tokio_process).await?;
716
717 if !quiet {
718 eprintln!("✅ 子进程已启动,开始代理转换...");
719
720 match running.list_tools(None).await {
722 Ok(tools_result) => {
723 let tools = &tools_result.tools;
724 if tools.is_empty() {
725 eprintln!("⚠️ 工具列表为空 (tools/list 返回 0 个工具)");
726 } else {
727 eprintln!("🔧 可用工具 ({} 个):", tools.len());
728 for tool in tools {
729 let desc = tool.description.as_deref().unwrap_or("无描述");
730 let desc_short = if desc.chars().count() > 50 {
731 format!("{}...", desc.chars().take(50).collect::<String>())
732 } else {
733 desc.to_string()
734 };
735 eprintln!(" - {} : {}", tool.name, desc_short);
736 }
737 }
738 }
739 Err(e) => {
740 eprintln!("⚠️ 获取工具列表失败: {}", e);
741 }
742 }
743
744 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
745 }
746
747 let proxy_handler = ProxyHandler::with_tool_filter(running, name.to_string(), tool_filter);
749 let stdio_transport = stdio();
750 let server = proxy_handler.serve(stdio_transport).await?;
751 server.waiting().await?;
752
753 Ok(())
754}
755
756fn protocol_name(protocol: &super::protocol::McpProtocol) -> &'static str {
758 match protocol {
759 super::protocol::McpProtocol::Sse => "SSE",
760 super::protocol::McpProtocol::Stream => "Streamable HTTP",
761 super::protocol::McpProtocol::Stdio => "Stdio",
762 }
763}
764
765fn create_http_client_with_headers(
767 config_headers: &HashMap<String, String>,
768 cli_headers: &[(String, String)],
769 cli_auth: Option<&String>,
770 timeout: u64,
771) -> Result<reqwest::Client> {
772 let mut headers = reqwest::header::HeaderMap::new();
773
774 for (key, value) in config_headers {
776 headers.insert(
777 key.parse::<reqwest::header::HeaderName>()?,
778 value.parse()?,
779 );
780 }
781
782 for (key, value) in cli_headers {
784 headers.insert(
785 key.parse::<reqwest::header::HeaderName>()?,
786 value.parse()?,
787 );
788 }
789
790 if let Some(auth) = cli_auth {
792 headers.insert("Authorization", auth.parse()?);
793 }
794
795 let client = reqwest::Client::builder()
796 .default_headers(headers)
797 .timeout(tokio::time::Duration::from_secs(timeout))
798 .build()?;
799
800 Ok(client)
801}
802
803async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
805 if !quiet {
806 eprintln!("🔍 检查服务: {}", args.url);
807 }
808
809 match super::protocol::detect_mcp_protocol(&args.url).await {
810 Ok(protocol) => {
811 if !quiet {
812 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
813 }
814 Ok(())
815 }
816 Err(e) => {
817 if !quiet {
818 eprintln!("❌ 服务检查失败: {}", e);
819 }
820 Err(e)
821 }
822 }
823}
824
825async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
827 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
828
829 if quiet {
830 println!("{}", protocol);
831 } else {
832 eprintln!("{}", protocol);
833 }
834
835 Ok(())
836}