lafere-api-codegen 0.1.0

Codegen for lafere api
Documentation
use crate::util::lafere_api_crate;

use proc_macro2::TokenStream;
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
	Attribute, Data, DeriveInput, Error, Expr, ExprLit, Fields, Ident, Lit,
	LitInt, Result, TypePath, Variant,
};

pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
	let DeriveInput {
		attrs,
		ident,
		generics,
		data,
		..
	} = input;

	let d = match data {
		Data::Enum(e) => e,
		_ => return Err(Error::new_spanned(ident, "only enums supported")),
	};

	if !repr_as_u16(attrs)? {
		return Err(Error::new_spanned(ident, "#[repr(u16)] required"));
	}

	if !generics.params.is_empty() {
		return Err(Error::new_spanned(generics, "generics not supported"));
	}

	// (fieldnum, ident)
	let variants = variants_no_fields(d.variants)?;

	let fire = lafere_api_crate()?;

	let from_variants = variants
		.iter()
		.map(|(num, id)| quote!(#num => Some(Self::#id)));

	let from_u16 = quote!(
		fn from_u16(num: u16) -> Option<Self> {
			match num {
				#(#from_variants),*,
				_ => None
			}
		}
	);

	let to_variants =
		variants.iter().map(|(num, id)| quote!(Self::#id => #num));

	let as_u16 = quote!(
		fn as_u16(&self) -> u16 {
			match self {
				#(#to_variants),*
			}
		}
	);

	Ok(quote!(
		impl #fire::message::Action for #ident {
			#from_u16
			#as_u16
		}
	))
}

fn repr_as_u16(attrs: Vec<Attribute>) -> Result<bool> {
	let mut repr_as = None;

	for attr in attrs {
		if !attr.path().is_ident("repr") {
			continue;
		}

		let ty: TypePath = attr.parse_args()?;

		repr_as = Some(ty);
	}

	match repr_as {
		Some(path) => {
			if !path.path.is_ident("u16") {
				return Err(Error::new_spanned(path, "expected u16"));
			}

			Ok(true)
		}
		None => Ok(false),
	}
}

fn variants_no_fields(
	variants: Punctuated<Variant, Comma>,
) -> Result<Vec<(LitInt, Ident)>> {
	variants
		.into_iter()
		.map(|v| {
			let fieldnum_expr = v
				.discriminant
				.ok_or_else(|| {
					Error::new_spanned(
						&v.ident,
						"needs to have a field number `Ident = x`",
					)
				})?
				.1;
			let fieldnum = match fieldnum_expr {
				Expr::Lit(ExprLit {
					lit: Lit::Int(int), ..
				}) => int,
				e => return Err(Error::new_spanned(e, "expected = int")),
			};

			let fieldnum_zero = fieldnum.base10_digits() == "0";

			if fieldnum_zero {
				return Err(Error::new_spanned(
					fieldnum_zero,
					"zero not allowed",
				));
			}

			if !matches!(v.fields, Fields::Unit) {
				return Err(Error::new_spanned(v.fields, "no fields allowed"));
			}

			Ok((fieldnum, v.ident))
		})
		.collect()
}