snarkvm_derives/
lib.rs

1// Copyright (C) 2019-2021 Aleo Systems Inc.
2// This file is part of the snarkVM library.
3
4// The snarkVM library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The snarkVM library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the snarkVM library. If not, see <https://www.gnu.org/licenses/>.
16
17use proc_macro2::{Ident, Span, TokenStream};
18use proc_macro_crate::crate_name;
19use proc_macro_error::{abort_call_site, proc_macro_error};
20use syn::*;
21
22use quote::{quote, ToTokens};
23
24#[proc_macro_derive(CanonicalSerialize)]
25pub fn derive_canonical_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
26    let ast = parse_macro_input!(input as DeriveInput);
27    proc_macro::TokenStream::from(impl_canonical_serialize(&ast))
28}
29
30enum IdentOrIndex {
31    Ident(proc_macro2::Ident),
32    Index(Index),
33}
34
35impl ToTokens for IdentOrIndex {
36    fn to_tokens(&self, tokens: &mut TokenStream) {
37        match self {
38            Self::Ident(ident) => ident.to_tokens(tokens),
39            Self::Index(index) => index.to_tokens(tokens),
40        }
41    }
42}
43
44fn impl_serialize_field(
45    serialize_body: &mut Vec<TokenStream>,
46    serialized_size_body: &mut Vec<TokenStream>,
47    serialize_uncompressed_body: &mut Vec<TokenStream>,
48    uncompressed_size_body: &mut Vec<TokenStream>,
49    idents: &mut Vec<IdentOrIndex>,
50    ty: &Type,
51) {
52    // Check if type is a tuple.
53    match ty {
54        Type::Tuple(tuple) => {
55            for (i, elem_ty) in tuple.elems.iter().enumerate() {
56                let index = Index::from(i);
57                idents.push(IdentOrIndex::Index(index));
58                impl_serialize_field(
59                    serialize_body,
60                    serialized_size_body,
61                    serialize_uncompressed_body,
62                    uncompressed_size_body,
63                    idents,
64                    elem_ty,
65                );
66                idents.pop();
67            }
68        }
69        _ => {
70            serialize_body.push(quote! { CanonicalSerialize::serialize(&self.#(#idents).*, writer)?; });
71            serialized_size_body.push(quote! { size += CanonicalSerialize::serialized_size(&self.#(#idents).*); });
72            serialize_uncompressed_body
73                .push(quote! { CanonicalSerialize::serialize_uncompressed(&self.#(#idents).*, writer)?; });
74            uncompressed_size_body.push(quote! { size += CanonicalSerialize::uncompressed_size(&self.#(#idents).*); });
75        }
76    }
77}
78
79fn impl_canonical_serialize(ast: &syn::DeriveInput) -> TokenStream {
80    let name = &ast.ident;
81
82    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
83
84    let len = if let Data::Struct(ref data_struct) = ast.data {
85        data_struct.fields.len()
86    } else {
87        panic!("Serialize can only be derived for structs, {} is not a struct", name);
88    };
89
90    let mut serialize_body = Vec::<TokenStream>::with_capacity(len);
91    let mut serialized_size_body = Vec::<TokenStream>::with_capacity(len);
92    let mut serialize_uncompressed_body = Vec::<TokenStream>::with_capacity(len);
93    let mut uncompressed_size_body = Vec::<TokenStream>::with_capacity(len);
94
95    match ast.data {
96        Data::Struct(ref data_struct) => {
97            let mut idents = Vec::<IdentOrIndex>::new();
98
99            for (i, field) in data_struct.fields.iter().enumerate() {
100                match field.ident {
101                    None => {
102                        let index = Index::from(i);
103                        idents.push(IdentOrIndex::Index(index));
104                    }
105                    Some(ref ident) => {
106                        idents.push(IdentOrIndex::Ident(ident.clone()));
107                    }
108                }
109
110                impl_serialize_field(
111                    &mut serialize_body,
112                    &mut serialized_size_body,
113                    &mut serialize_uncompressed_body,
114                    &mut uncompressed_size_body,
115                    &mut idents,
116                    &field.ty,
117                );
118
119                idents.clear();
120            }
121        }
122        _ => panic!("Serialize can only be derived for structs, {} is not a struct", name),
123    };
124
125    let gen = quote! {
126        impl #impl_generics CanonicalSerialize for #name #ty_generics #where_clause {
127            #[allow(unused_mut, unused_variables)]
128            fn serialize<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
129                #(#serialize_body)*
130                Ok(())
131            }
132            #[allow(unused_mut, unused_variables)]
133            fn serialized_size(&self) -> usize {
134                let mut size = 0;
135                #(#serialized_size_body)*
136                size
137            }
138            #[allow(unused_mut, unused_variables)]
139            fn serialize_uncompressed<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
140                #(#serialize_uncompressed_body)*
141                Ok(())
142            }
143            #[allow(unused_mut, unused_variables)]
144            fn uncompressed_size(&self) -> usize {
145                let mut size = 0;
146                #(#uncompressed_size_body)*
147                size
148            }
149        }
150    };
151    gen
152}
153
154#[proc_macro_derive(CanonicalDeserialize)]
155pub fn derive_canonical_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
156    let ast = parse_macro_input!(input as DeriveInput);
157    proc_macro::TokenStream::from(impl_canonical_deserialize(&ast))
158}
159
160/// Returns two TokenStreams, one for the compressed deserialize, one for the
161/// uncompressed.
162fn impl_deserialize_field(ty: &Type) -> (TokenStream, TokenStream) {
163    // Check if type is a tuple.
164    match ty {
165        Type::Tuple(tuple) => {
166            let (compressed_fields, uncompressed_fields): (Vec<_>, Vec<_>) =
167                tuple.elems.iter().map(impl_deserialize_field).unzip();
168            (
169                quote! { (#(#compressed_fields)*), },
170                quote! { (#(#uncompressed_fields)*), },
171            )
172        }
173        _ => (
174            quote! { CanonicalDeserialize::deserialize(reader)?, },
175            quote! { CanonicalDeserialize::deserialize_uncompressed(reader)?, },
176        ),
177    }
178}
179
180fn impl_canonical_deserialize(ast: &syn::DeriveInput) -> TokenStream {
181    let name = &ast.ident;
182
183    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
184
185    let deserialize_body;
186    let deserialize_uncompressed_body;
187
188    match ast.data {
189        Data::Struct(ref data_struct) => {
190            let mut tuple = false;
191            let mut compressed_field_cases = Vec::<TokenStream>::with_capacity(data_struct.fields.len());
192            let mut uncompressed_field_cases = Vec::<TokenStream>::with_capacity(data_struct.fields.len());
193            for field in data_struct.fields.iter() {
194                match &field.ident {
195                    None => {
196                        tuple = true;
197                        let (compressed, uncompressed) = impl_deserialize_field(&field.ty);
198                        compressed_field_cases.push(compressed);
199                        uncompressed_field_cases.push(uncompressed);
200                    }
201                    // struct field without len_type
202                    Some(ident) => {
203                        let (compressed_field, uncompressed_field) = impl_deserialize_field(&field.ty);
204                        compressed_field_cases.push(quote! { #ident: #compressed_field });
205                        uncompressed_field_cases.push(quote! { #ident: #uncompressed_field });
206                    }
207                }
208            }
209
210            if tuple {
211                deserialize_body = quote!({
212                    Ok(#name (
213                        #(#compressed_field_cases)*
214                    ))
215                });
216                deserialize_uncompressed_body = quote!({
217                    Ok(#name (
218                        #(#uncompressed_field_cases)*
219                    ))
220                });
221            } else {
222                deserialize_body = quote!({
223                    Ok(#name {
224                        #(#compressed_field_cases)*
225                    })
226                });
227                deserialize_uncompressed_body = quote!({
228                    Ok(#name {
229                        #(#uncompressed_field_cases)*
230                    })
231                });
232            }
233        }
234        _ => panic!("Deserialize can only be derived for structs, {} is not a Struct", name),
235    };
236
237    let gen = quote! {
238        impl #impl_generics CanonicalDeserialize for #name #ty_generics #where_clause {
239            #[allow(unused_mut,unused_variables)]
240            fn deserialize<R: Read>(reader: &mut R) -> Result<Self, SerializationError> {
241                #deserialize_body
242            }
243            #[allow(unused_mut,unused_variables)]
244            fn deserialize_uncompressed<R: Read>(reader: &mut R) -> Result<Self, SerializationError> {
245                #deserialize_uncompressed_body
246            }
247        }
248    };
249    gen
250}
251
252#[proc_macro_error]
253#[proc_macro_attribute]
254pub fn test_with_metrics(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
255    match parse::<ItemFn>(item.clone()) {
256        Ok(function) => {
257            fn generate_test_function(function: ItemFn, crate_name: Ident) -> proc_macro::TokenStream {
258                let name = &function.sig.ident;
259                let statements = function.block.stmts;
260                (quote! {
261                    // Generates a new test with Prometheus registry checks, and enforces
262                    // that the test runs serially with other tests that use metrics.
263                    #[test]
264                    #[serial]
265                    fn #name() {
266                        // Initialize Prometheus once in the test environment.
267                        #crate_name::testing::initialize_prometheus_for_testing();
268                        // Check that all metrics are 0 or empty.
269                        assert_eq!(0, #crate_name::Metrics::get_connected_peers());
270                        // Run the test logic.
271                        #(#statements)*
272                        // Check that all metrics are reset to 0 or empty (for next test).
273                        assert_eq!(0, Metrics::get_connected_peers());
274                    }
275                })
276                .into()
277            }
278            let name = crate_name("snarkos-metrics").unwrap_or_else(|_| "crate".to_string());
279            let crate_name = Ident::new(&name, Span::call_site());
280            generate_test_function(function, crate_name)
281        }
282        _ => abort_call_site!("test_with_metrics only works on functions"),
283    }
284}