Skip to main content

msgpacker_derive/
lib.rs

1#![crate_type = "proc-macro"]
2extern crate proc_macro;
3
4// This code is bad and should be refactored into something cleaner. Maybe some syn-based
5// framework?
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::punctuated::Punctuated;
10use syn::{
11    parse_macro_input, parse_quote, parse_str, Block, Data, DataEnum, DataStruct, DataUnion,
12    DeriveInput, Expr, ExprMatch, ExprTuple, Field, FieldPat, FieldValue, Fields, FieldsNamed,
13    FieldsUnnamed, GenericArgument, Ident, Member, Meta, Pat, PatIdent, PathArguments, Token, Type,
14    Variant,
15};
16
17fn contains_attribute(field: &Field, name: &str) -> bool {
18    let name = name.to_string();
19    if let Some(attr) = field.attrs.first() {
20        if let Meta::List(list) = &attr.meta {
21            if list.path.is_ident("msgpacker")
22                && list
23                    .tokens
24                    .clone()
25                    .into_iter()
26                    .find(|a| a.to_string() == name)
27                    .is_some()
28            {
29                return true;
30            }
31        }
32    }
33    false
34}
35
36fn impl_fields_named(name: Ident, f: FieldsNamed) -> impl Into<TokenStream> {
37    let mut values: Punctuated<FieldValue, Token![,]> = Punctuated::new();
38    let block_packable: Block = parse_quote! {
39        {
40            let mut n = 0;
41        }
42    };
43    let block_unpackable: Block = parse_quote! {
44        {
45            let mut n = 0;
46        }
47    };
48    let block_unpackable_iter: Block = parse_quote! {
49        {
50            let mut bytes = bytes.into_iter();
51            let mut n = 0;
52        }
53    };
54
55    let (
56        mut block_packable,
57        mut block_unpackable,
58        mut block_unpackable_iter
59    ) = f.named.into_pairs().map(|p| p.into_value()).fold(
60            (block_packable, block_unpackable, block_unpackable_iter),
61            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), field| {
62                let ident = field.ident.as_ref().cloned().unwrap();
63                let ty = field.ty.clone();
64
65                let mut is_vec = false;
66                let mut is_vec_u8 = false;
67
68                match &ty {
69                    Type::Path(p) if p.path.segments.last().filter(|p| p.ident == "Vec").is_some() => {
70                        is_vec = true;
71                        match &p.path.segments.last().unwrap().arguments {
72                            PathArguments::AngleBracketed(a) if a.args.len() == 1 => {
73                                if let Some(GenericArgument::Type(Type::Path(p))) = a.args.first() {
74                                    if p.path.segments.last().filter(|p| p.ident == "u8").is_some() {
75                                        is_vec_u8 = true;
76                                    }
77                                }
78                            }
79                            _ => (),
80                        }
81                    }
82
83                    _ => (),
84                }
85
86                if contains_attribute(&field, "map") {
87                    block_packable.stmts.push(parse_quote! {
88                        n += ::msgpacker::pack_map(buf, &self.#ident);
89                    });
90
91                    block_unpackable.stmts.push(parse_quote! {
92                        let #ident = ::msgpacker::unpack_map(buf).map(|(nv, t)| {
93                            n += nv;
94                            buf = &buf[nv..];
95                            t
96                        })?;
97                    });
98
99                    block_unpackable_iter.stmts.push(parse_quote! {
100                        let #ident = ::msgpacker::unpack_map_iter(bytes.by_ref()).map(|(nv, t)| {
101                            n += nv;
102                            t
103                        })?;
104                    });
105                } else if contains_attribute(&field, "array") || is_vec && !is_vec_u8 {
106                    block_packable.stmts.push(parse_quote! {
107                        n += ::msgpacker::pack_array(buf, &self.#ident);
108                    });
109
110                    block_unpackable.stmts.push(parse_quote! {
111                        let #ident = ::msgpacker::unpack_array(buf).map(|(nv, t)| {
112                            n += nv;
113                            buf = &buf[nv..];
114                            t
115                        })?;
116                    });
117
118                    block_unpackable_iter.stmts.push(parse_quote! {
119                        let #ident = ::msgpacker::unpack_array_iter(bytes.by_ref()).map(|(nv, t)| {
120                            n += nv;
121                            t
122                        })?;
123                    });
124                } else {
125                    block_packable.stmts.push(parse_quote! {
126                        n += <#ty as ::msgpacker::Packable>::pack(&self.#ident, buf);
127                    });
128
129                    block_unpackable.stmts.push(parse_quote! {
130                        let #ident = ::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
131                            n += nv;
132                            buf = &buf[nv..];
133                            t
134                        })?;
135                    });
136
137                    block_unpackable_iter.stmts.push(parse_quote! {
138                        let #ident = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
139                            n += nv;
140                            t
141                        })?;
142                    });
143                }
144
145
146                values.push(FieldValue {
147                    attrs: vec![],
148                    member: Member::Named(ident.clone()),
149                    colon_token: Some(<Token![:]>::default()),
150                    expr: parse_quote! { #ident },
151                });
152
153                (block_packable, block_unpackable, block_unpackable_iter)
154            },
155        );
156
157    block_packable.stmts.push(parse_quote! {
158        return n;
159    });
160
161    block_unpackable.stmts.push(parse_quote! {
162        return Ok((
163            n,
164            Self {
165                #values
166            },
167        ));
168    });
169
170    block_unpackable_iter.stmts.push(parse_quote! {
171        return Ok((
172            n,
173            Self {
174                #values
175            },
176        ));
177    });
178
179    quote! {
180        impl ::msgpacker::Packable for #name {
181            fn pack<T>(&self, buf: &mut T) -> usize
182            where
183                T: Extend<u8>,
184                #block_packable
185        }
186
187        impl ::msgpacker::Unpackable for #name {
188            type Error = ::msgpacker::Error;
189
190            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
191                #block_unpackable
192
193            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
194            where
195                I: IntoIterator<Item = u8>,
196                #block_unpackable_iter
197        }
198    }
199}
200
201fn impl_fields_unnamed(name: Ident, f: FieldsUnnamed) -> impl Into<TokenStream> {
202    let mut values: Punctuated<Expr, Token![,]> = Punctuated::new();
203    let block_packable: Block = parse_quote! {
204        {
205            let mut n = 0;
206        }
207    };
208    let block_unpackable: Block = parse_quote! {
209        {
210            let mut n = 0;
211        }
212    };
213    let block_unpackable_iter: Block = parse_quote! {
214        {
215            let mut bytes = bytes.into_iter();
216            let mut n = 0;
217        }
218    };
219
220    let (mut block_packable, mut block_unpackable, mut block_unpackable_iter) = f
221        .unnamed
222        .into_pairs()
223        .map(|p| p.into_value())
224        .enumerate()
225        .fold(
226            (block_packable, block_unpackable, block_unpackable_iter),
227            |(mut block_packable, mut block_unpackable, mut block_unpackable_iter), (i, field)| {
228                let ty = field.ty.clone();
229                let var: Expr = parse_str(format!("v{}", i).as_str()).unwrap();
230                let slf: Expr = parse_str(format!("self.{}", i).as_str()).unwrap();
231
232                if contains_attribute(&field, "map") {
233                    todo!("unnamed map is not implemented for derive macro; implement the traits manually")
234                } else if contains_attribute(&field, "array") {
235                    todo!("unnamed array is not implemented for derive macro; implement the traits manually")
236                } else {
237                    block_packable.stmts.push(parse_quote! {
238                        n += <#ty as ::msgpacker::Packable>::pack(&#slf, buf);
239                    });
240
241                    block_unpackable.stmts.push(parse_quote! {
242                        let #var = ::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
243                            n += nv;
244                            buf = &buf[nv..];
245                            t
246                        })?;
247                    });
248
249                    block_unpackable_iter.stmts.push(parse_quote! {
250                        let #var = ::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
251                            n += nv;
252                            t
253                        })?;
254                    });
255                }
256
257                values.push(var);
258
259                (block_packable, block_unpackable, block_unpackable_iter)
260            },
261        );
262
263    block_packable.stmts.push(parse_quote! {
264        return n;
265    });
266
267    block_unpackable.stmts.push(parse_quote! {
268        return Ok((n, Self(#values)));
269    });
270
271    block_unpackable_iter.stmts.push(parse_quote! {
272        return Ok((n, Self(#values)));
273    });
274
275    quote! {
276        impl ::msgpacker::Packable for #name {
277            fn pack<T>(&self, buf: &mut T) -> usize
278            where
279                T: Extend<u8>,
280                #block_packable
281        }
282
283        impl ::msgpacker::Unpackable for #name {
284            type Error = ::msgpacker::Error;
285
286            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error>
287                #block_unpackable
288
289            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
290            where
291                I: IntoIterator<Item = u8>,
292                #block_unpackable_iter
293        }
294    }
295}
296
297fn impl_fields_unit(name: Ident) -> impl Into<TokenStream> {
298    quote! {
299        impl ::msgpacker::Packable for #name {
300            fn pack<T>(&self, _buf: &mut T) -> usize
301            where
302                T: Extend<u8>,
303            {
304                0
305            }
306        }
307
308        impl ::msgpacker::Unpackable for #name {
309            type Error = ::msgpacker::Error;
310
311            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
312                Ok((0, Self))
313            }
314
315            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
316            where
317                I: IntoIterator<Item = u8>,
318            {
319                Ok((0, Self))
320            }
321        }
322    }
323}
324
325fn impl_fields_enum(name: Ident, v: Punctuated<Variant, Token![,]>) -> impl Into<TokenStream> {
326    if v.is_empty() {
327        todo!("empty enum is not implemented for derive macro; implement the traits manually");
328    }
329
330    let mut block_packable: ExprMatch = parse_quote! {
331        match self {
332        }
333    };
334
335    let mut block_unpackable: ExprMatch = parse_quote! {
336        match discriminant {
337        }
338    };
339
340    let mut block_unpackable_iter: ExprMatch = parse_quote! {
341        match discriminant {
342        }
343    };
344
345    v.into_iter().enumerate().for_each(|(i, v)| {
346        let discriminant = v
347            .discriminant
348            .map(|(_, d)| d)
349            .unwrap_or_else(|| parse_str(format!("{}", i).as_str()).unwrap());
350
351        // TODO check attributes of the field
352        let ident = v.ident.clone();
353        match v.fields {
354            Fields::Named(f) => {
355                let mut blk: Block = parse_str("{}").unwrap();
356                let mut blk_unpack: Block = parse_str("{}").unwrap();
357                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
358                let mut blk_unpack_fields: Punctuated<FieldValue, Token![,]> = Punctuated::new();
359
360                blk.stmts.push(parse_quote! {
361                    n += (#discriminant as u32).pack(buf);
362                });
363
364                f.named
365                    .iter()
366                    .filter_map(|n| n.ident.as_ref())
367                    .for_each(|field| {
368                        blk.stmts.push(parse_quote! {
369                            n += #field.pack(buf);
370                        });
371
372                        blk_unpack_fields.push(parse_quote! { #field });
373
374                        blk_unpack.stmts.push(parse_quote! {
375                            let #field =::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
376                                n += nv;
377                                buf = &buf[nv..];
378                                t
379                            })?;
380                        });
381
382                        blk_unpack_iter.stmts.push(parse_quote! {
383                            let #field =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
384                                n += nv;
385                                t
386                            })?;
387                        });
388                    });
389
390                let mut arm: syn::Arm = parse_quote! {
391                    #name::#ident {} => #blk,
392                };
393
394                f.named
395                    .iter()
396                    .filter_map(|n| n.ident.as_ref())
397                    .for_each(|field| {
398                        match &mut arm.pat {
399                            Pat::Struct(s) => {
400                                s.fields.push(FieldPat {
401                                    attrs: vec![],
402                                    member: Member::Named(field.clone()),
403                                    colon_token: None,
404                                    pat: Box::new(Pat::Ident(PatIdent {
405                                        attrs: vec![],
406                                        by_ref: None,
407                                        mutability: None,
408                                        ident: field.clone(),
409                                        subpat: None,
410                                    })),
411                                });
412                            }
413                            _ => todo!(
414                                "enum variant is not implemented for derive macro; implement the traits manually"
415                            ),
416                        }
417                    });
418
419                block_packable.arms.push(arm);
420
421                blk_unpack.stmts.push(parse_quote! {
422                    slf = #name::#ident { #blk_unpack_fields };
423                });
424
425                blk_unpack_iter.stmts.push(parse_quote! {
426                    slf = #name::#ident { #blk_unpack_fields };
427                });
428
429                block_unpackable.arms.push(parse_quote! {
430                    #discriminant => #blk_unpack,
431                });
432
433                block_unpackable_iter.arms.push(parse_quote! {
434                    #discriminant => #blk_unpack_iter,
435                });
436            }
437
438            Fields::Unnamed(f) => {
439                let mut blk: Block = parse_str("{}").unwrap();
440                let mut blk_unpack: Block = parse_str("{}").unwrap();
441                let mut blk_unpack_iter: Block = parse_str("{}").unwrap();
442
443                blk.stmts.push(parse_quote! {
444                    n += (#discriminant as u32).pack(buf);
445                });
446
447                let mut tuple_arm: ExprTuple = parse_str("()").unwrap();
448                f.unnamed.iter().enumerate().for_each(|(ii, _field)| {
449                    let ti: Expr = parse_str(format!("t{}", ii).as_str()).unwrap();
450                    tuple_arm.elems.push(ti.clone());
451
452                    blk.stmts.push(parse_quote! {
453                        n += #ti.pack(buf);
454                    });
455
456                    blk_unpack.stmts.push(parse_quote! {
457                        let #ti =::msgpacker::Unpackable::unpack_with_ofs(buf).map(|(nv, t)| {
458                            n += nv;
459                            buf = &buf[nv..];
460                            t
461                        })?;
462                    });
463
464                    blk_unpack_iter.stmts.push(parse_quote! {
465                        let #ti =::msgpacker::Unpackable::unpack_iter(bytes.by_ref()).map(|(nv, t)| {
466                            n += nv;
467                            t
468                        })?;
469                    });
470                });
471
472                blk_unpack.stmts.push(parse_quote! {
473                    slf = #name::#ident #tuple_arm;
474                });
475
476                blk_unpack_iter.stmts.push(parse_quote! {
477                    slf = #name::#ident #tuple_arm;
478                });
479
480                block_packable.arms.push(parse_quote! {
481                    #name::#ident #tuple_arm => #blk,
482                });
483
484                block_unpackable.arms.push(parse_quote! {
485                    #discriminant => #blk_unpack,
486                });
487
488                block_unpackable_iter.arms.push(parse_quote! {
489                    #discriminant => #blk_unpack_iter,
490                });
491            }
492
493            Fields::Unit => {
494                block_packable.arms.push(parse_quote! {
495                    #name::#ident => {
496                        n += (#discriminant as u32).pack(buf);
497                    }
498                });
499
500                block_unpackable.arms.push(parse_quote! {
501                    #discriminant => slf = #name::#ident,
502                });
503
504                block_unpackable_iter.arms.push(parse_quote! {
505                    #discriminant => slf = #name::#ident,
506                });
507            }
508        }
509    });
510
511    block_unpackable.arms.push(parse_quote! {
512        _ => {
513            return Err(::msgpacker::Error::InvalidEnumVariant);
514        }
515    });
516
517    block_unpackable_iter.arms.push(parse_quote! {
518        _ => {
519            return Err(::msgpacker::Error::InvalidEnumVariant);
520        }
521    });
522
523    quote! {
524        impl ::msgpacker::Packable for #name {
525            fn pack<T>(&self, buf: &mut T) -> usize
526            where
527                T: Extend<u8>,
528            {
529                let mut n = 0;
530
531                #block_packable;
532
533                return n;
534            }
535        }
536
537        impl ::msgpacker::Unpackable for #name {
538            type Error = ::msgpacker::Error;
539
540            #[allow(unused_mut)]
541            fn unpack_with_ofs(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
542                let (mut n, discriminant) = u32::unpack_with_ofs(&mut buf)?;
543                buf = &buf[n..];
544                let slf;
545
546                #block_unpackable;
547
548                Ok((n, slf))
549            }
550
551            fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
552            where
553                I: IntoIterator<Item = u8>,
554            {
555                let mut bytes = bytes.into_iter();
556                let (mut n, discriminant) = u32::unpack_iter(bytes.by_ref())?;
557                let slf;
558
559                #block_unpackable_iter;
560
561                Ok((n, slf))
562            }
563        }
564    }
565}
566
567#[proc_macro_derive(MsgPacker, attributes(msgpacker))]
568pub fn msg_packer(input: TokenStream) -> TokenStream {
569    let input = parse_macro_input!(input as DeriveInput);
570
571    let name = input.ident;
572    let data = input.data;
573    match data {
574        Data::Struct(DataStruct {
575            fields: Fields::Named(f),
576            ..
577        }) => impl_fields_named(name, f).into(),
578
579        Data::Struct(DataStruct {
580            fields: Fields::Unnamed(f),
581            ..
582        }) => impl_fields_unnamed(name, f).into(),
583
584        Data::Struct(DataStruct {
585            fields: Fields::Unit,
586            ..
587        }) => impl_fields_unit(name).into(),
588
589        Data::Enum(DataEnum { variants, .. }) => impl_fields_enum(name, variants).into(),
590
591        Data::Union(DataUnion { .. }) => {
592            todo!(
593                "union support is not implemented for derive macro; implement the traits manually"
594            )
595        }
596    }
597}