Skip to main content

lencode_macros/
lib.rs

1//! Derive macros for `lencode` encoding/decoding traits.
2//!
3//! - `#[derive(Encode)]` implements `lencode::Encode` by writing fields in declaration order
4//!   and encoding enum discriminants compactly.
5//! - `#[derive(Decode)]` implements `lencode::Decode` to read the same layout.
6//! - `#[derive(Pack)]` implements `lencode::pack::Pack` by packing/unpacking fields in
7//!   declaration order. For `#[repr(transparent)]` single‑field structs, it additionally
8//!   generates bulk `pack_slice`/`unpack_vec` overrides that transmute to/from the inner
9//!   type's slice/vec, enabling zero‑copy bulk I/O for newtypes over byte arrays.
10//!
11//! For C‑like enums with an explicit `#[repr(uN/iN)]`, the numeric value of the discriminant
12//! is preserved; otherwise, the variant index is used.
13use proc_macro::TokenStream;
14use proc_macro_crate::{FoundCrate, crate_name};
15use proc_macro2::{Span, TokenStream as TokenStream2};
16use quote::quote;
17use syn::{Attribute, DeriveInput, Ident, Result, Type, parse_quote, parse2};
18
19/// Returns `true` if `#[repr(transparent)]` is present on the item.
20fn has_repr_transparent(attrs: &[Attribute]) -> bool {
21    for attr in attrs {
22        if attr.path().is_ident("repr") {
23            let mut found = false;
24            let _ = attr.parse_nested_meta(|meta| {
25                if meta.path.is_ident("transparent") {
26                    found = true;
27                }
28                Ok(())
29            });
30            if found {
31                return true;
32            }
33        }
34    }
35    false
36}
37
38fn enum_repr_ty(attrs: &[Attribute]) -> Option<Type> {
39    let mut out: Option<Type> = None;
40    for attr in attrs {
41        if attr.path().is_ident("repr") {
42            let _ = attr.parse_nested_meta(|meta| {
43                if let Some(ident) = meta.path.get_ident() {
44                    match ident.to_string().as_str() {
45                        "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64"
46                        | "isize" => {
47                            let ty_ident = Ident::new(&ident.to_string(), Span::call_site());
48                            out = Some(parse_quote!(#ty_ident));
49                        }
50                        _ => {}
51                    }
52                }
53                Ok(())
54            });
55        }
56    }
57    out
58}
59
60fn crate_path() -> TokenStream2 {
61    // Resolve the path to the main `lencode` crate from the macro crate, honoring any
62    // potential crate renames by the downstream user. In ambiguous contexts like doctests,
63    // prefer the absolute `::lencode` path.
64    let found = crate_name("lencode");
65    match found {
66        Ok(FoundCrate::Itself) => quote!(::lencode),
67        Ok(FoundCrate::Name(actual_name)) => {
68            let ident = Ident::new(&actual_name, Span::call_site());
69            quote!(::#ident)
70        }
71        Err(_) => quote!(::lencode),
72    }
73}
74
75/// Derives `lencode::Encode` for structs and enums.
76///
77/// - Structs: fields are encoded in declaration order.
78/// - Enums: a compact discriminant is written, then any fields as for structs. C‑like enums
79///   with `#[repr(uN/iN)]` preserve the numeric discriminant.
80#[proc_macro_derive(Encode)]
81pub fn derive_encode(input: TokenStream) -> TokenStream {
82    match derive_encode_impl(input) {
83        Ok(ts) => ts.into(),
84        Err(err) => err.to_compile_error().into(),
85    }
86}
87
88/// Derives `lencode::Decode` for structs and enums.
89///
90/// The layout matches what `#[derive(Encode)]` produces.
91#[proc_macro_derive(Decode)]
92pub fn derive_decode(input: TokenStream) -> TokenStream {
93    match derive_decode_impl(input) {
94        Ok(ts) => ts.into(),
95        Err(err) => err.to_compile_error().into(),
96    }
97}
98
99/// Derives `lencode::pack::Pack` for structs.
100///
101/// - Fields are packed/unpacked in declaration order using their own `Pack` impls.
102/// - For `#[repr(transparent)]` single‑field structs, bulk `pack_slice` and `unpack_vec`
103///   overrides are generated that transmute to/from the inner type's slice/vec, enabling
104///   zero‑copy bulk I/O for newtypes over byte arrays.
105///
106/// # Example
107///
108/// ```ignore
109/// #[repr(transparent)]
110/// #[derive(Pack)]
111/// struct MyPubkey([u8; 32]);
112/// ```
113#[proc_macro_derive(Pack)]
114pub fn derive_pack(input: TokenStream) -> TokenStream {
115    match derive_pack_impl(input) {
116        Ok(ts) => ts.into(),
117        Err(err) => err.to_compile_error().into(),
118    }
119}
120
121#[inline(always)]
122fn derive_encode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
123    let derive_input = parse2::<DeriveInput>(input.into())?;
124    let krate = crate_path();
125    let name = derive_input.ident.clone();
126    // Prepare generics and add Encode bounds for all type parameters
127    let mut generics = derive_input.generics.clone();
128    {
129        // Collect type parameter idents first to avoid borrow conflicts
130        let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
131        let where_clause = generics.make_where_clause();
132        for ident in type_idents {
133            // Add `T: Encode` bound for each type parameter `T`
134            where_clause
135                .predicates
136                .push(parse_quote!(#ident: #krate::prelude::Encode));
137        }
138    }
139    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
140    match derive_input.data {
141        syn::Data::Struct(data_struct) => {
142            let fields = data_struct.fields;
143            let encode_body = match fields {
144                syn::Fields::Named(ref named_fields) => {
145                    let field_encodes = named_fields.named.iter().map(|f| {
146                        let fname = &f.ident;
147                        let ftype = &f.ty;
148                        quote! {
149                            total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#fname, writer, ctx.as_deref_mut())?;
150                        }
151                    });
152                    quote! {
153                        #(#field_encodes)*
154                    }
155                }
156                syn::Fields::Unnamed(ref unnamed_fields) => {
157                    let field_encodes = unnamed_fields.unnamed.iter().enumerate().map(|(i, f)| {
158                        let index = syn::Index::from(i);
159                        let ftype = &f.ty;
160                        quote! {
161                            total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#index, writer, ctx.as_deref_mut())?;
162                        }
163                    });
164                    quote! {
165                        #(#field_encodes)*
166                    }
167                }
168                syn::Fields::Unit => quote! {},
169            };
170            Ok(quote! {
171                impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
172                    #[inline(always)]
173                    fn encode_ext(
174                        &self,
175                        writer: &mut impl #krate::io::Write,
176                        mut ctx: Option<&mut #krate::context::EncoderContext>,
177                    ) -> #krate::Result<usize> {
178                        let mut total_bytes = 0;
179                        #encode_body
180                        Ok(total_bytes)
181                    }
182                }
183            })
184        }
185        syn::Data::Enum(data_enum) => {
186            let is_c_like = data_enum
187                .variants
188                .iter()
189                .all(|v| matches!(v.fields, syn::Fields::Unit));
190            let repr_ty = enum_repr_ty(&derive_input.attrs);
191            let use_numeric_disc = is_c_like && repr_ty.is_some();
192            let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
193            let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
194				let vname = &v.ident;
195				let idx_lit = syn::Index::from(idx);
196				match &v.fields {
197					syn::Fields::Named(named_fields) => {
198						let fields: Vec<_> = named_fields
199							.named
200							.iter()
201							.map(|f| (f.ident.as_ref().unwrap().clone(), f.ty.clone()))
202							.collect();
203
204						let field_names: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
205						let field_encodes = fields.iter().map(|(fname, ftype)| {
206							quote! {
207								total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, ctx.as_deref_mut())?;
208							}
209						});
210						quote! {
211							#name::#vname { #(#field_names),* } => {
212								total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
213								#(#field_encodes)*
214							}
215						}
216					}
217					syn::Fields::Unnamed(unnamed_fields) => {
218						let fields: Vec<_> = unnamed_fields
219							.unnamed
220							.iter()
221							.enumerate()
222							.map(|(i, f)| (Ident::new(&format!("field{}", i), Span::call_site()), f.ty.clone()))
223							.collect();
224
225						let field_indices: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
226						let field_encodes = fields.iter().map(|(fname, ftype)| {
227							quote! {
228								total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, ctx.as_deref_mut())?;
229							}
230						});
231						quote! {
232							#name::#vname( #(#field_indices),* ) => {
233								total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
234								#(#field_encodes)*
235							}
236						}
237					}
238					syn::Fields::Unit => {
239                        if use_numeric_disc {
240                            quote! {
241                                #name::#vname => {
242                                    let disc = (#name::#vname as #repr_ty_ts) as usize;
243                                    total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(disc, writer)?;
244                                }
245                            }
246                        } else {
247                            quote! {
248                                #name::#vname => {
249                                    total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
250                                }
251                            }
252                        }
253                    }
254				}
255			});
256            Ok(quote! {
257                impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
258                    #[inline(always)]
259                    fn encode_ext(
260                        &self,
261                        writer: &mut impl #krate::io::Write,
262                        mut ctx: Option<&mut #krate::context::EncoderContext>,
263                    ) -> #krate::Result<usize> {
264                        let mut total_bytes = 0;
265                        match self {
266                            #(#variant_matches)*
267                        }
268                        Ok(total_bytes)
269                    }
270                }
271            })
272        }
273        syn::Data::Union(_data_union) => {
274            // Unions are not supported
275            Err(syn::Error::new_spanned(
276                derive_input.ident,
277                "Encode cannot be derived for unions",
278            ))
279        }
280    }
281}
282
283#[inline(always)]
284fn derive_decode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
285    let derive_input = parse2::<DeriveInput>(input.into())?;
286    let krate = crate_path();
287    let name = derive_input.ident.clone();
288    // Prepare generics and add Decode bounds for all type parameters
289    let mut generics = derive_input.generics.clone();
290    {
291        // Collect type parameter idents first to avoid borrow conflicts
292        let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
293        let where_clause = generics.make_where_clause();
294        for ident in type_idents {
295            // Add `T: Decode` bound for each type parameter `T`
296            where_clause
297                .predicates
298                .push(parse_quote!(#ident: #krate::prelude::Decode));
299        }
300    }
301    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
302    match derive_input.data {
303        syn::Data::Struct(data_struct) => {
304            let fields = data_struct.fields;
305            let decode_body = match fields {
306                syn::Fields::Named(ref named_fields) => {
307                    let field_decodes = named_fields.named.iter().map(|f| {
308                        let fname = &f.ident;
309                        let ftype = &f.ty;
310                        quote! {
311                            #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
312                        }
313                    });
314                    quote! {
315                        Ok(#name {
316                            #(#field_decodes)*
317                        })
318                    }
319                }
320                syn::Fields::Unnamed(ref unnamed_fields) => {
321                    let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
322                        let ftype = &f.ty;
323                        quote! {
324                            <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
325                        }
326                    });
327                    quote! {
328                        Ok(#name(
329                            #(#field_decodes)*
330                        ))
331                    }
332                }
333                syn::Fields::Unit => quote! { Ok(#name) },
334            };
335            Ok(quote! {
336                impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
337                    #[inline(always)]
338                    fn decode_ext(
339                        reader: &mut impl #krate::io::Read,
340                        mut ctx: Option<&mut #krate::context::DecoderContext>,
341                    ) -> #krate::Result<Self> {
342                        #decode_body
343                    }
344                }
345            })
346        }
347        syn::Data::Enum(data_enum) => {
348            let is_c_like = data_enum
349                .variants
350                .iter()
351                .all(|v| matches!(v.fields, syn::Fields::Unit));
352            let repr_ty = enum_repr_ty(&derive_input.attrs);
353            let use_numeric_disc = is_c_like && repr_ty.is_some();
354            let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
355            let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
356                let vname = &v.ident;
357                let idx_lit = syn::Index::from(idx);
358                match &v.fields {
359                    syn::Fields::Named(named_fields) => {
360                        let field_decodes = named_fields.named.iter().map(|f| {
361                            let fname = &f.ident;
362                            let ftype = &f.ty;
363							quote! {
364								#fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
365							}
366						});
367                        quote! {
368                            #idx_lit => Ok(#name::#vname { #(#field_decodes)* }),
369                        }
370                    }
371                    syn::Fields::Unnamed(unnamed_fields) => {
372                        let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
373                            let ftype = &f.ty;
374                            quote! {
375                                <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
376                            }
377                        });
378                        quote! {
379                            #idx_lit => Ok(#name::#vname( #(#field_decodes)* )),
380                        }
381                    }
382                    syn::Fields::Unit => {
383                        if use_numeric_disc {
384                            quote! {
385                                disc if disc == ((#name::#vname as #repr_ty_ts) as usize) => Ok(#name::#vname),
386                            }
387                        } else {
388                            quote! {
389                                #idx_lit => Ok(#name::#vname),
390                            }
391                        }
392                    }
393                }
394            });
395            Ok(quote! {
396                impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
397                    #[inline(always)]
398                    fn decode_ext(
399                        reader: &mut impl #krate::io::Read,
400                        mut ctx: Option<&mut #krate::context::DecoderContext>,
401                    ) -> #krate::Result<Self> {
402                        let variant_idx = <usize as #krate::prelude::Decode>::decode_discriminant(reader)?;
403                        match variant_idx {
404                            #(#variant_matches)*
405                            _ => Err(#krate::io::Error::InvalidData),
406                        }
407                    }
408                }
409            })
410        }
411        syn::Data::Union(_data_union) => {
412            // Unions are not supported
413            Err(syn::Error::new_spanned(
414                derive_input.ident,
415                "Decode cannot be derived for unions",
416            ))
417        }
418    }
419}
420
421#[inline(always)]
422fn derive_pack_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
423    let derive_input = parse2::<DeriveInput>(input.into())?;
424    let krate = crate_path();
425    let name = derive_input.ident.clone();
426
427    let data_struct = match derive_input.data {
428        syn::Data::Struct(s) => s,
429        _ => {
430            return Err(syn::Error::new_spanned(
431                name,
432                "Pack can only be derived for structs",
433            ));
434        }
435    };
436
437    let is_transparent = has_repr_transparent(&derive_input.attrs);
438
439    // Collect fields info
440    let fields = &data_struct.fields;
441    let field_count = fields.len();
442
443    let (pack_body, unpack_body) = match fields {
444        syn::Fields::Named(named) => {
445            let pack_stmts = named.named.iter().map(|f| {
446                let fname = &f.ident;
447                let ftype = &f.ty;
448                quote! {
449                    total += <#ftype as #krate::pack::Pack>::pack(&self.#fname, writer)?;
450                }
451            });
452            let unpack_fields = named.named.iter().map(|f| {
453                let fname = &f.ident;
454                let ftype = &f.ty;
455                quote! {
456                    #fname: <#ftype as #krate::pack::Pack>::unpack(reader)?,
457                }
458            });
459            (
460                quote! {
461                    let mut total = 0usize;
462                    #(#pack_stmts)*
463                    Ok(total)
464                },
465                quote! {
466                    Ok(#name {
467                        #(#unpack_fields)*
468                    })
469                },
470            )
471        }
472        syn::Fields::Unnamed(unnamed) => {
473            let pack_stmts = unnamed.unnamed.iter().enumerate().map(|(i, f)| {
474                let index = syn::Index::from(i);
475                let ftype = &f.ty;
476                quote! {
477                    total += <#ftype as #krate::pack::Pack>::pack(&self.#index, writer)?;
478                }
479            });
480            let unpack_fields = unnamed.unnamed.iter().map(|f| {
481                let ftype = &f.ty;
482                quote! {
483                    <#ftype as #krate::pack::Pack>::unpack(reader)?,
484                }
485            });
486            (
487                quote! {
488                    let mut total = 0usize;
489                    #(#pack_stmts)*
490                    Ok(total)
491                },
492                quote! {
493                    Ok(#name(
494                        #(#unpack_fields)*
495                    ))
496                },
497            )
498        }
499        syn::Fields::Unit => (quote! { Ok(0) }, quote! { Ok(#name) }),
500    };
501
502    // For #[repr(transparent)] single-field structs, generate bulk pack_slice/unpack_vec
503    let bulk_methods = if is_transparent && field_count == 1 {
504        let inner_ty = match fields {
505            syn::Fields::Named(named) => &named.named[0].ty,
506            syn::Fields::Unnamed(unnamed) => &unnamed.unnamed[0].ty,
507            _ => unreachable!(),
508        };
509        quote! {
510            #[inline(always)]
511            fn pack_slice(items: &[Self], writer: &mut impl #krate::io::Write) -> #krate::Result<usize> {
512                // SAFETY: #[repr(transparent)] guarantees identical layout.
513                let inner: &[#inner_ty] = unsafe {
514                    core::slice::from_raw_parts(
515                        items.as_ptr() as *const #inner_ty,
516                        items.len(),
517                    )
518                };
519                <#inner_ty as #krate::pack::Pack>::pack_slice(inner, writer)
520            }
521
522            #[inline(always)]
523            fn unpack_vec(reader: &mut impl #krate::io::Read, count: usize) -> #krate::Result<Vec<Self>> {
524                let inner = <#inner_ty as #krate::pack::Pack>::unpack_vec(reader, count)?;
525                // SAFETY: #[repr(transparent)] guarantees identical layout.
526                Ok(unsafe { core::mem::transmute::<Vec<#inner_ty>, Vec<#name>>(inner) })
527            }
528        }
529    } else {
530        quote! {}
531    };
532
533    Ok(quote! {
534        impl #krate::pack::Pack for #name {
535            #[inline(always)]
536            fn pack(&self, writer: &mut impl #krate::io::Write) -> #krate::Result<usize> {
537                #pack_body
538            }
539
540            #[inline(always)]
541            fn unpack(reader: &mut impl #krate::io::Read) -> #krate::Result<Self> {
542                #unpack_body
543            }
544
545            #bulk_methods
546        }
547    })
548}
549
550#[test]
551fn test_derive_encode_struct_basic() {
552    let tokens = quote! {
553        struct TestStruct {
554            a: u32,
555            b: String,
556        }
557    };
558    let derived = derive_encode_impl(tokens).unwrap();
559    let expected = quote! {
560        impl ::lencode::prelude::Encode for TestStruct {
561            #[inline(always)]
562            fn encode_ext(
563                &self,
564                writer: &mut impl ::lencode::io::Write,
565                mut ctx: Option<&mut ::lencode::context::EncoderContext>,
566            ) -> ::lencode::Result<usize> {
567                let mut total_bytes = 0;
568                total_bytes += <u32 as ::lencode::prelude::Encode>::encode_ext(
569                    &self.a,
570                    writer,
571                    ctx.as_deref_mut()
572                )?;
573                total_bytes += <String as ::lencode::prelude::Encode>::encode_ext(
574                    &self.b,
575                    writer,
576                    ctx.as_deref_mut()
577                )?;
578                Ok(total_bytes)
579            }
580        }
581    };
582    assert_eq!(derived.to_string(), expected.to_string());
583}
584
585#[test]
586fn test_derive_decode_struct_basic() {
587    let tokens = quote! {
588        struct TestStruct {
589            a: u32,
590            b: String,
591        }
592    };
593    let derived = derive_decode_impl(tokens).unwrap();
594    let expected = quote! {
595        impl ::lencode::prelude::Decode for TestStruct {
596            #[inline(always)]
597            fn decode_ext(
598                reader: &mut impl ::lencode::io::Read,
599                mut ctx: Option<&mut ::lencode::context::DecoderContext>,
600            ) -> ::lencode::Result<Self> {
601                Ok(TestStruct {
602                    a: <u32 as ::lencode::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
603                    b: <String as ::lencode::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
604                })
605            }
606        }
607    };
608    assert_eq!(derived.to_string(), expected.to_string());
609}
610
611#[test]
612fn test_derive_pack_named_struct() {
613    let tokens = quote! {
614        struct Point {
615            x: u32,
616            y: u32,
617        }
618    };
619    let derived = derive_pack_impl(tokens).unwrap();
620    let expected = quote! {
621        impl ::lencode::pack::Pack for Point {
622            #[inline(always)]
623            fn pack(&self, writer: &mut impl ::lencode::io::Write) -> ::lencode::Result<usize> {
624                let mut total = 0usize;
625                total += <u32 as ::lencode::pack::Pack>::pack(&self.x, writer)?;
626                total += <u32 as ::lencode::pack::Pack>::pack(&self.y, writer)?;
627                Ok(total)
628            }
629
630            #[inline(always)]
631            fn unpack(reader: &mut impl ::lencode::io::Read) -> ::lencode::Result<Self> {
632                Ok(Point {
633                    x: <u32 as ::lencode::pack::Pack>::unpack(reader)?,
634                    y: <u32 as ::lencode::pack::Pack>::unpack(reader)?,
635                })
636            }
637        }
638    };
639    assert_eq!(derived.to_string(), expected.to_string());
640}
641
642#[test]
643fn test_derive_pack_transparent_tuple_struct() {
644    let tokens = quote! {
645        #[repr(transparent)]
646        struct MyKey([u8; 32]);
647    };
648    let derived = derive_pack_impl(tokens).unwrap();
649    // Just verify it parses and contains key signatures; exact whitespace around >> varies.
650    let s = derived.to_string();
651    assert!(
652        s.contains("pack_slice"),
653        "should contain pack_slice override"
654    );
655    assert!(
656        s.contains("unpack_vec"),
657        "should contain unpack_vec override"
658    );
659    assert!(
660        s.contains("transmute"),
661        "should contain transmute for bulk decode"
662    );
663    assert!(
664        s.contains("from_raw_parts"),
665        "should contain from_raw_parts for bulk encode"
666    );
667}