bufferfish_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::TokenStream;
4use proc_macro_error::{abort, proc_macro_error};
5use quote::quote;
6use syn::{
7    parse_macro_input, spanned::Spanned, Data, DeriveInput, Expr, Fields, Index, Type, TypePath,
8};
9
10#[proc_macro_derive(Encode, attributes(bufferfish))]
11#[proc_macro_error]
12pub fn bufferfish_impl_encodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    let ast = parse_macro_input!(input as DeriveInput);
14    let name = &ast.ident;
15
16    let packet_id = get_packet_id(&ast);
17    let packet_id_snippet = {
18        if let Some(packet_id) = packet_id {
19            quote! { bf.write_u16(u16::from(#packet_id))?; }
20        } else {
21            quote! {}
22        }
23    };
24
25    let mut encoded_snippets = Vec::new();
26
27    match &ast.data {
28        Data::Struct(data) => match &data.fields {
29            Fields::Named(fields) => {
30                for field in &fields.named {
31                    let Some(ident) = field.ident.as_ref() else {
32                        abort!(field.span(), "named fields are required");
33                    };
34
35                    encode_type(quote! { self.#ident }, &field.ty, &mut encoded_snippets)
36                }
37            }
38            Fields::Unnamed(fields) => {
39                for (i, field) in fields.unnamed.iter().enumerate() {
40                    let index = Index::from(i);
41                    encode_type(quote! { self.#index }, &field.ty, &mut encoded_snippets)
42                }
43            }
44            Fields::Unit => {}
45        },
46        Data::Enum(_) => {
47            // Enums are just encoded as a u8.
48            // TODO: Support any size.
49            encoded_snippets.push(quote! {
50                bf.write_u8(*self as u8)?;
51            });
52        }
53        Data::Union(_) => abort!(ast.span(), "decoding union types is not supported"),
54    };
55
56    let gen = quote! {
57        impl bufferfish::Encodable for #name {
58            fn encode(&self, bf: &mut bufferfish::Bufferfish) -> Result<(), bufferfish::BufferfishError> {
59                #(#encoded_snippets)*
60                Ok(())
61            }
62
63            fn to_bufferfish(&self) -> Result<bufferfish::Bufferfish, bufferfish::BufferfishError> {
64                let mut bf = bufferfish::Bufferfish::new();
65                #packet_id_snippet
66                self.encode(&mut bf)?;
67
68                Ok(bf)
69            }
70        }
71    };
72
73    gen.into()
74}
75
76#[proc_macro_derive(Decode, attributes(bufferfish))]
77#[proc_macro_error]
78pub fn bufferfish_impl_decodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
79    let ast = parse_macro_input!(input as DeriveInput);
80    let name = &ast.ident;
81
82    let packet_id = get_packet_id(&ast);
83    let packet_id_snippet = {
84        if let Some(packet_id) = packet_id {
85            quote! {
86                let packet_id = bf.read_u16()?;
87                if packet_id != u16::from(#packet_id) {
88                    return Err(bufferfish::BufferfishError::InvalidPacketId);
89                }
90            }
91        } else {
92            quote! {}
93        }
94    };
95
96    let decoded_snippets = match &ast.data {
97        Data::Struct(data) => match &data.fields {
98            Fields::Named(fields) => fields
99                .named
100                .iter()
101                .map(|field| {
102                    let ident = field.ident.as_ref().expect("named fields required");
103                    let ty = &field.ty;
104                    quote! {
105                        #ident: <#ty as bufferfish::Decodable>::decode(bf)?,
106                    }
107                })
108                .collect::<Vec<_>>(),
109            Fields::Unnamed(fields) => fields
110                .unnamed
111                .iter()
112                .map(|field| {
113                    let ty = &field.ty;
114                    quote! {
115                        <#ty as bufferfish::Decodable>::decode(bf)?,
116                    }
117                })
118                .collect::<Vec<_>>(),
119            Fields::Unit => Vec::new(),
120        },
121        Data::Enum(data_enum) => data_enum
122            .variants
123            .iter()
124            .enumerate()
125            .map(|(i, variant)| {
126                let ident = &variant.ident;
127                let idx = Index::from(i);
128                quote! {
129                    #idx => Self::#ident,
130                }
131            })
132            .collect::<Vec<_>>(),
133        Data::Union(_) => abort!(ast.span(), "unions are not supported"),
134    };
135
136    let gen = match &ast.data {
137        Data::Struct(data) => match &data.fields {
138            Fields::Named(_) => {
139                quote! {
140                    impl bufferfish::Decodable for #name {
141                        fn decode(bf: &mut bufferfish::Bufferfish) -> Result<Self, bufferfish::BufferfishError> {
142                            #packet_id_snippet
143                            Ok(Self {
144                                #(#decoded_snippets)*
145                            })
146                        }
147                    }
148                }
149            }
150            Fields::Unnamed(_) => {
151                quote! {
152                    impl bufferfish::Decodable for #name {
153                        fn decode(bf: &mut bufferfish::Bufferfish) -> Result<Self, bufferfish::BufferfishError> {
154                            #packet_id_snippet
155                            Ok(Self(
156                                #(#decoded_snippets)*
157                            ))
158                        }
159                    }
160                }
161            }
162            Fields::Unit => {
163                quote! {
164                    impl bufferfish::Decodable for #name {
165                        fn decode(bf: &mut bufferfish::Bufferfish) -> Result<Self, bufferfish::BufferfishError> {
166                            #packet_id_snippet
167                            Ok(Self)
168                        }
169                    }
170                }
171            }
172        },
173        Data::Enum(_) => {
174            quote! {
175                impl bufferfish::Decodable for #name {
176                    fn decode(bf: &mut bufferfish::Bufferfish) -> Result<Self, bufferfish::BufferfishError> {
177                        #packet_id_snippet
178                        let variant_idx = bf.read_u8()?;
179                        Ok(match variant_idx {
180                            #(#decoded_snippets)*
181                            _ => return Err(bufferfish::BufferfishError::InvalidEnumVariant),
182                        })
183                    }
184                }
185            }
186        }
187        _ => abort!(ast.span(), "only structs and enums are supported"),
188    };
189
190    gen.into()
191}
192
193fn get_packet_id(ast: &DeriveInput) -> Option<Expr> {
194    for attr in &ast.attrs {
195        if attr.path().is_ident("bufferfish") {
196            if let Ok(expr) = attr.parse_args::<syn::Expr>() {
197                return Some(expr);
198            } else {
199                abort!(attr.span(), "expected a single expression");
200            }
201        }
202    }
203
204    None
205}
206
207fn encode_type(accessor: TokenStream, ty: &Type, dst: &mut Vec<TokenStream>) {
208    match ty {
209        // Handle primitive types
210        Type::Path(TypePath { path, .. })
211            if path.is_ident("u8")
212                || path.is_ident("u16")
213                || path.is_ident("u32")
214                || path.is_ident("i8")
215                || path.is_ident("i16")
216                || path.is_ident("i32")
217                || path.is_ident("bool")
218                || path.is_ident("String") =>
219        {
220            dst.push(quote! {
221                bufferfish::Encodable::encode(&#accessor, bf)?;
222            });
223        }
224        // Handle arrays where elements impl Encodable
225        Type::Path(TypePath { path, .. })
226            if path.segments.len() == 1 && path.segments[0].ident == "Vec" =>
227        {
228            dst.push(quote! {
229                bf.write_array(&#accessor)?;
230            });
231        }
232        // Handle nested structs where fields impl Encodable
233        Type::Path(TypePath { .. }) => {
234            dst.push(quote! {
235                bufferfish::Encodable::encode(&#accessor, bf)?;
236            });
237        }
238        _ => abort!(ty.span(), "type cannot be encoded into a bufferfish"),
239    }
240}