Skip to main content

frozone_derive/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2extern crate proc_macro2;
3use proc_macro::TokenStream;
4use quote::quote;
5
6#[proc_macro_derive(Freezable, attributes(assume_frozen))]
7pub fn derive_freezable(input: TokenStream) -> TokenStream {
8    let ast: syn::DeriveInput = syn::parse_macro_input!(input);
9    let name = &ast.ident;
10    let generics = ast.generics.split_for_impl();
11
12    match ast.data {
13        syn::Data::Struct(data) => derive_freezable_struct(data, name, generics),
14        syn::Data::Enum(data) => derive_freezable_enum(data, name, generics),
15        _ => unimplemented!(),
16    }
17}
18
19/// generate Freezable impl for the enum
20/// (that recursively call `freeze()` on all non-excluded
21/// variant and their fields' types)
22fn derive_freezable_enum(
23    data: syn::DataEnum,
24    name: &syn::Ident,
25    generics: (
26        syn::ImplGenerics,
27        syn::TypeGenerics,
28        Option<&syn::WhereClause>,
29    ),
30) -> TokenStream {
31    let variants_names_and_freezes = data.variants.iter().map(|f| {
32        let name = &f.ident;
33        if let Some(af) = f.attrs.iter().find(|a| a.path().is_ident("assume_frozen")) {
34            if attr_helper_freeze_generics(&af) {
35                // the variant's field types still freezes their generic arguments
36                // (but not themselves)
37                let variant_fields = f
38                    .fields
39                    .iter()
40                    .map(|g| freeze_field_only_generics(&g.ty, name));
41
42                quote! {
43                    (stringify!(#name), {
44                        let mut hasher = core::hash::SipHasher::new();
45
46                        [#(#variant_fields,)*].iter().for_each(|(_name, x): &(&str, u64)| {
47                            x.hash(&mut hasher);
48                        });
49                        hasher.finish()
50                    })
51                }
52            } else {
53                // completely ignore the variant
54                quote! {
55                    (stringify!(#name), 0)
56                }
57            }
58        } else {
59            // handle simple cases such as `enum M {A = 1}`
60            let discriminant = f
61                .discriminant
62                .as_ref()
63                .map(|eq_d| eq_d.1.clone())
64                .map(|d| {
65                    quote! {
66                    use core::hash::Hasher;
67                    let mut hasher = core::hash::SipHasher::new();
68                        (#d).hash(&mut hasher);
69                    hasher.finish()
70                    }
71                })
72                .unwrap_or(quote! {0});
73
74            // freeze all fields of a variant `enum M { A(u8, OtherType, etc...) }`
75            let variant_fields = f.fields.iter().map(|g| {
76                let g_ty = &g.ty;
77                quote! {
78                    <#g_ty as frozone::Freezable>::freeze()
79                }
80            });
81
82            // combine all into the enum's final freeze
83            quote! {
84                (stringify!(#name), {
85                    let mut hasher = core::hash::SipHasher::new();
86
87                    #discriminant.hash(&mut hasher);
88                    [#(#variant_fields,)*].iter().for_each(|x: &u64| {
89                        x.hash(&mut hasher);
90                    });
91                    hasher.finish()
92                })
93            }
94        }
95    });
96
97    let (impl_generics, type_generics, where_clause) = generics;
98    quote! {
99        impl #impl_generics frozone::Freezable for #name #type_generics #where_clause {
100            fn freeze() -> u64 {
101                use core::hash::{Hash, Hasher};
102
103                [#(#variants_names_and_freezes,)*].iter().fold(0u64, |acc, x| {
104                    let mut hasher = core::hash::SipHasher::new();
105                    x.0.hash(&mut hasher);
106                    x.1.hash(&mut hasher);
107                    acc.overflowing_add(hasher.finish()).0
108                })
109            }
110        }
111    }
112    .into()
113}
114
115/// generate Freezable impl for the struct (that recursively
116/// call `freeze()` on all non-excluded fields' types
117fn derive_freezable_struct(
118    data: syn::DataStruct,
119    name: &syn::Ident,
120    generics: (
121        syn::ImplGenerics,
122        syn::TypeGenerics,
123        Option<&syn::WhereClause>,
124    ),
125) -> TokenStream {
126    let fields = data.fields.iter().map(|f| {
127        let name = &f.ident.as_ref().unwrap();
128        let ty = &f.ty;
129        if let Some(af) = f.attrs.iter().find(|a| a.path().is_ident("assume_frozen")) {
130            if attr_helper_freeze_generics(&af) {
131                // field type still freezes the generic arguments of its type
132                // (but not the type itself)
133                freeze_field_only_generics(ty, name)
134            } else {
135                quote! {
136                    (stringify!(#name), 0)
137                }
138            }
139        } else {
140            quote! {
141                (stringify!(#name), <#ty as frozone::Freezable>::freeze())
142            }
143        }
144    });
145
146    let (impl_generics, type_generics, where_clause) = generics;
147    let generated = quote! {
148        impl #impl_generics frozone::Freezable for #name #type_generics #where_clause {
149            fn freeze() -> u64 {
150                use core::hash::{Hash, Hasher};
151
152                [#(#fields,)*].iter().fold(0u64, |acc, x| {
153                    let mut hasher = core::hash::SipHasher::new();
154                    x.0.hash(&mut hasher);
155                    x.1.hash(&mut hasher);
156                    acc.overflowing_add(hasher.finish()).0
157                })
158            }
159        }
160    };
161
162    generated.into()
163}
164
165/// generate a quote! that freezes a type but only over its generic
166/// arguments (they must impl Freezable)
167fn freeze_field_only_generics(ty: &syn::Type, name: &syn::Ident) -> proc_macro2::TokenStream {
168    match ty {
169        syn::Type::Path(p) => {
170            let type_segments = p.path.segments.iter().map(|ps| {
171                match &ps.arguments {
172                    syn::PathArguments::None => quote! {
173                        {
174                            let mut hasher = core::hash::SipHasher::new();
175                            hasher.finish()
176                        }
177                    },
178                    syn::PathArguments::AngleBracketed(bracketed) => {
179                        let generics = bracketed
180                            .args
181                            .iter()
182                            .filter_map(|g| match g {
183                                syn::GenericArgument::Type(t) => Some(t),
184                                _ => None,
185                            })
186                            .map(|t| {
187                                quote! {
188                                <#t as frozone::Freezable>::freeze()
189                                }
190                            });
191                        quote! {{
192                            let mut hasher = core::hash::SipHasher::new();
193                            "GenericType".hash(&mut hasher); // prevent collisions with parenthesized generics
194                            [#(#generics,)*].iter().for_each(|x: &u64| {
195                                x.hash(&mut hasher);
196                            });
197                            hasher.finish()
198                        }}
199                    }
200                    // not sure how those would be constructed though
201                    syn::PathArguments::Parenthesized(parenthesized) => {
202                        let generic_output = match &parenthesized.output {
203                            syn::ReturnType::Default => quote! {
204                                <() as frozone::Freezable>::freeze()
205                            },
206                            syn::ReturnType::Type(_, box_of_t) => {
207                                let inner_type = *box_of_t.clone();
208                                quote! {
209                                    <#inner_type as frozone::Freezable>::freeze()
210                                }
211                            }
212                        };
213                        let generic_input = parenthesized.inputs.iter().map(|t| {
214                            quote! {
215                            <#t as frozone::Freezable>::freeze()
216                            }
217                        });
218
219                        quote! {{
220
221                            let mut hasher = core::hash::SipHasher::new();
222                            "GenericFunc".hash(&mut hasher); // prevent collisions with bracketed generics
223                            [#(#generic_input,)*].iter().for_each(|x: &u64| {
224                                x.hash(&mut hasher);
225                            });
226                            #generic_output.hash(&mut hasher);
227                            hasher.finish()
228                        }}
229                    }
230                }
231            });
232            quote! {
233                (stringify!(#name), {
234                    let mut hasher = core::hash::SipHasher::new();
235
236                    [#(#type_segments,)*].iter().for_each(|x: &u64| {
237                        x.hash(&mut hasher);
238                    });
239                    hasher.finish()
240                })
241            }
242        }
243        _ => {
244            panic!("type not a path");
245        }
246    }
247}
248fn attr_helper_freeze_generics(attr: &syn::Attribute) -> bool {
249    let mut found_freeze_generic = false;
250    let _ = attr.parse_nested_meta(|meta| {
251        if meta.path.is_ident("freeze_generics") {
252            found_freeze_generic = true;
253        }
254        Ok(())
255    }); // may be err if there is no parenthesis inside the #[assume_frozen] attr
256    // println!("found generic {found_freeze_generic:?}");
257    return found_freeze_generic;
258}