portable_hash_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream as TokenStream1;
4use proc_macro2::{Ident, Literal, Span, TokenStream};
5use quote::{format_ident, quote, ToTokens};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    ConstParam, Data, DeriveInput, Error, Fields, GenericArgument, Generics, Index, Lifetime,
11    LifetimeParam, Token, TypeParam, WhereClause,
12};
13
14fn crate_root() -> TokenStream {
15    quote!(::portable_hash)
16}
17
18#[proc_macro_derive(PortableHash)]
19#[allow(non_snake_case)]
20pub fn derive_portable_hash(input: TokenStream1) -> TokenStream1 {
21    let root = crate_root();
22    let hash = quote!(#root::PortableHash);
23    let hasher_write = quote!(#root::PortableHasher);
24
25    let mut input = parse_macro_input!(input as DeriveInput);
26    let ident = input.ident;
27
28    let mut tokens = TokenStream::new();
29    let mut types = Vec::new();
30
31    match input.data {
32        // Stability: structs are hashed in the order of their fields.
33        Data::Struct(x) => match x.fields {
34            Fields::Named(x) => {
35                let fields = x.named.iter().map(|x| {
36                    types.push(x.ty.clone());
37                    x.ident.as_ref().unwrap()
38                });
39                quote! {
40                    #( #hash::portable_hash(&self.#fields, state); )*
41                }
42                    .to_tokens(&mut tokens)
43            }
44
45            Fields::Unnamed(x) => {
46                let fields = x.unnamed.iter().enumerate().map(|(i, x)| {
47                    types.push(x.ty.clone());
48                    Index::from(i)
49                });
50                quote! {
51                    #( #hash::portable_hash(&self.#fields, state); )*
52                }
53                    .to_tokens(&mut tokens)
54            }
55
56            Fields::Unit => (),
57        },
58
59        // Stability: (TODO) enums must be keyed with a discriminant for DoS resistance.
60        Data::Enum(x) => {
61            let mut variant_tokens = TokenStream::new();
62
63            for (discriminant, x) in x.variants.iter().enumerate() {
64                let var = &x.ident;
65
66                // TODO(stabilisation): we do this for forward-compatibility, but does it cause other issues?
67                // Use write_u8 until discriminant > u8::MAX, then write_u16, u32, u64.
68                let span = Span::call_site(); // TODO: improve Span location?
69                let (discriminant_method, discriminant_value) = match discriminant {
70                    discriminant if discriminant > u32::MAX as usize => {
71                        (Ident::new("write_u64", span), Literal::u64_suffixed(discriminant as u64))
72                    }
73                    discriminant if discriminant > u16::MAX as usize => {
74                        (Ident::new("write_u32", span), Literal::u32_suffixed(discriminant as u32))
75                    }
76                    discriminant if discriminant > u8::MAX as usize => {
77                        (Ident::new("write_u16", span), Literal::u16_suffixed(discriminant as u16))
78                    }
79                    _ => {
80                        (Ident::new("write_u8", span), Literal::u8_suffixed(discriminant as u8))
81                    }
82                };
83
84                match &x.fields {
85                    Fields::Named(x) => {
86                        let fields: Vec<_> = x
87                            .named
88                            .iter()
89                            .map(|x| {
90                                types.push(x.ty.clone());
91                                x.ident.as_ref().unwrap()
92                            })
93                            .collect();
94                        // TODO(stabilisation): should we use the enum Name and write_str?
95                        //   It would allow re-ordering of named variants without changing
96                        //   the hash.
97                        quote! {
98                            Self::#var { #(#fields),* } => {
99                                state.#discriminant_method(#discriminant_value);
100                                #( #hash::portable_hash(#fields, state); )*
101                            }
102                        }
103                            .to_tokens(&mut variant_tokens);
104                    }
105
106                    Fields::Unnamed(x) => {
107                        let fields: Vec<_> = x
108                            .unnamed
109                            .iter()
110                            .enumerate()
111                            .map(|(i, x)| {
112                                types.push(x.ty.clone());
113                                format_ident!("_{}", i)
114                            })
115                            .collect();
116                        quote! {
117                            Self::#var(#(#fields),*) => {
118                                state.#discriminant_method(#discriminant_value);
119                                #( #hash::portable_hash(#fields, state); )*
120                            }
121                        }
122                            .to_tokens(&mut variant_tokens);
123                    }
124
125                    Fields::Unit => quote! {
126                        Self::#var => {
127                            state.#discriminant_method(#discriminant_value);
128                        },
129                    }
130                        .to_tokens(&mut variant_tokens),
131                }
132            }
133
134            // TODO(stability): use a portable discriminant for hashing
135            //   named -> use hash(str(variant_name)) + hash(data)
136            //   unnamed -> use hash(index) + hash(data)
137            //   unit -> use hash(index)
138            // Old: #hash::hash(&core::mem::discriminant(self), state);
139            quote! {
140                match self {
141                    #variant_tokens
142                }
143            }
144                .to_tokens(&mut tokens);
145        }
146
147        Data::Union(_) => {
148            return Error::new(ident.span(), "can't derive `Hash` for union")
149                .to_compile_error()
150                .into()
151        }
152    }
153
154    input.generics.make_where_clause();
155    let wc = input.generics.where_clause.as_mut().unwrap();
156    let where_ = fix_where(Some(wc));
157    let SplitGenerics {
158        lti,
159        ltt,
160        tpi,
161        tpt,
162        cpi,
163        cpt,
164        wc,
165    } = split_generics(&input.generics);
166    quote! {
167        impl<#(#lti,)* #(#tpi,)* #(#cpi,)*> #hash for #ident<#(#ltt,)* #(#tpt,)* #(#cpt),*> #where_ #wc
168            #( #types: #hash ),*
169        {
170            #[inline]
171            fn portable_hash<H: #hasher_write>(&self, state: &mut H) {
172                #tokens
173            }
174        }
175    }
176        .into()
177}
178
179// #[proc_macro]
180// pub fn impl_core_hash(input: TokenStream1) -> TokenStream1 {
181//     let root = crate_root();
182//     let hash = quote!(#root::Hash);
183//
184//     let input = parse_macro_input!(input as IdentsWithGenerics);
185//     let mut output = TokenStream::new();
186//
187//     for IdentWithGenerics {
188//         impl_generics,
189//         ident,
190//         use_generics,
191//         mut where_clause,
192//     } in input.punctuated
193//     {
194//         let where_ = fix_where(where_clause.as_mut());
195//         quote! {
196//             impl #impl_generics ::core::hash::Hash for #ident #use_generics #where_ #where_clause
197//                 Self: #hash,
198//             {
199//                 #[inline]
200//                 fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
201//                     <Self as #hash>::hash(
202//                         self, &mut #root::internal::WrapCoreForHasherU64::new(state)
203//                     )
204//                 }
205//             }
206//         }
207//             .to_tokens(&mut output);
208//     }
209//     output.into()
210// }
211//
212// #[proc_macro]
213// pub fn impl_core_hasher(input: TokenStream1) -> TokenStream1 {
214//     let root = crate_root();
215//     let hasher_t = quote!(#root::Hasher);
216//     let hasher_write = quote!(#root::HasherWrite);
217//
218//     let input = parse_macro_input!(input as IdentsWithGenerics);
219//     let mut output = TokenStream::new();
220//
221//     for IdentWithGenerics {
222//         impl_generics,
223//         ident,
224//         use_generics,
225//         mut where_clause,
226//     } in input.punctuated
227//     {
228//         let mut body = quote! {
229//             #[inline(always)]
230//             fn finish(&self) -> u64 {
231//                 <Self as #hasher_t::<u64>>::finish(self)
232//             }
233//
234//             #[inline(always)]
235//             fn write(&mut self, bytes: &[u8]) {
236//                 <Self as #hasher_write>::write(self, bytes)
237//             }
238//         };
239//
240//         for t in [
241//             quote!(u8),
242//             quote!(u16),
243//             quote!(u32),
244//             quote!(u64),
245//             quote!(u128),
246//             quote!(usize),
247//             quote!(i8),
248//             quote!(i16),
249//             quote!(i32),
250//             quote!(i64),
251//             quote!(i128),
252//             quote!(isize),
253//         ] {
254//             let wid = format_ident!("write_{t}");
255//             quote! {
256//                 #[inline(always)]
257//                 fn #wid(&mut self, i: #t) {
258//                     <Self as #hasher_write>::#wid(self, i);
259//                 }
260//             }
261//                 .to_tokens(&mut body);
262//         }
263//
264//         let where_ = fix_where(where_clause.as_mut());
265//         quote! {
266//             impl #impl_generics ::core::hash::Hasher for #ident #use_generics #where_ #where_clause
267//                 Self: #hasher_t<u64>,
268//             {
269//                 #body
270//             }
271//         }
272//             .to_tokens(&mut output);
273//     }
274//     output.into()
275// }
276//
277// #[proc_macro]
278// pub fn impl_core_build_hasher(input: TokenStream1) -> TokenStream1 {
279//     let root = crate_root();
280//     let build_hasher_t = quote!(#root::BuildHasher);
281//
282//     let input = parse_macro_input!(input as IdentsWithGenerics);
283//     let mut output = TokenStream::new();
284//
285//     for IdentWithGenerics {
286//         impl_generics,
287//         ident,
288//         use_generics,
289//         mut where_clause,
290//     } in input.punctuated
291//     {
292//         let where_ = fix_where(where_clause.as_mut());
293//         quote! {
294//             impl #impl_generics ::core::hash::BuildHasher for #ident #use_generics #where_ #where_clause
295//                 Self: #build_hasher_t<u64>,
296//             {
297//                 type Hasher = #root::internal::WrapHasherU64ForCore<<Self as #build_hasher_t::<u64>>::Hasher>;
298//
299//                 #[inline]
300//                 fn build_hasher(&self) -> Self::Hasher {
301//                     Self::Hasher::new(<Self as #build_hasher_t::<u64>>::build_hasher(self))
302//                 }
303//             }
304//         }
305//             .to_tokens(&mut output);
306//     }
307//     output.into()
308// }
309//
310// #[proc_macro]
311// #[allow(non_snake_case)]
312// pub fn impl_hash(input: TokenStream1) -> TokenStream1 {
313//     let root = crate_root();
314//     let hash = quote!(#root::Hash);
315//     let hasher_write = quote!(#root::HasherWrite);
316//
317//     let input = parse_macro_input!(input as IdentsWithGenerics);
318//     let mut output = TokenStream::new();
319//
320//     for IdentWithGenerics {
321//         impl_generics,
322//         ident,
323//         use_generics,
324//         mut where_clause,
325//     } in input.punctuated
326//     {
327//         let SplitGenerics {
328//             lti,
329//             ltt: _,
330//             tpi,
331//             tpt: _,
332//             cpi,
333//             cpt: _,
334//             wc: _,
335//         } = split_generics(&impl_generics);
336//         let where_ = fix_where(where_clause.as_mut());
337//
338//         quote! {
339//             impl<#(#lti,)* #(#tpi,)* #(#cpi,)*> #hash for #ident #use_generics #where_ #where_clause {
340//                 #[inline]
341//                 fn hash<H: #hasher_write>(&self, state: &mut H) {
342//                     <Self as ::core::hash::Hash>::hash(
343//                         self, &mut #root::internal::WrapHasherWriteForCore::new(state)
344//                     )
345//                 }
346//             }
347//         }
348//             .to_tokens(&mut output);
349//     }
350//     output.into()
351// }
352
353fn fix_where(wc: Option<&mut WhereClause>) -> Option<Token![where]> {
354    if let Some(wc) = wc {
355        if wc.predicates.is_empty() {
356            Some(wc.where_token)
357        } else {
358            if !wc.predicates.trailing_punct() {
359                wc.predicates.push_punct(<Token![,]>::default());
360            }
361            None
362        }
363    } else {
364        Some(<Token![where]>::default())
365    }
366}
367
368struct SplitGenerics<
369    'a,
370    LTI: Iterator<Item = &'a LifetimeParam>,
371    LTT: Iterator<Item = &'a Lifetime>,
372    TPI: Iterator<Item = &'a TypeParam>,
373    TPT: Iterator<Item = &'a Ident>,
374    CPI: Iterator<Item = &'a ConstParam>,
375    CPT: Iterator<Item = &'a Ident>,
376> {
377    lti: LTI,
378    ltt: LTT,
379    tpi: TPI,
380    tpt: TPT,
381    cpi: CPI,
382    cpt: CPT,
383    wc: &'a Option<WhereClause>,
384}
385
386fn split_generics(
387    generics: &Generics,
388) -> SplitGenerics<
389    impl Iterator<Item = &LifetimeParam>,
390    impl Iterator<Item = &Lifetime>,
391    impl Iterator<Item = &TypeParam>,
392    impl Iterator<Item = &Ident>,
393    impl Iterator<Item = &ConstParam>,
394    impl Iterator<Item = &Ident>,
395> {
396    SplitGenerics {
397        lti: generics.lifetimes(),
398        ltt: generics.lifetimes().map(|l| &l.lifetime),
399        tpi: generics.type_params(),
400        tpt: generics.type_params().map(|t| &t.ident),
401        cpi: generics.const_params(),
402        cpt: generics.const_params().map(|c| &c.ident),
403        wc: &generics.where_clause,
404    }
405}
406
407// struct IdentsWithGenerics {
408//     punctuated: Punctuated<IdentWithGenerics, Token![;]>,
409// }
410//
411// impl Parse for IdentsWithGenerics {
412//     fn parse(input: ParseStream) -> syn::Result<Self> {
413//         let punctuated = Punctuated::parse_terminated(input)?;
414//         Ok(Self { punctuated })
415//     }
416// }
417//
418// struct IdentWithGenerics {
419//     impl_generics: Generics,
420//     ident: Ident,
421//     use_generics: Option<GenericArguments>,
422//     where_clause: Option<WhereClause>,
423// }
424//
425// impl Parse for IdentWithGenerics {
426//     fn parse(input: ParseStream) -> syn::Result<Self> {
427//         let impl_generics = if Option::<Token![impl]>::parse(input)?.is_some() {
428//             Generics::parse(input)?
429//         } else {
430//             Generics::default()
431//         };
432//         let ident = Ident::parse(input)?;
433//         let use_generics = if input.peek(Token![<]) {
434//             Some(GenericArguments::parse(input)?)
435//         } else {
436//             None
437//         };
438//         let where_clause = Option::<WhereClause>::parse(input)?;
439//
440//         Ok(Self {
441//             impl_generics,
442//             ident,
443//             use_generics,
444//             where_clause,
445//         })
446//     }
447// }
448
449struct GenericArguments {
450    lt_token: Token![<],
451    args: Punctuated<GenericArgument, Token![,]>,
452    rt_token: Token![>],
453}
454
455impl Parse for GenericArguments {
456    fn parse(input: ParseStream) -> syn::Result<Self> {
457        let lt_token = <Token![<]>::parse(input)?;
458
459        let mut args = Punctuated::new();
460        while let Ok(arg) = GenericArgument::parse(input) {
461            args.push(arg);
462            if let Ok(comma) = <Token![,]>::parse(input) {
463                args.push_punct(comma);
464            } else {
465                break;
466            }
467        }
468
469        let rt_token = <Token![>]>::parse(input)?;
470
471        Ok(Self {
472            lt_token,
473            args,
474            rt_token,
475        })
476    }
477}
478
479impl ToTokens for GenericArguments {
480    fn to_tokens(&self, tokens: &mut TokenStream) {
481        self.lt_token.to_tokens(tokens);
482        self.args.to_tokens(tokens);
483        self.rt_token.to_tokens(tokens);
484    }
485}