1use proc_macro2::{Ident, Span, TokenStream};
18use proc_macro_crate::crate_name;
19use proc_macro_error::{abort_call_site, proc_macro_error};
20use syn::*;
21
22use quote::{quote, ToTokens};
23
24#[proc_macro_derive(CanonicalSerialize)]
25pub fn derive_canonical_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
26 let ast = parse_macro_input!(input as DeriveInput);
27 proc_macro::TokenStream::from(impl_canonical_serialize(&ast))
28}
29
30enum IdentOrIndex {
31 Ident(proc_macro2::Ident),
32 Index(Index),
33}
34
35impl ToTokens for IdentOrIndex {
36 fn to_tokens(&self, tokens: &mut TokenStream) {
37 match self {
38 Self::Ident(ident) => ident.to_tokens(tokens),
39 Self::Index(index) => index.to_tokens(tokens),
40 }
41 }
42}
43
44fn impl_serialize_field(
45 serialize_body: &mut Vec<TokenStream>,
46 serialized_size_body: &mut Vec<TokenStream>,
47 serialize_uncompressed_body: &mut Vec<TokenStream>,
48 uncompressed_size_body: &mut Vec<TokenStream>,
49 idents: &mut Vec<IdentOrIndex>,
50 ty: &Type,
51) {
52 match ty {
54 Type::Tuple(tuple) => {
55 for (i, elem_ty) in tuple.elems.iter().enumerate() {
56 let index = Index::from(i);
57 idents.push(IdentOrIndex::Index(index));
58 impl_serialize_field(
59 serialize_body,
60 serialized_size_body,
61 serialize_uncompressed_body,
62 uncompressed_size_body,
63 idents,
64 elem_ty,
65 );
66 idents.pop();
67 }
68 }
69 _ => {
70 serialize_body.push(quote! { CanonicalSerialize::serialize(&self.#(#idents).*, writer)?; });
71 serialized_size_body.push(quote! { size += CanonicalSerialize::serialized_size(&self.#(#idents).*); });
72 serialize_uncompressed_body
73 .push(quote! { CanonicalSerialize::serialize_uncompressed(&self.#(#idents).*, writer)?; });
74 uncompressed_size_body.push(quote! { size += CanonicalSerialize::uncompressed_size(&self.#(#idents).*); });
75 }
76 }
77}
78
79fn impl_canonical_serialize(ast: &syn::DeriveInput) -> TokenStream {
80 let name = &ast.ident;
81
82 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
83
84 let len = if let Data::Struct(ref data_struct) = ast.data {
85 data_struct.fields.len()
86 } else {
87 panic!("Serialize can only be derived for structs, {} is not a struct", name);
88 };
89
90 let mut serialize_body = Vec::<TokenStream>::with_capacity(len);
91 let mut serialized_size_body = Vec::<TokenStream>::with_capacity(len);
92 let mut serialize_uncompressed_body = Vec::<TokenStream>::with_capacity(len);
93 let mut uncompressed_size_body = Vec::<TokenStream>::with_capacity(len);
94
95 match ast.data {
96 Data::Struct(ref data_struct) => {
97 let mut idents = Vec::<IdentOrIndex>::new();
98
99 for (i, field) in data_struct.fields.iter().enumerate() {
100 match field.ident {
101 None => {
102 let index = Index::from(i);
103 idents.push(IdentOrIndex::Index(index));
104 }
105 Some(ref ident) => {
106 idents.push(IdentOrIndex::Ident(ident.clone()));
107 }
108 }
109
110 impl_serialize_field(
111 &mut serialize_body,
112 &mut serialized_size_body,
113 &mut serialize_uncompressed_body,
114 &mut uncompressed_size_body,
115 &mut idents,
116 &field.ty,
117 );
118
119 idents.clear();
120 }
121 }
122 _ => panic!("Serialize can only be derived for structs, {} is not a struct", name),
123 };
124
125 let gen = quote! {
126 impl #impl_generics CanonicalSerialize for #name #ty_generics #where_clause {
127 #[allow(unused_mut, unused_variables)]
128 fn serialize<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
129 #(#serialize_body)*
130 Ok(())
131 }
132 #[allow(unused_mut, unused_variables)]
133 fn serialized_size(&self) -> usize {
134 let mut size = 0;
135 #(#serialized_size_body)*
136 size
137 }
138 #[allow(unused_mut, unused_variables)]
139 fn serialize_uncompressed<W: Write>(&self, writer: &mut W) -> Result<(), SerializationError> {
140 #(#serialize_uncompressed_body)*
141 Ok(())
142 }
143 #[allow(unused_mut, unused_variables)]
144 fn uncompressed_size(&self) -> usize {
145 let mut size = 0;
146 #(#uncompressed_size_body)*
147 size
148 }
149 }
150 };
151 gen
152}
153
154#[proc_macro_derive(CanonicalDeserialize)]
155pub fn derive_canonical_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
156 let ast = parse_macro_input!(input as DeriveInput);
157 proc_macro::TokenStream::from(impl_canonical_deserialize(&ast))
158}
159
160fn impl_deserialize_field(ty: &Type) -> (TokenStream, TokenStream) {
163 match ty {
165 Type::Tuple(tuple) => {
166 let (compressed_fields, uncompressed_fields): (Vec<_>, Vec<_>) =
167 tuple.elems.iter().map(impl_deserialize_field).unzip();
168 (
169 quote! { (#(#compressed_fields)*), },
170 quote! { (#(#uncompressed_fields)*), },
171 )
172 }
173 _ => (
174 quote! { CanonicalDeserialize::deserialize(reader)?, },
175 quote! { CanonicalDeserialize::deserialize_uncompressed(reader)?, },
176 ),
177 }
178}
179
180fn impl_canonical_deserialize(ast: &syn::DeriveInput) -> TokenStream {
181 let name = &ast.ident;
182
183 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
184
185 let deserialize_body;
186 let deserialize_uncompressed_body;
187
188 match ast.data {
189 Data::Struct(ref data_struct) => {
190 let mut tuple = false;
191 let mut compressed_field_cases = Vec::<TokenStream>::with_capacity(data_struct.fields.len());
192 let mut uncompressed_field_cases = Vec::<TokenStream>::with_capacity(data_struct.fields.len());
193 for field in data_struct.fields.iter() {
194 match &field.ident {
195 None => {
196 tuple = true;
197 let (compressed, uncompressed) = impl_deserialize_field(&field.ty);
198 compressed_field_cases.push(compressed);
199 uncompressed_field_cases.push(uncompressed);
200 }
201 Some(ident) => {
203 let (compressed_field, uncompressed_field) = impl_deserialize_field(&field.ty);
204 compressed_field_cases.push(quote! { #ident: #compressed_field });
205 uncompressed_field_cases.push(quote! { #ident: #uncompressed_field });
206 }
207 }
208 }
209
210 if tuple {
211 deserialize_body = quote!({
212 Ok(#name (
213 #(#compressed_field_cases)*
214 ))
215 });
216 deserialize_uncompressed_body = quote!({
217 Ok(#name (
218 #(#uncompressed_field_cases)*
219 ))
220 });
221 } else {
222 deserialize_body = quote!({
223 Ok(#name {
224 #(#compressed_field_cases)*
225 })
226 });
227 deserialize_uncompressed_body = quote!({
228 Ok(#name {
229 #(#uncompressed_field_cases)*
230 })
231 });
232 }
233 }
234 _ => panic!("Deserialize can only be derived for structs, {} is not a Struct", name),
235 };
236
237 let gen = quote! {
238 impl #impl_generics CanonicalDeserialize for #name #ty_generics #where_clause {
239 #[allow(unused_mut,unused_variables)]
240 fn deserialize<R: Read>(reader: &mut R) -> Result<Self, SerializationError> {
241 #deserialize_body
242 }
243 #[allow(unused_mut,unused_variables)]
244 fn deserialize_uncompressed<R: Read>(reader: &mut R) -> Result<Self, SerializationError> {
245 #deserialize_uncompressed_body
246 }
247 }
248 };
249 gen
250}
251
252#[proc_macro_error]
253#[proc_macro_attribute]
254pub fn test_with_metrics(_: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
255 match parse::<ItemFn>(item.clone()) {
256 Ok(function) => {
257 fn generate_test_function(function: ItemFn, crate_name: Ident) -> proc_macro::TokenStream {
258 let name = &function.sig.ident;
259 let statements = function.block.stmts;
260 (quote! {
261 #[test]
264 #[serial]
265 fn #name() {
266 #crate_name::testing::initialize_prometheus_for_testing();
268 assert_eq!(0, #crate_name::Metrics::get_connected_peers());
270 #(#statements)*
272 assert_eq!(0, Metrics::get_connected_peers());
274 }
275 })
276 .into()
277 }
278 let name = crate_name("snarkos-metrics").unwrap_or_else(|_| "crate".to_string());
279 let crate_name = Ident::new(&name, Span::call_site());
280 generate_test_function(function, crate_name)
281 }
282 _ => abort_call_site!("test_with_metrics only works on functions"),
283 }
284}