1use crate::config::site::{ProxyConfig, ProxyRoute, SiteConfig, UpstreamConfig};
2use crate::handlers::websocket_proxy::WebSocketProxyHandler;
3use crate::middleware::compression::{CompressionMethod, CompressionMiddleware};
4use chrono;
5use log::{debug, error, info};
6use pingora::http::{RequestHeader, ResponseHeader};
7use pingora::prelude::*;
8use serde_json;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use url::Url;
13
14pub struct ProxyHandler {
15 proxy_config: ProxyConfig,
17 upstreams: HashMap<String, Vec<UpstreamConfig>>,
19 round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
21 connection_counts: HashMap<String, Arc<AtomicUsize>>,
23 websocket_handler: WebSocketProxyHandler,
25}
26
27impl ProxyHandler {
28 pub fn new(proxy_config: ProxyConfig) -> Self {
30 let mut upstreams = HashMap::new();
31 let mut round_robin_counters = HashMap::new();
32 let mut connection_counts = HashMap::new();
33
34 for upstream in &proxy_config.upstreams {
36 upstreams
37 .entry(upstream.name.clone())
38 .or_insert_with(Vec::new)
39 .push(upstream.clone());
40
41 round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
42
43 connection_counts.insert(upstream.url.clone(), Arc::new(AtomicUsize::new(0)));
44 }
45
46 Self {
47 proxy_config: proxy_config.clone(),
48 upstreams,
49 round_robin_counters,
50 connection_counts,
51 websocket_handler: WebSocketProxyHandler::new(proxy_config),
52 }
53 }
54
55 pub fn find_proxy_route(&self, path: &str) -> Option<&ProxyRoute> {
59 if !self.proxy_config.enabled {
60 return None;
61 }
62
63 self.proxy_config
65 .routes
66 .iter()
67 .filter(|route| path.starts_with(&route.path))
68 .max_by_key(|route| route.path.len())
69 }
70
71 pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
75 let upstream_servers = self
76 .upstreams
77 .get(upstream_name)
78 .ok_or_else(|| Error::new_str("Upstream not found"))?;
79
80 if upstream_servers.is_empty() {
81 return Err(Error::new_str("No servers available for upstream"));
82 }
83
84 let upstream = match self.proxy_config.load_balancing.method.as_str() {
85 "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
86 "weighted" => self.select_weighted(upstream_servers)?,
87 "least_connections" => self.select_least_connections(upstream_servers)?,
88 _ => &upstream_servers[0], };
90
91 Ok(upstream)
92 }
93
94 fn select_round_robin<'a>(
97 &self,
98 upstream_name: &str,
99 servers: &'a [UpstreamConfig],
100 ) -> Result<&'a UpstreamConfig> {
101 let counter = self
102 .round_robin_counters
103 .get(upstream_name)
104 .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
105
106 let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
107 Ok(&servers[index])
108 }
109
110 fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
113 let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
114 if total_weight == 0 {
115 return Ok(&servers[0]);
116 }
117
118 let random_weight = fastrand::u32(1..=total_weight);
119 let mut current_weight = 0;
120
121 for server in servers {
122 current_weight += server.weight;
123 if random_weight <= current_weight {
124 return Ok(server);
125 }
126 }
127
128 Ok(&servers[0])
129 }
130
131 fn select_least_connections<'a>(
134 &self,
135 servers: &'a [UpstreamConfig],
136 ) -> Result<&'a UpstreamConfig> {
137 let mut min_connections = usize::MAX;
139 let mut selected_server = &servers[0];
140
141 for server in servers {
142 let connections = self
143 .connection_counts
144 .get(&server.url)
145 .map(|c| c.load(Ordering::Relaxed))
146 .unwrap_or(0);
147
148 if connections < min_connections {
149 min_connections = connections;
150 selected_server = server;
151 }
152 }
153
154 info!(
155 "Selected server '{}' with {} connections",
156 selected_server.url, min_connections
157 );
158 Ok(selected_server)
159 }
160
161 fn increment_connections(&self, server_url: &str) {
163 if let Some(counter) = self.connection_counts.get(server_url) {
164 counter.fetch_add(1, Ordering::Relaxed);
165 }
166 }
167
168 fn decrement_connections(&self, server_url: &str) {
170 if let Some(counter) = self.connection_counts.get(server_url) {
171 counter.fetch_sub(1, Ordering::Relaxed);
172 }
173 }
174
175 pub fn get_upstream_url(&self, upstream: &UpstreamConfig) -> Result<Url> {
177 Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))
178 }
179
180 pub fn transform_path(&self, route: &ProxyRoute, original_path: &str) -> String {
182 let target_path = if route.strip_prefix {
183 original_path
184 .strip_prefix(&route.path)
185 .unwrap_or(original_path)
186 } else {
187 original_path
188 };
189
190 if let Some(rewrite_target) = &route.rewrite_target {
191 rewrite_target.clone()
192 } else {
193 format!("/{}", target_path.trim_start_matches('/'))
194 }
195 }
196
197 pub fn add_proxy_headers(
199 &self,
200 req: &mut RequestHeader,
201 session: &Session,
202 original_host: &str,
203 ) {
204 if self.proxy_config.headers.add_x_forwarded {
205 if let Some(client_addr) = session.client_addr() {
206 req.insert_header("X-Forwarded-For", client_addr.to_string())
207 .ok();
208 }
209 }
210
211 if self.proxy_config.headers.add_x_forwarded {
212 let proto = if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
213 "https"
214 } else {
215 "http"
216 };
217 req.insert_header("X-Forwarded-Proto", proto).ok();
218 }
219
220 if self.proxy_config.headers.add_x_forwarded {
221 req.insert_header("X-Forwarded-Host", original_host).ok();
222 }
223
224 if self.proxy_config.headers.add_forwarded {
225 if let Some(client_addr) = session.client_addr() {
226 let proto =
227 if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
228 "https"
229 } else {
230 "http"
231 };
232 let forwarded =
233 format!("for={};proto={};host={}", client_addr, proto, original_host);
234 req.insert_header("Forwarded", forwarded).ok();
235 }
236 }
237
238 for (key, value) in &self.proxy_config.headers.add {
240 req.insert_header(key.clone(), value.clone()).ok();
241 }
242 }
243
244 pub async fn handle_proxy_request(
246 &self,
247 session: &mut Session,
248 site: &SiteConfig,
249 path: &str,
250 ) -> Result<bool> {
251 if WebSocketProxyHandler::is_websocket_upgrade_request(session.req_header()) {
253 info!("Detected WebSocket upgrade request for path: {}", path);
254 return self
255 .websocket_handler
256 .handle_websocket_proxy(session, path)
257 .await;
258 }
259
260 if let Some(route) = self.find_proxy_route(path) {
262 info!("Proxying request {} to upstream '{}'", path, route.upstream);
263
264 let upstream = match self.select_upstream(&route.upstream) {
266 Ok(upstream) => upstream,
267 Err(e) => {
268 error!("Failed to select upstream: {}", e);
269 self.send_error_response(session, 502, "Bad Gateway")
270 .await?;
271 return Ok(true);
272 }
273 };
274
275 let upstream_url = match self.get_upstream_url(upstream) {
277 Ok(url) => url,
278 Err(e) => {
279 error!("Failed to parse upstream URL: {}", e);
280 self.send_error_response(session, 502, "Bad Gateway")
281 .await?;
282 return Ok(true);
283 }
284 };
285
286 let new_path = self.transform_path(route, path);
288
289 self.increment_connections(&upstream.url);
291
292 let proxy_result = self
294 .proxy_to_upstream(session, &upstream_url, &new_path, route, site)
295 .await;
296
297 self.decrement_connections(&upstream.url);
299
300 match proxy_result {
301 Ok(()) => {
302 info!("Successfully proxied request {} to {}", path, upstream.url);
303 Ok(true)
304 }
305 Err(e) => {
306 error!("Proxy request failed: {}", e);
307 self.send_error_response(session, 502, "Bad Gateway")
308 .await?;
309 Ok(true)
310 }
311 }
312 } else {
313 Ok(false)
315 }
316 }
317
318 async fn proxy_to_upstream(
320 &self,
321 session: &mut Session,
322 upstream_url: &Url,
323 new_path: &str,
324 _route: &ProxyRoute,
325 site: &SiteConfig,
326 ) -> Result<()> {
327 let client = reqwest::Client::builder()
329 .timeout(std::time::Duration::from_secs(
330 self.proxy_config.timeout.read,
331 ))
332 .build()
333 .map_err(|_| Error::new_str("Failed to create HTTP client"))?;
334
335 let original_host = session
337 .req_header()
338 .headers
339 .get("Host")
340 .and_then(|h| h.to_str().ok())
341 .unwrap_or("localhost");
342
343 let full_upstream_url = format!(
345 "{}://{}{}{}",
346 upstream_url.scheme(),
347 upstream_url.host_str().unwrap_or("localhost"),
348 upstream_url
349 .port()
350 .map(|p| format!(":{}", p))
351 .unwrap_or_default(),
352 new_path
353 );
354
355 debug!("Proxying to upstream URL: {}", full_upstream_url);
356
357 let method = session.req_header().method.clone();
359 let mut req_builder = match method.as_str() {
360 "GET" => client.get(&full_upstream_url),
361 "POST" => client.post(&full_upstream_url),
362 "PUT" => client.put(&full_upstream_url),
363 "DELETE" => client.delete(&full_upstream_url),
364 "PATCH" => client.patch(&full_upstream_url),
365 "HEAD" => client.head(&full_upstream_url),
366 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &full_upstream_url),
367 _ => client.get(&full_upstream_url), };
369
370 for (name, value) in session.req_header().headers.iter() {
372 if let Ok(value_str) = value.to_str() {
373 let name_str = name.as_str();
374 if name_str.to_lowercase() != "host" {
376 req_builder = req_builder.header(name_str, value_str);
377 }
378 }
379 }
380
381 let mut temp_header = session.req_header().clone();
383 self.add_proxy_headers(&mut temp_header, session, original_host);
384
385 for (name, value) in temp_header.headers.iter() {
387 if let Ok(value_str) = value.to_str() {
388 let name_str = name.as_str();
389 if name_str.starts_with("X-Forwarded") || name_str == "Forwarded" {
390 req_builder = req_builder.header(name_str, value_str);
391 }
392 }
393 }
394
395 let body = if method.as_str() == "POST"
397 || method.as_str() == "PUT"
398 || method.as_str() == "PATCH"
399 {
400 Vec::new()
403 } else {
404 Vec::new()
405 };
406
407 if !body.is_empty() {
408 req_builder = req_builder.body(body);
409 }
410
411 let response = req_builder
413 .send()
414 .await
415 .map_err(|_| Error::new_str("Upstream request failed"))?;
416
417 let status = response.status().as_u16();
419
420 let mut header_map = std::collections::HashMap::new();
422 for (name, value) in response.headers().iter() {
423 if let Ok(value_str) = value.to_str() {
424 header_map.insert(name.as_str().to_string(), value_str.to_string());
425 }
426 }
427
428 let body_bytes = response
430 .bytes()
431 .await
432 .map_err(|_| Error::new_str("Failed to read upstream response"))?;
433
434 let content_type = header_map
436 .get("content-type")
437 .unwrap_or(&"application/octet-stream".to_string())
438 .clone();
439
440 let compression_middleware = CompressionMiddleware::new(site.compression.clone());
441
442 let (final_body, encoding) = if compression_middleware
443 .should_compress(&content_type, body_bytes.len())
444 {
445 let accept_encoding = session
447 .req_header()
448 .headers
449 .get("accept-encoding")
450 .and_then(|h| h.to_str().ok());
451
452 let compression_method = compression_middleware.get_best_compression(accept_encoding);
453
454 match compression_middleware.compress(&body_bytes, compression_method.clone()) {
455 Ok(compressed_content) => {
456 debug!(
457 "Compressed proxy response {} bytes to {} bytes using {:?} ({}% reduction)",
458 body_bytes.len(),
459 compressed_content.len(),
460 compression_method,
461 if !body_bytes.is_empty() {
462 ((body_bytes.len() - compressed_content.len()) * 100) / body_bytes.len()
463 } else {
464 0
465 }
466 );
467 (compressed_content, Some(compression_method))
468 }
469 Err(e) => {
470 debug!("Compression failed: {}, serving uncompressed", e);
471 (body_bytes.to_vec().into(), None)
472 }
473 }
474 } else {
475 (body_bytes.to_vec().into(), None)
476 };
477
478 let mut resp_header = ResponseHeader::build(status, Some(4))?;
480
481 for (name, value) in header_map {
483 if name.to_lowercase() != "content-length" {
484 resp_header.insert_header(name, value)?;
485 }
486 }
487
488 resp_header.insert_header("Content-Length", final_body.len().to_string())?;
490
491 if let Some(method) = encoding {
492 if !matches!(method, CompressionMethod::None) {
493 resp_header.insert_header("Content-Encoding", method.as_str())?;
494 resp_header.insert_header("Vary", "Accept-Encoding")?;
495 }
496 }
497
498 session
500 .write_response_header(Box::new(resp_header), false)
501 .await?;
502 session.write_response_body(Some(final_body), true).await?;
503
504 Ok(())
505 }
506
507 async fn send_error_response(
509 &self,
510 session: &mut Session,
511 status_code: u16,
512 message: &str,
513 ) -> Result<()> {
514 let error_response = serde_json::json!({
515 "error": message,
516 "status": status_code,
517 "timestamp": chrono::Utc::now().to_rfc3339()
518 });
519
520 let response_body = error_response.to_string();
521 let response_bytes = response_body.into_bytes();
522 let mut header = ResponseHeader::build(status_code, Some(3))?;
523 header.insert_header("Content-Type", "application/json")?;
524 header.insert_header("Content-Length", response_bytes.len().to_string())?;
525
526 session
527 .write_response_header(Box::new(header), false)
528 .await?;
529 session
530 .write_response_body(Some(response_bytes.into()), true)
531 .await?;
532
533 Ok(())
534 }
535}