packbytes_derive/
lib.rs

1//! Derive macros for the `packbytes` crate.
2#![warn(missing_docs)]
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{Fields, Ident, Item, ItemEnum, ItemStruct, LitInt, Meta, Type};
11
12type UnitFields = Punctuated<syn::Field, Comma>;
13
14/// Derive the `FromBytes` trait for structs where each field implements it.
15///
16/// Const generics in stable don't allow implementing `FromBytes` for arrays `[T; N]` where
17/// `T: FromBytes`. This macro circumvents that by deriving a different implementation for fields
18/// whose types are arrays, allowing the trait to be derived even for structs with such fields.
19///
20/// # Endianness
21/// By default, the `FromBytes` and `ToBytes` derive macros assume that the data is prefered to be stored
22/// in the little endian order.
23/// You can change this by setting the attribute `#[packbytes(be)]` for big endian or `#[packbytes(ne)]`
24/// for the platform native endian.
25#[proc_macro_derive(FromBytes, attributes(packbytes))]
26pub fn frombytes_derive(input: TokenStream) -> TokenStream {
27    let item: ItemStruct =
28        syn::parse(input).expect("#[derive(FromBytes)] can be only applied to structs");
29
30    let name = &item.ident;
31    let generics = &item.generics;
32    let unit_fields = UnitFields::new();
33
34    let mut prefers_le = quote!(true);
35    for attr in item.attrs.iter() {
36        if let Meta::List(ref list) = attr.meta {
37            if list.path.is_ident("packbytes") {
38                get_endianness(&list.tokens, &mut prefers_le);
39            }
40        }
41    }
42
43    let fields = match item.fields {
44        Fields::Named(fields) => fields.named.into_iter(),
45        Fields::Unnamed(fields) => fields.unnamed.into_iter(),
46        _ => unit_fields.into_iter(),
47    };
48
49    let fields = fields.enumerate().map(|(i, field)| {
50        let name = field.ident.map_or_else(
51            || {
52                let i = i.to_string();
53                let lit = LitInt::new(&i, Span::call_site());
54                quote!(#lit)
55            },
56            |n| quote!(#n),
57        );
58        (name, field.ty)
59    });
60
61    let field_sizes = fields.clone().map(|(_, ty)| {
62        if let Type::Array(arr) = ty {
63            let len = arr.len;
64            let aty = arr.elem;
65            quote! { (#len) * <<#aty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
66        } else {
67            quote! { <<#ty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
68        }
69    });
70
71    let from_fields = |method| {
72        fields.clone().map(move |(name, ty)| {
73        if let Type::Array(arr) = ty {
74            let len = arr.len;
75            let aty = arr.elem;
76            quote! {
77                #name: {
78                    let size = <<#aty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
79                    let val = ::core::array::from_fn(|j| {
80                        <#aty as ::packbytes::FromBytes>::#method(bytes[i+j*size..i+(j+1)*size].try_into().unwrap())
81                    });
82                    i += (#len) * size;
83                    val
84                }
85            }
86        } else {
87            quote! {
88                #name: {
89                    let size = <<#ty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
90                    let val = <#ty as ::packbytes::FromBytes>::#method(bytes[i..i+size].try_into().unwrap());
91                    i += size;
92                    val
93                }
94            }
95        }
96    })
97    };
98    let from_le_fields = from_fields(quote!(from_le_bytes));
99    let from_be_fields = from_fields(quote!(from_be_bytes));
100
101    let tokens = quote! {
102        impl #generics ::packbytes::FromBytes for #name #generics {
103            type Bytes = [u8; #( #field_sizes + )* 0];
104
105            const PREFERS_LE: bool = #prefers_le;
106
107            #[inline]
108            fn from_le_bytes(bytes: Self::Bytes) -> Self {
109                let mut i = 0;
110                Self { #( #from_le_fields , )* }
111            }
112
113            #[inline]
114            fn from_be_bytes(bytes: Self::Bytes) -> Self {
115                let mut i = 0;
116                Self { #( #from_be_fields , )* }
117            }
118        }
119    };
120    TokenStream::from(tokens)
121}
122
123/// Derive the `ToBytes` trait for structs where each field implements it and fieldless enums.
124///
125/// Const generics in stable don't allow implementing `ToBytes` for arrays `[T; N]` where
126/// `T: ToBytes`. This macro circumvents that by deriving a different implementation for fields
127/// whose types are arrays, allowing the trait to be derived even for structs with such fields.
128///
129/// # Endianness
130/// By default, the `FromBytes` and `ToBytes` derive macros assume that the data is prefered to be stored
131/// in the little endian order.
132/// You can change this by setting the attribute `#[packbytes(be)]` for big endian or `#[packbytes(ne)]`
133/// for the platform native endian.
134///
135/// # Fieldless enums
136/// The trait is implemented for fieldless enums by converting the numerical value (of the type
137/// set by the `repr` attribute on the enum) to bytes.
138#[proc_macro_derive(ToBytes, attributes(packbytes))]
139pub fn tobytes_derive(input: TokenStream) -> TokenStream {
140    match syn::parse::<Item>(input) {
141        Ok(Item::Struct(item)) => tobytes_struct_derive(item),
142        Ok(Item::Enum(item)) => tobytes_enum_derive(item),
143        _ => panic!("#[derive(ToBytes)] can be only applied to structs or enums"),
144    }
145}
146
147fn tobytes_struct_derive(item: ItemStruct) -> TokenStream {
148    let name = &item.ident;
149    let generics = &item.generics;
150    let unit_fields = UnitFields::new();
151
152    let mut prefers_le = quote!(true);
153    for attr in item.attrs.iter() {
154        if let Meta::List(ref list) = attr.meta {
155            if list.path.is_ident("packbytes") {
156                get_endianness(&list.tokens, &mut prefers_le);
157            }
158        }
159    }
160
161    let fields = match item.fields {
162        Fields::Named(fields) => fields.named.into_iter(),
163        Fields::Unnamed(fields) => fields.unnamed.into_iter(),
164        _ => unit_fields.into_iter(),
165    };
166
167    let fields = fields.enumerate().map(|(i, field)| {
168        let name = field.ident.map_or_else(
169            || {
170                let i = i.to_string();
171                let lit = LitInt::new(&i, Span::call_site());
172                quote!(#lit)
173            },
174            |n| quote!(#n),
175        );
176        (name, field.ty)
177    });
178
179    let field_sizes = fields.clone().map(|(_, ty)| {
180        if let Type::Array(arr) = ty {
181            let len = arr.len;
182            let aty = arr.elem;
183            quote! { (#len) * <<#aty as ::packbytes::ToBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
184        } else {
185            quote! { <<#ty as ::packbytes::ToBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
186        }
187    });
188
189    let to_fields = |method| {
190        fields.clone().map(move |(name, ty)| {
191        if let Type::Array(arr) = ty {
192            let len = arr.len;
193            let aty = arr.elem;
194            quote! {
195                let size = <<#aty as ::packbytes::ToBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
196                for j in 0..(#len) {
197                    bytes[i+j*size..i+(j+1)*size].copy_from_slice(&<#aty as ::packbytes::ToBytes>::#method(self.#name[j]));
198                }
199                i += (#len)*size;
200            }
201        } else {
202            quote! {
203                let size = <<#ty as ::packbytes::ToBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
204                bytes[i..i+size].copy_from_slice(&<#ty as ::packbytes::ToBytes>::#method(self.#name));
205                i += size;
206            }
207        }
208    })
209    };
210    let to_le_fields = to_fields(quote!(to_le_bytes));
211    let to_be_fields = to_fields(quote!(to_be_bytes));
212
213    let tokens = quote! {
214        impl #generics ::packbytes::ToBytes for #name #generics {
215            type Bytes = [u8; #( #field_sizes + )* 0];
216
217            const PREFERS_LE: bool = #prefers_le;
218
219            #[inline]
220            fn to_le_bytes(self) -> Self::Bytes {
221                let mut bytes = <Self::Bytes as ::packbytes::ByteArray>::zeroed();
222                let mut i = 0;
223                #( #to_le_fields )*
224                bytes
225            }
226
227            #[inline]
228            fn to_be_bytes(self) -> Self::Bytes {
229                let mut bytes = <Self::Bytes as ::packbytes::ByteArray>::zeroed();
230                let mut i = 0;
231                #( #to_be_fields )*
232                bytes
233            }
234        }
235    };
236    TokenStream::from(tokens)
237}
238
239fn tobytes_enum_derive(item: ItemEnum) -> TokenStream {
240    for variant in item.variants.iter() {
241        let Fields::Unit = variant.fields else {
242            panic!("#[derive(ToBytes)] can be only applied to fieldless enums");
243        };
244    }
245
246    let mut repr = quote!(u8);
247    let mut prefers_le = quote!(true);
248    for attr in item.attrs.iter() {
249        if let Meta::List(ref list) = attr.meta {
250            if list.path.is_ident("packbytes") {
251                get_endianness(&list.tokens, &mut prefers_le);
252            } else if list.path.is_ident("repr") {
253                get_numeric_type(&list.tokens, &mut repr);
254            }
255        }
256    }
257
258    let name = &item.ident;
259    let generics = &item.generics;
260
261    let tokens = quote! {
262        impl #generics ::packbytes::ToBytes for #name #generics {
263            type Bytes = [u8; #repr::BITS as usize / 8];
264
265            #[inline]
266            fn to_le_bytes(self) -> Self::Bytes {
267                (self as #repr).to_le_bytes()
268            }
269
270            #[inline]
271            fn to_be_bytes(self) -> Self::Bytes {
272                (self as #repr).to_be_bytes()
273            }
274        }
275    };
276    TokenStream::from(tokens)
277}
278
279/// Derive the `TryFromBytes` trait for structs where each field implements it and fieldless enums.
280///
281/// Const generics in stable don't allow implementing `TryFromBytes` for arrays `[T; N]` where
282/// `T: TryFromBytes`. This macro circumvents that by deriving a different implementation for fields
283/// whose types are arrays, allowing the trait to be derived even for structs with such fields.
284/// Note that before `core::array::try_from_fn` is stabilised, this is provided only for `T: FromBytes`.
285///
286/// # Endianness
287/// By default, the `FromBytes` and `ToBytes` derive macros assume that the data is prefered to be stored
288/// in the little endian order.
289/// You can change this by setting the attribute `#[packbytes(be)]` for big endian or `#[packbytes(ne)]`
290/// for the platform native endian.
291///
292/// # Fieldless enums
293/// The trait is implementing for fieldless enums by first converting bytes to a numerical value
294/// (of the type set by the `repr` attribute on the enum) and then comparing it to the values of all variants.
295///
296/// # Errors
297/// By default, the error type is `packbytes::errors::InvalidData`. You can provide a custom error
298/// type with the `packbytes_error` attribute.
299///
300/// For enums, in case the bytes don't represent a valid enum variant, a value of the erro provided by
301/// the `Default` trait (if implemented) is returned. You can overrride it by setting the
302/// `packbytes_error_exp` attribute for a custom error expression.
303///
304/// For structs, you need to make sure that your error implements `From<<T as TryFromBytes>::Error>` for every `T`
305/// which is a type of a field of the struct. In particular, the types that implement `FromBytes`
306/// have error type `std::convert::Infallible` (the conversion can never fail).
307/// Thus your error type should implement `From<std::convert::Infallible>`.
308/// (When the `!` type is stabilised, this will automatically be true for every type.
309///
310/// ```
311/// # use packbytes_derive::TryFromBytes;
312/// enum MyError {
313///     InvalidFoo,
314///     SomethingElse
315/// }
316///
317/// impl From<std::convert::Infallible> for MyError {
318///     fn from(impossible: std::convert::Infallible) -> Self {
319///         unreachable!()
320///     }
321/// }
322///
323/// #[derive(TryFromBytes)]
324/// #[packbytes_error(MyError)]
325/// #[packbytes_error_exp(MyError::InvalidFoo)]
326/// enum Foo {
327///     Bar,
328///     Baz
329/// }
330///
331/// #[derive(TryFromBytes)]
332/// #[packbytes_error(MyError)]
333/// struct MyStruct {
334///     val: u16,
335///     foo: Foo
336/// }
337/// ```
338#[proc_macro_derive(TryFromBytes, attributes(packbytes, packbytes_error, packbytes_error_exp))]
339pub fn tryfrombytes_derive(input: TokenStream) -> TokenStream {
340    match syn::parse::<Item>(input) {
341        Ok(Item::Struct(item)) => tryfrombytes_struct_derive(item),
342        Ok(Item::Enum(item)) => tryfrombytes_enum_derive(item),
343        _ => panic!("#[derive(TryFromBytes)] can be only applied to structs or enums"),
344    }
345}
346
347fn tryfrombytes_struct_derive(item: ItemStruct) -> TokenStream {
348    let name = &item.ident;
349    let generics = &item.generics;
350    let unit_fields = UnitFields::new();
351
352    let mut error = quote!(::packbytes::error::InvalidData);
353    let mut prefers_le = quote!(true);
354    for attr in item.attrs.iter() {
355        if let Meta::List(ref list) = attr.meta {
356            if list.path.is_ident("packbytes") {
357                get_endianness(&list.tokens, &mut prefers_le);
358            } else if list.path.is_ident("packbytes_error") {
359                error = list.tokens.clone();
360            }
361        }
362    }
363
364    let fields = match item.fields {
365        Fields::Named(fields) => fields.named.into_iter(),
366        Fields::Unnamed(fields) => fields.unnamed.into_iter(),
367        _ => unit_fields.into_iter(),
368    };
369
370    let fields = fields.enumerate().map(|(i, field)| {
371        let name = field.ident.map_or_else(
372            || {
373                let i = i.to_string();
374                let lit = LitInt::new(&i, Span::call_site());
375                quote!(#lit)
376            },
377            |n| quote!(#n),
378        );
379        (name, field.ty)
380    });
381
382    let field_sizes = fields.clone().map(|(_, ty)| {
383        if let Type::Array(arr) = ty {
384            let len = arr.len;
385            let aty = arr.elem;
386            quote! { (#len) * <<#aty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
387        } else {
388            quote! { <<#ty as ::packbytes::TryFromBytes>::Bytes as ::packbytes::ByteArray>::SIZE }
389        }
390    });
391
392    // TODO: switch to `core::array::try_from_fn` once stabilised
393    let from_fields = |method, regular_method| {
394        fields.clone().map(move |(name, ty)| {
395        if let Type::Array(arr) = ty {
396            let len = arr.len;
397            let aty = arr.elem;
398            quote! {
399                #name: {
400                    let size = <<#aty as ::packbytes::FromBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
401                    let val = ::core::array::from_fn(|j| {
402                        <#aty as ::packbytes::FromBytes>::#regular_method(bytes[i+j*size..i+(j+1)*size].try_into().unwrap())
403                    });
404                    i += (#len) * size;
405                    val
406                }
407            }
408        } else {
409            quote! {
410                #name: {
411                    let size = <<#ty as ::packbytes::TryFromBytes>::Bytes as ::packbytes::ByteArray>::SIZE;
412                    let val = <#ty as ::packbytes::TryFromBytes>::#method(bytes[i..i+size].try_into().unwrap())?;
413                    i += size;
414                    val
415                }
416            }
417        }
418    })
419    };
420    let from_le_fields = from_fields(quote!(try_from_le_bytes), quote!(from_le_bytes));
421    let from_be_fields = from_fields(quote!(try_from_be_bytes), quote!(from_be_bytes));
422
423    let tokens = quote! {
424        impl #generics ::packbytes::TryFromBytes for #name #generics {
425            type Bytes = [u8; #( #field_sizes + )* 0];
426            type Error = #error;
427
428            const PREFERS_LE: bool = #prefers_le;
429
430            #[inline]
431            fn try_from_le_bytes(bytes: Self::Bytes) -> Result<Self, Self::Error> {
432                let mut i = 0;
433                Ok(Self { #( #from_le_fields , )* })
434            }
435
436            #[inline]
437            fn try_from_be_bytes(bytes: Self::Bytes) -> Result<Self, Self::Error> {
438                let mut i = 0;
439                Ok(Self { #( #from_be_fields , )* })
440            }
441        }
442    };
443    TokenStream::from(tokens)
444}
445
446fn tryfrombytes_enum_derive(item: ItemEnum) -> TokenStream {
447    for variant in item.variants.iter() {
448        let Fields::Unit = variant.fields else {
449            panic!("#[derive(TryFromBytes)] can be only applied to fieldless enums");
450        };
451    }
452
453    let mut repr = quote!(u8);
454    let mut error = quote!(::packbytes::error::InvalidData);
455    let mut error_exp = quote!(Default::default());
456    let mut prefers_le = quote!(true);
457    for attr in item.attrs.iter() {
458        if let Meta::List(ref list) = attr.meta {
459            if list.path.is_ident("packbytes") {
460                get_endianness(&list.tokens, &mut prefers_le);
461            } else if list.path.is_ident("repr") {
462                get_numeric_type(&list.tokens, &mut repr);
463            } else if list.path.is_ident("packbytes_error") {
464                error = list.tokens.clone();
465            } else if list.path.is_ident("packbytes_error_exp") {
466                error_exp = list.tokens.clone();
467            }
468        }
469    }
470
471    let name = &item.ident;
472    let generics = &item.generics;
473
474    let branches_le = item
475        .variants
476        .iter()
477        .map(|v| v.ident.clone())
478        .map(|variant| quote!(a if a == (Self :: #variant as #repr) => Ok(Self :: #variant) ,));
479    let branches_be = branches_le.clone();
480
481    let tokens = quote! {
482        impl #generics ::packbytes::TryFromBytes for #name #generics {
483            type Bytes = [u8; #repr::BITS as usize / 8];
484            type Error = #error;
485
486            const PREFERS_LE: bool = #prefers_le;
487
488            #[inline]
489            fn try_from_le_bytes(bytes: Self::Bytes) -> Result<Self, Self::Error> {
490                match #repr::from_le_bytes(bytes) {
491                    #( #branches_le )*
492                    _ => Err(#error_exp)
493                }
494            }
495
496            #[inline]
497            fn try_from_be_bytes(bytes: Self::Bytes) -> Result<Self, Self::Error> {
498                match #repr::from_be_bytes(bytes) {
499                    #( #branches_be )*
500                    _ => Err(#error_exp)
501                }
502            }
503        }
504    };
505    TokenStream::from(tokens)
506}
507
508fn get_endianness(ts: &proc_macro2::TokenStream, end: &mut proc_macro2::TokenStream) {
509    let ident = syn::parse2::<Ident>(ts.clone()).unwrap().to_string();
510    match ident.as_str() {
511        "be" => { *end = quote!(false); },
512        "ne" => { *end = quote!(cfg!(target_endian = "little")); }
513        "le" => {},
514        _ => { panic!("the valid values of the `packbytes` attribute are \"le\", \"be\" and \"ne\""); }
515    }
516}
517
518fn get_numeric_type(ts: &proc_macro2::TokenStream, repr: &mut proc_macro2::TokenStream) {
519    if let Ok(ident) = syn::parse2::<Ident>(ts.clone()) {
520        let ident = ident.to_string();
521        if matches!(ident.as_str(),
522            "u8"
523            | "u16"
524            | "u32"
525            | "u64"
526            | "u128"
527            | "usize"
528            | "i8"
529            | "i16"
530            | "i32"
531            | "i64"
532            | "i128"
533            | "isize"
534        ) {
535            *repr = ts.clone();
536        }
537    }
538}