actix_multipart_derive_impl/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
5    Meta, MetaList, MetaNameValue, NestedMeta, Path,
6};
7
8#[proc_macro_derive(MultipartForm, attributes(multipart))]
9pub fn derive(input: TokenStream) -> TokenStream {
10    let ast = parse_macro_input!(input as DeriveInput);
11
12    let name = ast.ident;
13
14    let b_name = format!("{}MultipartBuilder", name);
15    let b_ident = Ident::new(&b_name, name.span());
16
17    let fields = if let Data::Struct(DataStruct {
18        fields: Fields::Named(FieldsNamed { ref named, .. }),
19        ..
20    }) = ast.data
21    {
22        named
23    } else {
24        unimplemented!()
25    };
26
27    let optioned = fields.iter().map(|f| {
28        let Field { ident, ty, .. } = f;
29        quote! { #ident: ::std::option::Option<#ty> }
30    });
31
32    let field_max_sizes = fields.iter().map(|f| {
33        let Field { ident, attrs, .. } = f;
34
35        for attr in attrs {
36            // TODO: use something like https://github.com/TedDriggs/darling ??
37
38            if let Ok(m) = attr.parse_meta() {
39                if let Meta::List(MetaList { path, nested, .. }) = m {
40                    if path.get_ident().unwrap()
41                        != &Ident::new("multipart", proc_macro2::Span::call_site())
42                    {
43                        continue;
44                    }
45
46                    // it's our meta list, marked by multipart
47
48                    if let Some(NestedMeta::Meta(Meta::NameValue(MetaNameValue {
49                        path: Path { segments, .. },
50                        lit,
51                        ..
52                    }))) = nested.first()
53                    {
54                        for seg in segments {
55                            // if there's a max_size attr in the list, extract the lit
56                            if &seg.ident
57                                == &Ident::new(
58                                    "max_size",
59                                    proc_macro2::Span::call_site(),
60                                )
61                            {
62                                // TODO: ensure literal is numeric
63                                return quote! { stringify!(#ident) => Some(#lit) };
64                            }
65                        }
66                    }
67                }
68            }
69        }
70
71        quote! { stringify!(#ident) => None }
72    });
73
74    let build_fields = fields.iter().map(|f| {
75        let Field { ident, .. } = f;
76        quote! { #ident: self.#ident.unwrap() }
77    });
78
79    let bytes_appending = fields.iter().map(|f| {
80        let Field { ident, .. } = f;
81
82        quote! {
83           stringify!(#ident) => field_bytes.put(chunk),
84        }
85    });
86
87    let fields_from_bytes = fields.iter().map(|f| {
88        let Field { ident, ty, .. } = f;
89
90        quote! {
91           stringify!(#ident) => {
92                builder.#ident.replace(#ty::from_bytes(field_bytes));
93            }
94        }
95    });
96
97    let expanded = quote! {
98        #[derive(Debug, Clone, Default)]
99        struct #b_ident {
100            #(#optioned,)*
101        }
102
103        impl #b_ident {
104            fn max_size(field: &str) -> Option<usize> {
105                match field {
106                    #(#field_max_sizes,)*
107                    _ => None,
108                }
109            }
110
111            fn build(self) -> Result<#name, ::actix_web::Error> {
112                Ok(Form {
113                    #(#build_fields,)*
114                })
115            }
116        }
117
118        impl ::actix_web::FromRequest for #name {
119            type Error = ::actix_web::Error;
120            type Future = ::futures_util::future::LocalBoxFuture<'static, Result<Self, Self::Error>>;
121            type Config = ();
122
123            fn from_request(req: &::actix_web::HttpRequest, payload: &mut ::actix_web::dev::Payload) -> Self::Future {
124                use ::futures_util::future::FutureExt;
125                use ::futures_util::stream::StreamExt;
126                use ::actix_web::{error, web::{BufMut, BytesMut}};
127                use ::actix_multipart::Multipart;
128                use ::actix_multipart_derive::FromBytes;
129
130                let pl = payload.take();
131                let req2 = req.clone();
132
133                async move {
134                    let mut mp = ::actix_multipart::Multipart::new(req2.headers(), pl);
135
136                    let mut builder = #b_ident::default();
137
138                    while let Some(item) = mp.next().await {
139                        let mut field = item?;
140
141                        let headers = field.headers();
142
143                        let cd = field.content_disposition().unwrap();
144                        let name = cd.get_name().unwrap();
145
146                        let mut size = 0;
147                        let mut field_bytes = BytesMut::new();
148
149                        while let Some(chunk) = field.next().await {
150                            let chunk = chunk?;
151                            size += chunk.len();
152
153                            if (size > #b_ident::max_size(&name).unwrap_or(std::usize::MAX)) {
154                                return Err(error::ErrorPayloadTooLarge("field is too large"));
155                            }
156
157                            match name {
158                                #(#bytes_appending)*
159
160                                _ => {
161                                    // unknown field
162                                },
163                            }
164                        }
165
166                        let field_bytes = field_bytes.freeze();
167
168                        match name {
169                            #(#fields_from_bytes)*
170
171                            _ => {
172                                // unknown field
173                            },
174                        }
175                    }
176
177                    builder.build()
178                }
179                .boxed_local()
180            }
181        }
182    };
183
184    expanded.into()
185}