bfieldcodec_derive/
lib.rs

1//! This crate provides a derive macro for the `BFieldCodec` trait.
2//!
3//! The macro emits statements starting with:
4//!   `use crate::twenty_first::`
5//!
6//! Crates that use this macro must add a use statement `twenty_first`
7//! in their lib.rs:
8//!   `use twenty_first;`
9//!
10//! or if using twenty-first via re-export in `dep_crate`:
11//!   `use dep_crate::twenty_first;`
12//!
13//! Failure to do so will result in compile errors.
14
15extern crate proc_macro;
16
17use std::collections::HashMap;
18
19use proc_macro2::TokenStream;
20use quote::quote;
21use syn::parse_macro_input;
22use syn::punctuated::Punctuated;
23use syn::token::Comma;
24use syn::Attribute;
25use syn::DeriveInput;
26use syn::Field;
27use syn::Fields;
28use syn::Ident;
29use syn::Type;
30use syn::Variant;
31
32/// Derives `BFieldCodec` for structs and enums.
33///
34/// Fields that should not be serialized can be ignored by annotating them with
35/// `#[bfield_codec(ignore)]`.
36/// Ignored fields must implement [`Default`].
37///
38/// For enums, the discriminant used for serialization can be accessed through method
39/// `bfield_codec_discriminant`.
40///
41/// ### Example
42///
43/// ```ignore
44/// #[derive(BFieldCodec)]
45/// struct Foo {
46///    bar: u64,
47///    #[bfield_codec(ignore)]
48///    ignored: usize,
49/// }
50/// let foo = Foo { bar: 42, ignored: 7 };
51/// let encoded = foo.encode();
52/// let decoded = Foo::decode(&encoded).unwrap();
53/// assert_eq!(foo.bar, decoded.bar);
54/// ```
55///
56/// Accessing the discriminant of an enum's variant:
57///
58/// ```ignore
59/// #[derive(BFieldCodec)]
60/// enum Bar {
61///     Baz,
62///     Qux(u64),
63/// }
64/// let _discriminant = Bar::Baz.bfield_codec_discriminant();
65/// ```
66///
67/// ### Known limitations
68///
69/// - Enums whith variants that have named fields are currently not supported. Example:
70///      ```ignore
71///     #[derive(BFieldCodec)]  // Currently not supported.
72///     enum Foo {
73///        Bar { baz: u64 },
74///     }
75///     ```
76///
77/// - Enums with no variants are currently not supported. Consider using a unit struct instead.
78///     Example:
79///     ```ignore
80///     #[derive(BFieldCodec)]  // Currently not supported.
81///     enum Foo {}             // Consider `struct Foo;` instead.
82///     ```
83#[proc_macro_derive(BFieldCodec, attributes(bfield_codec))]
84pub fn bfieldcodec_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
85    let ast = parse_macro_input!(input as DeriveInput);
86    BFieldCodecDeriveBuilder::new(ast).build().into()
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90enum BFieldCodecDeriveType {
91    UnitStruct,
92    StructWithNamedFields,
93    StructWithUnnamedFields,
94    Enum,
95}
96
97struct BFieldCodecDeriveBuilder {
98    name: Ident,
99    derive_type: BFieldCodecDeriveType,
100    generics: syn::Generics,
101    attributes: Vec<Attribute>,
102
103    named_included_fields: Vec<Field>,
104    named_ignored_fields: Vec<Field>,
105
106    unnamed_fields: Vec<Field>,
107
108    variants: Option<Punctuated<Variant, syn::token::Comma>>,
109
110    encode_statements: Vec<TokenStream>,
111    decode_function_body: TokenStream,
112    static_length_body: TokenStream,
113    error_builder: BFieldCodecErrorEnumBuilder,
114}
115
116struct BFieldCodecErrorEnumBuilder {
117    name: Ident,
118    errors: HashMap<&'static str, BFieldCodecErrorEnumVariant>,
119}
120
121struct BFieldCodecErrorEnumVariant {
122    variant_name: Ident,
123    variant_type: TokenStream,
124    display_match_arm: TokenStream,
125}
126
127impl BFieldCodecDeriveBuilder {
128    fn new(ast: DeriveInput) -> Self {
129        let derive_type = Self::extract_derive_type(&ast);
130
131        let named_fields = Self::extract_named_fields(&ast);
132        let (ignored_fields, included_fields) = named_fields
133            .iter()
134            .cloned()
135            .partition::<Vec<_>, _>(Self::field_is_ignored);
136
137        let unnamed_fields = Self::extract_unnamed_fields(&ast);
138        let variants = Self::extract_variants(&ast);
139
140        let name = ast.ident;
141        let error_builder = BFieldCodecErrorEnumBuilder::new(name.clone());
142
143        Self {
144            name,
145            derive_type,
146            generics: ast.generics,
147            attributes: ast.attrs,
148
149            named_included_fields: included_fields,
150            named_ignored_fields: ignored_fields,
151            unnamed_fields,
152            variants,
153
154            encode_statements: vec![],
155            decode_function_body: quote! {},
156            static_length_body: quote! {},
157            error_builder,
158        }
159    }
160
161    fn extract_derive_type(ast: &DeriveInput) -> BFieldCodecDeriveType {
162        match &ast.data {
163            syn::Data::Struct(syn::DataStruct {
164                fields: Fields::Unit,
165                ..
166            }) => BFieldCodecDeriveType::UnitStruct,
167            syn::Data::Struct(syn::DataStruct {
168                fields: Fields::Named(_),
169                ..
170            }) => BFieldCodecDeriveType::StructWithNamedFields,
171            syn::Data::Struct(syn::DataStruct {
172                fields: Fields::Unnamed(_),
173                ..
174            }) => BFieldCodecDeriveType::StructWithUnnamedFields,
175            syn::Data::Enum(_) => BFieldCodecDeriveType::Enum,
176            _ => panic!("expected a struct or an enum"),
177        }
178    }
179
180    fn extract_named_fields(ast: &DeriveInput) -> Vec<Field> {
181        match &ast.data {
182            syn::Data::Struct(syn::DataStruct {
183                fields: Fields::Named(fields),
184                ..
185            }) => fields.named.iter().rev().cloned().collect::<Vec<_>>(),
186            _ => vec![],
187        }
188    }
189
190    fn extract_unnamed_fields(ast: &DeriveInput) -> Vec<Field> {
191        match &ast.data {
192            syn::Data::Struct(syn::DataStruct {
193                fields: Fields::Unnamed(fields),
194                ..
195            }) => fields.unnamed.iter().cloned().collect::<Vec<_>>(),
196            _ => vec![],
197        }
198    }
199
200    fn extract_variants(ast: &DeriveInput) -> Option<Punctuated<Variant, Comma>> {
201        match &ast.data {
202            syn::Data::Enum(data_enum) => Some(data_enum.variants.clone()),
203            _ => None,
204        }
205    }
206
207    fn field_is_ignored(field: &Field) -> bool {
208        let field_name = field.ident.as_ref().unwrap();
209        let mut relevant_attributes = field
210            .attrs
211            .iter()
212            .filter(|attr| attr.path().is_ident("bfield_codec"));
213        let attribute = match relevant_attributes.clone().count() {
214            0 => return false,
215            1 => relevant_attributes.next().unwrap(),
216            _ => panic!("field `{field_name}` must have at most 1 `bfield_codec` attribute"),
217        };
218        let parse_ignore = attribute.parse_nested_meta(|meta| match meta.path.get_ident() {
219            Some(ident) if ident == "ignore" => Ok(()),
220            Some(ident) => panic!("unknown identifier `{ident}` for field `{field_name}`"),
221            _ => unreachable!(),
222        });
223        parse_ignore.is_ok()
224    }
225
226    fn build(mut self) -> TokenStream {
227        self.error_builder.build(self.derive_type);
228        self.add_trait_bounds_to_generics();
229        self.build_methods();
230        self.into_tokens()
231    }
232
233    fn add_trait_bounds_to_generics(&mut self) {
234        let ignored_generics = self.extract_ignored_generics_list();
235        let ignored_generics = self.recursively_collect_all_ignored_generics(ignored_generics);
236
237        for param in &mut self.generics.params {
238            let syn::GenericParam::Type(type_param) = param else {
239                continue;
240            };
241            if ignored_generics.contains(&type_param.ident) {
242                continue;
243            }
244            type_param.bounds.push(syn::parse_quote!(BFieldCodec));
245        }
246    }
247
248    fn extract_ignored_generics_list(&self) -> Vec<syn::Ident> {
249        self.attributes
250            .iter()
251            .flat_map(Self::extract_ignored_generics)
252            .collect()
253    }
254
255    fn extract_ignored_generics(attr: &Attribute) -> Vec<Ident> {
256        if !attr.path().is_ident("bfield_codec") {
257            return vec![];
258        }
259
260        let mut ignored_generics = vec![];
261        attr.parse_nested_meta(|meta| match meta.path.get_ident() {
262            Some(ident) if ident == "ignore" => {
263                ignored_generics.push(ident.to_owned());
264                Ok(())
265            }
266            Some(ident) => Err(meta.error(format!("Unknown identifier \"{ident}\"."))),
267            _ => Err(meta.error("Expected an identifier.")),
268        })
269        .unwrap();
270        ignored_generics
271    }
272
273    /// For all ignored fields, add all type identifiers (including, recursively, the type
274    /// identifiers of generic type arguments) to the list of ignored type identifiers.
275    fn recursively_collect_all_ignored_generics(
276        &self,
277        mut ignored_generics: Vec<Ident>,
278    ) -> Vec<Ident> {
279        let mut ignored_types = self
280            .named_ignored_fields
281            .iter()
282            .map(|ignored_field| ignored_field.ty.clone())
283            .collect::<Vec<_>>();
284        while !ignored_types.is_empty() {
285            let ignored_type = ignored_types[0].clone();
286            ignored_types = ignored_types[1..].to_vec();
287            let Type::Path(type_path) = ignored_type else {
288                continue;
289            };
290            for segment in type_path.path.segments.into_iter() {
291                ignored_generics.push(segment.ident);
292                let syn::PathArguments::AngleBracketed(generic_arguments) = segment.arguments
293                else {
294                    continue;
295                };
296                for generic_argument in generic_arguments.args.into_iter() {
297                    let syn::GenericArgument::Type(t) = generic_argument else {
298                        continue;
299                    };
300                    ignored_types.push(t.clone());
301                }
302            }
303        }
304        ignored_generics
305    }
306
307    fn build_methods(&mut self) {
308        match self.derive_type {
309            BFieldCodecDeriveType::UnitStruct => self.build_methods_for_unit_struct(),
310            BFieldCodecDeriveType::StructWithNamedFields => {
311                self.build_methods_for_struct_with_named_fields()
312            }
313            BFieldCodecDeriveType::StructWithUnnamedFields => {
314                self.build_methods_for_struct_with_unnamed_fields()
315            }
316            BFieldCodecDeriveType::Enum => self.build_methods_for_enum(),
317        }
318    }
319
320    fn build_methods_for_unit_struct(&mut self) {
321        self.build_decode_function_body_for_unit_struct();
322        self.static_length_body = quote! {::core::option::Option::Some(0)};
323    }
324
325    fn build_methods_for_struct_with_named_fields(&mut self) {
326        self.build_encode_statements_for_struct_with_named_fields();
327        self.build_decode_function_body_for_struct_with_named_fields();
328        let included_fields = self.named_included_fields.clone();
329        self.build_static_length_body_for_struct(&included_fields);
330    }
331
332    fn build_methods_for_struct_with_unnamed_fields(&mut self) {
333        self.build_encode_statements_for_struct_with_unnamed_fields();
334        self.build_decode_function_body_for_struct_with_unnamed_fields();
335        let included_fields = self.unnamed_fields.clone();
336        self.build_static_length_body_for_struct(&included_fields);
337    }
338
339    fn build_methods_for_enum(&mut self) {
340        self.build_encode_statements_for_enum();
341        self.build_decode_function_body_for_enum();
342        self.build_static_length_body_for_enum();
343    }
344
345    fn build_encode_statements_for_struct_with_named_fields(&mut self) {
346        let included_field_names = self
347            .named_included_fields
348            .iter()
349            .map(|field| field.ident.as_ref().unwrap().to_owned());
350        let included_field_types = self
351            .named_included_fields
352            .iter()
353            .map(|field| field.ty.clone());
354        self.encode_statements = included_field_names
355            .clone()
356            .zip(included_field_types.clone())
357            .map(|(field_name, field_type)| {
358                quote! {
359                    let #field_name:
360                        ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement>
361                            = self.#field_name.encode();
362                    if <#field_type as crate::twenty_first::prelude::BFieldCodec>
363                        ::static_length().is_none() {
364                        elements.push(
365                            crate::twenty_first::prelude::BFieldElement::new(
366                                #field_name.len() as u64
367                            )
368                        );
369                    }
370                    elements.extend(#field_name);
371                }
372            })
373            .collect();
374    }
375
376    fn build_encode_statements_for_struct_with_unnamed_fields(&mut self) {
377        let field_types = self.unnamed_fields.iter().map(|field| field.ty.clone());
378        let indices: Vec<_> = (0..self.unnamed_fields.len())
379            .map(syn::Index::from)
380            .collect();
381        let field_names: Vec<_> = indices
382            .iter()
383            .map(|i| quote::format_ident!("field_value_{}", i.index))
384            .collect();
385        self.encode_statements = indices
386            .iter()
387            .zip(field_types.clone())
388            .zip(field_names.clone())
389            .rev()
390            .map(|((idx, field_type), field_name)| {
391                quote! {
392                    let #field_name:
393                        ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement>
394                            = self.#idx.encode();
395                    if <#field_type as crate::twenty_first::prelude::BFieldCodec>
396                        ::static_length().is_none() {
397                        elements.push(
398                            crate::twenty_first::prelude::BFieldElement::new(
399                                #field_name.len() as u64
400                            )
401                        );
402                    }
403                    elements.extend(#field_name);
404                }
405            })
406            .collect();
407    }
408
409    fn build_encode_statements_for_enum(&mut self) {
410        let encode_clauses = self
411            .enum_discriminants_and_variants()
412            .into_iter()
413            .map(|(d, v)| self.generate_encode_clause_for_variant(d, v));
414        let encode_match_statement = quote! {
415            match self {
416                #( #encode_clauses , )*
417            }
418        };
419        self.encode_statements = vec![encode_match_statement];
420    }
421
422    fn generate_encode_clause_for_variant(
423        &self,
424        discriminant: usize,
425        variant: &Variant,
426    ) -> TokenStream {
427        let variant_name = &variant.ident;
428        let associated_data = &variant.fields;
429
430        if associated_data.is_empty() {
431            return quote! {
432                Self::#variant_name => {
433                    elements.push(crate::twenty_first::prelude::BFieldElement::new(
434                        #discriminant as u64)
435                    );
436                }
437            };
438        }
439
440        let reversed_enumerated_associated_data = associated_data.iter().enumerate().rev();
441        let field_encoders = reversed_enumerated_associated_data.map(|(field_index, ad)| {
442            let field_name = self.enum_variant_field_name(discriminant, field_index);
443            let field_type = ad.ty.clone();
444            let field_encoding =
445                quote::format_ident!("variant_{}_field_{}_encoding", discriminant, field_index);
446            quote! {
447                let #field_encoding:
448                    ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement> =
449                        #field_name.encode();
450                if <#field_type as crate::twenty_first::prelude::BFieldCodec>
451                    ::static_length().is_none() {
452                    elements.push(
453                        crate::twenty_first::prelude::BFieldElement::new(
454                            #field_encoding.len() as u64
455                        )
456                    );
457                }
458                elements.extend(#field_encoding);
459            }
460        });
461
462        let field_names = associated_data
463            .iter()
464            .enumerate()
465            .map(|(field_index, _field)| self.enum_variant_field_name(discriminant, field_index));
466
467        quote! {
468            Self::#variant_name ( #( #field_names , )* ) => {
469                elements.push(
470                    crate::twenty_first::prelude::BFieldElement::new(
471                        #discriminant as u64
472                    )
473                );
474                #( #field_encoders )*
475            }
476        }
477    }
478
479    fn build_decode_function_body_for_unit_struct(&mut self) {
480        let sequence_too_long_error = self.error_builder.sequence_too_long();
481
482        self.decode_function_body = quote! {
483            if !sequence.is_empty() {
484                return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
485            }
486            ::core::result::Result::Ok(::std::boxed::Box::new(Self))
487        };
488    }
489
490    fn build_decode_function_body_for_struct_with_named_fields(&mut self) {
491        let sequence_too_long_error = self.error_builder.sequence_too_long();
492
493        let decode_statements = self
494            .named_included_fields
495            .iter()
496            .map(|field| {
497                let field_name = field.ident.as_ref().unwrap();
498                self.generate_decode_statement_for_field(field_name, &field.ty)
499            })
500            .collect::<Vec<_>>();
501
502        let included_field_names = self.named_included_fields.iter().map(|field| {
503            let field_name = field.ident.as_ref().unwrap().to_owned();
504            quote! { #field_name }
505        });
506        let ignored_field_names = self.named_ignored_fields.iter().map(|field| {
507            let field_name = field.ident.as_ref().unwrap().to_owned();
508            quote! { #field_name }
509        });
510
511        self.decode_function_body = quote! {
512            #(#decode_statements)*
513            if !sequence.is_empty() {
514                return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
515            }
516            ::core::result::Result::Ok(::std::boxed::Box::new(Self {
517                #(#included_field_names,)*
518                #(#ignored_field_names: ::core::default::Default::default(),)*
519            }))
520        };
521    }
522
523    fn build_decode_function_body_for_struct_with_unnamed_fields(&mut self) {
524        let sequence_too_long_error = self.error_builder.sequence_too_long();
525
526        let field_names = (0..self.unnamed_fields.len())
527            .map(|i| quote::format_ident!("field_value_{}", i))
528            .collect::<Vec<_>>();
529        let decode_statements = field_names
530            .iter()
531            .zip(self.unnamed_fields.iter())
532            .rev()
533            .map(|(field_name, field)| {
534                self.generate_decode_statement_for_field(field_name, &field.ty)
535            })
536            .collect::<Vec<_>>();
537
538        self.decode_function_body = quote! {
539            #(#decode_statements)*
540            if !sequence.is_empty() {
541                return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
542            }
543            ::core::result::Result::Ok(::std::boxed::Box::new(Self ( #(#field_names,)* )))
544        };
545    }
546
547    fn generate_decode_statement_for_field(
548        &self,
549        field_name: &Ident,
550        field_type: &Type,
551    ) -> TokenStream {
552        let sequence_empty_for_field_error = self.error_builder.sequence_empty_for_field();
553        let sequence_too_short_for_field_error = self.error_builder.sequence_too_short_for_field();
554        let field_name_as_string_literal = field_name.to_string();
555        quote! {
556            let (#field_name, sequence) = {
557                let maybe_fields_static_length =
558                    <#field_type as crate::twenty_first::prelude::BFieldCodec>
559                        ::static_length();
560                let field_has_dynamic_length = maybe_fields_static_length.is_none();
561                if sequence.is_empty() && field_has_dynamic_length {
562                    return ::core::result::Result::Err(
563                        #sequence_empty_for_field_error(#field_name_as_string_literal.to_string())
564                    );
565                }
566                let (len, sequence) = match maybe_fields_static_length {
567                    ::core::option::Option::Some(len) => (len, sequence),
568                    ::core::option::Option::None => (sequence[0].value() as usize, &sequence[1..]),
569                };
570                if sequence.len() < len {
571                    return ::core::result::Result::Err(#sequence_too_short_for_field_error(
572                        #field_name_as_string_literal.to_string(),
573                    ));
574                }
575                let decoded =
576                    *<#field_type as crate::twenty_first::prelude::BFieldCodec>
577                        ::decode(&sequence[..len]).map_err(|err|
578                            -> ::std::boxed::Box<
579                                    dyn ::std::error::Error
580                                    + ::core::marker::Send
581                                    + ::core::marker::Sync
582                            > {
583                                err.into()
584                            }
585                        )?;
586                (decoded, &sequence[len..])
587            };
588        }
589    }
590
591    fn build_decode_function_body_for_enum(&mut self) {
592        let sequence_empty_error = self.error_builder.sequence_empty();
593        let invalid_variant_error = self.error_builder.invalid_discriminant();
594
595        let mut match_arms = vec![];
596        for (discriminant, variant) in self.enum_discriminants_and_variants() {
597            let decode_clause = self.generate_decode_clause_for_variant(discriminant, variant);
598            let match_arm = quote! { #discriminant => { #decode_clause } };
599            match_arms.push(match_arm);
600        }
601
602        self.decode_function_body = quote! {
603            if sequence.is_empty() {
604                return ::core::result::Result::Err(#sequence_empty_error);
605            }
606            let (discriminant, sequence) = (sequence[0].value() as usize, &sequence[1..]);
607            match discriminant {
608                #(#match_arms ,)*
609                other_index => ::core::result::Result::Err(#invalid_variant_error(other_index)),
610            }
611        };
612    }
613
614    fn generate_decode_clause_for_variant(
615        &self,
616        discriminant: usize,
617        variant: &Variant,
618    ) -> TokenStream {
619        let sequence_too_long_error = self.error_builder.sequence_too_long();
620        let sequence_empty_error = self.error_builder.sequence_empty_for_variant();
621        let sequence_too_short_error = self.error_builder.sequence_too_short_for_variant();
622
623        let variant_name = &variant.ident;
624        let associated_data = &variant.fields;
625        if associated_data.is_empty() {
626            return quote! {
627                if !sequence.is_empty() {
628                    return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
629                }
630                ::core::result::Result::Ok(::std::boxed::Box::new(Self::#variant_name))
631            };
632        }
633
634        let field_decoders = associated_data
635            .iter()
636            .enumerate()
637            .rev()
638            .map(|(field_index, field)| {
639                let field_type = field.ty.clone();
640                let field_name = self.enum_variant_field_name(discriminant, field_index);
641                let field_value =
642                    quote::format_ident!("variant_{}_field_{}_value", discriminant, field_index);
643                quote! {
644                    let (#field_value, sequence) = {
645                        let maybe_fields_static_length =
646                            <#field_type as crate::twenty_first::prelude::BFieldCodec>
647                                ::static_length();
648                        let field_has_dynamic_length = maybe_fields_static_length.is_none();
649                        if sequence.is_empty() && field_has_dynamic_length {
650                            return ::core::result::Result::Err(
651                                #sequence_empty_error(#discriminant, #field_index)
652                            );
653                        }
654                        let (len, sequence) = match maybe_fields_static_length {
655                            ::core::option::Option::Some(len) => (len, sequence),
656                            ::core::option::Option::None => {
657                                (sequence[0].value() as usize, &sequence[1..])
658                            },
659                        };
660                        if sequence.len() < len {
661                            return ::core::result::Result::Err(
662                                #sequence_too_short_error(#discriminant, #field_index)
663                            );
664                        }
665                        let decoded =
666                            *<#field_type as crate::twenty_first::prelude::BFieldCodec>
667                                ::decode(
668                                    &sequence[..len]
669                                ).map_err(|err|
670                                    -> ::std::boxed::Box<
671                                            dyn ::std::error::Error
672                                            + ::core::marker::Send
673                                            + ::core::marker::Sync
674                                    > {
675                                        err.into()
676                                    }
677                                )?;
678                        (decoded, &sequence[len..])
679                    };
680                    let #field_name = #field_value;
681                }
682            })
683            .fold(quote! {}, |l, r| quote! {#l #r});
684        let field_names = associated_data
685            .iter()
686            .enumerate()
687            .map(|(field_index, _field)| self.enum_variant_field_name(discriminant, field_index));
688        quote! {
689            #field_decoders
690            if !sequence.is_empty() {
691                return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
692            }
693            ::core::result::Result::Ok(
694                ::std::boxed::Box::new(Self::#variant_name ( #( #field_names , )* ))
695            )
696        }
697    }
698
699    fn enum_variant_field_name(&self, discriminant: usize, field_index: usize) -> syn::Ident {
700        quote::format_ident!("variant_{}_field_{}", discriminant, field_index)
701    }
702
703    fn build_static_length_body_for_struct(&mut self, fields: &[Field]) {
704        let field_types = fields
705            .iter()
706            .map(|field| field.ty.clone())
707            .collect::<Vec<_>>();
708        let num_fields = field_types.len();
709        self.static_length_body = quote! {
710            let field_lengths : [::core::option::Option<usize>; #num_fields] = [
711                #(
712                    <#field_types as
713                    crate::twenty_first::prelude::BFieldCodec>::static_length(),
714                )*
715            ];
716            if field_lengths.iter().all(|fl| fl.is_some() ) {
717                ::core::option::Option::Some(field_lengths.iter().map(|fl| fl.unwrap()).sum())
718            }
719            else {
720                ::core::option::Option::None
721            }
722        };
723    }
724
725    fn build_static_length_body_for_enum(&mut self) {
726        let variants = self.variants.as_ref().unwrap();
727        let no_variants_have_associated_data = variants.iter().all(|v| v.fields.is_empty());
728        if no_variants_have_associated_data {
729            self.static_length_body = quote! {::core::option::Option::Some(1)};
730            return;
731        }
732
733        let num_variants = variants.len();
734        if num_variants == 0 {
735            self.static_length_body = quote! {::core::option::Option::Some(0)};
736            return;
737        }
738
739        // some variants have associated data
740        // if all variants encode to the same length, the length is statically known anyway
741        let variant_lengths = variants
742            .iter()
743            .map(|variant| {
744                let fields = variant.fields.clone();
745                let field_lengths = fields.iter().map(|f| {
746                    quote! {
747                        <#f as crate::twenty_first::prelude::BFieldCodec>
748                            ::static_length()
749                    }
750                });
751                let num_fields = fields.len();
752                quote! {{
753                    let field_lengths: [::core::option::Option<usize>; #num_fields] =
754                        [ #( #field_lengths , )* ];
755                    if field_lengths.iter().all(|fl| fl.is_some()) {
756                        Some(field_lengths.iter().map(|fl|fl.unwrap()).sum())
757                    } else {
758                        None
759                    }
760                }}
761            })
762            .collect::<Vec<_>>();
763
764        self.static_length_body = quote! {
765                let variant_lengths : [::core::option::Option<usize>; #num_variants] =
766                    [ #( #variant_lengths , )* ];
767                if variant_lengths.iter().all(|field_len| field_len.is_some()) &&
768                    variant_lengths.iter().all(|x| x.unwrap() == variant_lengths[0].unwrap()) {
769                    // account for discriminant
770                    Some(variant_lengths[0].unwrap() + 1)
771                }
772                else {
773                    None
774                }
775
776        };
777    }
778
779    fn enum_discriminants_and_variants(&self) -> Vec<(usize, &Variant)> {
780        self.variants.as_ref().unwrap().iter().enumerate().collect()
781    }
782
783    fn maybe_impl_enum_discriminants(&self) -> TokenStream {
784        if self.derive_type != BFieldCodecDeriveType::Enum {
785            return quote! {};
786        }
787
788        let mut variant_match_arms = vec![];
789        for (discriminant, variant) in self.enum_discriminants_and_variants() {
790            let ident = &variant.ident;
791            let mut match_statement = quote! { Self::#ident };
792            if !variant.fields.is_empty() {
793                match_statement.extend(quote! { ( .. ) });
794            }
795            let match_arm = quote! { #match_statement => #discriminant };
796            variant_match_arms.push(match_arm);
797        }
798
799        let name = self.name.clone();
800        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
801        quote! {
802            impl #impl_generics #name #ty_generics #where_clause {
803                pub fn bfield_codec_discriminant(&self) -> usize {
804                    match self {
805                        #( #variant_match_arms , )*
806                    }
807                }
808            }
809        }
810    }
811
812    fn into_tokens(self) -> TokenStream {
813        let maybe_impl_enum_discriminants = self.maybe_impl_enum_discriminants();
814        let name = self.name;
815        let error_enum_name = self.error_builder.error_enum_name();
816        let errors = self.error_builder.into_tokens();
817        let decode_function_body = self.decode_function_body;
818        let encode_statements = self.encode_statements;
819        let static_length_body = self.static_length_body;
820        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
821
822        quote! {
823            #maybe_impl_enum_discriminants
824            #errors
825            impl #impl_generics crate::twenty_first::prelude::BFieldCodec
826            for #name #ty_generics #where_clause {
827                type Error = #error_enum_name;
828
829                fn decode(
830                    sequence: &[crate::twenty_first::prelude::BFieldElement],
831                ) -> ::core::result::Result<::std::boxed::Box<Self>, Self::Error> {
832                    #decode_function_body
833                }
834
835                fn encode(&self) -> ::std::vec::Vec<
836                    crate::twenty_first::prelude::BFieldElement
837                > {
838                    let mut elements = ::std::vec::Vec::new();
839                    #(#encode_statements)*
840                    elements
841                }
842
843                fn static_length() -> ::core::option::Option<usize> {
844                    #static_length_body
845                }
846            }
847        }
848    }
849}
850
851impl BFieldCodecErrorEnumBuilder {
852    fn new(name: syn::Ident) -> Self {
853        Self {
854            name,
855            errors: HashMap::new(),
856        }
857    }
858
859    fn build(&mut self, derive_type: BFieldCodecDeriveType) {
860        match derive_type {
861            BFieldCodecDeriveType::UnitStruct => self.set_up_unit_struct_errors(),
862            BFieldCodecDeriveType::StructWithNamedFields
863            | BFieldCodecDeriveType::StructWithUnnamedFields => self.set_up_struct_errors(),
864            BFieldCodecDeriveType::Enum => self.set_up_enum_errors(),
865        }
866    }
867
868    fn set_up_unit_struct_errors(&mut self) {
869        self.register_error_sequence_too_long();
870        self.register_error_inner_decoding_failure();
871    }
872
873    fn set_up_struct_errors(&mut self) {
874        self.register_error_sequence_empty();
875        self.register_error_sequence_empty_for_field();
876        self.register_error_sequence_too_short_for_field();
877        self.register_error_sequence_too_long();
878        self.register_error_inner_decoding_failure();
879    }
880
881    fn set_up_enum_errors(&mut self) {
882        self.register_error_sequence_empty();
883        self.register_error_sequence_empty_for_variant();
884        self.register_error_sequence_too_short_for_variant();
885        self.register_error_sequence_too_long();
886        self.register_error_invalid_discriminant();
887        self.register_error_inner_decoding_failure();
888    }
889
890    fn register_error(
891        &mut self,
892        error_id: &'static str,
893        variant_name: Ident,
894        variant_type: TokenStream,
895        display_match_arm: TokenStream,
896    ) {
897        self.errors.insert(
898            error_id,
899            BFieldCodecErrorEnumVariant {
900                variant_name,
901                variant_type,
902                display_match_arm,
903            },
904        );
905    }
906
907    fn global_identifier(&self, variant_name: &Ident) -> TokenStream {
908        let error_enum_name = self.error_enum_name();
909        quote! { #error_enum_name::#variant_name }
910    }
911
912    fn error_enum_name(&self) -> syn::Ident {
913        quote::format_ident!("{}BFieldDecodingError", self.name)
914    }
915
916    fn register_error_sequence_too_long(&mut self) {
917        let name = self.name.to_string();
918
919        let variant_name = quote::format_ident!("SequenceTooLong");
920        let variant_type = quote! { #variant_name(usize) };
921        let display_match_arm = quote! {
922            Self::#variant_name(num_remaining_elements) => ::core::write!(
923                f,
924                "cannot decode {}: sequence too long ({num_remaining_elements} elements remaining)",
925                #name
926            )
927        };
928
929        self.register_error(
930            "seq_too_long",
931            variant_name,
932            variant_type,
933            display_match_arm,
934        );
935    }
936
937    fn register_error_sequence_empty(&mut self) {
938        let name = self.name.to_string();
939
940        let variant_name = quote::format_ident!("SequenceEmpty");
941        let variant_type = quote! { #variant_name };
942        let display_match_arm = quote! {
943            Self::#variant_name => ::core::write!( f, "cannot decode {}: sequence is empty", #name )
944        };
945
946        self.register_error("seq_empty", variant_name, variant_type, display_match_arm);
947    }
948
949    fn register_error_sequence_empty_for_field(&mut self) {
950        let name = self.name.to_string();
951
952        let variant_name = quote::format_ident!("SequenceEmptyForField");
953        let variant_type = quote! { #variant_name(String) };
954        let display_match_arm = quote! {
955            Self::#variant_name(field_name) => ::core::write!(
956                f,
957                "cannot decode {}, field {field_name}: sequence is empty",
958                #name,
959            )
960        };
961
962        self.register_error(
963            "seq_empty_for_field",
964            variant_name,
965            variant_type,
966            display_match_arm,
967        );
968    }
969
970    fn register_error_sequence_too_short_for_field(&mut self) {
971        let name = self.name.to_string();
972
973        let variant_name = quote::format_ident!("SequenceTooShortForField");
974        let variant_type = quote! { #variant_name(String) };
975        let display_match_arm = quote! {
976            Self::#variant_name(field_name) => ::core::write!(
977                f,
978                "cannot decode {}, field {field_name}: sequence too short",
979                #name,
980            )
981        };
982
983        self.register_error(
984            "seq_too_short_for_field",
985            variant_name,
986            variant_type,
987            display_match_arm,
988        );
989    }
990
991    fn register_error_sequence_empty_for_variant(&mut self) {
992        let name = self.name.to_string();
993
994        let variant_name = quote::format_ident!("SequenceEmptyForVariant");
995        let variant_type = quote! { #variant_name(usize, usize) };
996        let display_match_arm = quote! {
997            Self::#variant_name(variant_id, field_id) => ::core::write!(
998                f,
999                "cannot decode {}, variant {variant_id}, field {field_id}: sequence is empty",
1000                #name,
1001            )
1002        };
1003
1004        self.register_error(
1005            "seq_empty_for_variant",
1006            variant_name,
1007            variant_type,
1008            display_match_arm,
1009        );
1010    }
1011
1012    fn register_error_sequence_too_short_for_variant(&mut self) {
1013        let name = self.name.to_string();
1014
1015        let variant_name = quote::format_ident!("SequenceTooShortForVariant");
1016        let variant_type = quote! { #variant_name(usize, usize) };
1017        let display_match_arm = quote! {
1018            Self::#variant_name(variant_id, field_id) => ::core::write!(
1019                f,
1020                "cannot decode {}, variant {variant_id}, field {field_id}: sequence too short",
1021                #name,
1022            )
1023        };
1024
1025        self.register_error(
1026            "seq_too_short_for_variant",
1027            variant_name,
1028            variant_type,
1029            display_match_arm,
1030        );
1031    }
1032
1033    fn register_error_invalid_discriminant(&mut self) {
1034        let name = self.name.to_string();
1035
1036        let variant_name = quote::format_ident!("InvalidVariantIndex");
1037        let variant_type = quote! { #variant_name(usize) };
1038        let display_match_arm = quote! {
1039            Self::#variant_name(discriminant) => ::core::write!(
1040                f,
1041                "cannot decode {}: invalid variant index {discriminant}",
1042                #name
1043            )
1044        };
1045
1046        self.register_error(
1047            "invalid_discriminant",
1048            variant_name,
1049            variant_type,
1050            display_match_arm,
1051        );
1052    }
1053
1054    fn register_error_inner_decoding_failure(&mut self) {
1055        let name = self.name.to_string();
1056
1057        let variant_name = quote::format_ident!("InnerDecodingFailure");
1058        let variant_type = quote! {
1059            #variant_name(::std::boxed::Box<
1060                    dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1061                >
1062            )
1063        };
1064        let display_match_arm = quote! {
1065            Self::#variant_name(inner_error) => ::core::write!(
1066                f,
1067                "cannot decode {}: inner decoding failure: {}",
1068                #name,
1069                inner_error
1070            )
1071        };
1072
1073        self.register_error(
1074            "inner_decoding_failure",
1075            variant_name,
1076            variant_type,
1077            display_match_arm,
1078        );
1079    }
1080
1081    fn sequence_too_long(&self) -> TokenStream {
1082        let error = self.errors.get("seq_too_long").unwrap();
1083        self.global_identifier(&error.variant_name)
1084    }
1085
1086    fn sequence_empty(&self) -> TokenStream {
1087        let error = self.errors.get("seq_empty").unwrap();
1088        self.global_identifier(&error.variant_name)
1089    }
1090
1091    fn sequence_empty_for_field(&self) -> TokenStream {
1092        let error = self.errors.get("seq_empty_for_field").unwrap();
1093        self.global_identifier(&error.variant_name)
1094    }
1095
1096    fn sequence_too_short_for_field(&self) -> TokenStream {
1097        let error = self.errors.get("seq_too_short_for_field").unwrap();
1098        self.global_identifier(&error.variant_name)
1099    }
1100
1101    fn sequence_empty_for_variant(&self) -> TokenStream {
1102        let error = self.errors.get("seq_empty_for_variant").unwrap();
1103        self.global_identifier(&error.variant_name)
1104    }
1105
1106    fn sequence_too_short_for_variant(&self) -> TokenStream {
1107        let error = self.errors.get("seq_too_short_for_variant").unwrap();
1108        self.global_identifier(&error.variant_name)
1109    }
1110
1111    fn invalid_discriminant(&self) -> TokenStream {
1112        let error = self.errors.get("invalid_discriminant").unwrap();
1113        self.global_identifier(&error.variant_name)
1114    }
1115
1116    fn into_tokens(self) -> TokenStream {
1117        let error_enum_name = self.error_enum_name();
1118        let inner_decoding_failure_name = self
1119            .errors
1120            .get("inner_decoding_failure")
1121            .unwrap()
1122            .variant_name
1123            .clone();
1124
1125        let errors = self.errors.values();
1126        let variant_types = errors
1127            .clone()
1128            .map(|error| error.variant_type.clone())
1129            .collect::<Vec<_>>();
1130        let display_match_arms = errors
1131            .map(|error| error.display_match_arm.clone())
1132            .collect::<Vec<_>>();
1133
1134        quote! {
1135            #[derive(::core::fmt::Debug)]
1136            pub enum #error_enum_name {
1137                #( #variant_types , )*
1138            }
1139            impl ::std::error::Error for #error_enum_name {}
1140            impl ::std::fmt::Display for #error_enum_name {
1141                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
1142                    match self {
1143                        #( #display_match_arms , )*
1144                    }
1145                }
1146            }
1147            impl ::core::convert::From<::std::boxed::Box<
1148                dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1149            >>
1150            for #error_enum_name
1151            {
1152                fn from(err: ::std::boxed::Box<
1153                    dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1154                >)
1155                -> Self {
1156                    Self::#inner_decoding_failure_name(err)
1157                }
1158            }
1159        }
1160    }
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use syn::parse_quote;
1166
1167    use super::*;
1168
1169    #[test]
1170    fn unit_struct() {
1171        let ast = parse_quote! {
1172            #[derive(BFieldCodec)]
1173            struct UnitStruct;
1174        };
1175        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1176    }
1177
1178    #[test]
1179    fn tuple_struct() {
1180        let ast = parse_quote! {
1181            #[derive(BFieldCodec)]
1182            struct TupleStruct(u64, u32);
1183        };
1184        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1185    }
1186
1187    #[test]
1188    fn struct_with_named_fields() {
1189        let ast = parse_quote! {
1190            #[derive(BFieldCodec)]
1191            struct StructWithNamedFields {
1192                field1: u64,
1193                field2: u32,
1194                #[bfield_codec(ignore)]
1195                ignored_field: bool,
1196            }
1197        };
1198        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1199    }
1200
1201    #[test]
1202    fn enum_with_tuple_variants() {
1203        let ast = parse_quote! {
1204            #[derive(BFieldCodec)]
1205            enum Enum {
1206                Variant1,
1207                Variant2(u64),
1208                Variant3(u64, u32),
1209                #[bfield_codec(ignore)]
1210                IgnoredVariant,
1211            }
1212        };
1213        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1214    }
1215
1216    #[test]
1217    fn generic_tuple_struct() {
1218        let ast = parse_quote! {
1219            #[derive(BFieldCodec)]
1220            struct TupleStruct<T>(T, (T, T));
1221        };
1222        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1223    }
1224
1225    #[test]
1226    fn generic_struct_with_named_fields() {
1227        let ast = parse_quote! {
1228            #[derive(BFieldCodec)]
1229            struct StructWithNamedFields<T> {
1230                field1: T,
1231                field2: (T, T),
1232                #[bfield_codec(ignore)]
1233                ignored_field: bool,
1234            }
1235        };
1236        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1237    }
1238
1239    #[test]
1240    fn generic_enum() {
1241        let ast = parse_quote! {
1242            #[derive(BFieldCodec)]
1243            enum Enum<T> {
1244                Variant1,
1245                Variant2(T),
1246                Variant3(T, T),
1247                #[bfield_codec(ignore)]
1248                IgnoredVariant,
1249            }
1250        };
1251        let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1252    }
1253}