bin_proto_derive/
lib.rs

1#![deny(
2    clippy::pedantic,
3    clippy::nursery,
4    clippy::cargo,
5    clippy::unwrap_used,
6    clippy::expect_used,
7    clippy::suspicious,
8    clippy::complexity,
9    clippy::perf,
10    clippy::style,
11    unsafe_code
12)]
13#![allow(clippy::module_name_repetitions, clippy::option_if_let_else)]
14
15#[macro_use]
16extern crate quote;
17
18mod attr;
19mod codegen;
20mod enums;
21
22use attr::{AttrKind, Attrs};
23use codegen::{
24    decode_pad, encode_pad,
25    trait_impl::{impl_trait_for, TraitImplType},
26};
27use proc_macro2::TokenStream;
28use syn::{parse_macro_input, spanned::Spanned};
29
30use crate::codegen::enums::{decode_discriminant, encode_discriminant, variant_discriminant};
31
32#[derive(Clone, Copy)]
33enum Operation {
34    Decode,
35    Encode,
36}
37
38#[proc_macro_derive(BitDecode, attributes(codec))]
39pub fn decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
40    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
41    impl_codec(&ast, Operation::Decode).into()
42}
43
44#[proc_macro_derive(BitEncode, attributes(codec))]
45pub fn encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
46    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
47    impl_codec(&ast, Operation::Encode).into()
48}
49
50fn impl_codec(ast: &syn::DeriveInput, codec_type: Operation) -> TokenStream {
51    match ast.data {
52        syn::Data::Struct(ref s) => impl_for_struct(ast, s, codec_type),
53        syn::Data::Enum(ref e) => impl_for_enum(ast, e, codec_type),
54        syn::Data::Union(..) => unimplemented!("Codec is not derivable on Unions"),
55    }
56}
57
58fn impl_for_struct(
59    ast: &syn::DeriveInput,
60    strukt: &syn::DataStruct,
61    codec_type: Operation,
62) -> TokenStream {
63    let attrs = match Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Struct), ast.span()) {
64        Ok(attrs) => attrs,
65        Err(e) => return e.to_compile_error(),
66    };
67
68    let ctx_ty = attrs.ctx_ty();
69
70    let (impl_body, trait_type) = match codec_type {
71        Operation::Decode => {
72            let (decodes, initializers) = codegen::decodes(&strukt.fields);
73            let pad_before = attrs.pad_before.as_ref().map(decode_pad);
74            let pad_after = attrs.pad_after.as_ref().map(decode_pad);
75            let magic = attrs.decode_magic();
76
77            (
78                quote!(
79                    fn decode<__R, __E>(
80                        __io_reader: &mut __R,
81                        __ctx: &mut #ctx_ty,
82                        __tag: (),
83                    ) -> ::bin_proto::Result<Self>
84                    where
85                        __R: ::bin_proto::BitRead,
86                        __E: ::bin_proto::Endianness,
87                    {
88                        #pad_before
89                        #magic
90                        #decodes
91                        #pad_after
92                        ::core::result::Result::Ok(Self #initializers)
93                    }
94                ),
95                TraitImplType::Decode,
96            )
97        }
98        Operation::Encode => {
99            let encodes = codegen::encodes(&strukt.fields, true);
100            let pad_before = attrs.pad_before.as_ref().map(encode_pad);
101            let pad_after = attrs.pad_after.as_ref().map(encode_pad);
102            let magic = attrs.encode_magic();
103
104            (
105                quote!(
106                    fn encode<__W, __E>(
107                        &self,
108                        __io_writer: &mut __W,
109                        __ctx: &mut #ctx_ty,
110                        (): (),
111                    ) -> ::bin_proto::Result<()>
112                    where
113                        __W: ::bin_proto::BitWrite,
114                        __E: ::bin_proto::Endianness,
115                    {
116                        #pad_before
117                        #magic
118                        #encodes
119                        #pad_after
120                        ::core::result::Result::Ok(())
121                    }
122                ),
123                TraitImplType::Encode,
124            )
125        }
126    };
127
128    impl_trait_for(ast, &impl_body, &trait_type)
129}
130
131#[allow(clippy::too_many_lines)]
132fn impl_for_enum(ast: &syn::DeriveInput, e: &syn::DataEnum, codec_type: Operation) -> TokenStream {
133    let plan = match enums::Enum::try_new(ast, e) {
134        Ok(plan) => plan,
135        Err(e) => return e.to_compile_error(),
136    };
137    let attrs = match Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Enum), ast.span()) {
138        Ok(attrs) => attrs,
139        Err(e) => return e.to_compile_error(),
140    };
141    let discriminant_ty = &plan.discriminant_ty;
142    let ctx_ty = attrs.ctx_ty();
143
144    match codec_type {
145        Operation::Decode => {
146            let decode_variant = codegen::enums::decode_variant_fields(&plan);
147            let impl_body = quote!(
148                fn decode<__R, __E>(
149                    __io_reader: &mut __R,
150                    __ctx: &mut #ctx_ty,
151                    __tag: ::bin_proto::Tag<__Tag>,
152                ) -> ::bin_proto::Result<Self>
153                where
154                    __R: ::bin_proto::BitRead,
155                    __E: ::bin_proto::Endianness,
156                {
157                    ::core::result::Result::Ok(#decode_variant)
158                }
159            );
160            let tagged_decode_impl = impl_trait_for(
161                ast,
162                &impl_body,
163                &TraitImplType::TaggedDecode(discriminant_ty.clone()),
164            );
165
166            let decode_discriminant = decode_discriminant(&attrs);
167            let impl_body = quote!(
168                fn decode<__R, __E>(
169                    __io_reader: &mut __R,
170                    __ctx: &mut #ctx_ty,
171                    __tag: (),
172                ) -> ::bin_proto::Result<Self>
173                where
174                    __R: ::bin_proto::BitRead,
175                    __E: ::bin_proto::Endianness,
176                {
177                    let __tag: #discriminant_ty = #decode_discriminant?;
178                    <Self as ::bin_proto::BitDecode<_, ::bin_proto::Tag<#discriminant_ty>>>::decode::<_, __E>(
179                        __io_reader,
180                        __ctx,
181                        ::bin_proto::Tag(__tag)
182                    )
183                }
184            );
185            let decode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Decode);
186
187            quote!(
188                #tagged_decode_impl
189                #decode_impl
190            )
191        }
192        Operation::Encode => {
193            let encode_variant = codegen::enums::encode_variant_fields(&plan);
194            let pad_before = attrs.pad_before.as_ref().map(encode_pad);
195            let pad_after = attrs.pad_after.as_ref().map(encode_pad);
196            let impl_body = quote!(
197                fn encode<__W, __E>(
198                    &self,
199                    __io_writer: &mut __W,
200                    __ctx: &mut #ctx_ty,
201                    __tag: ::bin_proto::Untagged,
202                ) -> ::bin_proto::Result<()>
203                where
204                    __W: ::bin_proto::BitWrite,
205                    __E: ::bin_proto::Endianness,
206                {
207                    #pad_before
208                    #encode_variant
209                    #pad_after
210                    ::core::result::Result::Ok(())
211                }
212            );
213            let untagged_encode_impl =
214                impl_trait_for(ast, &impl_body, &TraitImplType::UntaggedEncode);
215
216            let variant_discriminant = variant_discriminant(&plan, &attrs);
217            let impl_body = quote!(
218                type Discriminant = #discriminant_ty;
219
220                fn discriminant(&self) -> Self::Discriminant {
221                    #variant_discriminant
222                }
223            );
224            let discriminable_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Discriminable);
225
226            let encode_discriminant = encode_discriminant(&attrs);
227            let impl_body = quote!(
228                fn encode<__W, __E>(
229                    &self,
230                    __io_writer: &mut __W,
231                    __ctx: &mut #ctx_ty,
232                    (): (),
233                ) -> ::bin_proto::Result<()>
234                where
235                    __W: ::bin_proto::BitWrite,
236                    __E: ::bin_proto::Endianness,
237                {
238                    #pad_before
239                    #encode_discriminant
240                    let res = <Self as ::bin_proto::BitEncode<_, _>>::encode::<_, __E>(
241                        self,
242                        __io_writer,
243                        __ctx,
244                        ::bin_proto::Untagged
245                    )?;
246                    #pad_after
247                    ::core::result::Result::Ok(res)
248                }
249            );
250            let encode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Encode);
251
252            quote!(
253                #untagged_encode_impl
254                #discriminable_impl
255                #encode_impl
256            )
257        }
258    }
259}