use crate::acme::CertManager;
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result};
use crate::lb::LoadBalancer;
use crate::network_policy::NetworkPolicyChecker;
use crate::routes::ServiceRegistry;
use crate::service::ReverseProxyService;
use crate::sni_resolver::SniCertResolver;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::Request;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, warn};
const INGRESS_BIND_BACKOFF_INITIAL: Duration = Duration::from_secs(2);
const INGRESS_BIND_BACKOFF_MAX: Duration = Duration::from_secs(30);
const INGRESS_BIND_WARN_EVERY: u64 = 30;
#[must_use]
fn next_ingress_backoff(current: Duration) -> Duration {
(current * 2).min(INGRESS_BIND_BACKOFF_MAX)
}
#[must_use]
fn should_warn_on_attempt(attempt: u64) -> bool {
attempt == 0 || attempt % INGRESS_BIND_WARN_EVERY == 0
}
async fn bind_with_retry(
addr: SocketAddr,
label: &str,
shutdown_rx: &mut watch::Receiver<bool>,
) -> Option<TcpListener> {
let mut attempt: u64 = 0;
let mut backoff = INGRESS_BIND_BACKOFF_INITIAL;
let mut warned_eacces = false;
loop {
if *shutdown_rx.borrow() {
return None;
}
match TcpListener::bind(addr).await {
Ok(listener) => {
if attempt > 0 {
info!(addr = %addr, label, attempt, "Ingress bound after retrying");
}
return Some(listener);
}
Err(e) => {
let is_eacces = e.kind() == std::io::ErrorKind::PermissionDenied;
if is_eacces && !warned_eacces {
warned_eacces = true;
warn!(
addr = %addr,
label,
error = %e,
"Ingress bind denied: binding 80/443 needs root or CAP_NET_BIND_SERVICE; \
will keep retrying without aborting startup"
);
} else if should_warn_on_attempt(attempt) {
warn!(
addr = %addr,
label,
attempt,
error = %e,
"Ingress bind failed; will keep retrying (port may be held by another process)"
);
} else {
debug!(addr = %addr, label, attempt, error = %e, "Ingress bind retry");
}
tokio::select! {
() = tokio::time::sleep(backoff) => {}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
return None;
}
}
}
attempt = attempt.saturating_add(1);
backoff = next_ingress_backoff(backoff);
}
}
}
}
pub struct ProxyServer {
config: Arc<ProxyConfig>,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
tls_acceptor: Option<TlsAcceptor>,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
}
impl ProxyServer {
pub fn new(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
config: Arc::new(config),
registry,
load_balancer,
shutdown_tx,
shutdown_rx,
tls_acceptor: None,
cert_manager: None,
network_policy_checker: None,
}
}
pub fn with_registry(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
) -> Self {
Self::new(config, registry, load_balancer)
}
pub fn with_tls_resolver(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
resolver: Arc<SniCertResolver>,
) -> Self {
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
config: Arc::new(config),
registry,
load_balancer,
shutdown_tx,
shutdown_rx,
tls_acceptor: Some(acceptor),
cert_manager: None,
network_policy_checker: None,
}
}
#[must_use]
pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
self.cert_manager = Some(cm);
self
}
#[must_use]
pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
self.network_policy_checker = Some(checker);
self
}
#[must_use]
pub fn has_tls(&self) -> bool {
self.tls_acceptor.is_some()
}
#[must_use]
pub fn tls_acceptor(&self) -> Option<&TlsAcceptor> {
self.tls_acceptor.as_ref()
}
#[must_use]
pub fn registry(&self) -> Arc<ServiceRegistry> {
self.registry.clone()
}
#[must_use]
pub fn config(&self) -> Arc<ProxyConfig> {
self.config.clone()
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub async fn run(&self) -> Result<()> {
let addr = self.config.server.http_addr;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTP proxy server listening");
self.accept_loop(listener).await
}
pub async fn run_on(&self, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTP proxy server listening");
self.accept_loop(listener).await
}
pub async fn run_with_retry(&self, addr: SocketAddr) -> Result<()> {
let mut shutdown_rx = self.shutdown_rx.clone();
let Some(listener) = bind_with_retry(addr, "http", &mut shutdown_rx).await else {
return Ok(());
};
info!(addr = %addr, "HTTP ingress server listening");
self.accept_loop(listener).await
}
async fn accept_loop(&self, listener: TcpListener) -> Result<()> {
let mut shutdown_rx = self.shutdown_rx.clone();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("Shutting down proxy server");
break;
}
}
result = listener.accept() => {
match result {
Ok((stream, remote_addr)) => {
let registry = self.registry.clone();
let load_balancer = self.load_balancer.clone();
let config = self.config.clone();
let cert_manager = self.cert_manager.clone();
let npc = self.network_policy_checker.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
remote_addr,
registry,
load_balancer,
config,
cert_manager,
npc,
).await {
debug!(
error = %e,
remote_addr = %remote_addr,
"Connection error"
);
}
});
}
Err(e) => {
warn!(error = %e, "Failed to accept connection");
}
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn handle_connection(
stream: tokio::net::TcpStream,
remote_addr: SocketAddr,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
config: Arc<ProxyConfig>,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
) -> Result<()> {
let io = TokioIo::new(stream);
let mut service =
ReverseProxyService::new(registry, load_balancer, config).with_remote_addr(remote_addr);
if let Some(cm) = cert_manager {
service = service.with_cert_manager(cm);
}
if let Some(checker) = network_policy_checker {
service = service.with_network_policy_checker(checker);
}
let service = service_fn(move |req: Request<Incoming>| {
let svc = service.clone();
async move {
match svc.proxy_request(req).await {
Ok(response) => Ok::<_, hyper::Error>(response),
Err(e) => {
error!(error = %e, "Proxy error");
Ok(ReverseProxyService::error_response(&e))
}
}
}
});
http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(false)
.serve_connection(io, service)
.with_upgrades()
.await
.map_err(ProxyError::Hyper)?;
Ok(())
}
pub async fn run_https(&self) -> Result<()> {
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let addr = self.config.server.https_addr;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTPS proxy server listening");
self.accept_loop_tls(listener, acceptor.clone()).await
}
pub async fn run_https_on(&self, addr: SocketAddr) -> Result<()> {
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTPS proxy server listening");
self.accept_loop_tls(listener, acceptor.clone()).await
}
pub async fn run_https_with_retry(&self, addr: SocketAddr) -> Result<()> {
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let mut shutdown_rx = self.shutdown_rx.clone();
let Some(listener) = bind_with_retry(addr, "https", &mut shutdown_rx).await else {
return Ok(());
};
info!(addr = %addr, "HTTPS ingress server listening");
self.accept_loop_tls(listener, acceptor.clone()).await
}
#[allow(clippy::similar_names)]
pub async fn run_both(&self) -> Result<()> {
let http_addr = self.config.server.http_addr;
let https_addr = self.config.server.https_addr;
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let http_listener =
TcpListener::bind(http_addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr: http_addr,
reason: e.to_string(),
})?;
let https_listener =
TcpListener::bind(https_addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr: https_addr,
reason: e.to_string(),
})?;
info!(http = %http_addr, https = %https_addr, "Proxy server listening");
let http_future = self.accept_loop(http_listener);
let https_future = self.accept_loop_tls(https_listener, acceptor.clone());
tokio::select! {
result = http_future => result,
result = https_future => result,
}
}
async fn accept_loop_tls(&self, listener: TcpListener, acceptor: TlsAcceptor) -> Result<()> {
let mut shutdown_rx = self.shutdown_rx.clone();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("Shutting down HTTPS proxy server");
break;
}
}
result = listener.accept() => {
match result {
Ok((stream, remote_addr)) => {
let registry = self.registry.clone();
let load_balancer = self.load_balancer.clone();
let config = self.config.clone();
let acceptor = acceptor.clone();
let cert_manager = self.cert_manager.clone();
let npc = self.network_policy_checker.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_tls_connection(
stream,
remote_addr,
registry,
load_balancer,
config,
acceptor,
cert_manager,
npc,
).await {
debug!(
error = %e,
remote_addr = %remote_addr,
"TLS connection error"
);
}
});
}
Err(e) => {
warn!(error = %e, "Failed to accept TLS connection");
}
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn handle_tls_connection(
stream: tokio::net::TcpStream,
remote_addr: SocketAddr,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
config: Arc<ProxyConfig>,
acceptor: TlsAcceptor,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
) -> Result<()> {
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| ProxyError::Tls(format!("TLS handshake failed: {e}")))?;
let io = TokioIo::new(tls_stream);
let mut service = ReverseProxyService::new(registry, load_balancer, config)
.with_remote_addr(remote_addr)
.with_tls(true);
if let Some(cm) = cert_manager {
service = service.with_cert_manager(cm);
}
if let Some(checker) = network_policy_checker {
service = service.with_network_policy_checker(checker);
}
let service = service_fn(move |req: Request<Incoming>| {
let svc = service.clone();
async move {
match svc.proxy_request(req).await {
Ok(response) => Ok::<_, hyper::Error>(response),
Err(e) => {
error!(error = %e, "Proxy error");
Ok(ReverseProxyService::error_response(&e))
}
}
}
});
http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(false)
.serve_connection(io, service)
.with_upgrades()
.await
.map_err(ProxyError::Hyper)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lb::LoadBalancer;
use crate::routes::{ResolvedService, RouteEntry};
use zlayer_spec::{ExposeType, Protocol};
fn make_entry(
service: &str,
host: Option<&str>,
path: &str,
backends: Vec<SocketAddr>,
) -> RouteEntry {
RouteEntry {
service_name: service.to_string(),
endpoint_name: "http".to_string(),
host: host.map(std::string::ToString::to_string),
path_prefix: path.to_string(),
resolved: ResolvedService {
name: service.to_string(),
backends,
use_tls: false,
sni_hostname: String::new(),
expose: ExposeType::Public,
protocol: Protocol::Http,
strip_prefix: false,
path_prefix: path.to_string(),
target_port: 8080,
},
}
}
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn roundtrip(
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
cert_manager: Option<Arc<CertManager>>,
raw_request: &str,
) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, remote_addr) = listener.accept().await.unwrap();
let _ = ProxyServer::handle_connection(
stream,
remote_addr,
registry,
load_balancer,
Arc::new(ProxyConfig::default()),
cert_manager,
None,
)
.await;
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
client.write_all(raw_request.as_bytes()).await.unwrap();
client.flush().await.unwrap();
let mut buf = Vec::new();
let _ = client.read_to_end(&mut buf).await;
server.abort();
String::from_utf8_lossy(&buf).into_owned()
}
#[tokio::test]
async fn test_unmatched_host_denied_404_generic_body() {
let registry = Arc::new(ServiceRegistry::new());
registry
.register(make_entry(
"known",
Some("known.example.com"),
"/",
vec!["127.0.0.1:9".parse().unwrap()],
))
.await;
let lb = Arc::new(LoadBalancer::new());
let resp = roundtrip(
registry,
lb,
None,
"GET /secret/path HTTP/1.1\r\nHost: attacker.unregistered.test\r\nConnection: close\r\n\r\n",
)
.await;
assert!(
resp.starts_with("HTTP/1.1 404"),
"unmatched host must be denied with 404, got: {resp}"
);
assert!(
!resp.contains("attacker.unregistered.test"),
"response must not echo the requested host: {resp}"
);
assert!(
!resp.contains("/secret/path"),
"response must not echo the requested path: {resp}"
);
assert!(
resp.contains("404 Not Found"),
"response should carry the generic 404 body: {resp}"
);
}
#[tokio::test]
async fn test_acme_challenge_served_not_denied() {
let tmp = tempfile::tempdir().unwrap();
let cm = Arc::new(
CertManager::new(tmp.path().to_string_lossy().into_owned(), None)
.await
.unwrap(),
);
let token = "test-token-abc";
cm.store_challenge(token, "example.com", "key-auth-payload-123");
let registry = Arc::new(ServiceRegistry::new());
let lb = Arc::new(LoadBalancer::new());
let resp = roundtrip(
registry,
lb,
Some(cm),
&format!(
"GET /.well-known/acme-challenge/{token} HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"
),
)
.await;
assert!(
resp.starts_with("HTTP/1.1 200"),
"ACME challenge must be served with 200, got: {resp}"
);
assert!(
resp.contains("key-auth-payload-123"),
"ACME challenge must return the key authorization: {resp}"
);
}
#[tokio::test]
async fn test_matched_no_backends_503_generic_body() {
let registry = Arc::new(ServiceRegistry::new());
let lb_group = "prod/api#http-secret-group";
let mut entry = make_entry("api", Some("api.example.com"), "/", vec![]);
entry.resolved.name = lb_group.to_string();
registry.register(entry).await;
let lb = Arc::new(LoadBalancer::new());
let resp = roundtrip(
registry,
lb,
None,
"GET / HTTP/1.1\r\nHost: api.example.com\r\nConnection: close\r\n\r\n",
)
.await;
assert!(
resp.starts_with("HTTP/1.1 503"),
"matched route with no healthy backends must return 503, got: {resp}"
);
assert!(
!resp.contains(lb_group),
"503 body must not leak the internal LB group name: {resp}"
);
assert!(
resp.contains("503 Service Unavailable"),
"response should carry the generic 503 body: {resp}"
);
}
#[tokio::test]
async fn test_server_shutdown() {
let registry = Arc::new(ServiceRegistry::new());
let lb = Arc::new(LoadBalancer::new());
let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
let shutdown_tx = server.shutdown_tx.clone();
let _ = shutdown_tx.send(true);
}
#[tokio::test]
async fn test_registry_integration() {
let registry = Arc::new(ServiceRegistry::new());
registry
.register(make_entry(
"test-service",
None,
"/api",
vec!["127.0.0.1:8081".parse().unwrap()],
))
.await;
let lb = Arc::new(LoadBalancer::new());
let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
let reg = server.registry();
assert_eq!(reg.route_count().await, 1);
}
#[test]
fn test_next_ingress_backoff_doubles_then_caps() {
assert_eq!(
next_ingress_backoff(INGRESS_BIND_BACKOFF_INITIAL),
INGRESS_BIND_BACKOFF_INITIAL * 2
);
let mut d = INGRESS_BIND_BACKOFF_INITIAL;
for _ in 0..20 {
d = next_ingress_backoff(d);
}
assert_eq!(d, INGRESS_BIND_BACKOFF_MAX);
assert_eq!(
next_ingress_backoff(INGRESS_BIND_BACKOFF_MAX),
INGRESS_BIND_BACKOFF_MAX
);
}
#[test]
fn test_should_warn_cadence() {
assert!(should_warn_on_attempt(0));
assert!(!should_warn_on_attempt(1));
assert!(!should_warn_on_attempt(INGRESS_BIND_WARN_EVERY - 1));
assert!(should_warn_on_attempt(INGRESS_BIND_WARN_EVERY));
assert!(should_warn_on_attempt(INGRESS_BIND_WARN_EVERY * 2));
}
#[tokio::test]
async fn test_bind_with_retry_succeeds_after_initial_conflict() {
let held = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = held.local_addr().unwrap();
let (_tx, mut rx) = watch::channel(false);
let handle = tokio::spawn(async move { bind_with_retry(addr, "test", &mut rx).await });
tokio::time::sleep(Duration::from_millis(50)).await;
drop(held);
let bound = tokio::time::timeout(Duration::from_secs(10), handle)
.await
.expect("bind_with_retry did not finish")
.expect("task panicked");
let listener = bound.expect("expected a bound listener, got None");
assert_eq!(listener.local_addr().unwrap().port(), addr.port());
}
#[tokio::test]
async fn test_bind_with_retry_returns_none_on_shutdown() {
let held = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = held.local_addr().unwrap();
let (tx, mut rx) = watch::channel(false);
let handle = tokio::spawn(async move { bind_with_retry(addr, "test", &mut rx).await });
tokio::time::sleep(Duration::from_millis(50)).await;
tx.send(true).unwrap();
let result = tokio::time::timeout(Duration::from_secs(10), handle)
.await
.expect("bind_with_retry did not respond to shutdown")
.expect("task panicked");
assert!(result.is_none(), "shutdown should yield None");
}
}