Skip to main content

gatel_core/proxy/
mod.rs

1pub mod cgi;
2pub mod circuit_breaker;
3pub mod dns_upstream;
4pub mod fastcgi;
5pub mod forward_proxy;
6pub mod health;
7pub mod lb;
8pub mod scgi;
9pub mod srv_upstream;
10#[cfg(unix)]
11pub mod unix_upstream;
12pub mod upstream;
13pub mod websocket;
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use http::{Request, Response, Uri};
19use http_body_util::BodyExt;
20use regex::Regex;
21use tracing::{debug, warn};
22
23use self::health::{HealthChecker, PassiveHealthChecker};
24use self::lb::*;
25use self::upstream::UpstreamPool;
26use crate::config::{LbPolicy, ProxyConfig};
27use crate::{Body, ProxyError, goals};
28
29/// Reverse proxy handler: forwards requests to upstream backends with
30/// load balancing, health checking, and retry support.
31pub struct ReverseProxy {
32    pool: Arc<UpstreamPool>,
33    lb: Box<dyn LoadBalancer>,
34    headers_up: Vec<(String, String)>,
35    headers_down: Vec<(String, String)>,
36    retries: u32,
37    /// Active health checker (holds the background task; dropped with the proxy).
38    _health_checker: Option<HealthChecker>,
39    /// Passive health checker (tracks 5xx responses).
40    passive_health: Option<Arc<PassiveHealthChecker>>,
41    /// Custom response bodies substituted when upstream returns these status codes.
42    error_pages: HashMap<u16, String>,
43    /// Compiled regex replacement rules for upstream request headers.
44    /// Each entry is `(header_name, compiled_regex, replacement)`.
45    headers_up_replace: Vec<(String, Regex, String)>,
46    /// When true, normalize the URI path before forwarding (collapse double
47    /// slashes, resolve `.` and `..` segments).
48    sanitize_uri: bool,
49}
50
51impl ReverseProxy {
52    pub fn new(config: &ProxyConfig) -> Self {
53        let pool = Arc::new(UpstreamPool::from_config(config));
54
55        // Collect weights for weighted strategies.
56        let weights: Vec<u32> = config.upstreams.iter().map(|u| u.weight).collect();
57
58        let lb: Box<dyn LoadBalancer> = match config.lb {
59            LbPolicy::RoundRobin => Box::new(RoundRobinLb::new()),
60            LbPolicy::Random => Box::new(RandomLb::new()),
61            LbPolicy::WeightedRoundRobin => Box::new(WeightedRoundRobinLb::new(&weights)),
62            LbPolicy::IpHash => Box::new(IpHashLb::new()),
63            LbPolicy::LeastConn => Box::new(LeastConnLb::new()),
64            LbPolicy::UriHash => Box::new(UriHashLb::new()),
65            LbPolicy::HeaderHash => {
66                let name = config
67                    .lb_header
68                    .clone()
69                    .unwrap_or_else(|| "X-Forwarded-For".to_string());
70                Box::new(HeaderHashLb::new(name))
71            }
72            LbPolicy::CookieHash => {
73                let name = config
74                    .lb_cookie
75                    .clone()
76                    .unwrap_or_else(|| "session".to_string());
77                Box::new(CookieHashLb::new(name))
78            }
79            LbPolicy::First => Box::new(FirstLb::new()),
80            LbPolicy::TwoRandomChoices => Box::new(TwoRandomChoicesLb::new()),
81        };
82
83        let headers_up: Vec<(String, String)> = config
84            .headers_up
85            .iter()
86            .map(|(k, v)| (k.clone(), v.clone()))
87            .collect();
88        let headers_down: Vec<(String, String)> = config
89            .headers_down
90            .iter()
91            .map(|(k, v)| (k.clone(), v.clone()))
92            .collect();
93
94        // Start active health checker if configured.
95        let health_checker = config
96            .health_check
97            .as_ref()
98            .map(|hc| HealthChecker::start(Arc::clone(&pool), hc));
99
100        // Create passive health checker if configured.
101        let passive_health = config
102            .passive_health
103            .as_ref()
104            .map(|ph| Arc::new(PassiveHealthChecker::new(pool.len(), ph)));
105
106        // Compile regex replacement rules for upstream headers.
107        let headers_up_replace: Vec<(String, Regex, String)> = config
108            .headers_up_replace
109            .iter()
110            .filter_map(|(name, pattern, replacement)| match Regex::new(pattern) {
111                Ok(re) => Some((name.clone(), re, replacement.clone())),
112                Err(e) => {
113                    warn!(
114                        header = name.as_str(),
115                        pattern = pattern.as_str(),
116                        error = %e,
117                        "invalid regex in header-up-replace, skipping"
118                    );
119                    None
120                }
121            })
122            .collect();
123
124        Self {
125            pool,
126            lb,
127            headers_up,
128            headers_down,
129            retries: config.retries,
130            _health_checker: health_checker,
131            passive_health,
132            error_pages: config.error_pages.clone(),
133            headers_up_replace,
134            sanitize_uri: config.sanitize_uri,
135        }
136    }
137}
138
139#[salvo::async_trait]
140impl salvo::Handler for ReverseProxy {
141    async fn handle(
142        &self,
143        req: &mut salvo::Request,
144        _depot: &mut salvo::Depot,
145        res: &mut salvo::Response,
146        ctrl: &mut salvo::FlowCtrl,
147    ) {
148        let client_addr = crate::hoops::client_addr(req);
149        let request = match goals::strip_request(req) {
150            Ok(r) => r,
151            Err(e) => {
152                goals::merge_response(res, e.into_response());
153                ctrl.skip_rest();
154                return;
155            }
156        };
157        let response = self
158            .proxy(request, client_addr)
159            .await
160            .unwrap_or_else(|e| e.into_response());
161        goals::merge_response(res, response);
162        ctrl.skip_rest();
163    }
164}
165
166impl ReverseProxy {
167    async fn proxy(
168        &self,
169        request: Request<Body>,
170        client_addr: std::net::SocketAddr,
171    ) -> Result<Response<Body>, ProxyError> {
172        // --- WebSocket upgrade detection ---
173        // If the request is a WebSocket upgrade, use the dedicated
174        // WebSocket proxy path instead of the normal HTTP proxy.
175        if websocket::is_websocket_upgrade(&request) {
176            debug!(client = %client_addr, "detected WebSocket upgrade request");
177
178            // Select a backend for the WebSocket connection.
179            let ws_lb_ctx = LbContext {
180                client_addr,
181                uri: request
182                    .uri()
183                    .path_and_query()
184                    .map(|pq| pq.as_str().to_string())
185                    .unwrap_or_else(|| "/".to_string()),
186                headers: request.headers().clone(),
187            };
188
189            let backend_idx = self
190                .lb
191                .select(&self.pool, &ws_lb_ctx)
192                .ok_or(ProxyError::NoUpstream)?;
193            let backend = &self.pool.backends[backend_idx];
194            let _conn_guard = self.pool.acquire_conn(backend_idx);
195
196            return websocket::proxy_websocket(request, &backend.addr).await;
197        }
198
199        // Build LbContext from the incoming request.
200        let lb_ctx = LbContext {
201            client_addr,
202            uri: request
203                .uri()
204                .path_and_query()
205                .map(|pq| pq.as_str().to_string())
206                .unwrap_or_else(|| "/".to_string()),
207            headers: request.headers().clone(),
208        };
209
210        // --- Connection limit check ---
211        if let Some(max_conns) = self.pool.max_connections {
212            let total = self.pool.total_active_conns();
213            if total >= max_conns {
214                warn!(
215                    limit = max_conns,
216                    active = total,
217                    "connection limit exceeded, returning 503"
218                );
219                return Response::builder()
220                    .status(http::StatusCode::SERVICE_UNAVAILABLE)
221                    .body(crate::full_body(
222                        "Service Unavailable: connection limit exceeded",
223                    ))
224                    .map_err(|e| ProxyError::Internal(e.to_string()));
225            }
226        }
227
228        // Collect the body bytes so we can replay on retries.
229        let (parts, body) = request.into_parts();
230        let body_bytes = body
231            .collect()
232            .await
233            .map_err(|e| ProxyError::Internal(format!("failed to buffer body: {e}")))?
234            .to_bytes();
235
236        let max_attempts = 1 + self.retries;
237        let mut last_failed_idx: Option<usize> = None;
238        let mut last_error: Option<ProxyError> = None;
239
240        for attempt in 0..max_attempts {
241            // --- Select a backend ---
242            let backend_idx = {
243                let idx = self.lb.select(&self.pool, &lb_ctx);
244                match idx {
245                    Some(i) if last_failed_idx == Some(i) && self.pool.len() > 1 => {
246                        // On retry, try to skip the backend that just failed.
247                        self.lb.select(&self.pool, &lb_ctx)
248                    }
249                    other => other,
250                }
251            };
252
253            let backend_idx = match backend_idx {
254                Some(i) => i,
255                None => {
256                    return Err(last_error.unwrap_or(ProxyError::NoUpstream));
257                }
258            };
259
260            let backend = &self.pool.backends[backend_idx];
261
262            // --- Build the upstream request ---
263            let mut req_parts = parts.clone();
264
265            // Optionally sanitize the URI path before forwarding.
266            if self.sanitize_uri {
267                let raw_pq = req_parts
268                    .uri
269                    .path_and_query()
270                    .map(|pq| pq.as_str().to_string())
271                    .unwrap_or_else(|| "/".to_string());
272                let (raw_path, raw_query) = if let Some(pos) = raw_pq.find('?') {
273                    (&raw_pq[..pos], Some(&raw_pq[pos + 1..]))
274                } else {
275                    (raw_pq.as_str(), None)
276                };
277                let sanitized_path = sanitize_path(raw_path);
278                let sanitized_pq = match raw_query {
279                    Some(q) if !q.is_empty() => format!("{sanitized_path}?{q}"),
280                    _ => sanitized_path,
281                };
282                if let Ok(new_uri) = sanitized_pq.parse::<http::uri::PathAndQuery>() {
283                    // Rebuild the URI with sanitized path-and-query.
284                    let mut builder = Uri::builder();
285                    if let Some(scheme) = req_parts.uri.scheme() {
286                        builder = builder.scheme(scheme.clone());
287                    }
288                    if let Some(authority) = req_parts.uri.authority() {
289                        builder = builder.authority(authority.clone());
290                    }
291                    builder = builder.path_and_query(new_uri);
292                    if let Ok(u) = builder.build() {
293                        req_parts.uri = u;
294                    }
295                }
296            }
297
298            // Use the scheme embedded in the backend address if present,
299            // otherwise default to "http://".
300            let scheme =
301                if backend.addr.starts_with("https://") || backend.addr.starts_with("http://") {
302                    ""
303                } else {
304                    "http://"
305                };
306            let upstream_uri = format!(
307                "{}{}{}",
308                scheme,
309                backend.addr,
310                req_parts
311                    .uri
312                    .path_and_query()
313                    .map(|pq| pq.as_str())
314                    .unwrap_or("/")
315            );
316            req_parts.uri = match upstream_uri.parse::<Uri>() {
317                Ok(u) => u,
318                Err(e) => {
319                    return Err(ProxyError::Internal(format!("invalid upstream URI: {e}")));
320                }
321            };
322
323            // Set the Host header to the upstream.
324            if let Ok(hv) = backend.addr.parse() {
325                req_parts.headers.insert(http::header::HOST, hv);
326            }
327
328            // Apply header-up directives.
329            for (name, value) in &self.headers_up {
330                if let Some(hdr_name) = name.strip_prefix('-') {
331                    if let Ok(hn) = hdr_name.parse::<http::header::HeaderName>() {
332                        req_parts.headers.remove(hn);
333                    }
334                } else {
335                    let expanded = value.replace("{client_ip}", &client_addr.ip().to_string());
336                    if let (Ok(hn), Ok(hv)) = (
337                        name.parse::<http::header::HeaderName>(),
338                        expanded.parse::<http::header::HeaderValue>(),
339                    ) {
340                        req_parts.headers.insert(hn, hv);
341                    }
342                }
343            }
344
345            // Apply header-up-replace directives (regex substitution on existing values).
346            for (name, re, replacement) in &self.headers_up_replace {
347                if let Ok(hn) = name.parse::<http::header::HeaderName>()
348                    && let Some(existing) = req_parts.headers.get(&hn)
349                    && let Ok(existing_str) = existing.to_str()
350                {
351                    let new_value = re.replace_all(existing_str, replacement.as_str());
352                    if let Ok(hv) = new_value.as_ref().parse::<http::header::HeaderValue>() {
353                        req_parts.headers.insert(hn, hv);
354                    }
355                }
356            }
357
358            // Replay body from the buffered bytes.
359            let req_body = crate::full_body(body_bytes.clone());
360            let upstream_req = Request::from_parts(req_parts, req_body);
361
362            debug!(
363                upstream = %backend.addr,
364                attempt = attempt + 1,
365                "forwarding request"
366            );
367
368            // --- Track active connections ---
369            let _conn_guard = self.pool.acquire_conn(backend_idx);
370
371            // --- Send the request ---
372            // Unix socket backends bypass the shared HTTPS connector.
373            let result = if is_unix_socket(&backend.addr) {
374                #[cfg(unix)]
375                {
376                    let path = unix_socket_path(&backend.addr);
377                    send_via_unix(path, upstream_req).await.map(|r| {
378                        r.map(|b| {
379                            let b: Body = b.map_err(|e| -> crate::BoxError { Box::new(e) }).boxed();
380                            b
381                        })
382                    })
383                }
384                #[cfg(not(unix))]
385                {
386                    let _ = upstream_req;
387                    Err(ProxyError::Internal(
388                        "Unix domain socket upstreams are not supported on this platform".into(),
389                    ))
390                }
391            } else {
392                self.pool
393                    .client
394                    .request(upstream_req)
395                    .await
396                    .map_err(ProxyError::Client)
397                    .map(|r| {
398                        r.map(|b| {
399                            let b: Body = b.map_err(|e| -> crate::BoxError { Box::new(e) }).boxed();
400                            b
401                        })
402                    })
403            };
404
405            match result {
406                Ok(resp) => {
407                    let (mut resp_parts, resp_body) = resp.into_parts();
408
409                    // Passive health: record 5xx
410                    if resp_parts.status.is_server_error()
411                        && let Some(ref ph) = self.passive_health
412                    {
413                        ph.record_failure(backend_idx, &self.pool).await;
414                    }
415
416                    // If server error and we have retries left, retry.
417                    if resp_parts.status.is_server_error() && attempt + 1 < max_attempts {
418                        warn!(
419                            upstream = %backend.addr,
420                            status = %resp_parts.status,
421                            attempt = attempt + 1,
422                            "upstream returned server error, retrying"
423                        );
424                        last_failed_idx = Some(backend_idx);
425                        last_error = Some(ProxyError::Internal(format!(
426                            "upstream {} returned {}",
427                            backend.addr, resp_parts.status
428                        )));
429                        continue;
430                    }
431
432                    // resp_body is already typed as Body (mapped above).
433
434                    // Apply header-down directives.
435                    for (name, value) in &self.headers_down {
436                        if let Some(hdr_name) = name.strip_prefix('-') {
437                            if let Ok(hn) = hdr_name.parse::<http::header::HeaderName>() {
438                                resp_parts.headers.remove(hn);
439                            }
440                        } else if let (Ok(hn), Ok(hv)) = (
441                            name.parse::<http::header::HeaderName>(),
442                            value.parse::<http::header::HeaderValue>(),
443                        ) {
444                            resp_parts.headers.insert(hn, hv);
445                        }
446                    }
447
448                    // Passive health recovery check.
449                    if let Some(ref ph) = self.passive_health {
450                        ph.maybe_recover(&self.pool).await;
451                    }
452
453                    // Error page interception: if the upstream status code has a
454                    // configured error page, replace the response body with it.
455                    let status_code = resp_parts.status.as_u16();
456                    if let Some(error_body) = self.error_pages.get(&status_code) {
457                        debug!(
458                            status = status_code,
459                            "intercepting upstream error with configured error page"
460                        );
461                        return Ok(Response::from_parts(
462                            resp_parts,
463                            crate::full_body(error_body.clone()),
464                        ));
465                    }
466
467                    return Ok(Response::from_parts(resp_parts, resp_body));
468                }
469                Err(e) => {
470                    // Passive health: record connection failure as well.
471                    if let Some(ref ph) = self.passive_health {
472                        ph.record_failure(backend_idx, &self.pool).await;
473                    }
474
475                    if attempt + 1 < max_attempts {
476                        warn!(
477                            upstream = %backend.addr,
478                            error = %e,
479                            attempt = attempt + 1,
480                            "upstream request failed, retrying"
481                        );
482                        last_failed_idx = Some(backend_idx);
483                        last_error = Some(e);
484                        continue;
485                    }
486
487                    return Err(e);
488                }
489            }
490        }
491
492        // Should not be reached, but just in case:
493        Err(last_error.unwrap_or(ProxyError::NoUpstream))
494    }
495}
496
497// ---------------------------------------------------------------------------
498// URI sanitization helpers
499// ---------------------------------------------------------------------------
500
501/// Sanitize a URI path by:
502/// 1. Collapsing consecutive slashes (e.g. `//foo///bar` → `/foo/bar`).
503/// 2. Resolving `.` (current-directory) segments.
504/// 3. Resolving `..` (parent-directory) segments without escaping the root.
505fn sanitize_path(path: &str) -> String {
506    // Split on '/' and process segments.
507    let mut segments: Vec<&str> = Vec::new();
508    for segment in path.split('/') {
509        match segment {
510            "" | "." => {
511                // Skip empty segments (produced by consecutive slashes) and `.`
512            }
513            ".." => {
514                // Go up one level, but never above root.
515                segments.pop();
516            }
517            s => {
518                segments.push(s);
519            }
520        }
521    }
522    let mut result = String::with_capacity(path.len());
523    result.push('/');
524    result.push_str(&segments.join("/"));
525    result
526}
527
528// ---------------------------------------------------------------------------
529// Unix socket upstream helpers
530// ---------------------------------------------------------------------------
531
532/// Returns true when the backend address is a Unix domain socket path.
533/// Accepts paths starting with `"unix:"` (scheme prefix) or `"/"` (absolute).
534fn is_unix_socket(addr: &str) -> bool {
535    addr.starts_with("unix:") || addr.starts_with('/')
536}
537
538/// Strip the optional `"unix:"` scheme prefix from a socket path.
539#[cfg(unix)]
540fn unix_socket_path(addr: &str) -> &str {
541    addr.strip_prefix("unix:").unwrap_or(addr)
542}
543
544/// Send an HTTP/1.1 request over a Unix domain socket and return the response.
545///
546/// This bypasses the shared HTTPS connector pool and instead opens a fresh
547/// `UnixStream`, performs an HTTP/1 handshake, and sends the request.
548#[cfg(unix)]
549async fn send_via_unix(
550    socket_path: &str,
551    request: http::Request<Body>,
552) -> Result<http::Response<hyper::body::Incoming>, ProxyError> {
553    let stream = tokio::net::UnixStream::connect(socket_path)
554        .await
555        .map_err(|e| ProxyError::Internal(format!("unix socket connect to {socket_path}: {e}")))?;
556    let io = hyper_util::rt::TokioIo::new(stream);
557    let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
558        .await
559        .map_err(ProxyError::Hyper)?;
560    tokio::spawn(async move {
561        let _ = conn.await;
562    });
563    sender
564        .send_request(request)
565        .await
566        .map_err(ProxyError::Hyper)
567}