use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, quote};
use syn::parse::{Parse, ParseStream, Parser as _};
use syn::punctuated::Punctuated;
use syn::{DeriveInput, Expr, ExprClosure, ExprPath, Field, Ident, Token, Type};
mod variable_names
{
pub const FIELD_VALUE: &str = "field_value_";
}
#[derive(Clone)]
enum Callable
{
Func(ExprPath),
Closure(ExprClosure),
}
impl Callable
{
fn call(&self, tokens: &mut TokenStream, arg: Option<&Ident>)
{
match self {
Self::Func(f) => {
f.to_tokens(tokens);
if let Some(a) = arg {
tokens.extend(quote!((#a)));
} else {
tokens.extend(quote!(()));
}
}
Self::Closure(c) => {
if let Some(a) = arg {
tokens.extend(quote!((#c)(#a)));
} else {
tokens.extend(quote!((#c)()));
}
}
}
}
}
enum Constructor
{
UseDefault,
Callable(Callable),
Value(Expr),
}
impl ToTokens for Constructor
{
fn to_tokens(&self, tokens: &mut TokenStream)
{
match self {
Self::UseDefault => tokens.extend(quote!(::core::default::Default::default())),
Self::Callable(c) => c.call(tokens, None),
Self::Value(v) => v.to_tokens(tokens),
}
}
}
enum Conversion
{
Identity,
Via(Callable),
}
impl Conversion
{
fn to_tokens_with_input(&self, tokens: &mut TokenStream, input: &Ident)
{
match self {
Self::Identity => input.to_tokens(tokens),
Self::Via(c) => c.call(tokens, Some(input)),
}
}
}
#[derive(Default)]
struct FieldAttrs
{
is_from: bool,
from_type: Option<Type>,
via: Option<Callable>,
default: Option<Constructor>,
}
enum EnumFromItem
{
From,
FromType(Type),
Via(Callable),
DefaultWord,
DefaultValue(Constructor),
}
fn parse_callable_from_expr(expr: Expr) -> syn::Result<Callable>
{
match expr {
Expr::Closure(c) => Ok(Callable::Closure(c)),
Expr::Path(p) => Ok(Callable::Func(p)),
Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(s),
..
}) => {
let path: ExprPath = s.parse()?;
Ok(Callable::Func(path))
}
other => Err(syn::Error::new_spanned(
other,
"expected a function path, closure, or string containing a function path",
)),
}
}
impl Parse for EnumFromItem
{
fn parse(input: ParseStream) -> syn::Result<Self>
{
let ident: Ident = input.parse()?;
let name = ident.to_string();
if input.peek(Token![=]) {
let _: Token![=] = input.parse()?;
match name.as_str() {
"via" => {
let expr: Expr = input.parse()?;
Ok(Self::Via(parse_callable_from_expr(expr)?))
}
"from" => {
let ty: Type = input.parse()?;
Ok(Self::FromType(ty))
}
"default" => {
let expr: Expr = input.parse()?;
let cons = match expr {
Expr::Closure(c) => Constructor::Callable(Callable::Closure(c)),
Expr::Path(p) => Constructor::Callable(Callable::Func(p)),
Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(s),
..
}) => {
let path: ExprPath = s.parse()?;
Constructor::Callable(Callable::Func(path))
}
other => Constructor::Value(other),
};
Ok(Self::DefaultValue(cons))
}
other => Err(syn::Error::new(
ident.span(),
format!("unknown enum_from option `{other}` with value"),
)),
}
} else {
match name.as_str() {
"from" => Ok(Self::From),
"default" => Ok(Self::DefaultWord),
other => Err(syn::Error::new(
ident.span(),
format!("unknown enum_from option `{other}`"),
)),
}
}
}
}
fn parse_field_attrs(field: &Field) -> syn::Result<FieldAttrs>
{
let mut out = FieldAttrs::default();
for attr in &field.attrs {
if !attr.path().is_ident("enum_from") {
continue;
}
match &attr.meta {
syn::Meta::Path(_) => {
out.is_from = true;
}
syn::Meta::List(list) => {
let parser = Punctuated::<EnumFromItem, Token![,]>::parse_terminated;
let items = parser.parse2(list.tokens.clone())?;
for item in items {
match item {
EnumFromItem::From => out.is_from = true,
EnumFromItem::FromType(ty) => {
out.is_from = true;
out.from_type = Some(ty);
}
EnumFromItem::Via(c) => out.via = Some(c),
EnumFromItem::DefaultWord => {
out.default = Some(Constructor::UseDefault);
}
EnumFromItem::DefaultValue(c) => {
out.default = Some(c);
}
}
}
}
syn::Meta::NameValue(nv) => {
return Err(syn::Error::new_spanned(
nv,
"expected `#[enum_from(...)]` or bare `#[enum_from]`",
));
}
}
}
Ok(out)
}
struct VariantProperties<'a>
{
ident: &'a Ident,
is_named: bool,
field_attrs: Vec<(&'a Field, FieldAttrs)>,
}
fn collect_variant(variant: &syn::Variant) -> syn::Result<Option<VariantProperties<'_>>>
{
let (fields, is_named) = match &variant.fields {
syn::Fields::Named(f) => (&f.named, true),
syn::Fields::Unnamed(f) => (&f.unnamed, false),
syn::Fields::Unit => return Ok(None),
};
let mut field_attrs = Vec::with_capacity(fields.len());
for field in fields {
let attrs = parse_field_attrs(field)?;
field_attrs.push((field, attrs));
}
Ok(Some(VariantProperties {
ident: &variant.ident,
is_named,
field_attrs,
}))
}
fn emit_default(field: &Field, attrs: &FieldAttrs, tokens: &mut TokenStream)
{
if let Some(ident) = &field.ident {
tokens.extend(quote! { #ident: });
}
if let Some(cons) = &attrs.default {
cons.to_tokens(tokens);
} else if attrs.is_from {
let ty = &field.ty;
tokens.extend(quote! { <#ty as ::core::default::Default>::default() });
} else {
field.ty.to_tokens(tokens);
}
tokens.extend(quote!(,));
}
fn emit_source(field: &Field, conversion: &Conversion, input: &Ident, tokens: &mut TokenStream)
{
if let Some(ident) = &field.ident {
tokens.extend(quote! { #ident: });
}
conversion.to_tokens_with_input(tokens, input);
tokens.extend(quote!(,));
}
fn emit_impl(
enum_name: &Ident,
variant: &VariantProperties<'_>,
source_idx: usize,
out: &mut TokenStream,
)
{
let input_ident = Ident::new(variable_names::FIELD_VALUE, Span::mixed_site());
let (src_field, src_attrs) = &variant.field_attrs[source_idx];
let src_ty: &Type = src_attrs.from_type.as_ref().unwrap_or(&src_field.ty);
let variant_ident = variant.ident;
let conversion = match &src_attrs.via {
Some(c) => Conversion::Via(c.clone()),
None => Conversion::Identity,
};
let mut body = TokenStream::new();
let mut inner = TokenStream::new();
for (i, (field, attrs)) in variant.field_attrs.iter().enumerate() {
if i == source_idx {
emit_source(field, &conversion, &input_ident, &mut inner);
} else {
emit_default(field, attrs, &mut inner);
}
}
if variant.is_named {
body.extend(quote!({ #inner }));
} else {
body.extend(quote!(( #inner )));
}
out.extend(quote! {
impl ::core::convert::From<#src_ty> for #enum_name {
fn from(#input_ident: #src_ty) -> Self {
#enum_name::#variant_ident #body
}
}
});
}
fn try_derive(input: &DeriveInput) -> syn::Result<TokenStream>
{
let syn::Data::Enum(data) = &input.data else {
return Err(syn::Error::new_spanned(
input,
"EnumFrom can only be derived for enums",
));
};
let type_name = &input.ident;
let mut out = TokenStream::new();
for variant in &data.variants {
let Some(props) = collect_variant(variant)? else {
continue;
};
for i in 0..props.field_attrs.len() {
if props.field_attrs[i].1.is_from {
emit_impl(type_name, &props, i, &mut out);
}
}
}
Ok(out)
}
pub fn generate(input: &DeriveInput) -> TokenStream
{
match try_derive(input) {
Ok(t) => t,
Err(e) => e.to_compile_error(),
}
}