use crate::acme::CertManager;
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result};
use crate::lb::LoadBalancer;
use crate::network_policy::NetworkPolicyChecker;
use crate::routes::ServiceRegistry;
use crate::service::ReverseProxyService;
use crate::sni_resolver::SniCertResolver;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::Request;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, warn};
pub struct ProxyServer {
config: Arc<ProxyConfig>,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
tls_acceptor: Option<TlsAcceptor>,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
}
impl ProxyServer {
pub fn new(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
config: Arc::new(config),
registry,
load_balancer,
shutdown_tx,
shutdown_rx,
tls_acceptor: None,
cert_manager: None,
network_policy_checker: None,
}
}
pub fn with_registry(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
) -> Self {
Self::new(config, registry, load_balancer)
}
pub fn with_tls_resolver(
config: ProxyConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
resolver: Arc<SniCertResolver>,
) -> Self {
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
config: Arc::new(config),
registry,
load_balancer,
shutdown_tx,
shutdown_rx,
tls_acceptor: Some(acceptor),
cert_manager: None,
network_policy_checker: None,
}
}
#[must_use]
pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
self.cert_manager = Some(cm);
self
}
#[must_use]
pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
self.network_policy_checker = Some(checker);
self
}
#[must_use]
pub fn has_tls(&self) -> bool {
self.tls_acceptor.is_some()
}
#[must_use]
pub fn tls_acceptor(&self) -> Option<&TlsAcceptor> {
self.tls_acceptor.as_ref()
}
#[must_use]
pub fn registry(&self) -> Arc<ServiceRegistry> {
self.registry.clone()
}
#[must_use]
pub fn config(&self) -> Arc<ProxyConfig> {
self.config.clone()
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub async fn run(&self) -> Result<()> {
let addr = self.config.server.http_addr;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTP proxy server listening");
self.accept_loop(listener).await
}
pub async fn run_on(&self, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTP proxy server listening");
self.accept_loop(listener).await
}
async fn accept_loop(&self, listener: TcpListener) -> Result<()> {
let mut shutdown_rx = self.shutdown_rx.clone();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("Shutting down proxy server");
break;
}
}
result = listener.accept() => {
match result {
Ok((stream, remote_addr)) => {
let registry = self.registry.clone();
let load_balancer = self.load_balancer.clone();
let config = self.config.clone();
let cert_manager = self.cert_manager.clone();
let npc = self.network_policy_checker.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
stream,
remote_addr,
registry,
load_balancer,
config,
cert_manager,
npc,
).await {
debug!(
error = %e,
remote_addr = %remote_addr,
"Connection error"
);
}
});
}
Err(e) => {
warn!(error = %e, "Failed to accept connection");
}
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn handle_connection(
stream: tokio::net::TcpStream,
remote_addr: SocketAddr,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
config: Arc<ProxyConfig>,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
) -> Result<()> {
let io = TokioIo::new(stream);
let mut service =
ReverseProxyService::new(registry, load_balancer, config).with_remote_addr(remote_addr);
if let Some(cm) = cert_manager {
service = service.with_cert_manager(cm);
}
if let Some(checker) = network_policy_checker {
service = service.with_network_policy_checker(checker);
}
let service = service_fn(move |req: Request<Incoming>| {
let svc = service.clone();
async move {
match svc.proxy_request(req).await {
Ok(response) => Ok::<_, hyper::Error>(response),
Err(e) => {
error!(error = %e, "Proxy error");
Ok(ReverseProxyService::error_response(&e))
}
}
}
});
http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(false)
.serve_connection(io, service)
.with_upgrades()
.await
.map_err(ProxyError::Hyper)?;
Ok(())
}
pub async fn run_https(&self) -> Result<()> {
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let addr = self.config.server.https_addr;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTPS proxy server listening");
self.accept_loop_tls(listener, acceptor.clone()).await
}
pub async fn run_https_on(&self, addr: SocketAddr) -> Result<()> {
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr,
reason: e.to_string(),
})?;
info!(addr = %addr, "HTTPS proxy server listening");
self.accept_loop_tls(listener, acceptor.clone()).await
}
#[allow(clippy::similar_names)]
pub async fn run_both(&self) -> Result<()> {
let http_addr = self.config.server.http_addr;
let https_addr = self.config.server.https_addr;
let acceptor = self
.tls_acceptor
.as_ref()
.ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
let http_listener =
TcpListener::bind(http_addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr: http_addr,
reason: e.to_string(),
})?;
let https_listener =
TcpListener::bind(https_addr)
.await
.map_err(|e| ProxyError::BindFailed {
addr: https_addr,
reason: e.to_string(),
})?;
info!(http = %http_addr, https = %https_addr, "Proxy server listening");
let http_future = self.accept_loop(http_listener);
let https_future = self.accept_loop_tls(https_listener, acceptor.clone());
tokio::select! {
result = http_future => result,
result = https_future => result,
}
}
async fn accept_loop_tls(&self, listener: TcpListener, acceptor: TlsAcceptor) -> Result<()> {
let mut shutdown_rx = self.shutdown_rx.clone();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
info!("Shutting down HTTPS proxy server");
break;
}
}
result = listener.accept() => {
match result {
Ok((stream, remote_addr)) => {
let registry = self.registry.clone();
let load_balancer = self.load_balancer.clone();
let config = self.config.clone();
let acceptor = acceptor.clone();
let cert_manager = self.cert_manager.clone();
let npc = self.network_policy_checker.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_tls_connection(
stream,
remote_addr,
registry,
load_balancer,
config,
acceptor,
cert_manager,
npc,
).await {
debug!(
error = %e,
remote_addr = %remote_addr,
"TLS connection error"
);
}
});
}
Err(e) => {
warn!(error = %e, "Failed to accept TLS connection");
}
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn handle_tls_connection(
stream: tokio::net::TcpStream,
remote_addr: SocketAddr,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
config: Arc<ProxyConfig>,
acceptor: TlsAcceptor,
cert_manager: Option<Arc<CertManager>>,
network_policy_checker: Option<NetworkPolicyChecker>,
) -> Result<()> {
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| ProxyError::Tls(format!("TLS handshake failed: {e}")))?;
let io = TokioIo::new(tls_stream);
let mut service = ReverseProxyService::new(registry, load_balancer, config)
.with_remote_addr(remote_addr)
.with_tls(true);
if let Some(cm) = cert_manager {
service = service.with_cert_manager(cm);
}
if let Some(checker) = network_policy_checker {
service = service.with_network_policy_checker(checker);
}
let service = service_fn(move |req: Request<Incoming>| {
let svc = service.clone();
async move {
match svc.proxy_request(req).await {
Ok(response) => Ok::<_, hyper::Error>(response),
Err(e) => {
error!(error = %e, "Proxy error");
Ok(ReverseProxyService::error_response(&e))
}
}
}
});
http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(false)
.serve_connection(io, service)
.with_upgrades()
.await
.map_err(ProxyError::Hyper)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lb::LoadBalancer;
use crate::routes::{ResolvedService, RouteEntry};
use zlayer_spec::{ExposeType, Protocol};
fn make_entry(
service: &str,
host: Option<&str>,
path: &str,
backends: Vec<SocketAddr>,
) -> RouteEntry {
RouteEntry {
service_name: service.to_string(),
endpoint_name: "http".to_string(),
host: host.map(std::string::ToString::to_string),
path_prefix: path.to_string(),
resolved: ResolvedService {
name: service.to_string(),
backends,
use_tls: false,
sni_hostname: String::new(),
expose: ExposeType::Public,
protocol: Protocol::Http,
strip_prefix: false,
path_prefix: path.to_string(),
target_port: 8080,
},
}
}
#[tokio::test]
async fn test_server_shutdown() {
let registry = Arc::new(ServiceRegistry::new());
let lb = Arc::new(LoadBalancer::new());
let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
let shutdown_tx = server.shutdown_tx.clone();
let _ = shutdown_tx.send(true);
}
#[tokio::test]
async fn test_registry_integration() {
let registry = Arc::new(ServiceRegistry::new());
registry
.register(make_entry(
"test-service",
None,
"/api",
vec!["127.0.0.1:8081".parse().unwrap()],
))
.await;
let lb = Arc::new(LoadBalancer::new());
let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
let reg = server.registry();
assert_eq!(reg.route_count().await, 1);
}
}