1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
use self::private::DefaultBodyLimitService;
use tower_layer::Layer;

/// Layer for configuring the default request body limit.
///
/// For security reasons, [`Bytes`] will, by default, not accept bodies larger than 2MB. This also
/// applies to extractors that uses [`Bytes`] internally such as `String`, [`Json`], and [`Form`].
///
/// This middleware provides ways to configure that.
///
/// Note that if an extractor consumes the body directly with [`Body::poll_frame`], or similar, the
/// default limit is _not_ applied.
///
/// # Difference between `DefaultBodyLimit` and [`RequestBodyLimit`]
///
/// `DefaultBodyLimit` and [`RequestBodyLimit`] serve similar functions but in different ways.
///
/// `DefaultBodyLimit` is local in that it only applies to [`FromRequest`] implementations that
/// explicitly apply it (or call another extractor that does). You can apply the limit with
/// [`RequestExt::with_limited_body`] or [`RequestExt::into_limited_body`]
///
/// [`RequestBodyLimit`] is applied globally to all requests, regardless of which extractors are
/// used or how the body is consumed.
///
/// # Example
///
/// ```
/// use axum::{
///     Router,
///     routing::post,
///     body::Body,
///     extract::{Request, DefaultBodyLimit},
/// };
///
/// let app = Router::new()
///     .route("/", post(|request: Request| async {}))
///     // change the default limit
///     .layer(DefaultBodyLimit::max(1024));
/// # let _: Router = app;
/// ```
///
/// In general using `DefaultBodyLimit` is recommended but if you need to use third party
/// extractors and want to make sure a limit is also applied there then [`RequestBodyLimit`] should
/// be used.
///
/// # Different limits for different routes
///
/// `DefaultBodyLimit` can also be selectively applied to have different limits for different
/// routes:
///
/// ```
/// use axum::{
///     Router,
///     routing::post,
///     body::Body,
///     extract::{Request, DefaultBodyLimit},
/// };
///
/// let app = Router::new()
///     // this route has a different limit
///     .route("/", post(|request: Request| async {}).layer(DefaultBodyLimit::max(1024)))
///     // this route still has the default limit
///     .route("/foo", post(|request: Request| async {}));
/// # let _: Router = app;
/// ```
///
/// [`Body::poll_frame`]: http_body::Body::poll_frame
/// [`Bytes`]: bytes::Bytes
/// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html
/// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html
/// [`FromRequest`]: crate::extract::FromRequest
/// [`RequestBodyLimit`]: tower_http::limit::RequestBodyLimit
/// [`RequestExt::with_limited_body`]: crate::RequestExt::with_limited_body
/// [`RequestExt::into_limited_body`]: crate::RequestExt::into_limited_body
#[derive(Debug, Clone)]
#[must_use]
pub struct DefaultBodyLimit {
    kind: DefaultBodyLimitKind,
}

#[derive(Debug, Clone, Copy)]
pub(crate) enum DefaultBodyLimitKind {
    Disable,
    Limit(usize),
}

impl DefaultBodyLimit {
    /// Disable the default request body limit.
    ///
    /// This must be used to receive bodies larger than the default limit of 2MB using [`Bytes`] or
    /// an extractor built on it such as `String`, [`Json`], [`Form`].
    ///
    /// Note that if you're accepting data from untrusted remotes it is recommend to add your own
    /// limit such as [`tower_http::limit`].
    ///
    /// # Example
    ///
    /// ```
    /// use axum::{
    ///     Router,
    ///     routing::get,
    ///     body::{Bytes, Body},
    ///     extract::DefaultBodyLimit,
    /// };
    /// use tower_http::limit::RequestBodyLimitLayer;
    /// use http_body_util::Limited;
    ///
    /// let app: Router<()> = Router::new()
    ///     .route("/", get(|body: Bytes| async {}))
    ///     // Disable the default limit
    ///     .layer(DefaultBodyLimit::disable())
    ///     // Set a different limit
    ///     .layer(RequestBodyLimitLayer::new(10 * 1000 * 1000));
    /// ```
    ///
    /// [`Bytes`]: bytes::Bytes
    /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html
    /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html
    pub fn disable() -> Self {
        Self {
            kind: DefaultBodyLimitKind::Disable,
        }
    }

    /// Set the default request body limit.
    ///
    /// By default the limit of request body sizes that [`Bytes::from_request`] (and other
    /// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method
    /// can be used to change that limit.
    ///
    /// # Example
    ///
    /// ```
    /// use axum::{
    ///     Router,
    ///     routing::get,
    ///     body::{Bytes, Body},
    ///     extract::DefaultBodyLimit,
    /// };
    /// use tower_http::limit::RequestBodyLimitLayer;
    /// use http_body_util::Limited;
    ///
    /// let app: Router<()> = Router::new()
    ///     .route("/", get(|body: Bytes| async {}))
    ///     // Replace the default of 2MB with 1024 bytes.
    ///     .layer(DefaultBodyLimit::max(1024));
    /// ```
    ///
    /// [`Bytes::from_request`]: bytes::Bytes
    /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html
    /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html
    pub fn max(limit: usize) -> Self {
        Self {
            kind: DefaultBodyLimitKind::Limit(limit),
        }
    }
}

impl<S> Layer<S> for DefaultBodyLimit {
    type Service = DefaultBodyLimitService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        DefaultBodyLimitService {
            inner,
            kind: self.kind,
        }
    }
}

mod private {
    use super::DefaultBodyLimitKind;
    use http::Request;
    use std::task::Context;
    use tower_service::Service;

    #[derive(Debug, Clone, Copy)]
    pub struct DefaultBodyLimitService<S> {
        pub(super) inner: S,
        pub(super) kind: DefaultBodyLimitKind,
    }

    impl<B, S> Service<Request<B>> for DefaultBodyLimitService<S>
    where
        S: Service<Request<B>>,
    {
        type Response = S::Response;
        type Error = S::Error;
        type Future = S::Future;

        #[inline]
        fn poll_ready(&mut self, cx: &mut Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
            self.inner.poll_ready(cx)
        }

        #[inline]
        fn call(&mut self, mut req: Request<B>) -> Self::Future {
            req.extensions_mut().insert(self.kind);
            self.inner.call(req)
        }
    }
}