cerdito_derive/
lib.rs

1extern crate proc_macro;
2use quote::quote;
3
4#[proc_macro_derive(Encode)]
5pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
7    let name = &ast.ident;
8    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
9    let body = match ast.data {
10        syn::Data::Struct(ref data) => generate_encode_for_struct(data, &name),
11        syn::Data::Enum(ref data) => generate_encode_for_enum(data, &name),
12        syn::Data::Union(_) => unimplemented!("Unions are not supported"),
13    };
14    let expanded = quote! {
15        #[automatically_derived]
16        impl #impl_generics ::cerdito::Encode for #name #type_generics #where_clause {
17            #[_async] fn encode<__CerditoEncoderTypeParam: ::cerdito::Encoder>(
18                &self,
19                encoder: &mut __CerditoEncoderTypeParam
20            ) -> Result<(), __CerditoEncoderTypeParam::Error> {
21                #body
22            }
23        }
24    };
25    proc_macro::TokenStream::from(expanded)
26}
27
28#[proc_macro_derive(Decode)]
29pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
31    let name = &ast.ident;
32    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
33    let body = match ast.data {
34        syn::Data::Struct(ref data) => generate_decode_for_struct(data, &name),
35        syn::Data::Enum(ref data) => generate_decode_for_enum(data, &name),
36        syn::Data::Union(_) => unimplemented!("Unions are not supported"),
37    };
38    let expanded = quote! {
39        #[automatically_derived]
40        impl #impl_generics ::cerdito::Decode for #name #type_generics #where_clause {
41            #[_async] fn decode<__CerditoDecoderTypeParam: ::cerdito::Decoder>(
42                decoder: &mut __CerditoDecoderTypeParam
43            ) -> Result<Self, __CerditoDecoderTypeParam::Error> {
44                #body
45            }
46        }
47    };
48    proc_macro::TokenStream::from(expanded)
49}
50
51fn get_fields(fields: &syn::Fields) -> Vec<(usize, proc_macro2::Ident, String, syn::Type)> {
52    match fields {
53        syn::Fields::Named(fields) => fields
54            .named
55            .iter()
56            .enumerate()
57            .map(|(i, f)| {
58                let default_name = format!("field_{}", i);
59                let field_ident = f
60                    .ident
61                    .clone()
62                    .or(Some(proc_macro2::Ident::new(
63                        &default_name,
64                        proc_macro2::Span::call_site(),
65                    )))
66                    .unwrap();
67                let field_name = field_ident.to_string();
68                (i, field_ident, field_name, f.ty.clone())
69            })
70            .collect(),
71        syn::Fields::Unnamed(fields) => fields
72            .unnamed
73            .iter()
74            .enumerate()
75            .map(|(i, f)| {
76                let default_name = format!("field_{}", i);
77                let field_ident = f
78                    .ident
79                    .clone()
80                    .or(Some(proc_macro2::Ident::new(
81                        &default_name,
82                        proc_macro2::Span::call_site(),
83                    )))
84                    .unwrap();
85                let field_name = field_ident.to_string();
86                (i, field_ident, field_name, f.ty.clone())
87            })
88            .collect(),
89        syn::Fields::Unit => vec![],
90    }
91}
92
93fn generate_encode_for_struct(
94    data: &syn::DataStruct,
95    name: &proc_macro2::Ident,
96) -> proc_macro2::TokenStream {
97    let name_str = name.to_string();
98    let fields = get_fields(&data.fields);
99    let field_idents: Vec<_> = fields
100        .iter()
101        .map(|(_, ident, _, _)| ident.clone())
102        .collect();
103    let field_codes: Vec<_> = fields
104        .iter()
105        .map(|(i, field_ident, field_name, _)| {
106            quote! {
107                _await!(encoder.encode_elem_begin(#i, Some(#field_name)))?;
108                _await!(#field_ident.encode(encoder))?;
109                _await!(encoder.encode_elem_end())?;
110            }
111        })
112        .collect();
113    let fields_len = fields.len();
114    match &data.fields {
115        syn::Fields::Named(_) => quote! {
116            _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
117            let Self { #(#field_idents),* } = self;
118            #(#field_codes)*
119            _await!(encoder.encode_struct_end())?;
120            Ok(())
121        },
122        syn::Fields::Unnamed(_) => quote! {
123            _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
124            let Self( #(#field_idents),* ) = self;
125            #(#field_codes)*
126            _await!(encoder.encode_struct_end())?;
127            Ok(())
128        },
129        syn::Fields::Unit => quote! {
130            _await!(encoder.encode_struct_begin(#fields_len, Some(#name_str)))?;
131            _await!(encoder.encode_struct_end())?;
132            Ok(())
133        },
134    }
135}
136
137fn generate_decode_for_struct(
138    data: &syn::DataStruct,
139    name: &proc_macro2::Ident,
140) -> proc_macro2::TokenStream {
141    let name_str = name.to_string();
142    let fields = get_fields(&data.fields);
143    let field_idents: Vec<_> = fields
144        .iter()
145        .map(|(_, ident, _, _)| ident.clone())
146        .collect();
147    let field_codes: Vec<_> = fields
148        .iter()
149        .map(|(i, field_ident, field_name, field_type)| {
150            quote! {
151                _await!(decoder.decode_elem_begin(#i, Some(#field_name)))?;
152                let #field_ident = if #i < __cerdito_len {
153                    _await!(<#field_type as ::cerdito::Decode>::decode(decoder))?
154                } else { // new program, old data
155                    <#field_type>::default() // TODO: Or fail if Default isn't implemented?
156                };
157                _await!(decoder.decode_elem_end())?;
158            }
159        })
160        .collect();
161    let fields_len = fields.len();
162
163    let compat = quote! {
164        // old program, new data
165        if __cerdito_len > #fields_len {
166            _await!(decoder.decode_skip(__cerdito_len - #fields_len))?;
167        }
168    };
169
170    match &data.fields {
171        syn::Fields::Named(_) => quote! {
172            let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
173            #(#field_codes)*
174            #compat
175            _await!(decoder.decode_struct_end())?;
176            Ok(Self { #(#field_idents),* })
177        },
178        syn::Fields::Unnamed(_) => quote! {
179            let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
180            #(#field_codes)*
181            #compat
182            _await!(decoder.decode_struct_end())?;
183            Ok(Self( #(#field_idents),* ))
184        },
185        syn::Fields::Unit => quote! {
186            let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, Some(#name_str)))?;
187            #compat
188            _await!(decoder.decode_struct_end())?;
189            Ok(Self)
190        },
191    }
192}
193
194fn generate_tags(data: &syn::DataEnum) -> Vec<proc_macro2::TokenStream> {
195    let mut current_expr: Option<proc_macro2::TokenStream> = None;
196    let mut current_incr: u32 = 0;
197    data.variants
198        .iter()
199        .map(|v| match &v.discriminant {
200            Some((_, expr)) => {
201                let e = quote! { #expr };
202                current_expr = Some(e.clone());
203                current_incr = 1;
204                e
205            }
206            None => match &current_expr {
207                Some(expr) => {
208                    let e = quote! { (#expr) + #current_incr };
209                    current_incr += 1;
210                    e
211                }
212                None => {
213                    let e = quote! { #current_incr };
214                    current_incr += 1;
215                    e
216                }
217            },
218        })
219        .collect()
220}
221
222fn generate_encode_for_enum(
223    data: &syn::DataEnum,
224    name: &proc_macro2::Ident,
225) -> proc_macro2::TokenStream {
226    let name_str = name.to_string();
227    let tags = generate_tags(data);
228    let variant_codes: Vec<_> = data.variants.iter().zip(tags).enumerate().map(|(i, (v, t))| {
229        let variant_name = v.ident.clone();
230        let variant_name_str = v.ident.to_string();
231        let fields = get_fields(&v.fields);
232        let field_idents: Vec<_> = fields
233            .iter()
234            .map(|(_, ident, _, _)| ident.clone())
235            .collect();
236        let field_codes: Vec<_> = fields
237            .iter()
238            .map(|(_i, field_ident, field_name, _field_type)| {
239                quote! {
240                    _await!(encoder.encode_elem_begin(#i, Some(#field_name)))?;
241                    _await!(#field_ident.encode(encoder))?;
242                    _await!(encoder.encode_elem_end())?;
243                }
244            })
245            .collect();
246        let fields_len = fields.len();
247        match &v.fields {
248            syn::Fields::Named(_) => quote! {
249                Self::#variant_name { #(#field_idents),* } => {
250                    let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); //TODO: error
251                    _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 1, #name_str, #variant_name_str))?;
252                    _await!(encoder.encode_struct_begin(#fields_len, None))?;
253                    #(#field_codes)*
254                    _await!(encoder.encode_struct_end())?;
255                    _await!(encoder.encode_enum_end())?;
256                }
257            },
258            syn::Fields::Unnamed(_) => quote! {
259                Self::#variant_name(#(#field_idents),*) => {
260                    let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); //TODO: error
261                    _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 1, #name_str, #variant_name_str))?;
262                    _await!(encoder.encode_struct_begin(#fields_len, None))?;
263                    #(#field_codes)*
264                    _await!(encoder.encode_struct_end())?;
265                    _await!(encoder.encode_enum_end())?;
266                }
267            },
268            syn::Fields::Unit => quote! {
269                Self::#variant_name => {
270                    let __cerdito_enum_tag: u32 = (#t).try_into().unwrap(); //TODO: error
271                    _await!(encoder.encode_enum_begin(__cerdito_enum_tag, 0, #name_str, #variant_name_str))?;
272                    _await!(encoder.encode_enum_end())?;
273                }
274            },
275        }
276    }).collect();
277
278    quote! {
279        match self {
280            #(#variant_codes)*
281        }
282        Ok(())
283    }
284}
285
286fn generate_decode_for_enum(
287    data: &syn::DataEnum,
288    name: &proc_macro2::Ident,
289) -> proc_macro2::TokenStream {
290    let name_str = name.to_string();
291    let tags = generate_tags(data);
292    let variant_codes: Vec<_> = data
293        .variants
294        .iter()
295        .zip(tags)
296        .enumerate()
297        .map(|(_i, (v, t))| {
298            let variant_name = v.ident.clone();
299            let fields = get_fields(&v.fields);
300            let field_idents: Vec<_> = fields
301                .iter()
302                .map(|(_, ident, _, _)| ident.clone())
303                .collect();
304            let field_codes: Vec<_> = fields
305                .iter()
306                .map(|(i, field_ident, field_name, field_type)| {
307                    quote! {
308                        _await!(decoder.decode_elem_begin(#i, Some(#field_name)))?;
309                        let #field_ident = if #i < __cerdito_len {
310                            _await!(<#field_type as ::cerdito::Decode>::decode(decoder))?
311                        } else { // new program, old data
312                            <#field_type>::default()
313                        };
314                        _await!(decoder.decode_elem_end())?;
315                    }
316                })
317                .collect();
318
319            let field_defaults: Vec<_> = fields
320                .iter()
321                .map(|(_i, field_ident, _field_name, field_type)| {
322                    quote! {
323                        let #field_ident = <#field_type>::default();
324                    }
325                })
326                .collect();
327
328            let fields_len = fields.len();
329
330            let compat = quote! {
331                // old program, new data
332                if __cerdito_len > #fields_len {
333                    _await!(decoder.decode_skip(__cerdito_len - #fields_len))?;
334                }
335            };
336
337            match &v.fields {
338                syn::Fields::Named(_) => quote! {
339                    #t => {
340                        match __cerdito_enum_len {
341                            0 => {
342                                #(#field_defaults)*
343                                Self::#variant_name { #(#field_idents),* }
344                            }
345                            1 => {
346                                let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
347                                #(#field_codes)*
348                                #compat
349                                _await!(decoder.decode_struct_end())?;
350                                Self::#variant_name { #(#field_idents),* }
351                            }
352                            _ => unreachable!(),
353                        }
354                    }
355                },
356                syn::Fields::Unnamed(_) => quote! {
357                    #t => {
358                        match __cerdito_enum_len {
359                            0 => {
360                                #(#field_defaults)*
361                                Self::#variant_name(#(#field_idents),*)
362                            }
363                            1 => {
364                                let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
365                                #(#field_codes)*
366                                #compat
367                                _await!(decoder.decode_struct_end())?;
368                                Self::#variant_name(#(#field_idents),*)
369                            }
370                            _ => unreachable!(),
371                        }
372                    }
373                },
374                syn::Fields::Unit => quote! {
375                    #t => {
376                        match __cerdito_enum_len {
377                            0 => {
378                                Self::#variant_name
379                            }
380                            1 => {
381                                let __cerdito_len = _await!(decoder.decode_struct_begin(#fields_len, None))?;
382                                #compat
383                                _await!(decoder.decode_struct_end())?;
384                                Self::#variant_name
385                            }
386                            _ => unreachable!(),
387                        }
388                    }
389                },
390            }
391        })
392        .collect();
393
394    quote! {
395        let (__cerdito_enum_tag, __cerdito_enum_len) = _await!(decoder.decode_enum_begin(#name_str))?;
396        let __cerdito_enum_value = match __cerdito_enum_tag.try_into().unwrap() { // TODO: error
397                #(#variant_codes)*
398                _ => panic!("Enum {:?} doesn't support variant {}", #name_str, __cerdito_enum_tag),
399        };
400        _await!(decoder.decode_enum_end())?;
401        Ok(__cerdito_enum_value)
402    }
403}