ntex_multipart_derive/
lib.rs1use bytesize::ByteSize;
2use darling::{FromDeriveInput, FromField, FromMeta};
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::quote;
6use std::collections::HashSet;
7use syn::{Type, parse_macro_input};
8
9#[derive(Default, FromMeta)]
10enum DuplicateField {
11 #[default]
12 Ignore,
13 Deny,
14 Replace,
15}
16
17#[derive(FromDeriveInput, Default)]
18#[darling(attributes(multipart), default)]
19struct MultipartFormAttrs {
20 deny_unknown_fields: bool,
21 duplicate_field: DuplicateField,
22}
23
24#[allow(clippy::disallowed_names)] #[derive(FromField, Default)]
26#[darling(attributes(multipart), default)]
27struct FieldAttrs {
28 rename: Option<String>,
29 limit: Option<String>,
30}
31
32struct ParsedField<'t> {
33 serialization_name: String,
34 rust_name: &'t Ident,
35 limit: Option<usize>,
36 ty: &'t Type,
37}
38
39#[proc_macro_derive(MultipartForm, attributes(multipart))]
146pub fn impl_multipart_form(input: TokenStream) -> TokenStream {
147 let input: syn::DeriveInput = parse_macro_input!(input);
148
149 let name = &input.ident;
150
151 let data_struct = match &input.data {
152 syn::Data::Struct(data_struct) => data_struct,
153 _ => {
154 return compile_err(syn::Error::new(
155 input.ident.span(),
156 "`MultipartForm` can only be derived for structs",
157 ));
158 }
159 };
160
161 let fields = match &data_struct.fields {
162 syn::Fields::Named(fields_named) => fields_named,
163 _ => {
164 return compile_err(syn::Error::new(
165 input.ident.span(),
166 "`MultipartForm` can only be derived for a struct with named fields",
167 ));
168 }
169 };
170
171 let attrs = match MultipartFormAttrs::from_derive_input(&input) {
172 Ok(attrs) => attrs,
173 Err(err) => return err.write_errors().into(),
174 };
175
176 let parsed = match fields
178 .named
179 .iter()
180 .map(|field| {
181 let rust_name = field.ident.as_ref().unwrap();
182 let attrs = FieldAttrs::from_field(field).map_err(|err| err.write_errors())?;
183 let serialization_name = attrs.rename.unwrap_or_else(|| rust_name.to_string());
184
185 let limit = match attrs.limit.map(|limit| match limit.parse::<ByteSize>() {
186 Ok(ByteSize(size)) => Ok(usize::try_from(size).unwrap()),
187 Err(err) => Err(syn::Error::new(
188 field.ident.as_ref().unwrap().span(),
189 format!("Could not parse size limit `{}`: {}", limit, err),
190 )),
191 }) {
192 Some(Err(err)) => return Err(compile_err(err)),
193 limit => limit.map(Result::unwrap),
194 };
195
196 Ok(ParsedField { serialization_name, rust_name, limit, ty: &field.ty })
197 })
198 .collect::<Result<Vec<_>, TokenStream>>()
199 {
200 Ok(attrs) => attrs,
201 Err(err) => return err,
202 };
203
204 let mut set = HashSet::new();
206 for field in &parsed {
207 if !set.insert(field.serialization_name.clone()) {
208 return compile_err(syn::Error::new(
209 field.rust_name.span(),
210 format!("Multiple fields named: `{}`", field.serialization_name),
211 ));
212 }
213 }
214
215 let unknown_field_result = if attrs.deny_unknown_fields {
217 quote!(::std::result::Result::Err(::ntex_multipart::MultipartError::UnknownField(
218 field.name().unwrap().to_string()
219 )))
220 } else {
221 quote!(::std::result::Result::Ok(()))
222 };
223
224 let duplicate_field = match attrs.duplicate_field {
226 DuplicateField::Ignore => quote!(::ntex_multipart::form::DuplicateField::Ignore),
227 DuplicateField::Deny => quote!(::ntex_multipart::form::DuplicateField::Deny),
228 DuplicateField::Replace => quote!(::ntex_multipart::form::DuplicateField::Replace),
229 };
230
231 let mut limit_impl = quote!();
233 for field in &parsed {
234 let name = &field.serialization_name;
235 if let Some(value) = field.limit {
236 limit_impl.extend(quote!(
237 #name => ::std::option::Option::Some(#value),
238 ));
239 }
240 }
241
242 let mut handle_field_impl = quote!();
244 for field in &parsed {
245 let name = &field.serialization_name;
246 let ty = &field.ty;
247
248 handle_field_impl.extend(quote!(
249 #name => ::std::boxed::Box::pin(
250 <#ty as ::ntex_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field)
251 ),
252 ));
253 }
254
255 let mut from_state_impl = quote!();
257 for field in &parsed {
258 let name = &field.serialization_name;
259 let rust_name = &field.rust_name;
260 let ty = &field.ty;
261 from_state_impl.extend(quote!(
262 #rust_name: <#ty as ::ntex_multipart::form::FieldGroupReader>::from_state(#name, &mut state)?,
263 ));
264 }
265
266 let generation = quote! {
267 impl ::ntex_multipart::MultipartCollect for #name {
268 fn limit(field_name: &str) -> ::std::option::Option<usize> {
269 match field_name {
270 #limit_impl
271 _ => None,
272 }
273 }
274
275 fn handle_field<'t>(
276 req: &'t ::ntex::web::HttpRequest,
277 field: ::ntex_multipart::Field,
278 limits: &'t mut ::ntex_multipart::form::Limits,
279 state: &'t mut ::ntex_multipart::form::State,
280 ) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::std::result::Result<(), ::ntex_multipart::MultipartError>> + 't>> {
281 match field.name().unwrap() {
282 #handle_field_impl
283 _ => return ::std::boxed::Box::pin(::std::future::ready(#unknown_field_result)),
284 }
285 }
286
287 fn from_state(mut state: ::ntex_multipart::form::State) -> ::std::result::Result<Self, ::ntex_multipart::MultipartError> {
288 Ok(Self {
289 #from_state_impl
290 })
291 }
292
293 }
294 };
295 generation.into()
296}
297
298fn compile_err(err: syn::Error) -> TokenStream {
300 TokenStream::from(err.to_compile_error())
301}