modo/extractor/multipart.rs
1use std::collections::HashMap;
2
3use axum::extract::FromRequest;
4use http::Request;
5use serde::de::DeserializeOwned;
6
7use crate::error::Error;
8use crate::sanitize::Sanitize;
9
10/// A single uploaded file extracted from a multipart request.
11pub struct UploadedFile {
12 /// Original file name from the upload.
13 pub name: String,
14 /// MIME content type (defaults to `application/octet-stream`).
15 pub content_type: String,
16 /// Size in bytes.
17 pub size: usize,
18 /// Raw file bytes.
19 pub data: bytes::Bytes,
20}
21
22impl std::fmt::Debug for UploadedFile {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("UploadedFile")
25 .field("name", &self.name)
26 .field("content_type", &self.content_type)
27 .field("size", &self.size)
28 .finish()
29 }
30}
31
32impl UploadedFile {
33 /// Build an `UploadedFile` by consuming an axum multipart field.
34 ///
35 /// Reads the entire field body into memory. Prefer using [`MultipartRequest`]
36 /// rather than calling this directly; it is public for advanced use cases
37 /// that need to process fields individually.
38 ///
39 /// # Errors
40 ///
41 /// Returns a `400 Bad Request` error if the field body cannot be read.
42 pub async fn from_field(
43 field: axum_extra::extract::multipart::Field,
44 ) -> crate::error::Result<Self> {
45 let name = field.file_name().unwrap_or("unnamed").to_string();
46 let content_type = field
47 .content_type()
48 .unwrap_or("application/octet-stream")
49 .to_string();
50 let data = field
51 .bytes()
52 .await
53 .map_err(|e| Error::bad_request(format!("failed to read file field: {e}")))?;
54 let size = data.len();
55 Ok(Self {
56 name,
57 content_type,
58 size,
59 data,
60 })
61 }
62
63 /// Returns the file extension from the original filename in lowercase, without the leading dot.
64 ///
65 /// Returns `None` if the filename has no extension (e.g. `"readme"`) or is empty.
66 /// For compound extensions such as `"archive.tar.gz"`, only the last component (`"gz"`)
67 /// is returned.
68 pub fn extension(&self) -> Option<String> {
69 let ext = self.name.rsplit('.').next()?;
70 if ext == self.name {
71 None
72 } else {
73 Some(ext.to_ascii_lowercase())
74 }
75 }
76
77 /// Start building a fluent validation chain for this file.
78 ///
79 /// Returns an [`UploadValidator`](crate::extractor::UploadValidator) that can be used to
80 /// check size and content type. Call
81 /// [`UploadValidator::check`](crate::extractor::UploadValidator::check) to finalize and
82 /// collect any violations.
83 pub fn validate(&self) -> crate::extractor::UploadValidator<'_> {
84 crate::extractor::upload_validator::UploadValidator::new(self)
85 }
86}
87
88/// A map of field names to their uploaded files, produced by [`MultipartRequest`].
89///
90/// Files are stored by the multipart field name. Multiple files with the same field
91/// name are supported. Use [`Files::get`] for a shared reference to the first file,
92/// [`Files::file`] to take ownership of the first file, or [`Files::files`] to take
93/// all files for a given field name.
94pub struct Files(HashMap<String, Vec<UploadedFile>>);
95
96impl std::fmt::Debug for Files {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Files")
99 .field("fields", &self.0.keys().collect::<Vec<_>>())
100 .finish()
101 }
102}
103
104impl Files {
105 /// Create a [`Files`] collection from a pre-built map.
106 pub fn from_map(map: HashMap<String, Vec<UploadedFile>>) -> Self {
107 Self(map)
108 }
109
110 /// Get a shared reference to the first file under `name`, if any.
111 pub fn get(&self, name: &str) -> Option<&UploadedFile> {
112 self.0.get(name).and_then(|v| v.first())
113 }
114
115 /// Take ownership of the first file under `name`.
116 ///
117 /// Removes the field entry entirely if no files remain after the take.
118 pub fn file(&mut self, name: &str) -> Option<UploadedFile> {
119 let files = self.0.get_mut(name)?;
120 if files.is_empty() {
121 None
122 } else {
123 let file = files.remove(0);
124 if files.is_empty() {
125 self.0.remove(name);
126 }
127 Some(file)
128 }
129 }
130
131 /// Take ownership of all files under `name`.
132 ///
133 /// Returns an empty `Vec` if `name` was not present.
134 pub fn files(&mut self, name: &str) -> Vec<UploadedFile> {
135 self.0.remove(name).unwrap_or_default()
136 }
137}
138
139/// Axum extractor for `multipart/form-data` requests.
140///
141/// Splits the multipart body into text fields (deserialized and sanitized into `T`) and
142/// file fields (collected into a [`Files`] map). The inner tuple is `(T, Files)`.
143///
144/// Text fields are re-encoded as `application/x-www-form-urlencoded` and deserialized
145/// via `serde_qs` (form-encoding mode) before [`Sanitize::sanitize`] is called on the
146/// result. Repeated text fields, nested structs (`address[city]=…`), and indexed-bracket
147/// `Vec<Struct>` rows (`contacts[0][kind]=…`) all deserialize naturally — same behavior
148/// as [`crate::extractor::FormRequest`]. File fields are fully buffered into memory as
149/// [`UploadedFile`] values.
150///
151/// # Errors
152///
153/// The [`FromRequest::Rejection`] is [`crate::Error`]. A `400 Bad Request` is returned
154/// if the request is not a valid `multipart/form-data` body, a field cannot be read,
155/// or the collected text fields cannot be deserialized into `T`. The error renders via its
156/// [`IntoResponse`](axum::response::IntoResponse) impl.
157///
158/// # Example
159///
160/// ```rust,no_run
161/// use modo::extractor::{MultipartRequest, Files};
162/// use modo::sanitize::Sanitize;
163/// use serde::Deserialize;
164///
165/// #[derive(Deserialize)]
166/// struct ProfileForm {
167/// display_name: String,
168/// }
169///
170/// impl Sanitize for ProfileForm {
171/// fn sanitize(&mut self) {
172/// self.display_name = self.display_name.trim().to_string();
173/// }
174/// }
175///
176/// async fn update_profile(
177/// MultipartRequest(form, mut files): MultipartRequest<ProfileForm>,
178/// ) {
179/// let avatar = files.file("avatar"); // Option<UploadedFile>
180/// }
181/// ```
182pub struct MultipartRequest<T>(pub T, pub Files);
183
184impl<S, T> FromRequest<S> for MultipartRequest<T>
185where
186 S: Send + Sync,
187 T: DeserializeOwned + Sanitize,
188{
189 type Rejection = Error;
190
191 async fn from_request(
192 req: Request<axum::body::Body>,
193 state: &S,
194 ) -> Result<Self, Self::Rejection> {
195 let mut multipart = axum_extra::extract::Multipart::from_request(req, state)
196 .await
197 .map_err(|e| Error::bad_request(format!("invalid multipart request: {e}")))?;
198
199 let mut text_fields: Vec<(String, String)> = Vec::new();
200 let mut file_fields: HashMap<String, Vec<UploadedFile>> = HashMap::new();
201
202 while let Some(field) = multipart
203 .next_field()
204 .await
205 .map_err(|e| Error::bad_request(format!("failed to read multipart field: {e}")))?
206 {
207 let field_name = field.name().unwrap_or("").to_string();
208
209 if field.file_name().is_some() {
210 let uploaded = UploadedFile::from_field(field).await?;
211 file_fields.entry(field_name).or_default().push(uploaded);
212 } else {
213 let text = field
214 .text()
215 .await
216 .map_err(|e| Error::bad_request(format!("failed to read text field: {e}")))?;
217 text_fields.push((field_name, text));
218 }
219 }
220
221 let encoded = serde_urlencoded::to_string(&text_fields).map_err(|e| {
222 Error::bad_request(format!("failed to encode multipart text fields: {e}"))
223 })?;
224 let mut value: T = serde_qs::Config::new()
225 .use_form_encoding(true)
226 .deserialize_str(&encoded)
227 .map_err(|e| {
228 Error::bad_request(format!("failed to deserialize multipart text fields: {e}"))
229 })?;
230 value.sanitize();
231
232 Ok(MultipartRequest(value, Files(file_fields)))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn file_with_name(name: &str) -> UploadedFile {
241 UploadedFile {
242 name: name.to_string(),
243 content_type: "application/octet-stream".to_string(),
244 size: 0,
245 data: bytes::Bytes::new(),
246 }
247 }
248
249 #[test]
250 fn extension_lowercase() {
251 assert_eq!(file_with_name("photo.JPG").extension(), Some("jpg".into()));
252 }
253
254 #[test]
255 fn extension_compound() {
256 assert_eq!(
257 file_with_name("archive.tar.gz").extension(),
258 Some("gz".into())
259 );
260 }
261
262 #[test]
263 fn extension_none() {
264 assert_eq!(file_with_name("noext").extension(), None);
265 }
266
267 #[test]
268 fn extension_dotfile() {
269 assert_eq!(
270 file_with_name(".gitignore").extension(),
271 Some("gitignore".into())
272 );
273 }
274
275 #[test]
276 fn extension_empty_filename() {
277 assert_eq!(file_with_name("").extension(), None);
278 }
279}