1use proc_macro::TokenStream;
2use proc_macro2::TokenTree;
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5 parse_macro_input, parse_quote, Attribute, Ident, ItemEnum, Meta, PathArguments, Type,
6 TypePath, Variant, Visibility,
7};
8
9struct Bound {
10 suffix: &'static str,
11 bound_packet_ident: &'static str,
12}
13
14const CLIENT_BOUND: Bound = Bound {
15 suffix: "S2c",
16 bound_packet_ident: "ClientBoundPacket",
17};
18
19const SERVER_BOUND: Bound = Bound {
20 suffix: "C2s",
21 bound_packet_ident: "ServerBoundPacket",
22};
23
24struct PacketStream<'a> {
25 ident: &'a Ident,
26 attrs: &'a Vec<Attribute>,
27 vis: &'a Visibility,
28 states: Vec<PacketStreamState<'a>>,
29}
30
31struct PacketStreamState<'a> {
32 attrs: &'a Vec<Attribute>,
33 ident: &'a Ident,
34 packets: Vec<Packet<'a>>,
35}
36
37#[derive(Clone)]
38struct Packet<'a> {
39 ident: &'a TypePath,
40 has_lifetime: bool,
41 changing_state: Option<proc_macro2::TokenStream>,
42 enforced_id: Option<proc_macro2::TokenStream>,
43}
44
45#[proc_macro_attribute]
46pub fn packet_stream(_attr: TokenStream, input: TokenStream) -> TokenStream {
47 let mut input = parse_macro_input!(input as ItemEnum);
48 let packet_stream = packet_stream_by_inputs(&mut input);
49 let client_bound_generated = generate_by_bound(&packet_stream, CLIENT_BOUND);
50 let server_bound_generated = generate_by_bound(&packet_stream, SERVER_BOUND);
51 let main_body_generated = generate_main_enum_body(&packet_stream);
52
53 quote! {
54 #main_body_generated
55 #client_bound_generated
56 #server_bound_generated
57 }
58 .into()
59}
60
61fn generate_main_enum_body(packet_stream: &PacketStream) -> proc_macro2::TokenStream {
62 let vis = packet_stream.vis;
63 let packet_stream_ident = packet_stream.ident;
64 let state_idents = idents_by_states(&packet_stream.states);
65 let attrs = packet_stream.attrs;
66 let state_attrs = attrs_by_states(&packet_stream.states);
67 quote! {
68 #(#attrs)*
69 #[allow(dead_code)]
70 #[derive(Debug)]
71 #vis enum #packet_stream_ident {
72 #(#(#state_attrs)* #state_idents,)*
73 }
74 }
75}
76
77fn generate_by_bound(packet_stream: &PacketStream, bound: Bound) -> proc_macro2::TokenStream {
78 let packet_stream_ident = packet_stream.ident;
79
80 let bound_packet_ident = format_ident!("{}", bound.bound_packet_ident);
81 let state_packet_names = packet_stream
82 .states
83 .iter()
84 .map(|state| format_ident!("{}{}Packets", state.ident, bound.suffix))
85 .collect::<Vec<_>>();
86 let state_names = packet_stream
87 .states
88 .iter()
89 .map(|state| state.ident)
90 .collect::<Vec<_>>();
91 let vis = packet_stream.vis;
92 let state_lifetimes = packet_stream
93 .states
94 .iter()
95 .map(|state| {
96 packets_filtered_with_suffix(&state.packets, bound.suffix)
97 .iter()
98 .any(|packet| packet.has_lifetime)
99 .then_some(quote! {<'a>})
100 })
101 .collect::<Vec<_>>();
102 let bound_packet_lifetime = state_lifetimes
103 .iter()
104 .any(|b| b.is_some())
105 .then_some(quote! {<'a>});
106 let bound_packet_lifetime_without_bracket = bound_packet_lifetime.clone().map(|_| quote! {'a});
107 let state_quotes: Vec<_> = packet_stream
108 .states
109 .iter()
110 .map(|state| {
111 let state_ident = state.ident;
112 let state_bound_packets = packets_filtered_with_suffix(&state.packets, bound.suffix);
113 let state_bound_packet_paths = paths_by_packets(&state_bound_packets);
114 let state = state.ident;
115 let state_packets_name = format_ident!("{state_ident}{}Packets", bound.suffix);
116 let vis = packet_stream.vis;
117 let bound_packets = format_ident!("{}", bound.bound_packet_ident);
118 let state_bound_packet_ids = ids_by_packets(&state_bound_packets);
119 let repr_attr = if state_bound_packet_paths.is_empty() { None } else {
120 Some(quote! { #[repr(u32)] })
121 };
122 let state_packet_lifetime = state_bound_packets.iter().any(|packet| packet.has_lifetime).then_some(quote! {<'a>});
123 let state_bound_packet_lifetimes = state_bound_packets.iter().map(|packet| packet.has_lifetime.then_some(quote! {<'a>})).collect::<Vec<_>>();
124
125 let serialization_attr = if cfg!(feature = "serialization") {
126 Some(quote! {#[derive(serialization::Serializable)]})
127 } else {
128 None
129 };
130 let packets_enum = quote! {
131 #serialization_attr
132 #[derive(Debug)]
133 #repr_attr
134 #vis enum #state_packets_name #state_packet_lifetime {
135 #(#state_bound_packet_paths(#state_bound_packet_paths #state_bound_packet_lifetimes) #state_bound_packet_ids,)*
136 }
137 };
138 let changing_state_stmt: Vec<_> = state_bound_packets
139 .iter()
140 .map(|field| {
141 if let Some(state) = &field.changing_state {
142 Some(quote! {Some(#packet_stream_ident::#state)})
143 } else {
144 Some(quote! {None})
145 }
146 })
147 .collect();
148
149 quote! {
150 #packets_enum
151
152 impl #bound_packet_lifetime From<#state_packets_name #state_packet_lifetime> for #bound_packets #bound_packet_lifetime {
153 fn from(value: #state_packets_name #state_packet_lifetime) -> Self {
154 #bound_packets::#state_packets_name(value)
155 }
156 }
157
158 impl #state_packet_lifetime packetize::Packet<#packet_stream_ident> for #state_packets_name #state_packet_lifetime {
159 fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
160 match self {
161 #(
162 #state_packets_name::#state_bound_packet_paths(value) => {
163 packetize::Packet::<#packet_stream_ident>::get_id(value, state)
164 }
165 )*
166 _ => unreachable!()
167 }
168 }
169
170 fn is_changing_state(&self) -> Option<#packet_stream_ident> {
171 match self {
172 #(
173 #state_packets_name::#state_bound_packet_paths(value) => {
174 <#state_bound_packet_paths #state_bound_packet_lifetimes as packetize::Packet::<#packet_stream_ident>>::is_changing_state(value)
175 }
176 )*
177 _ => unreachable!()
178 }
179 }
180 }
181
182 impl #bound_packet_lifetime TryFrom<#bound_packets #bound_packet_lifetime> for #state_packets_name #state_packet_lifetime {
183 type Error = ();
184
185 fn try_from(value: #bound_packets #bound_packet_lifetime) -> Result<Self, Self::Error> {
186 match value {
187 #bound_packets::#state_packets_name(value) => Ok(value),
188 _ => Err(())?,
189 }
190 }
191 }
192
193 #(
194 impl #state_packet_lifetime From<#state_bound_packet_paths #state_bound_packet_lifetimes> for #state_packets_name #state_packet_lifetime {
195 fn from(value: #state_bound_packet_paths #state_bound_packet_lifetimes) -> Self {
196 #state_packets_name::#state_bound_packet_paths(value)
197 }
198 }
199
200 impl #bound_packet_lifetime From<#state_bound_packet_paths #state_bound_packet_lifetimes> for #bound_packets #bound_packet_lifetime {
201 fn from(value: #state_bound_packet_paths #state_bound_packet_lifetimes) -> Self {
202 #bound_packets::#state_packets_name(#state_packets_name::#state_bound_packet_paths(value))
203 }
204 }
205
206 impl #bound_packet_lifetime TryFrom<#bound_packets #bound_packet_lifetime> for #state_bound_packet_paths #state_bound_packet_lifetimes {
207 type Error = ();
208
209 fn try_from(value: #bound_packets #bound_packet_lifetime) -> Result<Self, Self::Error> {
210 match value {
211 #bound_packets::#state_packets_name(value) => Ok(value.try_into()?),
212 _ => Err(())?,
213 }
214 }
215 }
216
217 impl #state_packet_lifetime TryFrom<#state_packets_name #state_packet_lifetime> for #state_bound_packet_paths #state_bound_packet_lifetimes {
218 type Error = ();
219
220 fn try_from(value: #state_packets_name #state_packet_lifetime) -> Result<Self, Self::Error> {
221 match value {
222 #state_packets_name::#state_bound_packet_paths(value) => Ok(value),
223 _ => Err(())?,
224 }
225 }
226 }
227
228 impl #state_bound_packet_lifetimes packetize::Packet<#packet_stream_ident> for #state_bound_packet_paths #state_bound_packet_lifetimes {
229 fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
230 match state {
231 #packet_stream_ident::#state => {
232 Some(#state_packets_name::#state_bound_packet_paths as u32)
233 },
234 _ => None,
235 }
236 }
237
238 fn is_changing_state(&self) -> Option<#packet_stream_ident> {
239 #changing_state_stmt
240 }
241 }
242 )*
243 }
244 })
245 .collect();
246 let serialization_attr = if cfg!(feature = "serialization") {
247 Some(quote! {#[derive(serialization::Serializable)]})
248 } else {
249 None
250 };
251 let part1 = quote! {
252 #(#state_quotes)*
253
254 #serialization_attr
255 #[derive(Debug)]
256 #vis enum #bound_packet_ident #bound_packet_lifetime {
257 #(#state_packet_names(#state_packet_names #state_lifetimes),)*
258 }
259
260 impl #bound_packet_lifetime packetize::Packet<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
261 fn get_id(&self, state: &#packet_stream_ident) -> Option<u32> {
262 match self {
263 #(
264 #bound_packet_ident::#state_packet_names(value) => {
265 packetize::Packet::<#packet_stream_ident>::get_id(value, state)
266 }
267 )*
268 _ => unreachable!()
269 }
270 }
271
272 fn is_changing_state(&self) -> Option<#packet_stream_ident> {
273 match self {
274 #(
275 #bound_packet_ident::#state_packet_names(value) => {
276 <#state_packet_names #state_lifetimes as packetize::Packet::<#packet_stream_ident>>::is_changing_state(value)
277 }
278 )*
279 _ => unreachable!()
280 }
281 }
282 }
283 };
284
285 #[cfg(not(feature = "serialization"))]
286 let part2 = quote! {};
287 #[cfg(feature = "serialization")]
288 let part2 = quote! {
289 impl<'de: #bound_packet_lifetime_without_bracket, #bound_packet_lifetime_without_bracket>
290 packetize::DecodePacket<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
291 fn decode_packet<D: serialization::Decoder>(
292 decoder: D,
293 state: &mut #packet_stream_ident,
294 ) -> Result<Self, D::Error> {
295 let result: Self = match state {
296 #(
297 #packet_stream_ident::#state_names =>
298 <#state_packet_names as serialization::Decode::>::decode_placed(decoder)?.into(),
299 )*
300 };
301 if let Some(new_state) = <Self as packetize::Packet::<#packet_stream_ident>>::is_changing_state(&result) {
302 *state = new_state;
303 }
304 Ok(result)
305 }
306 }
307
308 impl #bound_packet_lifetime packetize::EncodePacket<#packet_stream_ident> for #bound_packet_ident #bound_packet_lifetime {
309 fn encode_packet<E: serialization::Encoder>(
310 &self,
311 encoder: E,
312 state: &mut #packet_stream_ident,
313 ) -> Result<(), E::Error> {
314 if let Some(new_state) = <Self as packetize::Packet::<#packet_stream_ident>>::is_changing_state(self) {
315 *state = new_state;
316 }
317 match self {
318 #(
319 #bound_packet_ident::#state_packet_names(value) => serialization::Encode::encode(value, encoder)?,
320 )*
321 };
322 Ok(())
323 }
324 }
325 };
326 quote! {
327 #part1
328 #part2
329 }
330}
331
332fn packet_stream_by_inputs<'a>(item_enum: &'a mut ItemEnum) -> PacketStream<'a> {
333 let states: Vec<_> = item_enum
334 .variants
335 .iter_mut()
336 .map(|enum_variant| packet_stream_state_by_enum_variant(enum_variant))
337 .collect();
338 PacketStream {
339 ident: &item_enum.ident,
340 vis: &item_enum.vis,
341 states,
342 attrs: &item_enum.attrs,
343 }
344}
345
346fn idents_by_states<'a>(states: &Vec<PacketStreamState<'a>>) -> Vec<&'a Ident> {
347 states.iter().map(|state| state.ident).collect()
348}
349
350fn packet_stream_state_by_enum_variant(enum_variant: &mut Variant) -> PacketStreamState {
351 PacketStreamState {
352 ident: &enum_variant.ident,
353 packets: enum_variant
354 .fields
355 .iter_mut()
356 .map(|field| {
357 let mut has_lifetime = false;
358 Packet {
359 ident: match &mut field.ty {
360 Type::Path(path) => {
361 if path.path.get_ident().is_none() {
362 has_lifetime = true;
363 }
364 let ref mut value = path.path.segments;
365 for segment in value.iter_mut() {
366 segment.arguments = PathArguments::None;
367 }
368 path
369 }
370 _ => unimplemented!("type must path"),
371 },
372 changing_state: find_ident_in_attrs(&field.attrs, "change_state_to").map(
373 |attr| match attr.meta {
374 syn::Meta::List(list) => list.tokens,
375 _ => panic!("attribute needs single value input"),
376 },
377 ),
378 enforced_id: find_ident_in_attrs(&field.attrs, "id").map(|attr| {
379 match attr.meta {
380 syn::Meta::List(list) => {
381 let tokens = list.tokens;
382 quote! { = #tokens }
383 }
384 _ => panic!("attribute needs single value input"),
385 }
386 }),
387 has_lifetime,
388 }
389 })
390 .collect(),
391 attrs: &enum_variant.attrs,
392 }
393}
394
395fn find_ident_in_attrs<'a>(attrs: &'a Vec<Attribute>, ident: &'static str) -> Option<Attribute> {
396 attrs
397 .iter()
398 .find(|attr| {
399 let list = match &attr.meta {
400 Meta::List(list) => list,
401 _ => return false,
402 };
403 if !list.path.is_ident(ident) {
404 return false;
405 }
406 true
407 })
408 .map(|v| v.clone())
409}
410
411fn paths_by_packets<'a>(packets: &Vec<&Packet<'a>>) -> Vec<&'a TypePath> {
412 packets.iter().map(|packet| packet.ident).collect()
413}
414
415fn ids_by_packets<'a>(packets: &Vec<&Packet<'a>>) -> Vec<Option<proc_macro2::TokenStream>> {
416 packets
417 .iter()
418 .map(|packet| packet.enforced_id.clone())
419 .collect()
420}
421
422fn packets_filtered_with_suffix<'a>(
423 packets: &'a Vec<Packet<'a>>,
424 ends_with: &'static str,
425) -> Vec<&'a Packet<'a>> {
426 packets
427 .iter()
428 .filter(|packet| {
429 packet
430 .ident
431 .to_token_stream()
432 .to_string()
433 .ends_with(ends_with)
434 })
435 .collect::<Vec<_>>()
436}
437
438fn attrs_by_states<'a>(states: &Vec<PacketStreamState<'a>>) -> Vec<&'a Vec<Attribute>> {
439 states.iter().map(|state| state.attrs).collect()
440}