Skip to main content

ntex_multipart_derive/
lib.rs

1use bytesize::ByteSize;
2use darling::{FromDeriveInput, FromField, FromMeta};
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::quote;
6use std::collections::HashSet;
7use syn::{Type, parse_macro_input};
8
9#[derive(Default, FromMeta)]
10enum DuplicateField {
11    #[default]
12    Ignore,
13    Deny,
14    Replace,
15}
16
17#[derive(FromDeriveInput, Default)]
18#[darling(attributes(multipart), default)]
19struct MultipartFormAttrs {
20    deny_unknown_fields: bool,
21    duplicate_field: DuplicateField,
22}
23
24#[allow(clippy::disallowed_names)] // false positive in macro expansion
25#[derive(FromField, Default)]
26#[darling(attributes(multipart), default)]
27struct FieldAttrs {
28    rename: Option<String>,
29    limit: Option<String>,
30}
31
32struct ParsedField<'t> {
33    serialization_name: String,
34    rust_name: &'t Ident,
35    limit: Option<usize>,
36    ty: &'t Type,
37}
38
39/// Implements `MultipartCollect` for a struct so that it can be used with the `MultipartForm`
40/// extractor.
41///
42/// # Basic Use
43///
44/// Each field type should implement the `FieldReader` trait:
45///
46/// ```
47/// use ntex_multipart::MultipartForm;
48/// use ntex_multipart::form::{temp_file::TempFile, text::Text};
49///
50/// #[derive(MultipartForm)]
51/// struct ImageUpload {
52///     description: Text<String>,
53///     timestamp: Text<i64>,
54///     image: TempFile,
55/// }
56/// ```
57///
58/// # Optional and List Fields
59///
60/// You can also use `Vec<T>` and `Option<T>` provided that `T: FieldReader`.
61///
62/// A [`Vec`] field corresponds to an upload with multiple parts under the [same field
63/// name](https://www.rfc-editor.org/rfc/rfc7578#section-4.3).
64///
65/// ```
66/// use ntex_multipart::MultipartForm;
67/// use ntex_multipart::form::{temp_file::TempFile, text::Text};
68///
69/// #[derive(MultipartForm)]
70/// struct Form {
71///     category: Option<Text<String>>,
72///     files: Vec<TempFile>,
73/// }
74/// ```
75///
76/// # Field Renaming
77///
78/// You can use the `#[multipart(rename = "foo")]` attribute to receive a field by a different name.
79///
80/// ```
81/// use ntex_multipart::MultipartForm;
82/// use ntex_multipart::form::temp_file::TempFile;
83///
84/// #[derive(MultipartForm)]
85/// struct Form {
86///     #[multipart(rename = "files[]")]
87///     files: Vec<TempFile>,
88/// }
89/// ```
90///
91/// # Field Limits
92///
93/// You can use the `#[multipart(limit = "<size>")]` attribute to set field level limits. The limit
94/// string is parsed using [`bytesize`].
95///
96/// Note: the form is also subject to the global limits configured using `MultipartFormConfig`.
97///
98/// ```
99/// use ntex_multipart::MultipartForm;
100/// use ntex_multipart::form::{temp_file::TempFile, text::Text};
101///
102/// #[derive(MultipartForm)]
103/// struct Form {
104///     #[multipart(limit = "2 KiB")]
105///     description: Text<String>,
106///
107///     #[multipart(limit = "512 MiB")]
108///     files: Vec<TempFile>,
109/// }
110/// ```
111///
112/// # Unknown Fields
113///
114/// By default fields with an unknown name are ignored. They can be rejected using the
115/// `#[multipart(deny_unknown_fields)]` attribute:
116///
117/// ```
118/// use ntex_multipart::MultipartForm;
119///
120/// #[derive(MultipartForm)]
121/// #[multipart(deny_unknown_fields)]
122/// struct Form { }
123/// ```
124///
125/// # Duplicate Fields
126///
127/// The behaviour for when multiple fields with the same name are received can be changed using the
128/// `#[multipart(duplicate_field = "<behavior>")]` attribute:
129///
130/// - "ignore": (default) Extra fields are ignored. I.e., the first one is persisted.
131/// - "deny": A `MultipartError::UnknownField` error response is returned.
132/// - "replace": Each field is processed, but only the last one is persisted.
133///
134/// Note that `Vec` fields will ignore this option.
135///
136/// ```
137/// use ntex_multipart::MultipartForm;
138///
139/// #[derive(MultipartForm)]
140/// #[multipart(duplicate_field = "deny")]
141/// struct Form { }
142/// ```
143///
144/// [`bytesize`]: https://docs.rs/bytesize/2
145#[proc_macro_derive(MultipartForm, attributes(multipart))]
146pub fn impl_multipart_form(input: TokenStream) -> TokenStream {
147    let input: syn::DeriveInput = parse_macro_input!(input);
148
149    let name = &input.ident;
150
151    let data_struct = match &input.data {
152        syn::Data::Struct(data_struct) => data_struct,
153        _ => {
154            return compile_err(syn::Error::new(
155                input.ident.span(),
156                "`MultipartForm` can only be derived for structs",
157            ));
158        }
159    };
160
161    let fields = match &data_struct.fields {
162        syn::Fields::Named(fields_named) => fields_named,
163        _ => {
164            return compile_err(syn::Error::new(
165                input.ident.span(),
166                "`MultipartForm` can only be derived for a struct with named fields",
167            ));
168        }
169    };
170
171    let attrs = match MultipartFormAttrs::from_derive_input(&input) {
172        Ok(attrs) => attrs,
173        Err(err) => return err.write_errors().into(),
174    };
175
176    // Parse the field attributes
177    let parsed = match fields
178        .named
179        .iter()
180        .map(|field| {
181            let rust_name = field.ident.as_ref().unwrap();
182            let attrs = FieldAttrs::from_field(field).map_err(|err| err.write_errors())?;
183            let serialization_name = attrs.rename.unwrap_or_else(|| rust_name.to_string());
184
185            let limit = match attrs.limit.map(|limit| match limit.parse::<ByteSize>() {
186                Ok(ByteSize(size)) => Ok(usize::try_from(size).unwrap()),
187                Err(err) => Err(syn::Error::new(
188                    field.ident.as_ref().unwrap().span(),
189                    format!("Could not parse size limit `{}`: {}", limit, err),
190                )),
191            }) {
192                Some(Err(err)) => return Err(compile_err(err)),
193                limit => limit.map(Result::unwrap),
194            };
195
196            Ok(ParsedField { serialization_name, rust_name, limit, ty: &field.ty })
197        })
198        .collect::<Result<Vec<_>, TokenStream>>()
199    {
200        Ok(attrs) => attrs,
201        Err(err) => return err,
202    };
203
204    // Check that field names are unique
205    let mut set = HashSet::new();
206    for field in &parsed {
207        if !set.insert(field.serialization_name.clone()) {
208            return compile_err(syn::Error::new(
209                field.rust_name.span(),
210                format!("Multiple fields named: `{}`", field.serialization_name),
211            ));
212        }
213    }
214
215    // Return value when a field name is not supported by the form
216    let unknown_field_result = if attrs.deny_unknown_fields {
217        quote!(::std::result::Result::Err(::ntex_multipart::MultipartError::UnknownField(
218            field.name().unwrap().to_string()
219        )))
220    } else {
221        quote!(::std::result::Result::Ok(()))
222    };
223
224    // Value for duplicate action
225    let duplicate_field = match attrs.duplicate_field {
226        DuplicateField::Ignore => quote!(::ntex_multipart::form::DuplicateField::Ignore),
227        DuplicateField::Deny => quote!(::ntex_multipart::form::DuplicateField::Deny),
228        DuplicateField::Replace => quote!(::ntex_multipart::form::DuplicateField::Replace),
229    };
230
231    // limit() implementation
232    let mut limit_impl = quote!();
233    for field in &parsed {
234        let name = &field.serialization_name;
235        if let Some(value) = field.limit {
236            limit_impl.extend(quote!(
237                #name => ::std::option::Option::Some(#value),
238            ));
239        }
240    }
241
242    // handle_field() implementation
243    let mut handle_field_impl = quote!();
244    for field in &parsed {
245        let name = &field.serialization_name;
246        let ty = &field.ty;
247
248        handle_field_impl.extend(quote!(
249            #name => ::std::boxed::Box::pin(
250                <#ty as ::ntex_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field)
251            ),
252        ));
253    }
254
255    // from_state() implementation
256    let mut from_state_impl = quote!();
257    for field in &parsed {
258        let name = &field.serialization_name;
259        let rust_name = &field.rust_name;
260        let ty = &field.ty;
261        from_state_impl.extend(quote!(
262            #rust_name: <#ty as ::ntex_multipart::form::FieldGroupReader>::from_state(#name, &mut state)?,
263        ));
264    }
265
266    let generation = quote! {
267        impl ::ntex_multipart::MultipartCollect for #name {
268            fn limit(field_name: &str) -> ::std::option::Option<usize> {
269                match field_name {
270                    #limit_impl
271                    _ => None,
272                }
273            }
274
275            fn handle_field<'t>(
276                req: &'t ::ntex::web::HttpRequest,
277                field: ::ntex_multipart::Field,
278                limits: &'t mut ::ntex_multipart::form::Limits,
279                state: &'t mut ::ntex_multipart::form::State,
280            ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::std::result::Result<(), ::ntex_multipart::MultipartError>> + 't>> {
281                match field.name().unwrap() {
282                    #handle_field_impl
283                    _ => return ::std::boxed::Box::pin(::std::future::ready(#unknown_field_result)),
284                }
285            }
286
287            fn from_state(mut state: ::ntex_multipart::form::State) -> ::std::result::Result<Self, ::ntex_multipart::MultipartError> {
288                Ok(Self {
289                    #from_state_impl
290                })
291            }
292
293        }
294    };
295    generation.into()
296}
297
298/// Transform a syn error into a token stream for returning.
299fn compile_err(err: syn::Error) -> TokenStream {
300    TokenStream::from(err.to_compile_error())
301}