use crate::acme::CertManager;
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result};
use crate::lb::LoadBalancer;
use crate::network_policy::NetworkPolicyChecker;
use crate::routes::{transform_path, ResolvedService, ServiceRegistry};
use bytes::Bytes;
use http::{header, Request, Response, Uri, Version};
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::upgrade::OnUpgrade;
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::net::TcpStream;
use tower::Service;
use tracing::{debug, error, info, warn};
use zlayer_spec::ExposeType;
const OVERLAY_NETWORK: (u8, u8) = (10, 200);
fn is_overlay_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let octets = v4.octets();
octets[0] == OVERLAY_NETWORK.0 && octets[1] == OVERLAY_NETWORK.1
}
IpAddr::V6(_) => false,
}
}
pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
#[must_use]
pub fn empty_body() -> BoxBody {
http_body_util::Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
pub fn full_body(bytes: impl Into<Bytes>) -> BoxBody {
Full::new(bytes.into())
.map_err(|never| match never {})
.boxed()
}
#[derive(Clone)]
pub struct ReverseProxyService {
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
client: Client<hyper_util::client::legacy::connect::HttpConnector, BoxBody>,
config: Arc<ProxyConfig>,
remote_addr: Option<SocketAddr>,
is_tls: bool,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
trusted_proxies: Arc<crate::trust::TrustedProxyList>,
}
impl ReverseProxyService {
pub fn new(
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
config: Arc<ProxyConfig>,
) -> Self {
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(config.pool.max_idle_per_backend)
.pool_idle_timeout(config.pool.idle_timeout)
.pool_timer(hyper_util::rt::TokioTimer::new())
.build_http();
Self {
registry,
load_balancer,
client,
config,
remote_addr: None,
is_tls: false,
cert_manager: None,
network_policy_checker: None,
trusted_proxies: Arc::new(crate::trust::TrustedProxyList::localhost_only()),
}
}
#[must_use]
pub fn with_remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self
}
#[must_use]
pub fn with_tls(mut self, is_tls: bool) -> Self {
self.is_tls = is_tls;
self
}
#[must_use]
pub fn with_trusted_proxies(mut self, trusted: Arc<crate::trust::TrustedProxyList>) -> Self {
self.trusted_proxies = trusted;
self
}
#[must_use]
pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
self.cert_manager = Some(cm);
self
}
#[must_use]
pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
self.network_policy_checker = Some(checker);
self
}
#[must_use]
pub fn is_tls(&self) -> bool {
self.is_tls
}
#[allow(clippy::too_many_lines)]
pub async fn proxy_request(&self, mut req: Request<Incoming>) -> Result<Response<BoxBody>> {
let start = std::time::Instant::now();
let method = req.method().clone();
let uri = req.uri().clone();
let host = req
.headers()
.get(header::HOST)
.and_then(|h| h.to_str().ok())
.or_else(|| uri.host())
.map(std::string::ToString::to_string);
let path = uri.path().to_string();
if path.starts_with("/.well-known/acme-challenge/") {
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
if !token.is_empty() {
if let Some(ref cm) = self.cert_manager {
if let Some(auth) = cm.get_challenge_response(token) {
return Ok(Response::builder()
.status(200)
.header("content-type", "text/plain")
.body(full_body(auth))
.unwrap());
}
}
}
}
}
if crate::tunnel::is_upgrade_request(&req) {
let resolved = self
.registry
.resolve(host.as_deref(), &path)
.await
.ok_or_else(|| ProxyError::RouteNotFound {
host: host.as_deref().unwrap_or("<none>").to_string(),
path: path.clone(),
})?;
if resolved.expose == ExposeType::Internal {
if let Some(addr) = self.remote_addr {
if !is_overlay_ip(addr.ip()) {
return Err(ProxyError::Forbidden(
"endpoint is internal-only".to_string(),
));
}
}
}
if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
if !checker
.check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
.await
{
return Err(ProxyError::Forbidden(format!(
"network policy denied access to service '{}'",
resolved.name
)));
}
}
let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
ProxyError::NoHealthyBackends {
service: resolved.name.clone(),
}
})?;
let _guard = backend.track_connection();
let backend_addr = backend.addr;
info!(
method = %method,
host = ?host,
path = %path,
backend = %backend_addr,
service = %resolved.name,
"Forwarding upgrade request"
);
let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
let original_path = req.uri().path();
let transformed_path =
transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
let new_uri = format!(
"http://{}{}{}",
backend_addr,
transformed_path,
req.uri()
.query()
.map(|q| format!("?{q}"))
.unwrap_or_default()
);
let (orig_parts, _body) = req.into_parts();
let mut backend_parts = http::request::Builder::new()
.method(orig_parts.method.clone())
.uri(
new_uri
.parse::<Uri>()
.map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?,
)
.body(())
.unwrap()
.into_parts()
.0;
for (name, value) in &orig_parts.headers {
backend_parts.headers.insert(name.clone(), value.clone());
}
crate::tunnel::copy_upgrade_headers(&orig_parts, &mut backend_parts);
self.add_forwarding_headers(&mut backend_parts);
let tcp_stream = TcpStream::connect(backend_addr).await.map_err(|e| {
error!(error = %e, backend = %backend_addr, "Backend upgrade connect failed");
ProxyError::BackendConnectionFailed {
backend: backend_addr,
reason: e.to_string(),
}
})?;
let io = TokioIo::new(tcp_stream);
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
.preserve_header_case(true)
.handshake(io)
.await
.map_err(|e| {
error!(error = %e, backend = %backend_addr, "Backend upgrade handshake failed");
ProxyError::BackendRequestFailed(format!("Upgrade handshake failed: {e}"))
})?;
tokio::spawn(async move {
if let Err(e) = conn.with_upgrades().await {
error!(error = %e, "Backend upgrade connection driver error");
}
});
let backend_req =
Request::from_parts(backend_parts, http_body_util::Empty::<Bytes>::new());
let backend_response = sender.send_request(backend_req).await.map_err(|e| {
error!(error = %e, backend = %backend_addr, "Backend upgrade request failed");
ProxyError::BackendRequestFailed(e.to_string())
})?;
if backend_response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
let server_upgrade: OnUpgrade = hyper::upgrade::on(backend_response);
let mut resp_builder =
Response::builder().status(http::StatusCode::SWITCHING_PROTOCOLS);
if let Some(upgrade_val) = orig_parts.headers.get(header::UPGRADE) {
resp_builder = resp_builder.header(header::UPGRADE, upgrade_val.clone());
}
resp_builder = resp_builder.header(header::CONNECTION, "upgrade");
let client_response = resp_builder.body(empty_body()).map_err(|e| {
ProxyError::Internal(format!("Failed to build 101 response: {e}"))
})?;
tokio::spawn(async move {
if let Err(e) =
crate::tunnel::proxy_upgrade(client_upgrade, server_upgrade).await
{
debug!(error = %e, "Upgrade tunnel ended");
}
});
let (mut parts, body) = client_response.into_parts();
if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
parts.headers.insert("server-timing", hv);
}
return Ok(Response::from_parts(parts, body));
}
let (mut parts, body) = backend_response.into_parts();
let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
if self.is_tls && self.config.headers.hsts {
let value = if self.config.headers.hsts_subdomains {
format!(
"max-age={}; includeSubDomains",
self.config.headers.hsts_max_age
)
} else {
format!("max-age={}", self.config.headers.hsts_max_age)
};
if let Ok(hv) = value.parse() {
parts.headers.insert("strict-transport-security", hv);
}
}
if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
parts.headers.insert("server-timing", hv);
}
return Ok(Response::from_parts(parts, streaming_body));
}
debug!(method = %method, host = ?host, path = %path, "Routing request");
let resolved = self
.registry
.resolve(host.as_deref(), &path)
.await
.ok_or_else(|| ProxyError::RouteNotFound {
host: host.as_deref().unwrap_or("<none>").to_string(),
path: path.clone(),
})?;
if resolved.expose == ExposeType::Internal {
match self.remote_addr {
Some(addr) if !is_overlay_ip(addr.ip()) => {
warn!(
source = %addr.ip(),
service = %resolved.name,
"Rejected non-overlay source for internal endpoint"
);
return Err(ProxyError::Forbidden(
"endpoint is internal-only".to_string(),
));
}
None => {
debug!(
service = %resolved.name,
"No remote_addr available; skipping overlay source check"
);
}
_ => {}
}
}
if let (Some(checker), Some(addr)) = (&self.network_policy_checker, self.remote_addr) {
if !checker
.check_access(addr.ip(), &resolved.name, "*", resolved.target_port)
.await
{
return Err(ProxyError::Forbidden(format!(
"network policy denied access to service '{}'",
resolved.name
)));
}
}
let backend = self.load_balancer.select(&resolved.name).ok_or_else(|| {
ProxyError::NoHealthyBackends {
service: resolved.name.clone(),
}
})?;
let _guard = backend.track_connection();
let backend_addr = backend.addr;
info!(
method = %method,
host = ?host,
path = %path,
backend = %backend_addr,
service = %resolved.name,
"Forwarding request"
);
let forwarded_req = self.build_forwarded_request(req, &backend_addr, &resolved)?;
let response = self.client.request(forwarded_req).await.map_err(|e| {
error!(error = %e, backend = %backend_addr, "Backend request failed");
ProxyError::BackendRequestFailed(e.to_string())
})?;
let (mut parts, body) = response.into_parts();
let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
if self.is_tls && self.config.headers.hsts {
let value = if self.config.headers.hsts_subdomains {
format!(
"max-age={}; includeSubDomains",
self.config.headers.hsts_max_age
)
} else {
format!("max-age={}", self.config.headers.hsts_max_age)
};
if let Ok(hv) = value.parse() {
parts.headers.insert("strict-transport-security", hv);
}
}
if let Ok(hv) = format!("proxy;dur={}", start.elapsed().as_millis()).parse() {
parts.headers.insert("server-timing", hv);
}
Ok(Response::from_parts(parts, streaming_body))
}
fn build_forwarded_request(
&self,
req: Request<Incoming>,
backend: &SocketAddr,
resolved: &ResolvedService,
) -> Result<Request<BoxBody>> {
let (mut parts, body) = req.into_parts();
let original_path = parts.uri.path();
let transformed_path =
transform_path(&resolved.path_prefix, original_path, resolved.strip_prefix);
let new_uri = format!(
"http://{}{}{}",
backend,
transformed_path,
parts
.uri
.query()
.map(|q| format!("?{q}"))
.unwrap_or_default()
);
parts.uri = new_uri
.parse::<Uri>()
.map_err(|e| ProxyError::InvalidRequest(format!("Invalid URI: {e}")))?;
self.add_forwarding_headers(&mut parts);
Self::remove_hop_by_hop_headers(&mut parts);
let streaming_body: BoxBody = body.map_err(|e: hyper::Error| e).boxed();
let req = Request::from_parts(parts, streaming_body);
Ok(req)
}
fn add_forwarding_headers(&self, parts: &mut http::request::Parts) {
let config = &self.config.headers;
let peer_is_trusted = self
.remote_addr
.is_some_and(|addr| self.trusted_proxies.is_trusted(addr.ip()));
let effective_client_ip: Option<IpAddr> = if peer_is_trusted {
let cf_ip = parts
.headers
.get("cf-connecting-ip")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.trim().parse::<IpAddr>().ok());
let xff_leftmost = parts
.headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.split(',').next())
.and_then(|s| s.trim().parse::<IpAddr>().ok());
cf_ip
.or(xff_leftmost)
.or_else(|| self.remote_addr.map(|a| a.ip()))
} else {
self.remote_addr.map(|a| a.ip())
};
if config.x_forwarded_for {
if let Some(addr) = self.remote_addr {
let existing_xff = parts
.headers
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.map(std::string::ToString::to_string);
let new_value = if peer_is_trusted {
let real = effective_client_ip.unwrap_or_else(|| addr.ip()).to_string();
match existing_xff {
Some(chain) if !chain.trim().is_empty() => format!("{real}, {chain}"),
_ => real,
}
} else {
match existing_xff {
Some(chain) => format!("{}, {}", chain, addr.ip()),
None => addr.ip().to_string(),
}
};
if let Ok(value) = new_value.parse() {
parts.headers.insert("x-forwarded-for", value);
}
}
}
if config.x_forwarded_proto && parts.headers.get("x-forwarded-proto").is_none() {
let proto = if self.is_tls { "https" } else { "http" };
if let Ok(value) = proto.parse() {
parts.headers.insert("x-forwarded-proto", value);
}
}
if config.x_forwarded_host {
if let Some(host) = parts.headers.get(header::HOST).cloned() {
if parts.headers.get("x-forwarded-host").is_none() {
parts.headers.insert("x-forwarded-host", host);
}
}
}
if config.x_real_ip {
if let Some(ip) = effective_client_ip {
if parts.headers.get("x-real-ip").is_none() {
if let Ok(value) = ip.to_string().parse() {
parts.headers.insert("x-real-ip", value);
}
}
}
}
if config.via {
let proto_version = match parts.version {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_ => "1.1",
};
let via_value = format!("{} {}", proto_version, config.server_name);
let existing = parts
.headers
.get(header::VIA)
.and_then(|h| h.to_str().ok())
.map(|s| format!("{s}, {via_value}"))
.unwrap_or(via_value);
if let Ok(value) = existing.parse() {
parts.headers.insert(header::VIA, value);
}
}
}
fn remove_hop_by_hop_headers(parts: &mut http::request::Parts) {
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
let connection_headers: Vec<String> = parts
.headers
.get(header::CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|value| value.split(',').map(|s| s.trim().to_lowercase()).collect())
.unwrap_or_default();
for header_name in HOP_BY_HOP {
parts.headers.remove(*header_name);
}
for header_name in connection_headers {
parts.headers.remove(header_name.as_str());
}
}
pub fn error_response(error: &ProxyError) -> Response<BoxBody> {
let status = error.status_code();
let body = format!("{{\"error\": \"{error}\"}}");
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(full_body(body))
.unwrap()
}
}
impl Service<Request<Incoming>> for ReverseProxyService {
type Response = Response<BoxBody>;
type Error = ProxyError;
type Future = std::pin::Pin<
Box<
dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
+ Send,
>,
>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Incoming>) -> Self::Future {
let this = self.clone();
Box::pin(async move { this.proxy_request(req).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_response() {
let error = ProxyError::RouteNotFound {
host: "example.com".to_string(),
path: "/api".to_string(),
};
let response = ReverseProxyService::error_response(&error);
assert_eq!(response.status(), http::StatusCode::NOT_FOUND);
}
#[test]
fn test_hop_by_hop_headers() {
let mut parts = http::request::Builder::new()
.method("GET")
.uri("/test")
.header("connection", "keep-alive, x-custom")
.header("keep-alive", "timeout=5")
.header("x-custom", "value")
.header("x-other", "value")
.body(())
.unwrap()
.into_parts()
.0;
ReverseProxyService::remove_hop_by_hop_headers(&mut parts);
assert!(parts.headers.get("connection").is_none());
assert!(parts.headers.get("keep-alive").is_none());
assert!(parts.headers.get("x-custom").is_none());
assert!(parts.headers.get("x-other").is_some());
}
#[test]
fn test_is_overlay_ip_accepts_overlay_range() {
assert!(is_overlay_ip("10.200.0.1".parse().unwrap()));
assert!(is_overlay_ip("10.200.255.254".parse().unwrap()));
assert!(is_overlay_ip("10.200.1.100".parse().unwrap()));
}
#[test]
fn test_is_overlay_ip_rejects_non_overlay() {
assert!(!is_overlay_ip("192.168.1.1".parse().unwrap()));
assert!(!is_overlay_ip("10.0.0.1".parse().unwrap()));
assert!(!is_overlay_ip("10.201.0.1".parse().unwrap()));
assert!(!is_overlay_ip("172.16.0.1".parse().unwrap()));
assert!(!is_overlay_ip("8.8.8.8".parse().unwrap()));
}
#[test]
fn test_is_overlay_ip_rejects_ipv6() {
assert!(!is_overlay_ip("::1".parse().unwrap()));
assert!(!is_overlay_ip("fe80::1".parse().unwrap()));
}
#[test]
fn test_forbidden_error_response() {
let error = ProxyError::Forbidden("endpoint 'ws' is internal-only".to_string());
let response = ReverseProxyService::error_response(&error);
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
use crate::trust::TrustedProxyList;
fn build_svc(peer: SocketAddr, trusted: TrustedProxyList) -> ReverseProxyService {
let registry = Arc::new(ServiceRegistry::new());
let load_balancer = Arc::new(LoadBalancer::new());
let config = Arc::new(ProxyConfig::default());
ReverseProxyService::new(registry, load_balancer, config)
.with_remote_addr(peer)
.with_trusted_proxies(Arc::new(trusted))
}
fn parts_with_headers(headers: &[(&str, &str)]) -> http::request::Parts {
let mut builder = http::request::Builder::new().method("GET").uri("/");
for (k, v) in headers {
builder = builder.header(*k, *v);
}
builder.body(()).unwrap().into_parts().0
}
#[test]
fn trusted_peer_cf_connecting_ip_is_honored() {
let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
let svc = build_svc(peer, trusted);
let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
svc.add_forwarding_headers(&mut parts);
assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.7");
let xff = parts
.headers
.get("x-forwarded-for")
.unwrap()
.to_str()
.unwrap();
assert!(
xff.starts_with("198.51.100.7"),
"XFF should start with real client IP, got {xff}"
);
}
#[test]
fn trusted_peer_xff_leftmost_is_honored_when_no_cf_header() {
let peer: SocketAddr = "203.0.113.50:443".parse().unwrap();
let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
let svc = build_svc(peer, trusted);
let mut parts = parts_with_headers(&[("x-forwarded-for", "198.51.100.9, 10.0.0.1")]);
svc.add_forwarding_headers(&mut parts);
assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.9");
let xff = parts
.headers
.get("x-forwarded-for")
.unwrap()
.to_str()
.unwrap();
assert!(
xff.starts_with("198.51.100.9"),
"XFF should start with leftmost real client, got {xff}"
);
assert!(
xff.contains("10.0.0.1"),
"original chain should survive: {xff}"
);
}
#[test]
fn untrusted_peer_cf_connecting_ip_is_ignored() {
let peer: SocketAddr = "8.8.8.8:443".parse().unwrap();
let trusted = TrustedProxyList::new(vec!["203.0.113.0/24".parse().unwrap()], None);
let svc = build_svc(peer, trusted);
let mut parts = parts_with_headers(&[("cf-connecting-ip", "198.51.100.7")]);
svc.add_forwarding_headers(&mut parts);
assert_eq!(parts.headers.get("x-real-ip").unwrap(), "8.8.8.8");
let xff = parts
.headers
.get("x-forwarded-for")
.unwrap()
.to_str()
.unwrap();
assert!(
xff.ends_with("8.8.8.8"),
"XFF for untrusted peer should end with peer IP, got {xff}"
);
}
#[test]
fn no_headers_uses_peer_ip() {
let peer: SocketAddr = "198.51.100.250:443".parse().unwrap();
let trusted = TrustedProxyList::localhost_only();
let svc = build_svc(peer, trusted);
let mut parts = parts_with_headers(&[]);
svc.add_forwarding_headers(&mut parts);
assert_eq!(parts.headers.get("x-real-ip").unwrap(), "198.51.100.250");
assert_eq!(
parts.headers.get("x-forwarded-for").unwrap(),
"198.51.100.250"
);
}
}