1use darling::{FromField, FromAttributes, FromVariant, FromMeta};
2use proc_macro2::{TokenStream, Span};
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, Type, Ident, parse_str, Path, Fields, Field, GenericParam, Lifetime};
5
6#[derive(Debug, FromAttributes)]
7#[darling(attributes(marshal))]
8struct StructOrEnumReceiver {
9 magic: Option<syn::LitByteStr>,
10 ctx: Option<Path>,
11 tag_type: Option<Path>,
12 tag: Option<String>,
13 tag_bits: Option<usize>,
14}
15
16#[derive(Debug, Clone, Copy, FromMeta)]
17enum CtxType {
18 Coerce,
19 Forward,
20 Construct
21}
22
23#[derive(Debug, FromMeta, PartialEq, Eq)]
24struct ContextMapping {
25 field: String,
26 member: String,
27}
28
29#[derive(Debug, FromField)]
30#[darling(attributes(marshal))]
31struct StructFieldReceiver {
32 ident: Option<Ident>,
33 ty: Type,
34
35 align: Option<usize>,
36 bits: Option<usize>,
37 ctx: Option<CtxType>,
38 ctx_type: Option<Path>,
39 #[darling(multiple, rename = "ctx_member")]
40 ctx_members: Vec<ContextMapping>
41}
42
43#[derive(Debug, FromVariant)]
44#[darling(attributes(marshal))]
45struct EnumVariantReceiver {
46 ident: Ident,
47
48 tag: String
49}
50
51struct ProcessedField {
52 var_name: Ident,
53
54 ty: Type,
55 receiver: StructFieldReceiver,
56
57 context_type: TokenStream,
58 context_body: TokenStream,
59
60 get_ref: TokenStream,
61 get_ref_mut: TokenStream,
62 construct: TokenStream,
63}
64
65struct MarshalProcessedField {
66 pf: ProcessedField,
67 write_body: TokenStream,
68}
69
70struct DemarshalProcessedField {
71 pf: ProcessedField,
72 read_body: TokenStream,
73}
74
75struct UpdateProcessedField {
76 pf: ProcessedField,
77 update_body: Option<TokenStream>
78}
79
80fn process_field(our_context_type: &TokenStream, i: usize, field: &Field) -> ProcessedField {
81 let idx = syn::Index::from(i);
82
83 let receiver = StructFieldReceiver::from_field(field).unwrap();
84
85 let (accessor, var_name) = match receiver.ident.clone() {
86 Some(ident) => (quote! { #ident }, ident),
87 None => (quote! { #idx }, syn::Ident::new(format!("_{}", i).as_str(), Span::call_site()))
88 };
89
90 let ty = receiver.ty.clone();
91
92 let (ctx_ty, ctx_val) = match (receiver.bits, receiver.ctx.as_ref(), receiver.ctx_type.as_ref()) {
93 (Some(bits), _, _) => ( quote!{ binmarshal::BitSpecification::<#bits> }, quote!{ binmarshal::BitSpecification::<#bits> } ),
94 (_, Some(CtxType::Forward), _) => ( quote!{ #our_context_type }, quote!{ ctx } ),
95 (_, Some(CtxType::Coerce), Some(ctx_type)) => ( quote!{ #ctx_type }, quote!{ ctx.into() } ),
96 (_, Some(CtxType::Construct), Some(ctx_type)) => {
97 let inner = receiver.ctx_members.iter().map(|x| {
98 let field: TokenStream = parse_str(&x.field).unwrap();
99 let member: TokenStream = parse_str(&x.member).unwrap();
100 quote!{ #field: #member.clone() }
101 });
102 (quote!{ #ctx_type }, quote!{ #ctx_type { #(#inner),* } })
103 },
104 (_, None, _) => {
105 (quote!{ () }, quote!{ () })
106 },
107 _ => panic!("Invalid Context Combination")
108 };
109
110 ProcessedField {
111 get_ref: quote!{ let #var_name = &self.#accessor },
112 get_ref_mut: quote!{ let #var_name = &mut self.#accessor },
113 construct: quote!{ #var_name },
114 var_name,
115 ty,
116 receiver,
117 context_type: ctx_ty,
118 context_body: ctx_val,
119 }
120}
121
122fn process_field_marshal(our_context_type: &TokenStream, i: usize, field: Field) -> MarshalProcessedField {
123 let pf = process_field(our_context_type, i, &field);
124 let ProcessedField { var_name, ty, receiver, context_type, context_body, .. } = &pf;
125
126 let align = match receiver.align {
127 Some(align) => quote!{ writer.align(#align) },
128 None => quote!{}
129 };
130
131 let write_body = quote! {
132 {
133 #align;
134 <#ty as binmarshal::Marshal<#context_type>>::write(#var_name, writer, #context_body)
135 }?
136 };
137
138 MarshalProcessedField { pf, write_body }
139}
140
141fn process_field_demarshal(our_context_type: &TokenStream, lifetime: &Lifetime, i: usize, field: Field) -> DemarshalProcessedField {
142 let pf = process_field(our_context_type, i, &field);
143 let ProcessedField { var_name, ty, receiver, context_type, context_body, .. } = &pf;
144
145 let align = match receiver.align {
146 Some(align) => quote!{ view.align(#align) },
147 None => quote!{}
148 };
149
150 let read_body = quote! {
151 let #var_name = {
152 #align;
153 <#ty as binmarshal::Demarshal<#lifetime, #context_type>>::read(view, #context_body)
154 }?;
155 };
156
157 DemarshalProcessedField { pf, read_body }
158}
159
160fn process_field_update(our_context_type: &TokenStream, i: usize, field: Field) -> UpdateProcessedField {
161 let pf = process_field(our_context_type, i, &field);
162 let ProcessedField { var_name, receiver, .. } = &pf;
163
164 let update_body = match receiver.ctx.as_ref() {
165 Some(CtxType::Forward) => Some(quote! { #var_name.update(ctx) }),
166 Some(CtxType::Coerce) => Some(quote! {
167 let mut new_context = ctx.clone().into();
168 #var_name.update(&mut new_context);
169 *ctx = new_context.into();
170 }),
171 Some(CtxType::Construct) => {
172 let inner = receiver.ctx_members.iter().map(|x| {
173 let field: TokenStream = parse_str(&x.field).unwrap();
174 let member: TokenStream = parse_str(&x.member).unwrap();
175
176 let deref = match x.member.contains(".") {
178 true => quote!{ #member },
179 false => quote!{ *#member },
180 };
181 (quote!{ #field: #deref }, quote!{ #deref = new_context.#field })
182 });
183 let create_new_items = inner.clone().map(|x| x.0);
184 let propagate_back = inner.clone().map(|x| x.1);
185
186 let ctx_type = receiver.ctx_type.as_ref().unwrap();
187
188 Some(quote!{
189 let mut new_context = #ctx_type { #(#create_new_items),* };
190 #var_name.update(&mut new_context);
191 #(#propagate_back);*
192 })
193 },
194 _ => None
195 };
196
197 UpdateProcessedField { pf, update_body }
198}
199
200fn strip_bounds<'a, I: Iterator<Item = &'a GenericParam>>(generics: I) -> TokenStream {
201 let g = generics.map(|x| match x {
202 syn::GenericParam::Lifetime(lt) => {
203 let i = <.lifetime;
204 quote!{ #i }
205 },
206 syn::GenericParam::Type(ty) => {
207 let i = &ty.ident;
208 quote!{ #i }
209 },
210 syn::GenericParam::Const(c) => {
211 let c2 = &c.ident;
212 quote!{ #c2 }
213 },
214 });
215 quote!{ #(#g),* }
216}
217
218#[proc_macro_derive(Marshal, attributes(marshal))]
221pub fn derive_marshal(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
222 let DeriveInput {
223 attrs, vis: _, ident, generics, data
224 } = parse_macro_input!(input as DeriveInput);
225
226 let generics_without_bounds = strip_bounds(generics.params.iter());
227
228 let attrs = StructOrEnumReceiver::from_attributes(&attrs).unwrap();
229 let ctx_ty = if let Some(ctx) = attrs.ctx {
230 quote!{ #ctx }
231 } else {
232 quote! { () }
233 };
234
235 let magic_write = match attrs.magic {
236 Some(lit) => {
237 Some(quote!{
238 writer.write_magic(#lit)?
239 })
240 },
241 None => None
242 };
243
244 match data {
245 syn::Data::Struct(st) => {
246 let it = st.fields.into_iter().enumerate().map(|(i, field)| process_field_marshal(&ctx_ty, i, field));
247
248 let to_write = it.clone().map(|x| x.write_body);
249 let refs = it.clone().map(|x| x.pf.get_ref);
250
251 let out = quote! {
252 impl #generics binmarshal::Marshal<#ctx_ty> for #ident<#generics_without_bounds> {
253 fn write<W: binmarshal::rw::BitWriter>(&self, writer: &mut W, ctx: #ctx_ty) -> core::result::Result<(), binmarshal::MarshalError> {
254 #magic_write;
255
256 #(#refs;)*
257 #(#to_write;)*
258 Ok(())
259 }
260 }
261 };
262
263 out.into()
264 },
265 syn::Data::Enum(en) => {
266 let write_tag = match &attrs.tag {
267 Some(_) => quote! { }, None => {
269 let tag_type = attrs.tag_type.clone().unwrap();
271
272 let ctx_val = match attrs.tag_bits {
273 Some(bits) => quote!{ binmarshal::BitSpecification::<#bits> {} },
274 None => quote! { () }
275 };
276
277 quote! {
278 <#tag_type as binmarshal::Marshal<_>>::write(&_tag, writer, #ctx_val)?
279 }
280 }
281 };
282
283 let it = en.variants.into_iter().map(|variant| {
284 let receiver = EnumVariantReceiver::from_variant(&variant).unwrap();
285 let name = receiver.ident;
286 let tag: TokenStream = parse_str(&receiver.tag).unwrap();
287
288 let (fields, is_paren) = match variant.fields {
289 Fields::Named(named) => (named.named.into_iter().collect(), false),
290 Fields::Unnamed(unnamed) => (unnamed.unnamed.into_iter().collect(), true),
291 Fields::Unit => (vec![], false),
292 };
293
294 let processed_fields = fields.into_iter().enumerate().map(|(i, field)| process_field_marshal(&ctx_ty, i, field));
295
296 let to_write = processed_fields.clone().map(|t| t.write_body);
297 let construct = processed_fields.clone().map(|t| t.pf.construct);
298
299 let write_tag = quote!{
300 Self::#name { .. } => #tag
301 };
302
303 let write = match is_paren {
304 true => quote!{
305 Self::#name ( #(#construct),* ) => {
306 #(#to_write;)*
307 }
308 },
309 false => quote!{
310 Self::#name { #(#construct),* } => {
311 #(#to_write;)*
312 }
313 },
314 };
315
316 (write_tag, write)
317 });
318
319 let write_tag_variants = it.clone().map(|v| v.0);
320 let write_variants = it.clone().map(|v| v.1);
321
322 let out = quote! {
323 impl #generics binmarshal::Marshal<#ctx_ty> for #ident<#generics_without_bounds> {
324 #[inline(always)]
325 #[allow(unused_variables)]
326 fn write<W: binmarshal::rw::BitWriter>(&self, writer: &mut W, ctx: #ctx_ty) -> core::result::Result<(), binmarshal::MarshalError> {
327 let _tag = match &self {
328 #(#write_tag_variants),*
329 };
330 #magic_write;
331 #write_tag;
332 match self {
333 #(#write_variants),*
334 };
335
336 Ok(())
337 }
338 }
339 };
340
341 out.into()
342 },
343 syn::Data::Union(_) => panic!("Don't know how to serialise unions!"),
344 }
345}
346
347#[proc_macro_derive(Demarshal, attributes(marshal))]
350pub fn derive_demarshal(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
351 let DeriveInput {
352 attrs, vis: _, ident, generics, data
353 } = parse_macro_input!(input as DeriveInput);
354
355 let generics_inner = generics.params.iter();
356 let generics_inner = quote!{ #(#generics_inner),* };
357 let generics_without_bounds = strip_bounds(generics.params.iter());
358
359 let lifetime = generics.params.iter().find_map(|x| match x {
360 GenericParam::Lifetime(lt) => Some(lt.lifetime.clone()),
361 _ => None
362 });
363
364 let lifetime_def = match lifetime {
365 None => quote!{ 'dm, },
366 Some(_) => quote!{},
367 };
368 let lifetime = lifetime.unwrap_or(Lifetime::new("'dm", ident.span()));
369
370 let attrs = StructOrEnumReceiver::from_attributes(&attrs).unwrap();
371 let ctx_ty = if let Some(ctx) = attrs.ctx {
372 quote!{ #ctx }
373 } else {
374 quote! { () }
375 };
376
377 let magic_read = match attrs.magic {
378 Some(lit) => {
379 Some(quote!{
380 view.check_magic(#lit)?
381 })
382 },
383 None => None
384 };
385
386 match data {
387 syn::Data::Struct(st) => {
388 let it = st.fields.into_iter().enumerate().map(|(i, field)| process_field_demarshal(&ctx_ty, &lifetime, i, field));
389
390 let to_read = it.clone().map(|x| x.read_body);
391 let construct = it.clone().map(|x| x.pf.construct);
392
393 let out = quote! {
394 impl <#lifetime_def #generics_inner> binmarshal::Demarshal<#lifetime, #ctx_ty> for #ident<#generics_without_bounds> {
395 fn read(view: &mut binmarshal::rw::BitView<#lifetime>, ctx: #ctx_ty) -> core::result::Result<Self, binmarshal::MarshalError> {
396 #magic_read;
397
398 #(#to_read)*
399 Ok(Self {
400 #(#construct),*
401 })
402 }
403 }
404 };
405
406 out.into()
407 },
408 syn::Data::Enum(en) => {
409 let read_tag = match &attrs.tag {
410 Some(tag) => {
411 let in_tag: TokenStream = parse_str(&tag).unwrap();
412 quote! { let _tag = #in_tag; }
413 },
414 None => {
415 let tag_type = attrs.tag_type.clone().unwrap();
417
418 let ctx_val = match attrs.tag_bits {
419 Some(bits) => quote!{ binmarshal::BitSpecification::<#bits> {} },
420 None => quote! { () }
421 };
422
423 quote! {
424 let _tag = <#tag_type as binmarshal::Demarshal<#lifetime, _>>::read(view, #ctx_val)?;
425 }
426 }
427 };
428
429 let read_variants = en.variants.into_iter().map(|variant| {
430 let receiver = EnumVariantReceiver::from_variant(&variant).unwrap();
431 let name = receiver.ident;
432 let tag: TokenStream = parse_str(&receiver.tag).unwrap();
433
434 let (fields, is_paren) = match variant.fields {
435 Fields::Named(named) => (named.named.into_iter().collect(), false),
436 Fields::Unnamed(unnamed) => (unnamed.unnamed.into_iter().collect(), true),
437 Fields::Unit => (vec![], false),
438 };
439
440 let processed_fields = fields.into_iter().enumerate().map(|(i, field)| process_field_demarshal(&ctx_ty, &lifetime, i, field));
441
442 let to_read = processed_fields.clone().map(|t| t.read_body);
443 let construct = processed_fields.clone().map(|t| t.pf.construct);
444
445 let read = match is_paren {
446 true => quote!{
447 (#tag) => {
448 #(#to_read)*
449 Ok(Self::#name(#(#construct),*))
450 }
451 },
452 false => quote!{
453 (#tag) => {
454 #(#to_read)*
455 Ok(Self::#name { #(#construct),*})
456 }
457 },
458 };
459
460 read
461 });
462
463 let out = quote! {
464 impl<#lifetime_def #generics_inner> binmarshal::Demarshal<#lifetime, #ctx_ty> for #ident<#generics_without_bounds> {
465 #[inline(always)]
466 #[allow(unused_variables)]
467 fn read(view: &mut binmarshal::rw::BitView<#lifetime>, ctx: #ctx_ty) -> core::result::Result<Self, binmarshal::MarshalError> {
468 #magic_read;
469 #read_tag;
470 match _tag {
471 #(#read_variants),*,
472 _ => Err(binmarshal::MarshalError::IllegalTag)
473 }
474 }
475 }
476 };
477
478 out.into()
479 },
480 syn::Data::Union(_) => panic!("Don't know how to serialise unions!"),
481 }
482}
483
484#[proc_macro_derive(MarshalUpdate, attributes(marshal))]
487pub fn derive_marshal_update(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
488 let DeriveInput {
489 attrs, vis: _, ident, generics, data
490 } = parse_macro_input!(input as DeriveInput);
491
492 let generics_without_bounds = strip_bounds(generics.params.iter());
493
494 let attrs = StructOrEnumReceiver::from_attributes(&attrs).unwrap();
495 let ctx_ty = if let Some(ctx) = attrs.ctx {
496 quote!{ #ctx }
497 } else {
498 quote! { () }
499 };
500
501 match data {
502 syn::Data::Struct(st) => {
503 let it = st.fields.into_iter().enumerate().map(|(i, field)| process_field_update(&ctx_ty, i, field));
504
505 let to_update = it.clone().map(|x| x.update_body);
506 let get_ref_mut = it.clone().map(|x| x.pf.get_ref_mut);
507
508 let out = quote! {
509 impl #generics binmarshal::MarshalUpdate<#ctx_ty> for #ident<#generics_without_bounds> {
510 fn update(&mut self, ctx: &mut #ctx_ty) {
511 #(#get_ref_mut;)*
512 #(#to_update;)*
513 }
514 }
515 };
516
517 out.into()
518 },
519 syn::Data::Enum(en) => {
520 let update_tag = match &attrs.tag {
521 Some(tag) => {
522 let in_tag: TokenStream = parse_str(&tag).unwrap();
523 Some(in_tag)
524 },
525 None => None
526 };
527
528 let update_variants = en.variants.into_iter().map(|variant| {
529 let receiver = EnumVariantReceiver::from_variant(&variant).unwrap();
530 let name = receiver.ident;
531 let tag: TokenStream = parse_str(&receiver.tag).unwrap();
532
533 let inner_update: TokenStream = update_tag.clone().map(|x| quote!{ #x = #tag }).unwrap_or(quote!{ });
534
535 let (fields, is_paren) = match variant.fields {
536 Fields::Named(named) => (named.named.into_iter().collect(), false),
537 Fields::Unnamed(unnamed) => (unnamed.unnamed.into_iter().collect(), true),
538 Fields::Unit => (vec![], false),
539 };
540
541 let processed_fields = fields.into_iter().enumerate().map(|(i, field)| process_field_update(&ctx_ty, i, field));
542
543 let to_update = processed_fields.clone().map(|t| t.update_body);
544 let construct = processed_fields.clone().map(|t| t.pf.construct);
545
546 let update = match is_paren {
547 true => quote! {
548 Self::#name(#(#construct),*) => {
549 #inner_update;
550 #(#to_update;)*
551 }
552 },
553 false => quote! {
554 Self::#name { #(#construct),* } => {
555 #inner_update;
556 #(#to_update;)*
557 }
558 },
559 };
560
561 update
562 });
563
564 let out = quote! {
565 impl #generics binmarshal::MarshalUpdate<#ctx_ty> for #ident<#generics_without_bounds> {
566 fn update(&mut self, ctx: &mut #ctx_ty) {
567 match self {
568 #(#update_variants),*
569 }
570 }
571 }
572 };
573
574 out.into()
575 },
576 syn::Data::Union(_) => panic!("Don't know how to serialise unions!"),
577 }
578}
579
580#[derive(Debug, FromAttributes)]
583#[darling(attributes(proxy))]
584struct ProxyReceiver {
585 no_clone: Option<bool>
586}
587
588#[proc_macro_derive(Proxy, attributes(proxy))]
589pub fn derive_proxy(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
590 let DeriveInput {
591 attrs, vis: _, ident, generics, data
592 } = parse_macro_input!(input as DeriveInput);
593
594 let attrs = ProxyReceiver::from_attributes(&attrs).unwrap();
595
596 let clone = !(attrs.no_clone == Some(true));
597
598 let _generics_inner = &generics.params;
599
600 match data {
601 syn::Data::Struct(st) => {
602 match st.fields {
603 Fields::Unnamed(fields) => {
604 let field = &fields.unnamed[0];
605 let mut extra_default = vec![];
606 let mut extra_clone = vec![];
607
608 for (_, _) in fields.unnamed.iter().enumerate().skip(1) {
609 extra_default.push(quote! { Default::default() });
610 }
611
612 for (i, _) in fields.unnamed.iter().enumerate().skip(1) {
613 let i = syn::Index::from(i);
614 extra_clone.push(quote! { self.#i.clone() });
615 }
616
617 let ft = &field.ty;
618
619 let ident_generics = generics.params.iter().map(|g| match g {
620 syn::GenericParam::Lifetime(lt) => {
621 let lt = <.lifetime;
622 quote!{ #lt }
623 },
624 syn::GenericParam::Type(ty) => {
625 let t = &ty.ident;
626 quote!{ #t }
627 },
628 syn::GenericParam::Const(c) => {
629 let t = &c.ident;
630 quote!{ #t }
631 },
632 });
633
634 let (lt, gt) = (generics.lt_token, generics.gt_token);
635 let ident_generics = quote! {
636 #lt #(#ident_generics),* #gt
637 };
638
639 #[allow(unused_mut)]
640 let mut out = quote! {
641 impl #generics From<#ft> for #ident #ident_generics {
642 fn from(inner: #ft) -> Self {
643 Self(inner, #(#extra_default),*)
644 }
645 }
646
647 impl #generics Deref for #ident #ident_generics {
648 type Target = #ft;
649
650 fn deref(&self) -> &Self::Target { &self.0 }
651 }
652
653 impl #generics DerefMut for #ident #ident_generics {
654 fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
655 }
656
657 impl #generics core::fmt::Debug for #ident #ident_generics where #ft: core::fmt::Debug {
658 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.0.fmt(f) }
659 }
660
661 impl #generics PartialEq for #ident #ident_generics where #ft: PartialEq {
662 #[inline]
663 fn eq(&self, other: &Self) -> bool {
664 PartialEq::eq(&self.0, &other.0)
665 }
666 #[inline]
667 fn ne(&self, other: &Self) -> bool {
668 PartialEq::ne(&self.0, &other.0)
669 }
670 }
671
672 impl #generics Eq for #ident #ident_generics where #ft: Eq { }
673 };
674
675 if clone {
676 out = quote!{
677 #out
678
679 impl #generics Clone for #ident #ident_generics where #ft: Clone {
680 fn clone(&self) -> Self { Self(self.0.clone(),#(#extra_clone),*) }
681 }
682 }
683 }
684
685 #[cfg(feature = "serde")]
686 {
687 out = quote!{
688 #out
689
690 impl #generics serde::Serialize for #ident #ident_generics where #ft: serde::Serialize {
691 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
692 where
693 S: serde::Serializer,
694 {
695 self.0.serialize(serializer)
696 }
697 }
698
699 impl<'de, #_generics_inner> serde::Deserialize<'de> for #ident #ident_generics where #ft: serde::Deserialize<'de> {
700 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
701 where
702 D: serde::Deserializer<'de>,
703 {
704 <#ft as serde::Deserialize<'de>>::deserialize::<D>(deserializer).map(|x| Self(x, #(#extra_default),*))
705 }
706 }
707 };
708 }
709
710 #[cfg(feature = "schema")]
711 {
712 out = quote!{
713 #out
714
715 impl #generics schemars::JsonSchema for #ident #ident_generics where #ft: schemars::JsonSchema {
716 fn schema_name() -> alloc::string::String {
717 <#ft as schemars::JsonSchema>::schema_name()
718 }
719
720 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
721 <#ft as schemars::JsonSchema>::json_schema(gen)
722 }
723
724 fn is_referenceable() -> bool {
725 <#ft as schemars::JsonSchema>::is_referenceable()
726 }
727 }
728 }
729 }
730
731 out.into()
732 },
733 _ => panic!("Proxy only supported on newtype structs"),
734 }
735 },
736 syn::Data::Enum(_) => panic!("Proxy not supported on Enum types"),
737 syn::Data::Union(_) => panic!("Proxy not supported on Union types"),
738 }
739}