enum-tree-derive 0.1.0

Derive macros for the enum-tree crate
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, quote};
use syn::parse::{Parse, ParseStream, Parser as _};
use syn::punctuated::Punctuated;
use syn::{DeriveInput, Expr, ExprClosure, ExprPath, Field, Ident, Token, Type};

mod variable_names
{
    pub const FIELD_VALUE: &str = "field_value_";
}

/// Either a function path or a closure expression, used for both `via = ...`
/// and `default = ...` (callable forms).
#[derive(Clone)]
enum Callable
{
    Func(ExprPath),
    Closure(ExprClosure),
}

impl Callable
{
    fn call(&self, tokens: &mut TokenStream, arg: Option<&Ident>)
    {
        match self {
            Self::Func(f) => {
                f.to_tokens(tokens);
                if let Some(a) = arg {
                    tokens.extend(quote!((#a)));
                } else {
                    tokens.extend(quote!(()));
                }
            }
            Self::Closure(c) => {
                if let Some(a) = arg {
                    tokens.extend(quote!((#c)(#a)));
                } else {
                    tokens.extend(quote!((#c)()));
                }
            }
        }
    }
}

enum Constructor
{
    /// `default` (word) -> `Default::default()`.
    UseDefault,
    /// `default = path` or `default = "path"` -> `path()`.
    /// `default = |...| ...` -> `(closure)()`.
    Callable(Callable),
    /// `default = <arbitrary expr>` (struct lit, block, etc.) -> `<expr>`.
    Value(Expr),
}

impl ToTokens for Constructor
{
    fn to_tokens(&self, tokens: &mut TokenStream)
    {
        match self {
            Self::UseDefault => tokens.extend(quote!(::core::default::Default::default())),
            Self::Callable(c) => c.call(tokens, None),
            Self::Value(v) => v.to_tokens(tokens),
        }
    }
}

enum Conversion
{
    /// No `via`; pass the input through directly.
    Identity,
    /// `via = func_or_closure` -> apply to input before storing.
    Via(Callable),
}

impl Conversion
{
    fn to_tokens_with_input(&self, tokens: &mut TokenStream, input: &Ident)
    {
        match self {
            Self::Identity => input.to_tokens(tokens),
            Self::Via(c) => c.call(tokens, Some(input)),
        }
    }
}

#[derive(Default)]
struct FieldAttrs
{
    is_from:   bool,
    from_type: Option<Type>,
    via:       Option<Callable>,
    default:   Option<Constructor>,
}

/// Items inside `#[enum_from(...)]`.
enum EnumFromItem
{
    From,
    FromType(Type),
    Via(Callable),
    DefaultWord,
    DefaultValue(Constructor),
}

fn parse_callable_from_expr(expr: Expr) -> syn::Result<Callable>
{
    match expr {
        Expr::Closure(c) => Ok(Callable::Closure(c)),
        Expr::Path(p) => Ok(Callable::Func(p)),
        Expr::Lit(syn::ExprLit {
            lit: syn::Lit::Str(s),
            ..
        }) => {
            let path: ExprPath = s.parse()?;
            Ok(Callable::Func(path))
        }
        other => Err(syn::Error::new_spanned(
            other,
            "expected a function path, closure, or string containing a function path",
        )),
    }
}

impl Parse for EnumFromItem
{
    fn parse(input: ParseStream) -> syn::Result<Self>
    {
        // Each item is `ident` or `ident = <expr>`.
        let ident: Ident = input.parse()?;
        let name = ident.to_string();
        if input.peek(Token![=]) {
            let _: Token![=] = input.parse()?;
            match name.as_str() {
                "via" => {
                    let expr: Expr = input.parse()?;
                    Ok(Self::Via(parse_callable_from_expr(expr)?))
                }
                "from" => {
                    let ty: Type = input.parse()?;
                    Ok(Self::FromType(ty))
                }
                "default" => {
                    let expr: Expr = input.parse()?;
                    let cons = match expr {
                        Expr::Closure(c) => Constructor::Callable(Callable::Closure(c)),
                        Expr::Path(p) => Constructor::Callable(Callable::Func(p)),
                        Expr::Lit(syn::ExprLit {
                            lit: syn::Lit::Str(s),
                            ..
                        }) => {
                            let path: ExprPath = s.parse()?;
                            Constructor::Callable(Callable::Func(path))
                        }
                        other => Constructor::Value(other),
                    };
                    Ok(Self::DefaultValue(cons))
                }
                other => Err(syn::Error::new(
                    ident.span(),
                    format!("unknown enum_from option `{other}` with value"),
                )),
            }
        } else {
            match name.as_str() {
                "from" => Ok(Self::From),
                "default" => Ok(Self::DefaultWord),
                other => Err(syn::Error::new(
                    ident.span(),
                    format!("unknown enum_from option `{other}`"),
                )),
            }
        }
    }
}

fn parse_field_attrs(field: &Field) -> syn::Result<FieldAttrs>
{
    let mut out = FieldAttrs::default();
    for attr in &field.attrs {
        if !attr.path().is_ident("enum_from") {
            continue;
        }
        match &attr.meta {
            // Bare `#[enum_from]`
            syn::Meta::Path(_) => {
                out.is_from = true;
            }
            syn::Meta::List(list) => {
                let parser = Punctuated::<EnumFromItem, Token![,]>::parse_terminated;
                let items = parser.parse2(list.tokens.clone())?;
                for item in items {
                    match item {
                        EnumFromItem::From => out.is_from = true,
                        EnumFromItem::FromType(ty) => {
                            out.is_from = true;
                            out.from_type = Some(ty);
                        }
                        EnumFromItem::Via(c) => out.via = Some(c),
                        EnumFromItem::DefaultWord => {
                            out.default = Some(Constructor::UseDefault);
                        }
                        EnumFromItem::DefaultValue(c) => {
                            out.default = Some(c);
                        }
                    }
                }
            }
            syn::Meta::NameValue(nv) => {
                return Err(syn::Error::new_spanned(
                    nv,
                    "expected `#[enum_from(...)]` or bare `#[enum_from]`",
                ));
            }
        }
    }
    Ok(out)
}

struct VariantProperties<'a>
{
    ident:       &'a Ident,
    is_named:    bool,
    field_attrs: Vec<(&'a Field, FieldAttrs)>,
}

fn collect_variant(variant: &syn::Variant) -> syn::Result<Option<VariantProperties<'_>>>
{
    let (fields, is_named) = match &variant.fields {
        syn::Fields::Named(f) => (&f.named, true),
        syn::Fields::Unnamed(f) => (&f.unnamed, false),
        syn::Fields::Unit => return Ok(None),
    };
    let mut field_attrs = Vec::with_capacity(fields.len());
    for field in fields {
        let attrs = parse_field_attrs(field)?;
        field_attrs.push((field, attrs));
    }
    Ok(Some(VariantProperties {
        ident: &variant.ident,
        is_named,
        field_attrs,
    }))
}

/// Produce tokens that construct `field` when it is *not* the source of the
/// From impl we're generating.
fn emit_default(field: &Field, attrs: &FieldAttrs, tokens: &mut TokenStream)
{
    if let Some(ident) = &field.ident {
        tokens.extend(quote! { #ident: });
    }
    if let Some(cons) = &attrs.default {
        cons.to_tokens(tokens);
    } else if attrs.is_from {
        // Fallback for `from` fields without explicit default: try Default.
        let ty = &field.ty;
        tokens.extend(quote! { <#ty as ::core::default::Default>::default() });
    } else {
        // Unannotated field: assume the type path itself constructs the value
        // (i.e. it's a unit struct).
        field.ty.to_tokens(tokens);
    }
    tokens.extend(quote!(,));
}

fn emit_source(field: &Field, conversion: &Conversion, input: &Ident, tokens: &mut TokenStream)
{
    if let Some(ident) = &field.ident {
        tokens.extend(quote! { #ident: });
    }
    conversion.to_tokens_with_input(tokens, input);
    tokens.extend(quote!(,));
}

fn emit_impl(
    enum_name: &Ident,
    variant: &VariantProperties<'_>,
    source_idx: usize,
    out: &mut TokenStream,
)
{
    let input_ident = Ident::new(variable_names::FIELD_VALUE, Span::mixed_site());
    let (src_field, src_attrs) = &variant.field_attrs[source_idx];
    let src_ty: &Type = src_attrs.from_type.as_ref().unwrap_or(&src_field.ty);
    let variant_ident = variant.ident;

    let conversion = match &src_attrs.via {
        Some(c) => Conversion::Via(c.clone()),
        None => Conversion::Identity,
    };

    let mut body = TokenStream::new();
    let mut inner = TokenStream::new();
    for (i, (field, attrs)) in variant.field_attrs.iter().enumerate() {
        if i == source_idx {
            emit_source(field, &conversion, &input_ident, &mut inner);
        } else {
            emit_default(field, attrs, &mut inner);
        }
    }
    if variant.is_named {
        body.extend(quote!({ #inner }));
    } else {
        body.extend(quote!(( #inner )));
    }

    out.extend(quote! {
        impl ::core::convert::From<#src_ty> for #enum_name {
            fn from(#input_ident: #src_ty) -> Self {
                #enum_name::#variant_ident #body
            }
        }
    });
}

fn try_derive(input: &DeriveInput) -> syn::Result<TokenStream>
{
    let syn::Data::Enum(data) = &input.data else {
        return Err(syn::Error::new_spanned(
            input,
            "EnumFrom can only be derived for enums",
        ));
    };
    let type_name = &input.ident;
    let mut out = TokenStream::new();
    for variant in &data.variants {
        let Some(props) = collect_variant(variant)? else {
            continue;
        };
        for i in 0..props.field_attrs.len() {
            if props.field_attrs[i].1.is_from {
                emit_impl(type_name, &props, i, &mut out);
            }
        }
    }
    Ok(out)
}

pub fn generate(input: &DeriveInput) -> TokenStream
{
    match try_derive(input) {
        Ok(t) => t,
        Err(e) => e.to_compile_error(),
    }
}