endecode_derive/
lib.rs

1use proc_macro2::{Literal, Span, TokenTree};
2use quote::quote;
3use syn::{
4    parse_macro_input, parse_quote_spanned, Data, DeriveInput, Fields, Ident, WherePredicate,
5};
6
7#[proc_macro_derive(Encode)]
8pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9    let DeriveInput {
10        ident,
11        data,
12        mut generics,
13        ..
14    } = parse_macro_input!(input as DeriveInput);
15
16    match data {
17        Data::Struct(st) => {
18            generics
19                .make_where_clause()
20                .predicates
21                .extend(st.fields.iter().map(|field| -> WherePredicate {
22                    let ty = &field.ty;
23                    parse_quote_spanned! {
24                        Span::call_site() => #ty: endecode::encode::Encode
25                    }
26                }));
27            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
28
29            let fields = match st.fields {
30                Fields::Named(fields) => fields
31                    .named
32                    .into_iter()
33                    .map(|field| TokenTree::Ident(field.ident.clone().unwrap()))
34                    .collect(),
35                Fields::Unnamed(fields) => (0..fields.unnamed.len())
36                    .map(|field| TokenTree::Literal(Literal::usize_unsuffixed(field)))
37                    .collect(),
38                Fields::Unit => vec![],
39            };
40
41            quote! {
42                #[automatically_derived]
43                impl #impl_generics endecode::encode::Encode for #ident #ty_generics #where_clause {
44                    fn encode_internal(&self, vec: &mut Vec<u8>) {
45                        #(
46                            self.#fields.encode_internal(vec);
47                        )*
48                    }
49                }
50            }
51            .into()
52        }
53        Data::Enum(en) => {
54            let variants = en.variants.into_iter().enumerate().map(|(num, variant)| {
55                let num = Literal::usize_unsuffixed(num);
56                let ident = variant.ident;
57                match variant.fields {
58                    Fields::Named(fields) => {
59                        let fields: Vec<_> = fields
60                            .named
61                            .into_iter()
62                            .map(|field| field.ident.unwrap())
63                            .collect();
64
65                        quote! {
66                            #ident { #(#fields),* } => {
67                                vec.push(#num);
68                                #(
69                                    #fields.encode_internal(vec);
70                                )*
71                            }
72                        }
73                    }
74                    Fields::Unnamed(fields) => {
75                        let fields: Vec<_> = (0..fields.unnamed.len())
76                            .map(|field| Ident::new(&format!("_{field}"), Span::call_site()))
77                            .collect();
78
79                        quote! {
80                            #ident(#(#fields),*) => {
81                                vec.push(#num);
82                                #(
83                                    #fields.encode_internal(vec);
84                                )*
85                            }
86                        }
87                    }
88                    Fields::Unit => {
89                        quote! {
90                            #ident => vec.push(#num),
91                        }
92                    }
93                }
94            });
95
96            quote! {
97                #[automatically_derived]
98                impl endecode::encode::Encode for #ident {
99                    fn encode_internal(&self, vec: &mut Vec<u8>) {
100                        match self {
101                            #(Self::#variants)*
102                        }
103                    }
104                }
105            }
106            .into()
107        }
108        Data::Union(_) => {
109            panic!("Union... go implement it yourself");
110        }
111    }
112}
113
114#[proc_macro_derive(Decode)]
115pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116    let DeriveInput {
117        ident,
118        data,
119        mut generics,
120        ..
121    } = parse_macro_input!(input as DeriveInput);
122
123    match data {
124        Data::Struct(st) => {
125            generics
126                .make_where_clause()
127                .predicates
128                .extend(st.fields.iter().map(|field| -> WherePredicate {
129                    let ty = &field.ty;
130                    parse_quote_spanned! {
131                        Span::call_site() => #ty: endecode::encode::Encode
132                    }
133                }));
134            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
135
136            let fields = match st.fields {
137                Fields::Named(fields) => {
138                    let mut types = vec![];
139                    let fields: Vec<_> = fields
140                        .named
141                        .into_iter()
142                        .map(|field| {
143                            types.push(field.ty);
144                            field.ident.unwrap()
145                        })
146                        .collect();
147
148                    quote! {
149                        {
150                            #(
151                                #fields: <#types>::decode(iter)
152                            ),*
153                        }
154                    }
155                }
156                Fields::Unnamed(fields) => {
157                    let mut types = vec![];
158                    for field in fields.unnamed {
159                        types.push(field.ty);
160                    }
161
162                    quote! {
163                        (#(<#types>::decode(iter)),*)
164                    }
165                }
166                Fields::Unit => quote! {},
167            };
168
169            quote! {
170                #[automatically_derived]
171                impl #impl_generics endecode::decode::Decode for #ident #ty_generics #where_clause {
172                    fn decode(iter: &mut impl Iterator<Item = u8>) -> Self {
173                        Self #fields
174                    }
175                }
176            }
177            .into()
178        }
179        Data::Enum(en) => {
180            let nums = (0..en.variants.len()).map(|num| Literal::usize_unsuffixed(num));
181            let variants = en.variants.into_iter().map(|variant| {
182                let ident = variant.ident;
183                match variant.fields {
184                    Fields::Named(fields) => {
185                        let mut types = vec![];
186                        let fields: Vec<_> = fields
187                            .named
188                            .into_iter()
189                            .map(|field| {
190                                types.push(field.ty);
191                                field.ident.unwrap()
192                            })
193                            .collect();
194
195                        quote! {
196                            #ident { #(#fields: <#types>::decode(iter)),* }
197                        }
198                    }
199                    Fields::Unnamed(fields) => {
200                        let mut types = vec![];
201                        for field in fields.unnamed {
202                            types.push(field.ty);
203                        }
204
205                        quote! {
206                            #ident(#(<#types>::decode(iter)),*)
207                        }
208                    }
209                    Fields::Unit => quote! { #ident },
210                }
211            });
212
213            quote! {
214                #[automatically_derived]
215                impl endecode::decode::Decode for #ident {
216                    fn decode(iter: &mut impl Iterator<Item = u8>) -> Self {
217                        match iter.next().unwrap() {
218                            #(#nums => Self::#variants,)*
219                            i => panic!(concat!("index {} out of bounds for enum ", stringify!(#ident)), i)
220                        }
221                    }
222                }
223            }
224            .into()
225        }
226        Data::Union(_) => {
227            panic!("Now decoding?! You have to be kidding me");
228        }
229    }
230}