has_fields_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parse;
4use syn::parse_macro_input;
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7
8struct MacroInput {
9    name: syn::Expr,
10    fields: Punctuated<syn::LitStr, Comma>,
11}
12
13fn lit_str_parser(input: syn::parse::ParseStream) -> syn::Result<syn::LitStr> {
14    let lit = input.parse::<syn::LitStr>()?;
15    Ok(lit)
16}
17
18impl Parse for MacroInput {
19    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
20        let name = input.parse()?;
21        // Skip the comma
22        input.parse::<Comma>()?;
23        let fields = input.parse_terminated(lit_str_parser)?;
24        Ok(MacroInput { name, fields })
25    }
26}
27
28/// Check if given fields are Some(...) or not
29///
30/// If has missing fields, return a Err with the missing fields
31///
32/// If all fields are Some(...), return Ok(())
33///
34/// `require_fields!(&form, "field1", "field2", "...")`
35#[proc_macro]
36pub fn require_fields(item: TokenStream) -> TokenStream {
37    // First: parse the input
38    let MacroInput { name, fields } = parse_macro_input!(item as MacroInput);
39
40    // Second: generate if-else statements
41    let mut if_elses = quote!();
42    for field in fields {
43        let field_as_ident = syn::Ident::new(&field.value(), field.span());
44        let if_else = quote! {
45            if s.#field_as_ident.is_none() {
46                missing_fields.push(#field);
47            }
48        };
49        if_elses.extend(if_else);
50    }
51
52    quote! {
53        {
54            let s = #name;
55            let mut missing_fields = Vec::new();
56
57            #if_elses
58
59            if missing_fields.is_empty() {
60                Ok(())
61            } else {
62                Err(missing_fields)
63            }
64        }
65    }
66    .into()
67}
68
69/// Check if given fields are Some(...) or not
70///
71/// Gives a boolean result
72///
73/// `has_fields!(&form, "field1", "field2", "...")`
74#[proc_macro]
75pub fn has_fields(item: TokenStream) -> TokenStream {
76    // First: parse the input
77    let MacroInput { name, fields } = parse_macro_input!(item as MacroInput);
78
79    // Second: generate if-else statements
80    let mut if_elses = quote!();
81    for field in fields {
82        let field_as_ident = syn::Ident::new(&field.value(), field.span());
83        let if_else = quote! {
84            if s.#field_as_ident.is_none() {
85                return false;
86            }
87        };
88        if_elses.extend(if_else);
89    }
90
91    quote! {
92        (|| {
93            let s = #name;
94            #if_elses
95            true
96        })()
97    }
98    .into()
99}
100
101#[proc_macro_derive(HasFields)]
102pub fn derive_has_fields(item: TokenStream) -> TokenStream {
103    let ast = syn::parse(item).unwrap();
104    impl_has_fields(&ast)
105}
106
107fn impl_has_fields(ast: &syn::DeriveInput) -> TokenStream {
108    let name = &ast.ident;
109    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
110
111    let mut fields = Vec::new();
112    let mut non_option_fields = Vec::new();
113    if let syn::Data::Struct(syn::DataStruct {
114        fields: syn::Fields::Named(syn::FieldsNamed { named, .. }),
115        ..
116    }) = &ast.data
117    {
118        for field in named {
119            // if type of field is Option<T>, then we can use it
120            if let syn::Type::Path(syn::TypePath { path, .. }) = &field.ty {
121                if let Some(syn::PathSegment { ident, .. }) = path.segments.last() {
122                    if ident == "Option" {
123                        fields.push(field.ident.as_ref().unwrap());
124                        continue;
125                    }
126                }
127            }
128            non_option_fields.push(field.ident.as_ref().unwrap());
129        }
130    }
131
132    let mut if_elses = quote!();
133    for field in fields {
134        let if_else = quote! {
135            if self.#field.is_some() {
136                count += 1;
137            }
138        };
139        if_elses.extend(if_else);
140    }
141
142    let non_option_fields = non_option_fields.len();
143
144    quote! {
145        impl #impl_generics HasFields for #name #ty_generics #where_clause {
146            fn num_fields(&self) -> usize {
147                let mut count = #non_option_fields;
148                #if_elses
149                count
150            }
151        }
152    }
153    .into()
154}