generic_bytes_derive/
lib.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! # Derive macros for opaque-ke
7
8use proc_macro2::{Span, TokenStream};
9use quote::{quote, quote_spanned};
10use syn::{
11    parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam, Generics, Index,
12};
13
14//////////////////////////
15// TryFromForSizedBytes //
16//////////////////////////
17
18/// Derive TryFrom<&[u8], Error = ErrorType> for any T: SizedBytes, assuming
19/// ErrorType: Default. This proc-macro is here to work around the lack of
20/// specialization, but there's nothing otherwise clever about it.
21#[proc_macro_derive(TryFromForSizedBytes, attributes(ErrorType))]
22pub fn try_from_for_sized_bytes(source: proc_macro::TokenStream) -> proc_macro::TokenStream {
23    let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
24    let name = &ast.ident;
25
26    let error_type = get_type_from_attrs(&ast.attrs, "ErrorType").unwrap();
27
28    let generics = add_basic_bound(ast.generics);
29    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
30
31    let gen = quote! {
32        impl #impl_generics ::std::convert::TryFrom<&[u8]> for #name #ty_generics #where_clause {
33            type Error = #error_type;
34
35            fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
36                let expected_len = <<Self as ::generic_bytes::SizedBytes>::Len as generic_array::typenum::Unsigned>::to_usize();
37                if bytes.len() != expected_len {
38                    return Err(#error_type::default());
39                }
40                let arr = GenericArray::from_slice(bytes);
41                <Self as ::generic_bytes::SizedBytes>::from_arr(arr).map_err(|_| #error_type::default())
42            }
43        }
44    };
45    gen.into()
46}
47
48fn get_type_from_attrs(attrs: &[syn::Attribute], attr_name: &str) -> syn::Result<syn::Type> {
49    attrs
50        .iter()
51        .find(|attr| attr.path.is_ident(attr_name))
52        .map_or_else(
53            || {
54                Err(syn::Error::new(
55                    proc_macro2::Span::call_site(),
56                    format!("Could not find attribute {}", attr_name),
57                ))
58            },
59            |attr| match attr.parse_meta()? {
60                syn::Meta::NameValue(meta) => {
61                    if let syn::Lit::Str(lit) = &meta.lit {
62                        Ok(lit.clone())
63                    } else {
64                        Err(syn::Error::new_spanned(
65                            meta,
66                            &format!("Could not parse {} attribute", attr_name)[..],
67                        ))
68                    }
69                }
70                bad => Err(syn::Error::new_spanned(
71                    bad,
72                    &format!("Could not parse {} attribute", attr_name)[..],
73                )),
74            },
75        )
76        .and_then(|str| str.parse())
77}
78
79// add `T: SizedBytes` to each generic parameter
80fn add_basic_bound(mut generics: Generics) -> Generics {
81    for param in &mut generics.params {
82        if let GenericParam::Type(ref mut type_param) = *param {
83            type_param
84                .bounds
85                .push(parse_quote!(::generic_bytes::SizedBytes));
86        }
87    }
88    generics
89}
90
91////////////////
92// SizedBytes //
93////////////////
94
95// add where cause which reflects the bound propagation for generic SizedBytes clauses
96fn add_trait_bounds(
97    generics: &mut Generics,
98    data: &syn::Data,
99    bound: syn::Path,
100) -> Result<(), syn::Error> {
101    if generics.params.is_empty() {
102        return Ok(());
103    }
104
105    let types = collect_types(data)?;
106    if !types.is_empty() {
107        let where_clause = generics.make_where_clause();
108
109        types
110            .into_iter()
111            .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #bound)));
112        bounds_sum(data, where_clause)?;
113    }
114
115    Ok(())
116}
117
118fn collect_types(data: &syn::Data) -> Result<Vec<syn::Type>, syn::Error> {
119    use syn::*;
120
121    let types = match *data {
122        Data::Struct(ref data) => match &data.fields {
123            Fields::Named(FieldsNamed { named: fields, .. })
124            | Fields::Unnamed(FieldsUnnamed {
125                unnamed: fields, ..
126            }) => fields.iter().map(|f| f.ty.clone()).collect(),
127
128            Fields::Unit => Vec::new(),
129        },
130
131        Data::Enum(ref data) => data
132            .variants
133            .iter()
134            .flat_map(|variant| match &variant.fields {
135                Fields::Named(FieldsNamed { named: fields, .. })
136                | Fields::Unnamed(FieldsUnnamed {
137                    unnamed: fields, ..
138                }) => fields.iter().map(|f| f.ty.clone()).collect(),
139
140                Fields::Unit => Vec::new(),
141            })
142            .collect(),
143
144        Data::Union(_) => {
145            return Err(Error::new(
146                Span::call_site(),
147                "Union types are not supported.",
148            ))
149        }
150    };
151
152    Ok(types)
153}
154
155fn extract_size_type_from_generic_array(ty: &syn::Type) -> Option<&syn::Type> {
156    fn path_is_generic_array(path: &syn::Path) -> Option<&syn::GenericArgument> {
157        path.segments.iter().find_map(|pt| {
158            if pt.ident == "GenericArray" {
159                // It should have only on angle-bracketed param ("<Foo, Bar>"):
160                match &pt.arguments {
161                    syn::PathArguments::AngleBracketed(params) if params.args.len() == 2 => {
162                        params.args.last()
163                    }
164                    _ => None,
165                }
166            } else {
167                None
168            }
169        })
170    }
171
172    match ty {
173        syn::Type::Path(typepath)
174            if typepath.qself.is_none()
175                && typepath
176                    .path
177                    .segments
178                    .iter()
179                    .any(|pt| pt.ident == "GenericArray") =>
180        {
181            // Get the second parameter of the GenericArray
182            let type_param = path_is_generic_array(&typepath.path);
183            // This argument must be a type:
184            if let Some(syn::GenericArgument::Type(ty)) = type_param {
185                Some(ty)
186            } else {
187                None
188            }
189        }
190        _ => None,
191    }
192}
193
194fn bounds_sum(data: &Data, where_clause: &mut syn::WhereClause) -> Result<(), syn::Error> {
195    match *data {
196        Data::Struct(ref data) => {
197            match data.fields {
198                Fields::Named(ref fields) => {
199                    let mut quote = None;
200                    for f in fields.named.iter() {
201                        let ty = &f.ty;
202                        let res =
203                            if let Some(unsigned_ty) = extract_size_type_from_generic_array(ty) {
204                                quote_spanned! {f.span()=>
205                                                #unsigned_ty
206                                }
207                            } else {
208                                quote_spanned! {f.span()=>
209                                                <#ty as ::generic_bytes::SizedBytes>::Len
210                                }
211                            };
212                        if let Some(ih) = quote {
213                            quote = Some(quote! {
214                                ::generic_array::typenum::Sum<#ih, #res>
215                            });
216                            where_clause
217                                .predicates
218                                .push(parse_quote!(#ih: ::core::ops::Add<#res>));
219                            where_clause
220                                .predicates
221                                .push(parse_quote!(::generic_array::typenum::Sum<#ih, #res> : ::generic_array::ArrayLength<u8> + ::core::ops::Sub<#ih, Output = #res>));
222                            where_clause
223                                .predicates
224                                .push(parse_quote!(::generic_array::typenum::Diff<::generic_array::typenum::Sum<#ih, #res>, #ih> : ::generic_array::ArrayLength<u8>));
225                        } else {
226                            quote = Some(res);
227                        }
228                    }
229                    Ok(())
230                }
231                Fields::Unnamed(ref fields) => {
232                    let mut quote = None;
233                    for f in fields.unnamed.iter() {
234                        let ty = &f.ty;
235                        let res =
236                            if let Some(unsigned_ty) = extract_size_type_from_generic_array(ty) {
237                                quote_spanned! {f.span()=>
238                                                #unsigned_ty
239                                }
240                            } else {
241                                quote_spanned! {f.span()=>
242                                                <#ty as ::generic_bytes::SizedBytes>::Len
243                                }
244                            };
245                        if let Some(ih) = quote {
246                            quote = Some(quote! {
247                                ::generic_array::typenum::Sum<#ih, #res>
248                            });
249                            where_clause
250                                .predicates
251                                .push(parse_quote!(#ih : ::core::ops::Add<#res>));
252                            where_clause
253                                .predicates
254                                .push(parse_quote!(::generic_array::typenum::Sum<#ih, #res> : ::generic_array::ArrayLength<u8> + ::core::ops::Sub<#ih, Output = #res>));
255                            where_clause
256                                .predicates
257                                .push(parse_quote!(::generic_array::typenum::Diff<::generic_array::typenum::Sum<#ih, #res>, #ih> : ::generic_array::ArrayLength<u8>));
258                        } else {
259                            quote = Some(res);
260                        }
261                    }
262                    Ok(())
263                }
264                Fields::Unit => {
265                    // Unit structs cannot own more than 0 bytes of heap memory.
266                    unimplemented!()
267                }
268            }
269        }
270        Data::Enum(_) | Data::Union(_) => unimplemented!(),
271    }
272}
273
274// create a type expression summing up the ::Len of each field
275fn sum(data: &Data) -> TokenStream {
276    match *data {
277        Data::Struct(ref data) => {
278            match data.fields {
279                Fields::Named(ref fields) => {
280                    let mut quote = None;
281                    for f in fields.named.iter() {
282                        let ty = &f.ty;
283                        let res = quote_spanned! {f.span()=>
284                            <#ty as ::generic_bytes::SizedBytes>::Len
285                        };
286                        if let Some(ih) = quote {
287                            quote = Some(quote! {
288                                ::generic_array::typenum::Sum<#ih, #res>
289                            });
290                        } else {
291                            quote = Some(res);
292                        }
293                    }
294                    quote! {
295                        #quote
296                    }
297                }
298                Fields::Unnamed(ref fields) => {
299                    let mut quote = None;
300                    for f in fields.unnamed.iter() {
301                        let ty = &f.ty;
302                        let res = quote_spanned! {f.span()=>
303                            <#ty as ::generic_bytes::SizedBytes>::Len
304                        };
305                        if let Some(ih) = quote {
306                            quote = Some(quote! {
307                                ::generic_array::typenum::Sum<#ih, #res>
308                            });
309                        } else {
310                            quote = Some(res);
311                        }
312                    }
313                    quote! {
314                        #quote
315                    }
316                }
317                Fields::Unit => {
318                    // Unit structs cannot own more than 0 bytes of heap memory.
319                    unimplemented!()
320                }
321            }
322        }
323        Data::Enum(_) | Data::Union(_) => unimplemented!(),
324    }
325}
326
327// Generate an expression to concatenate the to_arr of each field
328fn byte_concatenation(data: &Data) -> TokenStream {
329    match *data {
330        Data::Struct(ref data) => {
331            match data.fields {
332                Fields::Named(ref fields) => {
333                    let mut quote = None;
334                    for f in fields.named.iter() {
335                        let name = &f.ident;
336                        let res = quote_spanned! {f.span()=>
337                            ::generic_bytes::SizedBytes::to_arr(&self.#name)
338                        };
339                        if let Some(ih) = quote {
340                            quote = Some(quote! {
341                                ::generic_array::sequence::Concat::concat(#ih, #res)
342                            });
343                        } else {
344                            quote = Some(res);
345                        }
346                    }
347                    quote! {
348                        #quote
349                    }
350                }
351                Fields::Unnamed(ref fields) => {
352                    let mut quote = None;
353                    for (i, f) in fields.unnamed.iter().enumerate() {
354                        let index = Index::from(i);
355                        let res = quote_spanned! {f.span()=>
356                            ::generic_bytes::SizedBytes::to_arr(&self.#index)
357                        };
358                        if let Some(ih) = quote {
359                            quote = Some(quote! {
360                                ::generic_array::sequence::Concat::concat(#ih, #res)
361                            });
362                        } else {
363                            quote = Some(res);
364                        }
365                    }
366                    quote! {
367                        #quote
368                    }
369                }
370                Fields::Unit => {
371                    // Unit structs cannot own more than 0 bytes of heap memory.
372                    quote!(0)
373                }
374            }
375        }
376        Data::Enum(_) | Data::Union(_) => unimplemented!(),
377    }
378}
379
380// Generate an expression to concatenate the to_arr of each field
381fn byte_splitting(constr: &proc_macro2::Ident, data: &Data) -> TokenStream {
382    match *data {
383        Data::Struct(ref data) => {
384            match data.fields {
385                Fields::Named(ref fields) => {
386                    let l = fields.named.len();
387                    let setup: TokenStream = fields
388                        .named
389                        .iter().enumerate()
390                        .map(|(i, f)| {
391                            let name = &f.ident;
392                            let ty = &f.ty;
393
394                            if i < (l-1) {
395                                quote_spanned! {f.span()=>
396                                    let (head, _tail): (&GenericArray<u8, <#ty as ::generic_bytes::SizedBytes>::Len>, &GenericArray<u8, _>) =
397                                                generic_array::sequence::Split::split(_tail);
398                                    let #name: #ty = ::generic_bytes::SizedBytes::from_arr(head)?;
399                                }
400                            } else {
401                                quote_spanned!{f.span() =>
402                                    let #name: #ty = ::generic_bytes::SizedBytes::from_arr(_tail)?;
403                                }
404                            }
405                        })
406                        .collect();
407
408                    let conclude: TokenStream = fields
409                        .named
410                        .iter()
411                        .map(|f| {
412                            let name = &f.ident;
413                            quote_spanned! {f.span()=>
414                                #name,
415                            }
416                        })
417                        .collect();
418                    quote! {
419                        let _tail = arr;
420                        #setup
421                        Ok(#constr {
422                            #conclude
423                        })
424                    }
425                }
426                Fields::Unnamed(ref fields) => {
427                    let l = fields.unnamed.len();
428                    let setup: TokenStream = fields
429                        .unnamed
430                        .iter()
431                        .enumerate()
432                        .map(|(i, f)| {
433                            let ty = &f.ty;
434                            if i < (l-1) {
435                                let field_name = format!("f_{}", i);
436                                let fname = syn::Ident::new(&field_name, f.span());
437                                quote_spanned! {f.span()=>
438                                                let (head, _tail) = generic_array::sequence::Split::split(_tail);
439                                                let #fname: #ty = ::generic_bytes::SizedBytes::from_arr(head)?;
440                                }
441                            } else {
442                                let field_name = format!("f_{}", i);
443                                let fname = syn::Ident::new(&field_name, f.span());
444                                quote_spanned! {f.span()=>
445                                                let #fname: #ty = ::generic_bytes::SizedBytes::from_arr(_tail)?;
446                                }
447                            }
448                        })
449                        .collect();
450
451                    let conclude: TokenStream = fields
452                        .unnamed
453                        .iter()
454                        .enumerate()
455                        .map(|(i, f)| {
456                            let field_name = format!("f_{}", i);
457                            let fname = syn::Ident::new(&field_name, f.span());
458                            quote_spanned! {f.span()=>
459                                #fname,
460                            }
461                        })
462                        .collect();
463                    quote! (
464                        let _tail = arr;
465                        #setup
466                        Ok(#constr (
467                            #conclude
468                        ))
469                    )
470                }
471                Fields::Unit => {
472                    // Unit structs cannot own more than 0 bytes of heap memory.
473                    quote!(0)
474                }
475            }
476        }
477        Data::Enum(_) | Data::Union(_) => unimplemented!(),
478    }
479}
480
481#[proc_macro_derive(SizedBytes)]
482pub fn derive_sized_bytes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
483    let mut input: DeriveInput = match syn::parse(input) {
484        Ok(input) => input,
485        Err(e) => return e.to_compile_error().into(),
486    };
487    let name = &input.ident;
488
489    // Add a bound `T::From : SizedBytes` to every type parameter occurrence `T::From`.
490    if let Err(e) = add_trait_bounds(
491        &mut input.generics,
492        &input.data,
493        parse_quote!(::generic_bytes::SizedBytes),
494    ) {
495        return e.to_compile_error().into();
496    };
497
498    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
499
500    // Generate an expression to sum the type lengths of each field.
501    let types_sum = sum(&input.data);
502
503    // Generate an expression to concatenate each field.
504    let to_arr_impl = byte_concatenation(&input.data);
505
506    // Generate an expression to ingest each field.
507    let from_arr_impl = byte_splitting(name, &input.data);
508
509    let res = quote! (
510        // The generated impl.
511        impl #impl_generics ::generic_bytes::SizedBytes for #name #ty_generics #where_clause {
512
513            type Len = #types_sum;
514
515            fn to_arr(&self) -> GenericArray<u8, Self::Len> {
516                #to_arr_impl
517            }
518
519            fn from_arr(arr: &GenericArray<u8, Self::Len>) -> Result<Self, ::generic_bytes::TryFromSizedBytesError> {
520                #from_arr_impl
521            }
522        }
523    );
524    res.into()
525}