packetize_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenTree;
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5    parse_macro_input, parse_quote, Attribute, Ident, ItemEnum, Meta, PathArguments, Type,
6    TypePath, Variant, Visibility,
7};
8
9struct Bound {
10    suffix: &'static str,
11    bound_packet_ident: &'static str,
12}
13
14const CLIENT_BOUND: Bound = Bound {
15    suffix: "S2c",
16    bound_packet_ident: "ClientBoundPacket",
17};
18
19const SERVER_BOUND: Bound = Bound {
20    suffix: "C2s",
21    bound_packet_ident: "ServerBoundPacket",
22};
23
24struct PacketStream<'a> {
25    ident: &'a Ident,
26    attrs: &'a Vec<Attribute>,
27    vis: &'a Visibility,
28    states: Vec<PacketStreamState<'a>>,
29}
30
31struct PacketStreamState<'a> {
32    attrs: &'a Vec<Attribute>,
33    ident: &'a Ident,
34    packets: Vec<Packet<'a>>,
35}
36
37#[derive(Clone)]
38struct Packet<'a> {
39    ident: &'a TypePath,
40    has_lifetime: bool,
41    changing_state: Option<proc_macro2::TokenStream>,
42    enforced_id: Option<proc_macro2::TokenStream>,
43}
44
45#[proc_macro_attribute]
46pub fn packet_stream(_attr: TokenStream, input: TokenStream) -> TokenStream {
47    let mut input = parse_macro_input!(input as ItemEnum);
48    let packet_stream = packet_stream_by_inputs(&mut input);
49    let client_bound_generated = generate_by_bound(&packet_stream, CLIENT_BOUND);
50    let server_bound_generated = generate_by_bound(&packet_stream, SERVER_BOUND);
51    let main_body_generated = generate_main_enum_body(&packet_stream);
52
53    quote! {
54        #main_body_generated
55        #client_bound_generated
56        #server_bound_generated
57    }
58    .into()
59}
60
61fn generate_main_enum_body(packet_stream: &PacketStream) -> proc_macro2::TokenStream {
62    let vis = packet_stream.vis;
63    let packet_stream_ident = packet_stream.ident;
64    let state_idents = idents_by_states(&packet_stream.states);
65    let attrs = packet_stream.attrs;
66    let state_attrs = attrs_by_states(&packet_stream.states);
67    quote! {
68        #(#attrs)*
69        #[allow(dead_code)]
70        #[derive(Debug)]
71        #vis enum #packet_stream_ident {
72            #(#(#state_attrs)* #state_idents,)*
73        }
74    }
75}
76
77fn generate_by_bound(packet_stream: &PacketStream, bound: Bound) -> proc_macro2::TokenStream {
78    let packet_stream_ident = packet_stream.ident;
79
80    let bound_packet_ident = format_ident!("{}", bound.bound_packet_ident);
81    let state_packet_names = packet_stream
82        .states
83        .iter()
84        .map(|state| format_ident!("{}{}Packets", state.ident, bound.suffix))
85        .collect::<Vec<_>>();
86    let state_names = packet_stream
87        .states
88        .iter()
89        .map(|state| state.ident)
90        .collect::<Vec<_>>();
91    let vis = packet_stream.vis;
92    let state_lifetimes = packet_stream
93        .states
94        .iter()
95        .map(|state| {
96            packets_filtered_with_suffix(&state.packets, bound.suffix)
97                .iter()
98                .any(|packet| packet.has_lifetime)
99                .then_some(quote! {<'a>})
100        })
101        .collect::<Vec<_>>();
102    let bound_packet_lifetime = state_lifetimes
103        .iter()
104        .any(|b| b.is_some())
105        .then_some(quote! {<'a>});
106    let bound_packet_lifetime_without_bracket = bound_packet_lifetime.clone().map(|_| quote! {'a});
107    let state_quotes: Vec<_> = packet_stream
108        .states
109        .iter()
110        .map(|state| {
111            let state_ident = state.ident;
112            let state_bound_packets = packets_filtered_with_suffix(&state.packets, bound.suffix);
113            let state_bound_packet_paths = paths_by_packets(&state_bound_packets);
114            let state = state.ident;
115            let state_packets_name = format_ident!("{state_ident}{}Packets", bound.suffix);
116            let vis = packet_stream.vis;
117            let bound_packets = format_ident!("{}", bound.bound_packet_ident);
118            let state_bound_packet_ids = ids_by_packets(&state_bound_packets);
119            let repr_attr = if state_bound_packet_paths.is_empty() { None } else {
120                Some(quote! { #[repr(u32)] })
121            };
122 let state_packet_lifetime = state_bound_packets.iter().any(|packet| packet.has_lifetime).then_some(quote! {<'a>});
123 let state_bound_packet_lifetimes = state_bound_packets.iter().map(|packet| packet.has_lifetime.then_some(quote! {<'a>})).collect::<Vec<_>>();
124
125            let serialization_attr = if cfg!(feature = "serialization") {
126                Some(quote! {#[derive(serialization::Serializable)]})
127            } else {
128                None
129            };
130            let packets_enum = quote! {
131                #serialization_attr
132                #[derive(Debug)]
133                #repr_attr
134                #vis enum #state_packets_name #state_packet_lifetime {
135                    #(#state_bound_packet_paths(#state_bound_packet_paths #state_bound_packet_lifetimes) #state_bound_packet_ids,)*
136                }
137            };
138            let changing_state_stmt: Vec<_> = state_bound_packets
139                .iter()
140                .map(|field| {
141                    if let Some(state) = &field.changing_state {
142                        Some(quote! {Some(#packet_stream_ident::#state)})
143                    } else {
144                        Some(quote! {None})
145                    }
146                })
147                .collect();
148
149            quote! {
150                #packets_enum
151
152                impl #bound_packet_lifetime From<#state_packets_name #state_packet_lifetime> for #bound_packets #bound_packet_lifetime {
153                    fn from(value: #state_packets_name #state_packet_lifetime) -> Self {
154                        #bound_packets::#state_packets_name(value)
155                    }
156                }
157
158                impl #state_packet_lifetime packetize::Packet<#packet_stream_ident> for #state_packets_name #state_packet_lifetime {
159                    fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
160                        match self {
161                            #(
162                                #state_packets_name::#state_bound_packet_paths(value) => {
163                                    packetize::Packet::<#packet_stream_ident>::get_id(value, state)
164                                }
165                            )*
166                            _ => unreachable!()
167                        }
168                    }
169
170                    fn is_changing_state(&self) -> Option<#packet_stream_ident> {
171                        match self {
172                            #(
173                                #state_packets_name::#state_bound_packet_paths(value) => {
174                                    <#state_bound_packet_paths #state_bound_packet_lifetimes as packetize::Packet::<#packet_stream_ident>>::is_changing_state(value)
175                                }
176                            )*
177                            _ => unreachable!()
178                        }
179                    }
180                }
181
182                impl #bound_packet_lifetime TryFrom<#bound_packets #bound_packet_lifetime> for #state_packets_name #state_packet_lifetime {
183                    type Error = ();
184
185                    fn try_from(value: #bound_packets #bound_packet_lifetime) -> Result<Self, Self::Error> {
186                        match value {
187                            #bound_packets::#state_packets_name(value) => Ok(value),
188                            _ => Err(())?,
189                        }
190                    }
191                }
192
193                #(
194                impl #state_packet_lifetime From<#state_bound_packet_paths #state_bound_packet_lifetimes> for #state_packets_name #state_packet_lifetime {
195                    fn from(value: #state_bound_packet_paths #state_bound_packet_lifetimes) -> Self {
196                        #state_packets_name::#state_bound_packet_paths(value)
197                    }
198                }
199
200                impl #bound_packet_lifetime From<#state_bound_packet_paths #state_bound_packet_lifetimes> for #bound_packets #bound_packet_lifetime {
201                    fn from(value: #state_bound_packet_paths #state_bound_packet_lifetimes) -> Self {
202                        #bound_packets::#state_packets_name(#state_packets_name::#state_bound_packet_paths(value))
203                    }
204                }
205
206                impl #bound_packet_lifetime TryFrom<#bound_packets #bound_packet_lifetime> for #state_bound_packet_paths #state_bound_packet_lifetimes {
207                    type Error = ();
208
209                    fn try_from(value: #bound_packets #bound_packet_lifetime) -> Result<Self, Self::Error> {
210                        match value {
211                            #bound_packets::#state_packets_name(value) => Ok(value.try_into()?),
212                            _ => Err(())?,
213                        }
214                    }
215                }
216
217                impl #state_packet_lifetime TryFrom<#state_packets_name #state_packet_lifetime> for #state_bound_packet_paths #state_bound_packet_lifetimes {
218                    type Error = ();
219
220                    fn try_from(value: #state_packets_name #state_packet_lifetime) -> Result<Self, Self::Error> {
221                        match value {
222                            #state_packets_name::#state_bound_packet_paths(value) => Ok(value),
223                            _ => Err(())?,
224                        }
225                    }
226                }
227
228                impl #state_bound_packet_lifetimes packetize::Packet<#packet_stream_ident> for #state_bound_packet_paths #state_bound_packet_lifetimes {
229                    fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
230                        match state {
231                            #packet_stream_ident::#state => {
232                                Some(#state_packets_name::#state_bound_packet_paths as u32)
233                            },
234                            _ => None,
235                        }
236                    }
237
238                    fn is_changing_state(&self) -> Option<#packet_stream_ident> {
239                        #changing_state_stmt
240                    }
241                }
242                )*
243            }
244        })
245        .collect();
246    let serialization_attr = if cfg!(feature = "serialization") {
247        Some(quote! {#[derive(serialization::Serializable)]})
248    } else {
249        None
250    };
251    let part1 = quote! {
252            #(#state_quotes)*
253
254            #serialization_attr
255            #[derive(Debug)]
256            #vis enum #bound_packet_ident #bound_packet_lifetime {
257                #(#state_packet_names(#state_packet_names #state_lifetimes),)*
258            }
259
260            impl #bound_packet_lifetime packetize::Packet<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
261                fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
262                    match self {
263                        #(
264                            #bound_packet_ident::#state_packet_names(value) => {
265                                packetize::Packet::<#packet_stream_ident>::get_id(value, state)
266                            }
267                        )*
268                        _ => unreachable!()
269                    }
270                }
271
272                fn is_changing_state(&self) -> Option<#packet_stream_ident> {
273                    match self {
274                        #(
275                            #bound_packet_ident::#state_packet_names(value) => {
276                                <#state_packet_names #state_lifetimes as packetize::Packet::<#packet_stream_ident>>::is_changing_state(value)
277                            }
278                        )*
279                        _ => unreachable!()
280                    }
281                }
282            }
283    };
284
285    #[cfg(not(feature = "serialization"))]
286    let part2 = quote! {};
287    #[cfg(feature = "serialization")]
288    let part2 = quote! {
289    impl<'de: #bound_packet_lifetime_without_bracket, #bound_packet_lifetime_without_bracket>
290        packetize::DecodePacket<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
291        fn decode_packet<D: serialization::Decoder>(
292            decoder: D,
293            state: &mut #packet_stream_ident,
294        ) -> Result<Self, D::Error> {
295            let result: Self = match state {
296                #(
297                #packet_stream_ident::#state_names =>
298                    <#state_packet_names as serialization::Decode::>::decode_placed(decoder)?.into(),
299                )*
300            };
301            if let Some(new_state) = <Self as packetize::Packet::<#packet_stream_ident>>::is_changing_state(&result) {
302                *state = new_state;
303            }
304            Ok(result)
305        }
306    }
307
308    impl #bound_packet_lifetime packetize::EncodePacket<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
309        fn encode_packet<E: serialization::Encoder>(
310            &self,
311            encoder: E,
312            state: &mut #packet_stream_ident,
313        ) -> Result<(), E::Error> {
314            if let Some(new_state) = <Self as packetize::Packet::<#packet_stream_ident>>::is_changing_state(self) {
315                *state = new_state;
316            }
317            match self {
318                #(
319                #bound_packet_ident::#state_packet_names(value) => serialization::Encode::encode(value, encoder)?,
320                )*
321            };
322            Ok(())
323        }
324    }
325        };
326    quote! {
327        #part1
328        #part2
329    }
330}
331
332fn packet_stream_by_inputs<'a>(item_enum: &'a mut ItemEnum) -> PacketStream<'a> {
333    let states: Vec<_> = item_enum
334        .variants
335        .iter_mut()
336        .map(|enum_variant| packet_stream_state_by_enum_variant(enum_variant))
337        .collect();
338    PacketStream {
339        ident: &item_enum.ident,
340        vis: &item_enum.vis,
341        states,
342        attrs: &item_enum.attrs,
343    }
344}
345
346fn idents_by_states<'a>(states: &Vec<PacketStreamState<'a>>) -> Vec<&'a Ident> {
347    states.iter().map(|state| state.ident).collect()
348}
349
350fn packet_stream_state_by_enum_variant(enum_variant: &mut Variant) -> PacketStreamState {
351    PacketStreamState {
352        ident: &enum_variant.ident,
353        packets: enum_variant
354            .fields
355            .iter_mut()
356            .map(|field| {
357                let mut has_lifetime = false;
358                Packet {
359                    ident: match &mut field.ty {
360                        Type::Path(path) => {
361                            if path.path.get_ident().is_none() {
362                                has_lifetime = true;
363                            }
364                            let ref mut value = path.path.segments;
365                            for segment in value.iter_mut() {
366                                segment.arguments = PathArguments::None;
367                            }
368                            path
369                        }
370                        _ => unimplemented!("type must path"),
371                    },
372                    changing_state: find_ident_in_attrs(&field.attrs, "change_state_to").map(
373                        |attr| match attr.meta {
374                            syn::Meta::List(list) => list.tokens,
375                            _ => panic!("attribute needs single value input"),
376                        },
377                    ),
378                    enforced_id: find_ident_in_attrs(&field.attrs, "id").map(|attr| {
379                        match attr.meta {
380                            syn::Meta::List(list) => {
381                                let tokens = list.tokens;
382                                quote! { = #tokens }
383                            }
384                            _ => panic!("attribute needs single value input"),
385                        }
386                    }),
387                    has_lifetime,
388                }
389            })
390            .collect(),
391        attrs: &enum_variant.attrs,
392    }
393}
394
395fn find_ident_in_attrs<'a>(attrs: &'a Vec<Attribute>, ident: &'static str) -> Option<Attribute> {
396    attrs
397        .iter()
398        .find(|attr| {
399            let list = match &attr.meta {
400                Meta::List(list) => list,
401                _ => return false,
402            };
403            if !list.path.is_ident(ident) {
404                return false;
405            }
406            true
407        })
408        .map(|v| v.clone())
409}
410
411fn paths_by_packets<'a>(packets: &Vec<&Packet<'a>>) -> Vec<&'a TypePath> {
412    packets.iter().map(|packet| packet.ident).collect()
413}
414
415fn ids_by_packets<'a>(packets: &Vec<&Packet<'a>>) -> Vec<Option<proc_macro2::TokenStream>> {
416    packets
417        .iter()
418        .map(|packet| packet.enforced_id.clone())
419        .collect()
420}
421
422fn packets_filtered_with_suffix<'a>(
423    packets: &'a Vec<Packet<'a>>,
424    ends_with: &'static str,
425) -> Vec<&'a Packet<'a>> {
426    packets
427        .iter()
428        .filter(|packet| {
429            packet
430                .ident
431                .to_token_stream()
432                .to_string()
433                .ends_with(ends_with)
434        })
435        .collect::<Vec<_>>()
436}
437
438fn attrs_by_states<'a>(states: &Vec<PacketStreamState<'a>>) -> Vec<&'a Vec<Attribute>> {
439    states.iter().map(|state| state.attrs).collect()
440}