1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use proc_macro::TokenStream; use proc_macro2::Span; use quote::{format_ident, quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, Ident, Item, ItemEnum, ItemStruct, Result, Token, }; struct Args { traits: Vec<Ident>, } struct TargetItem { ident: Ident, item_impl: TargetItemImpl, } enum TargetItemImpl { Enum(ItemEnum), Struct(ItemStruct), } impl Parse for Args { fn parse(input: ParseStream) -> Result<Self> { let traits = Punctuated::<Ident, Token![,]>::parse_terminated(input)?; if traits.is_empty() { Err(syn::Error::new( input.span(), "expected at least one trait name as an argument; none provided.", )) } else { Ok(Args { traits: traits.into_iter().collect(), }) } } } impl Parse for TargetItem { fn parse(input: ParseStream) -> Result<Self> { let item = Item::parse(input)?; match &item { Item::Struct(s) if s.generics.params.is_empty() => Ok(TargetItem { ident: s.ident.clone(), item_impl: TargetItemImpl::Struct(s.clone()), }), Item::Struct(_) => Err(syn::Error::new( input.span(), "must_implement_trait does not currently support types with generic parameters.", )), Item::Enum(e) if e.generics.params.is_empty() => Ok(TargetItem { ident: e.ident.clone(), item_impl: TargetItemImpl::Enum(e.clone()), }), Item::Enum(_) => Err(syn::Error::new( input.span(), "must_implement_trait does not currently support types with generic parameters.", )), _ => Err(input.error("must_implement_trait can only be used on structs and enums.")), } } } impl ToTokens for TargetItemImpl { fn to_tokens(&self, tokens: &mut quote::__private::TokenStream) { match &self { TargetItemImpl::Struct(s) => s.to_tokens(tokens), TargetItemImpl::Enum(e) => e.to_tokens(tokens), } } } #[proc_macro_attribute] pub fn must_implement_trait(attr_tokens: TokenStream, item_tokens: TokenStream) -> TokenStream { let item = parse_macro_input!(item_tokens as TargetItem); let args = parse_macro_input!(attr_tokens as Args); let ident_str = &item.ident; let item_declaration = item.item_impl; let traits = &args.traits; let trait_names = args .traits .iter() .map(|t| t.to_string()) .collect::<Vec<String>>(); let shim_trait_id = syn::Ident::new(&trait_names.join(""), Span::call_site()); let shim_ident = format_ident!( "_MustImplementTraitGadget{}For{}", item.ident, shim_trait_id ); let updated_syntax = quote! { #item_declaration struct #shim_ident where #ident_str: #(#traits)+*; }; updated_syntax.into() }