scale_decode_derive/
lib.rs

1// Copyright (C) 2023 Parity Technologies (UK) Ltd. (admin@parity.io)
2// This file is a part of the scale-decode crate.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//         http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// https://github.com/rust-lang/rust-clippy/issues/12643.
17// related to `darling::default` attribute expansion
18#![allow(clippy::manual_unwrap_or_default)]
19
20extern crate alloc;
21
22use alloc::string::ToString;
23use darling::FromAttributes;
24use proc_macro2::{Span, TokenStream as TokenStream2};
25use quote::quote;
26use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput};
27
28const ATTR_NAME: &str = "decode_as_type";
29
30// Macro docs in main crate; don't add any docs here.
31#[proc_macro_derive(DecodeAsType, attributes(decode_as_type, codec))]
32pub fn derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33    let input = parse_macro_input!(input as DeriveInput);
34
35    // parse top level attrs.
36    let attrs = match TopLevelAttrs::parse(&input.attrs) {
37        Ok(attrs) => attrs,
38        Err(e) => return e.write_errors().into(),
39    };
40
41    derive_with_attrs(attrs, input).into()
42}
43
44fn derive_with_attrs(attrs: TopLevelAttrs, input: DeriveInput) -> TokenStream2 {
45    let visibility = &input.vis;
46    // what type is the derive macro declared on?
47    match &input.data {
48        syn::Data::Enum(details) => generate_enum_impl(attrs, visibility, &input, details),
49        syn::Data::Struct(details) => generate_struct_impl(attrs, visibility, &input, details),
50        syn::Data::Union(_) => syn::Error::new(
51            input.ident.span(),
52            "Unions are not supported by the DecodeAsType macro",
53        )
54        .into_compile_error(),
55    }
56}
57
58fn generate_enum_impl(
59    attrs: TopLevelAttrs,
60    visibility: &syn::Visibility,
61    input: &DeriveInput,
62    details: &syn::DataEnum,
63) -> TokenStream2 {
64    let path_to_scale_decode = &attrs.crate_path;
65    let path_to_type: syn::Path = input.ident.clone().into();
66    let variant_names = details.variants.iter().map(|v| v.ident.to_string());
67
68    let generic_types = handle_generics(&attrs, input.generics.clone());
69    let ty_generics = generic_types.ty_generics();
70    let impl_generics = generic_types.impl_generics();
71    let visitor_where_clause = generic_types.visitor_where_clause();
72    let visitor_ty_generics = generic_types.visitor_ty_generics();
73    let visitor_impl_generics = generic_types.visitor_impl_generics();
74    let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
75    let type_resolver_ident = generic_types.type_resolver_ident();
76
77    // determine what the body of our visitor functions will be based on the type of enum fields
78    // that we're trying to generate output for.
79    let variant_ifs = details.variants.iter().map(|variant| {
80        let variant_ident = &variant.ident;
81        let variant_name = variant_ident.to_string();
82
83        let visit_one_variant_body = match &variant.fields {
84            syn::Fields::Named(fields) => {
85                let (
86                    field_count,
87                    field_composite_keyvals,
88                    field_tuple_keyvals
89                ) = named_field_keyvals(path_to_scale_decode, fields);
90
91                quote!{
92                    let fields = value.fields();
93                    return if fields.has_unnamed_fields() {
94                        if fields.remaining() != #field_count {
95                            return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength {
96                                actual_len: fields.remaining(),
97                                expected_len: #field_count
98                            }));
99                        }
100                        let vals = fields;
101                        Ok(#path_to_type::#variant_ident { #(#field_tuple_keyvals),* })
102                    } else {
103                        let vals: #path_to_scale_decode::BTreeMap<Option<&str>, _> = fields
104                            .map(|res| res.map(|item| (item.name(), item)))
105                            .collect::<Result<_, _>>()?;
106                        Ok(#path_to_type::#variant_ident { #(#field_composite_keyvals),* })
107                    }
108                }
109            },
110            syn::Fields::Unnamed(fields) => {
111                let (
112                    field_count,
113                    field_vals
114                ) = unnamed_field_vals(path_to_scale_decode, fields);
115
116                quote!{
117                    let fields = value.fields();
118                    if fields.remaining() != #field_count {
119                        return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength {
120                            actual_len: fields.remaining(),
121                            expected_len: #field_count
122                        }));
123                    }
124                    let vals = fields;
125                    return Ok(#path_to_type::#variant_ident ( #(#field_vals),* ))
126                }
127            },
128            syn::Fields::Unit => {
129                quote!{
130                    return Ok(#path_to_type::#variant_ident)
131                }
132            },
133        };
134
135        quote!{
136            if value.name() == #variant_name {
137                #visit_one_variant_body
138            }
139        }
140    });
141
142    quote!(
143        const _: () = {
144            #visibility struct Visitor #visitor_impl_generics (
145                ::core::marker::PhantomData<#visitor_phantomdata_type>
146            );
147
148            use #path_to_scale_decode::vec;
149            use #path_to_scale_decode::ToString;
150
151            impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
152                type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
153                fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
154                    Visitor(::core::marker::PhantomData)
155                }
156            }
157
158            impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
159                type Error = #path_to_scale_decode::Error;
160                type Value<'scale, 'info> = #path_to_type #ty_generics;
161                type TypeResolver = #type_resolver_ident;
162
163                fn visit_variant<'scale, 'info>(
164                    self,
165                    value: &mut #path_to_scale_decode::visitor::types::Variant<'scale, 'info, Self::TypeResolver>,
166                    type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
167                ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
168                    #(
169                        #variant_ifs
170                    )*
171                    Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::CannotFindVariant {
172                        got: value.name().to_string(),
173                        expected: vec![#(#variant_names),*]
174                    }))
175                }
176                // Allow an enum to be decoded through nested 1-field composites and tuples:
177                fn visit_composite<'scale, 'info>(
178                    self,
179                    value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
180                    _type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
181                ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
182                    if value.remaining() != 1 {
183                        return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Composite);
184                    }
185                    value.decode_item(self).unwrap()
186                }
187                fn visit_tuple<'scale, 'info>(
188                    self,
189                    value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
190                    _type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
191                ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
192                    if value.remaining() != 1 {
193                        return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Tuple);
194                    }
195                    value.decode_item(self).unwrap()
196                }
197            }
198        };
199    )
200}
201
202fn generate_struct_impl(
203    attrs: TopLevelAttrs,
204    visibility: &syn::Visibility,
205    input: &DeriveInput,
206    details: &syn::DataStruct,
207) -> TokenStream2 {
208    let path_to_scale_decode = &attrs.crate_path;
209    let path_to_type: syn::Path = input.ident.clone().into();
210
211    let generic_types = handle_generics(&attrs, input.generics.clone());
212    let ty_generics = generic_types.ty_generics();
213    let impl_generics = generic_types.impl_generics();
214    let visitor_where_clause = generic_types.visitor_where_clause();
215    let visitor_ty_generics = generic_types.visitor_ty_generics();
216    let visitor_impl_generics = generic_types.visitor_impl_generics();
217    let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
218    let type_resolver_ident = generic_types.type_resolver_ident();
219
220    // determine what the body of our visitor functions will be based on the type of struct
221    // that we're trying to generate output for.
222    let (visit_composite_body, visit_tuple_body) = match &details.fields {
223        syn::Fields::Named(fields) => {
224            let (field_count, field_composite_keyvals, field_tuple_keyvals) =
225                named_field_keyvals(path_to_scale_decode, fields);
226
227            (
228                quote! {
229                    if value.has_unnamed_fields() {
230                       return self.visit_tuple(&mut value.as_tuple(), type_id)
231                    }
232
233                    let vals: #path_to_scale_decode::BTreeMap<Option<&str>, _> =
234                        value.map(|res| res.map(|item| (item.name(), item))).collect::<Result<_, _>>()?;
235
236                    Ok(#path_to_type { #(#field_composite_keyvals),* })
237                },
238                quote! {
239                    if value.remaining() != #field_count {
240                        return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: #field_count }));
241                    }
242
243                    let vals = value;
244
245                    Ok(#path_to_type { #(#field_tuple_keyvals),* })
246                },
247            )
248        }
249        syn::Fields::Unnamed(fields) => {
250            let (field_count, field_vals) = unnamed_field_vals(path_to_scale_decode, fields);
251
252            (
253                quote! {
254                    self.visit_tuple(&mut value.as_tuple(), type_id)
255                },
256                quote! {
257                    if value.remaining() != #field_count {
258                        return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: #field_count }));
259                    }
260
261                    let vals = value;
262
263                    Ok(#path_to_type ( #( #field_vals ),* ))
264                },
265            )
266        }
267        syn::Fields::Unit => (
268            quote! {
269                self.visit_tuple(&mut value.as_tuple(), type_id)
270            },
271            quote! {
272                if value.remaining() > 0 {
273                    return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: 0 }));
274                }
275                Ok(#path_to_type)
276            },
277        ),
278    };
279
280    quote!(
281        const _: () = {
282            #visibility struct Visitor #visitor_impl_generics (
283                ::core::marker::PhantomData<#visitor_phantomdata_type>
284            );
285
286            use #path_to_scale_decode::vec;
287            use #path_to_scale_decode::ToString;
288
289            impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
290                type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
291                fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
292                    Visitor(::core::marker::PhantomData)
293                }
294            }
295
296            impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
297                type Error = #path_to_scale_decode::Error;
298                type Value<'scale, 'info> = #path_to_type #ty_generics;
299                type TypeResolver = #type_resolver_ident;
300
301                fn visit_composite<'scale, 'info>(
302                    self,
303                    value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
304                    type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
305                ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
306                    #visit_composite_body
307                }
308                fn visit_tuple<'scale, 'info>(
309                    self,
310                    value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
311                    type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
312                ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
313                    #visit_tuple_body
314                }
315            }
316
317            impl #impl_generics #path_to_scale_decode::DecodeAsFields for #path_to_type #ty_generics #visitor_where_clause  {
318                fn decode_as_fields<'info, R: #path_to_scale_decode::TypeResolver>(
319                    input: &mut &[u8],
320                    fields: &mut dyn #path_to_scale_decode::FieldIter<'info, R::TypeId>,
321                    types: &'info R
322                ) -> Result<Self, #path_to_scale_decode::Error>
323                {
324                    let mut composite = #path_to_scale_decode::visitor::types::Composite::new(core::iter::empty(), input, fields, types, false);
325                    use #path_to_scale_decode::{ Visitor, IntoVisitor };
326                    let val = <#path_to_type #ty_generics>::into_visitor().visit_composite(&mut composite, Default::default());
327
328                    // Consume any remaining bytes and update input:
329                    composite.skip_decoding()?;
330                    *input = composite.bytes_from_undecoded();
331
332                    val.map_err(From::from)
333                }
334            }
335        };
336    )
337}
338
339// Given some named fields, generate impls like `field_name: get_field_value()` for each field. Do this for the composite and tuple impls.
340fn named_field_keyvals<'f>(
341    path_to_scale_decode: &'f syn::Path,
342    fields: &'f syn::FieldsNamed,
343) -> (usize, impl Iterator<Item = TokenStream2> + 'f, impl Iterator<Item = TokenStream2> + 'f) {
344    let field_keyval_impls = fields.named.iter().map(move |f| {
345        let field_attrs = FieldAttrs::from_attributes(&f.attrs).unwrap_or_default();
346        let field_ident = f.ident.as_ref().expect("named field has ident");
347        let field_name = field_ident.to_string();
348        let skip_field = field_attrs.skip;
349
350        // If a field is skipped, we expect it to have a Default impl to use to populate it instead.
351        if skip_field {
352            return (
353                false,
354                quote!(#field_ident: ::core::default::Default::default()),
355                quote!(#field_ident: ::core::default::Default::default())
356            )
357        }
358
359        (
360            // Should we use this field (false means we'll not count it):
361            true,
362            // For turning named fields in scale typeinfo into named fields on struct like type:
363            quote!(#field_ident: {
364                let val = vals
365                    .get(&Some(#field_name))
366                    .ok_or_else(|| #path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::CannotFindField { name: #field_name.to_string() }))?
367                    .clone();
368                val.decode_as_type().map_err(|e| e.at_field(#field_name))?
369            }),
370            // For turning named fields in scale typeinfo into unnamed fields on tuple like type:
371            quote!(#field_ident: {
372                let val = vals.next().expect("field count should have been checked already on tuple type; please file a bug report")?;
373                val.decode_as_type().map_err(|e| e.at_field(#field_name))?
374            })
375        )
376    });
377
378    // if we skip any fields, we won't expect that field to exist in some tuple that's being given back.
379    let field_count = field_keyval_impls.clone().filter(|f| f.0).count();
380    let field_composite_keyvals = field_keyval_impls.clone().map(|v| v.1);
381    let field_tuple_keyvals = field_keyval_impls.map(|v| v.2);
382
383    (field_count, field_composite_keyvals, field_tuple_keyvals)
384}
385
386// Given some unnamed fields, generate impls like `get_field_value()` for each field. Do this for a tuple style impl.
387fn unnamed_field_vals<'f>(
388    _path_to_scale_decode: &'f syn::Path,
389    fields: &'f syn::FieldsUnnamed,
390) -> (usize, impl Iterator<Item = TokenStream2> + 'f) {
391    let field_val_impls = fields.unnamed.iter().enumerate().map(|(idx, f)| {
392        let field_attrs = FieldAttrs::from_attributes(&f.attrs).unwrap_or_default();
393        let skip_field = field_attrs.skip;
394
395        // If a field is skipped, we expect it to have a Default impl to use to populate it instead.
396        if skip_field {
397            return (false, quote!(::core::default::Default::default()));
398        }
399
400        (
401            // Should we use this field (false means we'll not count it):
402            true,
403            // For turning unnamed fields in scale typeinfo into unnamed fields on tuple like type:
404            quote!({
405                let val = vals.next().expect("field count should have been checked already on tuple type; please file a bug report")?;
406                val.decode_as_type().map_err(|e| e.at_idx(#idx))?
407            }),
408        )
409    });
410
411    // if we skip any fields, we won't expect that field to exist in some tuple that's being given back.
412    let field_count = field_val_impls.clone().filter(|f| f.0).count();
413    let field_vals = field_val_impls.map(|v| v.1);
414
415    (field_count, field_vals)
416}
417
418fn handle_generics(attrs: &TopLevelAttrs, generics: syn::Generics) -> GenericTypes {
419    let path_to_crate = &attrs.crate_path;
420
421    let type_resolver_ident =
422        syn::Ident::new(GenericTypes::TYPE_RESOLVER_IDENT_STR, Span::call_site());
423
424    // Where clause to use on Visitor/IntoVisitor
425    let visitor_where_clause = {
426        let (_, _, where_clause) = generics.split_for_impl();
427        let mut where_clause = where_clause.cloned().unwrap_or(syn::parse_quote!(where));
428        if let Some(where_predicates) = &attrs.trait_bounds {
429            // if custom trait bounds are given, append those to the where clause.
430            where_clause.predicates.extend(where_predicates.clone());
431        } else {
432            // else, append our default bounds to each parameter to ensure that it all lines up with our generated impls and such:
433            for param in generics.type_params() {
434                let ty = &param.ident;
435                where_clause.predicates.push(syn::parse_quote!(#ty: #path_to_crate::IntoVisitor));
436            }
437        }
438        where_clause
439    };
440
441    // (A, B, C, ScaleDecodeTypeResolver) style PhantomData type to use in Visitor struct.
442    let visitor_phantomdata_type = {
443        let tys = generics.params.iter().filter_map::<syn::Type, _>(|p| match p {
444            syn::GenericParam::Type(ty) => {
445                let ty = &ty.ident;
446                Some(syn::parse_quote!(#ty))
447            }
448            syn::GenericParam::Lifetime(lt) => {
449                let lt = &lt.lifetime;
450                Some(syn::parse_quote!(& #lt ()))
451            }
452            // We don't need to mention const's in the PhantomData type.
453            syn::GenericParam::Const(_) => None,
454        });
455
456        // Add a param for the type resolver generic.
457        let tys = tys.chain(core::iter::once(syn::parse_quote!(#type_resolver_ident)));
458
459        syn::parse_quote!( (#( #tys, )*) )
460    };
461
462    // generics for our Visitor/IntoVisitor; we just add the type resolver param to the list.
463    let visitor_generics = {
464        let mut type_generics = generics.clone();
465        let type_resolver_generic_param: syn::GenericParam =
466            syn::parse_quote!(#type_resolver_ident: #path_to_crate::TypeResolver);
467
468        type_generics.params.push(type_resolver_generic_param);
469        type_generics
470    };
471
472    // generics for the type itself
473    let type_generics = generics;
474
475    GenericTypes {
476        type_generics,
477        type_resolver_ident,
478        visitor_generics,
479        visitor_phantomdata_type,
480        visitor_where_clause,
481    }
482}
483
484struct GenericTypes {
485    type_resolver_ident: syn::Ident,
486    type_generics: syn::Generics,
487    visitor_generics: syn::Generics,
488    visitor_where_clause: syn::WhereClause,
489    visitor_phantomdata_type: syn::Type,
490}
491
492impl GenericTypes {
493    const TYPE_RESOLVER_IDENT_STR: &'static str = "ScaleDecodeTypeResolver";
494
495    pub fn ty_generics(&self) -> syn::TypeGenerics<'_> {
496        let (_, ty_generics, _) = self.type_generics.split_for_impl();
497        ty_generics
498    }
499    pub fn impl_generics(&self) -> syn::ImplGenerics<'_> {
500        let (impl_generics, _, _) = self.type_generics.split_for_impl();
501        impl_generics
502    }
503    pub fn visitor_where_clause(&self) -> &syn::WhereClause {
504        &self.visitor_where_clause
505    }
506    pub fn visitor_ty_generics(&self) -> syn::TypeGenerics<'_> {
507        let (_, ty_generics, _) = self.visitor_generics.split_for_impl();
508        ty_generics
509    }
510    pub fn visitor_impl_generics(&self) -> syn::ImplGenerics<'_> {
511        let (impl_generics, _, _) = self.visitor_generics.split_for_impl();
512        impl_generics
513    }
514    pub fn visitor_phantomdata_type(&self) -> &syn::Type {
515        &self.visitor_phantomdata_type
516    }
517    pub fn type_resolver_ident(&self) -> &syn::Ident {
518        &self.type_resolver_ident
519    }
520}
521
522struct TopLevelAttrs {
523    // path to the scale_decode crate, in case it's not a top level dependency.
524    crate_path: syn::Path,
525    // allow custom trait bounds to be used instead of the defaults.
526    trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
527}
528
529impl TopLevelAttrs {
530    fn parse(attrs: &[syn::Attribute]) -> darling::Result<Self> {
531        use darling::FromMeta;
532
533        #[derive(FromMeta)]
534        struct TopLevelAttrsInner {
535            #[darling(default)]
536            crate_path: Option<syn::Path>,
537            #[darling(default)]
538            trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
539        }
540
541        let mut res =
542            TopLevelAttrs { crate_path: syn::parse_quote!(::scale_decode), trait_bounds: None };
543
544        // look at each top level attr. parse any for decode_as_type.
545        for attr in attrs {
546            if !attr.path().is_ident(ATTR_NAME) {
547                continue;
548            }
549            let meta = &attr.meta;
550            let parsed_attrs = TopLevelAttrsInner::from_meta(meta)?;
551
552            res.trait_bounds = parsed_attrs.trait_bounds;
553            if let Some(crate_path) = parsed_attrs.crate_path {
554                res.crate_path = crate_path;
555            }
556        }
557
558        Ok(res)
559    }
560}
561
562/// Parse the attributes attached to some field
563#[derive(Debug, FromAttributes, Default)]
564#[darling(attributes(decode_as_type, codec))]
565struct FieldAttrs {
566    #[darling(default)]
567    skip: bool,
568}