1use quote::quote;
2use syn::{Data, DeriveInput, parse_macro_input};
3
4mod extract;
5mod generate;
6
7use extract::*;
8use generate::*;
9
10#[proc_macro_derive(Packet, attributes(varint, varlong, vec_end, vec_varint, vec_varlong))]
29pub fn derive_packet(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
30 let input = parse_macro_input!(input as DeriveInput);
31
32 let name = &input.ident;
33 let read_impl = generate_read(&input);
34 let write_impl = generate_write(&input);
35
36 let expanded = quote! {
37 impl #name {
38 pub fn read(buffer: &mut std::io::Cursor<&[u8]>) -> Option<Self> {
39 #read_impl
40 }
41
42 pub fn write(&self, buffer: &mut impl std::io::Write) -> std::io::Result<()> {
43 #write_impl
44 Ok(())
45 }
46 }
47 };
48
49 expanded.into()
50}
51
52#[proc_macro_derive(PacketUnion, attributes(id))]
55pub fn derive_packet_union(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
56 let input = parse_macro_input!(input as DeriveInput);
57
58 let enum_name = &input.ident;
59
60 let variants = match &input.data {
61 Data::Enum(data) => &data.variants,
62 _ => {
63 return syn::Error::new_spanned(enum_name, "packet union can only be applied to enums").to_compile_error().into();
64 }
65 };
66
67 let enum_impl = quote! {
68 impl crate::IntoPacket<#enum_name> for #enum_name {
69 fn into_packet(self) -> #enum_name {
70 self
71 }
72 }
73 };
74
75 let variant_impls = variants.iter().map(|variant| {
76 let variant_name = &variant.ident;
77
78 match &variant.fields {
79 syn::Fields::Unnamed(fields) => {
80 if fields.unnamed.len() == 1 {
81 let field_type = &fields.unnamed[0].ty;
82
83 quote! {
84 impl crate::IntoPacket<#enum_name> for #field_type {
85 fn into_packet(self) -> #enum_name {
86 #enum_name::#variant_name(self)
87 }
88 }
89 }
90 } else {
91 quote! {
92 compile_error!("packet union variants must have exactly one field");
93 }
94 }
95 }
96 syn::Fields::Named(_) => {
97 quote! {
98 compile_error!("packet union variants must use unnamed fields");
99 }
100 }
101 syn::Fields::Unit => {
102 quote! {
103 compile_error!("packet union variants cannot be unit variants");
104 }
105 }
106 }
107 });
108
109 let has_packet_ids = variants.iter().any(|v| extract::extract_packet_id(v).is_some());
110
111 let packet_impl = if has_packet_ids {
112 let id_arms = variants.iter().filter_map(|variant| {
113 let variant_name = &variant.ident;
114 extract::extract_packet_id(variant).map(|id| {
115 quote! {
116 Self::#variant_name(_) => #id,
117 }
118 })
119 });
120
121 let read_arms = variants.iter().filter_map(|variant| {
122 let variant_name = &variant.ident;
123
124 match &variant.fields {
125 syn::Fields::Unnamed(fields) => {
126 if fields.unnamed.len() == 1 {
127 let field_type = &fields.unnamed[0].ty;
128 extract::extract_packet_id(variant).map(|id| {
129 quote! {
130 #id => Some(Self::#variant_name(<#field_type>::read(buf)?)),
131 }
132 })
133 } else {
134 None
135 }
136 }
137 _ => None,
138 }
139 });
140
141 let write_arms = variants.iter().map(|variant| {
142 let variant_name = &variant.ident;
143
144 quote! {
145 Self::#variant_name(p) => p.write(buf),
146 }
147 });
148
149 quote! {
150 impl crate::ProtocolPacket for #enum_name {
151 fn id(&self) -> u32 {
152 match self {
153 #(#id_arms)*
154 }
155 }
156
157 fn read(id: u32, buf: &mut std::io::Cursor<&[u8]>) -> Option<Self> {
158 match id {
159 #(#read_arms)*
160 _ => None,
161 }
162 }
163
164 fn write(&self, buf: &mut impl std::io::Write) -> std::io::Result<()> {
165 match self {
166 #(#write_arms)*
167 }
168 }
169 }
170 }
171 } else {
172 quote! {}
173 };
174
175 let expanded = quote! {
176 #enum_impl
177
178 #(#variant_impls)*
179
180 #packet_impl
181 };
182
183 expanded.into()
184}