mkit_derive/
lib.rs

1extern crate proc_macro;
2
3use lazy_static::lazy_static;
4use proc_macro2::TokenStream;
5use proc_macro_error::{abort_call_site, proc_macro_error};
6use quote::quote;
7use syn::{spanned::Spanned, *};
8
9mod ty;
10
11lazy_static! {
12    pub(crate) static ref UNNAMED_FIELDS: Vec<&'static str> =
13        vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
14}
15
16#[proc_macro_derive(Cborize, attributes(cbor))]
17#[proc_macro_error]
18pub fn cborize_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19    let input: DeriveInput = syn::parse(input).unwrap();
20    let gen = match &input.data {
21        Data::Struct(_) => impl_cborize_struct(&input, false),
22        Data::Enum(_) => impl_cborize_enum(&input, false),
23        Data::Union(_) => abort_call_site!("cannot derive Cborize for union"),
24    };
25    gen.into()
26}
27
28#[proc_macro_derive(LocalCborize, attributes(cbor))]
29#[proc_macro_error]
30pub fn local_cborize_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31    let input: DeriveInput = syn::parse(input).unwrap();
32    let gen = match &input.data {
33        Data::Struct(_) => impl_cborize_struct(&input, true),
34        Data::Enum(_) => impl_cborize_enum(&input, true),
35        Data::Union(_) => {
36            abort_call_site!("cannot derive LocalCborize for union")
37        }
38    };
39    gen.into()
40}
41
42fn impl_cborize_struct(input: &DeriveInput, crate_local: bool) -> TokenStream {
43    let name = &input.ident;
44    let generics = no_default_generics(input);
45
46    let mut ts = TokenStream::new();
47    match &input.data {
48        Data::Struct(ast) => {
49            ts.extend(from_struct_to_cbor(
50                name,
51                &generics,
52                &ast.fields,
53                crate_local,
54            ));
55            ts.extend(from_cbor_to_struct(
56                name,
57                &generics,
58                &ast.fields,
59                crate_local,
60            ));
61            ts
62        }
63        _ => unreachable!(),
64    }
65}
66
67fn from_struct_to_cbor(
68    name: &Ident,
69    generics: &Generics,
70    fields: &Fields,
71    crate_local: bool,
72) -> TokenStream {
73    let id_declr = let_id(name, generics);
74    let croot = get_root_crate(crate_local);
75    let preamble = quote! {
76        let val: #croot::cbor::Cbor = {
77            #id_declr;
78            #croot::cbor::Tag::from_identifier(id).into()
79        };
80        items.push(val);
81    };
82
83    let token_fields = match fields {
84        Fields::Unit => quote! {},
85        Fields::Named(fields) => named_fields_to_cbor(fields, croot.clone()),
86        Fields::Unnamed(_) => {
87            abort_call_site!("unnamed struct not supported for Cborize {}", name)
88        }
89    };
90
91    let mut where_clause = match &generics.where_clause {
92        Some(where_clause) => quote! { #where_clause },
93        None => quote! { where },
94    };
95    for param in generics.params.iter() {
96        let type_var = match param {
97            GenericParam::Type(param) => &param.ident,
98            _ => abort_call_site!("only type parameter are supported"),
99        };
100        where_clause.extend(quote! { #type_var: #croot::cbor::IntoCbor, });
101    }
102
103    quote! {
104        impl#generics #croot::cbor::IntoCbor for #name#generics #where_clause {
105            fn into_cbor(self) -> #croot::Result<#croot::cbor::Cbor> {
106                let value = self;
107                let mut items: Vec<#croot::cbor::Cbor> = Vec::default();
108
109                #preamble
110                #token_fields;
111
112                Ok(items.into_cbor()?)
113            }
114        }
115    }
116}
117
118fn from_cbor_to_struct(
119    name: &Ident,
120    generics: &Generics,
121    fields: &Fields,
122    crate_local: bool,
123) -> TokenStream {
124    let name_lit = name.to_string();
125    let croot = get_root_crate(crate_local);
126    let n_fields = match fields {
127        Fields::Unit => 0,
128        Fields::Named(fields) => fields.named.len(),
129        Fields::Unnamed(_) => {
130            abort_call_site!("unnamed struct not supported for Cborize {}", name)
131        }
132    };
133
134    let id_declr = let_id(name, generics);
135    let preamble = quote! {
136        // validate the cbor msg for this type.
137        if items.len() == 0 {
138            #croot::err_at!(FailConvert, msg: "empty msg for {}", #name_lit)?;
139        }
140        let data_id = items.remove(0);
141        let type_id: #croot::cbor::Cbor = {
142            #id_declr;
143            #croot::cbor::Tag::from_identifier(id).into()
144        };
145        if data_id != type_id {
146            #croot::err_at!(FailConvert, msg: "bad id for {}", #name_lit)?;
147        }
148        if #n_fields != items.len() {
149            #croot::err_at!(FailConvert, msg: "bad arity {} {}", #n_fields, items.len())?;
150        }
151    };
152
153    let token_fields = match fields {
154        Fields::Unit => quote! {},
155        Fields::Named(fields) => {
156            let token_fields = cbor_to_named_fields(fields, croot.clone());
157            quote! { { #token_fields } }
158        }
159        Fields::Unnamed(_) => {
160            abort_call_site!("unnamed struct not supported for Cborize {}", name)
161        }
162    };
163
164    let mut where_clause = match &generics.where_clause {
165        Some(where_clause) => quote! { #where_clause },
166        None => quote! { where },
167    };
168    for param in generics.params.iter() {
169        let type_var = match param {
170            GenericParam::Type(param) => &param.ident,
171            _ => abort_call_site!("only type parameter are supported"),
172        };
173        where_clause.extend(quote! { #type_var: #croot::cbor::FromCbor, });
174    }
175
176    quote! {
177        impl#generics #croot::cbor::FromCbor for #name#generics #where_clause {
178            fn from_cbor(value: #croot::cbor::Cbor) -> #croot::Result<#name#generics> {
179                use #croot::{cbor::IntoCbor, Error};
180
181                let mut items = Vec::<#croot::cbor::Cbor>::from_cbor(value)?;
182
183                #preamble
184
185                Ok(#name #token_fields)
186            }
187        }
188    }
189}
190
191fn impl_cborize_enum(input: &DeriveInput, crate_local: bool) -> TokenStream {
192    let name = &input.ident;
193    let generics = no_default_generics(input);
194
195    let mut ts = TokenStream::new();
196    match &input.data {
197        Data::Enum(ast) => {
198            let variants: Vec<&Variant> = ast.variants.iter().collect();
199            ts.extend(from_enum_to_cbor(name, &generics, &variants, crate_local));
200            ts.extend(from_cbor_to_enum(name, &generics, &variants, crate_local));
201            ts
202        }
203        _ => unreachable!(),
204    }
205}
206
207fn from_enum_to_cbor(
208    name: &Ident,
209    generics: &Generics,
210    variants: &[&Variant],
211    crate_local: bool,
212) -> TokenStream {
213    let id_declr = let_id(name, generics);
214    let croot = get_root_crate(crate_local);
215    let preamble = quote! {
216        let val: #croot::cbor::Cbor = {
217            #id_declr;
218            #croot::cbor::Tag::from_identifier(id).into()
219        };
220        items.push(val);
221    };
222
223    let mut tok_variants: TokenStream = TokenStream::new();
224    for variant in variants.iter() {
225        let variant_name = &variant.ident;
226        let variant_lit = variant.ident.to_string();
227        let arm = match &variant.fields {
228            Fields::Unit => {
229                quote! { #name::#variant_name => #variant_lit.into_cbor()? }
230            }
231            Fields::Named(fields) => {
232                let (params, body) = named_var_fields_to_cbor(fields, croot.clone());
233                quote! {
234                    #name::#variant_name{#params} => {
235                        items.push(#variant_lit.into_cbor()?);
236                        #body
237                    },
238                }
239            }
240            Fields::Unnamed(fields) => {
241                let (params, body) = unnamed_fields_to_cbor(fields, croot.clone());
242                quote! {
243                    #name::#variant_name(#params) => {
244                        items.push(#variant_lit.into_cbor()?);
245                        #body
246                    },
247                }
248            }
249        };
250        tok_variants.extend(arm)
251    }
252
253    let mut where_clause = match &generics.where_clause {
254        Some(where_clause) => quote! { #where_clause },
255        None => quote! { where },
256    };
257    for param in generics.params.iter() {
258        let type_var = match param {
259            GenericParam::Type(param) => &param.ident,
260            _ => abort_call_site!("only type parameter are supported"),
261        };
262        where_clause.extend(quote! { #type_var: #croot::cbor::IntoCbor, });
263    }
264
265    quote! {
266        impl#generics #croot::cbor::IntoCbor for #name#generics #where_clause {
267            fn into_cbor(self) -> #croot::Result<#croot::cbor::Cbor> {
268                let value = self;
269
270                let mut items: Vec<#croot::cbor::Cbor> = Vec::default();
271
272                #preamble
273                match value {
274                    #tok_variants
275                }
276                Ok(items.into_cbor()?)
277            }
278        }
279    }
280}
281
282fn from_cbor_to_enum(
283    name: &Ident,
284    generics: &Generics,
285    variants: &[&Variant],
286    crate_local: bool,
287) -> TokenStream {
288    let name_lit = name.to_string();
289    let id_declr = let_id(name, generics);
290    let croot = get_root_crate(crate_local);
291    let preamble = quote! {
292        // validate the cbor msg for this type.
293        if items.len() < 2 {
294            #croot::err_at!(FailConvert, msg: "empty msg for {}", #name_lit)?;
295        }
296        let data_id = items.remove(0);
297        let type_id: #croot::cbor::Cbor= {
298            #id_declr;
299            #croot::cbor::Tag::from_identifier(id).into()
300        };
301        if data_id != type_id {
302            #croot::err_at!(FailConvert, msg: "bad {}", #name_lit)?
303        }
304
305        let variant_name = String::from_cbor(items.remove(0))?;
306    };
307
308    let mut check_variants: TokenStream = TokenStream::new();
309    for variant in variants.iter() {
310        let variant_lit = &variant.ident.to_string();
311        let arm = match &variant.fields {
312            Fields::Named(fields) => {
313                let n_fields = fields.named.len();
314                quote! {
315                   #variant_lit => {
316                        if #n_fields != items.len() {
317                            #croot::err_at!(
318                                FailConvert, msg: "bad arity {} {}",
319                                #n_fields, items.len()
320                            )?;
321                        }
322                    }
323                }
324            }
325            Fields::Unnamed(fields) => {
326                let n_fields = fields.unnamed.len();
327                quote! {
328                    #variant_lit => {
329                        if #n_fields != items.len() {
330                            #croot::err_at!(
331                                FailConvert, msg: "bad arity {} {}",
332                                #n_fields, items.len()
333                            )?;
334                        }
335                    }
336                }
337            }
338            Fields::Unit => {
339                quote! {
340                    #variant_lit => {
341                        if items.len() > 0 {
342                            #croot::err_at!(
343                                FailConvert, msg: "bad arity {}", items.len()
344                            )?;
345                        }
346                    }
347                }
348            }
349        };
350        check_variants.extend(arm)
351    }
352
353    let mut tok_variants: TokenStream = TokenStream::new();
354    for variant in variants.iter() {
355        let variant_name = &variant.ident;
356        let variant_lit = &variant.ident.to_string();
357        let arm = match &variant.fields {
358            Fields::Unit => quote! {
359                #variant_lit => #name::#variant_name
360            },
361            Fields::Named(fields) => {
362                let (_, body) = cbor_to_named_var_fields(fields, croot.clone());
363                quote! { #variant_lit => #name::#variant_name { #body }, }
364            }
365            Fields::Unnamed(fields) => {
366                let (_, body) = cbor_to_unnamed_fields(fields, croot.clone());
367                quote! { #variant_lit => #name::#variant_name(#body), }
368            }
369        };
370        tok_variants.extend(arm);
371    }
372
373    let mut where_clause = match &generics.where_clause {
374        Some(where_clause) => quote! { #where_clause },
375        None => quote! { where },
376    };
377    for param in generics.params.iter() {
378        let type_var = match param {
379            GenericParam::Type(param) => &param.ident,
380            _ => abort_call_site!("only type parameter are supported"),
381        };
382        where_clause.extend(quote! { #type_var: #croot::cbor::FromCbor, });
383    }
384    quote! {
385        impl#generics #croot::cbor::FromCbor for #name#generics #where_clause {
386            fn from_cbor(value: #croot::cbor::Cbor) -> #croot::Result<#name#generics> {
387                use #croot::{cbor::IntoCbor, Error};
388
389                let mut items =  Vec::<#croot::cbor::Cbor>::from_cbor(value)?;
390
391                #preamble
392
393                match variant_name.as_str() {
394                    #check_variants
395                    _ => #croot::err_at!(
396                        FailConvert, msg: "invalid variant_name {}", variant_name
397                    )?,
398                }
399
400                let val = match variant_name.as_str() {
401                    #tok_variants
402                    _ => #croot::err_at!(
403                        FailConvert, msg: "invalid variant_name {}", variant_name
404                    )?,
405                };
406                Ok(val)
407            }
408        }
409    }
410}
411
412fn named_fields_to_cbor(fields: &FieldsNamed, croot: TokenStream) -> TokenStream {
413    let mut tokens = TokenStream::new();
414    for field in fields.named.iter() {
415        let is_bytes = is_bytes_ty(&field.ty);
416
417        match &field.ident {
418            Some(field_name) if is_bytes => tokens.extend(quote! {
419                items.push(#croot::cbor::Cbor::bytes_into_cbor(value.#field_name)?);
420            }),
421            Some(field_name) => tokens.extend(quote! {
422                items.push(value.#field_name.into_cbor()?);
423            }),
424            None => (),
425        }
426    }
427    tokens
428}
429
430fn named_var_fields_to_cbor(
431    fields: &FieldsNamed,
432    croot: TokenStream,
433) -> (TokenStream, TokenStream) {
434    let mut params = TokenStream::new();
435    let mut body = TokenStream::new();
436    for field in fields.named.iter() {
437        let is_bytes = is_bytes_ty(&field.ty);
438
439        let field_name = field.ident.as_ref().unwrap();
440        params.extend(quote! { #field_name, });
441
442        match &field.ident {
443            Some(field_name) if is_bytes => body.extend(quote! {
444                items.push(#croot::cbor::Cbor::bytes_into_cbor(#field_name)?);
445            }),
446            Some(field_name) => body.extend(quote! {
447                items.push(#field_name.into_cbor()?);
448            }),
449            None => (),
450        }
451    }
452    (params, body)
453}
454
455fn unnamed_fields_to_cbor(
456    fields: &FieldsUnnamed,
457    croot: TokenStream,
458) -> (TokenStream, TokenStream) {
459    let mut params = TokenStream::new();
460    let mut body = TokenStream::new();
461    for (field_name, field) in UNNAMED_FIELDS.iter().zip(fields.unnamed.iter()) {
462        let field_name = Ident::new(field_name, field.span());
463        let is_bytes = is_bytes_ty(&field.ty);
464
465        params.extend(quote! { #field_name, });
466
467        if is_bytes {
468            body.extend(quote! {
469                items.push(#croot::cbor::Cbor::bytes_into_cbor(#field_name)?);
470            });
471        } else {
472            body.extend(quote! {
473                items.push(#field_name.into_cbor()?);
474            });
475        }
476    }
477    (params, body)
478}
479
480fn cbor_to_named_fields(fields: &FieldsNamed, croot: TokenStream) -> TokenStream {
481    let mut tokens = TokenStream::new();
482    for field in fields.named.iter() {
483        let is_bytes = is_bytes_ty(&field.ty);
484
485        let field_name = field.ident.as_ref().unwrap();
486        let ty = &field.ty;
487        let field_tokens = if is_bytes {
488            quote! {
489                #field_name: items.remove(0).into_bytes()?,
490            }
491        } else {
492            quote! {
493                #field_name: <#ty as #croot::cbor::FromCbor>::from_cbor(items.remove(0))?,
494            }
495        };
496        tokens.extend(field_tokens);
497    }
498    tokens
499}
500
501fn cbor_to_named_var_fields(
502    fields: &FieldsNamed,
503    croot: TokenStream,
504) -> (TokenStream, TokenStream) {
505    let mut params = TokenStream::new();
506    let mut body = TokenStream::new();
507    for field in fields.named.iter() {
508        let is_bytes = is_bytes_ty(&field.ty);
509
510        let field_name = field.ident.as_ref().unwrap();
511        params.extend(quote! { #field_name, });
512
513        let ty = &field.ty;
514        if is_bytes {
515            body.extend(quote! {
516                #field_name: items.remove(0).into_bytes()?,
517            });
518        } else {
519            body.extend(quote! {
520                #field_name: <#ty as #croot::cbor::FromCbor>::from_cbor(items.remove(0))?,
521            });
522        }
523    }
524    (params, body)
525}
526
527fn cbor_to_unnamed_fields(
528    fields: &FieldsUnnamed,
529    croot: TokenStream,
530) -> (TokenStream, TokenStream) {
531    let mut params = TokenStream::new();
532    let mut body = TokenStream::new();
533    for (field_name, field) in UNNAMED_FIELDS.iter().zip(fields.unnamed.iter()) {
534        let field_name = Ident::new(field_name, field.span());
535        let is_bytes = is_bytes_ty(&field.ty);
536
537        params.extend(quote! { #field_name, });
538
539        let ty = &field.ty;
540        if is_bytes {
541            body.extend(quote! { items.remove(0).into_bytes()?, });
542        } else {
543            body.extend(
544                quote! { <#ty as #croot::cbor::FromCbor>::from_cbor(items.remove(0))?, },
545            );
546        }
547    }
548    (params, body)
549}
550
551fn let_id(name: &Ident, generics: &Generics) -> TokenStream {
552    if generics.params.is_empty() {
553        quote! { let id = #name::ID.into_cbor()? }
554    } else {
555        quote! { let id = #name::#generics::ID.into_cbor()? }
556    }
557}
558
559fn get_root_crate(crate_local: bool) -> TokenStream {
560    if crate_local {
561        quote! { crate }
562    } else {
563        quote! { ::mkit }
564    }
565}
566
567fn no_default_generics(input: &DeriveInput) -> Generics {
568    let mut generics = input.generics.clone();
569    generics.params.iter_mut().for_each(|param| match param {
570        GenericParam::Type(param) => {
571            param.eq_token = None;
572            param.default = None;
573        }
574        _ => (),
575    });
576    generics
577}
578
579fn is_bytes_ty(ty: &syn::Type) -> bool {
580    match ty::subty_of_vec(ty) {
581        Some(subty) => ty::ty_u8(subty),
582        None => false,
583    }
584}