1extern crate proc_macro;
16
17use std::collections::HashMap;
18
19use proc_macro2::TokenStream;
20use quote::quote;
21use syn::parse_macro_input;
22use syn::punctuated::Punctuated;
23use syn::token::Comma;
24use syn::Attribute;
25use syn::DeriveInput;
26use syn::Field;
27use syn::Fields;
28use syn::Ident;
29use syn::Type;
30use syn::Variant;
31
32#[proc_macro_derive(BFieldCodec, attributes(bfield_codec))]
84pub fn bfieldcodec_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
85 let ast = parse_macro_input!(input as DeriveInput);
86 BFieldCodecDeriveBuilder::new(ast).build().into()
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90enum BFieldCodecDeriveType {
91 UnitStruct,
92 StructWithNamedFields,
93 StructWithUnnamedFields,
94 Enum,
95}
96
97struct BFieldCodecDeriveBuilder {
98 name: Ident,
99 derive_type: BFieldCodecDeriveType,
100 generics: syn::Generics,
101 attributes: Vec<Attribute>,
102
103 named_included_fields: Vec<Field>,
104 named_ignored_fields: Vec<Field>,
105
106 unnamed_fields: Vec<Field>,
107
108 variants: Option<Punctuated<Variant, syn::token::Comma>>,
109
110 encode_statements: Vec<TokenStream>,
111 decode_function_body: TokenStream,
112 static_length_body: TokenStream,
113 error_builder: BFieldCodecErrorEnumBuilder,
114}
115
116struct BFieldCodecErrorEnumBuilder {
117 name: Ident,
118 errors: HashMap<&'static str, BFieldCodecErrorEnumVariant>,
119}
120
121struct BFieldCodecErrorEnumVariant {
122 variant_name: Ident,
123 variant_type: TokenStream,
124 display_match_arm: TokenStream,
125}
126
127impl BFieldCodecDeriveBuilder {
128 fn new(ast: DeriveInput) -> Self {
129 let derive_type = Self::extract_derive_type(&ast);
130
131 let named_fields = Self::extract_named_fields(&ast);
132 let (ignored_fields, included_fields) = named_fields
133 .iter()
134 .cloned()
135 .partition::<Vec<_>, _>(Self::field_is_ignored);
136
137 let unnamed_fields = Self::extract_unnamed_fields(&ast);
138 let variants = Self::extract_variants(&ast);
139
140 let name = ast.ident;
141 let error_builder = BFieldCodecErrorEnumBuilder::new(name.clone());
142
143 Self {
144 name,
145 derive_type,
146 generics: ast.generics,
147 attributes: ast.attrs,
148
149 named_included_fields: included_fields,
150 named_ignored_fields: ignored_fields,
151 unnamed_fields,
152 variants,
153
154 encode_statements: vec![],
155 decode_function_body: quote! {},
156 static_length_body: quote! {},
157 error_builder,
158 }
159 }
160
161 fn extract_derive_type(ast: &DeriveInput) -> BFieldCodecDeriveType {
162 match &ast.data {
163 syn::Data::Struct(syn::DataStruct {
164 fields: Fields::Unit,
165 ..
166 }) => BFieldCodecDeriveType::UnitStruct,
167 syn::Data::Struct(syn::DataStruct {
168 fields: Fields::Named(_),
169 ..
170 }) => BFieldCodecDeriveType::StructWithNamedFields,
171 syn::Data::Struct(syn::DataStruct {
172 fields: Fields::Unnamed(_),
173 ..
174 }) => BFieldCodecDeriveType::StructWithUnnamedFields,
175 syn::Data::Enum(_) => BFieldCodecDeriveType::Enum,
176 _ => panic!("expected a struct or an enum"),
177 }
178 }
179
180 fn extract_named_fields(ast: &DeriveInput) -> Vec<Field> {
181 match &ast.data {
182 syn::Data::Struct(syn::DataStruct {
183 fields: Fields::Named(fields),
184 ..
185 }) => fields.named.iter().rev().cloned().collect::<Vec<_>>(),
186 _ => vec![],
187 }
188 }
189
190 fn extract_unnamed_fields(ast: &DeriveInput) -> Vec<Field> {
191 match &ast.data {
192 syn::Data::Struct(syn::DataStruct {
193 fields: Fields::Unnamed(fields),
194 ..
195 }) => fields.unnamed.iter().cloned().collect::<Vec<_>>(),
196 _ => vec![],
197 }
198 }
199
200 fn extract_variants(ast: &DeriveInput) -> Option<Punctuated<Variant, Comma>> {
201 match &ast.data {
202 syn::Data::Enum(data_enum) => Some(data_enum.variants.clone()),
203 _ => None,
204 }
205 }
206
207 fn field_is_ignored(field: &Field) -> bool {
208 let field_name = field.ident.as_ref().unwrap();
209 let mut relevant_attributes = field
210 .attrs
211 .iter()
212 .filter(|attr| attr.path().is_ident("bfield_codec"));
213 let attribute = match relevant_attributes.clone().count() {
214 0 => return false,
215 1 => relevant_attributes.next().unwrap(),
216 _ => panic!("field `{field_name}` must have at most 1 `bfield_codec` attribute"),
217 };
218 let parse_ignore = attribute.parse_nested_meta(|meta| match meta.path.get_ident() {
219 Some(ident) if ident == "ignore" => Ok(()),
220 Some(ident) => panic!("unknown identifier `{ident}` for field `{field_name}`"),
221 _ => unreachable!(),
222 });
223 parse_ignore.is_ok()
224 }
225
226 fn build(mut self) -> TokenStream {
227 self.error_builder.build(self.derive_type);
228 self.add_trait_bounds_to_generics();
229 self.build_methods();
230 self.into_tokens()
231 }
232
233 fn add_trait_bounds_to_generics(&mut self) {
234 let ignored_generics = self.extract_ignored_generics_list();
235 let ignored_generics = self.recursively_collect_all_ignored_generics(ignored_generics);
236
237 for param in &mut self.generics.params {
238 let syn::GenericParam::Type(type_param) = param else {
239 continue;
240 };
241 if ignored_generics.contains(&type_param.ident) {
242 continue;
243 }
244 type_param.bounds.push(syn::parse_quote!(BFieldCodec));
245 }
246 }
247
248 fn extract_ignored_generics_list(&self) -> Vec<syn::Ident> {
249 self.attributes
250 .iter()
251 .flat_map(Self::extract_ignored_generics)
252 .collect()
253 }
254
255 fn extract_ignored_generics(attr: &Attribute) -> Vec<Ident> {
256 if !attr.path().is_ident("bfield_codec") {
257 return vec![];
258 }
259
260 let mut ignored_generics = vec![];
261 attr.parse_nested_meta(|meta| match meta.path.get_ident() {
262 Some(ident) if ident == "ignore" => {
263 ignored_generics.push(ident.to_owned());
264 Ok(())
265 }
266 Some(ident) => Err(meta.error(format!("Unknown identifier \"{ident}\"."))),
267 _ => Err(meta.error("Expected an identifier.")),
268 })
269 .unwrap();
270 ignored_generics
271 }
272
273 fn recursively_collect_all_ignored_generics(
276 &self,
277 mut ignored_generics: Vec<Ident>,
278 ) -> Vec<Ident> {
279 let mut ignored_types = self
280 .named_ignored_fields
281 .iter()
282 .map(|ignored_field| ignored_field.ty.clone())
283 .collect::<Vec<_>>();
284 while !ignored_types.is_empty() {
285 let ignored_type = ignored_types[0].clone();
286 ignored_types = ignored_types[1..].to_vec();
287 let Type::Path(type_path) = ignored_type else {
288 continue;
289 };
290 for segment in type_path.path.segments.into_iter() {
291 ignored_generics.push(segment.ident);
292 let syn::PathArguments::AngleBracketed(generic_arguments) = segment.arguments
293 else {
294 continue;
295 };
296 for generic_argument in generic_arguments.args.into_iter() {
297 let syn::GenericArgument::Type(t) = generic_argument else {
298 continue;
299 };
300 ignored_types.push(t.clone());
301 }
302 }
303 }
304 ignored_generics
305 }
306
307 fn build_methods(&mut self) {
308 match self.derive_type {
309 BFieldCodecDeriveType::UnitStruct => self.build_methods_for_unit_struct(),
310 BFieldCodecDeriveType::StructWithNamedFields => {
311 self.build_methods_for_struct_with_named_fields()
312 }
313 BFieldCodecDeriveType::StructWithUnnamedFields => {
314 self.build_methods_for_struct_with_unnamed_fields()
315 }
316 BFieldCodecDeriveType::Enum => self.build_methods_for_enum(),
317 }
318 }
319
320 fn build_methods_for_unit_struct(&mut self) {
321 self.build_decode_function_body_for_unit_struct();
322 self.static_length_body = quote! {::core::option::Option::Some(0)};
323 }
324
325 fn build_methods_for_struct_with_named_fields(&mut self) {
326 self.build_encode_statements_for_struct_with_named_fields();
327 self.build_decode_function_body_for_struct_with_named_fields();
328 let included_fields = self.named_included_fields.clone();
329 self.build_static_length_body_for_struct(&included_fields);
330 }
331
332 fn build_methods_for_struct_with_unnamed_fields(&mut self) {
333 self.build_encode_statements_for_struct_with_unnamed_fields();
334 self.build_decode_function_body_for_struct_with_unnamed_fields();
335 let included_fields = self.unnamed_fields.clone();
336 self.build_static_length_body_for_struct(&included_fields);
337 }
338
339 fn build_methods_for_enum(&mut self) {
340 self.build_encode_statements_for_enum();
341 self.build_decode_function_body_for_enum();
342 self.build_static_length_body_for_enum();
343 }
344
345 fn build_encode_statements_for_struct_with_named_fields(&mut self) {
346 let included_field_names = self
347 .named_included_fields
348 .iter()
349 .map(|field| field.ident.as_ref().unwrap().to_owned());
350 let included_field_types = self
351 .named_included_fields
352 .iter()
353 .map(|field| field.ty.clone());
354 self.encode_statements = included_field_names
355 .clone()
356 .zip(included_field_types.clone())
357 .map(|(field_name, field_type)| {
358 quote! {
359 let #field_name:
360 ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement>
361 = self.#field_name.encode();
362 if <#field_type as crate::twenty_first::prelude::BFieldCodec>
363 ::static_length().is_none() {
364 elements.push(
365 crate::twenty_first::prelude::BFieldElement::new(
366 #field_name.len() as u64
367 )
368 );
369 }
370 elements.extend(#field_name);
371 }
372 })
373 .collect();
374 }
375
376 fn build_encode_statements_for_struct_with_unnamed_fields(&mut self) {
377 let field_types = self.unnamed_fields.iter().map(|field| field.ty.clone());
378 let indices: Vec<_> = (0..self.unnamed_fields.len())
379 .map(syn::Index::from)
380 .collect();
381 let field_names: Vec<_> = indices
382 .iter()
383 .map(|i| quote::format_ident!("field_value_{}", i.index))
384 .collect();
385 self.encode_statements = indices
386 .iter()
387 .zip(field_types.clone())
388 .zip(field_names.clone())
389 .rev()
390 .map(|((idx, field_type), field_name)| {
391 quote! {
392 let #field_name:
393 ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement>
394 = self.#idx.encode();
395 if <#field_type as crate::twenty_first::prelude::BFieldCodec>
396 ::static_length().is_none() {
397 elements.push(
398 crate::twenty_first::prelude::BFieldElement::new(
399 #field_name.len() as u64
400 )
401 );
402 }
403 elements.extend(#field_name);
404 }
405 })
406 .collect();
407 }
408
409 fn build_encode_statements_for_enum(&mut self) {
410 let encode_clauses = self
411 .enum_discriminants_and_variants()
412 .into_iter()
413 .map(|(d, v)| self.generate_encode_clause_for_variant(d, v));
414 let encode_match_statement = quote! {
415 match self {
416 #( #encode_clauses , )*
417 }
418 };
419 self.encode_statements = vec![encode_match_statement];
420 }
421
422 fn generate_encode_clause_for_variant(
423 &self,
424 discriminant: usize,
425 variant: &Variant,
426 ) -> TokenStream {
427 let variant_name = &variant.ident;
428 let associated_data = &variant.fields;
429
430 if associated_data.is_empty() {
431 return quote! {
432 Self::#variant_name => {
433 elements.push(crate::twenty_first::prelude::BFieldElement::new(
434 #discriminant as u64)
435 );
436 }
437 };
438 }
439
440 let reversed_enumerated_associated_data = associated_data.iter().enumerate().rev();
441 let field_encoders = reversed_enumerated_associated_data.map(|(field_index, ad)| {
442 let field_name = self.enum_variant_field_name(discriminant, field_index);
443 let field_type = ad.ty.clone();
444 let field_encoding =
445 quote::format_ident!("variant_{}_field_{}_encoding", discriminant, field_index);
446 quote! {
447 let #field_encoding:
448 ::std::vec::Vec<crate::twenty_first::prelude::BFieldElement> =
449 #field_name.encode();
450 if <#field_type as crate::twenty_first::prelude::BFieldCodec>
451 ::static_length().is_none() {
452 elements.push(
453 crate::twenty_first::prelude::BFieldElement::new(
454 #field_encoding.len() as u64
455 )
456 );
457 }
458 elements.extend(#field_encoding);
459 }
460 });
461
462 let field_names = associated_data
463 .iter()
464 .enumerate()
465 .map(|(field_index, _field)| self.enum_variant_field_name(discriminant, field_index));
466
467 quote! {
468 Self::#variant_name ( #( #field_names , )* ) => {
469 elements.push(
470 crate::twenty_first::prelude::BFieldElement::new(
471 #discriminant as u64
472 )
473 );
474 #( #field_encoders )*
475 }
476 }
477 }
478
479 fn build_decode_function_body_for_unit_struct(&mut self) {
480 let sequence_too_long_error = self.error_builder.sequence_too_long();
481
482 self.decode_function_body = quote! {
483 if !sequence.is_empty() {
484 return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
485 }
486 ::core::result::Result::Ok(::std::boxed::Box::new(Self))
487 };
488 }
489
490 fn build_decode_function_body_for_struct_with_named_fields(&mut self) {
491 let sequence_too_long_error = self.error_builder.sequence_too_long();
492
493 let decode_statements = self
494 .named_included_fields
495 .iter()
496 .map(|field| {
497 let field_name = field.ident.as_ref().unwrap();
498 self.generate_decode_statement_for_field(field_name, &field.ty)
499 })
500 .collect::<Vec<_>>();
501
502 let included_field_names = self.named_included_fields.iter().map(|field| {
503 let field_name = field.ident.as_ref().unwrap().to_owned();
504 quote! { #field_name }
505 });
506 let ignored_field_names = self.named_ignored_fields.iter().map(|field| {
507 let field_name = field.ident.as_ref().unwrap().to_owned();
508 quote! { #field_name }
509 });
510
511 self.decode_function_body = quote! {
512 #(#decode_statements)*
513 if !sequence.is_empty() {
514 return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
515 }
516 ::core::result::Result::Ok(::std::boxed::Box::new(Self {
517 #(#included_field_names,)*
518 #(#ignored_field_names: ::core::default::Default::default(),)*
519 }))
520 };
521 }
522
523 fn build_decode_function_body_for_struct_with_unnamed_fields(&mut self) {
524 let sequence_too_long_error = self.error_builder.sequence_too_long();
525
526 let field_names = (0..self.unnamed_fields.len())
527 .map(|i| quote::format_ident!("field_value_{}", i))
528 .collect::<Vec<_>>();
529 let decode_statements = field_names
530 .iter()
531 .zip(self.unnamed_fields.iter())
532 .rev()
533 .map(|(field_name, field)| {
534 self.generate_decode_statement_for_field(field_name, &field.ty)
535 })
536 .collect::<Vec<_>>();
537
538 self.decode_function_body = quote! {
539 #(#decode_statements)*
540 if !sequence.is_empty() {
541 return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
542 }
543 ::core::result::Result::Ok(::std::boxed::Box::new(Self ( #(#field_names,)* )))
544 };
545 }
546
547 fn generate_decode_statement_for_field(
548 &self,
549 field_name: &Ident,
550 field_type: &Type,
551 ) -> TokenStream {
552 let sequence_empty_for_field_error = self.error_builder.sequence_empty_for_field();
553 let sequence_too_short_for_field_error = self.error_builder.sequence_too_short_for_field();
554 let field_name_as_string_literal = field_name.to_string();
555 quote! {
556 let (#field_name, sequence) = {
557 let maybe_fields_static_length =
558 <#field_type as crate::twenty_first::prelude::BFieldCodec>
559 ::static_length();
560 let field_has_dynamic_length = maybe_fields_static_length.is_none();
561 if sequence.is_empty() && field_has_dynamic_length {
562 return ::core::result::Result::Err(
563 #sequence_empty_for_field_error(#field_name_as_string_literal.to_string())
564 );
565 }
566 let (len, sequence) = match maybe_fields_static_length {
567 ::core::option::Option::Some(len) => (len, sequence),
568 ::core::option::Option::None => (sequence[0].value() as usize, &sequence[1..]),
569 };
570 if sequence.len() < len {
571 return ::core::result::Result::Err(#sequence_too_short_for_field_error(
572 #field_name_as_string_literal.to_string(),
573 ));
574 }
575 let decoded =
576 *<#field_type as crate::twenty_first::prelude::BFieldCodec>
577 ::decode(&sequence[..len]).map_err(|err|
578 -> ::std::boxed::Box<
579 dyn ::std::error::Error
580 + ::core::marker::Send
581 + ::core::marker::Sync
582 > {
583 err.into()
584 }
585 )?;
586 (decoded, &sequence[len..])
587 };
588 }
589 }
590
591 fn build_decode_function_body_for_enum(&mut self) {
592 let sequence_empty_error = self.error_builder.sequence_empty();
593 let invalid_variant_error = self.error_builder.invalid_discriminant();
594
595 let mut match_arms = vec![];
596 for (discriminant, variant) in self.enum_discriminants_and_variants() {
597 let decode_clause = self.generate_decode_clause_for_variant(discriminant, variant);
598 let match_arm = quote! { #discriminant => { #decode_clause } };
599 match_arms.push(match_arm);
600 }
601
602 self.decode_function_body = quote! {
603 if sequence.is_empty() {
604 return ::core::result::Result::Err(#sequence_empty_error);
605 }
606 let (discriminant, sequence) = (sequence[0].value() as usize, &sequence[1..]);
607 match discriminant {
608 #(#match_arms ,)*
609 other_index => ::core::result::Result::Err(#invalid_variant_error(other_index)),
610 }
611 };
612 }
613
614 fn generate_decode_clause_for_variant(
615 &self,
616 discriminant: usize,
617 variant: &Variant,
618 ) -> TokenStream {
619 let sequence_too_long_error = self.error_builder.sequence_too_long();
620 let sequence_empty_error = self.error_builder.sequence_empty_for_variant();
621 let sequence_too_short_error = self.error_builder.sequence_too_short_for_variant();
622
623 let variant_name = &variant.ident;
624 let associated_data = &variant.fields;
625 if associated_data.is_empty() {
626 return quote! {
627 if !sequence.is_empty() {
628 return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
629 }
630 ::core::result::Result::Ok(::std::boxed::Box::new(Self::#variant_name))
631 };
632 }
633
634 let field_decoders = associated_data
635 .iter()
636 .enumerate()
637 .rev()
638 .map(|(field_index, field)| {
639 let field_type = field.ty.clone();
640 let field_name = self.enum_variant_field_name(discriminant, field_index);
641 let field_value =
642 quote::format_ident!("variant_{}_field_{}_value", discriminant, field_index);
643 quote! {
644 let (#field_value, sequence) = {
645 let maybe_fields_static_length =
646 <#field_type as crate::twenty_first::prelude::BFieldCodec>
647 ::static_length();
648 let field_has_dynamic_length = maybe_fields_static_length.is_none();
649 if sequence.is_empty() && field_has_dynamic_length {
650 return ::core::result::Result::Err(
651 #sequence_empty_error(#discriminant, #field_index)
652 );
653 }
654 let (len, sequence) = match maybe_fields_static_length {
655 ::core::option::Option::Some(len) => (len, sequence),
656 ::core::option::Option::None => {
657 (sequence[0].value() as usize, &sequence[1..])
658 },
659 };
660 if sequence.len() < len {
661 return ::core::result::Result::Err(
662 #sequence_too_short_error(#discriminant, #field_index)
663 );
664 }
665 let decoded =
666 *<#field_type as crate::twenty_first::prelude::BFieldCodec>
667 ::decode(
668 &sequence[..len]
669 ).map_err(|err|
670 -> ::std::boxed::Box<
671 dyn ::std::error::Error
672 + ::core::marker::Send
673 + ::core::marker::Sync
674 > {
675 err.into()
676 }
677 )?;
678 (decoded, &sequence[len..])
679 };
680 let #field_name = #field_value;
681 }
682 })
683 .fold(quote! {}, |l, r| quote! {#l #r});
684 let field_names = associated_data
685 .iter()
686 .enumerate()
687 .map(|(field_index, _field)| self.enum_variant_field_name(discriminant, field_index));
688 quote! {
689 #field_decoders
690 if !sequence.is_empty() {
691 return ::core::result::Result::Err(#sequence_too_long_error(sequence.len()));
692 }
693 ::core::result::Result::Ok(
694 ::std::boxed::Box::new(Self::#variant_name ( #( #field_names , )* ))
695 )
696 }
697 }
698
699 fn enum_variant_field_name(&self, discriminant: usize, field_index: usize) -> syn::Ident {
700 quote::format_ident!("variant_{}_field_{}", discriminant, field_index)
701 }
702
703 fn build_static_length_body_for_struct(&mut self, fields: &[Field]) {
704 let field_types = fields
705 .iter()
706 .map(|field| field.ty.clone())
707 .collect::<Vec<_>>();
708 let num_fields = field_types.len();
709 self.static_length_body = quote! {
710 let field_lengths : [::core::option::Option<usize>; #num_fields] = [
711 #(
712 <#field_types as
713 crate::twenty_first::prelude::BFieldCodec>::static_length(),
714 )*
715 ];
716 if field_lengths.iter().all(|fl| fl.is_some() ) {
717 ::core::option::Option::Some(field_lengths.iter().map(|fl| fl.unwrap()).sum())
718 }
719 else {
720 ::core::option::Option::None
721 }
722 };
723 }
724
725 fn build_static_length_body_for_enum(&mut self) {
726 let variants = self.variants.as_ref().unwrap();
727 let no_variants_have_associated_data = variants.iter().all(|v| v.fields.is_empty());
728 if no_variants_have_associated_data {
729 self.static_length_body = quote! {::core::option::Option::Some(1)};
730 return;
731 }
732
733 let num_variants = variants.len();
734 if num_variants == 0 {
735 self.static_length_body = quote! {::core::option::Option::Some(0)};
736 return;
737 }
738
739 let variant_lengths = variants
742 .iter()
743 .map(|variant| {
744 let fields = variant.fields.clone();
745 let field_lengths = fields.iter().map(|f| {
746 quote! {
747 <#f as crate::twenty_first::prelude::BFieldCodec>
748 ::static_length()
749 }
750 });
751 let num_fields = fields.len();
752 quote! {{
753 let field_lengths: [::core::option::Option<usize>; #num_fields] =
754 [ #( #field_lengths , )* ];
755 if field_lengths.iter().all(|fl| fl.is_some()) {
756 Some(field_lengths.iter().map(|fl|fl.unwrap()).sum())
757 } else {
758 None
759 }
760 }}
761 })
762 .collect::<Vec<_>>();
763
764 self.static_length_body = quote! {
765 let variant_lengths : [::core::option::Option<usize>; #num_variants] =
766 [ #( #variant_lengths , )* ];
767 if variant_lengths.iter().all(|field_len| field_len.is_some()) &&
768 variant_lengths.iter().all(|x| x.unwrap() == variant_lengths[0].unwrap()) {
769 Some(variant_lengths[0].unwrap() + 1)
771 }
772 else {
773 None
774 }
775
776 };
777 }
778
779 fn enum_discriminants_and_variants(&self) -> Vec<(usize, &Variant)> {
780 self.variants.as_ref().unwrap().iter().enumerate().collect()
781 }
782
783 fn maybe_impl_enum_discriminants(&self) -> TokenStream {
784 if self.derive_type != BFieldCodecDeriveType::Enum {
785 return quote! {};
786 }
787
788 let mut variant_match_arms = vec![];
789 for (discriminant, variant) in self.enum_discriminants_and_variants() {
790 let ident = &variant.ident;
791 let mut match_statement = quote! { Self::#ident };
792 if !variant.fields.is_empty() {
793 match_statement.extend(quote! { ( .. ) });
794 }
795 let match_arm = quote! { #match_statement => #discriminant };
796 variant_match_arms.push(match_arm);
797 }
798
799 let name = self.name.clone();
800 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
801 quote! {
802 impl #impl_generics #name #ty_generics #where_clause {
803 pub fn bfield_codec_discriminant(&self) -> usize {
804 match self {
805 #( #variant_match_arms , )*
806 }
807 }
808 }
809 }
810 }
811
812 fn into_tokens(self) -> TokenStream {
813 let maybe_impl_enum_discriminants = self.maybe_impl_enum_discriminants();
814 let name = self.name;
815 let error_enum_name = self.error_builder.error_enum_name();
816 let errors = self.error_builder.into_tokens();
817 let decode_function_body = self.decode_function_body;
818 let encode_statements = self.encode_statements;
819 let static_length_body = self.static_length_body;
820 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
821
822 quote! {
823 #maybe_impl_enum_discriminants
824 #errors
825 impl #impl_generics crate::twenty_first::prelude::BFieldCodec
826 for #name #ty_generics #where_clause {
827 type Error = #error_enum_name;
828
829 fn decode(
830 sequence: &[crate::twenty_first::prelude::BFieldElement],
831 ) -> ::core::result::Result<::std::boxed::Box<Self>, Self::Error> {
832 #decode_function_body
833 }
834
835 fn encode(&self) -> ::std::vec::Vec<
836 crate::twenty_first::prelude::BFieldElement
837 > {
838 let mut elements = ::std::vec::Vec::new();
839 #(#encode_statements)*
840 elements
841 }
842
843 fn static_length() -> ::core::option::Option<usize> {
844 #static_length_body
845 }
846 }
847 }
848 }
849}
850
851impl BFieldCodecErrorEnumBuilder {
852 fn new(name: syn::Ident) -> Self {
853 Self {
854 name,
855 errors: HashMap::new(),
856 }
857 }
858
859 fn build(&mut self, derive_type: BFieldCodecDeriveType) {
860 match derive_type {
861 BFieldCodecDeriveType::UnitStruct => self.set_up_unit_struct_errors(),
862 BFieldCodecDeriveType::StructWithNamedFields
863 | BFieldCodecDeriveType::StructWithUnnamedFields => self.set_up_struct_errors(),
864 BFieldCodecDeriveType::Enum => self.set_up_enum_errors(),
865 }
866 }
867
868 fn set_up_unit_struct_errors(&mut self) {
869 self.register_error_sequence_too_long();
870 self.register_error_inner_decoding_failure();
871 }
872
873 fn set_up_struct_errors(&mut self) {
874 self.register_error_sequence_empty();
875 self.register_error_sequence_empty_for_field();
876 self.register_error_sequence_too_short_for_field();
877 self.register_error_sequence_too_long();
878 self.register_error_inner_decoding_failure();
879 }
880
881 fn set_up_enum_errors(&mut self) {
882 self.register_error_sequence_empty();
883 self.register_error_sequence_empty_for_variant();
884 self.register_error_sequence_too_short_for_variant();
885 self.register_error_sequence_too_long();
886 self.register_error_invalid_discriminant();
887 self.register_error_inner_decoding_failure();
888 }
889
890 fn register_error(
891 &mut self,
892 error_id: &'static str,
893 variant_name: Ident,
894 variant_type: TokenStream,
895 display_match_arm: TokenStream,
896 ) {
897 self.errors.insert(
898 error_id,
899 BFieldCodecErrorEnumVariant {
900 variant_name,
901 variant_type,
902 display_match_arm,
903 },
904 );
905 }
906
907 fn global_identifier(&self, variant_name: &Ident) -> TokenStream {
908 let error_enum_name = self.error_enum_name();
909 quote! { #error_enum_name::#variant_name }
910 }
911
912 fn error_enum_name(&self) -> syn::Ident {
913 quote::format_ident!("{}BFieldDecodingError", self.name)
914 }
915
916 fn register_error_sequence_too_long(&mut self) {
917 let name = self.name.to_string();
918
919 let variant_name = quote::format_ident!("SequenceTooLong");
920 let variant_type = quote! { #variant_name(usize) };
921 let display_match_arm = quote! {
922 Self::#variant_name(num_remaining_elements) => ::core::write!(
923 f,
924 "cannot decode {}: sequence too long ({num_remaining_elements} elements remaining)",
925 #name
926 )
927 };
928
929 self.register_error(
930 "seq_too_long",
931 variant_name,
932 variant_type,
933 display_match_arm,
934 );
935 }
936
937 fn register_error_sequence_empty(&mut self) {
938 let name = self.name.to_string();
939
940 let variant_name = quote::format_ident!("SequenceEmpty");
941 let variant_type = quote! { #variant_name };
942 let display_match_arm = quote! {
943 Self::#variant_name => ::core::write!( f, "cannot decode {}: sequence is empty", #name )
944 };
945
946 self.register_error("seq_empty", variant_name, variant_type, display_match_arm);
947 }
948
949 fn register_error_sequence_empty_for_field(&mut self) {
950 let name = self.name.to_string();
951
952 let variant_name = quote::format_ident!("SequenceEmptyForField");
953 let variant_type = quote! { #variant_name(String) };
954 let display_match_arm = quote! {
955 Self::#variant_name(field_name) => ::core::write!(
956 f,
957 "cannot decode {}, field {field_name}: sequence is empty",
958 #name,
959 )
960 };
961
962 self.register_error(
963 "seq_empty_for_field",
964 variant_name,
965 variant_type,
966 display_match_arm,
967 );
968 }
969
970 fn register_error_sequence_too_short_for_field(&mut self) {
971 let name = self.name.to_string();
972
973 let variant_name = quote::format_ident!("SequenceTooShortForField");
974 let variant_type = quote! { #variant_name(String) };
975 let display_match_arm = quote! {
976 Self::#variant_name(field_name) => ::core::write!(
977 f,
978 "cannot decode {}, field {field_name}: sequence too short",
979 #name,
980 )
981 };
982
983 self.register_error(
984 "seq_too_short_for_field",
985 variant_name,
986 variant_type,
987 display_match_arm,
988 );
989 }
990
991 fn register_error_sequence_empty_for_variant(&mut self) {
992 let name = self.name.to_string();
993
994 let variant_name = quote::format_ident!("SequenceEmptyForVariant");
995 let variant_type = quote! { #variant_name(usize, usize) };
996 let display_match_arm = quote! {
997 Self::#variant_name(variant_id, field_id) => ::core::write!(
998 f,
999 "cannot decode {}, variant {variant_id}, field {field_id}: sequence is empty",
1000 #name,
1001 )
1002 };
1003
1004 self.register_error(
1005 "seq_empty_for_variant",
1006 variant_name,
1007 variant_type,
1008 display_match_arm,
1009 );
1010 }
1011
1012 fn register_error_sequence_too_short_for_variant(&mut self) {
1013 let name = self.name.to_string();
1014
1015 let variant_name = quote::format_ident!("SequenceTooShortForVariant");
1016 let variant_type = quote! { #variant_name(usize, usize) };
1017 let display_match_arm = quote! {
1018 Self::#variant_name(variant_id, field_id) => ::core::write!(
1019 f,
1020 "cannot decode {}, variant {variant_id}, field {field_id}: sequence too short",
1021 #name,
1022 )
1023 };
1024
1025 self.register_error(
1026 "seq_too_short_for_variant",
1027 variant_name,
1028 variant_type,
1029 display_match_arm,
1030 );
1031 }
1032
1033 fn register_error_invalid_discriminant(&mut self) {
1034 let name = self.name.to_string();
1035
1036 let variant_name = quote::format_ident!("InvalidVariantIndex");
1037 let variant_type = quote! { #variant_name(usize) };
1038 let display_match_arm = quote! {
1039 Self::#variant_name(discriminant) => ::core::write!(
1040 f,
1041 "cannot decode {}: invalid variant index {discriminant}",
1042 #name
1043 )
1044 };
1045
1046 self.register_error(
1047 "invalid_discriminant",
1048 variant_name,
1049 variant_type,
1050 display_match_arm,
1051 );
1052 }
1053
1054 fn register_error_inner_decoding_failure(&mut self) {
1055 let name = self.name.to_string();
1056
1057 let variant_name = quote::format_ident!("InnerDecodingFailure");
1058 let variant_type = quote! {
1059 #variant_name(::std::boxed::Box<
1060 dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1061 >
1062 )
1063 };
1064 let display_match_arm = quote! {
1065 Self::#variant_name(inner_error) => ::core::write!(
1066 f,
1067 "cannot decode {}: inner decoding failure: {}",
1068 #name,
1069 inner_error
1070 )
1071 };
1072
1073 self.register_error(
1074 "inner_decoding_failure",
1075 variant_name,
1076 variant_type,
1077 display_match_arm,
1078 );
1079 }
1080
1081 fn sequence_too_long(&self) -> TokenStream {
1082 let error = self.errors.get("seq_too_long").unwrap();
1083 self.global_identifier(&error.variant_name)
1084 }
1085
1086 fn sequence_empty(&self) -> TokenStream {
1087 let error = self.errors.get("seq_empty").unwrap();
1088 self.global_identifier(&error.variant_name)
1089 }
1090
1091 fn sequence_empty_for_field(&self) -> TokenStream {
1092 let error = self.errors.get("seq_empty_for_field").unwrap();
1093 self.global_identifier(&error.variant_name)
1094 }
1095
1096 fn sequence_too_short_for_field(&self) -> TokenStream {
1097 let error = self.errors.get("seq_too_short_for_field").unwrap();
1098 self.global_identifier(&error.variant_name)
1099 }
1100
1101 fn sequence_empty_for_variant(&self) -> TokenStream {
1102 let error = self.errors.get("seq_empty_for_variant").unwrap();
1103 self.global_identifier(&error.variant_name)
1104 }
1105
1106 fn sequence_too_short_for_variant(&self) -> TokenStream {
1107 let error = self.errors.get("seq_too_short_for_variant").unwrap();
1108 self.global_identifier(&error.variant_name)
1109 }
1110
1111 fn invalid_discriminant(&self) -> TokenStream {
1112 let error = self.errors.get("invalid_discriminant").unwrap();
1113 self.global_identifier(&error.variant_name)
1114 }
1115
1116 fn into_tokens(self) -> TokenStream {
1117 let error_enum_name = self.error_enum_name();
1118 let inner_decoding_failure_name = self
1119 .errors
1120 .get("inner_decoding_failure")
1121 .unwrap()
1122 .variant_name
1123 .clone();
1124
1125 let errors = self.errors.values();
1126 let variant_types = errors
1127 .clone()
1128 .map(|error| error.variant_type.clone())
1129 .collect::<Vec<_>>();
1130 let display_match_arms = errors
1131 .map(|error| error.display_match_arm.clone())
1132 .collect::<Vec<_>>();
1133
1134 quote! {
1135 #[derive(::core::fmt::Debug)]
1136 pub enum #error_enum_name {
1137 #( #variant_types , )*
1138 }
1139 impl ::std::error::Error for #error_enum_name {}
1140 impl ::std::fmt::Display for #error_enum_name {
1141 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
1142 match self {
1143 #( #display_match_arms , )*
1144 }
1145 }
1146 }
1147 impl ::core::convert::From<::std::boxed::Box<
1148 dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1149 >>
1150 for #error_enum_name
1151 {
1152 fn from(err: ::std::boxed::Box<
1153 dyn ::std::error::Error + ::core::marker::Send + ::core::marker::Sync
1154 >)
1155 -> Self {
1156 Self::#inner_decoding_failure_name(err)
1157 }
1158 }
1159 }
1160 }
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165 use syn::parse_quote;
1166
1167 use super::*;
1168
1169 #[test]
1170 fn unit_struct() {
1171 let ast = parse_quote! {
1172 #[derive(BFieldCodec)]
1173 struct UnitStruct;
1174 };
1175 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1176 }
1177
1178 #[test]
1179 fn tuple_struct() {
1180 let ast = parse_quote! {
1181 #[derive(BFieldCodec)]
1182 struct TupleStruct(u64, u32);
1183 };
1184 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1185 }
1186
1187 #[test]
1188 fn struct_with_named_fields() {
1189 let ast = parse_quote! {
1190 #[derive(BFieldCodec)]
1191 struct StructWithNamedFields {
1192 field1: u64,
1193 field2: u32,
1194 #[bfield_codec(ignore)]
1195 ignored_field: bool,
1196 }
1197 };
1198 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1199 }
1200
1201 #[test]
1202 fn enum_with_tuple_variants() {
1203 let ast = parse_quote! {
1204 #[derive(BFieldCodec)]
1205 enum Enum {
1206 Variant1,
1207 Variant2(u64),
1208 Variant3(u64, u32),
1209 #[bfield_codec(ignore)]
1210 IgnoredVariant,
1211 }
1212 };
1213 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1214 }
1215
1216 #[test]
1217 fn generic_tuple_struct() {
1218 let ast = parse_quote! {
1219 #[derive(BFieldCodec)]
1220 struct TupleStruct<T>(T, (T, T));
1221 };
1222 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1223 }
1224
1225 #[test]
1226 fn generic_struct_with_named_fields() {
1227 let ast = parse_quote! {
1228 #[derive(BFieldCodec)]
1229 struct StructWithNamedFields<T> {
1230 field1: T,
1231 field2: (T, T),
1232 #[bfield_codec(ignore)]
1233 ignored_field: bool,
1234 }
1235 };
1236 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1237 }
1238
1239 #[test]
1240 fn generic_enum() {
1241 let ast = parse_quote! {
1242 #[derive(BFieldCodec)]
1243 enum Enum<T> {
1244 Variant1,
1245 Variant2(T),
1246 Variant3(T, T),
1247 #[bfield_codec(ignore)]
1248 IgnoredVariant,
1249 }
1250 };
1251 let _rust_code = BFieldCodecDeriveBuilder::new(ast).build();
1252 }
1253}