pub(crate) mod protocol;
use protocol::{ProtocolContext, WsContext};
use crate::config::{GatewayConfig, Protocol};
use crate::error::{GatewayError, Result};
use crate::middleware::{Pipeline, RequestContext, TcpFilter};
use crate::proxy::tcp;
use crate::proxy::udp::{self, UdpProxyConfig};
use crate::proxy::HttpProxy;
use crate::router::RouterTable;
use crate::scaling::buffer::RequestBuffer;
use crate::scaling::concurrency::ConcurrencyLimiter;
use crate::scaling::revision::RevisionRouter;
use crate::service::passive_health::PassiveHealthCheck;
use crate::service::sticky::StickySessionManager;
use crate::service::ServiceRegistry;
use bytes::Bytes;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
type ResponseBody = UnsyncBoxBody<Bytes, std::io::Error>;
fn full_body(bytes: impl Into<Bytes>) -> ResponseBody {
http_body_util::Full::new(bytes.into())
.map_err(|never| match never {})
.boxed_unsync()
}
fn error_response(status: u16, message: &str) -> hyper::Response<ResponseBody> {
hyper::Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(full_body(Bytes::from(format!(
r#"{{"error":"{}"}}"#,
message
))))
.unwrap()
}
pub struct ScalingState {
pub buffers: HashMap<String, Arc<RequestBuffer>>,
pub limiters: HashMap<String, Arc<ConcurrencyLimiter>>,
pub revision_routers: HashMap<String, Arc<RevisionRouter>>,
}
pub struct GatewayState {
pub router_table: Arc<RouterTable>,
pub service_registry: Arc<ServiceRegistry>,
pub middleware_configs: Arc<HashMap<String, crate::config::MiddlewareConfig>>,
pub pipeline_cache: Arc<HashMap<String, Arc<Pipeline>>>,
pub http_proxy: Arc<HttpProxy>,
pub grpc_proxy: Arc<crate::proxy::grpc::GrpcProxy>,
pub scaling: Option<Arc<ScalingState>>,
pub mirrors: HashMap<String, Arc<crate::service::TrafficMirror>>,
pub failovers: HashMap<String, Arc<crate::service::FailoverSelector>>,
pub access_log: Arc<crate::observability::access_log::AccessLog>,
#[allow(dead_code)]
pub log_tx:
tokio::sync::mpsc::UnboundedSender<crate::observability::access_log::AccessLogEntry>,
pub sticky_managers: HashMap<String, Arc<StickySessionManager>>,
pub passive_health: HashMap<String, Arc<PassiveHealthCheck>>,
pub metrics: Arc<crate::observability::metrics::GatewayMetrics>,
}
pub async fn start_entrypoints(
config: &GatewayConfig,
state: Arc<GatewayState>,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
) -> Result<Vec<tokio::task::JoinHandle<()>>> {
let mut handles = Vec::new();
for (name, ep_config) in &config.entrypoints {
let addr: SocketAddr = ep_config.address.parse().map_err(|e| {
GatewayError::Config(format!(
"Invalid address '{}' for entrypoint '{}': {}",
ep_config.address, name, e
))
})?;
match ep_config.protocol {
Protocol::Http => {
let handle = start_http_entrypoint(
name.clone(),
addr,
ep_config.tls.as_ref(),
state.clone(),
shutdown_rx.clone(),
)
.await?;
handles.push(handle);
}
Protocol::Tcp => {
let handle = start_tcp_entrypoint(
name.clone(),
addr,
ep_config.max_connections,
&ep_config.tcp_allowed_ips,
state.clone(),
)
.await?;
handles.push(handle);
}
Protocol::Udp => {
let handle = start_udp_entrypoint(
name.clone(),
addr,
ep_config.udp_session_timeout_secs,
ep_config.udp_max_sessions,
state.clone(),
)
.await?;
handles.push(handle);
}
}
}
Ok(handles)
}
async fn start_http_entrypoint(
name: String,
addr: SocketAddr,
tls_config: Option<&crate::config::TlsConfig>,
state: Arc<GatewayState>,
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
) -> Result<tokio::task::JoinHandle<()>> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| GatewayError::Other(format!("Failed to bind {}: {}", addr, e)))?;
let tls_acceptor = if let Some(tls) = tls_config {
Some(crate::proxy::tls::build_tls_acceptor(tls)?)
} else {
None
};
tracing::info!(
entrypoint = name,
address = %addr,
tls = tls_acceptor.is_some(),
"HTTP entrypoint listening"
);
let ep_name = name.clone();
let handle = tokio::spawn(async move {
let mut conn_handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
loop {
conn_handles.retain(|h| !h.is_finished());
tokio::select! {
result = listener.accept() => {
let (stream, remote_addr) = match result {
Ok(conn) => conn,
Err(e) => {
tracing::error!(error = %e, "Failed to accept connection");
continue;
}
};
let state = state.clone();
let ep_name = ep_name.clone();
let tls_acceptor = tls_acceptor.clone();
let conn_handle = tokio::spawn(async move {
state.metrics.inc_connections();
if let Some(acceptor) = tls_acceptor {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
let _ = auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(
io,
service_fn(|req| {
handle_http_request(
req,
remote_addr,
ep_name.clone(),
state.clone(),
)
}),
)
.await;
}
Err(e) => {
tracing::debug!(error = %e, "TLS handshake failed");
}
}
} else {
let io = TokioIo::new(stream);
let _ = auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(
io,
service_fn(|req| {
handle_http_request(
req,
remote_addr,
ep_name.clone(),
state.clone(),
)
}),
)
.await;
}
state.metrics.dec_connections();
});
conn_handles.push(conn_handle);
}
_ = shutdown_rx.changed() => {
tracing::info!(entrypoint = ep_name, "Shutdown signal received, draining connections");
break;
}
}
}
let drain_timeout = Duration::from_secs(30);
let drain_deadline = tokio::time::Instant::now() + drain_timeout;
for handle in conn_handles {
let remaining = drain_deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
handle.abort();
} else {
tokio::select! {
_ = handle => {}
_ = tokio::time::sleep(remaining) => {
tracing::warn!(entrypoint = ep_name, "Connection drain timeout, aborting remaining");
break;
}
}
}
}
});
Ok(handle)
}
async fn start_tcp_entrypoint(
name: String,
addr: SocketAddr,
max_connections: Option<u32>,
tcp_allowed_ips: &[String],
state: Arc<GatewayState>,
) -> Result<tokio::task::JoinHandle<()>> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| GatewayError::Other(format!("Failed to bind TCP {}: {}", addr, e)))?;
let tcp_filter = Arc::new(TcpFilter::new(max_connections, tcp_allowed_ips)?);
tracing::info!(
entrypoint = name,
address = %addr,
max_connections = ?max_connections,
ip_filter = !tcp_allowed_ips.is_empty(),
"TCP entrypoint listening"
);
let handle = tokio::spawn(async move {
loop {
let (client_stream, remote_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(error = %e, "Failed to accept TCP connection");
continue;
}
};
let permit = match tcp_filter.check_connection(&remote_addr.ip().to_string()) {
Ok(permit) => permit,
Err(e) => {
tracing::debug!(
error = %e,
remote = %remote_addr,
"TCP connection rejected by filter"
);
continue;
}
};
let state = state.clone();
let ep_name = name.clone();
tokio::spawn(async move {
let _permit = permit;
let headers = http::HeaderMap::new();
if let Some(route) = state
.router_table
.match_request(None, "/", "TCP", &headers, &ep_name)
{
if let Some(lb) = state.service_registry.get(&route.service_name) {
if let Some(backend) = lb.next_backend() {
let address = tcp::extract_address(&backend.url);
match tcp::connect_upstream(address).await {
Ok(upstream_stream) => {
backend.inc_connections();
let result =
tcp::relay_tcp(client_stream, upstream_stream).await;
backend.dec_connections();
if let Err(e) = result {
tracing::debug!(
error = %e,
remote = %remote_addr,
"TCP relay ended"
);
}
}
Err(e) => {
tracing::warn!(
error = %e,
backend = backend.url,
"TCP upstream connection failed"
);
}
}
}
}
} else {
tracing::debug!(
remote = %remote_addr,
"No TCP route matched"
);
}
});
}
});
Ok(handle)
}
async fn start_udp_entrypoint(
name: String,
addr: SocketAddr,
session_timeout_secs: Option<u64>,
max_sessions: Option<usize>,
state: Arc<GatewayState>,
) -> Result<tokio::task::JoinHandle<()>> {
let headers = http::HeaderMap::new();
let upstream_addr = state
.router_table
.match_request(None, "/", "UDP", &headers, &name)
.and_then(|route| state.service_registry.get(&route.service_name))
.and_then(|lb| lb.next_backend())
.map(|backend| crate::proxy::tcp::extract_address(&backend.url).to_string())
.ok_or_else(|| {
GatewayError::Config(format!(
"UDP entrypoint '{}' has no matching router/service with a healthy backend",
name
))
})?;
let timeout = Duration::from_secs(session_timeout_secs.unwrap_or(30));
let max_sess = max_sessions.unwrap_or(10000);
let (socket, _) = udp::start_udp_listener(&addr.to_string(), &upstream_addr, timeout).await?;
let proxy = udp::UdpProxy::new(UdpProxyConfig {
session_timeout: timeout,
max_sessions: max_sess,
upstream_addr: upstream_addr.clone(),
});
let proxy = Arc::new(proxy);
tracing::info!(
entrypoint = name,
address = %addr,
upstream = upstream_addr,
session_timeout_secs = timeout.as_secs(),
max_sessions = max_sess,
"UDP entrypoint listening"
);
let handle = tokio::spawn(async move {
udp::run_udp_proxy(socket, proxy).await;
});
Ok(handle)
}
async fn handle_http_request(
req: hyper::Request<Incoming>,
remote_addr: SocketAddr,
entrypoint: String,
state: Arc<GatewayState>,
) -> std::result::Result<hyper::Response<ResponseBody>, hyper::Error> {
let host = req
.headers()
.get("Host")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let path = req.uri().path().to_string();
let method_str = req.method().as_str().to_string();
let _uri = req.uri().clone();
let is_ws = crate::proxy::websocket::is_websocket_upgrade(req.headers());
let is_grpc = crate::proxy::grpc::is_grpc_request(req.headers());
let is_sse = crate::proxy::streaming::is_streaming_request(req.headers());
let access_tracker = state.access_log.start_request();
let trace_ctx = crate::observability::tracing::extract_trace_context(req.headers())
.map(|ctx| ctx.child())
.unwrap_or_else(crate::observability::tracing::TraceContext::new_root);
let route = match state.router_table.match_request(
host.as_deref(),
&path,
&method_str,
req.headers(),
&entrypoint,
) {
Some(route) => route,
None => {
state.metrics.record_request(404, 0);
return Ok(error_response(404, "No route matched"));
}
};
state.metrics.record_router_request(&route.router_name);
state.metrics.record_service_request(&route.service_name);
let request_start = std::time::Instant::now();
let pipeline: Arc<Pipeline> = if let Some(cached) = state.pipeline_cache.get(&route.router_name)
{
cached.clone()
} else {
match Pipeline::from_config(&route.middlewares, &state.middleware_configs) {
Ok(p) => Arc::new(p),
Err(e) => {
tracing::error!(error = %e, "Failed to build middleware pipeline");
return Ok(error_response(500, "Internal server error"));
}
}
};
let ctx = RequestContext {
client_ip: remote_addr.ip().to_string(),
entrypoint: entrypoint.clone(),
router: route.router_name.clone(),
};
if is_ws {
let (mut temp_parts, _) = http::Request::builder()
.method(req.method())
.uri(req.uri())
.version(req.version())
.body(())
.unwrap()
.into_parts();
temp_parts.headers = req.headers().clone();
match pipeline.process_request(&mut temp_parts, &ctx).await {
Ok(Some(response)) => {
let (resp_parts, body) = response.into_parts();
return Ok(hyper::Response::from_parts(resp_parts, full_body(body)));
}
Ok(None) => {}
Err(e) => {
tracing::error!(error = %e, "Middleware error (WebSocket)");
return Ok(error_response(500, "Middleware error"));
}
}
let lb = match state.service_registry.get(&route.service_name) {
Some(lb) => lb,
None => return Ok(error_response(502, "Service not found")),
};
let backend = match lb.next_backend() {
Some(b) => b,
None => return Ok(error_response(503, "No healthy backends")),
};
let ws_ctx = WsContext {
route: route.clone(),
backend: backend.clone(),
pipeline: pipeline.clone(),
state: state.clone(),
remote_addr,
request_start,
};
let (ws_resp, relay_future) = protocol::handle_ws_upgrade(req, ws_ctx);
tokio::spawn(relay_future);
return Ok(ws_resp);
}
let (mut req_parts, body) = req.into_parts();
let body_bytes = match BodyExt::collect(body).await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
match pipeline.process_request(&mut req_parts, &ctx).await {
Ok(Some(response)) => {
let (resp_parts, body) = response.into_parts();
return Ok(hyper::Response::from_parts(resp_parts, full_body(body)));
}
Ok(None) => {}
Err(e) => {
tracing::error!(error = %e, "Middleware error");
return Ok(error_response(500, "Middleware error"));
}
}
let lb = match state.service_registry.get(&route.service_name) {
Some(lb) => lb,
None => {
return Ok(error_response(502, "Service not found"));
}
};
let scaling = state.scaling.as_ref();
let mut sticky_new_session: Option<String> = None;
let backend_from_sticky = state
.sticky_managers
.get(&route.service_name)
.and_then(|mgr| {
let session_id = req_parts
.headers
.get("cookie")
.and_then(|v| v.to_str().ok())
.and_then(|cookie| mgr.extract_session_id(cookie))
.map(|s| s.to_string());
match mgr.select_backend(session_id.as_deref(), lb.backends()) {
Some((backend, new_id)) => {
sticky_new_session = new_id;
Some(backend)
}
None => None,
}
});
let backend = if let Some(b) = backend_from_sticky {
Some(b)
} else if let Some(rev_router) =
scaling.and_then(|s| s.revision_routers.get(&route.service_name))
{
rev_router.next_backend().map(|(b, _rev_name)| b)
} else if let Some(limiter) = scaling.and_then(|s| s.limiters.get(&route.service_name)) {
limiter.select_with_capacity(lb.backends())
} else {
lb.next_backend()
};
let backend = match backend {
Some(b) => b,
None => {
if let Some(buffer) = scaling.and_then(|s| s.buffers.get(&route.service_name)) {
if buffer.needs_scale_up() {
tracing::info!(
service = route.service_name,
"Scale-from-zero triggered, buffering request"
);
}
match buffer.wait_for_backend().await {
crate::scaling::buffer::BufferResult::Ready => match lb.next_backend() {
Some(b) => b,
None => {
return Ok(error_response(503, "No healthy backends after scale-up"));
}
},
crate::scaling::buffer::BufferResult::Timeout => {
return Ok(error_response(504, "Backend scale-up timed out"));
}
crate::scaling::buffer::BufferResult::Overflow => {
return Ok(error_response(503, "Request buffer full"));
}
crate::scaling::buffer::BufferResult::Shutdown => {
return Ok(error_response(503, "Gateway shutting down"));
}
}
} else if let Some(failover) = state.failovers.get(&route.service_name) {
match failover.next_backend() {
Some((b, _is_failover)) => b,
None => {
return Ok(error_response(
503,
"No healthy backends (primary + failover)",
));
}
}
} else {
return Ok(error_response(503, "No healthy backends"));
}
}
};
state.metrics.record_backend_request(&backend.url);
if let Some(mirror) = state.mirrors.get(&route.service_name) {
mirror.mirror_request(
req_parts.method.clone(),
req_parts.uri.clone(),
req_parts.headers.clone(),
body_bytes.clone(),
);
}
let traceparent = trace_ctx.to_traceparent();
if let Ok(hval) = hyper::header::HeaderValue::from_str(&traceparent) {
req_parts
.headers
.insert(hyper::header::HeaderName::from_static("traceparent"), hval);
}
if is_grpc {
let ctx = ProtocolContext {
route,
backend,
req_parts,
body_bytes,
pipeline,
state: state.clone(),
remote_addr,
entrypoint,
trace_ctx,
access_tracker,
method_str,
path,
host,
sticky_new_session,
request_start,
};
return Ok(protocol::handle_grpc_dispatch(ctx, state.grpc_proxy.clone()).await);
}
if is_sse {
let ctx = ProtocolContext {
route,
backend,
req_parts,
body_bytes,
pipeline,
state: state.clone(),
remote_addr,
entrypoint,
trace_ctx,
access_tracker,
method_str,
path,
host,
sticky_new_session,
request_start,
};
return Ok(protocol::handle_sse_dispatch(ctx).await);
}
{
let ctx = ProtocolContext {
route,
backend,
req_parts,
body_bytes,
pipeline,
state: state.clone(),
remote_addr,
entrypoint,
trace_ctx,
access_tracker,
method_str,
path,
host,
sticky_new_session,
request_start,
};
Ok(protocol::handle_http_dispatch(ctx).await)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::EntrypointConfig;
#[test]
fn test_invalid_address() {
let config = GatewayConfig {
entrypoints: {
let mut m = HashMap::new();
m.insert(
"bad".to_string(),
EntrypointConfig {
address: "not-an-address".to_string(),
protocol: Protocol::Http,
tls: None,
max_connections: None,
tcp_allowed_ips: vec![],
udp_session_timeout_secs: None,
udp_max_sessions: None,
},
);
m
},
..GatewayConfig::default()
};
let state = Arc::new(GatewayState {
router_table: Arc::new(RouterTable::from_config(&HashMap::new()).unwrap()),
service_registry: Arc::new(ServiceRegistry::from_config(&HashMap::new()).unwrap()),
middleware_configs: Arc::new(HashMap::new()),
pipeline_cache: Arc::new(HashMap::new()),
http_proxy: Arc::new(HttpProxy::new()),
grpc_proxy: Arc::new(crate::proxy::grpc::GrpcProxy::new()),
scaling: None,
mirrors: HashMap::new(),
failovers: HashMap::new(),
access_log: Arc::new(crate::observability::access_log::AccessLog::new()),
log_tx: tokio::sync::mpsc::unbounded_channel().0,
sticky_managers: HashMap::new(),
passive_health: HashMap::new(),
metrics: Arc::new(crate::observability::metrics::GatewayMetrics::new()),
});
let rt = tokio::runtime::Runtime::new().unwrap();
let (_shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let result = rt.block_on(start_entrypoints(&config, state, shutdown_rx));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid address"));
}
}