1use proc_macro::TokenStream;
10use proc_macro_crate::{FoundCrate, crate_name};
11use proc_macro2::{Span, TokenStream as TokenStream2};
12use quote::quote;
13use syn::{Attribute, DeriveInput, Ident, Result, Type, parse_quote, parse2};
14
15fn enum_repr_ty(attrs: &[Attribute]) -> Option<Type> {
16 let mut out: Option<Type> = None;
17 for attr in attrs {
18 if attr.path().is_ident("repr") {
19 let _ = attr.parse_nested_meta(|meta| {
20 if let Some(ident) = meta.path.get_ident() {
21 match ident.to_string().as_str() {
22 "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64"
23 | "isize" => {
24 let ty_ident = Ident::new(&ident.to_string(), Span::call_site());
25 out = Some(parse_quote!(#ty_ident));
26 }
27 _ => {}
28 }
29 }
30 Ok(())
31 });
32 }
33 }
34 out
35}
36
37fn crate_path() -> TokenStream2 {
38 let found = crate_name("lencode");
42 match found {
43 Ok(FoundCrate::Itself) => quote!(::lencode),
44 Ok(FoundCrate::Name(actual_name)) => {
45 let ident = Ident::new(&actual_name, Span::call_site());
46 quote!(::#ident)
47 }
48 Err(_) => quote!(::lencode),
49 }
50}
51
52#[proc_macro_derive(Encode)]
58pub fn derive_encode(input: TokenStream) -> TokenStream {
59 match derive_encode_impl(input) {
60 Ok(ts) => ts.into(),
61 Err(err) => err.to_compile_error().into(),
62 }
63}
64
65#[proc_macro_derive(Decode)]
69pub fn derive_decode(input: TokenStream) -> TokenStream {
70 match derive_decode_impl(input) {
71 Ok(ts) => ts.into(),
72 Err(err) => err.to_compile_error().into(),
73 }
74}
75
76#[inline(always)]
77fn derive_encode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
78 let derive_input = parse2::<DeriveInput>(input.into())?;
79 let krate = crate_path();
80 let name = derive_input.ident.clone();
81 let mut generics = derive_input.generics.clone();
83 {
84 let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
86 let where_clause = generics.make_where_clause();
87 for ident in type_idents {
88 where_clause
90 .predicates
91 .push(parse_quote!(#ident: #krate::prelude::Encode));
92 }
93 }
94 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
95 match derive_input.data {
96 syn::Data::Struct(data_struct) => {
97 let fields = data_struct.fields;
98 let encode_body = match fields {
99 syn::Fields::Named(ref named_fields) => {
100 let field_encodes = named_fields.named.iter().map(|f| {
101 let fname = &f.ident;
102 let ftype = &f.ty;
103 quote! {
104 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#fname, writer, dedupe_encoder.as_deref_mut())?;
105 }
106 });
107 quote! {
108 #(#field_encodes)*
109 }
110 }
111 syn::Fields::Unnamed(ref unnamed_fields) => {
112 let field_encodes = unnamed_fields.unnamed.iter().enumerate().map(|(i, f)| {
113 let index = syn::Index::from(i);
114 let ftype = &f.ty;
115 quote! {
116 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#index, writer, dedupe_encoder.as_deref_mut())?;
117 }
118 });
119 quote! {
120 #(#field_encodes)*
121 }
122 }
123 syn::Fields::Unit => quote! {},
124 };
125 Ok(quote! {
126 impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
127 #[inline(always)]
128 fn encode_ext(
129 &self,
130 writer: &mut impl #krate::io::Write,
131 mut dedupe_encoder: Option<&mut #krate::dedupe::DedupeEncoder>,
132 ) -> #krate::Result<usize> {
133 let mut total_bytes = 0;
134 #encode_body
135 Ok(total_bytes)
136 }
137 }
138 })
139 }
140 syn::Data::Enum(data_enum) => {
141 let is_c_like = data_enum
142 .variants
143 .iter()
144 .all(|v| matches!(v.fields, syn::Fields::Unit));
145 let repr_ty = enum_repr_ty(&derive_input.attrs);
146 let use_numeric_disc = is_c_like && repr_ty.is_some();
147 let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
148 let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
149 let vname = &v.ident;
150 let idx_lit = syn::Index::from(idx);
151 match &v.fields {
152 syn::Fields::Named(named_fields) => {
153 let fields: Vec<_> = named_fields
154 .named
155 .iter()
156 .map(|f| (f.ident.as_ref().unwrap().clone(), f.ty.clone()))
157 .collect();
158
159 let field_names: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
160 let field_encodes = fields.iter().map(|(fname, ftype)| {
161 quote! {
162 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, dedupe_encoder.as_deref_mut())?;
163 }
164 });
165 quote! {
166 #name::#vname { #(#field_names),* } => {
167 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
168 #(#field_encodes)*
169 }
170 }
171 }
172 syn::Fields::Unnamed(unnamed_fields) => {
173 let fields: Vec<_> = unnamed_fields
174 .unnamed
175 .iter()
176 .enumerate()
177 .map(|(i, f)| (Ident::new(&format!("field{}", i), Span::call_site()), f.ty.clone()))
178 .collect();
179
180 let field_indices: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
181 let field_encodes = fields.iter().map(|(fname, ftype)| {
182 quote! {
183 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, dedupe_encoder.as_deref_mut())?;
184 }
185 });
186 quote! {
187 #name::#vname( #(#field_indices),* ) => {
188 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
189 #(#field_encodes)*
190 }
191 }
192 }
193 syn::Fields::Unit => {
194 if use_numeric_disc {
195 quote! {
196 #name::#vname => {
197 let disc = (#name::#vname as #repr_ty_ts) as usize;
198 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(disc, writer)?;
199 }
200 }
201 } else {
202 quote! {
203 #name::#vname => {
204 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
205 }
206 }
207 }
208 }
209 }
210 });
211 Ok(quote! {
212 impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
213 #[inline(always)]
214 fn encode_ext(
215 &self,
216 writer: &mut impl #krate::io::Write,
217 mut dedupe_encoder: Option<&mut #krate::dedupe::DedupeEncoder>,
218 ) -> #krate::Result<usize> {
219 let mut total_bytes = 0;
220 match self {
221 #(#variant_matches)*
222 }
223 Ok(total_bytes)
224 }
225 }
226 })
227 }
228 syn::Data::Union(_data_union) => {
229 Err(syn::Error::new_spanned(
231 derive_input.ident,
232 "Encode cannot be derived for unions",
233 ))
234 }
235 }
236}
237
238#[inline(always)]
239fn derive_decode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
240 let derive_input = parse2::<DeriveInput>(input.into())?;
241 let krate = crate_path();
242 let name = derive_input.ident.clone();
243 let mut generics = derive_input.generics.clone();
245 {
246 let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
248 let where_clause = generics.make_where_clause();
249 for ident in type_idents {
250 where_clause
252 .predicates
253 .push(parse_quote!(#ident: #krate::prelude::Decode));
254 }
255 }
256 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257 match derive_input.data {
258 syn::Data::Struct(data_struct) => {
259 let fields = data_struct.fields;
260 let decode_body = match fields {
261 syn::Fields::Named(ref named_fields) => {
262 let field_decodes = named_fields.named.iter().map(|f| {
263 let fname = &f.ident;
264 let ftype = &f.ty;
265 quote! {
266 #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
267 }
268 });
269 quote! {
270 Ok(#name {
271 #(#field_decodes)*
272 })
273 }
274 }
275 syn::Fields::Unnamed(ref unnamed_fields) => {
276 let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
277 let ftype = &f.ty;
278 quote! {
279 <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
280 }
281 });
282 quote! {
283 Ok(#name(
284 #(#field_decodes)*
285 ))
286 }
287 }
288 syn::Fields::Unit => quote! { Ok(#name) },
289 };
290 Ok(quote! {
291 impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
292 #[inline(always)]
293 fn decode_ext(
294 reader: &mut impl #krate::io::Read,
295 mut dedupe_decoder: Option<&mut #krate::dedupe::DedupeDecoder>,
296 ) -> #krate::Result<Self> {
297 #decode_body
298 }
299 }
300 })
301 }
302 syn::Data::Enum(data_enum) => {
303 let is_c_like = data_enum
304 .variants
305 .iter()
306 .all(|v| matches!(v.fields, syn::Fields::Unit));
307 let repr_ty = enum_repr_ty(&derive_input.attrs);
308 let use_numeric_disc = is_c_like && repr_ty.is_some();
309 let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
310 let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
311 let vname = &v.ident;
312 let idx_lit = syn::Index::from(idx);
313 match &v.fields {
314 syn::Fields::Named(named_fields) => {
315 let field_decodes = named_fields.named.iter().map(|f| {
316 let fname = &f.ident;
317 let ftype = &f.ty;
318 quote! {
319 #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
320 }
321 });
322 quote! {
323 #idx_lit => Ok(#name::#vname { #(#field_decodes)* }),
324 }
325 }
326 syn::Fields::Unnamed(unnamed_fields) => {
327 let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
328 let ftype = &f.ty;
329 quote! {
330 <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
331 }
332 });
333 quote! {
334 #idx_lit => Ok(#name::#vname( #(#field_decodes)* )),
335 }
336 }
337 syn::Fields::Unit => {
338 if use_numeric_disc {
339 quote! {
340 ((#name::#vname as #repr_ty_ts) as usize) => Ok(#name::#vname),
341 }
342 } else {
343 quote! {
344 #idx_lit => Ok(#name::#vname),
345 }
346 }
347 }
348 }
349 });
350 Ok(quote! {
351 impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
352 #[inline(always)]
353 fn decode_ext(
354 reader: &mut impl #krate::io::Read,
355 mut dedupe_decoder: Option<&mut #krate::dedupe::DedupeDecoder>,
356 ) -> #krate::Result<Self> {
357 let variant_idx = <usize as #krate::prelude::Decode>::decode_discriminant(reader)?;
358 match variant_idx {
359 #(#variant_matches)*
360 _ => Err(#krate::io::Error::InvalidData),
361 }
362 }
363 }
364 })
365 }
366 syn::Data::Union(_data_union) => {
367 Err(syn::Error::new_spanned(
369 derive_input.ident,
370 "Decode cannot be derived for unions",
371 ))
372 }
373 }
374}
375
376#[test]
377fn test_derive_encode_struct_basic() {
378 let tokens = quote! {
379 struct TestStruct {
380 a: u32,
381 b: String,
382 }
383 };
384 let derived = derive_encode_impl(tokens).unwrap();
385 let expected = quote! {
386 impl ::lencode::prelude::Encode for TestStruct {
387 #[inline(always)]
388 fn encode_ext(
389 &self,
390 writer: &mut impl ::lencode::io::Write,
391 mut dedupe_encoder: Option<&mut ::lencode::dedupe::DedupeEncoder>,
392 ) -> ::lencode::Result<usize> {
393 let mut total_bytes = 0;
394 total_bytes += <u32 as ::lencode::prelude::Encode>::encode_ext(
395 &self.a,
396 writer,
397 dedupe_encoder.as_deref_mut()
398 )?;
399 total_bytes += <String as ::lencode::prelude::Encode>::encode_ext(
400 &self.b,
401 writer,
402 dedupe_encoder.as_deref_mut()
403 )?;
404 Ok(total_bytes)
405 }
406 }
407 };
408 assert_eq!(derived.to_string(), expected.to_string());
409}
410
411#[test]
412fn test_derive_decode_struct_basic() {
413 let tokens = quote! {
414 struct TestStruct {
415 a: u32,
416 b: String,
417 }
418 };
419 let derived = derive_decode_impl(tokens).unwrap();
420 let expected = quote! {
421 impl ::lencode::prelude::Decode for TestStruct {
422 #[inline(always)]
423 fn decode_ext(
424 reader: &mut impl ::lencode::io::Read,
425 mut dedupe_decoder: Option<&mut ::lencode::dedupe::DedupeDecoder>,
426 ) -> ::lencode::Result<Self> {
427 Ok(TestStruct {
428 a: <u32 as ::lencode::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
429 b: <String as ::lencode::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
430 })
431 }
432 }
433 };
434 assert_eq!(derived.to_string(), expected.to_string());
435}