1use axum::{
7 extract::Request,
8 response::{Response, IntoResponse},
9 body::Body,
10 http::{StatusCode, HeaderValue},
11};
12use tracing::{warn, error};
13
14use crate::{
15 middleware::{Middleware, BoxFuture},
16 HttpError,
17};
18
19#[derive(Debug, Clone)]
21pub struct BodyLimitConfig {
22 pub max_size: usize,
24 pub log_oversized: bool,
26 pub error_message: String,
28 pub include_headers: bool,
30}
31
32impl Default for BodyLimitConfig {
33 fn default() -> Self {
34 Self {
35 max_size: 2 * 1024 * 1024, log_oversized: true,
37 error_message: "Request body too large".to_string(),
38 include_headers: true,
39 }
40 }
41}
42
43impl BodyLimitConfig {
44 pub fn new(max_size: usize) -> Self {
46 Self {
47 max_size,
48 ..Default::default()
49 }
50 }
51
52 pub fn with_max_size(mut self, max_size: usize) -> Self {
54 self.max_size = max_size;
55 self
56 }
57
58 pub fn with_logging(mut self, log_oversized: bool) -> Self {
60 self.log_oversized = log_oversized;
61 self
62 }
63
64 pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
66 self.error_message = message.into();
67 self
68 }
69
70 pub fn with_headers(mut self, include_headers: bool) -> Self {
72 self.include_headers = include_headers;
73 self
74 }
75}
76
77pub struct BodyLimitMiddleware {
79 config: BodyLimitConfig,
80}
81
82impl BodyLimitMiddleware {
83 pub fn new() -> Self {
85 Self {
86 config: BodyLimitConfig::default(),
87 }
88 }
89
90 pub fn with_limit(max_size: usize) -> Self {
92 Self {
93 config: BodyLimitConfig::new(max_size),
94 }
95 }
96
97 pub fn with_config(config: BodyLimitConfig) -> Self {
99 Self { config }
100 }
101
102 pub fn max_size(mut self, size: usize) -> Self {
104 self.config = self.config.with_max_size(size);
105 self
106 }
107
108 pub fn logging(mut self, enabled: bool) -> Self {
110 self.config = self.config.with_logging(enabled);
111 self
112 }
113
114 pub fn message<S: Into<String>>(mut self, message: S) -> Self {
116 self.config = self.config.with_message(message);
117 self
118 }
119
120 pub fn limit(&self) -> usize {
122 self.config.max_size
123 }
124
125 fn create_error_response(&self, content_length: Option<usize>) -> Response {
127 let mut error = HttpError::payload_too_large(&self.config.error_message);
128
129 if self.config.include_headers {
130 if let Some(size) = content_length {
131 error = error.with_detail(&format!(
132 "Request body size {} bytes exceeds limit of {} bytes",
133 size,
134 self.config.max_size
135 ));
136 } else {
137 error = error.with_detail(&format!(
138 "Request body exceeds limit of {} bytes",
139 self.config.max_size
140 ));
141 }
142 }
143
144 let mut response = error.into_response();
145
146 if self.config.include_headers {
147 if let Ok(max_size_header) = HeaderValue::from_str(&self.config.max_size.to_string()) {
148 response.headers_mut().insert("X-Max-Body-Size", max_size_header);
149 }
150 }
151
152 response
153 }
154
155 fn check_content_length(&self, request: &Request) -> Result<Option<usize>, Response> {
157 if let Some(content_length) = request.headers().get("content-length") {
158 if let Ok(content_length_str) = content_length.to_str() {
159 if let Ok(content_length) = content_length_str.parse::<usize>() {
160 if content_length > self.config.max_size {
161 if self.config.log_oversized {
162 warn!(
163 "Request body size {} bytes exceeds limit of {} bytes (Content-Length check)",
164 content_length,
165 self.config.max_size
166 );
167 }
168 return Err(self.create_error_response(Some(content_length)));
169 }
170 return Ok(Some(content_length));
171 }
172 }
173 }
174 Ok(None)
175 }
176}
177
178impl Default for BodyLimitMiddleware {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl Middleware for BodyLimitMiddleware {
185 fn process_request<'a>(
186 &'a self,
187 request: Request
188 ) -> BoxFuture<'a, Result<Request, Response>> {
189 Box::pin(async move {
190 let content_length = match self.check_content_length(&request) {
192 Ok(length) => length,
193 Err(response) => return Err(response),
194 };
195
196 let mut request = request;
198 request.extensions_mut().insert(BodyLimitInfo {
199 max_size: self.config.max_size,
200 content_length,
201 error_message: self.config.error_message.clone(),
202 });
203
204 Ok(request)
210 })
211 }
212
213 fn process_response<'a>(
214 &'a self,
215 response: Response
216 ) -> BoxFuture<'a, Response> {
217 Box::pin(async move {
218 if response.status() == StatusCode::PAYLOAD_TOO_LARGE && self.config.log_oversized {
220 warn!("Returned 413 Payload Too Large response due to body size limit");
221 }
222
223 response
224 })
225 }
226
227 fn name(&self) -> &'static str {
228 "BodyLimitMiddleware"
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct BodyLimitInfo {
235 pub max_size: usize,
236 pub content_length: Option<usize>,
237 pub error_message: String,
238}
239
240pub fn limit_body_size(body: Body, max_size: usize) -> LimitedBody {
242 LimitedBody {
243 body,
244 max_size,
245 consumed: 0,
246 }
247}
248
249pub struct LimitedBody {
251 body: Body,
252 max_size: usize,
253 consumed: usize,
254}
255
256impl LimitedBody {
257 pub fn new(body: Body, max_size: usize) -> Self {
259 Self {
260 body,
261 max_size,
262 consumed: 0,
263 }
264 }
265
266 pub fn remaining(&self) -> usize {
268 self.max_size.saturating_sub(self.consumed)
269 }
270
271 pub fn consumed(&self) -> usize {
273 self.consumed
274 }
275
276 pub fn is_exceeded(&self) -> bool {
278 self.consumed > self.max_size
279 }
280}
281
282pub mod limits {
288 pub const KB: usize = 1024;
290
291 pub const MB: usize = 1024 * 1024;
293
294 pub const MB_10: usize = 10 * MB;
296
297 pub const MB_100: usize = 100 * MB;
299
300 pub const GB: usize = 1024 * MB;
302
303 pub mod presets {
305 use super::super::BodyLimitMiddleware;
306 use super::*;
307
308 pub fn small_api() -> BodyLimitMiddleware {
310 BodyLimitMiddleware::with_limit(MB)
311 .message("API request body too large (1MB limit)")
312 }
313
314 pub fn file_upload() -> BodyLimitMiddleware {
316 BodyLimitMiddleware::with_limit(MB_10)
317 .message("File upload too large (10MB limit)")
318 }
319
320 pub fn large_upload() -> BodyLimitMiddleware {
322 BodyLimitMiddleware::with_limit(MB_100)
323 .message("Large file upload too large (100MB limit)")
324 }
325
326 pub fn tiny() -> BodyLimitMiddleware {
328 BodyLimitMiddleware::with_limit(64 * KB)
329 .message("Request body too large (64KB limit)")
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use axum::http::{Method, HeaderValue};
338
339 #[tokio::test]
340 async fn test_body_limit_middleware_basic() {
341 let middleware = BodyLimitMiddleware::new();
342
343 let request = Request::builder()
344 .method(Method::POST)
345 .uri("/test")
346 .body(Body::empty())
347 .unwrap();
348
349 let result = middleware.process_request(request).await;
350 assert!(result.is_ok());
351
352 let processed_request = result.unwrap();
353
354 let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>();
356 assert!(body_limit_info.is_some());
357
358 let body_limit_info = body_limit_info.unwrap();
359 assert_eq!(body_limit_info.max_size, 2 * 1024 * 1024); assert!(body_limit_info.content_length.is_none());
361 }
362
363 #[tokio::test]
364 async fn test_body_limit_middleware_custom_limit() {
365 let middleware = BodyLimitMiddleware::with_limit(1024); assert_eq!(middleware.limit(), 1024);
368 }
369
370 #[tokio::test]
371 async fn test_body_limit_middleware_builder() {
372 let middleware = BodyLimitMiddleware::new()
373 .max_size(512)
374 .logging(false)
375 .message("Too big!");
376
377 assert_eq!(middleware.config.max_size, 512);
378 assert!(!middleware.config.log_oversized);
379 assert_eq!(middleware.config.error_message, "Too big!");
380 }
381
382 #[tokio::test]
383 async fn test_content_length_check_within_limit() {
384 let middleware = BodyLimitMiddleware::with_limit(1000);
385
386 let request = Request::builder()
387 .method(Method::POST)
388 .header("content-length", "500")
389 .uri("/test")
390 .body(Body::empty())
391 .unwrap();
392
393 let result = middleware.process_request(request).await;
394 assert!(result.is_ok());
395
396 let processed_request = result.unwrap();
397 let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>().unwrap();
398 assert_eq!(body_limit_info.content_length, Some(500));
399 }
400
401 #[tokio::test]
402 async fn test_content_length_check_exceeds_limit() {
403 let middleware = BodyLimitMiddleware::with_limit(100);
404
405 let request = Request::builder()
406 .method(Method::POST)
407 .header("content-length", "200")
408 .uri("/test")
409 .body(Body::empty())
410 .unwrap();
411
412 let result = middleware.process_request(request).await;
413 assert!(result.is_err());
414
415 let error_response = result.unwrap_err();
416 assert_eq!(error_response.status(), StatusCode::PAYLOAD_TOO_LARGE);
417
418 assert!(error_response.headers().contains_key("X-Max-Body-Size"));
420 assert_eq!(
421 error_response.headers().get("X-Max-Body-Size").unwrap(),
422 "100"
423 );
424 }
425
426 #[tokio::test]
427 async fn test_body_limit_config() {
428 let config = BodyLimitConfig::new(512)
429 .with_logging(false)
430 .with_message("Custom message")
431 .with_headers(false);
432
433 let middleware = BodyLimitMiddleware::with_config(config);
434
435 assert_eq!(middleware.config.max_size, 512);
436 assert!(!middleware.config.log_oversized);
437 assert_eq!(middleware.config.error_message, "Custom message");
438 assert!(!middleware.config.include_headers);
439 }
440
441 #[tokio::test]
442 async fn test_body_limit_middleware_name() {
443 let middleware = BodyLimitMiddleware::new();
444 assert_eq!(middleware.name(), "BodyLimitMiddleware");
445 }
446
447 #[tokio::test]
448 async fn test_limited_body_creation() {
449 let body = Body::empty();
450 let limited = limit_body_size(body, 1024);
451
452 assert_eq!(limited.remaining(), 1024);
453 assert_eq!(limited.consumed(), 0);
454 assert!(!limited.is_exceeded());
455 }
456
457 #[tokio::test]
458 async fn test_body_limit_presets() {
459 let small = limits::presets::small_api();
460 assert_eq!(small.limit(), limits::MB);
461
462 let upload = limits::presets::file_upload();
463 assert_eq!(upload.limit(), limits::MB_10);
464
465 let large = limits::presets::large_upload();
466 assert_eq!(large.limit(), limits::MB_100);
467
468 let tiny = limits::presets::tiny();
469 assert_eq!(tiny.limit(), 64 * limits::KB);
470 }
471
472 #[tokio::test]
473 async fn test_body_limit_constants() {
474 assert_eq!(limits::KB, 1024);
475 assert_eq!(limits::MB, 1024 * 1024);
476 assert_eq!(limits::MB_10, 10 * 1024 * 1024);
477 assert_eq!(limits::MB_100, 100 * 1024 * 1024);
478 assert_eq!(limits::GB, 1024 * 1024 * 1024);
479 }
480
481 #[tokio::test]
482 async fn test_invalid_content_length_header() {
483 let middleware = BodyLimitMiddleware::with_limit(1000);
484
485 let request = Request::builder()
486 .method(Method::POST)
487 .header("content-length", "not-a-number")
488 .uri("/test")
489 .body(Body::empty())
490 .unwrap();
491
492 let result = middleware.process_request(request).await;
494 assert!(result.is_ok());
495
496 let processed_request = result.unwrap();
497 let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>().unwrap();
498 assert!(body_limit_info.content_length.is_none());
499 }
500}