1use std::{
2 fmt::{self, Debug, Formatter},
3 str::FromStr,
4};
5
6use futures_util::TryStreamExt;
7use mime::Mime;
8#[cfg(feature = "tempfile")]
9use tokio::fs::File;
10use tokio::io::{AsyncRead, AsyncReadExt};
11#[cfg(feature = "tempfile")]
12use tokio::io::{AsyncSeekExt, SeekFrom};
13
14use crate::{FromRequest, Request, RequestBody, Result, error::ParseMultipartError, http::header};
15
16#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
18pub struct Field(multer::Field<'static>);
19
20impl Debug for Field {
21 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
22 let mut d = f.debug_struct("Field");
23
24 if let Some(name) = self.name() {
25 d.field("name", &name);
26 }
27
28 if let Some(file_name) = self.file_name() {
29 d.field("file_name", &file_name);
30 }
31
32 if let Some(content_type) = self.content_type() {
33 d.field("content_type", &content_type);
34 }
35
36 d.finish()
37 }
38}
39
40impl Field {
41 #[inline]
43 pub fn content_type(&self) -> Option<&str> {
44 self.0.content_type().map(|mime| mime.essence_str())
45 }
46
47 #[inline]
49 pub fn file_name(&self) -> Option<&str> {
50 self.0.file_name()
51 }
52
53 #[inline]
55 pub fn name(&self) -> Option<&str> {
56 self.0.name()
57 }
58
59 pub async fn bytes(self) -> Result<Vec<u8>, ParseMultipartError> {
61 let mut data = Vec::new();
62 let mut buf = [0; 2048];
63 let mut reader = self.into_async_read();
64 loop {
65 let sz = reader.read(&mut buf[..]).await?;
66 if sz > 0 {
67 data.extend_from_slice(&buf[..sz]);
68 } else {
69 break;
70 }
71 }
72
73 Ok(data)
74 }
75
76 #[inline]
78 pub async fn text(self) -> Result<String, ParseMultipartError> {
79 Ok(String::from_utf8(self.bytes().await?)?)
80 }
81
82 #[cfg(feature = "tempfile")]
84 #[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))]
85 pub async fn tempfile(self) -> Result<File, ParseMultipartError> {
86 let mut reader = self.into_async_read();
87 let mut file = tokio::fs::File::from_std(::libtempfile::tempfile()?);
88 tokio::io::copy(&mut reader, &mut file).await?;
89 file.seek(SeekFrom::Start(0)).await?;
90 Ok(file)
91 }
92
93 pub fn into_async_read(self) -> impl AsyncRead + Send {
95 tokio_util::io::StreamReader::new(
96 self.0.map_err(|err| std::io::Error::other(err.to_string())),
97 )
98 }
99}
100
101#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
127pub struct Multipart {
128 inner: multer::Multipart<'static>,
129}
130
131impl<'a> FromRequest<'a> for Multipart {
132 async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
133 let content_type = req
134 .headers()
135 .get(header::CONTENT_TYPE)
136 .and_then(|err| err.to_str().ok())
137 .and_then(|value| Mime::from_str(value).ok())
138 .ok_or(ParseMultipartError::ContentTypeRequired)?;
139
140 if content_type.essence_str() != mime::MULTIPART_FORM_DATA {
141 return Err(ParseMultipartError::InvalidContentType(
142 content_type.essence_str().to_string(),
143 )
144 .into());
145 }
146
147 let boundary = multer::parse_boundary(content_type.as_ref())
148 .map_err(ParseMultipartError::Multipart)?;
149 Ok(Self {
150 inner: multer::Multipart::new(
151 tokio_util::io::ReaderStream::new(body.take()?.into_async_read()),
152 boundary,
153 ),
154 })
155 }
156}
157
158impl Multipart {
159 pub async fn next_field(&mut self) -> Result<Option<Field>, ParseMultipartError> {
161 match self.inner.next_field().await? {
162 Some(field) => Ok(Some(Field(field))),
163 None => Ok(None),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::{handler, http::StatusCode, test::TestClient};
172
173 #[tokio::test]
174 async fn test_multipart_extractor_content_type() {
175 #[handler(internal)]
176 async fn index(_multipart: Multipart) {
177 todo!()
178 }
179
180 let cli = TestClient::new(index);
181 let resp = cli
182 .post("/")
183 .header("content-type", "multipart/json; boundary=X-BOUNDARY")
184 .body(())
185 .send()
186 .await;
187 resp.assert_status(StatusCode::UNSUPPORTED_MEDIA_TYPE);
188 }
189
190 #[tokio::test]
191 async fn test_multipart_extractor() {
192 #[handler(internal)]
193 async fn index(mut multipart: Multipart) {
194 let field = multipart.next_field().await.unwrap().unwrap();
195 assert_eq!(field.name(), Some("my_text_field"));
196 assert_eq!(field.text().await.unwrap(), "abcd");
197
198 let field = multipart.next_field().await.unwrap().unwrap();
199 assert_eq!(field.name(), Some("my_file_field"));
200 assert_eq!(field.file_name(), Some("a-text-file.txt"));
201 assert_eq!(field.content_type(), Some("text/plain"));
202 assert_eq!(
203 field.text().await.unwrap(),
204 "Hello world\nHello\r\nWorld\rAgain"
205 );
206 }
207
208 let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_file_field\"; filename=\"a-text-file.txt\"\r\nContent-Type: text/plain\r\n\r\nHello world\nHello\r\nWorld\rAgain\r\n--X-BOUNDARY--\r\n";
209 let cli = TestClient::new(index);
210
211 let resp = cli
212 .post("/")
213 .header("content-type", "multipart/form-data; boundary=X-BOUNDARY")
214 .body(data)
215 .send()
216 .await;
217 resp.assert_status_is_ok();
218 }
219}