Skip to main content

bin_proto_derive/
lib.rs

1#![deny(
2    clippy::pedantic,
3    clippy::nursery,
4    clippy::cargo,
5    clippy::unwrap_used,
6    clippy::expect_used,
7    clippy::suspicious,
8    clippy::complexity,
9    clippy::perf,
10    clippy::style,
11    unsafe_code
12)]
13#![allow(clippy::module_name_repetitions, clippy::option_if_let_else)]
14
15#[macro_use]
16extern crate quote;
17
18mod attr;
19mod codegen;
20mod enums;
21
22use attr::{AttrKind, Attrs};
23use codegen::{
24    decode_pad, encode_pad,
25    trait_impl::{impl_trait_for, TraitImplType},
26};
27use proc_macro2::TokenStream;
28use syn::{parse_macro_input, spanned::Spanned, Error, Result};
29
30use crate::codegen::enums::{decode_discriminant, encode_discriminant, variant_discriminant};
31
32#[derive(Clone, Copy)]
33enum Operation {
34    Decode,
35    Encode,
36}
37
38#[proc_macro_derive(BitDecode, attributes(bin_proto))]
39pub fn decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
40    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
41    match impl_codec(&ast, Operation::Decode) {
42        Ok(tokens) => tokens,
43        Err(e) => e.to_compile_error(),
44    }
45    .into()
46}
47
48#[proc_macro_derive(BitEncode, attributes(bin_proto))]
49pub fn encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50    let ast: syn::DeriveInput = parse_macro_input!(input as syn::DeriveInput);
51    match impl_codec(&ast, Operation::Encode) {
52        Ok(tokens) => tokens,
53        Err(e) => e.to_compile_error(),
54    }
55    .into()
56}
57
58fn impl_codec(ast: &syn::DeriveInput, codec_type: Operation) -> Result<TokenStream> {
59    match ast.data {
60        syn::Data::Struct(ref s) => impl_for_struct(ast, s, codec_type),
61        syn::Data::Enum(ref e) => impl_for_enum(ast, e, codec_type),
62        syn::Data::Union(..) => Err(Error::new(
63            ast.span(),
64            "bin-proto traits are not derivable on unions",
65        )),
66    }
67}
68
69fn impl_for_struct(
70    ast: &syn::DeriveInput,
71    strukt: &syn::DataStruct,
72    codec_type: Operation,
73) -> Result<TokenStream> {
74    let attrs = Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Struct), ast.span())?;
75    let crate_path = attrs.crate_path();
76    let ctx_ty = attrs.ctx_ty();
77
78    let (impl_body, trait_type) = match codec_type {
79        Operation::Decode => {
80            let (decodes, initializers) = codegen::decodes(&crate_path, &strukt.fields)?;
81            let pad_before = attrs
82                .pad_before
83                .as_ref()
84                .map(|pad| decode_pad(&crate_path, pad));
85            let pad_after = attrs
86                .pad_after
87                .as_ref()
88                .map(|pad| decode_pad(&crate_path, pad));
89            let magic = attrs.decode_magic();
90
91            (
92                quote!(
93                    fn decode<__R, __E>(
94                        __io_reader: &mut __R,
95                        __ctx: &mut #ctx_ty,
96                        __tag: (),
97                    ) -> #crate_path::Result<Self>
98                    where
99                        __R: #crate_path::BitRead,
100                        __E: #crate_path::Endianness,
101                    {
102                        #pad_before
103                        #magic
104                        #decodes
105                        #pad_after
106                        ::core::result::Result::Ok(Self #initializers)
107                    }
108                ),
109                TraitImplType::Decode,
110            )
111        }
112        Operation::Encode => {
113            let encodes = codegen::encodes(&crate_path, &strukt.fields, true)?;
114            let pad_before = attrs
115                .pad_before
116                .as_ref()
117                .map(|pad| encode_pad(&crate_path, pad));
118            let pad_after = attrs
119                .pad_after
120                .as_ref()
121                .map(|pad| encode_pad(&crate_path, pad));
122            let magic = attrs.encode_magic();
123
124            (
125                quote!(
126                    fn encode<__W, __E>(
127                        &self,
128                        __io_writer: &mut __W,
129                        __ctx: &mut #ctx_ty,
130                        (): (),
131                    ) -> #crate_path::Result<()>
132                    where
133                        __W: #crate_path::BitWrite,
134                        __E: #crate_path::Endianness,
135                    {
136                        #pad_before
137                        #magic
138                        #encodes
139                        #pad_after
140                        ::core::result::Result::Ok(())
141                    }
142                ),
143                TraitImplType::Encode,
144            )
145        }
146    };
147
148    impl_trait_for(ast, &impl_body, &trait_type)
149}
150
151#[allow(clippy::too_many_lines)]
152fn impl_for_enum(
153    ast: &syn::DeriveInput,
154    e: &syn::DataEnum,
155    codec_type: Operation,
156) -> Result<TokenStream> {
157    let plan = enums::Enum::try_new(ast, e)?;
158    let attrs = Attrs::parse(ast.attrs.as_slice(), Some(AttrKind::Enum), ast.span())?;
159    let crate_path = attrs.crate_path();
160    let discriminant_ty = &plan.discriminant_ty;
161    let ctx_ty = attrs.ctx_ty();
162
163    Ok(match codec_type {
164        Operation::Decode => {
165            let decode_variant = codegen::enums::decode_variant_fields(&plan)?;
166            let impl_body = quote!(
167                fn decode<__R, __E>(
168                    __io_reader: &mut __R,
169                    __ctx: &mut #ctx_ty,
170                    __tag: #crate_path::Tag<__Tag>,
171                ) -> #crate_path::Result<Self>
172                where
173                    __R: #crate_path::BitRead,
174                    __E: #crate_path::Endianness,
175                {
176                    ::core::result::Result::Ok(#decode_variant)
177                }
178            );
179            let tagged_decode_impl = impl_trait_for(
180                ast,
181                &impl_body,
182                &TraitImplType::TaggedDecode(discriminant_ty.clone()),
183            )?;
184
185            let decode_discriminant = decode_discriminant(&attrs);
186            let impl_body = quote!(
187                fn decode<__R, __E>(
188                    __io_reader: &mut __R,
189                    __ctx: &mut #ctx_ty,
190                    __tag: (),
191                ) -> #crate_path::Result<Self>
192                where
193                    __R: #crate_path::BitRead,
194                    __E: #crate_path::Endianness,
195                {
196                    let __tag: #discriminant_ty = #decode_discriminant?;
197                    <Self as #crate_path::BitDecode<_, #crate_path::Tag<#discriminant_ty>>>::decode::<_, __E>(
198                        __io_reader,
199                        __ctx,
200                        #crate_path::Tag(__tag)
201                    )
202                }
203            );
204            let decode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Decode)?;
205
206            quote!(
207                #tagged_decode_impl
208                #decode_impl
209            )
210        }
211        Operation::Encode => {
212            let encode_variant = codegen::enums::encode_variant_fields(&plan)?;
213            let pad_before = attrs
214                .pad_before
215                .as_ref()
216                .map(|pad| encode_pad(&crate_path, pad));
217            let pad_after = attrs
218                .pad_after
219                .as_ref()
220                .map(|pad| encode_pad(&crate_path, pad));
221            let impl_body = quote!(
222                fn encode<__W, __E>(
223                    &self,
224                    __io_writer: &mut __W,
225                    __ctx: &mut #ctx_ty,
226                    __tag: #crate_path::Untagged,
227                ) -> #crate_path::Result<()>
228                where
229                    __W: #crate_path::BitWrite,
230                    __E: #crate_path::Endianness,
231                {
232                    #pad_before
233                    #encode_variant
234                    #pad_after
235                    ::core::result::Result::Ok(())
236                }
237            );
238            let untagged_encode_impl =
239                impl_trait_for(ast, &impl_body, &TraitImplType::UntaggedEncode)?;
240
241            let variant_discriminant = variant_discriminant(&plan)?;
242            let impl_body = quote!(
243                type Discriminant = #discriminant_ty;
244
245                fn discriminant(&self) -> ::core::option::Option<Self::Discriminant> {
246                    #variant_discriminant
247                }
248            );
249            let discriminable_impl =
250                impl_trait_for(ast, &impl_body, &TraitImplType::Discriminable)?;
251
252            let encode_discriminant = encode_discriminant(&attrs);
253            let impl_body = quote!(
254                fn encode<__W, __E>(
255                    &self,
256                    __io_writer: &mut __W,
257                    __ctx: &mut #ctx_ty,
258                    (): (),
259                ) -> #crate_path::Result<()>
260                where
261                    __W: #crate_path::BitWrite,
262                    __E: #crate_path::Endianness,
263                {
264                    #pad_before
265                    #encode_discriminant
266                    let res = <Self as #crate_path::BitEncode<_, _>>::encode::<_, __E>(
267                        self,
268                        __io_writer,
269                        __ctx,
270                        #crate_path::Untagged
271                    )?;
272                    #pad_after
273                    ::core::result::Result::Ok(res)
274                }
275            );
276            let encode_impl = impl_trait_for(ast, &impl_body, &TraitImplType::Encode)?;
277
278            quote!(
279                #untagged_encode_impl
280                #discriminable_impl
281                #encode_impl
282            )
283        }
284    })
285}