use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use http::StatusCode;
use http::header::CONTENT_LENGTH;
use http_body_util::BodyExt;
use crate::body::TakoBody;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub struct BodyLimit<F>
where
F: Fn(&Request) -> usize + Send + Sync + 'static,
{
limit: Option<usize>,
dynamic_limit: Option<F>,
}
impl<F> BodyLimit<F>
where
F: Fn(&Request) -> usize + Send + Sync + 'static,
{
pub fn new(limit: usize) -> Self {
Self {
limit: Some(limit),
dynamic_limit: None,
}
}
pub fn with_dynamic_limit(f: F) -> Self {
Self {
limit: None,
dynamic_limit: Some(f),
}
}
pub fn new_with_dynamic(limit: usize, f: F) -> Self {
Self {
limit: Some(limit),
dynamic_limit: Some(f),
}
}
}
impl<F> IntoMiddleware for BodyLimit<F>
where
F: Fn(&Request) -> usize + Send + Sync + 'static,
{
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let static_limit = self.limit;
let dynamic_limit = self.dynamic_limit.map(Arc::new);
move |req: Request, next: Next| {
let dynamic_limit = dynamic_limit.clone();
Box::pin(async move {
let limit = dynamic_limit
.as_ref()
.map(|f| f(&req))
.or(static_limit)
.unwrap_or(10 * 1024 * 1024);
if let Some(len) = req
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
&& len > limit
{
return (StatusCode::PAYLOAD_TOO_LARGE, "Body exceeds allowed size").into_response();
}
let (parts, body) = req.into_parts();
let collected = match body.collect().await {
Ok(c) => c.to_bytes(),
Err(_) => {
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
}
};
if collected.len() > limit {
return (StatusCode::PAYLOAD_TOO_LARGE, "Body exceeds allowed size").into_response();
}
let req = http::Request::from_parts(parts, TakoBody::from(Bytes::from(collected)));
next.run(req).await.into_response()
})
}
}
}