Skip to main content

zlayer_proxy/
service.rs

1//! Reverse proxy service implementation
2//!
3//! This module provides the core proxy service that handles request forwarding.
4//! It uses the `ServiceRegistry` for route resolution and backend selection.
5
6use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::{transform_path, ResolvedService, ServiceRegistry};
12use bytes::Bytes;
13use http::{header, Request, Response, Uri, Version};
14use http_body_util::{BodyExt, Full};
15use hyper::body::Incoming;
16use hyper::upgrade::OnUpgrade;
17use hyper_util::client::legacy::Client;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use tokio::net::TcpStream;
23use tower::Service;
24use tracing::{debug, error, info, warn};
25use zlayer_spec::ExposeType;
26
27/// The overlay network CIDR used for internal service communication.
28/// Source IPs outside this range are rejected for internal-only routes.
29const OVERLAY_NETWORK: (u8, u8) = (10, 200); // 10.200.0.0/16
30
31/// Check whether an IP address belongs to the overlay network (10.200.0.0/16).
32fn is_overlay_ip(ip: IpAddr) -> bool {
33    match ip {
34        IpAddr::V4(v4) => {
35            let octets = v4.octets();
36            octets[0] == OVERLAY_NETWORK.0 && octets[1] == OVERLAY_NETWORK.1
37        }
38        IpAddr::V6(_) => false,
39    }
40}
41
42/// Body type for outgoing responses
43pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
44
45/// Empty body utility
46#[must_use]
47pub fn empty_body() -> BoxBody {
48    http_body_util::Empty::<Bytes>::new()
49        .map_err(|never| match never {})
50        .boxed()
51}
52
53/// Full body utility
54pub fn full_body(bytes: impl Into<Bytes>) -> BoxBody {
55    Full::new(bytes.into())
56        .map_err(|never| match never {})
57        .boxed()
58}
59
60/// The reverse proxy service
61#[derive(Clone)]
62pub struct ReverseProxyService {
63    /// Service registry for route resolution
64    registry: Arc<ServiceRegistry>,
65    /// Load balancer for backend selection
66    load_balancer: Arc<LoadBalancer>,
67    /// HTTP client for backend requests
68    client: Client<hyper_util::client::legacy::connect::HttpConnector, BoxBody>,
69    /// Proxy configuration
70    config: Arc<ProxyConfig>,
71    /// Client remote address (set per-request)
72    remote_addr: Option<SocketAddr>,
73    /// Whether the connection is over TLS
74    is_tls: bool,
75    /// Certificate manager for ACME challenge responses
76    cert_manager: Option<Arc<CertManager>>,
77    /// Optional network policy checker for access control enforcement
78    network_policy_checker: Option<NetworkPolicyChecker>,
79    /// Trusted upstream proxies. Requests whose TCP peer IP is in this list
80    /// may set `CF-Connecting-IP` / `X-Forwarded-For` and be believed. When no
81    /// explicit list is provided, defaults to `TrustedProxyList::localhost_only()`
82    /// — a safe default for nodes that accidentally receive direct requests.
83    trusted_proxies: Arc<crate::trust::TrustedProxyList>,
84}
85
86impl ReverseProxyService {
87    /// Create a new reverse proxy service
88    pub fn new(
89        registry: Arc<ServiceRegistry>,
90        load_balancer: Arc<LoadBalancer>,
91        config: Arc<ProxyConfig>,
92    ) -> Self {
93        let client = Client::builder(TokioExecutor::new())
94            .pool_max_idle_per_host(config.pool.max_idle_per_backend)
95            .pool_idle_timeout(config.pool.idle_timeout)
96            .pool_timer(hyper_util::rt::TokioTimer::new())
97            .build_http();
98
99        Self {
100            registry,
101            load_balancer,
102            client,
103            config,
104            remote_addr: None,
105            is_tls: false,
106            cert_manager: None,
107            network_policy_checker: None,
108            trusted_proxies: Arc::new(crate::trust::TrustedProxyList::localhost_only()),
109        }
110    }
111
112    /// Set the remote client address for this request
113    #[must_use]
114    pub fn with_remote_addr(mut self, addr: SocketAddr) -> Self {
115        self.remote_addr = Some(addr);
116        self
117    }
118
119    /// Mark this connection as being over TLS
120    #[must_use]
121    pub fn with_tls(mut self, is_tls: bool) -> Self {
122        self.is_tls = is_tls;
123        self
124    }
125
126    /// Override the trusted-proxy list (default: `localhost_only`).
127    ///
128    /// Peers in this list are believed when they set `CF-Connecting-IP` or
129    /// `X-Forwarded-For` headers identifying the real client IP.
130    #[must_use]
131    pub fn with_trusted_proxies(mut self, trusted: Arc<crate::trust::TrustedProxyList>) -> Self {
132        self.trusted_proxies = trusted;
133        self
134    }
135
136    /// Set the certificate manager for ACME challenge interception
137    #[must_use]
138    pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
139        self.cert_manager = Some(cm);
140        self
141    }
142
143    /// Set the network policy checker for access control enforcement
144    #[must_use]
145    pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
146        self.network_policy_checker = Some(checker);
147        self
148    }
149
150    /// Check if this connection is over TLS
151    #[must_use]
152    pub fn is_tls(&self) -> bool {
153        self.is_tls
154    }
155
156    /// Handle an incoming HTTP request
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if route resolution fails, no healthy backends are
161    /// available, or the backend request fails.
162    ///
163    /// # Panics
164    ///
165    /// Panics if building a well-formed HTTP response for an ACME challenge
166    /// or upgrade reply fails (indicates a bug, not a runtime condition).
167    #[allow(clippy::too_many_lines)]
168    pub async fn proxy_request(&self, mut req: Request<Incoming>) -> Result<Response<BoxBody>> {
169        let start = std::time::Instant::now();
170        let method = req.method().clone();
171        let uri = req.uri().clone();
172
173        let host = req
174            .headers()
175            .get(header::HOST)
176            .and_then(|h| h.to_str().ok())
177            .or_else(|| uri.host())
178            .map(std::string::ToString::to_string);
179
180        let path = uri.path().to_string();
181
182        // ACME HTTP-01 challenge interception
183        if path.starts_with("/.well-known/acme-challenge/") {
184            if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
185                if !token.is_empty() {
186                    if let Some(ref cm) = self.cert_manager {
187                        if let Some(auth) = cm.get_challenge_response(token) {
188                            return Ok(Response::builder()
189                                .status(200)
190                                .header("content-type", "text/plain")
191                                .body(full_body(auth))
192                                .unwrap());
193                        }
194                    }
195                }
196            }
197        }
198
199        // Check for WebSocket/HTTP upgrade
200        if crate::tunnel::is_upgrade_request(&req) {
201            // Resolve to get backend for upgrade
202            let resolved = self
203                .registry
204                .resolve(host.as_deref(), &path)
205                .await
206                .ok_or_else(|| ProxyError::RouteNotFound {
207                    host: host.as_deref().unwrap_or("<none>").to_string(),
208                    path: path.clone(),
209                })?;
210
211            // Enforce internal endpoints
212            if resolved.expose == ExposeType::Internal {
213                if let Some(addr) = self.remote_addr {
214                    if !is_overlay_ip(addr.ip()) {
215                        return Err(ProxyError::Forbidden(
216                            "endpoint is internal-only".to_string(),
217                        ));
218                    }
219                }
220            }
221
222            // Enforce network policy access rules
223            if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
224                if !checker
225                    .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
226                    .await
227                {
228                    return Err(ProxyError::Forbidden(format!(
229                        "network policy denied access to service '{}'",
230                        resolved.name
231                    )));
232                }
233            }
234
235            let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
236                ProxyError::NoHealthyBackends {
237                    service: resolved.name.clone(),
238                }
239            })?;
240            let _guard = backend.track_connection();
241            let backend_addr = backend.addr;
242
243            info!(
244                method = %method,
245                host = ?host,
246                path = %path,
247                backend = %backend_addr,
248                service = %resolved.name,
249                "Forwarding upgrade request"
250            );
251
252            // Extract the client's OnUpgrade future BEFORE consuming the request
253            let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
254
255            // Build the backend URI
256            let original_path = req.uri().path();
257            let transformed_path =
258                transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
259            let new_uri = format!(
260                "http://{}{}{}",
261                backend_addr,
262                transformed_path,
263                req.uri()
264                    .query()
265                    .map(|q| format!("?{q}"))
266                    .unwrap_or_default()
267            );
268
269            // Build backend request, preserving upgrade headers
270            let (orig_parts, _body) = req.into_parts();
271            let mut backend_parts = http::request::Builder::new()
272                .method(orig_parts.method.clone())
273                .uri(
274                    new_uri
275                        .parse::<Uri>()
276                        .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?,
277                )
278                .body(())
279                .unwrap()
280                .into_parts()
281                .0;
282
283            // Copy all original headers first (preserving Host, etc.)
284            for (name, value) in &orig_parts.headers {
285                backend_parts.headers.insert(name.clone(), value.clone());
286            }
287
288            // Copy upgrade-specific headers (Connection, Upgrade, Sec-WebSocket-*)
289            crate::tunnel::copy_upgrade_headers(&orig_parts, &mut backend_parts);
290
291            // Add forwarding headers
292            self.add_forwarding_headers(&mut backend_parts);
293
294            // Connect directly to backend (bypass connection pool for long-lived upgrades)
295            let tcp_stream = TcpStream::connect(backend_addr).await.map_err(|e| {
296                error!(error = %e, backend = %backend_addr, "Backend upgrade connect failed");
297                ProxyError::BackendConnectionFailed {
298                    backend: backend_addr,
299                    reason: e.to_string(),
300                }
301            })?;
302            let io = TokioIo::new(tcp_stream);
303
304            // Perform HTTP/1.1 handshake preserving header case
305            let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
306                .preserve_header_case(true)
307                .handshake(io)
308                .await
309                .map_err(|e| {
310                    error!(error = %e, backend = %backend_addr, "Backend upgrade handshake failed");
311                    ProxyError::BackendRequestFailed(format!("Upgrade handshake failed: {e}"))
312                })?;
313
314            // Spawn the connection driver
315            tokio::spawn(async move {
316                if let Err(e) = conn.with_upgrades().await {
317                    error!(error = %e, "Backend upgrade connection driver error");
318                }
319            });
320
321            // Send the request to the backend
322            let backend_req =
323                Request::from_parts(backend_parts, http_body_util::Empty::<Bytes>::new());
324            let backend_response = sender.send_request(backend_req).await.map_err(|e| {
325                error!(error = %e, backend = %backend_addr, "Backend upgrade request failed");
326                ProxyError::BackendRequestFailed(e.to_string())
327            })?;
328
329            if backend_response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
330                // Get the server's OnUpgrade future
331                let server_upgrade: OnUpgrade = hyper::upgrade::on(backend_response);
332
333                // Build 101 response to send back to the client
334                let mut resp_builder =
335                    Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);
336                // Note: we need to construct the response manually since we consumed
337                // the backend response to get OnUpgrade. Copy relevant headers.
338                // The hyper::upgrade::on() for the response does NOT consume it —
339                // it was consumed. We need to return a 101 with appropriate headers.
340                // Actually, hyper::upgrade::on() takes the response by value, so we
341                // must build our own 101 response for the client.
342
343                // For the client response, set Connection: upgrade and Upgrade headers
344                if let Some(upgrade_val) = orig_parts.headers.get(header::UPGRADE) {
345                    resp_builder = resp_builder.header(header::UPGRADE, upgrade_val.clone());
346                }
347                resp_builder = resp_builder.header(header::CONNECTION, "upgrade");
348
349                let client_response = resp_builder.body(empty_body()).map_err(|e| {
350                    ProxyError::Internal(format!("Failed to build 101 response: {e}"))
351                })?;
352
353                // Spawn background task to bridge the upgraded connections
354                tokio::spawn(async move {
355                    if let Err(e) =
356                        crate::tunnel::proxy_upgrade(client_upgrade, server_upgrade).await
357                    {
358                        debug!(error = %e, "Upgrade tunnel ended");
359                    }
360                });
361
362                // Add timing header to the 101 response
363                let (mut parts, body) = client_response.into_parts();
364                if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
365                    parts.headers.insert("server-timing", hv);
366                }
367
368                return Ok(Response::from_parts(parts, body));
369            }
370
371            // Backend didn't upgrade — stream the response as-is
372            let (mut parts, body) = backend_response.into_parts();
373            let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
374
375            // Add HSTS header for TLS connections
376            if self.is_tls && self.config.headers.hsts {
377                let value = if self.config.headers.hsts_subdomains {
378                    format!(
379                        "max-age={}; includeSubDomains",
380                        self.config.headers.hsts_max_age
381                    )
382                } else {
383                    format!("max-age={}", self.config.headers.hsts_max_age)
384                };
385                if let Ok(hv) = value.parse() {
386                    parts.headers.insert("strict-transport-security", hv);
387                }
388            }
389
390            // Add Server-Timing header
391            if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
392                parts.headers.insert("server-timing", hv);
393            }
394
395            return Ok(Response::from_parts(parts, streaming_body));
396        }
397
398        debug!(method = %method, host = ?host, path = %path, "Routing request");
399
400        // Resolve route
401        let resolved = self
402            .registry
403            .resolve(host.as_deref(), &path)
404            .await
405            .ok_or_else(|| ProxyError::RouteNotFound {
406                host: host.as_deref().unwrap_or("<none>").to_string(),
407                path: path.clone(),
408            })?;
409
410        // Enforce internal endpoints
411        if resolved.expose == ExposeType::Internal {
412            match self.remote_addr {
413                Some(addr) if !is_overlay_ip(addr.ip()) => {
414                    warn!(
415                        source = %addr.ip(),
416                        service = %resolved.name,
417                        "Rejected non-overlay source for internal endpoint"
418                    );
419                    return Err(ProxyError::Forbidden(
420                        "endpoint is internal-only".to_string(),
421                    ));
422                }
423                None => {
424                    debug!(
425                        service = %resolved.name,
426                        "No remote_addr available; skipping overlay source check"
427                    );
428                }
429                _ => {}
430            }
431        }
432
433        // Enforce network policy access rules
434        if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
435            if !checker
436                .check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
437                .await
438            {
439                return Err(ProxyError::Forbidden(format!(
440                    "network policy denied access to service '{}'",
441                    resolved.name
442                )));
443            }
444        }
445
446        // Select backend via load balancer
447        let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
448            ProxyError::NoHealthyBackends {
449                service: resolved.name.clone(),
450            }
451        })?;
452        let _guard = backend.track_connection();
453        let backend_addr = backend.addr;
454
455        info!(
456            method = %method,
457            host = ?host,
458            path = %path,
459            backend = %backend_addr,
460            service = %resolved.name,
461            "Forwarding request"
462        );
463
464        // Build forwarded request
465        let forwarded_req = self.build_forwarded_request(req, &backend_addr, &resolved)?;
466
467        // Forward to backend
468        let response = self.client.request(forwarded_req).await.map_err(|e| {
469            error!(error = %e, backend = %backend_addr, "Backend request failed");
470            ProxyError::BackendRequestFailed(e.to_string())
471        })?;
472
473        let (mut parts, body) = response.into_parts();
474        let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
475
476        // Add HSTS header for TLS connections
477        if self.is_tls && self.config.headers.hsts {
478            let value = if self.config.headers.hsts_subdomains {
479                format!(
480                    "max-age={}; includeSubDomains",
481                    self.config.headers.hsts_max_age
482                )
483            } else {
484                format!("max-age={}", self.config.headers.hsts_max_age)
485            };
486            if let Ok(hv) = value.parse() {
487                parts.headers.insert("strict-transport-security", hv);
488            }
489        }
490
491        // Add Server-Timing header
492        if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
493            parts.headers.insert("server-timing", hv);
494        }
495
496        Ok(Response::from_parts(parts, streaming_body))
497    }
498
499    fn build_forwarded_request(
500        &self,
501        req: Request<Incoming>,
502        backend: &SocketAddr,
503        resolved: &ResolvedService,
504    ) -> Result<Request<BoxBody>> {
505        let (mut parts, body) = req.into_parts();
506
507        // Transform the path if needed
508        let original_path = parts.uri.path();
509        let transformed_path =
510            transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
511
512        // Build new URI for backend
513        let new_uri = format!(
514            "http://{}{}{}",
515            backend,
516            transformed_path,
517            parts
518                .uri
519                .query()
520                .map(|q| format!("?{q}"))
521                .unwrap_or_default()
522        );
523
524        parts.uri = new_uri
525            .parse::<Uri>()
526            .map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?;
527
528        // Add forwarding headers
529        self.add_forwarding_headers(&mut parts);
530
531        // Remove hop-by-hop headers
532        Self::remove_hop_by_hop_headers(&mut parts);
533
534        let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
535
536        let req = Request::from_parts(parts, streaming_body);
537        Ok(req)
538    }
539
540    fn add_forwarding_headers(&self, parts: &mut http::request::Parts) {
541        let config = &self.config.headers;
542
543        // Determine whether the immediate TCP peer is a trusted upstream proxy
544        // that may dictate the real client IP via CF-Connecting-IP or XFF.
545        let peer_is_trusted = self
546            .remote_addr
547            .is_some_and(|addr| self.trusted_proxies.is_trusted(addr.ip()));
548
549        // Compute the effective client IP:
550        //   - Trusted peer + CF-Connecting-IP (parseable) -> use CF header
551        //   - Trusted peer + leftmost X-Forwarded-For (parseable) -> use XFF
552        //   - Otherwise -> fall back to the TCP peer IP
553        let effective_client_ip: Option<IpAddr> = if peer_is_trusted {
554            let cf_ip = parts
555                .headers
556                .get("cf-connecting-ip")
557                .and_then(|h| h.to_str().ok())
558                .and_then(|s| s.trim().parse::<IpAddr>().ok());
559
560            let xff_leftmost = parts
561                .headers
562                .get("x-forwarded-for")
563                .and_then(|h| h.to_str().ok())
564                .and_then(|s| s.split(',').next())
565                .and_then(|s| s.trim().parse::<IpAddr>().ok());
566
567            cf_ip
568                .or(xff_leftmost)
569                .or_else(|| self.remote_addr.map(|a| a.ip()))
570        } else {
571            self.remote_addr.map(|a| a.ip())
572        };
573
574        // X-Forwarded-For
575        if config.x_forwarded_for {
576            if let Some(addr) = self.remote_addr {
577                let existing_xff = parts
578                    .headers
579                    .get("x-forwarded-for")
580                    .and_then(|h| h.to_str().ok())
581                    .map(std::string::ToString::to_string);
582
583                let new_value = if peer_is_trusted {
584                    // Trusted proxy: prepend the real client IP (from CF /
585                    // leftmost XFF / peer) to any existing chain so downstream
586                    // sees [real_client, ...upstream_chain].
587                    let real = effective_client_ip.unwrap_or_else(|| addr.ip()).to_string();
588                    match existing_xff {
589                        Some(chain) if !chain.trim().is_empty() => format!("{real}, {chain}"),
590                        _ => real,
591                    }
592                } else {
593                    // Untrusted peer: preserve existing behavior — append the
594                    // peer IP to any existing chain.
595                    match existing_xff {
596                        Some(chain) => format!("{}, {}", chain, addr.ip()),
597                        None => addr.ip().to_string(),
598                    }
599                };
600
601                if let Ok(value) = new_value.parse() {
602                    parts.headers.insert("x-forwarded-for", value);
603                }
604            }
605        }
606
607        // X-Forwarded-Proto
608        if config.x_forwarded_proto && parts.headers.get("x-forwarded-proto").is_none() {
609            let proto = if self.is_tls { "https" } else { "http" };
610            if let Ok(value) = proto.parse() {
611                parts.headers.insert("x-forwarded-proto", value);
612            }
613        }
614
615        // X-Forwarded-Host
616        if config.x_forwarded_host {
617            if let Some(host) = parts.headers.get(header::HOST).cloned() {
618                if parts.headers.get("x-forwarded-host").is_none() {
619                    parts.headers.insert("x-forwarded-host", host);
620                }
621            }
622        }
623
624        // X-Real-IP — set to the effective client IP only if the header is
625        // currently absent (conservative: do not overwrite a value set by an
626        // upstream component).
627        if config.x_real_ip {
628            if let Some(ip) = effective_client_ip {
629                if parts.headers.get("x-real-ip").is_none() {
630                    if let Ok(value) = ip.to_string().parse() {
631                        parts.headers.insert("x-real-ip", value);
632                    }
633                }
634            }
635        }
636
637        // Via header
638        if config.via {
639            let proto_version = match parts.version {
640                Version::HTTP_09 => "0.9",
641                Version::HTTP_10 => "1.0",
642                Version::HTTP_2 => "2.0",
643                Version::HTTP_3 => "3.0",
644                _ => "1.1",
645            };
646
647            let via_value = format!("{} {}", proto_version, config.server_name);
648            let existing = parts
649                .headers
650                .get(header::VIA)
651                .and_then(|h| h.to_str().ok())
652                .map(|s| format!("{s}, {via_value}"))
653                .unwrap_or(via_value);
654
655            if let Ok(value) = existing.parse() {
656                parts.headers.insert(header::VIA, value);
657            }
658        }
659    }
660
661    fn remove_hop_by_hop_headers(parts: &mut http::request::Parts) {
662        // Standard hop-by-hop headers that should not be forwarded
663        const HOP_BY_HOP: &[&str] = &[
664            "connection",
665            "keep-alive",
666            "proxy-authenticate",
667            "proxy-authorization",
668            "te",
669            "trailer",
670            "transfer-encoding",
671            "upgrade",
672        ];
673
674        // First, collect headers listed in the Connection header before we remove it
675        let connection_headers: Vec<String> = parts
676            .headers
677            .get(header::CONNECTION)
678            .and_then(|h| h.to_str().ok())
679            .map(|value| value.split(',').map(|s| s.trim().to_lowercase()).collect())
680            .unwrap_or_default();
681
682        for header_name in HOP_BY_HOP {
683            parts.headers.remove(*header_name);
684        }
685
686        // Also remove headers that were listed in the Connection header
687        for header_name in connection_headers {
688            parts.headers.remove(header_name.as_str());
689        }
690    }
691
692    /// Build a client-facing error response with a **generic** body.
693    ///
694    /// This is the default-deny safety boundary for the ingress proxy. The
695    /// proxy binds `0.0.0.0:80`/`:443`, so it MUST NOT leak internal details
696    /// (the requested Host/path, a backend address, or the internal
697    /// load-balancer group name) to an unauthenticated caller. The full,
698    /// detailed [`ProxyError`] is logged by the caller (`error!(error = %e)`
699    /// in `server.rs`); the body returned here carries only a minimal,
700    /// status-appropriate phrase so an unmatched / no-target request gets a
701    /// clean deny rather than an internal echo.
702    ///
703    /// # Panics
704    ///
705    /// Panics if building a valid HTTP response with a plain-text body fails,
706    /// which should never occur with well-formed status codes.
707    pub fn error_response(error: &ProxyError) -> Response<BoxBody> {
708        let status = error.status_code();
709        // Generic, non-leaking body keyed purely off the status code. We
710        // deliberately do NOT interpolate `error` (which can contain the Host,
711        // path, backend address, or LB group name) into the client-visible
712        // body.
713        let body = status.canonical_reason().map_or_else(
714            || status.as_str().to_string(),
715            |reason| format!("{} {reason}", status.as_u16()),
716        );
717
718        Response::builder()
719            .status(status)
720            .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
721            .body(full_body(body))
722            .unwrap()
723    }
724}
725
726impl Service<Request<Incoming>> for ReverseProxyService {
727    type Response = Response<BoxBody>;
728    type Error = ProxyError;
729    type Future = std::pin::Pin<
730        Box<
731            dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
732                + Send,
733        >,
734    >;
735
736    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
737        Poll::Ready(Ok(()))
738    }
739
740    fn call(&mut self, req: Request<Incoming>) -> Self::Future {
741        let this = self.clone();
742        Box::pin(async move { this.proxy_request(req).await })
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn test_error_response() {
752        let error = ProxyError::RouteNotFound {
753            host: "example.com".to_string(),
754            path: "/api".to_string(),
755        };
756
757        let response = ReverseProxyService::error_response(&error);
758        assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
759    }
760
761    #[test]
762    fn test_hop_by_hop_headers() {
763        let mut parts = http::request::Builder::new()
764            .method("GET")
765            .uri("/test")
766            .header("connection", "keep-alive, x-custom")
767            .header("keep-alive", "timeout=5")
768            .header("x-custom", "value")
769            .header("x-other", "value")
770            .body(())
771            .unwrap()
772            .into_parts()
773            .0;
774
775        ReverseProxyService::remove_hop_by_hop_headers(&mut parts);
776
777        assert!(parts.headers.get("connection").is_none());
778        assert!(parts.headers.get("keep-alive").is_none());
779        assert!(parts.headers.get("x-custom").is_none());
780        // x-other should remain
781        assert!(parts.headers.get("x-other").is_some());
782    }
783
784    #[test]
785    fn test_is_overlay_ip_accepts_overlay_range() {
786        // 10.200.x.x should be recognized as overlay
787        assert!(is_overlay_ip("10.200.0.1".parse().unwrap()));
788        assert!(is_overlay_ip("10.200.255.254".parse().unwrap()));
789        assert!(is_overlay_ip("10.200.1.100".parse().unwrap()));
790    }
791
792    #[test]
793    fn test_is_overlay_ip_rejects_non_overlay() {
794        // Non-overlay addresses
795        assert!(!is_overlay_ip("192.168.1.1".parse().unwrap()));
796        assert!(!is_overlay_ip("10.0.0.1".parse().unwrap()));
797        assert!(!is_overlay_ip("10.201.0.1".parse().unwrap()));
798        assert!(!is_overlay_ip("172.16.0.1".parse().unwrap()));
799        assert!(!is_overlay_ip("8.8.8.8".parse().unwrap()));
800    }
801
802    #[test]
803    fn test_is_overlay_ip_rejects_ipv6() {
804        assert!(!is_overlay_ip("::1".parse().unwrap()));
805        assert!(!is_overlay_ip("fe80::1".parse().unwrap()));
806    }
807
808    #[test]
809    fn test_forbidden_error_response() {
810        let error = ProxyError::Forbidden("endpoint 'ws' is internal-only".to_string());
811        let response = ReverseProxyService::error_response(&error);
812        assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
813    }
814
815    // --- Tests for CF-Connecting-IP / X-Forwarded-For trust handling ------
816
817    use crate::trust::TrustedProxyList;
818
819    fn build_svc(peer: SocketAddr, trusted: TrustedProxyList) -> ReverseProxyService {
820        let registry = Arc::new(ServiceRegistry::new());
821        let load_balancer = Arc::new(LoadBalancer::new());
822        let config = Arc::new(ProxyConfig::default());
823        ReverseProxyService::new(registry, load_balancer, config)
824            .with_remote_addr(peer)
825            .with_trusted_proxies(Arc::new(trusted))
826    }
827
828    fn parts_with_headers(headers: &[(&str, &str)]) -> http::request::Parts {
829        let mut builder = http::request::Builder::new().method("GET").uri("/");
830        for (k, v) in headers {
831            builder = builder.header(*k, *v);
832        }
833        builder.body(()).unwrap().into_parts().0
834    }
835
836    #[test]
837    fn trusted_peer_cf_connecting_ip_is_honored() {
838        // Peer 203.0.113.50 is inside the trusted /24. Its CF-Connecting-IP
839        // should become X-Real-IP and be prepended to X-Forwarded-For.
840        let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
841        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
842        let svc = build_svc(peer, trusted);
843
844        let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
845        svc.add_forwarding_headers(&mut parts);
846
847        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.7");
848        let xff = parts
849            .headers
850            .get("x-forwarded-for")
851            .unwrap()
852            .to_str()
853            .unwrap();
854        assert!(
855            xff.starts_with("198.51.100.7"),
856            "XFF should start with real client IP, got {xff}"
857        );
858    }
859
860    #[test]
861    fn trusted_peer_xff_leftmost_is_honored_when_no_cf_header() {
862        // Peer is trusted; no CF header but XFF chain is present. The leftmost
863        // XFF entry is treated as the real client IP.
864        let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
865        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
866        let svc = build_svc(peer, trusted);
867
868        let mut parts = parts_with_headers(&[("x-forwarded-for", "198.51.100.9, 10.0.0.1")]);
869        svc.add_forwarding_headers(&mut parts);
870
871        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.9");
872        let xff = parts
873            .headers
874            .get("x-forwarded-for")
875            .unwrap()
876            .to_str()
877            .unwrap();
878        // Real client prepended, original chain preserved after.
879        assert!(
880            xff.starts_with("198.51.100.9"),
881            "XFF should start with leftmost real client, got {xff}"
882        );
883        assert!(
884            xff.contains("10.0.0.1"),
885            "original chain should survive: {xff}"
886        );
887    }
888
889    #[test]
890    fn untrusted_peer_cf_connecting_ip_is_ignored() {
891        // Peer 8.8.8.8 is NOT in the trusted list. The CF header must be
892        // ignored and X-Real-IP must reflect the TCP peer.
893        let peer: SocketAddr = "8.8.8.8:443".parse().unwrap();
894        let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
895        let svc = build_svc(peer, trusted);
896
897        let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
898        svc.add_forwarding_headers(&mut parts);
899
900        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "8.8.8.8");
901        let xff = parts
902            .headers
903            .get("x-forwarded-for")
904            .unwrap()
905            .to_str()
906            .unwrap();
907        // Untrusted peer: XFF should end with the peer IP (append behavior).
908        assert!(
909            xff.ends_with("8.8.8.8"),
910            "XFF for untrusted peer should end with peer IP, got {xff}"
911        );
912    }
913
914    #[test]
915    fn no_headers_uses_peer_ip() {
916        // No CF, no XFF. Any peer (trusted or not) should yield X-Real-IP ==
917        // peer IP.
918        let peer: SocketAddr = "198.51.100.250:443".parse().unwrap();
919        let trusted = TrustedProxyList::localhost_only();
920        let svc = build_svc(peer, trusted);
921
922        let mut parts = parts_with_headers(&[]);
923        svc.add_forwarding_headers(&mut parts);
924
925        assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.250");
926        assert_eq!(
927            parts.headers.get("x-forwarded-for").unwrap(),
928            "198.51.100.250"
929        );
930    }
931}