round_based_derive/
lib.rs1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::ext::IdentExt;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Token, Variant};
8
9#[proc_macro_derive(ProtocolMessage, attributes(protocol_message))]
10pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
11 let input = parse_macro_input!(input as DeriveInput);
12
13 let mut root = None;
14
15 for attr in input.attrs {
16 if !attr.path.is_ident("protocol_message") {
17 continue;
18 }
19 if root.is_some() {
20 return quote_spanned! { attr.path.span() => compile_error!("#[protocol_message] attribute appears more than once"); }.into();
21 }
22 let tokens = attr.tokens.into();
23 root = Some(parse_macro_input!(tokens as RootAttribute));
24 }
25
26 let root_path = root
27 .map(|root| root.path)
28 .unwrap_or_else(|| Punctuated::from_iter([Ident::new("round_based", Span::call_site())]));
29
30 let enum_data = match input.data {
31 Data::Enum(e) => e,
32 Data::Struct(s) => {
33 return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMessage");}.into()
34 }
35 Data::Union(s) => {
36 return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMessage");}.into()
37 }
38 };
39
40 let name = input.ident;
41 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
42 let round_method_impl = if !enum_data.variants.is_empty() {
43 round_method(&name, enum_data.variants.iter())
44 } else {
45 quote! { match *self {} }
47 };
48
49 let impl_protocol_message = quote! {
50 impl #impl_generics #root_path::ProtocolMessage for #name #ty_generics #where_clause {
51 fn round(&self) -> u16 {
52 #round_method_impl
53 }
54 }
55 };
56
57 let impl_round_message = round_messages(
58 &root_path,
59 &name,
60 &input.generics,
61 enum_data.variants.iter(),
62 );
63
64 proc_macro::TokenStream::from(quote! {
65 #impl_protocol_message
66 #impl_round_message
67 })
68}
69
70fn round_method<'v>(enum_name: &Ident, variants: impl Iterator<Item = &'v Variant>) -> TokenStream {
71 let match_variants = (0u16..).zip(variants).map(|(i, variant)| {
72 let variant_name = &variant.ident;
73 match &variant.fields {
74 Fields::Unit => quote_spanned! {
75 variant.ident.span() =>
76 #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMessage"),
77 },
78 Fields::Named(_) => quote_spanned! {
79 variant.ident.span() =>
80 #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMessage"),
81 },
82 Fields::Unnamed(unnamed) => if unnamed.unnamed.len() == 1 {
83 quote_spanned! {
84 variant.ident.span() =>
85 #enum_name::#variant_name(_) => #i,
86 }
87 } else {
88 quote_spanned! {
89 variant.ident.span() =>
90 #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMessage"),
91 }
92 },
93 }
94 });
95 quote! {
96 match self {
97 #(#match_variants)*
98 }
99 }
100}
101
102fn round_messages<'v>(
103 root_path: &RootPath,
104 enum_name: &Ident,
105 generics: &Generics,
106 variants: impl Iterator<Item = &'v Variant>,
107) -> TokenStream {
108 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
109 let impls = (0u16..).zip(variants).map(|(i, variant)| {
110 let variant_name = &variant.ident;
111 match &variant.fields {
112 Fields::Unnamed(unnamed) if unnamed.unnamed.len() == 1 => {
113 let msg_type = &unnamed.unnamed[0].ty;
114 quote_spanned! {
115 variant.ident.span() =>
116 impl #impl_generics #root_path::RoundMessage<#msg_type> for #enum_name #ty_generics #where_clause {
117 const ROUND: u16 = #i;
118 fn to_protocol_message(round_message: #msg_type) -> Self {
119 #enum_name::#variant_name(round_message)
120 }
121 fn from_protocol_message(protocol_message: Self) -> Result<#msg_type, Self> {
122 #[allow(unreachable_patterns)]
123 match protocol_message {
124 #enum_name::#variant_name(msg) => Ok(msg),
125 _ => Err(protocol_message),
126 }
127 }
128 }
129 }
130 }
131 _ => quote! {},
132 }
133 });
134 quote! {
135 #(#impls)*
136 }
137}
138
139type RootPath = Punctuated<Ident, Token![::]>;
140
141#[allow(dead_code)]
142struct RootAttribute {
143 paren: syn::token::Paren,
144 root: kw::root,
145 eq: Token![=],
146 path: RootPath,
147}
148
149impl Parse for RootAttribute {
150 fn parse(input: ParseStream) -> syn::Result<Self> {
151 let content;
152 let paren = syn::parenthesized!(content in input);
153 let root = content.parse::<kw::root>()?;
154 let eq = content.parse::<Token![=]>()?;
155 let path = RootPath::parse_separated_nonempty_with(&content, Ident::parse_any)?;
156 let _ = content.parse::<syn::parse::Nothing>()?;
157
158 Ok(Self {
159 paren,
160 root,
161 eq,
162 path,
163 })
164 }
165}
166
167mod kw {
168 syn::custom_keyword! { root }
169}