use crate::Request;
use crate::dep::http_body_util::Limited;
use bytes::Bytes;
use rama_core::{Context, Layer, Service, error::BoxError};
use rama_http_types::Body;
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;
#[derive(Debug, Clone)]
pub struct BodyLimitLayer {
size: usize,
}
impl BodyLimitLayer {
pub const fn new(size: usize) -> Self {
Self { size }
}
}
impl<S> Layer<S> for BodyLimitLayer {
type Service = BodyLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
BodyLimitService::new(inner, self.size)
}
}
#[derive(Clone)]
pub struct BodyLimitService<S> {
inner: S,
size: usize,
}
impl<S> BodyLimitService<S> {
pub const fn new(service: S, size: usize) -> Self {
Self {
inner: service,
size,
}
}
define_inner_service_accessors!();
}
impl<S, State, ReqBody> Service<State, Request<ReqBody>> for BodyLimitService<S>
where
S: Service<State, Request<Body>>,
State: Clone + Send + Sync + 'static,
ReqBody: rama_http_types::dep::http_body::Body<Data = Bytes, Error: Into<BoxError>>
+ Send
+ Sync
+ 'static,
{
type Response = S::Response;
type Error = S::Error;
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
let req = req.map(|body| {
if self.size == 0 {
Body::new(body)
} else {
Body::new(Limited::new(body, self.size))
}
});
self.inner.serve(ctx, req).await
}
}
impl<S> fmt::Debug for BodyLimitService<S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BodyLimitService")
.field("inner", &self.inner)
.field("size", &self.size)
.finish()
}
}