bws_web_server/handlers/
proxy_handler.rs

1use crate::config::site::{ProxyConfig, ProxyRoute, SiteConfig, UpstreamConfig};
2use crate::handlers::websocket_proxy::WebSocketProxyHandler;
3use crate::middleware::compression::{CompressionMethod, CompressionMiddleware};
4use chrono;
5use log::{debug, error, info};
6use pingora::http::{RequestHeader, ResponseHeader};
7use pingora::prelude::*;
8use serde_json;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use url::Url;
13
14pub struct ProxyHandler {
15    /// Proxy configuration for the site, including routes and upstreams
16    proxy_config: ProxyConfig,
17    /// Map of upstream name to list of upstream server configs
18    upstreams: HashMap<String, Vec<UpstreamConfig>>,
19    /// Round-robin counters for each upstream
20    round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
21    /// Connection counts for each upstream server (for least-connections balancing)
22    connection_counts: HashMap<String, Arc<AtomicUsize>>,
23    /// Handler for WebSocket proxying
24    websocket_handler: WebSocketProxyHandler,
25}
26
27impl ProxyHandler {
28    /// Create a new ProxyHandler from the given proxy configuration
29    pub fn new(proxy_config: ProxyConfig) -> Self {
30        let mut upstreams = HashMap::new();
31        let mut round_robin_counters = HashMap::new();
32        let mut connection_counts = HashMap::new();
33
34        // Group upstreams by name
35        for upstream in &proxy_config.upstreams {
36            upstreams
37                .entry(upstream.name.clone())
38                .or_insert_with(Vec::new)
39                .push(upstream.clone());
40
41            round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
42
43            connection_counts.insert(upstream.url.clone(), Arc::new(AtomicUsize::new(0)));
44        }
45
46        Self {
47            proxy_config: proxy_config.clone(),
48            upstreams,
49            round_robin_counters,
50            connection_counts,
51            websocket_handler: WebSocketProxyHandler::new(proxy_config),
52        }
53    }
54
55    /// Find the appropriate proxy route for a given path
56    /// Find the most specific proxy route for a given request path
57    /// Returns None if proxying is disabled or no route matches.
58    pub fn find_proxy_route(&self, path: &str) -> Option<&ProxyRoute> {
59        if !self.proxy_config.enabled {
60            return None;
61        }
62
63        // Find matching route - most specific first
64        self.proxy_config
65            .routes
66            .iter()
67            .filter(|route| path.starts_with(&route.path))
68            .max_by_key(|route| route.path.len())
69    }
70
71    /// Select an upstream server for a given upstream name
72    /// Select an upstream server for a given upstream name using the configured load balancing method.
73    /// Returns an error if no upstreams are available.
74    pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
75        let upstream_servers = self
76            .upstreams
77            .get(upstream_name)
78            .ok_or_else(|| Error::new_str("Upstream not found"))?;
79
80        if upstream_servers.is_empty() {
81            return Err(Error::new_str("No servers available for upstream"));
82        }
83
84        let upstream = match self.proxy_config.load_balancing.method.as_str() {
85            "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
86            "weighted" => self.select_weighted(upstream_servers)?,
87            "least_connections" => self.select_least_connections(upstream_servers)?,
88            _ => &upstream_servers[0], // Default to first server
89        };
90
91        Ok(upstream)
92    }
93
94    /// Round-robin load balancing
95    /// Select an upstream server using round-robin load balancing.
96    fn select_round_robin<'a>(
97        &self,
98        upstream_name: &str,
99        servers: &'a [UpstreamConfig],
100    ) -> Result<&'a UpstreamConfig> {
101        let counter = self
102            .round_robin_counters
103            .get(upstream_name)
104            .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
105
106        let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
107        Ok(&servers[index])
108    }
109
110    /// Weighted load balancing
111    /// Select an upstream server using weighted random selection.
112    fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
113        let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
114        if total_weight == 0 {
115            return Ok(&servers[0]);
116        }
117
118        let random_weight = fastrand::u32(1..=total_weight);
119        let mut current_weight = 0;
120
121        for server in servers {
122            current_weight += server.weight;
123            if random_weight <= current_weight {
124                return Ok(server);
125            }
126        }
127
128        Ok(&servers[0])
129    }
130
131    /// Least connections load balancing (uses actual connection tracking)
132    /// Select an upstream server with the least number of active connections.
133    fn select_least_connections<'a>(
134        &self,
135        servers: &'a [UpstreamConfig],
136    ) -> Result<&'a UpstreamConfig> {
137        // Find the server with the least current connections
138        let mut min_connections = usize::MAX;
139        let mut selected_server = &servers[0];
140
141        for server in servers {
142            let connections = self
143                .connection_counts
144                .get(&server.url)
145                .map(|c| c.load(Ordering::Relaxed))
146                .unwrap_or(0);
147
148            if connections < min_connections {
149                min_connections = connections;
150                selected_server = server;
151            }
152        }
153
154        info!(
155            "Selected server '{}' with {} connections",
156            selected_server.url, min_connections
157        );
158        Ok(selected_server)
159    }
160
161    /// Increment connection count for a server
162    fn increment_connections(&self, server_url: &str) {
163        if let Some(counter) = self.connection_counts.get(server_url) {
164            counter.fetch_add(1, Ordering::Relaxed);
165        }
166    }
167
168    /// Decrement connection count for a server
169    fn decrement_connections(&self, server_url: &str) {
170        if let Some(counter) = self.connection_counts.get(server_url) {
171            counter.fetch_sub(1, Ordering::Relaxed);
172        }
173    }
174
175    /// Create HTTP peer for upstream server (simplified)
176    pub fn get_upstream_url(&self, upstream: &UpstreamConfig) -> Result<Url> {
177        Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))
178    }
179
180    /// Transform request path according to route configuration
181    pub fn transform_path(&self, route: &ProxyRoute, original_path: &str) -> String {
182        let target_path = if route.strip_prefix {
183            original_path
184                .strip_prefix(&route.path)
185                .unwrap_or(original_path)
186        } else {
187            original_path
188        };
189
190        if let Some(rewrite_target) = &route.rewrite_target {
191            rewrite_target.clone()
192        } else {
193            format!("/{}", target_path.trim_start_matches('/'))
194        }
195    }
196
197    /// Add proxy headers to upstream request
198    pub fn add_proxy_headers(
199        &self,
200        req: &mut RequestHeader,
201        session: &Session,
202        original_host: &str,
203    ) {
204        if self.proxy_config.headers.add_x_forwarded {
205            if let Some(client_addr) = session.client_addr() {
206                req.insert_header("X-Forwarded-For", client_addr.to_string())
207                    .ok();
208            }
209        }
210
211        if self.proxy_config.headers.add_x_forwarded {
212            let proto = if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
213                "https"
214            } else {
215                "http"
216            };
217            req.insert_header("X-Forwarded-Proto", proto).ok();
218        }
219
220        if self.proxy_config.headers.add_x_forwarded {
221            req.insert_header("X-Forwarded-Host", original_host).ok();
222        }
223
224        if self.proxy_config.headers.add_forwarded {
225            if let Some(client_addr) = session.client_addr() {
226                let proto =
227                    if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
228                        "https"
229                    } else {
230                        "http"
231                    };
232                let forwarded =
233                    format!("for={};proto={};host={}", client_addr, proto, original_host);
234                req.insert_header("Forwarded", forwarded).ok();
235            }
236        }
237
238        // Add custom headers
239        for (key, value) in &self.proxy_config.headers.add {
240            req.insert_header(key.clone(), value.clone()).ok();
241        }
242    }
243
244    /// Handle a proxy request for a specific site and path
245    pub async fn handle_proxy_request(
246        &self,
247        session: &mut Session,
248        site: &SiteConfig,
249        path: &str,
250    ) -> Result<bool> {
251        // Check if this is a WebSocket upgrade request
252        if WebSocketProxyHandler::is_websocket_upgrade_request(session.req_header()) {
253            info!("Detected WebSocket upgrade request for path: {}", path);
254            return self
255                .websocket_handler
256                .handle_websocket_proxy(session, path)
257                .await;
258        }
259
260        // Find matching route for regular HTTP proxy
261        if let Some(route) = self.find_proxy_route(path) {
262            info!("Proxying request {} to upstream '{}'", path, route.upstream);
263
264            // Select upstream server
265            let upstream = match self.select_upstream(&route.upstream) {
266                Ok(upstream) => upstream,
267                Err(e) => {
268                    error!("Failed to select upstream: {}", e);
269                    self.send_error_response(session, 502, "Bad Gateway")
270                        .await?;
271                    return Ok(true);
272                }
273            };
274
275            // Get upstream URL
276            let upstream_url = match self.get_upstream_url(upstream) {
277                Ok(url) => url,
278                Err(e) => {
279                    error!("Failed to parse upstream URL: {}", e);
280                    self.send_error_response(session, 502, "Bad Gateway")
281                        .await?;
282                    return Ok(true);
283                }
284            };
285
286            // Transform the request path
287            let new_path = self.transform_path(route, path);
288
289            // Track connection for load balancing
290            self.increment_connections(&upstream.url);
291
292            // Perform the proxy request
293            let proxy_result = self
294                .proxy_to_upstream(session, &upstream_url, &new_path, route, site)
295                .await;
296
297            // Always decrement connection count when done
298            self.decrement_connections(&upstream.url);
299
300            match proxy_result {
301                Ok(()) => {
302                    info!("Successfully proxied request {} to {}", path, upstream.url);
303                    Ok(true)
304                }
305                Err(e) => {
306                    error!("Proxy request failed: {}", e);
307                    self.send_error_response(session, 502, "Bad Gateway")
308                        .await?;
309                    Ok(true)
310                }
311            }
312        } else {
313            // No matching proxy route
314            Ok(false)
315        }
316    }
317
318    /// Perform the actual proxy request to upstream
319    async fn proxy_to_upstream(
320        &self,
321        session: &mut Session,
322        upstream_url: &Url,
323        new_path: &str,
324        _route: &ProxyRoute,
325        site: &SiteConfig,
326    ) -> Result<()> {
327        // Create a new HTTP client for the upstream request
328        let client = reqwest::Client::builder()
329            .timeout(std::time::Duration::from_secs(
330                self.proxy_config.timeout.read,
331            ))
332            .build()
333            .map_err(|_| Error::new_str("Failed to create HTTP client"))?;
334
335        // Get original host header
336        let original_host = session
337            .req_header()
338            .headers
339            .get("Host")
340            .and_then(|h| h.to_str().ok())
341            .unwrap_or("localhost");
342
343        // Build upstream URL with new path
344        let full_upstream_url = format!(
345            "{}://{}{}{}",
346            upstream_url.scheme(),
347            upstream_url.host_str().unwrap_or("localhost"),
348            upstream_url
349                .port()
350                .map(|p| format!(":{}", p))
351                .unwrap_or_default(),
352            new_path
353        );
354
355        debug!("Proxying to upstream URL: {}", full_upstream_url);
356
357        // Create upstream request
358        let method = session.req_header().method.clone();
359        let mut req_builder = match method.as_str() {
360            "GET" => client.get(&full_upstream_url),
361            "POST" => client.post(&full_upstream_url),
362            "PUT" => client.put(&full_upstream_url),
363            "DELETE" => client.delete(&full_upstream_url),
364            "PATCH" => client.patch(&full_upstream_url),
365            "HEAD" => client.head(&full_upstream_url),
366            "OPTIONS" => client.request(reqwest::Method::OPTIONS, &full_upstream_url),
367            _ => client.get(&full_upstream_url), // Default to GET
368        };
369
370        // Copy headers from original request
371        for (name, value) in session.req_header().headers.iter() {
372            if let Ok(value_str) = value.to_str() {
373                let name_str = name.as_str();
374                // Skip host header as we'll set it appropriately
375                if name_str.to_lowercase() != "host" {
376                    req_builder = req_builder.header(name_str, value_str);
377                }
378            }
379        }
380
381        // Add proxy headers
382        let mut temp_header = session.req_header().clone();
383        self.add_proxy_headers(&mut temp_header, session, original_host);
384
385        // Copy proxy headers to request
386        for (name, value) in temp_header.headers.iter() {
387            if let Ok(value_str) = value.to_str() {
388                let name_str = name.as_str();
389                if name_str.starts_with("X-Forwarded") || name_str == "Forwarded" {
390                    req_builder = req_builder.header(name_str, value_str);
391                }
392            }
393        }
394
395        // Read request body if present
396        let body = if method.as_str() == "POST"
397            || method.as_str() == "PUT"
398            || method.as_str() == "PATCH"
399        {
400            // For now, we'll handle requests without body.
401            // Full body proxying would require reading from session.read_request_body()
402            Vec::new()
403        } else {
404            Vec::new()
405        };
406
407        if !body.is_empty() {
408            req_builder = req_builder.body(body);
409        }
410
411        // Send request to upstream
412        let response = req_builder
413            .send()
414            .await
415            .map_err(|_| Error::new_str("Upstream request failed"))?;
416
417        // Get response status
418        let status = response.status().as_u16();
419
420        // Collect headers before consuming response
421        let mut header_map = std::collections::HashMap::new();
422        for (name, value) in response.headers().iter() {
423            if let Ok(value_str) = value.to_str() {
424                header_map.insert(name.as_str().to_string(), value_str.to_string());
425            }
426        }
427
428        // Get response body (this consumes the response)
429        let body_bytes = response
430            .bytes()
431            .await
432            .map_err(|_| Error::new_str("Failed to read upstream response"))?;
433
434        // Check if response should be compressed
435        let content_type = header_map
436            .get("content-type")
437            .unwrap_or(&"application/octet-stream".to_string())
438            .clone();
439
440        let compression_middleware = CompressionMiddleware::new(site.compression.clone());
441
442        let (final_body, encoding) = if compression_middleware
443            .should_compress(&content_type, body_bytes.len())
444        {
445            // Get the best compression method based on Accept-Encoding header
446            let accept_encoding = session
447                .req_header()
448                .headers
449                .get("accept-encoding")
450                .and_then(|h| h.to_str().ok());
451
452            let compression_method = compression_middleware.get_best_compression(accept_encoding);
453
454            match compression_middleware.compress(&body_bytes, compression_method.clone()) {
455                Ok(compressed_content) => {
456                    debug!(
457                        "Compressed proxy response {} bytes to {} bytes using {:?} ({}% reduction)",
458                        body_bytes.len(),
459                        compressed_content.len(),
460                        compression_method,
461                        if !body_bytes.is_empty() {
462                            ((body_bytes.len() - compressed_content.len()) * 100) / body_bytes.len()
463                        } else {
464                            0
465                        }
466                    );
467                    (compressed_content, Some(compression_method))
468                }
469                Err(e) => {
470                    debug!("Compression failed: {}, serving uncompressed", e);
471                    (body_bytes.to_vec().into(), None)
472                }
473            }
474        } else {
475            (body_bytes.to_vec().into(), None)
476        };
477
478        // Build response header
479        let mut resp_header = ResponseHeader::build(status, Some(4))?;
480
481        // Add collected headers (except content-length which we'll update)
482        for (name, value) in header_map {
483            if name.to_lowercase() != "content-length" {
484                resp_header.insert_header(name, value)?;
485            }
486        }
487
488        // Update content length and add encoding header
489        resp_header.insert_header("Content-Length", final_body.len().to_string())?;
490
491        if let Some(method) = encoding {
492            if !matches!(method, CompressionMethod::None) {
493                resp_header.insert_header("Content-Encoding", method.as_str())?;
494                resp_header.insert_header("Vary", "Accept-Encoding")?;
495            }
496        }
497
498        // Send response back to client
499        session
500            .write_response_header(Box::new(resp_header), false)
501            .await?;
502        session.write_response_body(Some(final_body), true).await?;
503
504        Ok(())
505    }
506
507    /// Send an error response
508    async fn send_error_response(
509        &self,
510        session: &mut Session,
511        status_code: u16,
512        message: &str,
513    ) -> Result<()> {
514        let error_response = serde_json::json!({
515            "error": message,
516            "status": status_code,
517            "timestamp": chrono::Utc::now().to_rfc3339()
518        });
519
520        let response_body = error_response.to_string();
521        let response_bytes = response_body.into_bytes();
522        let mut header = ResponseHeader::build(status_code, Some(3))?;
523        header.insert_header("Content-Type", "application/json")?;
524        header.insert_header("Content-Length", response_bytes.len().to_string())?;
525
526        session
527            .write_response_header(Box::new(header), false)
528            .await?;
529        session
530            .write_response_body(Some(response_bytes.into()), true)
531            .await?;
532
533        Ok(())
534    }
535}