must_implement_trait/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input,
7    punctuated::Punctuated,
8    Ident, Item, ItemEnum, ItemStruct, Result, Token,
9};
10
11struct Args {
12    traits: Vec<Ident>,
13}
14
15struct TargetItem {
16    ident: Ident,
17    item_impl: TargetItemImpl,
18}
19
20enum TargetItemImpl {
21    Enum(ItemEnum),
22    Struct(ItemStruct),
23}
24
25impl Parse for Args {
26    fn parse(input: ParseStream) -> Result<Self> {
27        let traits = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
28
29        if traits.is_empty() {
30            Err(syn::Error::new(
31                input.span(),
32                "expected at least one trait name as an argument; none provided.",
33            ))
34        } else {
35            Ok(Args {
36                traits: traits.into_iter().collect(),
37            })
38        }
39    }
40}
41
42impl Parse for TargetItem {
43    fn parse(input: ParseStream) -> Result<Self> {
44        let item = Item::parse(input)?;
45        match &item {
46            Item::Struct(s) if s.generics.params.is_empty() => Ok(TargetItem {
47                ident: s.ident.clone(),
48                item_impl: TargetItemImpl::Struct(s.clone()),
49            }),
50            Item::Struct(_) => Err(syn::Error::new(
51                input.span(),
52                "must_implement_trait does not currently support types with generic parameters.",
53            )),
54            Item::Enum(e) if e.generics.params.is_empty() => Ok(TargetItem {
55                ident: e.ident.clone(),
56                item_impl: TargetItemImpl::Enum(e.clone()),
57            }),
58            Item::Enum(_) => Err(syn::Error::new(
59                input.span(),
60                "must_implement_trait does not currently support types with generic parameters.",
61            )),
62            _ => Err(input.error("must_implement_trait can only be used on structs and enums.")),
63        }
64    }
65}
66
67impl ToTokens for TargetItemImpl {
68    fn to_tokens(&self, tokens: &mut quote::__private::TokenStream) {
69        match &self {
70            TargetItemImpl::Struct(s) => s.to_tokens(tokens),
71            TargetItemImpl::Enum(e) => e.to_tokens(tokens),
72        }
73    }
74}
75
76#[proc_macro_attribute]
77pub fn must_implement_trait(attr_tokens: TokenStream, item_tokens: TokenStream) -> TokenStream {
78    let item = parse_macro_input!(item_tokens as TargetItem);
79    let args = parse_macro_input!(attr_tokens as Args);
80
81    let ident_str = &item.ident;
82    let item_declaration = item.item_impl;
83    let traits = &args.traits;
84
85    let trait_names = args
86        .traits
87        .iter()
88        .map(|t| t.to_string())
89        .collect::<Vec<String>>();
90    let shim_trait_id = syn::Ident::new(&trait_names.join(""), Span::call_site());
91    let shim_ident = format_ident!(
92        "_MustImplementTraitGadget{}For{}",
93        item.ident,
94        shim_trait_id
95    );
96
97    let updated_syntax = quote! {
98        #item_declaration
99        struct #shim_ident where #ident_str: #(#traits)+*;
100    };
101
102    updated_syntax.into()
103}