Skip to main content

msgpacker_derive/
lib.rs

1#![crate_type = "proc-macro"]
2extern crate proc_macro;
3
4// This code is bad and should be refactored into something cleaner. Maybe some syn-based
5// framework?
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::punctuated::Punctuated;
10use syn::{
11    parse_macro_input, parse_quote, parse_str, Block, Data, DataEnum, DataStruct, DataUnion,
12    DeriveInput, Expr, ExprMatch, ExprTuple, Field, FieldPat, FieldValue, Fields, FieldsNamed,
13    FieldsUnnamed, GenericArgument, Generics, Ident, Member, Meta, Pat, PatIdent, PathArguments,
14    Token, Type, Variant,
15};
16
17fn contains_attribute(field: &Field, name: &str) -> bool {
18    let name = name.to_string();
19    if let Some(attr) = field.attrs.first() {
20        if let Meta::List(list) = &attr.meta {
21            if list.path.is_ident("msgpacker")
22                && list
23                    .tokens
24                    .clone()
25                    .into_iter()
26                    .find(|a| a.to_string() == name)
27                    .is_some()
28            {
29                return true;
30            }
31        }
32    }
33    false
34}
35
36fn impl_fields_named(name: Ident, f: FieldsNamed, generics: &Generics) -> impl Into<TokenStream> {
37    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
38    let mut values: Punctuated<FieldValue, Token![,]> = Punctuated::new();
39    let block_packable: Block = parse_quote! {
40        {
41            let mut n = 0;
42        }
43    };
44    let block_unpackable: Block = parse_quote! {
45        {
46            let mut n = 0;
47        }
48    };
49    let block_unpackable_iter: Block = parse_quote! {
50        {
51            let mut bytes = bytes.into_iter();
52            let mut n = 0;
53        }
54    };
55
56    let (
57        mut block_packable,
58        mut block_unpackable,
59        mut block_unpackable_iter
60    ) = f.named.into_pairs().map(|p| p.into_value()).fold(
61            (block_packable, block_unpackable, block_unpackable_iter),
62            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), field| {
63                let ident = field.ident.as_ref().cloned().unwrap();
64                let ty = field.ty.clone();
65
66                let mut is_vec = false;
67                let mut is_vec_u8 = false;
68
69                match &ty {
70                    Type::Path(p) if p.path.segments.last().filter(|p| p.ident == "Vec").is_some() => {
71                        is_vec = true;
72                        match &p.path.segments.last().unwrap().arguments {
73                            PathArguments::AngleBracketed(a) if a.args.len() == 1 => {
74                                if let Some(GenericArgument::Type(Type::Path(p))) = a.args.first() {
75                                    if p.path.segments.last().filter(|p| p.ident == "u8").is_some() {
76                                        is_vec_u8 = true;
77                                    }
78                                }
79                            }
80                            _ => (),
81                        }
82                    }
83
84                    _ => (),
85                }
86
87                if contains_attribute(&field, "map") {
88                    block_packable.stmts.push(parse_quote! {
89                        n += ::msgpacker::pack_map(buf, &self.#ident);
90                    });
91
92                    block_unpackable.stmts.push(parse_quote! {
93                        let #ident = ::msgpacker::unpack_map(buf).map(|(nv, t)| {
94                            n += nv;
95                            buf = &buf[nv..];
96                            t
97                        })?;
98                    });
99
100                    block_unpackable_iter.stmts.push(parse_quote! {
101                        let #ident = ::msgpacker::unpack_map_iter(bytes.by_ref()).map(|(nv, t)| {
102                            n += nv;
103                            t
104                        })?;
105                    });
106                } else if contains_attribute(&field, "array") || is_vec && !is_vec_u8 {
107                    block_packable.stmts.push(parse_quote! {
108                        n += ::msgpacker::pack_array(buf, &self.#ident);
109                    });
110
111                    block_unpackable.stmts.push(parse_quote! {
112                        let #ident = ::msgpacker::unpack_array(buf).map(|(nv, t)| {
113                            n += nv;
114                            buf = &buf[nv..];
115                            t
116                        })?;
117                    });
118
119                    block_unpackable_iter.stmts.push(parse_quote! {
120                        let #ident = ::msgpacker::unpack_array_iter(bytes.by_ref()).map(|(nv, t)| {
121                            n += nv;
122                            t
123                        })?;
124                    });
125                } else {
126                    block_packable.stmts.push(parse_quote! {
127                        n += <#ty as ::msgpacker::Packable>::pack(&self.#ident, buf);
128                    });
129
130                    block_unpackable.stmts.push(parse_quote! {
131                        let #ident = ::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
132                            n += nv;
133                            buf = &buf[nv..];
134                            t
135                        })?;
136                    });
137
138                    block_unpackable_iter.stmts.push(parse_quote! {
139                        let #ident = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
140                            n += nv;
141                            t
142                        })?;
143                    });
144                }
145
146
147                values.push(FieldValue {
148                    attrs: vec![],
149                    member: Member::Named(ident.clone()),
150                    colon_token: Some(<Token![:]>::default()),
151                    expr: parse_quote! { #ident },
152                });
153
154                (block_packable, block_unpackable, block_unpackable_iter)
155            },
156        );
157
158    block_packable.stmts.push(parse_quote! {
159        return n;
160    });
161
162    block_unpackable.stmts.push(parse_quote! {
163        return Ok((
164            n,
165            Self {
166                #values
167            },
168        ));
169    });
170
171    block_unpackable_iter.stmts.push(parse_quote! {
172        return Ok((
173            n,
174            Self {
175                #values
176            },
177        ));
178    });
179
180    quote! {
181        impl #impl_generics ::msgpacker::Packable for #name #ty_generics #where_clause {
182            fn pack<T>(&self, buf: &mut T) -> usize
183            where
184                T: Extend<u8>,
185                #block_packable
186        }
187
188        impl #impl_generics ::msgpacker::Unpackable for #name #ty_generics #where_clause {
189            type Error = ::msgpacker::Error;
190
191            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
192                #block_unpackable
193
194            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
195            where
196                I: IntoIterator<Item = u8>,
197                #block_unpackable_iter
198        }
199    }
200}
201
202fn impl_fields_unnamed(
203    name: Ident,
204    f: FieldsUnnamed,
205    generics: &Generics,
206) -> impl Into<TokenStream> {
207    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
208    let mut values: Punctuated<Expr, Token![,]> = Punctuated::new();
209    let block_packable: Block = parse_quote! {
210        {
211            let mut n = 0;
212        }
213    };
214    let block_unpackable: Block = parse_quote! {
215        {
216            let mut n = 0;
217        }
218    };
219    let block_unpackable_iter: Block = parse_quote! {
220        {
221            let mut bytes = bytes.into_iter();
222            let mut n = 0;
223        }
224    };
225
226    let (mut block_packable, mut block_unpackable, mut block_unpackable_iter) = f
227        .unnamed
228        .into_pairs()
229        .map(|p| p.into_value())
230        .enumerate()
231        .fold(
232            (block_packable, block_unpackable, block_unpackable_iter),
233            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), (i, field)| {
234                let ty = field.ty.clone();
235                let var: Expr = parse_str(format!("v{}", i).as_str()).unwrap();
236                let slf: Expr = parse_str(format!("self.{}", i).as_str()).unwrap();
237
238                if contains_attribute(&field, "map") {
239                    todo!("unnamed map is not implemented for derive macro; implement the traits manually")
240                } else if contains_attribute(&field, "array") {
241                    todo!("unnamed array is not implemented for derive macro; implement the traits manually")
242                } else {
243                    block_packable.stmts.push(parse_quote! {
244                        n += <#ty as ::msgpacker::Packable>::pack(&#slf, buf);
245                    });
246
247                    block_unpackable.stmts.push(parse_quote! {
248                        let #var = ::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
249                            n += nv;
250                            buf = &buf[nv..];
251                            t
252                        })?;
253                    });
254
255                    block_unpackable_iter.stmts.push(parse_quote! {
256                        let #var = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
257                            n += nv;
258                            t
259                        })?;
260                    });
261                }
262
263                values.push(var);
264
265                (block_packable, block_unpackable, block_unpackable_iter)
266            },
267        );
268
269    block_packable.stmts.push(parse_quote! {
270        return n;
271    });
272
273    block_unpackable.stmts.push(parse_quote! {
274        return Ok((n, Self(#values)));
275    });
276
277    block_unpackable_iter.stmts.push(parse_quote! {
278        return Ok((n, Self(#values)));
279    });
280
281    quote! {
282        impl #impl_generics ::msgpacker::Packable for #name #ty_generics #where_clause {
283            fn pack<T>(&self, buf: &mut T) -> usize
284            where
285                T: Extend<u8>,
286                #block_packable
287        }
288
289        impl #impl_generics ::msgpacker::Unpackable for #name #ty_generics #where_clause {
290            type Error = ::msgpacker::Error;
291
292            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
293                #block_unpackable
294
295            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
296            where
297                I: IntoIterator<Item = u8>,
298                #block_unpackable_iter
299        }
300    }
301}
302
303fn impl_fields_unit(name: Ident, generics: &Generics) -> impl Into<TokenStream> {
304    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
305    quote! {
306        impl #impl_generics ::msgpacker::Packable for #name #ty_generics #where_clause {
307            fn pack<T>(&self, _buf: &mut T) -> usize
308            where
309                T: Extend<u8>,
310            {
311                0
312            }
313        }
314
315        impl #impl_generics ::msgpacker::Unpackable for #name #ty_generics #where_clause {
316            type Error = ::msgpacker::Error;
317
318            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
319                Ok((0, Self))
320            }
321
322            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
323            where
324                I: IntoIterator<Item = u8>,
325            {
326                Ok((0, Self))
327            }
328        }
329    }
330}
331
332fn impl_fields_enum(
333    name: Ident,
334    v: Punctuated<Variant, Token![,]>,
335    generics: &Generics,
336) -> impl Into<TokenStream> {
337    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
338    if v.is_empty() {
339        todo!("empty enum is not implemented for derive macro; implement the traits manually");
340    }
341
342    let mut block_packable: ExprMatch = parse_quote! {
343        match self {
344        }
345    };
346
347    let mut block_unpackable: ExprMatch = parse_quote! {
348        match discriminant {
349        }
350    };
351
352    let mut block_unpackable_iter: ExprMatch = parse_quote! {
353        match discriminant {
354        }
355    };
356
357    v.into_iter().enumerate().for_each(|(i, v)| {
358        let discriminant = v
359            .discriminant
360            .map(|(_, d)| d)
361            .unwrap_or_else(|| parse_str(format!("{}", i).as_str()).unwrap());
362
363        // TODO check attributes of the field
364        let ident = v.ident.clone();
365        match v.fields {
366            Fields::Named(f) => {
367                let mut blk: Block = parse_str("{}").unwrap();
368                let mut blk_unpack: Block = parse_str("{}").unwrap();
369                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
370                let mut blk_unpack_fields: Punctuated<FieldValue, Token![,]> = Punctuated::new();
371
372                blk.stmts.push(parse_quote! {
373                    n += (#discriminant as u32).pack(buf);
374                });
375
376                f.named
377                    .iter()
378                    .filter_map(|n| n.ident.as_ref())
379                    .for_each(|field| {
380                        blk.stmts.push(parse_quote! {
381                            n += #field.pack(buf);
382                        });
383
384                        blk_unpack_fields.push(parse_quote! { #field });
385
386                        blk_unpack.stmts.push(parse_quote! {
387                            let #field =::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
388                                n += nv;
389                                buf = &buf[nv..];
390                                t
391                            })?;
392                        });
393
394                        blk_unpack_iter.stmts.push(parse_quote! {
395                            let #field =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
396                                n += nv;
397                                t
398                            })?;
399                        });
400                    });
401
402                let mut arm: syn::Arm = parse_quote! {
403                    #name::#ident {} => #blk,
404                };
405
406                f.named
407                    .iter()
408                    .filter_map(|n| n.ident.as_ref())
409                    .for_each(|field| {
410                        match &mut arm.pat {
411                            Pat::Struct(s) => {
412                                s.fields.push(FieldPat {
413                                    attrs: vec![],
414                                    member: Member::Named(field.clone()),
415                                    colon_token: None,
416                                    pat: Box::new(Pat::Ident(PatIdent {
417                                        attrs: vec![],
418                                        by_ref: None,
419                                        mutability: None,
420                                        ident: field.clone(),
421                                        subpat: None,
422                                    })),
423                                });
424                            }
425                            _ => todo!(
426                                "enum variant is not implemented for derive macro; implement the traits manually"
427                            ),
428                        }
429                    });
430
431                block_packable.arms.push(arm);
432
433                blk_unpack.stmts.push(parse_quote! {
434                    slf = #name::#ident { #blk_unpack_fields };
435                });
436
437                blk_unpack_iter.stmts.push(parse_quote! {
438                    slf = #name::#ident { #blk_unpack_fields };
439                });
440
441                block_unpackable.arms.push(parse_quote! {
442                    #discriminant => #blk_unpack,
443                });
444
445                block_unpackable_iter.arms.push(parse_quote! {
446                    #discriminant => #blk_unpack_iter,
447                });
448            }
449
450            Fields::Unnamed(f) => {
451                let mut blk: Block = parse_str("{}").unwrap();
452                let mut blk_unpack: Block = parse_str("{}").unwrap();
453                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
454
455                blk.stmts.push(parse_quote! {
456                    n += (#discriminant as u32).pack(buf);
457                });
458
459                let mut tuple_arm: ExprTuple = parse_str("()").unwrap();
460                f.unnamed.iter().enumerate().for_each(|(ii, _field)| {
461                    let ti: Expr = parse_str(format!("t{}", ii).as_str()).unwrap();
462                    tuple_arm.elems.push(ti.clone());
463
464                    blk.stmts.push(parse_quote! {
465                        n += #ti.pack(buf);
466                    });
467
468                    blk_unpack.stmts.push(parse_quote! {
469                        let #ti =::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
470                            n += nv;
471                            buf = &buf[nv..];
472                            t
473                        })?;
474                    });
475
476                    blk_unpack_iter.stmts.push(parse_quote! {
477                        let #ti =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
478                            n += nv;
479                            t
480                        })?;
481                    });
482                });
483
484                blk_unpack.stmts.push(parse_quote! {
485                    slf = #name::#ident #tuple_arm;
486                });
487
488                blk_unpack_iter.stmts.push(parse_quote! {
489                    slf = #name::#ident #tuple_arm;
490                });
491
492                block_packable.arms.push(parse_quote! {
493                    #name::#ident #tuple_arm => #blk,
494                });
495
496                block_unpackable.arms.push(parse_quote! {
497                    #discriminant => #blk_unpack,
498                });
499
500                block_unpackable_iter.arms.push(parse_quote! {
501                    #discriminant => #blk_unpack_iter,
502                });
503            }
504
505            Fields::Unit => {
506                block_packable.arms.push(parse_quote! {
507                    #name::#ident => {
508                        n += (#discriminant as u32).pack(buf);
509                    }
510                });
511
512                block_unpackable.arms.push(parse_quote! {
513                    #discriminant => slf = #name::#ident,
514                });
515
516                block_unpackable_iter.arms.push(parse_quote! {
517                    #discriminant => slf = #name::#ident,
518                });
519            }
520        }
521    });
522
523    block_unpackable.arms.push(parse_quote! {
524        _ => {
525            return Err(::msgpacker::Error::InvalidEnumVariant);
526        }
527    });
528
529    block_unpackable_iter.arms.push(parse_quote! {
530        _ => {
531            return Err(::msgpacker::Error::InvalidEnumVariant);
532        }
533    });
534
535    quote! {
536        impl #impl_generics ::msgpacker::Packable for #name #ty_generics #where_clause {
537            fn pack<T>(&self, buf: &mut T) -> usize
538            where
539                T: Extend<u8>,
540            {
541                let mut n = 0;
542
543                #block_packable;
544
545                return n;
546            }
547        }
548
549        impl #impl_generics ::msgpacker::Unpackable for #name #ty_generics #where_clause {
550            type Error = ::msgpacker::Error;
551
552            #[allow(unused_mut)]
553            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
554                let (mut n, discriminant) = u32::unpack_with_ofs(&mut buf)?;
555                buf = &buf[n..];
556                let slf;
557
558                #block_unpackable;
559
560                Ok((n, slf))
561            }
562
563            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
564            where
565                I: IntoIterator<Item = u8>,
566            {
567                let mut bytes = bytes.into_iter();
568                let (mut n, discriminant) = u32::unpack_iter(bytes.by_ref())?;
569                let slf;
570
571                #block_unpackable_iter;
572
573                Ok((n, slf))
574            }
575        }
576    }
577}
578
579#[proc_macro_derive(MsgPacker, attributes(msgpacker))]
580pub fn msg_packer(input: TokenStream) -> TokenStream {
581    let input = parse_macro_input!(input as DeriveInput);
582
583    let name = input.ident;
584    let generics = input.generics;
585    let data = input.data;
586    match data {
587        Data::Struct(DataStruct {
588            fields: Fields::Named(f),
589            ..
590        }) => impl_fields_named(name, f, &generics).into(),
591
592        Data::Struct(DataStruct {
593            fields: Fields::Unnamed(f),
594            ..
595        }) => impl_fields_unnamed(name, f, &generics).into(),
596
597        Data::Struct(DataStruct {
598            fields: Fields::Unit,
599            ..
600        }) => impl_fields_unit(name, &generics).into(),
601
602        Data::Enum(DataEnum { variants, .. }) => impl_fields_enum(name, variants, &generics).into(),
603
604        Data::Union(DataUnion { .. }) => {
605            todo!(
606                "union support is not implemented for derive macro; implement the traits manually"
607            )
608        }
609    }
610}