rust-proxy 0.1.0

A simple HTTP/HTTPS proxy server written in Rust
use clap::Parser;
use std::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

/// 支持HTTP和HTTPS的代理服务器
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    /// 代理服务器监听的地址
    #[arg(short, long, default_value = "127.0.0.1:8080")]
    listen: String,
}

async fn run_proxy(listen_addr: &str) -> io::Result<()> {
    let listener = TcpListener::bind(listen_addr).await?;
    println!("HTTP/HTTPS代理服务器启动,监听地址: {}", listen_addr);

    loop {
        let (client_stream, client_addr) = listener.accept().await?;
        println!("收到来自 {} 的连接", client_addr);

        tokio::spawn(async move {
            if let Err(e) = handle_client(client_stream).await {
                eprintln!("处理客户端连接时出错: {}", e);
            }
        });
    }
}

async fn handle_client(mut client_stream: TcpStream) -> io::Result<()> {
    // 读取客户端请求
    let mut buffer = [0; 4096];
    let n = client_stream.read(&mut buffer).await?;

    if n == 0 {
        return Ok(());
    }

    let request = String::from_utf8_lossy(&buffer[..n]);

    // 检查是否是CONNECT请求(HTTPS隧道)
    if request.starts_with("CONNECT ") {
        println!("收到HTTPS CONNECT请求:\n{}", request);
        return handle_connect_request(client_stream, &request).await;
    } else {
        println!("收到HTTP请求:\n{}", request);
        return handle_http_request(client_stream, &buffer[..n], &request).await;
    }
}

/// 处理HTTPS CONNECT请求
async fn handle_connect_request(mut client_stream: TcpStream, request: &str) -> io::Result<()> {
    // 解析CONNECT请求中的目标地址
    let target_addr = parse_connect_target(request).unwrap_or_else(|| {
        println!("无法解析CONNECT目标地址");
        "example.com:443".to_string()
    });

    println!("HTTPS隧道目标: {}", target_addr);

    // 连接到目标服务器
    let mut target_stream = match TcpStream::connect(&target_addr).await {
        Ok(stream) => stream,
        Err(e) => {
            eprintln!("连接目标服务器失败: {}", e);
            let error_response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
            client_stream.write_all(error_response.as_bytes()).await?;
            return Ok(());
        }
    };

    // 告诉客户端连接已建立
    let success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
    client_stream.write_all(success_response.as_bytes()).await?;

    println!("HTTPS隧道已建立,开始转发数据");

    // 在客户端和目标服务器之间双向转发数据
    let (mut client_read, mut client_write) = client_stream.split();
    let (mut target_read, mut target_write) = target_stream.split();

    let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
    let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);

    // 同时处理两个方向的流量
    match tokio::try_join!(client_to_target, target_to_client) {
        Ok((sent, received)) => {
            println!("HTTPS隧道转发完成: 发送 {} 字节, 接收 {} 字节", sent, received);
        }
        Err(e) => {
            eprintln!("HTTPS隧道转发时出错: {}", e);
        }
    }

    Ok(())
}

/// 处理普通HTTP请求
async fn handle_http_request(
    mut client_stream: TcpStream,
    request_data: &[u8],
    request: &str,
) -> io::Result<()> {
    // 解析目标地址
    let target_addr = parse_http_target(request);

    println!("HTTP请求目标: {}", target_addr);

    // 连接到目标服务器
    let mut target_stream = match TcpStream::connect(&target_addr).await {
        Ok(stream) => stream,
        Err(e) => {
            eprintln!("连接目标服务器失败: {}", e);
            let error_response = "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\n\r\n无法连接到目标服务器";
            client_stream.write_all(error_response.as_bytes()).await?;
            return Ok(());
        }
    };

    // 转发请求到目标服务器
    target_stream.write_all(request_data).await?;

    // 双向转发数据
    let (mut client_read, mut client_write) = client_stream.split();
    let (mut target_read, mut target_write) = target_stream.split();

    let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
    let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);

    // 同时处理两个方向的流量
    match tokio::try_join!(client_to_target, target_to_client) {
        Ok((sent, received)) => {
            println!("HTTP转发完成: 发送 {} 字节, 接收 {} 字节", sent, received);
        }
        Err(e) => {
            eprintln!("HTTP转发时出错: {}", e);
        }
    }

    Ok(())
}

/// 解析CONNECT请求中的目标地址
fn parse_connect_target(request: &str) -> Option<String> {
    for line in request.lines() {
        if line.starts_with("CONNECT ") {
            let parts: Vec<&str> = line.split_whitespace().collect();
            if parts.len() >= 2 {
                let target = parts[1];
                // CONNECT请求格式: CONNECT host:port HTTP/1.1
                // 确保有端口号
                if !target.contains(':') {
                    return Some(format!("{}:443", target));
                }
                return Some(target.to_string());
            }
        }
    }
    None
}

/// 解析HTTP请求中的目标地址
fn parse_http_target(request: &str) -> String {
    let mut target_host = "localhost";
    let mut target_port = 80;

    // 首先尝试从Host头解析
    for line in request.lines() {
        if line.to_lowercase().starts_with("host: ") {
            let host = line[6..].trim();
            target_host = host;

            // 检查是否有端口号
            if let Some(pos) = host.find(':') {
                target_host = &host[..pos];
                if let Ok(port) = host[pos + 1..].parse::<u16>() {
                    target_port = port;
                }
            }
            break;
        }
    }

    // 检查请求行中的URL协议
    for line in request.lines() {
        if line.starts_with("GET ") || line.starts_with("POST ") || line.starts_with("HEAD ") {
            let parts: Vec<&str> = line.split_whitespace().collect();
            if parts.len() >= 2 {
                let url = parts[1];
                if url.starts_with("https://") {
                    target_port = 443; // HTTPS默认端口

                    let url = &url[8..]; // 去掉 "https://"
                    if let Some(pos) = url.find('/') {
                        let host_port = &url[..pos];
                        target_host = host_port;

                        // 检查是否有自定义端口
                        if let Some(port_pos) = host_port.find(':') {
                            target_host = &host_port[..port_pos];
                            if let Ok(port) = host_port[port_pos + 1..].parse::<u16>() {
                                target_port = port;
                            }
                        }
                    }
                } else if url.starts_with("http://") {
                    target_port = 80; // HTTP默认端口

                    let url = &url[7..]; // 去掉 "http://"
                    if let Some(pos) = url.find('/') {
                        let host_port = &url[..pos];
                        target_host = host_port;

                        // 检查是否有自定义端口
                        if let Some(port_pos) = host_port.find(':') {
                            target_host = &host_port[..port_pos];
                            if let Ok(port) = host_port[port_pos + 1..].parse::<u16>() {
                                target_port = port;
                            }
                        }
                    }
                }
            }
            break;
        }
    }

    format!("{}:{}", target_host, target_port)
}

#[tokio::main]
async fn main() -> io::Result<()> {
    // 解析命令行参数
    let args = Args::parse();

    println!("启动HTTP代理服务器...");
    run_proxy(&args.listen).await
}