numbat-codec-derive 0.0.9

Macro implementations of numbat-codec #[derive(NestedEncode, NestedDecode, TopEncode, TopDecode)]
Documentation
use proc_macro::TokenStream;
use quote::quote;

use crate::nested_de_derive::*;
use crate::util::*;

fn fieldless_enum_match_arm_result_ok(
	name: &syn::Ident,
	data_enum: &syn::DataEnum,
) -> Vec<proc_macro2::TokenStream> {
	data_enum
		.variants
		.iter()
		.enumerate()
		.map(|(variant_index, variant)| {
			let variant_index_u8 = variant_index as u8;
			let variant_ident = &variant.ident;
			quote! {
				#variant_index_u8 => core::result::Result::Ok( #name::#variant_ident ),
			}
		})
		.collect()
}

fn fieldless_enum_match_arm(
	name: &syn::Ident,
	data_enum: &syn::DataEnum,
) -> Vec<proc_macro2::TokenStream> {
	data_enum
		.variants
		.iter()
		.enumerate()
		.map(|(variant_index, variant)| {
			let variant_index_u8 = variant_index as u8;
			let variant_ident = &variant.ident;
			quote! {
				#variant_index_u8 => #name::#variant_ident ,
			}
		})
		.collect()
}

/// Generates a default-value deserializer snippet automatically.
/// Currently only does so for enums whose first variant is fieldless.
/// Also generates the snippet for `top_decode_or_exit`.
/// Not called for TopDecodeOrDefault, since that one already provides an explicit default.
fn auto_default(ast: &syn::DeriveInput) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
	let name = &ast.ident;
	if let syn::Data::Enum(data_enum) = &ast.data {
		if data_enum.variants.is_empty() {
			panic!("cannot deserialize enums without variants");
		}
		let first_variant = &data_enum.variants[0];
		if first_variant.fields.is_empty() {
			let first_variant_ident = &first_variant.ident;
			let auto_default = quote! {
				if top_input.byte_len() == 0 {
					return core::result::Result::Ok(#name::#first_variant_ident);
				}
			};
			let auto_default_or_exit = quote! {
				if top_input.byte_len() == 0 {
					return #name::#first_variant_ident;
				}
			};
			return (auto_default, auto_default_or_exit);
		}
	}

	// returns nothing by default
	(quote! {}, quote! {})
}

/// Only returns the trait implementation method bodies, without the impl or method definitions.
fn top_decode_method_bodies(
	ast: &syn::DeriveInput,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
	let name = &ast.ident;
	match &ast.data {
		syn::Data::Struct(data_struct) => {
			let field_dep_decode_snippets =
				fields_decl_syntax(&data_struct.fields, |index, field| {
					dep_decode_snippet(index, field)
				});
			let field_dep_encode_or_exit_snippets =
				fields_decl_syntax(&data_struct.fields, |index, field| {
					dep_decode_or_exit_snippet(index, field)
				});

			let top_decode_body = quote! {
				let bytes = top_input.into_boxed_slice_u8();
				let input = &mut &*bytes;
				let result = #name #field_dep_decode_snippets ;
				if !input.is_empty() {
					return core::result::Result::Err(numbat_codec::DecodeError::INPUT_TOO_LONG);
				}
				core::result::Result::Ok(result)
			};
			let top_decode_or_exit_body = quote! {
				let bytes = top_input.into_boxed_slice_u8();
				let input = &mut &*bytes;
				let result = #name #field_dep_encode_or_exit_snippets ;
				if !input.is_empty() {
					exit(c, numbat_codec::DecodeError::INPUT_TOO_LONG);
				}
				result
			};
			(top_decode_body, top_decode_or_exit_body)
		},
		syn::Data::Enum(data_enum) => {
			assert!(
				data_enum.variants.len() < 256,
				"enums with more than 256 variants not supported"
			);
			if is_fieldless_enum(&data_enum) {
				// fieldless enums are special, they can be top-decoded as u8 directly
				let top_decode_arms = fieldless_enum_match_arm_result_ok(&name, &data_enum);
				let top_decode_or_exit_arms = fieldless_enum_match_arm(&name, &data_enum);

				let top_decode_body = quote! {
					match <u8 as numbat_codec::TopDecode>::top_decode(top_input)? {
						#(#top_decode_arms)*
						_ => core::result::Result::Err(numbat_codec::DecodeError::INVALID_VALUE),
					}
				};
				let top_decode_or_exit_body = quote! {
					match <u8 as numbat_codec::TopDecode>::top_decode_or_exit(top_input, c.clone(), exit) {
						#(#top_decode_or_exit_arms)*
						_ => exit(c, numbat_codec::DecodeError::INVALID_VALUE),
					}
				};
				(top_decode_body, top_decode_or_exit_body)
			} else {
				let variant_dep_decode_snippets = variant_dep_decode_snippets(&name, &data_enum);
				let variant_dep_decode_or_exit_snippets =
					variant_dep_decode_or_exit_snippets(&name, &data_enum);

				let top_decode_body = quote! {
					let bytes = top_input.into_boxed_slice_u8();
					let input = &mut &*bytes;
					let result = match <u8 as numbat_codec::NestedDecode>::dep_decode(input)? {
						#(#variant_dep_decode_snippets)*
						_ => core::result::Result::Err(numbat_codec::DecodeError::INVALID_VALUE),
					};
					if !input.is_empty() {
						return core::result::Result::Err(numbat_codec::DecodeError::INPUT_TOO_LONG);
					}
					result
				};
				let top_decode_or_exit_body = quote! {
					let bytes = top_input.into_boxed_slice_u8();
					let input = &mut &*bytes;
					let result = match <u8 as numbat_codec::NestedDecode>::dep_decode_or_exit(input, c.clone(), exit) {
						#(#variant_dep_decode_or_exit_snippets)*
						_ => exit(c, numbat_codec::DecodeError::INVALID_VALUE),
					};
					if !input.is_empty() {
						exit(c, numbat_codec::DecodeError::INPUT_TOO_LONG);
					}
					result
				};
				(top_decode_body, top_decode_or_exit_body)
			}
		},
		syn::Data::Union(_) => panic!("Union not supported"),
	}
}

pub fn top_decode_impl(ast: &syn::DeriveInput) -> TokenStream {
	let name = &ast.ident;
	let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl();
	let (top_decode_body, top_decode_or_exit_body) = top_decode_method_bodies(ast);
	let (auto_default, auto_default_or_exit) = auto_default(&ast);

	let gen = quote! {
		impl #impl_generics numbat_codec::TopDecode for #name #ty_generics #where_clause {
			fn top_decode<I: numbat_codec::TopDecodeInput>(top_input: I) -> core::result::Result<Self, numbat_codec::DecodeError> {
				#auto_default
				#top_decode_body
			}

			fn top_decode_or_exit<I: numbat_codec::TopDecodeInput, ExitCtx: Clone>(
				top_input: I,
				c: ExitCtx,
				exit: fn(ExitCtx, numbat_codec::DecodeError) -> !,
			) -> Self {
				#auto_default_or_exit
				#top_decode_or_exit_body
			}
		}
	};

	gen.into()
}

pub fn top_decode_or_default_impl(ast: &syn::DeriveInput) -> TokenStream {
	let name = &ast.ident;
	let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl();
	let (top_decode_body, top_decode_or_exit_body) = top_decode_method_bodies(ast);

	let gen = quote! {
		impl #impl_generics numbat_codec::TopDecode for #name #ty_generics #where_clause {
			fn top_decode<I: numbat_codec::TopDecodeInput>(top_input: I) -> core::result::Result<Self, numbat_codec::DecodeError> {
				if top_input.byte_len() == 0 {
					Ok(<#name #ty_generics as numbat_codec::DecodeDefault>::default())
				} else {
					#top_decode_body
				}
			}

			fn top_decode_or_exit<I: numbat_codec::TopDecodeInput, ExitCtx: Clone>(
				top_input: I,
				c: ExitCtx,
				exit: fn(ExitCtx, numbat_codec::DecodeError) -> !,
			) -> Self {
				if top_input.byte_len() == 0 {
					<#name #ty_generics as numbat_codec::DecodeDefault>::default()
				} else {
					#top_decode_or_exit_body
				}
			}
		}
	};

	gen.into()
}