poem/web/
multipart.rs

1use std::{
2    fmt::{self, Debug, Formatter},
3    str::FromStr,
4};
5
6use futures_util::TryStreamExt;
7use mime::Mime;
8#[cfg(feature = "tempfile")]
9use tokio::fs::File;
10use tokio::io::{AsyncRead, AsyncReadExt};
11#[cfg(feature = "tempfile")]
12use tokio::io::{AsyncSeekExt, SeekFrom};
13
14use crate::{FromRequest, Request, RequestBody, Result, error::ParseMultipartError, http::header};
15
16/// A single field in a multipart stream.
17#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
18pub struct Field(multer::Field<'static>);
19
20impl Debug for Field {
21    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
22        let mut d = f.debug_struct("Field");
23
24        if let Some(name) = self.name() {
25            d.field("name", &name);
26        }
27
28        if let Some(file_name) = self.file_name() {
29            d.field("file_name", &file_name);
30        }
31
32        if let Some(content_type) = self.content_type() {
33            d.field("content_type", &content_type);
34        }
35
36        d.finish()
37    }
38}
39
40impl Field {
41    /// Get the content type of the field.
42    #[inline]
43    pub fn content_type(&self) -> Option<&str> {
44        self.0.content_type().map(|mime| mime.essence_str())
45    }
46
47    /// The file name found in the `Content-Disposition` header.
48    #[inline]
49    pub fn file_name(&self) -> Option<&str> {
50        self.0.file_name()
51    }
52
53    /// The name found in the `Content-Disposition` header.
54    #[inline]
55    pub fn name(&self) -> Option<&str> {
56        self.0.name()
57    }
58
59    /// Get the full data of the field as bytes.
60    pub async fn bytes(self) -> Result<Vec<u8>, ParseMultipartError> {
61        let mut data = Vec::new();
62        let mut buf = [0; 2048];
63        let mut reader = self.into_async_read();
64        loop {
65            let sz = reader.read(&mut buf[..]).await?;
66            if sz > 0 {
67                data.extend_from_slice(&buf[..sz]);
68            } else {
69                break;
70            }
71        }
72
73        Ok(data)
74    }
75
76    /// Get the full field data as text.
77    #[inline]
78    pub async fn text(self) -> Result<String, ParseMultipartError> {
79        Ok(String::from_utf8(self.bytes().await?)?)
80    }
81
82    /// Write the full field data to a temporary file and return it.
83    #[cfg(feature = "tempfile")]
84    #[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))]
85    pub async fn tempfile(self) -> Result<File, ParseMultipartError> {
86        let mut reader = self.into_async_read();
87        let mut file = tokio::fs::File::from_std(::libtempfile::tempfile()?);
88        tokio::io::copy(&mut reader, &mut file).await?;
89        file.seek(SeekFrom::Start(0)).await?;
90        Ok(file)
91    }
92
93    /// Consume this field to return a reader.
94    pub fn into_async_read(self) -> impl AsyncRead + Send {
95        tokio_util::io::StreamReader::new(
96            self.0.map_err(|err| std::io::Error::other(err.to_string())),
97        )
98    }
99}
100
101/// An extractor that parses `multipart/form-data` requests commonly used with
102/// file uploads.
103///
104/// # Errors
105///
106/// - [`ReadBodyError`](crate::error::ReadBodyError)
107/// - [`ParseMultipartError`]
108///
109/// # Example
110///
111/// ```
112/// use poem::{
113///     Result,
114///     error::{BadRequest, Error},
115///     web::Multipart,
116/// };
117///
118/// async fn upload(mut multipart: Multipart) -> Result<()> {
119///     while let Some(field) = multipart.next_field().await? {
120///         let data = field.bytes().await.map_err(BadRequest)?;
121///         println!("{} bytes", data.len());
122///     }
123///     Ok(())
124/// }
125/// ```
126#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
127pub struct Multipart {
128    inner: multer::Multipart<'static>,
129}
130
131impl<'a> FromRequest<'a> for Multipart {
132    async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
133        let content_type = req
134            .headers()
135            .get(header::CONTENT_TYPE)
136            .and_then(|err| err.to_str().ok())
137            .and_then(|value| Mime::from_str(value).ok())
138            .ok_or(ParseMultipartError::ContentTypeRequired)?;
139
140        if content_type.essence_str() != mime::MULTIPART_FORM_DATA {
141            return Err(ParseMultipartError::InvalidContentType(
142                content_type.essence_str().to_string(),
143            )
144            .into());
145        }
146
147        let boundary = multer::parse_boundary(content_type.as_ref())
148            .map_err(ParseMultipartError::Multipart)?;
149        Ok(Self {
150            inner: multer::Multipart::new(
151                tokio_util::io::ReaderStream::new(body.take()?.into_async_read()),
152                boundary,
153            ),
154        })
155    }
156}
157
158impl Multipart {
159    /// Yields the next [`Field`] if available.
160    pub async fn next_field(&mut self) -> Result<Option<Field>, ParseMultipartError> {
161        match self.inner.next_field().await? {
162            Some(field) => Ok(Some(Field(field))),
163            None => Ok(None),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::{handler, http::StatusCode, test::TestClient};
172
173    #[tokio::test]
174    async fn test_multipart_extractor_content_type() {
175        #[handler(internal)]
176        async fn index(_multipart: Multipart) {
177            todo!()
178        }
179
180        let cli = TestClient::new(index);
181        let resp = cli
182            .post("/")
183            .header("content-type", "multipart/json; boundary=X-BOUNDARY")
184            .body(())
185            .send()
186            .await;
187        resp.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
188    }
189
190    #[tokio::test]
191    async fn test_multipart_extractor() {
192        #[handler(internal)]
193        async fn index(mut multipart: Multipart) {
194            let field = multipart.next_field().await.unwrap().unwrap();
195            assert_eq!(field.name(), Some("my_text_field"));
196            assert_eq!(field.text().await.unwrap(), "abcd");
197
198            let field = multipart.next_field().await.unwrap().unwrap();
199            assert_eq!(field.name(), Some("my_file_field"));
200            assert_eq!(field.file_name(), Some("a-text-file.txt"));
201            assert_eq!(field.content_type(), Some("text/plain"));
202            assert_eq!(
203                field.text().await.unwrap(),
204                "Hello world\nHello\r\nWorld\rAgain"
205            );
206        }
207
208        let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_file_field\"; filename=\"a-text-file.txt\"\r\nContent-Type: text/plain\r\n\r\nHello world\nHello\r\nWorld\rAgain\r\n--X-BOUNDARY--\r\n";
209        let cli = TestClient::new(index);
210
211        let resp = cli
212            .post("/")
213            .header("content-type", "multipart/form-data; boundary=X-BOUNDARY")
214            .body(data)
215            .send()
216            .await;
217        resp.assert_status_is_ok();
218    }
219}