modo/extractor/
multipart.rs1use std::collections::HashMap;
2
3use axum::extract::FromRequest;
4use http::Request;
5use serde::de::DeserializeOwned;
6
7use crate::error::Error;
8use crate::sanitize::Sanitize;
9
10pub struct UploadedFile {
12 pub name: String,
14 pub content_type: String,
16 pub size: usize,
18 pub data: bytes::Bytes,
20}
21
22impl std::fmt::Debug for UploadedFile {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("UploadedFile")
25 .field("name", &self.name)
26 .field("content_type", &self.content_type)
27 .field("size", &self.size)
28 .finish()
29 }
30}
31
32impl UploadedFile {
33 pub async fn from_field(
43 field: axum_extra::extract::multipart::Field,
44 ) -> crate::error::Result<Self> {
45 let name = field.file_name().unwrap_or("unnamed").to_string();
46 let content_type = field
47 .content_type()
48 .unwrap_or("application/octet-stream")
49 .to_string();
50 let data = field
51 .bytes()
52 .await
53 .map_err(|e| Error::bad_request(format!("failed to read file field: {e}")))?;
54 let size = data.len();
55 Ok(Self {
56 name,
57 content_type,
58 size,
59 data,
60 })
61 }
62
63 pub fn extension(&self) -> Option<String> {
69 let ext = self.name.rsplit('.').next()?;
70 if ext == self.name {
71 None
72 } else {
73 Some(ext.to_ascii_lowercase())
74 }
75 }
76
77 pub fn validate(&self) -> crate::extractor::UploadValidator<'_> {
84 crate::extractor::upload_validator::UploadValidator::new(self)
85 }
86}
87
88pub struct Files(HashMap<String, Vec<UploadedFile>>);
95
96impl std::fmt::Debug for Files {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Files")
99 .field("fields", &self.0.keys().collect::<Vec<_>>())
100 .finish()
101 }
102}
103
104impl Files {
105 pub fn from_map(map: HashMap<String, Vec<UploadedFile>>) -> Self {
107 Self(map)
108 }
109
110 pub fn get(&self, name: &str) -> Option<&UploadedFile> {
112 self.0.get(name).and_then(|v| v.first())
113 }
114
115 pub fn file(&mut self, name: &str) -> Option<UploadedFile> {
119 let files = self.0.get_mut(name)?;
120 if files.is_empty() {
121 None
122 } else {
123 let file = files.remove(0);
124 if files.is_empty() {
125 self.0.remove(name);
126 }
127 Some(file)
128 }
129 }
130
131 pub fn files(&mut self, name: &str) -> Vec<UploadedFile> {
135 self.0.remove(name).unwrap_or_default()
136 }
137}
138
139pub struct MultipartRequest<T>(pub T, pub Files);
176
177impl<S, T> FromRequest<S> for MultipartRequest<T>
178where
179 S: Send + Sync,
180 T: DeserializeOwned + Sanitize,
181{
182 type Rejection = Error;
183
184 async fn from_request(
185 req: Request<axum::body::Body>,
186 state: &S,
187 ) -> Result<Self, Self::Rejection> {
188 let mut multipart = axum_extra::extract::Multipart::from_request(req, state)
189 .await
190 .map_err(|e| Error::bad_request(format!("invalid multipart request: {e}")))?;
191
192 let mut text_fields: Vec<(String, String)> = Vec::new();
193 let mut file_fields: HashMap<String, Vec<UploadedFile>> = HashMap::new();
194
195 while let Some(field) = multipart
196 .next_field()
197 .await
198 .map_err(|e| Error::bad_request(format!("failed to read multipart field: {e}")))?
199 {
200 let field_name = field.name().unwrap_or("").to_string();
201
202 if field.file_name().is_some() {
203 let uploaded = UploadedFile::from_field(field).await?;
204 file_fields.entry(field_name).or_default().push(uploaded);
205 } else {
206 let text = field
207 .text()
208 .await
209 .map_err(|e| Error::bad_request(format!("failed to read text field: {e}")))?;
210 text_fields.push((field_name, text));
211 }
212 }
213
214 let encoded = serde_urlencoded::to_string(&text_fields).map_err(|e| {
215 Error::bad_request(format!("failed to encode multipart text fields: {e}"))
216 })?;
217 let mut value: T = serde_urlencoded::from_str(&encoded).map_err(|e| {
218 Error::bad_request(format!("failed to deserialize multipart text fields: {e}"))
219 })?;
220 value.sanitize();
221
222 Ok(MultipartRequest(value, Files(file_fields)))
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn file_with_name(name: &str) -> UploadedFile {
231 UploadedFile {
232 name: name.to_string(),
233 content_type: "application/octet-stream".to_string(),
234 size: 0,
235 data: bytes::Bytes::new(),
236 }
237 }
238
239 #[test]
240 fn extension_lowercase() {
241 assert_eq!(file_with_name("photo.JPG").extension(), Some("jpg".into()));
242 }
243
244 #[test]
245 fn extension_compound() {
246 assert_eq!(
247 file_with_name("archive.tar.gz").extension(),
248 Some("gz".into())
249 );
250 }
251
252 #[test]
253 fn extension_none() {
254 assert_eq!(file_with_name("noext").extension(), None);
255 }
256
257 #[test]
258 fn extension_dotfile() {
259 assert_eq!(
260 file_with_name(".gitignore").extension(),
261 Some("gitignore".into())
262 );
263 }
264
265 #[test]
266 fn extension_empty_filename() {
267 assert_eq!(file_with_name("").extension(), None);
268 }
269}