modo/extractor/
multipart.rs1use std::collections::HashMap;
2
3use axum::extract::FromRequest;
4use http::Request;
5use serde::de::DeserializeOwned;
6
7use crate::error::Error;
8use crate::sanitize::Sanitize;
9
10pub struct UploadedFile {
12 pub name: String,
14 pub content_type: String,
16 pub size: usize,
18 pub data: bytes::Bytes,
20}
21
22impl std::fmt::Debug for UploadedFile {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("UploadedFile")
25 .field("name", &self.name)
26 .field("content_type", &self.content_type)
27 .field("size", &self.size)
28 .finish()
29 }
30}
31
32impl UploadedFile {
33 pub async fn from_field(
43 field: axum_extra::extract::multipart::Field,
44 ) -> crate::error::Result<Self> {
45 let name = field.file_name().unwrap_or("unnamed").to_string();
46 let content_type = field
47 .content_type()
48 .unwrap_or("application/octet-stream")
49 .to_string();
50 let data = field
51 .bytes()
52 .await
53 .map_err(|e| Error::bad_request(format!("failed to read file field: {e}")))?;
54 let size = data.len();
55 Ok(Self {
56 name,
57 content_type,
58 size,
59 data,
60 })
61 }
62
63 pub fn extension(&self) -> Option<String> {
69 let ext = self.name.rsplit('.').next()?;
70 if ext == self.name {
71 None
72 } else {
73 Some(ext.to_ascii_lowercase())
74 }
75 }
76
77 pub fn validate(&self) -> crate::extractor::UploadValidator<'_> {
84 crate::extractor::upload_validator::UploadValidator::new(self)
85 }
86}
87
88pub struct Files(HashMap<String, Vec<UploadedFile>>);
95
96impl std::fmt::Debug for Files {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Files")
99 .field("fields", &self.0.keys().collect::<Vec<_>>())
100 .finish()
101 }
102}
103
104impl Files {
105 pub fn from_map(map: HashMap<String, Vec<UploadedFile>>) -> Self {
107 Self(map)
108 }
109
110 pub fn get(&self, name: &str) -> Option<&UploadedFile> {
112 self.0.get(name).and_then(|v| v.first())
113 }
114
115 pub fn file(&mut self, name: &str) -> Option<UploadedFile> {
119 let files = self.0.get_mut(name)?;
120 if files.is_empty() {
121 None
122 } else {
123 let file = files.remove(0);
124 if files.is_empty() {
125 self.0.remove(name);
126 }
127 Some(file)
128 }
129 }
130
131 pub fn files(&mut self, name: &str) -> Vec<UploadedFile> {
135 self.0.remove(name).unwrap_or_default()
136 }
137}
138
139pub struct MultipartRequest<T>(pub T, pub Files);
180
181impl<S, T> FromRequest<S> for MultipartRequest<T>
182where
183 S: Send + Sync,
184 T: DeserializeOwned + Sanitize,
185{
186 type Rejection = Error;
187
188 async fn from_request(
189 req: Request<axum::body::Body>,
190 state: &S,
191 ) -> Result<Self, Self::Rejection> {
192 let mut multipart = axum_extra::extract::Multipart::from_request(req, state)
193 .await
194 .map_err(|e| Error::bad_request(format!("invalid multipart request: {e}")))?;
195
196 let mut text_fields: Vec<(String, String)> = Vec::new();
197 let mut file_fields: HashMap<String, Vec<UploadedFile>> = HashMap::new();
198
199 while let Some(field) = multipart
200 .next_field()
201 .await
202 .map_err(|e| Error::bad_request(format!("failed to read multipart field: {e}")))?
203 {
204 let field_name = field.name().unwrap_or("").to_string();
205
206 if field.file_name().is_some() {
207 let uploaded = UploadedFile::from_field(field).await?;
208 file_fields.entry(field_name).or_default().push(uploaded);
209 } else {
210 let text = field
211 .text()
212 .await
213 .map_err(|e| Error::bad_request(format!("failed to read text field: {e}")))?;
214 text_fields.push((field_name, text));
215 }
216 }
217
218 let encoded = serde_urlencoded::to_string(&text_fields).map_err(|e| {
219 Error::bad_request(format!("failed to encode multipart text fields: {e}"))
220 })?;
221 let mut value: T = serde_urlencoded::from_str(&encoded).map_err(|e| {
222 Error::bad_request(format!("failed to deserialize multipart text fields: {e}"))
223 })?;
224 value.sanitize();
225
226 Ok(MultipartRequest(value, Files(file_fields)))
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 fn file_with_name(name: &str) -> UploadedFile {
235 UploadedFile {
236 name: name.to_string(),
237 content_type: "application/octet-stream".to_string(),
238 size: 0,
239 data: bytes::Bytes::new(),
240 }
241 }
242
243 #[test]
244 fn extension_lowercase() {
245 assert_eq!(file_with_name("photo.JPG").extension(), Some("jpg".into()));
246 }
247
248 #[test]
249 fn extension_compound() {
250 assert_eq!(
251 file_with_name("archive.tar.gz").extension(),
252 Some("gz".into())
253 );
254 }
255
256 #[test]
257 fn extension_none() {
258 assert_eq!(file_with_name("noext").extension(), None);
259 }
260
261 #[test]
262 fn extension_dotfile() {
263 assert_eq!(
264 file_with_name(".gitignore").extension(),
265 Some("gitignore".into())
266 );
267 }
268
269 #[test]
270 fn extension_empty_filename() {
271 assert_eq!(file_with_name("").extension(), None);
272 }
273}