msgpacker_derive/
lib.rs

1#![crate_type = "proc-macro"]
2extern crate proc_macro;
3
4// This code is bad and should be refactored into something cleaner. Maybe some syn-based
5// framework?
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::punctuated::Punctuated;
10use syn::{
11    parse_macro_input, parse_quote, parse_str, Block, Data, DataEnum, DataStruct, DataUnion,
12    DeriveInput, Expr, ExprMatch, ExprTuple, Field, FieldPat, FieldValue, Fields, FieldsNamed,
13    FieldsUnnamed, GenericArgument, Ident, Member, Meta, Pat, PatIdent, PathArguments, Token, Type,
14    Variant,
15};
16
17fn contains_attribute(field: &Field, name: &str) -> bool {
18    let name = name.to_string();
19    if let Some(attr) = field.attrs.first() {
20        if let Meta::List(list) = &attr.meta {
21            if list.path.is_ident("msgpacker") {
22                if list
23                    .tokens
24                    .clone()
25                    .into_iter()
26                    .find(|a| a.to_string() == name)
27                    .is_some()
28                {
29                    return true;
30                }
31            }
32        }
33    }
34    false
35}
36
37fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into<TokenStream> {
38    let mut values: Punctuated<FieldValue, Token![,]> = Punctuated::new();
39    let block_packable: Block = parse_quote! {
40        {
41            let mut n = 0;
42        }
43    };
44    let block_unpackable: Block = parse_quote! {
45        {
46            let mut n = 0;
47        }
48    };
49    let block_unpackable_iter: Block = parse_quote! {
50        {
51            let mut bytes = bytes.into_iter();
52            let mut n = 0;
53        }
54    };
55
56    let (
57        mut block_packable,
58        mut block_unpackable,
59        mut block_unpackable_iter
60    ) = f.named.into_pairs().map(|p| p.into_value()).fold(
61            (block_packable, block_unpackable, block_unpackable_iter),
62            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), field| {
63                let ident = field.ident.as_ref().cloned().unwrap();
64                let ty = field.ty.clone();
65
66                let mut is_vec = false;
67                let mut is_vec_u8 = false;
68
69                match &ty {
70                    Type::Path(p) if p.path.segments.last().filter(|p| p.ident.to_string() == "Vec").is_some() => {
71                        is_vec = true;
72                        match &p.path.segments.last().unwrap().arguments {
73                            PathArguments::AngleBracketed(a) if a.args.len() == 1 => {
74                                if let Some(GenericArgument::Type(Type::Path(p))) = a.args.first() {
75                                    if p.path.segments.last().filter(|p| p.ident.to_string() == "u8").is_some() {
76                                        is_vec_u8 = true;
77                                    }
78                                }
79                            }
80                            _ => (),
81                        }
82                    }
83
84                    _ => (),
85                }
86
87                if contains_attribute(&field, "map") {
88                    block_packable.stmts.push(parse_quote! {
89                        n += ::msgpacker::pack_map(buf, &self.#ident);
90                    });
91
92                    block_unpackable.stmts.push(parse_quote! {
93                        let #ident = ::msgpacker::unpack_map(buf).map(|(nv, t)| {
94                            n += nv;
95                            buf = &buf[nv..];
96                            t
97                        })?;
98                    });
99
100                    block_unpackable_iter.stmts.push(parse_quote! {
101                        let #ident = ::msgpacker::unpack_map_iter(bytes.by_ref()).map(|(nv, t)| {
102                            n += nv;
103                            t
104                        })?;
105                    });
106                } else if contains_attribute(&field, "array") || is_vec && !is_vec_u8 {
107                    block_packable.stmts.push(parse_quote! {
108                        n += ::msgpacker::pack_array(buf, &self.#ident);
109                    });
110
111                    block_unpackable.stmts.push(parse_quote! {
112                        let #ident = ::msgpacker::unpack_array(buf).map(|(nv, t)| {
113                            n += nv;
114                            buf = &buf[nv..];
115                            t
116                        })?;
117                    });
118
119                    block_unpackable_iter.stmts.push(parse_quote! {
120                        let #ident = ::msgpacker::unpack_array_iter(bytes.by_ref()).map(|(nv, t)| {
121                            n += nv;
122                            t
123                        })?;
124                    });
125                } else {
126                    block_packable.stmts.push(parse_quote! {
127                        n += <#ty as ::msgpacker::Packable>::pack(&self.#ident, buf);
128                    });
129
130                    block_unpackable.stmts.push(parse_quote! {
131                        let #ident = ::msgpacker::Unpackable::unpack(buf).map(|(nv, t)| {
132                            n += nv;
133                            buf = &buf[nv..];
134                            t
135                        })?;
136                    });
137
138                    block_unpackable_iter.stmts.push(parse_quote! {
139                        let #ident = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
140                            n += nv;
141                            t
142                        })?;
143                    });
144                }
145
146
147                values.push(FieldValue {
148                    attrs: vec![],
149                    member: Member::Named(ident.clone()),
150                    colon_token: Some(<Token![:]>::default()),
151                    expr: parse_quote! { #ident },
152                });
153
154                (block_packable, block_unpackable, block_unpackable_iter)
155            },
156        );
157
158    block_packable.stmts.push(parse_quote! {
159        return n;
160    });
161
162    block_unpackable.stmts.push(parse_quote! {
163        return Ok((
164            n,
165            Self {
166                #values
167            },
168        ));
169    });
170
171    block_unpackable_iter.stmts.push(parse_quote! {
172        return Ok((
173            n,
174            Self {
175                #values
176            },
177        ));
178    });
179
180    quote! {
181        impl ::msgpacker::Packable for #name {
182            fn pack<T>(&self, buf: &mut T) -> usize
183            where
184                T: Extend<u8>,
185                #block_packable
186        }
187
188        impl ::msgpacker::Unpackable for #name {
189            type Error = ::msgpacker::Error;
190
191            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
192                #block_unpackable
193
194            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
195            where
196                I: IntoIterator<Item = u8>,
197                #block_unpackable_iter
198        }
199    }
200}
201
202fn impl_fields_unnamed(name: Ident, f: FieldsUnnamed) -> impl Into<TokenStream> {
203    let mut values: Punctuated<Expr, Token![,]> = Punctuated::new();
204    let block_packable: Block = parse_quote! {
205        {
206            let mut n = 0;
207        }
208    };
209    let block_unpackable: Block = parse_quote! {
210        {
211            let mut n = 0;
212        }
213    };
214    let block_unpackable_iter: Block = parse_quote! {
215        {
216            let mut bytes = bytes.into_iter();
217            let mut n = 0;
218        }
219    };
220
221    let (mut block_packable, mut block_unpackable, mut block_unpackable_iter) = f
222        .unnamed
223        .into_pairs()
224        .map(|p| p.into_value())
225        .enumerate()
226        .fold(
227            (block_packable, block_unpackable, block_unpackable_iter),
228            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), (i, field)| {
229                let ty = field.ty.clone();
230                let var: Expr = parse_str(format!("v{}", i).as_str()).unwrap();
231                let slf: Expr = parse_str(format!("self.{}", i).as_str()).unwrap();
232
233                if contains_attribute(&field, "map") {
234                    todo!("unnamed map is not implemented for derive macro; implement the traits manually")
235                } else if contains_attribute(&field, "array") {
236                    todo!("unnamed array is not implemented for derive macro; implement the traits manually")
237                } else {
238                    block_packable.stmts.push(parse_quote! {
239                        n += <#ty as ::msgpacker::Packable>::pack(&#slf, buf);
240                    });
241
242                    block_unpackable.stmts.push(parse_quote! {
243                        let #var = ::msgpacker::Unpackable::unpack(buf).map(|(nv, t)| {
244                            n += nv;
245                            buf = &buf[nv..];
246                            t
247                        })?;
248                    });
249
250                    block_unpackable_iter.stmts.push(parse_quote! {
251                        let #var = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
252                            n += nv;
253                            t
254                        })?;
255                    });
256                }
257
258                values.push(var);
259
260                (block_packable, block_unpackable, block_unpackable_iter)
261            },
262        );
263
264    block_packable.stmts.push(parse_quote! {
265        return n;
266    });
267
268    block_unpackable.stmts.push(parse_quote! {
269        return Ok((n, Self(#values)));
270    });
271
272    block_unpackable_iter.stmts.push(parse_quote! {
273        return Ok((n, Self(#values)));
274    });
275
276    quote! {
277        impl ::msgpacker::Packable for #name {
278            fn pack<T>(&self, buf: &mut T) -> usize
279            where
280                T: Extend<u8>,
281                #block_packable
282        }
283
284        impl ::msgpacker::Unpackable for #name {
285            type Error = ::msgpacker::Error;
286
287            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
288                #block_unpackable
289
290            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
291            where
292                I: IntoIterator<Item = u8>,
293                #block_unpackable_iter
294        }
295    }
296}
297
298fn impl_fields_unit(name: Ident) -> impl Into<TokenStream> {
299    quote! {
300        impl ::msgpacker::Packable for #name {
301            fn pack<T>(&self, _buf: &mut T) -> usize
302            where
303                T: Extend<u8>,
304            {
305                0
306            }
307        }
308
309        impl ::msgpacker::Unpackable for #name {
310            type Error = ::msgpacker::Error;
311
312            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
313                Ok((0, Self))
314            }
315
316            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
317            where
318                I: IntoIterator<Item = u8>,
319            {
320                Ok((0, Self))
321            }
322        }
323    }
324}
325
326fn impl_fields_enum(name: Ident, v: Punctuated<Variant, Token![,]>) -> impl Into<TokenStream> {
327    if v.is_empty() {
328        todo!("empty enum is not implemented for derive macro; implement the traits manually");
329    }
330
331    let mut block_packable: ExprMatch = parse_quote! {
332        match self {
333        }
334    };
335
336    let mut block_unpackable: ExprMatch = parse_quote! {
337        match discriminant {
338        }
339    };
340
341    let mut block_unpackable_iter: ExprMatch = parse_quote! {
342        match discriminant {
343        }
344    };
345
346    v.into_iter().enumerate().for_each(|(i, v)| {
347        let discriminant = v
348            .discriminant
349            .map(|(_, d)| d)
350            .unwrap_or_else(|| parse_str(format!("{}", i).as_str()).unwrap());
351
352        // TODO check attributes of the field
353        let ident = v.ident.clone();
354        match v.fields {
355            Fields::Named(f) => {
356                let mut blk: Block = parse_str("{}").unwrap();
357                let mut blk_unpack: Block = parse_str("{}").unwrap();
358                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
359                let mut blk_unpack_fields: Punctuated<FieldValue, Token![,]> = Punctuated::new();
360
361                blk.stmts.push(parse_quote! {
362                    n += (#discriminant as u32).pack(buf);
363                });
364
365                f.named
366                    .iter()
367                    .filter_map(|n| n.ident.as_ref())
368                    .for_each(|field| {
369                        blk.stmts.push(parse_quote! {
370                            n += #field.pack(buf);
371                        });
372
373                        blk_unpack_fields.push(parse_quote! { #field });
374
375                        blk_unpack.stmts.push(parse_quote! {
376                            let #field =::msgpacker::Unpackable::unpack(buf).map(|(nv, t)| {
377                                n += nv;
378                                buf = &buf[nv..];
379                                t
380                            })?;
381                        });
382
383                        blk_unpack_iter.stmts.push(parse_quote! {
384                            let #field =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
385                                n += nv;
386                                t
387                            })?;
388                        });
389                    });
390
391                let mut arm: syn::Arm = parse_quote! {
392                    #name::#ident {} => #blk,
393                };
394
395                f.named
396                    .iter()
397                    .filter_map(|n| n.ident.as_ref())
398                    .for_each(|field| {
399                        match &mut arm.pat {
400                            Pat::Struct(s) => {
401                                s.fields.push(FieldPat {
402                                    attrs: vec![],
403                                    member: Member::Named(field.clone()),
404                                    colon_token: None,
405                                    pat: Box::new(Pat::Ident(PatIdent {
406                                        attrs: vec![],
407                                        by_ref: None,
408                                        mutability: None,
409                                        ident: field.clone(),
410                                        subpat: None,
411                                    })),
412                                });
413                            }
414                            _ => todo!(
415                                "enum variant is not implemented for derive macro; implement the traits manually"
416                            ),
417                        }
418                    });
419
420                block_packable.arms.push(arm);
421
422                blk_unpack.stmts.push(parse_quote! {
423                    slf = #name::#ident { #blk_unpack_fields };
424                });
425
426                blk_unpack_iter.stmts.push(parse_quote! {
427                    slf = #name::#ident { #blk_unpack_fields };
428                });
429
430                block_unpackable.arms.push(parse_quote! {
431                    #discriminant => #blk_unpack,
432                });
433
434                block_unpackable_iter.arms.push(parse_quote! {
435                    #discriminant => #blk_unpack_iter,
436                });
437            }
438
439            Fields::Unnamed(f) => {
440                let mut blk: Block = parse_str("{}").unwrap();
441                let mut blk_unpack: Block = parse_str("{}").unwrap();
442                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
443
444                blk.stmts.push(parse_quote! {
445                    n += (#discriminant as u32).pack(buf);
446                });
447
448                let mut tuple_arm: ExprTuple = parse_str("()").unwrap();
449                f.unnamed.iter().enumerate().for_each(|(ii, _field)| {
450                    let ti: Expr = parse_str(format!("t{}", ii).as_str()).unwrap();
451                    tuple_arm.elems.push(ti.clone());
452
453                    blk.stmts.push(parse_quote! {
454                        n += #ti.pack(buf);
455                    });
456
457                    blk_unpack.stmts.push(parse_quote! {
458                        let #ti =::msgpacker::Unpackable::unpack(buf).map(|(nv, t)| {
459                            n += nv;
460                            buf = &buf[nv..];
461                            t
462                        })?;
463                    });
464
465                    blk_unpack_iter.stmts.push(parse_quote! {
466                        let #ti =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
467                            n += nv;
468                            t
469                        })?;
470                    });
471                });
472
473                blk_unpack.stmts.push(parse_quote! {
474                    slf = #name::#ident #tuple_arm;
475                });
476
477                blk_unpack_iter.stmts.push(parse_quote! {
478                    slf = #name::#ident #tuple_arm;
479                });
480
481                block_packable.arms.push(parse_quote! {
482                    #name::#ident #tuple_arm => #blk,
483                });
484
485                block_unpackable.arms.push(parse_quote! {
486                    #discriminant => #blk_unpack,
487                });
488
489                block_unpackable_iter.arms.push(parse_quote! {
490                    #discriminant => #blk_unpack_iter,
491                });
492            }
493
494            Fields::Unit => {
495                block_packable.arms.push(parse_quote! {
496                    #name::#ident => {
497                        n += (#discriminant as u32).pack(buf);
498                    }
499                });
500
501                block_unpackable.arms.push(parse_quote! {
502                    #discriminant => slf = #name::#ident,
503                });
504
505                block_unpackable_iter.arms.push(parse_quote! {
506                    #discriminant => slf = #name::#ident,
507                });
508            }
509        }
510    });
511
512    block_unpackable.arms.push(parse_quote! {
513        _ => {
514            return Err(::msgpacker::Error::InvalidEnumVariant);
515        }
516    });
517
518    block_unpackable_iter.arms.push(parse_quote! {
519        _ => {
520            return Err(::msgpacker::Error::InvalidEnumVariant);
521        }
522    });
523
524    quote! {
525        impl ::msgpacker::Packable for #name {
526            fn pack<T>(&self, buf: &mut T) -> usize
527            where
528                T: Extend<u8>,
529            {
530                let mut n = 0;
531
532                #block_packable;
533
534                return n;
535            }
536        }
537
538        impl ::msgpacker::Unpackable for #name {
539            type Error = ::msgpacker::Error;
540
541            #[allow(unused_mut)]
542            fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
543                let (mut n, discriminant) = u32::unpack(&mut buf)?;
544                buf = &buf[n..];
545                let slf;
546
547                #block_unpackable;
548
549                Ok((n, slf))
550            }
551
552            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
553            where
554                I: IntoIterator<Item = u8>,
555            {
556                let mut bytes = bytes.into_iter();
557                let (mut n, discriminant) = u32::unpack_iter(bytes.by_ref())?;
558                let slf;
559
560                #block_unpackable_iter;
561
562                Ok((n, slf))
563            }
564        }
565    }
566}
567
568#[proc_macro_derive(MsgPacker, attributes(msgpacker))]
569pub fn msg_packer(input: TokenStream) -> TokenStream {
570    let input = parse_macro_input!(input as DeriveInput);
571
572    let name = input.ident;
573    let data = input.data;
574    match data {
575        Data::Struct(DataStruct {
576            fields: Fields::Named(f),
577            ..
578        }) => impl_fields_named(name, f).into(),
579
580        Data::Struct(DataStruct {
581            fields: Fields::Unnamed(f),
582            ..
583        }) => impl_fields_unnamed(name, f).into(),
584
585        Data::Struct(DataStruct {
586            fields: Fields::Unit,
587            ..
588        }) => impl_fields_unit(name).into(),
589
590        Data::Enum(DataEnum { variants, .. }) => impl_fields_enum(name, variants).into(),
591
592        Data::Union(DataUnion { .. }) => {
593            todo!(
594                "union support is not implemented for derive macro; implement the traits manually"
595            )
596        }
597    }
598}