easy_http_proxy_server/
server.rs

1//! Main proxy server implementation
2
3use crate::{ConnectionPool, Result};
4use hyper::server::conn::http1;
5use hyper::service::service_fn;
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use bytes::Bytes;
14use http_body_util::{Full, BodyExt};
15use log::{info, error, debug};
16
17/// Configuration for the proxy server
18#[derive(Debug, Clone)]
19pub struct ProxyConfig {
20    /// Address to bind the server to
21    pub addr: SocketAddr,
22    /// Enable verbose logging
23    pub verbose: bool,
24}
25
26impl ProxyConfig {
27    /// Create a new proxy configuration
28    pub fn new(addr: SocketAddr, verbose: bool) -> Self {
29        Self { addr, verbose }
30    }
31
32    /// Create a configuration with default address
33    pub fn localhost(port: u16, verbose: bool) -> Self {
34        Self {
35            addr: format!("127.0.0.1:{}", port).parse().unwrap(),
36            verbose,
37        }
38    }
39}
40
41/// HTTP/HTTPS proxy server
42pub struct ProxyServer {
43    config: ProxyConfig,
44    connection_pool: Arc<ConnectionPool>,
45    total_connections: Arc<AtomicU64>,
46    total_requests: Arc<AtomicU64>,
47}
48
49impl ProxyServer {
50    /// Create a new proxy server with the given configuration
51    pub fn new(config: ProxyConfig) -> Self {
52        Self {
53            config,
54            connection_pool: Arc::new(ConnectionPool::new()),
55            total_connections: Arc::new(AtomicU64::new(0)),
56            total_requests: Arc::new(AtomicU64::new(0)),
57        }
58    }
59
60    /// Create a new proxy server with custom connection pool
61    pub fn with_pool(config: ProxyConfig, pool: ConnectionPool) -> Self {
62        Self {
63            config,
64            connection_pool: Arc::new(pool),
65            total_connections: Arc::new(AtomicU64::new(0)),
66            total_requests: Arc::new(AtomicU64::new(0)),
67        }
68    }
69
70    /// Run the proxy server
71    pub async fn run(&self) -> Result<()> {
72        let listener = TcpListener::bind(self.config.addr).await?;
73        info!("HTTP Proxy Server listening on http://{}", self.config.addr);
74
75        loop {
76            let (stream, remote_addr) = listener.accept().await?;
77            let connections = Arc::clone(&self.total_connections);
78            let requests = Arc::clone(&self.total_requests);
79            let pool = Arc::clone(&self.connection_pool);
80
81            connections.fetch_add(1, Ordering::Relaxed);
82            info!("Accepted connection from {} (total: {})", 
83                  remote_addr, connections.load(Ordering::Relaxed));
84
85            let verbose_clone = self.config.verbose;
86            tokio::task::spawn(async move {
87                if let Err(err) = self::handle_connection(
88                    stream, 
89                    requests, 
90                    pool,
91                    verbose_clone
92                ).await {
93                    error!("Failed to handle connection: {:?}", err);
94                }
95                connections.fetch_sub(1, Ordering::Relaxed);
96            });
97        }
98    }
99
100    /// Get total connections handled
101    pub fn total_connections(&self) -> u64 {
102        self.total_connections.load(Ordering::Relaxed)
103    }
104
105    /// Get total requests handled
106    pub fn total_requests(&self) -> u64 {
107        self.total_requests.load(Ordering::Relaxed)
108    }
109
110    /// Get connection pool reference
111    pub fn connection_pool(&self) -> &Arc<ConnectionPool> {
112        &self.connection_pool
113    }
114}
115
116async fn handle_connection(
117    stream: TcpStream, 
118    requests: Arc<AtomicU64>,
119    pool: Arc<ConnectionPool>,
120    verbose: bool,
121) -> Result<()> {
122    let mut buffer = [0u8; 4096];
123    let n = stream.peek(&mut buffer).await?;
124    
125    if n == 0 {
126        return Ok(());
127    }
128    
129    let request_str = String::from_utf8_lossy(&buffer[..n]);
130    
131    if request_str.starts_with("CONNECT ") {
132        handle_https_tunnel(stream, requests, pool, verbose).await
133    } else {
134        requests.fetch_add(1, Ordering::Relaxed);
135        if verbose {
136            debug!("HTTP request (total requests: {})", requests.load(Ordering::Relaxed));
137        }
138        
139        let io = TokioIo::new(stream);
140        let service = service_fn(move |req| {
141            let pool_clone = Arc::clone(&pool);
142            handle_http_request(req, pool_clone, verbose)
143        });
144        
145        if let Err(err) = http1::Builder::new()
146            .serve_connection(io, service)
147            .await
148        {
149            error!("Failed to serve HTTP connection: {:?}", err);
150        }
151        Ok(())
152    }
153}
154
155async fn handle_https_tunnel(
156    mut client_stream: TcpStream, 
157    requests: Arc<AtomicU64>,
158    pool: Arc<ConnectionPool>,
159    verbose: bool,
160) -> Result<()> {
161    let mut buffer = [0u8; 4096];
162    let n = client_stream.read(&mut buffer).await?;
163    
164    if n == 0 {
165        return Ok(());
166    }
167    
168    let request_str = String::from_utf8_lossy(&buffer[..n]);
169    let lines: Vec<&str> = request_str.lines().collect();
170    
171    if lines.is_empty() {
172        return Ok(());
173    }
174    
175    let connect_line = lines[0];
176    if !connect_line.starts_with("CONNECT ") {
177        return Ok(());
178    }
179    
180    let parts: Vec<&str> = connect_line.split_whitespace().collect();
181    if parts.len() < 3 {
182        return Ok(());
183    }
184    
185    let authority = parts[1];
186    let host_port: Vec<&str> = authority.split(':').collect();
187    if host_port.len() != 2 {
188        let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
189        client_stream.write_all(response.as_bytes()).await?;
190        return Ok(());
191    }
192    
193    // Increase request count
194    requests.fetch_add(1, Ordering::Relaxed);
195    if verbose {
196        debug!("HTTPS tunnel request to {} (total requests: {})", 
197               authority, requests.load(Ordering::Relaxed));
198    }
199    
200    let host = host_port[0];
201    let port: u16 = match host_port[1].parse() {
202        Ok(p) => p,
203        Err(_) => {
204            let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
205            client_stream.write_all(response.as_bytes()).await?;
206            return Ok(());
207        }
208    };
209    
210    let target_addr = format!("{}:{}", host, port);
211    
212    // Try to get connection from pool
213    let target_stream = match pool.get_or_create(&target_addr).await {
214        Ok(stream) => stream,
215        Err(e) => {
216            error!("Failed to connect to target {}: {}", target_addr, e);
217            let response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
218            client_stream.write_all(response.as_bytes()).await?;
219            return Ok(());
220        }
221    };
222    
223    info!("HTTPS tunnel established to {} (requests: {})", 
224          target_addr, requests.load(Ordering::Relaxed));
225    
226    let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
227    client_stream.write_all(response.as_bytes()).await?;
228    client_stream.flush().await?;
229    
230    // Create bidirectional tunnel
231    let (mut client_read, mut client_write) = tokio::io::split(client_stream);
232    let (mut target_read, mut target_write) = tokio::io::split(target_stream);
233    
234    let client_to_target = tokio::spawn(async move {
235        let mut buffer = vec![0u8; 8192];
236        loop {
237            match client_read.read(&mut buffer).await {
238                Ok(0) => break,
239                Ok(n) => {
240                    if let Err(e) = target_write.write_all(&buffer[..n]).await {
241                        if verbose {
242                            debug!("Client to target write error: {}", e);
243                        }
244                        break;
245                    }
246                    if let Err(e) = target_write.flush().await {
247                        if verbose {
248                            debug!("Client to target flush error: {}", e);
249                        }
250                        break;
251                    }
252                }
253                Err(e) => {
254                    if verbose {
255                        debug!("Client to target read error: {}", e);
256                    }
257                    break;
258                }
259            }
260        }
261    });
262    
263    let target_to_client = tokio::spawn(async move {
264        let mut buffer = vec![0u8; 8192];
265        loop {
266            match target_read.read(&mut buffer).await {
267                Ok(0) => break,
268                Ok(n) => {
269                    if let Err(e) = client_write.write_all(&buffer[..n]).await {
270                        if verbose {
271                            debug!("Target to client write error: {}", e);
272                        }
273                        break;
274                    }
275                    if let Err(e) = client_write.flush().await {
276                        if verbose {
277                            debug!("Target to client flush error: {}", e);
278                        }
279                        break;
280                    }
281                }
282                Err(e) => {
283                    if verbose {
284                        debug!("Target to client read error: {}", e);
285                    }
286                    break;
287                }
288            }
289        }
290    });
291    
292    let _ = tokio::join!(client_to_target, target_to_client);
293    
294    if verbose {
295        debug!("HTTPS tunnel closed for {}", target_addr);
296    }
297    
298    Ok(())
299}
300
301async fn handle_http_request(
302    req: Request<hyper::body::Incoming>,
303    pool: Arc<ConnectionPool>,
304    verbose: bool,
305) -> Result<Response<Full<Bytes>>> {
306    if verbose {
307        debug!("HTTP proxy request to {}", req.uri());
308    }
309    
310    let (parts, body) = req.into_parts();
311    
312    let host = match parts.uri.host() {
313        Some(host) => host,
314        None => {
315            error!("Missing host in request URI");
316            return Ok(Response::builder()
317                .status(StatusCode::BAD_REQUEST)
318                .body(Full::new(Bytes::from("Bad Request: Missing host")))
319                .unwrap());
320        }
321    };
322    
323    let port = parts.uri.port_u16().unwrap_or(80);
324    let path = parts.uri.path();
325    let query = parts.uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
326    let method = parts.method.clone();
327    let headers = parts.headers.clone();
328    
329    let target_addr = format!("{}:{}", host, port);
330    
331    // Get or create connection
332    let mut target_stream = match pool.get_or_create(&target_addr).await {
333        Ok(stream) => stream,
334        Err(e) => {
335            error!("Failed to connect to HTTP target {}: {}", target_addr, e);
336            return Ok(Response::builder()
337                .status(StatusCode::BAD_GATEWAY)
338                .body(Full::new(Bytes::from("Bad Gateway")))
339                .unwrap());
340        }
341    };
342    
343    let body_bytes = match body.collect().await {
344        Ok(collected) => collected.to_bytes(),
345        Err(e) => {
346            error!("Failed to collect request body: {}", e);
347            return Ok(Response::builder()
348                .status(StatusCode::BAD_REQUEST)
349                .body(Full::new(Bytes::from("Bad Request")))
350                .unwrap());
351        }
352    };
353    
354    let request_line = format!("{} {}{} HTTP/1.1\r\n", method, path, query);
355    
356    if let Err(e) = target_stream.write_all(request_line.as_bytes()).await {
357        error!("Failed to write request line: {}", e);
358        return Ok(Response::builder()
359            .status(StatusCode::BAD_GATEWAY)
360            .body(Full::new(Bytes::from("Bad Gateway")))
361            .unwrap());
362    }
363    
364    for (key, value) in headers {
365        if let Some(key_str) = key {
366            let key_name = key_str.as_str();
367            if key_name.to_lowercase() != "proxy-connection" && key_name.to_lowercase() != "connection" {
368                let header_line = format!("{}: {}\r\n", key_name, value.to_str().unwrap_or(""));
369                if let Err(e) = target_stream.write_all(header_line.as_bytes()).await {
370                    error!("Failed to write header: {}", e);
371                    return Ok(Response::builder()
372                        .status(StatusCode::BAD_GATEWAY)
373                        .body(Full::new(Bytes::from("Bad Gateway")))
374                        .unwrap());
375                }
376            }
377        }
378    }
379    
380    if let Err(e) = target_stream.write_all(b"Connection: close\r\n\r\n").await {
381        error!("Failed to write header end: {}", e);
382        return Ok(Response::builder()
383            .status(StatusCode::BAD_GATEWAY)
384            .body(Full::new(Bytes::from("Bad Gateway")))
385            .unwrap());
386    }
387    
388    if !body_bytes.is_empty() {
389        if let Err(e) = target_stream.write_all(&body_bytes).await {
390            error!("Failed to write body: {}", e);
391            return Ok(Response::builder()
392                .status(StatusCode::BAD_GATEWAY)
393                .body(Full::new(Bytes::from("Bad Gateway")))
394                .unwrap());
395        }
396    }
397    
398    if let Err(e) = target_stream.flush().await {
399        error!("Failed to flush: {}", e);
400        return Ok(Response::builder()
401            .status(StatusCode::BAD_GATEWAY)
402            .body(Full::new(Bytes::from("Bad Gateway")))
403            .unwrap());
404    }
405    
406    let mut response_buffer = Vec::new();
407    if let Err(e) = target_stream.read_to_end(&mut response_buffer).await {
408        error!("Failed to read response: {}", e);
409        return Ok(Response::builder()
410            .status(StatusCode::BAD_GATEWAY)
411            .body(Full::new(Bytes::from("Bad Gateway")))
412            .unwrap());
413    }
414    
415    // Return connection to pool
416    // Note: This is a simplified implementation. In production, you'd want more
417    // sophisticated connection management, especially regarding the Connection: close header
418    pool.put(target_addr, target_stream).await;
419    
420    Ok(Response::builder()
421        .status(StatusCode::OK)
422        .body(Full::new(Bytes::from(response_buffer)))
423        .unwrap())
424}