Skip to main content

msgpacker_derive/
lib.rs

1#![crate_type = "proc-macro"]
2extern crate proc_macro;
3
4// This code is bad and should be refactored into something cleaner. Maybe some syn-based
5// framework?
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::punctuated::Punctuated;
10use syn::{
11    parse_macro_input, parse_quote, parse_str, Block, Data, DataEnum, DataStruct, DataUnion,
12    DeriveInput, Expr, ExprMatch, ExprTuple, Field, FieldPat, FieldValue, Fields, FieldsNamed,
13    FieldsUnnamed, GenericArgument, Ident, Member, Meta, Pat, PatIdent, PathArguments, Token, Type,
14    Variant,
15};
16
17fn contains_attribute(field: &Field, name: &str) -> bool {
18    let name = name.to_string();
19    if let Some(attr) = field.attrs.first() {
20        if let Meta::List(list) = &attr.meta {
21            if list.path.is_ident("msgpacker")
22                && list
23                    .tokens
24                    .clone()
25                    .into_iter()
26                    .find(|a| a.to_string() == name)
27                    .is_some()
28            {
29                return true;
30            }
31        }
32    }
33    false
34}
35
36fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into<TokenStream> {
37    let mut values: Punctuated<FieldValue, Token![,]> = Punctuated::new();
38    let block_packable: Block = parse_quote! {
39        {
40            let mut n = 0;
41        }
42    };
43    let block_unpackable: Block = parse_quote! {
44        {
45            let mut n = 0;
46        }
47    };
48    let block_unpackable_iter: Block = parse_quote! {
49        {
50            let mut bytes = bytes.into_iter();
51            let mut n = 0;
52        }
53    };
54
55    let (
56        mut block_packable,
57        mut block_unpackable,
58        mut block_unpackable_iter
59    ) = f.named.into_pairs().map(|p| p.into_value()).fold(
60            (block_packable, block_unpackable, block_unpackable_iter),
61            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), field| {
62                let ident = field.ident.as_ref().cloned().unwrap();
63                let ty = field.ty.clone();
64
65                let mut is_vec = false;
66                let mut is_vec_u8 = false;
67
68                match &ty {
69                    Type::Path(p) if p.path.segments.last().filter(|p| p.ident == "Vec").is_some() => {
70                        is_vec = true;
71                        match &p.path.segments.last().unwrap().arguments {
72                            PathArguments::AngleBracketed(a) if a.args.len() == 1 => {
73                                if let Some(GenericArgument::Type(Type::Path(p))) = a.args.first() {
74                                    if p.path.segments.last().filter(|p| p.ident == "u8").is_some() {
75                                        is_vec_u8 = true;
76                                    }
77                                }
78                            }
79                            _ => (),
80                        }
81                    }
82
83                    _ => (),
84                }
85
86                if contains_attribute(&field, "map") {
87                    block_packable.stmts.push(parse_quote! {
88                        {
89                        use ::msgpacker::wrap::*;
90                        n += (&&::msgpacker::wrap::ShadowWrap(&self.#ident)).pack_wrap(buf);
91                        }
92                        n += ::msgpacker::pack_map(buf, &self.#ident);
93                    });
94
95                    block_unpackable.stmts.push(parse_quote! {
96                        let #ident = ::msgpacker::unpack_map(buf).map(|(nv, t)| {
97                            n += nv;
98                            buf = &buf[nv..];
99                            t
100                        })?;
101                    });
102
103                    block_unpackable_iter.stmts.push(parse_quote! {
104                        let #ident = ::msgpacker::unpack_map_iter(bytes.by_ref()).map(|(nv, t)| {
105                            n += nv;
106                            t
107                        })?;
108                    });
109                } else if contains_attribute(&field, "array") || is_vec && !is_vec_u8 {
110                    block_packable.stmts.push(parse_quote! {
111                        n += ::msgpacker::pack_array(buf, &self.#ident);
112                    });
113
114                    block_unpackable.stmts.push(parse_quote! {
115                        let #ident = ::msgpacker::unpack_array(buf).map(|(nv, t)| {
116                            n += nv;
117                            buf = &buf[nv..];
118                            t
119                        })?;
120                    });
121
122                    block_unpackable_iter.stmts.push(parse_quote! {
123                        let #ident = ::msgpacker::unpack_array_iter(bytes.by_ref()).map(|(nv, t)| {
124                            n += nv;
125                            t
126                        })?;
127                    });
128                } else {
129                    block_packable.stmts.push(parse_quote! {
130                        {
131                        use ::msgpacker::wrap::*;
132                        n += (&&::msgpacker::wrap::ShadowWrap(&self.#ident)).pack_wrap(buf);
133                        }
134                    });
135
136                    block_unpackable.stmts.push(parse_quote! {
137                        let #ident = {
138                            use ::msgpacker::wrap::*;
139                            (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_wrap(buf)
140                        }.map(|(nv, t)| {
141                            n += nv;
142                            buf = &buf[nv..];
143                            t
144                        })?;
145                    });
146
147                    block_unpackable_iter.stmts.push(parse_quote! {
148                        let #ident = {
149                            use ::msgpacker::wrap::*;
150                            (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_iter_wrap(bytes.by_ref())
151                        }.map(|(nv, t)| {
152                            n += nv;
153                            t
154                        })?;
155                    });
156                }
157
158
159                values.push(FieldValue {
160                    attrs: vec![],
161                    member: Member::Named(ident.clone()),
162                    colon_token: Some(<Token![:]>::default()),
163                    expr: parse_quote! { #ident },
164                });
165
166                (block_packable, block_unpackable, block_unpackable_iter)
167            },
168        );
169
170    block_packable.stmts.push(parse_quote! {
171        return n;
172    });
173
174    block_unpackable.stmts.push(parse_quote! {
175        return Ok((
176            n,
177            Self {
178                #values
179            },
180        ));
181    });
182
183    block_unpackable_iter.stmts.push(parse_quote! {
184        return Ok((
185            n,
186            Self {
187                #values
188            },
189        ));
190    });
191
192    quote! {
193        impl ::msgpacker::Packable for #name {
194            fn pack<T>(&self, buf: &mut T) -> usize
195            where
196                T: Extend<u8>,
197                #block_packable
198        }
199
200        impl ::msgpacker::Unpackable for #name {
201            type Error = ::msgpacker::Error;
202
203            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
204                #block_unpackable
205
206            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
207            where
208                I: IntoIterator<Item = u8>,
209                #block_unpackable_iter
210        }
211    }
212}
213
214fn impl_fields_unnamed(name: Ident, f: FieldsUnnamed) -> impl Into<TokenStream> {
215    let mut values: Punctuated<Expr, Token![,]> = Punctuated::new();
216    let block_packable: Block = parse_quote! {
217        {
218            let mut n = 0;
219        }
220    };
221    let block_unpackable: Block = parse_quote! {
222        {
223            let mut n = 0;
224        }
225    };
226    let block_unpackable_iter: Block = parse_quote! {
227        {
228            let mut bytes = bytes.into_iter();
229            let mut n = 0;
230        }
231    };
232
233    let (mut block_packable, mut block_unpackable, mut block_unpackable_iter) = f
234        .unnamed
235        .into_pairs()
236        .map(|p| p.into_value())
237        .enumerate()
238        .fold(
239            (block_packable, block_unpackable, block_unpackable_iter),
240            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), (i, field)| {
241                let ty = field.ty.clone();
242                let var: Expr = parse_str(format!("v{}", i).as_str()).unwrap();
243                let slf: Expr = parse_str(format!("self.{}", i).as_str()).unwrap();
244
245                if contains_attribute(&field, "map") {
246                    todo!("unnamed map is not implemented for derive macro; implement the traits manually")
247                } else if contains_attribute(&field, "array") {
248                    todo!("unnamed array is not implemented for derive macro; implement the traits manually")
249                } else {
250                    block_packable.stmts.push(parse_quote! {
251                        {
252                        use ::msgpacker::wrap::*;
253                        n += (&&::msgpacker::wrap::ShadowWrap(&#slf)).pack_wrap(buf);
254                        }
255                    });
256
257                    block_unpackable.stmts.push(parse_quote! {
258                        let #var = {
259                            use ::msgpacker::wrap::*;
260                            (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_wrap(buf)
261                        }.map(|(nv, t)| {
262                            n += nv;
263                            buf = &buf[nv..];
264                            t
265                        })?;
266                    });
267
268                    block_unpackable_iter.stmts.push(parse_quote! {
269                        let #var = {
270                            use ::msgpacker::wrap::*;
271                            (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_iter_wrap(bytes.by_ref())
272                        }.map(|(nv, t)| {
273                            n += nv;
274                            t
275                        })?;
276                    });
277                }
278
279                values.push(var);
280
281                (block_packable, block_unpackable, block_unpackable_iter)
282            },
283        );
284
285    block_packable.stmts.push(parse_quote! {
286        return n;
287    });
288
289    block_unpackable.stmts.push(parse_quote! {
290        return Ok((n, Self(#values)));
291    });
292
293    block_unpackable_iter.stmts.push(parse_quote! {
294        return Ok((n, Self(#values)));
295    });
296
297    quote! {
298        impl ::msgpacker::Packable for #name {
299            fn pack<T>(&self, buf: &mut T) -> usize
300            where
301                T: Extend<u8>,
302                #block_packable
303        }
304
305        impl ::msgpacker::Unpackable for #name {
306            type Error = ::msgpacker::Error;
307
308            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
309                #block_unpackable
310
311            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
312            where
313                I: IntoIterator<Item = u8>,
314                #block_unpackable_iter
315        }
316    }
317}
318
319fn impl_fields_unit(name: Ident) -> impl Into<TokenStream> {
320    quote! {
321        impl ::msgpacker::Packable for #name {
322            fn pack<T>(&self, _buf: &mut T) -> usize
323            where
324                T: Extend<u8>,
325            {
326                0
327            }
328        }
329
330        impl ::msgpacker::Unpackable for #name {
331            type Error = ::msgpacker::Error;
332
333            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
334                Ok((0, Self))
335            }
336
337            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
338            where
339                I: IntoIterator<Item = u8>,
340            {
341                Ok((0, Self))
342            }
343        }
344    }
345}
346
347fn impl_fields_enum(name: Ident, v: Punctuated<Variant, Token![,]>) -> impl Into<TokenStream> {
348    if v.is_empty() {
349        todo!("empty enum is not implemented for derive macro; implement the traits manually");
350    }
351
352    let mut block_packable: ExprMatch = parse_quote! {
353        match self {
354        }
355    };
356
357    let mut block_unpackable: ExprMatch = parse_quote! {
358        match discriminant {
359        }
360    };
361
362    let mut block_unpackable_iter: ExprMatch = parse_quote! {
363        match discriminant {
364        }
365    };
366
367    v.into_iter().enumerate().for_each(|(i, v)| {
368        let discriminant = v
369            .discriminant
370            .map(|(_, d)| d)
371            .unwrap_or_else(|| parse_str(format!("{}", i).as_str()).unwrap());
372
373        // TODO check attributes of the field
374        let ident = v.ident.clone();
375        match v.fields {
376            Fields::Named(f) => {
377                let mut blk: Block = parse_str("{}").unwrap();
378                let mut blk_unpack: Block = parse_str("{}").unwrap();
379                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
380                let mut blk_unpack_fields: Punctuated<FieldValue, Token![,]> = Punctuated::new();
381
382                blk.stmts.push(parse_quote! {
383                    {
384                    use ::msgpacker::wrap::*;
385                    n += (&&::msgpacker::wrap::ShadowWrap(&(#discriminant as u32))).pack_wrap(buf);
386                    }
387                });
388
389                f.named
390                    .iter()
391                    .filter_map(|n| n.ident.as_ref())
392                    .for_each(|field| {
393                        blk.stmts.push(parse_quote! {
394                            {
395                            use ::msgpacker::wrap::*;
396                            n += (&&::msgpacker::wrap::ShadowWrap(&(#field))).pack_wrap(buf);
397                            }
398                        });
399
400                        blk_unpack_fields.push(parse_quote! { #field });
401
402                        blk_unpack.stmts.push(parse_quote! {
403                            let #field = ::msgpacker::Unpackable::unpack(buf).map(|(nv, t)| {
404                                n += nv;
405                                buf = &buf[nv..];
406                                t
407                            })?;
408                        });
409
410                        blk_unpack_iter.stmts.push(parse_quote! {
411                            let #field =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
412                                n += nv;
413                                t
414                            })?;
415                        });
416                    });
417
418                let mut arm: syn::Arm = parse_quote! {
419                    #name::#ident {} => #blk,
420                };
421
422                f.named
423                    .iter()
424                    .filter_map(|n| n.ident.as_ref())
425                    .for_each(|field| {
426                        match &mut arm.pat {
427                            Pat::Struct(s) => {
428                                s.fields.push(FieldPat {
429                                    attrs: vec![],
430                                    member: Member::Named(field.clone()),
431                                    colon_token: None,
432                                    pat: Box::new(Pat::Ident(PatIdent {
433                                        attrs: vec![],
434                                        by_ref: None,
435                                        mutability: None,
436                                        ident: field.clone(),
437                                        subpat: None,
438                                    })),
439                                });
440                            }
441                            _ => todo!(
442                                "enum variant is not implemented for derive macro; implement the traits manually"
443                            ),
444                        }
445                    });
446
447                block_packable.arms.push(arm);
448
449                blk_unpack.stmts.push(parse_quote! {
450                    slf = #name::#ident { #blk_unpack_fields };
451                });
452
453                blk_unpack_iter.stmts.push(parse_quote! {
454                    slf = #name::#ident { #blk_unpack_fields };
455                });
456
457                block_unpackable.arms.push(parse_quote! {
458                    #discriminant => #blk_unpack,
459                });
460
461                block_unpackable_iter.arms.push(parse_quote! {
462                    #discriminant => #blk_unpack_iter,
463                });
464            }
465
466            Fields::Unnamed(f) => {
467                let mut blk: Block = parse_str("{}").unwrap();
468                let mut blk_unpack: Block = parse_str("{}").unwrap();
469                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
470
471                blk.stmts.push(parse_quote! {
472                    {
473                    use ::msgpacker::wrap::*;
474                    n += (&&::msgpacker::wrap::ShadowWrap(&(#discriminant as u32))).pack_wrap(buf);
475                    }
476                });
477
478                let mut tuple_arm: ExprTuple = parse_str("()").unwrap();
479                f.unnamed.iter().enumerate().for_each(|(ii, field)| {
480                    let ty = field.ty.clone();
481                    let ti: Expr = parse_str(format!("t{}", ii).as_str()).unwrap();
482                    tuple_arm.elems.push(ti.clone());
483
484                    blk.stmts.push(parse_quote! {
485                        {
486                        use ::msgpacker::wrap::*;
487                        n += (&&::msgpacker::wrap::ShadowWrap(&#ti)).pack_wrap(buf);
488                        }
489                    });
490
491                    blk_unpack.stmts.push(parse_quote! {
492                        let #ti = {
493                                use ::msgpacker::wrap::*;
494                                (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_wrap(buf)
495                            }.map(|(nv, t)| {
496                            n += nv;
497                            buf = &buf[nv..];
498                            t
499                        })?;
500                    });
501
502                    blk_unpack_iter.stmts.push(parse_quote! {
503                        let #ti = {
504                                use ::msgpacker::wrap::*;
505                                (&&::msgpacker::wrap::ShadowWrapUnpack::<#ty>(::core::marker::PhantomData)).unpack_iter_wrap(bytes.by_ref())
506                            }.map(|(nv, t)| {
507                            n += nv;
508                            t
509                        })?;
510                    });
511                });
512
513                blk_unpack.stmts.push(parse_quote! {
514                    slf = #name::#ident #tuple_arm;
515                });
516
517                blk_unpack_iter.stmts.push(parse_quote! {
518                    slf = #name::#ident #tuple_arm;
519                });
520
521                block_packable.arms.push(parse_quote! {
522                    #name::#ident #tuple_arm => #blk,
523                });
524
525                block_unpackable.arms.push(parse_quote! {
526                    #discriminant => #blk_unpack,
527                });
528
529                block_unpackable_iter.arms.push(parse_quote! {
530                    #discriminant => #blk_unpack_iter,
531                });
532            }
533
534            Fields::Unit => {
535                block_packable.arms.push(parse_quote! {
536                    #name::#ident => {
537                        {
538                        use ::msgpacker::wrap::*;
539                        n += (&&::msgpacker::wrap::ShadowWrap(&(#discriminant as u32))).pack_wrap(buf);
540                        }
541                    }
542                });
543
544                block_unpackable.arms.push(parse_quote! {
545                    #discriminant => slf = #name::#ident,
546                });
547
548                block_unpackable_iter.arms.push(parse_quote! {
549                    #discriminant => slf = #name::#ident,
550                });
551            }
552        }
553    });
554
555    block_unpackable.arms.push(parse_quote! {
556        _ => {
557            return Err(::msgpacker::Error::InvalidEnumVariant);
558        }
559    });
560
561    block_unpackable_iter.arms.push(parse_quote! {
562        _ => {
563            return Err(::msgpacker::Error::InvalidEnumVariant);
564        }
565    });
566
567    quote! {
568        impl ::msgpacker::Packable for #name {
569            fn pack<T>(&self, buf: &mut T) -> usize
570            where
571                T: Extend<u8>,
572            {
573                let mut n = 0;
574
575                #block_packable;
576
577                return n;
578            }
579        }
580
581        impl ::msgpacker::Unpackable for #name {
582            type Error = ::msgpacker::Error;
583
584            #[allow(unused_mut)]
585            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
586                let (mut n, discriminant) = u32::unpack(&mut buf)?;
587                buf = &buf[n..];
588                let slf;
589
590                #block_unpackable;
591
592                Ok((n, slf))
593            }
594
595            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
596            where
597                I: IntoIterator<Item = u8>,
598            {
599                let mut bytes = bytes.into_iter();
600                let (mut n, discriminant) = u32::unpack_iter(bytes.by_ref())?;
601                let slf;
602
603                #block_unpackable_iter;
604
605                Ok((n, slf))
606            }
607        }
608    }
609}
610
611#[proc_macro_derive(MsgPacker, attributes(msgpacker))]
612pub fn msg_packer(input: TokenStream) -> TokenStream {
613    let input = parse_macro_input!(input as DeriveInput);
614
615    let name = input.ident;
616    let data = input.data;
617    match data {
618        Data::Struct(DataStruct {
619            fields: Fields::Named(f),
620            ..
621        }) => impl_fields_named(name, f).into(),
622
623        Data::Struct(DataStruct {
624            fields: Fields::Unnamed(f),
625            ..
626        }) => impl_fields_unnamed(name, f).into(),
627
628        Data::Struct(DataStruct {
629            fields: Fields::Unit,
630            ..
631        }) => impl_fields_unit(name).into(),
632
633        Data::Enum(DataEnum { variants, .. }) => impl_fields_enum(name, variants).into(),
634
635        Data::Union(DataUnion { .. }) => {
636            todo!(
637                "union support is not implemented for derive macro; implement the traits manually"
638            )
639        }
640    }
641}