lencode_macros/
lib.rs

1//! Derive macros for `lencode` encoding/decoding traits.
2//!
3//! - `#[derive(Encode)]` implements `lencode::Encode` by writing fields in declaration order
4//!   and encoding enum discriminants compactly.
5//! - `#[derive(Decode)]` implements `lencode::Decode` to read the same layout.
6//!
7//! For C‑like enums with an explicit `#[repr(uN/iN)]`, the numeric value of the discriminant
8//! is preserved; otherwise, the variant index is used.
9use proc_macro::TokenStream;
10use proc_macro_crate::{FoundCrate, crate_name};
11use proc_macro2::{Span, TokenStream as TokenStream2};
12use quote::quote;
13use syn::{Attribute, DeriveInput, Ident, Result, Type, parse_quote, parse2};
14
15fn enum_repr_ty(attrs: &[Attribute]) -> Option<Type> {
16    let mut out: Option<Type> = None;
17    for attr in attrs {
18        if attr.path().is_ident("repr") {
19            let _ = attr.parse_nested_meta(|meta| {
20                if let Some(ident) = meta.path.get_ident() {
21                    match ident.to_string().as_str() {
22                        "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64"
23                        | "isize" => {
24                            let ty_ident = Ident::new(&ident.to_string(), Span::call_site());
25                            out = Some(parse_quote!(#ty_ident));
26                        }
27                        _ => {}
28                    }
29                }
30                Ok(())
31            });
32        }
33    }
34    out
35}
36
37fn crate_path() -> TokenStream2 {
38    // Resolve the path to the main `lencode` crate from the macro crate, honoring any
39    // potential crate renames by the downstream user. In ambiguous contexts like doctests,
40    // prefer the absolute `::lencode` path.
41    let found = crate_name("lencode");
42    match found {
43        Ok(FoundCrate::Itself) => quote!(::lencode),
44        Ok(FoundCrate::Name(actual_name)) => {
45            let ident = Ident::new(&actual_name, Span::call_site());
46            quote!(::#ident)
47        }
48        Err(_) => quote!(::lencode),
49    }
50}
51
52/// Derives `lencode::Encode` for structs and enums.
53///
54/// - Structs: fields are encoded in declaration order.
55/// - Enums: a compact discriminant is written, then any fields as for structs. C‑like enums
56///   with `#[repr(uN/iN)]` preserve the numeric discriminant.
57#[proc_macro_derive(Encode)]
58pub fn derive_encode(input: TokenStream) -> TokenStream {
59    match derive_encode_impl(input) {
60        Ok(ts) => ts.into(),
61        Err(err) => err.to_compile_error().into(),
62    }
63}
64
65/// Derives `lencode::Decode` for structs and enums.
66///
67/// The layout matches what `#[derive(Encode)]` produces.
68#[proc_macro_derive(Decode)]
69pub fn derive_decode(input: TokenStream) -> TokenStream {
70    match derive_decode_impl(input) {
71        Ok(ts) => ts.into(),
72        Err(err) => err.to_compile_error().into(),
73    }
74}
75
76#[inline(always)]
77fn derive_encode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
78    let derive_input = parse2::<DeriveInput>(input.into())?;
79    let krate = crate_path();
80    let name = derive_input.ident.clone();
81    // Prepare generics and add Encode bounds for all type parameters
82    let mut generics = derive_input.generics.clone();
83    {
84        // Collect type parameter idents first to avoid borrow conflicts
85        let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
86        let where_clause = generics.make_where_clause();
87        for ident in type_idents {
88            // Add `T: Encode` bound for each type parameter `T`
89            where_clause
90                .predicates
91                .push(parse_quote!(#ident: #krate::prelude::Encode));
92        }
93    }
94    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
95    match derive_input.data {
96        syn::Data::Struct(data_struct) => {
97            let fields = data_struct.fields;
98            let encode_body = match fields {
99                syn::Fields::Named(ref named_fields) => {
100                    let field_encodes = named_fields.named.iter().map(|f| {
101                        let fname = &f.ident;
102                        let ftype = &f.ty;
103                        quote! {
104                            total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#fname, writer, dedupe_encoder.as_deref_mut())?;
105                        }
106                    });
107                    quote! {
108                        #(#field_encodes)*
109                    }
110                }
111                syn::Fields::Unnamed(ref unnamed_fields) => {
112                    let field_encodes = unnamed_fields.unnamed.iter().enumerate().map(|(i, f)| {
113                        let index = syn::Index::from(i);
114                        let ftype = &f.ty;
115                        quote! {
116                            total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#index, writer, dedupe_encoder.as_deref_mut())?;
117                        }
118                    });
119                    quote! {
120                        #(#field_encodes)*
121                    }
122                }
123                syn::Fields::Unit => quote! {},
124            };
125            Ok(quote! {
126                impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
127                    #[inline(always)]
128                    fn encode_ext(
129                        &self,
130                        writer: &mut impl #krate::io::Write,
131                        mut dedupe_encoder: Option<&mut #krate::dedupe::DedupeEncoder>,
132                    ) -> #krate::Result<usize> {
133                        let mut total_bytes = 0;
134                        #encode_body
135                        Ok(total_bytes)
136                    }
137                }
138            })
139        }
140        syn::Data::Enum(data_enum) => {
141            let is_c_like = data_enum
142                .variants
143                .iter()
144                .all(|v| matches!(v.fields, syn::Fields::Unit));
145            let repr_ty = enum_repr_ty(&derive_input.attrs);
146            let use_numeric_disc = is_c_like && repr_ty.is_some();
147            let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
148            let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
149				let vname = &v.ident;
150				let idx_lit = syn::Index::from(idx);
151				match &v.fields {
152					syn::Fields::Named(named_fields) => {
153						let fields: Vec<_> = named_fields
154							.named
155							.iter()
156							.map(|f| (f.ident.as_ref().unwrap().clone(), f.ty.clone()))
157							.collect();
158
159						let field_names: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
160						let field_encodes = fields.iter().map(|(fname, ftype)| {
161							quote! {
162								total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, dedupe_encoder.as_deref_mut())?;
163							}
164						});
165						quote! {
166							#name::#vname { #(#field_names),* } => {
167								total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
168								#(#field_encodes)*
169							}
170						}
171					}
172					syn::Fields::Unnamed(unnamed_fields) => {
173						let fields: Vec<_> = unnamed_fields
174							.unnamed
175							.iter()
176							.enumerate()
177							.map(|(i, f)| (Ident::new(&format!("field{}", i), Span::call_site()), f.ty.clone()))
178							.collect();
179
180						let field_indices: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
181						let field_encodes = fields.iter().map(|(fname, ftype)| {
182							quote! {
183								total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, dedupe_encoder.as_deref_mut())?;
184							}
185						});
186						quote! {
187							#name::#vname( #(#field_indices),* ) => {
188								total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
189								#(#field_encodes)*
190							}
191						}
192					}
193					syn::Fields::Unit => {
194                        if use_numeric_disc {
195                            quote! {
196                                #name::#vname => {
197                                    let disc = (#name::#vname as #repr_ty_ts) as usize;
198                                    total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(disc, writer)?;
199                                }
200                            }
201                        } else {
202                            quote! {
203                                #name::#vname => {
204                                    total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
205                                }
206                            }
207                        }
208                    }
209				}
210			});
211            Ok(quote! {
212                impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
213                    #[inline(always)]
214                    fn encode_ext(
215                        &self,
216                        writer: &mut impl #krate::io::Write,
217                        mut dedupe_encoder: Option<&mut #krate::dedupe::DedupeEncoder>,
218                    ) -> #krate::Result<usize> {
219                        let mut total_bytes = 0;
220                        match self {
221                            #(#variant_matches)*
222                        }
223                        Ok(total_bytes)
224                    }
225                }
226            })
227        }
228        syn::Data::Union(_data_union) => {
229            // Unions are not supported
230            Err(syn::Error::new_spanned(
231                derive_input.ident,
232                "Encode cannot be derived for unions",
233            ))
234        }
235    }
236}
237
238#[inline(always)]
239fn derive_decode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
240    let derive_input = parse2::<DeriveInput>(input.into())?;
241    let krate = crate_path();
242    let name = derive_input.ident.clone();
243    // Prepare generics and add Decode bounds for all type parameters
244    let mut generics = derive_input.generics.clone();
245    {
246        // Collect type parameter idents first to avoid borrow conflicts
247        let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
248        let where_clause = generics.make_where_clause();
249        for ident in type_idents {
250            // Add `T: Decode` bound for each type parameter `T`
251            where_clause
252                .predicates
253                .push(parse_quote!(#ident: #krate::prelude::Decode));
254        }
255    }
256    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
257    match derive_input.data {
258        syn::Data::Struct(data_struct) => {
259            let fields = data_struct.fields;
260            let decode_body = match fields {
261                syn::Fields::Named(ref named_fields) => {
262                    let field_decodes = named_fields.named.iter().map(|f| {
263                        let fname = &f.ident;
264                        let ftype = &f.ty;
265                        quote! {
266                            #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
267                        }
268                    });
269                    quote! {
270                        Ok(#name {
271                            #(#field_decodes)*
272                        })
273                    }
274                }
275                syn::Fields::Unnamed(ref unnamed_fields) => {
276                    let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
277                        let ftype = &f.ty;
278                        quote! {
279                            <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
280                        }
281                    });
282                    quote! {
283                        Ok(#name(
284                            #(#field_decodes)*
285                        ))
286                    }
287                }
288                syn::Fields::Unit => quote! { Ok(#name) },
289            };
290            Ok(quote! {
291                impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
292                    #[inline(always)]
293                    fn decode_ext(
294                        reader: &mut impl #krate::io::Read,
295                        mut dedupe_decoder: Option<&mut #krate::dedupe::DedupeDecoder>,
296                    ) -> #krate::Result<Self> {
297                        #decode_body
298                    }
299                }
300            })
301        }
302        syn::Data::Enum(data_enum) => {
303            let is_c_like = data_enum
304                .variants
305                .iter()
306                .all(|v| matches!(v.fields, syn::Fields::Unit));
307            let repr_ty = enum_repr_ty(&derive_input.attrs);
308            let use_numeric_disc = is_c_like && repr_ty.is_some();
309            let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
310            let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
311                let vname = &v.ident;
312                let idx_lit = syn::Index::from(idx);
313                match &v.fields {
314                    syn::Fields::Named(named_fields) => {
315                        let field_decodes = named_fields.named.iter().map(|f| {
316                            let fname = &f.ident;
317                            let ftype = &f.ty;
318							quote! {
319								#fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
320							}
321						});
322                        quote! {
323                            #idx_lit => Ok(#name::#vname { #(#field_decodes)* }),
324                        }
325                    }
326                    syn::Fields::Unnamed(unnamed_fields) => {
327                        let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
328                            let ftype = &f.ty;
329                            quote! {
330                                <#ftype as #krate::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
331                            }
332                        });
333                        quote! {
334                            #idx_lit => Ok(#name::#vname( #(#field_decodes)* )),
335                        }
336                    }
337                    syn::Fields::Unit => {
338                        if use_numeric_disc {
339                            quote! {
340                                ((#name::#vname as #repr_ty_ts) as usize) => Ok(#name::#vname),
341                            }
342                        } else {
343                            quote! {
344                                #idx_lit => Ok(#name::#vname),
345                            }
346                        }
347                    }
348                }
349            });
350            Ok(quote! {
351                impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
352                    #[inline(always)]
353                    fn decode_ext(
354                        reader: &mut impl #krate::io::Read,
355                        mut dedupe_decoder: Option<&mut #krate::dedupe::DedupeDecoder>,
356                    ) -> #krate::Result<Self> {
357                        let variant_idx = <usize as #krate::prelude::Decode>::decode_discriminant(reader)?;
358                        match variant_idx {
359                            #(#variant_matches)*
360                            _ => Err(#krate::io::Error::InvalidData),
361                        }
362                    }
363                }
364            })
365        }
366        syn::Data::Union(_data_union) => {
367            // Unions are not supported
368            Err(syn::Error::new_spanned(
369                derive_input.ident,
370                "Decode cannot be derived for unions",
371            ))
372        }
373    }
374}
375
376#[test]
377fn test_derive_encode_struct_basic() {
378    let tokens = quote! {
379        struct TestStruct {
380            a: u32,
381            b: String,
382        }
383    };
384    let derived = derive_encode_impl(tokens).unwrap();
385    let expected = quote! {
386        impl ::lencode::prelude::Encode for TestStruct {
387            #[inline(always)]
388            fn encode_ext(
389                &self,
390                writer: &mut impl ::lencode::io::Write,
391                mut dedupe_encoder: Option<&mut ::lencode::dedupe::DedupeEncoder>,
392            ) -> ::lencode::Result<usize> {
393                let mut total_bytes = 0;
394                total_bytes += <u32 as ::lencode::prelude::Encode>::encode_ext(
395                    &self.a,
396                    writer,
397                    dedupe_encoder.as_deref_mut()
398                )?;
399                total_bytes += <String as ::lencode::prelude::Encode>::encode_ext(
400                    &self.b,
401                    writer,
402                    dedupe_encoder.as_deref_mut()
403                )?;
404                Ok(total_bytes)
405            }
406        }
407    };
408    assert_eq!(derived.to_string(), expected.to_string());
409}
410
411#[test]
412fn test_derive_decode_struct_basic() {
413    let tokens = quote! {
414        struct TestStruct {
415            a: u32,
416            b: String,
417        }
418    };
419    let derived = derive_decode_impl(tokens).unwrap();
420    let expected = quote! {
421        impl ::lencode::prelude::Decode for TestStruct {
422            #[inline(always)]
423            fn decode_ext(
424                reader: &mut impl ::lencode::io::Read,
425                mut dedupe_decoder: Option<&mut ::lencode::dedupe::DedupeDecoder>,
426            ) -> ::lencode::Result<Self> {
427                Ok(TestStruct {
428                    a: <u32 as ::lencode::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
429                    b: <String as ::lencode::prelude::Decode>::decode_ext(reader, dedupe_decoder.as_deref_mut())?,
430                })
431            }
432        }
433    };
434    assert_eq!(derived.to_string(), expected.to_string());
435}