1use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8 ServiceExt,
9 model::{ClientCapabilities, ClientInfo},
10 transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::{ProxyHandler, ToolFilter};
13
14#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20 #[command(subcommand)]
21 pub command: Option<Commands>,
22
23 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25 pub url: Option<String>,
26
27 #[arg(short, long, global = true)]
29 pub verbose: bool,
30
31 #[arg(short, long, global = true)]
33 pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38 Convert(ConvertArgs),
40
41 Check(CheckArgs),
43
44 Detect(DetectArgs),
46}
47
48#[derive(Parser, Debug, Clone)]
50pub struct ConvertArgs {
51 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
53 pub url: String,
54
55 #[arg(short, long, help = "认证 header")]
57 pub auth: Option<String>,
58
59 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
61 pub header: Vec<(String, String)>,
62
63 #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
65 pub timeout: u64,
66
67 #[arg(long, default_value = "3", help = "重试次数")]
69 pub retries: u32,
70
71 #[arg(long, value_delimiter = ',', help = "工具白名单(逗号分隔),只允许指定的工具")]
73 pub allow_tools: Option<Vec<String>>,
74
75 #[arg(long, value_delimiter = ',', help = "工具黑名单(逗号分隔),排除指定的工具")]
77 pub deny_tools: Option<Vec<String>>,
78}
79
80#[derive(Parser, Debug)]
82pub struct CheckArgs {
83 #[arg(value_name = "URL")]
85 pub url: String,
86
87 #[arg(short, long)]
89 pub auth: Option<String>,
90
91 #[arg(long, default_value = "10")]
93 pub timeout: u64,
94}
95
96#[derive(Parser, Debug)]
98pub struct DetectArgs {
99 #[arg(value_name = "URL")]
101 pub url: String,
102
103 #[arg(short, long)]
105 pub auth: Option<String>,
106}
107
108fn parse_key_val(s: &str) -> Result<(String, String)> {
110 let pos = s.find('=')
111 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
112 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
113}
114
115pub async fn run_cli(cli: Cli) -> Result<()> {
117 match cli.command {
118 Some(Commands::Convert(args)) => {
119 run_convert_command(args, cli.verbose, cli.quiet).await
120 }
121 Some(Commands::Check(args)) => {
122 run_check_command(args, cli.verbose, cli.quiet).await
123 }
124 Some(Commands::Detect(args)) => {
125 run_detect_command(args, cli.verbose, cli.quiet).await
126 }
127 None => {
128 if let Some(url) = cli.url {
130 let args = ConvertArgs {
131 url,
132 auth: None,
133 header: vec![],
134 timeout: 30,
135 retries: 3,
136 allow_tools: None,
137 deny_tools: None,
138 };
139 run_convert_command(args, cli.verbose, cli.quiet).await
140 } else {
141 bail!("请提供 URL 或使用子命令")
142 }
143 }
144 }
145}
146
147async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
149 if args.allow_tools.is_some() && args.deny_tools.is_some() {
151 bail!("--allow-tools 和 --deny-tools 不能同时使用,请只选择其中一个");
152 }
153
154 let tool_filter = if let Some(allow_tools) = args.allow_tools.clone() {
156 ToolFilter::allow(allow_tools)
157 } else if let Some(deny_tools) = args.deny_tools.clone() {
158 ToolFilter::deny(deny_tools)
159 } else {
160 ToolFilter::default()
161 };
162
163 if !quiet {
164 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
165 if verbose {
166 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
167 }
168 if let Some(ref allow_tools) = args.allow_tools {
170 eprintln!("🔧 工具白名单: {:?}", allow_tools);
171 }
172 if let Some(ref deny_tools) = args.deny_tools {
173 eprintln!("🔧 工具黑名单: {:?}", deny_tools);
174 }
175 }
176
177 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
179
180 if !quiet {
181 let proto_name = match protocol {
182 super::protocol::McpProtocol::Sse => "SSE",
183 super::protocol::McpProtocol::Stream => "Streamable HTTP",
184 super::protocol::McpProtocol::Stdio => "Stdio",
185 };
186 eprintln!("🔍 检测到 {} 协议", proto_name);
187 eprintln!("🔗 建立连接...");
188 }
189
190 let http_client = create_http_client(&args)?;
192
193 let client_info = ClientInfo {
195 protocol_version: Default::default(),
196 capabilities: ClientCapabilities::builder()
197 .enable_experimental()
198 .enable_roots()
199 .enable_roots_list_changed()
200 .enable_sampling()
201 .build(),
202 ..Default::default()
203 };
204
205 let running = match protocol {
207 super::protocol::McpProtocol::Sse => {
208 let cfg = SseClientConfig {
209 sse_endpoint: args.url.clone().into(),
210 ..Default::default()
211 };
212 let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
213 client_info.serve(transport).await?
214 }
215 super::protocol::McpProtocol::Stream => {
216 let cfg = StreamableHttpClientTransportConfig {
217 uri: args.url.clone().into(),
218 ..Default::default()
219 };
220 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
221 client_info.serve(transport).await?
222 }
223 super::protocol::McpProtocol::Stdio => {
224 bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
225 }
226 };
227
228 if !quiet {
229 eprintln!("✅ 连接成功,开始代理转换...");
230 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
231 }
232
233 let proxy_handler = ProxyHandler::with_tool_filter(running, "cli".to_string(), tool_filter);
235 let stdio_transport = stdio();
236 let server = proxy_handler.serve(stdio_transport).await?;
237 server.waiting().await?;
238
239 Ok(())
240}
241
242
243
244
245
246
247fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
249 let mut headers = reqwest::header::HeaderMap::new();
250
251 if let Some(auth) = &args.auth {
253 headers.insert("Authorization", auth.parse()?);
254 }
255
256 for (key, value) in &args.header {
258 headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
259 }
260
261 let client = reqwest::Client::builder()
262 .default_headers(headers)
263 .timeout(tokio::time::Duration::from_secs(args.timeout))
264 .build()?;
265
266 Ok(client)
267}
268
269
270
271async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
273 if !quiet {
274 eprintln!("🔍 检查服务: {}", args.url);
275 }
276
277 match super::protocol::detect_mcp_protocol(&args.url).await {
278 Ok(protocol) => {
279 if !quiet {
280 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
281 }
282 Ok(())
283 }
284 Err(e) => {
285 if !quiet {
286 eprintln!("❌ 服务检查失败: {}", e);
287 }
288 Err(e)
289 }
290 }
291}
292
293async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
295 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
296
297 if quiet {
298 println!("{}", protocol);
299 } else {
300 eprintln!("{}", protocol);
301 }
302
303 Ok(())
304}