cbordata_derive/
lib.rs

1extern crate lazy_static;
2extern crate proc_macro2;
3extern crate proc_macro_error;
4extern crate quote;
5extern crate syn;
6
7use lazy_static::lazy_static;
8use proc_macro2::TokenStream;
9use proc_macro_error::{abort_call_site, proc_macro_error};
10use quote::quote;
11use syn::{spanned::Spanned, *};
12
13mod ty;
14
15lazy_static! {
16    pub(crate) static ref UNNAMED_FIELDS: Vec<&'static str> =
17        vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
18}
19
20#[proc_macro_derive(Cborize, attributes(cbor))]
21#[proc_macro_error]
22pub fn cborize_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
23    let input: DeriveInput = syn::parse(input).unwrap();
24    let gen = match &input.data {
25        Data::Struct(_) => impl_cborize_struct(&input, false),
26        Data::Enum(_) => impl_cborize_enum(&input, false),
27        Data::Union(_) => abort_call_site!("cannot derive Cborize for union"),
28    };
29    gen.into()
30}
31
32#[proc_macro_derive(LocalCborize, attributes(cbor))]
33#[proc_macro_error]
34pub fn local_cborize_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35    let input: DeriveInput = syn::parse(input).unwrap();
36    let gen = match &input.data {
37        Data::Struct(_) => impl_cborize_struct(&input, true),
38        Data::Enum(_) => impl_cborize_enum(&input, true),
39        Data::Union(_) => {
40            abort_call_site!("cannot derive LocalCborize for union")
41        }
42    };
43    gen.into()
44}
45
46fn impl_cborize_struct(input: &DeriveInput, crate_local: bool) -> TokenStream {
47    let name = &input.ident;
48    let generics = no_default_generics(input);
49
50    let mut ts = TokenStream::new();
51    match &input.data {
52        Data::Struct(ast) => {
53            ts.extend(from_struct_to_cbor(
54                name,
55                &generics,
56                &ast.fields,
57                crate_local,
58            ));
59            ts.extend(from_cbor_to_struct(
60                name,
61                &generics,
62                &ast.fields,
63                crate_local,
64            ));
65            ts
66        }
67        _ => unreachable!(),
68    }
69}
70
71fn from_struct_to_cbor(
72    name: &Ident,
73    generics: &Generics,
74    fields: &Fields,
75    crate_local: bool,
76) -> TokenStream {
77    let id_declr = let_id(name, generics);
78    let croot = get_root_crate(crate_local);
79    let preamble = quote! {
80        let val: #croot::Cbor = {
81            #id_declr;
82            #croot::Tag::from_identifier(id).into()
83        };
84        items.push(val);
85    };
86
87    let token_fields = match fields {
88        Fields::Unit => quote! {},
89        Fields::Named(fields) => named_fields_to_cbor(fields, croot.clone()),
90        Fields::Unnamed(_) => {
91            abort_call_site!("unnamed struct not supported for Cborize {}", name)
92        }
93    };
94
95    let mut where_clause = match &generics.where_clause {
96        Some(where_clause) => quote! { #where_clause },
97        None => quote! { where },
98    };
99    for param in generics.params.iter() {
100        let type_var = match param {
101            GenericParam::Type(param) => &param.ident,
102            _ => abort_call_site!("only type parameter are supported"),
103        };
104        where_clause.extend(quote! { #type_var: #croot::IntoCbor, });
105    }
106
107    quote! {
108        impl #generics #croot::IntoCbor for #name #generics #where_clause {
109            fn into_cbor(self) -> #croot::Result<#croot::Cbor> {
110                let value = self;
111                let mut items: Vec<#croot::Cbor> = Vec::default();
112
113                #preamble
114                #token_fields;
115
116                items.into_cbor()
117            }
118        }
119    }
120}
121
122fn from_cbor_to_struct(
123    name: &Ident,
124    generics: &Generics,
125    fields: &Fields,
126    crate_local: bool,
127) -> TokenStream {
128    let name_lit = name.to_string();
129    let croot = get_root_crate(crate_local);
130    let n_fields = match fields {
131        Fields::Unit => 0,
132        Fields::Named(fields) => fields.named.len(),
133        Fields::Unnamed(_) => {
134            abort_call_site!("unnamed struct not supported for Cborize {}", name)
135        }
136    };
137
138    let id_declr = let_id(name, generics);
139    let preamble = quote! {
140        // validate the cbor msg for this type.
141        if items.len() == 0 {
142            #croot::err_at!(FailConvert, msg: "empty msg for {}", #name_lit)?;
143        }
144        let data_id = items.remove(0);
145        let type_id: #croot::Cbor = {
146            #id_declr;
147            #croot::Tag::from_identifier(id).into()
148        };
149        if data_id != type_id {
150            #croot::err_at!(FailConvert, msg: "bad id for {}", #name_lit)?;
151        }
152        if #n_fields != items.len() {
153            #croot::err_at!(FailConvert, msg: "bad arity {} {}", #n_fields, items.len())?;
154        }
155    };
156
157    let token_fields = match fields {
158        Fields::Unit => quote! {},
159        Fields::Named(fields) => {
160            let token_fields = cbor_to_named_fields(fields, croot.clone());
161            quote! { { #token_fields } }
162        }
163        Fields::Unnamed(_) => {
164            abort_call_site!("unnamed struct not supported for Cborize {}", name)
165        }
166    };
167
168    let mut where_clause = match &generics.where_clause {
169        Some(where_clause) => quote! { #where_clause },
170        None => quote! { where },
171    };
172    for param in generics.params.iter() {
173        let type_var = match param {
174            GenericParam::Type(param) => &param.ident,
175            _ => abort_call_site!("only type parameter are supported"),
176        };
177        where_clause.extend(quote! { #type_var: #croot::FromCbor, });
178    }
179
180    quote! {
181        impl #generics #croot::FromCbor for #name #generics #where_clause {
182            fn from_cbor(value: #croot::Cbor) -> #croot::Result<#name #generics> {
183                use #croot::{IntoCbor, Error};
184
185                let mut items = Vec::<#croot::Cbor>::from_cbor(value)?;
186
187                #preamble
188
189                Ok(#name #token_fields)
190            }
191        }
192    }
193}
194
195fn impl_cborize_enum(input: &DeriveInput, crate_local: bool) -> TokenStream {
196    let name = &input.ident;
197    let generics = no_default_generics(input);
198
199    let mut ts = TokenStream::new();
200    match &input.data {
201        Data::Enum(ast) => {
202            let variants: Vec<&Variant> = ast.variants.iter().collect();
203            ts.extend(from_enum_to_cbor(name, &generics, &variants, crate_local));
204            ts.extend(from_cbor_to_enum(name, &generics, &variants, crate_local));
205            ts
206        }
207        _ => unreachable!(),
208    }
209}
210
211fn from_enum_to_cbor(
212    name: &Ident,
213    generics: &Generics,
214    variants: &[&Variant],
215    crate_local: bool,
216) -> TokenStream {
217    let id_declr = let_id(name, generics);
218    let croot = get_root_crate(crate_local);
219    let preamble = quote! {
220        let val: #croot::Cbor = {
221            #id_declr;
222            #croot::Tag::from_identifier(id).into()
223        };
224        items.push(val);
225    };
226
227    let mut tok_variants: TokenStream = TokenStream::new();
228    for variant in variants.iter() {
229        let variant_name = &variant.ident;
230        let variant_lit = variant.ident.to_string();
231        let arm = match &variant.fields {
232            Fields::Unit => {
233                quote! { #name::#variant_name => #variant_lit.into_cbor()? }
234            }
235            Fields::Named(fields) => {
236                let (params, body) = named_var_fields_to_cbor(fields, croot.clone());
237                quote! {
238                    #name::#variant_name{#params} => {
239                        items.push(#variant_lit.into_cbor()?);
240                        #body
241                    },
242                }
243            }
244            Fields::Unnamed(fields) => {
245                let (params, body) = unnamed_fields_to_cbor(fields, croot.clone());
246                quote! {
247                    #name::#variant_name(#params) => {
248                        items.push(#variant_lit.into_cbor()?);
249                        #body
250                    },
251                }
252            }
253        };
254        tok_variants.extend(arm)
255    }
256
257    let mut where_clause = match &generics.where_clause {
258        Some(where_clause) => quote! { #where_clause },
259        None => quote! { where },
260    };
261    for param in generics.params.iter() {
262        let type_var = match param {
263            GenericParam::Type(param) => &param.ident,
264            _ => abort_call_site!("only type parameter are supported"),
265        };
266        where_clause.extend(quote! { #type_var: #croot::IntoCbor, });
267    }
268
269    quote! {
270        impl #generics #croot::IntoCbor for #name #generics #where_clause {
271            fn into_cbor(self) -> #croot::Result<#croot::Cbor> {
272                let value = self;
273
274                let mut items: Vec<#croot::Cbor> = Vec::default();
275
276                #preamble
277                match value {
278                    #tok_variants
279                }
280                items.into_cbor()
281            }
282        }
283    }
284}
285
286fn from_cbor_to_enum(
287    name: &Ident,
288    generics: &Generics,
289    variants: &[&Variant],
290    crate_local: bool,
291) -> TokenStream {
292    let name_lit = name.to_string();
293    let id_declr = let_id(name, generics);
294    let croot = get_root_crate(crate_local);
295    let preamble = quote! {
296        // validate the cbor msg for this type.
297        if items.len() < 2 {
298            #croot::err_at!(FailConvert, msg: "empty msg for {}", #name_lit)?;
299        }
300        let data_id = items.remove(0);
301        let type_id: #croot::Cbor= {
302            #id_declr;
303            #croot::Tag::from_identifier(id).into()
304        };
305        if data_id != type_id {
306            #croot::err_at!(FailConvert, msg: "bad {}", #name_lit)?
307        }
308
309        let variant_name = String::from_cbor(items.remove(0))?;
310    };
311
312    let mut check_variants: TokenStream = TokenStream::new();
313    for variant in variants.iter() {
314        let variant_lit = &variant.ident.to_string();
315        let arm = match &variant.fields {
316            Fields::Named(fields) => {
317                let n_fields = fields.named.len();
318                quote! {
319                   #variant_lit => {
320                        if #n_fields != items.len() {
321                            #croot::err_at!(
322                                FailConvert, msg: "bad arity {} {}",
323                                #n_fields, items.len()
324                            )?;
325                        }
326                    }
327                }
328            }
329            Fields::Unnamed(fields) => {
330                let n_fields = fields.unnamed.len();
331                quote! {
332                    #variant_lit => {
333                        if #n_fields != items.len() {
334                            #croot::err_at!(
335                                FailConvert, msg: "bad arity {} {}",
336                                #n_fields, items.len()
337                            )?;
338                        }
339                    }
340                }
341            }
342            Fields::Unit => {
343                quote! {
344                    #variant_lit => {
345                        if items.len() > 0 {
346                            #croot::err_at!(
347                                FailConvert, msg: "bad arity {}", items.len()
348                            )?;
349                        }
350                    }
351                }
352            }
353        };
354        check_variants.extend(arm)
355    }
356
357    let mut tok_variants: TokenStream = TokenStream::new();
358    for variant in variants.iter() {
359        let variant_name = &variant.ident;
360        let variant_lit = &variant.ident.to_string();
361        let arm = match &variant.fields {
362            Fields::Unit => quote! {
363                #variant_lit => #name::#variant_name
364            },
365            Fields::Named(fields) => {
366                let (_, body) = cbor_to_named_var_fields(fields, croot.clone());
367                quote! { #variant_lit => #name::#variant_name { #body }, }
368            }
369            Fields::Unnamed(fields) => {
370                let (_, body) = cbor_to_unnamed_fields(fields, croot.clone());
371                quote! { #variant_lit => #name::#variant_name(#body), }
372            }
373        };
374        tok_variants.extend(arm);
375    }
376
377    let mut where_clause = match &generics.where_clause {
378        Some(where_clause) => quote! { #where_clause },
379        None => quote! { where },
380    };
381    for param in generics.params.iter() {
382        let type_var = match param {
383            GenericParam::Type(param) => &param.ident,
384            _ => abort_call_site!("only type parameter are supported"),
385        };
386        where_clause.extend(quote! { #type_var: #croot::FromCbor, });
387    }
388    quote! {
389        impl #generics #croot::FromCbor for #name #generics #where_clause {
390            fn from_cbor(value: #croot::Cbor) -> #croot::Result<#name #generics> {
391                use #croot::{IntoCbor, Error};
392
393                let mut items =  Vec::<#croot::Cbor>::from_cbor(value)?;
394
395                #preamble
396
397                match variant_name.as_str() {
398                    #check_variants
399                    _ => #croot::err_at!(
400                        FailConvert, msg: "invalid variant_name {}", variant_name
401                    )?,
402                }
403
404                let val = match variant_name.as_str() {
405                    #tok_variants
406                    _ => #croot::err_at!(
407                        FailConvert, msg: "invalid variant_name {}", variant_name
408                    )?,
409                };
410                Ok(val)
411            }
412        }
413    }
414}
415
416fn named_fields_to_cbor(fields: &FieldsNamed, croot: TokenStream) -> TokenStream {
417    let mut tokens = TokenStream::new();
418    for field in fields.named.iter() {
419        let is_bytes = is_bytes_ty(&field.ty);
420
421        match &field.ident {
422            Some(field_name) if is_bytes => tokens.extend(quote! {
423                items.push(#croot::Cbor::from_bytes(value.#field_name)?);
424            }),
425            Some(field_name) => tokens.extend(quote! {
426                items.push(value.#field_name.into_cbor()?);
427            }),
428            None => (),
429        }
430    }
431    tokens
432}
433
434fn named_var_fields_to_cbor(
435    fields: &FieldsNamed,
436    croot: TokenStream,
437) -> (TokenStream, TokenStream) {
438    let mut params = TokenStream::new();
439    let mut body = TokenStream::new();
440    for field in fields.named.iter() {
441        let is_bytes = is_bytes_ty(&field.ty);
442
443        let field_name = field.ident.as_ref().unwrap();
444        params.extend(quote! { #field_name, });
445
446        match &field.ident {
447            Some(field_name) if is_bytes => body.extend(quote! {
448                items.push(#croot::Cbor::from_bytes(#field_name)?);
449            }),
450            Some(field_name) => body.extend(quote! {
451                items.push(#field_name.into_cbor()?);
452            }),
453            None => (),
454        }
455    }
456    (params, body)
457}
458
459fn unnamed_fields_to_cbor(
460    fields: &FieldsUnnamed,
461    croot: TokenStream,
462) -> (TokenStream, TokenStream) {
463    let mut params = TokenStream::new();
464    let mut body = TokenStream::new();
465    for (field_name, field) in UNNAMED_FIELDS.iter().zip(fields.unnamed.iter()) {
466        let field_name = Ident::new(field_name, field.span());
467        let is_bytes = is_bytes_ty(&field.ty);
468
469        params.extend(quote! { #field_name, });
470
471        if is_bytes {
472            body.extend(quote! {
473                items.push(#croot::Cbor::from_bytes(#field_name)?);
474            });
475        } else {
476            body.extend(quote! {
477                items.push(#field_name.into_cbor()?);
478            });
479        }
480    }
481    (params, body)
482}
483
484fn cbor_to_named_fields(fields: &FieldsNamed, croot: TokenStream) -> TokenStream {
485    let mut tokens = TokenStream::new();
486    for field in fields.named.iter() {
487        let is_bytes = is_bytes_ty(&field.ty);
488
489        let field_name = field.ident.as_ref().unwrap();
490        let ty = &field.ty;
491        let field_tokens = if is_bytes {
492            quote! {
493                #field_name: items.remove(0).into_bytes()?,
494            }
495        } else {
496            quote! {
497                #field_name: <#ty as #croot::FromCbor>::from_cbor(items.remove(0))?,
498            }
499        };
500        tokens.extend(field_tokens);
501    }
502    tokens
503}
504
505fn cbor_to_named_var_fields(
506    fields: &FieldsNamed,
507    croot: TokenStream,
508) -> (TokenStream, TokenStream) {
509    let mut params = TokenStream::new();
510    let mut body = TokenStream::new();
511    for field in fields.named.iter() {
512        let is_bytes = is_bytes_ty(&field.ty);
513
514        let field_name = field.ident.as_ref().unwrap();
515        params.extend(quote! { #field_name, });
516
517        let ty = &field.ty;
518        if is_bytes {
519            body.extend(quote! {
520                #field_name: items.remove(0).into_bytes()?,
521            });
522        } else {
523            body.extend(quote! {
524                #field_name: <#ty as #croot::FromCbor>::from_cbor(items.remove(0))?,
525            });
526        }
527    }
528    (params, body)
529}
530
531fn cbor_to_unnamed_fields(
532    fields: &FieldsUnnamed,
533    croot: TokenStream,
534) -> (TokenStream, TokenStream) {
535    let mut params = TokenStream::new();
536    let mut body = TokenStream::new();
537    for (field_name, field) in UNNAMED_FIELDS.iter().zip(fields.unnamed.iter()) {
538        let field_name = Ident::new(field_name, field.span());
539        let is_bytes = is_bytes_ty(&field.ty);
540
541        params.extend(quote! { #field_name, });
542
543        let ty = &field.ty;
544        if is_bytes {
545            body.extend(quote! { items.remove(0).into_bytes()?, });
546        } else {
547            body.extend(
548                quote! { <#ty as #croot::FromCbor>::from_cbor(items.remove(0))?, },
549            );
550        }
551    }
552    (params, body)
553}
554
555fn let_id(name: &Ident, generics: &Generics) -> TokenStream {
556    if generics.params.is_empty() {
557        quote! { let id = #name::ID.into_cbor()? }
558    } else {
559        quote! { let id = #name::#generics::ID.into_cbor()? }
560    }
561}
562
563fn get_root_crate(crate_local: bool) -> TokenStream {
564    if crate_local {
565        quote! { crate }
566    } else {
567        quote! { ::cbordata }
568    }
569}
570
571fn no_default_generics(input: &DeriveInput) -> Generics {
572    let mut generics = input.generics.clone();
573    generics.params.iter_mut().for_each(|param| {
574        if let GenericParam::Type(param) = param {
575            param.eq_token = None;
576            param.default = None;
577        }
578    });
579    generics
580}
581
582fn is_bytes_ty(ty: &syn::Type) -> bool {
583    match ty::subty_of_vec(ty) {
584        Some(subty) => ty::ty_u8(subty),
585        None => false,
586    }
587}