1use std::ops::Deref;
17
18use bytes::Bytes;
19use poem::error::ReadBodyError;
20use poem::http::StatusCode;
21use poem::{Error, FromRequest, Request, RequestBody, Result};
22
23#[derive(Debug, Clone, Copy)]
28pub struct RawBodyLimit(pub usize);
29
30#[derive(Debug, Clone)]
33pub struct RawBody(pub Bytes);
34
35impl RawBody {
36 pub const DEFAULT_LIMIT: usize = 2 * 1024 * 1024;
39
40 pub fn into_inner(self) -> Bytes {
42 self.0
43 }
44
45 pub async fn extract_with_limit(body: &mut RequestBody, limit: usize) -> Result<Self> {
48 let raw = body.take()?;
49 match raw.into_bytes_limit(limit).await {
50 Ok(bytes) => Ok(Self(bytes)),
51 Err(ReadBodyError::PayloadTooLarge) => {
52 Err(Error::from_status(StatusCode::PAYLOAD_TOO_LARGE))
53 }
54 Err(err) => Err(err.into()),
55 }
56 }
57}
58
59impl Deref for RawBody {
60 type Target = Bytes;
61 fn deref(&self) -> &Bytes {
62 &self.0
63 }
64}
65
66impl<'a> FromRequest<'a> for RawBody {
67 async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
68 let limit = req
69 .extensions()
70 .get::<RawBodyLimit>()
71 .map(|l| l.0)
72 .unwrap_or(Self::DEFAULT_LIMIT);
73 Self::extract_with_limit(body, limit).await
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 use poem::Body;
82
83 fn request_with_body(payload: impl Into<Body>) -> (Request, RequestBody) {
84 Request::builder().body(payload).split()
85 }
86
87 #[tokio::test]
88 async fn happy_path_reads_the_full_payload() {
89 let (req, mut body) = request_with_body("hello world");
90 let raw = RawBody::from_request(&req, &mut body).await.expect("read");
91 assert_eq!(&raw.0[..], b"hello world");
92 assert_eq!(raw.len(), 11); }
94
95 #[tokio::test]
96 async fn empty_body_yields_empty_bytes() {
97 let (req, mut body) = request_with_body(Body::empty());
98 let raw = RawBody::from_request(&req, &mut body).await.expect("read");
99 assert!(raw.0.is_empty());
100 }
101
102 #[tokio::test]
103 async fn oversize_body_returns_413_payload_too_large() {
104 let payload = vec![b'x'; RawBody::DEFAULT_LIMIT + 1];
106 let (req, mut body) = request_with_body(payload);
107 let err = RawBody::from_request(&req, &mut body)
108 .await
109 .expect_err("over the cap");
110 let resp = err.into_response();
111 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
112 }
113
114 #[tokio::test]
115 async fn extract_with_limit_enforces_the_caller_cap() {
116 let payload = vec![b'x'; 64];
118 let (_req, mut body) = request_with_body(payload);
119 let err = RawBody::extract_with_limit(&mut body, 32)
120 .await
121 .expect_err("over the tight cap");
122 let resp = err.into_response();
123 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
124 }
125
126 #[tokio::test]
127 async fn extract_with_limit_passes_when_payload_fits() {
128 let payload = vec![b'x'; 32];
129 let (_req, mut body) = request_with_body(payload);
130 let raw = RawBody::extract_with_limit(&mut body, 32)
131 .await
132 .expect("fits");
133 assert_eq!(raw.0.len(), 32);
134 }
135
136 #[tokio::test]
137 async fn request_extension_limit_overrides_the_default() {
138 let mut req = Request::builder().body(vec![b'x'; 64]);
141 req.extensions_mut().insert(RawBodyLimit(32));
142 let (req, mut body) = req.split();
143 let err = RawBody::from_request(&req, &mut body)
144 .await
145 .expect_err("over the extension cap");
146 let resp = err.into_response();
147 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
148 }
149
150 #[tokio::test]
151 async fn request_extension_limit_passes_when_payload_fits() {
152 let mut req = Request::builder().body(vec![b'x'; 32]);
155 req.extensions_mut().insert(RawBodyLimit(32));
156 let (req, mut body) = req.split();
157 let raw = RawBody::from_request(&req, &mut body).await.expect("fits");
158 assert_eq!(raw.0.len(), 32);
159 }
160
161 #[tokio::test]
162 async fn missing_extension_falls_back_to_default_limit() {
163 let (req, mut body) = request_with_body("hi");
166 let raw = RawBody::from_request(&req, &mut body).await.expect("fits");
167 assert_eq!(&raw.0[..], b"hi");
168 }
169}