use std::convert::Infallible;
use std::time::Duration;
use bytes::Bytes;
use tower::ServiceBuilder;
use crate::http::body::BoxError;
use crate::Body;
#[derive(Debug, Clone)]
pub struct CompressionOptions {
pub min_body_bytes: usize,
pub level: Option<u32>,
}
impl CompressionOptions {
pub const DEFAULT_MIN_BODY_BYTES: usize = 1024;
}
impl Default for CompressionOptions {
fn default() -> Self {
Self {
min_body_bytes: Self::DEFAULT_MIN_BODY_BYTES,
level: None,
}
}
}
pub(crate) type ServeService =
tower::util::BoxCloneService<hyper::Request<Body>, hyper::Response<Body>, Infallible>;
pub(crate) type ClientService =
tower::util::BoxService<hyper::Request<Body>, hyper::Response<Body>, hyper::Error>;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct StackConfig {
pub timeout: Option<Duration>,
pub max_request_body_wire_bytes: Option<usize>,
pub max_request_body_decoded_bytes: Option<usize>,
pub load_shed: bool,
pub compression: Option<CompressionOptions>,
pub decompression: bool,
}
impl Default for StackConfig {
fn default() -> Self {
Self {
timeout: None,
max_request_body_wire_bytes: None,
max_request_body_decoded_bytes: None,
load_shed: false,
compression: None,
decompression: true,
}
}
}
impl StackConfig {
#[must_use]
pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn with_decompression(mut self, decompression: bool) -> Self {
self.decompression = decompression;
self
}
}
pub(crate) fn build_stack(svc: ServeService, cfg: &StackConfig) -> ServeService {
let svc = apply_body_limit(svc, cfg.max_request_body_decoded_bytes); let svc = apply_decompression(svc, cfg.decompression);
let svc = apply_compression(svc, cfg.compression.as_ref());
let svc = apply_timeout(svc, cfg.timeout);
let svc = apply_load_shed(svc, cfg.load_shed);
apply_body_limit(svc, cfg.max_request_body_wire_bytes) }
fn renormalize_body<B>(body: B) -> Body
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
Body::new(body)
}
pub(crate) fn apply_body_limit(svc: ServeService, limit: Option<usize>) -> ServeService {
use tower::ServiceExt;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::map_request_body::MapRequestBodyLayer;
use tower_http::map_response_body::MapResponseBodyLayer;
let Some(limit) = limit else {
return svc;
};
ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(renormalize_body))
.layer(RequestBodyLimitLayer::new(limit))
.layer(MapRequestBodyLayer::new(renormalize_body))
.service(svc)
.boxed_clone()
}
pub(crate) fn apply_load_shed(svc: ServeService, enabled: bool) -> ServeService {
use crate::http::server::HandleLayerErrorLayer;
use tower::ServiceExt;
if !enabled {
return svc;
}
ServiceBuilder::new()
.layer(HandleLayerErrorLayer)
.layer(tower::load_shed::LoadShedLayer::new())
.service(svc)
.boxed_clone()
}
pub(crate) fn apply_timeout(svc: ServeService, timeout: Option<Duration>) -> ServeService {
use crate::http::server::HandleLayerErrorLayer;
use tower::ServiceExt;
let Some(timeout) = timeout else {
return svc;
};
ServiceBuilder::new()
.layer(HandleLayerErrorLayer)
.layer(tower::timeout::TimeoutLayer::new(timeout))
.service(svc)
.boxed_clone()
}
pub(crate) fn apply_compression(
svc: ServeService,
comp: Option<&CompressionOptions>,
) -> ServeService {
use tower::ServiceExt;
use tower_http::map_response_body::MapResponseBodyLayer;
let Some(comp) = comp else {
return svc;
};
ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(renormalize_body))
.layer(build_compression_layer(comp))
.service(svc)
.boxed_clone()
}
pub(crate) fn apply_decompression(svc: ServeService, enabled: bool) -> ServeService {
use tower::ServiceExt;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::map_request_body::MapRequestBodyLayer;
use tower_http::map_response_body::MapResponseBodyLayer;
if !enabled {
return svc;
}
ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(renormalize_body))
.layer(RequestDecompressionLayer::new())
.layer(MapRequestBodyLayer::new(renormalize_body))
.service(svc)
.boxed_clone()
}
pub(crate) fn build_compression_layer(
comp: &CompressionOptions,
) -> tower_http::compression::CompressionLayer<impl tower_http::compression::Predicate> {
use http::{Extensions, HeaderMap, StatusCode, Version};
use tower_http::compression::{
predicate::{NotForContentType, Predicate, SizeAbove},
CompressionLayer, CompressionLevel,
};
let mut layer = CompressionLayer::new().zstd(true);
if let Some(level) = comp.level {
layer = layer.quality(CompressionLevel::Precise(level as i32));
}
let not_pre_compressed = |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
!h.contains_key(http::header::CONTENT_ENCODING)
};
let not_no_transform = |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
h.get(http::header::CACHE_CONTROL)
.and_then(|v| v.to_str().ok())
.map(|v| {
!v.split(',')
.any(|d| d.trim().eq_ignore_ascii_case("no-transform"))
})
.unwrap_or(true)
};
let predicate = SizeAbove::new(comp.min_body_bytes.min(u16::MAX as usize) as u16)
.and(NotForContentType::IMAGES)
.and(NotForContentType::SSE)
.and(NotForContentType::const_new("audio/"))
.and(NotForContentType::const_new("video/"))
.and(NotForContentType::const_new("application/zstd"))
.and(NotForContentType::const_new("application/octet-stream"))
.and(not_pre_compressed)
.and(not_no_transform);
layer.compress_when(predicate)
}
pub(crate) fn build_client_stack(
sender: hyper::client::conn::http1::SendRequest<Body>,
cfg: &StackConfig,
) -> ClientService {
use tower::ServiceExt;
use tower_http::map_response_body::MapResponseBodyLayer;
struct SendRequestSvc(hyper::client::conn::http1::SendRequest<Body>);
impl tower::Service<hyper::Request<Body>> for SendRequestSvc {
type Response = hyper::Response<hyper::body::Incoming>;
type Error = hyper::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
Box::pin(self.0.send_request(req))
}
}
let svc = ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(renormalize_body))
.service(SendRequestSvc(sender))
.boxed();
apply_client_decompression(svc, cfg.decompression)
}
pub(crate) fn apply_client_decompression(svc: ClientService, enabled: bool) -> ClientService {
use tower::ServiceExt;
use tower_http::decompression::DecompressionLayer;
use tower_http::map_response_body::MapResponseBodyLayer;
if !enabled {
return svc;
}
ServiceBuilder::new()
.layer(MapResponseBodyLayer::new(renormalize_body))
.layer(DecompressionLayer::new())
.service(svc)
.boxed()
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http_body_util::BodyExt;
use std::convert::Infallible;
use tower::ServiceExt;
#[derive(Clone)]
struct EchoService;
impl tower::Service<hyper::Request<Body>> for EchoService {
type Response = hyper::Response<Body>;
type Error = Infallible;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
Box::pin(async move {
let bytes = req
.into_body()
.collect()
.await
.map(|c| c.to_bytes())
.unwrap_or_default();
Ok(hyper::Response::new(Body::full(bytes)))
})
}
}
fn default_cfg() -> StackConfig {
StackConfig {
timeout: None,
max_request_body_wire_bytes: Some(1024 * 1024),
max_request_body_decoded_bytes: Some(1024 * 1024),
load_shed: true,
compression: None,
decompression: true,
}
}
fn boxed_echo() -> ServeService {
ServiceBuilder::new().service(EchoService).boxed_clone()
}
#[tokio::test]
async fn real_chain_round_trips_a_request() {
let stack = build_stack(boxed_echo(), &default_cfg());
let req = hyper::Request::builder()
.uri("/")
.body(Body::full("ping"))
.unwrap();
let resp = stack.oneshot(req).await.expect("service infallible");
assert_eq!(resp.status(), hyper::StatusCode::OK);
let body = resp
.into_body()
.collect()
.await
.expect("body collect")
.to_bytes();
assert_eq!(body, Bytes::from_static(b"ping"));
}
#[tokio::test]
async fn real_chain_with_compression_enabled_still_round_trips() {
let mut cfg = default_cfg();
cfg.compression = Some(CompressionOptions {
level: None,
min_body_bytes: 0,
});
let stack = build_stack(boxed_echo(), &cfg);
let req = hyper::Request::builder()
.uri("/")
.header("accept-encoding", "zstd")
.body(Body::full("ping"))
.unwrap();
let resp = stack.oneshot(req).await.expect("service infallible");
assert_eq!(resp.status(), hyper::StatusCode::OK);
let _ = resp.into_body().collect().await;
}
#[tokio::test]
async fn build_stack_accepts_additional_outer_layer() {
use tower_http::map_request_body::MapRequestBodyLayer;
let inner = build_stack(boxed_echo(), &default_cfg());
let stack = ServiceBuilder::new()
.layer(MapRequestBodyLayer::new(|b: Body| b))
.service(inner);
let req = hyper::Request::builder()
.uri("/")
.body(Body::full("ping"))
.unwrap();
let resp = stack.oneshot(req).await.expect("service infallible");
assert_eq!(resp.status(), hyper::StatusCode::OK);
let body = resp
.into_body()
.collect()
.await
.expect("body collect")
.to_bytes();
assert_eq!(body, Bytes::from_static(b"ping"));
}
}