1use proc_macro::TokenStream;
14use proc_macro_crate::{FoundCrate, crate_name};
15use proc_macro2::{Span, TokenStream as TokenStream2};
16use quote::quote;
17use syn::{Attribute, DeriveInput, Ident, Result, Type, parse_quote, parse2};
18
19fn has_repr_transparent(attrs: &[Attribute]) -> bool {
21 for attr in attrs {
22 if attr.path().is_ident("repr") {
23 let mut found = false;
24 let _ = attr.parse_nested_meta(|meta| {
25 if meta.path.is_ident("transparent") {
26 found = true;
27 }
28 Ok(())
29 });
30 if found {
31 return true;
32 }
33 }
34 }
35 false
36}
37
38fn enum_repr_ty(attrs: &[Attribute]) -> Option<Type> {
39 let mut out: Option<Type> = None;
40 for attr in attrs {
41 if attr.path().is_ident("repr") {
42 let _ = attr.parse_nested_meta(|meta| {
43 if let Some(ident) = meta.path.get_ident() {
44 match ident.to_string().as_str() {
45 "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64"
46 | "isize" => {
47 let ty_ident = Ident::new(&ident.to_string(), Span::call_site());
48 out = Some(parse_quote!(#ty_ident));
49 }
50 _ => {}
51 }
52 }
53 Ok(())
54 });
55 }
56 }
57 out
58}
59
60fn crate_path() -> TokenStream2 {
61 let found = crate_name("lencode");
65 match found {
66 Ok(FoundCrate::Itself) => quote!(::lencode),
67 Ok(FoundCrate::Name(actual_name)) => {
68 let ident = Ident::new(&actual_name, Span::call_site());
69 quote!(::#ident)
70 }
71 Err(_) => quote!(::lencode),
72 }
73}
74
75#[proc_macro_derive(Encode)]
81pub fn derive_encode(input: TokenStream) -> TokenStream {
82 match derive_encode_impl(input) {
83 Ok(ts) => ts.into(),
84 Err(err) => err.to_compile_error().into(),
85 }
86}
87
88#[proc_macro_derive(Decode)]
92pub fn derive_decode(input: TokenStream) -> TokenStream {
93 match derive_decode_impl(input) {
94 Ok(ts) => ts.into(),
95 Err(err) => err.to_compile_error().into(),
96 }
97}
98
99#[proc_macro_derive(Pack)]
114pub fn derive_pack(input: TokenStream) -> TokenStream {
115 match derive_pack_impl(input) {
116 Ok(ts) => ts.into(),
117 Err(err) => err.to_compile_error().into(),
118 }
119}
120
121#[inline(always)]
122fn derive_encode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
123 let derive_input = parse2::<DeriveInput>(input.into())?;
124 let krate = crate_path();
125 let name = derive_input.ident.clone();
126 let mut generics = derive_input.generics.clone();
128 {
129 let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
131 let where_clause = generics.make_where_clause();
132 for ident in type_idents {
133 where_clause
135 .predicates
136 .push(parse_quote!(#ident: #krate::prelude::Encode));
137 }
138 }
139 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
140 match derive_input.data {
141 syn::Data::Struct(data_struct) => {
142 let fields = data_struct.fields;
143 let encode_body = match fields {
144 syn::Fields::Named(ref named_fields) => {
145 let field_encodes = named_fields.named.iter().map(|f| {
146 let fname = &f.ident;
147 let ftype = &f.ty;
148 quote! {
149 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#fname, writer, ctx.as_deref_mut())?;
150 }
151 });
152 quote! {
153 #(#field_encodes)*
154 }
155 }
156 syn::Fields::Unnamed(ref unnamed_fields) => {
157 let field_encodes = unnamed_fields.unnamed.iter().enumerate().map(|(i, f)| {
158 let index = syn::Index::from(i);
159 let ftype = &f.ty;
160 quote! {
161 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(&self.#index, writer, ctx.as_deref_mut())?;
162 }
163 });
164 quote! {
165 #(#field_encodes)*
166 }
167 }
168 syn::Fields::Unit => quote! {},
169 };
170 Ok(quote! {
171 impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
172 #[inline(always)]
173 fn encode_ext(
174 &self,
175 writer: &mut impl #krate::io::Write,
176 mut ctx: Option<&mut #krate::context::EncoderContext>,
177 ) -> #krate::Result<usize> {
178 let mut total_bytes = 0;
179 #encode_body
180 Ok(total_bytes)
181 }
182 }
183 })
184 }
185 syn::Data::Enum(data_enum) => {
186 let is_c_like = data_enum
187 .variants
188 .iter()
189 .all(|v| matches!(v.fields, syn::Fields::Unit));
190 let repr_ty = enum_repr_ty(&derive_input.attrs);
191 let use_numeric_disc = is_c_like && repr_ty.is_some();
192 let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
193 let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
194 let vname = &v.ident;
195 let idx_lit = syn::Index::from(idx);
196 match &v.fields {
197 syn::Fields::Named(named_fields) => {
198 let fields: Vec<_> = named_fields
199 .named
200 .iter()
201 .map(|f| (f.ident.as_ref().unwrap().clone(), f.ty.clone()))
202 .collect();
203
204 let field_names: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
205 let field_encodes = fields.iter().map(|(fname, ftype)| {
206 quote! {
207 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, ctx.as_deref_mut())?;
208 }
209 });
210 quote! {
211 #name::#vname { #(#field_names),* } => {
212 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
213 #(#field_encodes)*
214 }
215 }
216 }
217 syn::Fields::Unnamed(unnamed_fields) => {
218 let fields: Vec<_> = unnamed_fields
219 .unnamed
220 .iter()
221 .enumerate()
222 .map(|(i, f)| (Ident::new(&format!("field{}", i), Span::call_site()), f.ty.clone()))
223 .collect();
224
225 let field_indices: Vec<_> = fields.iter().map(|(ident, _)| ident).collect();
226 let field_encodes = fields.iter().map(|(fname, ftype)| {
227 quote! {
228 total_bytes += <#ftype as #krate::prelude::Encode>::encode_ext(#fname, writer, ctx.as_deref_mut())?;
229 }
230 });
231 quote! {
232 #name::#vname( #(#field_indices),* ) => {
233 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
234 #(#field_encodes)*
235 }
236 }
237 }
238 syn::Fields::Unit => {
239 if use_numeric_disc {
240 quote! {
241 #name::#vname => {
242 let disc = (#name::#vname as #repr_ty_ts) as usize;
243 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(disc, writer)?;
244 }
245 }
246 } else {
247 quote! {
248 #name::#vname => {
249 total_bytes += <usize as #krate::prelude::Encode>::encode_discriminant(#idx_lit as usize, writer)?;
250 }
251 }
252 }
253 }
254 }
255 });
256 Ok(quote! {
257 impl #impl_generics #krate::prelude::Encode for #name #ty_generics #where_clause {
258 #[inline(always)]
259 fn encode_ext(
260 &self,
261 writer: &mut impl #krate::io::Write,
262 mut ctx: Option<&mut #krate::context::EncoderContext>,
263 ) -> #krate::Result<usize> {
264 let mut total_bytes = 0;
265 match self {
266 #(#variant_matches)*
267 }
268 Ok(total_bytes)
269 }
270 }
271 })
272 }
273 syn::Data::Union(_data_union) => {
274 Err(syn::Error::new_spanned(
276 derive_input.ident,
277 "Encode cannot be derived for unions",
278 ))
279 }
280 }
281}
282
283#[inline(always)]
284fn derive_decode_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
285 let derive_input = parse2::<DeriveInput>(input.into())?;
286 let krate = crate_path();
287 let name = derive_input.ident.clone();
288 let mut generics = derive_input.generics.clone();
290 {
291 let type_idents: Vec<Ident> = generics.type_params().map(|tp| tp.ident.clone()).collect();
293 let where_clause = generics.make_where_clause();
294 for ident in type_idents {
295 where_clause
297 .predicates
298 .push(parse_quote!(#ident: #krate::prelude::Decode));
299 }
300 }
301 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
302 match derive_input.data {
303 syn::Data::Struct(data_struct) => {
304 let fields = data_struct.fields;
305 let decode_body = match fields {
306 syn::Fields::Named(ref named_fields) => {
307 let field_decodes = named_fields.named.iter().map(|f| {
308 let fname = &f.ident;
309 let ftype = &f.ty;
310 quote! {
311 #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
312 }
313 });
314 quote! {
315 Ok(#name {
316 #(#field_decodes)*
317 })
318 }
319 }
320 syn::Fields::Unnamed(ref unnamed_fields) => {
321 let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
322 let ftype = &f.ty;
323 quote! {
324 <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
325 }
326 });
327 quote! {
328 Ok(#name(
329 #(#field_decodes)*
330 ))
331 }
332 }
333 syn::Fields::Unit => quote! { Ok(#name) },
334 };
335 Ok(quote! {
336 impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
337 #[inline(always)]
338 fn decode_ext(
339 reader: &mut impl #krate::io::Read,
340 mut ctx: Option<&mut #krate::context::DecoderContext>,
341 ) -> #krate::Result<Self> {
342 #decode_body
343 }
344 }
345 })
346 }
347 syn::Data::Enum(data_enum) => {
348 let is_c_like = data_enum
349 .variants
350 .iter()
351 .all(|v| matches!(v.fields, syn::Fields::Unit));
352 let repr_ty = enum_repr_ty(&derive_input.attrs);
353 let use_numeric_disc = is_c_like && repr_ty.is_some();
354 let repr_ty_ts = repr_ty.unwrap_or(parse_quote!(usize));
355 let variant_matches = data_enum.variants.iter().enumerate().map(|(idx, v)| {
356 let vname = &v.ident;
357 let idx_lit = syn::Index::from(idx);
358 match &v.fields {
359 syn::Fields::Named(named_fields) => {
360 let field_decodes = named_fields.named.iter().map(|f| {
361 let fname = &f.ident;
362 let ftype = &f.ty;
363 quote! {
364 #fname: <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
365 }
366 });
367 quote! {
368 #idx_lit => Ok(#name::#vname { #(#field_decodes)* }),
369 }
370 }
371 syn::Fields::Unnamed(unnamed_fields) => {
372 let field_decodes = unnamed_fields.unnamed.iter().map(|f| {
373 let ftype = &f.ty;
374 quote! {
375 <#ftype as #krate::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
376 }
377 });
378 quote! {
379 #idx_lit => Ok(#name::#vname( #(#field_decodes)* )),
380 }
381 }
382 syn::Fields::Unit => {
383 if use_numeric_disc {
384 quote! {
385 disc if disc == ((#name::#vname as #repr_ty_ts) as usize) => Ok(#name::#vname),
386 }
387 } else {
388 quote! {
389 #idx_lit => Ok(#name::#vname),
390 }
391 }
392 }
393 }
394 });
395 Ok(quote! {
396 impl #impl_generics #krate::prelude::Decode for #name #ty_generics #where_clause {
397 #[inline(always)]
398 fn decode_ext(
399 reader: &mut impl #krate::io::Read,
400 mut ctx: Option<&mut #krate::context::DecoderContext>,
401 ) -> #krate::Result<Self> {
402 let variant_idx = <usize as #krate::prelude::Decode>::decode_discriminant(reader)?;
403 match variant_idx {
404 #(#variant_matches)*
405 _ => Err(#krate::io::Error::InvalidData),
406 }
407 }
408 }
409 })
410 }
411 syn::Data::Union(_data_union) => {
412 Err(syn::Error::new_spanned(
414 derive_input.ident,
415 "Decode cannot be derived for unions",
416 ))
417 }
418 }
419}
420
421#[inline(always)]
422fn derive_pack_impl(input: impl Into<TokenStream2>) -> Result<TokenStream2> {
423 let derive_input = parse2::<DeriveInput>(input.into())?;
424 let krate = crate_path();
425 let name = derive_input.ident.clone();
426
427 let data_struct = match derive_input.data {
428 syn::Data::Struct(s) => s,
429 _ => {
430 return Err(syn::Error::new_spanned(
431 name,
432 "Pack can only be derived for structs",
433 ));
434 }
435 };
436
437 let is_transparent = has_repr_transparent(&derive_input.attrs);
438
439 let fields = &data_struct.fields;
441 let field_count = fields.len();
442
443 let (pack_body, unpack_body) = match fields {
444 syn::Fields::Named(named) => {
445 let pack_stmts = named.named.iter().map(|f| {
446 let fname = &f.ident;
447 let ftype = &f.ty;
448 quote! {
449 total += <#ftype as #krate::pack::Pack>::pack(&self.#fname, writer)?;
450 }
451 });
452 let unpack_fields = named.named.iter().map(|f| {
453 let fname = &f.ident;
454 let ftype = &f.ty;
455 quote! {
456 #fname: <#ftype as #krate::pack::Pack>::unpack(reader)?,
457 }
458 });
459 (
460 quote! {
461 let mut total = 0usize;
462 #(#pack_stmts)*
463 Ok(total)
464 },
465 quote! {
466 Ok(#name {
467 #(#unpack_fields)*
468 })
469 },
470 )
471 }
472 syn::Fields::Unnamed(unnamed) => {
473 let pack_stmts = unnamed.unnamed.iter().enumerate().map(|(i, f)| {
474 let index = syn::Index::from(i);
475 let ftype = &f.ty;
476 quote! {
477 total += <#ftype as #krate::pack::Pack>::pack(&self.#index, writer)?;
478 }
479 });
480 let unpack_fields = unnamed.unnamed.iter().map(|f| {
481 let ftype = &f.ty;
482 quote! {
483 <#ftype as #krate::pack::Pack>::unpack(reader)?,
484 }
485 });
486 (
487 quote! {
488 let mut total = 0usize;
489 #(#pack_stmts)*
490 Ok(total)
491 },
492 quote! {
493 Ok(#name(
494 #(#unpack_fields)*
495 ))
496 },
497 )
498 }
499 syn::Fields::Unit => (quote! { Ok(0) }, quote! { Ok(#name) }),
500 };
501
502 let bulk_methods = if is_transparent && field_count == 1 {
504 let inner_ty = match fields {
505 syn::Fields::Named(named) => &named.named[0].ty,
506 syn::Fields::Unnamed(unnamed) => &unnamed.unnamed[0].ty,
507 _ => unreachable!(),
508 };
509 quote! {
510 #[inline(always)]
511 fn pack_slice(items: &[Self], writer: &mut impl #krate::io::Write) -> #krate::Result<usize> {
512 let inner: &[#inner_ty] = unsafe {
514 core::slice::from_raw_parts(
515 items.as_ptr() as *const #inner_ty,
516 items.len(),
517 )
518 };
519 <#inner_ty as #krate::pack::Pack>::pack_slice(inner, writer)
520 }
521
522 #[inline(always)]
523 fn unpack_vec(reader: &mut impl #krate::io::Read, count: usize) -> #krate::Result<Vec<Self>> {
524 let inner = <#inner_ty as #krate::pack::Pack>::unpack_vec(reader, count)?;
525 Ok(unsafe { core::mem::transmute::<Vec<#inner_ty>, Vec<#name>>(inner) })
527 }
528 }
529 } else {
530 quote! {}
531 };
532
533 Ok(quote! {
534 impl #krate::pack::Pack for #name {
535 #[inline(always)]
536 fn pack(&self, writer: &mut impl #krate::io::Write) -> #krate::Result<usize> {
537 #pack_body
538 }
539
540 #[inline(always)]
541 fn unpack(reader: &mut impl #krate::io::Read) -> #krate::Result<Self> {
542 #unpack_body
543 }
544
545 #bulk_methods
546 }
547 })
548}
549
550#[test]
551fn test_derive_encode_struct_basic() {
552 let tokens = quote! {
553 struct TestStruct {
554 a: u32,
555 b: String,
556 }
557 };
558 let derived = derive_encode_impl(tokens).unwrap();
559 let expected = quote! {
560 impl ::lencode::prelude::Encode for TestStruct {
561 #[inline(always)]
562 fn encode_ext(
563 &self,
564 writer: &mut impl ::lencode::io::Write,
565 mut ctx: Option<&mut ::lencode::context::EncoderContext>,
566 ) -> ::lencode::Result<usize> {
567 let mut total_bytes = 0;
568 total_bytes += <u32 as ::lencode::prelude::Encode>::encode_ext(
569 &self.a,
570 writer,
571 ctx.as_deref_mut()
572 )?;
573 total_bytes += <String as ::lencode::prelude::Encode>::encode_ext(
574 &self.b,
575 writer,
576 ctx.as_deref_mut()
577 )?;
578 Ok(total_bytes)
579 }
580 }
581 };
582 assert_eq!(derived.to_string(), expected.to_string());
583}
584
585#[test]
586fn test_derive_decode_struct_basic() {
587 let tokens = quote! {
588 struct TestStruct {
589 a: u32,
590 b: String,
591 }
592 };
593 let derived = derive_decode_impl(tokens).unwrap();
594 let expected = quote! {
595 impl ::lencode::prelude::Decode for TestStruct {
596 #[inline(always)]
597 fn decode_ext(
598 reader: &mut impl ::lencode::io::Read,
599 mut ctx: Option<&mut ::lencode::context::DecoderContext>,
600 ) -> ::lencode::Result<Self> {
601 Ok(TestStruct {
602 a: <u32 as ::lencode::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
603 b: <String as ::lencode::prelude::Decode>::decode_ext(reader, ctx.as_deref_mut())?,
604 })
605 }
606 }
607 };
608 assert_eq!(derived.to_string(), expected.to_string());
609}
610
611#[test]
612fn test_derive_pack_named_struct() {
613 let tokens = quote! {
614 struct Point {
615 x: u32,
616 y: u32,
617 }
618 };
619 let derived = derive_pack_impl(tokens).unwrap();
620 let expected = quote! {
621 impl ::lencode::pack::Pack for Point {
622 #[inline(always)]
623 fn pack(&self, writer: &mut impl ::lencode::io::Write) -> ::lencode::Result<usize> {
624 let mut total = 0usize;
625 total += <u32 as ::lencode::pack::Pack>::pack(&self.x, writer)?;
626 total += <u32 as ::lencode::pack::Pack>::pack(&self.y, writer)?;
627 Ok(total)
628 }
629
630 #[inline(always)]
631 fn unpack(reader: &mut impl ::lencode::io::Read) -> ::lencode::Result<Self> {
632 Ok(Point {
633 x: <u32 as ::lencode::pack::Pack>::unpack(reader)?,
634 y: <u32 as ::lencode::pack::Pack>::unpack(reader)?,
635 })
636 }
637 }
638 };
639 assert_eq!(derived.to_string(), expected.to_string());
640}
641
642#[test]
643fn test_derive_pack_transparent_tuple_struct() {
644 let tokens = quote! {
645 #[repr(transparent)]
646 struct MyKey([u8; 32]);
647 };
648 let derived = derive_pack_impl(tokens).unwrap();
649 let s = derived.to_string();
651 assert!(
652 s.contains("pack_slice"),
653 "should contain pack_slice override"
654 );
655 assert!(
656 s.contains("unpack_vec"),
657 "should contain unpack_vec override"
658 );
659 assert!(
660 s.contains("transmute"),
661 "should contain transmute for bulk decode"
662 );
663 assert!(
664 s.contains("from_raw_parts"),
665 "should contain from_raw_parts for bulk encode"
666 );
667}