1#[doc(hidden)]
24use syn::{parse_macro_input, TypeParam, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Plus, Data, DataEnum, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident, Index, Lifetime, LifetimeParam, TypeParamBound};
25#[doc(hidden)]
26use quote::{quote, quote_spanned};
27#[doc(hidden)]
28use proc_macro::TokenStream;
29#[doc(hidden)]
30use proc_macro2::Span;
31
32const DEFAULT_LIFETIME: &'static str = "'__payload";
33const DEFAULT_SCOPE_LIFETIME: &'static str = "'__payload_scope";
34const DEFAULT_CONTEXT: &'static str = "__PayloadCtx";
35const DEFAULT_MIDDLEWARE: &'static str = "__PayloadMw";
36
37#[doc(hidden)]
38fn resolve_lifetime(generics: &Generics, lifetime_name: &str) -> (bool, Lifetime) {
39 if let Some(existing_lifetime) = generics.params.iter().find_map(|param| {
40 if let GenericParam::Lifetime(lifetime_param) = param {
41 Some(lifetime_param.lifetime.clone())
42 } else {
43 None
44 }
45 }) {
46 return (true, existing_lifetime);
47 }
48
49 (false, Lifetime::new(lifetime_name, Span::call_site()))
50}
51
52#[doc(hidden)]
53fn has_bound(bounds: &Punctuated<TypeParamBound, Plus>, bound_to_check: &str) -> bool {
54 bounds.iter().any(|bound| {
55 if let TypeParamBound::Trait(trait_bound) = bound {
56 trait_bound.path.segments.iter().any(|segment| segment.ident == bound_to_check)
57 } else {
58 false
59 }
60 })
61}
62
63#[doc(hidden)]
64fn schema_into_impl(generics: &mut Generics, internal: bool, context: &Ident) {
65 for param in generics.params.iter_mut() {
66 if let GenericParam::Type(type_param) = param {
67 if type_param.ident == DEFAULT_CONTEXT {
68 continue;
69 }
70
71 if !has_bound(&type_param.bounds, "IntoPayload") {
72 type_param.bounds.push(if internal {
73 parse_quote!(IntoPayload<#context>)
74 } else {
75 parse_quote!(npsd::IntoPayload<#context>)
76 });
77 }
78 }
79 }
80}
81
82#[doc(hidden)]
83fn schema_from_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
84 for param in generics.params.iter_mut() {
85 if let GenericParam::Type(type_param) = param {
86 if type_param.ident == DEFAULT_CONTEXT {
87 continue;
88 }
89
90 if !has_bound(&type_param.bounds, "FromPayload") {
91 type_param.bounds.push(if internal {
92 parse_quote!(FromPayload<#lifetime, #context>)
93 } else {
94 parse_quote!(npsd::FromPayload<#lifetime, #context>)
95 });
96 }
97 }
98 }
99}
100
101#[doc(hidden)]
102fn schema_payload_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
103 for param in generics.params.iter_mut() {
104 if let GenericParam::Type(type_param) = param {
105 if type_param.ident == DEFAULT_CONTEXT {
106 continue;
107 }
108
109 if !has_bound(&type_param.bounds, "Payload") {
110 type_param.bounds.push(if internal {
111 parse_quote!(Payload<#lifetime, #context>)
112 } else {
113 parse_quote!(npsd::Payload<#lifetime, #context>)
114 });
115 }
116 }
117 }
118}
119
120#[doc(hidden)]
121fn async_schema_into_impl(generics: &mut Generics, internal: bool, context: &Ident) {
122 for param in generics.params.iter_mut() {
123 if let GenericParam::Type(type_param) = param {
124 if type_param.ident == DEFAULT_CONTEXT {
125 continue;
126 }
127
128 if !has_bound(&type_param.bounds, "AsyncIntoPayload") {
129 type_param.bounds.push(if internal {
130 parse_quote!(AsyncIntoPayload<#context>)
131 } else {
132 parse_quote!(npsd::AsyncIntoPayload<#context>)
133 });
134 }
135 }
136 }
137}
138
139#[doc(hidden)]
140fn async_schema_from_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
141 for param in generics.params.iter_mut() {
142 if let GenericParam::Type(type_param) = param {
143 if type_param.ident == DEFAULT_CONTEXT {
144 continue;
145 }
146
147 if !has_bound(&type_param.bounds, "AsyncFromPayload") {
148 type_param.bounds.push(if internal {
149 parse_quote!(AsyncFromPayload<#lifetime, #context>)
150 } else {
151 parse_quote!(npsd::AsyncFromPayload<#lifetime, #context>)
152 });
153 }
154 }
155 }
156}
157
158#[doc(hidden)]
159fn async_schema_payload_impl(generics: &mut Generics, internal: bool, lifetime: &Lifetime, context: &Ident) {
160 for param in generics.params.iter_mut() {
161 if let GenericParam::Type(type_param) = param {
162 if type_param.ident == DEFAULT_CONTEXT {
163 continue;
164 }
165
166 if !has_bound(&type_param.bounds, "AsyncPayload") {
167 type_param.bounds.push(if internal {
168 parse_quote!(AsyncPayload<#lifetime, #context>)
169 } else {
170 parse_quote!(npsd::AsyncPayload<#lifetime, #context>)
171 });
172 }
173 }
174 }
175}
176
177#[doc(hidden)]
178#[proc_macro_derive(Info)]
179pub fn payload_info_public_impl(input: TokenStream) -> TokenStream {
180 let DeriveInput { ident, generics, .. } = parse_macro_input!(input);
181 let (generics_impl, ty_generics, where_clause) = generics.split_for_impl();
182
183 let gen = quote! {
184 impl #generics_impl npsd::PayloadInfo for #ident #ty_generics #where_clause {
185 const TYPE: &'static str = stringify!(#ident);
186 }
187 };
188
189 gen.into()
190}
191
192#[doc(hidden)]
193#[proc_macro_derive(InfoInternal)]
194pub fn payload_info_intenal_impl(input: TokenStream) -> TokenStream {
195 let DeriveInput { ident, generics, .. } = parse_macro_input!(input);
196 let (generics_impl, ty_generics, where_clause) = generics.split_for_impl();
197
198 let gen = quote! {
199 impl #generics_impl PayloadInfo for #ident #ty_generics #where_clause {
200 const TYPE: &'static str = stringify!(#ident);
201 }
202 };
203
204 gen.into()
205}
206
207#[proc_macro_derive(Schema)]
208pub fn schema_public_impl(input: TokenStream) -> TokenStream {
209 schema_impl(input, false)
210}
211
212#[doc(hidden)]
213#[proc_macro_derive(SchemaInternal)]
214pub fn schema_internal_impl(input: TokenStream) -> TokenStream {
215 schema_impl(input, true)
216}
217
218#[doc(hidden)]
219fn schema_impl(input: TokenStream, internal: bool) -> TokenStream {
220 let DeriveInput { ident, data, generics, .. } = parse_macro_input!(input);
221 let (_, ty_generics, where_clause) = generics.split_for_impl();
222
223 let (lifetime_exist, lifetime) = resolve_lifetime(&generics, DEFAULT_LIFETIME);
224 let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
225 let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
226 let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
227 let mut context_generics = generics.clone();
228
229 let context_param: GenericParam = syn::parse_quote!(#context);
230 context_generics.params.push(context_param);
231
232 let mut into_generics = context_generics.clone();
233 let mut from_generics = context_generics.clone();
234 let mut payload_generics = context_generics.clone();
235
236 if !lifetime_exist {
237 let lifetime_param = LifetimeParam::new(lifetime.clone());
238 from_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
239 payload_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
240 }
241
242 schema_into_impl(&mut into_generics, internal, &context);
243 let (into_impl, _, _) = into_generics.split_for_impl();
244
245 schema_from_impl(&mut from_generics, internal, &lifetime, &context);
246 let (from_impl, _, _) = from_generics.split_for_impl();
247
248 schema_payload_impl(&mut payload_generics, internal, &lifetime, &context);
249 let (payload_impl, _, _) = payload_generics.split_for_impl();
250
251 let sender_block = match data.clone() {
252 Data::Struct(data_struct) => {
253 let fields = match data_struct.fields {
254 Fields::Named(FieldsNamed { named, .. }) => {
255 named.iter().map(|f| {
256 let name = &f.ident;
257 let span = f.span();
258
259 quote_spanned! { span =>
260 next.into_payload(&self.#name, ctx)?;
261 }
262 }).collect::<Vec<_>>()
263 },
264
265 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
266 unnamed.iter().enumerate().map(|(i, _)| {
267 let index = Index::from(i);
268 let span = index.span();
269
270 quote_spanned! { span =>
271 next.into_payload(&self.#index, ctx)?;
272 }
273 }).collect::<Vec<_>>()
274 },
275
276 Fields::Unit => Vec::new(),
277 };
278
279 quote! { #( #fields )* }
280 },
281 Data::Enum(DataEnum { variants, .. }) => {
282 let variant_cases = variants.iter().enumerate().map(|(index, variant)| {
283 let variant_ident = &variant.ident;
284 let variant_span = variant.span();
285
286 match &variant.fields {
287 Fields::Named(FieldsNamed { named, .. }) => {
288 let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = named.iter()
289 .map(|f| {
290 let name = f.ident.as_ref().unwrap();
291 let span = name.span();
292 let pattern = quote_spanned! { span => #name };
293 let serialization = quote_spanned! { span => next.into_payload(&#name, ctx)?; };
294 (pattern, serialization)
295 }).unzip();
296
297 quote_spanned! { variant_span =>
298 #ident::#variant_ident { #(#field_patterns,)* } => {
299 next.into_payload(&#index, ctx)?;
300 #( #field_serializations )*
301 }
302 }
303 },
304 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
305 let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = unnamed.iter().enumerate()
306 .map(|(i, _)| {
307 let field_name = Ident::new(&format!("__self_{}", i), Span::call_site());
308 let pattern = quote! { #field_name };
309 let serialization = quote! { next.into_payload(&#field_name, ctx)?; };
310 (pattern, serialization)
311 }).unzip();
312
313 quote_spanned! { variant_span =>
314 #ident::#variant_ident( #( #field_patterns, )* ) => {
315 next.into_payload(&#index, ctx)?;
316 #( #field_serializations )*
317 }
318 }
319 },
320 Fields::Unit => {
321 quote_spanned! { variant_span =>
322 #ident::#variant_ident => {
323 next.into_payload(&#index, ctx)?;
324 }
325 }
326 },
327 }
328 });
329
330 quote! {
331 match self {
332 #( #variant_cases, )*
333 }
334 }
335 },
336 Data::Union(_) => {
337 return quote! {
338 compile_error!("Union types are not supported by this macro.");
339 }.into();
340 },
341 };
342
343 let receiver_block = match data.clone() {
344 Data::Struct(data_struct) => {
345 match data_struct.fields {
346 Fields::Named(FieldsNamed { named, .. }) => {
347 let fields = named.iter().map(|f| {
348 let field = &f.ident;
349 let ty = &f.ty;
350 let span = f.span();
351
352 quote_spanned! { span =>
353 #field: next.from_payload::<#context, #ty>(ctx)? }
355 }).collect::<Vec<_>>();
356
357 quote! {
358 Ok(#ident {
359 #( #fields ),*
360 })
361 }
362 },
363 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
364 let fields = unnamed.iter().enumerate().map(|(_, f)| {
365 let ty = &f.ty;
366
367 quote! {
368 next.from_payload::<#context, #ty>(ctx)? }
370 }).collect::<Vec<_>>();
371
372 quote! {
373 Ok(#ident (
374 #( #fields ),*
375 ))
376 }
377 },
378 Fields::Unit => {
379 quote! {
380 Ok(#ident)
381 }
382 },
383 }
384 },
385 Data::Enum(DataEnum { variants, .. }) => {
386 let match_variants = variants.iter().enumerate().map(|(index, variant)| {
387 let variant_ident = &variant.ident;
388
389 match &variant.fields {
390 Fields::Named(FieldsNamed { named, .. }) => {
391 let deserializations = named.iter().map(|f| {
392 let name = &f.ident;
393 let ty = &f.ty;
394
395 quote! {
396 #name: next.from_payload::<#context, #ty>(ctx)? }
398 });
399
400 quote! {
401 #index => Ok(#ident::#variant_ident { #(#deserializations),* })
402 }
403 },
404 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
405 let deserializations = unnamed.iter().map(|f| {
406 let ty = &f.ty;
407
408 quote! {
409 next.from_payload::<#context, #ty>(ctx)? }
411 });
412
413 quote! {
414 #index => Ok(#ident::#variant_ident( #(#deserializations),* ))
415 }
416 },
417 Fields::Unit => {
418 quote! {
419 #index => Ok(#ident::#variant_ident)
420 }
421 },
422 }
423 }).collect::<Vec<_>>();
424
425 if internal {
426 quote! {
427 let variant_index: usize = next.from_payload(ctx)?;
428
429 match variant_index {
430 #(#match_variants,)*
431 _ => Err(Error::UnknownVariant("Index out of bounds for enum".to_string())),
432 }
433 }
434 } else {
435 quote! {
436 let variant_index: usize = next.from_payload(ctx)?;
437
438 match variant_index {
439 #(#match_variants,)*
440 _ => Err(npsd::Error::UnknownVariant("Index out of bounds for enum".to_string())),
441 }
442 }
443 }
444 },
445 Data::Union(_) => {
446 return quote! {
447 compile_error!("Union types are not supported by this macro.");
448 }.into();
449 },
450 };
451
452 let gen = if internal {
453 quote! {
454 impl #into_impl IntoPayload<#context> for #ident #ty_generics #where_clause {
455 fn into_payload<#scope, #mw: Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
456 #sender_block
457 Ok(())
458 }
459 }
460
461 impl #from_impl FromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
462 fn from_payload<#mw: Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
463 #receiver_block
464 }
465 }
466
467 impl #payload_impl Payload<#lifetime, #context> for #ident #ty_generics #where_clause {}
468 }
469 } else {
470 quote! {
471 impl #into_impl npsd::IntoPayload<#context> for #ident #ty_generics #where_clause {
472 fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
473 #sender_block
474 Ok(())
475 }
476 }
477
478 impl #from_impl npsd::FromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
479 fn from_payload<#mw: npsd::Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
480 #receiver_block
481 }
482 }
483
484 impl #payload_impl npsd::Payload<#lifetime, #context> for #ident #ty_generics #where_clause {}
485 }
486 };
487
488 gen.into()
489}
490
491#[proc_macro_derive(Bitmap)]
492pub fn bitmap_derive(input: TokenStream) -> TokenStream {
493 bitmap_impl(input, false)
494}
495
496#[doc(hidden)]
497#[proc_macro_derive(BitmapInternal)]
498pub fn bitmap_internal_derive(input: TokenStream) -> TokenStream {
499 bitmap_impl(input, true)
500}
501
502#[doc(hidden)]
503fn bitmap_impl(input: TokenStream, internal: bool) -> TokenStream {
504 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
505
506 let fields = match data {
507 Data::Struct(ref data_struct) => &data_struct.fields,
508 _ => {
509 return quote! {
510 compile_error!("Bitmap can only be derived for structs with named or unnamed fields");
511 }.into();
512 }
513 };
514
515 let field_count = match fields {
516 Fields::Named(ref named_fields) => named_fields.named.len(),
517 Fields::Unnamed(ref unnamed_fields) => unnamed_fields.unnamed.len(),
518 Fields::Unit => 0,
519 };
520
521 if field_count > 8 {
522 return quote! {
523 compile_error!("Bitmap can only be derived for structs with no more than 8 fields");
524 }.into();
525 }
526
527 let lifetime = Lifetime::new(DEFAULT_LIFETIME, Span::call_site());
528 let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
529
530 let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
531 let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
532
533 let into_payload_impl = generate_into_payload_impl(&ident, &fields, &scope, &context, &mw, internal);
534 let from_payload_impl = generate_from_payload_impl(&ident, &fields, &lifetime, &context, &mw, internal);
535 let payload_impl = generate_payload_impl(&ident, &lifetime,&context, internal);
536
537 let expanded = quote! {
538 #into_payload_impl
539 #from_payload_impl
540 #payload_impl
541 };
542
543 TokenStream::from(expanded)
544}
545
546#[doc(hidden)]
547fn generate_into_payload_impl(name: &Ident, fields: &Fields, scope: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
548 let field_conversions = match fields {
549 Fields::Named(FieldsNamed { named, .. }) => {
550 named.iter().enumerate().map(|(i, f)| {
551 let field_name = &f.ident;
552 let bit_position = i as u8;
553
554 quote! {
555 if self.#field_name {
556 byte |= 1 << #bit_position;
557 }
558 }
559 }).collect::<Vec<_>>()
560 },
561 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
562 unnamed.iter().enumerate().map(|(i, _)| {
563 let field_name = Index::from(i);
564 let bit_position = i as u8;
565
566 quote! {
567 if self.#field_name {
568 byte |= 1 << #bit_position;
569 }
570 }
571 }).collect::<Vec<_>>()
572 },
573 Fields::Unit => vec![],
574 };
575
576 if internal {
577 quote! {
578 impl<#context> IntoPayload<#context> for #name {
579 fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
580 let mut byte: u8 = 0;
581 #(#field_conversions)*
582 next.into_payload(&byte, ctx)
583 }
584 }
585 }
586 } else {
587 quote! {
588 impl<#context> npsd::IntoPayload<#context> for #name {
589 fn into_payload<#scope, #mw: npsd::Middleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
590 let mut byte: u8 = 0;
591 #(#field_conversions)*
592 next.into_payload(&byte, ctx)
593 }
594 }
595 }
596 }
597}
598
599#[doc(hidden)]
600fn generate_from_payload_impl(name: &Ident, fields: &Fields, lifetime: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
601 let field_assignments = match fields {
602 Fields::Named(FieldsNamed { named, .. }) => {
603 named.iter().enumerate().map(|(i, f)| {
604 let field_name = &f.ident;
605 let bit_position = i as u8;
606
607 quote! {
608 #field_name: (byte & (1 << #bit_position)) != 0
609 }
610 }).collect::<Vec<_>>()
611 },
612 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
613 unnamed.iter().enumerate().map(|(i, _)| {
614 let field_name = Index::from(i);
615 let bit_position = i as u8;
616
617 quote! {
618 #field_name: (byte & (1 << #bit_position)) != 0
619 }
620 }).collect::<Vec<_>>()
621 },
622 Fields::Unit => vec![],
623 };
624
625 if internal {
626 quote! {
627 impl<#lifetime, #context> FromPayload<#lifetime, #context> for #name {
628 fn from_payload<#mw: Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
629 let byte: u8 = next.from_payload(ctx)?;
630
631 Ok(#name {
632 #(#field_assignments),*
633 })
634 }
635 }
636 }
637 } else {
638 quote! {
639 impl<#lifetime, #context> npsd::FromPayload<#lifetime, #context> for #name {
640 fn from_payload<#mw: npsd::Middleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
641 let byte: u8 = next.from_payload(ctx)?;
642
643 Ok(#name {
644 #(#field_assignments),*
645 })
646 }
647 }
648 }
649 }
650}
651
652#[doc(hidden)]
653fn generate_payload_impl(name: &Ident, lifetime: &Lifetime, context: &Ident, internal: bool) -> proc_macro2::TokenStream {
654 if internal {
655 quote! {
656 impl<#lifetime, #context> Payload<#lifetime, #context> for #name {}
657 }
658 } else {
659 quote! {
660 impl<#lifetime, #context> npsd::Payload<#lifetime, #context> for #name {}
661 }
662 }
663}
664
665#[proc_macro_derive(AsyncSchema)]
666pub fn async_schema_public_impl(input: TokenStream) -> TokenStream {
667 async_schema_impl(input, false)
668}
669
670#[doc(hidden)]
671#[proc_macro_derive(AsyncSchemaInternal)]
672pub fn async_schema_internal_impl(input: TokenStream) -> TokenStream {
673 async_schema_impl(input, true)
674}
675
676#[doc(hidden)]
677fn async_schema_impl(input: TokenStream, internal: bool) -> TokenStream {
678 let DeriveInput { ident, data, generics, .. } = parse_macro_input!(input);
679 let (_, ty_generics, where_clause) = generics.split_for_impl();
680
681 let (lifetime_exist, lifetime) = resolve_lifetime(&generics, DEFAULT_LIFETIME);
682 let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
683 let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
684 let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
685 let mut context_generics = generics.clone();
686
687 let mut context_param: TypeParam = syn::parse_quote!(#context);
688
689 let send_bound: TypeParamBound = syn::parse_quote!(Send);
690 let sync_bound: TypeParamBound = syn::parse_quote!(Sync);
691
692 context_param.bounds.push(send_bound);
693 context_param.bounds.push(sync_bound);
694
695 context_generics.params.push(GenericParam::Type(context_param));
696
697 let mut into_generics = context_generics.clone();
698 let mut from_generics = context_generics.clone();
699 let mut payload_generics = context_generics.clone();
700
701 if !lifetime_exist {
702 let lifetime_param = LifetimeParam::new(lifetime.clone());
703 from_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
704 payload_generics.params.insert(0, GenericParam::Lifetime(lifetime_param.clone()));
705 }
706
707 async_schema_into_impl(&mut into_generics, internal, &context);
708 let (into_impl, _, _) = into_generics.split_for_impl();
709
710 async_schema_from_impl(&mut from_generics, internal, &lifetime, &context);
711 let (from_impl, _, _) = from_generics.split_for_impl();
712
713 async_schema_payload_impl(&mut payload_generics, internal, &lifetime, &context);
714 let (payload_impl, _, _) = payload_generics.split_for_impl();
715
716 let sender_block = match data.clone() {
717 Data::Struct(data_struct) => {
718 let fields = match data_struct.fields {
719 Fields::Named(FieldsNamed { named, .. }) => {
720 named.iter().map(|f| {
721 let name = &f.ident;
722 let span = f.span();
723
724 quote_spanned! { span =>
725 next.poll_into_payload(&self.#name, ctx).await?;
726 }
727 }).collect::<Vec<_>>()
728 },
729
730 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
731 unnamed.iter().enumerate().map(|(i, _)| {
732 let index = Index::from(i);
733 let span = index.span();
734
735 quote_spanned! { span =>
736 next.poll_into_payload(&self.#index, ctx).await?;
737 }
738 }).collect::<Vec<_>>()
739 },
740
741 Fields::Unit => Vec::new(),
742 };
743
744 quote! { #( #fields )* }
745 },
746 Data::Enum(DataEnum { variants, .. }) => {
747 let variant_cases = variants.iter().enumerate().map(|(index, variant)| {
748 let variant_ident = &variant.ident;
749 let variant_span = variant.span();
750
751 match &variant.fields {
752 Fields::Named(FieldsNamed { named, .. }) => {
753 let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = named.iter()
754 .map(|f| {
755 let name = f.ident.as_ref().unwrap();
756 let span = name.span();
757 let pattern = quote_spanned! { span => #name };
758 let serialization = quote_spanned! { span => next.poll_into_payload(&#name, ctx).await?; };
759 (pattern, serialization)
760 }).unzip();
761
762 quote_spanned! { variant_span =>
763 #ident::#variant_ident { #(#field_patterns,)* } => {
764 next.poll_into_payload(&#index, ctx).await?;
765 #( #field_serializations )*
766 }
767 }
768 },
769 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
770 let (field_patterns, field_serializations): (Vec<_>, Vec<_>) = unnamed.iter().enumerate()
771 .map(|(i, _)| {
772 let field_name = Ident::new(&format!("__self_{}", i), Span::call_site());
773 let pattern = quote! { #field_name };
774 let serialization = quote! { next.poll_into_payload(&#field_name, ctx).await?; };
775 (pattern, serialization)
776 }).unzip();
777
778 quote_spanned! { variant_span =>
779 #ident::#variant_ident( #( #field_patterns, )* ) => {
780 next.poll_into_payload(&#index, ctx).await?;
781 #( #field_serializations )*
782 }
783 }
784 },
785 Fields::Unit => {
786 quote_spanned! { variant_span =>
787 #ident::#variant_ident => {
788 next.poll_into_payload(&#index, ctx).await?;
789 }
790 }
791 },
792 }
793 });
794
795 quote! {
796 match self {
797 #( #variant_cases, )*
798 }
799 }
800 },
801 Data::Union(_) => {
802 return quote! {
803 compile_error!("Union types are not supported by this macro.");
804 }.into();
805 },
806 };
807
808 let receiver_block = match data.clone() {
809 Data::Struct(data_struct) => {
810 match data_struct.fields {
811 Fields::Named(FieldsNamed { named, .. }) => {
812 let fields = named.iter().map(|f| {
813 let field = &f.ident;
814 let ty = &f.ty;
815 let span = f.span();
816
817 quote_spanned! { span =>
818 #field: next.poll_from_payload::<#context, #ty>(ctx).await? }
820 }).collect::<Vec<_>>();
821
822 quote! {
823 Ok(#ident {
824 #( #fields ),*
825 })
826 }
827 },
828 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
829 let fields = unnamed.iter().enumerate().map(|(_, f)| {
830 let ty = &f.ty;
831
832 quote! {
833 next.poll_from_payload::<#context, #ty>(ctx).await? }
835 }).collect::<Vec<_>>();
836
837 quote! {
838 Ok(#ident (
839 #( #fields ),*
840 ))
841 }
842 },
843 Fields::Unit => {
844 quote! {
845 Ok(#ident)
846 }
847 },
848 }
849 },
850 Data::Enum(DataEnum { variants, .. }) => {
851 let match_variants = variants.iter().enumerate().map(|(index, variant)| {
852 let variant_ident = &variant.ident;
853
854 match &variant.fields {
855 Fields::Named(FieldsNamed { named, .. }) => {
856 let deserializations = named.iter().map(|f| {
857 let name = &f.ident;
858 let ty = &f.ty;
859
860 quote! {
861 #name: next.poll_from_payload::<#context, #ty>(ctx).await? }
863 });
864
865 quote! {
866 #index => Ok(#ident::#variant_ident { #(#deserializations),* })
867 }
868 },
869 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
870 let deserializations = unnamed.iter().map(|f| {
871 let ty = &f.ty;
872
873 quote! {
874 next.poll_from_payload::<#context, #ty>(ctx).await? }
876 });
877
878 quote! {
879 #index => Ok(#ident::#variant_ident( #(#deserializations),* ))
880 }
881 },
882 Fields::Unit => {
883 quote! {
884 #index => Ok(#ident::#variant_ident)
885 }
886 },
887 }
888 }).collect::<Vec<_>>();
889
890 if internal {
891 quote! {
892 let variant_index: usize = next.poll_from_payload(ctx).await?;
893
894 match variant_index {
895 #(#match_variants,)*
896 _ => Err(Error::UnknownVariant("Index out of bounds for enum".to_string())),
897 }
898 }
899 } else {
900 quote! {
901 let variant_index: usize = next.poll_from_payload(ctx).await?;
902
903 match variant_index {
904 #(#match_variants,)*
905 _ => Err(npsd::Error::UnknownVariant("Index out of bounds for enum".to_string())),
906 }
907 }
908 }
909 },
910 Data::Union(_) => {
911 return quote! {
912 compile_error!("Union types are not supported by this macro.");
913 }.into();
914 },
915 };
916
917 let gen = if internal {
918 quote! {
919 impl #into_impl AsyncIntoPayload<#context> for #ident #ty_generics #where_clause {
920 async fn poll_into_payload<#scope, #mw: AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
921 #sender_block
922 Ok(())
923 }
924 }
925
926 impl #from_impl AsyncFromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
927 async fn poll_from_payload<#mw: AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
928 #receiver_block
929 }
930 }
931
932 impl #payload_impl AsyncPayload<#lifetime, #context> for #ident #ty_generics #where_clause {}
933 }
934 } else {
935 quote! {
936 impl #into_impl npsd::AsyncIntoPayload<#context> for #ident #ty_generics #where_clause {
937 async fn poll_into_payload<#scope, #mw: npsd::AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
938 #sender_block
939 Ok(())
940 }
941 }
942
943 impl #from_impl npsd::AsyncFromPayload<#lifetime, #context> for #ident #ty_generics #where_clause {
944 async fn poll_from_payload<#mw: npsd::AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
945 #receiver_block
946 }
947 }
948
949 impl #payload_impl npsd::AsyncPayload<#lifetime, #context> for #ident #ty_generics #where_clause {}
950 }
951 };
952
953 gen.into()
954}
955
956
957#[proc_macro_derive(AsyncBitmap)]
958pub fn async_bitmap_derive(input: TokenStream) -> TokenStream {
959 async_bitmap_impl(input, false)
960}
961
962#[doc(hidden)]
963#[proc_macro_derive(AsyncBitmapInternal)]
964pub fn async_bitmap_internal_derive(input: TokenStream) -> TokenStream {
965 async_bitmap_impl(input, true)
966}
967
968#[doc(hidden)]
969fn async_bitmap_impl(input: TokenStream, internal: bool) -> TokenStream {
970 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
971
972 let fields = match data {
973 Data::Struct(ref data_struct) => &data_struct.fields,
974 _ => {
975 return quote! {
976 compile_error!("Bitmap can only be derived for structs with named or unnamed fields");
977 }.into();
978 }
979 };
980
981 let field_count = match fields {
982 Fields::Named(ref named_fields) => named_fields.named.len(),
983 Fields::Unnamed(ref unnamed_fields) => unnamed_fields.unnamed.len(),
984 Fields::Unit => 0,
985 };
986
987 if field_count > 8 {
988 return quote! {
989 compile_error!("Bitmap can only be derived for structs with no more than 8 fields");
990 }.into();
991 }
992
993 let lifetime = Lifetime::new(DEFAULT_LIFETIME, Span::call_site());
994 let scope = Lifetime::new(DEFAULT_SCOPE_LIFETIME, Span::call_site());
995
996 let context = Ident::new(DEFAULT_CONTEXT, Span::call_site());
997 let mw = Ident::new(DEFAULT_MIDDLEWARE, Span::call_site());
998
999 let into_payload_impl = async_generate_into_payload_impl(&ident, &fields, &scope, &context, &mw, internal);
1000 let from_payload_impl = async_generate_from_payload_impl(&ident, &fields, &lifetime, &context, &mw, internal);
1001 let payload_impl = async_generate_payload_impl(&ident, &lifetime, &context, internal);
1002
1003 let expanded = quote! {
1004 #into_payload_impl
1005 #from_payload_impl
1006 #payload_impl
1007 };
1008
1009 TokenStream::from(expanded)
1010}
1011
1012#[doc(hidden)]
1013fn async_generate_into_payload_impl(name: &Ident, fields: &Fields, scope: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
1014 let field_conversions = match fields {
1015 Fields::Named(FieldsNamed { named, .. }) => {
1016 named.iter().enumerate().map(|(i, f)| {
1017 let field_name = &f.ident;
1018 let bit_position = i as u8;
1019
1020 quote! {
1021 if self.#field_name {
1022 byte |= 1 << #bit_position;
1023 }
1024 }
1025 }).collect::<Vec<_>>()
1026 },
1027 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
1028 unnamed.iter().enumerate().map(|(i, _)| {
1029 let field_name = Index::from(i);
1030 let bit_position = i as u8;
1031
1032 quote! {
1033 if self.#field_name {
1034 byte |= 1 << #bit_position;
1035 }
1036 }
1037 }).collect::<Vec<_>>()
1038 },
1039 Fields::Unit => vec![],
1040 };
1041
1042 if internal {
1043 quote! {
1044 impl<#context: Send + Sync> AsyncIntoPayload<#context> for #name {
1045 async fn poll_into_payload<#scope, #mw: AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), Error> {
1046 let mut byte: u8 = 0;
1047 #(#field_conversions)*
1048 next.poll_into_payload(&byte, ctx).await
1049 }
1050 }
1051 }
1052 } else {
1053 quote! {
1054 impl<#context: Send + Sync> npsd::AsyncIntoPayload<#context> for #name {
1055 async fn poll_into_payload<#scope, #mw: npsd::AsyncMiddleware<#scope>>(&self, ctx: &mut #context, next: &mut #mw) -> Result<(), npsd::Error> {
1056 let mut byte: u8 = 0;
1057 #(#field_conversions)*
1058 next.poll_into_payload(&byte, ctx).await
1059 }
1060 }
1061 }
1062 }
1063}
1064
1065#[doc(hidden)]
1066fn async_generate_from_payload_impl(name: &Ident, fields: &Fields, lifetime: &Lifetime, context: &Ident, mw: &Ident, internal: bool) -> proc_macro2::TokenStream {
1067 let field_assignments = match fields {
1068 Fields::Named(FieldsNamed { named, .. }) => {
1069 named.iter().enumerate().map(|(i, f)| {
1070 let field_name = &f.ident;
1071 let bit_position = i as u8;
1072
1073 quote! {
1074 #field_name: (byte & (1 << #bit_position)) != 0
1075 }
1076 }).collect::<Vec<_>>()
1077 },
1078 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
1079 unnamed.iter().enumerate().map(|(i, _)| {
1080 let field_name = Index::from(i);
1081 let bit_position = i as u8;
1082
1083 quote! {
1084 #field_name: (byte & (1 << #bit_position)) != 0
1085 }
1086 }).collect::<Vec<_>>()
1087 },
1088 Fields::Unit => vec![],
1089 };
1090
1091 if internal {
1092 quote! {
1093 impl<#lifetime, #context: Send + Sync> AsyncFromPayload<#lifetime, #context> for #name {
1094 async fn poll_from_payload<#mw: AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, Error> {
1095 let byte: u8 = next.poll_from_payload(ctx).await?;
1096
1097 Ok(#name {
1098 #(#field_assignments),*
1099 })
1100 }
1101 }
1102 }
1103 } else {
1104 quote! {
1105 impl<#lifetime, #context: Send + Sync> npsd::AsyncFromPayload<#lifetime, #context> for #name {
1106 async fn poll_from_payload<#mw: npsd::AsyncMiddleware<#lifetime>>(ctx: &mut #context, next: &mut #mw) -> Result<Self, npsd::Error> {
1107 let byte: u8 = next.poll_from_payload(ctx).await?;
1108
1109 Ok(#name {
1110 #(#field_assignments),*
1111 })
1112 }
1113 }
1114 }
1115 }
1116}
1117
1118#[doc(hidden)]
1119fn async_generate_payload_impl(name: &Ident, lifetime: &Lifetime, context: &Ident, internal: bool) -> proc_macro2::TokenStream {
1120 if internal {
1121 quote! {
1122 impl<#lifetime, #context: Send + Sync> AsyncPayload<#lifetime, #context> for #name {}
1123 }
1124 } else {
1125 quote! {
1126 impl<#lifetime, #context: Send + Sync> npsd::AsyncPayload<#lifetime, #context> for #name {}
1127 }
1128 }
1129}