npsd_schema/
lib.rs

1//! # `npsd` Derive Macros
2//! 
3//! This module provides a set of custom derive macros to simplify the implementation of traits required by the `npsd` framework.
4//! These macros automate the generation of boilerplate code for various payload processing tasks, including serialization, deserialization, and payload conversion.
5//!
6//! ## Available Macros
7//!
8//! ### `#[derive(Info)]`
9//! Generates an implementation of the `PayloadInfo` trait, which provides metadata about the payload type.
10//!
11//! ### `#[derive(Schema)]`
12//! Generates implementations for payload processing traits such as `IntoPayload`, `FromPayload`, and `Payload` for public use.
13//!
14//! ### `#[derive(Bitmap)]`
15//! Generates implementations for payload processing traits for bitmap structures with up to 8 fields.
16//!
17//! ### `#[derive(AsyncSchema)]`
18//! Generates asynchronous implementations for payload processing traits such as `AsyncIntoPayload`, `AsyncFromPayload`, and `AsyncPayload` for public use.
19//!
20//! ### `#[derive(AsyncBitmap)]`
21//! Generates asynchronous implementations for payload processing traits for bitmap structures with up to 8 fields.
22
23#[doc(hidden)]
24use syn::{parse_macro_input, TypeParam, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Plus, Data, DataEnum, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident, Index, Lifetime, LifetimeParam, TypeParamBound};
25#[doc(hidden)]
26use quote::{quote, quote_spanned};
27#[doc(hidden)]
28use proc_macro::TokenStream;
29#[doc(hidden)]
30use proc_macro2::Span;
31
32const DEFAULT_LIFETIME: &'static str = "'__payload";
33const DEFAULT_SCOPE_LIFETIME: &'static str = "'__payload_scope";
34const DEFAULT_CONTEXT: &'static str = "__PayloadCtx";
35const DEFAULT_MIDDLEWARE: &'static str = "__PayloadMw";
36
37#[doc(hidden)]
38fn resolve_lifetime(generics: &Generics, lifetime_name: &str) -> (bool, Lifetime) {
39    if let Some(existing_lifetime) = generics.params.iter().find_map(|param| {
40        if let GenericParam::Lifetime(lifetime_param) = param {
41            Some(lifetime_param.lifetime.clone())
42        } else {
43            None
44        }
45    }) {
46        return (true, existing_lifetime);
47    }
48
49    (false, Lifetime::new(lifetime_name, Span::call_site()))
50}
51
52#[doc(hidden)]
53fn has_bound(bounds: &Punctuated<TypeParamBound, Plus>, bound_to_check: &str) -> bool {
54    bounds.iter().any(|bound| {
55        if let TypeParamBound::Trait(trait_bound) = bound {
56            trait_bound.path.segments.iter().any(|segment| segment.ident == bound_to_check)
57        } else {
58            false
59        }
60    })
61}
62
63#[doc(hidden)]
64fn schema_into_impl(generics: &mut Generics, internal: bool, context: &Ident) {
65    for param in generics.params.iter_mut() {
66        if let GenericParam::Type(type_param) = param {
67            if type_param.ident == DEFAULT_CONTEXT {
68                continue;
69            }
70
71            if !has_bound(&type_param.bounds, "IntoPayload") {
72                type_param.bounds.push(if internal {
73                    parse_quote!(IntoPayload<#context>)
74                } else {
75                    parse_quote!(npsd::IntoPayload<#context>)
76                });
77            }
78        }
79    }
80}
81
82#[doc(hidden)]
83fn schema_from_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
84    for param in generics.params.iter_mut() {
85        if let GenericParam::Type(type_param) = param {
86            if type_param.ident == DEFAULT_CONTEXT {
87                continue;
88            }
89
90            if !has_bound(&type_param.bounds, "FromPayload") {
91                type_param.bounds.push(if internal {
92                    parse_quote!(FromPayload<#lifetime, #context>)
93                } else {
94                    parse_quote!(npsd::FromPayload<#lifetime, #context>)
95                });
96            }
97        }
98    }
99}
100
101#[doc(hidden)]
102fn schema_payload_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
103    for param in generics.params.iter_mut() {
104        if let GenericParam::Type(type_param) = param {
105            if type_param.ident == DEFAULT_CONTEXT {
106                continue;
107            }
108
109            if !has_bound(&type_param.bounds, "Payload") {
110                type_param.bounds.push(if internal {
111                    parse_quote!(Payload<#lifetime, #context>)
112                } else {
113                    parse_quote!(npsd::Payload<#lifetime, #context>)
114                });
115            }
116        }
117    }
118}
119
120#[doc(hidden)]
121fn async_schema_into_impl(generics: &mut Generics, internal: bool, context: &Ident) {
122    for param in generics.params.iter_mut() {
123        if let GenericParam::Type(type_param) = param {
124            if type_param.ident == DEFAULT_CONTEXT {
125                continue;
126            }
127
128            if !has_bound(&type_param.bounds, "AsyncIntoPayload") {
129                type_param.bounds.push(if internal {
130                    parse_quote!(AsyncIntoPayload<#context>)
131                } else {
132                    parse_quote!(npsd::AsyncIntoPayload<#context>)
133                });
134            }
135        }
136    }
137}
138
139#[doc(hidden)]
140fn async_schema_from_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
141    for param in generics.params.iter_mut() {
142        if let GenericParam::Type(type_param) = param {
143            if type_param.ident == DEFAULT_CONTEXT {
144                continue;
145            }
146
147            if !has_bound(&type_param.bounds, "AsyncFromPayload") {
148                type_param.bounds.push(if internal {
149                    parse_quote!(AsyncFromPayload<#lifetime, #context>)
150                } else {
151                    parse_quote!(npsd::AsyncFromPayload<#lifetime, #context>)
152                });
153            }
154        }
155    }
156}
157
158#[doc(hidden)]
159fn async_schema_payload_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
160    for param in generics.params.iter_mut() {
161        if let GenericParam::Type(type_param) = param {
162            if type_param.ident == DEFAULT_CONTEXT {
163                continue;
164            }
165
166            if !has_bound(&type_param.bounds, "AsyncPayload") {
167                type_param.bounds.push(if internal {
168                    parse_quote!(AsyncPayload<#lifetime, #context>)
169                } else {
170                    parse_quote!(npsd::AsyncPayload<#lifetime, #context>)
171                });
172            }
173        }
174    }
175}
176
177#[doc(hidden)]
178#[proc_macro_derive(Info)]
179pub fn payload_info_public_impl(input: TokenStream) -> TokenStream {
180    let DeriveInput { ident, generics, .. } = parse_macro_input!(input);
181    let (generics_impl, ty_generics, where_clause) = generics.split_for_impl();
182
183    let gen = quote! {
184        impl #generics_impl npsd::PayloadInfo for #ident #ty_generics #where_clause {
185            const TYPE: &'static str = stringify!(#ident);
186        }
187    };
188
189    gen.into()
190}
191
192#[doc(hidden)]
193#[proc_macro_derive(InfoInternal)]
194pub fn payload_info_intenal_impl(input: TokenStream) -> TokenStream {
195    let DeriveInput { ident, generics, .. } = parse_macro_input!(input);
196    let (generics_impl, ty_generics, where_clause) = generics.split_for_impl();
197
198    let gen = quote! {
199        impl #generics_impl PayloadInfo for #ident #ty_generics #where_clause {
200            const TYPE: &'static str = stringify!(#ident);
201        }
202    };
203
204    gen.into()
205}
206
207#[proc_macro_derive(Schema)]
208pub fn schema_public_impl(input: TokenStream) -> TokenStream {
209    schema_impl(input, false)
210}
211
212#[doc(hidden)]
213#[proc_macro_derive(SchemaInternal)]
214pub fn schema_internal_impl(input: TokenStream) -> TokenStream {
215    schema_impl(input, true)
216}
217
218#[doc(hidden)]
219fn schema_impl(input: TokenStream, internal: bool) -> TokenStream {
220    let DeriveInput { ident, data, generics, .. } = parse_macro_input!(input);
221    let (_, ty_generics, where_clause) = generics.split_for_impl();
222
223    let (lifetime_exist, lifetime) = resolve_lifetime(&generics, DEFAULT_LIFETIME);
224    let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
225    let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
226    let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
227    let mut context_generics = generics.clone();
228
229    let context_param: GenericParam = syn::parse_quote!(#context);
230    context_generics.params.push(context_param);
231
232    let mut into_generics = context_generics.clone();
233    let mut from_generics = context_generics.clone();
234    let mut payload_generics = context_generics.clone();
235
236    if !lifetime_exist {
237        let lifetime_param = LifetimeParam::new(lifetime.clone());
238        from_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
239        payload_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
240    }
241
242    schema_into_impl(&mut into_generics, internal, &context);
243    let (into_impl, _, _) = into_generics.split_for_impl();
244
245    schema_from_impl(&mut from_generics, internal, &lifetime, &context);
246    let (from_impl, _, _) = from_generics.split_for_impl();
247
248    schema_payload_impl(&mut payload_generics, internal, &lifetime, &context);
249    let (payload_impl, _, _) = payload_generics.split_for_impl();
250
251    let sender_block = match data.clone() {
252        Data::Struct(data_struct) => {
253            let fields = match data_struct.fields {
254                Fields::Named(FieldsNamed { named, .. }) => {
255                    named.iter().map(|f| {
256                        let name = &f.ident;
257                        let span = f.span();
258
259                        quote_spanned! { span =>
260                            next.into_payload(&self.#name, ctx)?;
261                        }
262                    }).collect::<Vec<_>>()
263                },
264
265                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
266                    unnamed.iter().enumerate().map(|(i, _)| {
267                        let index = Index::from(i);
268                        let span = index.span();
269
270                        quote_spanned! { span =>
271                            next.into_payload(&self.#index, ctx)?;
272                        }
273                    }).collect::<Vec<_>>()
274                },
275
276                Fields::Unit => Vec::new(),
277            };
278
279            quote! { #( #fields )* }
280        },
281        Data::Enum(DataEnum { variants, .. }) => {
282            let variant_cases = variants.iter().enumerate().map(|(index, variant)| {
283                let variant_ident = &variant.ident;
284                let variant_span = variant.span(); 
285
286                match &variant.fields {
287                    Fields::Named(FieldsNamed { named, .. }) => {
288                        let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = named.iter()
289                            .map(|f| {
290                                let name = f.ident.as_ref().unwrap();
291                                let span = name.span();
292                                let pattern = quote_spanned! { span => #name };
293                                let serialization = quote_spanned! { span => next.into_payload(&#name, ctx)?; };
294                                (pattern, serialization)
295                            }).unzip();
296
297                        quote_spanned! { variant_span => 
298                            #ident::#variant_ident { #(#field_patterns,)* } => {
299                                next.into_payload(&#index, ctx)?;
300                                #( #field_serializations )*
301                            }
302                        }
303                    },
304                    Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
305                        let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = unnamed.iter().enumerate()
306                            .map(|(i, _)| {
307                                let field_name = Ident::new(&format!("__self_{}", i), Span::call_site());
308                                let pattern = quote! { #field_name };
309                                let serialization = quote! { next.into_payload(&#field_name, ctx)?; };
310                                (pattern, serialization)
311                            }).unzip();
312                    
313                        quote_spanned! { variant_span => 
314                            #ident::#variant_ident( #( #field_patterns, )* ) => {
315                                next.into_payload(&#index, ctx)?;
316                                #( #field_serializations )*
317                            }
318                        }
319                    },
320                    Fields::Unit => {
321                        quote_spanned! { variant_span => 
322                            #ident::#variant_ident => {
323                                next.into_payload(&#index, ctx)?;
324                            }
325                        }
326                    },
327                }
328            });
329
330            quote! {
331                match self {
332                    #( #variant_cases, )*
333                }
334            }
335        },
336        Data::Union(_) => {
337            return quote! {
338                compile_error!("Union types are not supported by this macro.");
339            }.into();
340        },
341    };
342
343    let receiver_block = match data.clone() {
344        Data::Struct(data_struct) => {
345            match data_struct.fields {
346                Fields::Named(FieldsNamed { named, .. }) => {
347                    let fields = named.iter().map(|f| {
348                        let field = &f.ident;
349                        let ty = &f.ty;
350                        let span = f.span();
351
352                        quote_spanned! { span =>
353                            #field: next.from_payload::<#context, #ty>(ctx)? // as #ty
354                        }
355                    }).collect::<Vec<_>>();
356
357                    quote! {
358                        Ok(#ident {
359                            #( #fields ),*
360                        })
361                    }
362                },
363                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
364                    let fields = unnamed.iter().enumerate().map(|(_, f)| {
365                        let ty = &f.ty;
366
367                        quote! {
368                            next.from_payload::<#context, #ty>(ctx)? // as #ty
369                        }
370                    }).collect::<Vec<_>>();
371
372                    quote! {
373                        Ok(#ident (
374                            #( #fields ),*
375                        ))
376                    }
377                },
378                Fields::Unit => {
379                    quote! {
380                        Ok(#ident)
381                    }
382                },
383            }
384        },
385        Data::Enum(DataEnum { variants, .. }) => {
386            let match_variants = variants.iter().enumerate().map(|(index, variant)| {
387                let variant_ident = &variant.ident;
388                
389                match &variant.fields {
390                    Fields::Named(FieldsNamed { named, .. }) => {
391                        let deserializations = named.iter().map(|f| {
392                            let name = &f.ident;
393                            let ty = &f.ty;
394        
395                            quote! {
396                                #name: next.from_payload::<#context, #ty>(ctx)? // as #ty
397                            }
398                        });
399                        
400                        quote! {
401                            #index => Ok(#ident::#variant_ident { #(#deserializations),* })
402                        }
403                    },
404                    Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
405                        let deserializations = unnamed.iter().map(|f| {
406                            let ty = &f.ty;
407        
408                            quote! {
409                                next.from_payload::<#context, #ty>(ctx)? // as #ty
410                            }
411                        });
412                        
413                        quote! {
414                            #index => Ok(#ident::#variant_ident( #(#deserializations),* ))
415                        }
416                    },
417                    Fields::Unit => {
418                        quote! {
419                            #index => Ok(#ident::#variant_ident)
420                        }
421                    },
422                }
423            }).collect::<Vec<_>>();
424        
425            if internal {
426                quote! {
427                    let variant_index: usize = next.from_payload(ctx)?;
428            
429                    match variant_index {
430                        #(#match_variants,)*
431                        _ => Err(Error::UnknownVariant("Index out of bounds for enum".to_string())),
432                    }
433                }
434            } else {
435                quote! {
436                    let variant_index: usize = next.from_payload(ctx)?;
437            
438                    match variant_index {
439                        #(#match_variants,)*
440                        _ => Err(npsd::Error::UnknownVariant("Index out of bounds for enum".to_string())),
441                    }
442                }
443            }
444        },
445        Data::Union(_) => {
446            return quote! {
447                compile_error!("Union types are not supported by this macro.");
448            }.into();
449        },
450    };
451
452    let gen = if internal {
453        quote! {
454            impl #into_impl IntoPayload<#context> for #ident #ty_generics #where_clause {
455                fn into_payload<#scope, #mw: Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
456                    #sender_block
457                    Ok(())
458                }
459            }
460
461            impl #from_impl FromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
462                fn from_payload<#mw: Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
463                    #receiver_block
464                }
465            }
466
467            impl #payload_impl Payload<#lifetime, #context> for #ident #ty_generics #where_clause {}
468        }
469    } else {
470        quote! {
471            impl #into_impl npsd::IntoPayload<#context> for #ident #ty_generics #where_clause {
472                fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
473                    #sender_block
474                    Ok(())
475                }
476            }
477
478            impl #from_impl npsd::FromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
479                fn from_payload<#mw: npsd::Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
480                    #receiver_block
481                }
482            }
483
484            impl #payload_impl npsd::Payload<#lifetime, #context> for #ident #ty_generics #where_clause {}
485        }
486    };
487
488    gen.into()
489}
490
491#[proc_macro_derive(Bitmap)]
492pub fn bitmap_derive(input: TokenStream) -> TokenStream {
493    bitmap_impl(input, false)
494}
495
496#[doc(hidden)]
497#[proc_macro_derive(BitmapInternal)]
498pub fn bitmap_internal_derive(input: TokenStream) -> TokenStream {
499    bitmap_impl(input, true)
500}
501
502#[doc(hidden)]
503fn bitmap_impl(input: TokenStream, internal: bool) -> TokenStream {
504    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
505
506    let fields = match data {
507        Data::Struct(ref data_struct) => &data_struct.fields,
508        _ => {
509            return quote! {
510                compile_error!("Bitmap can only be derived for structs with named or unnamed fields");
511            }.into();
512        } 
513    };
514
515    let field_count = match fields {
516        Fields::Named(ref named_fields) => named_fields.named.len(),
517        Fields::Unnamed(ref unnamed_fields) => unnamed_fields.unnamed.len(),
518        Fields::Unit => 0,
519    };
520
521    if field_count > 8 {
522        return quote! {
523            compile_error!("Bitmap can only be derived for structs with no more than 8 fields");
524        }.into();
525    }
526
527    let lifetime = Lifetime::new(DEFAULT_LIFETIME, Span::call_site());
528    let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
529
530    let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
531    let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
532
533    let into_payload_impl = generate_into_payload_impl(&ident, &fields, &scope, &context, &mw, internal);
534    let from_payload_impl = generate_from_payload_impl(&ident, &fields, &lifetime, &context, &mw, internal);
535    let payload_impl = generate_payload_impl(&ident, &lifetime,&context, internal);
536
537    let expanded = quote! {
538        #into_payload_impl
539        #from_payload_impl
540        #payload_impl
541    };
542
543    TokenStream::from(expanded)
544}
545
546#[doc(hidden)]
547fn generate_into_payload_impl(name: &Ident, fields: &Fields, scope: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
548    let field_conversions = match fields {
549        Fields::Named(FieldsNamed { named, .. }) => {
550            named.iter().enumerate().map(|(i, f)| {
551                let field_name = &f.ident;
552                let bit_position = i as u8;
553
554                quote! {
555                    if self.#field_name {
556                        byte |= 1 << #bit_position;
557                    }
558                }
559            }).collect::<Vec<_>>()
560        },
561        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
562            unnamed.iter().enumerate().map(|(i, _)| {
563                let field_name = Index::from(i);
564                let bit_position = i as u8;
565
566                quote! {
567                    if self.#field_name {
568                        byte |= 1 << #bit_position;
569                    }
570                }
571            }).collect::<Vec<_>>()
572        },
573        Fields::Unit => vec![],
574    };
575
576    if internal {
577        quote! {
578            impl<#context> IntoPayload<#context> for #name {
579                fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
580                    let mut byte: u8 = 0;
581                    #(#field_conversions)*
582                    next.into_payload(&byte, ctx)
583                }
584            }
585        }
586    } else {
587        quote! {
588            impl<#context> npsd::IntoPayload<#context> for #name {
589                fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
590                    let mut byte: u8 = 0;
591                    #(#field_conversions)*
592                    next.into_payload(&byte, ctx)
593                }
594            }
595        }
596    }
597}
598
599#[doc(hidden)]
600fn generate_from_payload_impl(name: &Ident, fields: &Fields, lifetime: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
601    let field_assignments = match fields {
602        Fields::Named(FieldsNamed { named, .. }) => {
603            named.iter().enumerate().map(|(i, f)| {
604                let field_name = &f.ident;
605                let bit_position = i as u8;
606
607                quote! {
608                    #field_name: (byte & (1 << #bit_position)) != 0
609                }
610            }).collect::<Vec<_>>()
611        },
612        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
613            unnamed.iter().enumerate().map(|(i, _)| {
614                let field_name = Index::from(i);
615                let bit_position = i as u8;
616
617                quote! {
618                    #field_name: (byte & (1 << #bit_position)) != 0
619                }
620            }).collect::<Vec<_>>()
621        },
622        Fields::Unit => vec![],
623    };
624
625    if internal {
626        quote! {
627            impl<#lifetime, #context> FromPayload<#lifetime, #context> for #name {
628                fn from_payload<#mw: Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
629                    let byte: u8 = next.from_payload(ctx)?;
630
631                    Ok(#name {
632                        #(#field_assignments),*
633                    })
634                }
635            }
636        }
637    } else {
638        quote! {
639            impl<#lifetime, #context> npsd::FromPayload<#lifetime, #context> for #name {
640                fn from_payload<#mw: npsd::Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
641                    let byte: u8 = next.from_payload(ctx)?;
642
643                    Ok(#name {
644                        #(#field_assignments),*
645                    })
646                }
647            }
648        }
649    }
650}
651
652#[doc(hidden)]
653fn generate_payload_impl(name: &Ident, lifetime: &Lifetime, context: &Ident, internal: bool) -> proc_macro2::TokenStream {
654    if internal {
655        quote! {
656            impl<#lifetime, #context> Payload<#lifetime, #context> for #name {}
657        }
658    } else {
659        quote! {
660            impl<#lifetime, #context> npsd::Payload<#lifetime, #context> for #name {}
661        }
662    }
663}
664
665#[proc_macro_derive(AsyncSchema)]
666pub fn async_schema_public_impl(input: TokenStream) -> TokenStream {
667    async_schema_impl(input, false)
668}
669
670#[doc(hidden)]
671#[proc_macro_derive(AsyncSchemaInternal)]
672pub fn async_schema_internal_impl(input: TokenStream) -> TokenStream {
673    async_schema_impl(input, true)
674}
675
676#[doc(hidden)]
677fn async_schema_impl(input: TokenStream, internal: bool) -> TokenStream {
678    let DeriveInput { ident, data, generics, .. } = parse_macro_input!(input);
679    let (_, ty_generics, where_clause) = generics.split_for_impl();
680
681    let (lifetime_exist, lifetime) = resolve_lifetime(&generics, DEFAULT_LIFETIME);
682    let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
683    let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
684    let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
685    let mut context_generics = generics.clone();
686
687    let mut context_param: TypeParam = syn::parse_quote!(#context);
688
689    let send_bound: TypeParamBound = syn::parse_quote!(Send);
690    let sync_bound: TypeParamBound = syn::parse_quote!(Sync);
691
692    context_param.bounds.push(send_bound);
693    context_param.bounds.push(sync_bound);
694
695    context_generics.params.push(GenericParam::Type(context_param));
696
697    let mut into_generics = context_generics.clone();
698    let mut from_generics = context_generics.clone();
699    let mut payload_generics = context_generics.clone();
700
701    if !lifetime_exist {
702        let lifetime_param = LifetimeParam::new(lifetime.clone());
703        from_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
704        payload_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
705    }
706    
707    async_schema_into_impl(&mut into_generics, internal, &context);
708    let (into_impl, _, _) = into_generics.split_for_impl();
709
710    async_schema_from_impl(&mut from_generics, internal, &lifetime, &context);
711    let (from_impl, _, _) = from_generics.split_for_impl();
712
713    async_schema_payload_impl(&mut payload_generics, internal, &lifetime, &context);
714    let (payload_impl, _, _) = payload_generics.split_for_impl();
715
716    let sender_block = match data.clone() {
717        Data::Struct(data_struct) => {
718            let fields = match data_struct.fields {
719                Fields::Named(FieldsNamed { named, .. }) => {
720                    named.iter().map(|f| {
721                        let name = &f.ident;
722                        let span = f.span();
723
724                        quote_spanned! { span =>
725                            next.poll_into_payload(&self.#name, ctx).await?;
726                        }
727                    }).collect::<Vec<_>>()
728                },
729
730                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
731                    unnamed.iter().enumerate().map(|(i, _)| {
732                        let index = Index::from(i);
733                        let span = index.span();
734
735                        quote_spanned! { span =>
736                            next.poll_into_payload(&self.#index, ctx).await?;
737                        }
738                    }).collect::<Vec<_>>()
739                },
740
741                Fields::Unit => Vec::new(),
742            };
743
744            quote! { #( #fields )* }
745        },
746        Data::Enum(DataEnum { variants, .. }) => {
747            let variant_cases = variants.iter().enumerate().map(|(index, variant)| {
748                let variant_ident = &variant.ident;
749                let variant_span = variant.span(); 
750
751                match &variant.fields {
752                    Fields::Named(FieldsNamed { named, .. }) => {
753                        let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = named.iter()
754                            .map(|f| {
755                                let name = f.ident.as_ref().unwrap();
756                                let span = name.span();
757                                let pattern = quote_spanned! { span => #name };
758                                let serialization = quote_spanned! { span => next.poll_into_payload(&#name, ctx).await?; };
759                                (pattern, serialization)
760                            }).unzip();
761
762                        quote_spanned! { variant_span => 
763                            #ident::#variant_ident { #(#field_patterns,)* } => {
764                                next.poll_into_payload(&#index, ctx).await?;
765                                #( #field_serializations )*
766                            }
767                        }
768                    },
769                    Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
770                        let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = unnamed.iter().enumerate()
771                            .map(|(i, _)| {
772                                let field_name = Ident::new(&format!("__self_{}", i), Span::call_site());
773                                let pattern = quote! { #field_name };
774                                let serialization = quote! { next.poll_into_payload(&#field_name, ctx).await?; };
775                                (pattern, serialization)
776                            }).unzip();
777                    
778                        quote_spanned! { variant_span => 
779                            #ident::#variant_ident( #( #field_patterns, )* ) => {
780                                next.poll_into_payload(&#index, ctx).await?;
781                                #( #field_serializations )*
782                            }
783                        }
784                    },
785                    Fields::Unit => {
786                        quote_spanned! { variant_span => 
787                            #ident::#variant_ident => {
788                                next.poll_into_payload(&#index, ctx).await?;
789                            }
790                        }
791                    },
792                }
793            });
794
795            quote! {
796                match self {
797                    #( #variant_cases, )*
798                }
799            }
800        },
801        Data::Union(_) => {
802            return quote! {
803                compile_error!("Union types are not supported by this macro.");
804            }.into();
805        },
806    };
807
808    let receiver_block = match data.clone() {
809        Data::Struct(data_struct) => {
810            match data_struct.fields {
811                Fields::Named(FieldsNamed { named, .. }) => {
812                    let fields = named.iter().map(|f| {
813                        let field = &f.ident;
814                        let ty = &f.ty;
815                        let span = f.span();
816
817                        quote_spanned! { span =>
818                            #field: next.poll_from_payload::<#context, #ty>(ctx).await? // as #ty
819                        }
820                    }).collect::<Vec<_>>();
821
822                    quote! {
823                        Ok(#ident {
824                            #( #fields ),*
825                        })
826                    }
827                },
828                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
829                    let fields = unnamed.iter().enumerate().map(|(_, f)| {
830                        let ty = &f.ty;
831
832                        quote! {
833                            next.poll_from_payload::<#context, #ty>(ctx).await? // as #ty
834                        }
835                    }).collect::<Vec<_>>();
836
837                    quote! {
838                        Ok(#ident (
839                            #( #fields ),*
840                        ))
841                    }
842                },
843                Fields::Unit => {
844                    quote! {
845                        Ok(#ident)
846                    }
847                },
848            }
849        },
850        Data::Enum(DataEnum { variants, .. }) => {
851            let match_variants = variants.iter().enumerate().map(|(index, variant)| {
852                let variant_ident = &variant.ident;
853                
854                match &variant.fields {
855                    Fields::Named(FieldsNamed { named, .. }) => {
856                        let deserializations = named.iter().map(|f| {
857                            let name = &f.ident;
858                            let ty = &f.ty;
859        
860                            quote! {
861                                #name: next.poll_from_payload::<#context, #ty>(ctx).await? // as #ty
862                            }
863                        });
864                        
865                        quote! {
866                            #index => Ok(#ident::#variant_ident { #(#deserializations),* })
867                        }
868                    },
869                    Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
870                        let deserializations = unnamed.iter().map(|f| {
871                            let ty = &f.ty;
872        
873                            quote! {
874                                next.poll_from_payload::<#context, #ty>(ctx).await? // as #ty
875                            }
876                        });
877                        
878                        quote! {
879                            #index => Ok(#ident::#variant_ident( #(#deserializations),* ))
880                        }
881                    },
882                    Fields::Unit => {
883                        quote! {
884                            #index => Ok(#ident::#variant_ident)
885                        }
886                    },
887                }
888            }).collect::<Vec<_>>();
889        
890            if internal {
891                quote! {
892                    let variant_index: usize = next.poll_from_payload(ctx).await?;
893            
894                    match variant_index {
895                        #(#match_variants,)*
896                        _ => Err(Error::UnknownVariant("Index out of bounds for enum".to_string())),
897                    }
898                }
899            } else {
900                quote! {
901                    let variant_index: usize = next.poll_from_payload(ctx).await?;
902            
903                    match variant_index {
904                        #(#match_variants,)*
905                        _ => Err(npsd::Error::UnknownVariant("Index out of bounds for enum".to_string())),
906                    }
907                }
908            }
909        },
910        Data::Union(_) => {
911            return quote! {
912                compile_error!("Union types are not supported by this macro.");
913            }.into();
914        },
915    };
916
917    let gen = if internal {
918        quote! {
919            impl #into_impl AsyncIntoPayload<#context> for #ident #ty_generics #where_clause {
920                async fn poll_into_payload<#scope, #mw: AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
921                    #sender_block
922                    Ok(())
923                }
924            }
925
926            impl #from_impl AsyncFromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
927                async fn poll_from_payload<#mw: AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
928                    #receiver_block
929                }
930            }
931
932            impl #payload_impl AsyncPayload<#lifetime, #context> for #ident #ty_generics #where_clause {}
933        }
934    } else {
935        quote! {
936            impl #into_impl npsd::AsyncIntoPayload<#context> for #ident #ty_generics #where_clause {
937                async fn poll_into_payload<#scope, #mw: npsd::AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
938                    #sender_block
939                    Ok(())
940                }
941            }
942
943            impl #from_impl npsd::AsyncFromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
944                async fn poll_from_payload<#mw: npsd::AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
945                    #receiver_block
946                }
947            }
948
949            impl #payload_impl npsd::AsyncPayload<#lifetime, #context> for #ident #ty_generics #where_clause {}
950        }
951    };
952
953    gen.into()
954}
955
956
957#[proc_macro_derive(AsyncBitmap)]
958pub fn async_bitmap_derive(input: TokenStream) -> TokenStream {
959    async_bitmap_impl(input, false)
960}
961
962#[doc(hidden)]
963#[proc_macro_derive(AsyncBitmapInternal)]
964pub fn async_bitmap_internal_derive(input: TokenStream) -> TokenStream {
965    async_bitmap_impl(input, true)
966}
967
968#[doc(hidden)]
969fn async_bitmap_impl(input: TokenStream, internal: bool) -> TokenStream {
970    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
971
972    let fields = match data {
973        Data::Struct(ref data_struct) => &data_struct.fields,
974        _ => {
975            return quote! {
976                compile_error!("Bitmap can only be derived for structs with named or unnamed fields");
977            }.into();
978        } 
979    };
980
981    let field_count = match fields {
982        Fields::Named(ref named_fields) => named_fields.named.len(),
983        Fields::Unnamed(ref unnamed_fields) => unnamed_fields.unnamed.len(),
984        Fields::Unit => 0,
985    };
986
987    if field_count > 8 {
988        return quote! {
989            compile_error!("Bitmap can only be derived for structs with no more than 8 fields");
990        }.into();
991    }
992
993    let lifetime = Lifetime::new(DEFAULT_LIFETIME, Span::call_site());
994    let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
995
996    let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
997    let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
998
999    let into_payload_impl = async_generate_into_payload_impl(&ident, &fields, &scope, &context, &mw, internal);
1000    let from_payload_impl = async_generate_from_payload_impl(&ident, &fields, &lifetime, &context, &mw, internal);
1001    let payload_impl = async_generate_payload_impl(&ident, &lifetime, &context, internal);
1002
1003    let expanded = quote! {
1004        #into_payload_impl
1005        #from_payload_impl
1006        #payload_impl
1007    };
1008
1009    TokenStream::from(expanded)
1010}
1011
1012#[doc(hidden)]
1013fn async_generate_into_payload_impl(name: &Ident, fields: &Fields, scope: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
1014    let field_conversions = match fields {
1015        Fields::Named(FieldsNamed { named, .. }) => {
1016            named.iter().enumerate().map(|(i, f)| {
1017                let field_name = &f.ident;
1018                let bit_position = i as u8;
1019
1020                quote! {
1021                    if self.#field_name {
1022                        byte |= 1 << #bit_position;
1023                    }
1024                }
1025            }).collect::<Vec<_>>()
1026        },
1027        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
1028            unnamed.iter().enumerate().map(|(i, _)| {
1029                let field_name = Index::from(i);
1030                let bit_position = i as u8;
1031
1032                quote! {
1033                    if self.#field_name {
1034                        byte |= 1 << #bit_position;
1035                    }
1036                }
1037            }).collect::<Vec<_>>()
1038        },
1039        Fields::Unit => vec![],
1040    };
1041
1042    if internal {
1043        quote! {
1044            impl<#context: Send + Sync> AsyncIntoPayload<#context> for #name {
1045                async fn poll_into_payload<#scope, #mw: AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
1046                    let mut byte: u8 = 0;
1047                    #(#field_conversions)*
1048                    next.poll_into_payload(&byte, ctx).await
1049                }
1050            }
1051        }
1052    } else {
1053        quote! {
1054            impl<#context: Send + Sync> npsd::AsyncIntoPayload<#context> for #name {
1055                async fn poll_into_payload<#scope, #mw: npsd::AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
1056                    let mut byte: u8 = 0;
1057                    #(#field_conversions)*
1058                    next.poll_into_payload(&byte, ctx).await
1059                }
1060            }
1061        }
1062    }
1063}
1064
1065#[doc(hidden)]
1066fn async_generate_from_payload_impl(name: &Ident, fields: &Fields, lifetime: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
1067    let field_assignments = match fields {
1068        Fields::Named(FieldsNamed { named, .. }) => {
1069            named.iter().enumerate().map(|(i, f)| {
1070                let field_name = &f.ident;
1071                let bit_position = i as u8;
1072
1073                quote! {
1074                    #field_name: (byte & (1 << #bit_position)) != 0
1075                }
1076            }).collect::<Vec<_>>()
1077        },
1078        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
1079            unnamed.iter().enumerate().map(|(i, _)| {
1080                let field_name = Index::from(i);
1081                let bit_position = i as u8;
1082
1083                quote! {
1084                    #field_name: (byte & (1 << #bit_position)) != 0
1085                }
1086            }).collect::<Vec<_>>()
1087        },
1088        Fields::Unit => vec![],
1089    };
1090
1091    if internal {
1092        quote! {
1093            impl<#lifetime, #context: Send + Sync> AsyncFromPayload<#lifetime, #context> for #name {
1094                async fn poll_from_payload<#mw: AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
1095                    let byte: u8 = next.poll_from_payload(ctx).await?;
1096
1097                    Ok(#name {
1098                        #(#field_assignments),*
1099                    })
1100                }
1101            }
1102        }
1103    } else {
1104        quote! {
1105            impl<#lifetime, #context: Send + Sync> npsd::AsyncFromPayload<#lifetime, #context> for #name {
1106                async fn poll_from_payload<#mw: npsd::AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
1107                    let byte: u8 = next.poll_from_payload(ctx).await?;
1108
1109                    Ok(#name {
1110                        #(#field_assignments),*
1111                    })
1112                }
1113            }
1114        }
1115    }
1116}
1117
1118#[doc(hidden)]
1119fn async_generate_payload_impl(name: &Ident, lifetime: &Lifetime, context: &Ident, internal: bool) -> proc_macro2::TokenStream {
1120    if internal {
1121        quote! {
1122            impl<#lifetime, #context: Send + Sync> AsyncPayload<#lifetime, #context> for #name {}
1123        }
1124    } else {
1125        quote! {
1126            impl<#lifetime, #context: Send + Sync> npsd::AsyncPayload<#lifetime, #context> for #name {}
1127        }
1128    }
1129}