1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parse;
4use syn::parse_macro_input;
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7
8struct MacroInput {
9 name: syn::Expr,
10 fields: Punctuated<syn::LitStr, Comma>,
11}
12
13fn lit_str_parser(input: syn::parse::ParseStream) -> syn::Result<syn::LitStr> {
14 let lit = input.parse::<syn::LitStr>()?;
15 Ok(lit)
16}
17
18impl Parse for MacroInput {
19 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
20 let name = input.parse()?;
21 input.parse::<Comma>()?;
23 let fields = input.parse_terminated(lit_str_parser)?;
24 Ok(MacroInput { name, fields })
25 }
26}
27
28#[proc_macro]
36pub fn require_fields(item: TokenStream) -> TokenStream {
37 let MacroInput { name, fields } = parse_macro_input!(item as MacroInput);
39
40 let mut if_elses = quote!();
42 for field in fields {
43 let field_as_ident = syn::Ident::new(&field.value(), field.span());
44 let if_else = quote! {
45 if s.#field_as_ident.is_none() {
46 missing_fields.push(#field);
47 }
48 };
49 if_elses.extend(if_else);
50 }
51
52 quote! {
53 {
54 let s = #name;
55 let mut missing_fields = Vec::new();
56
57 #if_elses
58
59 if missing_fields.is_empty() {
60 Ok(())
61 } else {
62 Err(missing_fields)
63 }
64 }
65 }
66 .into()
67}
68
69#[proc_macro]
75pub fn has_fields(item: TokenStream) -> TokenStream {
76 let MacroInput { name, fields } = parse_macro_input!(item as MacroInput);
78
79 let mut if_elses = quote!();
81 for field in fields {
82 let field_as_ident = syn::Ident::new(&field.value(), field.span());
83 let if_else = quote! {
84 if s.#field_as_ident.is_none() {
85 return false;
86 }
87 };
88 if_elses.extend(if_else);
89 }
90
91 quote! {
92 (|| {
93 let s = #name;
94 #if_elses
95 true
96 })()
97 }
98 .into()
99}
100
101#[proc_macro_derive(HasFields)]
102pub fn derive_has_fields(item: TokenStream) -> TokenStream {
103 let ast = syn::parse(item).unwrap();
104 impl_has_fields(&ast)
105}
106
107fn impl_has_fields(ast: &syn::DeriveInput) -> TokenStream {
108 let name = &ast.ident;
109 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
110
111 let mut fields = Vec::new();
112 let mut non_option_fields = Vec::new();
113 if let syn::Data::Struct(syn::DataStruct {
114 fields: syn::Fields::Named(syn::FieldsNamed { named, .. }),
115 ..
116 }) = &ast.data
117 {
118 for field in named {
119 if let syn::Type::Path(syn::TypePath { path, .. }) = &field.ty {
121 if let Some(syn::PathSegment { ident, .. }) = path.segments.last() {
122 if ident == "Option" {
123 fields.push(field.ident.as_ref().unwrap());
124 continue;
125 }
126 }
127 }
128 non_option_fields.push(field.ident.as_ref().unwrap());
129 }
130 }
131
132 let mut if_elses = quote!();
133 for field in fields {
134 let if_else = quote! {
135 if self.#field.is_some() {
136 count += 1;
137 }
138 };
139 if_elses.extend(if_else);
140 }
141
142 let non_option_fields = non_option_fields.len();
143
144 quote! {
145 impl #impl_generics HasFields for #name #ty_generics #where_clause {
146 fn num_fields(&self) -> usize {
147 let mut count = #non_option_fields;
148 #if_elses
149 count
150 }
151 }
152 }
153 .into()
154}