1extern crate proc_macro;
2use quote::quote;
3
4#[proc_macro_derive(Encode)]
5pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
7 let name = &ast.ident;
8 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
9 let body = match ast.data {
10 syn::Data::Struct(ref data) => generate_encode_for_struct(data, &name),
11 syn::Data::Enum(ref data) => generate_encode_for_enum(data, &name),
12 syn::Data::Union(_) => unimplemented!("Unions are not supported"),
13 };
14 let expanded = quote! {
15 #[automatically_derived]
16 impl #impl_generics ::cerdito::Encode for #name #type_generics #where_clause {
17 #[_async] fn encode<__CerditoEncoderTypeParam: ::cerdito::Encoder>(
18 &self,
19 encoder: &mut __CerditoEncoderTypeParam
20 ) -> Result<(), __CerditoEncoderTypeParam::Error> {
21 #body
22 }
23 }
24 };
25 proc_macro::TokenStream::from(expanded)
26}
27
28#[proc_macro_derive(Decode)]
29pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
31 let name = &ast.ident;
32 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
33 let body = match ast.data {
34 syn::Data::Struct(ref data) => generate_decode_for_struct(data, &name),
35 syn::Data::Enum(ref data) => generate_decode_for_enum(data, &name),
36 syn::Data::Union(_) => unimplemented!("Unions are not supported"),
37 };
38 let expanded = quote! {
39 #[automatically_derived]
40 impl #impl_generics ::cerdito::Decode for #name #type_generics #where_clause {
41 #[_async] fn decode<__CerditoDecoderTypeParam: ::cerdito::Decoder>(
42 decoder: &mut __CerditoDecoderTypeParam
43 ) -> Result<Self, __CerditoDecoderTypeParam::Error> {
44 #body
45 }
46 }
47 };
48 proc_macro::TokenStream::from(expanded)
49}
50
51fn get_fields(fields: &syn::Fields) -> Vec<(usize, proc_macro2::Ident, String, syn::Type)> {
52 match fields {
53 syn::Fields::Named(fields) => fields
54 .named
55 .iter()
56 .enumerate()
57 .map(|(i, f)| {
58 let default_name = format!("field_{}", i);
59 let field_ident = f
60 .ident
61 .clone()
62 .or(Some(proc_macro2::Ident::new(
63 &default_name,
64 proc_macro2::Span::call_site(),
65 )))
66 .unwrap();
67 let field_name = field_ident.to_string();
68 (i, field_ident, field_name, f.ty.clone())
69 })
70 .collect(),
71 syn::Fields::Unnamed(fields) => fields
72 .unnamed
73 .iter()
74 .enumerate()
75 .map(|(i, f)| {
76 let default_name = format!("field_{}", i);
77 let field_ident = f
78 .ident
79 .clone()
80 .or(Some(proc_macro2::Ident::new(
81 &default_name,
82 proc_macro2::Span::call_site(),
83 )))
84 .unwrap();
85 let field_name = field_ident.to_string();
86 (i, field_ident, field_name, f.ty.clone())
87 })
88 .collect(),
89 syn::Fields::Unit => vec![],
90 }
91}
92
93fn generate_encode_for_struct(
94 data: &syn::DataStruct,
95 name: &proc_macro2::Ident,
96) -> proc_macro2::TokenStream {
97 let name_str = name.to_string();
98 let fields = get_fields(&data.fields);
99 let field_idents: Vec<_> = fields
100 .iter()
101 .map(|(_, ident, _, _)| ident.clone())
102 .collect();
103 let field_codes: Vec<_> = fields
104 .iter()
105 .map(|(i, field_ident, field_name, _)| {
106 quote! {
107 _await!(encoder.encode_elem_begin(#i, Some(#field_name)))?;
108 _await!(#field_ident.encode(encoder))?;
109 _await!(encoder.encode_elem_end())?;
110 }
111 })
112 .collect();
113 let fields_len = fields.len();
114 match &data.fields {
115 syn::Fields::Named(_) => quote! {
116 _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
117 let Self { #(#field_idents),* } = self;
118 #(#field_codes)*
119 _await!(encoder.encode_struct_end())?;
120 Ok(())
121 },
122 syn::Fields::Unnamed(_) => quote! {
123 _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
124 let Self( #(#field_idents),* ) = self;
125 #(#field_codes)*
126 _await!(encoder.encode_struct_end())?;
127 Ok(())
128 },
129 syn::Fields::Unit => quote! {
130 _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
131 _await!(encoder.encode_struct_end())?;
132 Ok(())
133 },
134 }
135}
136
137fn generate_decode_for_struct(
138 data: &syn::DataStruct,
139 name: &proc_macro2::Ident,
140) -> proc_macro2::TokenStream {
141 let name_str = name.to_string();
142 let fields = get_fields(&data.fields);
143 let field_idents: Vec<_> = fields
144 .iter()
145 .map(|(_, ident, _, _)| ident.clone())
146 .collect();
147 let field_codes: Vec<_> = fields
148 .iter()
149 .map(|(i, field_ident, field_name, field_type)| {
150 quote! {
151 _await!(decoder.decode_elem_begin(#i, Some(#field_name)))?;
152 let #field_ident = if #i < __cerdito_len {
153 _await!(<#field_type as ::cerdito::Decode>::decode(decoder))?
154 } else { <#field_type>::default() };
157 _await!(decoder.decode_elem_end())?;
158 }
159 })
160 .collect();
161 let fields_len = fields.len();
162
163 let compat = quote! {
164 if __cerdito_len > #fields_len {
166 _await!(decoder.decode_skip(__cerdito_len - #fields_len))?;
167 }
168 };
169
170 match &data.fields {
171 syn::Fields::Named(_) => quote! {
172 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
173 #(#field_codes)*
174 #compat
175 _await!(decoder.decode_struct_end())?;
176 Ok(Self { #(#field_idents),* })
177 },
178 syn::Fields::Unnamed(_) => quote! {
179 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
180 #(#field_codes)*
181 #compat
182 _await!(decoder.decode_struct_end())?;
183 Ok(Self( #(#field_idents),* ))
184 },
185 syn::Fields::Unit => quote! {
186 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
187 #compat
188 _await!(decoder.decode_struct_end())?;
189 Ok(Self)
190 },
191 }
192}
193
194fn generate_tags(data: &syn::DataEnum) -> Vec<proc_macro2::TokenStream> {
195 let mut current_expr: Option<proc_macro2::TokenStream> = None;
196 let mut current_incr: u32 = 0;
197 data.variants
198 .iter()
199 .map(|v| match &v.discriminant {
200 Some((_, expr)) => {
201 let e = quote! { #expr };
202 current_expr = Some(e.clone());
203 current_incr = 1;
204 e
205 }
206 None => match ¤t_expr {
207 Some(expr) => {
208 let e = quote! { (#expr) + #current_incr };
209 current_incr += 1;
210 e
211 }
212 None => {
213 let e = quote! { #current_incr };
214 current_incr += 1;
215 e
216 }
217 },
218 })
219 .collect()
220}
221
222fn generate_encode_for_enum(
223 data: &syn::DataEnum,
224 name: &proc_macro2::Ident,
225) -> proc_macro2::TokenStream {
226 let name_str = name.to_string();
227 let tags = generate_tags(data);
228 let variant_codes: Vec<_> = data.variants.iter().zip(tags).enumerate().map(|(i, (v, t))| {
229 let variant_name = v.ident.clone();
230 let variant_name_str = v.ident.to_string();
231 let fields = get_fields(&v.fields);
232 let field_idents: Vec<_> = fields
233 .iter()
234 .map(|(_, ident, _, _)| ident.clone())
235 .collect();
236 let field_codes: Vec<_> = fields
237 .iter()
238 .map(|(_i, field_ident, field_name, _field_type)| {
239 quote! {
240 _await!(encoder.encode_elem_begin(#i, Some(#field_name)))?;
241 _await!(#field_ident.encode(encoder))?;
242 _await!(encoder.encode_elem_end())?;
243 }
244 })
245 .collect();
246 let fields_len = fields.len();
247 match &v.fields {
248 syn::Fields::Named(_) => quote! {
249 Self::#variant_name { #(#field_idents),* } => {
250 let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 1, #name_str, #variant_name_str))?;
252 _await!(encoder.encode_struct_begin(#fields_len, None))?;
253 #(#field_codes)*
254 _await!(encoder.encode_struct_end())?;
255 _await!(encoder.encode_enum_end())?;
256 }
257 },
258 syn::Fields::Unnamed(_) => quote! {
259 Self::#variant_name(#(#field_idents),*) => {
260 let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 1, #name_str, #variant_name_str))?;
262 _await!(encoder.encode_struct_begin(#fields_len, None))?;
263 #(#field_codes)*
264 _await!(encoder.encode_struct_end())?;
265 _await!(encoder.encode_enum_end())?;
266 }
267 },
268 syn::Fields::Unit => quote! {
269 Self::#variant_name => {
270 let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 0, #name_str, #variant_name_str))?;
272 _await!(encoder.encode_enum_end())?;
273 }
274 },
275 }
276 }).collect();
277
278 quote! {
279 match self {
280 #(#variant_codes)*
281 }
282 Ok(())
283 }
284}
285
286fn generate_decode_for_enum(
287 data: &syn::DataEnum,
288 name: &proc_macro2::Ident,
289) -> proc_macro2::TokenStream {
290 let name_str = name.to_string();
291 let tags = generate_tags(data);
292 let variant_codes: Vec<_> = data
293 .variants
294 .iter()
295 .zip(tags)
296 .enumerate()
297 .map(|(_i, (v, t))| {
298 let variant_name = v.ident.clone();
299 let fields = get_fields(&v.fields);
300 let field_idents: Vec<_> = fields
301 .iter()
302 .map(|(_, ident, _, _)| ident.clone())
303 .collect();
304 let field_codes: Vec<_> = fields
305 .iter()
306 .map(|(i, field_ident, field_name, field_type)| {
307 quote! {
308 _await!(decoder.decode_elem_begin(#i, Some(#field_name)))?;
309 let #field_ident = if #i < __cerdito_len {
310 _await!(<#field_type as ::cerdito::Decode>::decode(decoder))?
311 } else { <#field_type>::default()
313 };
314 _await!(decoder.decode_elem_end())?;
315 }
316 })
317 .collect();
318
319 let field_defaults: Vec<_> = fields
320 .iter()
321 .map(|(_i, field_ident, _field_name, field_type)| {
322 quote! {
323 let #field_ident = <#field_type>::default();
324 }
325 })
326 .collect();
327
328 let fields_len = fields.len();
329
330 let compat = quote! {
331 if __cerdito_len > #fields_len {
333 _await!(decoder.decode_skip(__cerdito_len - #fields_len))?;
334 }
335 };
336
337 match &v.fields {
338 syn::Fields::Named(_) => quote! {
339 #t => {
340 match __cerdito_enum_len {
341 0 => {
342 #(#field_defaults)*
343 Self::#variant_name { #(#field_idents),* }
344 }
345 1 => {
346 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
347 #(#field_codes)*
348 #compat
349 _await!(decoder.decode_struct_end())?;
350 Self::#variant_name { #(#field_idents),* }
351 }
352 _ => unreachable!(),
353 }
354 }
355 },
356 syn::Fields::Unnamed(_) => quote! {
357 #t => {
358 match __cerdito_enum_len {
359 0 => {
360 #(#field_defaults)*
361 Self::#variant_name(#(#field_idents),*)
362 }
363 1 => {
364 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
365 #(#field_codes)*
366 #compat
367 _await!(decoder.decode_struct_end())?;
368 Self::#variant_name(#(#field_idents),*)
369 }
370 _ => unreachable!(),
371 }
372 }
373 },
374 syn::Fields::Unit => quote! {
375 #t => {
376 match __cerdito_enum_len {
377 0 => {
378 Self::#variant_name
379 }
380 1 => {
381 let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
382 #compat
383 _await!(decoder.decode_struct_end())?;
384 Self::#variant_name
385 }
386 _ => unreachable!(),
387 }
388 }
389 },
390 }
391 })
392 .collect();
393
394 quote! {
395 let (__cerdito_enum_tag, __cerdito_enum_len) = _await!(decoder.decode_enum_begin(#name_str))?;
396 let __cerdito_enum_value = match __cerdito_enum_tag.try_into().unwrap() { #(#variant_codes)*
398 _ => panic!("Enum {:?} doesn't support variant {}", #name_str, __cerdito_enum_tag),
399 };
400 _await!(decoder.decode_enum_end())?;
401 Ok(__cerdito_enum_value)
402 }
403}