use crate::{ConnectionPool, Result};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use bytes::Bytes;
use http_body_util::{Full, BodyExt};
use log::{info, error, debug};
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub addr: SocketAddr,
pub verbose: bool,
}
impl ProxyConfig {
pub fn new(addr: SocketAddr, verbose: bool) -> Self {
Self { addr, verbose }
}
pub fn localhost(port: u16, verbose: bool) -> Self {
Self {
addr: format!("127.0.0.1:{}", port).parse().unwrap(),
verbose,
}
}
}
pub struct ProxyServer {
config: ProxyConfig,
connection_pool: Arc<ConnectionPool>,
total_connections: Arc<AtomicU64>,
total_requests: Arc<AtomicU64>,
}
impl ProxyServer {
pub fn new(config: ProxyConfig) -> Self {
Self {
config,
connection_pool: Arc::new(ConnectionPool::new()),
total_connections: Arc::new(AtomicU64::new(0)),
total_requests: Arc::new(AtomicU64::new(0)),
}
}
pub fn with_pool(config: ProxyConfig, pool: ConnectionPool) -> Self {
Self {
config,
connection_pool: Arc::new(pool),
total_connections: Arc::new(AtomicU64::new(0)),
total_requests: Arc::new(AtomicU64::new(0)),
}
}
pub async fn run(&self) -> Result<()> {
let listener = TcpListener::bind(self.config.addr).await?;
info!("HTTP Proxy Server listening on http://{}", self.config.addr);
loop {
let (stream, remote_addr) = listener.accept().await?;
let connections = Arc::clone(&self.total_connections);
let requests = Arc::clone(&self.total_requests);
let pool = Arc::clone(&self.connection_pool);
connections.fetch_add(1, Ordering::Relaxed);
info!("Accepted connection from {} (total: {})",
remote_addr, connections.load(Ordering::Relaxed));
let verbose_clone = self.config.verbose;
tokio::task::spawn(async move {
if let Err(err) = self::handle_connection(
stream,
requests,
pool,
verbose_clone
).await {
error!("Failed to handle connection: {:?}", err);
}
connections.fetch_sub(1, Ordering::Relaxed);
});
}
}
pub fn total_connections(&self) -> u64 {
self.total_connections.load(Ordering::Relaxed)
}
pub fn total_requests(&self) -> u64 {
self.total_requests.load(Ordering::Relaxed)
}
pub fn connection_pool(&self) -> &Arc<ConnectionPool> {
&self.connection_pool
}
}
async fn handle_connection(
stream: TcpStream,
requests: Arc<AtomicU64>,
pool: Arc<ConnectionPool>,
verbose: bool,
) -> Result<()> {
let mut buffer = [0u8; 4096];
let n = stream.peek(&mut buffer).await?;
if n == 0 {
return Ok(());
}
let request_str = String::from_utf8_lossy(&buffer[..n]);
if request_str.starts_with("CONNECT ") {
handle_https_tunnel(stream, requests, pool, verbose).await
} else {
requests.fetch_add(1, Ordering::Relaxed);
if verbose {
debug!("HTTP request (total requests: {})", requests.load(Ordering::Relaxed));
}
let io = TokioIo::new(stream);
let service = service_fn(move |req| {
let pool_clone = Arc::clone(&pool);
handle_http_request(req, pool_clone, verbose)
});
if let Err(err) = http1::Builder::new()
.serve_connection(io, service)
.await
{
error!("Failed to serve HTTP connection: {:?}", err);
}
Ok(())
}
}
async fn handle_https_tunnel(
mut client_stream: TcpStream,
requests: Arc<AtomicU64>,
pool: Arc<ConnectionPool>,
verbose: bool,
) -> Result<()> {
let mut buffer = [0u8; 4096];
let n = client_stream.read(&mut buffer).await?;
if n == 0 {
return Ok(());
}
let request_str = String::from_utf8_lossy(&buffer[..n]);
let lines: Vec<&str> = request_str.lines().collect();
if lines.is_empty() {
return Ok(());
}
let connect_line = lines[0];
if !connect_line.starts_with("CONNECT ") {
return Ok(());
}
let parts: Vec<&str> = connect_line.split_whitespace().collect();
if parts.len() < 3 {
return Ok(());
}
let authority = parts[1];
let host_port: Vec<&str> = authority.split(':').collect();
if host_port.len() != 2 {
let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
client_stream.write_all(response.as_bytes()).await?;
return Ok(());
}
requests.fetch_add(1, Ordering::Relaxed);
if verbose {
debug!("HTTPS tunnel request to {} (total requests: {})",
authority, requests.load(Ordering::Relaxed));
}
let host = host_port[0];
let port: u16 = match host_port[1].parse() {
Ok(p) => p,
Err(_) => {
let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
client_stream.write_all(response.as_bytes()).await?;
return Ok(());
}
};
let target_addr = format!("{}:{}", host, port);
let target_stream = match pool.get_or_create(&target_addr).await {
Ok(stream) => stream,
Err(e) => {
error!("Failed to connect to target {}: {}", target_addr, e);
let response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
client_stream.write_all(response.as_bytes()).await?;
return Ok(());
}
};
info!("HTTPS tunnel established to {} (requests: {})",
target_addr, requests.load(Ordering::Relaxed));
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
client_stream.write_all(response.as_bytes()).await?;
client_stream.flush().await?;
let (mut client_read, mut client_write) = tokio::io::split(client_stream);
let (mut target_read, mut target_write) = tokio::io::split(target_stream);
let client_to_target = tokio::spawn(async move {
let mut buffer = vec![0u8; 8192];
loop {
match client_read.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
if let Err(e) = target_write.write_all(&buffer[..n]).await {
if verbose {
debug!("Client to target write error: {}", e);
}
break;
}
if let Err(e) = target_write.flush().await {
if verbose {
debug!("Client to target flush error: {}", e);
}
break;
}
}
Err(e) => {
if verbose {
debug!("Client to target read error: {}", e);
}
break;
}
}
}
});
let target_to_client = tokio::spawn(async move {
let mut buffer = vec![0u8; 8192];
loop {
match target_read.read(&mut buffer).await {
Ok(0) => break,
Ok(n) => {
if let Err(e) = client_write.write_all(&buffer[..n]).await {
if verbose {
debug!("Target to client write error: {}", e);
}
break;
}
if let Err(e) = client_write.flush().await {
if verbose {
debug!("Target to client flush error: {}", e);
}
break;
}
}
Err(e) => {
if verbose {
debug!("Target to client read error: {}", e);
}
break;
}
}
}
});
let _ = tokio::join!(client_to_target, target_to_client);
if verbose {
debug!("HTTPS tunnel closed for {}", target_addr);
}
Ok(())
}
async fn handle_http_request(
req: Request<hyper::body::Incoming>,
pool: Arc<ConnectionPool>,
verbose: bool,
) -> Result<Response<Full<Bytes>>> {
if verbose {
debug!("HTTP proxy request to {}", req.uri());
}
let (parts, body) = req.into_parts();
let host = match parts.uri.host() {
Some(host) => host,
None => {
error!("Missing host in request URI");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Missing host")))
.unwrap());
}
};
let port = parts.uri.port_u16().unwrap_or(80);
let path = parts.uri.path();
let query = parts.uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
let method = parts.method.clone();
let headers = parts.headers.clone();
let target_addr = format!("{}:{}", host, port);
let mut target_stream = match pool.get_or_create(&target_addr).await {
Ok(stream) => stream,
Err(e) => {
error!("Failed to connect to HTTP target {}: {}", target_addr, e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
};
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!("Failed to collect request body: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request")))
.unwrap());
}
};
let request_line = format!("{} {}{} HTTP/1.1\r\n", method, path, query);
if let Err(e) = target_stream.write_all(request_line.as_bytes()).await {
error!("Failed to write request line: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
for (key, value) in headers {
if let Some(key_str) = key {
let key_name = key_str.as_str();
if key_name.to_lowercase() != "proxy-connection" && key_name.to_lowercase() != "connection" {
let header_line = format!("{}: {}\r\n", key_name, value.to_str().unwrap_or(""));
if let Err(e) = target_stream.write_all(header_line.as_bytes()).await {
error!("Failed to write header: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
}
}
}
if let Err(e) = target_stream.write_all(b"Connection: close\r\n\r\n").await {
error!("Failed to write header end: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
if !body_bytes.is_empty() {
if let Err(e) = target_stream.write_all(&body_bytes).await {
error!("Failed to write body: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
}
if let Err(e) = target_stream.flush().await {
error!("Failed to flush: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
let mut response_buffer = Vec::new();
if let Err(e) = target_stream.read_to_end(&mut response_buffer).await {
error!("Failed to read response: {}", e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Bad Gateway")))
.unwrap());
}
pool.put(target_addr, target_stream).await;
Ok(Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::from(response_buffer)))
.unwrap())
}