Skip to main content

rune_axum_size/
layer.rs

1use http::{
2    header::CONTENT_LENGTH,
3    Request, Response, StatusCode,
4};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tower::{Layer, Service};
9
10/// Configuration for the request body size limit middleware.
11///
12/// Build with [`BodyLimit::new()`] and chain methods to set the byte limit or
13/// error status code, then pass to [`BodyLimitLayer::new()`].
14///
15/// # Examples
16///
17/// ```rust
18/// use http::StatusCode;
19/// use rune_axum_size::BodyLimit;
20///
21/// let config = BodyLimit::new()
22///     .limit(10 * 1024 * 1024) // 10 MiB
23///     .status(StatusCode::PAYLOAD_TOO_LARGE);
24/// ```
25#[derive(Clone, Debug)]
26pub struct BodyLimit {
27    limit_bytes: u64,
28    status: StatusCode,
29}
30
31impl Default for BodyLimit {
32    fn default() -> Self {
33        Self {
34            limit_bytes: 1_048_576,
35            status: StatusCode::PAYLOAD_TOO_LARGE,
36        }
37    }
38}
39
40impl BodyLimit {
41    /// Creates a `BodyLimit` with defaults: 1 MiB limit, `413 Payload Too Large`.
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Sets the maximum allowed body size in bytes.
47    ///
48    /// Defaults to 1 MiB (1,048,576 bytes).
49    pub fn limit(mut self, bytes: u64) -> Self {
50        self.limit_bytes = bytes;
51        self
52    }
53
54    /// Sets the HTTP status code returned when the limit is exceeded.
55    ///
56    /// Defaults to `413 Payload Too Large`.
57    pub fn status(mut self, status: StatusCode) -> Self {
58        self.status = status;
59        self
60    }
61
62    fn exceeds_limit<B>(&self, req: &Request<B>) -> bool {
63        req.headers()
64            .get(CONTENT_LENGTH)
65            .and_then(|v| v.to_str().ok())
66            .and_then(|s| s.parse::<u64>().ok())
67            .is_some_and(|len| len > self.limit_bytes)
68    }
69}
70
71/// Tower [`Layer`] that rejects requests whose `Content-Length` exceeds the configured limit.
72///
73/// Apply with Axum's `.layer()` call. Use [`BodyLimitLayer::default()`] for a 1 MiB limit
74/// with a `413` response, or [`BodyLimitLayer::new()`] to supply a custom [`BodyLimit`].
75///
76/// > [!NOTE]
77/// > Only the `Content-Length` header is inspected. Requests that omit this header
78/// > pass through regardless of their actual body size.
79///
80/// # Examples
81///
82/// ```rust,no_run
83/// use axum::{routing::post, Router};
84/// use rune_axum_size::BodyLimitLayer;
85///
86/// let app: Router = Router::new()
87///     .route("/upload", post(|| async { "ok" }))
88///     .layer(BodyLimitLayer::default());
89/// ```
90#[derive(Clone, Debug, Default)]
91pub struct BodyLimitLayer {
92    config: BodyLimit,
93}
94
95impl BodyLimitLayer {
96    /// Creates a `BodyLimitLayer` from a custom [`BodyLimit`] configuration.
97    pub fn new(config: BodyLimit) -> Self {
98        Self { config }
99    }
100}
101
102impl<S> Layer<S> for BodyLimitLayer {
103    type Service = BodyLimitService<S>;
104
105    fn layer(&self, inner: S) -> Self::Service {
106        BodyLimitService {
107            inner,
108            config: self.config.clone(),
109        }
110    }
111}
112
113/// Tower [`Service`] produced by [`BodyLimitLayer`].
114#[derive(Clone, Debug)]
115pub struct BodyLimitService<S> {
116    inner: S,
117    config: BodyLimit,
118}
119
120impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for BodyLimitService<S>
121where
122    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
123    S::Future: Send + 'static,
124    S::Error: Send + 'static,
125    ResBody: Default + Send + 'static,
126{
127    type Response = Response<ResBody>;
128    type Error = S::Error;
129    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
130
131    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132        self.inner.poll_ready(cx)
133    }
134
135    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
136        if self.config.exceeds_limit(&req) {
137            let status = self.config.status;
138            return Box::pin(async move {
139                Ok(Response::builder()
140                    .status(status)
141                    .body(ResBody::default())
142                    .expect("error response is valid"))
143            });
144        }
145        Box::pin(self.inner.call(req))
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use axum::{body::Body, routing::post, Router};
153    use http::StatusCode;
154    use tower::ServiceExt;
155
156    fn build_app(config: BodyLimit) -> Router {
157        Router::new()
158            .route("/upload", post(|| async { "ok" }))
159            .layer(BodyLimitLayer::new(config))
160    }
161
162    async fn send(app: Router, req: http::Request<Body>) -> http::Response<Body> {
163        app.oneshot(req).await.unwrap()
164    }
165
166    fn post_with_length(content_length: u64) -> http::Request<Body> {
167        http::Request::builder()
168            .method("POST")
169            .uri("/upload")
170            .header(CONTENT_LENGTH, content_length)
171            .body(Body::empty())
172            .unwrap()
173    }
174
175    fn post_without_length() -> http::Request<Body> {
176        http::Request::builder()
177            .method("POST")
178            .uri("/upload")
179            .body(Body::empty())
180            .unwrap()
181    }
182
183    #[tokio::test]
184    async fn rejects_when_content_length_exceeds_limit() {
185        let response = send(
186            build_app(BodyLimit::new().limit(1024)),
187            post_with_length(2048),
188        )
189        .await;
190        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
191    }
192
193    #[tokio::test]
194    async fn passes_when_content_length_within_limit() {
195        let response = send(
196            build_app(BodyLimit::new().limit(1024)),
197            post_with_length(512),
198        )
199        .await;
200        assert_eq!(response.status(), StatusCode::OK);
201    }
202
203    #[tokio::test]
204    async fn passes_at_exact_limit() {
205        let response = send(
206            build_app(BodyLimit::new().limit(1024)),
207            post_with_length(1024),
208        )
209        .await;
210        assert_eq!(response.status(), StatusCode::OK);
211    }
212
213    #[tokio::test]
214    async fn rejects_one_byte_over_limit() {
215        let response = send(
216            build_app(BodyLimit::new().limit(1024)),
217            post_with_length(1025),
218        )
219        .await;
220        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
221    }
222
223    #[tokio::test]
224    async fn passes_when_no_content_length_header() {
225        let response = send(build_app(BodyLimit::new().limit(1024)), post_without_length()).await;
226        assert_eq!(response.status(), StatusCode::OK);
227    }
228
229    #[tokio::test]
230    async fn custom_status_code() {
231        let config = BodyLimit::new()
232            .limit(512)
233            .status(StatusCode::UNPROCESSABLE_ENTITY);
234        let response = send(build_app(config), post_with_length(1024)).await;
235        assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
236    }
237
238    #[tokio::test]
239    async fn default_layer_uses_413_and_one_mib_limit() {
240        let app = Router::new()
241            .route("/upload", post(|| async { "ok" }))
242            .layer(BodyLimitLayer::default());
243
244        let within = send(
245            app.clone(),
246            post_with_length(1_048_576),
247        )
248        .await;
249        assert_eq!(within.status(), StatusCode::OK);
250
251        let over = send(app, post_with_length(1_048_577)).await;
252        assert_eq!(over.status(), StatusCode::PAYLOAD_TOO_LARGE);
253    }
254
255    #[tokio::test]
256    async fn zero_limit_rejects_any_body() {
257        let response = send(
258            build_app(BodyLimit::new().limit(0)),
259            post_with_length(1),
260        )
261        .await;
262        assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
263    }
264}