1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote, ToTokens};
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    punctuated::Punctuated,
    Ident, Item, ItemEnum, ItemStruct, Result, Token,
};

struct Args {
    traits: Vec<Ident>,
}

struct TargetItem {
    ident: Ident,
    item_impl: TargetItemImpl,
}

enum TargetItemImpl {
    Enum(ItemEnum),
    Struct(ItemStruct),
}

impl Parse for Args {
    fn parse(input: ParseStream) -> Result<Self> {
        let traits = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;

        if traits.is_empty() {
            Err(syn::Error::new(
                input.span(),
                "expected at least one trait name as an argument; none provided.",
            ))
        } else {
            Ok(Args {
                traits: traits.into_iter().collect(),
            })
        }
    }
}

impl Parse for TargetItem {
    fn parse(input: ParseStream) -> Result<Self> {
        let item = Item::parse(input)?;
        match &item {
            Item::Struct(s) if s.generics.params.is_empty() => Ok(TargetItem {
                ident: s.ident.clone(),
                item_impl: TargetItemImpl::Struct(s.clone()),
            }),
            Item::Struct(_) => Err(syn::Error::new(
                input.span(),
                "must_implement_trait does not currently support types with generic parameters.",
            )),
            Item::Enum(e) if e.generics.params.is_empty() => Ok(TargetItem {
                ident: e.ident.clone(),
                item_impl: TargetItemImpl::Enum(e.clone()),
            }),
            Item::Enum(_) => Err(syn::Error::new(
                input.span(),
                "must_implement_trait does not currently support types with generic parameters.",
            )),
            _ => Err(input.error("must_implement_trait can only be used on structs and enums.")),
        }
    }
}

impl ToTokens for TargetItemImpl {
    fn to_tokens(&self, tokens: &mut quote::__private::TokenStream) {
        match &self {
            TargetItemImpl::Struct(s) => s.to_tokens(tokens),
            TargetItemImpl::Enum(e) => e.to_tokens(tokens),
        }
    }
}

#[proc_macro_attribute]
pub fn must_implement_trait(attr_tokens: TokenStream, item_tokens: TokenStream) -> TokenStream {
    let item = parse_macro_input!(item_tokens as TargetItem);
    let args = parse_macro_input!(attr_tokens as Args);

    let ident_str = &item.ident;
    let item_declaration = item.item_impl;
    let traits = &args.traits;

    let trait_names = args
        .traits
        .iter()
        .map(|t| t.to_string())
        .collect::<Vec<String>>();
    let shim_trait_id = syn::Ident::new(&trait_names.join(""), Span::call_site());
    let shim_ident = format_ident!(
        "_MustImplementTraitGadget{}For{}",
        item.ident,
        shim_trait_id
    );

    let updated_syntax = quote! {
        #item_declaration
        struct #shim_ident where #ident_str: #(#traits)+*;
    };

    updated_syntax.into()
}