flat_bytes_derive/
lib.rs

1#![deny(clippy::pedantic)]
2#![allow(clippy::missing_panics_doc)]
3
4use proc_macro::TokenStream;
5use quote::format_ident;
6use quote::quote;
7use quote::ToTokens;
8use syn::parse_macro_input;
9use syn::Field;
10use syn::Fields;
11use syn::ItemEnum;
12use syn::ItemStruct;
13
14#[proc_macro_derive(Flat)]
15pub fn derive_flat(input: TokenStream) -> TokenStream {
16    #![allow(clippy::similar_names)]
17
18    let input = parse_macro_input!(input as ItemStruct);
19
20    let ident = &input.ident;
21
22    let fields: Vec<Field> = match input.fields {
23        Fields::Named(ref n) => n.named.iter().cloned().collect(),
24        Fields::Unnamed(ref un) => un.unnamed.iter().cloned().collect(),
25        Fields::Unit => vec![],
26    };
27
28    let fields_ser = fields.iter().enumerate().map(|(idx, f)| {
29        let ty = &f.ty;
30        if let Some(i) = &f.ident {
31            quote! {
32                res.append(&mut <#ty as Flat>::serialize(&self.#i));
33            }
34        } else {
35            let idx = syn::Index::from(idx);
36            quote! {
37                res.append(&mut <#ty as Flat>::serialize(&self.#idx));
38            }
39        }
40    });
41
42    let fields_der = fields.iter().enumerate().map(|(idx, f)| {
43        let ty = &f.ty;
44        if let Some(i) = &f.ident {
45            quote! {
46                let #i = <#ty as flat_bytes::Flat>::deserialize_with_size(data)?;
47                total += #i.1;
48                let data = &data[#i.1..];
49                let #i = #i.0;
50            }
51        } else {
52            let i = format_ident!("field{}", idx);
53            quote! {
54                let #i = <#ty as flat_bytes::Flat>::deserialize_with_size(data)?;
55                total += #i.1;
56                let data = &data[#i.1..];
57                let #i = #i.0;
58            }
59        }
60    });
61
62    let alloc = match input.fields {
63        Fields::Named(ref n) => {
64            let names = n.named.iter().map(|f| f.ident.as_ref().unwrap());
65            quote! {
66                #ident{#(#names),*}
67            }
68        }
69        Fields::Unnamed(ref un) => {
70            let names = (0..un.unnamed.len()).map(|i| format_ident!("field{}", i));
71            quote! {
72                #ident(#(#names),*)
73            }
74        }
75        Fields::Unit => ident.to_token_stream(),
76    };
77
78    let output = quote! {
79      impl flat_bytes::Flat for #ident {
80        fn deserialize_with_size(data: &[u8]) -> Option<(Self, usize)> {
81            let mut total = 0;
82            #(#fields_der)*
83            Some((#alloc, total))
84        }
85
86        fn serialize(&self) -> Vec<u8> {
87            use flat_bytes::Flat;
88            let mut res = vec![];
89            #(#fields_ser;)*
90            res
91        }
92      }
93    };
94    output.into()
95}
96
97fn derive_serialize(input: &ItemEnum, dtype: &syn::Path) -> proc_macro2::TokenStream {
98    let mut last_idx = 0;
99    let match_arms = input.variants.iter().map(|v| {
100        let i = v.ident.clone();
101        let d = v
102            .discriminant
103            .as_ref()
104            .and_then(|(_, e)| match e {
105                syn::Expr::Lit(syn::ExprLit {
106                    lit: syn::Lit::Int(i),
107                    ..
108                }) => i.base10_parse::<u64>().ok(),
109                _ => None,
110            })
111            .unwrap_or(last_idx + 1);
112        last_idx = d;
113        match &v.fields {
114            syn::Fields::Unit => quote! {
115              Self::#i => {
116                let i = #d as #dtype;
117                res.extend_from_slice(&i.to_le_bytes());
118              }
119            },
120            syn::Fields::Unnamed(fu) => {
121                let fields = fu
122                    .unnamed
123                    .iter()
124                    .enumerate()
125                    .map(|(i, f)| {
126                        let ty = &f.ty;
127                        let i = format_ident!("field{}", i);
128                        let t = quote! {
129                            &mut <#ty as Flat>::serialize(#i)
130                        };
131                        (i, t)
132                    })
133                    .collect::<Vec<_>>();
134                let (names, fields): (Vec<_>, Vec<_>) = fields.iter().cloned().unzip();
135                quote! {
136                  Self::#i(#(#names),*) => {
137                    let i = #d as #dtype;
138                    res.extend_from_slice(&i.to_le_bytes());
139                    #(
140                      res.append(#fields);
141                    )*
142                  }
143                }
144            }
145            syn::Fields::Named(fs) => {
146                let fields = fs
147                    .named
148                    .iter()
149                    .map(|f| {
150                        let ty = &f.ty;
151                        let i = f.ident.as_ref().unwrap();
152                        (
153                            i,
154                            quote! {
155                                &mut <#ty as Flat>::serialize(#i)
156                            },
157                        )
158                    })
159                    .collect::<Vec<_>>();
160                let (names, fields): (Vec<_>, Vec<_>) = fields.iter().cloned().unzip();
161                quote! {
162                  Self::#i{#(#names),*} => {
163                    let i = #d as #dtype;
164                    res.extend_from_slice(&i.to_le_bytes());
165                    #(
166                      res.append(#fields);
167                    )*
168                  }
169                }
170            }
171        }
172    });
173
174    quote! {
175      let mut res: Vec<u8> = vec![];
176      match self {
177        #(#match_arms),*
178      }
179      res
180    }
181}
182
183fn derive_deserialize(input: &ItemEnum, dtype: &syn::Path) -> proc_macro2::TokenStream {
184    let ident = &input.ident;
185    let mut last_idx = 0;
186    let match_arms = input.variants.iter().map(|v| {
187        let i = v.ident.clone();
188        let d = v
189            .discriminant
190            .as_ref()
191            .and_then(|(_, e)| match e {
192                syn::Expr::Lit(syn::ExprLit {
193                    lit: syn::Lit::Int(i),
194                    ..
195                }) => i.base10_parse::<u64>().ok(),
196                _ => None,
197            })
198            .unwrap_or(last_idx + 1);
199        last_idx = d;
200        match &v.fields {
201            syn::Fields::Unit => quote! {
202              #d => {
203                Some((#ident::#i, total))
204              }
205            },
206            syn::Fields::Unnamed(fu) => {
207                let fields = fu
208                    .unnamed
209                    .iter()
210                    .enumerate()
211                    .map(|(i, f)| {
212                        let name = quote::format_ident!("field{}", i);
213                        let ty = &f.ty;
214                        quote! {
215                          let #name = #ty::deserialize_with_size(data)?;
216                          let data = &data[#name.1..];
217                          total += #name.1;
218                          let #name = #name.0;
219                        }
220                    })
221                    .collect::<Vec<_>>();
222                let field_names = fu
223                    .unnamed
224                    .iter()
225                    .enumerate()
226                    .map(|(i, _f)| quote::format_ident!("field{}", i))
227                    .collect::<Vec<_>>();
228                quote! {
229                  #d => {
230                    #(
231                      #fields
232                    )*
233                    Some((#ident::#i(#(#field_names),*), total))
234                  }
235                }
236            }
237            syn::Fields::Named(fs) => {
238                let fields = fs
239                    .named
240                    .iter()
241                    .map(|f| {
242                        let name = f.ident.clone().unwrap();
243                        let ty = &f.ty;
244                        quote! {
245                          let #name = #ty::deserialize_with_size(data)?;
246                          let data = &data[#name.1..];
247                          total += #name.1;
248                          let #name = #name.0;
249                        }
250                    })
251                    .collect::<Vec<_>>();
252                let field_names = fs
253                    .named
254                    .iter()
255                    .map(|f| f.ident.clone().unwrap())
256                    .collect::<Vec<_>>();
257                quote! {
258                  #d => {
259                    #(
260                      #fields
261                    )*
262                    Some((#ident::#i{#(#field_names),*}, total))
263                  }
264                }
265            }
266        }
267    });
268
269    quote! {
270      if data.len() < ::std::mem::size_of::<#dtype>() {
271        return None
272      }
273      let idx = {
274        let mut tmp = [0u8; ::std::mem::size_of::<#dtype>()];
275        tmp.copy_from_slice(&data[..::std::mem::size_of::<#dtype>()]);
276        #dtype::from_le_bytes(tmp) as u64
277      };
278      let data = &data[::std::mem::size_of::<#dtype>()..];
279      let mut total = ::std::mem::size_of::<#dtype>();
280
281      match idx {
282        #(#match_arms,)*
283        _ => None,
284      }
285    }
286}
287
288#[proc_macro]
289pub fn flat_enum(input: TokenStream) -> TokenStream {
290    let input = parse_macro_input!(input as ItemEnum);
291    let mut enum_output = input.clone();
292    for v in enum_output.variants.iter_mut() {
293        v.discriminant = None;
294    }
295
296    let ident = &input.ident;
297    let dtype = input
298        .attrs
299        .iter()
300        .flat_map(syn::Attribute::parse_meta)
301        .find_map(|m| {
302            if !m.path().is_ident("repr") {
303                return None;
304            }
305            match m {
306                syn::Meta::List(l) => match l.nested.first() {
307                    Some(syn::NestedMeta::Meta(m)) => Some(m.path().clone()),
308                    _ => None,
309                },
310                _ => None,
311            }
312        })
313        .unwrap();
314
315    let serialize = derive_serialize(&input, &dtype);
316    let deserialize = derive_deserialize(&input, &dtype);
317
318    (quote! {
319      #enum_output
320
321      impl flat_bytes::Flat for #ident {
322        fn deserialize_with_size(data: &[u8]) -> Option<(Self, usize)> {
323          #deserialize
324        }
325
326        fn serialize(&self) -> Vec<u8> {
327          use flat_bytes::Flat;
328          #serialize
329        }
330      }
331    })
332    .into()
333}