Skip to main content

better_auth_core/middleware/
body_limit.rs

1use super::Middleware;
2use crate::error::AuthResult;
3use crate::types::{AuthRequest, AuthResponse};
4use async_trait::async_trait;
5
6/// Configuration for body size limit middleware.
7#[derive(Debug, Clone)]
8pub struct BodyLimitConfig {
9    /// Maximum body size in bytes. Defaults to 1 MB.
10    pub max_bytes: usize,
11
12    /// Whether the middleware is enabled.
13    pub enabled: bool,
14}
15
16impl Default for BodyLimitConfig {
17    fn default() -> Self {
18        Self {
19            max_bytes: 1_048_576, // 1 MB
20            enabled: true,
21        }
22    }
23}
24
25impl BodyLimitConfig {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn max_bytes(mut self, max: usize) -> Self {
31        self.max_bytes = max;
32        self
33    }
34
35    pub fn enabled(mut self, enabled: bool) -> Self {
36        self.enabled = enabled;
37        self
38    }
39}
40
41/// Body size limit middleware.
42///
43/// Rejects requests whose body exceeds the configured maximum size.
44pub struct BodyLimitMiddleware {
45    config: BodyLimitConfig,
46}
47
48impl BodyLimitMiddleware {
49    pub fn new(config: BodyLimitConfig) -> Self {
50        Self { config }
51    }
52}
53
54#[async_trait]
55impl Middleware for BodyLimitMiddleware {
56    fn name(&self) -> &'static str {
57        "body-limit"
58    }
59
60    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
61        if !self.config.enabled {
62            return Ok(None);
63        }
64
65        if let Some(body) = &req.body
66            && body.len() > self.config.max_bytes
67        {
68            return Ok(Some(AuthResponse::json(
69                413,
70                &crate::types::CodeMessageResponse {
71                    code: "BODY_TOO_LARGE",
72                    message: format!(
73                        "Request body exceeds maximum size of {} bytes",
74                        self.config.max_bytes
75                    ),
76                },
77            )?));
78        }
79
80        Ok(None)
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::types::HttpMethod;
88    use std::collections::HashMap;
89
90    fn make_request_with_body(body_size: usize) -> AuthRequest {
91        AuthRequest {
92            method: HttpMethod::Post,
93            path: "/sign-up/email".to_string(),
94            headers: HashMap::new(),
95            body: Some(vec![0u8; body_size]),
96            query: HashMap::new(),
97            virtual_user_id: None,
98        }
99    }
100
101    #[tokio::test]
102    async fn test_body_limit_allows_within_limit() {
103        let mw = BodyLimitMiddleware::new(BodyLimitConfig::new().max_bytes(1024));
104        let req = make_request_with_body(512);
105        assert!(mw.before_request(&req).await.unwrap().is_none());
106    }
107
108    #[tokio::test]
109    async fn test_body_limit_allows_exact_limit() {
110        let mw = BodyLimitMiddleware::new(BodyLimitConfig::new().max_bytes(1024));
111        let req = make_request_with_body(1024);
112        assert!(mw.before_request(&req).await.unwrap().is_none());
113    }
114
115    #[tokio::test]
116    async fn test_body_limit_rejects_over_limit() {
117        let mw = BodyLimitMiddleware::new(BodyLimitConfig::new().max_bytes(1024));
118        let req = make_request_with_body(2048);
119        let resp = mw.before_request(&req).await.unwrap();
120        assert!(resp.is_some());
121        assert_eq!(resp.unwrap().status, 413);
122    }
123
124    #[tokio::test]
125    async fn test_body_limit_allows_no_body() {
126        let mw = BodyLimitMiddleware::new(BodyLimitConfig::new().max_bytes(1024));
127        let req = AuthRequest {
128            method: HttpMethod::Get,
129            path: "/get-session".to_string(),
130            headers: HashMap::new(),
131            body: None,
132            query: HashMap::new(),
133            virtual_user_id: None,
134        };
135        assert!(mw.before_request(&req).await.unwrap().is_none());
136    }
137
138    #[tokio::test]
139    async fn test_body_limit_disabled() {
140        let config = BodyLimitConfig::new().max_bytes(10).enabled(false);
141        let mw = BodyLimitMiddleware::new(config);
142        let req = make_request_with_body(1000);
143        assert!(mw.before_request(&req).await.unwrap().is_none());
144    }
145}