round_based_derive/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::ext::IdentExt;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Token, Variant};
8
9#[proc_macro_derive(ProtocolMessage, attributes(protocol_message))]
10pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
11    let input = parse_macro_input!(input as DeriveInput);
12
13    let mut root = None;
14
15    for attr in input.attrs {
16        if !attr.path.is_ident("protocol_message") {
17            continue;
18        }
19        if root.is_some() {
20            return quote_spanned! { attr.path.span() => compile_error!("#[protocol_message] attribute appears more than once"); }.into();
21        }
22        let tokens = attr.tokens.into();
23        root = Some(parse_macro_input!(tokens as RootAttribute));
24    }
25
26    let root_path = root
27        .map(|root| root.path)
28        .unwrap_or_else(|| Punctuated::from_iter([Ident::new("round_based", Span::call_site())]));
29
30    let enum_data = match input.data {
31        Data::Enum(e) => e,
32        Data::Struct(s) => {
33            return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMessage");}.into()
34        }
35        Data::Union(s) => {
36            return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMessage");}.into()
37        }
38    };
39
40    let name = input.ident;
41    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42    let round_method_impl = if !enum_data.variants.is_empty() {
43        round_method(&name, enum_data.variants.iter())
44    } else {
45        // Special case for empty enum. Empty protocol message is useless, but let it be
46        quote! { match *self {} }
47    };
48
49    let impl_protocol_message = quote! {
50        impl #impl_generics #root_path::ProtocolMessage for #name #ty_generics #where_clause {
51            fn round(&self) -> u16 {
52                #round_method_impl
53            }
54        }
55    };
56
57    let impl_round_message = round_messages(
58        &root_path,
59        &name,
60        &input.generics,
61        enum_data.variants.iter(),
62    );
63
64    proc_macro::TokenStream::from(quote! {
65        #impl_protocol_message
66        #impl_round_message
67    })
68}
69
70fn round_method<'v>(enum_name: &Ident, variants: impl Iterator<Item = &'v Variant>) -> TokenStream {
71    let match_variants = (0u16..).zip(variants).map(|(i, variant)| {
72        let variant_name = &variant.ident;
73        match &variant.fields {
74            Fields::Unit => quote_spanned! {
75                variant.ident.span() =>
76                #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMessage"),
77            },
78            Fields::Named(_) => quote_spanned! {
79                variant.ident.span() =>
80                #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMessage"),
81            },
82            Fields::Unnamed(unnamed) => if unnamed.unnamed.len() == 1 {
83                quote_spanned! {
84                    variant.ident.span() =>
85                    #enum_name::#variant_name(_) => #i,
86                }
87            } else {
88                quote_spanned! {
89                    variant.ident.span() =>
90                    #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMessage"),
91                }
92            },
93        }
94    });
95    quote! {
96        match self {
97            #(#match_variants)*
98        }
99    }
100}
101
102fn round_messages<'v>(
103    root_path: &RootPath,
104    enum_name: &Ident,
105    generics: &Generics,
106    variants: impl Iterator<Item = &'v Variant>,
107) -> TokenStream {
108    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
109    let impls = (0u16..).zip(variants).map(|(i, variant)| {
110        let variant_name = &variant.ident;
111        match &variant.fields {
112            Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => {
113                let msg_type = &unnamed.unnamed[0].ty;
114                quote_spanned! {
115                    variant.ident.span() =>
116                    impl #impl_generics #root_path::RoundMessage<#msg_type> for #enum_name #ty_generics #where_clause {
117                        const ROUND: u16 = #i;
118                        fn to_protocol_message(round_message: #msg_type) -> Self {
119                            #enum_name::#variant_name(round_message)
120                        }
121                        fn from_protocol_message(protocol_message: Self) -> Result<#msg_type, Self> {
122                            #[allow(unreachable_patterns)]
123                            match protocol_message {
124                                #enum_name::#variant_name(msg) => Ok(msg),
125                                _ => Err(protocol_message),
126                            }
127                        }
128                    }
129                }
130            }
131            _ => quote! {},
132        }
133    });
134    quote! {
135        #(#impls)*
136    }
137}
138
139type RootPath = Punctuated<Ident, Token![::]>;
140
141#[allow(dead_code)]
142struct RootAttribute {
143    paren: syn::token::Paren,
144    root: kw::root,
145    eq: Token![=],
146    path: RootPath,
147}
148
149impl Parse for RootAttribute {
150    fn parse(input: ParseStream) -> syn::Result<Self> {
151        let content;
152        let paren = syn::parenthesized!(content in input);
153        let root = content.parse::<kw::root>()?;
154        let eq = content.parse::<Token![=]>()?;
155        let path = RootPath::parse_separated_nonempty_with(&content, Ident::parse_any)?;
156        let _ = content.parse::<syn::parse::Nothing>()?;
157
158        Ok(Self {
159            paren,
160            root,
161            eq,
162            path,
163        })
164    }
165}
166
167mod kw {
168    syn::custom_keyword! { root }
169}