1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
use std::convert::TryFrom; use quote::quote; use syn::export::{Span, TokenStream2}; use syn::{Fields, Ident, ItemEnum}; use crate::attribute_helpers::contains_skip; pub fn enum_ser(input: &ItemEnum) -> syn::Result<TokenStream2> { let name = &input.ident; let generics = &input.generics; let mut body = TokenStream2::new(); let mut serializable_field_types = TokenStream2::new(); for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_idx = u8::try_from(variant_idx).expect("up to 256 enum variants are supported"); let variant_ident = &variant.ident; let mut variant_header = TokenStream2::new(); let mut variant_body = TokenStream2::new(); match &variant.fields { Fields::Named(fields) => { for field in &fields.named { let field_name = field.ident.as_ref().unwrap(); if contains_skip(&field.attrs) { variant_header.extend(quote! { _#field_name, }); continue; } else { let field_type = &field.ty; serializable_field_types.extend(quote! { #field_type: borsh::ser::BorshSerialize, }); variant_header.extend(quote! { #field_name, }); } variant_body.extend(quote! { borsh::BorshSerialize::serialize(#field_name, writer)?; }) } variant_header = quote! { { #variant_header }}; } Fields::Unnamed(fields) => { for (field_idx, field) in fields.unnamed.iter().enumerate() { let field_idx = u32::try_from(field_idx).expect("up to 2^32 fields are supported"); if contains_skip(&field.attrs) { let field_ident = Ident::new(format!("_id{}", field_idx).as_str(), Span::call_site()); variant_header.extend(quote! { #field_ident, }); continue; } else { let field_type = &field.ty; serializable_field_types.extend(quote! { #field_type: borsh::ser::BorshSerialize, }); let field_ident = Ident::new(format!("id{}", field_idx).as_str(), Span::call_site()); variant_header.extend(quote! { #field_ident, }); variant_body.extend(quote! { borsh::BorshSerialize::serialize(#field_ident, writer)?; }) } } variant_header = quote! { ( #variant_header )}; } Fields::Unit => {} } body.extend(quote!( #name::#variant_ident #variant_header => { let variant_idx: u8 = #variant_idx; writer.write_all(&variant_idx.to_le_bytes())?; #variant_body } )) } Ok(quote! { impl #generics borsh::ser::BorshSerialize for #name #generics where #serializable_field_types { fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), std::io::Error> { match self { #body } Ok(()) } } }) }