1use anyhow::{Result, bail};
5use clap::Parser;
6
7use crate::proxy::ProxyHandler;
8use rmcp::{
9 ServiceExt,
10 model::{ClientCapabilities, ClientInfo},
11 transport::{
12 StreamableHttpClientTransport, stdio,
13 streamable_http_client::StreamableHttpClientTransportConfig,
14 },
15};
16
17#[derive(Parser, Debug)]
19#[command(name = "mcp-proxy")]
20#[command(version = env!("CARGO_PKG_VERSION"))]
21#[command(about = "MCP 协议转换代理工具", long_about = None)]
22pub struct Cli {
23 #[command(subcommand)]
24 pub command: Option<Commands>,
25
26 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址(直接模式)")]
28 pub url: Option<String>,
29
30 #[arg(short, long, global = true)]
32 pub verbose: bool,
33
34 #[arg(short, long, global = true)]
36 pub quiet: bool,
37}
38
39#[derive(clap::Subcommand, Debug)]
40pub enum Commands {
41 Convert(ConvertArgs),
43
44 Check(CheckArgs),
46
47 Detect(DetectArgs),
49}
50
51#[derive(Parser, Debug, Clone)]
53pub struct ConvertArgs {
54 #[arg(value_name = "URL", help = "MCP 服务的 URL 地址")]
56 pub url: String,
57
58 #[arg(short, long, help = "认证 header")]
60 pub auth: Option<String>,
61
62 #[arg(short = 'H', long, value_parser = parse_key_val, help = "自定义 HTTP headers (KEY=VALUE 格式)")]
64 pub header: Vec<(String, String)>,
65
66 #[arg(long, default_value = "30", help = "连接超时时间(秒)")]
68 pub timeout: u64,
69
70 #[arg(long, default_value = "3", help = "重试次数")]
72 pub retries: u32,
73}
74
75#[derive(Parser, Debug)]
77pub struct CheckArgs {
78 #[arg(value_name = "URL")]
80 pub url: String,
81
82 #[arg(short, long)]
84 pub auth: Option<String>,
85
86 #[arg(long, default_value = "10")]
88 pub timeout: u64,
89}
90
91#[derive(Parser, Debug)]
93pub struct DetectArgs {
94 #[arg(value_name = "URL")]
96 pub url: String,
97
98 #[arg(short, long)]
100 pub auth: Option<String>,
101}
102
103fn parse_key_val(s: &str) -> Result<(String, String)> {
105 let pos = s
106 .find('=')
107 .ok_or_else(|| anyhow::anyhow!("无效的 KEY=VALUE 格式: {}", s))?;
108 Ok((s[..pos].to_string(), s[pos + 1..].to_string()))
109}
110
111pub async fn run_cli(cli: Cli) -> Result<()> {
113 match cli.command {
114 Some(Commands::Convert(args)) => run_convert_command(args, cli.verbose, cli.quiet).await,
115 Some(Commands::Check(args)) => run_check_command(args, cli.verbose, cli.quiet).await,
116 Some(Commands::Detect(args)) => run_detect_command(args, cli.verbose, cli.quiet).await,
117 None => {
118 if let Some(url) = cli.url {
120 let args = ConvertArgs {
121 url,
122 auth: None,
123 header: vec![],
124 timeout: 30,
125 retries: 3,
126 };
127 run_convert_command(args, cli.verbose, cli.quiet).await
128 } else {
129 bail!("请提供 URL 或使用子命令")
130 }
131 }
132 }
133}
134
135async fn run_convert_command(args: ConvertArgs, verbose: bool, quiet: bool) -> Result<()> {
137 if !quiet {
138 eprintln!("🚀 MCP-Stdio-Proxy: {} → stdio", args.url);
139 if verbose {
140 eprintln!("📡 超时: {}s, 重试: {}", args.timeout, args.retries);
141 }
142 }
143
144 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
146
147 if !quiet {
148 let proto_name = match protocol {
149 super::protocol::McpProtocol::Sse => "SSE",
150 super::protocol::McpProtocol::Stream => "Streamable HTTP",
151 super::protocol::McpProtocol::Stdio => "Stdio",
152 };
153 eprintln!("🔍 检测到 {} 协议", proto_name);
154 eprintln!("🔗 建立连接...");
155 }
156
157 let http_client = create_http_client(&args)?;
159
160 let client_info = ClientInfo {
162 protocol_version: Default::default(),
163 capabilities: ClientCapabilities::builder()
164 .enable_experimental()
165 .enable_roots()
166 .enable_roots_list_changed()
167 .enable_sampling()
168 .build(),
169 ..Default::default()
170 };
171
172 let running = match protocol {
175 super::protocol::McpProtocol::Sse | super::protocol::McpProtocol::Stream => {
176 let cfg = StreamableHttpClientTransportConfig {
177 uri: args.url.clone().into(),
178 ..Default::default()
179 };
180 let transport = StreamableHttpClientTransport::with_client(http_client, cfg);
181 client_info.serve(transport).await?
182 }
183 super::protocol::McpProtocol::Stdio => {
184 bail!("Stdio 协议不支持通过 URL 转换,请使用命令行模式")
185 }
186 };
187
188 if !quiet {
189 eprintln!("✅ 连接成功,开始代理转换...");
190 eprintln!("💡 现在可以通过 stdin 发送 JSON-RPC 请求");
191 }
192
193 let proxy_handler = ProxyHandler::new(running);
195 let stdio_transport = stdio();
196 let server = proxy_handler.serve(stdio_transport).await?;
197 server.waiting().await?;
198
199 Ok(())
200}
201
202fn create_http_client(args: &ConvertArgs) -> Result<reqwest::Client> {
204 let mut headers = reqwest::header::HeaderMap::new();
205
206 if let Some(auth) = &args.auth {
208 headers.insert("Authorization", auth.parse()?);
209 }
210
211 for (key, value) in &args.header {
213 headers.insert(key.parse::<reqwest::header::HeaderName>()?, value.parse()?);
214 }
215
216 let client = reqwest::Client::builder()
217 .default_headers(headers)
218 .timeout(tokio::time::Duration::from_secs(args.timeout))
219 .build()?;
220
221 Ok(client)
222}
223
224async fn run_check_command(args: CheckArgs, _verbose: bool, quiet: bool) -> Result<()> {
226 if !quiet {
227 eprintln!("🔍 检查服务: {}", args.url);
228 }
229
230 match super::protocol::detect_mcp_protocol(&args.url).await {
231 Ok(protocol) => {
232 if !quiet {
233 eprintln!("✅ 服务正常,检测到 {} 协议", protocol);
234 }
235 Ok(())
236 }
237 Err(e) => {
238 if !quiet {
239 eprintln!("❌ 服务检查失败: {}", e);
240 }
241 Err(e)
242 }
243 }
244}
245
246async fn run_detect_command(args: DetectArgs, _verbose: bool, quiet: bool) -> Result<()> {
248 let protocol = super::protocol::detect_mcp_protocol(&args.url).await?;
249
250 if quiet {
251 println!("{}", protocol);
252 } else {
253 eprintln!("{}", protocol);
254 }
255
256 Ok(())
257}