1mod repr;
2
3use proc_macro::TokenStream;
4use repr::Repr;
5use std::usize;
6use syn::{
7 Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, Generics, Index, Path, TraitBound,
8 TraitBoundModifier, TypeParamBound,
9};
10
11fn add_trait_bounds(generics: &mut Generics, path: Path) {
12 for param in generics.type_params_mut() {
13 param.bounds.push(TypeParamBound::Trait(TraitBound {
14 paren_token: None,
15 modifier: TraitBoundModifier::None,
16 lifetimes: None,
17 path: path.clone(),
18 }));
19 }
20}
21
22fn encode_struct(data: DataStruct) -> proc_macro2::TokenStream {
23 match data.fields {
24 Fields::Unnamed(fields) => {
25 let tys = fields.unnamed.iter().map(|field| &field.ty);
26 let counter = (0..usize::MAX).map(Index::from);
27
28 quote::quote! {
29 #(<#tys as ::binser::Encode>::encode(self.#counter, writer));*
30 }
31 }
32 Fields::Named(fields) => {
33 let tys = fields.named.iter().map(|field| &field.ty);
34 let names = fields
35 .named
36 .iter()
37 .map(|field| field.ident.as_ref().unwrap());
38
39 quote::quote! {
40 #(<#tys as ::binser::Encode>::encode(self.#names, writer));*
41 }
42 }
43 Fields::Unit => quote::quote! {},
44 }
45}
46
47fn encode_enum(repr: Repr, data: DataEnum) -> proc_macro2::TokenStream {
48 for variant in &data.variants {
49 match variant.fields {
50 Fields::Unit => {}
51 _ => panic!("enum fields must not contain any data"),
52 }
53 }
54
55 let names = data.variants.iter().map(|variant| &variant.ident);
56 let discriminants = enum_discriminants(&data);
57
58 quote::quote! {
59 match self {
60 #(Self::#names => writer.write::<#repr>(#discriminants),)*
61 }
62 }
63}
64
65fn enum_discriminants(data: &DataEnum) -> impl Iterator<Item = &Expr> {
66 data.variants
67 .iter()
68 .map(|variant| match variant.discriminant.as_ref() {
69 Some(discriminant) => &discriminant.1,
70 None => panic!("enums must have explicit discriminants"),
71 })
72}
73
74#[proc_macro_derive(Encode)]
75pub fn derive_encode(input: TokenStream) -> TokenStream {
76 let mut input = syn::parse_macro_input!(input as DeriveInput);
77 let body = match input.data {
78 Data::Struct(data) => encode_struct(data),
79 Data::Enum(data) => {
80 let repr = match Repr::parse(&input.attrs) {
81 Ok(repr) => repr,
82 Err(err) => panic!("failed to parse repr: {}", err),
83 };
84
85 encode_enum(repr, data)
86 }
87 Data::Union(_) => panic!("only structs and enums can #[derive(Encode)]"),
88 };
89
90 add_trait_bounds(&mut input.generics, syn::parse_quote! { ::binser::Encode });
91
92 let name = input.ident;
93 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
94
95 (quote::quote! {
96 impl #impl_generics ::binser::Encode for #name #ty_generics #where_clause {
97 fn encode(self, writer: &mut ::binser::Writer) {
98 #body
99 }
100 }
101 })
102 .into()
103}
104
105fn decode_struct(data: DataStruct) -> proc_macro2::TokenStream {
106 match data.fields {
107 Fields::Unnamed(fields) => {
108 let tys = fields.unnamed.iter().map(|field| &field.ty);
109
110 quote::quote! {
111 Ok(Self(#(<#tys as ::binser::Decode>::decode(reader)?),*))
112 }
113 }
114 Fields::Named(fields) => {
115 let names = fields
116 .named
117 .iter()
118 .map(|field| field.ident.as_ref().unwrap());
119 let tys = fields.named.iter().map(|field| &field.ty);
120
121 quote::quote! {
122 Ok(Self {
123 #(#names: <#tys as ::binser::Decode>::decode(reader)?),*
124 })
125 }
126 }
127 Fields::Unit => quote::quote! { Ok(Self) },
128 }
129}
130
131fn decode_enum(repr: Repr, data: DataEnum) -> proc_macro2::TokenStream {
132 let names = data.variants.iter().map(|variant| &variant.ident);
133 let discriminants = enum_discriminants(&data);
134
135 quote::quote! {
136 let value = reader.read::<#repr>()?;
137
138 #(if value == #discriminants {
139 return Ok(Self::#names);
140 })*
141
142 Err(::binser::Error::InvalidVariant)
143 }
144}
145
146#[proc_macro_derive(Decode)]
147pub fn derive_decode(input: TokenStream) -> TokenStream {
148 let mut input = syn::parse_macro_input!(input as DeriveInput);
149 let body = match input.data {
150 Data::Struct(data) => decode_struct(data),
151 Data::Enum(data) => {
152 let repr = match Repr::parse(&input.attrs) {
153 Ok(repr) => repr,
154 Err(err) => panic!("failed to parse repr: {}", err),
155 };
156
157 decode_enum(repr, data)
158 }
159 Data::Union(_) => panic!("only structs and enums can #[derive(Encode)]"),
160 };
161
162 add_trait_bounds(&mut input.generics, syn::parse_quote! { ::binser::Decode });
163
164 let name = input.ident;
165 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
166
167 (quote::quote! {
168 impl #impl_generics ::binser::Decode for #name #ty_generics #where_clause {
169 fn decode(reader: &mut ::binser::Reader) -> Result<Self, ::binser::Error> {
170 #body
171 }
172 }
173 })
174 .into()
175}
176
177fn size_struct(data: DataStruct) -> proc_macro2::TokenStream {
178 let fields = match data.fields {
179 Fields::Unnamed(fields) => fields.unnamed,
180 Fields::Named(fields) => fields.named,
181 Fields::Unit => return quote::quote! { 0 },
182 };
183
184 let tys = fields.iter().map(|field| &field.ty);
185 quote::quote! {
186 0 #(+ <#tys as ::binser::Size>::SIZE)*
187 }
188}
189
190fn size_enum(repr: Repr) -> proc_macro2::TokenStream {
191 quote::quote! { <#repr as ::binser::Size>::SIZE }
192}
193
194#[proc_macro_derive(Size)]
195pub fn derive_size(input: TokenStream) -> TokenStream {
196 let mut input = syn::parse_macro_input!(input as DeriveInput);
197 let body = match input.data {
198 Data::Struct(data) => size_struct(data),
199 Data::Enum(_) => {
200 let repr = match Repr::parse(&input.attrs) {
201 Ok(repr) => repr,
202 Err(err) => panic!("failed to parse repr: {}", err),
203 };
204
205 size_enum(repr)
206 }
207 Data::Union(_) => panic!("only structs and enums can #[derive(Size)]"),
208 };
209
210 add_trait_bounds(&mut input.generics, syn::parse_quote! { ::binser::Size });
211
212 let name = input.ident;
213 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
214
215 (quote::quote! {
216 impl #impl_generics ::binser::Size for #name #ty_generics #where_clause {
217 const SIZE: usize = #body;
218 }
219 })
220 .into()
221}