1use clap::Parser;
5use anyhow::{Result, bail};
6
7use rmcp::{
8 ServiceExt,
9 model::{ClientCapabilities, ClientInfo},
10 transport::{SseClientTransport, StreamableHttpClientTransport, sse_client::SseClientConfig, streamable_http_client::StreamableHttpClientTransportConfig, stdio},
11};
12use crate::proxy::ProxyHandler;
13
14#[derive(Parser, Debug)]
16#[command(name = "mcp-proxy")]
17#[command(version = env!("CARGO_PKG_VERSION"))]
18#[command(about = "MCP 协议转换代理工具", long_about = None)]
19pub struct Cli {
20 #[command(subcommand)]
21 pub command: Option<Commands>,
22
23 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
25 pub url: Option<String>,
26
27 #[arg(short, long, global = true)]
29 pub verbose: bool,
30
31 #[arg(short, long, global = true)]
33 pub quiet: bool,
34}
35
36#[derive(clap::Subcommand, Debug)]
37pub enum Commands {
38 Convert(ConvertArgs),
40
41 Check(CheckArgs),
43
44 Detect(DetectArgs),
46}
47
48#[derive(Parser, Debug, Clone)]
50pub struct ConvertArgs {
51 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
53 pub url: String,
54
55 #[arg(short, long, help = "认证 header")]
57 pub auth: Option<String>,
58
59 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
61 pub header: Vec<(String, String)>,
62
63 #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
65 pub timeout: u64,
66
67 #[arg(long, default_value = "3", help = "重试次数")]
69 pub retries: u32,
70}
71
72#[derive(Parser, Debug)]
74pub struct CheckArgs {
75 #[arg(value_name = "URL")]
77 pub url: String,
78
79 #[arg(short, long)]
81 pub auth: Option<String>,
82
83 #[arg(long, default_value = "10")]
85 pub timeout: u64,
86}
87
88#[derive(Parser, Debug)]
90pub struct DetectArgs {
91 #[arg(value_name = "URL")]
93 pub url: String,
94
95 #[arg(short, long)]
97 pub auth: Option<String>,
98}
99
100fn parse_key_val(s: &str) -> Result<(String, String)> {
102 let pos = s.find('=')
103 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
104 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
105}
106
107pub async fn run_cli(cli: Cli) -> Result<()> {
109 match cli.command {
110 Some(Commands::Convert(args)) => {
111 run_convert_command(args, cli.verbose, cli.quiet).await
112 }
113 Some(Commands::Check(args)) => {
114 run_check_command(args, cli.verbose, cli.quiet).await
115 }
116 Some(Commands::Detect(args)) => {
117 run_detect_command(args, cli.verbose, cli.quiet).await
118 }
119 None => {
120 if let Some(url) = cli.url {
122 let args = ConvertArgs {
123 url,
124 auth: None,
125 header: vec![],
126 timeout: 30,
127 retries: 3,
128 };
129 run_convert_command(args, cli.verbose, cli.quiet).await
130 } else {
131 bail!("请提供 URL 或使用子命令")
132 }
133 }
134 }
135}
136
137async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
139 if !quiet {
140 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
141 if verbose {
142 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
143 }
144 }
145
146 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
148
149 if !quiet {
150 let proto_name = match protocol {
151 super::protocol::McpProtocol::Sse => "SSE",
152 super::protocol::McpProtocol::Stream => "Streamable HTTP",
153 super::protocol::McpProtocol::Stdio => "Stdio",
154 };
155 eprintln!("🔍 检测到 {} 协议", proto_name);
156 eprintln!("🔗 建立连接...");
157 }
158
159 let http_client = create_http_client(&args)?;
161
162 let client_info = ClientInfo {
164 protocol_version: Default::default(),
165 capabilities: ClientCapabilities::builder()
166 .enable_experimental()
167 .enable_roots()
168 .enable_roots_list_changed()
169 .enable_sampling()
170 .build(),
171 ..Default::default()
172 };
173
174 let running = match protocol {
176 super::protocol::McpProtocol::Sse => {
177 let cfg = SseClientConfig {
178 sse_endpoint: args.url.clone().into(),
179 ..Default::default()
180 };
181 let transport = SseClientTransport::start_with_client(http_client, cfg).await?;
182 client_info.serve(transport).await?
183 }
184 super::protocol::McpProtocol::Stream => {
185 let cfg = StreamableHttpClientTransportConfig {
186 uri: args.url.clone().into(),
187 ..Default::default()
188 };
189 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
190 client_info.serve(transport).await?
191 }
192 super::protocol::McpProtocol::Stdio => {
193 bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
194 }
195 };
196
197 if !quiet {
198 eprintln!("✅ 连接成功,开始代理转换...");
199 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
200 }
201
202 let proxy_handler = ProxyHandler::new(running);
204 let stdio_transport = stdio();
205 let server = proxy_handler.serve(stdio_transport).await?;
206 server.waiting().await?;
207
208 Ok(())
209}
210
211
212
213
214
215
216fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
218 let mut headers = reqwest::header::HeaderMap::new();
219
220 if let Some(auth) = &args.auth {
222 headers.insert("Authorization", auth.parse()?);
223 }
224
225 for (key, value) in &args.header {
227 headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
228 }
229
230 let client = reqwest::Client::builder()
231 .default_headers(headers)
232 .timeout(tokio::time::Duration::from_secs(args.timeout))
233 .build()?;
234
235 Ok(client)
236}
237
238
239
240async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
242 if !quiet {
243 eprintln!("🔍 检查服务: {}", args.url);
244 }
245
246 match super::protocol::detect_mcp_protocol(&args.url).await {
247 Ok(protocol) => {
248 if !quiet {
249 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
250 }
251 Ok(())
252 }
253 Err(e) => {
254 if !quiet {
255 eprintln!("❌ 服务检查失败: {}", e);
256 }
257 Err(e)
258 }
259 }
260}
261
262async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
264 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
265
266 if quiet {
267 println!("{}", protocol);
268 } else {
269 eprintln!("{}", protocol);
270 }
271
272 Ok(())
273}