Skip to main content

modo/extractor/
multipart.rs

1use std::collections::HashMap;
2
3use axum::extract::FromRequest;
4use http::Request;
5use serde::de::DeserializeOwned;
6
7use crate::error::Error;
8use crate::sanitize::Sanitize;
9
10/// A single uploaded file extracted from a multipart request.
11pub struct UploadedFile {
12    /// Original file name from the upload.
13    pub name: String,
14    /// MIME content type (defaults to `application/octet-stream`).
15    pub content_type: String,
16    /// Size in bytes.
17    pub size: usize,
18    /// Raw file bytes.
19    pub data: bytes::Bytes,
20}
21
22impl std::fmt::Debug for UploadedFile {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("UploadedFile")
25            .field("name", &self.name)
26            .field("content_type", &self.content_type)
27            .field("size", &self.size)
28            .finish()
29    }
30}
31
32impl UploadedFile {
33    /// Build an `UploadedFile` by consuming an axum multipart field.
34    ///
35    /// Reads the entire field body into memory. Prefer using [`MultipartRequest`]
36    /// rather than calling this directly; it is public for advanced use cases
37    /// that need to process fields individually.
38    ///
39    /// # Errors
40    ///
41    /// Returns a `400 Bad Request` error if the field body cannot be read.
42    pub async fn from_field(
43        field: axum_extra::extract::multipart::Field,
44    ) -> crate::error::Result<Self> {
45        let name = field.file_name().unwrap_or("unnamed").to_string();
46        let content_type = field
47            .content_type()
48            .unwrap_or("application/octet-stream")
49            .to_string();
50        let data = field
51            .bytes()
52            .await
53            .map_err(|e| Error::bad_request(format!("failed to read file field: {e}")))?;
54        let size = data.len();
55        Ok(Self {
56            name,
57            content_type,
58            size,
59            data,
60        })
61    }
62
63    /// Returns the file extension from the original filename in lowercase, without the leading dot.
64    ///
65    /// Returns `None` if the filename has no extension (e.g. `"readme"`) or is empty.
66    /// For compound extensions such as `"archive.tar.gz"`, only the last component (`"gz"`)
67    /// is returned.
68    pub fn extension(&self) -> Option<String> {
69        let ext = self.name.rsplit('.').next()?;
70        if ext == self.name {
71            None
72        } else {
73            Some(ext.to_ascii_lowercase())
74        }
75    }
76
77    /// Start building a fluent validation chain for this file.
78    ///
79    /// Returns an [`UploadValidator`](crate::extractor::UploadValidator) that can be used to
80    /// check size and content type. Call
81    /// [`UploadValidator::check`](crate::extractor::UploadValidator::check) to finalize and
82    /// collect any violations.
83    pub fn validate(&self) -> crate::extractor::UploadValidator<'_> {
84        crate::extractor::upload_validator::UploadValidator::new(self)
85    }
86}
87
88/// A map of field names to their uploaded files, produced by [`MultipartRequest`].
89///
90/// Files are stored by the multipart field name. Multiple files with the same field
91/// name are supported. Use [`Files::get`] for a shared reference to the first file,
92/// [`Files::file`] to take ownership of the first file, or [`Files::files`] to take
93/// all files for a given field name.
94pub struct Files(HashMap<String, Vec<UploadedFile>>);
95
96impl std::fmt::Debug for Files {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("Files")
99            .field("fields", &self.0.keys().collect::<Vec<_>>())
100            .finish()
101    }
102}
103
104impl Files {
105    /// Create a [`Files`] collection from a pre-built map.
106    pub fn from_map(map: HashMap<String, Vec<UploadedFile>>) -> Self {
107        Self(map)
108    }
109
110    /// Get a shared reference to the first file under `name`, if any.
111    pub fn get(&self, name: &str) -> Option<&UploadedFile> {
112        self.0.get(name).and_then(|v| v.first())
113    }
114
115    /// Take ownership of the first file under `name`.
116    ///
117    /// Removes the field entry entirely if no files remain after the take.
118    pub fn file(&mut self, name: &str) -> Option<UploadedFile> {
119        let files = self.0.get_mut(name)?;
120        if files.is_empty() {
121            None
122        } else {
123            let file = files.remove(0);
124            if files.is_empty() {
125                self.0.remove(name);
126            }
127            Some(file)
128        }
129    }
130
131    /// Take ownership of all files under `name`.
132    ///
133    /// Returns an empty `Vec` if `name` was not present.
134    pub fn files(&mut self, name: &str) -> Vec<UploadedFile> {
135        self.0.remove(name).unwrap_or_default()
136    }
137}
138
139/// Axum extractor for `multipart/form-data` requests.
140///
141/// Splits the multipart body into text fields (deserialized and sanitized into `T`) and
142/// file fields (collected into a [`Files`] map). The inner tuple is `(T, Files)`.
143///
144/// Text fields are re-encoded as `application/x-www-form-urlencoded` and deserialized
145/// via `serde_urlencoded` before [`Sanitize::sanitize`] is called on the result. File
146/// fields are fully buffered into memory as [`UploadedFile`] values.
147///
148/// # Errors
149///
150/// The [`FromRequest::Rejection`] is [`crate::Error`]. A `400 Bad Request` is returned
151/// if the request is not a valid `multipart/form-data` body, a field cannot be read,
152/// or the collected text fields cannot be deserialized into `T`. The error renders via
153/// [`crate::Error::into_response`].
154///
155/// # Example
156///
157/// ```rust,no_run
158/// use modo::extractor::{MultipartRequest, Files};
159/// use modo::sanitize::Sanitize;
160/// use serde::Deserialize;
161///
162/// #[derive(Deserialize)]
163/// struct ProfileForm {
164///     display_name: String,
165/// }
166///
167/// impl Sanitize for ProfileForm {
168///     fn sanitize(&mut self) {
169///         self.display_name = self.display_name.trim().to_string();
170///     }
171/// }
172///
173/// async fn update_profile(
174///     MultipartRequest(form, mut files): MultipartRequest<ProfileForm>,
175/// ) {
176///     let avatar = files.file("avatar"); // Option<UploadedFile>
177/// }
178/// ```
179pub struct MultipartRequest<T>(pub T, pub Files);
180
181impl<S, T> FromRequest<S> for MultipartRequest<T>
182where
183    S: Send + Sync,
184    T: DeserializeOwned + Sanitize,
185{
186    type Rejection = Error;
187
188    async fn from_request(
189        req: Request<axum::body::Body>,
190        state: &S,
191    ) -> Result<Self, Self::Rejection> {
192        let mut multipart = axum_extra::extract::Multipart::from_request(req, state)
193            .await
194            .map_err(|e| Error::bad_request(format!("invalid multipart request: {e}")))?;
195
196        let mut text_fields: Vec<(String, String)> = Vec::new();
197        let mut file_fields: HashMap<String, Vec<UploadedFile>> = HashMap::new();
198
199        while let Some(field) = multipart
200            .next_field()
201            .await
202            .map_err(|e| Error::bad_request(format!("failed to read multipart field: {e}")))?
203        {
204            let field_name = field.name().unwrap_or("").to_string();
205
206            if field.file_name().is_some() {
207                let uploaded = UploadedFile::from_field(field).await?;
208                file_fields.entry(field_name).or_default().push(uploaded);
209            } else {
210                let text = field
211                    .text()
212                    .await
213                    .map_err(|e| Error::bad_request(format!("failed to read text field: {e}")))?;
214                text_fields.push((field_name, text));
215            }
216        }
217
218        let encoded = serde_urlencoded::to_string(&text_fields).map_err(|e| {
219            Error::bad_request(format!("failed to encode multipart text fields: {e}"))
220        })?;
221        let mut value: T = serde_urlencoded::from_str(&encoded).map_err(|e| {
222            Error::bad_request(format!("failed to deserialize multipart text fields: {e}"))
223        })?;
224        value.sanitize();
225
226        Ok(MultipartRequest(value, Files(file_fields)))
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    fn file_with_name(name: &str) -> UploadedFile {
235        UploadedFile {
236            name: name.to_string(),
237            content_type: "application/octet-stream".to_string(),
238            size: 0,
239            data: bytes::Bytes::new(),
240        }
241    }
242
243    #[test]
244    fn extension_lowercase() {
245        assert_eq!(file_with_name("photo.JPG").extension(), Some("jpg".into()));
246    }
247
248    #[test]
249    fn extension_compound() {
250        assert_eq!(
251            file_with_name("archive.tar.gz").extension(),
252            Some("gz".into())
253        );
254    }
255
256    #[test]
257    fn extension_none() {
258        assert_eq!(file_with_name("noext").extension(), None);
259    }
260
261    #[test]
262    fn extension_dotfile() {
263        assert_eq!(
264            file_with_name(".gitignore").extension(),
265            Some("gitignore".into())
266        );
267    }
268
269    #[test]
270    fn extension_empty_filename() {
271        assert_eq!(file_with_name("").extension(), None);
272    }
273}