bws_web_server/handlers/
proxy_handler.rs

1use crate::config::site::{ProxyConfig, ProxyRoute, SiteConfig, UpstreamConfig};
2use crate::handlers::websocket_proxy::WebSocketProxyHandler;
3use chrono;
4use log::{debug, error, info};
5use pingora::http::{RequestHeader, ResponseHeader};
6use pingora::prelude::*;
7use serde_json;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use url::Url;
12
13pub struct ProxyHandler {
14    proxy_config: ProxyConfig,
15    upstreams: HashMap<String, Vec<UpstreamConfig>>,
16    round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
17    connection_counts: HashMap<String, Arc<AtomicUsize>>,
18    websocket_handler: WebSocketProxyHandler,
19}
20
21impl ProxyHandler {
22    pub fn new(proxy_config: ProxyConfig) -> Self {
23        let mut upstreams = HashMap::new();
24        let mut round_robin_counters = HashMap::new();
25        let mut connection_counts = HashMap::new();
26
27        // Group upstreams by name
28        for upstream in &proxy_config.upstreams {
29            upstreams
30                .entry(upstream.name.clone())
31                .or_insert_with(Vec::new)
32                .push(upstream.clone());
33
34            round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
35
36            connection_counts.insert(upstream.url.clone(), Arc::new(AtomicUsize::new(0)));
37        }
38
39        Self {
40            proxy_config: proxy_config.clone(),
41            upstreams,
42            round_robin_counters,
43            connection_counts,
44            websocket_handler: WebSocketProxyHandler::new(proxy_config),
45        }
46    }
47
48    /// Find the appropriate proxy route for a given path
49    pub fn find_proxy_route(&self, path: &str) -> Option<&ProxyRoute> {
50        if !self.proxy_config.enabled {
51            return None;
52        }
53
54        // Find matching route - most specific first
55        self.proxy_config
56            .routes
57            .iter()
58            .filter(|route| path.starts_with(&route.path))
59            .max_by_key(|route| route.path.len())
60    }
61
62    /// Select an upstream server for a given upstream name
63    pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
64        let upstream_servers = self
65            .upstreams
66            .get(upstream_name)
67            .ok_or_else(|| Error::new_str("Upstream not found"))?;
68
69        if upstream_servers.is_empty() {
70            return Err(Error::new_str("No servers available for upstream"));
71        }
72
73        let upstream = match self.proxy_config.load_balancing.method.as_str() {
74            "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
75            "weighted" => self.select_weighted(upstream_servers)?,
76            "least_connections" => self.select_least_connections(upstream_servers)?,
77            _ => &upstream_servers[0], // Default to first server
78        };
79
80        Ok(upstream)
81    }
82
83    /// Round-robin load balancing
84    fn select_round_robin<'a>(
85        &self,
86        upstream_name: &str,
87        servers: &'a [UpstreamConfig],
88    ) -> Result<&'a UpstreamConfig> {
89        let counter = self
90            .round_robin_counters
91            .get(upstream_name)
92            .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
93
94        let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
95        Ok(&servers[index])
96    }
97
98    /// Weighted load balancing
99    fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
100        let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
101        if total_weight == 0 {
102            return Ok(&servers[0]);
103        }
104
105        let random_weight = fastrand::u32(1..=total_weight);
106        let mut current_weight = 0;
107
108        for server in servers {
109            current_weight += server.weight;
110            if random_weight <= current_weight {
111                return Ok(server);
112            }
113        }
114
115        Ok(&servers[0])
116    }
117
118    /// Least connections load balancing (uses actual connection tracking)
119    fn select_least_connections<'a>(
120        &self,
121        servers: &'a [UpstreamConfig],
122    ) -> Result<&'a UpstreamConfig> {
123        // Find the server with the least current connections
124        let mut min_connections = usize::MAX;
125        let mut selected_server = &servers[0];
126
127        for server in servers {
128            let connections = self
129                .connection_counts
130                .get(&server.url)
131                .map(|c| c.load(Ordering::Relaxed))
132                .unwrap_or(0);
133
134            if connections < min_connections {
135                min_connections = connections;
136                selected_server = server;
137            }
138        }
139
140        info!(
141            "Selected server '{}' with {} connections",
142            selected_server.url, min_connections
143        );
144        Ok(selected_server)
145    }
146
147    /// Increment connection count for a server
148    fn increment_connections(&self, server_url: &str) {
149        if let Some(counter) = self.connection_counts.get(server_url) {
150            counter.fetch_add(1, Ordering::Relaxed);
151        }
152    }
153
154    /// Decrement connection count for a server
155    fn decrement_connections(&self, server_url: &str) {
156        if let Some(counter) = self.connection_counts.get(server_url) {
157            counter.fetch_sub(1, Ordering::Relaxed);
158        }
159    }
160
161    /// Create HTTP peer for upstream server (simplified)
162    pub fn get_upstream_url(&self, upstream: &UpstreamConfig) -> Result<Url> {
163        Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))
164    }
165
166    /// Transform request path according to route configuration
167    pub fn transform_path(&self, route: &ProxyRoute, original_path: &str) -> String {
168        let target_path = if route.strip_prefix {
169            original_path
170                .strip_prefix(&route.path)
171                .unwrap_or(original_path)
172        } else {
173            original_path
174        };
175
176        if let Some(rewrite_target) = &route.rewrite_target {
177            rewrite_target.clone()
178        } else {
179            format!("/{}", target_path.trim_start_matches('/'))
180        }
181    }
182
183    /// Add proxy headers to upstream request
184    pub fn add_proxy_headers(
185        &self,
186        req: &mut RequestHeader,
187        session: &Session,
188        original_host: &str,
189    ) {
190        if self.proxy_config.headers.add_x_forwarded {
191            if let Some(client_addr) = session.client_addr() {
192                req.insert_header("X-Forwarded-For", client_addr.to_string())
193                    .ok();
194            }
195        }
196
197        if self.proxy_config.headers.add_x_forwarded {
198            let proto = if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
199                "https"
200            } else {
201                "http"
202            };
203            req.insert_header("X-Forwarded-Proto", proto).ok();
204        }
205
206        if self.proxy_config.headers.add_x_forwarded {
207            req.insert_header("X-Forwarded-Host", original_host).ok();
208        }
209
210        if self.proxy_config.headers.add_forwarded {
211            if let Some(client_addr) = session.client_addr() {
212                let proto =
213                    if session.req_header().uri.scheme().map(|s| s.as_str()) == Some("https") {
214                        "https"
215                    } else {
216                        "http"
217                    };
218                let forwarded =
219                    format!("for={};proto={};host={}", client_addr, proto, original_host);
220                req.insert_header("Forwarded", forwarded).ok();
221            }
222        }
223
224        // Add custom headers
225        for (key, value) in &self.proxy_config.headers.add {
226            req.insert_header(key.clone(), value.clone()).ok();
227        }
228    }
229
230    /// Handle a proxy request for a specific site and path
231    pub async fn handle_proxy_request(
232        &self,
233        session: &mut Session,
234        _site: &SiteConfig,
235        path: &str,
236    ) -> Result<bool> {
237        // Check if this is a WebSocket upgrade request
238        if WebSocketProxyHandler::is_websocket_upgrade_request(session.req_header()) {
239            info!("Detected WebSocket upgrade request for path: {}", path);
240            return self
241                .websocket_handler
242                .handle_websocket_proxy(session, path)
243                .await;
244        }
245
246        // Find matching route for regular HTTP proxy
247        if let Some(route) = self.find_proxy_route(path) {
248            info!("Proxying request {} to upstream '{}'", path, route.upstream);
249
250            // Select upstream server
251            let upstream = match self.select_upstream(&route.upstream) {
252                Ok(upstream) => upstream,
253                Err(e) => {
254                    error!("Failed to select upstream: {}", e);
255                    self.send_error_response(session, 502, "Bad Gateway")
256                        .await?;
257                    return Ok(true);
258                }
259            };
260
261            // Get upstream URL
262            let upstream_url = match self.get_upstream_url(upstream) {
263                Ok(url) => url,
264                Err(e) => {
265                    error!("Failed to parse upstream URL: {}", e);
266                    self.send_error_response(session, 502, "Bad Gateway")
267                        .await?;
268                    return Ok(true);
269                }
270            };
271
272            // Transform the request path
273            let new_path = self.transform_path(route, path);
274
275            // Track connection for load balancing
276            self.increment_connections(&upstream.url);
277
278            // Perform the proxy request
279            let proxy_result = self
280                .proxy_to_upstream(session, &upstream_url, &new_path, route)
281                .await;
282
283            // Always decrement connection count when done
284            self.decrement_connections(&upstream.url);
285
286            match proxy_result {
287                Ok(()) => {
288                    info!("Successfully proxied request {} to {}", path, upstream.url);
289                    Ok(true)
290                }
291                Err(e) => {
292                    error!("Proxy request failed: {}", e);
293                    self.send_error_response(session, 502, "Bad Gateway")
294                        .await?;
295                    Ok(true)
296                }
297            }
298        } else {
299            // No matching proxy route
300            Ok(false)
301        }
302    }
303
304    /// Perform the actual proxy request to upstream
305    async fn proxy_to_upstream(
306        &self,
307        session: &mut Session,
308        upstream_url: &Url,
309        new_path: &str,
310        _route: &ProxyRoute,
311    ) -> Result<()> {
312        // Create a new HTTP client for the upstream request
313        let client = reqwest::Client::builder()
314            .timeout(std::time::Duration::from_secs(
315                self.proxy_config.timeout.read,
316            ))
317            .build()
318            .map_err(|_| Error::new_str("Failed to create HTTP client"))?;
319
320        // Get original host header
321        let original_host = session
322            .req_header()
323            .headers
324            .get("Host")
325            .and_then(|h| h.to_str().ok())
326            .unwrap_or("localhost");
327
328        // Build upstream URL with new path
329        let full_upstream_url = format!(
330            "{}://{}{}{}",
331            upstream_url.scheme(),
332            upstream_url.host_str().unwrap_or("localhost"),
333            upstream_url
334                .port()
335                .map(|p| format!(":{}", p))
336                .unwrap_or_default(),
337            new_path
338        );
339
340        debug!("Proxying to upstream URL: {}", full_upstream_url);
341
342        // Create upstream request
343        let method = session.req_header().method.clone();
344        let mut req_builder = match method.as_str() {
345            "GET" => client.get(&full_upstream_url),
346            "POST" => client.post(&full_upstream_url),
347            "PUT" => client.put(&full_upstream_url),
348            "DELETE" => client.delete(&full_upstream_url),
349            "PATCH" => client.patch(&full_upstream_url),
350            "HEAD" => client.head(&full_upstream_url),
351            "OPTIONS" => client.request(reqwest::Method::OPTIONS, &full_upstream_url),
352            _ => client.get(&full_upstream_url), // Default to GET
353        };
354
355        // Copy headers from original request
356        for (name, value) in session.req_header().headers.iter() {
357            if let Ok(value_str) = value.to_str() {
358                let name_str = name.as_str();
359                // Skip host header as we'll set it appropriately
360                if name_str.to_lowercase() != "host" {
361                    req_builder = req_builder.header(name_str, value_str);
362                }
363            }
364        }
365
366        // Add proxy headers
367        let mut temp_header = session.req_header().clone();
368        self.add_proxy_headers(&mut temp_header, session, original_host);
369
370        // Copy proxy headers to request
371        for (name, value) in temp_header.headers.iter() {
372            if let Ok(value_str) = value.to_str() {
373                let name_str = name.as_str();
374                if name_str.starts_with("X-Forwarded") || name_str == "Forwarded" {
375                    req_builder = req_builder.header(name_str, value_str);
376                }
377            }
378        }
379
380        // Read request body if present
381        let body = if method.as_str() == "POST"
382            || method.as_str() == "PUT"
383            || method.as_str() == "PATCH"
384        {
385            // For now, we'll handle requests without body.
386            // Full body proxying would require reading from session.read_request_body()
387            Vec::new()
388        } else {
389            Vec::new()
390        };
391
392        if !body.is_empty() {
393            req_builder = req_builder.body(body);
394        }
395
396        // Send request to upstream
397        let response = req_builder
398            .send()
399            .await
400            .map_err(|_| Error::new_str("Upstream request failed"))?;
401
402        // Get response status
403        let status = response.status().as_u16();
404
405        // Collect headers before consuming response
406        let mut header_map = std::collections::HashMap::new();
407        for (name, value) in response.headers().iter() {
408            if let Ok(value_str) = value.to_str() {
409                header_map.insert(name.as_str().to_string(), value_str.to_string());
410            }
411        }
412
413        // Get response body (this consumes the response)
414        let body = response
415            .bytes()
416            .await
417            .map_err(|_| Error::new_str("Failed to read upstream response"))?;
418
419        // Build response header
420        let mut resp_header = ResponseHeader::build(status, Some(4))?;
421
422        // Add collected headers
423        for (name, value) in header_map {
424            resp_header.insert_header(name, value)?;
425        }
426
427        // Send response back to client
428        session
429            .write_response_header(Box::new(resp_header), false)
430            .await?;
431        session.write_response_body(Some(body), true).await?;
432
433        Ok(())
434    }
435
436    /// Send an error response
437    async fn send_error_response(
438        &self,
439        session: &mut Session,
440        status_code: u16,
441        message: &str,
442    ) -> Result<()> {
443        let error_response = serde_json::json!({
444            "error": message,
445            "status": status_code,
446            "timestamp": chrono::Utc::now().to_rfc3339()
447        });
448
449        let response_body = error_response.to_string();
450        let response_bytes = response_body.into_bytes();
451        let mut header = ResponseHeader::build(status_code, Some(3))?;
452        header.insert_header("Content-Type", "application/json")?;
453        header.insert_header("Content-Length", response_bytes.len().to_string())?;
454
455        session
456            .write_response_header(Box::new(header), false)
457            .await?;
458        session
459            .write_response_body(Some(response_bytes.into()), true)
460            .await?;
461
462        Ok(())
463    }
464}