1use std::{
2 io,
3 pin::Pin,
4 str::{FromStr, Utf8Error},
5 task::{ready, Context, Poll},
6};
7
8use memchr::{memchr, memmem::find, memrchr};
9use rocket::{
10 data::{DataStream, FromData, Outcome, ToByteUnit},
11 futures::StreamExt,
12 http::{ContentType, Header, HeaderMap, Status},
13 tokio::io::{AsyncBufRead, AsyncRead, ReadBuf},
14 Data, Request,
15};
16use thiserror::Error;
17use tokio_util::{
18 bytes::{BufMut, Bytes, BytesMut},
19 codec::{Decoder, FramedRead},
20};
21
22type Result<T, E = Error> = std::result::Result<T, E>;
23
24#[derive(Debug, Error)]
26pub enum Error {
27 #[error(transparent)]
29 Io(#[from] io::Error),
30 #[error(transparent)]
32 Encoding(#[from] Utf8Error),
33 #[cfg(feature = "json")]
37 #[error(transparent)]
38 Json(#[from] serde_json::Error),
39 #[error("The content type of a multipart stream must specify a boundary")]
41 BoundaryNotSpecified,
42}
43
44pub struct MultipartReadSection<'r, 'a> {
49 headers: HeaderMap<'static>,
50 reader: &'a mut MultipartReader<'r>,
51}
52
53impl<'a> MultipartReadSection<'_, 'a> {
54 pub fn headers(&self) -> &HeaderMap<'static> {
56 &self.headers
57 }
58
59 pub fn content_type(&self) -> Option<ContentType> {
61 let s = self.headers.get_one("Content-Type")?;
62 ContentType::from_str(s).ok()
63 }
64
65 pub async fn to_bytes(self) -> Result<Bytes> {
70 let mut raw_data = BytesMut::new();
71 while let MultipartFrame::Data(bytes) = &mut self.reader.buffer {
72 raw_data.unsplit(bytes.split());
73 match self.reader.stream.next().await {
74 Some(Ok(next)) => self.reader.buffer = next,
75 Some(Err(e)) => {
76 self.reader.buffer = MultipartFrame::End;
77 return Err(e);
78 }
79 None => self.reader.buffer = MultipartFrame::End,
80 }
81 }
82 Ok(raw_data.freeze())
83 }
84
85 #[cfg(feature = "json")]
89 pub async fn json<T: serde::de::DeserializeOwned>(self) -> Result<T> {
90 let bytes = self.to_bytes().await?;
91 Ok(serde_json::from_slice(&bytes)?)
92 }
93}
94
95impl AsyncRead for MultipartReadSection<'_, '_> {
96 fn poll_read(
97 self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 buf: &mut ReadBuf<'_>,
100 ) -> Poll<io::Result<()>> {
101 let mut this = self.get_mut();
103 match Pin::new(&mut this).poll_fill_buf(cx) {
104 Poll::Ready(Ok(buffer)) => {
105 let write_buf = buf.initialize_unfilled();
106 let len = buffer.len().min(write_buf.len());
107 write_buf[..len].copy_from_slice(&buffer[..len]);
108 unsafe { buf.advance_mut(len) };
109 Pin::new(this).consume(len);
110 Poll::Ready(Ok(()))
111 }
112 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
113 Poll::Pending => return Poll::Pending,
114 }
115 }
116}
117
118impl AsyncBufRead for MultipartReadSection<'_, '_> {
119 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
120 let this = self.get_mut();
121 let buffer = &mut this.reader.buffer;
122 if buffer.is_empty() {
123 match ready!(this.reader.stream.poll_next_unpin(cx)) {
124 Some(Ok(by)) => *buffer = by,
125 None => *buffer = MultipartFrame::End,
126 Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
127 }
128 }
129 if let MultipartFrame::Data(data) = buffer {
131 return Poll::Ready(Ok(data));
132 } else {
133 return Poll::Ready(Ok(b""));
134 }
135 }
136
137 fn consume(self: Pin<&mut Self>, amt: usize) {
138 let this = self.get_mut();
139 if let MultipartFrame::Data(data) = &mut this.reader.buffer {
140 let _ = data.split_to(amt.min(data.len()));
141 }
142 }
143}
144
145#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
146enum MultipartDecoderState {
147 BeforeFirstBoundary,
148 Headers,
149 Data,
150 End,
151}
152struct MultipartDecoder<'r> {
153 state: MultipartDecoderState,
154 boundary: &'r str,
155}
156
157#[derive(Debug, PartialEq, Eq, Clone, Hash)]
158enum MultipartFrame {
159 Boundary,
160 Header(Header<'static>),
161 Data(BytesMut),
162 End,
163}
164
165impl MultipartFrame {
166 fn is_empty(&self) -> bool {
167 if let Self::Data(v) = self {
168 v.is_empty()
169 } else {
170 false
171 }
172 }
173}
174
175const CHUNK_SIZE: usize = 1024;
176
177impl MultipartDecoder<'_> {
178 fn parse_header(header: &[u8]) -> Result<Header<'static>> {
179 if let Some(middle) = memchr(b':', header) {
180 Ok(Header::new(
181 std::str::from_utf8(&header[..middle])?.to_owned(),
182 std::str::from_utf8(&header[middle + 1..])?
183 .trim()
184 .to_owned(),
185 ))
186 } else {
187 todo!()
189 }
190 }
191}
192
193impl Decoder for MultipartDecoder<'_> {
194 type Item = MultipartFrame;
195 type Error = Error;
196
197 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
198 loop {
199 match self.state {
200 MultipartDecoderState::BeforeFirstBoundary => {
201 if let Some(pos) = find(src, self.boundary.as_bytes()) {
202 if pos < 2 {
203 let _ = src.split_to(pos + 1);
204 continue;
205 }
206 if pos + self.boundary.len() + 2 > src.len() {
207 let _ = src.split_to(pos - 4);
208 src.reserve(CHUNK_SIZE);
209 return Ok(None);
210 }
211 if &src[pos - 2..pos] == b"--"
212 && &src[pos + self.boundary.len()..][..2] == b"--"
213 {
214 self.state = MultipartDecoderState::End;
215 continue;
216 }
217 if &src[pos - 2..pos] != b"--"
218 && &src[pos + self.boundary.len()..][..2] != b"\r\n"
219 {
220 let _ = src.split_to(pos + self.boundary.len());
221 } else {
222 let _ = src.split_to(pos + self.boundary.len() + "\r\n".len());
223 self.state = MultipartDecoderState::Headers;
224 return Ok(Some(MultipartFrame::Boundary));
225 }
226 } else {
227 src.reserve(CHUNK_SIZE);
228 return Ok(None);
229 }
230 }
231 MultipartDecoderState::Headers => {
232 if let Some(end) = find(src, b"\r\n") {
233 let header = src.split_to(end + "\r\n".len());
234 if end == 0 {
235 self.state = MultipartDecoderState::Data;
236 continue;
237 }
238 return Ok(Some(MultipartFrame::Header(Self::parse_header(
239 &header[..],
240 )?)));
241 } else {
242 src.reserve(CHUNK_SIZE);
243 return Ok(None);
244 }
245 }
246 MultipartDecoderState::Data => {
247 if let Some(pos) = find(src, self.boundary.as_bytes()) {
248 if pos < 4 {
249 let data = src.split_to(pos + 1);
250 return Ok(Some(MultipartFrame::Data(data)));
251 }
252 if pos + self.boundary.len() + 2 > src.len() {
253 let data = src.split_to(pos - 4);
254 src.reserve(CHUNK_SIZE);
255 if data.is_empty() {
256 return Ok(None);
257 } else {
258 return Ok(Some(MultipartFrame::Data(data)));
259 }
260 }
261 if &src[pos - 4..pos] == b"\r\n--"
262 && &src[pos + self.boundary.len()..][..2] == b"--"
263 {
264 self.state = MultipartDecoderState::End;
265 if pos - 4 > 0 {
266 let data = src.split_to(pos - 4);
267 return Ok(Some(MultipartFrame::Data(data)));
268 }
269 continue;
270 }
271 if &src[pos - 4..pos] != b"\r\n--"
272 && &src[pos + self.boundary.len()..][..2] != b"\r\n"
273 {
274 let data = src.split_to(pos + self.boundary.len());
275 return Ok(Some(MultipartFrame::Data(data)));
276 } else {
277 if pos > 4 {
278 let data = src.split_to(pos - 4);
279 return Ok(Some(MultipartFrame::Data(data)));
280 }
281 let _ = src.split_to(pos + self.boundary.len() + "\r\n".len());
282 self.state = MultipartDecoderState::Headers;
283 return Ok(Some(MultipartFrame::Boundary));
284 }
285 } else {
286 let end = src
287 .len()
288 .saturating_sub(self.boundary.len() + 4) .max(memrchr(b'\r', src).unwrap_or(0));
290 let data = src.split_to(end);
291 src.reserve(CHUNK_SIZE);
292 if data.is_empty() {
293 return Ok(None);
294 } else {
295 return Ok(Some(MultipartFrame::Data(data)));
296 }
297 }
298 }
299 MultipartDecoderState::End => {
300 let _ = src.split();
301 src.reserve(CHUNK_SIZE);
302 return Ok(None);
303 }
304 }
305 }
306 }
307}
308
309pub struct MultipartReader<'r> {
337 stream: FramedRead<DataStream<'r>, MultipartDecoder<'r>>,
338 buffer: MultipartFrame,
339 content_type: &'r ContentType,
340}
341
342impl<'r> MultipartReader<'r> {
343 pub async fn next(&mut self) -> Result<Option<MultipartReadSection<'r, '_>>> {
347 while self.buffer != MultipartFrame::Boundary {
348 match self.stream.next().await {
349 Some(Ok(MultipartFrame::End)) | None => {
350 self.buffer = MultipartFrame::End;
351 return Ok(None);
352 }
353 Some(Ok(val)) => self.buffer = val,
354 Some(Err(e)) => return Err(e),
355 }
356 }
357
358 let mut headers = HeaderMap::new();
359 loop {
360 match self.stream.next().await {
361 Some(Ok(MultipartFrame::End)) | None => {
362 self.buffer = MultipartFrame::End;
363 if headers.is_empty() {
364 return Ok(None);
365 } else {
366 break;
367 }
368 }
369 Some(Ok(MultipartFrame::Header(header))) => headers.add(header),
370 Some(Ok(MultipartFrame::Boundary)) => {
371 self.buffer = MultipartFrame::Boundary;
372 break;
373 }
374 Some(Ok(val @ MultipartFrame::Data(_))) => {
375 self.buffer = val;
376 break;
377 }
378 Some(Err(e)) => return Err(e),
380 }
381 }
382 Ok(Some(MultipartReadSection {
383 headers,
384 reader: self,
385 }))
386 }
387
388 pub fn content_type(&self) -> &'r ContentType {
390 self.content_type
391 }
392}
393
394#[rocket::async_trait]
395impl<'r> FromData<'r> for MultipartReader<'r> {
396 type Error = Error;
397
398 async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
399 let limit = req
400 .rocket()
401 .config()
402 .limits
403 .get("file/multipart")
404 .unwrap_or(1.mebibytes());
405 if let Some(content_type) = req.content_type().filter(|c| c.top() == "multipart") {
406 if let Some(boundary) = content_type.param("boundary") {
407 Outcome::Success(Self {
408 stream: FramedRead::new(
409 data.open(limit),
410 MultipartDecoder {
411 state: MultipartDecoderState::BeforeFirstBoundary,
412 boundary,
413 },
414 ),
415 buffer: MultipartFrame::Data(BytesMut::new()),
416 content_type,
417 })
418 } else {
419 Outcome::Error((Status::BadRequest, Error::BoundaryNotSpecified))
420 }
421 } else {
422 Outcome::Forward((data, Status::BadRequest))
423 }
424 }
425}