1use http::{
2 header::CONTENT_LENGTH,
3 Request, Response, StatusCode,
4};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tower::{Layer, Service};
9
10#[derive(Clone, Debug)]
26pub struct BodyLimit {
27 limit_bytes: u64,
28 status: StatusCode,
29}
30
31impl Default for BodyLimit {
32 fn default() -> Self {
33 Self {
34 limit_bytes: 1_048_576,
35 status: StatusCode::PAYLOAD_TOO_LARGE,
36 }
37 }
38}
39
40impl BodyLimit {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn limit(mut self, bytes: u64) -> Self {
50 self.limit_bytes = bytes;
51 self
52 }
53
54 pub fn status(mut self, status: StatusCode) -> Self {
58 self.status = status;
59 self
60 }
61
62 fn exceeds_limit<B>(&self, req: &Request<B>) -> bool {
63 req.headers()
64 .get(CONTENT_LENGTH)
65 .and_then(|v| v.to_str().ok())
66 .and_then(|s| s.parse::<u64>().ok())
67 .is_some_and(|len| len > self.limit_bytes)
68 }
69}
70
71#[derive(Clone, Debug, Default)]
91pub struct BodyLimitLayer {
92 config: BodyLimit,
93}
94
95impl BodyLimitLayer {
96 pub fn new(config: BodyLimit) -> Self {
98 Self { config }
99 }
100}
101
102impl<S> Layer<S> for BodyLimitLayer {
103 type Service = BodyLimitService<S>;
104
105 fn layer(&self, inner: S) -> Self::Service {
106 BodyLimitService {
107 inner,
108 config: self.config.clone(),
109 }
110 }
111}
112
113#[derive(Clone, Debug)]
115pub struct BodyLimitService<S> {
116 inner: S,
117 config: BodyLimit,
118}
119
120impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for BodyLimitService<S>
121where
122 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
123 S::Future: Send + 'static,
124 S::Error: Send + 'static,
125 ResBody: Default + Send + 'static,
126{
127 type Response = Response<ResBody>;
128 type Error = S::Error;
129 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
130
131 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132 self.inner.poll_ready(cx)
133 }
134
135 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
136 if self.config.exceeds_limit(&req) {
137 let status = self.config.status;
138 return Box::pin(async move {
139 Ok(Response::builder()
140 .status(status)
141 .body(ResBody::default())
142 .expect("error response is valid"))
143 });
144 }
145 Box::pin(self.inner.call(req))
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use axum::{body::Body, routing::post, Router};
153 use http::StatusCode;
154 use tower::ServiceExt;
155
156 fn build_app(config: BodyLimit) -> Router {
157 Router::new()
158 .route("/upload", post(|| async { "ok" }))
159 .layer(BodyLimitLayer::new(config))
160 }
161
162 async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
163 app.oneshot(req).await.unwrap()
164 }
165
166 fn post_with_length(content_length: u64) -> http::Request<Body> {
167 http::Request::builder()
168 .method("POST")
169 .uri("/upload")
170 .header(CONTENT_LENGTH, content_length)
171 .body(Body::empty())
172 .unwrap()
173 }
174
175 fn post_without_length() -> http::Request<Body> {
176 http::Request::builder()
177 .method("POST")
178 .uri("/upload")
179 .body(Body::empty())
180 .unwrap()
181 }
182
183 #[tokio::test]
184 async fn rejects_when_content_length_exceeds_limit() {
185 let response = send(
186 build_app(BodyLimit::new().limit(1024)),
187 post_with_length(2048),
188 )
189 .await;
190 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
191 }
192
193 #[tokio::test]
194 async fn passes_when_content_length_within_limit() {
195 let response = send(
196 build_app(BodyLimit::new().limit(1024)),
197 post_with_length(512),
198 )
199 .await;
200 assert_eq!(response.status(), StatusCode::OK);
201 }
202
203 #[tokio::test]
204 async fn passes_at_exact_limit() {
205 let response = send(
206 build_app(BodyLimit::new().limit(1024)),
207 post_with_length(1024),
208 )
209 .await;
210 assert_eq!(response.status(), StatusCode::OK);
211 }
212
213 #[tokio::test]
214 async fn rejects_one_byte_over_limit() {
215 let response = send(
216 build_app(BodyLimit::new().limit(1024)),
217 post_with_length(1025),
218 )
219 .await;
220 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
221 }
222
223 #[tokio::test]
224 async fn passes_when_no_content_length_header() {
225 let response = send(build_app(BodyLimit::new().limit(1024)), post_without_length()).await;
226 assert_eq!(response.status(), StatusCode::OK);
227 }
228
229 #[tokio::test]
230 async fn custom_status_code() {
231 let config = BodyLimit::new()
232 .limit(512)
233 .status(StatusCode::UNPROCESSABLE_ENTITY);
234 let response = send(build_app(config), post_with_length(1024)).await;
235 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
236 }
237
238 #[tokio::test]
239 async fn default_layer_uses_413_and_one_mib_limit() {
240 let app = Router::new()
241 .route("/upload", post(|| async { "ok" }))
242 .layer(BodyLimitLayer::default());
243
244 let within = send(
245 app.clone(),
246 post_with_length(1_048_576),
247 )
248 .await;
249 assert_eq!(within.status(), StatusCode::OK);
250
251 let over = send(app, post_with_length(1_048_577)).await;
252 assert_eq!(over.status(), StatusCode::PAYLOAD_TOO_LARGE);
253 }
254
255 #[tokio::test]
256 async fn zero_limit_rejects_any_body() {
257 let response = send(
258 build_app(BodyLimit::new().limit(0)),
259 post_with_length(1),
260 )
261 .await;
262 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
263 }
264}