Skip to main content

modo/extractor/
multipart.rs

1use std::collections::HashMap;
2
3use axum::extract::FromRequest;
4use http::Request;
5use serde::de::DeserializeOwned;
6
7use crate::error::Error;
8use crate::sanitize::Sanitize;
9
10/// A single uploaded file extracted from a multipart request.
11pub struct UploadedFile {
12    /// Original file name from the upload.
13    pub name: String,
14    /// MIME content type (defaults to `application/octet-stream`).
15    pub content_type: String,
16    /// Size in bytes.
17    pub size: usize,
18    /// Raw file bytes.
19    pub data: bytes::Bytes,
20}
21
22impl std::fmt::Debug for UploadedFile {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("UploadedFile")
25            .field("name", &self.name)
26            .field("content_type", &self.content_type)
27            .field("size", &self.size)
28            .finish()
29    }
30}
31
32impl UploadedFile {
33    /// Build an `UploadedFile` by consuming an axum multipart field.
34    ///
35    /// Reads the entire field body into memory. Prefer using [`MultipartRequest`]
36    /// rather than calling this directly; it is public for advanced use cases
37    /// that need to process fields individually.
38    ///
39    /// # Errors
40    ///
41    /// Returns a `400 Bad Request` error if the field body cannot be read.
42    pub async fn from_field(
43        field: axum_extra::extract::multipart::Field,
44    ) -> crate::error::Result<Self> {
45        let name = field.file_name().unwrap_or("unnamed").to_string();
46        let content_type = field
47            .content_type()
48            .unwrap_or("application/octet-stream")
49            .to_string();
50        let data = field
51            .bytes()
52            .await
53            .map_err(|e| Error::bad_request(format!("failed to read file field: {e}")))?;
54        let size = data.len();
55        Ok(Self {
56            name,
57            content_type,
58            size,
59            data,
60        })
61    }
62
63    /// Returns the file extension from the original filename in lowercase, without the leading dot.
64    ///
65    /// Returns `None` if the filename has no extension (e.g. `"readme"`) or is empty.
66    /// For compound extensions such as `"archive.tar.gz"`, only the last component (`"gz"`)
67    /// is returned.
68    pub fn extension(&self) -> Option<String> {
69        let ext = self.name.rsplit('.').next()?;
70        if ext == self.name {
71            None
72        } else {
73            Some(ext.to_ascii_lowercase())
74        }
75    }
76
77    /// Start building a fluent validation chain for this file.
78    ///
79    /// Returns an [`UploadValidator`](crate::extractor::UploadValidator) that can be used to
80    /// check size and content type. Call
81    /// [`UploadValidator::check`](crate::extractor::UploadValidator::check) to finalize and
82    /// collect any violations.
83    pub fn validate(&self) -> crate::extractor::UploadValidator<'_> {
84        crate::extractor::upload_validator::UploadValidator::new(self)
85    }
86}
87
88/// A map of field names to their uploaded files, produced by [`MultipartRequest`].
89///
90/// Files are stored by the multipart field name. Multiple files with the same field
91/// name are supported. Use [`Files::get`] for a shared reference to the first file,
92/// [`Files::file`] to take ownership of the first file, or [`Files::files`] to take
93/// all files for a given field name.
94pub struct Files(HashMap<String, Vec<UploadedFile>>);
95
96impl std::fmt::Debug for Files {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("Files")
99            .field("fields", &self.0.keys().collect::<Vec<_>>())
100            .finish()
101    }
102}
103
104impl Files {
105    /// Create a [`Files`] collection from a pre-built map.
106    pub fn from_map(map: HashMap<String, Vec<UploadedFile>>) -> Self {
107        Self(map)
108    }
109
110    /// Get a shared reference to the first file under `name`, if any.
111    pub fn get(&self, name: &str) -> Option<&UploadedFile> {
112        self.0.get(name).and_then(|v| v.first())
113    }
114
115    /// Take ownership of the first file under `name`.
116    ///
117    /// Removes the field entry entirely if no files remain after the take.
118    pub fn file(&mut self, name: &str) -> Option<UploadedFile> {
119        let files = self.0.get_mut(name)?;
120        if files.is_empty() {
121            None
122        } else {
123            let file = files.remove(0);
124            if files.is_empty() {
125                self.0.remove(name);
126            }
127            Some(file)
128        }
129    }
130
131    /// Take ownership of all files under `name`.
132    ///
133    /// Returns an empty `Vec` if `name` was not present.
134    pub fn files(&mut self, name: &str) -> Vec<UploadedFile> {
135        self.0.remove(name).unwrap_or_default()
136    }
137}
138
139/// Axum extractor for `multipart/form-data` requests.
140///
141/// Splits the multipart body into text fields (deserialized and sanitized into `T`) and
142/// file fields (collected into a [`Files`] map). The inner tuple is `(T, Files)`.
143///
144/// Text fields are URL-encoded and deserialized via `serde_urlencoded` before
145/// [`Sanitize::sanitize`] is called on the result. File fields are collected into memory
146/// as [`UploadedFile`] values.
147///
148/// Returns a 400 Bad Request error if the request is not valid multipart data or a field
149/// cannot be read.
150///
151/// # Example
152///
153/// ```
154/// use modo::extractor::{MultipartRequest, Files};
155/// use modo::Sanitize;
156/// use serde::Deserialize;
157///
158/// #[derive(Deserialize)]
159/// struct ProfileForm {
160///     display_name: String,
161/// }
162///
163/// impl Sanitize for ProfileForm {
164///     fn sanitize(&mut self) {
165///         self.display_name = self.display_name.trim().to_string();
166///     }
167/// }
168///
169/// async fn update_profile(
170///     MultipartRequest(form, mut files): MultipartRequest<ProfileForm>,
171/// ) {
172///     let avatar = files.file("avatar"); // Option<UploadedFile>
173/// }
174/// ```
175pub struct MultipartRequest<T>(pub T, pub Files);
176
177impl<S, T> FromRequest<S> for MultipartRequest<T>
178where
179    S: Send + Sync,
180    T: DeserializeOwned + Sanitize,
181{
182    type Rejection = Error;
183
184    async fn from_request(
185        req: Request<axum::body::Body>,
186        state: &S,
187    ) -> Result<Self, Self::Rejection> {
188        let mut multipart = axum_extra::extract::Multipart::from_request(req, state)
189            .await
190            .map_err(|e| Error::bad_request(format!("invalid multipart request: {e}")))?;
191
192        let mut text_fields: Vec<(String, String)> = Vec::new();
193        let mut file_fields: HashMap<String, Vec<UploadedFile>> = HashMap::new();
194
195        while let Some(field) = multipart
196            .next_field()
197            .await
198            .map_err(|e| Error::bad_request(format!("failed to read multipart field: {e}")))?
199        {
200            let field_name = field.name().unwrap_or("").to_string();
201
202            if field.file_name().is_some() {
203                let uploaded = UploadedFile::from_field(field).await?;
204                file_fields.entry(field_name).or_default().push(uploaded);
205            } else {
206                let text = field
207                    .text()
208                    .await
209                    .map_err(|e| Error::bad_request(format!("failed to read text field: {e}")))?;
210                text_fields.push((field_name, text));
211            }
212        }
213
214        let encoded = serde_urlencoded::to_string(&text_fields).map_err(|e| {
215            Error::bad_request(format!("failed to encode multipart text fields: {e}"))
216        })?;
217        let mut value: T = serde_urlencoded::from_str(&encoded).map_err(|e| {
218            Error::bad_request(format!("failed to deserialize multipart text fields: {e}"))
219        })?;
220        value.sanitize();
221
222        Ok(MultipartRequest(value, Files(file_fields)))
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn file_with_name(name: &str) -> UploadedFile {
231        UploadedFile {
232            name: name.to_string(),
233            content_type: "application/octet-stream".to_string(),
234            size: 0,
235            data: bytes::Bytes::new(),
236        }
237    }
238
239    #[test]
240    fn extension_lowercase() {
241        assert_eq!(file_with_name("photo.JPG").extension(), Some("jpg".into()));
242    }
243
244    #[test]
245    fn extension_compound() {
246        assert_eq!(
247            file_with_name("archive.tar.gz").extension(),
248            Some("gz".into())
249        );
250    }
251
252    #[test]
253    fn extension_none() {
254        assert_eq!(file_with_name("noext").extension(), None);
255    }
256
257    #[test]
258    fn extension_dotfile() {
259        assert_eq!(
260            file_with_name(".gitignore").extension(),
261            Some("gitignore".into())
262        );
263    }
264
265    #[test]
266    fn extension_empty_filename() {
267        assert_eq!(file_with_name("").extension(), None);
268    }
269}