bin_proto_derive/
lib.rs

1#![deny(
2    clippy::pedantic,
3    clippy::nursery,
4    clippy::cargo,
5    clippy::unwrap_used,
6    clippy::expect_used,
7    clippy::suspicious,
8    clippy::complexity,
9    clippy::perf,
10    clippy::style,
11    unsafe_code
12)]
13#![allow(clippy::module_name_repetitions, clippy::option_if_let_else)]
14
15#[macro_use]
16extern crate quote;
17
18mod attr;
19mod codegen;
20mod enums;
21
22use attr::{AttrKind, Attrs};
23use codegen::{
24    decode_pad, encode_pad,
25    trait_impl::{impl_trait_for, TraitImplType},
26};
27use proc_macro2::TokenStream;
28use syn::{parse_macro_input, spanned::Spanned, Error, Result};
29
30use crate::codegen::enums::{decode_discriminant, encode_discriminant, variant_discriminant};
31
32#[derive(Clone, Copy)]
33enum Operation {
34    Decode,
35    Encode,
36}
37
38#[proc_macro_derive(BitDecode, attributes(bin_proto))]
39pub fn decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
40    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
41    match impl_codec(&ast, Operation::Decode) {
42        Ok(tokens) => tokens,
43        Err(e) => e.to_compile_error(),
44    }
45    .into()
46}
47
48#[proc_macro_derive(BitEncode, attributes(bin_proto))]
49pub fn encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
51    match impl_codec(&ast, Operation::Encode) {
52        Ok(tokens) => tokens,
53        Err(e) => e.to_compile_error(),
54    }
55    .into()
56}
57
58fn impl_codec(ast: &syn::DeriveInput, codec_type: Operation) -> Result<TokenStream> {
59    match ast.data {
60        syn::Data::Struct(ref s) => impl_for_struct(ast, s, codec_type),
61        syn::Data::Enum(ref e) => impl_for_enum(ast, e, codec_type),
62        syn::Data::Union(..) => Err(Error::new(
63            ast.span(),
64            "bin-proto traits are not derivable on unions",
65        )),
66    }
67}
68
69fn impl_for_struct(
70    ast: &syn::DeriveInput,
71    strukt: &syn::DataStruct,
72    codec_type: Operation,
73) -> Result<TokenStream> {
74    let attrs = Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Struct), ast.span())?;
75
76    let ctx_ty = attrs.ctx_ty();
77
78    let (impl_body, trait_type) = match codec_type {
79        Operation::Decode => {
80            let (decodes, initializers) = codegen::decodes(&strukt.fields)?;
81            let pad_before = attrs.pad_before.as_ref().map(decode_pad);
82            let pad_after = attrs.pad_after.as_ref().map(decode_pad);
83            let magic = attrs.decode_magic();
84
85            (
86                quote!(
87                    fn decode<__R, __E>(
88                        __io_reader: &mut __R,
89                        __ctx: &mut #ctx_ty,
90                        __tag: (),
91                    ) -> ::bin_proto::Result<Self>
92                    where
93                        __R: ::bin_proto::BitRead,
94                        __E: ::bin_proto::Endianness,
95                    {
96                        #pad_before
97                        #magic
98                        #decodes
99                        #pad_after
100                        ::core::result::Result::Ok(Self #initializers)
101                    }
102                ),
103                TraitImplType::Decode,
104            )
105        }
106        Operation::Encode => {
107            let encodes = codegen::encodes(&strukt.fields, true)?;
108            let pad_before = attrs.pad_before.as_ref().map(encode_pad);
109            let pad_after = attrs.pad_after.as_ref().map(encode_pad);
110            let magic = attrs.encode_magic();
111
112            (
113                quote!(
114                    fn encode<__W, __E>(
115                        &self,
116                        __io_writer: &mut __W,
117                        __ctx: &mut #ctx_ty,
118                        (): (),
119                    ) -> ::bin_proto::Result<()>
120                    where
121                        __W: ::bin_proto::BitWrite,
122                        __E: ::bin_proto::Endianness,
123                    {
124                        #pad_before
125                        #magic
126                        #encodes
127                        #pad_after
128                        ::core::result::Result::Ok(())
129                    }
130                ),
131                TraitImplType::Encode,
132            )
133        }
134    };
135
136    impl_trait_for(ast, &impl_body, &trait_type)
137}
138
139#[allow(clippy::too_many_lines)]
140fn impl_for_enum(
141    ast: &syn::DeriveInput,
142    e: &syn::DataEnum,
143    codec_type: Operation,
144) -> Result<TokenStream> {
145    let plan = enums::Enum::try_new(ast, e)?;
146    let attrs = Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Enum), ast.span())?;
147    let discriminant_ty = &plan.discriminant_ty;
148    let ctx_ty = attrs.ctx_ty();
149
150    Ok(match codec_type {
151        Operation::Decode => {
152            let decode_variant = codegen::enums::decode_variant_fields(&plan)?;
153            let impl_body = quote!(
154                fn decode<__R, __E>(
155                    __io_reader: &mut __R,
156                    __ctx: &mut #ctx_ty,
157                    __tag: ::bin_proto::Tag<__Tag>,
158                ) -> ::bin_proto::Result<Self>
159                where
160                    __R: ::bin_proto::BitRead,
161                    __E: ::bin_proto::Endianness,
162                {
163                    ::core::result::Result::Ok(#decode_variant)
164                }
165            );
166            let tagged_decode_impl = impl_trait_for(
167                ast,
168                &impl_body,
169                &TraitImplType::TaggedDecode(discriminant_ty.clone()),
170            )?;
171
172            let decode_discriminant = decode_discriminant(&attrs);
173            let impl_body = quote!(
174                fn decode<__R, __E>(
175                    __io_reader: &mut __R,
176                    __ctx: &mut #ctx_ty,
177                    __tag: (),
178                ) -> ::bin_proto::Result<Self>
179                where
180                    __R: ::bin_proto::BitRead,
181                    __E: ::bin_proto::Endianness,
182                {
183                    let __tag: #discriminant_ty = #decode_discriminant?;
184                    <Self as ::bin_proto::BitDecode<_, ::bin_proto::Tag<#discriminant_ty>>>::decode::<_, __E>(
185                        __io_reader,
186                        __ctx,
187                        ::bin_proto::Tag(__tag)
188                    )
189                }
190            );
191            let decode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Decode)?;
192
193            quote!(
194                #tagged_decode_impl
195                #decode_impl
196            )
197        }
198        Operation::Encode => {
199            let encode_variant = codegen::enums::encode_variant_fields(&plan)?;
200            let pad_before = attrs.pad_before.as_ref().map(encode_pad);
201            let pad_after = attrs.pad_after.as_ref().map(encode_pad);
202            let impl_body = quote!(
203                fn encode<__W, __E>(
204                    &self,
205                    __io_writer: &mut __W,
206                    __ctx: &mut #ctx_ty,
207                    __tag: ::bin_proto::Untagged,
208                ) -> ::bin_proto::Result<()>
209                where
210                    __W: ::bin_proto::BitWrite,
211                    __E: ::bin_proto::Endianness,
212                {
213                    #pad_before
214                    #encode_variant
215                    #pad_after
216                    ::core::result::Result::Ok(())
217                }
218            );
219            let untagged_encode_impl =
220                impl_trait_for(ast, &impl_body, &TraitImplType::UntaggedEncode)?;
221
222            let variant_discriminant = variant_discriminant(&plan)?;
223            let impl_body = quote!(
224                type Discriminant = #discriminant_ty;
225
226                fn discriminant(&self) -> ::core::option::Option<Self::Discriminant> {
227                    #variant_discriminant
228                }
229            );
230            let discriminable_impl =
231                impl_trait_for(ast, &impl_body, &TraitImplType::Discriminable)?;
232
233            let encode_discriminant = encode_discriminant(&attrs);
234            let impl_body = quote!(
235                fn encode<__W, __E>(
236                    &self,
237                    __io_writer: &mut __W,
238                    __ctx: &mut #ctx_ty,
239                    (): (),
240                ) -> ::bin_proto::Result<()>
241                where
242                    __W: ::bin_proto::BitWrite,
243                    __E: ::bin_proto::Endianness,
244                {
245                    #pad_before
246                    #encode_discriminant
247                    let res = <Self as ::bin_proto::BitEncode<_, _>>::encode::<_, __E>(
248                        self,
249                        __io_writer,
250                        __ctx,
251                        ::bin_proto::Untagged
252                    )?;
253                    #pad_after
254                    ::core::result::Result::Ok(res)
255                }
256            );
257            let encode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Encode)?;
258
259            quote!(
260                #untagged_encode_impl
261                #discriminable_impl
262                #encode_impl
263            )
264        }
265    })
266}