use clap::Parser;
use std::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info, warn};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "127.0.0.1:8080")]
listen: String,
}
async fn run_proxy(listen_addr: &str) -> io::Result<()> {
let listener = TcpListener::bind(listen_addr).await?;
info!("HTTP/HTTPS代理服务器启动,监听地址: {}", listen_addr);
loop {
let (client_stream, client_addr) = listener.accept().await?;
info!("收到来自 {} 的连接", client_addr);
tokio::spawn(async move {
if let Err(e) = handle_client(client_stream, client_addr).await {
error!("处理客户端连接时出错: {}", e);
}
});
}
}
async fn handle_client(
mut client_stream: TcpStream,
client_addr: std::net::SocketAddr,
) -> io::Result<()> {
let mut buffer = [0; 4096];
let n = client_stream.read(&mut buffer).await?;
if n == 0 {
return Ok(());
}
let request_data = &buffer[..n];
let request_str = String::from_utf8_lossy(request_data);
if request_str.starts_with("CONNECT ") {
info!("来自 {} 的HTTPS CONNECT请求", client_addr);
handle_connect_request(client_stream, request_str).await
} else {
info!("来自 {} 的HTTP请求", client_addr);
handle_http_request(client_stream, request_data).await
}
}
async fn handle_connect_request(
mut client_stream: TcpStream,
request_str: std::borrow::Cow<'_, str>,
) -> io::Result<()> {
let target_addr = parse_connect_target(&request_str).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "无法解析CONNECT目标地址")
})?;
info!("HTTPS隧道目标: {}", target_addr);
let mut target_stream = match TcpStream::connect(&target_addr).await {
Ok(stream) => stream,
Err(e) => {
error!("连接目标服务器失败: {}", e);
let error_response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
client_stream.write_all(error_response.as_bytes()).await?;
return Err(e);
}
};
let success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
client_stream.write_all(success_response.as_bytes()).await?;
info!("HTTPS隧道已建立,开始转发数据");
let (mut client_read, mut client_write) = client_stream.split();
let (mut target_read, mut target_write) = target_stream.split();
let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
match tokio::try_join!(client_to_target, target_to_client) {
Ok((sent, received)) => {
info!("HTTPS隧道转发完成: 发送 {} 字节, 接收 {} 字节", sent, received);
}
Err(e) => {
warn!("HTTPS隧道转发时出错: {}", e);
}
}
Ok(())
}
async fn handle_http_request(
mut client_stream: TcpStream,
request_data: &[u8],
) -> io::Result<()> {
let request_str = String::from_utf8_lossy(request_data);
let target_addr = parse_http_target(&request_str);
info!("HTTP请求目标: {}", target_addr);
let mut target_stream = match TcpStream::connect(&target_addr).await {
Ok(stream) => stream,
Err(e) => {
error!("连接目标服务器失败: {}", e);
let error_response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
client_stream.write_all(error_response.as_bytes()).await?;
return Err(e);
}
};
target_stream.write_all(request_data).await?;
let (mut client_read, mut client_write) = client_stream.split();
let (mut target_read, mut target_write) = target_stream.split();
let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
match tokio::try_join!(client_to_target, target_to_client) {
Ok((sent, received)) => {
info!("HTTP转发完成: 发送 {} 字节, 接收 {} 字节", sent, received);
}
Err(e) => {
warn!("HTTP转发时出错: {}", e);
}
}
Ok(())
}
fn parse_connect_target(request: &str) -> Option<String> {
for line in request.lines() {
if line.starts_with("CONNECT ") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let target = parts[1];
if !target.contains(':') {
return Some(format!("{}:443", target));
}
return Some(target.to_string());
}
}
}
None
}
fn parse_http_target(request: &str) -> String {
let mut target_host = "localhost";
let mut target_port = 80;
for line in request.lines() {
if line.to_lowercase().starts_with("host: ") {
let host = line[6..].trim();
target_host = host;
if let Some(pos) = host.find(':') {
target_host = &host[..pos];
if let Ok(port) = host[pos + 1..].parse::<u16>() {
target_port = port;
}
}
break;
}
}
for line in request.lines() {
if line.starts_with("GET ") || line.starts_with("POST ") || line.starts_with("HEAD ") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let url = parts[1];
if url.starts_with("https://") {
target_port = 443;
let url = &url[8..]; if let Some(pos) = url.find('/') {
let host_port = &url[..pos];
target_host = host_port;
if let Some(port_pos) = host_port.find(':') {
target_host = &host_port[..port_pos];
if let Ok(port) = host_port[port_pos + 1..].parse::<u16>() {
target_port = port;
}
}
}
} else if url.starts_with("http://") {
target_port = 80;
let url = &url[7..]; if let Some(pos) = url.find('/') {
let host_port = &url[..pos];
target_host = host_port;
if let Some(port_pos) = host_port.find(':') {
target_host = &host_port[..port_pos];
if let Ok(port) = host_port[port_pos + 1..].parse::<u16>() {
target_port = port;
}
}
}
}
}
break;
}
}
format!("{}:{}", target_host, target_port)
}
#[tokio::main]
async fn main() -> io::Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
info!("启动HTTP/HTTPS代理服务器...");
run_proxy(&args.listen).await
}