validator_derive/
lib.rs

1use darling::ast::Data;
2use darling::util::{Override, WithOriginal};
3use darling::FromDeriveInput;
4use proc_macro_error2::{abort, proc_macro_error};
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, DeriveInput, Field, GenericParam, Path, PathArguments};
7
8use tokens::cards::credit_card_tokens;
9use tokens::contains::contains_tokens;
10use tokens::custom::custom_tokens;
11use tokens::does_not_contain::does_not_contain_tokens;
12use tokens::email::email_tokens;
13use tokens::ip::ip_tokens;
14use tokens::length::length_tokens;
15use tokens::must_match::must_match_tokens;
16use tokens::nested::nested_tokens;
17use tokens::non_control_character::non_control_char_tokens;
18use tokens::range::range_tokens;
19use tokens::regex::regex_tokens;
20use tokens::required::required_tokens;
21use tokens::schema::schema_tokens;
22use tokens::url::url_tokens;
23use types::*;
24use utils::{quote_use_stmts, CrateName};
25
26mod tokens;
27mod types;
28mod utils;
29
30impl ToTokens for ValidateField {
31    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
32        let field_name = self.ident.clone().unwrap();
33        let field_name_str = self.ident.clone().unwrap().to_string();
34
35        let type_name = self.ty.to_token_stream().to_string();
36        let is_number = NUMBER_TYPES.contains(&type_name);
37
38        let (actual_field, wrapper_closure) = self.if_let_option_wrapper(&field_name, is_number);
39
40        // Length validation
41        let length = if let Some(length) = self.length.clone() {
42            wrapper_closure(length_tokens(&self.crate_name, length, &actual_field, &field_name_str))
43        } else {
44            quote!()
45        };
46
47        // Email validation
48        let email = if let Some(email) = self.email.clone() {
49            wrapper_closure(email_tokens(
50                &self.crate_name,
51                match email {
52                    Override::Inherit => Email::default(),
53                    Override::Explicit(e) => e,
54                },
55                &actual_field,
56                &field_name_str,
57            ))
58        } else {
59            quote!()
60        };
61
62        // Credit card validation
63        let card = if let Some(credit_card) = self.credit_card.clone() {
64            wrapper_closure(credit_card_tokens(
65                &self.crate_name,
66                match credit_card {
67                    Override::Inherit => Card::default(),
68                    Override::Explicit(c) => c,
69                },
70                &actual_field,
71                &field_name_str,
72            ))
73        } else {
74            quote!()
75        };
76
77        // Url validation
78        let url = if let Some(url) = self.url.clone() {
79            wrapper_closure(url_tokens(
80                &self.crate_name,
81                match url {
82                    Override::Inherit => Url::default(),
83                    Override::Explicit(u) => u,
84                },
85                &actual_field,
86                &field_name_str,
87            ))
88        } else {
89            quote!()
90        };
91
92        // Ip address validation
93        let ip = if let Some(ip) = self.ip.clone() {
94            wrapper_closure(ip_tokens(
95                &self.crate_name,
96                match ip {
97                    Override::Inherit => Ip::default(),
98                    Override::Explicit(i) => i,
99                },
100                &actual_field,
101                &field_name_str,
102            ))
103        } else {
104            quote!()
105        };
106
107        // Non control character validation
108        let ncc = if let Some(ncc) = self.non_control_character.clone() {
109            wrapper_closure(non_control_char_tokens(
110                &self.crate_name,
111                match ncc {
112                    Override::Inherit => NonControlCharacter::default(),
113                    Override::Explicit(n) => n,
114                },
115                &actual_field,
116                &field_name_str,
117            ))
118        } else {
119            quote!()
120        };
121
122        // Range validation
123        let range = if let Some(range) = self.range.clone() {
124            wrapper_closure(range_tokens(&self.crate_name, range, &actual_field, &field_name_str))
125        } else {
126            quote!()
127        };
128
129        // Required validation
130        let required = if let Some(required) = self.required.clone() {
131            required_tokens(
132                &self.crate_name,
133                match required {
134                    Override::Inherit => Required::default(),
135                    Override::Explicit(r) => r,
136                },
137                &field_name,
138                &field_name_str,
139            )
140        } else {
141            quote!()
142        };
143
144        // Contains validation
145        let contains = if let Some(contains) = self.contains.clone() {
146            wrapper_closure(contains_tokens(
147                &self.crate_name,
148                contains,
149                &actual_field,
150                &field_name_str,
151            ))
152        } else {
153            quote!()
154        };
155
156        // Does not contain validation
157        let does_not_contain = if let Some(does_not_contain) = self.does_not_contain.clone() {
158            wrapper_closure(does_not_contain_tokens(
159                &self.crate_name,
160                does_not_contain,
161                &actual_field,
162                &field_name_str,
163            ))
164        } else {
165            quote!()
166        };
167
168        // Must match validation
169        let must_match = if let Some(must_match) = self.must_match.clone() {
170            // TODO: handle option for other
171            wrapper_closure(must_match_tokens(
172                &self.crate_name,
173                must_match,
174                &actual_field,
175                &field_name_str,
176            ))
177        } else {
178            quote!()
179        };
180
181        // Regex validation
182        let regex = if let Some(regex) = self.regex.clone() {
183            wrapper_closure(regex_tokens(&self.crate_name, regex, &actual_field, &field_name_str))
184        } else {
185            quote!()
186        };
187
188        // Custom validation
189        let mut custom = quote!();
190        // We try to be smart when passing arguments
191        let is_cow = type_name.contains("Cow <");
192        let custom_actual_field = if is_cow {
193            quote!(#actual_field.as_ref())
194        } else if is_number || type_name.starts_with("&") {
195            quote!(#actual_field)
196        } else {
197            quote!(&#actual_field)
198        };
199
200        for c in &self.custom {
201            let tokens = custom_tokens(c.clone(), &custom_actual_field, &field_name_str);
202            custom = quote!(
203                #custom
204
205                #tokens
206            );
207        }
208        if !self.custom.is_empty() {
209            custom = wrapper_closure(custom);
210        }
211
212        let nested = if let Some(n) = self.nested {
213            if n {
214                wrapper_closure(nested_tokens(&actual_field, &field_name_str))
215            } else {
216                quote!()
217            }
218        } else {
219            quote!()
220        };
221
222        tokens.extend(quote! {
223            #length
224            #email
225            #card
226            #url
227            #ip
228            #ncc
229            #range
230            #required
231            #contains
232            #does_not_contain
233            #must_match
234            #regex
235            #custom
236            #nested
237        });
238    }
239}
240
241// The main struct we get from parsing the attributes
242// The "supports(struct_named)" attribute guarantees only named structs to work with this macro
243#[derive(Debug, FromDeriveInput)]
244#[darling(attributes(validate), supports(struct_named))]
245#[darling(and_then = "ValidationData::validate")]
246struct ValidationData {
247    ident: syn::Ident,
248    generics: syn::Generics,
249    data: Data<(), WithOriginal<ValidateField, syn::Field>>,
250    #[darling(multiple)]
251    schema: Vec<Schema>,
252    context: Option<Path>,
253    mutable: Option<bool>,
254    nest_all_fields: Option<bool>,
255    /// The name of the crate to use for the generated code,
256    /// defaults to `validator`.
257    #[darling(rename = "crate", default)]
258    crate_name: CrateName,
259}
260
261impl ValidationData {
262    fn validate(self) -> darling::Result<Self> {
263        if let Some(context) = &self.context {
264            // Check if context lifetime is not `'v_a`
265            for segment in &context.segments {
266                match &segment.arguments {
267                    PathArguments::AngleBracketed(args) => {
268                        for arg in &args.args {
269                            match arg {
270                                syn::GenericArgument::Lifetime(lt) => {
271                                    if lt.ident != "v_a" {
272                                        abort! {
273                                            lt.ident, "Invalid argument reference";
274                                            note = "The lifetime `'{}` is not supported.", lt.ident;
275                                            help = "Please use the validator lifetime `'v_a`";
276                                        }
277                                    }
278                                }
279                                _ => (),
280                            }
281                        }
282                    }
283                    _ => (),
284                }
285            }
286        }
287
288        match &self.data {
289            Data::Struct(fields) => {
290                let original_fields: Vec<&Field> =
291                    fields.fields.iter().map(|f| &f.original).collect();
292                for f in &fields.fields {
293                    f.parsed.validate(&self.ident, &original_fields, &f.original);
294                }
295            }
296            _ => (),
297        }
298
299        Ok(self)
300    }
301}
302
303#[proc_macro_error]
304#[proc_macro_derive(Validate, attributes(validate))]
305pub fn derive_validation(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
306    let input: DeriveInput = parse_macro_input!(input);
307
308    // parse the input to the ValidationData struct defined above
309    let validation_data = match ValidationData::from_derive_input(&input) {
310        Ok(data) => data,
311        Err(e) => return e.write_errors().into(),
312    };
313
314    let crate_name = validation_data.crate_name;
315
316    let custom_context = if let Some(context) = &validation_data.context {
317        if let Some(mutable) = validation_data.mutable {
318            if mutable {
319                quote!(&'v_a mut #context)
320            } else {
321                quote!(&'v_a #context)
322            }
323        } else {
324            quote!(&'v_a #context)
325        }
326    } else {
327        quote!(())
328    };
329
330    // get all the fields to quote them below
331    let mut validation_fields: Vec<ValidateField> = validation_data
332        .data
333        .take_struct()
334        .unwrap()
335        .fields
336        .into_iter()
337        .map(|f| f.parsed)
338        // skip fields with #[validate(skip)] attribute
339        .filter(|f| if let Some(s) = f.skip { !s } else { true })
340        .map(|f| ValidateField { crate_name: crate_name.clone(), ..f })
341        .collect();
342
343    if let Some(nest_all_fields) = validation_data.nest_all_fields {
344        if nest_all_fields {
345            validation_fields = validation_fields
346                .iter_mut()
347                .map(|f| {
348                    f.nested = Some(true);
349                    f.to_owned()
350                })
351                .collect();
352        }
353    }
354
355    // generate `use` statements for all used validator traits
356    let use_statements = quote_use_stmts(&crate_name, &validation_fields);
357
358    // Schema validation
359    let schema = validation_data.schema.iter().fold(quote!(), |acc, s| {
360        let st = schema_tokens(s.clone());
361        let acc = quote! {
362            #acc
363            #st
364        };
365        acc
366    });
367
368    let ident = validation_data.ident;
369    let (imp, ty, whr) = validation_data.generics.split_for_impl();
370
371    let struct_generics_quote =
372        validation_data.generics.params.iter().fold(quote!(), |mut q, g| {
373            if let GenericParam::Type(t) = g {
374                // Default types are not allowed in trait impl
375                if t.default.is_some() {
376                    let mut t2 = t.clone();
377                    t2.default = None;
378                    let g2 = GenericParam::Type(t2);
379                    q.extend(quote!(#g2, ));
380                } else {
381                    q.extend(quote!(#g, ));
382                }
383            } else {
384                q.extend(quote!(#g, ));
385            }
386            q
387        });
388
389    let imp_args = if struct_generics_quote.is_empty() {
390        quote!(<'v_a>)
391    } else {
392        quote!(<'v_a, #struct_generics_quote>)
393    };
394
395    let argless_validation = if validation_data.context.is_none() {
396        quote! {
397            impl #imp #crate_name::Validate for #ident #ty #whr {
398                fn validate(&self) -> ::std::result::Result<(), #crate_name::ValidationErrors> {
399                    use #crate_name::ValidateArgs;
400                    self.validate_with_args(())
401                }
402            }
403        }
404    } else {
405        quote!()
406    };
407
408    quote!(
409        #argless_validation
410
411        impl #imp_args #crate_name::ValidateArgs<'v_a> for #ident #ty #whr {
412            type Args = #custom_context;
413
414            fn validate_with_args(&self, args: Self::Args)
415            -> ::std::result::Result<(), #crate_name::ValidationErrors>
416             {
417                #use_statements
418
419                let mut errors = #crate_name::ValidationErrors::new();
420
421                #(#validation_fields)*
422
423                #schema
424
425                if errors.is_empty() {
426                    ::std::result::Result::Ok(())
427                } else {
428                    ::std::result::Result::Err(errors)
429                }
430            }
431        }
432    )
433    .into()
434}