ssz_derive/
lib.rs

1#![recursion_limit = "256"]
2//! Provides procedural derive macros for the `Encode` and `Decode` traits of the `eth2_ssz` crate.
3//!
4//! Supports field attributes, see each derive macro for more information.
5
6use darling::{FromDeriveInput, FromMeta};
7use proc_macro::TokenStream;
8use quote::quote;
9use std::convert::TryInto;
10use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident};
11
12/// The highest possible union selector value (higher values are reserved for backwards compatible
13/// extensions).
14const MAX_UNION_SELECTOR: u8 = 127;
15
16#[derive(Debug, FromDeriveInput)]
17#[darling(attributes(ssz))]
18struct StructOpts {
19    #[darling(default)]
20    enum_behaviour: Option<String>,
21}
22
23/// Field-level configuration.
24#[derive(Debug, Default, FromMeta)]
25struct FieldOpts {
26    #[darling(default)]
27    with: Option<Ident>,
28    #[darling(default)]
29    skip_serializing: bool,
30    #[darling(default)]
31    skip_deserializing: bool,
32}
33
34const ENUM_TRANSPARENT: &str = "transparent";
35const ENUM_UNION: &str = "union";
36const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION];
37const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \
38    e.g., #[ssz(enum_behaviour = \"transparent\")]";
39
40enum EnumBehaviour {
41    Transparent,
42    Union,
43}
44
45impl EnumBehaviour {
46    pub fn new(s: Option<String>) -> Option<Self> {
47        s.map(|s| match s.as_ref() {
48            ENUM_TRANSPARENT => EnumBehaviour::Transparent,
49            ENUM_UNION => EnumBehaviour::Union,
50            other => panic!(
51                "{} is an invalid enum_behaviour, use either {:?}",
52                other, ENUM_VARIANTS
53            ),
54        })
55    }
56}
57
58fn parse_ssz_fields(struct_data: &syn::DataStruct) -> Vec<(&syn::Type, &syn::Ident, FieldOpts)> {
59    struct_data
60        .fields
61        .iter()
62        .map(|field| {
63            let ty = &field.ty;
64            let ident = match &field.ident {
65                Some(ref ident) => ident,
66                _ => panic!("ssz_derive only supports named struct fields."),
67            };
68
69            let field_opts_candidates = field
70                .attrs
71                .iter()
72                .filter(|attr| attr.path.get_ident().map_or(false, |ident| *ident == "ssz"))
73                .collect::<Vec<_>>();
74
75            if field_opts_candidates.len() > 1 {
76                panic!("more than one field-level \"ssz\" attribute provided")
77            }
78
79            let field_opts = field_opts_candidates
80                .first()
81                .map(|attr| {
82                    let meta = attr.parse_meta().unwrap();
83                    FieldOpts::from_meta(&meta).unwrap()
84                })
85                .unwrap_or_default();
86
87            (ty, ident, field_opts)
88        })
89        .collect()
90}
91
92/// Implements `ssz::Encode` for some `struct` or `enum`.
93#[proc_macro_derive(Encode, attributes(ssz))]
94pub fn ssz_encode_derive(input: TokenStream) -> TokenStream {
95    let item = parse_macro_input!(input as DeriveInput);
96    let opts = StructOpts::from_derive_input(&item).unwrap();
97    let enum_opt = EnumBehaviour::new(opts.enum_behaviour);
98
99    match &item.data {
100        syn::Data::Struct(s) => {
101            if enum_opt.is_some() {
102                panic!("enum_behaviour is invalid for structs");
103            }
104            ssz_encode_derive_struct(&item, s)
105        }
106        syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) {
107            EnumBehaviour::Transparent => ssz_encode_derive_enum_transparent(&item, s),
108            EnumBehaviour::Union => ssz_encode_derive_enum_union(&item, s),
109        },
110        _ => panic!("ssz_derive only supports structs and enums"),
111    }
112}
113
114/// Derive `ssz::Encode` for a struct.
115///
116/// Fields are encoded in the order they are defined.
117///
118/// ## Field attributes
119///
120/// - `#[ssz(skip_serializing)]`: the field will not be serialized.
121fn ssz_encode_derive_struct(derive_input: &DeriveInput, struct_data: &DataStruct) -> TokenStream {
122    let name = &derive_input.ident;
123    let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
124
125    let field_is_ssz_fixed_len = &mut vec![];
126    let field_fixed_len = &mut vec![];
127    let field_ssz_bytes_len = &mut vec![];
128    let field_encoder_append = &mut vec![];
129
130    for (ty, ident, field_opts) in parse_ssz_fields(struct_data) {
131        if field_opts.skip_serializing {
132            continue;
133        }
134
135        if let Some(module) = field_opts.with {
136            let module = quote! { #module::encode };
137            field_is_ssz_fixed_len.push(quote! { #module::is_ssz_fixed_len() });
138            field_fixed_len.push(quote! { #module::ssz_fixed_len() });
139            field_ssz_bytes_len.push(quote! { #module::ssz_bytes_len(&self.#ident) });
140            field_encoder_append.push(quote! {
141                encoder.append_parameterized(
142                    #module::is_ssz_fixed_len(),
143                    |buf| #module::ssz_append(&self.#ident, buf)
144                )
145            });
146        } else {
147            field_is_ssz_fixed_len.push(quote! { <#ty as ssz::Encode>::is_ssz_fixed_len() });
148            field_fixed_len.push(quote! { <#ty as ssz::Encode>::ssz_fixed_len() });
149            field_ssz_bytes_len.push(quote! { self.#ident.ssz_bytes_len() });
150            field_encoder_append.push(quote! { encoder.append(&self.#ident) });
151        }
152    }
153
154    let output = quote! {
155        impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
156            fn is_ssz_fixed_len() -> bool {
157                #(
158                    #field_is_ssz_fixed_len &&
159                )*
160                    true
161            }
162
163            fn ssz_fixed_len() -> usize {
164                if <Self as ssz::Encode>::is_ssz_fixed_len() {
165                    let mut len: usize = 0;
166                    #(
167                        len = len
168                            .checked_add(#field_fixed_len)
169                            .expect("encode ssz_fixed_len length overflow");
170                    )*
171                    len
172                } else {
173                    ssz::BYTES_PER_LENGTH_OFFSET
174                }
175            }
176
177            fn ssz_bytes_len(&self) -> usize {
178                if <Self as ssz::Encode>::is_ssz_fixed_len() {
179                    <Self as ssz::Encode>::ssz_fixed_len()
180                } else {
181                    let mut len: usize = 0;
182                    #(
183                        if #field_is_ssz_fixed_len {
184                            len = len
185                                .checked_add(#field_fixed_len)
186                                .expect("encode ssz_bytes_len length overflow");
187                        } else {
188                            len = len
189                                .checked_add(ssz::BYTES_PER_LENGTH_OFFSET)
190                                .expect("encode ssz_bytes_len length overflow for offset");
191                            len = len
192                                .checked_add(#field_ssz_bytes_len)
193                                .expect("encode ssz_bytes_len length overflow for bytes");
194                        }
195                    )*
196
197                    len
198                }
199            }
200
201            fn ssz_append(&self, buf: &mut Vec<u8>) {
202                let mut offset: usize = 0;
203                #(
204                    offset = offset
205                        .checked_add(#field_fixed_len)
206                        .expect("encode ssz_append offset overflow");
207                )*
208
209                let mut encoder = ssz::SszEncoder::container(buf, offset);
210
211                #(
212                    #field_encoder_append;
213                )*
214
215                encoder.finalize();
216            }
217        }
218    };
219    output.into()
220}
221
222/// Derive `ssz::Encode` for an enum in the "transparent" method.
223///
224/// The "transparent" method is distinct from the "union" method specified in the SSZ specification.
225/// When using "transparent", the enum will be ignored and the contained field will be serialized as
226/// if the enum does not exist. Since an union variant "selector" is not serialized, it is not
227/// possible to reliably decode an enum that is serialized transparently.
228///
229/// ## Limitations
230///
231/// Only supports:
232/// - Enums with a single field per variant, where
233///     - All fields are variably sized from an SSZ-perspective (not fixed size).
234///
235/// ## Panics
236///
237/// Will panic at compile-time if the single field requirement isn't met, but will panic *at run
238/// time* if the variable-size requirement isn't met.
239fn ssz_encode_derive_enum_transparent(
240    derive_input: &DeriveInput,
241    enum_data: &DataEnum,
242) -> TokenStream {
243    let name = &derive_input.ident;
244    let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
245
246    let (patterns, assert_exprs): (Vec<_>, Vec<_>) = enum_data
247        .variants
248        .iter()
249        .map(|variant| {
250            let variant_name = &variant.ident;
251
252            if variant.fields.len() != 1 {
253                panic!("ssz::Encode can only be derived for enums with 1 field per variant");
254            }
255
256            let pattern = quote! {
257                #name::#variant_name(ref inner)
258            };
259
260            let ty = &(&variant.fields).into_iter().next().unwrap().ty;
261            let type_assert = quote! {
262                !<#ty as ssz::Encode>::is_ssz_fixed_len()
263            };
264            (pattern, type_assert)
265        })
266        .unzip();
267
268    let output = quote! {
269        impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
270            fn is_ssz_fixed_len() -> bool {
271                assert!(
272                    #(
273                        #assert_exprs &&
274                    )* true,
275                    "not all enum variants are variably-sized"
276                );
277                false
278            }
279
280            fn ssz_bytes_len(&self) -> usize {
281                match self {
282                    #(
283                        #patterns => inner.ssz_bytes_len(),
284                    )*
285                }
286            }
287
288            fn ssz_append(&self, buf: &mut Vec<u8>) {
289                match self {
290                    #(
291                        #patterns => inner.ssz_append(buf),
292                    )*
293                }
294            }
295        }
296    };
297    output.into()
298}
299
300/// Derive `ssz::Encode` for an `enum` following the "union" SSZ spec.
301///
302/// The union selector will be determined based upon the order in which the enum variants are
303/// defined. E.g., the top-most variant in the enum will have a selector of `0`, the variant
304/// beneath it will have a selector of `1` and so on.
305///
306/// # Limitations
307///
308/// Only supports enums where each variant has a single field.
309fn ssz_encode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum) -> TokenStream {
310    let name = &derive_input.ident;
311    let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
312
313    let patterns: Vec<_> = enum_data
314        .variants
315        .iter()
316        .map(|variant| {
317            let variant_name = &variant.ident;
318
319            if variant.fields.len() != 1 {
320                panic!("ssz::Encode can only be derived for enums with 1 field per variant");
321            }
322
323            let pattern = quote! {
324                #name::#variant_name(ref inner)
325            };
326            pattern
327        })
328        .collect();
329
330    let union_selectors = compute_union_selectors(patterns.len());
331
332    let output = quote! {
333        impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
334            fn is_ssz_fixed_len() -> bool {
335                false
336            }
337
338            fn ssz_bytes_len(&self) -> usize {
339                match self {
340                    #(
341                        #patterns => inner
342                            .ssz_bytes_len()
343                            .checked_add(1)
344                            .expect("encoded length must be less than usize::max_value"),
345                    )*
346                }
347            }
348
349            fn ssz_append(&self, buf: &mut Vec<u8>) {
350                match self {
351                    #(
352                        #patterns => {
353                            let union_selector: u8 = #union_selectors;
354                            debug_assert!(union_selector <= ssz::MAX_UNION_SELECTOR);
355                            buf.push(union_selector);
356                            inner.ssz_append(buf)
357                        },
358                    )*
359                }
360            }
361        }
362    };
363    output.into()
364}
365
366/// Derive `ssz::Decode` for a struct or enum.
367#[proc_macro_derive(Decode, attributes(ssz))]
368pub fn ssz_decode_derive(input: TokenStream) -> TokenStream {
369    let item = parse_macro_input!(input as DeriveInput);
370    let opts = StructOpts::from_derive_input(&item).unwrap();
371    let enum_opt = EnumBehaviour::new(opts.enum_behaviour);
372
373    match &item.data {
374        syn::Data::Struct(s) => {
375            if enum_opt.is_some() {
376                panic!("enum_behaviour is invalid for structs");
377            }
378            ssz_decode_derive_struct(&item, s)
379        }
380        syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) {
381            EnumBehaviour::Transparent => panic!(
382                "Decode cannot be derived for enum_behaviour \"{}\", only \"{}\" is valid.",
383                ENUM_TRANSPARENT, ENUM_UNION
384            ),
385            EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, s),
386        },
387        _ => panic!("ssz_derive only supports structs and enums"),
388    }
389}
390
391/// Implements `ssz::Decode` for some `struct`.
392///
393/// Fields are decoded in the order they are defined.
394///
395/// ## Field attributes
396///
397/// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a
398/// `Default` implementation. The decoder will assume that the field was not serialized at all
399/// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it).
400fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream {
401    let name = &item.ident;
402    let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl();
403
404    let mut register_types = vec![];
405    let mut field_names = vec![];
406    let mut fixed_decodes = vec![];
407    let mut decodes = vec![];
408    let mut is_fixed_lens = vec![];
409    let mut fixed_lens = vec![];
410
411    for (ty, ident, field_opts) in parse_ssz_fields(struct_data) {
412        field_names.push(quote! {
413            #ident
414        });
415
416        // Field should not be deserialized; use a `Default` impl to instantiate.
417        if field_opts.skip_deserializing {
418            decodes.push(quote! {
419                let #ident = <_>::default();
420            });
421
422            fixed_decodes.push(quote! {
423                let #ident = <_>::default();
424            });
425
426            continue;
427        }
428
429        let is_ssz_fixed_len;
430        let ssz_fixed_len;
431        let from_ssz_bytes;
432        if let Some(module) = field_opts.with {
433            let module = quote! { #module::decode };
434
435            is_ssz_fixed_len = quote! { #module::is_ssz_fixed_len() };
436            ssz_fixed_len = quote! { #module::ssz_fixed_len() };
437            from_ssz_bytes = quote! { #module::from_ssz_bytes(slice) };
438
439            register_types.push(quote! {
440                builder.register_type_parameterized(#is_ssz_fixed_len, #ssz_fixed_len)?;
441            });
442            decodes.push(quote! {
443                let #ident = decoder.decode_next_with(|slice| #module::from_ssz_bytes(slice))?;
444            });
445        } else {
446            is_ssz_fixed_len = quote! { <#ty as ssz::Decode>::is_ssz_fixed_len() };
447            ssz_fixed_len = quote! { <#ty as ssz::Decode>::ssz_fixed_len() };
448            from_ssz_bytes = quote! { <#ty as ssz::Decode>::from_ssz_bytes(slice) };
449
450            register_types.push(quote! {
451                builder.register_type::<#ty>()?;
452            });
453            decodes.push(quote! {
454                let #ident = decoder.decode_next()?;
455            });
456        }
457
458        fixed_decodes.push(quote! {
459            let #ident = {
460                start = end;
461                end = end
462                    .checked_add(#ssz_fixed_len)
463                    .ok_or_else(|| ssz::DecodeError::OutOfBoundsByte {
464                        i: usize::max_value()
465                    })?;
466                let slice = bytes.get(start..end)
467                    .ok_or_else(|| ssz::DecodeError::InvalidByteLength {
468                        len: bytes.len(),
469                        expected: end
470                    })?;
471                #from_ssz_bytes?
472            };
473        });
474        is_fixed_lens.push(is_ssz_fixed_len);
475        fixed_lens.push(ssz_fixed_len);
476    }
477
478    let output = quote! {
479        impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
480            fn is_ssz_fixed_len() -> bool {
481                #(
482                    #is_fixed_lens &&
483                )*
484                    true
485            }
486
487            fn ssz_fixed_len() -> usize {
488                if <Self as ssz::Decode>::is_ssz_fixed_len() {
489                    let mut len: usize = 0;
490                    #(
491                        len = len
492                            .checked_add(#fixed_lens)
493                            .expect("decode ssz_fixed_len overflow");
494                    )*
495                    len
496                } else {
497                    ssz::BYTES_PER_LENGTH_OFFSET
498                }
499            }
500
501            fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result<Self, ssz::DecodeError> {
502                if <Self as ssz::Decode>::is_ssz_fixed_len() {
503                    if bytes.len() != <Self as ssz::Decode>::ssz_fixed_len() {
504                        return Err(ssz::DecodeError::InvalidByteLength {
505                            len: bytes.len(),
506                            expected: <Self as ssz::Decode>::ssz_fixed_len(),
507                        });
508                    }
509
510                    let mut start: usize = 0;
511                    let mut end = start;
512
513                    #(
514                        #fixed_decodes
515                    )*
516
517                    Ok(Self {
518                        #(
519                            #field_names,
520                        )*
521                    })
522                } else {
523                    let mut builder = ssz::SszDecoderBuilder::new(bytes);
524
525                    #(
526                        #register_types
527                    )*
528
529                    let mut decoder = builder.build()?;
530
531                    #(
532                        #decodes
533                    )*
534
535
536                    Ok(Self {
537                        #(
538                            #field_names,
539                        )*
540                    })
541                }
542            }
543        }
544    };
545    output.into()
546}
547
548/// Derive `ssz::Decode` for an `enum` following the "union" SSZ spec.
549fn ssz_decode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum) -> TokenStream {
550    let name = &derive_input.ident;
551    let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
552
553    let (constructors, var_types): (Vec<_>, Vec<_>) = enum_data
554        .variants
555        .iter()
556        .map(|variant| {
557            let variant_name = &variant.ident;
558
559            if variant.fields.len() != 1 {
560                panic!("ssz::Encode can only be derived for enums with 1 field per variant");
561            }
562
563            let constructor = quote! {
564                #name::#variant_name
565            };
566
567            let ty = &(&variant.fields).into_iter().next().unwrap().ty;
568            (constructor, ty)
569        })
570        .unzip();
571
572    let union_selectors = compute_union_selectors(constructors.len());
573
574    let output = quote! {
575        impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
576            fn is_ssz_fixed_len() -> bool {
577                false
578            }
579
580            fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
581                // Sanity check to ensure the definition here does not drift from the one defined in
582                // `ssz`.
583                debug_assert_eq!(#MAX_UNION_SELECTOR, ssz::MAX_UNION_SELECTOR);
584
585                let (selector, body) = ssz::split_union_bytes(bytes)?;
586
587                match selector.into() {
588                    #(
589                        #union_selectors => {
590                            <#var_types as ssz::Decode>::from_ssz_bytes(body).map(#constructors)
591                        },
592                    )*
593                    other => Err(ssz::DecodeError::UnionSelectorInvalid(other))
594                }
595            }
596        }
597    };
598    output.into()
599}
600
601fn compute_union_selectors(num_variants: usize) -> Vec<u8> {
602    let union_selectors = (0..num_variants)
603        .map(|i| {
604            i.try_into()
605                .expect("union selector exceeds u8::max_value, union has too many variants")
606        })
607        .collect::<Vec<u8>>();
608
609    let highest_selector = union_selectors
610        .last()
611        .copied()
612        .expect("0-variant union is not permitted");
613
614    assert!(
615        highest_selector <= MAX_UNION_SELECTOR,
616        "union selector {} exceeds limit of {}, enum has too many variants",
617        highest_selector,
618        MAX_UNION_SELECTOR
619    );
620
621    union_selectors
622}