use super::body::{HyperResponseBody, StreamBody};
use super::dispatch::RouteClass;
#[cfg(feature = "profiling")]
use super::internal_routes::match_profiling_route;
use super::internal_routes::{
build_internal_handler, invoke_internal_route, match_internal_route_from_path,
};
use super::request::RequestHead;
use super::router::{DispatchResult, GateCheck, ServerDispatch, gate_result};
use super::sse::SseWriter;
#[cfg(feature = "ws")]
use super::ws_proxy::{self, WsUpgrade};
use super::{BufferConfig, Request, Response};
use crate::resource::HealthState;
use crate::runtime_state::RuntimeInner;
use std::sync::Arc;
fn build_status_text_table() -> Box<[Box<str>]> {
(100u16..600)
.map(|code| Box::from(code.to_string()))
.collect()
}
#[cfg(feature = "grpc")]
use super::grpc_support::is_grpc_request;
pub(super) struct ConnCtx {
pub(super) tracing_enabled: bool,
pub(super) metrics_handle: Option<metrics_exporter_prometheus::PrometheusHandle>,
#[cfg(feature = "profiling")]
pub(super) profiling_enabled: bool,
pub(super) max_request_body: usize,
pub(super) sse_buffer_size: usize,
#[cfg(feature = "ws")]
pub(super) ws_buffer_size: usize,
pub(super) health_state: Option<HealthState>,
pub(super) is_tls: bool,
}
impl ConnCtx {
pub(super) fn from_runtime(
rt: &Arc<RuntimeInner>,
buffers: BufferConfig,
is_tls: bool,
) -> Self {
Self {
tracing_enabled: rt.config.tracing_enabled,
metrics_handle: rt.metrics_handle.clone(),
#[cfg(feature = "profiling")]
profiling_enabled: rt.config.profiling_enabled,
max_request_body: buffers.max_request_body,
sse_buffer_size: buffers.sse_buffer_size,
#[cfg(feature = "ws")]
ws_buffer_size: buffers.ws_buffer_size,
health_state: rt.health_state.clone(),
is_tls,
}
}
pub(super) fn without_runtime(buffers: BufferConfig, is_tls: bool) -> Self {
Self {
tracing_enabled: false,
metrics_handle: None,
#[cfg(feature = "profiling")]
profiling_enabled: false,
max_request_body: buffers.max_request_body,
sse_buffer_size: buffers.sse_buffer_size,
#[cfg(feature = "ws")]
ws_buffer_size: buffers.ws_buffer_size,
health_state: None,
is_tls,
}
}
}
async fn collect_body_limited(
hyper_req: hyper::Request<hyper::body::Incoming>,
max_body: usize,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<Request, hyper::Response<HyperResponseBody>> {
let (parts, body) = hyper_req.into_parts();
let body_bytes = collect_body(body, max_body).await?;
let mut req = match Request::from_hyper(parts, body_bytes) {
Some(r) => r,
None => return Err(to_hyper_full(Response::text_raw(405, "method not allowed"))),
};
if let Some(addr) = remote_addr {
req.set_remote_addr(addr);
}
req.set_tls(is_tls);
Ok(req)
}
async fn collect_body(
body: hyper::body::Incoming,
max_body: usize,
) -> Result<bytes::Bytes, hyper::Response<HyperResponseBody>> {
use http_body_util::BodyExt;
let limited = http_body_util::Limited::new(body, max_body);
match limited.collect().await {
Ok(collected) => Ok(collected.to_bytes()),
Err(_) => Err(to_hyper_full(Response::text_raw(
413,
"request body too large",
))),
}
}
#[cfg(feature = "ws")]
fn build_head_only_request_ws(
mut hyper_req: hyper::Request<hyper::body::Incoming>,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<(Request, WsUpgrade), Box<hyper::Response<HyperResponseBody>>> {
let ws = ws_proxy::extract_ws_upgrade(&mut hyper_req);
let head = RequestHead::from_hyper_request(&hyper_req, remote_addr, is_tls)
.ok_or_else(|| Box::new(to_hyper_full(Response::text_raw(405, "method not allowed"))))?;
Ok((head.to_request(None), ws))
}
#[cfg(not(feature = "ws"))]
fn build_head_only_request(
hyper_req: hyper::Request<hyper::body::Incoming>,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<Request, Box<hyper::Response<HyperResponseBody>>> {
let head = RequestHead::from_hyper_request(&hyper_req, remote_addr, is_tls)
.ok_or_else(|| Box::new(to_hyper_full(Response::text_raw(405, "method not allowed"))))?;
Ok(head.to_request(None))
}
#[cfg(feature = "ws")]
async fn collect_request(
hyper_req: hyper::Request<hyper::body::Incoming>,
max_body: usize,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<(Request, WsUpgrade), hyper::Response<HyperResponseBody>> {
let mut r = hyper_req;
let ws_upgrade = ws_proxy::extract_ws_upgrade(&mut r);
let req = collect_body_limited(r, max_body, remote_addr, is_tls).await?;
Ok((req, ws_upgrade))
}
#[cfg(not(feature = "ws"))]
async fn collect_request(
hyper_req: hyper::Request<hyper::body::Incoming>,
max_body: usize,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<Request, hyper::Response<HyperResponseBody>> {
collect_body_limited(hyper_req, max_body, remote_addr, is_tls).await
}
pub(super) async fn handle_request(
hyper_req: hyper::Request<hyper::body::Incoming>,
dispatch: &ServerDispatch,
ctx: &ConnCtx,
remote_addr: Option<std::net::IpAddr>,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
#[cfg(feature = "grpc")]
let hyper_req = match try_dispatch_grpc(hyper_req, dispatch, remote_addr, ctx.is_tls).await {
Ok(resp) => return resp,
Err(req) => req,
};
let (route_class, internal_route, pre_method) =
match RequestHead::from_hyper_request(&hyper_req, remote_addr, ctx.is_tls) {
Some(head) => {
let rc = dispatch.classify_route(&head);
let ir = match_internal_route_from_path(head.path(), ctx);
#[cfg(feature = "profiling")]
let ir = ir.or_else(|| match_profiling_route(head.path(), head.query(), ctx));
let m = head.method();
(rc, ir, Some(m))
}
None => (RouteClass::Buffered, None, None),
};
if let Some(route) = internal_route {
return dispatch_internal_head_only(
&hyper_req,
route,
dispatch,
ctx,
remote_addr,
pre_method.unwrap_or(super::method::Method::Get),
)
.await;
}
#[cfg(feature = "ws")]
let is_ws_upgrade = hyper_req
.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("websocket"));
#[cfg(not(feature = "ws"))]
let is_ws_upgrade = false;
let skip_body_collection = matches!(route_class, RouteClass::HeadOnly)
|| (matches!(&route_class, RouteClass::StreamingProxy { .. }) && is_ws_upgrade);
match route_class {
RouteClass::StreamingProxy {
backend,
prefix,
params,
} if !is_ws_upgrade => {
return dispatch_streaming_proxy(
hyper_req,
dispatch,
ctx,
remote_addr,
&backend,
&prefix,
params,
)
.await;
}
RouteClass::StreamingProxyUnhealthy => {
return Ok(to_hyper_full(Response::text_raw(
503,
"service unavailable",
)));
}
_ => {} }
#[cfg(feature = "ws")]
let build_result: Result<(Request, WsUpgrade), hyper::Response<HyperResponseBody>> =
match skip_body_collection {
true => build_head_only_request_ws(hyper_req, remote_addr, ctx.is_tls).map_err(|b| *b),
false => {
collect_request(hyper_req, ctx.max_request_body, remote_addr, ctx.is_tls).await
}
};
#[cfg(feature = "ws")]
let (req, ws_upgrade) = match build_result {
Ok(pair) => pair,
Err(resp) => return Ok(resp),
};
#[cfg(not(feature = "ws"))]
let build_result: Result<Request, hyper::Response<HyperResponseBody>> =
match skip_body_collection {
true => build_head_only_request(hyper_req, remote_addr, ctx.is_tls).map_err(|b| *b),
false => {
collect_request(hyper_req, ctx.max_request_body, remote_addr, ctx.is_tls).await
}
};
#[cfg(not(feature = "ws"))]
let req = match build_result {
Ok(r) => r,
Err(resp) => return Ok(resp),
};
let start = std::time::Instant::now();
let result = dispatch.dispatch(req);
let gate_check = match result.needs_middleware_gate() {
true => dispatch.middleware_gate(result.request_ref()),
false => None,
};
let gate_blocked = match gate_check {
None => None,
Some(GateCheck { reached, fut }) => gate_result(reached, fut.await),
};
if let Some(blocked) = gate_blocked {
let req = result.request_ref();
record_request(ctx, req.method(), req.path(), blocked.status(), start);
return Ok(to_hyper_full(blocked));
}
#[cfg(feature = "ws")]
if let Some(rejected) = result
.is_websocket()
.then(|| ws_proxy::check_ws_origin(result.request_ref()))
.flatten()
{
let req = result.request_ref();
record_request(ctx, req.method(), req.path(), rejected.status(), start);
return Ok(to_hyper_full(rejected));
}
match result {
DispatchResult::Async(fut, req) => {
let resp = strip_body_if_head(req.is_head(), fut.await);
record_request(ctx, req.method(), req.path(), resp.status(), start);
Ok(to_hyper_full(resp))
}
DispatchResult::Stream(fut, req) => handle_stream_response(fut.await, req, ctx, start),
DispatchResult::Sse(handler, req) => {
record_request(ctx, req.method(), req.path(), 200, start);
handle_sse(handler, req, ctx.sse_buffer_size)
}
#[cfg(feature = "ws")]
DispatchResult::WebSocket(handler, req) => {
record_request(ctx, req.method(), req.path(), 101, start);
ws_proxy::handle_ws_upgrade(ws_upgrade, handler, req, ctx.ws_buffer_size)
}
#[cfg(feature = "ws")]
DispatchResult::ProxyWebSocket(req, backend, prefix) => {
record_request(ctx, req.method(), req.path(), 101, start);
ws_proxy::handle_proxy_ws(ws_upgrade, req, backend, prefix)
}
DispatchResult::ProxyStream(req, backend, prefix) => {
handle_proxy_stream_response(req, &backend, &prefix, ctx, start).await
}
}
}
async fn dispatch_internal_head_only(
hyper_req: &hyper::Request<hyper::body::Incoming>,
route: super::internal_routes::InternalRoute,
dispatch: &ServerDispatch,
ctx: &ConnCtx,
remote_addr: Option<std::net::IpAddr>,
method: super::method::Method,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let start = std::time::Instant::now();
match dispatch.skip_middleware_for_internal() {
true => {
let is_head = matches!(method, super::method::Method::Head);
let resp = invoke_internal_route(&route);
let resp = strip_body_if_head(is_head, resp);
record_request(
ctx,
method.as_str(),
hyper_req.uri().path(),
resp.status(),
start,
);
Ok(to_hyper_full(resp))
}
false => {
dispatch_internal_through_middleware(
hyper_req,
route,
dispatch,
ctx,
remote_addr,
start,
)
.await
}
}
}
async fn dispatch_internal_through_middleware(
hyper_req: &hyper::Request<hyper::body::Incoming>,
route: super::internal_routes::InternalRoute,
dispatch: &ServerDispatch,
ctx: &ConnCtx,
remote_addr: Option<std::net::IpAddr>,
start: std::time::Instant,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let head = match RequestHead::from_hyper_request(hyper_req, remote_addr, ctx.is_tls) {
Some(h) => h,
None => return Ok(to_hyper_full(Response::text_raw(405, "method not allowed"))),
};
let req = head.to_request(None);
let handler = build_internal_handler(route);
match dispatch.dispatch_with_handler(&handler, req) {
DispatchResult::Async(fut, req) => {
let resp = strip_body_if_head(req.is_head(), fut.await);
record_request(ctx, req.method(), req.path(), resp.status(), start);
Ok(to_hyper_full(resp))
}
_ => Ok(to_hyper_full(Response::text_raw(
500,
"internal dispatch error",
))),
}
}
fn handle_sse(
handler: super::router::SseHandler,
req: Request,
buffer_size: usize,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let (tx, rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(buffer_size);
let _task = crate::task::spawn(move || {
let mut writer = SseWriter::new(tx);
if let Err(e) = handler(&req, &mut writer) {
tracing::warn!(error = %e, "SSE handler returned error");
}
});
let body = StreamBody { rx };
let builder = hyper::Response::builder()
.status(200)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache");
Ok(streaming_response_or_empty(builder, body))
}
fn empty_stream_body() -> StreamBody {
StreamBody {
rx: tokio::sync::mpsc::channel(1).1,
}
}
fn streaming_response_or_empty(
builder: hyper::http::response::Builder,
body: StreamBody,
) -> hyper::Response<HyperResponseBody> {
builder
.body(HyperResponseBody::Streaming(body))
.unwrap_or_else(|err| {
tracing::error!("failed to build streaming response: {err}");
hyper::Response::new(HyperResponseBody::Streaming(empty_stream_body()))
})
}
pub(super) fn to_hyper_full(resp: Response) -> hyper::Response<HyperResponseBody> {
let (parts, body) = resp.into_hyper().into_parts();
hyper::Response::from_parts(parts, HyperResponseBody::Full(body))
}
fn strip_body_if_head(is_head: bool, resp: Response) -> Response {
match is_head {
true => resp.strip_body(),
false => resp,
}
}
fn handle_stream_response(
stream_resp: super::stream::StreamResponse,
req: Request,
ctx: &ConnCtx,
start: std::time::Instant,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let is_head = req.is_head();
let parts = stream_resp.into_parts();
record_request(ctx, req.method(), req.path(), parts.status, start);
let body = match is_head {
true => empty_stream_body(),
false => StreamBody { rx: parts.rx },
};
let mut builder = hyper::Response::builder().status(parts.status);
for (name, value) in &parts.headers {
builder = builder.header(name.as_ref(), value.as_ref());
}
Ok(streaming_response_or_empty(builder, body))
}
async fn handle_proxy_stream_response(
req: Request,
backend: &str,
prefix: &str,
ctx: &ConnCtx,
start: std::time::Instant,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let proxy_req = super::async_proxy::ProxyRequest::from_request(&req);
let is_head = req.is_head();
let upstream =
match super::async_proxy::forward_request_streaming(proxy_req, backend, prefix).await {
Ok(u) => u,
Err(e) => {
tracing::warn!(error = %e, "streaming proxy upstream failed");
record_request(ctx, req.method(), req.path(), 502, start);
return Ok(to_hyper_full(Response::text_raw(
502,
"proxy upstream failed",
)));
}
};
record_request(ctx, req.method(), req.path(), upstream.status, start);
let mut builder = hyper::Response::builder().status(upstream.status);
for (name, value) in upstream.headers.iter() {
builder = builder.header(name.as_ref(), value.as_ref());
}
let body = match is_head {
true => empty_stream_body(),
false => StreamBody { rx: upstream.rx },
};
Ok(streaming_response_or_empty(builder, body))
}
async fn dispatch_streaming_proxy(
hyper_req: hyper::Request<hyper::body::Incoming>,
dispatch: &ServerDispatch,
ctx: &ConnCtx,
remote_addr: Option<std::net::IpAddr>,
backend: &str,
prefix: &str,
params: super::request::Params,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let start = std::time::Instant::now();
let method =
super::method::Method::from_hyper(hyper_req.method()).unwrap_or(super::method::Method::Get);
let method_str = method.as_str();
let is_head = matches!(method, super::method::Method::Head);
let gate_blocked =
run_head_gate(&hyper_req, dispatch, remote_addr, ctx.is_tls, Some(params)).await;
if let Some(blocked) = gate_blocked {
record_request(
ctx,
method_str,
hyper_req.uri().path(),
blocked.status(),
start,
);
return Ok(to_hyper_full(blocked));
}
let scheme = match ctx.is_tls {
true => "https",
false => "http",
};
let (hyper_parts, body) = hyper_req.into_parts();
let path: Box<str> = hyper_parts.uri.path().into();
let proxy_parts = super::async_proxy::IncomingProxyParts {
method,
path_and_query: hyper_parts
.uri
.path_and_query()
.map_or("/", |pq| pq.as_str())
.into(),
headers: hyper_parts.headers,
remote_addr,
scheme,
};
let upstream =
match super::async_proxy::forward_incoming_streaming(proxy_parts, body, backend, prefix)
.await
{
Ok(u) => u,
Err(e) => {
tracing::warn!(error = %e, "streaming proxy upstream failed");
record_request(ctx, method_str, &path, 502, start);
return Ok(to_hyper_full(Response::text_raw(
502,
"proxy upstream failed",
)));
}
};
record_request(ctx, method_str, &path, upstream.status, start);
let mut builder = hyper::Response::builder().status(upstream.status);
for (name, value) in upstream.headers.iter() {
builder = builder.header(name.as_ref(), value.as_ref());
}
let response_body = match is_head {
true => empty_stream_body(),
false => StreamBody { rx: upstream.rx },
};
Ok(streaming_response_or_empty(builder, response_body))
}
fn record_request(
ctx: &ConnCtx,
method: &'static str,
path: &str,
status: u16,
start: std::time::Instant,
) {
let elapsed = start.elapsed();
if ctx.tracing_enabled {
tracing::info!(
method,
path,
status,
latency_ms = elapsed.as_millis(),
"request completed"
);
}
if ctx.metrics_handle.is_some() {
let status_label = status_to_label(status);
metrics::counter!(
"http_requests_total",
"method" => method,
"status" => status_label,
)
.increment(1);
metrics::histogram!(
"http_request_duration_seconds",
"method" => method,
"status" => status_label,
)
.record(elapsed.as_secs_f64());
}
}
fn status_to_label(status: u16) -> &'static str {
match status {
200 => "200",
201 => "201",
204 => "204",
301 => "301",
302 => "302",
304 => "304",
400 => "400",
401 => "401",
403 => "403",
404 => "404",
405 => "405",
413 => "413",
500 => "500",
502 => "502",
503 => "503",
100..600 => {
static TABLE: std::sync::OnceLock<Box<[Box<str>]>> = std::sync::OnceLock::new();
let table = TABLE.get_or_init(build_status_text_table);
&table[(status - 100) as usize]
}
_ => "unknown",
}
}
#[cfg(feature = "grpc")]
async fn try_dispatch_grpc(
hyper_req: hyper::Request<hyper::body::Incoming>,
dispatch: &ServerDispatch,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<
Result<hyper::Response<HyperResponseBody>, std::convert::Infallible>,
hyper::Request<hyper::body::Incoming>,
> {
let is_grpc = dispatch.grpc_router().is_some() && is_grpc_request(&hyper_req);
match is_grpc {
false => Err(hyper_req),
true => Ok(dispatch_grpc_inner(hyper_req, dispatch, remote_addr, is_tls).await),
}
}
#[cfg(feature = "grpc")]
async fn dispatch_grpc_inner(
hyper_req: hyper::Request<hyper::body::Incoming>,
dispatch: &ServerDispatch,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
) -> Result<hyper::Response<HyperResponseBody>, std::convert::Infallible> {
let grpc_router = match dispatch.grpc_router() {
Some(r) => r,
None => {
return Ok(to_hyper_full(Response::text_raw(
500,
"grpc router missing",
)));
}
};
let blocked = run_head_gate(&hyper_req, dispatch, remote_addr, is_tls, None).await;
match blocked {
Some(resp) => Ok(to_hyper_full(resp)),
None => grpc_router.dispatch(hyper_req).await,
}
}
async fn run_head_gate(
hyper_req: &hyper::Request<hyper::body::Incoming>,
dispatch: &ServerDispatch,
remote_addr: Option<std::net::IpAddr>,
is_tls: bool,
params: Option<super::request::Params>,
) -> Option<Response> {
let head = RequestHead::from_hyper_request(hyper_req, remote_addr, is_tls)?;
let GateCheck { reached, fut } = dispatch.middleware_gate_head(&head, params)?;
gate_result(reached, fut.await)
}