tagname_derive 0.3.1

derive(TagName) implementation
Documentation
use proc_macro2::TokenStream;
use syn::punctuated::Pair;
use syn::spanned::Spanned;
use syn::Attribute;
use syn::{Expr, ExprLit, Fields, Lit, Variant};

use super::{Case, Tag, TagData, TaggedUnion};

#[derive(Debug)]
pub struct Error(pub TokenStream);

pub(crate) fn traverse_ast(ast: syn::DeriveInput) -> Result<TaggedUnion, Error> {
    let name = ast.ident.clone();
    let variants = traverse_enum(ast)?;

    if variants.is_empty() {
        return Err(Error(
            quote::quote_spanned! {name.span()=> compile_error!("cannot derive `TagName` for empty enum types"); },
        ));
    }

    let tags = traverse_variants(variants)?;

    Ok(TaggedUnion { name, tags })
}

fn traverse_enum(ast: syn::DeriveInput) -> Result<Vec<Variant>, Error> {
    match ast.data {
        syn::Data::Enum(enum_data) => Ok(enum_data
            .variants
            .into_pairs()
            .into_iter()
            .map(Pair::into_value)
            .collect()),
        syn::Data::Struct(s) => Err(Error(
            quote::quote_spanned! {s.struct_token.span()=> compile_error!("cannot derive `TagName` for struct types"); },
        )),
        syn::Data::Union(u) => Err(Error(
            quote::quote_spanned! {u.union_token.span()=> compile_error!("cannot derive `TagName` for union types"); },
        )),
    }
}

fn traverse_variants(variants: Vec<Variant>) -> Result<Vec<Tag>, Error> {
    let mut tags = Vec::new();
    for v in variants {
        if !v.attrs.is_empty() {}
        tags.push(match v.fields {
            Fields::Unit => Tag::Unit(TagData {
                ident: v.ident.clone(),
                case: traverse_attribute(&v)?,
            }),
            Fields::Unnamed(_) => Tag::Unnamed(TagData {
                ident: v.ident.clone(),
                case: traverse_attribute(&v)?,
            }),
            Fields::Named(_) => Tag::Named(TagData {
                ident: v.ident.clone(),
                case: traverse_attribute(&v)?,
            }),
        });
    }
    Ok(tags)
}

fn traverse_attribute(variant: &Variant) -> Result<Case, Error> {
    if variant.attrs.is_empty() {
        return Ok(Case::Unchanged);
    }
    let tag_attribute = variant.attrs.iter().find(|attr| attr.path.is_ident("tag"));

    if tag_attribute.is_none() {
        return Ok(Case::Unchanged);
    }

    traverse_tag_attribute_argument(tag_attribute.unwrap())
}

fn traverse_tag_attribute_argument(attribute: &Attribute) -> Result<Case, Error> {
    let expr: Expr = attribute.parse_args().map_err(move |_| {
        Error(quote::quote_spanned! {attribute.span()=> compile_error!(r#"`tag` attribute expects an assignment expression `[tag(case = "lower" | "upper")]`"#); })
    })?;

    let assign_expr = match expr {
        Expr::Assign(assign_expr) => assign_expr,
        _ => {
            return Err(Error(
                quote::quote_spanned! {expr.span()=> compile_error!(r#"`tag` attribute expects an assignment expression `[tag(case = "lower" | "upper")]`"#); },
            ));
        }
    };

    match *assign_expr.left {
        Expr::Path(ref p)
            if !p.path.segments.is_empty() && p.path.segments.first().unwrap().ident == "case" => {}
        _ => {
            return Err(Error(
                quote::quote_spanned! {assign_expr.span()=> compile_error!("left handside of the assignment in `tag` attribute must be `case`");},
            ))
        }
    }

    match *assign_expr.right {
        Expr::Lit(ExprLit {
            lit: Lit::Str(s),
            attrs: _,
        }) => match s.value().as_str() {
            "lower" => Ok(Case::Lower),
            "upper" => Ok(Case::Upper),
            _ => Err(Error(
                quote::quote_spanned! {s.span()=> compile_error!("`case` value must be either \"upper\" or \"lower\"");},
            )),
        },
        _ => Err(Error(
            quote::quote_spanned! {assign_expr.right.span()=> compile_error!(r#"right handside of the assignment expression is expected to be a string literal "lower" or "upper""#);},
        )),
    }
}