1use crate::{ConnectionPool, Result};
4use hyper::server::conn::http1;
5use hyper::service::service_fn;
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use bytes::Bytes;
14use http_body_util::{Full, BodyExt};
15use log::{info, error, debug};
16
17#[derive(Debug, Clone)]
19pub struct ProxyConfig {
20 pub addr: SocketAddr,
22 pub verbose: bool,
24}
25
26impl ProxyConfig {
27 pub fn new(addr: SocketAddr, verbose: bool) -> Self {
29 Self { addr, verbose }
30 }
31
32 pub fn localhost(port: u16, verbose: bool) -> Self {
34 Self {
35 addr: format!("127.0.0.1:{}", port).parse().unwrap(),
36 verbose,
37 }
38 }
39}
40
41pub struct ProxyServer {
43 config: ProxyConfig,
44 connection_pool: Arc<ConnectionPool>,
45 total_connections: Arc<AtomicU64>,
46 total_requests: Arc<AtomicU64>,
47}
48
49impl ProxyServer {
50 pub fn new(config: ProxyConfig) -> Self {
52 Self {
53 config,
54 connection_pool: Arc::new(ConnectionPool::new()),
55 total_connections: Arc::new(AtomicU64::new(0)),
56 total_requests: Arc::new(AtomicU64::new(0)),
57 }
58 }
59
60 pub fn with_pool(config: ProxyConfig, pool: ConnectionPool) -> Self {
62 Self {
63 config,
64 connection_pool: Arc::new(pool),
65 total_connections: Arc::new(AtomicU64::new(0)),
66 total_requests: Arc::new(AtomicU64::new(0)),
67 }
68 }
69
70 pub async fn run(&self) -> Result<()> {
72 let listener = TcpListener::bind(self.config.addr).await?;
73 info!("HTTP Proxy Server listening on http://{}", self.config.addr);
74
75 loop {
76 let (stream, remote_addr) = listener.accept().await?;
77 let connections = Arc::clone(&self.total_connections);
78 let requests = Arc::clone(&self.total_requests);
79 let pool = Arc::clone(&self.connection_pool);
80
81 connections.fetch_add(1, Ordering::Relaxed);
82 info!("Accepted connection from {} (total: {})",
83 remote_addr, connections.load(Ordering::Relaxed));
84
85 let verbose_clone = self.config.verbose;
86 tokio::task::spawn(async move {
87 if let Err(err) = self::handle_connection(
88 stream,
89 requests,
90 pool,
91 verbose_clone
92 ).await {
93 error!("Failed to handle connection: {:?}", err);
94 }
95 connections.fetch_sub(1, Ordering::Relaxed);
96 });
97 }
98 }
99
100 pub fn total_connections(&self) -> u64 {
102 self.total_connections.load(Ordering::Relaxed)
103 }
104
105 pub fn total_requests(&self) -> u64 {
107 self.total_requests.load(Ordering::Relaxed)
108 }
109
110 pub fn connection_pool(&self) -> &Arc<ConnectionPool> {
112 &self.connection_pool
113 }
114}
115
116async fn handle_connection(
117 stream: TcpStream,
118 requests: Arc<AtomicU64>,
119 pool: Arc<ConnectionPool>,
120 verbose: bool,
121) -> Result<()> {
122 let mut buffer = [0u8; 4096];
123 let n = stream.peek(&mut buffer).await?;
124
125 if n == 0 {
126 return Ok(());
127 }
128
129 let request_str = String::from_utf8_lossy(&buffer[..n]);
130
131 if request_str.starts_with("CONNECT ") {
132 handle_https_tunnel(stream, requests, pool, verbose).await
133 } else {
134 requests.fetch_add(1, Ordering::Relaxed);
135 if verbose {
136 debug!("HTTP request (total requests: {})", requests.load(Ordering::Relaxed));
137 }
138
139 let io = TokioIo::new(stream);
140 let service = service_fn(move |req| {
141 let pool_clone = Arc::clone(&pool);
142 handle_http_request(req, pool_clone, verbose)
143 });
144
145 if let Err(err) = http1::Builder::new()
146 .serve_connection(io, service)
147 .await
148 {
149 error!("Failed to serve HTTP connection: {:?}", err);
150 }
151 Ok(())
152 }
153}
154
155async fn handle_https_tunnel(
156 mut client_stream: TcpStream,
157 requests: Arc<AtomicU64>,
158 pool: Arc<ConnectionPool>,
159 verbose: bool,
160) -> Result<()> {
161 let mut buffer = [0u8; 4096];
162 let n = client_stream.read(&mut buffer).await?;
163
164 if n == 0 {
165 return Ok(());
166 }
167
168 let request_str = String::from_utf8_lossy(&buffer[..n]);
169 let lines: Vec<&str> = request_str.lines().collect();
170
171 if lines.is_empty() {
172 return Ok(());
173 }
174
175 let connect_line = lines[0];
176 if !connect_line.starts_with("CONNECT ") {
177 return Ok(());
178 }
179
180 let parts: Vec<&str> = connect_line.split_whitespace().collect();
181 if parts.len() < 3 {
182 return Ok(());
183 }
184
185 let authority = parts[1];
186 let host_port: Vec<&str> = authority.split(':').collect();
187 if host_port.len() != 2 {
188 let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
189 client_stream.write_all(response.as_bytes()).await?;
190 return Ok(());
191 }
192
193 requests.fetch_add(1, Ordering::Relaxed);
195 if verbose {
196 debug!("HTTPS tunnel request to {} (total requests: {})",
197 authority, requests.load(Ordering::Relaxed));
198 }
199
200 let host = host_port[0];
201 let port: u16 = match host_port[1].parse() {
202 Ok(p) => p,
203 Err(_) => {
204 let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
205 client_stream.write_all(response.as_bytes()).await?;
206 return Ok(());
207 }
208 };
209
210 let target_addr = format!("{}:{}", host, port);
211
212 let target_stream = match pool.get_or_create(&target_addr).await {
214 Ok(stream) => stream,
215 Err(e) => {
216 error!("Failed to connect to target {}: {}", target_addr, e);
217 let response = "HTTP/1.1 502 Bad Gateway\r\n\r\n";
218 client_stream.write_all(response.as_bytes()).await?;
219 return Ok(());
220 }
221 };
222
223 info!("HTTPS tunnel established to {} (requests: {})",
224 target_addr, requests.load(Ordering::Relaxed));
225
226 let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
227 client_stream.write_all(response.as_bytes()).await?;
228 client_stream.flush().await?;
229
230 let (mut client_read, mut client_write) = tokio::io::split(client_stream);
232 let (mut target_read, mut target_write) = tokio::io::split(target_stream);
233
234 let client_to_target = tokio::spawn(async move {
235 let mut buffer = vec![0u8; 8192];
236 loop {
237 match client_read.read(&mut buffer).await {
238 Ok(0) => break,
239 Ok(n) => {
240 if let Err(e) = target_write.write_all(&buffer[..n]).await {
241 if verbose {
242 debug!("Client to target write error: {}", e);
243 }
244 break;
245 }
246 if let Err(e) = target_write.flush().await {
247 if verbose {
248 debug!("Client to target flush error: {}", e);
249 }
250 break;
251 }
252 }
253 Err(e) => {
254 if verbose {
255 debug!("Client to target read error: {}", e);
256 }
257 break;
258 }
259 }
260 }
261 });
262
263 let target_to_client = tokio::spawn(async move {
264 let mut buffer = vec![0u8; 8192];
265 loop {
266 match target_read.read(&mut buffer).await {
267 Ok(0) => break,
268 Ok(n) => {
269 if let Err(e) = client_write.write_all(&buffer[..n]).await {
270 if verbose {
271 debug!("Target to client write error: {}", e);
272 }
273 break;
274 }
275 if let Err(e) = client_write.flush().await {
276 if verbose {
277 debug!("Target to client flush error: {}", e);
278 }
279 break;
280 }
281 }
282 Err(e) => {
283 if verbose {
284 debug!("Target to client read error: {}", e);
285 }
286 break;
287 }
288 }
289 }
290 });
291
292 let _ = tokio::join!(client_to_target, target_to_client);
293
294 if verbose {
295 debug!("HTTPS tunnel closed for {}", target_addr);
296 }
297
298 Ok(())
299}
300
301async fn handle_http_request(
302 req: Request<hyper::body::Incoming>,
303 pool: Arc<ConnectionPool>,
304 verbose: bool,
305) -> Result<Response<Full<Bytes>>> {
306 if verbose {
307 debug!("HTTP proxy request to {}", req.uri());
308 }
309
310 let (parts, body) = req.into_parts();
311
312 let host = match parts.uri.host() {
313 Some(host) => host,
314 None => {
315 error!("Missing host in request URI");
316 return Ok(Response::builder()
317 .status(StatusCode::BAD_REQUEST)
318 .body(Full::new(Bytes::from("Bad Request: Missing host")))
319 .unwrap());
320 }
321 };
322
323 let port = parts.uri.port_u16().unwrap_or(80);
324 let path = parts.uri.path();
325 let query = parts.uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
326 let method = parts.method.clone();
327 let headers = parts.headers.clone();
328
329 let target_addr = format!("{}:{}", host, port);
330
331 let mut target_stream = match pool.get_or_create(&target_addr).await {
333 Ok(stream) => stream,
334 Err(e) => {
335 error!("Failed to connect to HTTP target {}: {}", target_addr, e);
336 return Ok(Response::builder()
337 .status(StatusCode::BAD_GATEWAY)
338 .body(Full::new(Bytes::from("Bad Gateway")))
339 .unwrap());
340 }
341 };
342
343 let body_bytes = match body.collect().await {
344 Ok(collected) => collected.to_bytes(),
345 Err(e) => {
346 error!("Failed to collect request body: {}", e);
347 return Ok(Response::builder()
348 .status(StatusCode::BAD_REQUEST)
349 .body(Full::new(Bytes::from("Bad Request")))
350 .unwrap());
351 }
352 };
353
354 let request_line = format!("{} {}{} HTTP/1.1\r\n", method, path, query);
355
356 if let Err(e) = target_stream.write_all(request_line.as_bytes()).await {
357 error!("Failed to write request line: {}", e);
358 return Ok(Response::builder()
359 .status(StatusCode::BAD_GATEWAY)
360 .body(Full::new(Bytes::from("Bad Gateway")))
361 .unwrap());
362 }
363
364 for (key, value) in headers {
365 if let Some(key_str) = key {
366 let key_name = key_str.as_str();
367 if key_name.to_lowercase() != "proxy-connection" && key_name.to_lowercase() != "connection" {
368 let header_line = format!("{}: {}\r\n", key_name, value.to_str().unwrap_or(""));
369 if let Err(e) = target_stream.write_all(header_line.as_bytes()).await {
370 error!("Failed to write header: {}", e);
371 return Ok(Response::builder()
372 .status(StatusCode::BAD_GATEWAY)
373 .body(Full::new(Bytes::from("Bad Gateway")))
374 .unwrap());
375 }
376 }
377 }
378 }
379
380 if let Err(e) = target_stream.write_all(b"Connection: close\r\n\r\n").await {
381 error!("Failed to write header end: {}", e);
382 return Ok(Response::builder()
383 .status(StatusCode::BAD_GATEWAY)
384 .body(Full::new(Bytes::from("Bad Gateway")))
385 .unwrap());
386 }
387
388 if !body_bytes.is_empty() {
389 if let Err(e) = target_stream.write_all(&body_bytes).await {
390 error!("Failed to write body: {}", e);
391 return Ok(Response::builder()
392 .status(StatusCode::BAD_GATEWAY)
393 .body(Full::new(Bytes::from("Bad Gateway")))
394 .unwrap());
395 }
396 }
397
398 if let Err(e) = target_stream.flush().await {
399 error!("Failed to flush: {}", e);
400 return Ok(Response::builder()
401 .status(StatusCode::BAD_GATEWAY)
402 .body(Full::new(Bytes::from("Bad Gateway")))
403 .unwrap());
404 }
405
406 let mut response_buffer = Vec::new();
407 if let Err(e) = target_stream.read_to_end(&mut response_buffer).await {
408 error!("Failed to read response: {}", e);
409 return Ok(Response::builder()
410 .status(StatusCode::BAD_GATEWAY)
411 .body(Full::new(Bytes::from("Bad Gateway")))
412 .unwrap());
413 }
414
415 pool.put(target_addr, target_stream).await;
419
420 Ok(Response::builder()
421 .status(StatusCode::OK)
422 .body(Full::new(Bytes::from(response_buffer)))
423 .unwrap())
424}