1use crate::config::site::{ProxyConfig, ProxyRoute, SiteConfig, UpstreamConfig};
2use crate::handlers::websocket_proxy::WebSocketProxyHandler;
3use chrono;
4use log::{debug, error, info};
5use pingora::http::{RequestHeader, ResponseHeader};
6use pingora::prelude::*;
7use serde_json;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use url::Url;
12
13pub struct ProxyHandler {
14 proxy_config: ProxyConfig,
15 upstreams: HashMap<String, Vec<UpstreamConfig>>,
16 round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
17 connection_counts: HashMap<String, Arc<AtomicUsize>>,
18 websocket_handler: WebSocketProxyHandler,
19}
20
21impl ProxyHandler {
22 pub fn new(proxy_config: ProxyConfig) -> Self {
23 let mut upstreams = HashMap::new();
24 let mut round_robin_counters = HashMap::new();
25 let mut connection_counts = HashMap::new();
26
27 for upstream in &proxy_config.upstreams {
29 upstreams
30 .entry(upstream.name.clone())
31 .or_insert_with(Vec::new)
32 .push(upstream.clone());
33
34 round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
35
36 connection_counts.insert(upstream.url.clone(), Arc::new(AtomicUsize::new(0)));
37 }
38
39 Self {
40 proxy_config: proxy_config.clone(),
41 upstreams,
42 round_robin_counters,
43 connection_counts,
44 websocket_handler: WebSocketProxyHandler::new(proxy_config),
45 }
46 }
47
48 pub fn find_proxy_route(&self, path: &str) -> Option<&ProxyRoute> {
50 if !self.proxy_config.enabled {
51 return None;
52 }
53
54 self.proxy_config
56 .routes
57 .iter()
58 .filter(|route| path.starts_with(&route.path))
59 .max_by_key(|route| route.path.len())
60 }
61
62 pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
64 let upstream_servers = self
65 .upstreams
66 .get(upstream_name)
67 .ok_or_else(|| Error::new_str("Upstream not found"))?;
68
69 if upstream_servers.is_empty() {
70 return Err(Error::new_str("No servers available for upstream"));
71 }
72
73 let upstream = match self.proxy_config.load_balancing.method.as_str() {
74 "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
75 "weighted" => self.select_weighted(upstream_servers)?,
76 "least_connections" => self.select_least_connections(upstream_servers)?,
77 _ => &upstream_servers[0], };
79
80 Ok(upstream)
81 }
82
83 fn select_round_robin<'a>(
85 &self,
86 upstream_name: &str,
87 servers: &'a [UpstreamConfig],
88 ) -> Result<&'a UpstreamConfig> {
89 let counter = self
90 .round_robin_counters
91 .get(upstream_name)
92 .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
93
94 let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
95 Ok(&servers[index])
96 }
97
98 fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
100 let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
101 if total_weight == 0 {
102 return Ok(&servers[0]);
103 }
104
105 let random_weight = fastrand::u32(1..=total_weight);
106 let mut current_weight = 0;
107
108 for server in servers {
109 current_weight += server.weight;
110 if random_weight <= current_weight {
111 return Ok(server);
112 }
113 }
114
115 Ok(&servers[0])
116 }
117
118 fn select_least_connections<'a>(
120 &self,
121 servers: &'a [UpstreamConfig],
122 ) -> Result<&'a UpstreamConfig> {
123 let mut min_connections = usize::MAX;
125 let mut selected_server = &servers[0];
126
127 for server in servers {
128 let connections = self
129 .connection_counts
130 .get(&server.url)
131 .map(|c| c.load(Ordering::Relaxed))
132 .unwrap_or(0);
133
134 if connections < min_connections {
135 min_connections = connections;
136 selected_server = server;
137 }
138 }
139
140 info!(
141 "Selected server '{}' with {} connections",
142 selected_server.url, min_connections
143 );
144 Ok(selected_server)
145 }
146
147 fn increment_connections(&self, server_url: &str) {
149 if let Some(counter) = self.connection_counts.get(server_url) {
150 counter.fetch_add(1, Ordering::Relaxed);
151 }
152 }
153
154 fn decrement_connections(&self, server_url: &str) {
156 if let Some(counter) = self.connection_counts.get(server_url) {
157 counter.fetch_sub(1, Ordering::Relaxed);
158 }
159 }
160
161 pub fn get_upstream_url(&self, upstream: &UpstreamConfig) -> Result<Url> {
163 Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))
164 }
165
166 pub fn transform_path(&self, route: &ProxyRoute, original_path: &str) -> String {
168 let target_path = if route.strip_prefix {
169 original_path
170 .strip_prefix(&route.path)
171 .unwrap_or(original_path)
172 } else {
173 original_path
174 };
175
176 if let Some(rewrite_target) = &route.rewrite_target {
177 rewrite_target.clone()
178 } else {
179 format!("/{}", target_path.trim_start_matches('/'))
180 }
181 }
182
183 pub fn add_proxy_headers(
185 &self,
186 req: &mut RequestHeader,
187 session: &Session,
188 original_host: &str,
189 ) {
190 if self.proxy_config.headers.add_x_forwarded {
191 if let Some(client_addr) = session.client_addr() {
192 req.insert_header("X-Forwarded-For", client_addr.to_string())
193 .ok();
194 }
195 }
196
197 if self.proxy_config.headers.add_x_forwarded {
198 let proto = if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
199 "https"
200 } else {
201 "http"
202 };
203 req.insert_header("X-Forwarded-Proto", proto).ok();
204 }
205
206 if self.proxy_config.headers.add_x_forwarded {
207 req.insert_header("X-Forwarded-Host", original_host).ok();
208 }
209
210 if self.proxy_config.headers.add_forwarded {
211 if let Some(client_addr) = session.client_addr() {
212 let proto =
213 if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
214 "https"
215 } else {
216 "http"
217 };
218 let forwarded =
219 format!("for={};proto={};host={}", client_addr, proto, original_host);
220 req.insert_header("Forwarded", forwarded).ok();
221 }
222 }
223
224 for (key, value) in &self.proxy_config.headers.add {
226 req.insert_header(key.clone(), value.clone()).ok();
227 }
228 }
229
230 pub async fn handle_proxy_request(
232 &self,
233 session: &mut Session,
234 _site: &SiteConfig,
235 path: &str,
236 ) -> Result<bool> {
237 if WebSocketProxyHandler::is_websocket_upgrade_request(session.req_header()) {
239 info!("Detected WebSocket upgrade request for path: {}", path);
240 return self
241 .websocket_handler
242 .handle_websocket_proxy(session, path)
243 .await;
244 }
245
246 if let Some(route) = self.find_proxy_route(path) {
248 info!("Proxying request {} to upstream '{}'", path, route.upstream);
249
250 let upstream = match self.select_upstream(&route.upstream) {
252 Ok(upstream) => upstream,
253 Err(e) => {
254 error!("Failed to select upstream: {}", e);
255 self.send_error_response(session, 502, "Bad Gateway")
256 .await?;
257 return Ok(true);
258 }
259 };
260
261 let upstream_url = match self.get_upstream_url(upstream) {
263 Ok(url) => url,
264 Err(e) => {
265 error!("Failed to parse upstream URL: {}", e);
266 self.send_error_response(session, 502, "Bad Gateway")
267 .await?;
268 return Ok(true);
269 }
270 };
271
272 let new_path = self.transform_path(route, path);
274
275 self.increment_connections(&upstream.url);
277
278 let proxy_result = self
280 .proxy_to_upstream(session, &upstream_url, &new_path, route)
281 .await;
282
283 self.decrement_connections(&upstream.url);
285
286 match proxy_result {
287 Ok(()) => {
288 info!("Successfully proxied request {} to {}", path, upstream.url);
289 Ok(true)
290 }
291 Err(e) => {
292 error!("Proxy request failed: {}", e);
293 self.send_error_response(session, 502, "Bad Gateway")
294 .await?;
295 Ok(true)
296 }
297 }
298 } else {
299 Ok(false)
301 }
302 }
303
304 async fn proxy_to_upstream(
306 &self,
307 session: &mut Session,
308 upstream_url: &Url,
309 new_path: &str,
310 _route: &ProxyRoute,
311 ) -> Result<()> {
312 let client = reqwest::Client::builder()
314 .timeout(std::time::Duration::from_secs(
315 self.proxy_config.timeout.read,
316 ))
317 .build()
318 .map_err(|_| Error::new_str("Failed to create HTTP client"))?;
319
320 let original_host = session
322 .req_header()
323 .headers
324 .get("Host")
325 .and_then(|h| h.to_str().ok())
326 .unwrap_or("localhost");
327
328 let full_upstream_url = format!(
330 "{}://{}{}{}",
331 upstream_url.scheme(),
332 upstream_url.host_str().unwrap_or("localhost"),
333 upstream_url
334 .port()
335 .map(|p| format!(":{}", p))
336 .unwrap_or_default(),
337 new_path
338 );
339
340 debug!("Proxying to upstream URL: {}", full_upstream_url);
341
342 let method = session.req_header().method.clone();
344 let mut req_builder = match method.as_str() {
345 "GET" => client.get(&full_upstream_url),
346 "POST" => client.post(&full_upstream_url),
347 "PUT" => client.put(&full_upstream_url),
348 "DELETE" => client.delete(&full_upstream_url),
349 "PATCH" => client.patch(&full_upstream_url),
350 "HEAD" => client.head(&full_upstream_url),
351 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &full_upstream_url),
352 _ => client.get(&full_upstream_url), };
354
355 for (name, value) in session.req_header().headers.iter() {
357 if let Ok(value_str) = value.to_str() {
358 let name_str = name.as_str();
359 if name_str.to_lowercase() != "host" {
361 req_builder = req_builder.header(name_str, value_str);
362 }
363 }
364 }
365
366 let mut temp_header = session.req_header().clone();
368 self.add_proxy_headers(&mut temp_header, session, original_host);
369
370 for (name, value) in temp_header.headers.iter() {
372 if let Ok(value_str) = value.to_str() {
373 let name_str = name.as_str();
374 if name_str.starts_with("X-Forwarded") || name_str == "Forwarded" {
375 req_builder = req_builder.header(name_str, value_str);
376 }
377 }
378 }
379
380 let body = if method.as_str() == "POST"
382 || method.as_str() == "PUT"
383 || method.as_str() == "PATCH"
384 {
385 Vec::new()
388 } else {
389 Vec::new()
390 };
391
392 if !body.is_empty() {
393 req_builder = req_builder.body(body);
394 }
395
396 let response = req_builder
398 .send()
399 .await
400 .map_err(|_| Error::new_str("Upstream request failed"))?;
401
402 let status = response.status().as_u16();
404
405 let mut header_map = std::collections::HashMap::new();
407 for (name, value) in response.headers().iter() {
408 if let Ok(value_str) = value.to_str() {
409 header_map.insert(name.as_str().to_string(), value_str.to_string());
410 }
411 }
412
413 let body = response
415 .bytes()
416 .await
417 .map_err(|_| Error::new_str("Failed to read upstream response"))?;
418
419 let mut resp_header = ResponseHeader::build(status, Some(4))?;
421
422 for (name, value) in header_map {
424 resp_header.insert_header(name, value)?;
425 }
426
427 session
429 .write_response_header(Box::new(resp_header), false)
430 .await?;
431 session.write_response_body(Some(body), true).await?;
432
433 Ok(())
434 }
435
436 async fn send_error_response(
438 &self,
439 session: &mut Session,
440 status_code: u16,
441 message: &str,
442 ) -> Result<()> {
443 let error_response = serde_json::json!({
444 "error": message,
445 "status": status_code,
446 "timestamp": chrono::Utc::now().to_rfc3339()
447 });
448
449 let response_body = error_response.to_string();
450 let response_bytes = response_body.into_bytes();
451 let mut header = ResponseHeader::build(status_code, Some(3))?;
452 header.insert_header("Content-Type", "application/json")?;
453 header.insert_header("Content-Length", response_bytes.len().to_string())?;
454
455 session
456 .write_response_header(Box::new(header), false)
457 .await?;
458 session
459 .write_response_body(Some(response_bytes.into()), true)
460 .await?;
461
462 Ok(())
463 }
464}