Skip to main content

eml_codec_derives/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, ItemFn};
4use syn::{punctuated::Punctuated, token::Comma, Attribute, LitStr, Variant};
5
6// derive(ToStringFromPrint) ---------------------------------------------------
7
8#[proc_macro_derive(ToStringFromPrint)]
9pub fn derive_to_string_from_print(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    let name = input.ident;
12    let (impl_generics, ty_generics, where_clauses) = input.generics.split_for_impl();
13
14    let expanded = quote! {
15        impl #impl_generics ToString for #name #ty_generics #where_clauses {
16            fn to_string(&self) -> String {
17                String::from_utf8_lossy(
18                    &crate::print::print_to_vec(
19                        crate::print::FMT_NOFOLD,
20                        self,
21                    )
22                ).to_string()
23            }
24        }
25    };
26
27    expanded.into()
28}
29
30// derive(FuzzEq) --------------------------------------------------------------
31
32#[proc_macro_derive(FuzzEq, attributes(fuzz_eq))]
33pub fn derive_fuzz_eq(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as DeriveInput);
35
36    let name = input.ident;
37    let generics = add_bounds(input.generics, quote! { FuzzEq });
38
39    let (impl_generics, ty_generics, where_clauses) = generics.split_for_impl();
40
41    let body = match input.data {
42        Data::Struct(data) => derive_fuzz_eq_struct(&data.fields),
43        Data::Enum(data) => derive_fuzz_eq_enum(&name, &data.variants),
44        Data::Union(_) => {
45            return syn::Error::new_spanned(name, "FuzzEq cannot be derived for unions")
46                .to_compile_error()
47                .into();
48        }
49    };
50
51    let expanded = quote! {
52        impl #impl_generics FuzzEq for #name #ty_generics #where_clauses {
53            fn fuzz_eq(&self, other: &Self) -> bool {
54                #body
55            }
56        }
57    };
58
59    expanded.into() // TokenStream::from(expanded)
60}
61
62fn derive_fuzz_eq_struct(fields: &Fields) -> proc_macro2::TokenStream {
63    match fields {
64        Fields::Named(fields) => {
65            let comparisons = fields
66                .named
67                .iter()
68                .filter(|f| !has_attr(&f.attrs, "fuzz_eq", "ignore"))
69                .map(|f| {
70                    let name = &f.ident;
71                    if has_attr(&f.attrs, "fuzz_eq", "use_eq") {
72                        quote! { &self.#name == &other.#name }
73                    } else {
74                        quote! { self.#name.fuzz_eq(&other.#name) }
75                    }
76                });
77
78            quote! {
79                true #(&& #comparisons)*
80            }
81        }
82        Fields::Unnamed(fields) => {
83            let indices = (0..fields.unnamed.len()).map(syn::Index::from);
84
85            let comparisons = indices.map(|i| {
86                quote! { self.#i.fuzz_eq(&other.#i) }
87            });
88
89            quote! {
90                true #(&& #comparisons)*
91            }
92        }
93        Fields::Unit => quote!(true),
94    }
95}
96
97fn derive_fuzz_eq_enum(
98    enum_name: &syn::Ident,
99    variants: &Punctuated<Variant, Comma>,
100) -> proc_macro2::TokenStream {
101    let arms = variants.iter().map(|variant| {
102        let vname = &variant.ident;
103
104        match &variant.fields {
105            Fields::Unit => {
106                quote! {
107                    (#enum_name::#vname, #enum_name::#vname) => true
108                }
109            }
110            Fields::Unnamed(fields) => {
111                let lhs: Vec<_> = (0..fields.unnamed.len())
112                    .map(|i| syn::Ident::new(&format!("a{i}"), vname.span()))
113                    .collect();
114                let rhs: Vec<_> = (0..fields.unnamed.len())
115                    .map(|i| syn::Ident::new(&format!("b{i}"), vname.span()))
116                    .collect();
117
118                let comparisons = lhs.iter().zip(rhs.iter()).map(|(a, b)| {
119                    if has_attr(&variant.attrs, "fuzz_eq", "use_eq") {
120                        quote! { #a == #b }
121                    } else {
122                        quote! { #a.fuzz_eq(&#b) }
123                    }
124                });
125
126                quote! {
127                    (
128                        #enum_name::#vname( #(#lhs),* ),
129                        #enum_name::#vname( #(#rhs),* )
130                    ) => {
131                        true #(&& #comparisons)*
132                    }
133                }
134            }
135            Fields::Named(fields) => {
136                let lhs: Vec<_> = fields
137                    .named
138                    .iter()
139                    .map(|f| {
140                        syn::Ident::new(&format!("a_{}", f.ident.as_ref().unwrap()), vname.span())
141                    })
142                    .collect();
143                let rhs: Vec<_> = fields
144                    .named
145                    .iter()
146                    .map(|f| {
147                        syn::Ident::new(&format!("b_{}", f.ident.as_ref().unwrap()), vname.span())
148                    })
149                    .collect();
150
151                let names: Vec<_> = fields
152                    .named
153                    .iter()
154                    .map(|f| f.ident.as_ref().unwrap())
155                    .collect();
156
157                let comparisons = lhs.iter().zip(rhs.iter()).map(|(a, b)| {
158                    if has_attr(&variant.attrs, "fuzz_eq", "use_eq") {
159                        quote! { #a == #b }
160                    } else {
161                        quote! { #a.fuzz_eq(&#b) }
162                    }
163                });
164
165                quote! {
166                    (
167                        #enum_name::#vname { #(#names: #lhs),* },
168                        #enum_name::#vname { #(#names: #rhs),* }
169                    ) => {
170                        true #(&& #comparisons)*
171                    }
172                }
173            }
174        }
175    });
176
177    quote! {
178        match (self, other) {
179            #(#arms),*,
180            _ => false
181        }
182    }
183}
184
185// derive(ContainsUtf8) --------------------------------------------------------
186
187#[proc_macro_derive(ContainsUtf8, attributes(contains_utf8))]
188pub fn derive_contains_utf8(input: TokenStream) -> TokenStream {
189    let input = parse_macro_input!(input as DeriveInput);
190
191    let name = input.ident;
192    let generics = add_bounds(input.generics, quote! { ContainsUtf8 });
193
194    let (impl_generics, ty_generics, where_clauses) = generics.split_for_impl();
195
196    let body = if let Some(b) = has_bool_attr(&input.attrs, "contains_utf8") {
197        quote! { #b }
198    } else {
199        match input.data {
200            Data::Struct(data) => derive_contains_utf8_struct(&data.fields),
201            Data::Enum(data) => derive_contains_utf8_enum(&name, &data.variants),
202            Data::Union(_) => {
203                return syn::Error::new_spanned(name, "ContainsUtf8 cannot be derived for unions")
204                    .to_compile_error()
205                    .into();
206            }
207        }
208    };
209
210    let expanded = quote! {
211        impl #impl_generics ContainsUtf8 for #name #ty_generics #where_clauses {
212            fn contains_utf8(&self) -> bool {
213                #body
214            }
215        }
216    };
217
218    expanded.into()
219}
220
221fn derive_contains_utf8_struct(fields: &Fields) -> proc_macro2::TokenStream {
222    match fields {
223        Fields::Named(fields) => {
224            let tests = fields
225                .named
226                .iter()
227                .filter(|f| !has_attr(&f.attrs, "contains_utf8", "ignore"))
228                .map(|f| {
229                    let name = &f.ident;
230                    quote! { self.#name.contains_utf8() }
231                });
232
233            quote! { false #(|| #tests)* }
234        }
235        Fields::Unnamed(fields) => {
236            let indices = (0..fields.unnamed.len()).map(syn::Index::from);
237
238            let comparisons = indices.map(|i| {
239                quote! { self.#i.contains_utf8() }
240            });
241
242            quote! { false #(|| #comparisons)* }
243        }
244        Fields::Unit => quote!(false),
245    }
246}
247
248fn derive_contains_utf8_enum(
249    enum_name: &syn::Ident,
250    variants: &Punctuated<Variant, Comma>,
251) -> proc_macro2::TokenStream {
252    let arms = variants.iter().map(|variant| {
253        let vname = &variant.ident;
254
255        match &variant.fields {
256            Fields::Unit => {
257                quote! {
258                    #enum_name::#vname => false
259                }
260            }
261            Fields::Unnamed(fields) => {
262                let ids: Vec<_> = (0..fields.unnamed.len())
263                    .map(|i| syn::Ident::new(&format!("a{i}"), vname.span()))
264                    .collect();
265
266                let tests = ids.iter().map(|a| quote! { #a.contains_utf8() });
267
268                quote! {
269                    #enum_name::#vname( #(#ids),* ) => false #(|| #tests)*
270                }
271            }
272            Fields::Named(fields) => {
273                let ids: Vec<_> = fields
274                    .named
275                    .iter()
276                    .map(|f| {
277                        syn::Ident::new(&format!("a_{}", f.ident.as_ref().unwrap()), vname.span())
278                    })
279                    .collect();
280
281                let names: Vec<_> = fields
282                    .named
283                    .iter()
284                    .map(|f| f.ident.as_ref().unwrap())
285                    .collect();
286
287                let tests = ids.iter().map(|a| quote! { #a.contains_utf8() });
288
289                quote! {
290                    #enum_name::#vname { #(#names: #ids),* } => false #(|| #tests)*
291                }
292            }
293        }
294    });
295
296    quote! {
297        match self {
298            #(#arms),*,
299        }
300    }
301}
302
303// instrument_input ------------------------------------------------------------
304
305// This macro is fairly ad-hoc (it is simply a wrapper over the
306// tracing::instrument macro), but saves us quite a bit of repeated
307// boilerplate...
308#[proc_macro_attribute]
309pub fn instrument_input(attr: TokenStream, input: TokenStream) -> TokenStream {
310    let mut input = parse_macro_input!(input as ItemFn);
311    let feat = parse_macro_input!(attr as LitStr);
312    let attr: Attribute = parse_quote! {
313        #[cfg_attr(
314            feature = #feat,
315            tracing::instrument(fields(input = %crate::utils::bytes_to_trace_string(input)))
316        )]
317    };
318    input.attrs.push(attr);
319    TokenStream::from(quote! { #input })
320}
321
322// helpers
323
324fn add_bounds(mut generics: syn::Generics, trait_id: impl quote::ToTokens) -> syn::Generics {
325    let params = generics.params.clone();
326    let where_clause = generics.make_where_clause();
327
328    for param in &params {
329        if let syn::GenericParam::Type(type_param) = param {
330            let ident = &type_param.ident;
331
332            where_clause.predicates.push(syn::parse_quote! {
333                #ident: #trait_id
334            });
335        }
336    }
337
338    generics
339}
340
341fn has_attr(attrs: &Vec<Attribute>, path: &str, name: &str) -> bool {
342    attrs.iter().any(|attr| {
343        attr.path().is_ident(path)
344            && attr
345                .parse_args::<syn::Ident>()
346                .map_or(false, |ident| ident == name)
347    })
348}
349
350fn has_bool_attr(attrs: &Vec<Attribute>, path: &str) -> Option<syn::LitBool> {
351    attrs.iter().find_map(|attr| {
352        if attr.path().is_ident(path) {
353            attr.parse_args::<syn::LitBool>().ok()
354        } else {
355            None
356        }
357    })
358}